Merge pull request #1022 from powdr-labs/fix_type_of_intermeditae

Type of intermediates is expr.
This commit is contained in:
Leo
2024-02-06 13:57:25 +00:00
committed by GitHub
2 changed files with 42 additions and 20 deletions

View File

@@ -6,7 +6,7 @@ use std::{collections::HashMap, fmt::Display, rc::Rc};
use itertools::Itertools;
use powdr_ast::{
analyzed::{
types::{Type, TypedExpression},
types::{ArrayType, Type, TypedExpression},
AlgebraicExpression, AlgebraicReference, Analyzed, Expression, FunctionValueDefinition,
Identity, IdentityKind, PolynomialReference, PolynomialType, PublicDeclaration, Reference,
StatementIdentifier, Symbol, SymbolKind,
@@ -59,7 +59,12 @@ pub fn condense<T: FieldElement>(
let Some(FunctionValueDefinition::Expression(e)) = definition else {
panic!("Expected expression")
};
assert!(e.ty.is_none() || e.ty == Some(Type::col()));
assert!(
e.ty.is_none() ||
e.ty == Some(Type::Expr) ||
matches!(&e.ty, Some(Type::Array(ArrayType{base, ..})) if base.as_ref() == &Type::Expr),
"Intermediate column type has to be expr or expr[], but got: {}", e.ty.as_ref().map(|t| t.to_string()).unwrap_or_default()
);
Some((
name.clone(),
(symbol.clone(), condenser.condense_expression(&e.e)),

View File

@@ -105,7 +105,7 @@ where
source,
name,
SymbolKind::Poly(PolynomialType::Intermediate),
Some(Type::col()),
Some(Type::Expr),
Some(FunctionDefinition::Expression(value)),
),
PilStatement::PublicDeclaration(source, name, polynomial, array_index, index) => {
@@ -232,23 +232,23 @@ where
}
Some(value) => {
// TODO if we have proper type deduction here in the future, we can rely only on the type.
let (ty, symbol_kind) = if ty == Some(Type::col())
|| (ty.is_none()
&& matches!(&value, parsed::Expression::LambdaExpression(lambda) if lambda.params.len() == 1))
{
(
Some(Type::col()),
SymbolKind::Poly(PolynomialType::Constant),
)
} else if ty == Some(Type::Fe)
|| (ty.is_none() && self.evaluate_expression(value.clone()).is_ok())
{
// Value evaluates to a constant number => treat it as a constant
(Some(Type::Fe), SymbolKind::Constant())
} else {
// Otherwise, treat it as "generic definition"
(ty, SymbolKind::Other())
};
let ty = ty.or_else(|| {
if matches!(&value, parsed::Expression::LambdaExpression(lambda) if lambda.params.len() == 1) {
Some(Type::col())
} else if self.evaluate_expression(value.clone()).is_ok() {
// Value evaluates to a constant number => treat it as a constant
Some(Type::Fe)
} else {
// Otherwise, treat it as "generic definition"
None
}
});
let symbol_kind = ty
.as_ref()
.map(Self::symbol_kind_from_type)
.unwrap_or(SymbolKind::Other());
self.handle_symbol_definition(
source,
name,
@@ -260,6 +260,23 @@ where
}
}
fn symbol_kind_from_type(ty: &Type) -> SymbolKind {
match ty {
Type::Expr => SymbolKind::Poly(PolynomialType::Intermediate),
Type::Fe => SymbolKind::Constant(),
t if *t == Type::col() => SymbolKind::Poly(PolynomialType::Constant),
Type::Array(ArrayType { base, length: _ }) if base.as_ref() == &Type::col() => {
// Array of fixed columns
SymbolKind::Poly(PolynomialType::Constant)
}
Type::Array(ArrayType { base, length: _ }) if base.as_ref() == &Type::Expr => {
SymbolKind::Poly(PolynomialType::Intermediate)
}
// Otherwise, treat it as "generic definition"
_ => SymbolKind::Other(),
}
}
fn handle_identity_statement(&mut self, statement: PilStatement<T>) -> Vec<PILItem<T>> {
let (source, kind, left, right) = match statement {
PilStatement::PolynomialIdentity(source, expression)