diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index 58527ae80..c702d3351 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -332,6 +332,7 @@ impl Display for Pattern { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match self { Pattern::CatchAll => write!(f, "_"), + Pattern::Ellipsis => write!(f, ".."), Pattern::Number(n) => write!(f, "{n}"), Pattern::String(s) => write!(f, "{}", quote(s)), Pattern::Tuple(t) => write!(f, "({})", t.iter().format(", ")), diff --git a/ast/src/parsed/folder.rs b/ast/src/parsed/folder.rs index f6771d6ea..ef965685b 100644 --- a/ast/src/parsed/folder.rs +++ b/ast/src/parsed/folder.rs @@ -170,9 +170,11 @@ pub trait ExpressionFolder { fn fold_pattern(&mut self, pattern: Pattern) -> Result { Ok(match pattern { - Pattern::CatchAll | Pattern::Number(_) | Pattern::String(_) | Pattern::Variable(_) => { - pattern - } + Pattern::CatchAll + | Pattern::Ellipsis + | Pattern::Number(_) + | Pattern::String(_) + | Pattern::Variable(_) => pattern, Pattern::Tuple(t) => Pattern::Tuple( t.into_iter() .map(|p| self.fold_pattern(p)) diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index 4ec7f40b4..c862a96e9 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -824,7 +824,8 @@ impl Children for ArrayExpression { #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)] pub enum Pattern { - CatchAll, + CatchAll, // "_", matches a single value + Ellipsis, // "..", matches a series of values, only valid inside array patterns #[schemars(skip)] Number(BigInt), String(String), @@ -846,18 +847,22 @@ impl Pattern { impl Children for Pattern { fn children(&self) -> Box + '_> { match self { - Pattern::CatchAll | Pattern::Number(_) | Pattern::String(_) | Pattern::Variable(_) => { - Box::new(empty()) - } + Pattern::CatchAll + | Pattern::Ellipsis + | Pattern::Number(_) + | Pattern::String(_) + | Pattern::Variable(_) => Box::new(empty()), Pattern::Tuple(p) | Pattern::Array(p) => Box::new(p.iter()), } } fn children_mut(&mut self) -> Box + '_> { match self { - Pattern::CatchAll | Pattern::Number(_) | Pattern::String(_) | Pattern::Variable(_) => { - Box::new(empty()) - } + Pattern::CatchAll + | Pattern::Ellipsis + | Pattern::Number(_) + | Pattern::String(_) + | Pattern::Variable(_) => Box::new(empty()), Pattern::Tuple(p) | Pattern::Array(p) => Box::new(p.iter_mut()), } } diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index 82cd2fddc..08d87eec0 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -628,10 +628,15 @@ Pattern: Pattern = { StringLiteral => Pattern::String(<>), TuplePattern, ArrayPattern, - VariablePattern, + Identifier => Pattern::Variable(<>), //EnumPattern, } +PatternIncludingEllipsis: Pattern = { + Pattern => <>, + ".." => Pattern::Ellipsis, +} + TuplePattern: Pattern = { "(" ")" => Pattern::Tuple(vec![]), "(" ",")+> ")" => Pattern::Tuple({items.push(last); items}) @@ -639,11 +644,7 @@ TuplePattern: Pattern = { ArrayPattern: Pattern = { "[" "]" => Pattern::Array(vec![]), - "[" ",")*> "]" => Pattern::Array({items.push(last); items}) -} - -VariablePattern: Pattern = { - => Pattern::Variable(<>), + "[" ",")*> "]" => Pattern::Array({items.push(last); items}) } diff --git a/pil-analyzer/src/evaluator.rs b/pil-analyzer/src/evaluator.rs index 781f817fc..9fa9bf1ee 100644 --- a/pil-analyzer/src/evaluator.rs +++ b/pil-analyzer/src/evaluator.rs @@ -251,6 +251,7 @@ impl<'a, T: FieldElement> Value<'a, T> { pattern: &Pattern, ) -> Option>>> { match pattern { + Pattern::Ellipsis => unreachable!("Should be handled higher up"), Pattern::CatchAll => Some(vec![]), Pattern::Number(n) => match v.as_ref() { Value::Integer(x) if x == n => Some(vec![]), @@ -278,18 +279,38 @@ impl<'a, T: FieldElement> Value<'a, T> { } _ => unreachable!(), }, - Pattern::Array(items) => match v.as_ref() { - Value::Array(values) if values.len() == items.len() => values - .iter() - .zip(items) + Pattern::Array(items) => { + let Value::Array(values) = v.as_ref() else { + panic!("Type error") + }; + // Index of ".." + let ellipsis_pos = items.iter().position(|i| *i == Pattern::Ellipsis); + // Check if the value is too short. + let length_matches = match ellipsis_pos { + Some(_) => values.len() >= items.len() - 1, + None => values.len() == items.len(), + }; + if !length_matches { + return None; + } + // Split value into "left" and "right" part. + let left_len = ellipsis_pos.unwrap_or(values.len()); + let right_len = ellipsis_pos.map(|p| items.len() - p - 1).unwrap_or(0); + let left = values.iter().take(left_len); + let right = values.iter().skip(values.len() - right_len); + assert_eq!( + left.len() + right.len(), + items.len() - ellipsis_pos.map(|_| 1).unwrap_or_default() + ); + left.chain(right) + .zip(items.iter().filter(|&i| *i != Pattern::Ellipsis)) .try_fold(vec![], |mut vars, (e, p)| { Value::try_match_pattern(e, p).map(|v| { vars.extend(v); vars }) - }), - _ => None, - }, + }) + } Pattern::Variable(_) => Some(vec![v.clone()]), } } @@ -1188,4 +1209,65 @@ mod test { "[21, 5, 307, 9, 90, 99, -1]".to_string() ); } + + #[test] + pub fn match_skip_array() { + let src = r#" + let f: int[] -> int = |arr| match arr { + [x, .., y] => x + y, + [] => 19, + _ => 99, + }; + let t = [f([]), f([1]), f([1, 2]), f([1, 2, 3]), f([1, 2, 3, 4])]; + "#; + assert_eq!( + parse_and_evaluate_symbol(src, "t"), + "[19, 99, 3, 4, 5]".to_string() + ); + } + + #[test] + pub fn match_skip_array_2() { + let src = r#" + let f: int[] -> int = |arr| match arr { + [.., y] => y, + _ => 99, + }; + let t = [f([]), f([1]), f([1, 2]), f([1, 2, 3]), f([1, 2, 3, 4])]; + "#; + assert_eq!( + parse_and_evaluate_symbol(src, "t"), + "[99, 1, 2, 3, 4]".to_string() + ); + } + + #[test] + pub fn match_skip_array_3() { + let src = r#" + let f: int[] -> int = |arr| match arr { + [.., x, y] => x, + [..] => 99, + }; + let t = [f([]), f([1]), f([1, 2]), f([1, 2, 3]), f([1, 2, 3, 4])]; + "#; + assert_eq!( + parse_and_evaluate_symbol(src, "t"), + "[99, 99, 1, 2, 3]".to_string() + ); + } + + #[test] + pub fn match_skip_array_4() { + let src = r#" + let f: int[] -> int = |arr| match arr { + [x, y, ..] => y, + [..] => 99, + }; + let t = [f([]), f([1]), f([1, 2]), f([1, 2, 3]), f([1, 2, 3, 4])]; + "#; + assert_eq!( + parse_and_evaluate_symbol(src, "t"), + "[99, 99, 2, 2, 2]".to_string() + ); + } } diff --git a/pil-analyzer/src/expression_processor.rs b/pil-analyzer/src/expression_processor.rs index 830d9d390..b9cb7044e 100644 --- a/pil-analyzer/src/expression_processor.rs +++ b/pil-analyzer/src/expression_processor.rs @@ -142,8 +142,14 @@ impl ExpressionProcessor { /// Processes a pattern, registering all variables bound in there. fn process_pattern(&mut self, pattern: &Pattern) { match pattern { - Pattern::CatchAll | Pattern::Number(_) | Pattern::String(_) => {} + Pattern::CatchAll | Pattern::Ellipsis | Pattern::Number(_) | Pattern::String(_) => {} Pattern::Tuple(items) | Pattern::Array(items) => { + if matches!(pattern, Pattern::Array(_)) { + // If there is more than one Pattern::Ellipsis in items, it is an error + if items.iter().filter(|p| *p == &Pattern::Ellipsis).count() > 1 { + panic!("Only one \"..\"-item allowed in array pattern"); + } + } items.iter().for_each(|p| self.process_pattern(p)); } Pattern::Variable(name) => { diff --git a/pil-analyzer/src/type_inference.rs b/pil-analyzer/src/type_inference.rs index 4bcf87a83..b7cb7b13b 100644 --- a/pil-analyzer/src/type_inference.rs +++ b/pil-analyzer/src/type_inference.rs @@ -680,6 +680,7 @@ impl TypeChecker { /// Type-checks a pattern and adds local variables. fn infer_type_of_pattern(&mut self, pattern: &Pattern) -> Result { Ok(match pattern { + Pattern::Ellipsis => unreachable!("Should be handled higher up."), Pattern::CatchAll => self.new_type_var(), Pattern::Number(_) => { let ty = self.new_type_var(); @@ -696,7 +697,9 @@ impl TypeChecker { Pattern::Array(items) => { let item_type = self.new_type_var(); for item in items { - self.expect_type_of_pattern(&item_type, item)?; + if item != &Pattern::Ellipsis { + self.expect_type_of_pattern(&item_type, item)?; + } } Type::Array(ArrayType { base: Box::new(item_type), diff --git a/pil-analyzer/tests/parse_display.rs b/pil-analyzer/tests/parse_display.rs index a614f4873..c0e6ae715 100644 --- a/pil-analyzer/tests/parse_display.rs +++ b/pil-analyzer/tests/parse_display.rs @@ -642,3 +642,27 @@ fn match_shadowing() { "; assert_eq!(input, analyze_string::(input).to_string()); } + +#[test] +fn single_ellipsis() { + let input = " let t: int[] -> int = (|i| match i { + [1, .., 3] => 2, + [..] => 3, + [.., 1] => 9, + [7, 8, ..] => 2, + _ => -1, + }); +"; + assert_eq!(input, analyze_string::(input).to_string()); +} + +#[test] +#[should_panic = "Only one \"..\"-item allowed in array pattern"] +fn multi_ellipsis() { + let input = " let t: int[] -> int = (|i| match i { + [1, .., 3, ..] => 2, + _ => -1, + }); +"; + assert_eq!(input, analyze_string::(input).to_string()); +}