From ac0ccff780bc06bdd1451ec1e497f11dba16cf8c Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 19 Sep 2023 13:39:57 +0200 Subject: [PATCH] Array literals in expressions. --- asm_to_pil/src/vm_to_constrained.rs | 1 + ast/src/parsed/display.rs | 7 ++++++ ast/src/parsed/mod.rs | 6 ++++++ ast/src/parsed/utils.rs | 7 ++++-- backend/src/pilcom_cli/json_exporter/mod.rs | 1 + executor/src/constant_evaluator/mod.rs | 1 + parser/src/lib.rs | 24 ++++++++++++--------- parser/src/powdr.lalrpop | 1 + parser_util/src/lib.rs | 21 ++++++++++++++++++ pil_analyzer/src/pil_analyzer.rs | 10 +++++++-- 10 files changed, 65 insertions(+), 14 deletions(-) diff --git a/asm_to_pil/src/vm_to_constrained.rs b/asm_to_pil/src/vm_to_constrained.rs index 610f4f9cf..aa8961565 100644 --- a/asm_to_pil/src/vm_to_constrained.rs +++ b/asm_to_pil/src/vm_to_constrained.rs @@ -589,6 +589,7 @@ impl ASMPILConverter { Expression::Number(value) => vec![(value, AffineExpressionComponent::Constant)], Expression::String(_) => panic!(), Expression::Tuple(_) => panic!(), + Expression::ArrayLiteral(_) => panic!(), Expression::MatchExpression(_, _) => panic!(), Expression::FreeInput(expr) => { vec![(1.into(), AffineExpressionComponent::FreeInput(*expr))] diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index 15d7329a2..de1f38a34 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -433,6 +433,7 @@ impl Display for Expression { Expression::String(value) => write!(f, "\"{value}\""), // TODO quote? Expression::Tuple(items) => write!(f, "({})", format_expressions(items)), Expression::LambdaExpression(lambda) => write!(f, "{}", lambda), + Expression::ArrayLiteral(array) => write!(f, "{array}"), Expression::BinaryOperation(left, op, right) => write!(f, "({left} {op} {right})"), Expression::UnaryOperation(op, exp) => write!(f, "{op}{exp}"), Expression::FunctionCall(fun_call) => write!(f, "{fun_call}"), @@ -504,6 +505,12 @@ impl Display for LambdaExpression { } } +impl Display for ArrayLiteral { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!(f, "[{}]", self.items.iter().format(", ")) + } +} + impl Display for BinaryOperator { fn fmt(&self, f: &mut Formatter<'_>) -> Result { write!( diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index c1b8f81d3..27ea94a25 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -66,6 +66,7 @@ pub enum Expression> { String(String), Tuple(Vec>), LambdaExpression(LambdaExpression), + ArrayLiteral(ArrayLiteral), BinaryOperation( Box>, BinaryOperator, @@ -271,6 +272,11 @@ pub struct LambdaExpression> { pub body: Box>, } +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub struct ArrayLiteral> { + pub items: Vec>, +} + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] pub enum UnaryOperator { Plus, diff --git a/ast/src/parsed/utils.rs b/ast/src/parsed/utils.rs index 6204de226..36225ca25 100644 --- a/ast/src/parsed/utils.rs +++ b/ast/src/parsed/utils.rs @@ -1,8 +1,8 @@ use std::{iter::once, ops::ControlFlow}; use super::{ - ArrayExpression, Expression, FunctionCall, FunctionDefinition, LambdaExpression, MatchArm, - PilStatement, + ArrayExpression, ArrayLiteral, Expression, FunctionCall, FunctionDefinition, LambdaExpression, + MatchArm, PilStatement, }; /// Visits `expr` and all of its sub-expressions and returns true if `f` returns true on any of them. @@ -111,6 +111,7 @@ where previsit_expression(e, f)? } Expression::Tuple(items) + | Expression::ArrayLiteral(ArrayLiteral { items }) | Expression::FunctionCall(FunctionCall { id: _, arguments: items, @@ -152,6 +153,7 @@ where previsit_expression_mut(e.as_mut(), f)? } Expression::Tuple(items) + | Expression::ArrayLiteral(ArrayLiteral { items }) | Expression::FunctionCall(FunctionCall { arguments: items, .. }) => items @@ -190,6 +192,7 @@ where postvisit_expression_mut(e.as_mut(), f)? } Expression::Tuple(items) + | Expression::ArrayLiteral(ArrayLiteral { items }) | Expression::FunctionCall(FunctionCall { arguments: items, .. }) => items diff --git a/backend/src/pilcom_cli/json_exporter/mod.rs b/backend/src/pilcom_cli/json_exporter/mod.rs index 8cb9049d5..8647f67f5 100644 --- a/backend/src/pilcom_cli/json_exporter/mod.rs +++ b/backend/src/pilcom_cli/json_exporter/mod.rs @@ -311,6 +311,7 @@ impl<'a, T: FieldElement> Exporter<'a, T> { Expression::FunctionCall(_) => panic!("No function calls allowed here."), Expression::String(_) => panic!("Strings not allowed here."), Expression::Tuple(_) => panic!("Tuples not allowed here"), + Expression::ArrayLiteral(_) => panic!("Array literals not allowed here"), Expression::MatchExpression(_, _) => { panic!("No match expressions allowed here.") } diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index d982b85f9..f6e641552 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -97,6 +97,7 @@ impl<'a, T: FieldElement> Evaluator<'a, T> { Expression::Number(n) => *n, Expression::String(_) => panic!(), Expression::Tuple(_) => panic!(), + Expression::ArrayLiteral(_) => panic!(), Expression::BinaryOperation(left, op, right) => { evaluate_binary_operation(self.evaluate(left), *op, self.evaluate(right)) } diff --git a/parser/src/lib.rs b/parser/src/lib.rs index 138bc3a0a..a5dcf11a8 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -45,6 +45,7 @@ mod test { SelectedExpressions, }; use number::GoldilocksField; + use parser_util::UnwrapErrToStderr; use std::fs; use test_log::test; @@ -115,11 +116,7 @@ mod test { )); let input = fs::read_to_string(file).unwrap(); - parse(Some(name), &input).unwrap_or_else(|err| { - eprintln!("Parse error during test:"); - err.output_to_stderr(); - panic!(); - }) + parse(Some(name), &input).unwrap_err_to_stderr() } fn parse_asm_file(name: &str) -> ASMProgram { @@ -129,11 +126,7 @@ mod test { )); let input = fs::read_to_string(file).unwrap(); - parse_asm(Some(name), &input).unwrap_or_else(|err| { - eprintln!("Parse error during test:"); - err.output_to_stderr(); - panic!(); - }) + parse_asm(Some(name), &input).unwrap_err_to_stderr() } #[test] @@ -195,6 +188,7 @@ mod test { mod display { use number::GoldilocksField; + use parser_util::UnwrapErrToStderr; use pretty_assertions::assert_eq; use crate::parse; @@ -248,5 +242,15 @@ public out = y(%last_row);"#; ); assert_eq!(input.trim(), printed.trim()); } + + #[test] + fn array_literals() { + let input = r#"let x = [[1], [2], [(3 + 7)]];"#; + let printed = format!( + "{}", + parse::(Some("input"), input).unwrap_err_to_stderr() + ); + assert_eq!(input.trim(), printed.trim()); + } } } diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index 1fd68a57a..fb744c120 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -457,6 +457,7 @@ Term: Box> = { FieldElement => Box::new(Expression::Number(<>)), StringLiteral => Box::new(Expression::String(<>)), MatchExpression, + "[" "]" => Box::new(Expression::ArrayLiteral(ArrayLiteral{items})), "(" "," ")" => { let mut list = vec![head]; list.extend(tail); Box::new(Expression::Tuple(list)) }, "(" ")", "${" "}" => Box::new(Expression::FreeInput(<>)) diff --git a/parser_util/src/lib.rs b/parser_util/src/lib.rs index d22d77848..f8a9eec9b 100644 --- a/parser_util/src/lib.rs +++ b/parser_util/src/lib.rs @@ -57,3 +57,24 @@ pub fn handle_parse_error<'a>( message: format!("{err}"), } } + +/// Convenience trait that outputs parser errors to stderr and panics. +/// Should be used mostly in tests. +pub trait UnwrapErrToStderr { + type Inner; + fn unwrap_err_to_stderr(self) -> Self::Inner; +} + +impl<'a, T> UnwrapErrToStderr for Result> { + type Inner = T; + + fn unwrap_err_to_stderr(self) -> Self::Inner { + match self { + Ok(r) => r, + Err(err) => { + err.output_to_stderr(); + panic!("Parse error."); + } + } + } +} diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index dd19c0b25..5d83d5acb 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -5,8 +5,8 @@ use std::path::{Path, PathBuf}; use analysis::MacroExpander; use ast::parsed::utils::postvisit_expression_mut; use ast::parsed::{ - self, ArrayExpression, BinaryOperator, FunctionDefinition, LambdaExpression, MatchArm, - MatchPattern, PilStatement, PolynomialName, UnaryOperator, + self, ArrayExpression, ArrayLiteral, BinaryOperator, FunctionDefinition, LambdaExpression, + MatchArm, MatchPattern, PilStatement, PolynomialName, UnaryOperator, }; use number::{DegreeType, FieldElement}; @@ -521,6 +521,11 @@ impl PILContext { PExpression::Number(n) => Expression::Number(n), PExpression::String(value) => Expression::String(value), PExpression::Tuple(items) => Expression::Tuple(self.process_expressions(items)), + PExpression::ArrayLiteral(ArrayLiteral { items }) => { + Expression::ArrayLiteral(ArrayLiteral { + items: self.process_expressions(items), + }) + } PExpression::LambdaExpression(LambdaExpression { params, body }) => { Expression::LambdaExpression(LambdaExpression { params, @@ -618,6 +623,7 @@ impl PILContext { Number(n) => Some(*n), String(_) => None, Tuple(_) => None, + ArrayLiteral(_) => None, LambdaExpression(_) => None, BinaryOperation(left, op, right) => self.evaluate_binary_operation(left, *op, right), UnaryOperation(op, value) => self.evaluate_unary_operation(*op, value),