diff --git a/analysis/src/macro_expansion.rs b/analysis/src/macro_expansion.rs index a1f724871..51d6faae2 100644 --- a/analysis/src/macro_expansion.rs +++ b/analysis/src/macro_expansion.rs @@ -108,7 +108,8 @@ where match &mut statement { PilStatement::Expression(_start, e) => match e { Expression::FunctionCall(FunctionCall { id, arguments }) => { - if !self.macros.contains_key(id) { + assert!(id.namespace.is_none()); + if !self.macros.contains_key(&id.name) { panic!("Macro {id} not found - only macros allowed at this point, no fixed columns."); } let arguments = std::mem::take(arguments) @@ -118,7 +119,7 @@ where a }) .collect(); - if self.expand_macro(id, arguments).is_some() { + if self.expand_macro(&id.name, arguments).is_some() { panic!("Invoked a macro in statement context with non-empty expression."); } } @@ -185,10 +186,12 @@ where *e = self.arguments[self.parameter_names[&poly.name]].clone() } } else if let Expression::FunctionCall(call) = e { - let name = call.id.as_str(); - if !self.shadowing_locals.contains(name) && self.macros.contains_key(name) { + if call.id.namespace.is_none() + && !self.shadowing_locals.contains(&call.id.name) + && self.macros.contains_key(&call.id.name) + { *e = self - .expand_macro(name, std::mem::take(&mut call.arguments)) + .expand_macro(&call.id.name, std::mem::take(&mut call.arguments)) .expect("Invoked a macro in expression context with empty expression.") } } diff --git a/analysis/src/vm/inference.rs b/analysis/src/vm/inference.rs index 6343baea0..907c7240b 100644 --- a/analysis/src/vm/inference.rs +++ b/analysis/src/vm/inference.rs @@ -40,7 +40,7 @@ fn infer_machine(mut machine: Machine) -> Result, let def = machine .instructions .iter() - .find(|i| i.name == c.id) + .find(|i| i.name == c.id.to_string()) .unwrap(); let outputs = def.instruction.params.outputs.clone().unwrap_or_default(); diff --git a/asm_to_pil/src/vm_to_constrained.rs b/asm_to_pil/src/vm_to_constrained.rs index 5313fb012..ab6801c98 100644 --- a/asm_to_pil/src/vm_to_constrained.rs +++ b/asm_to_pil/src/vm_to_constrained.rs @@ -236,7 +236,8 @@ impl ASMPILConverter { match *rhs { Expression::FunctionCall(c) => { - self.handle_functional_instruction(lhs_with_reg, c.id, c.arguments) + assert!(c.id.namespace.is_none()); + self.handle_functional_instruction(lhs_with_reg, c.id.name, c.arguments) } _ => self.handle_non_functional_assignment(start, lhs_with_reg, *rhs), } diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index 6baf06b5e..8671d3a69 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -231,7 +231,7 @@ pub struct IndexAccess { #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] pub struct FunctionCall { - pub id: String, + pub id: Ref, pub arguments: Vec>, } diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index dde3e7160..aeaaa691f 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -492,7 +492,7 @@ IndexAccess: IndexAccess = { } FunctionCall: FunctionCall = { - "(" ")" => FunctionCall {<>}, + "(" ")" => FunctionCall {<>}, } NamespacedPolynomialReference: NamespacedPolynomialReference = { diff --git a/pil_analyzer/src/evaluator.rs b/pil_analyzer/src/evaluator.rs index a82a1a78d..87eb93340 100644 --- a/pil_analyzer/src/evaluator.rs +++ b/pil_analyzer/src/evaluator.rs @@ -84,7 +84,7 @@ impl<'a, T: FieldElement> Evaluator<'a, T> { .map(|a| self.evaluate(a)) .collect::, _>>()?; assert!(arg_values.len() == 1); - let values = &self.function_cache[id.as_str()]; + let values = &self.function_cache[id.to_string().as_str()]; Ok(values[arg_values[0].to_degree() as usize % values.len()]) } Expression::MatchExpression(scrutinee, arms) => { diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index 0c4824af4..25c9ae595 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -7,7 +7,7 @@ use analysis::MacroExpander; use ast::parsed::visitor::ExpressionVisitable; use ast::parsed::{ self, ArrayExpression, ArrayLiteral, FunctionDefinition, LambdaExpression, MatchArm, - MatchPattern, PilStatement, PolynomialName, SelectedExpressions, + MatchPattern, NamespacedPolynomialReference, PilStatement, PolynomialName, SelectedExpressions, }; use number::{DegreeType, FieldElement}; @@ -558,16 +558,7 @@ impl<'a, T: FieldElement> ExpressionProcessor<'a, T> { pub fn process_expression(&mut self, expr: parsed::Expression) -> Expression { use parsed::Expression as PExpression; match expr { - PExpression::Reference(poly) => { - if poly.namespace.is_none() && self.local_variables.contains_key(&poly.name) { - let id = self.local_variables[&poly.name]; - Expression::Reference(Reference::LocalVar(id, poly.name.to_string())) - } else { - Expression::Reference(Reference::Poly( - self.process_namespaced_polynomial_reference(poly), - )) - } - } + PExpression::Reference(poly) => Expression::Reference(self.process_reference(poly)), PExpression::PublicReference(name) => Expression::PublicReference(name), PExpression::Number(n) => Expression::Number(n), PExpression::String(value) => Expression::String(value), @@ -596,7 +587,7 @@ impl<'a, T: FieldElement> ExpressionProcessor<'a, T> { }) } PExpression::FunctionCall(c) => Expression::FunctionCall(parsed::FunctionCall { - id: self.analyzer.namespaced_ref_to_absolute(&None, &c.id), + id: self.process_reference(c.id), arguments: self.process_expressions(c.arguments), }), PExpression::MatchExpression(scrutinee, arms) => Expression::MatchExpression( @@ -617,6 +608,15 @@ impl<'a, T: FieldElement> ExpressionProcessor<'a, T> { } } + fn process_reference(&mut self, reference: NamespacedPolynomialReference) -> Reference { + if reference.namespace.is_none() && self.local_variables.contains_key(&reference.name) { + let id = self.local_variables[&reference.name]; + Reference::LocalVar(id, reference.name.to_string()) + } else { + Reference::Poly(self.process_namespaced_polynomial_reference(reference)) + } + } + fn process_function( &mut self, params: &[String], @@ -879,6 +879,17 @@ namespace N(65536); let input = r#"namespace N(16); col witness y[3]; (N.y[3] - 2) = 0; +"#; + let formatted = process_pil_file_contents::(input).to_string(); + assert_eq!(formatted, input); + } + + #[test] + fn namespaced_call() { + let input = r#"namespace Assembly(2); + col fixed A = [0]*; + col fixed C(i) { (Assembly.A((i + 2)) + 3) }; + col fixed D(i) { Assembly.C((i + 3)) }; "#; let formatted = process_pil_file_contents::(input).to_string(); assert_eq!(formatted, input);