Merge pull request #614 from powdr-labs/arrays

Support types for declarations
This commit is contained in:
Leo
2024-02-06 10:08:46 +00:00
committed by GitHub
30 changed files with 1028 additions and 187 deletions

View File

@@ -10,7 +10,10 @@ use std::{
use itertools::Itertools;
use self::parsed::asm::{AbsoluteSymbolPath, SymbolPath};
use self::{
parsed::asm::{AbsoluteSymbolPath, SymbolPath},
types::{ArrayType, FunctionType, TupleType, Type},
};
use super::*;
@@ -47,7 +50,18 @@ impl<T: Display> Display for Analyzed<T> {
};
write!(f, " col {kind}{name}")?;
if let Some(length) = symbol.length {
write!(f, "[{length}]")?;
if let PolynomialType::Committed = poly_type {
write!(f, "[{length}]")?;
assert!(definition.is_none());
} else {
// Do not print an array size, because we will do it as part of the type.
assert!(matches!(
definition,
Some(FunctionValueDefinition::Expression(
TypedExpression { e: _, ty: Some(_) }
))
));
}
}
if let Some(value) = definition {
writeln!(f, "{value};")?
@@ -57,11 +71,18 @@ impl<T: Display> Display for Analyzed<T> {
}
SymbolKind::Constant() => {
let indentation = if is_local { " " } else { "" };
writeln!(
f,
"{indentation}constant {name}{};",
definition.as_ref().unwrap()
)?;
let Some(FunctionValueDefinition::Expression(TypedExpression {
e,
ty: Some(Type::Fe),
})) = &definition
else {
panic!(
"Invalid constant value: {}",
definition.as_ref().unwrap()
);
};
writeln!(f, "{indentation}constant {name} = {e};",)?;
}
SymbolKind::Other() => {
write!(f, " let {name}")?;
@@ -108,7 +129,17 @@ impl<T: Display> Display for FunctionValueDefinition<T> {
write!(f, " = {}", items.iter().format(" + "))
}
FunctionValueDefinition::Query(e) => format_outer_function(e, Some("query"), f),
FunctionValueDefinition::Expression(e) => format_outer_function(e, None, f),
FunctionValueDefinition::Expression(TypedExpression { e, ty: None }) => {
format_outer_function(e, None, f)
}
FunctionValueDefinition::Expression(TypedExpression { e, ty: Some(ty) })
if *ty == Type::col() =>
{
format_outer_function(e, None, f)
}
FunctionValueDefinition::Expression(TypedExpression { e, ty: Some(ty) }) => {
write!(f, ": {ty} = {e}")
}
}
}
}
@@ -248,3 +279,65 @@ impl Display for PolynomialReference {
write!(f, "{}", self.name,)
}
}
impl Display for Type {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
match self {
Type::Bool => write!(f, "bool"),
Type::Int => write!(f, "int"),
Type::Fe => write!(f, "fe"),
Type::String => write!(f, "string"),
Type::Expr => write!(f, "expr"),
Type::Constr => write!(f, "constr"),
Type::Array(ar) => write!(f, "{ar}"),
Type::Tuple(tu) => write!(f, "{tu}"),
Type::Function(fun) => write!(f, "{fun}"),
}
}
}
impl Display for ArrayType {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
let length = self.length.iter().format("");
if self.base.needs_parentheses() {
write!(f, "({})[{length}]", self.base)
} else {
write!(f, "{}[{length}]", self.base)
}
}
}
impl Display for TupleType {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(f, "({})", format_list_of_types(&self.items))
}
}
impl Display for FunctionType {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
if *self == Self::col() {
write!(f, "col")
} else {
write!(
f,
"{} -> {}",
format_list_of_types(&self.params),
self.value
)
}
}
}
fn format_list_of_types(types: &[Type]) -> String {
types
.iter()
.map(|x| {
if x.needs_parentheses() {
format!("({x})")
} else {
x.to_string()
}
})
.format(", ")
.to_string()
}

View File

@@ -1,4 +1,5 @@
mod display;
pub mod types;
pub mod visitor;
use core::hash::Hash;
@@ -15,6 +16,8 @@ pub use crate::parsed::UnaryOperator;
use crate::parsed::{self, SelectedExpressions};
use crate::SourceRef;
use self::types::TypedExpression;
#[derive(Debug, Clone)]
pub enum StatementIdentifier {
/// Either an intermediate column or a definition.
@@ -303,7 +306,9 @@ impl<T> Analyzed<T> {
.iter_mut()
.flat_map(|e| e.pattern.iter_mut())
.for_each(|e| e.post_visit_expressions_mut(f)),
Some(FunctionValueDefinition::Expression(e)) => e.post_visit_expressions_mut(f),
Some(FunctionValueDefinition::Expression(TypedExpression { e, ty: _ })) => {
e.post_visit_expressions_mut(f)
}
None => {}
});
}
@@ -467,7 +472,7 @@ pub enum SymbolKind {
pub enum FunctionValueDefinition<T> {
Array(Vec<RepeatedArray<T>>),
Query(Expression<T>),
Expression(Expression<T>),
Expression(TypedExpression<T>),
}
/// An array of elements that might be repeated.

138
ast/src/analyzed/types.rs Normal file
View File

@@ -0,0 +1,138 @@
use std::fmt::Display;
use powdr_number::FieldElement;
use crate::parsed::{ArrayTypeName, Expression, FunctionTypeName, TupleTypeName, TypeName};
use super::Reference;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub struct TypedExpression<T, Ref = Reference> {
pub e: Expression<T, Ref>,
pub ty: Option<Type>,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub enum Type {
/// Boolean
Bool,
/// Integer (arbitrary precision)
Int,
/// Field element (unspecified field)
Fe,
/// String
String,
/// Algebraic expression
Expr,
/// Polynomial identity or lookup (not yet supported)
Constr,
Array(ArrayType),
Tuple(TupleType),
Function(FunctionType),
}
impl Type {
/// Returns the column type `int -> fe`.
pub fn col() -> Self {
Type::Function(FunctionType::col())
}
/// Returns true if the type name needs parentheses around it during formatting
/// when used inside a complex expression.
pub fn needs_parentheses(&self) -> bool {
match self {
Type::Bool
| Type::Int
| Type::Fe
| Type::String
| Type::Expr
| Type::Constr
| Type::Array(_)
| Type::Tuple(_) => false,
Type::Function(fun) => fun.needs_parentheses(),
}
}
}
impl<T: FieldElement, Ref: Display> From<TypeName<Expression<T, Ref>>> for Type {
fn from(value: TypeName<Expression<T, Ref>>) -> Self {
match value {
TypeName::Bool => Type::Bool,
TypeName::Int => Type::Int,
TypeName::Fe => Type::Fe,
TypeName::String => Type::String,
TypeName::Expr => Type::Expr,
TypeName::Constr => Type::Constr,
TypeName::Col => Type::col(),
TypeName::Array(ar) => Type::Array(ar.into()),
TypeName::Tuple(tu) => Type::Tuple(tu.into()),
TypeName::Function(fun) => Type::Function(fun.into()),
}
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub struct ArrayType {
pub base: Box<Type>,
pub length: Option<u64>,
}
impl<T: FieldElement, Ref: Display> From<ArrayTypeName<Expression<T, Ref>>> for ArrayType {
fn from(name: ArrayTypeName<Expression<T, Ref>>) -> Self {
let length = name.length.as_ref().map(|l| {
if let Expression::Number(n) = l {
n.to_degree()
} else {
panic!(
"Array length expression not resolved in type name prior to conversion: {name}"
);
}
});
ArrayType {
base: Box::new(Type::from(*name.base)),
length,
}
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub struct TupleType {
pub items: Vec<Type>,
}
impl<T: FieldElement, Ref: Display> From<TupleTypeName<Expression<T, Ref>>> for TupleType {
fn from(value: TupleTypeName<Expression<T, Ref>>) -> Self {
TupleType {
items: value.items.into_iter().map(Into::into).collect(),
}
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub struct FunctionType {
pub params: Vec<Type>,
pub value: Box<Type>,
}
impl FunctionType {
/// Returns the column type `int -> fe`.
pub fn col() -> Self {
FunctionType {
params: vec![Type::Int],
value: Box::new(Type::Fe),
}
}
pub fn needs_parentheses(&self) -> bool {
*self != Self::col()
}
}
impl<T: FieldElement, Ref: Display> From<FunctionTypeName<Expression<T, Ref>>> for FunctionType {
fn from(name: FunctionTypeName<Expression<T, Ref>>) -> Self {
FunctionType {
params: name.params.into_iter().map(Into::into).collect(),
value: Box::new(Type::from(*name.value)),
}
}
}

View File

@@ -86,7 +86,8 @@ impl<T> ExpressionVisitable<Expression<T>> for FunctionValueDefinition<T> {
F: FnMut(&mut Expression<T>) -> ControlFlow<B>,
{
match self {
FunctionValueDefinition::Query(e) | FunctionValueDefinition::Expression(e) => {
FunctionValueDefinition::Query(e)
| FunctionValueDefinition::Expression(TypedExpression { e, ty: _ }) => {
e.visit_expressions_mut(f, o)
}
FunctionValueDefinition::Array(array) => array
@@ -101,7 +102,8 @@ impl<T> ExpressionVisitable<Expression<T>> for FunctionValueDefinition<T> {
F: FnMut(&Expression<T>) -> ControlFlow<B>,
{
match self {
FunctionValueDefinition::Query(e) | FunctionValueDefinition::Expression(e) => {
FunctionValueDefinition::Query(e)
| FunctionValueDefinition::Expression(TypedExpression { e, ty: _ }) => {
e.visit_expressions(f, o)
}
FunctionValueDefinition::Array(array) => array

View File

@@ -7,7 +7,10 @@ use itertools::Itertools;
use crate::{
indent,
parsed::asm::{AbsoluteSymbolPath, Part},
parsed::{
asm::{AbsoluteSymbolPath, Part},
ExpressionWithTypeName,
},
write_indented_by, write_items_indented,
};
@@ -44,9 +47,15 @@ impl<T: Display> Display for AnalysisASMFile<T> {
Item::Machine(machine) => {
write_indented_by(f, format!("machine {name}{machine}"), current_path.len())?;
}
Item::Expression(expression) => write_indented_by(
Item::Expression(ExpressionWithTypeName { e, type_name }) => write_indented_by(
f,
format!("let {name} = {expression};\n"),
format!(
"let {name}{} = {e};\n",
type_name
.as_ref()
.map(|tn| format!(": {tn}"))
.unwrap_or_default()
),
current_path.len(),
)?,
}

View File

@@ -18,7 +18,7 @@ use crate::parsed::{
AbsoluteSymbolPath, AssignmentRegister, CallableRef, InstructionBody, OperationId, Params,
},
visitor::{ExpressionVisitable, VisitOrder},
NamespacedPolynomialReference, PilStatement,
ExpressionWithTypeName, NamespacedPolynomialReference, PilStatement,
};
use crate::SourceRef;
@@ -672,7 +672,7 @@ pub struct SubmachineDeclaration {
#[derive(Clone, Debug)]
pub enum Item<T> {
Machine(Machine<T>),
Expression(Expression<T>),
Expression(ExpressionWithTypeName<T>),
}
impl<T> Item<T> {

View File

@@ -1,5 +1,7 @@
use std::fmt::{Display, Formatter, Result};
use crate::parsed::ExpressionWithTypeName;
use super::{Link, LinkFrom, LinkTo, Location, Machine, Object, Operation, PILGraph};
impl Display for Location {
@@ -11,8 +13,15 @@ impl Display for Location {
impl<T: Display> Display for PILGraph<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
writeln!(f, "// Utilities")?;
for (name, e) in &self.definitions {
writeln!(f, "let {name} = {e};")?;
for (name, ExpressionWithTypeName { e, type_name }) in &self.definitions {
writeln!(
f,
"let {name}{} = {e};",
type_name
.as_ref()
.map(|tn| format!(": {tn}"))
.unwrap_or_default()
)?;
}
for (location, object) in &self.objects {
writeln!(f, "// Object {}", location)?;

View File

@@ -2,7 +2,7 @@ use std::collections::BTreeMap;
use crate::parsed::{
asm::{AbsoluteSymbolPath, Params},
Expression, PilStatement,
Expression, ExpressionWithTypeName, PilStatement,
};
mod display;
@@ -30,7 +30,7 @@ pub struct PILGraph<T> {
pub main: Machine,
pub entry_points: Vec<Operation<T>>,
pub objects: BTreeMap<Location, Object<T>>,
pub definitions: BTreeMap<AbsoluteSymbolPath, Expression<T>>,
pub definitions: BTreeMap<AbsoluteSymbolPath, ExpressionWithTypeName<T>>,
}
#[derive(Default, Clone)]

View File

@@ -11,7 +11,7 @@ use derive_more::From;
use crate::SourceRef;
use super::{Expression, PilStatement};
use super::{Expression, ExpressionWithTypeName, PilStatement};
#[derive(Default, Clone, Debug, PartialEq, Eq)]
pub struct ASMProgram<T> {
@@ -51,7 +51,7 @@ pub enum SymbolValue<T> {
/// A module definition
Module(Module<T>),
/// A generic symbol / function.
Expression(Expression<T>),
Expression(ExpressionWithTypeName<T>),
}
impl<T> SymbolValue<T> {
@@ -74,7 +74,7 @@ pub enum SymbolValueRef<'a, T> {
/// A module definition
Module(ModuleRef<'a, T>),
/// A generic symbol / function.
Expression(&'a Expression<T>),
Expression(&'a ExpressionWithTypeName<T>),
}
#[derive(Debug, Clone, PartialEq, Eq, From)]

View File

@@ -57,8 +57,15 @@ impl<T: Display> Display for ModuleStatement<T> {
SymbolValue::Module(m @ Module::Local(_)) => {
write!(f, "mod {name} {m}")
}
SymbolValue::Expression(e) => {
write!(f, "let {name} = {e};")
SymbolValue::Expression(ExpressionWithTypeName { e, type_name }) => {
write!(
f,
"let {name}{} = {e};",
type_name
.as_ref()
.map(|t| format!(": {t}"))
.unwrap_or_default()
)
}
},
}
@@ -373,9 +380,15 @@ impl<T: Display> Display for PilStatement<T> {
PilStatement::Namespace(_, name, poly_length) => {
write!(f, "namespace {name}({poly_length});")
}
PilStatement::LetStatement(_, name, None) => write!(f, " let {name};"),
PilStatement::LetStatement(_, name, Some(expr)) => {
write!(f, " let {name} = {expr};")
PilStatement::LetStatement(_, name, type_name, value) => {
write!(f, " let {name}")?;
if let Some(type_name) = type_name {
write!(f, ": {type_name}")?;
}
if let Some(value) = &value {
write!(f, " = {value}")?;
}
write!(f, ";")
}
PilStatement::PolynomialDefinition(_, name, value) => {
write!(f, " pol {name} = {value};")
@@ -584,6 +597,74 @@ impl Display for UnaryOperator {
}
}
impl<E: Display> Display for TypeName<E> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
match self {
TypeName::Bool => write!(f, "bool"),
TypeName::Int => write!(f, "int"),
TypeName::Fe => write!(f, "fe"),
TypeName::String => write!(f, "string"),
TypeName::Col => write!(f, "col"),
TypeName::Expr => write!(f, "expr"),
TypeName::Constr => write!(f, "constr"),
TypeName::Array(array) => write!(f, "{array}"),
TypeName::Tuple(tuple) => write!(f, "{tuple}"),
TypeName::Function(fun) => write!(f, "{fun}"),
}
}
}
impl<E: Display> Display for ArrayTypeName<E> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
if self.base.needs_parentheses() {
write!(f, "({})", self.base)
} else {
write!(f, "{}", self.base)
}?;
write!(
f,
"[{}]",
self.length
.as_ref()
.map(|l| l.to_string())
.unwrap_or_default()
)
}
}
impl<E: Display> Display for TupleTypeName<E> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(f, "({})", self.items.iter().format(", "))
}
}
impl<E: Display> Display for FunctionTypeName<E> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
let params = self
.params
.iter()
.map(|x| {
if x.needs_parentheses() {
format!("({x})")
} else {
format!("{x}")
}
})
.join(", ")
+ if self.params.is_empty() { "" } else { " " };
write!(
f,
"{params}-> {}",
if self.value.needs_parentheses() {
format!("({})", self.value)
} else {
format!("{}", self.value)
}
)
}
}
#[cfg(test)]
mod tests {

View File

@@ -24,7 +24,12 @@ pub enum PilStatement<T> {
Include(SourceRef, String),
/// Name of namespace and polynomial degree (constant)
Namespace(SourceRef, SymbolPath, Expression<T>),
LetStatement(SourceRef, String, Option<Expression<T>>),
LetStatement(
SourceRef,
String,
Option<TypeName<Expression<T>>>,
Option<Expression<T>>,
),
PolynomialDefinition(SourceRef, String, Expression<T>),
PublicDeclaration(
SourceRef,
@@ -68,7 +73,7 @@ impl<T> PilStatement<T> {
| PilStatement::PolynomialConstantDefinition(_, name, _)
| PilStatement::ConstantDefinition(_, name, _)
| PilStatement::PublicDeclaration(_, name, _, _, _)
| PilStatement::LetStatement(_, name, _) => Box::new(once(name)),
| PilStatement::LetStatement(_, name, _, _) => Box::new(once(name)),
PilStatement::PolynomialConstantDeclaration(_, polynomials)
| PilStatement::PolynomialCommitDeclaration(_, polynomials, _) => {
Box::new(polynomials.iter().map(|p| &p.name))
@@ -98,8 +103,11 @@ impl<T> PilStatement<T> {
| PilStatement::Namespace(_, _, e)
| PilStatement::PolynomialDefinition(_, _, e)
| PilStatement::PolynomialIdentity(_, e)
| PilStatement::ConstantDefinition(_, _, e)
| PilStatement::LetStatement(_, _, Some(e)) => Box::new(once(e)),
| PilStatement::ConstantDefinition(_, _, e) => Box::new(once(e)),
PilStatement::LetStatement(_, _, type_name, value) => {
Box::new(type_name.iter().flat_map(|t| t.expressions()).chain(value))
}
PilStatement::PublicDeclaration(_, _, _, i, e) => Box::new(i.iter().chain(once(e))),
@@ -107,8 +115,7 @@ impl<T> PilStatement<T> {
| PilStatement::PolynomialCommitDeclaration(_, _, Some(fundef)) => fundef.expressions(),
PilStatement::PolynomialCommitDeclaration(_, _, None)
| PilStatement::Include(_, _)
| PilStatement::PolynomialConstantDeclaration(_, _)
| PilStatement::LetStatement(_, _, None) => Box::new(empty()),
| PilStatement::PolynomialConstantDeclaration(_, _) => Box::new(empty()),
}
}
@@ -126,8 +133,14 @@ impl<T> PilStatement<T> {
| PilStatement::Namespace(_, _, e)
| PilStatement::PolynomialDefinition(_, _, e)
| PilStatement::PolynomialIdentity(_, e)
| PilStatement::ConstantDefinition(_, _, e)
| PilStatement::LetStatement(_, _, Some(e)) => Box::new(once(e)),
| PilStatement::ConstantDefinition(_, _, e) => Box::new(once(e)),
PilStatement::LetStatement(_, _, type_name, value) => Box::new(
type_name
.iter_mut()
.flat_map(|t| t.expressions_mut())
.chain(value),
),
PilStatement::PublicDeclaration(_, _, _, i, e) => Box::new(i.iter_mut().chain(once(e))),
@@ -137,8 +150,7 @@ impl<T> PilStatement<T> {
}
PilStatement::PolynomialCommitDeclaration(_, _, None)
| PilStatement::Include(_, _)
| PilStatement::PolynomialConstantDeclaration(_, _)
| PilStatement::LetStatement(_, _, None) => Box::new(empty()),
| PilStatement::PolynomialConstantDeclaration(_, _) => Box::new(empty()),
}
}
}
@@ -496,3 +508,141 @@ impl<T> ArrayExpression<T> {
}
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub enum TypeName<E> {
/// Boolean
Bool,
/// Integer (arbitrary precision)
Int,
/// Field element (unspecified field)
Fe,
/// String
String,
/// Column, shorthand for "int -> fe"
Col,
/// Algebraic expression
Expr,
/// Polynomial identity
Constr,
Array(ArrayTypeName<E>),
Tuple(TupleTypeName<E>),
Function(FunctionTypeName<E>),
}
impl<E> TypeName<E> {
/// Returns true if the type name needs parentheses during formatting
/// when used inside a complex expression.
pub fn needs_parentheses(&self) -> bool {
match self {
TypeName::Bool
| TypeName::Int
| TypeName::Fe
| TypeName::String
| TypeName::Col
| TypeName::Expr
| TypeName::Constr
| TypeName::Array(_)
| TypeName::Tuple(_) => false,
TypeName::Function(_) => true,
}
}
/// Returns an iterator over all (top-level) expressions in this type name.
pub fn expressions(&self) -> Box<dyn Iterator<Item = &E> + '_> {
match self {
TypeName::Bool
| TypeName::Int
| TypeName::Fe
| TypeName::String
| TypeName::Col
| TypeName::Expr
| TypeName::Constr => Box::new(empty()),
TypeName::Array(a) => a.expressions(),
TypeName::Tuple(t) => t.expressions(),
TypeName::Function(f) => f.expressions(),
}
}
/// Returns an iterator over all (top-level) expressions in this type name.
pub fn expressions_mut(&mut self) -> Box<dyn Iterator<Item = &mut E> + '_> {
match self {
TypeName::Bool
| TypeName::Int
| TypeName::Fe
| TypeName::String
| TypeName::Col
| TypeName::Expr
| TypeName::Constr => Box::new(empty()),
TypeName::Array(a) => a.expressions_mut(),
TypeName::Tuple(t) => t.expressions_mut(),
TypeName::Function(f) => f.expressions_mut(),
}
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub struct ArrayTypeName<E> {
pub base: Box<TypeName<E>>,
pub length: Option<E>,
}
impl<E> ArrayTypeName<E> {
/// Returns an iterator over all (top-level) expressions in this type name.
pub fn expressions(&self) -> Box<dyn Iterator<Item = &E> + '_> {
Box::new(self.base.expressions().chain(self.length.iter()))
}
/// Returns an iterator over all (top-level) expressions in this type name.
pub fn expressions_mut(&mut self) -> Box<dyn Iterator<Item = &mut E> + '_> {
Box::new(self.base.expressions_mut().chain(self.length.iter_mut()))
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub struct TupleTypeName<E> {
pub items: Vec<TypeName<E>>,
}
impl<E> TupleTypeName<E> {
/// Returns an iterator over all (top-level) expressions in this type name.
pub fn expressions(&self) -> Box<dyn Iterator<Item = &E> + '_> {
Box::new(self.items.iter().flat_map(|t| t.expressions()))
}
/// Returns an iterator over all (top-level) expressions in this type name.
pub fn expressions_mut(&mut self) -> Box<dyn Iterator<Item = &mut E> + '_> {
Box::new(self.items.iter_mut().flat_map(|t| t.expressions_mut()))
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub struct FunctionTypeName<E> {
pub params: Vec<TypeName<E>>,
pub value: Box<TypeName<E>>,
}
impl<E> FunctionTypeName<E> {
/// Returns an iterator over all (top-level) expressions in this type name.
pub fn expressions(&self) -> Box<dyn Iterator<Item = &E> + '_> {
Box::new(
self.params
.iter()
.flat_map(|t| t.expressions())
.chain(self.value.expressions()),
)
}
/// Returns an iterator over all (top-level) expressions in this type name.
pub fn expressions_mut(&mut self) -> Box<dyn Iterator<Item = &mut E> + '_> {
Box::new(
self.params
.iter_mut()
.flat_map(|t| t.expressions_mut())
.chain(self.value.expressions_mut()),
)
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub struct ExpressionWithTypeName<T, Ref = NamespacedPolynomialReference> {
pub e: Expression<T, Ref>,
pub type_name: Option<TypeName<Expression<T, Ref>>>,
}

View File

@@ -1,9 +1,9 @@
use std::{iter::once, ops::ControlFlow};
use super::{
ArrayExpression, ArrayLiteral, Expression, FunctionCall, FunctionDefinition, IfExpression,
IndexAccess, LambdaExpression, MatchArm, MatchPattern, NamespacedPolynomialReference,
PilStatement, SelectedExpressions,
ArrayExpression, ArrayLiteral, ArrayTypeName, Expression, FunctionCall, FunctionDefinition,
FunctionTypeName, IfExpression, IndexAccess, LambdaExpression, MatchArm, MatchPattern,
NamespacedPolynomialReference, PilStatement, SelectedExpressions, TupleTypeName, TypeName,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
@@ -205,8 +205,17 @@ impl<T> ExpressionVisitable<Expression<T, NamespacedPolynomialReference>> for Pi
| PilStatement::PolynomialDefinition(_, _, e)
| PilStatement::PolynomialIdentity(_, e)
| PilStatement::PublicDeclaration(_, _, _, None, e)
| PilStatement::ConstantDefinition(_, _, e)
| PilStatement::LetStatement(_, _, Some(e)) => e.visit_expressions_mut(f, o),
| PilStatement::ConstantDefinition(_, _, e) => e.visit_expressions_mut(f, o),
PilStatement::LetStatement(_, _, type_name, value) => {
if let Some(t) = type_name {
t.visit_expressions_mut(f, o)?;
};
if let Some(v) = value {
v.visit_expressions_mut(f, o)?;
};
ControlFlow::Continue(())
}
PilStatement::PublicDeclaration(_, _, _, Some(i), e) => [i, e]
.into_iter()
@@ -218,8 +227,7 @@ impl<T> ExpressionVisitable<Expression<T, NamespacedPolynomialReference>> for Pi
}
PilStatement::PolynomialCommitDeclaration(_, _, None)
| PilStatement::Include(_, _)
| PilStatement::PolynomialConstantDeclaration(_, _)
| PilStatement::LetStatement(_, _, None) => ControlFlow::Continue(()),
| PilStatement::PolynomialConstantDeclaration(_, _) => ControlFlow::Continue(()),
}
}
@@ -242,8 +250,17 @@ impl<T> ExpressionVisitable<Expression<T, NamespacedPolynomialReference>> for Pi
| PilStatement::PolynomialDefinition(_, _, e)
| PilStatement::PolynomialIdentity(_, e)
| PilStatement::PublicDeclaration(_, _, _, None, e)
| PilStatement::ConstantDefinition(_, _, e)
| PilStatement::LetStatement(_, _, Some(e)) => e.visit_expressions(f, o),
| PilStatement::ConstantDefinition(_, _, e) => e.visit_expressions(f, o),
PilStatement::LetStatement(_, _, type_name, value) => {
if let Some(t) = type_name {
t.visit_expressions(f, o)?;
};
if let Some(v) = value {
v.visit_expressions(f, o)?;
};
ControlFlow::Continue(())
}
PilStatement::PublicDeclaration(_, _, _, Some(i), e) => [i, e]
.into_iter()
@@ -255,8 +272,7 @@ impl<T> ExpressionVisitable<Expression<T, NamespacedPolynomialReference>> for Pi
}
PilStatement::PolynomialCommitDeclaration(_, _, None)
| PilStatement::Include(_, _)
| PilStatement::PolynomialConstantDeclaration(_, _)
| PilStatement::LetStatement(_, _, None) => ControlFlow::Continue(()),
| PilStatement::PolynomialConstantDeclaration(_, _) => ControlFlow::Continue(()),
}
}
}
@@ -478,3 +494,105 @@ impl<T, Ref> ExpressionVisitable<Expression<T, Ref>> for IfExpression<T, Ref> {
.try_for_each(|e| e.visit_expressions(f, o))
}
}
impl<E: ExpressionVisitable<E>> ExpressionVisitable<E> for TypeName<E> {
fn visit_expressions_mut<F, B>(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
where
F: FnMut(&mut E) -> ControlFlow<B>,
{
match self {
TypeName::Bool
| TypeName::Int
| TypeName::Fe
| TypeName::String
| TypeName::Col
| TypeName::Expr
| TypeName::Constr => ControlFlow::Continue(()),
TypeName::Array(a) => a.visit_expressions_mut(f, o),
TypeName::Tuple(t) => t.visit_expressions_mut(f, o),
TypeName::Function(fun) => fun.visit_expressions_mut(f, o),
}
}
fn visit_expressions<F, B>(&self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
where
F: FnMut(&E) -> ControlFlow<B>,
{
match self {
TypeName::Bool
| TypeName::Int
| TypeName::Fe
| TypeName::String
| TypeName::Col
| TypeName::Expr
| TypeName::Constr => ControlFlow::Continue(()),
TypeName::Array(a) => a.visit_expressions(f, o),
TypeName::Tuple(t) => t.visit_expressions(f, o),
TypeName::Function(fun) => fun.visit_expressions(f, o),
}
}
}
impl<E: ExpressionVisitable<E>> ExpressionVisitable<E> for ArrayTypeName<E> {
fn visit_expressions_mut<F, B>(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
where
F: FnMut(&mut E) -> ControlFlow<B>,
{
self.base.visit_expressions_mut(f, o)?;
self.length
.iter_mut()
.try_for_each(|e| e.visit_expressions_mut(f, o))
}
fn visit_expressions<F, B>(&self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
where
F: FnMut(&E) -> ControlFlow<B>,
{
self.base.visit_expressions(f, o)?;
self.length
.iter()
.try_for_each(|e| e.visit_expressions(f, o))
}
}
impl<E: ExpressionVisitable<E>> ExpressionVisitable<E> for TupleTypeName<E> {
fn visit_expressions_mut<F, B>(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
where
F: FnMut(&mut E) -> ControlFlow<B>,
{
self.items
.iter_mut()
.try_for_each(|i| i.visit_expressions_mut(f, o))
}
fn visit_expressions<F, B>(&self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
where
F: FnMut(&E) -> ControlFlow<B>,
{
self.items
.iter()
.try_for_each(|i| i.visit_expressions(f, o))
}
}
impl<E: ExpressionVisitable<E>> ExpressionVisitable<E> for FunctionTypeName<E> {
fn visit_expressions_mut<F, B>(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
where
F: FnMut(&mut E) -> ControlFlow<B>,
{
self.params
.iter_mut()
.chain(once(self.value.as_mut()))
.try_for_each(|i| i.visit_expressions_mut(f, o))
}
fn visit_expressions<F, B>(&self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
where
F: FnMut(&E) -> ControlFlow<B>,
{
self.params
.iter()
.chain(once(self.value.as_ref()))
.try_for_each(|i| i.visit_expressions(f, o))
}
}

View File

@@ -1,37 +1,26 @@
# Declarations
Powdr-pil allows the same syntax to declare various kinds of symbols. This includes
constants, fixed columns, witness columns and even macros. It deduces the symbol kind
by its type and the way the symbol is used.
constants, fixed columns, witness columns and even higher-order functions. It deduces the symbol kind
from the type of the symbol and the way the symbol is used.
Symbols can be declared using ``let <name>;`` and they can be declared and defined
using ``let <name> = <value>;``, where ``<value>`` is an expression.
using ``let <name> = <value>;``, where ``<value>`` is an expression. The [type](./types.md) of the symbol
can be explicitly specified using ``let <name>: <type>;`` and ``let <name>: <type> = <value>;``.
This syntax can be used for constants, fixed columns, witness columns and even (higher-order)
functions that can transform expressions. The kind of symbol is deduced by its type and the
way the symbol is used:
- symbols without a value are witness columns,
- symbols evaluating to a number are constants,
- symbols defined as a function with a single parameter are fixed columns and
- everything else is a "generic symbol" that is not a column.
- Symbols without a value are witness columns. Their type can be omitted. If it is given, it must be ``int -> fe`` or its shorthand ``col``.
- Symbols evaluating to a number or with type ``fe`` are constants.
- Symbols without type but with a value that is a function with a single parameter are fixed columns.
- Symbols defined with a value and type ``int -> fe`` or its shorthand ``col`` are also fixed columns.
- Everything else is a "generic symbol" that is not a column or constant.
Examples:
```rust
// This defines a constant
let rows = 2**16;
// This defines a fixed column that contains the row number in each row.
let step = |i| i;
// Here, we have a witness column.
let x;
// This functions returns the square of its input (classified as a fixed column).
let square = |x| x*x;
// A recursive function, taking a function and an integer as parameter
let sum = |f, i| match i {
0 => f(0),
_ => f(i) + sum(f, i - 1)
};
// The same function as "square" above, but employing a trick to avoid it
// being classified as a column.
let square_non_column = (|| |x| x*x)();
{{#include ../../../test_data/pil/book/declarations.pil:declarations}}
```

View File

@@ -1,7 +1,10 @@
use std::{collections::HashMap, fmt::Display, rc::Rc};
use itertools::Itertools;
use powdr_ast::analyzed::{Analyzed, FunctionValueDefinition};
use powdr_ast::analyzed::{
types::{Type, TypedExpression},
Analyzed, FunctionValueDefinition,
};
use powdr_number::{DegreeType, FieldElement};
use powdr_pil_analyzer::evaluator::{self, Custom, EvalError, SymbolLookup, Value};
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
@@ -43,20 +46,25 @@ fn generate_values<T: FieldElement>(
};
// TODO we should maybe pre-compute some symbols here.
let result = match body {
FunctionValueDefinition::Expression(e) => (0..degree)
.into_par_iter()
.map(|i| {
// We could try to avoid the first evaluation to be run for each iteration,
// but the data is not thread-safe.
let fun = evaluator::evaluate(e, &symbols).unwrap();
evaluator::evaluate_function_call(
fun,
vec![Rc::new(Value::Integer(num_bigint::BigInt::from(i)))],
&symbols,
)
.and_then(|v| v.try_to_field_element())
})
.collect::<Result<Vec<_>, _>>(),
FunctionValueDefinition::Expression(TypedExpression { e, ty }) => {
if let Some(ty) = ty {
assert_eq!(ty, &Type::col())
};
(0..degree)
.into_par_iter()
.map(|i| {
// We could try to avoid the first evaluation to be run for each iteration,
// but the data is not thread-safe.
let fun = evaluator::evaluate(e, &symbols).unwrap();
evaluator::evaluate_function_call(
fun,
vec![Rc::new(Value::Integer(num_bigint::BigInt::from(i)))],
&symbols,
)
.and_then(|v| v.try_to_field_element())
})
.collect::<Result<Vec<_>, _>>()
}
FunctionValueDefinition::Array(values) => values
.iter()
.map(|elements| {
@@ -103,8 +111,8 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, FixedColumnRef<'a>> for Symbols<'a
Value::Custom(FixedColumnRef { name })
} else if let Some((_, value)) = self.analyzed.definitions.get(&name.to_string()) {
match value {
Some(FunctionValueDefinition::Expression(value)) => {
evaluator::evaluate(value, self)?
Some(FunctionValueDefinition::Expression(TypedExpression { e, ty: _ })) => {
evaluator::evaluate(e, self)?
}
Some(_) => Err(EvalError::Unsupported(
"Cannot evaluate arrays and queries.".to_string(),
@@ -268,7 +276,7 @@ mod test {
let src = r#"
constant %N = 8;
namespace F(%N);
let minus_one = [|x| x - 1][0];
let minus_one: int -> int = |x| x - 1;
pol constant EVEN(i) { 2 * minus_one(i) + 2 };
"#;
let analyzed = analyze_string(src);

View File

@@ -15,7 +15,8 @@ use powdr_ast::{
},
folder::Folder,
visitor::ExpressionVisitable,
ArrayLiteral, FunctionCall, IndexAccess, LambdaExpression, MatchArm,
ArrayLiteral, ExpressionWithTypeName, FunctionCall, IndexAccess, LambdaExpression,
MatchArm,
},
};
@@ -79,9 +80,14 @@ impl<'a, T> Folder<T> for Canonicalizer<'a> {
.map(Some)
.transpose(),
},
SymbolValue::Expression(mut e) => {
canonicalize_inside_expression(&mut e, &self.path, self.paths);
Some(Ok(SymbolValue::Expression(e)))
SymbolValue::Expression(mut exp) => {
for tne in
exp.type_name.iter_mut().flat_map(|tn| tn.expressions_mut())
{
canonicalize_inside_expression(tne, &self.path, self.paths);
}
canonicalize_inside_expression(&mut exp.e, &self.path, self.paths);
Some(Ok(SymbolValue::Expression(exp)))
}
}
.map(|value| value.map(|value| SymbolDefinition { name, value }.into()))
@@ -321,7 +327,10 @@ fn check_module<T: Clone>(
check_module(location.with_part(name), m, state)?;
}
SymbolValue::Import(s) => check_import(location.clone(), s.clone(), state)?,
SymbolValue::Expression(e) => {
SymbolValue::Expression(ExpressionWithTypeName { e, type_name }) => {
for tne in type_name.iter().flat_map(|tn| tn.expressions()) {
check_expression(&location, tne, state, &HashSet::default())?
}
check_expression(&location, e, state, &HashSet::default())?
}
}

View File

@@ -7,7 +7,7 @@ use powdr_ast::{
asm::AbsoluteSymbolPath,
asm::SymbolPath,
build::{direct_reference, index_access, namespaced_reference},
Expression, PILFile, PilStatement, SelectedExpressions,
Expression, ExpressionWithTypeName, PILFile, PilStatement, SelectedExpressions,
},
SourceRef,
};
@@ -42,9 +42,14 @@ pub fn link<T: FieldElement>(graph: PILGraph<T>) -> Result<PILFile<T>, Vec<Strin
// Group by namespace and then sort by name.
(namespace, name)
})
.flat_map(|(mut namespace, e)| {
.flat_map(|(mut namespace, ExpressionWithTypeName { e, type_name })| {
let name = namespace.pop().unwrap();
let def = PilStatement::LetStatement(SourceRef::unknown(), name.to_string(), Some(e));
let def = PilStatement::LetStatement(
SourceRef::unknown(),
name.to_string(),
type_name,
Some(e),
);
// If there is a namespace change, insert a namespace statement.
if current_namespace != namespace {

View File

@@ -224,7 +224,7 @@ mod test {
match stmt {
PilStatement::Include(s, _)
| PilStatement::Namespace(s, _, _)
| PilStatement::LetStatement(s, _, _)
| PilStatement::LetStatement(s, _, _, _)
| PilStatement::PolynomialDefinition(s, _, _)
| PilStatement::PublicDeclaration(s, _, _, _, _)
| PilStatement::PolynomialConstantDeclaration(s, _)
@@ -385,7 +385,7 @@ mod test {
constant %N = 16;
namespace Fibonacci(%N);
constant %last_row = (%N - 1);
let bool = [(|X| (X * (1 - X)))][0];
let bool: expr -> expr = (|X| (X * (1 - X)));
let one_hot = (|i, which| match i { which => 1, _ => 0, });
pol constant ISLAST(i) { one_hot(i, %last_row) };
pol commit arr[8];
@@ -440,5 +440,39 @@ namespace Fibonacci(%N);
);
assert_eq!(input.trim(), printed.trim());
}
#[test]
fn type_names_simple() {
let input = r#"
let a: col;
let b: int;
let c: fe;
let d: int[];
let e: int[7];
let f: (int, fe, fe[3])[2];"#;
let printed = format!(
"{}",
parse::<GoldilocksField>(Some("input"), input).unwrap()
);
assert_eq!(input.trim(), printed.trim());
}
#[test]
fn type_names_complex() {
let input = r#"
let a: int -> fe;
let b: int -> ();
let c: -> ();
let d: int, int -> fe;
let e: int, int -> (fe, int[2]);
let f: ((int, fe), fe[2] -> (fe -> int))[];
let g: (int -> fe) -> int;
let h: int -> (fe -> int);"#;
let printed = format!(
"{}",
parse::<GoldilocksField>(Some("input"), input).unwrap()
);
assert_eq!(input.trim(), printed.trim());
}
}
}

View File

@@ -57,7 +57,11 @@ Part: Part = {
}
LetStatementAtModuleLevel: SymbolDefinition<T> = {
"let" <name:Identifier> "=" <value:Expression> ";" => SymbolDefinition { name, value: SymbolValue::Expression(value) }
"let" <name:Identifier> <type_name:(":" <TypeName>)?> "=" <value:Expression> ";" =>
SymbolDefinition {
name,
value: SymbolValue::Expression(ExpressionWithTypeName{ e: value, type_name })
}
}
// ---------------------------- PIL part -----------------------------
@@ -88,7 +92,7 @@ Namespace: PilStatement<T> = {
}
LetStatement: PilStatement<T> = {
<start:@L> "let" <id:Identifier> <expr:( "=" <Expression> )?> => PilStatement::LetStatement(ctx.source_ref(start), id, expr)
<start:@L> "let" <id:Identifier> <type_name:(":" <TypeName>)?> <expr:( "=" <Expression> )?> => PilStatement::LetStatement(ctx.source_ref(start), id, type_name, expr)
}
ConstantDefinition: PilStatement<T> = {
@@ -540,6 +544,36 @@ IfExpression: Box<Expression<T>> = {
"{" <else_body:BoxedExpression> "}" => Box::new(Expression::IfExpression(IfExpression{<>}))
}
// ---------------------------- Type Names -----------------------------
TypeName: TypeName<Expression<T>> = {
<params:TypeNameTermList> "->" <value:TypeNameTermBox> => TypeName::Function(FunctionTypeName{<>}),
TypeNameTerm
}
TypeNameTermList: Vec<TypeName<Expression<T>>> = {
=> vec![],
<mut list:( <TypeNameTerm> "," )*> <end:TypeNameTerm> => { list.push(end); list }
}
TypeNameTermBox: Box<TypeName<Expression<T>>> = {
TypeNameTerm => Box::new(<>)
}
TypeNameTerm: TypeName<Expression<T>> = {
"bool" => TypeName::Bool,
"int" => TypeName::Int,
"fe" => TypeName::Fe,
"string" => TypeName::String,
"col" => TypeName::Col,
"expr" => TypeName::Expr,
"constr" => TypeName::Constr,
<base:TypeNameTerm> "[" <length:Expression?> "]" => TypeName::Array(ArrayTypeName{base: Box::new(base), length}),
"(" <mut items:( <TypeNameTerm> "," )+> <end:TypeNameTerm> ")" => { items.push(end); TypeName::Tuple(TupleTypeName{items}) },
"(" ")" => TypeName::Tuple(TupleTypeName{items: vec![]}),
"(" <TypeName> ")",
}
// ---------------------------- Terminals -----------------------------
@@ -561,6 +595,8 @@ SpecialIdentifier: &'input str = {
"insn",
"int",
"fe",
"expr",
"constr",
"bool",
}

View File

@@ -6,6 +6,7 @@ use std::{collections::HashMap, fmt::Display, rc::Rc};
use itertools::Itertools;
use powdr_ast::{
analyzed::{
types::{Type, TypedExpression},
AlgebraicExpression, AlgebraicReference, Analyzed, Expression, FunctionValueDefinition,
Identity, IdentityKind, PolynomialReference, PolynomialType, PublicDeclaration, Reference,
StatementIdentifier, Symbol, SymbolKind,
@@ -58,9 +59,10 @@ pub fn condense<T: FieldElement>(
let Some(FunctionValueDefinition::Expression(e)) = definition else {
panic!("Expected expression")
};
assert!(e.ty.is_none() || e.ty == Some(Type::col()));
Some((
name.clone(),
(symbol.clone(), condenser.condense_expression(e)),
(symbol.clone(), condenser.condense_expression(&e.e)),
))
} else {
None
@@ -231,7 +233,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, Condensate<T>> for &'a Condenser<T
}
} else {
match value {
Some(FunctionValueDefinition::Expression(value)) => {
Some(FunctionValueDefinition::Expression(TypedExpression { e: value, ty: _ })) => {
evaluator::evaluate(value, self)?
}
_ => Err(EvalError::Unsupported(
@@ -278,8 +280,8 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, Condensate<T>> for &'a Condenser<T
};
match self.symbols[&name].1.as_ref() {
Some(FunctionValueDefinition::Expression(v)) => {
let function = evaluate(v, self)?;
Some(FunctionValueDefinition::Expression(TypedExpression { e, ty: _ })) => {
let function = evaluate(e, self)?;
evaluate_function_call(function, arguments, self)
}
None => Err(EvalError::SymbolNotFound(format!(
@@ -390,7 +392,7 @@ impl<T: FieldElement> Custom for Condensate<T> {
fn type_name(&self) -> String {
match self {
Condensate::Expression(_) => "expr".to_string(),
Condensate::Identity(_, _) => "identity".to_string(),
Condensate::Identity(_, _) => "constr".to_string(),
}
}
}

View File

@@ -6,7 +6,7 @@ use std::{
use itertools::Itertools;
use powdr_ast::{
analyzed::{Expression, FunctionValueDefinition, Reference, Symbol},
analyzed::{types::TypedExpression, Expression, FunctionValueDefinition, Reference, Symbol},
parsed::{
display::quote, BinaryOperator, FunctionCall, LambdaExpression, MatchArm, MatchPattern,
UnaryOperator,
@@ -189,7 +189,7 @@ const BUILTINS: [(&str, BuiltinFunction); 6] = [
#[derive(Clone, Copy, PartialEq, Debug)]
pub enum BuiltinFunction {
/// std::array::len: [_] -> int, returns the length of an array
/// std::array::len: _[] -> int, returns the length of an array
ArrayLen,
/// std::field::modulus: -> int, returns the field modulus as int
Modulus,
@@ -282,7 +282,9 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Definitions<'a, T> {
fn lookup(&self, name: &'a str) -> Result<Value<'a, T, NoCustom>, EvalError> {
Ok(match self.0.get(&name.to_string()) {
Some((_, value)) => match value {
Some(FunctionValueDefinition::Expression(value)) => evaluate(value, self)?,
Some(FunctionValueDefinition::Expression(TypedExpression { e, ty: _ })) => {
evaluate(e, self)?
}
_ => Err(EvalError::Unsupported(
"Cannot evaluate arrays and queries.".to_string(),
))?,
@@ -640,7 +642,8 @@ mod test {
fn parse_and_evaluate_symbol(input: &str, symbol: &str) -> String {
let analyzed = analyze_string::<GoldilocksField>(input);
let Some(FunctionValueDefinition::Expression(symbol)) = &analyzed.definitions[symbol].1
let Some(FunctionValueDefinition::Expression(TypedExpression { e: symbol, ty: _ })) =
&analyzed.definitions[symbol].1
else {
panic!()
};
@@ -686,8 +689,8 @@ mod test {
#[test]
pub fn capturing() {
let src = r#"namespace Main(16);
let f = |n, g| match n { 99 => |i| n, 1 => g(3) };
let result = f(1, f(99, |x| x + 3000));
let f: int, (int -> int) -> (int -> int) = |n, g| match n { 99 => |i| n, 1 => g };
let result = f(1, f(99, |x| x + 3000))(0);
"#;
// If the lambda function returned by the expression f(99, ...) does not
// properly capture the value of n in a closure, then f(1, ...) would return 1.

View File

@@ -505,6 +505,58 @@ namespace N(16);
let expected = r#"namespace N(16);
let w = (|| 2);
constant x = (|i| (|| N.w()))(2)();
"#;
let formatted = analyze_string::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);
}
#[test]
fn simple_type_resolution() {
let input = r#"namespace N(16);
let w: col[3 + 4];
"#;
let expected = r#"namespace N(16);
col witness w[7];
"#;
let formatted = analyze_string::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);
}
#[test]
fn complex_type_resolution() {
let input = r#"namespace N(16);
let f: int -> int = |i| i + 10;
let x: (int -> int), int -> int = |k, i| k(2**i);
let y: (int -> fe)[x(f, 2)];
let z: (((int -> int), int -> int)[x(|i| i, 3)], col) = ([x, x, x, x, x, x, x, x], y[0]);
"#;
let expected = r#"namespace N(16);
let f: int -> int = (|i| (i + 10));
let x: (int -> int), int -> int = (|k, i| k((2 ** i)));
col witness y[14];
let z: (((int -> int), int -> int)[8], col) = ([N.x, N.x, N.x, N.x, N.x, N.x, N.x, N.x], N.y[0]);
"#;
let formatted = analyze_string::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);
}
#[test]
fn expr_and_identity() {
let input = r#"namespace N(16);
let f: expr, expr -> constr[] = |x, y| [x == y];
let g: expr -> constr[1] = |x| [x == 0];
let x: col;
let y: col;
f(x, y);
g((x));
"#;
let expected = r#"namespace N(16);
let f: expr, expr -> constr[] = (|x, y| [(x == y)]);
let g: expr -> constr[1] = (|x| [(x == 0)]);
col witness x;
col witness y;
N.x = N.y;
N.x = 0;
"#;
let formatted = analyze_string::<GoldilocksField>(input).to_string();
assert_eq!(formatted, expected);

View File

@@ -1,8 +1,9 @@
use std::collections::{BTreeMap, HashMap};
use std::marker::PhantomData;
use powdr_ast::analyzed::types::{ArrayType, Type, TypedExpression};
use powdr_ast::parsed::{
self, FunctionDefinition, PilStatement, PolynomialName, SelectedExpressions,
self, FunctionDefinition, PilStatement, PolynomialName, SelectedExpressions, TypeName,
};
use powdr_ast::SourceRef;
use powdr_number::{DegreeType, FieldElement};
@@ -103,8 +104,8 @@ where
.handle_symbol_definition(
source,
name,
None,
SymbolKind::Poly(PolynomialType::Intermediate),
Some(Type::col()),
Some(FunctionDefinition::Expression(value)),
),
PilStatement::PublicDeclaration(source, name, polynomial, array_index, index) => {
@@ -117,8 +118,8 @@ where
.handle_symbol_definition(
source,
name,
None,
SymbolKind::Poly(PolynomialType::Constant),
Some(Type::col()),
Some(definition),
),
PilStatement::PolynomialCommitDeclaration(source, polynomials, None) => {
@@ -130,12 +131,14 @@ where
Some(definition),
) => {
assert!(polynomials.len() == 1);
let name = polynomials.pop().unwrap();
let (name, ty) =
self.name_and_type_from_polynomial_name(polynomials.pop().unwrap());
self.handle_symbol_definition(
source,
name.name,
name.array_size,
name,
SymbolKind::Poly(PolynomialType::Committed),
ty,
Some(definition),
)
}
@@ -147,24 +150,53 @@ where
self.handle_symbol_definition(
source,
name,
None,
SymbolKind::Constant(),
Some(Type::Fe),
Some(FunctionDefinition::Expression(value)),
)
}
PilStatement::LetStatement(source, name, value) => {
self.handle_generic_definition(source, name, value)
PilStatement::LetStatement(source, name, type_name, value) => {
self.handle_generic_definition(source, name, type_name, value)
}
_ => self.handle_identity_statement(statement),
}
}
fn name_and_type_from_polynomial_name(
&mut self,
PolynomialName { name, array_size }: PolynomialName<T>,
) -> (String, Option<Type>) {
let ty = Some(match array_size {
None => Type::col(),
Some(len) => {
let length = self
.evaluate_expression(len)
.map_err(|e| {
panic!("Error evaluating length of array of witness columns {name}:\n{e}")
})
.map(|length| length.to_degree())
.ok();
Type::Array(ArrayType {
base: Box::new(Type::col()),
length,
})
}
});
(name, ty)
}
fn handle_generic_definition(
&mut self,
source: SourceRef,
name: String,
value: Option<::powdr_ast::parsed::Expression<T>>,
type_name: Option<TypeName<parsed::Expression<T>>>,
value: Option<parsed::Expression<T>>,
) -> Vec<PILItem<T>> {
let ty = type_name.map(|n|
self.resolve_type_name(n.clone())
.map_err(|e| panic!("Error evaluating expressions in type name \"{n}\" to reduce it to a type:\n{e})"))
.unwrap()
);
// Determine whether this is a fixed column, a constant or something else
// depending on the structure of the value and if we can evaluate
// it to a single number.
@@ -172,30 +204,56 @@ where
match value {
None => {
// No value provided => treat it as a witness column.
let ty = ty
.map(|t| {
if let Type::Array(ArrayType { base, length }) = &t {
if base.as_ref() != &Type::col() {
panic!("Symbol {name} is declared without value and thus must be a witness column array, but its type is {t} instead of col[].");
}
if length.is_none() {
panic!("Explicit array length required for column {name}: {t}");
}
t
} else {
if t != Type::col() {
panic!("Symbol {name} is declared without value and thus must be a witness column, but its type is {t} instead of col.");
}
t
}
})
.unwrap_or(Type::col());
self.handle_symbol_definition(
source,
name,
None,
SymbolKind::Poly(PolynomialType::Committed),
Some(ty),
None,
)
}
Some(value) => {
let symbol_kind = if matches!(&value, parsed::Expression::LambdaExpression(lambda) if lambda.params.len() == 1)
// TODO if we have proper type deduction here in the future, we can rely only on the type.
let (ty, symbol_kind) = if ty == Some(Type::col())
|| (ty.is_none()
&& matches!(&value, parsed::Expression::LambdaExpression(lambda) if lambda.params.len() == 1))
{
(
Some(Type::col()),
SymbolKind::Poly(PolynomialType::Constant),
)
} else if ty == Some(Type::Fe)
|| (ty.is_none() && self.evaluate_expression(value.clone()).is_ok())
{
SymbolKind::Poly(PolynomialType::Constant)
} else if self.evaluate_expression(value.clone()).is_ok() {
// Value evaluates to a constant number => treat it as a constant
SymbolKind::Constant()
(Some(Type::Fe), SymbolKind::Constant())
} else {
// Otherwise, treat it as "generic definition"
SymbolKind::Other()
(ty, SymbolKind::Other())
};
self.handle_symbol_definition(
source,
name,
None,
symbol_kind,
ty,
Some(FunctionDefinition::Expression(value)),
)
}
@@ -261,12 +319,13 @@ where
) -> Vec<PILItem<T>> {
polynomials
.into_iter()
.flat_map(|PolynomialName { name, array_size }| {
.flat_map(|poly_name| {
let (name, ty) = self.name_and_type_from_polynomial_name(poly_name);
self.handle_symbol_definition(
source.clone(),
name,
array_size,
SymbolKind::Poly(polynomial_type),
ty,
None,
)
})
@@ -277,17 +336,20 @@ where
&mut self,
source: SourceRef,
name: String,
array_size: Option<::powdr_ast::parsed::Expression<T>>,
symbol_kind: SymbolKind,
ty: Option<Type>,
value: Option<FunctionDefinition<T>>,
) -> Vec<PILItem<T>> {
let have_array_size = array_size.is_some();
let length = array_size
.map(|l| self.evaluate_expression(l).unwrap())
.map(|l| l.to_degree());
if length.is_some() {
assert!(value.is_none());
}
let length = ty.as_ref().and_then(|t| {
if let Type::Array(ArrayType { length, base: _ }) = t {
if length.is_none() && symbol_kind != SymbolKind::Other() {
panic!("Explicit array length required for column {name}.");
}
*length
} else {
None
}
});
let id = self.counters.dispense_symbol_id(symbol_kind, length);
let name = self.driver.resolve_decl(&name);
let symbol = Symbol {
@@ -300,13 +362,15 @@ where
let value = value.map(|v| match v {
FunctionDefinition::Expression(expr) => {
assert!(!have_array_size);
assert!(symbol_kind != SymbolKind::Poly(PolynomialType::Committed));
FunctionValueDefinition::Expression(self.process_expression(expr))
FunctionValueDefinition::Expression(TypedExpression {
e: self.process_expression(expr),
ty,
})
}
FunctionDefinition::Query(expr) => {
assert!(!have_array_size);
assert_eq!(symbol_kind, SymbolKind::Poly(PolynomialType::Committed));
assert!(ty.is_none() || ty == Some(Type::col()));
FunctionValueDefinition::Query(self.process_expression(expr))
}
FunctionDefinition::Array(value) => {
@@ -318,6 +382,7 @@ where
expression.iter().map(|e| e.size()).sum::<DegreeType>(),
self.degree.unwrap()
);
assert!(ty.is_none() || ty == Some(Type::col()));
FunctionValueDefinition::Array(expression)
}
});
@@ -351,10 +416,18 @@ where
})]
}
fn evaluate_expression(
&self,
expr: ::powdr_ast::parsed::Expression<T>,
) -> Result<T, EvalError> {
/// Resolves a type name into a concrete type.
/// This routine mainly evaluates array length expressions.
fn resolve_type_name(&self, mut n: TypeName<parsed::Expression<T>>) -> Result<Type, EvalError> {
// Replace all expressions by number literals.
for e in n.expressions_mut() {
let v = self.evaluate_expression(e.clone())?;
*e = parsed::Expression::Number(v);
}
Ok(n.into())
}
fn evaluate_expression(&self, expr: parsed::Expression<T>) -> Result<T, EvalError> {
evaluator::evaluate_expression(
&ExpressionProcessor::new(self.driver).process_expression(expr),
self.driver.definitions(),
@@ -366,13 +439,13 @@ where
ExpressionProcessor::new(self.driver)
}
fn process_expression(&self, expr: ::powdr_ast::parsed::Expression<T>) -> Expression<T> {
fn process_expression(&self, expr: parsed::Expression<T>) -> Expression<T> {
self.expression_processor().process_expression(expr)
}
fn process_selected_expressions(
&self,
expr: ::powdr_ast::parsed::SelectedExpressions<::powdr_ast::parsed::Expression<T>>,
expr: parsed::SelectedExpressions<parsed::Expression<T>>,
) -> SelectedExpressions<Expression<T>> {
self.expression_processor()
.process_selected_expressions(expr)

View File

@@ -137,11 +137,11 @@ machine Arith(CLK32_31, operation_id){
/// returns |n| a(0) * b(n) + ... + a(n) * b(0)
let product = |a, b| |n| dot_prod(n + 1, a, |i| b(n - i));
/// Converts array to function, extended by zeros.
let array_as_fun = [|arr| |i| if 0 <= i && i < array::len(arr) {
let array_as_fun: expr[] -> (int -> expr) = |arr| |i| if 0 <= i && i < array::len(arr) {
arr[i]
} else {
0
}][0];
};
let shift_right = |fn, amount| |i| fn(i - amount);
let x1f = array_as_fun(x1);
@@ -151,12 +151,11 @@ machine Arith(CLK32_31, operation_id){
let y3f = array_as_fun(y3);
// Defined for arguments from 0 to 31 (inclusive)
let eq0 = (|| |nr|
let eq0: int -> expr = |nr|
product(x1f, y1f)(nr)
+ x2f(nr)
- shift_right(y2f, 16)(nr)
- y3f(nr)
)();
- y3f(nr);
// Note that Polygon uses a single 22-Bit column. However, this approach allows for a lower degree (2**16)
// while still preventing overflows: The 32-bit carry gets added to 32 16-Bit values, which can't overflow

View File

@@ -13,4 +13,5 @@ let map = |arr, f| new(len(arr), |i| f(arr[i]));
let fold = |arr, initial, folder| std::utils::fold(len(arr), |i| arr[i], initial, folder);
/// Returns the sum of the array elements.
let sum = [|arr| fold(arr, 0, |a, b| a + b)][0];
/// This actually also works on field elements, so the type is currently too restrictive.
let sum: int[] -> int = |arr| fold(arr, 0, |a, b| a + b);

View File

@@ -2,6 +2,6 @@
/// when evaluated.
/// It returns an empty array so that it can be used at constraint level.
/// This symbol is not an empty array, the actual semantics are overridden.
let print = [];
let print: string -> constr[] = [];
let println = [|msg| print(msg + "\n")][0];
let println: string -> constr[] = |msg| print(msg + "\n");

View File

@@ -22,4 +22,4 @@ let sum = |length, f| fold(length, f, 0, |acc, e| (acc + e));
let unchanged_until = |c, latch| (c' - c) * (1 - latch) == 0;
/// Evaluates to a constraint that forces `c` to be either 0 or 1.
let force_bool = [|c| c * (1 - c) == 0][0];
let force_bool: expr -> constr = |c| c * (1 - c) == 0;

View File

@@ -26,11 +26,12 @@ mod R {
machine FullConstant {
degree 2;
let C = |i| match i % 2 {
let C: int -> fe = |i| match i % 2 {
0 => x,
1 => y,
};
col commit w[2];
// Use some weird type just for the sake of it.
let w: col[sum(2, |i| 1)];
// This and the next line are the same.
super::utils::sum(2, |i| w[i]) == 8;

View File

@@ -29,8 +29,7 @@ namespace Arith(N);
/// returns f(0) + f(1) + ... + f(length - 1)
let sum = |length, f| fold(length, f, 0, |acc, e| acc + e);
// TODO the weird syntax is needed so that this is not classified as a constant column
let force_boolean = (|| |x| x * (1 - x) == 0)();
let force_boolean: expr -> constr = |x| x * (1 - x) == 0;
let clock = |j, row| if row % 32 == j { 1 } else { 0 };
// Arrays of fixed columns are not supported yet.
@@ -140,8 +139,7 @@ namespace Arith(N);
// That way we could even support functions returning lookups.
// x can only change between two blocks of 32 rows.
// TODO the weird syntax is needed so that this is not classified as a fixed column.
let fixed_inside_32_block = (|| |x| (x - x') * (1 - CLK32[31]) == 0)();
let fixed_inside_32_block: expr -> constr = |x| (x - x') * (1 - CLK32[31]) == 0;
make_array(16, |i| fixed_inside_32_block(x1[i]));
make_array(16, |i| fixed_inside_32_block(y1[i]));
@@ -209,12 +207,11 @@ namespace Arith(N);
let q2f = array_as_fun(q2, 16);
// Defined for arguments from 0 to 31 (inclusive)
let eq0 = (|| |nr|
let eq0: int -> expr = |nr|
product(x1f, y1f)(nr)
+ x2f(nr)
- shift_right(y2f, 16)(nr)
- y3f(nr)
)();
- y3f(nr);
/*******
@@ -224,16 +221,16 @@ namespace Arith(N);
*******/
// 0xffffffffffffffffffffffffffffffffffffffffffffffffffff fffe ffff fc2f
let p = array_as_fun([
let p: col = array_as_fun([
0xfc2f, 0xffff, 0xfffe, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff
], 16);
// The "- 4 * shift_right(p, 16)" effectively subtracts 4 * (p << 16 * 16) = 2 ** 258 * p
// As a result, the term computes `(x - 2 ** 258) * p`.
let product_with_p = (|| |x| |nr| product(p, x)(nr) - 4 * shift_right(p, 16)(nr))();
// The "- 4 * shift_right(p, 16)" effectively subtracts 4 * (p << 16 * 16) = 2 ** 258 * p
// As a result, the term computes `(x - 2 ** 258) * p`.
let product_with_p: int -> (int -> expr) = |x| |nr| product(p, x)(nr) - 4 * shift_right(p, 16)(nr);
let eq1 = (|| |nr| product(sf, x2f)(nr) - product(sf, x1f)(nr) - y2f(nr) + y1f(nr) + product_with_p(q0f)(nr))();
let eq1: int -> expr = |nr| product(sf, x2f)(nr) - product(sf, x1f)(nr) - y2f(nr) + y1f(nr) + product_with_p(q0f)(nr);
/*******
*
@@ -241,7 +238,7 @@ namespace Arith(N);
*
*******/
let eq2 = (|| |nr| 2 * product(sf, y1f)(nr) - 3 * product(x1f, x1f)(nr) + product_with_p(q0f)(nr))();
let eq2: int -> expr = |nr| 2 * product(sf, y1f)(nr) - 3 * product(x1f, x1f)(nr) + product_with_p(q0f)(nr);
/*******
*
@@ -249,7 +246,7 @@ namespace Arith(N);
*
*******/
let eq3 = (|| |nr| product(sf, sf)(nr) - x1f(nr) - x2f(nr) - x3f(nr) + product_with_p(q1f)(nr))();
let eq3: int -> expr = |nr| product(sf, sf)(nr) - x1f(nr) - x2f(nr) - x3f(nr) + product_with_p(q1f)(nr);
/*******
@@ -258,7 +255,7 @@ namespace Arith(N);
*
*******/
let eq4 = (|| |nr| product(sf, x1f)(nr) - product(sf, x3f)(nr) - y1f(nr) - y3f(nr) + product_with_p(q2f)(nr))();
let eq4: int -> expr = |nr| product(sf, x1f)(nr) - product(sf, x3f)(nr) - y1f(nr) - y3f(nr) + product_with_p(q2f)(nr);
pol commit selEq[4];

View File

@@ -0,0 +1,30 @@
namespace Main(16);
// ANCHOR: declarations
// This defines a constant
let rows = 2**16;
// This defines a fixed column that contains the row number in each row.
let step = |i| i;
// This defines a copy of the column, also a fixed column because the type
// is explicitly specified.
let also_step: col = step;
// Here, we have a witness column.
let x;
// This functions defines a fixed column where each cell contains the
// square of its row number.
let square = |x| x*x;
// The same function as `square` above, but now its type is given as
// `int -> int` and thus it is *not* classified as a column. Instead,
// it is stored as a utility function. If utility functions are
// referenced in constraints, they have to be evaluated, meaning that
// the constraint `w = square_non_column;` is invalid but both
// `w = square_non_column(7);` and `w = square;` are valid constraints.
let square_non_column: int -> int = |x| x*x;
// A recursive function, taking a function and an integer as parameter
let sum = |f, i| match i {
0 => f(0),
_ => f(i) + sum(f, i - 1)
};
// ANCHOR_END: declarations
// We need at least one constraint to create a proof in the test.
let w;
w + square = 0;

View File

@@ -6,20 +6,17 @@ namespace Main(16);
};
// returns f(0) + f(1) + ... + f(length - 1)
let sum = |length, f| fold(length, f, 0, |acc, e| acc + e);
// If called with a single value, this function evaluates the equality,
// otherwise, it returns a constraint (if called with a column or
// an algebraic expression).
// If we write "|x| x == 20", it will be classified as a fixed column,
// so we use a trick that makes it not look like a function with a single
// parameter.
let equals_twenty = [|x| x == 20][0];
// declares an array of 16 witness columns.
// This function takes an algebraic expression (a column or expression
// involving columns) and returns an identity that forces this expression
// to equal 20.
let equals_twenty: expr -> constr = |x| x == 20;
// This declares an array of 16 witness columns.
col witness wit[16];
// This expression has to evaluate to a constraint, but we can still use
// This expression has to evaluate to an identity, but we can still use
// higher order functions and all the flexibility of the language.
// The sub-expression "sum(16, |i| wit[i]" evaluates to the algebraic
// The sub-expression `sum(16, |i| wit[i])` evaluates to the algebraic
// expression "wit[0] + wit[1] + ... + wit[15]", which is then
// turned by "equals_twenty" into the constraint
// turned into the identity by `equals_twenty`
// wit[0] + wit[1] + ... + wit[15] == 20.
equals_twenty(sum(16, |i| wit[i]));