Merge pull request #665 from powdr-labs/extract_evaluator_processor

Extract evaluator and processor.
This commit is contained in:
chriseth
2023-10-09 15:24:19 +00:00
committed by GitHub
4 changed files with 201 additions and 171 deletions

View File

@@ -1,10 +1,9 @@
use std::collections::HashMap;
use ast::analyzed::{Analyzed, Expression, FunctionValueDefinition, Reference};
use ast::parsed::{FunctionCall, MatchArm, MatchPattern};
use ast::{evaluate_binary_operation, evaluate_unary_operation};
use ast::analyzed::{Analyzed, FunctionValueDefinition};
use itertools::Itertools;
use number::{DegreeType, FieldElement};
use pil_analyzer::evaluator::Evaluator;
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
/// Generates the constant polynomial values for all constant polynomials
@@ -42,18 +41,21 @@ fn generate_values<T: FieldElement>(
.into_par_iter()
.map(|i| {
Evaluator {
analyzed,
constants: &analyzed.constants,
definitions: &analyzed.definitions,
variables: &[i.into()],
other_constants,
function_cache: other_constants,
}
.evaluate(body)
.unwrap()
})
.collect(),
FunctionValueDefinition::Array(values) => {
let evaluator = Evaluator {
analyzed,
constants: &analyzed.constants,
definitions: &analyzed.definitions,
variables: &[],
other_constants,
function_cache: other_constants,
};
let values: Vec<_> = values
.iter()
@@ -61,7 +63,7 @@ fn generate_values<T: FieldElement>(
let items = elements
.pattern()
.iter()
.map(|v| evaluator.evaluate(v))
.map(|v| evaluator.evaluate(v).unwrap())
.collect::<Vec<_>>();
items
@@ -81,55 +83,6 @@ fn generate_values<T: FieldElement>(
}
}
struct Evaluator<'a, T> {
analyzed: &'a Analyzed<T>,
other_constants: &'a HashMap<&'a str, Vec<T>>,
variables: &'a [T],
}
impl<'a, T: FieldElement> Evaluator<'a, T> {
fn evaluate(&self, expr: &Expression<T>) -> T {
match expr {
Expression::Constant(name) => self.analyzed.constants[name],
Expression::Reference(Reference::LocalVar(i, _name)) => self.variables[*i as usize],
Expression::Reference(Reference::Poly(_)) => todo!(),
Expression::PublicReference(_) => todo!(),
Expression::Number(n) => *n,
Expression::String(_) => panic!(),
Expression::Tuple(_) => panic!(),
Expression::ArrayLiteral(_) => panic!(),
Expression::BinaryOperation(left, op, right) => {
evaluate_binary_operation(self.evaluate(left), *op, self.evaluate(right))
}
Expression::UnaryOperation(op, expr) => {
evaluate_unary_operation(*op, self.evaluate(expr))
}
Expression::LambdaExpression(_) => panic!(),
Expression::FunctionCall(FunctionCall { id, arguments }) => {
let arg_values = arguments
.iter()
.map(|a| self.evaluate(a))
.collect::<Vec<_>>();
assert!(arg_values.len() == 1);
let values = &self.other_constants[id.as_str()];
values[arg_values[0].to_degree() as usize % values.len()]
}
Expression::MatchExpression(scrutinee, arms) => {
let v = self.evaluate(scrutinee);
arms.iter()
.find_map(|MatchArm { pattern, value }| match pattern {
MatchPattern::Pattern(p) => {
(self.evaluate(p) == v).then(|| self.evaluate(value))
}
MatchPattern::CatchAll => Some(self.evaluate(value)),
})
.expect("No arm matched the value {v}")
}
Expression::FreeInput(_) => panic!(),
}
}
}
#[cfg(test)]
mod test {
use number::GoldilocksField;

View File

@@ -0,0 +1,94 @@
use std::collections::HashMap;
use ast::{
analyzed::{Analyzed, Expression, FunctionValueDefinition, Reference, Symbol},
evaluate_binary_operation, evaluate_unary_operation,
parsed::{FunctionCall, MatchArm, MatchPattern},
};
use number::FieldElement;
/// Evaluates an expression to a single value.
pub fn evaluate_expression<T: FieldElement>(
analyzed: &Analyzed<T>,
expression: &Expression<T>,
) -> Result<T, String> {
Evaluator {
constants: &analyzed.constants,
definitions: &analyzed.definitions,
function_cache: &Default::default(),
variables: &[],
}
.evaluate(expression)
}
pub struct Evaluator<'a, T> {
pub constants: &'a HashMap<String, T>,
pub definitions: &'a HashMap<String, (Symbol, Option<FunctionValueDefinition<T>>)>,
/// Contains full value tables of functions (columns) we already evaluated.
pub function_cache: &'a HashMap<&'a str, Vec<T>>,
pub variables: &'a [T],
}
impl<'a, T: FieldElement> Evaluator<'a, T> {
pub fn evaluate(&self, expr: &Expression<T>) -> Result<T, String> {
match expr {
Expression::Constant(name) => Ok(self.constants[name]),
Expression::Reference(Reference::LocalVar(i, _name)) => Ok(self.variables[*i as usize]),
Expression::Reference(Reference::Poly(poly)) => {
if !poly.next && poly.index.is_none() {
let name = poly.name.to_owned();
if let Some(value) = self.constants.get(&name) {
Ok(*value)
} else {
let (_, value) = &self.definitions[&name];
match value {
Some(FunctionValueDefinition::Expression(value)) => {
self.evaluate(value)
}
_ => Err("Cannot evaluate function values".to_string()),
}
}
} else {
Err("Cannot evaluate arrays or next references.".to_string())
}
}
Expression::PublicReference(r) => Err(format!("Cannot evaluate public reference: {r}")),
Expression::Number(n) => Ok(*n),
Expression::String(_) => Err("Cannot evaluate string literal.".to_string()),
Expression::Tuple(_) => Err("Cannot evaluate tuple.".to_string()),
Expression::ArrayLiteral(_) => Err("Cannot evaluate array literal.".to_string()),
Expression::BinaryOperation(left, op, right) => Ok(evaluate_binary_operation(
self.evaluate(left)?,
*op,
self.evaluate(right)?,
)),
Expression::UnaryOperation(op, expr) => {
Ok(evaluate_unary_operation(*op, self.evaluate(expr)?))
}
Expression::LambdaExpression(_) => {
Err("Cannot evaluate lambda expression.".to_string())
}
Expression::FunctionCall(FunctionCall { id, arguments }) => {
let arg_values = arguments
.iter()
.map(|a| self.evaluate(a))
.collect::<Result<Vec<_>, _>>()?;
assert!(arg_values.len() == 1);
let values = &self.function_cache[id.as_str()];
Ok(values[arg_values[0].to_degree() as usize % values.len()])
}
Expression::MatchExpression(scrutinee, arms) => {
let v = self.evaluate(scrutinee);
arms.iter()
.find_map(|MatchArm { pattern, value }| match pattern {
MatchPattern::Pattern(p) => {
(self.evaluate(p) == v).then(|| self.evaluate(value))
}
MatchPattern::CatchAll => Some(self.evaluate(value)),
})
.expect("No arm matched the value {v}")
}
Expression::FreeInput(_) => Err("Cannot evaluate free input.".to_string()),
}
}
}

View File

@@ -1,5 +1,6 @@
#![deny(clippy::print_stdout)]
pub mod evaluator;
pub mod pil_analyzer;
use std::path::Path;

View File

@@ -6,8 +6,8 @@ use analysis::MacroExpander;
use ast::parsed::visitor::ExpressionVisitable;
use ast::parsed::{
self, ArrayExpression, ArrayLiteral, BinaryOperator, FunctionDefinition, LambdaExpression,
MatchArm, MatchPattern, PilStatement, PolynomialName, UnaryOperator,
self, ArrayExpression, ArrayLiteral, FunctionDefinition, LambdaExpression, MatchArm,
MatchPattern, PilStatement, PolynomialName,
};
use number::{DegreeType, FieldElement};
@@ -17,6 +17,8 @@ use ast::analyzed::{
StatementIdentifier, Symbol, SymbolKind,
};
use crate::evaluator::Evaluator;
pub fn process_pil_file<T: FieldElement>(path: &Path) -> Analyzed<T> {
let mut analyzer = PILAnalyzer::new();
analyzer.process_file(path);
@@ -46,7 +48,6 @@ struct PILAnalyzer<T> {
current_file: PathBuf,
symbol_counters: BTreeMap<SymbolKind, u64>,
identity_counter: HashMap<IdentityKind, u64>,
local_variables: HashMap<String, u64>,
macro_expander: MacroExpander<T>,
}
@@ -196,7 +197,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
);
}
PilStatement::ConstantDefinition(_start, name, value) => {
self.handle_constant_definition(name, self.evaluate_expression(&value).unwrap())
self.handle_constant_definition(name, self.evaluate_expression(value).unwrap())
}
PilStatement::LetStatement(start, name, value) => {
self.handle_generic_definition(start, name, value)
@@ -255,7 +256,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
);
}
_ => {
if let Some(constant) = self.evaluate_expression(&value) {
if let Ok(constant) = self.evaluate_expression(value.clone()) {
// Value evaluates to a constant number => treat it as a constant
self.handle_constant_definition(name.to_string(), constant);
} else {
@@ -288,25 +289,25 @@ impl<T: FieldElement> PILAnalyzer<T> {
PilStatement::PlookupIdentity(start, key, haystack) => (
start,
IdentityKind::Plookup,
self.process_selected_expression(key),
self.process_selected_expression(haystack),
ExpressionProcessor::new(self).process_selected_expression(key),
ExpressionProcessor::new(self).process_selected_expression(haystack),
),
PilStatement::PermutationIdentity(start, left, right) => (
start,
IdentityKind::Permutation,
self.process_selected_expression(left),
self.process_selected_expression(right),
ExpressionProcessor::new(self).process_selected_expression(left),
ExpressionProcessor::new(self).process_selected_expression(right),
),
PilStatement::ConnectIdentity(start, left, right) => (
start,
IdentityKind::Connect,
SelectedExpressions {
selector: None,
expressions: self.process_expressions(left),
expressions: ExpressionProcessor::new(self).process_expressions(left),
},
SelectedExpressions {
selector: None,
expressions: self.process_expressions(right),
expressions: ExpressionProcessor::new(self).process_expressions(right),
},
),
// TODO at some point, these should all be caught by the type checker.
@@ -335,7 +336,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
fn handle_namespace(&mut self, name: String, degree: ::ast::parsed::Expression<T>) {
// TODO: the polynomial degree should be handled without going through a field element. This requires having types in Expression
self.polynomial_degree = self.evaluate_expression(&degree).unwrap().to_degree();
self.polynomial_degree = self.evaluate_expression(degree).unwrap().to_degree();
self.namespace = name;
}
@@ -366,7 +367,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
) -> u64 {
let have_array_size = array_size.is_some();
let length = array_size
.map(|l| self.evaluate_expression(&l).unwrap())
.map(|l| self.evaluate_expression(l).unwrap())
.map(|l| l.to_degree());
if length.is_some() {
assert!(value.is_none());
@@ -397,16 +398,21 @@ impl<T: FieldElement> PILAnalyzer<T> {
FunctionDefinition::Mapping(params, expr) => {
assert!(!have_array_size);
assert!(symbol_kind == SymbolKind::Poly(PolynomialType::Constant));
FunctionValueDefinition::Mapping(self.process_function(&params, expr))
FunctionValueDefinition::Mapping(
ExpressionProcessor::new(self).process_function(&params, expr),
)
}
FunctionDefinition::Query(params, expr) => {
assert!(!have_array_size);
assert_eq!(symbol_kind, SymbolKind::Poly(PolynomialType::Committed));
FunctionValueDefinition::Query(self.process_function(&params, expr))
FunctionValueDefinition::Query(
ExpressionProcessor::new(self).process_function(&params, expr),
)
}
FunctionDefinition::Array(value) => {
let size = value.solve(self.polynomial_degree);
let expression = self.process_array_expression(value, size);
let expression =
ExpressionProcessor::new(self).process_array_expression(value, size);
assert_eq!(
expression.iter().map(|e| e.size()).sum::<DegreeType>(),
self.polynomial_degree
@@ -424,33 +430,6 @@ impl<T: FieldElement> PILAnalyzer<T> {
id
}
fn process_function(
&mut self,
params: &[String],
expression: ::ast::parsed::Expression<T>,
) -> Expression<T> {
let previous_local_vars = std::mem::take(&mut self.local_variables);
assert!(self.local_variables.is_empty());
self.local_variables = params
.iter()
.enumerate()
.map(|(i, p)| (p.clone(), i as u64))
.collect();
// Re-add the outer local variables if we do not overwrite them
// and increase their index by the number of parameters.
// TODO re-evaluate if this mechanism makes sense as soon as we properly
// support nested functions and closures.
for (name, index) in &previous_local_vars {
self.local_variables
.entry(name.clone())
.or_insert(index + params.len() as u64);
}
let processed_value = self.process_expression(expression);
self.local_variables = previous_local_vars;
processed_value
}
fn handle_public_declaration(
&mut self,
source: SourceRef,
@@ -465,8 +444,9 @@ impl<T: FieldElement> PILAnalyzer<T> {
id,
source,
name: name.to_string(),
polynomial: self.process_namespaced_polynomial_reference(poly),
index: self.evaluate_expression(&index).unwrap().to_degree(),
polynomial: ExpressionProcessor::new(self)
.process_namespaced_polynomial_reference(poly),
index: self.evaluate_expression(index).unwrap().to_degree(),
},
);
self.source_order
@@ -485,7 +465,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
id
}
fn namespaced(&self, name: &str) -> String {
pub fn namespaced(&self, name: &str) -> String {
self.namespaced_ref(&None, name)
}
@@ -499,7 +479,39 @@ impl<T: FieldElement> PILAnalyzer<T> {
}
}
fn process_selected_expression(
fn evaluate_expression(&self, expr: ::ast::parsed::Expression<T>) -> Result<T, String> {
Evaluator {
constants: &self.constants,
definitions: &self.definitions,
function_cache: &Default::default(),
variables: &[],
}
.evaluate(&self.process_expression(expr))
}
fn process_expression(&self, expr: ::ast::parsed::Expression<T>) -> Expression<T> {
ExpressionProcessor::new(self).process_expression(expr)
}
}
/// The ExpressionProcessor turns parsed expressions into analyzed expressions.
/// Its main job is to resolve references:
/// It turns simple references into fully namespaced references and resolves local function variables.
/// It also evaluates expressions that are required to be compile-time constant.
struct ExpressionProcessor<'a, T> {
analyzer: &'a PILAnalyzer<T>,
local_variables: HashMap<String, u64>,
}
impl<'a, T: FieldElement> ExpressionProcessor<'a, T> {
fn new(analyzer: &'a PILAnalyzer<T>) -> Self {
Self {
analyzer,
local_variables: Default::default(),
}
}
pub fn process_selected_expression(
&mut self,
expr: ::ast::parsed::SelectedExpressions<T>,
) -> SelectedExpressions<T> {
@@ -509,7 +521,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
}
}
fn process_array_expression(
pub fn process_array_expression(
&mut self,
array_expression: ::ast::parsed::ArrayExpression<T>,
size: DegreeType,
@@ -538,7 +550,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
}
}
fn process_expressions(
pub fn process_expressions(
&mut self,
exprs: Vec<::ast::parsed::Expression<T>>,
) -> Vec<Expression<T>> {
@@ -548,7 +560,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
.collect()
}
fn process_expression(&mut self, expr: ::ast::parsed::Expression<T>) -> Expression<T> {
pub fn process_expression(&mut self, expr: ::ast::parsed::Expression<T>) -> Expression<T> {
use ast::parsed::Expression as PExpression;
match expr {
PExpression::Constant(name) => Expression::Constant(name),
@@ -586,7 +598,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
Expression::UnaryOperation(op, Box::new(self.process_expression(*value)))
}
PExpression::FunctionCall(c) => Expression::FunctionCall(parsed::FunctionCall {
id: self.namespaced(&c.id),
id: self.analyzer.namespaced(&c.id),
arguments: self.process_expressions(c.arguments),
}),
PExpression::MatchExpression(scrutinee, arms) => Expression::MatchExpression(
@@ -607,16 +619,43 @@ impl<T: FieldElement> PILAnalyzer<T> {
}
}
fn process_namespaced_polynomial_reference(
&self,
fn process_function(
&mut self,
params: &[String],
expression: ::ast::parsed::Expression<T>,
) -> Expression<T> {
let previous_local_vars = std::mem::take(&mut self.local_variables);
assert!(self.local_variables.is_empty());
self.local_variables = params
.iter()
.enumerate()
.map(|(i, p)| (p.clone(), i as u64))
.collect();
// Re-add the outer local variables if we do not overwrite them
// and increase their index by the number of parameters.
// TODO re-evaluate if this mechanism makes sense as soon as we properly
// support nested functions and closures.
for (name, index) in &previous_local_vars {
self.local_variables
.entry(name.clone())
.or_insert(index + params.len() as u64);
}
let processed_value = self.process_expression(expression);
self.local_variables = previous_local_vars;
processed_value
}
pub fn process_namespaced_polynomial_reference(
&mut self,
poly: ::ast::parsed::NamespacedPolynomialReference<T>,
) -> PolynomialReference {
let index = poly
.index()
.as_ref()
.map(|i| self.evaluate_expression(i).unwrap())
.map(|i| self.analyzer.evaluate_expression(*i.clone()).unwrap())
.map(|i| i.to_degree());
let name = self.namespaced_ref(poly.namespace(), poly.name());
let name = self.analyzer.namespaced_ref(poly.namespace(), poly.name());
PolynomialReference {
name,
poly_id: None,
@@ -625,8 +664,8 @@ impl<T: FieldElement> PILAnalyzer<T> {
}
}
fn process_shifted_polynomial_reference(
&self,
pub fn process_shifted_polynomial_reference(
&mut self,
poly: ::ast::parsed::ShiftedPolynomialReference<T>,
) -> PolynomialReference {
PolynomialReference {
@@ -634,63 +673,6 @@ impl<T: FieldElement> PILAnalyzer<T> {
..self.process_namespaced_polynomial_reference(poly.into_namespaced())
}
}
fn evaluate_expression(&self, expr: &::ast::parsed::Expression<T>) -> Option<T> {
use ast::parsed::Expression::*;
match expr {
Constant(name) => Some(
*self
.constants
.get(name)
.unwrap_or_else(|| panic!("Constant {name} not found.")),
),
Reference(name) => {
// TODO this whole mechanism should be replaced by a generic "reference"
// type plus operators.
if !name.shift() && name.namespace().is_none() {
// See if it might be a constant
self.constants.get(&name.name().to_owned()).cloned()
} else {
None
}
}
PublicReference(_) => None,
Number(n) => Some(*n),
String(_) => None,
Tuple(_) => None,
ArrayLiteral(_) => None,
LambdaExpression(_) => None,
BinaryOperation(left, op, right) => self.evaluate_binary_operation(left, *op, right),
UnaryOperation(op, value) => self.evaluate_unary_operation(*op, value),
FunctionCall(_) => None,
FreeInput(_) => panic!(),
MatchExpression(_, _) => None,
}
}
fn evaluate_binary_operation(
&self,
left: &::ast::parsed::Expression<T>,
op: BinaryOperator,
right: &::ast::parsed::Expression<T>,
) -> Option<T> {
Some(ast::evaluate_binary_operation(
self.evaluate_expression(left)?,
op,
self.evaluate_expression(right)?,
))
}
fn evaluate_unary_operation(
&self,
op: UnaryOperator,
value: &::ast::parsed::Expression<T>,
) -> Option<T> {
Some(ast::evaluate_unary_operation(
op,
self.evaluate_expression(value)?,
))
}
}
pub fn inline_intermediate_polynomials<T: Copy>(analyzed: &Analyzed<T>) -> Vec<Identity<T>> {