mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-04-20 03:03:25 -04:00
Merge pull request #311 from powdr-org/macros_for_asm
Add macros to assembly.
This commit is contained in:
@@ -84,3 +84,17 @@ fn test_simple_sum_asm_pil() {
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_sum_asm_macro_pil() {
|
||||
verify_pil(
|
||||
"simple_sum_asm_macro.pil",
|
||||
Some(|q| match q {
|
||||
"\"input\", 0" => Some(13.into()),
|
||||
"\"input\", 1" => Some(2.into()),
|
||||
"\"input\", 2" => Some(11.into()),
|
||||
"\"input\", 3" => Some(2.into()),
|
||||
_ => Some(7.into()),
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -64,6 +64,7 @@ pub enum InstructionBodyElement<T> {
|
||||
PlookupOperator,
|
||||
SelectedExpressions<T>,
|
||||
),
|
||||
FunctionCall(String, Vec<Expression<T>>),
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
use std::{iter::once, ops::ControlFlow};
|
||||
|
||||
use number::{DegreeType, FieldElement};
|
||||
|
||||
use crate::asm_ast::ASMStatement;
|
||||
@@ -184,3 +186,98 @@ impl<T> ArrayExpression<T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Traverses the expression tree and calls `f` in post-order.
|
||||
pub fn postvisit_expression_mut<T, F, B>(e: &mut Expression<T>, f: &mut F) -> ControlFlow<B>
|
||||
where
|
||||
F: FnMut(&mut Expression<T>) -> ControlFlow<B>,
|
||||
{
|
||||
match e {
|
||||
Expression::PolynomialReference(_)
|
||||
| Expression::Constant(_)
|
||||
| Expression::PublicReference(_)
|
||||
| Expression::Number(_)
|
||||
| Expression::String(_) => {}
|
||||
Expression::BinaryOperation(left, _, right) => {
|
||||
postvisit_expression_mut(left, f)?;
|
||||
postvisit_expression_mut(right, f)?;
|
||||
}
|
||||
Expression::UnaryOperation(_, e) => postvisit_expression_mut(e.as_mut(), f)?,
|
||||
Expression::Tuple(items) | Expression::FunctionCall(_, items) => items
|
||||
.iter_mut()
|
||||
.try_for_each(|item| postvisit_expression_mut(item, f))?,
|
||||
Expression::FreeInput(query) => postvisit_expression_mut(query.as_mut(), f)?,
|
||||
Expression::MatchExpression(scrutinee, arms) => {
|
||||
once(scrutinee.as_mut())
|
||||
.chain(arms.iter_mut().map(|(_n, e)| e))
|
||||
.try_for_each(|item| postvisit_expression_mut(item, f))?;
|
||||
}
|
||||
};
|
||||
f(e)
|
||||
}
|
||||
|
||||
/// Traverses the expression trees of the statement and calls `f` in post-order.
|
||||
/// Does not enter ASMBlocks or macro definitions.
|
||||
pub fn postvisit_expression_in_statement_mut<T, F, B>(
|
||||
statement: &mut Statement<T>,
|
||||
f: &mut F,
|
||||
) -> ControlFlow<B>
|
||||
where
|
||||
F: FnMut(&mut Expression<T>) -> ControlFlow<B>,
|
||||
{
|
||||
match statement {
|
||||
Statement::FunctionCall(_, _, arguments) => arguments
|
||||
.iter_mut()
|
||||
.try_for_each(|e| postvisit_expression_mut(e, f)),
|
||||
Statement::PlookupIdentity(_, left, right)
|
||||
| Statement::PermutationIdentity(_, left, right) => left
|
||||
.selector
|
||||
.iter_mut()
|
||||
.chain(left.expressions.iter_mut())
|
||||
.chain(right.selector.iter_mut())
|
||||
.chain(right.expressions.iter_mut())
|
||||
.try_for_each(|e| postvisit_expression_mut(e, f)),
|
||||
Statement::ConnectIdentity(_start, left, right) => left
|
||||
.iter_mut()
|
||||
.chain(right.iter_mut())
|
||||
.try_for_each(|e| postvisit_expression_mut(e, f)),
|
||||
|
||||
Statement::Namespace(_, _, e)
|
||||
| Statement::PolynomialDefinition(_, _, e)
|
||||
| Statement::PolynomialIdentity(_, e)
|
||||
| Statement::PublicDeclaration(_, _, _, e)
|
||||
| Statement::ConstantDefinition(_, _, e) => postvisit_expression_mut(e, f),
|
||||
|
||||
Statement::PolynomialConstantDefinition(_, _, fundef)
|
||||
| Statement::PolynomialCommitDeclaration(_, _, Some(fundef)) => match fundef {
|
||||
FunctionDefinition::Query(_, e) | FunctionDefinition::Mapping(_, e) => {
|
||||
postvisit_expression_mut(e, f)
|
||||
}
|
||||
FunctionDefinition::Array(ae) => postvisit_expression_in_array_expression_mut(ae, f),
|
||||
},
|
||||
Statement::PolynomialCommitDeclaration(_, _, None)
|
||||
| Statement::Include(_, _)
|
||||
| Statement::PolynomialConstantDeclaration(_, _)
|
||||
| Statement::MacroDefinition(_, _, _, _, _)
|
||||
| Statement::ASMBlock(_, _) => ControlFlow::Continue(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn postvisit_expression_in_array_expression_mut<T, F, B>(
|
||||
ae: &mut ArrayExpression<T>,
|
||||
f: &mut F,
|
||||
) -> ControlFlow<B>
|
||||
where
|
||||
F: FnMut(&mut Expression<T>) -> ControlFlow<B>,
|
||||
{
|
||||
match ae {
|
||||
ArrayExpression::Value(expressions) | ArrayExpression::RepeatedValue(expressions) => {
|
||||
expressions
|
||||
.iter_mut()
|
||||
.try_for_each(|e| postvisit_expression_mut(e, f))
|
||||
}
|
||||
ArrayExpression::Concat(a1, a2) => [a1, a2]
|
||||
.iter_mut()
|
||||
.try_for_each(|e| postvisit_expression_in_array_expression_mut(e, f)),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ use parser_util::{handle_parse_error, ParseError};
|
||||
pub mod asm_ast;
|
||||
pub mod ast;
|
||||
pub mod display;
|
||||
pub mod macro_expansion;
|
||||
|
||||
lalrpop_mod!(
|
||||
#[allow(clippy::all)]
|
||||
|
||||
142
parser/src/macro_expansion.rs
Normal file
142
parser/src/macro_expansion.rs
Normal file
@@ -0,0 +1,142 @@
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
ops::ControlFlow,
|
||||
};
|
||||
|
||||
use crate::ast::*;
|
||||
use number::FieldElement;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct MacroExpander<T> {
|
||||
macros: HashMap<String, MacroDefinition<T>>,
|
||||
arguments: Vec<Expression<T>>,
|
||||
parameter_names: HashMap<String, usize>,
|
||||
shadowing_locals: HashSet<String>,
|
||||
statements: Vec<Statement<T>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct MacroDefinition<T> {
|
||||
pub parameters: Vec<String>,
|
||||
pub identities: Vec<Statement<T>>,
|
||||
pub expression: Option<Expression<T>>,
|
||||
}
|
||||
|
||||
impl<T> MacroExpander<T>
|
||||
where
|
||||
T: FieldElement,
|
||||
{
|
||||
pub fn new() -> Self {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
/// Expands all macro references inside the statements and also adds
|
||||
/// any macros defined therein to the list of macros.
|
||||
///
|
||||
/// Note that macros are not namespaced!
|
||||
pub fn expand_macros(&mut self, statements: Vec<Statement<T>>) -> Vec<Statement<T>> {
|
||||
assert!(self.statements.is_empty());
|
||||
for statement in statements {
|
||||
self.handle_statement(statement);
|
||||
}
|
||||
std::mem::take(&mut self.statements)
|
||||
}
|
||||
|
||||
pub fn handle_statement(&mut self, mut statement: Statement<T>) {
|
||||
let mut added_locals = false;
|
||||
if let Statement::PolynomialConstantDefinition(_, _, f)
|
||||
| Statement::PolynomialCommitDeclaration(_, _, Some(f)) = &statement
|
||||
{
|
||||
if let FunctionDefinition::Mapping(params, _) | FunctionDefinition::Query(params, _) = f
|
||||
{
|
||||
assert!(self.shadowing_locals.is_empty());
|
||||
self.shadowing_locals.extend(params.iter().cloned());
|
||||
added_locals = true;
|
||||
}
|
||||
}
|
||||
|
||||
postvisit_expression_in_statement_mut(&mut statement, &mut |e| self.process_expression(e));
|
||||
|
||||
match &mut statement {
|
||||
Statement::FunctionCall(_start, name, arguments) => {
|
||||
if !self.macros.contains_key(name) {
|
||||
panic!(
|
||||
"Macro {name} not found - only macros allowed at this point, no fixed columns."
|
||||
);
|
||||
}
|
||||
if self.expand_macro(name, std::mem::take(arguments)).is_some() {
|
||||
panic!("Invoked a macro in statement context with non-empty expression.");
|
||||
}
|
||||
}
|
||||
Statement::MacroDefinition(_start, name, parameters, statements, expression) => {
|
||||
// We expand lazily. Is that a mistake?
|
||||
let is_new = self
|
||||
.macros
|
||||
.insert(
|
||||
std::mem::take(name),
|
||||
MacroDefinition {
|
||||
parameters: std::mem::take(parameters),
|
||||
identities: std::mem::take(statements),
|
||||
expression: std::mem::take(expression),
|
||||
},
|
||||
)
|
||||
.is_none();
|
||||
assert!(is_new);
|
||||
}
|
||||
_ => self.statements.push(statement),
|
||||
};
|
||||
|
||||
if added_locals {
|
||||
self.shadowing_locals.clear();
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_macro(&mut self, name: &str, arguments: Vec<Expression<T>>) -> Option<Expression<T>> {
|
||||
let old_arguments = std::mem::replace(&mut self.arguments, arguments);
|
||||
|
||||
let mac = &self
|
||||
.macros
|
||||
.get(name)
|
||||
.unwrap_or_else(|| panic!("Macro {name} not found."));
|
||||
let parameters = mac
|
||||
.parameters
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, n)| (n.clone(), i))
|
||||
.collect();
|
||||
let old_parameters = std::mem::replace(&mut self.parameter_names, parameters);
|
||||
|
||||
let mut expression = mac.expression.clone();
|
||||
let identities = mac.identities.clone();
|
||||
for identity in identities {
|
||||
self.handle_statement(identity)
|
||||
}
|
||||
if let Some(e) = &mut expression {
|
||||
postvisit_expression_mut(e, &mut |e| self.process_expression(e));
|
||||
};
|
||||
|
||||
self.arguments = old_arguments;
|
||||
self.parameter_names = old_parameters;
|
||||
expression
|
||||
}
|
||||
|
||||
fn process_expression(&mut self, e: &mut Expression<T>) -> ControlFlow<()> {
|
||||
if let Expression::PolynomialReference(poly) = e {
|
||||
if poly.namespace.is_none() && self.parameter_names.contains_key(&poly.name) {
|
||||
// TODO to make this work inside macros, "next" and "index" need to be
|
||||
// their own ast nodes / operators.
|
||||
assert!(!poly.next);
|
||||
assert!(poly.index.is_none());
|
||||
*e = self.arguments[self.parameter_names[&poly.name]].clone()
|
||||
}
|
||||
} else if let Expression::FunctionCall(name, arguments) = e {
|
||||
if self.macros.contains_key(name.as_str()) {
|
||||
*e = self
|
||||
.expand_macro(name, std::mem::take(arguments))
|
||||
.expect("Invoked a macro in expression context with empty expression.")
|
||||
}
|
||||
}
|
||||
|
||||
ControlFlow::<()>::Continue(())
|
||||
}
|
||||
}
|
||||
@@ -188,6 +188,7 @@ InstructionBodyElements: Vec<InstructionBodyElement<T>> = {
|
||||
InstructionBodyElement: InstructionBodyElement<T> = {
|
||||
<l:BoxedExpression> "=" <r:BoxedExpression> => InstructionBodyElement::Expression(Expression::BinaryOperation(l, BinaryOperator::Sub, r)),
|
||||
<SelectedExpressions> <PlookupOperator> <SelectedExpressions> => InstructionBodyElement::PlookupIdentity(<>),
|
||||
<Identifier> "(" <ExpressionList> ")" => InstructionBodyElement::FunctionCall(<>)
|
||||
}
|
||||
|
||||
// This is only valid in instructions, not in PIL in general.
|
||||
|
||||
@@ -6,6 +6,7 @@ use number::{BigInt, DegreeType, FieldElement};
|
||||
use parser::asm_ast::ASMStatement;
|
||||
use parser::ast;
|
||||
pub use parser::ast::{BinaryOperator, UnaryOperator};
|
||||
use parser::macro_expansion::MacroExpander;
|
||||
|
||||
use crate::util::previsit_expressions_in_pil_file_mut;
|
||||
|
||||
@@ -31,7 +32,6 @@ struct PILContext<T> {
|
||||
constants: HashMap<String, T>,
|
||||
definitions: HashMap<String, (Polynomial, Option<FunctionValueDefinition<T>>)>,
|
||||
public_declarations: HashMap<String, PublicDeclaration>,
|
||||
macros: HashMap<String, MacroDefinition<T>>,
|
||||
identities: Vec<Identity<T>>,
|
||||
/// The order in which definitions and identities
|
||||
/// appear in the source.
|
||||
@@ -44,17 +44,7 @@ struct PILContext<T> {
|
||||
intermediate_poly_counter: u64,
|
||||
identity_counter: HashMap<IdentityKind, u64>,
|
||||
local_variables: HashMap<String, u64>,
|
||||
/// If we are evaluating a macro, this holds the arguments.
|
||||
macro_arguments: Option<Vec<Expression<T>>>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct MacroDefinition<T> {
|
||||
pub source: SourceRef,
|
||||
pub absolute_name: String,
|
||||
pub parameters: Vec<String>,
|
||||
pub identities: Vec<ast::Statement<T>>,
|
||||
pub expression: Option<ast::Expression<T>>,
|
||||
macro_expander: MacroExpander<T>,
|
||||
}
|
||||
|
||||
impl<T> From<PILContext<T>> for Analyzed<T> {
|
||||
@@ -133,7 +123,9 @@ impl<T: FieldElement> PILContext<T> {
|
||||
});
|
||||
|
||||
for statement in pil_file.0 {
|
||||
self.handle_statement(statement);
|
||||
for statement in self.macro_expander.expand_macros(vec![statement]) {
|
||||
self.handle_statement(statement);
|
||||
}
|
||||
}
|
||||
|
||||
self.current_file = old_current_file;
|
||||
@@ -192,15 +184,9 @@ impl<T: FieldElement> PILContext<T> {
|
||||
Statement::ConstantDefinition(_, name, value) => {
|
||||
self.handle_constant_definition(name, value)
|
||||
}
|
||||
Statement::MacroDefinition(start, name, params, statments, expression) => self
|
||||
.handle_macro_definition(
|
||||
self.to_source_ref(start),
|
||||
name,
|
||||
params,
|
||||
statments,
|
||||
expression,
|
||||
),
|
||||
|
||||
Statement::MacroDefinition(_, _, _, _, _) => {
|
||||
panic!("Macros should have been eliminated.");
|
||||
}
|
||||
Statement::ASMBlock(start, asm_statements) => {
|
||||
self.handle_assembly(self.to_source_ref(start), asm_statements)
|
||||
}
|
||||
@@ -219,20 +205,6 @@ impl<T: FieldElement> PILContext<T> {
|
||||
}
|
||||
|
||||
fn handle_identity_statement(&mut self, statement: ast::Statement<T>) {
|
||||
if let ast::Statement::FunctionCall(_start, name, arguments) = statement {
|
||||
if !self.macros.contains_key(&name) {
|
||||
panic!(
|
||||
"Macro {name} not found - only macros allowed at this point, no fixed columns."
|
||||
);
|
||||
}
|
||||
// TODO check that it does not contain local variable references.
|
||||
// But we also need to do some other well-formedness checks.
|
||||
if self.process_macro_call(name, arguments).is_some() {
|
||||
panic!("Invoked a macro in statement context with non-empty expression.");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
let (start, kind, left, right) = match statement {
|
||||
ast::Statement::PolynomialIdentity(start, expression) => (
|
||||
start,
|
||||
@@ -435,33 +407,8 @@ impl<T: FieldElement> PILContext<T> {
|
||||
id
|
||||
}
|
||||
|
||||
fn handle_macro_definition(
|
||||
&mut self,
|
||||
source: SourceRef,
|
||||
name: String,
|
||||
params: Vec<String>,
|
||||
statements: Vec<ast::Statement<T>>,
|
||||
expression: Option<ast::Expression<T>>,
|
||||
) {
|
||||
let absolute_name = self.namespaced(&name);
|
||||
let is_new = self
|
||||
.macros
|
||||
.insert(
|
||||
name,
|
||||
MacroDefinition {
|
||||
source,
|
||||
absolute_name,
|
||||
parameters: params.to_vec(),
|
||||
identities: statements.to_vec(),
|
||||
expression,
|
||||
},
|
||||
)
|
||||
.is_none();
|
||||
assert!(is_new);
|
||||
}
|
||||
|
||||
fn handle_assembly(&mut self, _source: SourceRef, asm_statements: Vec<ASMStatement<T>>) {
|
||||
let statements = pilgen::asm_to_pil(asm_statements.into_iter());
|
||||
let statements = pilgen::asm_to_pil(asm_statements.into_iter(), &mut self.macro_expander);
|
||||
for s in statements {
|
||||
self.handle_statement(s)
|
||||
}
|
||||
@@ -526,15 +473,9 @@ impl<T: FieldElement> PILContext<T> {
|
||||
ast::Expression::PolynomialReference(poly) => {
|
||||
if poly.namespace.is_none() && self.local_variables.contains_key(&poly.name) {
|
||||
let id = self.local_variables[&poly.name];
|
||||
// TODO to make this work inside macros, "next" and "index" need to be
|
||||
// their own ast nodes / operators.
|
||||
assert!(!poly.next);
|
||||
assert!(poly.index.is_none());
|
||||
if let Some(arguments) = &self.macro_arguments {
|
||||
arguments[id as usize].clone()
|
||||
} else {
|
||||
Expression::LocalVariableReference(id)
|
||||
}
|
||||
Expression::LocalVariableReference(id)
|
||||
} else {
|
||||
Expression::PolynomialReference(self.process_polynomial_reference(poly))
|
||||
}
|
||||
@@ -561,10 +502,6 @@ impl<T: FieldElement> PILContext<T> {
|
||||
Expression::UnaryOperation(op, Box::new(self.process_expression(*value)))
|
||||
}
|
||||
}
|
||||
ast::Expression::FunctionCall(name, arguments) if self.macros.contains_key(&name) => {
|
||||
self.process_macro_call(name, arguments)
|
||||
.expect("Invoked a macro in expression context with empty expression.")
|
||||
}
|
||||
ast::Expression::FunctionCall(name, arguments) => Expression::FunctionCall(
|
||||
self.namespaced(&name),
|
||||
self.process_expressions(arguments),
|
||||
@@ -588,38 +525,6 @@ impl<T: FieldElement> PILContext<T> {
|
||||
}
|
||||
}
|
||||
|
||||
fn process_macro_call(
|
||||
&mut self,
|
||||
name: String,
|
||||
arguments: Vec<ast::Expression<T>>,
|
||||
) -> Option<Expression<T>> {
|
||||
let arguments = Some(self.process_expressions(arguments));
|
||||
let old_arguments = std::mem::replace(&mut self.macro_arguments, arguments);
|
||||
|
||||
let old_locals = std::mem::take(&mut self.local_variables);
|
||||
|
||||
let mac = &self
|
||||
.macros
|
||||
.get(&name)
|
||||
.unwrap_or_else(|| panic!("Macro {name} not found."));
|
||||
self.local_variables = mac
|
||||
.parameters
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, n)| (n.clone(), i as u64))
|
||||
.collect();
|
||||
// TODO avoid clones
|
||||
let expression = mac.expression.clone();
|
||||
let identities = mac.identities.clone();
|
||||
for identity in identities {
|
||||
self.handle_identity_statement(identity);
|
||||
}
|
||||
let result = expression.map(|expr| self.process_expression(expr));
|
||||
self.macro_arguments = old_arguments;
|
||||
self.local_variables = old_locals;
|
||||
result
|
||||
}
|
||||
|
||||
fn process_polynomial_reference(
|
||||
&self,
|
||||
poly: ast::PolynomialReference<T>,
|
||||
|
||||
@@ -6,6 +6,7 @@ use number::FieldElement;
|
||||
|
||||
use parser::asm_ast::*;
|
||||
use parser::ast::*;
|
||||
use parser::macro_expansion::MacroExpander;
|
||||
use parser_util::ParseError;
|
||||
|
||||
/// Compiles a stand-alone assembly file to PIL.
|
||||
@@ -13,16 +14,19 @@ pub fn compile<'a, T: FieldElement>(
|
||||
file_name: Option<&str>,
|
||||
input: &'a str,
|
||||
) -> Result<PILFile<T>, ParseError<'a>> {
|
||||
let statements = parser::parse_asm(file_name, input)
|
||||
.map(|ast| ASMPILConverter::new().convert(ast.0, ASMKind::StandAlone))?;
|
||||
let statements = parser::parse_asm(file_name, input).map(|ast| {
|
||||
let mut macro_expander = MacroExpander::new();
|
||||
ASMPILConverter::new(&mut macro_expander).convert(ast.0, ASMKind::StandAlone)
|
||||
})?;
|
||||
Ok(PILFile(statements))
|
||||
}
|
||||
|
||||
/// Compiles inline assembly to PIL.
|
||||
pub fn asm_to_pil<T: FieldElement>(
|
||||
statements: impl IntoIterator<Item = ASMStatement<T>>,
|
||||
macro_expander: &mut MacroExpander<T>,
|
||||
) -> Vec<Statement<T>> {
|
||||
ASMPILConverter::new().convert(statements, ASMKind::Inline)
|
||||
ASMPILConverter::new(macro_expander).convert(statements, ASMKind::Inline)
|
||||
}
|
||||
|
||||
#[derive(PartialEq)]
|
||||
@@ -31,8 +35,8 @@ enum ASMKind {
|
||||
StandAlone,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct ASMPILConverter<T> {
|
||||
struct ASMPILConverter<'a, T> {
|
||||
macro_expander: &'a mut MacroExpander<T>,
|
||||
pil: Vec<Statement<T>>,
|
||||
pc_name: Option<String>,
|
||||
registers: BTreeMap<String, Register<T>>,
|
||||
@@ -44,9 +48,18 @@ struct ASMPILConverter<T> {
|
||||
program_constant_names: Vec<String>,
|
||||
}
|
||||
|
||||
impl<T: FieldElement> ASMPILConverter<T> {
|
||||
fn new() -> Self {
|
||||
Default::default()
|
||||
impl<'a, T: FieldElement> ASMPILConverter<'a, T> {
|
||||
fn new(macro_expander: &'a mut MacroExpander<T>) -> Self {
|
||||
Self {
|
||||
macro_expander,
|
||||
pil: Default::default(),
|
||||
pc_name: None,
|
||||
registers: Default::default(),
|
||||
instructions: Default::default(),
|
||||
code_lines: Default::default(),
|
||||
line_lookup: Default::default(),
|
||||
program_constant_names: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert(
|
||||
@@ -91,7 +104,9 @@ impl<T: FieldElement> ASMPILConverter<T> {
|
||||
ASMStatement::InstructionDeclaration(start, name, params, body) => {
|
||||
self.handle_instruction_def(start, body, name, params);
|
||||
}
|
||||
ASMStatement::InlinePil(_start, statements) => self.pil.extend(statements.clone()),
|
||||
ASMStatement::InlinePil(_start, statements) => self
|
||||
.pil
|
||||
.extend(self.macro_expander.expand_macros(statements)),
|
||||
ASMStatement::Assignment(start, write_regs, assign_reg, value) => match *value {
|
||||
Expression::FunctionCall(function_name, args) => {
|
||||
self.handle_functional_instruction(
|
||||
@@ -263,6 +278,30 @@ impl<T: FieldElement> ASMPILConverter<T> {
|
||||
|
||||
let instr = Instruction { inputs, outputs };
|
||||
|
||||
// First transform into PIL so that we can apply macro expansion.
|
||||
let mut statements = body
|
||||
.into_iter()
|
||||
.map(|el| match el {
|
||||
InstructionBodyElement::Expression(expr) => {
|
||||
Statement::PolynomialIdentity(start, expr)
|
||||
}
|
||||
InstructionBodyElement::PlookupIdentity(left, op, right) => {
|
||||
assert!(
|
||||
left.selector.is_none(),
|
||||
"LHS selector not supported, could and-combine with instruction flag later."
|
||||
);
|
||||
match op {
|
||||
PlookupOperator::In => Statement::PlookupIdentity(start, left, right),
|
||||
PlookupOperator::Is => Statement::PermutationIdentity(start, left, right),
|
||||
}
|
||||
}
|
||||
InstructionBodyElement::FunctionCall(name, arguments) => {
|
||||
Statement::FunctionCall(start, name, arguments)
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Substitute parameter references by the column names
|
||||
let substitutions = instr
|
||||
.literal_arg_names()
|
||||
.map(|arg_name| {
|
||||
@@ -270,40 +309,52 @@ impl<T: FieldElement> ASMPILConverter<T> {
|
||||
self.create_witness_fixed_pair(start, ¶m_col_name);
|
||||
(arg_name.clone(), param_col_name)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for expr in body {
|
||||
match expr {
|
||||
InstructionBodyElement::Expression(expr) => {
|
||||
let expr = substitute(expr, &substitutions);
|
||||
match extract_update(expr) {
|
||||
(Some(var), expr) => {
|
||||
self.registers
|
||||
.get_mut(&var)
|
||||
.unwrap()
|
||||
.conditioned_updates
|
||||
.push((direct_reference(&instruction_flag), expr));
|
||||
}
|
||||
(None, expr) => self.pil.push(Statement::PolynomialIdentity(
|
||||
0,
|
||||
build_mul(direct_reference(&instruction_flag), expr.clone()),
|
||||
)),
|
||||
.collect::<HashMap<_, _>>();
|
||||
statements.iter_mut().for_each(|s| {
|
||||
postvisit_expression_in_statement_mut(s, &mut |e| {
|
||||
if let Expression::PolynomialReference(r) = e {
|
||||
if let Some(sub) = substitutions.get(&r.name) {
|
||||
r.name = sub.clone();
|
||||
}
|
||||
}
|
||||
InstructionBodyElement::PlookupIdentity(left, op, right) => {
|
||||
assert!(left.selector.is_none(), "LHS selector not supported, could and-combine with instruction flag later.");
|
||||
let left = SelectedExpressions {
|
||||
selector: Some(direct_reference(&instruction_flag)),
|
||||
expressions: substitute_vec(left.expressions, &substitutions),
|
||||
};
|
||||
let right = substitute_selected_exprs(right, &substitutions);
|
||||
self.pil.push(match op {
|
||||
PlookupOperator::In => Statement::PlookupIdentity(start, left, right),
|
||||
PlookupOperator::Is => Statement::PermutationIdentity(start, left, right),
|
||||
})
|
||||
std::ops::ControlFlow::Continue::<()>(())
|
||||
});
|
||||
});
|
||||
|
||||
// Expand macros and analyze resulting statements.
|
||||
for mut statement in self.macro_expander.expand_macros(statements) {
|
||||
if let Statement::PolynomialIdentity(_start, expr) = statement {
|
||||
match extract_update(expr) {
|
||||
(Some(var), expr) => {
|
||||
self.registers
|
||||
.get_mut(&var)
|
||||
.unwrap()
|
||||
.conditioned_updates
|
||||
.push((direct_reference(&instruction_flag), expr));
|
||||
}
|
||||
(None, expr) => self.pil.push(Statement::PolynomialIdentity(
|
||||
0,
|
||||
build_mul(direct_reference(&instruction_flag), expr.clone()),
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
match &mut statement {
|
||||
Statement::PermutationIdentity(_, left, _)
|
||||
| Statement::PlookupIdentity(_, left, _) => {
|
||||
assert!(
|
||||
left.selector.is_none(),
|
||||
"LHS selector not supported, could and-combine with instruction flag later."
|
||||
);
|
||||
left.selector = Some(direct_reference(&instruction_flag));
|
||||
self.pil.push(statement)
|
||||
}
|
||||
_ => {
|
||||
panic!("Invalid statement for instruction body: {statement}");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
self.instructions.insert(name, instr);
|
||||
}
|
||||
|
||||
@@ -879,10 +930,6 @@ fn build_binary_expr<T>(
|
||||
Expression::BinaryOperation(Box::new(left), op, Box::new(right))
|
||||
}
|
||||
|
||||
fn build_unary_expr<T>(op: UnaryOperator, exp: Expression<T>) -> Expression<T> {
|
||||
Expression::UnaryOperation(op, Box::new(exp))
|
||||
}
|
||||
|
||||
fn build_number<T: FieldElement, V: Into<T>>(value: V) -> Expression<T> {
|
||||
Expression::Number(value.into())
|
||||
}
|
||||
@@ -908,77 +955,6 @@ fn extract_update<T: FieldElement>(expr: Expression<T>) -> (Option<String>, Expr
|
||||
}
|
||||
}
|
||||
|
||||
fn substitute<T: FieldElement>(
|
||||
input: Expression<T>,
|
||||
substitution: &HashMap<String, String>,
|
||||
) -> Expression<T> {
|
||||
match input {
|
||||
// TODO namespace
|
||||
Expression::PolynomialReference(r) => {
|
||||
Expression::PolynomialReference(PolynomialReference {
|
||||
name: substitute_string(&r.name, substitution),
|
||||
..r.clone()
|
||||
})
|
||||
}
|
||||
Expression::BinaryOperation(left, op, right) => build_binary_expr(
|
||||
substitute(*left, substitution),
|
||||
op,
|
||||
substitute(*right, substitution),
|
||||
),
|
||||
Expression::UnaryOperation(op, exp) => build_unary_expr(op, substitute(*exp, substitution)),
|
||||
Expression::FunctionCall(name, args) => Expression::FunctionCall(
|
||||
name,
|
||||
args.into_iter()
|
||||
.map(|e| substitute(e, substitution))
|
||||
.collect(),
|
||||
),
|
||||
Expression::Tuple(items) => Expression::Tuple(
|
||||
items
|
||||
.into_iter()
|
||||
.map(|e| substitute(e, substitution))
|
||||
.collect(),
|
||||
),
|
||||
Expression::Constant(_)
|
||||
| Expression::PublicReference(_)
|
||||
| Expression::Number(_)
|
||||
| Expression::String(_)
|
||||
| Expression::FreeInput(_) => input.clone(),
|
||||
Expression::MatchExpression(scrutinee, arms) => Expression::MatchExpression(
|
||||
Box::new(substitute(*scrutinee, substitution)),
|
||||
arms.into_iter()
|
||||
.map(|(n, e)| (n, substitute(e, substitution)))
|
||||
.collect(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn substitute_selected_exprs<T: FieldElement>(
|
||||
input: SelectedExpressions<T>,
|
||||
substitution: &HashMap<String, String>,
|
||||
) -> SelectedExpressions<T> {
|
||||
SelectedExpressions {
|
||||
selector: input.selector.map(|s| substitute(s, substitution)),
|
||||
expressions: substitute_vec(input.expressions, substitution),
|
||||
}
|
||||
}
|
||||
|
||||
fn substitute_vec<T: FieldElement>(
|
||||
input: Vec<Expression<T>>,
|
||||
substitution: &HashMap<String, String>,
|
||||
) -> Vec<Expression<T>> {
|
||||
input
|
||||
.into_iter()
|
||||
.map(|e| substitute(e, substitution))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn substitute_string(input: &str, substitution: &HashMap<String, String>) -> String {
|
||||
substitution
|
||||
.get(input)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| input.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use std::fs;
|
||||
|
||||
42
test_data/pil/simple_sum_asm_macro.pil
Normal file
42
test_data/pil/simple_sum_asm_macro.pil
Normal file
@@ -0,0 +1,42 @@
|
||||
namespace Main(2**10);
|
||||
col witness XInv;
|
||||
col witness XIsZero;
|
||||
XIsZero * (1 - XIsZero) = 0;
|
||||
|
||||
macro if_then_else(condition, true_value, false_value) { condition * true_value + (1 - condition) * false_value };
|
||||
macro jump_to(target) { pc' = target; };
|
||||
macro jump_to_if(condition, target) { jump_to(if_then_else(condition, target, pc + 1)); };
|
||||
|
||||
assembly {
|
||||
reg pc[@pc];
|
||||
reg X[<=];
|
||||
reg A;
|
||||
reg CNT;
|
||||
|
||||
pil {
|
||||
// Just to test if pil-inside-assembly-inside-pil works.
|
||||
XIsZero = 1 - X * XInv;
|
||||
XIsZero * X = 0;
|
||||
}
|
||||
|
||||
instr jmpz X, l: label { jump_to_if(XIsZero, l) }
|
||||
instr jmp l: label { jump_to(l) }
|
||||
instr dec_CNT { CNT' = CNT - 1 }
|
||||
instr assert_zero X { XIsZero = 1 }
|
||||
|
||||
CNT <=X= ${ ("input", 1) };
|
||||
|
||||
start::
|
||||
jmpz CNT, check;
|
||||
A <=X= A + ${ ("input", CNT + 1) };
|
||||
// Could use "CNT <=X= CNT - 1", but that would need X.
|
||||
dec_CNT;
|
||||
jmp start;
|
||||
|
||||
check::
|
||||
A <=X= A - ${ ("input", 0) };
|
||||
assert_zero A;
|
||||
|
||||
end::
|
||||
jmp end;
|
||||
};
|
||||
Reference in New Issue
Block a user