mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-05-13 03:00:26 -04:00
Merge pull request #614 from powdr-labs/arrays
Support types for declarations
This commit is contained in:
@@ -10,7 +10,10 @@ use std::{
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use self::parsed::asm::{AbsoluteSymbolPath, SymbolPath};
|
||||
use self::{
|
||||
parsed::asm::{AbsoluteSymbolPath, SymbolPath},
|
||||
types::{ArrayType, FunctionType, TupleType, Type},
|
||||
};
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -47,7 +50,18 @@ impl<T: Display> Display for Analyzed<T> {
|
||||
};
|
||||
write!(f, " col {kind}{name}")?;
|
||||
if let Some(length) = symbol.length {
|
||||
write!(f, "[{length}]")?;
|
||||
if let PolynomialType::Committed = poly_type {
|
||||
write!(f, "[{length}]")?;
|
||||
assert!(definition.is_none());
|
||||
} else {
|
||||
// Do not print an array size, because we will do it as part of the type.
|
||||
assert!(matches!(
|
||||
definition,
|
||||
Some(FunctionValueDefinition::Expression(
|
||||
TypedExpression { e: _, ty: Some(_) }
|
||||
))
|
||||
));
|
||||
}
|
||||
}
|
||||
if let Some(value) = definition {
|
||||
writeln!(f, "{value};")?
|
||||
@@ -57,11 +71,18 @@ impl<T: Display> Display for Analyzed<T> {
|
||||
}
|
||||
SymbolKind::Constant() => {
|
||||
let indentation = if is_local { " " } else { "" };
|
||||
writeln!(
|
||||
f,
|
||||
"{indentation}constant {name}{};",
|
||||
definition.as_ref().unwrap()
|
||||
)?;
|
||||
let Some(FunctionValueDefinition::Expression(TypedExpression {
|
||||
e,
|
||||
ty: Some(Type::Fe),
|
||||
})) = &definition
|
||||
else {
|
||||
panic!(
|
||||
"Invalid constant value: {}",
|
||||
definition.as_ref().unwrap()
|
||||
);
|
||||
};
|
||||
|
||||
writeln!(f, "{indentation}constant {name} = {e};",)?;
|
||||
}
|
||||
SymbolKind::Other() => {
|
||||
write!(f, " let {name}")?;
|
||||
@@ -108,7 +129,17 @@ impl<T: Display> Display for FunctionValueDefinition<T> {
|
||||
write!(f, " = {}", items.iter().format(" + "))
|
||||
}
|
||||
FunctionValueDefinition::Query(e) => format_outer_function(e, Some("query"), f),
|
||||
FunctionValueDefinition::Expression(e) => format_outer_function(e, None, f),
|
||||
FunctionValueDefinition::Expression(TypedExpression { e, ty: None }) => {
|
||||
format_outer_function(e, None, f)
|
||||
}
|
||||
FunctionValueDefinition::Expression(TypedExpression { e, ty: Some(ty) })
|
||||
if *ty == Type::col() =>
|
||||
{
|
||||
format_outer_function(e, None, f)
|
||||
}
|
||||
FunctionValueDefinition::Expression(TypedExpression { e, ty: Some(ty) }) => {
|
||||
write!(f, ": {ty} = {e}")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -248,3 +279,65 @@ impl Display for PolynomialReference {
|
||||
write!(f, "{}", self.name,)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Type {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
|
||||
match self {
|
||||
Type::Bool => write!(f, "bool"),
|
||||
Type::Int => write!(f, "int"),
|
||||
Type::Fe => write!(f, "fe"),
|
||||
Type::String => write!(f, "string"),
|
||||
Type::Expr => write!(f, "expr"),
|
||||
Type::Constr => write!(f, "constr"),
|
||||
Type::Array(ar) => write!(f, "{ar}"),
|
||||
Type::Tuple(tu) => write!(f, "{tu}"),
|
||||
Type::Function(fun) => write!(f, "{fun}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ArrayType {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
|
||||
let length = self.length.iter().format("");
|
||||
if self.base.needs_parentheses() {
|
||||
write!(f, "({})[{length}]", self.base)
|
||||
} else {
|
||||
write!(f, "{}[{length}]", self.base)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for TupleType {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
|
||||
write!(f, "({})", format_list_of_types(&self.items))
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for FunctionType {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
|
||||
if *self == Self::col() {
|
||||
write!(f, "col")
|
||||
} else {
|
||||
write!(
|
||||
f,
|
||||
"{} -> {}",
|
||||
format_list_of_types(&self.params),
|
||||
self.value
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn format_list_of_types(types: &[Type]) -> String {
|
||||
types
|
||||
.iter()
|
||||
.map(|x| {
|
||||
if x.needs_parentheses() {
|
||||
format!("({x})")
|
||||
} else {
|
||||
x.to_string()
|
||||
}
|
||||
})
|
||||
.format(", ")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
mod display;
|
||||
pub mod types;
|
||||
pub mod visitor;
|
||||
|
||||
use core::hash::Hash;
|
||||
@@ -15,6 +16,8 @@ pub use crate::parsed::UnaryOperator;
|
||||
use crate::parsed::{self, SelectedExpressions};
|
||||
use crate::SourceRef;
|
||||
|
||||
use self::types::TypedExpression;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum StatementIdentifier {
|
||||
/// Either an intermediate column or a definition.
|
||||
@@ -303,7 +306,9 @@ impl<T> Analyzed<T> {
|
||||
.iter_mut()
|
||||
.flat_map(|e| e.pattern.iter_mut())
|
||||
.for_each(|e| e.post_visit_expressions_mut(f)),
|
||||
Some(FunctionValueDefinition::Expression(e)) => e.post_visit_expressions_mut(f),
|
||||
Some(FunctionValueDefinition::Expression(TypedExpression { e, ty: _ })) => {
|
||||
e.post_visit_expressions_mut(f)
|
||||
}
|
||||
None => {}
|
||||
});
|
||||
}
|
||||
@@ -467,7 +472,7 @@ pub enum SymbolKind {
|
||||
pub enum FunctionValueDefinition<T> {
|
||||
Array(Vec<RepeatedArray<T>>),
|
||||
Query(Expression<T>),
|
||||
Expression(Expression<T>),
|
||||
Expression(TypedExpression<T>),
|
||||
}
|
||||
|
||||
/// An array of elements that might be repeated.
|
||||
|
||||
138
ast/src/analyzed/types.rs
Normal file
138
ast/src/analyzed/types.rs
Normal file
@@ -0,0 +1,138 @@
|
||||
use std::fmt::Display;
|
||||
|
||||
use powdr_number::FieldElement;
|
||||
|
||||
use crate::parsed::{ArrayTypeName, Expression, FunctionTypeName, TupleTypeName, TypeName};
|
||||
|
||||
use super::Reference;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
pub struct TypedExpression<T, Ref = Reference> {
|
||||
pub e: Expression<T, Ref>,
|
||||
pub ty: Option<Type>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
pub enum Type {
|
||||
/// Boolean
|
||||
Bool,
|
||||
/// Integer (arbitrary precision)
|
||||
Int,
|
||||
/// Field element (unspecified field)
|
||||
Fe,
|
||||
/// String
|
||||
String,
|
||||
/// Algebraic expression
|
||||
Expr,
|
||||
/// Polynomial identity or lookup (not yet supported)
|
||||
Constr,
|
||||
Array(ArrayType),
|
||||
Tuple(TupleType),
|
||||
Function(FunctionType),
|
||||
}
|
||||
|
||||
impl Type {
|
||||
/// Returns the column type `int -> fe`.
|
||||
pub fn col() -> Self {
|
||||
Type::Function(FunctionType::col())
|
||||
}
|
||||
|
||||
/// Returns true if the type name needs parentheses around it during formatting
|
||||
/// when used inside a complex expression.
|
||||
pub fn needs_parentheses(&self) -> bool {
|
||||
match self {
|
||||
Type::Bool
|
||||
| Type::Int
|
||||
| Type::Fe
|
||||
| Type::String
|
||||
| Type::Expr
|
||||
| Type::Constr
|
||||
| Type::Array(_)
|
||||
| Type::Tuple(_) => false,
|
||||
Type::Function(fun) => fun.needs_parentheses(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FieldElement, Ref: Display> From<TypeName<Expression<T, Ref>>> for Type {
|
||||
fn from(value: TypeName<Expression<T, Ref>>) -> Self {
|
||||
match value {
|
||||
TypeName::Bool => Type::Bool,
|
||||
TypeName::Int => Type::Int,
|
||||
TypeName::Fe => Type::Fe,
|
||||
TypeName::String => Type::String,
|
||||
TypeName::Expr => Type::Expr,
|
||||
TypeName::Constr => Type::Constr,
|
||||
TypeName::Col => Type::col(),
|
||||
TypeName::Array(ar) => Type::Array(ar.into()),
|
||||
TypeName::Tuple(tu) => Type::Tuple(tu.into()),
|
||||
TypeName::Function(fun) => Type::Function(fun.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
pub struct ArrayType {
|
||||
pub base: Box<Type>,
|
||||
pub length: Option<u64>,
|
||||
}
|
||||
|
||||
impl<T: FieldElement, Ref: Display> From<ArrayTypeName<Expression<T, Ref>>> for ArrayType {
|
||||
fn from(name: ArrayTypeName<Expression<T, Ref>>) -> Self {
|
||||
let length = name.length.as_ref().map(|l| {
|
||||
if let Expression::Number(n) = l {
|
||||
n.to_degree()
|
||||
} else {
|
||||
panic!(
|
||||
"Array length expression not resolved in type name prior to conversion: {name}"
|
||||
);
|
||||
}
|
||||
});
|
||||
ArrayType {
|
||||
base: Box::new(Type::from(*name.base)),
|
||||
length,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
pub struct TupleType {
|
||||
pub items: Vec<Type>,
|
||||
}
|
||||
|
||||
impl<T: FieldElement, Ref: Display> From<TupleTypeName<Expression<T, Ref>>> for TupleType {
|
||||
fn from(value: TupleTypeName<Expression<T, Ref>>) -> Self {
|
||||
TupleType {
|
||||
items: value.items.into_iter().map(Into::into).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
pub struct FunctionType {
|
||||
pub params: Vec<Type>,
|
||||
pub value: Box<Type>,
|
||||
}
|
||||
|
||||
impl FunctionType {
|
||||
/// Returns the column type `int -> fe`.
|
||||
pub fn col() -> Self {
|
||||
FunctionType {
|
||||
params: vec![Type::Int],
|
||||
value: Box::new(Type::Fe),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn needs_parentheses(&self) -> bool {
|
||||
*self != Self::col()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: FieldElement, Ref: Display> From<FunctionTypeName<Expression<T, Ref>>> for FunctionType {
|
||||
fn from(name: FunctionTypeName<Expression<T, Ref>>) -> Self {
|
||||
FunctionType {
|
||||
params: name.params.into_iter().map(Into::into).collect(),
|
||||
value: Box::new(Type::from(*name.value)),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -86,7 +86,8 @@ impl<T> ExpressionVisitable<Expression<T>> for FunctionValueDefinition<T> {
|
||||
F: FnMut(&mut Expression<T>) -> ControlFlow<B>,
|
||||
{
|
||||
match self {
|
||||
FunctionValueDefinition::Query(e) | FunctionValueDefinition::Expression(e) => {
|
||||
FunctionValueDefinition::Query(e)
|
||||
| FunctionValueDefinition::Expression(TypedExpression { e, ty: _ }) => {
|
||||
e.visit_expressions_mut(f, o)
|
||||
}
|
||||
FunctionValueDefinition::Array(array) => array
|
||||
@@ -101,7 +102,8 @@ impl<T> ExpressionVisitable<Expression<T>> for FunctionValueDefinition<T> {
|
||||
F: FnMut(&Expression<T>) -> ControlFlow<B>,
|
||||
{
|
||||
match self {
|
||||
FunctionValueDefinition::Query(e) | FunctionValueDefinition::Expression(e) => {
|
||||
FunctionValueDefinition::Query(e)
|
||||
| FunctionValueDefinition::Expression(TypedExpression { e, ty: _ }) => {
|
||||
e.visit_expressions(f, o)
|
||||
}
|
||||
FunctionValueDefinition::Array(array) => array
|
||||
|
||||
@@ -7,7 +7,10 @@ use itertools::Itertools;
|
||||
|
||||
use crate::{
|
||||
indent,
|
||||
parsed::asm::{AbsoluteSymbolPath, Part},
|
||||
parsed::{
|
||||
asm::{AbsoluteSymbolPath, Part},
|
||||
ExpressionWithTypeName,
|
||||
},
|
||||
write_indented_by, write_items_indented,
|
||||
};
|
||||
|
||||
@@ -44,9 +47,15 @@ impl<T: Display> Display for AnalysisASMFile<T> {
|
||||
Item::Machine(machine) => {
|
||||
write_indented_by(f, format!("machine {name}{machine}"), current_path.len())?;
|
||||
}
|
||||
Item::Expression(expression) => write_indented_by(
|
||||
Item::Expression(ExpressionWithTypeName { e, type_name }) => write_indented_by(
|
||||
f,
|
||||
format!("let {name} = {expression};\n"),
|
||||
format!(
|
||||
"let {name}{} = {e};\n",
|
||||
type_name
|
||||
.as_ref()
|
||||
.map(|tn| format!(": {tn}"))
|
||||
.unwrap_or_default()
|
||||
),
|
||||
current_path.len(),
|
||||
)?,
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ use crate::parsed::{
|
||||
AbsoluteSymbolPath, AssignmentRegister, CallableRef, InstructionBody, OperationId, Params,
|
||||
},
|
||||
visitor::{ExpressionVisitable, VisitOrder},
|
||||
NamespacedPolynomialReference, PilStatement,
|
||||
ExpressionWithTypeName, NamespacedPolynomialReference, PilStatement,
|
||||
};
|
||||
use crate::SourceRef;
|
||||
|
||||
@@ -672,7 +672,7 @@ pub struct SubmachineDeclaration {
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum Item<T> {
|
||||
Machine(Machine<T>),
|
||||
Expression(Expression<T>),
|
||||
Expression(ExpressionWithTypeName<T>),
|
||||
}
|
||||
|
||||
impl<T> Item<T> {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
use std::fmt::{Display, Formatter, Result};
|
||||
|
||||
use crate::parsed::ExpressionWithTypeName;
|
||||
|
||||
use super::{Link, LinkFrom, LinkTo, Location, Machine, Object, Operation, PILGraph};
|
||||
|
||||
impl Display for Location {
|
||||
@@ -11,8 +13,15 @@ impl Display for Location {
|
||||
impl<T: Display> Display for PILGraph<T> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
|
||||
writeln!(f, "// Utilities")?;
|
||||
for (name, e) in &self.definitions {
|
||||
writeln!(f, "let {name} = {e};")?;
|
||||
for (name, ExpressionWithTypeName { e, type_name }) in &self.definitions {
|
||||
writeln!(
|
||||
f,
|
||||
"let {name}{} = {e};",
|
||||
type_name
|
||||
.as_ref()
|
||||
.map(|tn| format!(": {tn}"))
|
||||
.unwrap_or_default()
|
||||
)?;
|
||||
}
|
||||
for (location, object) in &self.objects {
|
||||
writeln!(f, "// Object {}", location)?;
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::collections::BTreeMap;
|
||||
|
||||
use crate::parsed::{
|
||||
asm::{AbsoluteSymbolPath, Params},
|
||||
Expression, PilStatement,
|
||||
Expression, ExpressionWithTypeName, PilStatement,
|
||||
};
|
||||
|
||||
mod display;
|
||||
@@ -30,7 +30,7 @@ pub struct PILGraph<T> {
|
||||
pub main: Machine,
|
||||
pub entry_points: Vec<Operation<T>>,
|
||||
pub objects: BTreeMap<Location, Object<T>>,
|
||||
pub definitions: BTreeMap<AbsoluteSymbolPath, Expression<T>>,
|
||||
pub definitions: BTreeMap<AbsoluteSymbolPath, ExpressionWithTypeName<T>>,
|
||||
}
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
|
||||
@@ -11,7 +11,7 @@ use derive_more::From;
|
||||
|
||||
use crate::SourceRef;
|
||||
|
||||
use super::{Expression, PilStatement};
|
||||
use super::{Expression, ExpressionWithTypeName, PilStatement};
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq, Eq)]
|
||||
pub struct ASMProgram<T> {
|
||||
@@ -51,7 +51,7 @@ pub enum SymbolValue<T> {
|
||||
/// A module definition
|
||||
Module(Module<T>),
|
||||
/// A generic symbol / function.
|
||||
Expression(Expression<T>),
|
||||
Expression(ExpressionWithTypeName<T>),
|
||||
}
|
||||
|
||||
impl<T> SymbolValue<T> {
|
||||
@@ -74,7 +74,7 @@ pub enum SymbolValueRef<'a, T> {
|
||||
/// A module definition
|
||||
Module(ModuleRef<'a, T>),
|
||||
/// A generic symbol / function.
|
||||
Expression(&'a Expression<T>),
|
||||
Expression(&'a ExpressionWithTypeName<T>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, From)]
|
||||
|
||||
@@ -57,8 +57,15 @@ impl<T: Display> Display for ModuleStatement<T> {
|
||||
SymbolValue::Module(m @ Module::Local(_)) => {
|
||||
write!(f, "mod {name} {m}")
|
||||
}
|
||||
SymbolValue::Expression(e) => {
|
||||
write!(f, "let {name} = {e};")
|
||||
SymbolValue::Expression(ExpressionWithTypeName { e, type_name }) => {
|
||||
write!(
|
||||
f,
|
||||
"let {name}{} = {e};",
|
||||
type_name
|
||||
.as_ref()
|
||||
.map(|t| format!(": {t}"))
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
},
|
||||
}
|
||||
@@ -373,9 +380,15 @@ impl<T: Display> Display for PilStatement<T> {
|
||||
PilStatement::Namespace(_, name, poly_length) => {
|
||||
write!(f, "namespace {name}({poly_length});")
|
||||
}
|
||||
PilStatement::LetStatement(_, name, None) => write!(f, " let {name};"),
|
||||
PilStatement::LetStatement(_, name, Some(expr)) => {
|
||||
write!(f, " let {name} = {expr};")
|
||||
PilStatement::LetStatement(_, name, type_name, value) => {
|
||||
write!(f, " let {name}")?;
|
||||
if let Some(type_name) = type_name {
|
||||
write!(f, ": {type_name}")?;
|
||||
}
|
||||
if let Some(value) = &value {
|
||||
write!(f, " = {value}")?;
|
||||
}
|
||||
write!(f, ";")
|
||||
}
|
||||
PilStatement::PolynomialDefinition(_, name, value) => {
|
||||
write!(f, " pol {name} = {value};")
|
||||
@@ -584,6 +597,74 @@ impl Display for UnaryOperator {
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Display> Display for TypeName<E> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
|
||||
match self {
|
||||
TypeName::Bool => write!(f, "bool"),
|
||||
TypeName::Int => write!(f, "int"),
|
||||
TypeName::Fe => write!(f, "fe"),
|
||||
TypeName::String => write!(f, "string"),
|
||||
TypeName::Col => write!(f, "col"),
|
||||
TypeName::Expr => write!(f, "expr"),
|
||||
TypeName::Constr => write!(f, "constr"),
|
||||
TypeName::Array(array) => write!(f, "{array}"),
|
||||
TypeName::Tuple(tuple) => write!(f, "{tuple}"),
|
||||
TypeName::Function(fun) => write!(f, "{fun}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Display> Display for ArrayTypeName<E> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
|
||||
if self.base.needs_parentheses() {
|
||||
write!(f, "({})", self.base)
|
||||
} else {
|
||||
write!(f, "{}", self.base)
|
||||
}?;
|
||||
write!(
|
||||
f,
|
||||
"[{}]",
|
||||
self.length
|
||||
.as_ref()
|
||||
.map(|l| l.to_string())
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Display> Display for TupleTypeName<E> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
|
||||
write!(f, "({})", self.items.iter().format(", "))
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Display> Display for FunctionTypeName<E> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
|
||||
let params = self
|
||||
.params
|
||||
.iter()
|
||||
.map(|x| {
|
||||
if x.needs_parentheses() {
|
||||
format!("({x})")
|
||||
} else {
|
||||
format!("{x}")
|
||||
}
|
||||
})
|
||||
.join(", ")
|
||||
+ if self.params.is_empty() { "" } else { " " };
|
||||
|
||||
write!(
|
||||
f,
|
||||
"{params}-> {}",
|
||||
if self.value.needs_parentheses() {
|
||||
format!("({})", self.value)
|
||||
} else {
|
||||
format!("{}", self.value)
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
|
||||
@@ -24,7 +24,12 @@ pub enum PilStatement<T> {
|
||||
Include(SourceRef, String),
|
||||
/// Name of namespace and polynomial degree (constant)
|
||||
Namespace(SourceRef, SymbolPath, Expression<T>),
|
||||
LetStatement(SourceRef, String, Option<Expression<T>>),
|
||||
LetStatement(
|
||||
SourceRef,
|
||||
String,
|
||||
Option<TypeName<Expression<T>>>,
|
||||
Option<Expression<T>>,
|
||||
),
|
||||
PolynomialDefinition(SourceRef, String, Expression<T>),
|
||||
PublicDeclaration(
|
||||
SourceRef,
|
||||
@@ -68,7 +73,7 @@ impl<T> PilStatement<T> {
|
||||
| PilStatement::PolynomialConstantDefinition(_, name, _)
|
||||
| PilStatement::ConstantDefinition(_, name, _)
|
||||
| PilStatement::PublicDeclaration(_, name, _, _, _)
|
||||
| PilStatement::LetStatement(_, name, _) => Box::new(once(name)),
|
||||
| PilStatement::LetStatement(_, name, _, _) => Box::new(once(name)),
|
||||
PilStatement::PolynomialConstantDeclaration(_, polynomials)
|
||||
| PilStatement::PolynomialCommitDeclaration(_, polynomials, _) => {
|
||||
Box::new(polynomials.iter().map(|p| &p.name))
|
||||
@@ -98,8 +103,11 @@ impl<T> PilStatement<T> {
|
||||
| PilStatement::Namespace(_, _, e)
|
||||
| PilStatement::PolynomialDefinition(_, _, e)
|
||||
| PilStatement::PolynomialIdentity(_, e)
|
||||
| PilStatement::ConstantDefinition(_, _, e)
|
||||
| PilStatement::LetStatement(_, _, Some(e)) => Box::new(once(e)),
|
||||
| PilStatement::ConstantDefinition(_, _, e) => Box::new(once(e)),
|
||||
|
||||
PilStatement::LetStatement(_, _, type_name, value) => {
|
||||
Box::new(type_name.iter().flat_map(|t| t.expressions()).chain(value))
|
||||
}
|
||||
|
||||
PilStatement::PublicDeclaration(_, _, _, i, e) => Box::new(i.iter().chain(once(e))),
|
||||
|
||||
@@ -107,8 +115,7 @@ impl<T> PilStatement<T> {
|
||||
| PilStatement::PolynomialCommitDeclaration(_, _, Some(fundef)) => fundef.expressions(),
|
||||
PilStatement::PolynomialCommitDeclaration(_, _, None)
|
||||
| PilStatement::Include(_, _)
|
||||
| PilStatement::PolynomialConstantDeclaration(_, _)
|
||||
| PilStatement::LetStatement(_, _, None) => Box::new(empty()),
|
||||
| PilStatement::PolynomialConstantDeclaration(_, _) => Box::new(empty()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -126,8 +133,14 @@ impl<T> PilStatement<T> {
|
||||
| PilStatement::Namespace(_, _, e)
|
||||
| PilStatement::PolynomialDefinition(_, _, e)
|
||||
| PilStatement::PolynomialIdentity(_, e)
|
||||
| PilStatement::ConstantDefinition(_, _, e)
|
||||
| PilStatement::LetStatement(_, _, Some(e)) => Box::new(once(e)),
|
||||
| PilStatement::ConstantDefinition(_, _, e) => Box::new(once(e)),
|
||||
|
||||
PilStatement::LetStatement(_, _, type_name, value) => Box::new(
|
||||
type_name
|
||||
.iter_mut()
|
||||
.flat_map(|t| t.expressions_mut())
|
||||
.chain(value),
|
||||
),
|
||||
|
||||
PilStatement::PublicDeclaration(_, _, _, i, e) => Box::new(i.iter_mut().chain(once(e))),
|
||||
|
||||
@@ -137,8 +150,7 @@ impl<T> PilStatement<T> {
|
||||
}
|
||||
PilStatement::PolynomialCommitDeclaration(_, _, None)
|
||||
| PilStatement::Include(_, _)
|
||||
| PilStatement::PolynomialConstantDeclaration(_, _)
|
||||
| PilStatement::LetStatement(_, _, None) => Box::new(empty()),
|
||||
| PilStatement::PolynomialConstantDeclaration(_, _) => Box::new(empty()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -496,3 +508,141 @@ impl<T> ArrayExpression<T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
pub enum TypeName<E> {
|
||||
/// Boolean
|
||||
Bool,
|
||||
/// Integer (arbitrary precision)
|
||||
Int,
|
||||
/// Field element (unspecified field)
|
||||
Fe,
|
||||
/// String
|
||||
String,
|
||||
/// Column, shorthand for "int -> fe"
|
||||
Col,
|
||||
/// Algebraic expression
|
||||
Expr,
|
||||
/// Polynomial identity
|
||||
Constr,
|
||||
Array(ArrayTypeName<E>),
|
||||
Tuple(TupleTypeName<E>),
|
||||
Function(FunctionTypeName<E>),
|
||||
}
|
||||
|
||||
impl<E> TypeName<E> {
|
||||
/// Returns true if the type name needs parentheses during formatting
|
||||
/// when used inside a complex expression.
|
||||
pub fn needs_parentheses(&self) -> bool {
|
||||
match self {
|
||||
TypeName::Bool
|
||||
| TypeName::Int
|
||||
| TypeName::Fe
|
||||
| TypeName::String
|
||||
| TypeName::Col
|
||||
| TypeName::Expr
|
||||
| TypeName::Constr
|
||||
| TypeName::Array(_)
|
||||
| TypeName::Tuple(_) => false,
|
||||
TypeName::Function(_) => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over all (top-level) expressions in this type name.
|
||||
pub fn expressions(&self) -> Box<dyn Iterator<Item = &E> + '_> {
|
||||
match self {
|
||||
TypeName::Bool
|
||||
| TypeName::Int
|
||||
| TypeName::Fe
|
||||
| TypeName::String
|
||||
| TypeName::Col
|
||||
| TypeName::Expr
|
||||
| TypeName::Constr => Box::new(empty()),
|
||||
TypeName::Array(a) => a.expressions(),
|
||||
TypeName::Tuple(t) => t.expressions(),
|
||||
TypeName::Function(f) => f.expressions(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over all (top-level) expressions in this type name.
|
||||
pub fn expressions_mut(&mut self) -> Box<dyn Iterator<Item = &mut E> + '_> {
|
||||
match self {
|
||||
TypeName::Bool
|
||||
| TypeName::Int
|
||||
| TypeName::Fe
|
||||
| TypeName::String
|
||||
| TypeName::Col
|
||||
| TypeName::Expr
|
||||
| TypeName::Constr => Box::new(empty()),
|
||||
TypeName::Array(a) => a.expressions_mut(),
|
||||
TypeName::Tuple(t) => t.expressions_mut(),
|
||||
TypeName::Function(f) => f.expressions_mut(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
pub struct ArrayTypeName<E> {
|
||||
pub base: Box<TypeName<E>>,
|
||||
pub length: Option<E>,
|
||||
}
|
||||
|
||||
impl<E> ArrayTypeName<E> {
|
||||
/// Returns an iterator over all (top-level) expressions in this type name.
|
||||
pub fn expressions(&self) -> Box<dyn Iterator<Item = &E> + '_> {
|
||||
Box::new(self.base.expressions().chain(self.length.iter()))
|
||||
}
|
||||
/// Returns an iterator over all (top-level) expressions in this type name.
|
||||
pub fn expressions_mut(&mut self) -> Box<dyn Iterator<Item = &mut E> + '_> {
|
||||
Box::new(self.base.expressions_mut().chain(self.length.iter_mut()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
pub struct TupleTypeName<E> {
|
||||
pub items: Vec<TypeName<E>>,
|
||||
}
|
||||
|
||||
impl<E> TupleTypeName<E> {
|
||||
/// Returns an iterator over all (top-level) expressions in this type name.
|
||||
pub fn expressions(&self) -> Box<dyn Iterator<Item = &E> + '_> {
|
||||
Box::new(self.items.iter().flat_map(|t| t.expressions()))
|
||||
}
|
||||
/// Returns an iterator over all (top-level) expressions in this type name.
|
||||
pub fn expressions_mut(&mut self) -> Box<dyn Iterator<Item = &mut E> + '_> {
|
||||
Box::new(self.items.iter_mut().flat_map(|t| t.expressions_mut()))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
pub struct FunctionTypeName<E> {
|
||||
pub params: Vec<TypeName<E>>,
|
||||
pub value: Box<TypeName<E>>,
|
||||
}
|
||||
|
||||
impl<E> FunctionTypeName<E> {
|
||||
/// Returns an iterator over all (top-level) expressions in this type name.
|
||||
pub fn expressions(&self) -> Box<dyn Iterator<Item = &E> + '_> {
|
||||
Box::new(
|
||||
self.params
|
||||
.iter()
|
||||
.flat_map(|t| t.expressions())
|
||||
.chain(self.value.expressions()),
|
||||
)
|
||||
}
|
||||
/// Returns an iterator over all (top-level) expressions in this type name.
|
||||
pub fn expressions_mut(&mut self) -> Box<dyn Iterator<Item = &mut E> + '_> {
|
||||
Box::new(
|
||||
self.params
|
||||
.iter_mut()
|
||||
.flat_map(|t| t.expressions_mut())
|
||||
.chain(self.value.expressions_mut()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
|
||||
pub struct ExpressionWithTypeName<T, Ref = NamespacedPolynomialReference> {
|
||||
pub e: Expression<T, Ref>,
|
||||
pub type_name: Option<TypeName<Expression<T, Ref>>>,
|
||||
}
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use std::{iter::once, ops::ControlFlow};
|
||||
|
||||
use super::{
|
||||
ArrayExpression, ArrayLiteral, Expression, FunctionCall, FunctionDefinition, IfExpression,
|
||||
IndexAccess, LambdaExpression, MatchArm, MatchPattern, NamespacedPolynomialReference,
|
||||
PilStatement, SelectedExpressions,
|
||||
ArrayExpression, ArrayLiteral, ArrayTypeName, Expression, FunctionCall, FunctionDefinition,
|
||||
FunctionTypeName, IfExpression, IndexAccess, LambdaExpression, MatchArm, MatchPattern,
|
||||
NamespacedPolynomialReference, PilStatement, SelectedExpressions, TupleTypeName, TypeName,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
@@ -205,8 +205,17 @@ impl<T> ExpressionVisitable<Expression<T, NamespacedPolynomialReference>> for Pi
|
||||
| PilStatement::PolynomialDefinition(_, _, e)
|
||||
| PilStatement::PolynomialIdentity(_, e)
|
||||
| PilStatement::PublicDeclaration(_, _, _, None, e)
|
||||
| PilStatement::ConstantDefinition(_, _, e)
|
||||
| PilStatement::LetStatement(_, _, Some(e)) => e.visit_expressions_mut(f, o),
|
||||
| PilStatement::ConstantDefinition(_, _, e) => e.visit_expressions_mut(f, o),
|
||||
|
||||
PilStatement::LetStatement(_, _, type_name, value) => {
|
||||
if let Some(t) = type_name {
|
||||
t.visit_expressions_mut(f, o)?;
|
||||
};
|
||||
if let Some(v) = value {
|
||||
v.visit_expressions_mut(f, o)?;
|
||||
};
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
|
||||
PilStatement::PublicDeclaration(_, _, _, Some(i), e) => [i, e]
|
||||
.into_iter()
|
||||
@@ -218,8 +227,7 @@ impl<T> ExpressionVisitable<Expression<T, NamespacedPolynomialReference>> for Pi
|
||||
}
|
||||
PilStatement::PolynomialCommitDeclaration(_, _, None)
|
||||
| PilStatement::Include(_, _)
|
||||
| PilStatement::PolynomialConstantDeclaration(_, _)
|
||||
| PilStatement::LetStatement(_, _, None) => ControlFlow::Continue(()),
|
||||
| PilStatement::PolynomialConstantDeclaration(_, _) => ControlFlow::Continue(()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -242,8 +250,17 @@ impl<T> ExpressionVisitable<Expression<T, NamespacedPolynomialReference>> for Pi
|
||||
| PilStatement::PolynomialDefinition(_, _, e)
|
||||
| PilStatement::PolynomialIdentity(_, e)
|
||||
| PilStatement::PublicDeclaration(_, _, _, None, e)
|
||||
| PilStatement::ConstantDefinition(_, _, e)
|
||||
| PilStatement::LetStatement(_, _, Some(e)) => e.visit_expressions(f, o),
|
||||
| PilStatement::ConstantDefinition(_, _, e) => e.visit_expressions(f, o),
|
||||
|
||||
PilStatement::LetStatement(_, _, type_name, value) => {
|
||||
if let Some(t) = type_name {
|
||||
t.visit_expressions(f, o)?;
|
||||
};
|
||||
if let Some(v) = value {
|
||||
v.visit_expressions(f, o)?;
|
||||
};
|
||||
ControlFlow::Continue(())
|
||||
}
|
||||
|
||||
PilStatement::PublicDeclaration(_, _, _, Some(i), e) => [i, e]
|
||||
.into_iter()
|
||||
@@ -255,8 +272,7 @@ impl<T> ExpressionVisitable<Expression<T, NamespacedPolynomialReference>> for Pi
|
||||
}
|
||||
PilStatement::PolynomialCommitDeclaration(_, _, None)
|
||||
| PilStatement::Include(_, _)
|
||||
| PilStatement::PolynomialConstantDeclaration(_, _)
|
||||
| PilStatement::LetStatement(_, _, None) => ControlFlow::Continue(()),
|
||||
| PilStatement::PolynomialConstantDeclaration(_, _) => ControlFlow::Continue(()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -478,3 +494,105 @@ impl<T, Ref> ExpressionVisitable<Expression<T, Ref>> for IfExpression<T, Ref> {
|
||||
.try_for_each(|e| e.visit_expressions(f, o))
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: ExpressionVisitable<E>> ExpressionVisitable<E> for TypeName<E> {
|
||||
fn visit_expressions_mut<F, B>(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
|
||||
where
|
||||
F: FnMut(&mut E) -> ControlFlow<B>,
|
||||
{
|
||||
match self {
|
||||
TypeName::Bool
|
||||
| TypeName::Int
|
||||
| TypeName::Fe
|
||||
| TypeName::String
|
||||
| TypeName::Col
|
||||
| TypeName::Expr
|
||||
| TypeName::Constr => ControlFlow::Continue(()),
|
||||
TypeName::Array(a) => a.visit_expressions_mut(f, o),
|
||||
TypeName::Tuple(t) => t.visit_expressions_mut(f, o),
|
||||
TypeName::Function(fun) => fun.visit_expressions_mut(f, o),
|
||||
}
|
||||
}
|
||||
|
||||
fn visit_expressions<F, B>(&self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
|
||||
where
|
||||
F: FnMut(&E) -> ControlFlow<B>,
|
||||
{
|
||||
match self {
|
||||
TypeName::Bool
|
||||
| TypeName::Int
|
||||
| TypeName::Fe
|
||||
| TypeName::String
|
||||
| TypeName::Col
|
||||
| TypeName::Expr
|
||||
| TypeName::Constr => ControlFlow::Continue(()),
|
||||
TypeName::Array(a) => a.visit_expressions(f, o),
|
||||
TypeName::Tuple(t) => t.visit_expressions(f, o),
|
||||
TypeName::Function(fun) => fun.visit_expressions(f, o),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: ExpressionVisitable<E>> ExpressionVisitable<E> for ArrayTypeName<E> {
|
||||
fn visit_expressions_mut<F, B>(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
|
||||
where
|
||||
F: FnMut(&mut E) -> ControlFlow<B>,
|
||||
{
|
||||
self.base.visit_expressions_mut(f, o)?;
|
||||
self.length
|
||||
.iter_mut()
|
||||
.try_for_each(|e| e.visit_expressions_mut(f, o))
|
||||
}
|
||||
|
||||
fn visit_expressions<F, B>(&self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
|
||||
where
|
||||
F: FnMut(&E) -> ControlFlow<B>,
|
||||
{
|
||||
self.base.visit_expressions(f, o)?;
|
||||
self.length
|
||||
.iter()
|
||||
.try_for_each(|e| e.visit_expressions(f, o))
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: ExpressionVisitable<E>> ExpressionVisitable<E> for TupleTypeName<E> {
|
||||
fn visit_expressions_mut<F, B>(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
|
||||
where
|
||||
F: FnMut(&mut E) -> ControlFlow<B>,
|
||||
{
|
||||
self.items
|
||||
.iter_mut()
|
||||
.try_for_each(|i| i.visit_expressions_mut(f, o))
|
||||
}
|
||||
|
||||
fn visit_expressions<F, B>(&self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
|
||||
where
|
||||
F: FnMut(&E) -> ControlFlow<B>,
|
||||
{
|
||||
self.items
|
||||
.iter()
|
||||
.try_for_each(|i| i.visit_expressions(f, o))
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: ExpressionVisitable<E>> ExpressionVisitable<E> for FunctionTypeName<E> {
|
||||
fn visit_expressions_mut<F, B>(&mut self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
|
||||
where
|
||||
F: FnMut(&mut E) -> ControlFlow<B>,
|
||||
{
|
||||
self.params
|
||||
.iter_mut()
|
||||
.chain(once(self.value.as_mut()))
|
||||
.try_for_each(|i| i.visit_expressions_mut(f, o))
|
||||
}
|
||||
|
||||
fn visit_expressions<F, B>(&self, f: &mut F, o: VisitOrder) -> ControlFlow<B>
|
||||
where
|
||||
F: FnMut(&E) -> ControlFlow<B>,
|
||||
{
|
||||
self.params
|
||||
.iter()
|
||||
.chain(once(self.value.as_ref()))
|
||||
.try_for_each(|i| i.visit_expressions(f, o))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,37 +1,26 @@
|
||||
# Declarations
|
||||
|
||||
Powdr-pil allows the same syntax to declare various kinds of symbols. This includes
|
||||
constants, fixed columns, witness columns and even macros. It deduces the symbol kind
|
||||
by its type and the way the symbol is used.
|
||||
constants, fixed columns, witness columns and even higher-order functions. It deduces the symbol kind
|
||||
from the type of the symbol and the way the symbol is used.
|
||||
|
||||
Symbols can be declared using ``let <name>;`` and they can be declared and defined
|
||||
using ``let <name> = <value>;``, where ``<value>`` is an expression.
|
||||
using ``let <name> = <value>;``, where ``<value>`` is an expression. The [type](./types.md) of the symbol
|
||||
can be explicitly specified using ``let <name>: <type>;`` and ``let <name>: <type> = <value>;``.
|
||||
|
||||
This syntax can be used for constants, fixed columns, witness columns and even (higher-order)
|
||||
functions that can transform expressions. The kind of symbol is deduced by its type and the
|
||||
way the symbol is used:
|
||||
|
||||
- symbols without a value are witness columns,
|
||||
- symbols evaluating to a number are constants,
|
||||
- symbols defined as a function with a single parameter are fixed columns and
|
||||
- everything else is a "generic symbol" that is not a column.
|
||||
- Symbols without a value are witness columns. Their type can be omitted. If it is given, it must be ``int -> fe`` or its shorthand ``col``.
|
||||
- Symbols evaluating to a number or with type ``fe`` are constants.
|
||||
- Symbols without type but with a value that is a function with a single parameter are fixed columns.
|
||||
- Symbols defined with a value and type ``int -> fe`` or its shorthand ``col`` are also fixed columns.
|
||||
- Everything else is a "generic symbol" that is not a column or constant.
|
||||
|
||||
Examples:
|
||||
|
||||
|
||||
```rust
|
||||
// This defines a constant
|
||||
let rows = 2**16;
|
||||
// This defines a fixed column that contains the row number in each row.
|
||||
let step = |i| i;
|
||||
// Here, we have a witness column.
|
||||
let x;
|
||||
// This functions returns the square of its input (classified as a fixed column).
|
||||
let square = |x| x*x;
|
||||
// A recursive function, taking a function and an integer as parameter
|
||||
let sum = |f, i| match i {
|
||||
0 => f(0),
|
||||
_ => f(i) + sum(f, i - 1)
|
||||
};
|
||||
// The same function as "square" above, but employing a trick to avoid it
|
||||
// being classified as a column.
|
||||
let square_non_column = (|| |x| x*x)();
|
||||
{{#include ../../../test_data/pil/book/declarations.pil:declarations}}
|
||||
```
|
||||
@@ -1,7 +1,10 @@
|
||||
use std::{collections::HashMap, fmt::Display, rc::Rc};
|
||||
|
||||
use itertools::Itertools;
|
||||
use powdr_ast::analyzed::{Analyzed, FunctionValueDefinition};
|
||||
use powdr_ast::analyzed::{
|
||||
types::{Type, TypedExpression},
|
||||
Analyzed, FunctionValueDefinition,
|
||||
};
|
||||
use powdr_number::{DegreeType, FieldElement};
|
||||
use powdr_pil_analyzer::evaluator::{self, Custom, EvalError, SymbolLookup, Value};
|
||||
use rayon::prelude::{IntoParallelIterator, ParallelIterator};
|
||||
@@ -43,20 +46,25 @@ fn generate_values<T: FieldElement>(
|
||||
};
|
||||
// TODO we should maybe pre-compute some symbols here.
|
||||
let result = match body {
|
||||
FunctionValueDefinition::Expression(e) => (0..degree)
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
// We could try to avoid the first evaluation to be run for each iteration,
|
||||
// but the data is not thread-safe.
|
||||
let fun = evaluator::evaluate(e, &symbols).unwrap();
|
||||
evaluator::evaluate_function_call(
|
||||
fun,
|
||||
vec![Rc::new(Value::Integer(num_bigint::BigInt::from(i)))],
|
||||
&symbols,
|
||||
)
|
||||
.and_then(|v| v.try_to_field_element())
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>(),
|
||||
FunctionValueDefinition::Expression(TypedExpression { e, ty }) => {
|
||||
if let Some(ty) = ty {
|
||||
assert_eq!(ty, &Type::col())
|
||||
};
|
||||
(0..degree)
|
||||
.into_par_iter()
|
||||
.map(|i| {
|
||||
// We could try to avoid the first evaluation to be run for each iteration,
|
||||
// but the data is not thread-safe.
|
||||
let fun = evaluator::evaluate(e, &symbols).unwrap();
|
||||
evaluator::evaluate_function_call(
|
||||
fun,
|
||||
vec![Rc::new(Value::Integer(num_bigint::BigInt::from(i)))],
|
||||
&symbols,
|
||||
)
|
||||
.and_then(|v| v.try_to_field_element())
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()
|
||||
}
|
||||
FunctionValueDefinition::Array(values) => values
|
||||
.iter()
|
||||
.map(|elements| {
|
||||
@@ -103,8 +111,8 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, FixedColumnRef<'a>> for Symbols<'a
|
||||
Value::Custom(FixedColumnRef { name })
|
||||
} else if let Some((_, value)) = self.analyzed.definitions.get(&name.to_string()) {
|
||||
match value {
|
||||
Some(FunctionValueDefinition::Expression(value)) => {
|
||||
evaluator::evaluate(value, self)?
|
||||
Some(FunctionValueDefinition::Expression(TypedExpression { e, ty: _ })) => {
|
||||
evaluator::evaluate(e, self)?
|
||||
}
|
||||
Some(_) => Err(EvalError::Unsupported(
|
||||
"Cannot evaluate arrays and queries.".to_string(),
|
||||
@@ -268,7 +276,7 @@ mod test {
|
||||
let src = r#"
|
||||
constant %N = 8;
|
||||
namespace F(%N);
|
||||
let minus_one = [|x| x - 1][0];
|
||||
let minus_one: int -> int = |x| x - 1;
|
||||
pol constant EVEN(i) { 2 * minus_one(i) + 2 };
|
||||
"#;
|
||||
let analyzed = analyze_string(src);
|
||||
|
||||
@@ -15,7 +15,8 @@ use powdr_ast::{
|
||||
},
|
||||
folder::Folder,
|
||||
visitor::ExpressionVisitable,
|
||||
ArrayLiteral, FunctionCall, IndexAccess, LambdaExpression, MatchArm,
|
||||
ArrayLiteral, ExpressionWithTypeName, FunctionCall, IndexAccess, LambdaExpression,
|
||||
MatchArm,
|
||||
},
|
||||
};
|
||||
|
||||
@@ -79,9 +80,14 @@ impl<'a, T> Folder<T> for Canonicalizer<'a> {
|
||||
.map(Some)
|
||||
.transpose(),
|
||||
},
|
||||
SymbolValue::Expression(mut e) => {
|
||||
canonicalize_inside_expression(&mut e, &self.path, self.paths);
|
||||
Some(Ok(SymbolValue::Expression(e)))
|
||||
SymbolValue::Expression(mut exp) => {
|
||||
for tne in
|
||||
exp.type_name.iter_mut().flat_map(|tn| tn.expressions_mut())
|
||||
{
|
||||
canonicalize_inside_expression(tne, &self.path, self.paths);
|
||||
}
|
||||
canonicalize_inside_expression(&mut exp.e, &self.path, self.paths);
|
||||
Some(Ok(SymbolValue::Expression(exp)))
|
||||
}
|
||||
}
|
||||
.map(|value| value.map(|value| SymbolDefinition { name, value }.into()))
|
||||
@@ -321,7 +327,10 @@ fn check_module<T: Clone>(
|
||||
check_module(location.with_part(name), m, state)?;
|
||||
}
|
||||
SymbolValue::Import(s) => check_import(location.clone(), s.clone(), state)?,
|
||||
SymbolValue::Expression(e) => {
|
||||
SymbolValue::Expression(ExpressionWithTypeName { e, type_name }) => {
|
||||
for tne in type_name.iter().flat_map(|tn| tn.expressions()) {
|
||||
check_expression(&location, tne, state, &HashSet::default())?
|
||||
}
|
||||
check_expression(&location, e, state, &HashSet::default())?
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use powdr_ast::{
|
||||
asm::AbsoluteSymbolPath,
|
||||
asm::SymbolPath,
|
||||
build::{direct_reference, index_access, namespaced_reference},
|
||||
Expression, PILFile, PilStatement, SelectedExpressions,
|
||||
Expression, ExpressionWithTypeName, PILFile, PilStatement, SelectedExpressions,
|
||||
},
|
||||
SourceRef,
|
||||
};
|
||||
@@ -42,9 +42,14 @@ pub fn link<T: FieldElement>(graph: PILGraph<T>) -> Result<PILFile<T>, Vec<Strin
|
||||
// Group by namespace and then sort by name.
|
||||
(namespace, name)
|
||||
})
|
||||
.flat_map(|(mut namespace, e)| {
|
||||
.flat_map(|(mut namespace, ExpressionWithTypeName { e, type_name })| {
|
||||
let name = namespace.pop().unwrap();
|
||||
let def = PilStatement::LetStatement(SourceRef::unknown(), name.to_string(), Some(e));
|
||||
let def = PilStatement::LetStatement(
|
||||
SourceRef::unknown(),
|
||||
name.to_string(),
|
||||
type_name,
|
||||
Some(e),
|
||||
);
|
||||
|
||||
// If there is a namespace change, insert a namespace statement.
|
||||
if current_namespace != namespace {
|
||||
|
||||
@@ -224,7 +224,7 @@ mod test {
|
||||
match stmt {
|
||||
PilStatement::Include(s, _)
|
||||
| PilStatement::Namespace(s, _, _)
|
||||
| PilStatement::LetStatement(s, _, _)
|
||||
| PilStatement::LetStatement(s, _, _, _)
|
||||
| PilStatement::PolynomialDefinition(s, _, _)
|
||||
| PilStatement::PublicDeclaration(s, _, _, _, _)
|
||||
| PilStatement::PolynomialConstantDeclaration(s, _)
|
||||
@@ -385,7 +385,7 @@ mod test {
|
||||
constant %N = 16;
|
||||
namespace Fibonacci(%N);
|
||||
constant %last_row = (%N - 1);
|
||||
let bool = [(|X| (X * (1 - X)))][0];
|
||||
let bool: expr -> expr = (|X| (X * (1 - X)));
|
||||
let one_hot = (|i, which| match i { which => 1, _ => 0, });
|
||||
pol constant ISLAST(i) { one_hot(i, %last_row) };
|
||||
pol commit arr[8];
|
||||
@@ -440,5 +440,39 @@ namespace Fibonacci(%N);
|
||||
);
|
||||
assert_eq!(input.trim(), printed.trim());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn type_names_simple() {
|
||||
let input = r#"
|
||||
let a: col;
|
||||
let b: int;
|
||||
let c: fe;
|
||||
let d: int[];
|
||||
let e: int[7];
|
||||
let f: (int, fe, fe[3])[2];"#;
|
||||
let printed = format!(
|
||||
"{}",
|
||||
parse::<GoldilocksField>(Some("input"), input).unwrap()
|
||||
);
|
||||
assert_eq!(input.trim(), printed.trim());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn type_names_complex() {
|
||||
let input = r#"
|
||||
let a: int -> fe;
|
||||
let b: int -> ();
|
||||
let c: -> ();
|
||||
let d: int, int -> fe;
|
||||
let e: int, int -> (fe, int[2]);
|
||||
let f: ((int, fe), fe[2] -> (fe -> int))[];
|
||||
let g: (int -> fe) -> int;
|
||||
let h: int -> (fe -> int);"#;
|
||||
let printed = format!(
|
||||
"{}",
|
||||
parse::<GoldilocksField>(Some("input"), input).unwrap()
|
||||
);
|
||||
assert_eq!(input.trim(), printed.trim());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -57,7 +57,11 @@ Part: Part = {
|
||||
}
|
||||
|
||||
LetStatementAtModuleLevel: SymbolDefinition<T> = {
|
||||
"let" <name:Identifier> "=" <value:Expression> ";" => SymbolDefinition { name, value: SymbolValue::Expression(value) }
|
||||
"let" <name:Identifier> <type_name:(":" <TypeName>)?> "=" <value:Expression> ";" =>
|
||||
SymbolDefinition {
|
||||
name,
|
||||
value: SymbolValue::Expression(ExpressionWithTypeName{ e: value, type_name })
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------- PIL part -----------------------------
|
||||
@@ -88,7 +92,7 @@ Namespace: PilStatement<T> = {
|
||||
}
|
||||
|
||||
LetStatement: PilStatement<T> = {
|
||||
<start:@L> "let" <id:Identifier> <expr:( "=" <Expression> )?> => PilStatement::LetStatement(ctx.source_ref(start), id, expr)
|
||||
<start:@L> "let" <id:Identifier> <type_name:(":" <TypeName>)?> <expr:( "=" <Expression> )?> => PilStatement::LetStatement(ctx.source_ref(start), id, type_name, expr)
|
||||
}
|
||||
|
||||
ConstantDefinition: PilStatement<T> = {
|
||||
@@ -540,6 +544,36 @@ IfExpression: Box<Expression<T>> = {
|
||||
"{" <else_body:BoxedExpression> "}" => Box::new(Expression::IfExpression(IfExpression{<>}))
|
||||
}
|
||||
|
||||
// ---------------------------- Type Names -----------------------------
|
||||
|
||||
TypeName: TypeName<Expression<T>> = {
|
||||
<params:TypeNameTermList> "->" <value:TypeNameTermBox> => TypeName::Function(FunctionTypeName{<>}),
|
||||
TypeNameTerm
|
||||
}
|
||||
|
||||
TypeNameTermList: Vec<TypeName<Expression<T>>> = {
|
||||
=> vec![],
|
||||
<mut list:( <TypeNameTerm> "," )*> <end:TypeNameTerm> => { list.push(end); list }
|
||||
}
|
||||
|
||||
TypeNameTermBox: Box<TypeName<Expression<T>>> = {
|
||||
TypeNameTerm => Box::new(<>)
|
||||
}
|
||||
|
||||
TypeNameTerm: TypeName<Expression<T>> = {
|
||||
"bool" => TypeName::Bool,
|
||||
"int" => TypeName::Int,
|
||||
"fe" => TypeName::Fe,
|
||||
"string" => TypeName::String,
|
||||
"col" => TypeName::Col,
|
||||
"expr" => TypeName::Expr,
|
||||
"constr" => TypeName::Constr,
|
||||
<base:TypeNameTerm> "[" <length:Expression?> "]" => TypeName::Array(ArrayTypeName{base: Box::new(base), length}),
|
||||
"(" <mut items:( <TypeNameTerm> "," )+> <end:TypeNameTerm> ")" => { items.push(end); TypeName::Tuple(TupleTypeName{items}) },
|
||||
"(" ")" => TypeName::Tuple(TupleTypeName{items: vec![]}),
|
||||
"(" <TypeName> ")",
|
||||
}
|
||||
|
||||
// ---------------------------- Terminals -----------------------------
|
||||
|
||||
|
||||
@@ -561,6 +595,8 @@ SpecialIdentifier: &'input str = {
|
||||
"insn",
|
||||
"int",
|
||||
"fe",
|
||||
"expr",
|
||||
"constr",
|
||||
"bool",
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ use std::{collections::HashMap, fmt::Display, rc::Rc};
|
||||
use itertools::Itertools;
|
||||
use powdr_ast::{
|
||||
analyzed::{
|
||||
types::{Type, TypedExpression},
|
||||
AlgebraicExpression, AlgebraicReference, Analyzed, Expression, FunctionValueDefinition,
|
||||
Identity, IdentityKind, PolynomialReference, PolynomialType, PublicDeclaration, Reference,
|
||||
StatementIdentifier, Symbol, SymbolKind,
|
||||
@@ -58,9 +59,10 @@ pub fn condense<T: FieldElement>(
|
||||
let Some(FunctionValueDefinition::Expression(e)) = definition else {
|
||||
panic!("Expected expression")
|
||||
};
|
||||
assert!(e.ty.is_none() || e.ty == Some(Type::col()));
|
||||
Some((
|
||||
name.clone(),
|
||||
(symbol.clone(), condenser.condense_expression(e)),
|
||||
(symbol.clone(), condenser.condense_expression(&e.e)),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
@@ -231,7 +233,7 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, Condensate<T>> for &'a Condenser<T
|
||||
}
|
||||
} else {
|
||||
match value {
|
||||
Some(FunctionValueDefinition::Expression(value)) => {
|
||||
Some(FunctionValueDefinition::Expression(TypedExpression { e: value, ty: _ })) => {
|
||||
evaluator::evaluate(value, self)?
|
||||
}
|
||||
_ => Err(EvalError::Unsupported(
|
||||
@@ -278,8 +280,8 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, Condensate<T>> for &'a Condenser<T
|
||||
};
|
||||
|
||||
match self.symbols[&name].1.as_ref() {
|
||||
Some(FunctionValueDefinition::Expression(v)) => {
|
||||
let function = evaluate(v, self)?;
|
||||
Some(FunctionValueDefinition::Expression(TypedExpression { e, ty: _ })) => {
|
||||
let function = evaluate(e, self)?;
|
||||
evaluate_function_call(function, arguments, self)
|
||||
}
|
||||
None => Err(EvalError::SymbolNotFound(format!(
|
||||
@@ -390,7 +392,7 @@ impl<T: FieldElement> Custom for Condensate<T> {
|
||||
fn type_name(&self) -> String {
|
||||
match self {
|
||||
Condensate::Expression(_) => "expr".to_string(),
|
||||
Condensate::Identity(_, _) => "identity".to_string(),
|
||||
Condensate::Identity(_, _) => "constr".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use std::{
|
||||
|
||||
use itertools::Itertools;
|
||||
use powdr_ast::{
|
||||
analyzed::{Expression, FunctionValueDefinition, Reference, Symbol},
|
||||
analyzed::{types::TypedExpression, Expression, FunctionValueDefinition, Reference, Symbol},
|
||||
parsed::{
|
||||
display::quote, BinaryOperator, FunctionCall, LambdaExpression, MatchArm, MatchPattern,
|
||||
UnaryOperator,
|
||||
@@ -189,7 +189,7 @@ const BUILTINS: [(&str, BuiltinFunction); 6] = [
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Debug)]
|
||||
pub enum BuiltinFunction {
|
||||
/// std::array::len: [_] -> int, returns the length of an array
|
||||
/// std::array::len: _[] -> int, returns the length of an array
|
||||
ArrayLen,
|
||||
/// std::field::modulus: -> int, returns the field modulus as int
|
||||
Modulus,
|
||||
@@ -282,7 +282,9 @@ impl<'a, T: FieldElement> SymbolLookup<'a, T, NoCustom> for Definitions<'a, T> {
|
||||
fn lookup(&self, name: &'a str) -> Result<Value<'a, T, NoCustom>, EvalError> {
|
||||
Ok(match self.0.get(&name.to_string()) {
|
||||
Some((_, value)) => match value {
|
||||
Some(FunctionValueDefinition::Expression(value)) => evaluate(value, self)?,
|
||||
Some(FunctionValueDefinition::Expression(TypedExpression { e, ty: _ })) => {
|
||||
evaluate(e, self)?
|
||||
}
|
||||
_ => Err(EvalError::Unsupported(
|
||||
"Cannot evaluate arrays and queries.".to_string(),
|
||||
))?,
|
||||
@@ -640,7 +642,8 @@ mod test {
|
||||
|
||||
fn parse_and_evaluate_symbol(input: &str, symbol: &str) -> String {
|
||||
let analyzed = analyze_string::<GoldilocksField>(input);
|
||||
let Some(FunctionValueDefinition::Expression(symbol)) = &analyzed.definitions[symbol].1
|
||||
let Some(FunctionValueDefinition::Expression(TypedExpression { e: symbol, ty: _ })) =
|
||||
&analyzed.definitions[symbol].1
|
||||
else {
|
||||
panic!()
|
||||
};
|
||||
@@ -686,8 +689,8 @@ mod test {
|
||||
#[test]
|
||||
pub fn capturing() {
|
||||
let src = r#"namespace Main(16);
|
||||
let f = |n, g| match n { 99 => |i| n, 1 => g(3) };
|
||||
let result = f(1, f(99, |x| x + 3000));
|
||||
let f: int, (int -> int) -> (int -> int) = |n, g| match n { 99 => |i| n, 1 => g };
|
||||
let result = f(1, f(99, |x| x + 3000))(0);
|
||||
"#;
|
||||
// If the lambda function returned by the expression f(99, ...) does not
|
||||
// properly capture the value of n in a closure, then f(1, ...) would return 1.
|
||||
|
||||
@@ -505,6 +505,58 @@ namespace N(16);
|
||||
let expected = r#"namespace N(16);
|
||||
let w = (|| 2);
|
||||
constant x = (|i| (|| N.w()))(2)();
|
||||
"#;
|
||||
let formatted = analyze_string::<GoldilocksField>(input).to_string();
|
||||
assert_eq!(formatted, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simple_type_resolution() {
|
||||
let input = r#"namespace N(16);
|
||||
let w: col[3 + 4];
|
||||
"#;
|
||||
let expected = r#"namespace N(16);
|
||||
col witness w[7];
|
||||
"#;
|
||||
let formatted = analyze_string::<GoldilocksField>(input).to_string();
|
||||
assert_eq!(formatted, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn complex_type_resolution() {
|
||||
let input = r#"namespace N(16);
|
||||
let f: int -> int = |i| i + 10;
|
||||
let x: (int -> int), int -> int = |k, i| k(2**i);
|
||||
let y: (int -> fe)[x(f, 2)];
|
||||
let z: (((int -> int), int -> int)[x(|i| i, 3)], col) = ([x, x, x, x, x, x, x, x], y[0]);
|
||||
"#;
|
||||
let expected = r#"namespace N(16);
|
||||
let f: int -> int = (|i| (i + 10));
|
||||
let x: (int -> int), int -> int = (|k, i| k((2 ** i)));
|
||||
col witness y[14];
|
||||
let z: (((int -> int), int -> int)[8], col) = ([N.x, N.x, N.x, N.x, N.x, N.x, N.x, N.x], N.y[0]);
|
||||
"#;
|
||||
let formatted = analyze_string::<GoldilocksField>(input).to_string();
|
||||
assert_eq!(formatted, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn expr_and_identity() {
|
||||
let input = r#"namespace N(16);
|
||||
let f: expr, expr -> constr[] = |x, y| [x == y];
|
||||
let g: expr -> constr[1] = |x| [x == 0];
|
||||
let x: col;
|
||||
let y: col;
|
||||
f(x, y);
|
||||
g((x));
|
||||
"#;
|
||||
let expected = r#"namespace N(16);
|
||||
let f: expr, expr -> constr[] = (|x, y| [(x == y)]);
|
||||
let g: expr -> constr[1] = (|x| [(x == 0)]);
|
||||
col witness x;
|
||||
col witness y;
|
||||
N.x = N.y;
|
||||
N.x = 0;
|
||||
"#;
|
||||
let formatted = analyze_string::<GoldilocksField>(input).to_string();
|
||||
assert_eq!(formatted, expected);
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
use std::collections::{BTreeMap, HashMap};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use powdr_ast::analyzed::types::{ArrayType, Type, TypedExpression};
|
||||
use powdr_ast::parsed::{
|
||||
self, FunctionDefinition, PilStatement, PolynomialName, SelectedExpressions,
|
||||
self, FunctionDefinition, PilStatement, PolynomialName, SelectedExpressions, TypeName,
|
||||
};
|
||||
use powdr_ast::SourceRef;
|
||||
use powdr_number::{DegreeType, FieldElement};
|
||||
@@ -103,8 +104,8 @@ where
|
||||
.handle_symbol_definition(
|
||||
source,
|
||||
name,
|
||||
None,
|
||||
SymbolKind::Poly(PolynomialType::Intermediate),
|
||||
Some(Type::col()),
|
||||
Some(FunctionDefinition::Expression(value)),
|
||||
),
|
||||
PilStatement::PublicDeclaration(source, name, polynomial, array_index, index) => {
|
||||
@@ -117,8 +118,8 @@ where
|
||||
.handle_symbol_definition(
|
||||
source,
|
||||
name,
|
||||
None,
|
||||
SymbolKind::Poly(PolynomialType::Constant),
|
||||
Some(Type::col()),
|
||||
Some(definition),
|
||||
),
|
||||
PilStatement::PolynomialCommitDeclaration(source, polynomials, None) => {
|
||||
@@ -130,12 +131,14 @@ where
|
||||
Some(definition),
|
||||
) => {
|
||||
assert!(polynomials.len() == 1);
|
||||
let name = polynomials.pop().unwrap();
|
||||
let (name, ty) =
|
||||
self.name_and_type_from_polynomial_name(polynomials.pop().unwrap());
|
||||
|
||||
self.handle_symbol_definition(
|
||||
source,
|
||||
name.name,
|
||||
name.array_size,
|
||||
name,
|
||||
SymbolKind::Poly(PolynomialType::Committed),
|
||||
ty,
|
||||
Some(definition),
|
||||
)
|
||||
}
|
||||
@@ -147,24 +150,53 @@ where
|
||||
self.handle_symbol_definition(
|
||||
source,
|
||||
name,
|
||||
None,
|
||||
SymbolKind::Constant(),
|
||||
Some(Type::Fe),
|
||||
Some(FunctionDefinition::Expression(value)),
|
||||
)
|
||||
}
|
||||
PilStatement::LetStatement(source, name, value) => {
|
||||
self.handle_generic_definition(source, name, value)
|
||||
PilStatement::LetStatement(source, name, type_name, value) => {
|
||||
self.handle_generic_definition(source, name, type_name, value)
|
||||
}
|
||||
_ => self.handle_identity_statement(statement),
|
||||
}
|
||||
}
|
||||
|
||||
fn name_and_type_from_polynomial_name(
|
||||
&mut self,
|
||||
PolynomialName { name, array_size }: PolynomialName<T>,
|
||||
) -> (String, Option<Type>) {
|
||||
let ty = Some(match array_size {
|
||||
None => Type::col(),
|
||||
Some(len) => {
|
||||
let length = self
|
||||
.evaluate_expression(len)
|
||||
.map_err(|e| {
|
||||
panic!("Error evaluating length of array of witness columns {name}:\n{e}")
|
||||
})
|
||||
.map(|length| length.to_degree())
|
||||
.ok();
|
||||
Type::Array(ArrayType {
|
||||
base: Box::new(Type::col()),
|
||||
length,
|
||||
})
|
||||
}
|
||||
});
|
||||
(name, ty)
|
||||
}
|
||||
|
||||
fn handle_generic_definition(
|
||||
&mut self,
|
||||
source: SourceRef,
|
||||
name: String,
|
||||
value: Option<::powdr_ast::parsed::Expression<T>>,
|
||||
type_name: Option<TypeName<parsed::Expression<T>>>,
|
||||
value: Option<parsed::Expression<T>>,
|
||||
) -> Vec<PILItem<T>> {
|
||||
let ty = type_name.map(|n|
|
||||
self.resolve_type_name(n.clone())
|
||||
.map_err(|e| panic!("Error evaluating expressions in type name \"{n}\" to reduce it to a type:\n{e})"))
|
||||
.unwrap()
|
||||
);
|
||||
// Determine whether this is a fixed column, a constant or something else
|
||||
// depending on the structure of the value and if we can evaluate
|
||||
// it to a single number.
|
||||
@@ -172,30 +204,56 @@ where
|
||||
match value {
|
||||
None => {
|
||||
// No value provided => treat it as a witness column.
|
||||
let ty = ty
|
||||
.map(|t| {
|
||||
if let Type::Array(ArrayType { base, length }) = &t {
|
||||
if base.as_ref() != &Type::col() {
|
||||
panic!("Symbol {name} is declared without value and thus must be a witness column array, but its type is {t} instead of col[].");
|
||||
}
|
||||
if length.is_none() {
|
||||
panic!("Explicit array length required for column {name}: {t}");
|
||||
}
|
||||
t
|
||||
} else {
|
||||
if t != Type::col() {
|
||||
panic!("Symbol {name} is declared without value and thus must be a witness column, but its type is {t} instead of col.");
|
||||
}
|
||||
t
|
||||
}
|
||||
})
|
||||
.unwrap_or(Type::col());
|
||||
self.handle_symbol_definition(
|
||||
source,
|
||||
name,
|
||||
None,
|
||||
SymbolKind::Poly(PolynomialType::Committed),
|
||||
Some(ty),
|
||||
None,
|
||||
)
|
||||
}
|
||||
Some(value) => {
|
||||
let symbol_kind = if matches!(&value, parsed::Expression::LambdaExpression(lambda) if lambda.params.len() == 1)
|
||||
// TODO if we have proper type deduction here in the future, we can rely only on the type.
|
||||
let (ty, symbol_kind) = if ty == Some(Type::col())
|
||||
|| (ty.is_none()
|
||||
&& matches!(&value, parsed::Expression::LambdaExpression(lambda) if lambda.params.len() == 1))
|
||||
{
|
||||
(
|
||||
Some(Type::col()),
|
||||
SymbolKind::Poly(PolynomialType::Constant),
|
||||
)
|
||||
} else if ty == Some(Type::Fe)
|
||||
|| (ty.is_none() && self.evaluate_expression(value.clone()).is_ok())
|
||||
{
|
||||
SymbolKind::Poly(PolynomialType::Constant)
|
||||
} else if self.evaluate_expression(value.clone()).is_ok() {
|
||||
// Value evaluates to a constant number => treat it as a constant
|
||||
SymbolKind::Constant()
|
||||
(Some(Type::Fe), SymbolKind::Constant())
|
||||
} else {
|
||||
// Otherwise, treat it as "generic definition"
|
||||
SymbolKind::Other()
|
||||
(ty, SymbolKind::Other())
|
||||
};
|
||||
self.handle_symbol_definition(
|
||||
source,
|
||||
name,
|
||||
None,
|
||||
symbol_kind,
|
||||
ty,
|
||||
Some(FunctionDefinition::Expression(value)),
|
||||
)
|
||||
}
|
||||
@@ -261,12 +319,13 @@ where
|
||||
) -> Vec<PILItem<T>> {
|
||||
polynomials
|
||||
.into_iter()
|
||||
.flat_map(|PolynomialName { name, array_size }| {
|
||||
.flat_map(|poly_name| {
|
||||
let (name, ty) = self.name_and_type_from_polynomial_name(poly_name);
|
||||
self.handle_symbol_definition(
|
||||
source.clone(),
|
||||
name,
|
||||
array_size,
|
||||
SymbolKind::Poly(polynomial_type),
|
||||
ty,
|
||||
None,
|
||||
)
|
||||
})
|
||||
@@ -277,17 +336,20 @@ where
|
||||
&mut self,
|
||||
source: SourceRef,
|
||||
name: String,
|
||||
array_size: Option<::powdr_ast::parsed::Expression<T>>,
|
||||
symbol_kind: SymbolKind,
|
||||
ty: Option<Type>,
|
||||
value: Option<FunctionDefinition<T>>,
|
||||
) -> Vec<PILItem<T>> {
|
||||
let have_array_size = array_size.is_some();
|
||||
let length = array_size
|
||||
.map(|l| self.evaluate_expression(l).unwrap())
|
||||
.map(|l| l.to_degree());
|
||||
if length.is_some() {
|
||||
assert!(value.is_none());
|
||||
}
|
||||
let length = ty.as_ref().and_then(|t| {
|
||||
if let Type::Array(ArrayType { length, base: _ }) = t {
|
||||
if length.is_none() && symbol_kind != SymbolKind::Other() {
|
||||
panic!("Explicit array length required for column {name}.");
|
||||
}
|
||||
*length
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
let id = self.counters.dispense_symbol_id(symbol_kind, length);
|
||||
let name = self.driver.resolve_decl(&name);
|
||||
let symbol = Symbol {
|
||||
@@ -300,13 +362,15 @@ where
|
||||
|
||||
let value = value.map(|v| match v {
|
||||
FunctionDefinition::Expression(expr) => {
|
||||
assert!(!have_array_size);
|
||||
assert!(symbol_kind != SymbolKind::Poly(PolynomialType::Committed));
|
||||
FunctionValueDefinition::Expression(self.process_expression(expr))
|
||||
FunctionValueDefinition::Expression(TypedExpression {
|
||||
e: self.process_expression(expr),
|
||||
ty,
|
||||
})
|
||||
}
|
||||
FunctionDefinition::Query(expr) => {
|
||||
assert!(!have_array_size);
|
||||
assert_eq!(symbol_kind, SymbolKind::Poly(PolynomialType::Committed));
|
||||
assert!(ty.is_none() || ty == Some(Type::col()));
|
||||
FunctionValueDefinition::Query(self.process_expression(expr))
|
||||
}
|
||||
FunctionDefinition::Array(value) => {
|
||||
@@ -318,6 +382,7 @@ where
|
||||
expression.iter().map(|e| e.size()).sum::<DegreeType>(),
|
||||
self.degree.unwrap()
|
||||
);
|
||||
assert!(ty.is_none() || ty == Some(Type::col()));
|
||||
FunctionValueDefinition::Array(expression)
|
||||
}
|
||||
});
|
||||
@@ -351,10 +416,18 @@ where
|
||||
})]
|
||||
}
|
||||
|
||||
fn evaluate_expression(
|
||||
&self,
|
||||
expr: ::powdr_ast::parsed::Expression<T>,
|
||||
) -> Result<T, EvalError> {
|
||||
/// Resolves a type name into a concrete type.
|
||||
/// This routine mainly evaluates array length expressions.
|
||||
fn resolve_type_name(&self, mut n: TypeName<parsed::Expression<T>>) -> Result<Type, EvalError> {
|
||||
// Replace all expressions by number literals.
|
||||
for e in n.expressions_mut() {
|
||||
let v = self.evaluate_expression(e.clone())?;
|
||||
*e = parsed::Expression::Number(v);
|
||||
}
|
||||
Ok(n.into())
|
||||
}
|
||||
|
||||
fn evaluate_expression(&self, expr: parsed::Expression<T>) -> Result<T, EvalError> {
|
||||
evaluator::evaluate_expression(
|
||||
&ExpressionProcessor::new(self.driver).process_expression(expr),
|
||||
self.driver.definitions(),
|
||||
@@ -366,13 +439,13 @@ where
|
||||
ExpressionProcessor::new(self.driver)
|
||||
}
|
||||
|
||||
fn process_expression(&self, expr: ::powdr_ast::parsed::Expression<T>) -> Expression<T> {
|
||||
fn process_expression(&self, expr: parsed::Expression<T>) -> Expression<T> {
|
||||
self.expression_processor().process_expression(expr)
|
||||
}
|
||||
|
||||
fn process_selected_expressions(
|
||||
&self,
|
||||
expr: ::powdr_ast::parsed::SelectedExpressions<::powdr_ast::parsed::Expression<T>>,
|
||||
expr: parsed::SelectedExpressions<parsed::Expression<T>>,
|
||||
) -> SelectedExpressions<Expression<T>> {
|
||||
self.expression_processor()
|
||||
.process_selected_expressions(expr)
|
||||
|
||||
@@ -137,11 +137,11 @@ machine Arith(CLK32_31, operation_id){
|
||||
/// returns |n| a(0) * b(n) + ... + a(n) * b(0)
|
||||
let product = |a, b| |n| dot_prod(n + 1, a, |i| b(n - i));
|
||||
/// Converts array to function, extended by zeros.
|
||||
let array_as_fun = [|arr| |i| if 0 <= i && i < array::len(arr) {
|
||||
let array_as_fun: expr[] -> (int -> expr) = |arr| |i| if 0 <= i && i < array::len(arr) {
|
||||
arr[i]
|
||||
} else {
|
||||
0
|
||||
}][0];
|
||||
};
|
||||
let shift_right = |fn, amount| |i| fn(i - amount);
|
||||
|
||||
let x1f = array_as_fun(x1);
|
||||
@@ -151,12 +151,11 @@ machine Arith(CLK32_31, operation_id){
|
||||
let y3f = array_as_fun(y3);
|
||||
|
||||
// Defined for arguments from 0 to 31 (inclusive)
|
||||
let eq0 = (|| |nr|
|
||||
let eq0: int -> expr = |nr|
|
||||
product(x1f, y1f)(nr)
|
||||
+ x2f(nr)
|
||||
- shift_right(y2f, 16)(nr)
|
||||
- y3f(nr)
|
||||
)();
|
||||
- y3f(nr);
|
||||
|
||||
// Note that Polygon uses a single 22-Bit column. However, this approach allows for a lower degree (2**16)
|
||||
// while still preventing overflows: The 32-bit carry gets added to 32 16-Bit values, which can't overflow
|
||||
|
||||
@@ -13,4 +13,5 @@ let map = |arr, f| new(len(arr), |i| f(arr[i]));
|
||||
let fold = |arr, initial, folder| std::utils::fold(len(arr), |i| arr[i], initial, folder);
|
||||
|
||||
/// Returns the sum of the array elements.
|
||||
let sum = [|arr| fold(arr, 0, |a, b| a + b)][0];
|
||||
/// This actually also works on field elements, so the type is currently too restrictive.
|
||||
let sum: int[] -> int = |arr| fold(arr, 0, |a, b| a + b);
|
||||
@@ -2,6 +2,6 @@
|
||||
/// when evaluated.
|
||||
/// It returns an empty array so that it can be used at constraint level.
|
||||
/// This symbol is not an empty array, the actual semantics are overridden.
|
||||
let print = [];
|
||||
let print: string -> constr[] = [];
|
||||
|
||||
let println = [|msg| print(msg + "\n")][0];
|
||||
let println: string -> constr[] = |msg| print(msg + "\n");
|
||||
@@ -22,4 +22,4 @@ let sum = |length, f| fold(length, f, 0, |acc, e| (acc + e));
|
||||
let unchanged_until = |c, latch| (c' - c) * (1 - latch) == 0;
|
||||
|
||||
/// Evaluates to a constraint that forces `c` to be either 0 or 1.
|
||||
let force_bool = [|c| c * (1 - c) == 0][0];
|
||||
let force_bool: expr -> constr = |c| c * (1 - c) == 0;
|
||||
@@ -26,11 +26,12 @@ mod R {
|
||||
machine FullConstant {
|
||||
degree 2;
|
||||
|
||||
let C = |i| match i % 2 {
|
||||
let C: int -> fe = |i| match i % 2 {
|
||||
0 => x,
|
||||
1 => y,
|
||||
};
|
||||
col commit w[2];
|
||||
// Use some weird type just for the sake of it.
|
||||
let w: col[sum(2, |i| 1)];
|
||||
|
||||
// This and the next line are the same.
|
||||
super::utils::sum(2, |i| w[i]) == 8;
|
||||
|
||||
@@ -29,8 +29,7 @@ namespace Arith(N);
|
||||
/// returns f(0) + f(1) + ... + f(length - 1)
|
||||
let sum = |length, f| fold(length, f, 0, |acc, e| acc + e);
|
||||
|
||||
// TODO the weird syntax is needed so that this is not classified as a constant column
|
||||
let force_boolean = (|| |x| x * (1 - x) == 0)();
|
||||
let force_boolean: expr -> constr = |x| x * (1 - x) == 0;
|
||||
|
||||
let clock = |j, row| if row % 32 == j { 1 } else { 0 };
|
||||
// Arrays of fixed columns are not supported yet.
|
||||
@@ -140,8 +139,7 @@ namespace Arith(N);
|
||||
// That way we could even support functions returning lookups.
|
||||
|
||||
// x can only change between two blocks of 32 rows.
|
||||
// TODO the weird syntax is needed so that this is not classified as a fixed column.
|
||||
let fixed_inside_32_block = (|| |x| (x - x') * (1 - CLK32[31]) == 0)();
|
||||
let fixed_inside_32_block: expr -> constr = |x| (x - x') * (1 - CLK32[31]) == 0;
|
||||
|
||||
make_array(16, |i| fixed_inside_32_block(x1[i]));
|
||||
make_array(16, |i| fixed_inside_32_block(y1[i]));
|
||||
@@ -209,12 +207,11 @@ namespace Arith(N);
|
||||
let q2f = array_as_fun(q2, 16);
|
||||
|
||||
// Defined for arguments from 0 to 31 (inclusive)
|
||||
let eq0 = (|| |nr|
|
||||
let eq0: int -> expr = |nr|
|
||||
product(x1f, y1f)(nr)
|
||||
+ x2f(nr)
|
||||
- shift_right(y2f, 16)(nr)
|
||||
- y3f(nr)
|
||||
)();
|
||||
- y3f(nr);
|
||||
|
||||
|
||||
/*******
|
||||
@@ -224,16 +221,16 @@ namespace Arith(N);
|
||||
*******/
|
||||
|
||||
// 0xffffffffffffffffffffffffffffffffffffffffffffffffffff fffe ffff fc2f
|
||||
let p = array_as_fun([
|
||||
let p: col = array_as_fun([
|
||||
0xfc2f, 0xffff, 0xfffe, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff,
|
||||
0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff
|
||||
], 16);
|
||||
|
||||
// The "- 4 * shift_right(p, 16)" effectively subtracts 4 * (p << 16 * 16) = 2 ** 258 * p
|
||||
// As a result, the term computes `(x - 2 ** 258) * p`.
|
||||
let product_with_p = (|| |x| |nr| product(p, x)(nr) - 4 * shift_right(p, 16)(nr))();
|
||||
// The "- 4 * shift_right(p, 16)" effectively subtracts 4 * (p << 16 * 16) = 2 ** 258 * p
|
||||
// As a result, the term computes `(x - 2 ** 258) * p`.
|
||||
let product_with_p: int -> (int -> expr) = |x| |nr| product(p, x)(nr) - 4 * shift_right(p, 16)(nr);
|
||||
|
||||
let eq1 = (|| |nr| product(sf, x2f)(nr) - product(sf, x1f)(nr) - y2f(nr) + y1f(nr) + product_with_p(q0f)(nr))();
|
||||
let eq1: int -> expr = |nr| product(sf, x2f)(nr) - product(sf, x1f)(nr) - y2f(nr) + y1f(nr) + product_with_p(q0f)(nr);
|
||||
|
||||
/*******
|
||||
*
|
||||
@@ -241,7 +238,7 @@ namespace Arith(N);
|
||||
*
|
||||
*******/
|
||||
|
||||
let eq2 = (|| |nr| 2 * product(sf, y1f)(nr) - 3 * product(x1f, x1f)(nr) + product_with_p(q0f)(nr))();
|
||||
let eq2: int -> expr = |nr| 2 * product(sf, y1f)(nr) - 3 * product(x1f, x1f)(nr) + product_with_p(q0f)(nr);
|
||||
|
||||
/*******
|
||||
*
|
||||
@@ -249,7 +246,7 @@ namespace Arith(N);
|
||||
*
|
||||
*******/
|
||||
|
||||
let eq3 = (|| |nr| product(sf, sf)(nr) - x1f(nr) - x2f(nr) - x3f(nr) + product_with_p(q1f)(nr))();
|
||||
let eq3: int -> expr = |nr| product(sf, sf)(nr) - x1f(nr) - x2f(nr) - x3f(nr) + product_with_p(q1f)(nr);
|
||||
|
||||
|
||||
/*******
|
||||
@@ -258,7 +255,7 @@ namespace Arith(N);
|
||||
*
|
||||
*******/
|
||||
|
||||
let eq4 = (|| |nr| product(sf, x1f)(nr) - product(sf, x3f)(nr) - y1f(nr) - y3f(nr) + product_with_p(q2f)(nr))();
|
||||
let eq4: int -> expr = |nr| product(sf, x1f)(nr) - product(sf, x3f)(nr) - y1f(nr) - y3f(nr) + product_with_p(q2f)(nr);
|
||||
|
||||
pol commit selEq[4];
|
||||
|
||||
|
||||
30
test_data/pil/book/declarations.pil
Normal file
30
test_data/pil/book/declarations.pil
Normal file
@@ -0,0 +1,30 @@
|
||||
namespace Main(16);
|
||||
// ANCHOR: declarations
|
||||
// This defines a constant
|
||||
let rows = 2**16;
|
||||
// This defines a fixed column that contains the row number in each row.
|
||||
let step = |i| i;
|
||||
// This defines a copy of the column, also a fixed column because the type
|
||||
// is explicitly specified.
|
||||
let also_step: col = step;
|
||||
// Here, we have a witness column.
|
||||
let x;
|
||||
// This functions defines a fixed column where each cell contains the
|
||||
// square of its row number.
|
||||
let square = |x| x*x;
|
||||
// The same function as `square` above, but now its type is given as
|
||||
// `int -> int` and thus it is *not* classified as a column. Instead,
|
||||
// it is stored as a utility function. If utility functions are
|
||||
// referenced in constraints, they have to be evaluated, meaning that
|
||||
// the constraint `w = square_non_column;` is invalid but both
|
||||
// `w = square_non_column(7);` and `w = square;` are valid constraints.
|
||||
let square_non_column: int -> int = |x| x*x;
|
||||
// A recursive function, taking a function and an integer as parameter
|
||||
let sum = |f, i| match i {
|
||||
0 => f(0),
|
||||
_ => f(i) + sum(f, i - 1)
|
||||
};
|
||||
// ANCHOR_END: declarations
|
||||
// We need at least one constraint to create a proof in the test.
|
||||
let w;
|
||||
w + square = 0;
|
||||
@@ -6,20 +6,17 @@ namespace Main(16);
|
||||
};
|
||||
// returns f(0) + f(1) + ... + f(length - 1)
|
||||
let sum = |length, f| fold(length, f, 0, |acc, e| acc + e);
|
||||
// If called with a single value, this function evaluates the equality,
|
||||
// otherwise, it returns a constraint (if called with a column or
|
||||
// an algebraic expression).
|
||||
// If we write "|x| x == 20", it will be classified as a fixed column,
|
||||
// so we use a trick that makes it not look like a function with a single
|
||||
// parameter.
|
||||
let equals_twenty = [|x| x == 20][0];
|
||||
// declares an array of 16 witness columns.
|
||||
// This function takes an algebraic expression (a column or expression
|
||||
// involving columns) and returns an identity that forces this expression
|
||||
// to equal 20.
|
||||
let equals_twenty: expr -> constr = |x| x == 20;
|
||||
// This declares an array of 16 witness columns.
|
||||
col witness wit[16];
|
||||
// This expression has to evaluate to a constraint, but we can still use
|
||||
// This expression has to evaluate to an identity, but we can still use
|
||||
// higher order functions and all the flexibility of the language.
|
||||
// The sub-expression "sum(16, |i| wit[i]" evaluates to the algebraic
|
||||
// The sub-expression `sum(16, |i| wit[i])` evaluates to the algebraic
|
||||
// expression "wit[0] + wit[1] + ... + wit[15]", which is then
|
||||
// turned by "equals_twenty" into the constraint
|
||||
// turned into the identity by `equals_twenty`
|
||||
// wit[0] + wit[1] + ... + wit[15] == 20.
|
||||
equals_twenty(sum(16, |i| wit[i]));
|
||||
|
||||
|
||||
Reference in New Issue
Block a user