diff --git a/ast/src/analyzed/display.rs b/ast/src/analyzed/display.rs index 636dc5558..95a2ec562 100644 --- a/ast/src/analyzed/display.rs +++ b/ast/src/analyzed/display.rs @@ -10,7 +10,10 @@ use std::{ use itertools::Itertools; -use self::parsed::asm::{AbsoluteSymbolPath, SymbolPath}; +use self::{ + parsed::asm::{AbsoluteSymbolPath, SymbolPath}, + types::{ArrayType, FunctionType, TupleType, Type}, +}; use super::*; @@ -47,7 +50,18 @@ impl Display for Analyzed { }; write!(f, " col {kind}{name}")?; if let Some(length) = symbol.length { - write!(f, "[{length}]")?; + if let PolynomialType::Committed = poly_type { + write!(f, "[{length}]")?; + assert!(definition.is_none()); + } else { + // Do not print an array size, because we will do it as part of the type. + assert!(matches!( + definition, + Some(FunctionValueDefinition::Expression( + TypedExpression { e: _, ty: Some(_) } + )) + )); + } } if let Some(value) = definition { writeln!(f, "{value};")? @@ -57,11 +71,18 @@ impl Display for Analyzed { } SymbolKind::Constant() => { let indentation = if is_local { " " } else { "" }; - writeln!( - f, - "{indentation}constant {name}{};", - definition.as_ref().unwrap() - )?; + let Some(FunctionValueDefinition::Expression(TypedExpression { + e, + ty: Some(Type::Fe), + })) = &definition + else { + panic!( + "Invalid constant value: {}", + definition.as_ref().unwrap() + ); + }; + + writeln!(f, "{indentation}constant {name} = {e};",)?; } SymbolKind::Other() => { write!(f, " let {name}")?; @@ -108,7 +129,17 @@ impl Display for FunctionValueDefinition { write!(f, " = {}", items.iter().format(" + ")) } FunctionValueDefinition::Query(e) => format_outer_function(e, Some("query"), f), - FunctionValueDefinition::Expression(e) => format_outer_function(e, None, f), + FunctionValueDefinition::Expression(TypedExpression { e, ty: None }) => { + format_outer_function(e, None, f) + } + FunctionValueDefinition::Expression(TypedExpression { e, ty: Some(ty) }) + if *ty == Type::col() => + { + format_outer_function(e, None, f) + } + FunctionValueDefinition::Expression(TypedExpression { e, ty: Some(ty) }) => { + write!(f, ": {ty} = {e}") + } } } } @@ -248,3 +279,65 @@ impl Display for PolynomialReference { write!(f, "{}", self.name,) } } + +impl Display for Type { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match self { + Type::Bool => write!(f, "bool"), + Type::Int => write!(f, "int"), + Type::Fe => write!(f, "fe"), + Type::String => write!(f, "string"), + Type::Expr => write!(f, "expr"), + Type::Constr => write!(f, "constr"), + Type::Array(ar) => write!(f, "{ar}"), + Type::Tuple(tu) => write!(f, "{tu}"), + Type::Function(fun) => write!(f, "{fun}"), + } + } +} + +impl Display for ArrayType { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let length = self.length.iter().format(""); + if self.base.needs_parentheses() { + write!(f, "({})[{length}]", self.base) + } else { + write!(f, "{}[{length}]", self.base) + } + } +} + +impl Display for TupleType { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!(f, "({})", format_list_of_types(&self.items)) + } +} + +impl Display for FunctionType { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + if *self == Self::col() { + write!(f, "col") + } else { + write!( + f, + "{} -> {}", + format_list_of_types(&self.params), + self.value + ) + } + } +} + +fn format_list_of_types(types: &[Type]) -> String { + types + .iter() + .map(|x| { + if x.needs_parentheses() { + format!("({x})") + } else { + x.to_string() + } + }) + .format(", ") + .to_string() +} diff --git a/ast/src/analyzed/mod.rs b/ast/src/analyzed/mod.rs index 613619310..77f59c60e 100644 --- a/ast/src/analyzed/mod.rs +++ b/ast/src/analyzed/mod.rs @@ -1,4 +1,5 @@ mod display; +pub mod types; pub mod visitor; use core::hash::Hash; @@ -15,6 +16,8 @@ pub use crate::parsed::UnaryOperator; use crate::parsed::{self, SelectedExpressions}; use crate::SourceRef; +use self::types::TypedExpression; + #[derive(Debug, Clone)] pub enum StatementIdentifier { /// Either an intermediate column or a definition. @@ -303,7 +306,9 @@ impl Analyzed { .iter_mut() .flat_map(|e| e.pattern.iter_mut()) .for_each(|e| e.post_visit_expressions_mut(f)), - Some(FunctionValueDefinition::Expression(e)) => e.post_visit_expressions_mut(f), + Some(FunctionValueDefinition::Expression(TypedExpression { e, ty: _ })) => { + e.post_visit_expressions_mut(f) + } None => {} }); } @@ -467,7 +472,7 @@ pub enum SymbolKind { pub enum FunctionValueDefinition { Array(Vec>), Query(Expression), - Expression(Expression), + Expression(TypedExpression), } /// An array of elements that might be repeated. diff --git a/ast/src/analyzed/types.rs b/ast/src/analyzed/types.rs new file mode 100644 index 000000000..c6eec9587 --- /dev/null +++ b/ast/src/analyzed/types.rs @@ -0,0 +1,138 @@ +use std::fmt::Display; + +use powdr_number::FieldElement; + +use crate::parsed::{ArrayTypeName, Expression, FunctionTypeName, TupleTypeName, TypeName}; + +use super::Reference; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub struct TypedExpression { + pub e: Expression, + pub ty: Option, +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub enum Type { + /// Boolean + Bool, + /// Integer (arbitrary precision) + Int, + /// Field element (unspecified field) + Fe, + /// String + String, + /// Algebraic expression + Expr, + /// Polynomial identity or lookup (not yet supported) + Constr, + Array(ArrayType), + Tuple(TupleType), + Function(FunctionType), +} + +impl Type { + /// Returns the column type `int -> fe`. + pub fn col() -> Self { + Type::Function(FunctionType::col()) + } + + /// Returns true if the type name needs parentheses around it during formatting + /// when used inside a complex expression. + pub fn needs_parentheses(&self) -> bool { + match self { + Type::Bool + | Type::Int + | Type::Fe + | Type::String + | Type::Expr + | Type::Constr + | Type::Array(_) + | Type::Tuple(_) => false, + Type::Function(fun) => fun.needs_parentheses(), + } + } +} + +impl From>> for Type { + fn from(value: TypeName>) -> Self { + match value { + TypeName::Bool => Type::Bool, + TypeName::Int => Type::Int, + TypeName::Fe => Type::Fe, + TypeName::String => Type::String, + TypeName::Expr => Type::Expr, + TypeName::Constr => Type::Constr, + TypeName::Col => Type::col(), + TypeName::Array(ar) => Type::Array(ar.into()), + TypeName::Tuple(tu) => Type::Tuple(tu.into()), + TypeName::Function(fun) => Type::Function(fun.into()), + } + } +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub struct ArrayType { + pub base: Box, + pub length: Option, +} + +impl From>> for ArrayType { + fn from(name: ArrayTypeName>) -> Self { + let length = name.length.as_ref().map(|l| { + if let Expression::Number(n) = l { + n.to_degree() + } else { + panic!( + "Array length expression not resolved in type name prior to conversion: {name}" + ); + } + }); + ArrayType { + base: Box::new(Type::from(*name.base)), + length, + } + } +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub struct TupleType { + pub items: Vec, +} + +impl From>> for TupleType { + fn from(value: TupleTypeName>) -> Self { + TupleType { + items: value.items.into_iter().map(Into::into).collect(), + } + } +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub struct FunctionType { + pub params: Vec, + pub value: Box, +} + +impl FunctionType { + /// Returns the column type `int -> fe`. + pub fn col() -> Self { + FunctionType { + params: vec![Type::Int], + value: Box::new(Type::Fe), + } + } + + pub fn needs_parentheses(&self) -> bool { + *self != Self::col() + } +} + +impl From>> for FunctionType { + fn from(name: FunctionTypeName>) -> Self { + FunctionType { + params: name.params.into_iter().map(Into::into).collect(), + value: Box::new(Type::from(*name.value)), + } + } +} diff --git a/ast/src/analyzed/visitor.rs b/ast/src/analyzed/visitor.rs index dc17e79f0..f70732518 100644 --- a/ast/src/analyzed/visitor.rs +++ b/ast/src/analyzed/visitor.rs @@ -86,7 +86,8 @@ impl ExpressionVisitable> for FunctionValueDefinition { F: FnMut(&mut Expression) -> ControlFlow, { match self { - FunctionValueDefinition::Query(e) | FunctionValueDefinition::Expression(e) => { + FunctionValueDefinition::Query(e) + | FunctionValueDefinition::Expression(TypedExpression { e, ty: _ }) => { e.visit_expressions_mut(f, o) } FunctionValueDefinition::Array(array) => array @@ -101,7 +102,8 @@ impl ExpressionVisitable> for FunctionValueDefinition { F: FnMut(&Expression) -> ControlFlow, { match self { - FunctionValueDefinition::Query(e) | FunctionValueDefinition::Expression(e) => { + FunctionValueDefinition::Query(e) + | FunctionValueDefinition::Expression(TypedExpression { e, ty: _ }) => { e.visit_expressions(f, o) } FunctionValueDefinition::Array(array) => array diff --git a/ast/src/asm_analysis/display.rs b/ast/src/asm_analysis/display.rs index b00bebebf..6f2109dbe 100644 --- a/ast/src/asm_analysis/display.rs +++ b/ast/src/asm_analysis/display.rs @@ -7,7 +7,10 @@ use itertools::Itertools; use crate::{ indent, - parsed::asm::{AbsoluteSymbolPath, Part}, + parsed::{ + asm::{AbsoluteSymbolPath, Part}, + ExpressionWithTypeName, + }, write_indented_by, write_items_indented, }; @@ -44,9 +47,15 @@ impl Display for AnalysisASMFile { Item::Machine(machine) => { write_indented_by(f, format!("machine {name}{machine}"), current_path.len())?; } - Item::Expression(expression) => write_indented_by( + Item::Expression(ExpressionWithTypeName { e, type_name }) => write_indented_by( f, - format!("let {name} = {expression};\n"), + format!( + "let {name}{} = {e};\n", + type_name + .as_ref() + .map(|tn| format!(": {tn}")) + .unwrap_or_default() + ), current_path.len(), )?, } diff --git a/ast/src/asm_analysis/mod.rs b/ast/src/asm_analysis/mod.rs index 6bcbd9def..3faa262d5 100644 --- a/ast/src/asm_analysis/mod.rs +++ b/ast/src/asm_analysis/mod.rs @@ -18,7 +18,7 @@ use crate::parsed::{ AbsoluteSymbolPath, AssignmentRegister, CallableRef, InstructionBody, OperationId, Params, }, visitor::{ExpressionVisitable, VisitOrder}, - NamespacedPolynomialReference, PilStatement, + ExpressionWithTypeName, NamespacedPolynomialReference, PilStatement, }; use crate::SourceRef; @@ -672,7 +672,7 @@ pub struct SubmachineDeclaration { #[derive(Clone, Debug)] pub enum Item { Machine(Machine), - Expression(Expression), + Expression(ExpressionWithTypeName), } impl Item { diff --git a/ast/src/object/display.rs b/ast/src/object/display.rs index 50c647f36..c95fb1bea 100644 --- a/ast/src/object/display.rs +++ b/ast/src/object/display.rs @@ -1,5 +1,7 @@ use std::fmt::{Display, Formatter, Result}; +use crate::parsed::ExpressionWithTypeName; + use super::{Link, LinkFrom, LinkTo, Location, Machine, Object, Operation, PILGraph}; impl Display for Location { @@ -11,8 +13,15 @@ impl Display for Location { impl Display for PILGraph { fn fmt(&self, f: &mut Formatter<'_>) -> Result { writeln!(f, "// Utilities")?; - for (name, e) in &self.definitions { - writeln!(f, "let {name} = {e};")?; + for (name, ExpressionWithTypeName { e, type_name }) in &self.definitions { + writeln!( + f, + "let {name}{} = {e};", + type_name + .as_ref() + .map(|tn| format!(": {tn}")) + .unwrap_or_default() + )?; } for (location, object) in &self.objects { writeln!(f, "// Object {}", location)?; diff --git a/ast/src/object/mod.rs b/ast/src/object/mod.rs index 7c6201701..05771bed1 100644 --- a/ast/src/object/mod.rs +++ b/ast/src/object/mod.rs @@ -2,7 +2,7 @@ use std::collections::BTreeMap; use crate::parsed::{ asm::{AbsoluteSymbolPath, Params}, - Expression, PilStatement, + Expression, ExpressionWithTypeName, PilStatement, }; mod display; @@ -30,7 +30,7 @@ pub struct PILGraph { pub main: Machine, pub entry_points: Vec>, pub objects: BTreeMap>, - pub definitions: BTreeMap>, + pub definitions: BTreeMap>, } #[derive(Default, Clone)] diff --git a/ast/src/parsed/asm.rs b/ast/src/parsed/asm.rs index a43fec214..d216eef48 100644 --- a/ast/src/parsed/asm.rs +++ b/ast/src/parsed/asm.rs @@ -11,7 +11,7 @@ use derive_more::From; use crate::SourceRef; -use super::{Expression, PilStatement}; +use super::{Expression, ExpressionWithTypeName, PilStatement}; #[derive(Default, Clone, Debug, PartialEq, Eq)] pub struct ASMProgram { @@ -51,7 +51,7 @@ pub enum SymbolValue { /// A module definition Module(Module), /// A generic symbol / function. - Expression(Expression), + Expression(ExpressionWithTypeName), } impl SymbolValue { @@ -74,7 +74,7 @@ pub enum SymbolValueRef<'a, T> { /// A module definition Module(ModuleRef<'a, T>), /// A generic symbol / function. - Expression(&'a Expression), + Expression(&'a ExpressionWithTypeName), } #[derive(Debug, Clone, PartialEq, Eq, From)] diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index 928753ac9..bbfae8f47 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -57,8 +57,15 @@ impl Display for ModuleStatement { SymbolValue::Module(m @ Module::Local(_)) => { write!(f, "mod {name} {m}") } - SymbolValue::Expression(e) => { - write!(f, "let {name} = {e};") + SymbolValue::Expression(ExpressionWithTypeName { e, type_name }) => { + write!( + f, + "let {name}{} = {e};", + type_name + .as_ref() + .map(|t| format!(": {t}")) + .unwrap_or_default() + ) } }, } @@ -373,9 +380,15 @@ impl Display for PilStatement { PilStatement::Namespace(_, name, poly_length) => { write!(f, "namespace {name}({poly_length});") } - PilStatement::LetStatement(_, name, None) => write!(f, " let {name};"), - PilStatement::LetStatement(_, name, Some(expr)) => { - write!(f, " let {name} = {expr};") + PilStatement::LetStatement(_, name, type_name, value) => { + write!(f, " let {name}")?; + if let Some(type_name) = type_name { + write!(f, ": {type_name}")?; + } + if let Some(value) = &value { + write!(f, " = {value}")?; + } + write!(f, ";") } PilStatement::PolynomialDefinition(_, name, value) => { write!(f, " pol {name} = {value};") @@ -584,6 +597,74 @@ impl Display for UnaryOperator { } } +impl Display for TypeName { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + match self { + TypeName::Bool => write!(f, "bool"), + TypeName::Int => write!(f, "int"), + TypeName::Fe => write!(f, "fe"), + TypeName::String => write!(f, "string"), + TypeName::Col => write!(f, "col"), + TypeName::Expr => write!(f, "expr"), + TypeName::Constr => write!(f, "constr"), + TypeName::Array(array) => write!(f, "{array}"), + TypeName::Tuple(tuple) => write!(f, "{tuple}"), + TypeName::Function(fun) => write!(f, "{fun}"), + } + } +} + +impl Display for ArrayTypeName { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + if self.base.needs_parentheses() { + write!(f, "({})", self.base) + } else { + write!(f, "{}", self.base) + }?; + write!( + f, + "[{}]", + self.length + .as_ref() + .map(|l| l.to_string()) + .unwrap_or_default() + ) + } +} + +impl Display for TupleTypeName { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + write!(f, "({})", self.items.iter().format(", ")) + } +} + +impl Display for FunctionTypeName { + fn fmt(&self, f: &mut Formatter<'_>) -> Result { + let params = self + .params + .iter() + .map(|x| { + if x.needs_parentheses() { + format!("({x})") + } else { + format!("{x}") + } + }) + .join(", ") + + if self.params.is_empty() { "" } else { " " }; + + write!( + f, + "{params}-> {}", + if self.value.needs_parentheses() { + format!("({})", self.value) + } else { + format!("{}", self.value) + } + ) + } +} + #[cfg(test)] mod tests { diff --git a/ast/src/parsed/mod.rs b/ast/src/parsed/mod.rs index a81803489..83bc86972 100644 --- a/ast/src/parsed/mod.rs +++ b/ast/src/parsed/mod.rs @@ -24,7 +24,12 @@ pub enum PilStatement { Include(SourceRef, String), /// Name of namespace and polynomial degree (constant) Namespace(SourceRef, SymbolPath, Expression), - LetStatement(SourceRef, String, Option>), + LetStatement( + SourceRef, + String, + Option>>, + Option>, + ), PolynomialDefinition(SourceRef, String, Expression), PublicDeclaration( SourceRef, @@ -68,7 +73,7 @@ impl PilStatement { | PilStatement::PolynomialConstantDefinition(_, name, _) | PilStatement::ConstantDefinition(_, name, _) | PilStatement::PublicDeclaration(_, name, _, _, _) - | PilStatement::LetStatement(_, name, _) => Box::new(once(name)), + | PilStatement::LetStatement(_, name, _, _) => Box::new(once(name)), PilStatement::PolynomialConstantDeclaration(_, polynomials) | PilStatement::PolynomialCommitDeclaration(_, polynomials, _) => { Box::new(polynomials.iter().map(|p| &p.name)) @@ -98,8 +103,11 @@ impl PilStatement { | PilStatement::Namespace(_, _, e) | PilStatement::PolynomialDefinition(_, _, e) | PilStatement::PolynomialIdentity(_, e) - | PilStatement::ConstantDefinition(_, _, e) - | PilStatement::LetStatement(_, _, Some(e)) => Box::new(once(e)), + | PilStatement::ConstantDefinition(_, _, e) => Box::new(once(e)), + + PilStatement::LetStatement(_, _, type_name, value) => { + Box::new(type_name.iter().flat_map(|t| t.expressions()).chain(value)) + } PilStatement::PublicDeclaration(_, _, _, i, e) => Box::new(i.iter().chain(once(e))), @@ -107,8 +115,7 @@ impl PilStatement { | PilStatement::PolynomialCommitDeclaration(_, _, Some(fundef)) => fundef.expressions(), PilStatement::PolynomialCommitDeclaration(_, _, None) | PilStatement::Include(_, _) - | PilStatement::PolynomialConstantDeclaration(_, _) - | PilStatement::LetStatement(_, _, None) => Box::new(empty()), + | PilStatement::PolynomialConstantDeclaration(_, _) => Box::new(empty()), } } @@ -126,8 +133,14 @@ impl PilStatement { | PilStatement::Namespace(_, _, e) | PilStatement::PolynomialDefinition(_, _, e) | PilStatement::PolynomialIdentity(_, e) - | PilStatement::ConstantDefinition(_, _, e) - | PilStatement::LetStatement(_, _, Some(e)) => Box::new(once(e)), + | PilStatement::ConstantDefinition(_, _, e) => Box::new(once(e)), + + PilStatement::LetStatement(_, _, type_name, value) => Box::new( + type_name + .iter_mut() + .flat_map(|t| t.expressions_mut()) + .chain(value), + ), PilStatement::PublicDeclaration(_, _, _, i, e) => Box::new(i.iter_mut().chain(once(e))), @@ -137,8 +150,7 @@ impl PilStatement { } PilStatement::PolynomialCommitDeclaration(_, _, None) | PilStatement::Include(_, _) - | PilStatement::PolynomialConstantDeclaration(_, _) - | PilStatement::LetStatement(_, _, None) => Box::new(empty()), + | PilStatement::PolynomialConstantDeclaration(_, _) => Box::new(empty()), } } } @@ -496,3 +508,141 @@ impl ArrayExpression { } } } + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub enum TypeName { + /// Boolean + Bool, + /// Integer (arbitrary precision) + Int, + /// Field element (unspecified field) + Fe, + /// String + String, + /// Column, shorthand for "int -> fe" + Col, + /// Algebraic expression + Expr, + /// Polynomial identity + Constr, + Array(ArrayTypeName), + Tuple(TupleTypeName), + Function(FunctionTypeName), +} + +impl TypeName { + /// Returns true if the type name needs parentheses during formatting + /// when used inside a complex expression. + pub fn needs_parentheses(&self) -> bool { + match self { + TypeName::Bool + | TypeName::Int + | TypeName::Fe + | TypeName::String + | TypeName::Col + | TypeName::Expr + | TypeName::Constr + | TypeName::Array(_) + | TypeName::Tuple(_) => false, + TypeName::Function(_) => true, + } + } + + /// Returns an iterator over all (top-level) expressions in this type name. + pub fn expressions(&self) -> Box + '_> { + match self { + TypeName::Bool + | TypeName::Int + | TypeName::Fe + | TypeName::String + | TypeName::Col + | TypeName::Expr + | TypeName::Constr => Box::new(empty()), + TypeName::Array(a) => a.expressions(), + TypeName::Tuple(t) => t.expressions(), + TypeName::Function(f) => f.expressions(), + } + } + + /// Returns an iterator over all (top-level) expressions in this type name. + pub fn expressions_mut(&mut self) -> Box + '_> { + match self { + TypeName::Bool + | TypeName::Int + | TypeName::Fe + | TypeName::String + | TypeName::Col + | TypeName::Expr + | TypeName::Constr => Box::new(empty()), + TypeName::Array(a) => a.expressions_mut(), + TypeName::Tuple(t) => t.expressions_mut(), + TypeName::Function(f) => f.expressions_mut(), + } + } +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub struct ArrayTypeName { + pub base: Box>, + pub length: Option, +} + +impl ArrayTypeName { + /// Returns an iterator over all (top-level) expressions in this type name. + pub fn expressions(&self) -> Box + '_> { + Box::new(self.base.expressions().chain(self.length.iter())) + } + /// Returns an iterator over all (top-level) expressions in this type name. + pub fn expressions_mut(&mut self) -> Box + '_> { + Box::new(self.base.expressions_mut().chain(self.length.iter_mut())) + } +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub struct TupleTypeName { + pub items: Vec>, +} + +impl TupleTypeName { + /// Returns an iterator over all (top-level) expressions in this type name. + pub fn expressions(&self) -> Box + '_> { + Box::new(self.items.iter().flat_map(|t| t.expressions())) + } + /// Returns an iterator over all (top-level) expressions in this type name. + pub fn expressions_mut(&mut self) -> Box + '_> { + Box::new(self.items.iter_mut().flat_map(|t| t.expressions_mut())) + } +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub struct FunctionTypeName { + pub params: Vec>, + pub value: Box>, +} + +impl FunctionTypeName { + /// Returns an iterator over all (top-level) expressions in this type name. + pub fn expressions(&self) -> Box + '_> { + Box::new( + self.params + .iter() + .flat_map(|t| t.expressions()) + .chain(self.value.expressions()), + ) + } + /// Returns an iterator over all (top-level) expressions in this type name. + pub fn expressions_mut(&mut self) -> Box + '_> { + Box::new( + self.params + .iter_mut() + .flat_map(|t| t.expressions_mut()) + .chain(self.value.expressions_mut()), + ) + } +} + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] +pub struct ExpressionWithTypeName { + pub e: Expression, + pub type_name: Option>>, +} diff --git a/ast/src/parsed/visitor.rs b/ast/src/parsed/visitor.rs index 5cdb1534d..2be5dc947 100644 --- a/ast/src/parsed/visitor.rs +++ b/ast/src/parsed/visitor.rs @@ -1,9 +1,9 @@ use std::{iter::once, ops::ControlFlow}; use super::{ - ArrayExpression, ArrayLiteral, Expression, FunctionCall, FunctionDefinition, IfExpression, - IndexAccess, LambdaExpression, MatchArm, MatchPattern, NamespacedPolynomialReference, - PilStatement, SelectedExpressions, + ArrayExpression, ArrayLiteral, ArrayTypeName, Expression, FunctionCall, FunctionDefinition, + FunctionTypeName, IfExpression, IndexAccess, LambdaExpression, MatchArm, MatchPattern, + NamespacedPolynomialReference, PilStatement, SelectedExpressions, TupleTypeName, TypeName, }; #[derive(Clone, Copy, Debug, PartialEq, Eq)] @@ -205,8 +205,17 @@ impl ExpressionVisitable> for Pi | PilStatement::PolynomialDefinition(_, _, e) | PilStatement::PolynomialIdentity(_, e) | PilStatement::PublicDeclaration(_, _, _, None, e) - | PilStatement::ConstantDefinition(_, _, e) - | PilStatement::LetStatement(_, _, Some(e)) => e.visit_expressions_mut(f, o), + | PilStatement::ConstantDefinition(_, _, e) => e.visit_expressions_mut(f, o), + + PilStatement::LetStatement(_, _, type_name, value) => { + if let Some(t) = type_name { + t.visit_expressions_mut(f, o)?; + }; + if let Some(v) = value { + v.visit_expressions_mut(f, o)?; + }; + ControlFlow::Continue(()) + } PilStatement::PublicDeclaration(_, _, _, Some(i), e) => [i, e] .into_iter() @@ -218,8 +227,7 @@ impl ExpressionVisitable> for Pi } PilStatement::PolynomialCommitDeclaration(_, _, None) | PilStatement::Include(_, _) - | PilStatement::PolynomialConstantDeclaration(_, _) - | PilStatement::LetStatement(_, _, None) => ControlFlow::Continue(()), + | PilStatement::PolynomialConstantDeclaration(_, _) => ControlFlow::Continue(()), } } @@ -242,8 +250,17 @@ impl ExpressionVisitable> for Pi | PilStatement::PolynomialDefinition(_, _, e) | PilStatement::PolynomialIdentity(_, e) | PilStatement::PublicDeclaration(_, _, _, None, e) - | PilStatement::ConstantDefinition(_, _, e) - | PilStatement::LetStatement(_, _, Some(e)) => e.visit_expressions(f, o), + | PilStatement::ConstantDefinition(_, _, e) => e.visit_expressions(f, o), + + PilStatement::LetStatement(_, _, type_name, value) => { + if let Some(t) = type_name { + t.visit_expressions(f, o)?; + }; + if let Some(v) = value { + v.visit_expressions(f, o)?; + }; + ControlFlow::Continue(()) + } PilStatement::PublicDeclaration(_, _, _, Some(i), e) => [i, e] .into_iter() @@ -255,8 +272,7 @@ impl ExpressionVisitable> for Pi } PilStatement::PolynomialCommitDeclaration(_, _, None) | PilStatement::Include(_, _) - | PilStatement::PolynomialConstantDeclaration(_, _) - | PilStatement::LetStatement(_, _, None) => ControlFlow::Continue(()), + | PilStatement::PolynomialConstantDeclaration(_, _) => ControlFlow::Continue(()), } } } @@ -478,3 +494,105 @@ impl ExpressionVisitable> for IfExpression { .try_for_each(|e| e.visit_expressions(f, o)) } } + +impl> ExpressionVisitable for TypeName { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut E) -> ControlFlow, + { + match self { + TypeName::Bool + | TypeName::Int + | TypeName::Fe + | TypeName::String + | TypeName::Col + | TypeName::Expr + | TypeName::Constr => ControlFlow::Continue(()), + TypeName::Array(a) => a.visit_expressions_mut(f, o), + TypeName::Tuple(t) => t.visit_expressions_mut(f, o), + TypeName::Function(fun) => fun.visit_expressions_mut(f, o), + } + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&E) -> ControlFlow, + { + match self { + TypeName::Bool + | TypeName::Int + | TypeName::Fe + | TypeName::String + | TypeName::Col + | TypeName::Expr + | TypeName::Constr => ControlFlow::Continue(()), + TypeName::Array(a) => a.visit_expressions(f, o), + TypeName::Tuple(t) => t.visit_expressions(f, o), + TypeName::Function(fun) => fun.visit_expressions(f, o), + } + } +} + +impl> ExpressionVisitable for ArrayTypeName { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut E) -> ControlFlow, + { + self.base.visit_expressions_mut(f, o)?; + self.length + .iter_mut() + .try_for_each(|e| e.visit_expressions_mut(f, o)) + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&E) -> ControlFlow, + { + self.base.visit_expressions(f, o)?; + self.length + .iter() + .try_for_each(|e| e.visit_expressions(f, o)) + } +} + +impl> ExpressionVisitable for TupleTypeName { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut E) -> ControlFlow, + { + self.items + .iter_mut() + .try_for_each(|i| i.visit_expressions_mut(f, o)) + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&E) -> ControlFlow, + { + self.items + .iter() + .try_for_each(|i| i.visit_expressions(f, o)) + } +} + +impl> ExpressionVisitable for FunctionTypeName { + fn visit_expressions_mut(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&mut E) -> ControlFlow, + { + self.params + .iter_mut() + .chain(once(self.value.as_mut())) + .try_for_each(|i| i.visit_expressions_mut(f, o)) + } + + fn visit_expressions(&self, f: &mut F, o: VisitOrder) -> ControlFlow + where + F: FnMut(&E) -> ControlFlow, + { + self.params + .iter() + .chain(once(self.value.as_ref())) + .try_for_each(|i| i.visit_expressions(f, o)) + } +} diff --git a/book/src/pil/declarations.md b/book/src/pil/declarations.md index c5b7524af..10fe1bf40 100644 --- a/book/src/pil/declarations.md +++ b/book/src/pil/declarations.md @@ -1,37 +1,26 @@ # Declarations Powdr-pil allows the same syntax to declare various kinds of symbols. This includes -constants, fixed columns, witness columns and even macros. It deduces the symbol kind -by its type and the way the symbol is used. +constants, fixed columns, witness columns and even higher-order functions. It deduces the symbol kind +from the type of the symbol and the way the symbol is used. Symbols can be declared using ``let ;`` and they can be declared and defined -using ``let = ;``, where ```` is an expression. +using ``let = ;``, where ```` is an expression. The [type](./types.md) of the symbol +can be explicitly specified using ``let : ;`` and ``let : = ;``. + This syntax can be used for constants, fixed columns, witness columns and even (higher-order) functions that can transform expressions. The kind of symbol is deduced by its type and the way the symbol is used: -- symbols without a value are witness columns, -- symbols evaluating to a number are constants, -- symbols defined as a function with a single parameter are fixed columns and -- everything else is a "generic symbol" that is not a column. +- Symbols without a value are witness columns. Their type can be omitted. If it is given, it must be ``int -> fe`` or its shorthand ``col``. +- Symbols evaluating to a number or with type ``fe`` are constants. +- Symbols without type but with a value that is a function with a single parameter are fixed columns. +- Symbols defined with a value and type ``int -> fe`` or its shorthand ``col`` are also fixed columns. +- Everything else is a "generic symbol" that is not a column or constant. Examples: + ```rust -// This defines a constant -let rows = 2**16; -// This defines a fixed column that contains the row number in each row. -let step = |i| i; -// Here, we have a witness column. -let x; -// This functions returns the square of its input (classified as a fixed column). -let square = |x| x*x; -// A recursive function, taking a function and an integer as parameter -let sum = |f, i| match i { - 0 => f(0), - _ => f(i) + sum(f, i - 1) -}; -// The same function as "square" above, but employing a trick to avoid it -// being classified as a column. -let square_non_column = (|| |x| x*x)(); +{{#include ../../../test_data/pil/book/declarations.pil:declarations}} ``` \ No newline at end of file diff --git a/executor/src/constant_evaluator/mod.rs b/executor/src/constant_evaluator/mod.rs index c089e9ccc..1c4e14f9e 100644 --- a/executor/src/constant_evaluator/mod.rs +++ b/executor/src/constant_evaluator/mod.rs @@ -1,7 +1,10 @@ use std::{collections::HashMap, fmt::Display, rc::Rc}; use itertools::Itertools; -use powdr_ast::analyzed::{Analyzed, FunctionValueDefinition}; +use powdr_ast::analyzed::{ + types::{Type, TypedExpression}, + Analyzed, FunctionValueDefinition, +}; use powdr_number::{DegreeType, FieldElement}; use powdr_pil_analyzer::evaluator::{self, Custom, EvalError, SymbolLookup, Value}; use rayon::prelude::{IntoParallelIterator, ParallelIterator}; @@ -43,20 +46,25 @@ fn generate_values( }; // TODO we should maybe pre-compute some symbols here. let result = match body { - FunctionValueDefinition::Expression(e) => (0..degree) - .into_par_iter() - .map(|i| { - // We could try to avoid the first evaluation to be run for each iteration, - // but the data is not thread-safe. - let fun = evaluator::evaluate(e, &symbols).unwrap(); - evaluator::evaluate_function_call( - fun, - vec![Rc::new(Value::Integer(num_bigint::BigInt::from(i)))], - &symbols, - ) - .and_then(|v| v.try_to_field_element()) - }) - .collect::, _>>(), + FunctionValueDefinition::Expression(TypedExpression { e, ty }) => { + if let Some(ty) = ty { + assert_eq!(ty, &Type::col()) + }; + (0..degree) + .into_par_iter() + .map(|i| { + // We could try to avoid the first evaluation to be run for each iteration, + // but the data is not thread-safe. + let fun = evaluator::evaluate(e, &symbols).unwrap(); + evaluator::evaluate_function_call( + fun, + vec![Rc::new(Value::Integer(num_bigint::BigInt::from(i)))], + &symbols, + ) + .and_then(|v| v.try_to_field_element()) + }) + .collect::, _>>() + } FunctionValueDefinition::Array(values) => values .iter() .map(|elements| { @@ -103,8 +111,8 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, FixedColumnRef<'a>> for Symbols<'a Value::Custom(FixedColumnRef { name }) } else if let Some((_, value)) = self.analyzed.definitions.get(&name.to_string()) { match value { - Some(FunctionValueDefinition::Expression(value)) => { - evaluator::evaluate(value, self)? + Some(FunctionValueDefinition::Expression(TypedExpression { e, ty: _ })) => { + evaluator::evaluate(e, self)? } Some(_) => Err(EvalError::Unsupported( "Cannot evaluate arrays and queries.".to_string(), @@ -268,7 +276,7 @@ mod test { let src = r#" constant %N = 8; namespace F(%N); - let minus_one = [|x| x - 1][0]; + let minus_one: int -> int = |x| x - 1; pol constant EVEN(i) { 2 * minus_one(i) + 2 }; "#; let analyzed = analyze_string(src); diff --git a/importer/src/path_canonicalizer.rs b/importer/src/path_canonicalizer.rs index f0be06844..7a4420395 100644 --- a/importer/src/path_canonicalizer.rs +++ b/importer/src/path_canonicalizer.rs @@ -15,7 +15,8 @@ use powdr_ast::{ }, folder::Folder, visitor::ExpressionVisitable, - ArrayLiteral, FunctionCall, IndexAccess, LambdaExpression, MatchArm, + ArrayLiteral, ExpressionWithTypeName, FunctionCall, IndexAccess, LambdaExpression, + MatchArm, }, }; @@ -79,9 +80,14 @@ impl<'a, T> Folder for Canonicalizer<'a> { .map(Some) .transpose(), }, - SymbolValue::Expression(mut e) => { - canonicalize_inside_expression(&mut e, &self.path, self.paths); - Some(Ok(SymbolValue::Expression(e))) + SymbolValue::Expression(mut exp) => { + for tne in + exp.type_name.iter_mut().flat_map(|tn| tn.expressions_mut()) + { + canonicalize_inside_expression(tne, &self.path, self.paths); + } + canonicalize_inside_expression(&mut exp.e, &self.path, self.paths); + Some(Ok(SymbolValue::Expression(exp))) } } .map(|value| value.map(|value| SymbolDefinition { name, value }.into())) @@ -321,7 +327,10 @@ fn check_module( check_module(location.with_part(name), m, state)?; } SymbolValue::Import(s) => check_import(location.clone(), s.clone(), state)?, - SymbolValue::Expression(e) => { + SymbolValue::Expression(ExpressionWithTypeName { e, type_name }) => { + for tne in type_name.iter().flat_map(|tn| tn.expressions()) { + check_expression(&location, tne, state, &HashSet::default())? + } check_expression(&location, e, state, &HashSet::default())? } } diff --git a/linker/src/lib.rs b/linker/src/lib.rs index 862f7d97f..8b72ee79b 100644 --- a/linker/src/lib.rs +++ b/linker/src/lib.rs @@ -7,7 +7,7 @@ use powdr_ast::{ asm::AbsoluteSymbolPath, asm::SymbolPath, build::{direct_reference, index_access, namespaced_reference}, - Expression, PILFile, PilStatement, SelectedExpressions, + Expression, ExpressionWithTypeName, PILFile, PilStatement, SelectedExpressions, }, SourceRef, }; @@ -42,9 +42,14 @@ pub fn link(graph: PILGraph) -> Result, Vec expr = (|X| (X * (1 - X))); let one_hot = (|i, which| match i { which => 1, _ => 0, }); pol constant ISLAST(i) { one_hot(i, %last_row) }; pol commit arr[8]; @@ -440,5 +440,39 @@ namespace Fibonacci(%N); ); assert_eq!(input.trim(), printed.trim()); } + + #[test] + fn type_names_simple() { + let input = r#" + let a: col; + let b: int; + let c: fe; + let d: int[]; + let e: int[7]; + let f: (int, fe, fe[3])[2];"#; + let printed = format!( + "{}", + parse::(Some("input"), input).unwrap() + ); + assert_eq!(input.trim(), printed.trim()); + } + + #[test] + fn type_names_complex() { + let input = r#" + let a: int -> fe; + let b: int -> (); + let c: -> (); + let d: int, int -> fe; + let e: int, int -> (fe, int[2]); + let f: ((int, fe), fe[2] -> (fe -> int))[]; + let g: (int -> fe) -> int; + let h: int -> (fe -> int);"#; + let printed = format!( + "{}", + parse::(Some("input"), input).unwrap() + ); + assert_eq!(input.trim(), printed.trim()); + } } } diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index 1cd55a1d8..e09af9e2e 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -57,7 +57,11 @@ Part: Part = { } LetStatementAtModuleLevel: SymbolDefinition = { - "let" "=" ";" => SymbolDefinition { name, value: SymbolValue::Expression(value) } + "let" )?> "=" ";" => + SymbolDefinition { + name, + value: SymbolValue::Expression(ExpressionWithTypeName{ e: value, type_name }) + } } // ---------------------------- PIL part ----------------------------- @@ -88,7 +92,7 @@ Namespace: PilStatement = { } LetStatement: PilStatement = { - "let" )?> "}" => Box::new(Expression::IfExpression(IfExpression{<>})) } +// ---------------------------- Type Names ----------------------------- + +TypeName: TypeName> = { + "->" => TypeName::Function(FunctionTypeName{<>}), + TypeNameTerm +} + +TypeNameTermList: Vec>> = { + => vec![], + "," )*> => { list.push(end); list } +} + +TypeNameTermBox: Box>> = { + TypeNameTerm => Box::new(<>) +} + +TypeNameTerm: TypeName> = { + "bool" => TypeName::Bool, + "int" => TypeName::Int, + "fe" => TypeName::Fe, + "string" => TypeName::String, + "col" => TypeName::Col, + "expr" => TypeName::Expr, + "constr" => TypeName::Constr, + "[" "]" => TypeName::Array(ArrayTypeName{base: Box::new(base), length}), + "(" "," )+> ")" => { items.push(end); TypeName::Tuple(TupleTypeName{items}) }, + "(" ")" => TypeName::Tuple(TupleTypeName{items: vec![]}), + "(" ")", +} + // ---------------------------- Terminals ----------------------------- @@ -561,6 +595,8 @@ SpecialIdentifier: &'input str = { "insn", "int", "fe", + "expr", + "constr", "bool", } diff --git a/pil-analyzer/src/condenser.rs b/pil-analyzer/src/condenser.rs index 71441e172..0ca92fbbf 100644 --- a/pil-analyzer/src/condenser.rs +++ b/pil-analyzer/src/condenser.rs @@ -6,6 +6,7 @@ use std::{collections::HashMap, fmt::Display, rc::Rc}; use itertools::Itertools; use powdr_ast::{ analyzed::{ + types::{Type, TypedExpression}, AlgebraicExpression, AlgebraicReference, Analyzed, Expression, FunctionValueDefinition, Identity, IdentityKind, PolynomialReference, PolynomialType, PublicDeclaration, Reference, StatementIdentifier, Symbol, SymbolKind, @@ -58,9 +59,10 @@ pub fn condense( let Some(FunctionValueDefinition::Expression(e)) = definition else { panic!("Expected expression") }; + assert!(e.ty.is_none() || e.ty == Some(Type::col())); Some(( name.clone(), - (symbol.clone(), condenser.condense_expression(e)), + (symbol.clone(), condenser.condense_expression(&e.e)), )) } else { None @@ -231,7 +233,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, Condensate> for &'a Condenser { + Some(FunctionValueDefinition::Expression(TypedExpression { e: value, ty: _ })) => { evaluator::evaluate(value, self)? } _ => Err(EvalError::Unsupported( @@ -278,8 +280,8 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, Condensate> for &'a Condenser { - let function = evaluate(v, self)?; + Some(FunctionValueDefinition::Expression(TypedExpression { e, ty: _ })) => { + let function = evaluate(e, self)?; evaluate_function_call(function, arguments, self) } None => Err(EvalError::SymbolNotFound(format!( @@ -390,7 +392,7 @@ impl Custom for Condensate { fn type_name(&self) -> String { match self { Condensate::Expression(_) => "expr".to_string(), - Condensate::Identity(_, _) => "identity".to_string(), + Condensate::Identity(_, _) => "constr".to_string(), } } } diff --git a/pil-analyzer/src/evaluator.rs b/pil-analyzer/src/evaluator.rs index 17131dd01..9b18ee5b5 100644 --- a/pil-analyzer/src/evaluator.rs +++ b/pil-analyzer/src/evaluator.rs @@ -6,7 +6,7 @@ use std::{ use itertools::Itertools; use powdr_ast::{ - analyzed::{Expression, FunctionValueDefinition, Reference, Symbol}, + analyzed::{types::TypedExpression, Expression, FunctionValueDefinition, Reference, Symbol}, parsed::{ display::quote, BinaryOperator, FunctionCall, LambdaExpression, MatchArm, MatchPattern, UnaryOperator, @@ -189,7 +189,7 @@ const BUILTINS: [(&str, BuiltinFunction); 6] = [ #[derive(Clone, Copy, PartialEq, Debug)] pub enum BuiltinFunction { - /// std::array::len: [_] -> int, returns the length of an array + /// std::array::len: _[] -> int, returns the length of an array ArrayLen, /// std::field::modulus: -> int, returns the field modulus as int Modulus, @@ -282,7 +282,9 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Definitions<'a, T> { fn lookup(&self, name: &'a str) -> Result, EvalError> { Ok(match self.0.get(&name.to_string()) { Some((_, value)) => match value { - Some(FunctionValueDefinition::Expression(value)) => evaluate(value, self)?, + Some(FunctionValueDefinition::Expression(TypedExpression { e, ty: _ })) => { + evaluate(e, self)? + } _ => Err(EvalError::Unsupported( "Cannot evaluate arrays and queries.".to_string(), ))?, @@ -640,7 +642,8 @@ mod test { fn parse_and_evaluate_symbol(input: &str, symbol: &str) -> String { let analyzed = analyze_string::(input); - let Some(FunctionValueDefinition::Expression(symbol)) = &analyzed.definitions[symbol].1 + let Some(FunctionValueDefinition::Expression(TypedExpression { e: symbol, ty: _ })) = + &analyzed.definitions[symbol].1 else { panic!() }; @@ -686,8 +689,8 @@ mod test { #[test] pub fn capturing() { let src = r#"namespace Main(16); - let f = |n, g| match n { 99 => |i| n, 1 => g(3) }; - let result = f(1, f(99, |x| x + 3000)); + let f: int, (int -> int) -> (int -> int) = |n, g| match n { 99 => |i| n, 1 => g }; + let result = f(1, f(99, |x| x + 3000))(0); "#; // If the lambda function returned by the expression f(99, ...) does not // properly capture the value of n in a closure, then f(1, ...) would return 1. diff --git a/pil-analyzer/src/pil_analyzer.rs b/pil-analyzer/src/pil_analyzer.rs index 81bcecdd0..f51cc60ba 100644 --- a/pil-analyzer/src/pil_analyzer.rs +++ b/pil-analyzer/src/pil_analyzer.rs @@ -505,6 +505,58 @@ namespace N(16); let expected = r#"namespace N(16); let w = (|| 2); constant x = (|i| (|| N.w()))(2)(); +"#; + let formatted = analyze_string::(input).to_string(); + assert_eq!(formatted, expected); + } + + #[test] + fn simple_type_resolution() { + let input = r#"namespace N(16); + let w: col[3 + 4]; + "#; + let expected = r#"namespace N(16); + col witness w[7]; +"#; + let formatted = analyze_string::(input).to_string(); + assert_eq!(formatted, expected); + } + + #[test] + fn complex_type_resolution() { + let input = r#"namespace N(16); + let f: int -> int = |i| i + 10; + let x: (int -> int), int -> int = |k, i| k(2**i); + let y: (int -> fe)[x(f, 2)]; + let z: (((int -> int), int -> int)[x(|i| i, 3)], col) = ([x, x, x, x, x, x, x, x], y[0]); + "#; + let expected = r#"namespace N(16); + let f: int -> int = (|i| (i + 10)); + let x: (int -> int), int -> int = (|k, i| k((2 ** i))); + col witness y[14]; + let z: (((int -> int), int -> int)[8], col) = ([N.x, N.x, N.x, N.x, N.x, N.x, N.x, N.x], N.y[0]); +"#; + let formatted = analyze_string::(input).to_string(); + assert_eq!(formatted, expected); + } + + #[test] + fn expr_and_identity() { + let input = r#"namespace N(16); + let f: expr, expr -> constr[] = |x, y| [x == y]; + let g: expr -> constr[1] = |x| [x == 0]; + let x: col; + let y: col; + f(x, y); + g((x)); + "#; + let expected = r#"namespace N(16); + let f: expr, expr -> constr[] = (|x, y| [(x == y)]); + let g: expr -> constr[1] = (|x| [(x == 0)]); + col witness x; + col witness y; + N.x = N.y; + N.x = 0; "#; let formatted = analyze_string::(input).to_string(); assert_eq!(formatted, expected); diff --git a/pil-analyzer/src/statement_processor.rs b/pil-analyzer/src/statement_processor.rs index 4358c593c..61c242a41 100644 --- a/pil-analyzer/src/statement_processor.rs +++ b/pil-analyzer/src/statement_processor.rs @@ -1,8 +1,9 @@ use std::collections::{BTreeMap, HashMap}; use std::marker::PhantomData; +use powdr_ast::analyzed::types::{ArrayType, Type, TypedExpression}; use powdr_ast::parsed::{ - self, FunctionDefinition, PilStatement, PolynomialName, SelectedExpressions, + self, FunctionDefinition, PilStatement, PolynomialName, SelectedExpressions, TypeName, }; use powdr_ast::SourceRef; use powdr_number::{DegreeType, FieldElement}; @@ -103,8 +104,8 @@ where .handle_symbol_definition( source, name, - None, SymbolKind::Poly(PolynomialType::Intermediate), + Some(Type::col()), Some(FunctionDefinition::Expression(value)), ), PilStatement::PublicDeclaration(source, name, polynomial, array_index, index) => { @@ -117,8 +118,8 @@ where .handle_symbol_definition( source, name, - None, SymbolKind::Poly(PolynomialType::Constant), + Some(Type::col()), Some(definition), ), PilStatement::PolynomialCommitDeclaration(source, polynomials, None) => { @@ -130,12 +131,14 @@ where Some(definition), ) => { assert!(polynomials.len() == 1); - let name = polynomials.pop().unwrap(); + let (name, ty) = + self.name_and_type_from_polynomial_name(polynomials.pop().unwrap()); + self.handle_symbol_definition( source, - name.name, - name.array_size, + name, SymbolKind::Poly(PolynomialType::Committed), + ty, Some(definition), ) } @@ -147,24 +150,53 @@ where self.handle_symbol_definition( source, name, - None, SymbolKind::Constant(), + Some(Type::Fe), Some(FunctionDefinition::Expression(value)), ) } - PilStatement::LetStatement(source, name, value) => { - self.handle_generic_definition(source, name, value) + PilStatement::LetStatement(source, name, type_name, value) => { + self.handle_generic_definition(source, name, type_name, value) } _ => self.handle_identity_statement(statement), } } + fn name_and_type_from_polynomial_name( + &mut self, + PolynomialName { name, array_size }: PolynomialName, + ) -> (String, Option) { + let ty = Some(match array_size { + None => Type::col(), + Some(len) => { + let length = self + .evaluate_expression(len) + .map_err(|e| { + panic!("Error evaluating length of array of witness columns {name}:\n{e}") + }) + .map(|length| length.to_degree()) + .ok(); + Type::Array(ArrayType { + base: Box::new(Type::col()), + length, + }) + } + }); + (name, ty) + } + fn handle_generic_definition( &mut self, source: SourceRef, name: String, - value: Option<::powdr_ast::parsed::Expression>, + type_name: Option>>, + value: Option>, ) -> Vec> { + let ty = type_name.map(|n| + self.resolve_type_name(n.clone()) + .map_err(|e| panic!("Error evaluating expressions in type name \"{n}\" to reduce it to a type:\n{e})")) + .unwrap() + ); // Determine whether this is a fixed column, a constant or something else // depending on the structure of the value and if we can evaluate // it to a single number. @@ -172,30 +204,56 @@ where match value { None => { // No value provided => treat it as a witness column. + let ty = ty + .map(|t| { + if let Type::Array(ArrayType { base, length }) = &t { + if base.as_ref() != &Type::col() { + panic!("Symbol {name} is declared without value and thus must be a witness column array, but its type is {t} instead of col[]."); + } + if length.is_none() { + panic!("Explicit array length required for column {name}: {t}"); + } + t + } else { + if t != Type::col() { + panic!("Symbol {name} is declared without value and thus must be a witness column, but its type is {t} instead of col."); + } + t + } + }) + .unwrap_or(Type::col()); self.handle_symbol_definition( source, name, - None, SymbolKind::Poly(PolynomialType::Committed), + Some(ty), None, ) } Some(value) => { - let symbol_kind = if matches!(&value, parsed::Expression::LambdaExpression(lambda) if lambda.params.len() == 1) + // TODO if we have proper type deduction here in the future, we can rely only on the type. + let (ty, symbol_kind) = if ty == Some(Type::col()) + || (ty.is_none() + && matches!(&value, parsed::Expression::LambdaExpression(lambda) if lambda.params.len() == 1)) + { + ( + Some(Type::col()), + SymbolKind::Poly(PolynomialType::Constant), + ) + } else if ty == Some(Type::Fe) + || (ty.is_none() && self.evaluate_expression(value.clone()).is_ok()) { - SymbolKind::Poly(PolynomialType::Constant) - } else if self.evaluate_expression(value.clone()).is_ok() { // Value evaluates to a constant number => treat it as a constant - SymbolKind::Constant() + (Some(Type::Fe), SymbolKind::Constant()) } else { // Otherwise, treat it as "generic definition" - SymbolKind::Other() + (ty, SymbolKind::Other()) }; self.handle_symbol_definition( source, name, - None, symbol_kind, + ty, Some(FunctionDefinition::Expression(value)), ) } @@ -261,12 +319,13 @@ where ) -> Vec> { polynomials .into_iter() - .flat_map(|PolynomialName { name, array_size }| { + .flat_map(|poly_name| { + let (name, ty) = self.name_and_type_from_polynomial_name(poly_name); self.handle_symbol_definition( source.clone(), name, - array_size, SymbolKind::Poly(polynomial_type), + ty, None, ) }) @@ -277,17 +336,20 @@ where &mut self, source: SourceRef, name: String, - array_size: Option<::powdr_ast::parsed::Expression>, symbol_kind: SymbolKind, + ty: Option, value: Option>, ) -> Vec> { - let have_array_size = array_size.is_some(); - let length = array_size - .map(|l| self.evaluate_expression(l).unwrap()) - .map(|l| l.to_degree()); - if length.is_some() { - assert!(value.is_none()); - } + let length = ty.as_ref().and_then(|t| { + if let Type::Array(ArrayType { length, base: _ }) = t { + if length.is_none() && symbol_kind != SymbolKind::Other() { + panic!("Explicit array length required for column {name}."); + } + *length + } else { + None + } + }); let id = self.counters.dispense_symbol_id(symbol_kind, length); let name = self.driver.resolve_decl(&name); let symbol = Symbol { @@ -300,13 +362,15 @@ where let value = value.map(|v| match v { FunctionDefinition::Expression(expr) => { - assert!(!have_array_size); assert!(symbol_kind != SymbolKind::Poly(PolynomialType::Committed)); - FunctionValueDefinition::Expression(self.process_expression(expr)) + FunctionValueDefinition::Expression(TypedExpression { + e: self.process_expression(expr), + ty, + }) } FunctionDefinition::Query(expr) => { - assert!(!have_array_size); assert_eq!(symbol_kind, SymbolKind::Poly(PolynomialType::Committed)); + assert!(ty.is_none() || ty == Some(Type::col())); FunctionValueDefinition::Query(self.process_expression(expr)) } FunctionDefinition::Array(value) => { @@ -318,6 +382,7 @@ where expression.iter().map(|e| e.size()).sum::(), self.degree.unwrap() ); + assert!(ty.is_none() || ty == Some(Type::col())); FunctionValueDefinition::Array(expression) } }); @@ -351,10 +416,18 @@ where })] } - fn evaluate_expression( - &self, - expr: ::powdr_ast::parsed::Expression, - ) -> Result { + /// Resolves a type name into a concrete type. + /// This routine mainly evaluates array length expressions. + fn resolve_type_name(&self, mut n: TypeName>) -> Result { + // Replace all expressions by number literals. + for e in n.expressions_mut() { + let v = self.evaluate_expression(e.clone())?; + *e = parsed::Expression::Number(v); + } + Ok(n.into()) + } + + fn evaluate_expression(&self, expr: parsed::Expression) -> Result { evaluator::evaluate_expression( &ExpressionProcessor::new(self.driver).process_expression(expr), self.driver.definitions(), @@ -366,13 +439,13 @@ where ExpressionProcessor::new(self.driver) } - fn process_expression(&self, expr: ::powdr_ast::parsed::Expression) -> Expression { + fn process_expression(&self, expr: parsed::Expression) -> Expression { self.expression_processor().process_expression(expr) } fn process_selected_expressions( &self, - expr: ::powdr_ast::parsed::SelectedExpressions<::powdr_ast::parsed::Expression>, + expr: parsed::SelectedExpressions>, ) -> SelectedExpressions> { self.expression_processor() .process_selected_expressions(expr) diff --git a/std/arith.asm b/std/arith.asm index 5067bcefd..8104f68b0 100644 --- a/std/arith.asm +++ b/std/arith.asm @@ -137,11 +137,11 @@ machine Arith(CLK32_31, operation_id){ /// returns |n| a(0) * b(n) + ... + a(n) * b(0) let product = |a, b| |n| dot_prod(n + 1, a, |i| b(n - i)); /// Converts array to function, extended by zeros. - let array_as_fun = [|arr| |i| if 0 <= i && i < array::len(arr) { + let array_as_fun: expr[] -> (int -> expr) = |arr| |i| if 0 <= i && i < array::len(arr) { arr[i] } else { 0 - }][0]; + }; let shift_right = |fn, amount| |i| fn(i - amount); let x1f = array_as_fun(x1); @@ -151,12 +151,11 @@ machine Arith(CLK32_31, operation_id){ let y3f = array_as_fun(y3); // Defined for arguments from 0 to 31 (inclusive) - let eq0 = (|| |nr| + let eq0: int -> expr = |nr| product(x1f, y1f)(nr) + x2f(nr) - shift_right(y2f, 16)(nr) - - y3f(nr) - )(); + - y3f(nr); // Note that Polygon uses a single 22-Bit column. However, this approach allows for a lower degree (2**16) // while still preventing overflows: The 32-bit carry gets added to 32 16-Bit values, which can't overflow diff --git a/std/array.asm b/std/array.asm index 724fa42fa..19cca4b64 100644 --- a/std/array.asm +++ b/std/array.asm @@ -13,4 +13,5 @@ let map = |arr, f| new(len(arr), |i| f(arr[i])); let fold = |arr, initial, folder| std::utils::fold(len(arr), |i| arr[i], initial, folder); /// Returns the sum of the array elements. -let sum = [|arr| fold(arr, 0, |a, b| a + b)][0]; \ No newline at end of file +/// This actually also works on field elements, so the type is currently too restrictive. +let sum: int[] -> int = |arr| fold(arr, 0, |a, b| a + b); \ No newline at end of file diff --git a/std/debug.asm b/std/debug.asm index e7f90a10c..3ea28f6b0 100644 --- a/std/debug.asm +++ b/std/debug.asm @@ -2,6 +2,6 @@ /// when evaluated. /// It returns an empty array so that it can be used at constraint level. /// This symbol is not an empty array, the actual semantics are overridden. -let print = []; +let print: string -> constr[] = []; -let println = [|msg| print(msg + "\n")][0]; \ No newline at end of file +let println: string -> constr[] = |msg| print(msg + "\n"); \ No newline at end of file diff --git a/std/utils.asm b/std/utils.asm index 134ffbcad..ae251d2b6 100644 --- a/std/utils.asm +++ b/std/utils.asm @@ -22,4 +22,4 @@ let sum = |length, f| fold(length, f, 0, |acc, e| (acc + e)); let unchanged_until = |c, latch| (c' - c) * (1 - latch) == 0; /// Evaluates to a constraint that forces `c` to be either 0 or 1. -let force_bool = [|c| c * (1 - c) == 0][0]; \ No newline at end of file +let force_bool: expr -> constr = |c| c * (1 - c) == 0; \ No newline at end of file diff --git a/test_data/asm/pil_at_module_level.asm b/test_data/asm/pil_at_module_level.asm index abb515a35..e68b29766 100644 --- a/test_data/asm/pil_at_module_level.asm +++ b/test_data/asm/pil_at_module_level.asm @@ -26,11 +26,12 @@ mod R { machine FullConstant { degree 2; - let C = |i| match i % 2 { + let C: int -> fe = |i| match i % 2 { 0 => x, 1 => y, }; - col commit w[2]; + // Use some weird type just for the sake of it. + let w: col[sum(2, |i| 1)]; // This and the next line are the same. super::utils::sum(2, |i| w[i]) == 8; diff --git a/test_data/pil/arith_improved.pil b/test_data/pil/arith_improved.pil index 7ae10e4f1..232db3121 100644 --- a/test_data/pil/arith_improved.pil +++ b/test_data/pil/arith_improved.pil @@ -29,8 +29,7 @@ namespace Arith(N); /// returns f(0) + f(1) + ... + f(length - 1) let sum = |length, f| fold(length, f, 0, |acc, e| acc + e); - // TODO the weird syntax is needed so that this is not classified as a constant column - let force_boolean = (|| |x| x * (1 - x) == 0)(); + let force_boolean: expr -> constr = |x| x * (1 - x) == 0; let clock = |j, row| if row % 32 == j { 1 } else { 0 }; // Arrays of fixed columns are not supported yet. @@ -140,8 +139,7 @@ namespace Arith(N); // That way we could even support functions returning lookups. // x can only change between two blocks of 32 rows. - // TODO the weird syntax is needed so that this is not classified as a fixed column. - let fixed_inside_32_block = (|| |x| (x - x') * (1 - CLK32[31]) == 0)(); + let fixed_inside_32_block: expr -> constr = |x| (x - x') * (1 - CLK32[31]) == 0; make_array(16, |i| fixed_inside_32_block(x1[i])); make_array(16, |i| fixed_inside_32_block(y1[i])); @@ -209,12 +207,11 @@ namespace Arith(N); let q2f = array_as_fun(q2, 16); // Defined for arguments from 0 to 31 (inclusive) - let eq0 = (|| |nr| + let eq0: int -> expr = |nr| product(x1f, y1f)(nr) + x2f(nr) - shift_right(y2f, 16)(nr) - - y3f(nr) - )(); + - y3f(nr); /******* @@ -224,16 +221,16 @@ namespace Arith(N); *******/ // 0xffffffffffffffffffffffffffffffffffffffffffffffffffff fffe ffff fc2f - let p = array_as_fun([ + let p: col = array_as_fun([ 0xfc2f, 0xffff, 0xfffe, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff ], 16); - // The "- 4 * shift_right(p, 16)" effectively subtracts 4 * (p << 16 * 16) = 2 ** 258 * p - // As a result, the term computes `(x - 2 ** 258) * p`. - let product_with_p = (|| |x| |nr| product(p, x)(nr) - 4 * shift_right(p, 16)(nr))(); + // The "- 4 * shift_right(p, 16)" effectively subtracts 4 * (p << 16 * 16) = 2 ** 258 * p + // As a result, the term computes `(x - 2 ** 258) * p`. + let product_with_p: int -> (int -> expr) = |x| |nr| product(p, x)(nr) - 4 * shift_right(p, 16)(nr); - let eq1 = (|| |nr| product(sf, x2f)(nr) - product(sf, x1f)(nr) - y2f(nr) + y1f(nr) + product_with_p(q0f)(nr))(); + let eq1: int -> expr = |nr| product(sf, x2f)(nr) - product(sf, x1f)(nr) - y2f(nr) + y1f(nr) + product_with_p(q0f)(nr); /******* * @@ -241,7 +238,7 @@ namespace Arith(N); * *******/ - let eq2 = (|| |nr| 2 * product(sf, y1f)(nr) - 3 * product(x1f, x1f)(nr) + product_with_p(q0f)(nr))(); + let eq2: int -> expr = |nr| 2 * product(sf, y1f)(nr) - 3 * product(x1f, x1f)(nr) + product_with_p(q0f)(nr); /******* * @@ -249,7 +246,7 @@ namespace Arith(N); * *******/ - let eq3 = (|| |nr| product(sf, sf)(nr) - x1f(nr) - x2f(nr) - x3f(nr) + product_with_p(q1f)(nr))(); + let eq3: int -> expr = |nr| product(sf, sf)(nr) - x1f(nr) - x2f(nr) - x3f(nr) + product_with_p(q1f)(nr); /******* @@ -258,7 +255,7 @@ namespace Arith(N); * *******/ - let eq4 = (|| |nr| product(sf, x1f)(nr) - product(sf, x3f)(nr) - y1f(nr) - y3f(nr) + product_with_p(q2f)(nr))(); + let eq4: int -> expr = |nr| product(sf, x1f)(nr) - product(sf, x3f)(nr) - y1f(nr) - y3f(nr) + product_with_p(q2f)(nr); pol commit selEq[4]; diff --git a/test_data/pil/book/declarations.pil b/test_data/pil/book/declarations.pil new file mode 100644 index 000000000..de4020fa2 --- /dev/null +++ b/test_data/pil/book/declarations.pil @@ -0,0 +1,30 @@ +namespace Main(16); +// ANCHOR: declarations + // This defines a constant + let rows = 2**16; + // This defines a fixed column that contains the row number in each row. + let step = |i| i; + // This defines a copy of the column, also a fixed column because the type + // is explicitly specified. + let also_step: col = step; + // Here, we have a witness column. + let x; + // This functions defines a fixed column where each cell contains the + // square of its row number. + let square = |x| x*x; + // The same function as `square` above, but now its type is given as + // `int -> int` and thus it is *not* classified as a column. Instead, + // it is stored as a utility function. If utility functions are + // referenced in constraints, they have to be evaluated, meaning that + // the constraint `w = square_non_column;` is invalid but both + // `w = square_non_column(7);` and `w = square;` are valid constraints. + let square_non_column: int -> int = |x| x*x; + // A recursive function, taking a function and an integer as parameter + let sum = |f, i| match i { + 0 => f(0), + _ => f(i) + sum(f, i - 1) + }; +// ANCHOR_END: declarations + // We need at least one constraint to create a proof in the test. + let w; + w + square = 0; \ No newline at end of file diff --git a/test_data/pil/book/generic_to_algebraic.pil b/test_data/pil/book/generic_to_algebraic.pil index d9cfc5674..9ee4a981a 100644 --- a/test_data/pil/book/generic_to_algebraic.pil +++ b/test_data/pil/book/generic_to_algebraic.pil @@ -6,20 +6,17 @@ namespace Main(16); }; // returns f(0) + f(1) + ... + f(length - 1) let sum = |length, f| fold(length, f, 0, |acc, e| acc + e); - // If called with a single value, this function evaluates the equality, - // otherwise, it returns a constraint (if called with a column or - // an algebraic expression). - // If we write "|x| x == 20", it will be classified as a fixed column, - // so we use a trick that makes it not look like a function with a single - // parameter. - let equals_twenty = [|x| x == 20][0]; - // declares an array of 16 witness columns. + // This function takes an algebraic expression (a column or expression + // involving columns) and returns an identity that forces this expression + // to equal 20. + let equals_twenty: expr -> constr = |x| x == 20; + // This declares an array of 16 witness columns. col witness wit[16]; - // This expression has to evaluate to a constraint, but we can still use + // This expression has to evaluate to an identity, but we can still use // higher order functions and all the flexibility of the language. - // The sub-expression "sum(16, |i| wit[i]" evaluates to the algebraic + // The sub-expression `sum(16, |i| wit[i])` evaluates to the algebraic // expression "wit[0] + wit[1] + ... + wit[15]", which is then - // turned by "equals_twenty" into the constraint + // turned into the identity by `equals_twenty` // wit[0] + wit[1] + ... + wit[15] == 20. equals_twenty(sum(16, |i| wit[i]));