diff --git a/compiler/tests/pil.rs b/compiler/tests/pil.rs index 3a1039e4d..7cdbf6829 100644 --- a/compiler/tests/pil.rs +++ b/compiler/tests/pil.rs @@ -84,3 +84,17 @@ fn test_simple_sum_asm_pil() { }), ) } + +#[test] +fn test_simple_sum_asm_macro_pil() { + verify_pil( + "simple_sum_asm_macro.pil", + Some(|q| match q { + "\"input\", 0" => Some(13.into()), + "\"input\", 1" => Some(2.into()), + "\"input\", 2" => Some(11.into()), + "\"input\", 3" => Some(2.into()), + _ => Some(7.into()), + }), + ) +} diff --git a/parser/src/asm_ast.rs b/parser/src/asm_ast.rs index 2a2d6fac1..7b3925f86 100644 --- a/parser/src/asm_ast.rs +++ b/parser/src/asm_ast.rs @@ -64,6 +64,7 @@ pub enum InstructionBodyElement { PlookupOperator, SelectedExpressions, ), + FunctionCall(String, Vec>), } #[derive(Debug, PartialEq, Eq, Clone)] diff --git a/parser/src/ast.rs b/parser/src/ast.rs index b81522270..cb8fe4e47 100644 --- a/parser/src/ast.rs +++ b/parser/src/ast.rs @@ -1,3 +1,5 @@ +use std::{iter::once, ops::ControlFlow}; + use number::{DegreeType, FieldElement}; use crate::asm_ast::ASMStatement; @@ -184,3 +186,98 @@ impl ArrayExpression { } } } + +/// Traverses the expression tree and calls `f` in post-order. +pub fn postvisit_expression_mut(e: &mut Expression, f: &mut F) -> ControlFlow +where + F: FnMut(&mut Expression) -> ControlFlow, +{ + match e { + Expression::PolynomialReference(_) + | Expression::Constant(_) + | Expression::PublicReference(_) + | Expression::Number(_) + | Expression::String(_) => {} + Expression::BinaryOperation(left, _, right) => { + postvisit_expression_mut(left, f)?; + postvisit_expression_mut(right, f)?; + } + Expression::UnaryOperation(_, e) => postvisit_expression_mut(e.as_mut(), f)?, + Expression::Tuple(items) | Expression::FunctionCall(_, items) => items + .iter_mut() + .try_for_each(|item| postvisit_expression_mut(item, f))?, + Expression::FreeInput(query) => postvisit_expression_mut(query.as_mut(), f)?, + Expression::MatchExpression(scrutinee, arms) => { + once(scrutinee.as_mut()) + .chain(arms.iter_mut().map(|(_n, e)| e)) + .try_for_each(|item| postvisit_expression_mut(item, f))?; + } + }; + f(e) +} + +/// Traverses the expression trees of the statement and calls `f` in post-order. +/// Does not enter ASMBlocks or macro definitions. +pub fn postvisit_expression_in_statement_mut( + statement: &mut Statement, + f: &mut F, +) -> ControlFlow +where + F: FnMut(&mut Expression) -> ControlFlow, +{ + match statement { + Statement::FunctionCall(_, _, arguments) => arguments + .iter_mut() + .try_for_each(|e| postvisit_expression_mut(e, f)), + Statement::PlookupIdentity(_, left, right) + | Statement::PermutationIdentity(_, left, right) => left + .selector + .iter_mut() + .chain(left.expressions.iter_mut()) + .chain(right.selector.iter_mut()) + .chain(right.expressions.iter_mut()) + .try_for_each(|e| postvisit_expression_mut(e, f)), + Statement::ConnectIdentity(_start, left, right) => left + .iter_mut() + .chain(right.iter_mut()) + .try_for_each(|e| postvisit_expression_mut(e, f)), + + Statement::Namespace(_, _, e) + | Statement::PolynomialDefinition(_, _, e) + | Statement::PolynomialIdentity(_, e) + | Statement::PublicDeclaration(_, _, _, e) + | Statement::ConstantDefinition(_, _, e) => postvisit_expression_mut(e, f), + + Statement::PolynomialConstantDefinition(_, _, fundef) + | Statement::PolynomialCommitDeclaration(_, _, Some(fundef)) => match fundef { + FunctionDefinition::Query(_, e) | FunctionDefinition::Mapping(_, e) => { + postvisit_expression_mut(e, f) + } + FunctionDefinition::Array(ae) => postvisit_expression_in_array_expression_mut(ae, f), + }, + Statement::PolynomialCommitDeclaration(_, _, None) + | Statement::Include(_, _) + | Statement::PolynomialConstantDeclaration(_, _) + | Statement::MacroDefinition(_, _, _, _, _) + | Statement::ASMBlock(_, _) => ControlFlow::Continue(()), + } +} + +fn postvisit_expression_in_array_expression_mut( + ae: &mut ArrayExpression, + f: &mut F, +) -> ControlFlow +where + F: FnMut(&mut Expression) -> ControlFlow, +{ + match ae { + ArrayExpression::Value(expressions) | ArrayExpression::RepeatedValue(expressions) => { + expressions + .iter_mut() + .try_for_each(|e| postvisit_expression_mut(e, f)) + } + ArrayExpression::Concat(a1, a2) => [a1, a2] + .iter_mut() + .try_for_each(|e| postvisit_expression_in_array_expression_mut(e, f)), + } +} diff --git a/parser/src/lib.rs b/parser/src/lib.rs index f5e2aacad..1a2cdd074 100644 --- a/parser/src/lib.rs +++ b/parser/src/lib.rs @@ -8,6 +8,7 @@ use parser_util::{handle_parse_error, ParseError}; pub mod asm_ast; pub mod ast; pub mod display; +pub mod macro_expansion; lalrpop_mod!( #[allow(clippy::all)] diff --git a/parser/src/macro_expansion.rs b/parser/src/macro_expansion.rs new file mode 100644 index 000000000..11cc9420f --- /dev/null +++ b/parser/src/macro_expansion.rs @@ -0,0 +1,142 @@ +use std::{ + collections::{HashMap, HashSet}, + ops::ControlFlow, +}; + +use crate::ast::*; +use number::FieldElement; + +#[derive(Debug, Default)] +pub struct MacroExpander { + macros: HashMap>, + arguments: Vec>, + parameter_names: HashMap, + shadowing_locals: HashSet, + statements: Vec>, +} + +#[derive(Debug)] +struct MacroDefinition { + pub parameters: Vec, + pub identities: Vec>, + pub expression: Option>, +} + +impl MacroExpander +where + T: FieldElement, +{ + pub fn new() -> Self { + Default::default() + } + + /// Expands all macro references inside the statements and also adds + /// any macros defined therein to the list of macros. + /// + /// Note that macros are not namespaced! + pub fn expand_macros(&mut self, statements: Vec>) -> Vec> { + assert!(self.statements.is_empty()); + for statement in statements { + self.handle_statement(statement); + } + std::mem::take(&mut self.statements) + } + + pub fn handle_statement(&mut self, mut statement: Statement) { + let mut added_locals = false; + if let Statement::PolynomialConstantDefinition(_, _, f) + | Statement::PolynomialCommitDeclaration(_, _, Some(f)) = &statement + { + if let FunctionDefinition::Mapping(params, _) | FunctionDefinition::Query(params, _) = f + { + assert!(self.shadowing_locals.is_empty()); + self.shadowing_locals.extend(params.iter().cloned()); + added_locals = true; + } + } + + postvisit_expression_in_statement_mut(&mut statement, &mut |e| self.process_expression(e)); + + match &mut statement { + Statement::FunctionCall(_start, name, arguments) => { + if !self.macros.contains_key(name) { + panic!( + "Macro {name} not found - only macros allowed at this point, no fixed columns." + ); + } + if self.expand_macro(name, std::mem::take(arguments)).is_some() { + panic!("Invoked a macro in statement context with non-empty expression."); + } + } + Statement::MacroDefinition(_start, name, parameters, statements, expression) => { + // We expand lazily. Is that a mistake? + let is_new = self + .macros + .insert( + std::mem::take(name), + MacroDefinition { + parameters: std::mem::take(parameters), + identities: std::mem::take(statements), + expression: std::mem::take(expression), + }, + ) + .is_none(); + assert!(is_new); + } + _ => self.statements.push(statement), + }; + + if added_locals { + self.shadowing_locals.clear(); + } + } + + fn expand_macro(&mut self, name: &str, arguments: Vec>) -> Option> { + let old_arguments = std::mem::replace(&mut self.arguments, arguments); + + let mac = &self + .macros + .get(name) + .unwrap_or_else(|| panic!("Macro {name} not found.")); + let parameters = mac + .parameters + .iter() + .enumerate() + .map(|(i, n)| (n.clone(), i)) + .collect(); + let old_parameters = std::mem::replace(&mut self.parameter_names, parameters); + + let mut expression = mac.expression.clone(); + let identities = mac.identities.clone(); + for identity in identities { + self.handle_statement(identity) + } + if let Some(e) = &mut expression { + postvisit_expression_mut(e, &mut |e| self.process_expression(e)); + }; + + self.arguments = old_arguments; + self.parameter_names = old_parameters; + expression + } + + fn process_expression(&mut self, e: &mut Expression) -> ControlFlow<()> { + if let Expression::PolynomialReference(poly) = e { + if poly.namespace.is_none() && self.parameter_names.contains_key(&poly.name) { + // TODO to make this work inside macros, "next" and "index" need to be + // their own ast nodes / operators. + assert!(!poly.next); + assert!(poly.index.is_none()); + *e = self.arguments[self.parameter_names[&poly.name]].clone() + } + } else if let Expression::FunctionCall(name, arguments) = e { + if self.macros.contains_key(name.as_str()) { + *e = self + .expand_macro(name, std::mem::take(arguments)) + .expect("Invoked a macro in expression context with empty expression.") + } + } + + ControlFlow::<()>::Continue(()) + } +} diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index 26953179f..bd3556393 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -188,6 +188,7 @@ InstructionBodyElements: Vec> = { InstructionBodyElement: InstructionBodyElement = { "=" => InstructionBodyElement::Expression(Expression::BinaryOperation(l, BinaryOperator::Sub, r)), => InstructionBodyElement::PlookupIdentity(<>), + "(" ")" => InstructionBodyElement::FunctionCall(<>) } // This is only valid in instructions, not in PIL in general. diff --git a/pil_analyzer/src/pil_analyzer.rs b/pil_analyzer/src/pil_analyzer.rs index 7640cb8e0..449b923d2 100644 --- a/pil_analyzer/src/pil_analyzer.rs +++ b/pil_analyzer/src/pil_analyzer.rs @@ -6,6 +6,7 @@ use number::{BigInt, DegreeType, FieldElement}; use parser::asm_ast::ASMStatement; use parser::ast; pub use parser::ast::{BinaryOperator, UnaryOperator}; +use parser::macro_expansion::MacroExpander; use crate::util::previsit_expressions_in_pil_file_mut; @@ -31,7 +32,6 @@ struct PILContext { constants: HashMap, definitions: HashMap>)>, public_declarations: HashMap, - macros: HashMap>, identities: Vec>, /// The order in which definitions and identities /// appear in the source. @@ -44,17 +44,7 @@ struct PILContext { intermediate_poly_counter: u64, identity_counter: HashMap, local_variables: HashMap, - /// If we are evaluating a macro, this holds the arguments. - macro_arguments: Option>>, -} - -#[derive(Debug)] -pub struct MacroDefinition { - pub source: SourceRef, - pub absolute_name: String, - pub parameters: Vec, - pub identities: Vec>, - pub expression: Option>, + macro_expander: MacroExpander, } impl From> for Analyzed { @@ -133,7 +123,9 @@ impl PILContext { }); for statement in pil_file.0 { - self.handle_statement(statement); + for statement in self.macro_expander.expand_macros(vec![statement]) { + self.handle_statement(statement); + } } self.current_file = old_current_file; @@ -192,15 +184,9 @@ impl PILContext { Statement::ConstantDefinition(_, name, value) => { self.handle_constant_definition(name, value) } - Statement::MacroDefinition(start, name, params, statments, expression) => self - .handle_macro_definition( - self.to_source_ref(start), - name, - params, - statments, - expression, - ), - + Statement::MacroDefinition(_, _, _, _, _) => { + panic!("Macros should have been eliminated."); + } Statement::ASMBlock(start, asm_statements) => { self.handle_assembly(self.to_source_ref(start), asm_statements) } @@ -219,20 +205,6 @@ impl PILContext { } fn handle_identity_statement(&mut self, statement: ast::Statement) { - if let ast::Statement::FunctionCall(_start, name, arguments) = statement { - if !self.macros.contains_key(&name) { - panic!( - "Macro {name} not found - only macros allowed at this point, no fixed columns." - ); - } - // TODO check that it does not contain local variable references. - // But we also need to do some other well-formedness checks. - if self.process_macro_call(name, arguments).is_some() { - panic!("Invoked a macro in statement context with non-empty expression."); - } - return; - } - let (start, kind, left, right) = match statement { ast::Statement::PolynomialIdentity(start, expression) => ( start, @@ -435,33 +407,8 @@ impl PILContext { id } - fn handle_macro_definition( - &mut self, - source: SourceRef, - name: String, - params: Vec, - statements: Vec>, - expression: Option>, - ) { - let absolute_name = self.namespaced(&name); - let is_new = self - .macros - .insert( - name, - MacroDefinition { - source, - absolute_name, - parameters: params.to_vec(), - identities: statements.to_vec(), - expression, - }, - ) - .is_none(); - assert!(is_new); - } - fn handle_assembly(&mut self, _source: SourceRef, asm_statements: Vec>) { - let statements = pilgen::asm_to_pil(asm_statements.into_iter()); + let statements = pilgen::asm_to_pil(asm_statements.into_iter(), &mut self.macro_expander); for s in statements { self.handle_statement(s) } @@ -526,15 +473,9 @@ impl PILContext { ast::Expression::PolynomialReference(poly) => { if poly.namespace.is_none() && self.local_variables.contains_key(&poly.name) { let id = self.local_variables[&poly.name]; - // TODO to make this work inside macros, "next" and "index" need to be - // their own ast nodes / operators. assert!(!poly.next); assert!(poly.index.is_none()); - if let Some(arguments) = &self.macro_arguments { - arguments[id as usize].clone() - } else { - Expression::LocalVariableReference(id) - } + Expression::LocalVariableReference(id) } else { Expression::PolynomialReference(self.process_polynomial_reference(poly)) } @@ -561,10 +502,6 @@ impl PILContext { Expression::UnaryOperation(op, Box::new(self.process_expression(*value))) } } - ast::Expression::FunctionCall(name, arguments) if self.macros.contains_key(&name) => { - self.process_macro_call(name, arguments) - .expect("Invoked a macro in expression context with empty expression.") - } ast::Expression::FunctionCall(name, arguments) => Expression::FunctionCall( self.namespaced(&name), self.process_expressions(arguments), @@ -588,38 +525,6 @@ impl PILContext { } } - fn process_macro_call( - &mut self, - name: String, - arguments: Vec>, - ) -> Option> { - let arguments = Some(self.process_expressions(arguments)); - let old_arguments = std::mem::replace(&mut self.macro_arguments, arguments); - - let old_locals = std::mem::take(&mut self.local_variables); - - let mac = &self - .macros - .get(&name) - .unwrap_or_else(|| panic!("Macro {name} not found.")); - self.local_variables = mac - .parameters - .iter() - .enumerate() - .map(|(i, n)| (n.clone(), i as u64)) - .collect(); - // TODO avoid clones - let expression = mac.expression.clone(); - let identities = mac.identities.clone(); - for identity in identities { - self.handle_identity_statement(identity); - } - let result = expression.map(|expr| self.process_expression(expr)); - self.macro_arguments = old_arguments; - self.local_variables = old_locals; - result - } - fn process_polynomial_reference( &self, poly: ast::PolynomialReference, diff --git a/pilgen/src/lib.rs b/pilgen/src/lib.rs index 9ec7607d4..ac52fedf1 100644 --- a/pilgen/src/lib.rs +++ b/pilgen/src/lib.rs @@ -6,6 +6,7 @@ use number::FieldElement; use parser::asm_ast::*; use parser::ast::*; +use parser::macro_expansion::MacroExpander; use parser_util::ParseError; /// Compiles a stand-alone assembly file to PIL. @@ -13,16 +14,19 @@ pub fn compile<'a, T: FieldElement>( file_name: Option<&str>, input: &'a str, ) -> Result, ParseError<'a>> { - let statements = parser::parse_asm(file_name, input) - .map(|ast| ASMPILConverter::new().convert(ast.0, ASMKind::StandAlone))?; + let statements = parser::parse_asm(file_name, input).map(|ast| { + let mut macro_expander = MacroExpander::new(); + ASMPILConverter::new(&mut macro_expander).convert(ast.0, ASMKind::StandAlone) + })?; Ok(PILFile(statements)) } /// Compiles inline assembly to PIL. pub fn asm_to_pil( statements: impl IntoIterator>, + macro_expander: &mut MacroExpander, ) -> Vec> { - ASMPILConverter::new().convert(statements, ASMKind::Inline) + ASMPILConverter::new(macro_expander).convert(statements, ASMKind::Inline) } #[derive(PartialEq)] @@ -31,8 +35,8 @@ enum ASMKind { StandAlone, } -#[derive(Default)] -struct ASMPILConverter { +struct ASMPILConverter<'a, T> { + macro_expander: &'a mut MacroExpander, pil: Vec>, pc_name: Option, registers: BTreeMap>, @@ -44,9 +48,18 @@ struct ASMPILConverter { program_constant_names: Vec, } -impl ASMPILConverter { - fn new() -> Self { - Default::default() +impl<'a, T: FieldElement> ASMPILConverter<'a, T> { + fn new(macro_expander: &'a mut MacroExpander) -> Self { + Self { + macro_expander, + pil: Default::default(), + pc_name: None, + registers: Default::default(), + instructions: Default::default(), + code_lines: Default::default(), + line_lookup: Default::default(), + program_constant_names: Default::default(), + } } fn convert( @@ -91,7 +104,9 @@ impl ASMPILConverter { ASMStatement::InstructionDeclaration(start, name, params, body) => { self.handle_instruction_def(start, body, name, params); } - ASMStatement::InlinePil(_start, statements) => self.pil.extend(statements.clone()), + ASMStatement::InlinePil(_start, statements) => self + .pil + .extend(self.macro_expander.expand_macros(statements)), ASMStatement::Assignment(start, write_regs, assign_reg, value) => match *value { Expression::FunctionCall(function_name, args) => { self.handle_functional_instruction( @@ -263,6 +278,30 @@ impl ASMPILConverter { let instr = Instruction { inputs, outputs }; + // First transform into PIL so that we can apply macro expansion. + let mut statements = body + .into_iter() + .map(|el| match el { + InstructionBodyElement::Expression(expr) => { + Statement::PolynomialIdentity(start, expr) + } + InstructionBodyElement::PlookupIdentity(left, op, right) => { + assert!( + left.selector.is_none(), + "LHS selector not supported, could and-combine with instruction flag later." + ); + match op { + PlookupOperator::In => Statement::PlookupIdentity(start, left, right), + PlookupOperator::Is => Statement::PermutationIdentity(start, left, right), + } + } + InstructionBodyElement::FunctionCall(name, arguments) => { + Statement::FunctionCall(start, name, arguments) + } + }) + .collect::>(); + + // Substitute parameter references by the column names let substitutions = instr .literal_arg_names() .map(|arg_name| { @@ -270,40 +309,52 @@ impl ASMPILConverter { self.create_witness_fixed_pair(start, ¶m_col_name); (arg_name.clone(), param_col_name) }) - .collect(); - - for expr in body { - match expr { - InstructionBodyElement::Expression(expr) => { - let expr = substitute(expr, &substitutions); - match extract_update(expr) { - (Some(var), expr) => { - self.registers - .get_mut(&var) - .unwrap() - .conditioned_updates - .push((direct_reference(&instruction_flag), expr)); - } - (None, expr) => self.pil.push(Statement::PolynomialIdentity( - 0, - build_mul(direct_reference(&instruction_flag), expr.clone()), - )), + .collect::>(); + statements.iter_mut().for_each(|s| { + postvisit_expression_in_statement_mut(s, &mut |e| { + if let Expression::PolynomialReference(r) = e { + if let Some(sub) = substitutions.get(&r.name) { + r.name = sub.clone(); } } - InstructionBodyElement::PlookupIdentity(left, op, right) => { - assert!(left.selector.is_none(), "LHS selector not supported, could and-combine with instruction flag later."); - let left = SelectedExpressions { - selector: Some(direct_reference(&instruction_flag)), - expressions: substitute_vec(left.expressions, &substitutions), - }; - let right = substitute_selected_exprs(right, &substitutions); - self.pil.push(match op { - PlookupOperator::In => Statement::PlookupIdentity(start, left, right), - PlookupOperator::Is => Statement::PermutationIdentity(start, left, right), - }) + std::ops::ControlFlow::Continue::<()>(()) + }); + }); + + // Expand macros and analyze resulting statements. + for mut statement in self.macro_expander.expand_macros(statements) { + if let Statement::PolynomialIdentity(_start, expr) = statement { + match extract_update(expr) { + (Some(var), expr) => { + self.registers + .get_mut(&var) + .unwrap() + .conditioned_updates + .push((direct_reference(&instruction_flag), expr)); + } + (None, expr) => self.pil.push(Statement::PolynomialIdentity( + 0, + build_mul(direct_reference(&instruction_flag), expr.clone()), + )), + } + } else { + match &mut statement { + Statement::PermutationIdentity(_, left, _) + | Statement::PlookupIdentity(_, left, _) => { + assert!( + left.selector.is_none(), + "LHS selector not supported, could and-combine with instruction flag later." + ); + left.selector = Some(direct_reference(&instruction_flag)); + self.pil.push(statement) + } + _ => { + panic!("Invalid statement for instruction body: {statement}"); + } } } } + self.instructions.insert(name, instr); } @@ -879,10 +930,6 @@ fn build_binary_expr( Expression::BinaryOperation(Box::new(left), op, Box::new(right)) } -fn build_unary_expr(op: UnaryOperator, exp: Expression) -> Expression { - Expression::UnaryOperation(op, Box::new(exp)) -} - fn build_number>(value: V) -> Expression { Expression::Number(value.into()) } @@ -908,77 +955,6 @@ fn extract_update(expr: Expression) -> (Option, Expr } } -fn substitute( - input: Expression, - substitution: &HashMap, -) -> Expression { - match input { - // TODO namespace - Expression::PolynomialReference(r) => { - Expression::PolynomialReference(PolynomialReference { - name: substitute_string(&r.name, substitution), - ..r.clone() - }) - } - Expression::BinaryOperation(left, op, right) => build_binary_expr( - substitute(*left, substitution), - op, - substitute(*right, substitution), - ), - Expression::UnaryOperation(op, exp) => build_unary_expr(op, substitute(*exp, substitution)), - Expression::FunctionCall(name, args) => Expression::FunctionCall( - name, - args.into_iter() - .map(|e| substitute(e, substitution)) - .collect(), - ), - Expression::Tuple(items) => Expression::Tuple( - items - .into_iter() - .map(|e| substitute(e, substitution)) - .collect(), - ), - Expression::Constant(_) - | Expression::PublicReference(_) - | Expression::Number(_) - | Expression::String(_) - | Expression::FreeInput(_) => input.clone(), - Expression::MatchExpression(scrutinee, arms) => Expression::MatchExpression( - Box::new(substitute(*scrutinee, substitution)), - arms.into_iter() - .map(|(n, e)| (n, substitute(e, substitution))) - .collect(), - ), - } -} - -fn substitute_selected_exprs( - input: SelectedExpressions, - substitution: &HashMap, -) -> SelectedExpressions { - SelectedExpressions { - selector: input.selector.map(|s| substitute(s, substitution)), - expressions: substitute_vec(input.expressions, substitution), - } -} - -fn substitute_vec( - input: Vec>, - substitution: &HashMap, -) -> Vec> { - input - .into_iter() - .map(|e| substitute(e, substitution)) - .collect() -} - -fn substitute_string(input: &str, substitution: &HashMap) -> String { - substitution - .get(input) - .cloned() - .unwrap_or_else(|| input.to_string()) -} - #[cfg(test)] mod test { use std::fs; diff --git a/test_data/pil/simple_sum_asm_macro.pil b/test_data/pil/simple_sum_asm_macro.pil new file mode 100644 index 000000000..43454380e --- /dev/null +++ b/test_data/pil/simple_sum_asm_macro.pil @@ -0,0 +1,42 @@ +namespace Main(2**10); + col witness XInv; + col witness XIsZero; + XIsZero * (1 - XIsZero) = 0; + + macro if_then_else(condition, true_value, false_value) { condition * true_value + (1 - condition) * false_value }; + macro jump_to(target) { pc' = target; }; + macro jump_to_if(condition, target) { jump_to(if_then_else(condition, target, pc + 1)); }; + +assembly { + reg pc[@pc]; + reg X[<=]; + reg A; + reg CNT; + + pil { + // Just to test if pil-inside-assembly-inside-pil works. + XIsZero = 1 - X * XInv; + XIsZero * X = 0; + } + + instr jmpz X, l: label { jump_to_if(XIsZero, l) } + instr jmp l: label { jump_to(l) } + instr dec_CNT { CNT' = CNT - 1 } + instr assert_zero X { XIsZero = 1 } + + CNT <=X= ${ ("input", 1) }; + + start:: + jmpz CNT, check; + A <=X= A + ${ ("input", CNT + 1) }; + // Could use "CNT <=X= CNT - 1", but that would need X. + dec_CNT; + jmp start; + + check:: + A <=X= A - ${ ("input", 0) }; + assert_zero A; + + end:: + jmp end; +}; \ No newline at end of file