diff --git a/ast/src/analyzed/display.rs b/ast/src/analyzed/display.rs index 6f4a438f8..e7ac24fc9 100644 --- a/ast/src/analyzed/display.rs +++ b/ast/src/analyzed/display.rs @@ -33,18 +33,29 @@ impl Display for Analyzed { for statement in &self.source_order { match statement { StatementIdentifier::Definition(name) => { - let (poly, definition) = &self.definitions[name]; - let name = update_namespace(name, poly.degree, f)?; - let kind = match &poly.poly_type { - PolynomialType::Committed => "witness ", - PolynomialType::Constant => "fixed ", - PolynomialType::Intermediate => "", - }; - write!(f, " col {kind}{name}")?; - if let Some(value) = definition { - writeln!(f, "{value};")? - } else { - writeln!(f, ";")? + let (symbol, definition) = &self.definitions[name]; + let name = update_namespace(name, symbol.degree, f)?; + match symbol.kind { + SymbolKind::Poly(poly_type) => { + let kind = match &poly_type { + PolynomialType::Committed => "witness ", + PolynomialType::Constant => "fixed ", + PolynomialType::Intermediate => "", + }; + write!(f, " col {kind}{name}")?; + if let Some(value) = definition { + writeln!(f, "{value};")? + } else { + writeln!(f, ";")? + } + } + SymbolKind::Other() => { + write!(f, " let {name}")?; + if let Some(value) = definition { + write!(f, "{value}")? + } + writeln!(f, ";")? + } } } StatementIdentifier::PublicDeclaration(name) => { @@ -126,14 +137,8 @@ impl Display for SelectedExpressions { impl Display for Reference { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Reference::LocalVar(index) => { - // TODO this is not really reproducing the input, but - // if we want to do that, we would need the names of the local variables somehow. - if *index == 0 { - write!(f, "i") - } else { - write!(f, "${index}") - } + Reference::LocalVar(_index, name) => { + write!(f, "{name}") } Reference::Poly(r) => write!(f, "{r}"), } diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index eafb3a2d9..f9fcad935 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -27,7 +27,7 @@ pub enum StatementIdentifier { pub struct Analyzed { /// Constants are not namespaced! pub constants: HashMap, - pub definitions: HashMap>)>, + pub definitions: HashMap>)>, pub public_declarations: HashMap, pub identities: Vec>, /// The order in which definitions and identities @@ -51,33 +51,36 @@ impl Analyzed { pub fn constant_polys_in_source_order( &self, - ) -> Vec<&(Polynomial, Option>)> { + ) -> Vec<&(Symbol, Option>)> { self.definitions_in_source_order(PolynomialType::Constant) } pub fn committed_polys_in_source_order( &self, - ) -> Vec<&(Polynomial, Option>)> { + ) -> Vec<&(Symbol, Option>)> { self.definitions_in_source_order(PolynomialType::Committed) } pub fn intermediate_polys_in_source_order( &self, - ) -> Vec<&(Polynomial, Option>)> { + ) -> Vec<&(Symbol, Option>)> { self.definitions_in_source_order(PolynomialType::Intermediate) } pub fn definitions_in_source_order( &self, poly_type: PolynomialType, - ) -> Vec<&(Polynomial, Option>)> { + ) -> Vec<&(Symbol, Option>)> { self.source_order .iter() .filter_map(move |statement| { if let StatementIdentifier::Definition(name) = statement { let definition = &self.definitions[name]; - if definition.0.poly_type == poly_type { - return Some(definition); + match definition.0.kind { + SymbolKind::Poly(ptype) if ptype == poly_type => { + return Some(definition); + } + _ => {} } } None @@ -88,12 +91,11 @@ impl Analyzed { fn declaration_type_count(&self, poly_type: PolynomialType) -> usize { self.definitions .iter() - .filter_map(move |(_name, (poly, _))| { - if poly.poly_type == poly_type { - Some(poly.length.unwrap_or(1) as usize) - } else { - None + .filter_map(move |(_name, (symbol, _))| match symbol.kind { + SymbolKind::Poly(ptype) if ptype == poly_type => { + Some(symbol.length.unwrap_or(1) as usize) } + _ => None, }) .sum() } @@ -140,7 +142,7 @@ impl Analyzed { let mut names_to_remove: HashSet = Default::default(); self.definitions.retain(|name, (poly, _def)| { - if to_remove.contains(&(poly as &Polynomial).into()) { + if to_remove.contains(&(poly as &Symbol).into()) { names_to_remove.insert(name.clone()); false } else { @@ -156,7 +158,7 @@ impl Analyzed { true }); self.definitions.values_mut().for_each(|(poly, _def)| { - let poly_id = PolyID::from(poly as &Polynomial); + let poly_id = PolyID::from(poly as &Symbol); assert!(!to_remove.contains(&poly_id)); poly.id = replacements[&poly_id].id; }); @@ -225,21 +227,29 @@ impl Analyzed { } #[derive(Debug, Clone)] -pub struct Polynomial { +pub struct Symbol { pub id: u64, pub source: SourceRef, pub absolute_name: String, - pub poly_type: PolynomialType, + pub kind: SymbolKind, pub degree: DegreeType, pub length: Option, } -impl Polynomial { +impl Symbol { pub fn is_array(&self) -> bool { self.length.is_some() } } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SymbolKind { + /// Fixed, witness or intermediate polynomial + Poly(PolynomialType), + /// Other symbol, could be a constant, depends on the type. + Other(), +} + #[derive(Debug)] pub enum FunctionValueDefinition { Mapping(Expression), @@ -375,7 +385,7 @@ impl Expression { #[derive(Debug, Clone, PartialEq, Eq)] pub enum Reference { - LocalVar(u64), + LocalVar(u64, String), Poly(PolynomialReference), } @@ -398,11 +408,14 @@ pub struct PolyID { pub ptype: PolynomialType, } -impl From<&Polynomial> for PolyID { - fn from(poly: &Polynomial) -> Self { +impl From<&Symbol> for PolyID { + fn from(symbol: &Symbol) -> Self { + let SymbolKind::Poly(ptype) = symbol.kind else { + panic!() + }; PolyID { - id: poly.id, - ptype: poly.poly_type, + id: symbol.id, + ptype, } } } diff --git a/backend/src/pilstark/json_exporter/expression_counter.rs b/backend/src/pilstark/json_exporter/expression_counter.rs index addfe3c78..3b4c4d772 100644 --- a/backend/src/pilstark/json_exporter/expression_counter.rs +++ b/backend/src/pilstark/json_exporter/expression_counter.rs @@ -1,8 +1,8 @@ use std::collections::HashMap; use ast::analyzed::{ - Analyzed, Expression, Identity, Polynomial, PolynomialType, PublicDeclaration, - SelectedExpressions, StatementIdentifier, + Analyzed, Expression, Identity, PolynomialType, PublicDeclaration, SelectedExpressions, + StatementIdentifier, Symbol, SymbolKind, }; /// Computes expression IDs for each intermediate polynomial. @@ -13,7 +13,7 @@ pub fn compute_intermediate_expression_ids(analyzed: &Analyzed) -> HashMap expression_counter += match item { StatementIdentifier::Definition(name) => { let poly = &analyzed.definitions[name].0; - if poly.poly_type == PolynomialType::Intermediate { + if poly.kind == SymbolKind::Poly(PolynomialType::Intermediate) { ids.insert(poly.id, expression_counter as u64); } poly.expression_count() @@ -38,9 +38,9 @@ impl ExpressionCounter for Identity { } } -impl ExpressionCounter for Polynomial { +impl ExpressionCounter for Symbol { fn expression_count(&self) -> usize { - (self.poly_type == PolynomialType::Intermediate).into() + (self.kind == SymbolKind::Poly(PolynomialType::Intermediate)).into() } } diff --git a/backend/src/pilstark/json_exporter/mod.rs b/backend/src/pilstark/json_exporter/mod.rs index ed8973501..5c0047b10 100644 --- a/backend/src/pilstark/json_exporter/mod.rs +++ b/backend/src/pilstark/json_exporter/mod.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use ast::analyzed::{ self, Analyzed, BinaryOperator, Expression, FunctionValueDefinition, IdentityKind, PolyID, - PolynomialReference, PolynomialType, StatementIdentifier, UnaryOperator, + PolynomialReference, PolynomialType, StatementIdentifier, SymbolKind, UnaryOperator, }; use starky::types::{ ConnectionIdentity, Expression as StarkyExpr, PermutationIdentity, PlookupIdentity, @@ -47,7 +47,7 @@ pub fn export(analyzed: &Analyzed) -> PIL { match item { StatementIdentifier::Definition(name) => { if let (poly, Some(value)) = &analyzed.definitions[name] { - if poly.poly_type == PolynomialType::Intermediate { + if poly.kind == SymbolKind::Poly(PolynomialType::Intermediate) { if let FunctionValueDefinition::Expression(value) = value { let expression_id = exporter.extract_expression(value, 1); assert_eq!( @@ -141,6 +141,13 @@ pub fn export(analyzed: &Analyzed) -> PIL { } } +fn symbol_kind_to_json_string(k: SymbolKind) -> &'static str { + match k { + SymbolKind::Poly(poly_type) => polynomial_type_to_json_string(poly_type), + SymbolKind::Other() => panic!("Cannot translate \"other\" symbol to json."), + } +} + fn polynomial_type_to_json_string(t: PolynomialType) -> &'static str { polynomial_reference_type_to_type(polynomial_reference_type_to_json_string(t)) } @@ -176,20 +183,20 @@ impl<'a, T: FieldElement> Exporter<'a, T> { self.analyzed .definitions .iter() - .map(|(name, (poly, _value))| { - let id = if poly.poly_type == PolynomialType::Intermediate { - self.intermediate_poly_expression_ids[&poly.id] + .map(|(name, (symbol, _value))| { + let id = if symbol.kind == SymbolKind::Poly(PolynomialType::Intermediate) { + self.intermediate_poly_expression_ids[&symbol.id] } else { - poly.id + symbol.id }; let out = Reference { polType: None, - type_: polynomial_type_to_json_string(poly.poly_type).to_string(), + type_: symbol_kind_to_json_string(symbol.kind).to_string(), id: id as usize, - polDeg: poly.degree as usize, - isArray: poly.is_array(), + polDeg: symbol.degree as usize, + isArray: symbol.is_array(), elementType: None, - len: poly.length.map(|l| l as usize), + len: symbol.length.map(|l| l as usize), }; (name.clone(), out) }) @@ -240,7 +247,7 @@ impl<'a, T: FieldElement> Exporter<'a, T> { Expression::Reference(analyzed::Reference::Poly(reference)) => { self.polynomial_reference_to_json(reference) } - Expression::Reference(analyzed::Reference::LocalVar(_)) => { + Expression::Reference(analyzed::Reference::LocalVar(_, _)) => { panic!("No local variable references allowed here.") } Expression::PublicReference(name) => ( diff --git a/compiler/src/util.rs b/compiler/src/util.rs index c691a1716..a392cfa25 100644 --- a/compiler/src/util.rs +++ b/compiler/src/util.rs @@ -1,4 +1,4 @@ -use ast::analyzed::{Analyzed, FunctionValueDefinition, Polynomial}; +use ast::analyzed::{Analyzed, FunctionValueDefinition, Symbol}; use number::{read_polys_file, DegreeType, FieldElement}; use std::{fs::File, io::BufReader, path::Path}; @@ -6,7 +6,7 @@ pub trait PolySet { const FILE_NAME: &'static str; fn get_polys( pil: &Analyzed, - ) -> Vec<&(Polynomial, Option>)>; + ) -> Vec<&(Symbol, Option>)>; } pub struct FixedPolySet; @@ -15,7 +15,7 @@ impl PolySet for FixedPolySet { fn get_polys( pil: &Analyzed, - ) -> Vec<&(Polynomial, Option>)> { + ) -> Vec<&(Symbol, Option>)> { pil.constant_polys_in_source_order() } } @@ -26,7 +26,7 @@ impl PolySet for WitnessPolySet { fn get_polys( pil: &Analyzed, - ) -> Vec<&(Polynomial, Option>)> { + ) -> Vec<&(Symbol, Option>)> { pil.committed_polys_in_source_order() } } diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index f6e641552..ee3700aeb 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -91,7 +91,7 @@ impl<'a, T: FieldElement> Evaluator<'a, T> { fn evaluate(&self, expr: &Expression) -> T { match expr { Expression::Constant(name) => self.analyzed.constants[name], - Expression::Reference(Reference::LocalVar(i)) => self.variables[*i as usize], + Expression::Reference(Reference::LocalVar(i, _name)) => self.variables[*i as usize], Expression::Reference(Reference::Poly(_)) => todo!(), Expression::PublicReference(_) => todo!(), Expression::Number(n) => *n, diff --git a/executor/src/witgen/query_processor.rs b/executor/src/witgen/query_processor.rs index ed54c7059..ba3c10ab5 100644 --- a/executor/src/witgen/query_processor.rs +++ b/executor/src/witgen/query_processor.rs @@ -69,7 +69,7 @@ fn interpolate_query<'b, T: FieldElement>( .map(|i| interpolate_query(i, rows)) .collect::, _>>()? .join(", ")), - Expression::Reference(Reference::LocalVar(i)) => { + Expression::Reference(Reference::LocalVar(i, _name)) => { assert!(*i == 0); Ok(format!("{}", rows.current_row_index)) } diff --git a/pil_analyzer/Cargo.toml b/pil_analyzer/Cargo.toml index 5941a1f47..4bc14f7ed 100644 --- a/pil_analyzer/Cargo.toml +++ b/pil_analyzer/Cargo.toml @@ -14,3 +14,4 @@ analysis = { version = "0.1", path = "../analysis" } [dev-dependencies] test-log = "0.2.12" env_logger = "0.10.0" +pretty_assertions = "1.3.0" diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index 2c4775c93..042dfc103 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -14,9 +14,9 @@ use ast::analyzed::util::{ postvisit_expressions_in_identity_mut, previsit_expressions_in_pil_file_mut, }; use ast::analyzed::{ - Analyzed, Expression, FunctionValueDefinition, Identity, IdentityKind, Polynomial, - PolynomialReference, PolynomialType, PublicDeclaration, Reference, RepeatedArray, - SelectedExpressions, SourceRef, StatementIdentifier, + Analyzed, Expression, FunctionValueDefinition, Identity, IdentityKind, PolynomialReference, + PolynomialType, PublicDeclaration, Reference, RepeatedArray, SelectedExpressions, SourceRef, + StatementIdentifier, Symbol, SymbolKind, }; pub fn process_pil_file(path: &Path) -> Analyzed { @@ -37,7 +37,7 @@ struct PILAnalyzer { polynomial_degree: DegreeType, /// Constants are not namespaced! constants: HashMap, - definitions: HashMap>)>, + definitions: HashMap>)>, public_declarations: HashMap, identities: Vec>, /// The order in which definitions and identities @@ -49,6 +49,7 @@ struct PILAnalyzer { commit_poly_counter: u64, constant_poly_counter: u64, intermediate_poly_counter: u64, + other_symbol_counter: u64, identity_counter: HashMap, local_variables: HashMap, macro_expander: MacroExpander, @@ -149,7 +150,7 @@ impl PILAnalyzer { self.to_source_ref(start), name, None, - PolynomialType::Intermediate, + SymbolKind::Poly(PolynomialType::Intermediate), Some(FunctionDefinition::Expression(value)), ); } @@ -167,7 +168,7 @@ impl PILAnalyzer { self.to_source_ref(start), name, None, - PolynomialType::Constant, + SymbolKind::Poly(PolynomialType::Constant), Some(definition), ); } @@ -184,43 +185,15 @@ impl PILAnalyzer { self.to_source_ref(start), name.name, name.array_size, - PolynomialType::Committed, + SymbolKind::Poly(PolynomialType::Committed), Some(definition), ); } - PilStatement::ConstantDefinition(_, name, value) => { - self.handle_constant_definition(name, value) + PilStatement::ConstantDefinition(_start, name, value) => { + self.handle_constant_definition(name, self.evaluate_expression(&value).unwrap()) } - PilStatement::LetStatement(start, name, None) => { - // Handle all let statements without assignment as witness column declarations for now. - self.handle_polynomial_definition( - self.to_source_ref(start), - name, - None, - PolynomialType::Committed, - None, - ); - } - PilStatement::LetStatement(start, name, Some(value)) => { - // Determine this is a fixed column or a constant depending on the structure - // of the value. - // Later, this should depend on the type. - - if let parsed::Expression::LambdaExpression(parsed::LambdaExpression { - params, - body, - }) = value - { - self.handle_polynomial_definition( - self.to_source_ref(start), - name, - None, - PolynomialType::Constant, - Some(FunctionDefinition::Mapping(params, *body)), - ); - } else { - self.handle_constant_definition(name, value); - } + PilStatement::LetStatement(start, name, value) => { + self.handle_generic_definition(start, name, value) } PilStatement::MacroDefinition(_, _, _, _, _) => { panic!("Macros should have been eliminated."); @@ -239,6 +212,62 @@ impl PILAnalyzer { } } + fn handle_generic_definition( + &mut self, + start: usize, + name: String, + value: Option<::ast::parsed::Expression>, + ) { + // Determine whether this is a fixed column, a constant or something else + // depending on the structure of the value and if we can evaluate + // it to a single number. + // Later, this should depend on the type. + match value { + None => { + // No value provided => treat it as a witness column. + self.handle_polynomial_definition( + self.to_source_ref(start), + name, + None, + SymbolKind::Poly(PolynomialType::Committed), + None, + ); + } + Some(value) => { + match value { + parsed::Expression::LambdaExpression(parsed::LambdaExpression { + params, + body, + }) if params.len() == 1 => { + // Assigned value is a lambda expression with a single parameter => treat it as a fixed column. + self.handle_polynomial_definition( + self.to_source_ref(start), + name, + None, + SymbolKind::Poly(PolynomialType::Constant), + Some(FunctionDefinition::Mapping(params, *body)), + ); + } + _ => { + if let Some(constant) = self.evaluate_expression(&value) { + // Value evaluates to a constant number => treat it as a constant + self.handle_constant_definition(name.to_string(), constant); + } else { + // Otherwise, treat it as "generic definition" + self.handle_polynomial_definition( + self.to_source_ref(start), + name, + None, + SymbolKind::Other(), + Some(FunctionDefinition::Expression(value)), + ); + } + } + } + } + } + } + fn handle_identity_statement(&mut self, statement: PilStatement) { let (start, kind, left, right) = match statement { PilStatement::PolynomialIdentity(start, expression) => ( @@ -315,7 +344,7 @@ impl PILAnalyzer { source.clone(), name, array_size, - polynomial_type, + SymbolKind::Poly(polynomial_type), None, ); } @@ -326,7 +355,7 @@ impl PILAnalyzer { source: SourceRef, name: String, array_size: Option<::ast::parsed::Expression>, - polynomial_type: PolynomialType, + symbol_kind: SymbolKind, value: Option>, ) -> u64 { let have_array_size = array_size.is_some(); @@ -336,39 +365,43 @@ impl PILAnalyzer { if length.is_some() { assert!(value.is_none()); } - let counter = match polynomial_type { - PolynomialType::Committed => &mut self.commit_poly_counter, - PolynomialType::Constant => &mut self.constant_poly_counter, - PolynomialType::Intermediate => &mut self.intermediate_poly_counter, + let counter = match symbol_kind { + SymbolKind::Poly(PolynomialType::Committed) => &mut self.commit_poly_counter, + SymbolKind::Poly(PolynomialType::Constant) => &mut self.constant_poly_counter, + SymbolKind::Poly(PolynomialType::Intermediate) => &mut self.intermediate_poly_counter, + SymbolKind::Other() => &mut self.other_symbol_counter, }; let id = *counter; *counter += length.unwrap_or(1); let absolute_name = self.namespaced(&name); - let poly = Polynomial { + let symbol = Symbol { id, source, absolute_name, degree: self.polynomial_degree, - poly_type: polynomial_type, + kind: symbol_kind, length, }; - let name = poly.absolute_name.clone(); + let name = symbol.absolute_name.clone(); let value = value.map(|v| match v { FunctionDefinition::Expression(expr) => { assert!(!have_array_size); - assert!(poly.poly_type == PolynomialType::Intermediate); + assert!( + symbol_kind == SymbolKind::Other() + || symbol_kind == SymbolKind::Poly(PolynomialType::Intermediate) + ); FunctionValueDefinition::Expression(self.process_expression(expr)) } FunctionDefinition::Mapping(params, expr) => { assert!(!have_array_size); - assert!(poly.poly_type == PolynomialType::Constant); - FunctionValueDefinition::Mapping(self.process_function(params, expr)) + assert!(symbol_kind == SymbolKind::Poly(PolynomialType::Constant)); + FunctionValueDefinition::Mapping(self.process_function(¶ms, expr)) } FunctionDefinition::Query(params, expr) => { assert!(!have_array_size); - assert_eq!(poly.poly_type, PolynomialType::Committed); - FunctionValueDefinition::Query(self.process_function(params, expr)) + assert_eq!(symbol_kind, SymbolKind::Poly(PolynomialType::Committed)); + FunctionValueDefinition::Query(self.process_function(¶ms, expr)) } FunctionDefinition::Array(value) => { let size = value.solve(self.polynomial_degree); @@ -382,7 +415,7 @@ impl PILAnalyzer { }); let is_new = self .definitions - .insert(name.clone(), (poly, value)) + .insert(name.clone(), (symbol, value)) .is_none(); assert!(is_new, "{name} already defined."); self.source_order @@ -392,17 +425,28 @@ impl PILAnalyzer { fn process_function( &mut self, - params: Vec, + params: &[String], expression: ::ast::parsed::Expression, ) -> Expression { + let previous_local_vars = std::mem::take(&mut self.local_variables); + assert!(self.local_variables.is_empty()); self.local_variables = params .iter() .enumerate() .map(|(i, p)| (p.clone(), i as u64)) .collect(); + // Re-add the outer local variables if we do not overwrite them + // and increase their index by the number of parameters. + // TODO re-evaluate if this mechanism makes sense as soon as we properly + // support nested functions and closures. + for (name, index) in &previous_local_vars { + self.local_variables + .entry(name.clone()) + .or_insert(index + params.len() as u64); + } let processed_value = self.process_expression(expression); - self.local_variables.clear(); + self.local_variables = previous_local_vars; processed_value } @@ -428,12 +472,8 @@ impl PILAnalyzer { .push(StatementIdentifier::PublicDeclaration(name)); } - fn handle_constant_definition(&mut self, name: String, value: ::ast::parsed::Expression) { - // TODO does the order matter here? - let is_new = self - .constants - .insert(name.to_string(), self.evaluate_expression(&value).unwrap()) - .is_none(); + fn handle_constant_definition(&mut self, name: String, value: T) { + let is_new = self.constants.insert(name.to_string(), value).is_none(); assert!(is_new, "Constant {name} was defined twice."); } @@ -502,7 +542,7 @@ impl PILAnalyzer { } fn process_expression(&mut self, expr: ::ast::parsed::Expression) -> Expression { - use ::ast::parsed::Expression as PExpression; + use ast::parsed::Expression as PExpression; match expr { PExpression::Constant(name) => Expression::Constant(name), PExpression::Reference(poly) => { @@ -510,7 +550,7 @@ impl PILAnalyzer { let id = self.local_variables[poly.name()]; assert!(!poly.shift()); assert!(poly.index().is_none()); - Expression::Reference(Reference::LocalVar(id)) + Expression::Reference(Reference::LocalVar(id, poly.name().to_string())) } else { Expression::Reference(Reference::Poly( self.process_shifted_polynomial_reference(poly), @@ -527,10 +567,8 @@ impl PILAnalyzer { }) } PExpression::LambdaExpression(LambdaExpression { params, body }) => { - Expression::LambdaExpression(LambdaExpression { - params, - body: Box::new(self.process_expression(*body)), - }) + let body = Box::new(self.process_function(¶ms, *body)); + Expression::LambdaExpression(LambdaExpression { params, body }) } PExpression::BinaryOperation(left, op, right) => { if let Some(value) = self.evaluate_binary_operation(&left, op, &right) { @@ -601,7 +639,7 @@ impl PILAnalyzer { } fn evaluate_expression(&self, expr: &::ast::parsed::Expression) -> Option { - use ::ast::parsed::Expression::*; + use ast::parsed::Expression::*; match expr { Constant(name) => Some( *self @@ -662,18 +700,16 @@ pub fn inline_intermediate_polynomials(analyzed: &Analyzed) -> Vec None, - PolynomialType::Constant => None, - PolynomialType::Intermediate => Some(( - pol.id, + .map(|(symbol, def)| { + ( + symbol.id, match def.as_ref().unwrap() { FunctionValueDefinition::Expression(e) => e.clone(), _ => unreachable!(), }, - )), + ) }) .collect(), ) @@ -748,6 +784,8 @@ mod test { use number::GoldilocksField; use test_log::test; + use pretty_assertions::assert_eq; + use super::*; #[test] @@ -797,11 +835,6 @@ namespace T(65536); { T.pc, T.reg_write_X_A, T.reg_write_X_CNT } in (1 - T.first_step) { T.line, T.p_reg_write_X_A, T.p_reg_write_X_CNT }; "#; let formatted = process_pil_file_contents::(input).to_string(); - if input != formatted { - for (i, f) in input.split('\n').zip(formatted.split('\n')) { - assert_eq!(i, f); - } - } assert_eq!(input, formatted); } @@ -836,6 +869,26 @@ namespace T(65536); col int2 = N.intermediate; col int3 = (N.int2 + N.intermediate); N.int3 = (2 * N.x); +"#; + let formatted = process_pil_file_contents::(input).to_string(); + assert_eq!(formatted, expected); + } + + #[test] + fn let_definitions() { + let input = r#"namespace N(65536); + let x; + let t = |i| i + 1; + let z = 7; + let other = [1, 2]; + let other_fun = |i, j| (i + 7, (|k| k - i)); +"#; + let expected = r#"constant z = 7; +namespace N(65536); + col witness x; + col fixed t(i) { (i + 1) }; + let other = [1, 2]; + let other_fun = |i, j| ((i + 7), |k| (k - i)); "#; let formatted = process_pil_file_contents::(input).to_string(); assert_eq!(formatted, expected);