Match expressions.

This commit is contained in:
chriseth
2023-04-05 19:43:39 +02:00
parent 604695c753
commit 12e3e5d021
12 changed files with 106 additions and 1 deletions

View File

@@ -47,6 +47,19 @@ impl Display for Expression {
Expression::UnaryOperation(op, exp) => write!(f, "{op}{exp}"),
Expression::FunctionCall(fun, args) => write!(f, "{fun}({})", format_expressions(args)),
Expression::LocalVariableReference(index) => write!(f, "${index}"),
Expression::MatchExpression(scrutinee, arms) => write!(
f,
"match {scrutinee} {{ {} }}",
arms.iter()
.map(|(n, e)| format!(
"{} => {e},",
n.as_ref()
.map(|n| n.to_string())
.unwrap_or_else(|| "_".to_string())
))
.collect::<Vec<_>>()
.join(" ")
),
}
}
}

View File

@@ -159,6 +159,10 @@ pub enum Expression {
UnaryOperation(UnaryOperator, Box<Expression>),
/// Call to a non-macro function (like a constant polynomial)
FunctionCall(String, Vec<Expression>),
MatchExpression(
Box<Expression>,
Vec<(Option<AbstractNumberType>, Expression)>,
),
}
#[derive(Debug, PartialEq, Eq, Default, Clone)]

View File

@@ -508,6 +508,12 @@ impl PILContext {
ast::Expression::FunctionCall(name, arguments) => {
Expression::FunctionCall(self.namespaced(name), self.process_expressions(arguments))
}
ast::Expression::MatchExpression(scrutinee, arms) => Expression::MatchExpression(
Box::new(self.process_expression(scrutinee)),
arms.iter()
.map(|(n, e)| (n.clone(), self.process_expression(e)))
.collect(),
),
ast::Expression::FreeInput(_) => panic!(),
}
}
@@ -576,6 +582,7 @@ impl PILContext {
ast::Expression::UnaryOperation(op, value) => self.evaluate_unary_operation(op, value),
ast::Expression::FunctionCall(_, _) => None,
ast::Expression::FreeInput(_) => panic!(),
ast::Expression::MatchExpression(_, _) => None,
}
}

View File

@@ -376,6 +376,7 @@ impl ASMPILConverter {
Expression::Number(value) => vec![(value.clone(), AffineExpressionComponent::Constant)],
Expression::String(_) => panic!(),
Expression::Tuple(_) => panic!(),
Expression::MatchExpression(_, _) => panic!(),
Expression::FreeInput(expr) => {
vec![(
1.into(),
@@ -792,6 +793,12 @@ fn substitute(input: &Expression, substitution: &HashMap<String, String>) -> Exp
| Expression::Number(_)
| Expression::String(_)
| Expression::FreeInput(_) => input.clone(),
Expression::MatchExpression(scrutinee, arms) => Expression::MatchExpression(
Box::new(substitute(scrutinee, substitution)),
arms.iter()
.map(|(n, e)| (n.clone(), substitute(e, substitution)))
.collect(),
),
}
}

View File

@@ -91,6 +91,13 @@ impl<'a> Evaluator<'a> {
let values = &self.other_constants[name.as_str()];
values[abstract_to_degree(&arg_values[0]) as usize % values.len()].clone()
}
Expression::MatchExpression(scrutinee, arms) => {
let v = self.evaluate(scrutinee);
arms.iter()
.find(|(n, _)| n.is_none() || n.as_ref() == Some(&v))
.map(|(_, e)| self.evaluate(e))
.expect("No arm matched the value {v}")
}
}
}
@@ -194,6 +201,27 @@ mod test {
);
}
#[test]
pub fn test_match() {
let src = r#"
constant %N = 8;
namespace F(%N);
pol constant X(i) { match i {
0 => 7,
3 => 9,
5 => 2,
_ => 4,
} + 1 };
"#;
let analyzed = analyze_string(src);
let (constants, degree) = generate(&analyzed);
assert_eq!(degree, 8);
assert_eq!(
constants,
vec![("F.X", convert(vec![8, 5, 5, 10, 5, 3, 5, 5]))]
);
}
#[test]
pub fn test_macro() {
let src = r#"

View File

@@ -306,6 +306,9 @@ impl<'a> Exporter<'a> {
}
Expression::String(_) => panic!("Strings not allowed here."),
Expression::Tuple(_) => panic!("Tuples not allowed here"),
Expression::MatchExpression(_, _) => {
panic!("No match expressions allowed here.")
}
}
}

View File

@@ -48,6 +48,10 @@ pub enum Expression {
UnaryOperation(UnaryOperator, Box<Expression>),
FunctionCall(String, Vec<Expression>),
FreeInput(Box<Expression>),
MatchExpression(
Box<Expression>,
Vec<(Option<AbstractNumberType>, Expression)>,
),
}
#[derive(Debug, PartialEq, Eq, Default, Clone)]

View File

@@ -153,6 +153,19 @@ impl Display for Expression {
Expression::UnaryOperation(op, exp) => write!(f, "{op}{exp}"),
Expression::FunctionCall(fun, args) => write!(f, "{fun}({})", format_expressions(args)),
Expression::FreeInput(input) => write!(f, "${{ {input} }}"),
Expression::MatchExpression(scrutinee, arms) => write!(
f,
"match {scrutinee} {{ {} }}",
arms.iter()
.map(|(n, e)| format!(
"{} => {e},",
n.as_ref()
.map(|n| n.to_string())
.unwrap_or_else(|| "_".to_string())
))
.collect::<Vec<_>>()
.join(" ")
),
}
}
}

View File

@@ -241,7 +241,7 @@ Expression: Expression = {
}
BoxedExpression: Box<Expression> = {
BinaryOr
BinaryOr,
}
BinaryOr: Box<Expression> = {
@@ -328,6 +328,7 @@ Term: Box<Expression> = {
PublicReference => Box::new(Expression::PublicReference(<>)),
Number => Box::new(Expression::Number(<>)),
StringLiteral => Box::new(Expression::String(<>)),
MatchExpression,
"(" <head:Expression> "," <tail:ExpressionList> ")" => { let mut list = vec![head]; list.extend(tail); Box::new(Expression::Tuple(list)) },
"(" <BoxedExpression> ")",
"${" <BoxedExpression> "}" => Box::new(Expression::FreeInput(<>))
@@ -348,6 +349,15 @@ PublicReference: String = {
":" <Identifier>
}
MatchExpression: Box<Expression> = {
"match" <BoxedExpression> "{" <(<MatchArm> ",")*> "}" => Box::new(Expression::MatchExpression(<>))
}
MatchArm: (Option<AbstractNumberType>, Expression) = {
<n:Number> "=>" <e:Expression> => (Some(n), e),
<n:"_"> "=>" <e:Expression> => (None, e),
}
// ---------------------------- Terminals -----------------------------

View File

@@ -47,6 +47,9 @@ impl<SV: SymbolicVariables> ExpressionEvaluator<SV> {
Expression::FunctionCall(_, _) => {
Err("Function calls not implemented.".to_string().into())
}
Expression::MatchExpression(_, _) => {
Err("Match expressions not implemented.".to_string().into())
}
}
}

View File

@@ -141,6 +141,14 @@ impl<'a> ReferenceExtractor<'a> {
Expression::BinaryOperation(l, _, r) => &self.in_expression(l) | &self.in_expression(r),
Expression::UnaryOperation(_, e) => self.in_expression(e),
Expression::FunctionCall(_, args) => self.in_expressions(args),
Expression::MatchExpression(scrutinee, arms) => {
&self.in_expression(scrutinee)
| &arms
.iter()
.map(|(_, e)| self.in_expression(e))
.reduce(|a, b| &a | &b)
.unwrap_or_default()
}
Expression::LocalVariableReference(_)
| Expression::PublicReference(_)
| Expression::Number(_)

View File

@@ -1,3 +1,5 @@
use std::iter::once;
use crate::analyzer::{Expression, PolynomialReference};
use super::FixedData;
@@ -62,6 +64,9 @@ pub fn expr_any(expr: &Expression, f: &mut impl FnMut(&Expression) -> bool) -> b
Expression::BinaryOperation(l, _, r) => expr_any(l, f) || expr_any(r, f),
Expression::UnaryOperation(_, e) => expr_any(e, f),
Expression::FunctionCall(_, args) => args.iter().any(|e| expr_any(e, f)),
Expression::MatchExpression(scrutinee, arms) => once(scrutinee.as_ref())
.chain(arms.iter().map(|(_n, e)| e))
.any(|e| expr_any(e, f)),
Expression::Constant(_)
| Expression::PolynomialReference(_)
| Expression::LocalVariableReference(_)