mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-04-20 03:03:25 -04:00
Implement if expression.
This commit is contained in:
@@ -595,6 +595,7 @@ impl<T: FieldElement> ASMPILConverter<T> {
|
||||
Expression::Tuple(_) => panic!(),
|
||||
Expression::ArrayLiteral(_) => panic!(),
|
||||
Expression::MatchExpression(_, _) => panic!(),
|
||||
Expression::IfExpression(_) => panic!(),
|
||||
Expression::FreeInput(expr) => {
|
||||
vec![(1.into(), AffineExpressionComponent::FreeInput(*expr))]
|
||||
}
|
||||
|
||||
@@ -315,6 +315,16 @@ impl<T: Display, Ref: Display> Display for MatchPattern<T, Ref> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Display, Ref: Display> Display for IfExpression<T, Ref> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
|
||||
write!(
|
||||
f,
|
||||
"if {} {{ {} }} else {{ {} }}",
|
||||
self.condition, self.body, self.else_body
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Display> Display for Param<T> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
|
||||
write!(
|
||||
@@ -464,6 +474,7 @@ impl<T: Display, Ref: Display> Display for Expression<T, Ref> {
|
||||
Expression::MatchExpression(scrutinee, arms) => {
|
||||
write!(f, "match {scrutinee} {{ {} }}", arms.iter().format(" "))
|
||||
}
|
||||
Expression::IfExpression(e) => write!(f, "{e}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,8 @@ use super::{
|
||||
ASMModule, ASMProgram, Import, Machine, Module, ModuleStatement, SymbolDefinition,
|
||||
SymbolValue,
|
||||
},
|
||||
ArrayLiteral, Expression, FunctionCall, IndexAccess, LambdaExpression, MatchArm, MatchPattern,
|
||||
ArrayLiteral, Expression, FunctionCall, IfExpression, IndexAccess, LambdaExpression, MatchArm,
|
||||
MatchPattern,
|
||||
};
|
||||
|
||||
pub trait Folder<T> {
|
||||
@@ -94,6 +95,9 @@ pub trait ExpressionFolder<T, Ref> {
|
||||
.map(|a| self.fold_match_arm(a))
|
||||
.collect::<Result<_, _>>()?,
|
||||
),
|
||||
Expression::IfExpression(if_expr) => {
|
||||
Expression::IfExpression(self.fold_if_expression(if_expr)?)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -154,6 +158,21 @@ pub trait ExpressionFolder<T, Ref> {
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_if_expression(
|
||||
&mut self,
|
||||
IfExpression {
|
||||
condition,
|
||||
body,
|
||||
else_body,
|
||||
}: IfExpression<T, Ref>,
|
||||
) -> Result<IfExpression<T, Ref>, Self::Error> {
|
||||
Ok(IfExpression {
|
||||
condition: self.fold_boxed_expression(*condition)?,
|
||||
body: self.fold_boxed_expression(*body)?,
|
||||
else_body: self.fold_boxed_expression(*else_body)?,
|
||||
})
|
||||
}
|
||||
|
||||
fn fold_boxed_expression(
|
||||
&mut self,
|
||||
e: Expression<T, Ref>,
|
||||
|
||||
@@ -84,6 +84,7 @@ pub enum Expression<T, Ref = NamespacedPolynomialReference> {
|
||||
FunctionCall(FunctionCall<T, Ref>),
|
||||
FreeInput(Box<Expression<T, Ref>>),
|
||||
MatchExpression(Box<Expression<T, Ref>>, Vec<MatchArm<T, Ref>>),
|
||||
IfExpression(IfExpression<T, Ref>),
|
||||
}
|
||||
|
||||
impl<T, Ref> Expression<T, Ref> {
|
||||
@@ -241,6 +242,13 @@ pub enum MatchPattern<T, Ref = NamespacedPolynomialReference> {
|
||||
Pattern(Expression<T, Ref>),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
pub struct IfExpression<T, Ref = NamespacedPolynomialReference> {
|
||||
pub condition: Box<Expression<T, Ref>>,
|
||||
pub body: Box<Expression<T, Ref>>,
|
||||
pub else_body: Box<Expression<T, Ref>>,
|
||||
}
|
||||
|
||||
/// The definition of a function (excluding its name):
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
pub enum FunctionDefinition<T> {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use std::{iter::once, ops::ControlFlow};
|
||||
|
||||
use super::{
|
||||
ArrayExpression, ArrayLiteral, Expression, FunctionCall, FunctionDefinition, IndexAccess,
|
||||
LambdaExpression, MatchArm, MatchPattern, NamespacedPolynomialReference, PilStatement,
|
||||
SelectedExpressions,
|
||||
ArrayExpression, ArrayLiteral, Expression, FunctionCall, FunctionDefinition, IfExpression,
|
||||
IndexAccess, LambdaExpression, MatchArm, MatchPattern, NamespacedPolynomialReference,
|
||||
PilStatement, SelectedExpressions,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
@@ -137,6 +137,7 @@ impl<T, Ref> ExpressionVisitable<Expression<T, Ref>> for Expression<T, Ref> {
|
||||
arms.iter_mut()
|
||||
.try_for_each(|arm| arm.visit_expressions_mut(f, o))?;
|
||||
}
|
||||
Expression::IfExpression(if_expr) => if_expr.visit_expressions_mut(f, o)?,
|
||||
};
|
||||
if o == VisitOrder::Post {
|
||||
f(self)?;
|
||||
@@ -175,6 +176,7 @@ impl<T, Ref> ExpressionVisitable<Expression<T, Ref>> for Expression<T, Ref> {
|
||||
arms.iter()
|
||||
.try_for_each(|arm| arm.visit_expressions(f, o))?;
|
||||
}
|
||||
Expression::IfExpression(if_expr) => if_expr.visit_expressions(f, o)?,
|
||||
};
|
||||
if o == VisitOrder::Post {
|
||||
f(self)?;
|
||||
@@ -454,3 +456,23 @@ impl<T, Ref> ExpressionVisitable<Expression<T, Ref>> for MatchPattern<T, Ref> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, Ref> ExpressionVisitable<Expression<T, Ref>> for IfExpression<T, Ref> {
|
||||
fn visit_expressions_mut<F, B>(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
|
||||
where
|
||||
F: FnMut(&mut Expression<T, Ref>) -> ControlFlow<B>,
|
||||
{
|
||||
[&mut self.condition, &mut self.body, &mut self.else_body]
|
||||
.into_iter()
|
||||
.try_for_each(|e| e.visit_expressions_mut(f, o))
|
||||
}
|
||||
|
||||
fn visit_expressions<F, B>(&self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
|
||||
where
|
||||
F: FnMut(&Expression<T, Ref>) -> ControlFlow<B>,
|
||||
{
|
||||
[&self.condition, &self.body, &self.else_body]
|
||||
.into_iter()
|
||||
.try_for_each(|e| e.visit_expressions(f, o))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -227,6 +227,22 @@ mod test {
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_if() {
|
||||
let src = r#"
|
||||
constant %N = 8;
|
||||
namespace F(%N);
|
||||
let X = |i| if i < 3 { 7 } else { 9 };
|
||||
"#;
|
||||
let analyzed = analyze_string(src);
|
||||
assert_eq!(analyzed.degree(), 8);
|
||||
let constants = generate(&analyzed);
|
||||
assert_eq!(
|
||||
constants,
|
||||
vec![("F.X", convert(vec![7, 7, 7, 9, 9, 9, 9, 9]))]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_macro() {
|
||||
let src = r#"
|
||||
|
||||
@@ -477,6 +477,7 @@ Term: Box<Expression<T>> = {
|
||||
FieldElement => Box::new(Expression::Number(<>)),
|
||||
StringLiteral => Box::new(Expression::String(<>)),
|
||||
MatchExpression,
|
||||
IfExpression,
|
||||
"[" <items:ExpressionList> "]" => Box::new(Expression::ArrayLiteral(ArrayLiteral{items})),
|
||||
"(" <head:Expression> "," <tail:ExpressionList> ")" => { let mut list = vec![head]; list.extend(tail); Box::new(Expression::Tuple(list)) },
|
||||
"(" <BoxedExpression> ")",
|
||||
@@ -517,6 +518,13 @@ MatchPattern: MatchPattern<T> = {
|
||||
Expression => MatchPattern::Pattern(<>),
|
||||
}
|
||||
|
||||
IfExpression: Box<Expression<T>> = {
|
||||
"if" <condition:BoxedExpression>
|
||||
"{" <body:BoxedExpression> "}"
|
||||
"else"
|
||||
"{" <else_body:BoxedExpression> "}" => Box::new(Expression::IfExpression(IfExpression{<>}))
|
||||
}
|
||||
|
||||
// ---------------------------- Terminals -----------------------------
|
||||
|
||||
|
||||
|
||||
@@ -322,6 +322,15 @@ mod internal {
|
||||
.ok_or_else(EvalError::NoMatch)?;
|
||||
evaluate(body, locals, symbols)?
|
||||
}
|
||||
Expression::IfExpression(if_expr) => {
|
||||
let v = evaluate(&if_expr.condition, locals, symbols)?.try_to_number()?;
|
||||
let body = if !v.is_zero() {
|
||||
&if_expr.body
|
||||
} else {
|
||||
&if_expr.else_body
|
||||
};
|
||||
evaluate(body.as_ref(), locals, symbols)?
|
||||
}
|
||||
Expression::FreeInput(_) => Err(EvalError::Unsupported(
|
||||
"Cannot evaluate free input.".to_string(),
|
||||
))?,
|
||||
|
||||
@@ -3,8 +3,8 @@ use std::collections::HashMap;
|
||||
use ast::{
|
||||
analyzed::{Expression, PolynomialReference, Reference, RepeatedArray},
|
||||
parsed::{
|
||||
self, ArrayExpression, ArrayLiteral, LambdaExpression, MatchArm, MatchPattern,
|
||||
NamespacedPolynomialReference, SelectedExpressions,
|
||||
self, ArrayExpression, ArrayLiteral, IfExpression, LambdaExpression, MatchArm,
|
||||
MatchPattern, NamespacedPolynomialReference, SelectedExpressions,
|
||||
},
|
||||
};
|
||||
use number::DegreeType;
|
||||
@@ -129,6 +129,15 @@ impl<R: ReferenceResolver> ExpressionProcessor<R> {
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
PExpression::IfExpression(IfExpression {
|
||||
condition,
|
||||
body,
|
||||
else_body,
|
||||
}) => Expression::IfExpression(IfExpression {
|
||||
condition: Box::new(self.process_expression(*condition)),
|
||||
body: Box::new(self.process_expression(*body)),
|
||||
else_body: Box::new(self.process_expression(*else_body)),
|
||||
}),
|
||||
PExpression::FreeInput(_) => panic!(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -630,6 +630,17 @@ namespace N(65536);
|
||||
assert_eq!(formatted, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn if_expr() {
|
||||
let input = r#"namespace Assembly(2);
|
||||
col fixed A = [0]*;
|
||||
col fixed C(i) { if (i < 3) { Assembly.A(i) } else { (i + 9) } };
|
||||
col fixed D(i) { if Assembly.C(i) { 3 } else { 2 } };
|
||||
"#;
|
||||
let formatted = process_pil_file_contents::<GoldilocksField>(input).to_string();
|
||||
assert_eq!(formatted, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn symbolic_functions() {
|
||||
let input = r#"namespace N(16);
|
||||
|
||||
@@ -703,6 +703,7 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> {
|
||||
panic!("does not matched IO pattern")
|
||||
}
|
||||
Expression::MatchExpression(_, _) => todo!(),
|
||||
Expression::IfExpression(_) => panic!(),
|
||||
Expression::IndexAccess(_) => todo!(),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user