Parse turbofish (#1219)

This commit is contained in:
chriseth
2024-04-04 18:31:11 +02:00
committed by GitHub
parent 467d9c8e49
commit 161f4d8181
11 changed files with 217 additions and 49 deletions

View File

@@ -601,7 +601,11 @@ impl Display for PolynomialName {
impl Display for NamespacedPolynomialReference {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(f, "{}", self.path.to_dotted_string())
if let Some(type_args) = &self.type_args {
write!(f, "{}::<{}>", self.path, type_args.iter().format(", "))
} else {
write!(f, "{}", self.path.to_dotted_string())
}
}
}

View File

@@ -477,11 +477,15 @@ pub struct PolynomialName {
/// This is different from SymbolPath mainly due to different formatting.
pub struct NamespacedPolynomialReference {
pub path: SymbolPath,
pub type_args: Option<Vec<Type<Expression>>>,
}
impl From<SymbolPath> for NamespacedPolynomialReference {
fn from(value: SymbolPath) -> Self {
Self { path: value }
Self {
path: value,
type_args: Default::default(),
}
}
}
@@ -491,7 +495,11 @@ impl NamespacedPolynomialReference {
}
pub fn try_to_identifier(&self) -> Option<&String> {
self.path.try_to_identifier()
if self.type_args.is_none() {
self.path.try_to_identifier()
} else {
None
}
}
}

View File

@@ -508,4 +508,34 @@ namespace N(2);
let printed = format!("{}", parse(Some("input"), input).unwrap_err_to_stderr());
assert_eq!(input.trim(), printed.trim());
}
#[test]
fn type_args() {
let input = r#"
namespace N(2);
let<T: Ord> max: T, T -> T = (|a, b| if (a < b) { b } else { a });
let<T1, T2> left: T1, T2 -> T1 = (|a, b| a);
let seven = max::<int>(3, 7);
let five = left::<int, fe[]>(5, [7]);
let also_five = five::<>;
"#;
let printed = format!("{}", parse(Some("input"), input).unwrap_err_to_stderr());
assert_eq!(input.trim(), printed.trim());
}
#[test]
fn type_args_with_space() {
let input = r#"
namespace N(2);
let<T: Ord> max: T, T -> T = (|a, b| if (a < b) { b } else { a });
let seven = max :: <int>(3, 7);
"#;
let expected = r#"
namespace N(2);
let<T: Ord> max: T, T -> T = (|a, b| if (a < b) { b } else { a });
let seven = max::<int>(3, 7);
"#;
let printed = format!("{}", parse(Some("input"), input).unwrap_err_to_stderr());
assert_eq!(expected.trim(), printed.trim());
}
}

View File

@@ -60,6 +60,19 @@ Part: Part = {
<name:Identifier> => Part::Named(name),
}
// Same as SymbolPath plus we allow "::<...>" at the end.
GenericSymbolPath: (SymbolPath, Option<Vec<Type<Expression>>>) = {
// If we "inline" SymbolPath here, we get an ambiguity error.
<abs:"::"?> <parts:( <Part> "::" )*> <end:Part> <types:("::" "<" <TypeTermList> ">")?> => (
SymbolPath::from_parts([
abs.map(|_| vec![Part::Named(String::new())]).unwrap_or_default(),
parts,
vec![end],
].concat()),
types
),
}
/// Same as SymbolPath except that we do not allow 'int' and 'fe' to be parsed as identifiers.
TypeSymbolPath: SymbolPath = {
<abs:"::"?> <parts:( <TypeSymbolPathPart> "::" )*> <end:TypeSymbolPathPart> => {
@@ -558,7 +571,7 @@ Term: Box<Expression> = {
IndexAccess => Box::new(Expression::IndexAccess(<>)),
FunctionCall => Box::new(Expression::FunctionCall(<>)),
ConstantIdentifier => Box::new(Expression::Reference(NamespacedPolynomialReference::from_identifier(<>))),
NamespacedPolynomialReference => Box::new(Expression::Reference(<>)),
GenericReference => Box::new(Expression::Reference(<>)),
PublicIdentifier => Box::new(Expression::PublicReference(<>)),
Number => Box::new(Expression::Number(<>.into(), None)),
StringLiteral => Box::new(Expression::String(<>)),
@@ -585,6 +598,11 @@ NamespacedPolynomialReference: NamespacedPolynomialReference = {
<namespace:Identifier> "." <name:Identifier> => SymbolPath::from_parts([namespace, name].into_iter().map(Part::Named)).into(),
}
GenericReference: NamespacedPolynomialReference = {
<namespace:Identifier> "." <name:Identifier> => SymbolPath::from_parts([namespace, name].into_iter().map(Part::Named)).into(),
<path:GenericSymbolPath> => NamespacedPolynomialReference{path: path.0, type_args: path.1},
}
MatchExpression: Box<Expression> = {
"match" <BoxedExpression> "{" <MatchArms> "}" => Box::new(Expression::MatchExpression(<>))
}

View File

@@ -1,30 +1,33 @@
use std::collections::HashMap;
use core::panic;
use std::collections::{HashMap, HashSet};
use powdr_ast::{
analyzed::{Expression, PolynomialReference, Reference, RepeatedArray},
parsed::{
self, asm::SymbolPath, ArrayExpression, ArrayLiteral, IfExpression, LambdaExpression,
self, ArrayExpression, ArrayLiteral, IfExpression, LambdaExpression,
LetStatementInsideBlock, MatchArm, NamespacedPolynomialReference, Pattern,
SelectedExpressions, StatementInsideBlock,
},
};
use powdr_number::DegreeType;
use crate::AnalysisDriver;
use crate::{type_processor::TypeProcessor, AnalysisDriver};
/// The ExpressionProcessor turns parsed expressions into analyzed expressions.
/// Its main job is to resolve references:
/// It turns simple references into fully namespaced references and resolves local function variables.
pub struct ExpressionProcessor<D: AnalysisDriver> {
pub struct ExpressionProcessor<'a, D: AnalysisDriver> {
driver: D,
type_vars: &'a HashSet<&'a String>,
local_variables: HashMap<String, u64>,
local_variable_counter: u64,
}
impl<D: AnalysisDriver> ExpressionProcessor<D> {
pub fn new(driver: D) -> Self {
impl<'a, D: AnalysisDriver> ExpressionProcessor<'a, D> {
pub fn new(driver: D, type_vars: &'a HashSet<&'a String>) -> Self {
Self {
driver,
type_vars,
local_variables: Default::default(),
local_variable_counter: 0,
}
@@ -168,7 +171,7 @@ impl<D: AnalysisDriver> ExpressionProcessor<D> {
let id = self.local_variables[name];
Reference::LocalVar(id, name.to_string())
}
_ => Reference::Poly(self.process_namespaced_polynomial_reference(&reference.path)),
_ => Reference::Poly(self.process_namespaced_polynomial_reference(reference)),
}
}
@@ -225,15 +228,18 @@ impl<D: AnalysisDriver> ExpressionProcessor<D> {
pub fn process_namespaced_polynomial_reference(
&mut self,
path: &SymbolPath,
reference: NamespacedPolynomialReference,
) -> PolynomialReference {
let type_processor = TypeProcessor::new(self.driver, self.type_vars);
let type_args = reference.type_args.map(|args| {
args.into_iter()
.map(|t| type_processor.process_type(t))
.collect()
});
PolynomialReference {
name: self.driver.resolve_value_ref(path),
name: self.driver.resolve_value_ref(&reference.path),
poly_id: None,
// These will be filled by the type checker.
// TODO at some point we should support the turbofish operator
// in the parser.
type_args: Default::default(),
type_args,
}
}

View File

@@ -349,7 +349,8 @@ impl PILAnalyzer {
fn handle_namespace(&mut self, name: SymbolPath, degree: Option<parsed::Expression>) {
if let Some(degree) = degree {
let degree = ExpressionProcessor::new(self.driver()).process_expression(degree);
let degree = ExpressionProcessor::new(self.driver(), &Default::default())
.process_expression(degree);
// TODO we should maybe implement a separate evaluator that is able to run before type checking
// and is field-independent (only uses integers)?
let namespace_degree: u64 = u64::try_from(

View File

@@ -312,7 +312,10 @@ where
source,
IdentityKind::Polynomial,
SelectedExpressions {
selector: Some(self.process_expression(expression)),
selector: Some(
self.expression_processor(&Default::default())
.process_expression(expression),
),
expressions: vec![],
},
SelectedExpressions::default(),
@@ -320,25 +323,33 @@ where
PilStatement::PlookupIdentity(source, key, haystack) => (
source,
IdentityKind::Plookup,
self.process_selected_expressions(key),
self.process_selected_expressions(haystack),
self.expression_processor(&Default::default())
.process_selected_expressions(key),
self.expression_processor(&Default::default())
.process_selected_expressions(haystack),
),
PilStatement::PermutationIdentity(source, left, right) => (
source,
IdentityKind::Permutation,
self.process_selected_expressions(left),
self.process_selected_expressions(right),
self.expression_processor(&Default::default())
.process_selected_expressions(left),
self.expression_processor(&Default::default())
.process_selected_expressions(right),
),
PilStatement::ConnectIdentity(source, left, right) => (
source,
IdentityKind::Connect,
SelectedExpressions {
selector: None,
expressions: self.expression_processor().process_expressions(left),
expressions: self
.expression_processor(&Default::default())
.process_expressions(left),
},
SelectedExpressions {
selector: None,
expressions: self.expression_processor().process_expressions(right),
expressions: self
.expression_processor(&Default::default())
.process_expressions(right),
},
),
// TODO at some point, these should all be caught by the type checker.
@@ -457,15 +468,21 @@ where
));
assert!(type_scheme.is_none() || type_scheme == Some(Type::Col.into()));
}
let type_vars = type_scheme
.as_ref()
.map(|ts| ts.vars.vars().collect())
.unwrap_or_default();
FunctionValueDefinition::Expression(TypedExpression {
e: self.process_expression(expr),
e: self
.expression_processor(&type_vars)
.process_expression(expr),
type_scheme,
})
}
FunctionDefinition::Array(value) => {
let size = value.solve(self.degree.unwrap());
let expression = self
.expression_processor()
.expression_processor(&Default::default())
.process_array_expression(value, size);
assert_eq!(
expression.iter().map(|e| e.size()).sum::<DegreeType>(),
@@ -489,8 +506,8 @@ where
) -> Vec<PILItem> {
let id = self.counters.dispense_public_id();
let polynomial = self
.expression_processor()
.process_namespaced_polynomial_reference(&poly.path);
.expression_processor(&Default::default())
.process_namespaced_polynomial_reference(poly);
let array_index = array_index.map(|i| {
let index: u64 = untyped_evaluator::evaluate_expression_to_int(self.driver, i)
.unwrap()
@@ -512,20 +529,11 @@ where
})]
}
fn expression_processor(&self) -> ExpressionProcessor<D> {
ExpressionProcessor::new(self.driver)
}
fn process_expression(&self, expr: parsed::Expression) -> Expression {
self.expression_processor().process_expression(expr)
}
fn process_selected_expressions(
&self,
expr: parsed::SelectedExpressions<parsed::Expression>,
) -> SelectedExpressions<Expression> {
self.expression_processor()
.process_selected_expressions(expr)
fn expression_processor<'b>(
&'b self,
type_vars: &'b HashSet<&'b String>,
) -> ExpressionProcessor<'b, D> {
ExpressionProcessor::new(self.driver, type_vars)
}
fn type_processor<'b>(&'b self, type_vars: &'b HashSet<&'b String>) -> TypeProcessor<'b, D> {

View File

@@ -57,6 +57,8 @@ struct TypeChecker<'a> {
/// Declared types for all symbols. Contains the unmodified type scheme for symbols
/// with generic types and newly created type variables for symbols without declared type.
declared_types: HashMap<String, TypeScheme>,
/// Current mapping of declared type vars to type. Reset before checking each definition.
declared_type_vars: HashMap<String, Type>,
unifier: Unifier,
/// Last used type variable index.
last_type_var: usize,
@@ -68,6 +70,7 @@ impl<'a> TypeChecker<'a> {
statement_type,
local_var_types: Default::default(),
declared_types: Default::default(),
declared_type_vars: Default::default(),
unifier: Default::default(),
last_type_var: Default::default(),
}
@@ -133,8 +136,14 @@ impl<'a> TypeChecker<'a> {
let declared_type = self.declared_types[&name].clone();
let result = if declared_type.vars.is_empty() {
self.declared_type_vars.clear();
self.process_concrete_symbol(&name, declared_type.ty.clone(), value)
} else {
self.declared_type_vars = declared_type
.vars
.vars()
.map(|v| (v.clone(), self.new_type_var()))
.collect();
self.infer_type_of_expression(value).map(|ty| {
inferred_types.insert(name.to_string(), ty);
})
@@ -145,6 +154,7 @@ impl<'a> TypeChecker<'a> {
));
}
}
self.declared_type_vars.clear();
self.check_expressions(expressions)?;
@@ -454,9 +464,22 @@ impl<'a> TypeChecker<'a> {
poly_id: _,
type_args,
})) => {
// The generic args (some of them) could be pre-filled by the parser, but we do not yet support that.
assert!(type_args.is_none());
let (ty, args) = self.instantiate_scheme(self.declared_types[name].clone());
if let Some(requested_type_args) = type_args {
if requested_type_args.len() != args.len() {
return Err(format!(
"Expected {} type arguments for symbol {name}, but got {}: {}",
args.len(),
requested_type_args.len(),
requested_type_args.iter().join(", ")
));
}
for (requested, inferred) in requested_type_args.iter_mut().zip(&args) {
requested.substitute_type_vars(&self.declared_type_vars);
self.unifier
.unify_types(requested.clone(), inferred.clone())?;
}
}
*type_args = Some(args);
type_for_reference(&ty)
}

View File

@@ -17,7 +17,7 @@ pub fn evaluate_expression_to_int(
expr: parsed::Expression,
) -> Result<BigInt, EvalError> {
evaluator::evaluate_expression::<GoldilocksField>(
&ExpressionProcessor::new(driver).process_expression(expr),
&ExpressionProcessor::new(driver, &Default::default()).process_expression(expr),
driver.definitions(),
)?
.try_to_integer()

View File

@@ -6,9 +6,6 @@ use pretty_assertions::assert_eq;
#[test]
fn parse_print_analyzed() {
// Re-add this line once we can parse the turbofish operator.
// col witness X_free_value(__i) query match std::prover::eval(T.pc) { 0 => std::prover::Query::Input(1), 3 => std::prover::Query::Input(std::convert::int::<fe>(std::prover::eval(T.CNT) + 1)), 7 => std::prover::Query::Input(0), };
// This is rather a test for the Display trait than for the analyzer.
let input = r#"constant %N = 65536;
public P = T.pc(2);
@@ -52,6 +49,7 @@ namespace T(65536);
T.A' = (((T.first_step' * 0) + (T.reg_write_X_A * T.X)) + ((1 - (T.first_step' + T.reg_write_X_A)) * T.A));
col witness X_free_value(__i) query match std::prover::eval(T.pc) {
0 => std::prover::Query::Input(1),
3 => std::prover::Query::Input(std::convert::int::<fe>((std::prover::eval(T.CNT) + 1))),
7 => std::prover::Query::Input(0),
};
col fixed p_X_const = [0, 0, 0, 0, 0, 0, 0, 0, 0] + [0]*;

View File

@@ -333,6 +333,78 @@ fn query_with_wrong_type() {
type_check(input, &[]);
}
#[test]
#[should_panic = "Type int[] does not satisfy trait"]
fn wrong_type_args() {
let input = "
let<T: FromLiteral + Mul + Add> bn: T, T -> T = |a, b| a * 0x100000000 + b;
let t: int = bn::<int[]>(5, 6);
";
type_check(input, &[]);
}
#[test]
#[should_panic = "Symbol not found: T"]
fn specialization_non_declared_type_var() {
let input = "
let<T: FromLiteral> x: T = 1;
let t: int = x::<T>;
";
type_check(input, &[]);
}
#[test]
#[should_panic = "Expected 0 type arguments for symbol x, but got 1: int[]"]
fn specialization_of_non_generic_symbol() {
let input = "
let x: int = 1;
let t: int = x::<int[]>;
";
type_check(input, &[]);
}
#[test]
fn specialization_of_non_generic_symbol2() {
let input = "
let x: int = 1;
let t: int = x::<>;
";
type_check(input, &[]);
}
#[test]
fn partial_specialization() {
let input = "
let<T1, T2> fold: int, (int -> T1), T2, (T2, T1 -> T2) -> T2 = |length, f, initial, folder|
if length <= 0 {
initial
} else {
folder(fold((length - 1), f, initial, folder), f((length - 1)))
};
let<T> fold_to_int_arr: int, (int -> T), int[], (int[], T -> int[]) -> int[] = fold::<T, int[]>;
let<T> fold_int: int, (int -> int), T, (T, int -> T) -> T = fold::<int, T>;
let y = fold_to_int_arr(4, |i| i, [], |acc, x| acc + [x]);
let z = fold_int(4, |i| i, 0, |acc, x| acc + x);
";
type_check(input, &[("y", "", "int[]"), ("z", "", "int")]);
}
#[test]
fn partial_specialization2() {
let input = "
let<T1, T2> fold: int, (int -> T1), T2, (T2, T1 -> T2) -> T2 = |length, f, initial, folder|
if length <= 0 {
initial
} else {
folder(fold((length - 1), f, initial, folder), f((length - 1)))
};
// This just forces the two type vars to be the same.
let<T> fold_to_same: int, (int -> T), T, (T, T -> T) -> T = fold::<T, T>;
let y = fold_to_same(4, |i| i, 0, |acc, x| acc + x);
";
type_check(input, &[("y", "", "int")]);
}
#[test]
fn type_from_pattern() {
let input = "