Merge pull request #640 from powdr-labs/generalize_definitions

Generalize definitions.
This commit is contained in:
Georg Wiese
2023-09-29 11:49:45 +00:00
committed by GitHub
9 changed files with 224 additions and 145 deletions

View File

@@ -33,18 +33,29 @@ impl<T: Display> Display for Analyzed<T> {
for statement in &self.source_order {
match statement {
StatementIdentifier::Definition(name) => {
let (poly, definition) = &self.definitions[name];
let name = update_namespace(name, poly.degree, f)?;
let kind = match &poly.poly_type {
PolynomialType::Committed => "witness ",
PolynomialType::Constant => "fixed ",
PolynomialType::Intermediate => "",
};
write!(f, " col {kind}{name}")?;
if let Some(value) = definition {
writeln!(f, "{value};")?
} else {
writeln!(f, ";")?
let (symbol, definition) = &self.definitions[name];
let name = update_namespace(name, symbol.degree, f)?;
match symbol.kind {
SymbolKind::Poly(poly_type) => {
let kind = match &poly_type {
PolynomialType::Committed => "witness ",
PolynomialType::Constant => "fixed ",
PolynomialType::Intermediate => "",
};
write!(f, " col {kind}{name}")?;
if let Some(value) = definition {
writeln!(f, "{value};")?
} else {
writeln!(f, ";")?
}
}
SymbolKind::Other() => {
write!(f, " let {name}")?;
if let Some(value) = definition {
write!(f, "{value}")?
}
writeln!(f, ";")?
}
}
}
StatementIdentifier::PublicDeclaration(name) => {
@@ -126,14 +137,8 @@ impl<T: Display> Display for SelectedExpressions<T> {
impl Display for Reference {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Reference::LocalVar(index) => {
// TODO this is not really reproducing the input, but
// if we want to do that, we would need the names of the local variables somehow.
if *index == 0 {
write!(f, "i")
} else {
write!(f, "${index}")
}
Reference::LocalVar(_index, name) => {
write!(f, "{name}")
}
Reference::Poly(r) => write!(f, "{r}"),
}

View File

@@ -27,7 +27,7 @@ pub enum StatementIdentifier {
pub struct Analyzed<T> {
/// Constants are not namespaced!
pub constants: HashMap<String, T>,
pub definitions: HashMap<String, (Polynomial, Option<FunctionValueDefinition<T>>)>,
pub definitions: HashMap<String, (Symbol, Option<FunctionValueDefinition<T>>)>,
pub public_declarations: HashMap<String, PublicDeclaration>,
pub identities: Vec<Identity<T>>,
/// The order in which definitions and identities
@@ -51,33 +51,36 @@ impl<T> Analyzed<T> {
pub fn constant_polys_in_source_order(
&self,
) -> Vec<&(Polynomial, Option<FunctionValueDefinition<T>>)> {
) -> Vec<&(Symbol, Option<FunctionValueDefinition<T>>)> {
self.definitions_in_source_order(PolynomialType::Constant)
}
pub fn committed_polys_in_source_order(
&self,
) -> Vec<&(Polynomial, Option<FunctionValueDefinition<T>>)> {
) -> Vec<&(Symbol, Option<FunctionValueDefinition<T>>)> {
self.definitions_in_source_order(PolynomialType::Committed)
}
pub fn intermediate_polys_in_source_order(
&self,
) -> Vec<&(Polynomial, Option<FunctionValueDefinition<T>>)> {
) -> Vec<&(Symbol, Option<FunctionValueDefinition<T>>)> {
self.definitions_in_source_order(PolynomialType::Intermediate)
}
pub fn definitions_in_source_order(
&self,
poly_type: PolynomialType,
) -> Vec<&(Polynomial, Option<FunctionValueDefinition<T>>)> {
) -> Vec<&(Symbol, Option<FunctionValueDefinition<T>>)> {
self.source_order
.iter()
.filter_map(move |statement| {
if let StatementIdentifier::Definition(name) = statement {
let definition = &self.definitions[name];
if definition.0.poly_type == poly_type {
return Some(definition);
match definition.0.kind {
SymbolKind::Poly(ptype) if ptype == poly_type => {
return Some(definition);
}
_ => {}
}
}
None
@@ -88,12 +91,11 @@ impl<T> Analyzed<T> {
fn declaration_type_count(&self, poly_type: PolynomialType) -> usize {
self.definitions
.iter()
.filter_map(move |(_name, (poly, _))| {
if poly.poly_type == poly_type {
Some(poly.length.unwrap_or(1) as usize)
} else {
None
.filter_map(move |(_name, (symbol, _))| match symbol.kind {
SymbolKind::Poly(ptype) if ptype == poly_type => {
Some(symbol.length.unwrap_or(1) as usize)
}
_ => None,
})
.sum()
}
@@ -140,7 +142,7 @@ impl<T> Analyzed<T> {
let mut names_to_remove: HashSet<String> = Default::default();
self.definitions.retain(|name, (poly, _def)| {
if to_remove.contains(&(poly as &Polynomial).into()) {
if to_remove.contains(&(poly as &Symbol).into()) {
names_to_remove.insert(name.clone());
false
} else {
@@ -156,7 +158,7 @@ impl<T> Analyzed<T> {
true
});
self.definitions.values_mut().for_each(|(poly, _def)| {
let poly_id = PolyID::from(poly as &Polynomial);
let poly_id = PolyID::from(poly as &Symbol);
assert!(!to_remove.contains(&poly_id));
poly.id = replacements[&poly_id].id;
});
@@ -225,21 +227,29 @@ impl<T> Analyzed<T> {
}
#[derive(Debug, Clone)]
pub struct Polynomial {
pub struct Symbol {
pub id: u64,
pub source: SourceRef,
pub absolute_name: String,
pub poly_type: PolynomialType,
pub kind: SymbolKind,
pub degree: DegreeType,
pub length: Option<DegreeType>,
}
impl Polynomial {
impl Symbol {
pub fn is_array(&self) -> bool {
self.length.is_some()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SymbolKind {
/// Fixed, witness or intermediate polynomial
Poly(PolynomialType),
/// Other symbol, could be a constant, depends on the type.
Other(),
}
#[derive(Debug)]
pub enum FunctionValueDefinition<T> {
Mapping(Expression<T>),
@@ -375,7 +385,7 @@ impl<T> Expression<T> {
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Reference {
LocalVar(u64),
LocalVar(u64, String),
Poly(PolynomialReference),
}
@@ -398,11 +408,14 @@ pub struct PolyID {
pub ptype: PolynomialType,
}
impl From<&Polynomial> for PolyID {
fn from(poly: &Polynomial) -> Self {
impl From<&Symbol> for PolyID {
fn from(symbol: &Symbol) -> Self {
let SymbolKind::Poly(ptype) = symbol.kind else {
panic!()
};
PolyID {
id: poly.id,
ptype: poly.poly_type,
id: symbol.id,
ptype,
}
}
}

View File

@@ -1,8 +1,8 @@
use std::collections::HashMap;
use ast::analyzed::{
Analyzed, Expression, Identity, Polynomial, PolynomialType, PublicDeclaration,
SelectedExpressions, StatementIdentifier,
Analyzed, Expression, Identity, PolynomialType, PublicDeclaration, SelectedExpressions,
StatementIdentifier, Symbol, SymbolKind,
};
/// Computes expression IDs for each intermediate polynomial.
@@ -13,7 +13,7 @@ pub fn compute_intermediate_expression_ids<T>(analyzed: &Analyzed<T>) -> HashMap
expression_counter += match item {
StatementIdentifier::Definition(name) => {
let poly = &analyzed.definitions[name].0;
if poly.poly_type == PolynomialType::Intermediate {
if poly.kind == SymbolKind::Poly(PolynomialType::Intermediate) {
ids.insert(poly.id, expression_counter as u64);
}
poly.expression_count()
@@ -38,9 +38,9 @@ impl<T> ExpressionCounter for Identity<T> {
}
}
impl ExpressionCounter for Polynomial {
impl ExpressionCounter for Symbol {
fn expression_count(&self) -> usize {
(self.poly_type == PolynomialType::Intermediate).into()
(self.kind == SymbolKind::Poly(PolynomialType::Intermediate)).into()
}
}

View File

@@ -4,7 +4,7 @@ use std::collections::HashMap;
use ast::analyzed::{
self, Analyzed, BinaryOperator, Expression, FunctionValueDefinition, IdentityKind, PolyID,
PolynomialReference, PolynomialType, StatementIdentifier, UnaryOperator,
PolynomialReference, PolynomialType, StatementIdentifier, SymbolKind, UnaryOperator,
};
use starky::types::{
ConnectionIdentity, Expression as StarkyExpr, PermutationIdentity, PlookupIdentity,
@@ -47,7 +47,7 @@ pub fn export<T: FieldElement>(analyzed: &Analyzed<T>) -> PIL {
match item {
StatementIdentifier::Definition(name) => {
if let (poly, Some(value)) = &analyzed.definitions[name] {
if poly.poly_type == PolynomialType::Intermediate {
if poly.kind == SymbolKind::Poly(PolynomialType::Intermediate) {
if let FunctionValueDefinition::Expression(value) = value {
let expression_id = exporter.extract_expression(value, 1);
assert_eq!(
@@ -141,6 +141,13 @@ pub fn export<T: FieldElement>(analyzed: &Analyzed<T>) -> PIL {
}
}
fn symbol_kind_to_json_string(k: SymbolKind) -> &'static str {
match k {
SymbolKind::Poly(poly_type) => polynomial_type_to_json_string(poly_type),
SymbolKind::Other() => panic!("Cannot translate \"other\" symbol to json."),
}
}
fn polynomial_type_to_json_string(t: PolynomialType) -> &'static str {
polynomial_reference_type_to_type(polynomial_reference_type_to_json_string(t))
}
@@ -176,20 +183,20 @@ impl<'a, T: FieldElement> Exporter<'a, T> {
self.analyzed
.definitions
.iter()
.map(|(name, (poly, _value))| {
let id = if poly.poly_type == PolynomialType::Intermediate {
self.intermediate_poly_expression_ids[&poly.id]
.map(|(name, (symbol, _value))| {
let id = if symbol.kind == SymbolKind::Poly(PolynomialType::Intermediate) {
self.intermediate_poly_expression_ids[&symbol.id]
} else {
poly.id
symbol.id
};
let out = Reference {
polType: None,
type_: polynomial_type_to_json_string(poly.poly_type).to_string(),
type_: symbol_kind_to_json_string(symbol.kind).to_string(),
id: id as usize,
polDeg: poly.degree as usize,
isArray: poly.is_array(),
polDeg: symbol.degree as usize,
isArray: symbol.is_array(),
elementType: None,
len: poly.length.map(|l| l as usize),
len: symbol.length.map(|l| l as usize),
};
(name.clone(), out)
})
@@ -240,7 +247,7 @@ impl<'a, T: FieldElement> Exporter<'a, T> {
Expression::Reference(analyzed::Reference::Poly(reference)) => {
self.polynomial_reference_to_json(reference)
}
Expression::Reference(analyzed::Reference::LocalVar(_)) => {
Expression::Reference(analyzed::Reference::LocalVar(_, _)) => {
panic!("No local variable references allowed here.")
}
Expression::PublicReference(name) => (

View File

@@ -1,4 +1,4 @@
use ast::analyzed::{Analyzed, FunctionValueDefinition, Polynomial};
use ast::analyzed::{Analyzed, FunctionValueDefinition, Symbol};
use number::{read_polys_file, DegreeType, FieldElement};
use std::{fs::File, io::BufReader, path::Path};
@@ -6,7 +6,7 @@ pub trait PolySet {
const FILE_NAME: &'static str;
fn get_polys<T: FieldElement>(
pil: &Analyzed<T>,
) -> Vec<&(Polynomial, Option<FunctionValueDefinition<T>>)>;
) -> Vec<&(Symbol, Option<FunctionValueDefinition<T>>)>;
}
pub struct FixedPolySet;
@@ -15,7 +15,7 @@ impl PolySet for FixedPolySet {
fn get_polys<T: FieldElement>(
pil: &Analyzed<T>,
) -> Vec<&(Polynomial, Option<FunctionValueDefinition<T>>)> {
) -> Vec<&(Symbol, Option<FunctionValueDefinition<T>>)> {
pil.constant_polys_in_source_order()
}
}
@@ -26,7 +26,7 @@ impl PolySet for WitnessPolySet {
fn get_polys<T: FieldElement>(
pil: &Analyzed<T>,
) -> Vec<&(Polynomial, Option<FunctionValueDefinition<T>>)> {
) -> Vec<&(Symbol, Option<FunctionValueDefinition<T>>)> {
pil.committed_polys_in_source_order()
}
}

View File

@@ -91,7 +91,7 @@ impl<'a, T: FieldElement> Evaluator<'a, T> {
fn evaluate(&self, expr: &Expression<T>) -> T {
match expr {
Expression::Constant(name) => self.analyzed.constants[name],
Expression::Reference(Reference::LocalVar(i)) => self.variables[*i as usize],
Expression::Reference(Reference::LocalVar(i, _name)) => self.variables[*i as usize],
Expression::Reference(Reference::Poly(_)) => todo!(),
Expression::PublicReference(_) => todo!(),
Expression::Number(n) => *n,

View File

@@ -69,7 +69,7 @@ fn interpolate_query<'b, T: FieldElement>(
.map(|i| interpolate_query(i, rows))
.collect::<Result<Vec<_>, _>>()?
.join(", ")),
Expression::Reference(Reference::LocalVar(i)) => {
Expression::Reference(Reference::LocalVar(i, _name)) => {
assert!(*i == 0);
Ok(format!("{}", rows.current_row_index))
}

View File

@@ -14,3 +14,4 @@ analysis = { version = "0.1", path = "../analysis" }
[dev-dependencies]
test-log = "0.2.12"
env_logger = "0.10.0"
pretty_assertions = "1.3.0"

View File

@@ -14,9 +14,9 @@ use ast::analyzed::util::{
postvisit_expressions_in_identity_mut, previsit_expressions_in_pil_file_mut,
};
use ast::analyzed::{
Analyzed, Expression, FunctionValueDefinition, Identity, IdentityKind, Polynomial,
PolynomialReference, PolynomialType, PublicDeclaration, Reference, RepeatedArray,
SelectedExpressions, SourceRef, StatementIdentifier,
Analyzed, Expression, FunctionValueDefinition, Identity, IdentityKind, PolynomialReference,
PolynomialType, PublicDeclaration, Reference, RepeatedArray, SelectedExpressions, SourceRef,
StatementIdentifier, Symbol, SymbolKind,
};
pub fn process_pil_file<T: FieldElement>(path: &Path) -> Analyzed<T> {
@@ -37,7 +37,7 @@ struct PILAnalyzer<T> {
polynomial_degree: DegreeType,
/// Constants are not namespaced!
constants: HashMap<String, T>,
definitions: HashMap<String, (Polynomial, Option<FunctionValueDefinition<T>>)>,
definitions: HashMap<String, (Symbol, Option<FunctionValueDefinition<T>>)>,
public_declarations: HashMap<String, PublicDeclaration>,
identities: Vec<Identity<T>>,
/// The order in which definitions and identities
@@ -49,6 +49,7 @@ struct PILAnalyzer<T> {
commit_poly_counter: u64,
constant_poly_counter: u64,
intermediate_poly_counter: u64,
other_symbol_counter: u64,
identity_counter: HashMap<IdentityKind, u64>,
local_variables: HashMap<String, u64>,
macro_expander: MacroExpander<T>,
@@ -149,7 +150,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
self.to_source_ref(start),
name,
None,
PolynomialType::Intermediate,
SymbolKind::Poly(PolynomialType::Intermediate),
Some(FunctionDefinition::Expression(value)),
);
}
@@ -167,7 +168,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
self.to_source_ref(start),
name,
None,
PolynomialType::Constant,
SymbolKind::Poly(PolynomialType::Constant),
Some(definition),
);
}
@@ -184,43 +185,15 @@ impl<T: FieldElement> PILAnalyzer<T> {
self.to_source_ref(start),
name.name,
name.array_size,
PolynomialType::Committed,
SymbolKind::Poly(PolynomialType::Committed),
Some(definition),
);
}
PilStatement::ConstantDefinition(_, name, value) => {
self.handle_constant_definition(name, value)
PilStatement::ConstantDefinition(_start, name, value) => {
self.handle_constant_definition(name, self.evaluate_expression(&value).unwrap())
}
PilStatement::LetStatement(start, name, None) => {
// Handle all let statements without assignment as witness column declarations for now.
self.handle_polynomial_definition(
self.to_source_ref(start),
name,
None,
PolynomialType::Committed,
None,
);
}
PilStatement::LetStatement(start, name, Some(value)) => {
// Determine this is a fixed column or a constant depending on the structure
// of the value.
// Later, this should depend on the type.
if let parsed::Expression::LambdaExpression(parsed::LambdaExpression {
params,
body,
}) = value
{
self.handle_polynomial_definition(
self.to_source_ref(start),
name,
None,
PolynomialType::Constant,
Some(FunctionDefinition::Mapping(params, *body)),
);
} else {
self.handle_constant_definition(name, value);
}
PilStatement::LetStatement(start, name, value) => {
self.handle_generic_definition(start, name, value)
}
PilStatement::MacroDefinition(_, _, _, _, _) => {
panic!("Macros should have been eliminated.");
@@ -239,6 +212,62 @@ impl<T: FieldElement> PILAnalyzer<T> {
}
}
fn handle_generic_definition(
&mut self,
start: usize,
name: String,
value: Option<::ast::parsed::Expression<T>>,
) {
// Determine whether this is a fixed column, a constant or something else
// depending on the structure of the value and if we can evaluate
// it to a single number.
// Later, this should depend on the type.
match value {
None => {
// No value provided => treat it as a witness column.
self.handle_polynomial_definition(
self.to_source_ref(start),
name,
None,
SymbolKind::Poly(PolynomialType::Committed),
None,
);
}
Some(value) => {
match value {
parsed::Expression::LambdaExpression(parsed::LambdaExpression {
params,
body,
}) if params.len() == 1 => {
// Assigned value is a lambda expression with a single parameter => treat it as a fixed column.
self.handle_polynomial_definition(
self.to_source_ref(start),
name,
None,
SymbolKind::Poly(PolynomialType::Constant),
Some(FunctionDefinition::Mapping(params, *body)),
);
}
_ => {
if let Some(constant) = self.evaluate_expression(&value) {
// Value evaluates to a constant number => treat it as a constant
self.handle_constant_definition(name.to_string(), constant);
} else {
// Otherwise, treat it as "generic definition"
self.handle_polynomial_definition(
self.to_source_ref(start),
name,
None,
SymbolKind::Other(),
Some(FunctionDefinition::Expression(value)),
);
}
}
}
}
}
}
fn handle_identity_statement(&mut self, statement: PilStatement<T>) {
let (start, kind, left, right) = match statement {
PilStatement::PolynomialIdentity(start, expression) => (
@@ -315,7 +344,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
source.clone(),
name,
array_size,
polynomial_type,
SymbolKind::Poly(polynomial_type),
None,
);
}
@@ -326,7 +355,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
source: SourceRef,
name: String,
array_size: Option<::ast::parsed::Expression<T>>,
polynomial_type: PolynomialType,
symbol_kind: SymbolKind,
value: Option<FunctionDefinition<T>>,
) -> u64 {
let have_array_size = array_size.is_some();
@@ -336,39 +365,43 @@ impl<T: FieldElement> PILAnalyzer<T> {
if length.is_some() {
assert!(value.is_none());
}
let counter = match polynomial_type {
PolynomialType::Committed => &mut self.commit_poly_counter,
PolynomialType::Constant => &mut self.constant_poly_counter,
PolynomialType::Intermediate => &mut self.intermediate_poly_counter,
let counter = match symbol_kind {
SymbolKind::Poly(PolynomialType::Committed) => &mut self.commit_poly_counter,
SymbolKind::Poly(PolynomialType::Constant) => &mut self.constant_poly_counter,
SymbolKind::Poly(PolynomialType::Intermediate) => &mut self.intermediate_poly_counter,
SymbolKind::Other() => &mut self.other_symbol_counter,
};
let id = *counter;
*counter += length.unwrap_or(1);
let absolute_name = self.namespaced(&name);
let poly = Polynomial {
let symbol = Symbol {
id,
source,
absolute_name,
degree: self.polynomial_degree,
poly_type: polynomial_type,
kind: symbol_kind,
length,
};
let name = poly.absolute_name.clone();
let name = symbol.absolute_name.clone();
let value = value.map(|v| match v {
FunctionDefinition::Expression(expr) => {
assert!(!have_array_size);
assert!(poly.poly_type == PolynomialType::Intermediate);
assert!(
symbol_kind == SymbolKind::Other()
|| symbol_kind == SymbolKind::Poly(PolynomialType::Intermediate)
);
FunctionValueDefinition::Expression(self.process_expression(expr))
}
FunctionDefinition::Mapping(params, expr) => {
assert!(!have_array_size);
assert!(poly.poly_type == PolynomialType::Constant);
FunctionValueDefinition::Mapping(self.process_function(params, expr))
assert!(symbol_kind == SymbolKind::Poly(PolynomialType::Constant));
FunctionValueDefinition::Mapping(self.process_function(&params, expr))
}
FunctionDefinition::Query(params, expr) => {
assert!(!have_array_size);
assert_eq!(poly.poly_type, PolynomialType::Committed);
FunctionValueDefinition::Query(self.process_function(params, expr))
assert_eq!(symbol_kind, SymbolKind::Poly(PolynomialType::Committed));
FunctionValueDefinition::Query(self.process_function(&params, expr))
}
FunctionDefinition::Array(value) => {
let size = value.solve(self.polynomial_degree);
@@ -382,7 +415,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
});
let is_new = self
.definitions
.insert(name.clone(), (poly, value))
.insert(name.clone(), (symbol, value))
.is_none();
assert!(is_new, "{name} already defined.");
self.source_order
@@ -392,17 +425,28 @@ impl<T: FieldElement> PILAnalyzer<T> {
fn process_function(
&mut self,
params: Vec<String>,
params: &[String],
expression: ::ast::parsed::Expression<T>,
) -> Expression<T> {
let previous_local_vars = std::mem::take(&mut self.local_variables);
assert!(self.local_variables.is_empty());
self.local_variables = params
.iter()
.enumerate()
.map(|(i, p)| (p.clone(), i as u64))
.collect();
// Re-add the outer local variables if we do not overwrite them
// and increase their index by the number of parameters.
// TODO re-evaluate if this mechanism makes sense as soon as we properly
// support nested functions and closures.
for (name, index) in &previous_local_vars {
self.local_variables
.entry(name.clone())
.or_insert(index + params.len() as u64);
}
let processed_value = self.process_expression(expression);
self.local_variables.clear();
self.local_variables = previous_local_vars;
processed_value
}
@@ -428,12 +472,8 @@ impl<T: FieldElement> PILAnalyzer<T> {
.push(StatementIdentifier::PublicDeclaration(name));
}
fn handle_constant_definition(&mut self, name: String, value: ::ast::parsed::Expression<T>) {
// TODO does the order matter here?
let is_new = self
.constants
.insert(name.to_string(), self.evaluate_expression(&value).unwrap())
.is_none();
fn handle_constant_definition(&mut self, name: String, value: T) {
let is_new = self.constants.insert(name.to_string(), value).is_none();
assert!(is_new, "Constant {name} was defined twice.");
}
@@ -502,7 +542,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
}
fn process_expression(&mut self, expr: ::ast::parsed::Expression<T>) -> Expression<T> {
use ::ast::parsed::Expression as PExpression;
use ast::parsed::Expression as PExpression;
match expr {
PExpression::Constant(name) => Expression::Constant(name),
PExpression::Reference(poly) => {
@@ -510,7 +550,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
let id = self.local_variables[poly.name()];
assert!(!poly.shift());
assert!(poly.index().is_none());
Expression::Reference(Reference::LocalVar(id))
Expression::Reference(Reference::LocalVar(id, poly.name().to_string()))
} else {
Expression::Reference(Reference::Poly(
self.process_shifted_polynomial_reference(poly),
@@ -527,10 +567,8 @@ impl<T: FieldElement> PILAnalyzer<T> {
})
}
PExpression::LambdaExpression(LambdaExpression { params, body }) => {
Expression::LambdaExpression(LambdaExpression {
params,
body: Box::new(self.process_expression(*body)),
})
let body = Box::new(self.process_function(&params, *body));
Expression::LambdaExpression(LambdaExpression { params, body })
}
PExpression::BinaryOperation(left, op, right) => {
if let Some(value) = self.evaluate_binary_operation(&left, op, &right) {
@@ -601,7 +639,7 @@ impl<T: FieldElement> PILAnalyzer<T> {
}
fn evaluate_expression(&self, expr: &::ast::parsed::Expression<T>) -> Option<T> {
use ::ast::parsed::Expression::*;
use ast::parsed::Expression::*;
match expr {
Constant(name) => Some(
*self
@@ -662,18 +700,16 @@ pub fn inline_intermediate_polynomials<T: Copy>(analyzed: &Analyzed<T>) -> Vec<I
substitute_intermediate(
analyzed.identities.clone(),
&analyzed
.definitions
.definitions_in_source_order(PolynomialType::Intermediate)
.iter()
.filter_map(|(_, (pol, def))| match pol.poly_type {
PolynomialType::Committed => None,
PolynomialType::Constant => None,
PolynomialType::Intermediate => Some((
pol.id,
.map(|(symbol, def)| {
(
symbol.id,
match def.as_ref().unwrap() {
FunctionValueDefinition::Expression(e) => e.clone(),
_ => unreachable!(),
},
)),
)
})
.collect(),
)
@@ -748,6 +784,8 @@ mod test {
use number::GoldilocksField;
use test_log::test;
use pretty_assertions::assert_eq;
use super::*;
#[test]
@@ -797,11 +835,6 @@ namespace T(65536);
{ T.pc, T.reg_write_X_A, T.reg_write_X_CNT } in (1 - T.first_step) { T.line, T.p_reg_write_X_A, T.p_reg_write_X_CNT };
"#;
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
if input != formatted {
for (i, f) in input.split('\n').zip(formatted.split('\n')) {
assert_eq!(i, f);
}
}
assert_eq!(input, formatted);
}
@@ -836,6 +869,26 @@ namespace T(65536);
col int2 = N.intermediate;
col int3 = (N.int2 + N.intermediate);
N.int3 = (2 * N.x);
"#;
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);
}
#[test]
fn let_definitions() {
let input = r#"namespace N(65536);
let x;
let t = |i| i + 1;
let z = 7;
let other = [1, 2];
let other_fun = |i, j| (i + 7, (|k| k - i));
"#;
let expected = r#"constant z = 7;
namespace N(65536);
col witness x;
col fixed t(i) { (i + 1) };
let other = [1, 2];
let other_fun = |i, j| ((i + 7), |k| (k - i));
"#;
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);