From 12e3e5d0217e82b08817873ccc2c1db47d906bbd Mon Sep 17 00:00:00 2001 From: chriseth Date: Wed, 5 Apr 2023 19:43:39 +0200 Subject: [PATCH] Match expressions. --- src/analyzer/display.rs | 13 +++++++++ src/analyzer/mod.rs | 4 +++ src/analyzer/pil_analyzer.rs | 7 +++++ src/asm_compiler/mod.rs | 7 +++++ src/constant_evaluator/mod.rs | 28 +++++++++++++++++++ src/json_exporter/mod.rs | 3 ++ src/parser/ast.rs | 4 +++ src/parser/display.rs | 13 +++++++++ src/parser/powdr.lalrpop | 12 +++++++- src/witness_generator/expression_evaluator.rs | 3 ++ .../machines/machine_extractor.rs | 8 ++++++ src/witness_generator/util.rs | 5 ++++ 12 files changed, 106 insertions(+), 1 deletion(-) diff --git a/src/analyzer/display.rs b/src/analyzer/display.rs index 0d7e7f385..6f7a88e1b 100644 --- a/src/analyzer/display.rs +++ b/src/analyzer/display.rs @@ -47,6 +47,19 @@ impl Display for Expression { Expression::UnaryOperation(op, exp) => write!(f, "{op}{exp}"), Expression::FunctionCall(fun, args) => write!(f, "{fun}({})", format_expressions(args)), Expression::LocalVariableReference(index) => write!(f, "${index}"), + Expression::MatchExpression(scrutinee, arms) => write!( + f, + "match {scrutinee} {{ {} }}", + arms.iter() + .map(|(n, e)| format!( + "{} => {e},", + n.as_ref() + .map(|n| n.to_string()) + .unwrap_or_else(|| "_".to_string()) + )) + .collect::>() + .join(" ") + ), } } } diff --git a/src/analyzer/mod.rs b/src/analyzer/mod.rs index 9f26b1a64..279032c5e 100644 --- a/src/analyzer/mod.rs +++ b/src/analyzer/mod.rs @@ -159,6 +159,10 @@ pub enum Expression { UnaryOperation(UnaryOperator, Box), /// Call to a non-macro function (like a constant polynomial) FunctionCall(String, Vec), + MatchExpression( + Box, + Vec<(Option, Expression)>, + ), } #[derive(Debug, PartialEq, Eq, Default, Clone)] diff --git a/src/analyzer/pil_analyzer.rs b/src/analyzer/pil_analyzer.rs index 26a7ee8ba..8435859ff 100644 --- a/src/analyzer/pil_analyzer.rs +++ b/src/analyzer/pil_analyzer.rs @@ -508,6 +508,12 @@ impl PILContext { ast::Expression::FunctionCall(name, arguments) => { Expression::FunctionCall(self.namespaced(name), self.process_expressions(arguments)) } + ast::Expression::MatchExpression(scrutinee, arms) => Expression::MatchExpression( + Box::new(self.process_expression(scrutinee)), + arms.iter() + .map(|(n, e)| (n.clone(), self.process_expression(e))) + .collect(), + ), ast::Expression::FreeInput(_) => panic!(), } } @@ -576,6 +582,7 @@ impl PILContext { ast::Expression::UnaryOperation(op, value) => self.evaluate_unary_operation(op, value), ast::Expression::FunctionCall(_, _) => None, ast::Expression::FreeInput(_) => panic!(), + ast::Expression::MatchExpression(_, _) => None, } } diff --git a/src/asm_compiler/mod.rs b/src/asm_compiler/mod.rs index e97cd1e77..7770ffa76 100644 --- a/src/asm_compiler/mod.rs +++ b/src/asm_compiler/mod.rs @@ -376,6 +376,7 @@ impl ASMPILConverter { Expression::Number(value) => vec![(value.clone(), AffineExpressionComponent::Constant)], Expression::String(_) => panic!(), Expression::Tuple(_) => panic!(), + Expression::MatchExpression(_, _) => panic!(), Expression::FreeInput(expr) => { vec![( 1.into(), @@ -792,6 +793,12 @@ fn substitute(input: &Expression, substitution: &HashMap) -> Exp | Expression::Number(_) | Expression::String(_) | Expression::FreeInput(_) => input.clone(), + Expression::MatchExpression(scrutinee, arms) => Expression::MatchExpression( + Box::new(substitute(scrutinee, substitution)), + arms.iter() + .map(|(n, e)| (n.clone(), substitute(e, substitution))) + .collect(), + ), } } diff --git a/src/constant_evaluator/mod.rs b/src/constant_evaluator/mod.rs index 641a1e931..93acba581 100644 --- a/src/constant_evaluator/mod.rs +++ b/src/constant_evaluator/mod.rs @@ -91,6 +91,13 @@ impl<'a> Evaluator<'a> { let values = &self.other_constants[name.as_str()]; values[abstract_to_degree(&arg_values[0]) as usize % values.len()].clone() } + Expression::MatchExpression(scrutinee, arms) => { + let v = self.evaluate(scrutinee); + arms.iter() + .find(|(n, _)| n.is_none() || n.as_ref() == Some(&v)) + .map(|(_, e)| self.evaluate(e)) + .expect("No arm matched the value {v}") + } } } @@ -194,6 +201,27 @@ mod test { ); } + #[test] + pub fn test_match() { + let src = r#" + constant %N = 8; + namespace F(%N); + pol constant X(i) { match i { + 0 => 7, + 3 => 9, + 5 => 2, + _ => 4, + } + 1 }; + "#; + let analyzed = analyze_string(src); + let (constants, degree) = generate(&analyzed); + assert_eq!(degree, 8); + assert_eq!( + constants, + vec![("F.X", convert(vec![8, 5, 5, 10, 5, 3, 5, 5]))] + ); + } + #[test] pub fn test_macro() { let src = r#" diff --git a/src/json_exporter/mod.rs b/src/json_exporter/mod.rs index d385a6553..50ec01c65 100644 --- a/src/json_exporter/mod.rs +++ b/src/json_exporter/mod.rs @@ -306,6 +306,9 @@ impl<'a> Exporter<'a> { } Expression::String(_) => panic!("Strings not allowed here."), Expression::Tuple(_) => panic!("Tuples not allowed here"), + Expression::MatchExpression(_, _) => { + panic!("No match expressions allowed here.") + } } } diff --git a/src/parser/ast.rs b/src/parser/ast.rs index f6d5048d1..59d1edae1 100644 --- a/src/parser/ast.rs +++ b/src/parser/ast.rs @@ -48,6 +48,10 @@ pub enum Expression { UnaryOperation(UnaryOperator, Box), FunctionCall(String, Vec), FreeInput(Box), + MatchExpression( + Box, + Vec<(Option, Expression)>, + ), } #[derive(Debug, PartialEq, Eq, Default, Clone)] diff --git a/src/parser/display.rs b/src/parser/display.rs index 62bf51ff6..0c61e0733 100644 --- a/src/parser/display.rs +++ b/src/parser/display.rs @@ -153,6 +153,19 @@ impl Display for Expression { Expression::UnaryOperation(op, exp) => write!(f, "{op}{exp}"), Expression::FunctionCall(fun, args) => write!(f, "{fun}({})", format_expressions(args)), Expression::FreeInput(input) => write!(f, "${{ {input} }}"), + Expression::MatchExpression(scrutinee, arms) => write!( + f, + "match {scrutinee} {{ {} }}", + arms.iter() + .map(|(n, e)| format!( + "{} => {e},", + n.as_ref() + .map(|n| n.to_string()) + .unwrap_or_else(|| "_".to_string()) + )) + .collect::>() + .join(" ") + ), } } } diff --git a/src/parser/powdr.lalrpop b/src/parser/powdr.lalrpop index 5b4c4b7ea..605e5e433 100644 --- a/src/parser/powdr.lalrpop +++ b/src/parser/powdr.lalrpop @@ -241,7 +241,7 @@ Expression: Expression = { } BoxedExpression: Box = { - BinaryOr + BinaryOr, } BinaryOr: Box = { @@ -328,6 +328,7 @@ Term: Box = { PublicReference => Box::new(Expression::PublicReference(<>)), Number => Box::new(Expression::Number(<>)), StringLiteral => Box::new(Expression::String(<>)), + MatchExpression, "(" "," ")" => { let mut list = vec![head]; list.extend(tail); Box::new(Expression::Tuple(list)) }, "(" ")", "${" "}" => Box::new(Expression::FreeInput(<>)) @@ -348,6 +349,15 @@ PublicReference: String = { ":" } +MatchExpression: Box = { + "match" "{" <( ",")*> "}" => Box::new(Expression::MatchExpression(<>)) +} + +MatchArm: (Option, Expression) = { + "=>" => (Some(n), e), + "=>" => (None, e), +} + // ---------------------------- Terminals ----------------------------- diff --git a/src/witness_generator/expression_evaluator.rs b/src/witness_generator/expression_evaluator.rs index 400642288..2c3382ac5 100644 --- a/src/witness_generator/expression_evaluator.rs +++ b/src/witness_generator/expression_evaluator.rs @@ -47,6 +47,9 @@ impl ExpressionEvaluator { Expression::FunctionCall(_, _) => { Err("Function calls not implemented.".to_string().into()) } + Expression::MatchExpression(_, _) => { + Err("Match expressions not implemented.".to_string().into()) + } } } diff --git a/src/witness_generator/machines/machine_extractor.rs b/src/witness_generator/machines/machine_extractor.rs index dc61bfc1d..2cedf47af 100644 --- a/src/witness_generator/machines/machine_extractor.rs +++ b/src/witness_generator/machines/machine_extractor.rs @@ -141,6 +141,14 @@ impl<'a> ReferenceExtractor<'a> { Expression::BinaryOperation(l, _, r) => &self.in_expression(l) | &self.in_expression(r), Expression::UnaryOperation(_, e) => self.in_expression(e), Expression::FunctionCall(_, args) => self.in_expressions(args), + Expression::MatchExpression(scrutinee, arms) => { + &self.in_expression(scrutinee) + | &arms + .iter() + .map(|(_, e)| self.in_expression(e)) + .reduce(|a, b| &a | &b) + .unwrap_or_default() + } Expression::LocalVariableReference(_) | Expression::PublicReference(_) | Expression::Number(_) diff --git a/src/witness_generator/util.rs b/src/witness_generator/util.rs index c2942f27d..76fe6476e 100644 --- a/src/witness_generator/util.rs +++ b/src/witness_generator/util.rs @@ -1,3 +1,5 @@ +use std::iter::once; + use crate::analyzer::{Expression, PolynomialReference}; use super::FixedData; @@ -62,6 +64,9 @@ pub fn expr_any(expr: &Expression, f: &mut impl FnMut(&Expression) -> bool) -> b Expression::BinaryOperation(l, _, r) => expr_any(l, f) || expr_any(r, f), Expression::UnaryOperation(_, e) => expr_any(e, f), Expression::FunctionCall(_, args) => args.iter().any(|e| expr_any(e, f)), + Expression::MatchExpression(scrutinee, arms) => once(scrutinee.as_ref()) + .chain(arms.iter().map(|(_n, e)| e)) + .any(|e| expr_any(e, f)), Expression::Constant(_) | Expression::PolynomialReference(_) | Expression::LocalVariableReference(_)