Merge pull request #794 from powdr-labs/extract_expr_processor

Extract expression processor.
This commit is contained in:
Leo
2023-11-27 17:11:22 +00:00
committed by GitHub
3 changed files with 215 additions and 185 deletions

View File

@@ -0,0 +1,181 @@
use std::collections::HashMap;
use ast::{
analyzed::{Expression, PolynomialReference, Reference, RepeatedArray},
parsed::{
self, ArrayExpression, ArrayLiteral, LambdaExpression, MatchArm, MatchPattern,
NamespacedPolynomialReference, SelectedExpressions,
},
};
use number::DegreeType;
/// The ExpressionProcessor turns parsed expressions into analyzed expressions.
/// Its main job is to resolve references:
/// It turns simple references into fully namespaced references and resolves local function variables.
/// It also evaluates expressions that are required to be compile-time constant.
pub struct ExpressionProcessor<R: ReferenceResolver> {
resolver: R,
local_variables: HashMap<String, u64>,
}
pub trait ReferenceResolver {
/// Turns a reference to a name with an optional namespace to an absolute name.
fn resolve(&self, namespace: &Option<String>, name: &str) -> String;
}
impl<R: ReferenceResolver> ExpressionProcessor<R> {
pub fn new(resolver: R) -> Self {
Self {
resolver,
local_variables: Default::default(),
}
}
pub fn process_selected_expression<T>(
&mut self,
expr: SelectedExpressions<parsed::Expression<T>>,
) -> SelectedExpressions<Expression<T>> {
SelectedExpressions {
selector: expr.selector.map(|e| self.process_expression(e)),
expressions: self.process_expressions(expr.expressions),
}
}
pub fn process_array_expression<T>(
&mut self,
array_expression: ::ast::parsed::ArrayExpression<T>,
size: DegreeType,
) -> Vec<RepeatedArray<T>> {
match array_expression {
ArrayExpression::Value(expressions) => {
let values = self.process_expressions(expressions);
let size = values.len() as DegreeType;
vec![RepeatedArray::new(values, size)]
}
ArrayExpression::RepeatedValue(expressions) => {
if size == 0 {
vec![]
} else {
vec![RepeatedArray::new(
self.process_expressions(expressions),
size,
)]
}
}
ArrayExpression::Concat(left, right) => self
.process_array_expression(*left, size)
.into_iter()
.chain(self.process_array_expression(*right, size))
.collect(),
}
}
pub fn process_expressions<T>(
&mut self,
exprs: Vec<parsed::Expression<T>>,
) -> Vec<Expression<T>> {
exprs
.into_iter()
.map(|e| self.process_expression(e))
.collect()
}
pub fn process_expression<T>(&mut self, expr: parsed::Expression<T>) -> Expression<T> {
use parsed::Expression as PExpression;
match expr {
PExpression::Reference(poly) => Expression::Reference(self.process_reference(poly)),
PExpression::PublicReference(name) => Expression::PublicReference(name),
PExpression::Number(n) => Expression::Number(n),
PExpression::String(value) => Expression::String(value),
PExpression::Tuple(items) => Expression::Tuple(self.process_expressions(items)),
PExpression::ArrayLiteral(ArrayLiteral { items }) => {
Expression::ArrayLiteral(ArrayLiteral {
items: self.process_expressions(items),
})
}
PExpression::LambdaExpression(LambdaExpression { params, body }) => {
let body = Box::new(self.process_function(&params, *body));
Expression::LambdaExpression(LambdaExpression { params, body })
}
PExpression::BinaryOperation(left, op, right) => Expression::BinaryOperation(
Box::new(self.process_expression(*left)),
op,
Box::new(self.process_expression(*right)),
),
PExpression::UnaryOperation(op, value) => {
Expression::UnaryOperation(op, Box::new(self.process_expression(*value)))
}
PExpression::IndexAccess(index_access) => {
Expression::IndexAccess(parsed::IndexAccess {
array: Box::new(self.process_expression(*index_access.array)),
index: Box::new(self.process_expression(*index_access.index)),
})
}
PExpression::FunctionCall(c) => Expression::FunctionCall(parsed::FunctionCall {
function: Box::new(self.process_expression(*c.function)),
arguments: self.process_expressions(c.arguments),
}),
PExpression::MatchExpression(scrutinee, arms) => Expression::MatchExpression(
Box::new(self.process_expression(*scrutinee)),
arms.into_iter()
.map(|MatchArm { pattern, value }| MatchArm {
pattern: match pattern {
MatchPattern::CatchAll => MatchPattern::CatchAll,
MatchPattern::Pattern(e) => {
MatchPattern::Pattern(self.process_expression(e))
}
},
value: self.process_expression(value),
})
.collect(),
),
PExpression::FreeInput(_) => panic!(),
}
}
fn process_reference(&mut self, reference: NamespacedPolynomialReference) -> Reference {
if reference.namespace.is_none() && self.local_variables.contains_key(&reference.name) {
let id = self.local_variables[&reference.name];
Reference::LocalVar(id, reference.name.to_string())
} else {
Reference::Poly(self.process_namespaced_polynomial_reference(reference))
}
}
pub fn process_function<T>(
&mut self,
params: &[String],
expression: ::ast::parsed::Expression<T>,
) -> Expression<T> {
let previous_local_vars = std::mem::take(&mut self.local_variables);
assert!(self.local_variables.is_empty());
self.local_variables = params
.iter()
.enumerate()
.map(|(i, p)| (p.clone(), i as u64))
.collect();
// Re-add the outer local variables if we do not overwrite them
// and increase their index by the number of parameters.
// TODO re-evaluate if this mechanism makes sense as soon as we properly
// support nested functions and closures.
for (name, index) in &previous_local_vars {
self.local_variables
.entry(name.clone())
.or_insert(index + params.len() as u64);
}
let processed_value = self.process_expression(expression);
self.local_variables = previous_local_vars;
processed_value
}
pub fn process_namespaced_polynomial_reference(
&mut self,
poly: ::ast::parsed::NamespacedPolynomialReference,
) -> PolynomialReference {
PolynomialReference {
name: self.resolver.resolve(&poly.namespace, &poly.name),
poly_id: None,
}
}
}

View File

@@ -2,6 +2,7 @@
mod condenser;
pub mod evaluator;
pub mod expression_processor;
pub mod pil_analyzer;
use std::path::Path;

View File

@@ -6,19 +6,18 @@ use analysis::MacroExpander;
use ast::parsed::visitor::ExpressionVisitable;
use ast::parsed::{
self, ArrayExpression, ArrayLiteral, FunctionDefinition, LambdaExpression, MatchArm,
MatchPattern, NamespacedPolynomialReference, PilStatement, PolynomialName, SelectedExpressions,
self, FunctionDefinition, LambdaExpression, PilStatement, PolynomialName, SelectedExpressions,
};
use number::{DegreeType, FieldElement};
use ast::analyzed::{
AlgebraicExpression, Analyzed, Expression, FunctionValueDefinition, Identity, IdentityKind,
PolynomialReference, PolynomialType, PublicDeclaration, Reference, RepeatedArray, SourceRef,
StatementIdentifier, Symbol, SymbolKind,
PolynomialType, PublicDeclaration, SourceRef, StatementIdentifier, Symbol, SymbolKind,
};
use crate::evaluator::EvalError;
use crate::{condenser, evaluator};
use crate::expression_processor::ReferenceResolver;
use crate::{condenser, evaluator, expression_processor::ExpressionProcessor};
pub fn process_pil_file<T: FieldElement>(path: &Path) -> Analyzed<T> {
let mut analyzer = PILAnalyzer::new();
@@ -258,25 +257,28 @@ impl<T: FieldElement> PILAnalyzer<T> {
PilStatement::PlookupIdentity(start, key, haystack) => (
start,
IdentityKind::Plookup,
ExpressionProcessor::new(self).process_selected_expression(key),
ExpressionProcessor::new(self).process_selected_expression(haystack),
self.expression_processor().process_selected_expression(key),
self.expression_processor()
.process_selected_expression(haystack),
),
PilStatement::PermutationIdentity(start, left, right) => (
start,
IdentityKind::Permutation,
ExpressionProcessor::new(self).process_selected_expression(left),
ExpressionProcessor::new(self).process_selected_expression(right),
self.expression_processor()
.process_selected_expression(left),
self.expression_processor()
.process_selected_expression(right),
),
PilStatement::ConnectIdentity(start, left, right) => (
start,
IdentityKind::Connect,
SelectedExpressions {
selector: None,
expressions: ExpressionProcessor::new(self).process_expressions(left),
expressions: self.expression_processor().process_expressions(left),
},
SelectedExpressions {
selector: None,
expressions: ExpressionProcessor::new(self).process_expressions(right),
expressions: self.expression_processor().process_expressions(right),
},
),
// TODO at some point, these should all be caught by the type checker.
@@ -376,7 +378,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
FunctionDefinition::Query(params, expr) => {
assert!(!have_array_size);
assert_eq!(symbol_kind, SymbolKind::Poly(PolynomialType::Committed));
let body = Box::new(ExpressionProcessor::new(self).process_function(&params, expr));
let body = Box::new(self.expression_processor().process_function(&params, expr));
FunctionValueDefinition::Query(Expression::LambdaExpression(LambdaExpression {
params,
body,
@@ -384,8 +386,9 @@ impl<T: FieldElement> PILAnalyzer<T> {
}
FunctionDefinition::Array(value) => {
let size = value.solve(self.polynomial_degree.unwrap());
let expression =
ExpressionProcessor::new(self).process_array_expression(value, size);
let expression = self
.expression_processor()
.process_array_expression(value, size);
assert_eq!(
expression.iter().map(|e| e.size()).sum::<DegreeType>(),
self.polynomial_degree.unwrap()
@@ -412,8 +415,9 @@ impl<T: FieldElement> PILAnalyzer<T> {
index: parsed::Expression<T>,
) {
let id = self.public_declarations.len() as u64;
let polynomial =
ExpressionProcessor::new(self).process_namespaced_polynomial_reference(poly);
let polynomial = self
.expression_processor()
.process_namespaced_polynomial_reference(poly);
let array_index = array_index.map(|i| {
let index = self.evaluate_expression(i).unwrap().to_degree();
assert!(index <= usize::MAX as u64);
@@ -445,186 +449,30 @@ impl<T: FieldElement> PILAnalyzer<T> {
format!("{}.{name}", self.namespace)
}
pub fn namespaced_ref_to_absolute(&self, namespace: &Option<String>, name: &str) -> String {
if name.starts_with('%') || self.definitions.contains_key(&name.to_string()) {
assert!(namespace.is_none());
// Constants are not namespaced
name.to_string()
} else {
format!("{}.{name}", namespace.as_ref().unwrap_or(&self.namespace))
}
}
fn evaluate_expression(&self, expr: ::ast::parsed::Expression<T>) -> Result<T, EvalError> {
evaluator::evaluate_expression(&self.process_expression(expr), &self.definitions)?
.try_to_number()
}
fn expression_processor(&self) -> ExpressionProcessor<PILResolver<T>> {
ExpressionProcessor::new(PILResolver(self))
}
fn process_expression(&self, expr: ::ast::parsed::Expression<T>) -> Expression<T> {
ExpressionProcessor::new(self).process_expression(expr)
self.expression_processor().process_expression(expr)
}
}
/// The ExpressionProcessor turns parsed expressions into analyzed expressions.
/// Its main job is to resolve references:
/// It turns simple references into fully namespaced references and resolves local function variables.
/// It also evaluates expressions that are required to be compile-time constant.
struct ExpressionProcessor<'a, T> {
analyzer: &'a PILAnalyzer<T>,
local_variables: HashMap<String, u64>,
}
struct PILResolver<'a, T>(&'a PILAnalyzer<T>);
impl<'a, T: FieldElement> ExpressionProcessor<'a, T> {
fn new(analyzer: &'a PILAnalyzer<T>) -> Self {
Self {
analyzer,
local_variables: Default::default(),
}
}
pub fn process_selected_expression(
&mut self,
expr: SelectedExpressions<parsed::Expression<T>>,
) -> SelectedExpressions<Expression<T>> {
SelectedExpressions {
selector: expr.selector.map(|e| self.process_expression(e)),
expressions: self.process_expressions(expr.expressions),
}
}
pub fn process_array_expression(
&mut self,
array_expression: ::ast::parsed::ArrayExpression<T>,
size: DegreeType,
) -> Vec<RepeatedArray<T>> {
match array_expression {
ArrayExpression::Value(expressions) => {
let values = self.process_expressions(expressions);
let size = values.len() as DegreeType;
vec![RepeatedArray::new(values, size)]
}
ArrayExpression::RepeatedValue(expressions) => {
if size == 0 {
vec![]
} else {
vec![RepeatedArray::new(
self.process_expressions(expressions),
size,
)]
}
}
ArrayExpression::Concat(left, right) => self
.process_array_expression(*left, size)
.into_iter()
.chain(self.process_array_expression(*right, size))
.collect(),
}
}
pub fn process_expressions(&mut self, exprs: Vec<parsed::Expression<T>>) -> Vec<Expression<T>> {
exprs
.into_iter()
.map(|e| self.process_expression(e))
.collect()
}
pub fn process_expression(&mut self, expr: parsed::Expression<T>) -> Expression<T> {
use parsed::Expression as PExpression;
match expr {
PExpression::Reference(poly) => Expression::Reference(self.process_reference(poly)),
PExpression::PublicReference(name) => Expression::PublicReference(name),
PExpression::Number(n) => Expression::Number(n),
PExpression::String(value) => Expression::String(value),
PExpression::Tuple(items) => Expression::Tuple(self.process_expressions(items)),
PExpression::ArrayLiteral(ArrayLiteral { items }) => {
Expression::ArrayLiteral(ArrayLiteral {
items: self.process_expressions(items),
})
}
PExpression::LambdaExpression(LambdaExpression { params, body }) => {
let body = Box::new(self.process_function(&params, *body));
Expression::LambdaExpression(LambdaExpression { params, body })
}
PExpression::BinaryOperation(left, op, right) => Expression::BinaryOperation(
Box::new(self.process_expression(*left)),
op,
Box::new(self.process_expression(*right)),
),
PExpression::UnaryOperation(op, value) => {
Expression::UnaryOperation(op, Box::new(self.process_expression(*value)))
}
PExpression::IndexAccess(index_access) => {
Expression::IndexAccess(parsed::IndexAccess {
array: Box::new(self.process_expression(*index_access.array)),
index: Box::new(self.process_expression(*index_access.index)),
})
}
PExpression::FunctionCall(c) => Expression::FunctionCall(parsed::FunctionCall {
function: Box::new(self.process_expression(*c.function)),
arguments: self.process_expressions(c.arguments),
}),
PExpression::MatchExpression(scrutinee, arms) => Expression::MatchExpression(
Box::new(self.process_expression(*scrutinee)),
arms.into_iter()
.map(|MatchArm { pattern, value }| MatchArm {
pattern: match pattern {
MatchPattern::CatchAll => MatchPattern::CatchAll,
MatchPattern::Pattern(e) => {
MatchPattern::Pattern(self.process_expression(e))
}
},
value: self.process_expression(value),
})
.collect(),
),
PExpression::FreeInput(_) => panic!(),
}
}
fn process_reference(&mut self, reference: NamespacedPolynomialReference) -> Reference {
if reference.namespace.is_none() && self.local_variables.contains_key(&reference.name) {
let id = self.local_variables[&reference.name];
Reference::LocalVar(id, reference.name.to_string())
impl<'a, T: FieldElement> ReferenceResolver for PILResolver<'a, T> {
fn resolve(&self, namespace: &Option<String>, name: &str) -> String {
if name.starts_with('%') || self.0.definitions.contains_key(&name.to_string()) {
assert!(namespace.is_none());
// Constants are not namespaced
name.to_string()
} else {
Reference::Poly(self.process_namespaced_polynomial_reference(reference))
}
}
fn process_function(
&mut self,
params: &[String],
expression: ::ast::parsed::Expression<T>,
) -> Expression<T> {
let previous_local_vars = std::mem::take(&mut self.local_variables);
assert!(self.local_variables.is_empty());
self.local_variables = params
.iter()
.enumerate()
.map(|(i, p)| (p.clone(), i as u64))
.collect();
// Re-add the outer local variables if we do not overwrite them
// and increase their index by the number of parameters.
for (name, index) in &previous_local_vars {
self.local_variables
.entry(name.clone())
.or_insert(index + params.len() as u64);
}
let processed_value = self.process_expression(expression);
self.local_variables = previous_local_vars;
processed_value
}
pub fn process_namespaced_polynomial_reference(
&mut self,
poly: ::ast::parsed::NamespacedPolynomialReference,
) -> PolynomialReference {
let name = self
.analyzer
.namespaced_ref_to_absolute(&poly.namespace, &poly.name);
PolynomialReference {
name,
poly_id: None,
format!("{}.{name}", namespace.as_ref().unwrap_or(&self.0.namespace))
}
}
}