diff --git a/analysis/src/macro_expansion.rs b/analysis/src/macro_expansion.rs index f6ac089c8..a1f724871 100644 --- a/analysis/src/macro_expansion.rs +++ b/analysis/src/macro_expansion.rs @@ -7,7 +7,7 @@ use ast::parsed::{ asm::{ASMProgram, Instruction, InstructionBody, Machine, MachineStatement}, folder::Folder, visitor::ExpressionVisitable, - Expression, FunctionDefinition, PilStatement, + Expression, FunctionCall, FunctionDefinition, PilStatement, }; use number::FieldElement; @@ -105,19 +105,25 @@ where } } - statement.post_visit_expressions_mut(&mut |e| self.process_expression(e)); - match &mut statement { - PilStatement::FunctionCall(_start, name, arguments) => { - if !self.macros.contains_key(name) { - panic!( - "Macro {name} not found - only macros allowed at this point, no fixed columns." - ); + PilStatement::Expression(_start, e) => match e { + Expression::FunctionCall(FunctionCall { id, arguments }) => { + if !self.macros.contains_key(id) { + panic!("Macro {id} not found - only macros allowed at this point, no fixed columns."); + } + let arguments = std::mem::take(arguments) + .into_iter() + .map(|mut a| { + self.process_expression(&mut a); + a + }) + .collect(); + if self.expand_macro(id, arguments).is_some() { + panic!("Invoked a macro in statement context with non-empty expression."); + } } - if self.expand_macro(name, std::mem::take(arguments)).is_some() { - panic!("Invoked a macro in statement context with non-empty expression."); - } - } + _ => panic!("Only function calls or identities allowed at PIL statement level."), + }, PilStatement::MacroDefinition(_start, name, parameters, statements, expression) => { // We expand lazily. Is that a mistake? let is_new = self @@ -133,7 +139,10 @@ where .is_none(); assert!(is_new); } - _ => self.statements.push(statement), + _ => { + statement.post_visit_expressions_mut(&mut |e| self.process_expression(e)); + self.statements.push(statement); + } }; if added_locals { @@ -176,9 +185,10 @@ where *e = self.arguments[self.parameter_names[&poly.name]].clone() } } else if let Expression::FunctionCall(call) = e { - if self.macros.contains_key(call.id.as_str()) { + let name = call.id.as_str(); + if !self.shadowing_locals.contains(name) && self.macros.contains_key(name) { *e = self - .expand_macro(call.id.as_str(), std::mem::take(&mut call.arguments)) + .expand_macro(name, std::mem::take(&mut call.arguments)) .expect("Invoked a macro in expression context with empty expression.") } } diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index 095d8a358..eb4a8b20b 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -392,8 +392,8 @@ impl Display for PilStatement { }; write!(f, "macro {name}({}) {{{body}}};", params.join(", ")) } - PilStatement::FunctionCall(_, name, args) => { - write!(f, "{name}({});", format_expressions(args)) + PilStatement::Expression(_, e) => { + write!(f, "{e};") } } } diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index 06461d0aa..6baf06b5e 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -54,7 +54,7 @@ pub enum PilStatement { Vec>, Option>, ), - FunctionCall(usize, String, Vec>), + Expression(usize, Expression), } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] diff --git a/ast/src/parsed/visitor.rs b/ast/src/parsed/visitor.rs index 43a9555f7..c6071e800 100644 --- a/ast/src/parsed/visitor.rs +++ b/ast/src/parsed/visitor.rs @@ -189,9 +189,7 @@ impl ExpressionVisitable> for Pi F: FnMut(&mut Expression) -> ControlFlow, { match self { - PilStatement::FunctionCall(_, _, arguments) => arguments - .iter_mut() - .try_for_each(|e| e.visit_expressions_mut(f, o)), + PilStatement::Expression(_, e) => e.visit_expressions_mut(f, o), PilStatement::PlookupIdentity(_, left, right) | PilStatement::PermutationIdentity(_, left, right) => [left, right] .into_iter() @@ -229,9 +227,7 @@ impl ExpressionVisitable> for Pi F: FnMut(&Expression) -> ControlFlow, { match self { - PilStatement::FunctionCall(_, _, arguments) => { - arguments.iter().try_for_each(|e| e.visit_expressions(f, o)) - } + PilStatement::Expression(_, e) => e.visit_expressions(f, o), PilStatement::PlookupIdentity(_, left, right) | PilStatement::PermutationIdentity(_, left, right) => [left, right] .into_iter() diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index 23de7922f..dde3e7160 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -62,7 +62,7 @@ pub PilStatement = { PermutationIdentity, ConnectIdentity, MacroDefinition, - FunctionCallStatement, + ExpressionStatement, }; Include: PilStatement = { @@ -161,8 +161,8 @@ MacroDefinition: PilStatement = { => PilStatement::MacroDefinition(<>) } -FunctionCallStatement: PilStatement = { - <@L> "(" ")" => PilStatement::FunctionCall(<>) +ExpressionStatement: PilStatement = { + <@L> => PilStatement::Expression(<>) } PolCol = { @@ -257,8 +257,7 @@ InstructionBodyElement: PilStatement = { PolynomialIdentity, PlookupIdentity, PermutationIdentity, - // We could use FunctionCallStatement here, but it makes lalrpop fail to build - <@L> "(" ")" => PilStatement::FunctionCall(<>) + ExpressionStatement, } Params: Params = {