Flexible array pattern (#1208)

Depends on #1205 

Will document in https://github.com/powdr-labs/powdr/pull/1214
This commit is contained in:
chriseth
2024-04-03 13:48:07 +02:00
committed by GitHub
parent f46d59dfe1
commit 4e5c464df4
8 changed files with 149 additions and 25 deletions

View File

@@ -332,6 +332,7 @@ impl Display for Pattern {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
match self {
Pattern::CatchAll => write!(f, "_"),
Pattern::Ellipsis => write!(f, ".."),
Pattern::Number(n) => write!(f, "{n}"),
Pattern::String(s) => write!(f, "{}", quote(s)),
Pattern::Tuple(t) => write!(f, "({})", t.iter().format(", ")),

View File

@@ -170,9 +170,11 @@ pub trait ExpressionFolder<Ref> {
fn fold_pattern(&mut self, pattern: Pattern) -> Result<Pattern, Self::Error> {
Ok(match pattern {
Pattern::CatchAll | Pattern::Number(_) | Pattern::String(_) | Pattern::Variable(_) => {
pattern
}
Pattern::CatchAll
| Pattern::Ellipsis
| Pattern::Number(_)
| Pattern::String(_)
| Pattern::Variable(_) => pattern,
Pattern::Tuple(t) => Pattern::Tuple(
t.into_iter()
.map(|p| self.fold_pattern(p))

View File

@@ -824,7 +824,8 @@ impl Children<Expression> for ArrayExpression {
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Serialize, Deserialize, JsonSchema)]
pub enum Pattern {
CatchAll,
CatchAll, // "_", matches a single value
Ellipsis, // "..", matches a series of values, only valid inside array patterns
#[schemars(skip)]
Number(BigInt),
String(String),
@@ -846,18 +847,22 @@ impl Pattern {
impl Children<Pattern> for Pattern {
fn children(&self) -> Box<dyn Iterator<Item = &Pattern> + '_> {
match self {
Pattern::CatchAll | Pattern::Number(_) | Pattern::String(_) | Pattern::Variable(_) => {
Box::new(empty())
}
Pattern::CatchAll
| Pattern::Ellipsis
| Pattern::Number(_)
| Pattern::String(_)
| Pattern::Variable(_) => Box::new(empty()),
Pattern::Tuple(p) | Pattern::Array(p) => Box::new(p.iter()),
}
}
fn children_mut(&mut self) -> Box<dyn Iterator<Item = &mut Pattern> + '_> {
match self {
Pattern::CatchAll | Pattern::Number(_) | Pattern::String(_) | Pattern::Variable(_) => {
Box::new(empty())
}
Pattern::CatchAll
| Pattern::Ellipsis
| Pattern::Number(_)
| Pattern::String(_)
| Pattern::Variable(_) => Box::new(empty()),
Pattern::Tuple(p) | Pattern::Array(p) => Box::new(p.iter_mut()),
}
}

View File

@@ -628,10 +628,15 @@ Pattern: Pattern = {
StringLiteral => Pattern::String(<>),
TuplePattern,
ArrayPattern,
VariablePattern,
Identifier => Pattern::Variable(<>),
//EnumPattern,
}
PatternIncludingEllipsis: Pattern = {
Pattern => <>,
".." => Pattern::Ellipsis,
}
TuplePattern: Pattern = {
"(" ")" => Pattern::Tuple(vec![]),
"(" <mut items:(<Pattern> ",")+> <last:Pattern> ")" => Pattern::Tuple({items.push(last); items})
@@ -639,11 +644,7 @@ TuplePattern: Pattern = {
ArrayPattern: Pattern = {
"[" "]" => Pattern::Array(vec![]),
"[" <mut items:(<Pattern> ",")*> <last:Pattern> "]" => Pattern::Array({items.push(last); items})
}
VariablePattern: Pattern = {
<Identifier> => Pattern::Variable(<>),
"[" <mut items:(<PatternIncludingEllipsis> ",")*> <last:PatternIncludingEllipsis> "]" => Pattern::Array({items.push(last); items})
}

View File

@@ -251,6 +251,7 @@ impl<'a, T: FieldElement> Value<'a, T> {
pattern: &Pattern,
) -> Option<Vec<Arc<Value<'b, T>>>> {
match pattern {
Pattern::Ellipsis => unreachable!("Should be handled higher up"),
Pattern::CatchAll => Some(vec![]),
Pattern::Number(n) => match v.as_ref() {
Value::Integer(x) if x == n => Some(vec![]),
@@ -278,18 +279,38 @@ impl<'a, T: FieldElement> Value<'a, T> {
}
_ => unreachable!(),
},
Pattern::Array(items) => match v.as_ref() {
Value::Array(values) if values.len() == items.len() => values
.iter()
.zip(items)
Pattern::Array(items) => {
let Value::Array(values) = v.as_ref() else {
panic!("Type error")
};
// Index of ".."
let ellipsis_pos = items.iter().position(|i| *i == Pattern::Ellipsis);
// Check if the value is too short.
let length_matches = match ellipsis_pos {
Some(_) => values.len() >= items.len() - 1,
None => values.len() == items.len(),
};
if !length_matches {
return None;
}
// Split value into "left" and "right" part.
let left_len = ellipsis_pos.unwrap_or(values.len());
let right_len = ellipsis_pos.map(|p| items.len() - p - 1).unwrap_or(0);
let left = values.iter().take(left_len);
let right = values.iter().skip(values.len() - right_len);
assert_eq!(
left.len() + right.len(),
items.len() - ellipsis_pos.map(|_| 1).unwrap_or_default()
);
left.chain(right)
.zip(items.iter().filter(|&i| *i != Pattern::Ellipsis))
.try_fold(vec![], |mut vars, (e, p)| {
Value::try_match_pattern(e, p).map(|v| {
vars.extend(v);
vars
})
}),
_ => None,
},
})
}
Pattern::Variable(_) => Some(vec![v.clone()]),
}
}
@@ -1188,4 +1209,65 @@ mod test {
"[21, 5, 307, 9, 90, 99, -1]".to_string()
);
}
#[test]
pub fn match_skip_array() {
let src = r#"
let f: int[] -> int = |arr| match arr {
[x, .., y] => x + y,
[] => 19,
_ => 99,
};
let t = [f([]), f([1]), f([1, 2]), f([1, 2, 3]), f([1, 2, 3, 4])];
"#;
assert_eq!(
parse_and_evaluate_symbol(src, "t"),
"[19, 99, 3, 4, 5]".to_string()
);
}
#[test]
pub fn match_skip_array_2() {
let src = r#"
let f: int[] -> int = |arr| match arr {
[.., y] => y,
_ => 99,
};
let t = [f([]), f([1]), f([1, 2]), f([1, 2, 3]), f([1, 2, 3, 4])];
"#;
assert_eq!(
parse_and_evaluate_symbol(src, "t"),
"[99, 1, 2, 3, 4]".to_string()
);
}
#[test]
pub fn match_skip_array_3() {
let src = r#"
let f: int[] -> int = |arr| match arr {
[.., x, y] => x,
[..] => 99,
};
let t = [f([]), f([1]), f([1, 2]), f([1, 2, 3]), f([1, 2, 3, 4])];
"#;
assert_eq!(
parse_and_evaluate_symbol(src, "t"),
"[99, 99, 1, 2, 3]".to_string()
);
}
#[test]
pub fn match_skip_array_4() {
let src = r#"
let f: int[] -> int = |arr| match arr {
[x, y, ..] => y,
[..] => 99,
};
let t = [f([]), f([1]), f([1, 2]), f([1, 2, 3]), f([1, 2, 3, 4])];
"#;
assert_eq!(
parse_and_evaluate_symbol(src, "t"),
"[99, 99, 2, 2, 2]".to_string()
);
}
}

View File

@@ -142,8 +142,14 @@ impl<D: AnalysisDriver> ExpressionProcessor<D> {
/// Processes a pattern, registering all variables bound in there.
fn process_pattern(&mut self, pattern: &Pattern) {
match pattern {
Pattern::CatchAll | Pattern::Number(_) | Pattern::String(_) => {}
Pattern::CatchAll | Pattern::Ellipsis | Pattern::Number(_) | Pattern::String(_) => {}
Pattern::Tuple(items) | Pattern::Array(items) => {
if matches!(pattern, Pattern::Array(_)) {
// If there is more than one Pattern::Ellipsis in items, it is an error
if items.iter().filter(|p| *p == &Pattern::Ellipsis).count() > 1 {
panic!("Only one \"..\"-item allowed in array pattern");
}
}
items.iter().for_each(|p| self.process_pattern(p));
}
Pattern::Variable(name) => {

View File

@@ -680,6 +680,7 @@ impl TypeChecker {
/// Type-checks a pattern and adds local variables.
fn infer_type_of_pattern(&mut self, pattern: &Pattern) -> Result<Type, String> {
Ok(match pattern {
Pattern::Ellipsis => unreachable!("Should be handled higher up."),
Pattern::CatchAll => self.new_type_var(),
Pattern::Number(_) => {
let ty = self.new_type_var();
@@ -696,7 +697,9 @@ impl TypeChecker {
Pattern::Array(items) => {
let item_type = self.new_type_var();
for item in items {
self.expect_type_of_pattern(&item_type, item)?;
if item != &Pattern::Ellipsis {
self.expect_type_of_pattern(&item_type, item)?;
}
}
Type::Array(ArrayType {
base: Box::new(item_type),

View File

@@ -642,3 +642,27 @@ fn match_shadowing() {
";
assert_eq!(input, analyze_string::<GoldilocksField>(input).to_string());
}
#[test]
fn single_ellipsis() {
let input = " let t: int[] -> int = (|i| match i {
[1, .., 3] => 2,
[..] => 3,
[.., 1] => 9,
[7, 8, ..] => 2,
_ => -1,
});
";
assert_eq!(input, analyze_string::<GoldilocksField>(input).to_string());
}
#[test]
#[should_panic = "Only one \"..\"-item allowed in array pattern"]
fn multi_ellipsis() {
let input = " let t: int[] -> int = (|i| match i {
[1, .., 3, ..] => 2,
_ => -1,
});
";
assert_eq!(input, analyze_string::<GoldilocksField>(input).to_string());
}