Implement if expression.

This commit is contained in:
chriseth
2023-11-22 19:41:07 +01:00
parent 13a62df359
commit 53301117bc
11 changed files with 121 additions and 6 deletions

View File

@@ -595,6 +595,7 @@ impl<T: FieldElement> ASMPILConverter<T> {
Expression::Tuple(_) => panic!(),
Expression::ArrayLiteral(_) => panic!(),
Expression::MatchExpression(_, _) => panic!(),
Expression::IfExpression(_) => panic!(),
Expression::FreeInput(expr) => {
vec![(1.into(), AffineExpressionComponent::FreeInput(*expr))]
}

View File

@@ -315,6 +315,16 @@ impl<T: Display, Ref: Display> Display for MatchPattern<T, Ref> {
}
}
impl<T: Display, Ref: Display> Display for IfExpression<T, Ref> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(
f,
"if {} {{ {} }} else {{ {} }}",
self.condition, self.body, self.else_body
)
}
}
impl<T: Display> Display for Param<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(
@@ -464,6 +474,7 @@ impl<T: Display, Ref: Display> Display for Expression<T, Ref> {
Expression::MatchExpression(scrutinee, arms) => {
write!(f, "match {scrutinee} {{ {} }}", arms.iter().format(" "))
}
Expression::IfExpression(e) => write!(f, "{e}"),
}
}
}

View File

@@ -3,7 +3,8 @@ use super::{
ASMModule, ASMProgram, Import, Machine, Module, ModuleStatement, SymbolDefinition,
SymbolValue,
},
ArrayLiteral, Expression, FunctionCall, IndexAccess, LambdaExpression, MatchArm, MatchPattern,
ArrayLiteral, Expression, FunctionCall, IfExpression, IndexAccess, LambdaExpression, MatchArm,
MatchPattern,
};
pub trait Folder<T> {
@@ -94,6 +95,9 @@ pub trait ExpressionFolder<T, Ref> {
.map(|a| self.fold_match_arm(a))
.collect::<Result<_, _>>()?,
),
Expression::IfExpression(if_expr) => {
Expression::IfExpression(self.fold_if_expression(if_expr)?)
}
})
}
@@ -154,6 +158,21 @@ pub trait ExpressionFolder<T, Ref> {
})
}
fn fold_if_expression(
&mut self,
IfExpression {
condition,
body,
else_body,
}: IfExpression<T, Ref>,
) -> Result<IfExpression<T, Ref>, Self::Error> {
Ok(IfExpression {
condition: self.fold_boxed_expression(*condition)?,
body: self.fold_boxed_expression(*body)?,
else_body: self.fold_boxed_expression(*else_body)?,
})
}
fn fold_boxed_expression(
&mut self,
e: Expression<T, Ref>,

View File

@@ -84,6 +84,7 @@ pub enum Expression<T, Ref = NamespacedPolynomialReference> {
FunctionCall(FunctionCall<T, Ref>),
FreeInput(Box<Expression<T, Ref>>),
MatchExpression(Box<Expression<T, Ref>>, Vec<MatchArm<T, Ref>>),
IfExpression(IfExpression<T, Ref>),
}
impl<T, Ref> Expression<T, Ref> {
@@ -241,6 +242,13 @@ pub enum MatchPattern<T, Ref = NamespacedPolynomialReference> {
Pattern(Expression<T, Ref>),
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub struct IfExpression<T, Ref = NamespacedPolynomialReference> {
pub condition: Box<Expression<T, Ref>>,
pub body: Box<Expression<T, Ref>>,
pub else_body: Box<Expression<T, Ref>>,
}
/// The definition of a function (excluding its name):
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub enum FunctionDefinition<T> {

View File

@@ -1,9 +1,9 @@
use std::{iter::once, ops::ControlFlow};
use super::{
ArrayExpression, ArrayLiteral, Expression, FunctionCall, FunctionDefinition, IndexAccess,
LambdaExpression, MatchArm, MatchPattern, NamespacedPolynomialReference, PilStatement,
SelectedExpressions,
ArrayExpression, ArrayLiteral, Expression, FunctionCall, FunctionDefinition, IfExpression,
IndexAccess, LambdaExpression, MatchArm, MatchPattern, NamespacedPolynomialReference,
PilStatement, SelectedExpressions,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -137,6 +137,7 @@ impl<T, Ref> ExpressionVisitable<Expression<T, Ref>> for Expression<T, Ref> {
arms.iter_mut()
.try_for_each(|arm| arm.visit_expressions_mut(f, o))?;
}
Expression::IfExpression(if_expr) => if_expr.visit_expressions_mut(f, o)?,
};
if o == VisitOrder::Post {
f(self)?;
@@ -175,6 +176,7 @@ impl<T, Ref> ExpressionVisitable<Expression<T, Ref>> for Expression<T, Ref> {
arms.iter()
.try_for_each(|arm| arm.visit_expressions(f, o))?;
}
Expression::IfExpression(if_expr) => if_expr.visit_expressions(f, o)?,
};
if o == VisitOrder::Post {
f(self)?;
@@ -454,3 +456,23 @@ impl<T, Ref> ExpressionVisitable<Expression<T, Ref>> for MatchPattern<T, Ref> {
}
}
}
impl<T, Ref> ExpressionVisitable<Expression<T, Ref>> for IfExpression<T, Ref> {
fn visit_expressions_mut<F, B>(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
where
F: FnMut(&mut Expression<T, Ref>) -> ControlFlow<B>,
{
[&mut self.condition, &mut self.body, &mut self.else_body]
.into_iter()
.try_for_each(|e| e.visit_expressions_mut(f, o))
}
fn visit_expressions<F, B>(&self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
where
F: FnMut(&Expression<T, Ref>) -> ControlFlow<B>,
{
[&self.condition, &self.body, &self.else_body]
.into_iter()
.try_for_each(|e| e.visit_expressions(f, o))
}
}

View File

@@ -227,6 +227,22 @@ mod test {
);
}
#[test]
pub fn test_if() {
let src = r#"
constant %N = 8;
namespace F(%N);
let X = |i| if i < 3 { 7 } else { 9 };
"#;
let analyzed = analyze_string(src);
assert_eq!(analyzed.degree(), 8);
let constants = generate(&analyzed);
assert_eq!(
constants,
vec![("F.X", convert(vec![7, 7, 7, 9, 9, 9, 9, 9]))]
);
}
#[test]
pub fn test_macro() {
let src = r#"

View File

@@ -477,6 +477,7 @@ Term: Box<Expression<T>> = {
FieldElement => Box::new(Expression::Number(<>)),
StringLiteral => Box::new(Expression::String(<>)),
MatchExpression,
IfExpression,
"[" <items:ExpressionList> "]" => Box::new(Expression::ArrayLiteral(ArrayLiteral{items})),
"(" <head:Expression> "," <tail:ExpressionList> ")" => { let mut list = vec![head]; list.extend(tail); Box::new(Expression::Tuple(list)) },
"(" <BoxedExpression> ")",
@@ -517,6 +518,13 @@ MatchPattern: MatchPattern<T> = {
Expression => MatchPattern::Pattern(<>),
}
IfExpression: Box<Expression<T>> = {
"if" <condition:BoxedExpression>
"{" <body:BoxedExpression> "}"
"else"
"{" <else_body:BoxedExpression> "}" => Box::new(Expression::IfExpression(IfExpression{<>}))
}
// ---------------------------- Terminals -----------------------------

View File

@@ -322,6 +322,15 @@ mod internal {
.ok_or_else(EvalError::NoMatch)?;
evaluate(body, locals, symbols)?
}
Expression::IfExpression(if_expr) => {
let v = evaluate(&if_expr.condition, locals, symbols)?.try_to_number()?;
let body = if !v.is_zero() {
&if_expr.body
} else {
&if_expr.else_body
};
evaluate(body.as_ref(), locals, symbols)?
}
Expression::FreeInput(_) => Err(EvalError::Unsupported(
"Cannot evaluate free input.".to_string(),
))?,

View File

@@ -3,8 +3,8 @@ use std::collections::HashMap;
use ast::{
analyzed::{Expression, PolynomialReference, Reference, RepeatedArray},
parsed::{
self, ArrayExpression, ArrayLiteral, LambdaExpression, MatchArm, MatchPattern,
NamespacedPolynomialReference, SelectedExpressions,
self, ArrayExpression, ArrayLiteral, IfExpression, LambdaExpression, MatchArm,
MatchPattern, NamespacedPolynomialReference, SelectedExpressions,
},
};
use number::DegreeType;
@@ -129,6 +129,15 @@ impl<R: ReferenceResolver> ExpressionProcessor<R> {
})
.collect(),
),
PExpression::IfExpression(IfExpression {
condition,
body,
else_body,
}) => Expression::IfExpression(IfExpression {
condition: Box::new(self.process_expression(*condition)),
body: Box::new(self.process_expression(*body)),
else_body: Box::new(self.process_expression(*else_body)),
}),
PExpression::FreeInput(_) => panic!(),
}
}

View File

@@ -630,6 +630,17 @@ namespace N(65536);
assert_eq!(formatted, input);
}
#[test]
fn if_expr() {
let input = r#"namespace Assembly(2);
col fixed A = [0]*;
col fixed C(i) { if (i < 3) { Assembly.A(i) } else { (i + 9) } };
col fixed D(i) { if Assembly.C(i) { 3 } else { 2 } };
"#;
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
assert_eq!(formatted, input);
}
#[test]
fn symbolic_functions() {
let input = r#"namespace N(16);

View File

@@ -703,6 +703,7 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> {
panic!("does not matched IO pattern")
}
Expression::MatchExpression(_, _) => todo!(),
Expression::IfExpression(_) => panic!(),
Expression::IndexAccess(_) => todo!(),
}
}