mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-04-20 03:03:25 -04:00
Merge pull request #794 from powdr-labs/extract_expr_processor
Extract expression processor.
This commit is contained in:
181
pil_analyzer/src/expression_processor.rs
Normal file
181
pil_analyzer/src/expression_processor.rs
Normal file
@@ -0,0 +1,181 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use ast::{
|
||||
analyzed::{Expression, PolynomialReference, Reference, RepeatedArray},
|
||||
parsed::{
|
||||
self, ArrayExpression, ArrayLiteral, LambdaExpression, MatchArm, MatchPattern,
|
||||
NamespacedPolynomialReference, SelectedExpressions,
|
||||
},
|
||||
};
|
||||
use number::DegreeType;
|
||||
|
||||
/// The ExpressionProcessor turns parsed expressions into analyzed expressions.
|
||||
/// Its main job is to resolve references:
|
||||
/// It turns simple references into fully namespaced references and resolves local function variables.
|
||||
/// It also evaluates expressions that are required to be compile-time constant.
|
||||
pub struct ExpressionProcessor<R: ReferenceResolver> {
|
||||
resolver: R,
|
||||
local_variables: HashMap<String, u64>,
|
||||
}
|
||||
|
||||
pub trait ReferenceResolver {
|
||||
/// Turns a reference to a name with an optional namespace to an absolute name.
|
||||
fn resolve(&self, namespace: &Option<String>, name: &str) -> String;
|
||||
}
|
||||
|
||||
impl<R: ReferenceResolver> ExpressionProcessor<R> {
|
||||
pub fn new(resolver: R) -> Self {
|
||||
Self {
|
||||
resolver,
|
||||
local_variables: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn process_selected_expression<T>(
|
||||
&mut self,
|
||||
expr: SelectedExpressions<parsed::Expression<T>>,
|
||||
) -> SelectedExpressions<Expression<T>> {
|
||||
SelectedExpressions {
|
||||
selector: expr.selector.map(|e| self.process_expression(e)),
|
||||
expressions: self.process_expressions(expr.expressions),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn process_array_expression<T>(
|
||||
&mut self,
|
||||
array_expression: ::ast::parsed::ArrayExpression<T>,
|
||||
size: DegreeType,
|
||||
) -> Vec<RepeatedArray<T>> {
|
||||
match array_expression {
|
||||
ArrayExpression::Value(expressions) => {
|
||||
let values = self.process_expressions(expressions);
|
||||
let size = values.len() as DegreeType;
|
||||
vec![RepeatedArray::new(values, size)]
|
||||
}
|
||||
ArrayExpression::RepeatedValue(expressions) => {
|
||||
if size == 0 {
|
||||
vec![]
|
||||
} else {
|
||||
vec![RepeatedArray::new(
|
||||
self.process_expressions(expressions),
|
||||
size,
|
||||
)]
|
||||
}
|
||||
}
|
||||
ArrayExpression::Concat(left, right) => self
|
||||
.process_array_expression(*left, size)
|
||||
.into_iter()
|
||||
.chain(self.process_array_expression(*right, size))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn process_expressions<T>(
|
||||
&mut self,
|
||||
exprs: Vec<parsed::Expression<T>>,
|
||||
) -> Vec<Expression<T>> {
|
||||
exprs
|
||||
.into_iter()
|
||||
.map(|e| self.process_expression(e))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn process_expression<T>(&mut self, expr: parsed::Expression<T>) -> Expression<T> {
|
||||
use parsed::Expression as PExpression;
|
||||
match expr {
|
||||
PExpression::Reference(poly) => Expression::Reference(self.process_reference(poly)),
|
||||
PExpression::PublicReference(name) => Expression::PublicReference(name),
|
||||
PExpression::Number(n) => Expression::Number(n),
|
||||
PExpression::String(value) => Expression::String(value),
|
||||
PExpression::Tuple(items) => Expression::Tuple(self.process_expressions(items)),
|
||||
PExpression::ArrayLiteral(ArrayLiteral { items }) => {
|
||||
Expression::ArrayLiteral(ArrayLiteral {
|
||||
items: self.process_expressions(items),
|
||||
})
|
||||
}
|
||||
PExpression::LambdaExpression(LambdaExpression { params, body }) => {
|
||||
let body = Box::new(self.process_function(¶ms, *body));
|
||||
Expression::LambdaExpression(LambdaExpression { params, body })
|
||||
}
|
||||
PExpression::BinaryOperation(left, op, right) => Expression::BinaryOperation(
|
||||
Box::new(self.process_expression(*left)),
|
||||
op,
|
||||
Box::new(self.process_expression(*right)),
|
||||
),
|
||||
PExpression::UnaryOperation(op, value) => {
|
||||
Expression::UnaryOperation(op, Box::new(self.process_expression(*value)))
|
||||
}
|
||||
PExpression::IndexAccess(index_access) => {
|
||||
Expression::IndexAccess(parsed::IndexAccess {
|
||||
array: Box::new(self.process_expression(*index_access.array)),
|
||||
index: Box::new(self.process_expression(*index_access.index)),
|
||||
})
|
||||
}
|
||||
PExpression::FunctionCall(c) => Expression::FunctionCall(parsed::FunctionCall {
|
||||
function: Box::new(self.process_expression(*c.function)),
|
||||
arguments: self.process_expressions(c.arguments),
|
||||
}),
|
||||
PExpression::MatchExpression(scrutinee, arms) => Expression::MatchExpression(
|
||||
Box::new(self.process_expression(*scrutinee)),
|
||||
arms.into_iter()
|
||||
.map(|MatchArm { pattern, value }| MatchArm {
|
||||
pattern: match pattern {
|
||||
MatchPattern::CatchAll => MatchPattern::CatchAll,
|
||||
MatchPattern::Pattern(e) => {
|
||||
MatchPattern::Pattern(self.process_expression(e))
|
||||
}
|
||||
},
|
||||
value: self.process_expression(value),
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
PExpression::FreeInput(_) => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_reference(&mut self, reference: NamespacedPolynomialReference) -> Reference {
|
||||
if reference.namespace.is_none() && self.local_variables.contains_key(&reference.name) {
|
||||
let id = self.local_variables[&reference.name];
|
||||
Reference::LocalVar(id, reference.name.to_string())
|
||||
} else {
|
||||
Reference::Poly(self.process_namespaced_polynomial_reference(reference))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn process_function<T>(
|
||||
&mut self,
|
||||
params: &[String],
|
||||
expression: ::ast::parsed::Expression<T>,
|
||||
) -> Expression<T> {
|
||||
let previous_local_vars = std::mem::take(&mut self.local_variables);
|
||||
|
||||
assert!(self.local_variables.is_empty());
|
||||
self.local_variables = params
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, p)| (p.clone(), i as u64))
|
||||
.collect();
|
||||
// Re-add the outer local variables if we do not overwrite them
|
||||
// and increase their index by the number of parameters.
|
||||
// TODO re-evaluate if this mechanism makes sense as soon as we properly
|
||||
// support nested functions and closures.
|
||||
for (name, index) in &previous_local_vars {
|
||||
self.local_variables
|
||||
.entry(name.clone())
|
||||
.or_insert(index + params.len() as u64);
|
||||
}
|
||||
let processed_value = self.process_expression(expression);
|
||||
self.local_variables = previous_local_vars;
|
||||
processed_value
|
||||
}
|
||||
|
||||
pub fn process_namespaced_polynomial_reference(
|
||||
&mut self,
|
||||
poly: ::ast::parsed::NamespacedPolynomialReference,
|
||||
) -> PolynomialReference {
|
||||
PolynomialReference {
|
||||
name: self.resolver.resolve(&poly.namespace, &poly.name),
|
||||
poly_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
mod condenser;
|
||||
pub mod evaluator;
|
||||
pub mod expression_processor;
|
||||
pub mod pil_analyzer;
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
@@ -6,19 +6,18 @@ use analysis::MacroExpander;
|
||||
|
||||
use ast::parsed::visitor::ExpressionVisitable;
|
||||
use ast::parsed::{
|
||||
self, ArrayExpression, ArrayLiteral, FunctionDefinition, LambdaExpression, MatchArm,
|
||||
MatchPattern, NamespacedPolynomialReference, PilStatement, PolynomialName, SelectedExpressions,
|
||||
self, FunctionDefinition, LambdaExpression, PilStatement, PolynomialName, SelectedExpressions,
|
||||
};
|
||||
use number::{DegreeType, FieldElement};
|
||||
|
||||
use ast::analyzed::{
|
||||
AlgebraicExpression, Analyzed, Expression, FunctionValueDefinition, Identity, IdentityKind,
|
||||
PolynomialReference, PolynomialType, PublicDeclaration, Reference, RepeatedArray, SourceRef,
|
||||
StatementIdentifier, Symbol, SymbolKind,
|
||||
PolynomialType, PublicDeclaration, SourceRef, StatementIdentifier, Symbol, SymbolKind,
|
||||
};
|
||||
|
||||
use crate::evaluator::EvalError;
|
||||
use crate::{condenser, evaluator};
|
||||
use crate::expression_processor::ReferenceResolver;
|
||||
use crate::{condenser, evaluator, expression_processor::ExpressionProcessor};
|
||||
|
||||
pub fn process_pil_file<T: FieldElement>(path: &Path) -> Analyzed<T> {
|
||||
let mut analyzer = PILAnalyzer::new();
|
||||
@@ -258,25 +257,28 @@ impl<T: FieldElement> PILAnalyzer<T> {
|
||||
PilStatement::PlookupIdentity(start, key, haystack) => (
|
||||
start,
|
||||
IdentityKind::Plookup,
|
||||
ExpressionProcessor::new(self).process_selected_expression(key),
|
||||
ExpressionProcessor::new(self).process_selected_expression(haystack),
|
||||
self.expression_processor().process_selected_expression(key),
|
||||
self.expression_processor()
|
||||
.process_selected_expression(haystack),
|
||||
),
|
||||
PilStatement::PermutationIdentity(start, left, right) => (
|
||||
start,
|
||||
IdentityKind::Permutation,
|
||||
ExpressionProcessor::new(self).process_selected_expression(left),
|
||||
ExpressionProcessor::new(self).process_selected_expression(right),
|
||||
self.expression_processor()
|
||||
.process_selected_expression(left),
|
||||
self.expression_processor()
|
||||
.process_selected_expression(right),
|
||||
),
|
||||
PilStatement::ConnectIdentity(start, left, right) => (
|
||||
start,
|
||||
IdentityKind::Connect,
|
||||
SelectedExpressions {
|
||||
selector: None,
|
||||
expressions: ExpressionProcessor::new(self).process_expressions(left),
|
||||
expressions: self.expression_processor().process_expressions(left),
|
||||
},
|
||||
SelectedExpressions {
|
||||
selector: None,
|
||||
expressions: ExpressionProcessor::new(self).process_expressions(right),
|
||||
expressions: self.expression_processor().process_expressions(right),
|
||||
},
|
||||
),
|
||||
// TODO at some point, these should all be caught by the type checker.
|
||||
@@ -376,7 +378,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
|
||||
FunctionDefinition::Query(params, expr) => {
|
||||
assert!(!have_array_size);
|
||||
assert_eq!(symbol_kind, SymbolKind::Poly(PolynomialType::Committed));
|
||||
let body = Box::new(ExpressionProcessor::new(self).process_function(¶ms, expr));
|
||||
let body = Box::new(self.expression_processor().process_function(¶ms, expr));
|
||||
FunctionValueDefinition::Query(Expression::LambdaExpression(LambdaExpression {
|
||||
params,
|
||||
body,
|
||||
@@ -384,8 +386,9 @@ impl<T: FieldElement> PILAnalyzer<T> {
|
||||
}
|
||||
FunctionDefinition::Array(value) => {
|
||||
let size = value.solve(self.polynomial_degree.unwrap());
|
||||
let expression =
|
||||
ExpressionProcessor::new(self).process_array_expression(value, size);
|
||||
let expression = self
|
||||
.expression_processor()
|
||||
.process_array_expression(value, size);
|
||||
assert_eq!(
|
||||
expression.iter().map(|e| e.size()).sum::<DegreeType>(),
|
||||
self.polynomial_degree.unwrap()
|
||||
@@ -412,8 +415,9 @@ impl<T: FieldElement> PILAnalyzer<T> {
|
||||
index: parsed::Expression<T>,
|
||||
) {
|
||||
let id = self.public_declarations.len() as u64;
|
||||
let polynomial =
|
||||
ExpressionProcessor::new(self).process_namespaced_polynomial_reference(poly);
|
||||
let polynomial = self
|
||||
.expression_processor()
|
||||
.process_namespaced_polynomial_reference(poly);
|
||||
let array_index = array_index.map(|i| {
|
||||
let index = self.evaluate_expression(i).unwrap().to_degree();
|
||||
assert!(index <= usize::MAX as u64);
|
||||
@@ -445,186 +449,30 @@ impl<T: FieldElement> PILAnalyzer<T> {
|
||||
format!("{}.{name}", self.namespace)
|
||||
}
|
||||
|
||||
pub fn namespaced_ref_to_absolute(&self, namespace: &Option<String>, name: &str) -> String {
|
||||
if name.starts_with('%') || self.definitions.contains_key(&name.to_string()) {
|
||||
assert!(namespace.is_none());
|
||||
// Constants are not namespaced
|
||||
name.to_string()
|
||||
} else {
|
||||
format!("{}.{name}", namespace.as_ref().unwrap_or(&self.namespace))
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate_expression(&self, expr: ::ast::parsed::Expression<T>) -> Result<T, EvalError> {
|
||||
evaluator::evaluate_expression(&self.process_expression(expr), &self.definitions)?
|
||||
.try_to_number()
|
||||
}
|
||||
|
||||
fn expression_processor(&self) -> ExpressionProcessor<PILResolver<T>> {
|
||||
ExpressionProcessor::new(PILResolver(self))
|
||||
}
|
||||
|
||||
fn process_expression(&self, expr: ::ast::parsed::Expression<T>) -> Expression<T> {
|
||||
ExpressionProcessor::new(self).process_expression(expr)
|
||||
self.expression_processor().process_expression(expr)
|
||||
}
|
||||
}
|
||||
|
||||
/// The ExpressionProcessor turns parsed expressions into analyzed expressions.
|
||||
/// Its main job is to resolve references:
|
||||
/// It turns simple references into fully namespaced references and resolves local function variables.
|
||||
/// It also evaluates expressions that are required to be compile-time constant.
|
||||
struct ExpressionProcessor<'a, T> {
|
||||
analyzer: &'a PILAnalyzer<T>,
|
||||
local_variables: HashMap<String, u64>,
|
||||
}
|
||||
struct PILResolver<'a, T>(&'a PILAnalyzer<T>);
|
||||
|
||||
impl<'a, T: FieldElement> ExpressionProcessor<'a, T> {
|
||||
fn new(analyzer: &'a PILAnalyzer<T>) -> Self {
|
||||
Self {
|
||||
analyzer,
|
||||
local_variables: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn process_selected_expression(
|
||||
&mut self,
|
||||
expr: SelectedExpressions<parsed::Expression<T>>,
|
||||
) -> SelectedExpressions<Expression<T>> {
|
||||
SelectedExpressions {
|
||||
selector: expr.selector.map(|e| self.process_expression(e)),
|
||||
expressions: self.process_expressions(expr.expressions),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn process_array_expression(
|
||||
&mut self,
|
||||
array_expression: ::ast::parsed::ArrayExpression<T>,
|
||||
size: DegreeType,
|
||||
) -> Vec<RepeatedArray<T>> {
|
||||
match array_expression {
|
||||
ArrayExpression::Value(expressions) => {
|
||||
let values = self.process_expressions(expressions);
|
||||
let size = values.len() as DegreeType;
|
||||
vec![RepeatedArray::new(values, size)]
|
||||
}
|
||||
ArrayExpression::RepeatedValue(expressions) => {
|
||||
if size == 0 {
|
||||
vec![]
|
||||
} else {
|
||||
vec![RepeatedArray::new(
|
||||
self.process_expressions(expressions),
|
||||
size,
|
||||
)]
|
||||
}
|
||||
}
|
||||
ArrayExpression::Concat(left, right) => self
|
||||
.process_array_expression(*left, size)
|
||||
.into_iter()
|
||||
.chain(self.process_array_expression(*right, size))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn process_expressions(&mut self, exprs: Vec<parsed::Expression<T>>) -> Vec<Expression<T>> {
|
||||
exprs
|
||||
.into_iter()
|
||||
.map(|e| self.process_expression(e))
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn process_expression(&mut self, expr: parsed::Expression<T>) -> Expression<T> {
|
||||
use parsed::Expression as PExpression;
|
||||
match expr {
|
||||
PExpression::Reference(poly) => Expression::Reference(self.process_reference(poly)),
|
||||
PExpression::PublicReference(name) => Expression::PublicReference(name),
|
||||
PExpression::Number(n) => Expression::Number(n),
|
||||
PExpression::String(value) => Expression::String(value),
|
||||
PExpression::Tuple(items) => Expression::Tuple(self.process_expressions(items)),
|
||||
PExpression::ArrayLiteral(ArrayLiteral { items }) => {
|
||||
Expression::ArrayLiteral(ArrayLiteral {
|
||||
items: self.process_expressions(items),
|
||||
})
|
||||
}
|
||||
PExpression::LambdaExpression(LambdaExpression { params, body }) => {
|
||||
let body = Box::new(self.process_function(¶ms, *body));
|
||||
Expression::LambdaExpression(LambdaExpression { params, body })
|
||||
}
|
||||
PExpression::BinaryOperation(left, op, right) => Expression::BinaryOperation(
|
||||
Box::new(self.process_expression(*left)),
|
||||
op,
|
||||
Box::new(self.process_expression(*right)),
|
||||
),
|
||||
PExpression::UnaryOperation(op, value) => {
|
||||
Expression::UnaryOperation(op, Box::new(self.process_expression(*value)))
|
||||
}
|
||||
PExpression::IndexAccess(index_access) => {
|
||||
Expression::IndexAccess(parsed::IndexAccess {
|
||||
array: Box::new(self.process_expression(*index_access.array)),
|
||||
index: Box::new(self.process_expression(*index_access.index)),
|
||||
})
|
||||
}
|
||||
PExpression::FunctionCall(c) => Expression::FunctionCall(parsed::FunctionCall {
|
||||
function: Box::new(self.process_expression(*c.function)),
|
||||
arguments: self.process_expressions(c.arguments),
|
||||
}),
|
||||
PExpression::MatchExpression(scrutinee, arms) => Expression::MatchExpression(
|
||||
Box::new(self.process_expression(*scrutinee)),
|
||||
arms.into_iter()
|
||||
.map(|MatchArm { pattern, value }| MatchArm {
|
||||
pattern: match pattern {
|
||||
MatchPattern::CatchAll => MatchPattern::CatchAll,
|
||||
MatchPattern::Pattern(e) => {
|
||||
MatchPattern::Pattern(self.process_expression(e))
|
||||
}
|
||||
},
|
||||
value: self.process_expression(value),
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
PExpression::FreeInput(_) => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_reference(&mut self, reference: NamespacedPolynomialReference) -> Reference {
|
||||
if reference.namespace.is_none() && self.local_variables.contains_key(&reference.name) {
|
||||
let id = self.local_variables[&reference.name];
|
||||
Reference::LocalVar(id, reference.name.to_string())
|
||||
impl<'a, T: FieldElement> ReferenceResolver for PILResolver<'a, T> {
|
||||
fn resolve(&self, namespace: &Option<String>, name: &str) -> String {
|
||||
if name.starts_with('%') || self.0.definitions.contains_key(&name.to_string()) {
|
||||
assert!(namespace.is_none());
|
||||
// Constants are not namespaced
|
||||
name.to_string()
|
||||
} else {
|
||||
Reference::Poly(self.process_namespaced_polynomial_reference(reference))
|
||||
}
|
||||
}
|
||||
|
||||
fn process_function(
|
||||
&mut self,
|
||||
params: &[String],
|
||||
expression: ::ast::parsed::Expression<T>,
|
||||
) -> Expression<T> {
|
||||
let previous_local_vars = std::mem::take(&mut self.local_variables);
|
||||
|
||||
assert!(self.local_variables.is_empty());
|
||||
self.local_variables = params
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, p)| (p.clone(), i as u64))
|
||||
.collect();
|
||||
// Re-add the outer local variables if we do not overwrite them
|
||||
// and increase their index by the number of parameters.
|
||||
for (name, index) in &previous_local_vars {
|
||||
self.local_variables
|
||||
.entry(name.clone())
|
||||
.or_insert(index + params.len() as u64);
|
||||
}
|
||||
let processed_value = self.process_expression(expression);
|
||||
self.local_variables = previous_local_vars;
|
||||
processed_value
|
||||
}
|
||||
|
||||
pub fn process_namespaced_polynomial_reference(
|
||||
&mut self,
|
||||
poly: ::ast::parsed::NamespacedPolynomialReference,
|
||||
) -> PolynomialReference {
|
||||
let name = self
|
||||
.analyzer
|
||||
.namespaced_ref_to_absolute(&poly.namespace, &poly.name);
|
||||
PolynomialReference {
|
||||
name,
|
||||
poly_id: None,
|
||||
format!("{}.{name}", namespace.as_ref().unwrap_or(&self.0.namespace))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user