diff --git a/src/analyzer/mod.rs b/src/analyzer/mod.rs index 1beb786b1..69ff3e09a 100644 --- a/src/analyzer/mod.rs +++ b/src/analyzer/mod.rs @@ -194,6 +194,8 @@ pub enum Expression { Number(ConstantNumberType), BinaryOperation(Box, BinaryOperator, Box), UnaryOperation(UnaryOperator, Box), + /// Call to a non-macro function (like a constant polynomial) + FunctionCall(String, Vec), } #[derive(Debug, PartialEq, Eq, Default, Clone)] @@ -330,6 +332,11 @@ impl Context { fn handle_identity_statement(&mut self, statement: &ast::Statement) { if let ast::Statement::FunctionCall(_start, name, arguments) = statement { + if !self.macros.contains_key(name) { + panic!( + "Macro {name} not found - only macros allowed at this point, no fixed columns." + ); + } // TODO check that it does not contain local variable references. // But we also need to do some other well-formedness checks. if self.process_macro_call(name, arguments).is_some() { @@ -595,9 +602,13 @@ impl Context { Expression::UnaryOperation(*op, Box::new(self.process_expression(value))) } } - ast::Expression::FunctionCall(name, arguments) => self - .process_macro_call(name, arguments) - .expect("Invoked a macro in expression context with empty expression."), + ast::Expression::FunctionCall(name, arguments) if self.macros.contains_key(name) => { + self.process_macro_call(name, arguments) + .expect("Invoked a macro in expression context with empty expression.") + } + ast::Expression::FunctionCall(name, arguments) => { + Expression::FunctionCall(self.namespaced(name), self.process_expressions(arguments)) + } } } diff --git a/src/commit_evaluator/mod.rs b/src/commit_evaluator/mod.rs index 0d2fbdeca..0fbf6be53 100644 --- a/src/commit_evaluator/mod.rs +++ b/src/commit_evaluator/mod.rs @@ -157,6 +157,7 @@ impl<'a> Evaluator<'a> { Expression::UnaryOperation(op, expr) => self.evaluate_unary_operation(op, expr), Expression::LocalVariableReference(_) => panic!(), Expression::PublicReference(_) => panic!(), + Expression::FunctionCall(_, _) => panic!(), } } fn evaluate_binary_operation( diff --git a/src/constant_evaluator/mod.rs b/src/constant_evaluator/mod.rs index d112b884e..8563d2e7b 100644 --- a/src/constant_evaluator/mod.rs +++ b/src/constant_evaluator/mod.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use crate::analyzer::{Analyzed, BinaryOperator, ConstantNumberType, Expression, UnaryOperator}; /// Generates the constant polynomial values for all constant polynomials @@ -7,24 +9,24 @@ pub fn generate( analyzed: &Analyzed, ) -> (Vec<(&String, Vec)>, ConstantNumberType) { let mut degree = None; - let values = analyzed - .constant_polys_in_source_order() - .iter() - .filter_map(|(poly, value)| { - if let Some(value) = value { - if let Some(degree) = degree { - assert!(degree == poly.degree); - } else { - degree = Some(poly.degree); - } - return Some(( - &poly.absolute_name, - generate_values(analyzed, poly.degree, value), - )); + let mut other_constants = HashMap::new(); + for (poly, value) in analyzed.constant_polys_in_source_order() { + if let Some(value) = value { + if let Some(degree) = degree { + assert!(degree == poly.degree); + } else { + degree = Some(poly.degree); } - None - }) - .collect(); + let values = generate_values(analyzed, poly.degree, value, &other_constants); + other_constants.insert(&poly.absolute_name, values); + } + } + let mut values = Vec::new(); + for (poly, _) in analyzed.constant_polys_in_source_order() { + if let Some(v) = other_constants.get_mut(&poly.absolute_name) { + values.push((&poly.absolute_name, std::mem::take(v))); + }; + } (values, degree.unwrap_or_default()) } @@ -32,12 +34,14 @@ fn generate_values( analyzed: &Analyzed, degree: ConstantNumberType, body: &Expression, + other_constants: &HashMap<&String, Vec>, ) -> Vec { (0..degree) .map(|i| { Evaluator { analyzed, variables: &[i], + other_constants, } .evaluate(body) }) @@ -46,6 +50,7 @@ fn generate_values( struct Evaluator<'a> { analyzed: &'a Analyzed, + other_constants: &'a HashMap<&'a String, Vec>, variables: &'a [ConstantNumberType], } @@ -61,6 +66,11 @@ impl<'a> Evaluator<'a> { self.evaluate_binary_operation(left, op, right) } Expression::UnaryOperation(op, expr) => self.evaluate_unary_operation(op, expr), + Expression::FunctionCall(name, args) => { + let arg_values = args.iter().map(|a| self.evaluate(a)).collect::>(); + assert!(arg_values.len() == 1); + self.other_constants[name][arg_values[0] as usize] + } } } @@ -186,4 +196,45 @@ mod test { )] ); } + + #[test] + pub fn test_poly_call() { + let src = r#" + constant %N = 10; + namespace F(%N); + col fixed seq(i) { i }; + col fixed doub(i) { seq((2 * i) % %N) + 1 }; + col fixed half_nibble(i) { i & 0x7 }; + col fixed doubled_half_nibble(i) { half_nibble(i / 2) }; + "#; + let analyzed = analyze_string(src); + let (constants, degree) = generate(&analyzed); + assert_eq!(degree, 10); + assert_eq!(constants.len(), 4); + assert_eq!( + constants[0], + (&"F.seq".to_string(), (0..=9i128).collect::>()) + ); + assert_eq!( + constants[1], + ( + &"F.doub".to_string(), + [1i128, 3, 5, 7, 9, 1, 3, 5, 7, 9].to_vec() + ) + ); + assert_eq!( + constants[2], + ( + &"F.half_nibble".to_string(), + [0i128, 1, 2, 3, 4, 5, 6, 7, 0, 1].to_vec() + ) + ); + assert_eq!( + constants[3], + ( + &"F.doubled_half_nibble".to_string(), + [0i128, 0, 1, 1, 2, 2, 3, 3, 4, 4].to_vec() + ) + ); + } } diff --git a/src/json_exporter/mod.rs b/src/json_exporter/mod.rs index 291c11619..0de0ff01f 100644 --- a/src/json_exporter/mod.rs +++ b/src/json_exporter/mod.rs @@ -296,6 +296,9 @@ impl<'a> Exporter<'a> { ), } } + Expression::FunctionCall(_, _) => { + panic!("No function calls allowed here.") + } } } diff --git a/tests/global.pil b/tests/global.pil index 16c4758ac..d7838b594 100644 --- a/tests/global.pil +++ b/tests/global.pil @@ -20,8 +20,7 @@ namespace Global(%N); col fixed LLAST(i) { one_hot(i, %N - 1) }; col fixed BYTE(i) { i & 0xff }; col fixed BYTE2(i) { i & 0xffff }; -// col fixed BYTE_2A(i) { BYTE2(i) >> 8 }; - col fixed BYTE_2A(i) { (i & 0xffff) >> 8 }; + col fixed BYTE_2A(i) { BYTE2(i) >> 8 }; // TODO it might be confusing to remember which one is the array index // and which one is the polynomial parameter. // Here, k is the array index and i is the polynomial parameter.