mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-01-09 14:48:16 -05:00
Macros
This commit is contained in:
28
README.md
28
README.md
@@ -118,6 +118,34 @@ fn<T> mul(a: T[], b: T[]) -> T[] = [a[i] * b[i] | i: 0..a.len()];
|
||||
|
||||
We will stick as much to Rust as possible for now. This means there is a trait for the multiplication operator that we define.
|
||||
|
||||
### Macros
|
||||
|
||||
As a "quick and dirty" hack, we implemented syntactic macros for now:
|
||||
|
||||
```
|
||||
macro ite(C, A, B) { C * A + (1 - C) * B }
|
||||
```
|
||||
|
||||
Macros can evaluate to zero or more statements (constraints / identities) and
|
||||
zero or one expression.
|
||||
The statements are terminated by `;` and the last element is the expression.
|
||||
Macros can of course also invoke other macros:
|
||||
|
||||
```
|
||||
macro bool(X) { X * (1 - X) = 0; }
|
||||
macro ite(C, A, B) { bool(C); C * A + (1 - C) * B }
|
||||
```
|
||||
|
||||
In the example above, `bool` evaluates to one polynomial identity constraint and no expression.
|
||||
The macro `ite` adds the identity constraint generated through the invocation of `bool`
|
||||
to the list and evaluates to an expression at the same time.
|
||||
|
||||
If a macro is used in statement context, it cannot have an expression and
|
||||
if it is used in expression context, it must have an expression (but can also have statements).
|
||||
|
||||
The optimizer will of course ensure that redundant constraints are removed
|
||||
(be it because the are just duplicated or because they are already implied by lookups).
|
||||
|
||||
### Instruction / Assembly language
|
||||
|
||||
The second layer of this langauge is to define an assembly-like language that helps in defining complex constants.
|
||||
|
||||
@@ -26,6 +26,7 @@ struct Context {
|
||||
constants: HashMap<String, ConstantNumberType>,
|
||||
definitions: HashMap<String, (Polynomial, Option<Expression>)>,
|
||||
public_declarations: HashMap<String, PublicDeclaration>,
|
||||
macros: HashMap<String, MacroDefinition>,
|
||||
identities: Vec<Identity>,
|
||||
/// The order in which definitions and identities
|
||||
/// appear in the source.
|
||||
@@ -38,6 +39,8 @@ struct Context {
|
||||
intermediate_poly_counter: u64,
|
||||
identity_counter: HashMap<IdentityKind, u64>,
|
||||
local_variables: HashMap<String, u64>,
|
||||
/// If we are evaluating a macro, this holds the arguments.
|
||||
macro_arguments: Option<Vec<Expression>>,
|
||||
}
|
||||
|
||||
pub enum StatementIdentifier {
|
||||
@@ -209,6 +212,15 @@ pub enum PolynomialType {
|
||||
Intermediate,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MacroDefinition {
|
||||
pub source: SourceRef,
|
||||
pub absolute_name: String,
|
||||
pub parameters: Vec<String>,
|
||||
pub identities: Vec<ast::Statement>,
|
||||
pub expression: Option<ast::Expression>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct SourceRef {
|
||||
pub file: String, // TODO should maybe be a shared pointer
|
||||
@@ -290,11 +302,16 @@ impl Context {
|
||||
Statement::ConstantDefinition(_, name, value) => {
|
||||
self.handle_constant_definition(name, value)
|
||||
}
|
||||
Statement::MacroDefinition(start, name, params, statments, expression) => self
|
||||
.handle_macro_definition(
|
||||
self.to_source_ref(*start),
|
||||
name,
|
||||
params,
|
||||
statments,
|
||||
expression,
|
||||
),
|
||||
_ => {
|
||||
let identity = self.process_identity_statement(statement);
|
||||
let id = self.identities.len();
|
||||
self.identities.push(identity);
|
||||
self.source_order.push(StatementIdentifier::Identity(id));
|
||||
self.handle_identity_statement(statement);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -311,7 +328,16 @@ impl Context {
|
||||
}
|
||||
}
|
||||
|
||||
fn process_identity_statement(&mut self, statement: &ast::Statement) -> Identity {
|
||||
fn handle_identity_statement(&mut self, statement: &ast::Statement) {
|
||||
if let ast::Statement::FunctionCall(_start, name, arguments) = statement {
|
||||
// TODO check that it does not contain local variable references.
|
||||
// But we also need to do some other well-formedness checks.
|
||||
if self.process_macro_call(name, arguments).is_some() {
|
||||
panic!("Invoked a macro in statement context with non-empty expression.");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let (start, kind, left, right) = match statement {
|
||||
ast::Statement::PolynomialIdentity(start, expression) => (
|
||||
start,
|
||||
@@ -352,13 +378,16 @@ impl Context {
|
||||
}
|
||||
};
|
||||
let id = self.dispense_id(kind);
|
||||
Identity {
|
||||
let identity = Identity {
|
||||
id,
|
||||
kind,
|
||||
source: self.to_source_ref(*start),
|
||||
left,
|
||||
right,
|
||||
}
|
||||
};
|
||||
let id = self.identities.len();
|
||||
self.identities.push(identity);
|
||||
self.source_order.push(StatementIdentifier::Identity(id));
|
||||
}
|
||||
|
||||
fn handle_include(&mut self, path: &str) {
|
||||
@@ -422,6 +451,7 @@ impl Context {
|
||||
length,
|
||||
};
|
||||
let name = poly.absolute_name.clone();
|
||||
assert!(self.local_variables.is_empty());
|
||||
self.local_variables = parameters
|
||||
.map(|p| {
|
||||
p.iter()
|
||||
@@ -431,7 +461,7 @@ impl Context {
|
||||
})
|
||||
.unwrap_or_default();
|
||||
let value = value.map(|e| self.process_expression(e));
|
||||
self.local_variables = HashMap::default();
|
||||
self.local_variables.clear();
|
||||
let is_new = self
|
||||
.definitions
|
||||
.insert(name.clone(), (poly, value))
|
||||
@@ -480,6 +510,30 @@ impl Context {
|
||||
id
|
||||
}
|
||||
|
||||
fn handle_macro_definition(
|
||||
&mut self,
|
||||
source: SourceRef,
|
||||
name: &String,
|
||||
params: &[String],
|
||||
statements: &[ast::Statement],
|
||||
expression: &Option<ast::Expression>,
|
||||
) {
|
||||
let is_new = self
|
||||
.macros
|
||||
.insert(
|
||||
name.clone(),
|
||||
MacroDefinition {
|
||||
source,
|
||||
absolute_name: self.namespaced(name),
|
||||
parameters: params.to_vec(),
|
||||
identities: statements.to_vec(),
|
||||
expression: expression.clone(),
|
||||
},
|
||||
)
|
||||
.is_none();
|
||||
assert!(is_new);
|
||||
}
|
||||
|
||||
fn namespaced(&self, name: &String) -> String {
|
||||
self.namespaced_ref(&None, name)
|
||||
}
|
||||
@@ -488,26 +542,35 @@ impl Context {
|
||||
format!("{}.{name}", namespace.as_ref().unwrap_or(&self.namespace))
|
||||
}
|
||||
|
||||
fn process_selected_expression(&self, expr: &ast::SelectedExpressions) -> SelectedExpressions {
|
||||
fn process_selected_expression(
|
||||
&mut self,
|
||||
expr: &ast::SelectedExpressions,
|
||||
) -> SelectedExpressions {
|
||||
SelectedExpressions {
|
||||
selector: expr.selector.as_ref().map(|e| self.process_expression(e)),
|
||||
expressions: self.process_expressions(&expr.expressions),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_expressions(&self, exprs: &[ast::Expression]) -> Vec<Expression> {
|
||||
fn process_expressions(&mut self, exprs: &[ast::Expression]) -> Vec<Expression> {
|
||||
exprs.iter().map(|e| self.process_expression(e)).collect()
|
||||
}
|
||||
|
||||
fn process_expression(&self, expr: &ast::Expression) -> Expression {
|
||||
fn process_expression(&mut self, expr: &ast::Expression) -> Expression {
|
||||
match expr {
|
||||
ast::Expression::Constant(name) => Expression::Constant(name.clone()),
|
||||
ast::Expression::PolynomialReference(poly) => {
|
||||
if poly.namespace.is_none() && self.local_variables.contains_key(&poly.name) {
|
||||
let id = self.local_variables[&poly.name];
|
||||
// TODO to make this work inside macros, "next" and "index" need to be
|
||||
// their own ast nodes / operators.
|
||||
assert!(!poly.next);
|
||||
assert!(poly.index.is_none());
|
||||
Expression::LocalVariableReference(id)
|
||||
if let Some(arguments) = &self.macro_arguments {
|
||||
arguments[id as usize].clone()
|
||||
} else {
|
||||
Expression::LocalVariableReference(id)
|
||||
}
|
||||
} else {
|
||||
Expression::PolynomialReference(self.process_polynomial_reference(poly))
|
||||
}
|
||||
@@ -532,9 +595,41 @@ impl Context {
|
||||
Expression::UnaryOperation(*op, Box::new(self.process_expression(value)))
|
||||
}
|
||||
}
|
||||
ast::Expression::FunctionCall(name, arguments) => self
|
||||
.process_macro_call(name, arguments)
|
||||
.expect("Invoked a macro in expression context with empty expression."),
|
||||
}
|
||||
}
|
||||
|
||||
fn process_macro_call(
|
||||
&mut self,
|
||||
name: &String,
|
||||
arguments: &[ast::Expression],
|
||||
) -> Option<Expression> {
|
||||
let arguments = Some(self.process_expressions(arguments));
|
||||
let old_arguments = std::mem::replace(&mut self.macro_arguments, arguments);
|
||||
|
||||
let old_locals = std::mem::take(&mut self.local_variables);
|
||||
|
||||
let mac = &self.macros[name];
|
||||
self.local_variables = mac
|
||||
.parameters
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, n)| (n.clone(), i as u64))
|
||||
.collect();
|
||||
// TODO avoid clones
|
||||
let expression = mac.expression.clone();
|
||||
let identities = mac.identities.clone();
|
||||
for identity in &identities {
|
||||
self.handle_identity_statement(identity);
|
||||
}
|
||||
let result = expression.map(|expr| self.process_expression(&expr));
|
||||
self.macro_arguments = old_arguments;
|
||||
self.local_variables = old_locals;
|
||||
result
|
||||
}
|
||||
|
||||
fn process_polynomial_reference(&self, poly: &ast::PolynomialReference) -> PolynomialReference {
|
||||
let index = poly
|
||||
.index
|
||||
@@ -562,6 +657,7 @@ impl Context {
|
||||
self.evaluate_binary_operation(left, op, right)
|
||||
}
|
||||
ast::Expression::UnaryOperation(op, value) => self.evaluate_unary_operation(op, value),
|
||||
ast::Expression::FunctionCall(_, _) => None, // TODO we should also try to evaluate through macro calls.
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -140,4 +140,45 @@ mod test {
|
||||
vec![(&"F.EVEN".to_string(), vec![-2, 0, 2, 4, 6, 8, 10, 12])]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_macro() {
|
||||
let src = r#"
|
||||
constant %N = 8;
|
||||
namespace F(%N);
|
||||
macro minus_one(X) { X - 1 };
|
||||
pol constant EVEN(i) { 2 * minus_one(i) };
|
||||
"#;
|
||||
let analyzed = analyze_string(src);
|
||||
let (constants, degree) = generate(&analyzed);
|
||||
assert_eq!(degree, 8);
|
||||
assert_eq!(
|
||||
constants,
|
||||
vec![(&"F.EVEN".to_string(), vec![-2, 0, 2, 4, 6, 8, 10, 12])]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_macro_double() {
|
||||
let src = r#"
|
||||
constant %N = 12;
|
||||
namespace F(%N);
|
||||
macro is_nonzero(X) { X / X };
|
||||
macro is_zero(X) { 1 - is_nonzero(X) };
|
||||
macro is_one(X) { is_zero(1 - X) };
|
||||
macro is_equal(A, B) { is_zero(A - B) };
|
||||
macro ite(C, T, F) { is_one(C) * T + is_zero(C) * F };
|
||||
pol constant TEN(i) { ite(is_equal(i, 10), 1, 0) };
|
||||
"#;
|
||||
let analyzed = analyze_string(src);
|
||||
let (constants, degree) = generate(&analyzed);
|
||||
assert_eq!(degree, 12);
|
||||
assert_eq!(
|
||||
constants,
|
||||
vec![(
|
||||
&"F.TEN".to_string(),
|
||||
[[0; 10].to_vec(), [1, 0].to_vec()].concat()
|
||||
)]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct PILFile(pub Vec<Statement>);
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub enum Statement {
|
||||
/// File name
|
||||
Include(usize, String),
|
||||
@@ -17,6 +17,14 @@ pub enum Statement {
|
||||
PermutationIdentity(usize, SelectedExpressions, SelectedExpressions),
|
||||
ConnectIdentity(usize, Vec<Expression>, Vec<Expression>),
|
||||
ConstantDefinition(usize, String, Expression),
|
||||
MacroDefinition(
|
||||
usize,
|
||||
String,
|
||||
Vec<String>,
|
||||
Vec<Statement>,
|
||||
Option<Expression>,
|
||||
),
|
||||
FunctionCall(usize, String, Vec<Expression>),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
@@ -37,9 +45,10 @@ pub enum Expression {
|
||||
Number(ConstantNumberType),
|
||||
BinaryOperation(Box<Expression>, BinaryOperator, Box<Expression>),
|
||||
UnaryOperation(UnaryOperator, Box<Expression>),
|
||||
FunctionCall(String, Vec<Expression>),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Default)]
|
||||
#[derive(Debug, PartialEq, Eq, Default, Clone)]
|
||||
pub struct PolynomialName {
|
||||
pub name: String,
|
||||
pub array_size: Option<Expression>,
|
||||
|
||||
@@ -155,4 +155,46 @@ mod test {
|
||||
parse_file("test_files/rom.pil");
|
||||
parse_file("test_files/storage.pil");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simple_macro() {
|
||||
let parsed = pil::PILFileParser::new()
|
||||
.parse("macro f(x) { x in g; x + 1 };")
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
parsed,
|
||||
PILFile(vec![Statement::MacroDefinition(
|
||||
0,
|
||||
"f".to_string(),
|
||||
vec!["x".to_string()],
|
||||
vec![Statement::PlookupIdentity(
|
||||
13,
|
||||
SelectedExpressions {
|
||||
selector: None,
|
||||
expressions: vec![Expression::PolynomialReference(PolynomialReference {
|
||||
name: "x".to_string(),
|
||||
..Default::default()
|
||||
})]
|
||||
},
|
||||
SelectedExpressions {
|
||||
selector: None,
|
||||
expressions: vec![Expression::PolynomialReference(PolynomialReference {
|
||||
name: "g".to_string(),
|
||||
..Default::default()
|
||||
})]
|
||||
}
|
||||
)],
|
||||
Some(Expression::BinaryOperation(
|
||||
Box::new(Expression::PolynomialReference(PolynomialReference {
|
||||
namespace: None,
|
||||
name: "x".to_string(),
|
||||
index: None,
|
||||
next: false
|
||||
})),
|
||||
BinaryOperator::Add,
|
||||
Box::new(Expression::Number(1))
|
||||
))
|
||||
)])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,8 @@ Statement = {
|
||||
PlookupIdentity,
|
||||
PermutationIdentity,
|
||||
ConnectIdentity,
|
||||
MacroDefinition,
|
||||
FunctionCallStatement
|
||||
};
|
||||
|
||||
Include: Statement = {
|
||||
@@ -96,6 +98,17 @@ ConnectIdentity: Statement = {
|
||||
<@L> "{" <ExpressionList> "}" "connect" "{" <ExpressionList> "}" => Statement::ConnectIdentity(<>)
|
||||
}
|
||||
|
||||
MacroDefinition: Statement = {
|
||||
<@L> "macro" <Identifier> "(" <ParameterList> ")" "{" <( <Statement> ";")*> <Expression?> "}"
|
||||
=> Statement::MacroDefinition(<>)
|
||||
}
|
||||
|
||||
FunctionCallStatement: Statement = {
|
||||
<@L> <Identifier> "(" <ExpressionList> ")" => Statement::FunctionCall(<>)
|
||||
}
|
||||
|
||||
|
||||
|
||||
ExpressionList: Vec<Expression> = {
|
||||
<mut list:( <Expression> "," )*> <end:Expression> => { list.push(end); list }
|
||||
}
|
||||
@@ -144,6 +157,7 @@ UnaryOp: UnaryOperator = {
|
||||
}
|
||||
|
||||
Term: Box<Expression> = {
|
||||
FunctionCall => Box::new(<>),
|
||||
ConstantIdentifier => Box::new(Expression::Constant(<>)),
|
||||
PolynomialReference => Box::new(Expression::PolynomialReference(<>)),
|
||||
PublicReference => Box::new(Expression::PublicReference(<>)),
|
||||
@@ -151,6 +165,10 @@ Term: Box<Expression> = {
|
||||
"(" <BoxedExpression> ")",
|
||||
}
|
||||
|
||||
FunctionCall: Expression = {
|
||||
<Identifier> "(" <ExpressionList> ")" => Expression::FunctionCall(<>)
|
||||
}
|
||||
|
||||
PolynomialReference: PolynomialReference = {
|
||||
<namespace:( <Identifier> "." )?>
|
||||
<name:Identifier>
|
||||
|
||||
32
tests/fib_macro.pil
Normal file
32
tests/fib_macro.pil
Normal file
@@ -0,0 +1,32 @@
|
||||
constant %N = 16;
|
||||
|
||||
namespace Fibonacci(%N);
|
||||
constant %last_row = %N - 1;
|
||||
|
||||
macro bool(X) { X * (1 - X) = 0; };
|
||||
macro is_nonzero(X) { X / X }; // 0 / 0 == 0 makes this work...
|
||||
macro is_zero(X) { 1 - is_nonzero(X) };
|
||||
macro is_equal(A, B) { is_zero(A - B) };
|
||||
macro is_one(X) { is_equal(X, 1) };
|
||||
macro ite(C, A, B) { is_nonzero(C) * A + is_zero(C) * B};
|
||||
|
||||
macro one_hot(i, index) { ite(is_equal(i, index), 1, 0) };
|
||||
|
||||
pol constant ISLAST(i) { one_hot(i, %last_row) };
|
||||
pol commit x, y;
|
||||
|
||||
macro constrain_equal_expr(A, B) { A - B };
|
||||
macro force_equal_on_last_row(poly, value) { ISLAST * constrain_equal_expr(poly, value) = 0; };
|
||||
|
||||
// TODO would be easier if we could use "'" as an operator,
|
||||
// then we could write a "force_equal_on_first_row" macro,
|
||||
// and the macro would add a "'" to the parameter.
|
||||
force_equal_on_last_row(x', 1);
|
||||
force_equal_on_last_row(y', 1);
|
||||
|
||||
macro on_regular_row(cond) { (1 - ISLAST) * cond = 0; };
|
||||
|
||||
on_regular_row(constrain_equal_expr(x', y));
|
||||
on_regular_row(constrain_equal_expr(y', x + y));
|
||||
|
||||
public out = y(%last_row);
|
||||
@@ -2,9 +2,8 @@ use std::{path::Path, process::Command};
|
||||
|
||||
use powdr::compiler;
|
||||
|
||||
#[test]
|
||||
fn test_fibonaccy() {
|
||||
compiler::compile(Path::new("./tests/fibonacci.pil"));
|
||||
fn verify(file_name: &str) {
|
||||
compiler::compile(Path::new(&format!("./tests/{file_name}")));
|
||||
|
||||
let pilcom = std::env::var("PILCOM")
|
||||
.expect("Please set the PILCOM environment variable to the path to the pilcom repository.");
|
||||
@@ -14,7 +13,7 @@ fn test_fibonaccy() {
|
||||
format!("{pilcom}/src/main_pilverifier.js"),
|
||||
"commits.bin".to_string(),
|
||||
"-p".to_string(),
|
||||
"fibonacci.pil.json".to_string(),
|
||||
format!("{file_name}.json"),
|
||||
"-c".to_string(),
|
||||
"constants.bin".to_string(),
|
||||
])
|
||||
@@ -33,3 +32,13 @@ fn test_fibonaccy() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fibonacci() {
|
||||
verify("fibonacci.pil");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fibonacci_macro() {
|
||||
verify("fib_macro.pil");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user