From 161f4d8181797eece731cb8746bd3acf4e2cab39 Mon Sep 17 00:00:00 2001 From: chriseth Date: Thu, 4 Apr 2024 18:31:11 +0200 Subject: [PATCH] Parse turbofish (#1219) --- ast/src/parsed/display.rs | 6 +- ast/src/parsed/mod.rs | 12 +++- parser/src/lib.rs | 30 ++++++++++ parser/src/powdr.lalrpop | 20 ++++++- pil-analyzer/src/expression_processor.rs | 32 ++++++----- pil-analyzer/src/pil_analyzer.rs | 3 +- pil-analyzer/src/statement_processor.rs | 58 +++++++++++-------- pil-analyzer/src/type_inference.rs | 27 ++++++++- pil-analyzer/src/untyped_evaluator.rs | 2 +- pil-analyzer/tests/parse_display.rs | 4 +- pil-analyzer/tests/types.rs | 72 ++++++++++++++++++++++++ 11 files changed, 217 insertions(+), 49 deletions(-) diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index ca42d2727..f71e9255b 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -601,7 +601,11 @@ impl Display for PolynomialName { impl Display for NamespacedPolynomialReference { fn fmt(&self, f: &mut Formatter<'_>) -> Result { - write!(f, "{}", self.path.to_dotted_string()) + if let Some(type_args) = &self.type_args { + write!(f, "{}::<{}>", self.path, type_args.iter().format(", ")) + } else { + write!(f, "{}", self.path.to_dotted_string()) + } } } diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index 2fd6ce450..96264084a 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -477,11 +477,15 @@ pub struct PolynomialName { /// This is different from SymbolPath mainly due to different formatting. pub struct NamespacedPolynomialReference { pub path: SymbolPath, + pub type_args: Option>>, } impl From for NamespacedPolynomialReference { fn from(value: SymbolPath) -> Self { - Self { path: value } + Self { + path: value, + type_args: Default::default(), + } } } @@ -491,7 +495,11 @@ impl NamespacedPolynomialReference { } pub fn try_to_identifier(&self) -> Option<&String> { - self.path.try_to_identifier() + if self.type_args.is_none() { + self.path.try_to_identifier() + } else { + None + } } } diff --git a/parser/src/lib.rs b/parser/src/lib.rs index c2f63b9f0..3c09728b3 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -508,4 +508,34 @@ namespace N(2); let printed = format!("{}", parse(Some("input"), input).unwrap_err_to_stderr()); assert_eq!(input.trim(), printed.trim()); } + + #[test] + fn type_args() { + let input = r#" +namespace N(2); + let max: T, T -> T = (|a, b| if (a < b) { b } else { a }); + let left: T1, T2 -> T1 = (|a, b| a); + let seven = max::(3, 7); + let five = left::(5, [7]); + let also_five = five::<>; +"#; + let printed = format!("{}", parse(Some("input"), input).unwrap_err_to_stderr()); + assert_eq!(input.trim(), printed.trim()); + } + + #[test] + fn type_args_with_space() { + let input = r#" +namespace N(2); + let max: T, T -> T = (|a, b| if (a < b) { b } else { a }); + let seven = max :: (3, 7); +"#; + let expected = r#" +namespace N(2); + let max: T, T -> T = (|a, b| if (a < b) { b } else { a }); + let seven = max::(3, 7); +"#; + let printed = format!("{}", parse(Some("input"), input).unwrap_err_to_stderr()); + assert_eq!(expected.trim(), printed.trim()); + } } diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index 294f86c17..e414c7660 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -60,6 +60,19 @@ Part: Part = { => Part::Named(name), } +// Same as SymbolPath plus we allow "::<...>" at the end. +GenericSymbolPath: (SymbolPath, Option>>) = { + // If we "inline" SymbolPath here, we get an ambiguity error. + "::" )*> ">")?> => ( + SymbolPath::from_parts([ + abs.map(|_| vec![Part::Named(String::new())]).unwrap_or_default(), + parts, + vec![end], + ].concat()), + types + ), +} + /// Same as SymbolPath except that we do not allow 'int' and 'fe' to be parsed as identifiers. TypeSymbolPath: SymbolPath = { "::" )*> => { @@ -558,7 +571,7 @@ Term: Box = { IndexAccess => Box::new(Expression::IndexAccess(<>)), FunctionCall => Box::new(Expression::FunctionCall(<>)), ConstantIdentifier => Box::new(Expression::Reference(NamespacedPolynomialReference::from_identifier(<>))), - NamespacedPolynomialReference => Box::new(Expression::Reference(<>)), + GenericReference => Box::new(Expression::Reference(<>)), PublicIdentifier => Box::new(Expression::PublicReference(<>)), Number => Box::new(Expression::Number(<>.into(), None)), StringLiteral => Box::new(Expression::String(<>)), @@ -585,6 +598,11 @@ NamespacedPolynomialReference: NamespacedPolynomialReference = { "." => SymbolPath::from_parts([namespace, name].into_iter().map(Part::Named)).into(), } +GenericReference: NamespacedPolynomialReference = { + "." => SymbolPath::from_parts([namespace, name].into_iter().map(Part::Named)).into(), + => NamespacedPolynomialReference{path: path.0, type_args: path.1}, +} + MatchExpression: Box = { "match" "{" "}" => Box::new(Expression::MatchExpression(<>)) } diff --git a/pil-analyzer/src/expression_processor.rs b/pil-analyzer/src/expression_processor.rs index 8c48ba13a..160e8d4d6 100644 --- a/pil-analyzer/src/expression_processor.rs +++ b/pil-analyzer/src/expression_processor.rs @@ -1,30 +1,33 @@ -use std::collections::HashMap; +use core::panic; +use std::collections::{HashMap, HashSet}; use powdr_ast::{ analyzed::{Expression, PolynomialReference, Reference, RepeatedArray}, parsed::{ - self, asm::SymbolPath, ArrayExpression, ArrayLiteral, IfExpression, LambdaExpression, + self, ArrayExpression, ArrayLiteral, IfExpression, LambdaExpression, LetStatementInsideBlock, MatchArm, NamespacedPolynomialReference, Pattern, SelectedExpressions, StatementInsideBlock, }, }; use powdr_number::DegreeType; -use crate::AnalysisDriver; +use crate::{type_processor::TypeProcessor, AnalysisDriver}; /// The ExpressionProcessor turns parsed expressions into analyzed expressions. /// Its main job is to resolve references: /// It turns simple references into fully namespaced references and resolves local function variables. -pub struct ExpressionProcessor { +pub struct ExpressionProcessor<'a, D: AnalysisDriver> { driver: D, + type_vars: &'a HashSet<&'a String>, local_variables: HashMap, local_variable_counter: u64, } -impl ExpressionProcessor { - pub fn new(driver: D) -> Self { +impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> { + pub fn new(driver: D, type_vars: &'a HashSet<&'a String>) -> Self { Self { driver, + type_vars, local_variables: Default::default(), local_variable_counter: 0, } @@ -168,7 +171,7 @@ impl ExpressionProcessor { let id = self.local_variables[name]; Reference::LocalVar(id, name.to_string()) } - _ => Reference::Poly(self.process_namespaced_polynomial_reference(&reference.path)), + _ => Reference::Poly(self.process_namespaced_polynomial_reference(reference)), } } @@ -225,15 +228,18 @@ impl ExpressionProcessor { pub fn process_namespaced_polynomial_reference( &mut self, - path: &SymbolPath, + reference: NamespacedPolynomialReference, ) -> PolynomialReference { + let type_processor = TypeProcessor::new(self.driver, self.type_vars); + let type_args = reference.type_args.map(|args| { + args.into_iter() + .map(|t| type_processor.process_type(t)) + .collect() + }); PolynomialReference { - name: self.driver.resolve_value_ref(path), + name: self.driver.resolve_value_ref(&reference.path), poly_id: None, - // These will be filled by the type checker. - // TODO at some point we should support the turbofish operator - // in the parser. - type_args: Default::default(), + type_args, } } diff --git a/pil-analyzer/src/pil_analyzer.rs b/pil-analyzer/src/pil_analyzer.rs index ef6225d03..057988c2b 100644 --- a/pil-analyzer/src/pil_analyzer.rs +++ b/pil-analyzer/src/pil_analyzer.rs @@ -349,7 +349,8 @@ impl PILAnalyzer { fn handle_namespace(&mut self, name: SymbolPath, degree: Option) { if let Some(degree) = degree { - let degree = ExpressionProcessor::new(self.driver()).process_expression(degree); + let degree = ExpressionProcessor::new(self.driver(), &Default::default()) + .process_expression(degree); // TODO we should maybe implement a separate evaluator that is able to run before type checking // and is field-independent (only uses integers)? let namespace_degree: u64 = u64::try_from( diff --git a/pil-analyzer/src/statement_processor.rs b/pil-analyzer/src/statement_processor.rs index 6695257d2..42f7ed376 100644 --- a/pil-analyzer/src/statement_processor.rs +++ b/pil-analyzer/src/statement_processor.rs @@ -312,7 +312,10 @@ where source, IdentityKind::Polynomial, SelectedExpressions { - selector: Some(self.process_expression(expression)), + selector: Some( + self.expression_processor(&Default::default()) + .process_expression(expression), + ), expressions: vec![], }, SelectedExpressions::default(), @@ -320,25 +323,33 @@ where PilStatement::PlookupIdentity(source, key, haystack) => ( source, IdentityKind::Plookup, - self.process_selected_expressions(key), - self.process_selected_expressions(haystack), + self.expression_processor(&Default::default()) + .process_selected_expressions(key), + self.expression_processor(&Default::default()) + .process_selected_expressions(haystack), ), PilStatement::PermutationIdentity(source, left, right) => ( source, IdentityKind::Permutation, - self.process_selected_expressions(left), - self.process_selected_expressions(right), + self.expression_processor(&Default::default()) + .process_selected_expressions(left), + self.expression_processor(&Default::default()) + .process_selected_expressions(right), ), PilStatement::ConnectIdentity(source, left, right) => ( source, IdentityKind::Connect, SelectedExpressions { selector: None, - expressions: self.expression_processor().process_expressions(left), + expressions: self + .expression_processor(&Default::default()) + .process_expressions(left), }, SelectedExpressions { selector: None, - expressions: self.expression_processor().process_expressions(right), + expressions: self + .expression_processor(&Default::default()) + .process_expressions(right), }, ), // TODO at some point, these should all be caught by the type checker. @@ -457,15 +468,21 @@ where )); assert!(type_scheme.is_none() || type_scheme == Some(Type::Col.into())); } + let type_vars = type_scheme + .as_ref() + .map(|ts| ts.vars.vars().collect()) + .unwrap_or_default(); FunctionValueDefinition::Expression(TypedExpression { - e: self.process_expression(expr), + e: self + .expression_processor(&type_vars) + .process_expression(expr), type_scheme, }) } FunctionDefinition::Array(value) => { let size = value.solve(self.degree.unwrap()); let expression = self - .expression_processor() + .expression_processor(&Default::default()) .process_array_expression(value, size); assert_eq!( expression.iter().map(|e| e.size()).sum::(), @@ -489,8 +506,8 @@ where ) -> Vec { let id = self.counters.dispense_public_id(); let polynomial = self - .expression_processor() - .process_namespaced_polynomial_reference(&poly.path); + .expression_processor(&Default::default()) + .process_namespaced_polynomial_reference(poly); let array_index = array_index.map(|i| { let index: u64 = untyped_evaluator::evaluate_expression_to_int(self.driver, i) .unwrap() @@ -512,20 +529,11 @@ where })] } - fn expression_processor(&self) -> ExpressionProcessor { - ExpressionProcessor::new(self.driver) - } - - fn process_expression(&self, expr: parsed::Expression) -> Expression { - self.expression_processor().process_expression(expr) - } - - fn process_selected_expressions( - &self, - expr: parsed::SelectedExpressions, - ) -> SelectedExpressions { - self.expression_processor() - .process_selected_expressions(expr) + fn expression_processor<'b>( + &'b self, + type_vars: &'b HashSet<&'b String>, + ) -> ExpressionProcessor<'b, D> { + ExpressionProcessor::new(self.driver, type_vars) } fn type_processor<'b>(&'b self, type_vars: &'b HashSet<&'b String>) -> TypeProcessor<'b, D> { diff --git a/pil-analyzer/src/type_inference.rs b/pil-analyzer/src/type_inference.rs index eeaa34aeb..faba170f2 100644 --- a/pil-analyzer/src/type_inference.rs +++ b/pil-analyzer/src/type_inference.rs @@ -57,6 +57,8 @@ struct TypeChecker<'a> { /// Declared types for all symbols. Contains the unmodified type scheme for symbols /// with generic types and newly created type variables for symbols without declared type. declared_types: HashMap, + /// Current mapping of declared type vars to type. Reset before checking each definition. + declared_type_vars: HashMap, unifier: Unifier, /// Last used type variable index. last_type_var: usize, @@ -68,6 +70,7 @@ impl<'a> TypeChecker<'a> { statement_type, local_var_types: Default::default(), declared_types: Default::default(), + declared_type_vars: Default::default(), unifier: Default::default(), last_type_var: Default::default(), } @@ -133,8 +136,14 @@ impl<'a> TypeChecker<'a> { let declared_type = self.declared_types[&name].clone(); let result = if declared_type.vars.is_empty() { + self.declared_type_vars.clear(); self.process_concrete_symbol(&name, declared_type.ty.clone(), value) } else { + self.declared_type_vars = declared_type + .vars + .vars() + .map(|v| (v.clone(), self.new_type_var())) + .collect(); self.infer_type_of_expression(value).map(|ty| { inferred_types.insert(name.to_string(), ty); }) @@ -145,6 +154,7 @@ impl<'a> TypeChecker<'a> { )); } } + self.declared_type_vars.clear(); self.check_expressions(expressions)?; @@ -454,9 +464,22 @@ impl<'a> TypeChecker<'a> { poly_id: _, type_args, })) => { - // The generic args (some of them) could be pre-filled by the parser, but we do not yet support that. - assert!(type_args.is_none()); let (ty, args) = self.instantiate_scheme(self.declared_types[name].clone()); + if let Some(requested_type_args) = type_args { + if requested_type_args.len() != args.len() { + return Err(format!( + "Expected {} type arguments for symbol {name}, but got {}: {}", + args.len(), + requested_type_args.len(), + requested_type_args.iter().join(", ") + )); + } + for (requested, inferred) in requested_type_args.iter_mut().zip(&args) { + requested.substitute_type_vars(&self.declared_type_vars); + self.unifier + .unify_types(requested.clone(), inferred.clone())?; + } + } *type_args = Some(args); type_for_reference(&ty) } diff --git a/pil-analyzer/src/untyped_evaluator.rs b/pil-analyzer/src/untyped_evaluator.rs index 462a72e45..536b31259 100644 --- a/pil-analyzer/src/untyped_evaluator.rs +++ b/pil-analyzer/src/untyped_evaluator.rs @@ -17,7 +17,7 @@ pub fn evaluate_expression_to_int( expr: parsed::Expression, ) -> Result { evaluator::evaluate_expression::( - &ExpressionProcessor::new(driver).process_expression(expr), + &ExpressionProcessor::new(driver, &Default::default()).process_expression(expr), driver.definitions(), )? .try_to_integer() diff --git a/pil-analyzer/tests/parse_display.rs b/pil-analyzer/tests/parse_display.rs index f177f1add..3f95b170f 100644 --- a/pil-analyzer/tests/parse_display.rs +++ b/pil-analyzer/tests/parse_display.rs @@ -6,9 +6,6 @@ use pretty_assertions::assert_eq; #[test] fn parse_print_analyzed() { - // Re-add this line once we can parse the turbofish operator. - // col witness X_free_value(__i) query match std::prover::eval(T.pc) { 0 => std::prover::Query::Input(1), 3 => std::prover::Query::Input(std::convert::int::(std::prover::eval(T.CNT) + 1)), 7 => std::prover::Query::Input(0), }; - // This is rather a test for the Display trait than for the analyzer. let input = r#"constant %N = 65536; public P = T.pc(2); @@ -52,6 +49,7 @@ namespace T(65536); T.A' = (((T.first_step' * 0) + (T.reg_write_X_A * T.X)) + ((1 - (T.first_step' + T.reg_write_X_A)) * T.A)); col witness X_free_value(__i) query match std::prover::eval(T.pc) { 0 => std::prover::Query::Input(1), + 3 => std::prover::Query::Input(std::convert::int::((std::prover::eval(T.CNT) + 1))), 7 => std::prover::Query::Input(0), }; col fixed p_X_const = [0, 0, 0, 0, 0, 0, 0, 0, 0] + [0]*; diff --git a/pil-analyzer/tests/types.rs b/pil-analyzer/tests/types.rs index 0f5235db6..632520ef0 100644 --- a/pil-analyzer/tests/types.rs +++ b/pil-analyzer/tests/types.rs @@ -333,6 +333,78 @@ fn query_with_wrong_type() { type_check(input, &[]); } +#[test] +#[should_panic = "Type int[] does not satisfy trait"] +fn wrong_type_args() { + let input = " + let bn: T, T -> T = |a, b| a * 0x100000000 + b; + let t: int = bn::(5, 6); + "; + type_check(input, &[]); +} + +#[test] +#[should_panic = "Symbol not found: T"] +fn specialization_non_declared_type_var() { + let input = " + let x: T = 1; + let t: int = x::; + "; + type_check(input, &[]); +} + +#[test] +#[should_panic = "Expected 0 type arguments for symbol x, but got 1: int[]"] +fn specialization_of_non_generic_symbol() { + let input = " + let x: int = 1; + let t: int = x::; + "; + type_check(input, &[]); +} + +#[test] +fn specialization_of_non_generic_symbol2() { + let input = " + let x: int = 1; + let t: int = x::<>; + "; + type_check(input, &[]); +} + +#[test] +fn partial_specialization() { + let input = " + let fold: int, (int -> T1), T2, (T2, T1 -> T2) -> T2 = |length, f, initial, folder| + if length <= 0 { + initial + } else { + folder(fold((length - 1), f, initial, folder), f((length - 1))) + }; + let fold_to_int_arr: int, (int -> T), int[], (int[], T -> int[]) -> int[] = fold::; + let fold_int: int, (int -> int), T, (T, int -> T) -> T = fold::; + let y = fold_to_int_arr(4, |i| i, [], |acc, x| acc + [x]); + let z = fold_int(4, |i| i, 0, |acc, x| acc + x); + "; + type_check(input, &[("y", "", "int[]"), ("z", "", "int")]); +} + +#[test] +fn partial_specialization2() { + let input = " + let fold: int, (int -> T1), T2, (T2, T1 -> T2) -> T2 = |length, f, initial, folder| + if length <= 0 { + initial + } else { + folder(fold((length - 1), f, initial, folder), f((length - 1))) + }; + // This just forces the two type vars to be the same. + let fold_to_same: int, (int -> T), T, (T, T -> T) -> T = fold::; + let y = fold_to_same(4, |i| i, 0, |acc, x| acc + x); + "; + type_check(input, &[("y", "", "int")]); +} + #[test] fn type_from_pattern() { let input = "