diff --git a/pil-analyzer/src/statement_processor.rs b/pil-analyzer/src/statement_processor.rs index 429754998..143514bfb 100644 --- a/pil-analyzer/src/statement_processor.rs +++ b/pil-analyzer/src/statement_processor.rs @@ -13,7 +13,7 @@ use powdr_ast::parsed::{ ArrayLiteral, EnumDeclaration, EnumVariant, FunctionDefinition, FunctionKind, LambdaExpression, PilStatement, PolynomialName, SelectedExpressions, TraitDeclaration, TraitFunction, }; -use powdr_ast::parsed::{NamedExpression, SymbolCategory, TraitImplementation}; +use powdr_ast::parsed::{ArrayExpression, NamedExpression, SymbolCategory, TraitImplementation}; use powdr_parser_util::SourceRef; use std::str::FromStr; @@ -426,110 +426,145 @@ where degree: self.degree, }; - if let Some(FunctionDefinition::TypeDeclaration(enum_decl)) = value { - // For enums, we add PILItems both for the enum itself and also for all - // its type constructors. - assert_eq!(symbol_kind, SymbolKind::Other()); - let enum_decl = self.process_enum_declaration(enum_decl); - let shared_enum_decl = Arc::new(enum_decl.clone()); - let var_items = enum_decl.variants.iter().map(|variant| { - let var_symbol = Symbol { - id: self.counters.dispense_symbol_id(SymbolKind::Other(), None), - source: source.clone(), - absolute_name: self - .driver - .resolve_namespaced_decl(&[&name, &variant.name]) - .relative_to(&Default::default()) - .to_string(), - stage: None, - kind: SymbolKind::Other(), - length: None, - degree: None, - }; - let value = FunctionValueDefinition::TypeConstructor( - shared_enum_decl.clone(), - variant.clone(), - ); - PILItem::Definition(var_symbol, Some(value)) - }); - return iter::once(PILItem::Definition( - symbol, - Some(FunctionValueDefinition::TypeDeclaration(enum_decl.clone())), - )) - .chain(var_items) + match value { + Some(FunctionDefinition::TypeDeclaration(enum_decl)) => { + assert_eq!(symbol_kind, SymbolKind::Other()); + self.process_enum_declaration(source, name, symbol, enum_decl) + } + Some(FunctionDefinition::TraitDeclaration(trait_decl)) => { + self.process_trait_declaration(source, name, symbol, trait_decl) + } + Some(FunctionDefinition::Expression(expr)) => { + self.process_expression_symbol(symbol_kind, symbol, type_scheme, expr) + } + Some(FunctionDefinition::Array(value)) => { + self.process_array_symbol(symbol, type_scheme, value) + } + None => vec![PILItem::Definition(symbol, None)], + } + } + + fn process_trait_declaration( + &mut self, + source: SourceRef, + name: String, + symbol: Symbol, + trait_decl: TraitDeclaration, + ) -> Vec { + let type_vars = trait_decl.type_vars.iter().collect(); + let functions = trait_decl + .functions + .into_iter() + .map(|f| TraitFunction { + name: f.name, + ty: self.type_processor(&type_vars).process_type(f.ty), + }) .collect(); - } else if let Some(FunctionDefinition::TraitDeclaration(trait_decl)) = value { - let trait_decl = self.process_trait_declaration(trait_decl); - let shared_trait_decl = Arc::new(trait_decl.clone()); - let trait_functions = trait_decl.functions.iter().map(|function| { - let f_symbol = Symbol { - id: self.counters.dispense_symbol_id(SymbolKind::Other(), None), - source: source.clone(), - absolute_name: self - .driver + let trait_decl = TraitDeclaration { + name: self.driver.resolve_decl(&trait_decl.name), + type_vars: trait_decl.type_vars, + functions, + }; + + let inner_items = trait_decl + .functions + .iter() + .map(|function| { + ( + self.driver .resolve_namespaced_decl(&[&name, &function.name]) .relative_to(&Default::default()) .to_string(), + FunctionValueDefinition::TraitFunction( + Arc::new(trait_decl.clone()), + function.clone(), + ), + ) + }) + .collect(); + let trait_functions = self.process_inner_definitions(source, inner_items); + + iter::once(PILItem::Definition( + symbol, + Some(FunctionValueDefinition::TraitDeclaration( + trait_decl.clone(), + )), + )) + .chain(trait_functions) + .collect() + } + + fn process_expression_symbol( + &mut self, + symbol_kind: SymbolKind, + symbol: Symbol, + type_scheme: Option, + expr: parsed::Expression, + ) -> Vec { + if symbol_kind == SymbolKind::Poly(PolynomialType::Committed) { + // The only allowed value for a witness column is a query function. + assert!(matches!( + expr, + parsed::Expression::LambdaExpression( + _, + LambdaExpression { + kind: FunctionKind::Query, + .. + } + ) + )); + assert!(type_scheme.is_none() || type_scheme == Some(Type::Col.into())); + } + let type_vars = type_scheme + .as_ref() + .map(|ts| ts.vars.vars().collect()) + .unwrap_or_default(); + let value = FunctionValueDefinition::Expression(TypedExpression { + e: self + .expression_processor(&type_vars) + .process_expression(expr), + type_scheme, + }); + + vec![PILItem::Definition(symbol, Some(value))] + } + + fn process_array_symbol( + &mut self, + symbol: Symbol, + type_scheme: Option, + value: ArrayExpression, + ) -> Vec { + let expression = self + .expression_processor(&Default::default()) + .process_array_expression(value); + assert!(type_scheme.is_none() || type_scheme == Some(Type::Col.into())); + let value = FunctionValueDefinition::Array(expression); + + vec![PILItem::Definition(symbol, Some(value))] + } + + /// Given a list of (absolute_name, value) pairs, create PIL items for each of them. + fn process_inner_definitions( + &mut self, + source: SourceRef, + inner_items: Vec<(String, FunctionValueDefinition)>, + ) -> Vec { + inner_items + .into_iter() + .map(|(absolute_name, value)| { + let symbol = Symbol { + id: self.counters.dispense_symbol_id(SymbolKind::Other(), None), + source: source.clone(), + absolute_name, stage: None, kind: SymbolKind::Other(), length: None, degree: None, }; - let value = FunctionValueDefinition::TraitFunction( - shared_trait_decl.clone(), - function.clone(), - ); - PILItem::Definition(f_symbol, Some(value)) - }); - return iter::once(PILItem::Definition( - symbol, - Some(FunctionValueDefinition::TraitDeclaration( - trait_decl.clone(), - )), - )) - .chain(trait_functions) - .collect(); - } - - let value = value.map(|v| match v { - FunctionDefinition::Expression(expr) => { - if symbol_kind == SymbolKind::Poly(PolynomialType::Committed) { - // The only allowed value for a witness column is a query function. - assert!(matches!( - expr, - parsed::Expression::LambdaExpression( - _, - LambdaExpression { - kind: FunctionKind::Query, - .. - } - ) - )); - assert!(type_scheme.is_none() || type_scheme == Some(Type::Col.into())); - } - let type_vars = type_scheme - .as_ref() - .map(|ts| ts.vars.vars().collect()) - .unwrap_or_default(); - FunctionValueDefinition::Expression(TypedExpression { - e: self - .expression_processor(&type_vars) - .process_expression(expr), - type_scheme, - }) - } - FunctionDefinition::Array(value) => { - let expression = self - .expression_processor(&Default::default()) - .process_array_expression(value); - assert!(type_scheme.is_none() || type_scheme == Some(Type::Col.into())); - FunctionValueDefinition::Array(expression) - } - FunctionDefinition::TypeDeclaration(_) | FunctionDefinition::TraitDeclaration(_) => { - unreachable!() - } - }); - vec![PILItem::Definition(symbol, value)] + PILItem::Definition(symbol, Some(value)) + }) + .collect() } fn handle_public_declaration( @@ -577,20 +612,48 @@ where } fn process_enum_declaration( - &self, + &mut self, + source: SourceRef, + name: String, + symbol: Symbol, enum_decl: EnumDeclaration, - ) -> EnumDeclaration { + ) -> Vec { let type_vars = enum_decl.type_vars.vars().collect(); let variants = enum_decl .variants .into_iter() .map(|v| self.process_enum_variant(v, &type_vars)) .collect(); - EnumDeclaration { + let enum_decl = EnumDeclaration { name: self.driver.resolve_decl(&enum_decl.name), type_vars: enum_decl.type_vars, variants, - } + }; + + let inner_items: Vec<_> = enum_decl + .variants + .iter() + .map(|variant| { + ( + self.driver + .resolve_namespaced_decl(&[&name, &variant.name]) + .relative_to(&Default::default()) + .to_string(), + FunctionValueDefinition::TypeConstructor( + Arc::new(enum_decl.clone()), + variant.clone(), + ), + ) + }) + .collect(); + let var_items = self.process_inner_definitions(source, inner_items); + + iter::once(PILItem::Definition( + symbol, + Some(FunctionValueDefinition::TypeDeclaration(enum_decl.clone())), + )) + .chain(var_items) + .collect() } fn process_enum_variant( @@ -608,26 +671,6 @@ where } } - fn process_trait_declaration( - &self, - trait_decl: parsed::TraitDeclaration, - ) -> TraitDeclaration { - let type_vars = trait_decl.type_vars.iter().collect(); - let functions = trait_decl - .functions - .into_iter() - .map(|f| TraitFunction { - name: f.name, - ty: self.type_processor(&type_vars).process_type(f.ty), - }) - .collect(); - TraitDeclaration { - name: self.driver.resolve_decl(&trait_decl.name), - type_vars: trait_decl.type_vars, - functions, - } - } - fn process_trait_implementation( &self, trait_impl: parsed::TraitImplementation,