diff --git a/asm_to_pil/src/vm_to_constrained.rs b/asm_to_pil/src/vm_to_constrained.rs index 88f0d1558..68cf46dbc 100644 --- a/asm_to_pil/src/vm_to_constrained.rs +++ b/asm_to_pil/src/vm_to_constrained.rs @@ -595,6 +595,7 @@ impl ASMPILConverter { Expression::Tuple(_) => panic!(), Expression::ArrayLiteral(_) => panic!(), Expression::MatchExpression(_, _) => panic!(), + Expression::IfExpression(_) => panic!(), Expression::FreeInput(expr) => { vec![(1.into(), AffineExpressionComponent::FreeInput(*expr))] } diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index 8e1ba411c..ac23386dc 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -315,6 +315,16 @@ impl Display for MatchPattern { } } +impl Display for IfExpression { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!( + f, + "if {} {{ {} }} else {{ {} }}", + self.condition, self.body, self.else_body + ) + } +} + impl Display for Param { fn fmt(&self, f: &mut Formatter<'_>) -> Result { write!( @@ -464,6 +474,7 @@ impl Display for Expression { Expression::MatchExpression(scrutinee, arms) => { write!(f, "match {scrutinee} {{ {} }}", arms.iter().format(" ")) } + Expression::IfExpression(e) => write!(f, "{e}"), } } } diff --git a/ast/src/parsed/folder.rs b/ast/src/parsed/folder.rs index 8dfc759f3..614eddb3a 100644 --- a/ast/src/parsed/folder.rs +++ b/ast/src/parsed/folder.rs @@ -3,7 +3,8 @@ use super::{ ASMModule, ASMProgram, Import, Machine, Module, ModuleStatement, SymbolDefinition, SymbolValue, }, - ArrayLiteral, Expression, FunctionCall, IndexAccess, LambdaExpression, MatchArm, MatchPattern, + ArrayLiteral, Expression, FunctionCall, IfExpression, IndexAccess, LambdaExpression, MatchArm, + MatchPattern, }; pub trait Folder { @@ -94,6 +95,9 @@ pub trait ExpressionFolder { .map(|a| self.fold_match_arm(a)) .collect::>()?, ), + Expression::IfExpression(if_expr) => { + Expression::IfExpression(self.fold_if_expression(if_expr)?) + } }) } @@ -154,6 +158,21 @@ pub trait ExpressionFolder { }) } + fn fold_if_expression( + &mut self, + IfExpression { + condition, + body, + else_body, + }: IfExpression, + ) -> Result, Self::Error> { + Ok(IfExpression { + condition: self.fold_boxed_expression(*condition)?, + body: self.fold_boxed_expression(*body)?, + else_body: self.fold_boxed_expression(*else_body)?, + }) + } + fn fold_boxed_expression( &mut self, e: Expression, diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index 402ac75dd..26657af85 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -84,6 +84,7 @@ pub enum Expression { FunctionCall(FunctionCall), FreeInput(Box>), MatchExpression(Box>, Vec>), + IfExpression(IfExpression), } impl Expression { @@ -241,6 +242,13 @@ pub enum MatchPattern { Pattern(Expression), } +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub struct IfExpression { + pub condition: Box>, + pub body: Box>, + pub else_body: Box>, +} + /// The definition of a function (excluding its name): #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] pub enum FunctionDefinition { diff --git a/ast/src/parsed/visitor.rs b/ast/src/parsed/visitor.rs index 242565413..8fa9e7bfa 100644 --- a/ast/src/parsed/visitor.rs +++ b/ast/src/parsed/visitor.rs @@ -1,9 +1,9 @@ use std::{iter::once, ops::ControlFlow}; use super::{ - ArrayExpression, ArrayLiteral, Expression, FunctionCall, FunctionDefinition, IndexAccess, - LambdaExpression, MatchArm, MatchPattern, NamespacedPolynomialReference, PilStatement, - SelectedExpressions, + ArrayExpression, ArrayLiteral, Expression, FunctionCall, FunctionDefinition, IfExpression, + IndexAccess, LambdaExpression, MatchArm, MatchPattern, NamespacedPolynomialReference, + PilStatement, SelectedExpressions, }; #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -137,6 +137,7 @@ impl ExpressionVisitable> for Expression { arms.iter_mut() .try_for_each(|arm| arm.visit_expressions_mut(f, o))?; } + Expression::IfExpression(if_expr) => if_expr.visit_expressions_mut(f, o)?, }; if o == VisitOrder::Post { f(self)?; @@ -175,6 +176,7 @@ impl ExpressionVisitable> for Expression { arms.iter() .try_for_each(|arm| arm.visit_expressions(f, o))?; } + Expression::IfExpression(if_expr) => if_expr.visit_expressions(f, o)?, }; if o == VisitOrder::Post { f(self)?; @@ -454,3 +456,23 @@ impl ExpressionVisitable> for MatchPattern { } } } + +impl ExpressionVisitable> for IfExpression { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut Expression) -> ControlFlow, + { + [&mut self.condition, &mut self.body, &mut self.else_body] + .into_iter() + .try_for_each(|e| e.visit_expressions_mut(f, o)) + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&Expression) -> ControlFlow, + { + [&self.condition, &self.body, &self.else_body] + .into_iter() + .try_for_each(|e| e.visit_expressions(f, o)) + } +} diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index e58845ffb..aa8045515 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -227,6 +227,22 @@ mod test { ); } + #[test] + pub fn test_if() { + let src = r#" + constant %N = 8; + namespace F(%N); + let X = |i| if i < 3 { 7 } else { 9 }; + "#; + let analyzed = analyze_string(src); + assert_eq!(analyzed.degree(), 8); + let constants = generate(&analyzed); + assert_eq!( + constants, + vec![("F.X", convert(vec![7, 7, 7, 9, 9, 9, 9, 9]))] + ); + } + #[test] pub fn test_macro() { let src = r#" diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index 5295dc10b..4682c1966 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -477,6 +477,7 @@ Term: Box> = { FieldElement => Box::new(Expression::Number(<>)), StringLiteral => Box::new(Expression::String(<>)), MatchExpression, + IfExpression, "[" "]" => Box::new(Expression::ArrayLiteral(ArrayLiteral{items})), "(" "," ")" => { let mut list = vec![head]; list.extend(tail); Box::new(Expression::Tuple(list)) }, "(" ")", @@ -517,6 +518,13 @@ MatchPattern: MatchPattern = { Expression => MatchPattern::Pattern(<>), } +IfExpression: Box> = { + "if" + "{" "}" + "else" + "{" "}" => Box::new(Expression::IfExpression(IfExpression{<>})) +} + // ---------------------------- Terminals ----------------------------- diff --git a/pil_analyzer/src/evaluator.rs b/pil_analyzer/src/evaluator.rs index 62726fd34..3924ccaa8 100644 --- a/pil_analyzer/src/evaluator.rs +++ b/pil_analyzer/src/evaluator.rs @@ -322,6 +322,15 @@ mod internal { .ok_or_else(EvalError::NoMatch)?; evaluate(body, locals, symbols)? } + Expression::IfExpression(if_expr) => { + let v = evaluate(&if_expr.condition, locals, symbols)?.try_to_number()?; + let body = if !v.is_zero() { + &if_expr.body + } else { + &if_expr.else_body + }; + evaluate(body.as_ref(), locals, symbols)? + } Expression::FreeInput(_) => Err(EvalError::Unsupported( "Cannot evaluate free input.".to_string(), ))?, diff --git a/pil_analyzer/src/expression_processor.rs b/pil_analyzer/src/expression_processor.rs index ef0f665b5..7c9754971 100644 --- a/pil_analyzer/src/expression_processor.rs +++ b/pil_analyzer/src/expression_processor.rs @@ -3,8 +3,8 @@ use std::collections::HashMap; use ast::{ analyzed::{Expression, PolynomialReference, Reference, RepeatedArray}, parsed::{ - self, ArrayExpression, ArrayLiteral, LambdaExpression, MatchArm, MatchPattern, - NamespacedPolynomialReference, SelectedExpressions, + self, ArrayExpression, ArrayLiteral, IfExpression, LambdaExpression, MatchArm, + MatchPattern, NamespacedPolynomialReference, SelectedExpressions, }, }; use number::DegreeType; @@ -129,6 +129,15 @@ impl ExpressionProcessor { }) .collect(), ), + PExpression::IfExpression(IfExpression { + condition, + body, + else_body, + }) => Expression::IfExpression(IfExpression { + condition: Box::new(self.process_expression(*condition)), + body: Box::new(self.process_expression(*body)), + else_body: Box::new(self.process_expression(*else_body)), + }), PExpression::FreeInput(_) => panic!(), } } diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index 64668e0fc..bdd1e2550 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -630,6 +630,17 @@ namespace N(65536); assert_eq!(formatted, input); } + #[test] + fn if_expr() { + let input = r#"namespace Assembly(2); + col fixed A = [0]*; + col fixed C(i) { if (i < 3) { Assembly.A(i) } else { (i + 9) } }; + col fixed D(i) { if Assembly.C(i) { 3 } else { 2 } }; +"#; + let formatted = process_pil_file_contents::(input).to_string(); + assert_eq!(formatted, input); + } + #[test] fn symbolic_functions() { let input = r#"namespace N(16); diff --git a/riscv_executor/src/lib.rs b/riscv_executor/src/lib.rs index 4501628b8..a14387832 100644 --- a/riscv_executor/src/lib.rs +++ b/riscv_executor/src/lib.rs @@ -703,6 +703,7 @@ impl<'a, 'b, F: FieldElement> Executor<'a, 'b, F> { panic!("does not matched IO pattern") } Expression::MatchExpression(_, _) => todo!(), + Expression::IfExpression(_) => panic!(), Expression::IndexAccess(_) => todo!(), } }