diff --git a/pil-analyzer/src/condenser.rs b/pil-analyzer/src/condenser.rs index 0ca92fbbf..44bd6f1c1 100644 --- a/pil-analyzer/src/condenser.rs +++ b/pil-analyzer/src/condenser.rs @@ -6,7 +6,7 @@ use std::{collections::HashMap, fmt::Display, rc::Rc}; use itertools::Itertools; use powdr_ast::{ analyzed::{ - types::{Type, TypedExpression}, + types::{ArrayType, Type, TypedExpression}, AlgebraicExpression, AlgebraicReference, Analyzed, Expression, FunctionValueDefinition, Identity, IdentityKind, PolynomialReference, PolynomialType, PublicDeclaration, Reference, StatementIdentifier, Symbol, SymbolKind, @@ -59,7 +59,12 @@ pub fn condense( let Some(FunctionValueDefinition::Expression(e)) = definition else { panic!("Expected expression") }; - assert!(e.ty.is_none() || e.ty == Some(Type::col())); + assert!( + e.ty.is_none() || + e.ty == Some(Type::Expr) || + matches!(&e.ty, Some(Type::Array(ArrayType{base, ..})) if base.as_ref() == &Type::Expr), + "Intermediate column type has to be expr or expr[], but got: {}", e.ty.as_ref().map(|t| t.to_string()).unwrap_or_default() + ); Some(( name.clone(), (symbol.clone(), condenser.condense_expression(&e.e)), diff --git a/pil-analyzer/src/statement_processor.rs b/pil-analyzer/src/statement_processor.rs index 61c242a41..8ecf73087 100644 --- a/pil-analyzer/src/statement_processor.rs +++ b/pil-analyzer/src/statement_processor.rs @@ -105,7 +105,7 @@ where source, name, SymbolKind::Poly(PolynomialType::Intermediate), - Some(Type::col()), + Some(Type::Expr), Some(FunctionDefinition::Expression(value)), ), PilStatement::PublicDeclaration(source, name, polynomial, array_index, index) => { @@ -232,23 +232,23 @@ where } Some(value) => { // TODO if we have proper type deduction here in the future, we can rely only on the type. - let (ty, symbol_kind) = if ty == Some(Type::col()) - || (ty.is_none() - && matches!(&value, parsed::Expression::LambdaExpression(lambda) if lambda.params.len() == 1)) - { - ( - Some(Type::col()), - SymbolKind::Poly(PolynomialType::Constant), - ) - } else if ty == Some(Type::Fe) - || (ty.is_none() && self.evaluate_expression(value.clone()).is_ok()) - { - // Value evaluates to a constant number => treat it as a constant - (Some(Type::Fe), SymbolKind::Constant()) - } else { - // Otherwise, treat it as "generic definition" - (ty, SymbolKind::Other()) - }; + + let ty = ty.or_else(|| { + if matches!(&value, parsed::Expression::LambdaExpression(lambda) if lambda.params.len() == 1) { + Some(Type::col()) + } else if self.evaluate_expression(value.clone()).is_ok() { + // Value evaluates to a constant number => treat it as a constant + Some(Type::Fe) + } else { + // Otherwise, treat it as "generic definition" + None + } + }); + let symbol_kind = ty + .as_ref() + .map(Self::symbol_kind_from_type) + .unwrap_or(SymbolKind::Other()); + self.handle_symbol_definition( source, name, @@ -260,6 +260,23 @@ where } } + fn symbol_kind_from_type(ty: &Type) -> SymbolKind { + match ty { + Type::Expr => SymbolKind::Poly(PolynomialType::Intermediate), + Type::Fe => SymbolKind::Constant(), + t if *t == Type::col() => SymbolKind::Poly(PolynomialType::Constant), + Type::Array(ArrayType { base, length: _ }) if base.as_ref() == &Type::col() => { + // Array of fixed columns + SymbolKind::Poly(PolynomialType::Constant) + } + Type::Array(ArrayType { base, length: _ }) if base.as_ref() == &Type::Expr => { + SymbolKind::Poly(PolynomialType::Intermediate) + } + // Otherwise, treat it as "generic definition" + _ => SymbolKind::Other(), + } + } + fn handle_identity_statement(&mut self, statement: PilStatement) -> Vec> { let (source, kind, left, right) = match statement { PilStatement::PolynomialIdentity(source, expression)