diff --git a/analysis/src/macro_expansion.rs b/analysis/src/macro_expansion.rs index 2d3a691e3..cde19c1b5 100644 --- a/analysis/src/macro_expansion.rs +++ b/analysis/src/macro_expansion.rs @@ -4,9 +4,9 @@ use std::{ }; use ast::parsed::{ - asm::{ASMFile, Instruction, InstructionBody, InstructionBodyElement, MachineStatement}, + asm::{ASMFile, Instruction, InstructionBody, MachineStatement}, postvisit_expression_in_statement_mut, postvisit_expression_mut, Expression, - FunctionDefinition, PilStatement, SelectedExpressions, + FunctionDefinition, PilStatement, }; use number::FieldElement; @@ -45,21 +45,7 @@ where MachineStatement::InstructionDeclaration(_, _, Instruction { body, .. }) => { match body { InstructionBody::Local(body) => { - body.iter_mut().for_each(|e| match e { - InstructionBodyElement::PolynomialIdentity(left, right) => { - self.process_expression(left); - self.process_expression(right); - } - InstructionBodyElement::PlookupIdentity(left, _, right) => { - self.process_selected_expressions(left); - self.process_selected_expressions(right); - } - InstructionBodyElement::FunctionCall(c) => { - c.arguments.iter_mut().for_each(|i| { - self.process_expression(i); - }); - } - }); + *body = expander.expand_macros(std::mem::take(body)) } InstructionBody::CallableRef(..) => {} } @@ -184,21 +170,4 @@ where ControlFlow::<()>::Continue(()) } - - fn process_expressions(&mut self, exprs: &mut [Expression]) -> ControlFlow<()> { - for e in exprs.iter_mut() { - self.process_expression(e)?; - } - ControlFlow::Continue(()) - } - - fn process_selected_expressions( - &mut self, - exprs: &mut SelectedExpressions, - ) -> ControlFlow<()> { - if let Some(e) = &mut exprs.selector { - self.process_expression(e)?; - }; - self.process_expressions(&mut exprs.expressions) - } } diff --git a/asm_to_pil/src/vm_to_constrained.rs b/asm_to_pil/src/vm_to_constrained.rs index 9ffafeeb0..1b035e409 100644 --- a/asm_to_pil/src/vm_to_constrained.rs +++ b/asm_to_pil/src/vm_to_constrained.rs @@ -9,7 +9,7 @@ use ast::{ LinkDefinitionStatement, Machine, PilBlock, RegisterDeclarationStatement, RegisterTy, Rom, }, parsed::{ - asm::{InstructionBody, InstructionBodyElement, PlookupOperator}, + asm::InstructionBody, build::{ build_add, build_binary_expr, build_mul, build_number, build_sub, direct_reference, next_reference, @@ -342,33 +342,7 @@ impl ASMPILConverter { // First transform into PIL so that we can apply macro expansion. let res = match s.instruction.body { - InstructionBody::Local(body) => { - let mut statements = body - .into_iter() - .map(|el| match el { - InstructionBodyElement::PolynomialIdentity(left, right) => { - PilStatement::PolynomialIdentity(s.start, build_sub(left, right)) - } - InstructionBodyElement::PlookupIdentity(left, op, right) => { - assert!( - left.selector.is_none(), - "LHS selector not supported, could and-combine with instruction flag later." - ); - match op { - PlookupOperator::In => { - PilStatement::PlookupIdentity(s.start, left, right) - } - PlookupOperator::Is => { - PilStatement::PermutationIdentity(s.start, left, right) - } - } - } - InstructionBodyElement::FunctionCall(c) => { - PilStatement::FunctionCall(s.start, c.id, c.arguments) - } - }) - .collect::>(); - + InstructionBody::Local(mut body) => { // Substitute parameter references by the column names let substitutions = instruction .literal_arg_names() @@ -378,7 +352,7 @@ impl ASMPILConverter { (arg_name.clone(), param_col_name) }) .collect::>(); - statements.iter_mut().for_each(|s| { + body.iter_mut().for_each(|s| { postvisit_expression_in_statement_mut(s, &mut |e| { if let Expression::PolynomialReference(r) = e { if let Some(sub) = substitutions.get(r.name()) { @@ -389,8 +363,7 @@ impl ASMPILConverter { }); }); - // Expand macros and analyze resulting statements. - for mut statement in statements { + for mut statement in body { if let PilStatement::PolynomialIdentity(_start, expr) = statement { match extract_update(expr) { (Some(var), expr) => { @@ -412,9 +385,9 @@ impl ASMPILConverter { PilStatement::PermutationIdentity(_, left, _) | PilStatement::PlookupIdentity(_, left, _) => { assert!( - left.selector.is_none(), - "LHS selector not supported, could and-combine with instruction flag later." - ); + left.selector.is_none(), + "LHS selector not supported, could and-combine with instruction flag later." + ); left.selector = Some(direct_reference(&instruction_flag)); self.pil.push(statement) } diff --git a/ast/src/parsed/asm.rs b/ast/src/parsed/asm.rs index 94411397a..5c7e1c545 100644 --- a/ast/src/parsed/asm.rs +++ b/ast/src/parsed/asm.rs @@ -1,6 +1,6 @@ use number::AbstractNumberType; -use super::{Expression, PilStatement, SelectedExpressions}; +use super::{Expression, PilStatement}; #[derive(Debug, PartialEq, Eq)] pub struct ASMFile { @@ -101,7 +101,7 @@ pub struct CallableRef { #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] pub enum InstructionBody { - Local(Vec>), + Local(Vec>), CallableRef(CallableRef), } @@ -133,17 +133,6 @@ pub struct Param { pub ty: Option, } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] -pub enum InstructionBodyElement { - PolynomialIdentity(Expression, Expression), - PlookupIdentity( - SelectedExpressions, - PlookupOperator, - SelectedExpressions, - ), - FunctionCall(FunctionCall), -} - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)] pub struct FunctionCall { pub id: String, diff --git a/ast/src/parsed/display.rs b/ast/src/parsed/display.rs index cf36281ca..5f982c0e7 100644 --- a/ast/src/parsed/display.rs +++ b/ast/src/parsed/display.rs @@ -253,22 +253,6 @@ impl Display for FunctionCall { } } -impl Display for InstructionBodyElement { - fn fmt(&self, f: &mut Formatter<'_>) -> Result { - match self { - InstructionBodyElement::PolynomialIdentity(left, right) => { - write!(f, "{left} = {right}") - } - InstructionBodyElement::PlookupIdentity(left, operator, right) => { - write!(f, "{left} {operator} {right}") - } - InstructionBodyElement::FunctionCall(c) => { - write!(f, "{c}") - } - } - } -} - impl Display for PlookupOperator { fn fmt(&self, f: &mut Formatter<'_>) -> Result { match self { diff --git a/compiler/tests/asm.rs b/compiler/tests/asm.rs index f0cf251dc..175655746 100644 --- a/compiler/tests/asm.rs +++ b/compiler/tests/asm.rs @@ -173,3 +173,10 @@ fn hello_world_asm_fail() { let i = [1]; verify_asm::(f, slice_to_vec(&i)); } + +#[test] +fn test_macros_in_instructions() { + let f = "macros_in_instructions.asm"; + verify_asm::(f, Default::default()); + gen_halo2_proof(f, Default::default()); +} diff --git a/parser/src/powdr.lalrpop b/parser/src/powdr.lalrpop index cc672e3a1..c929e1372 100644 --- a/parser/src/powdr.lalrpop +++ b/parser/src/powdr.lalrpop @@ -212,24 +212,17 @@ pub CallableRef: CallableRef = { "." => CallableRef { instance, callable } } -InstructionBodyElements: Vec> = { +InstructionBodyElements: Vec> = { "," )*> => { list.push(end); list }, => vec![] } -InstructionBodyElement: InstructionBodyElement = { - "=" => InstructionBodyElement::PolynomialIdentity(l, r), - => InstructionBodyElement::PlookupIdentity(<>), - "(" ")" => InstructionBodyElement::FunctionCall(FunctionCall {<>}) -} - -// This is only valid in instructions, not in PIL in general. -// "connect" is not supported because it does not support selectors -// and we need that for the instruction. - -PlookupOperator: PlookupOperator = { - "in" => PlookupOperator::In, - "is" => PlookupOperator::Is, +InstructionBodyElement: PilStatement = { + PolynomialIdentity, + PlookupIdentity, + PermutationIdentity, + // We could use FunctionCallStatement here, but it makes lalrpop fail to build + <@L> "(" ")" => PilStatement::FunctionCall(<>) } Params: Params = { diff --git a/test_data/asm/macros_in_instructions.asm b/test_data/asm/macros_in_instructions.asm new file mode 100644 index 000000000..819f8e4d4 --- /dev/null +++ b/test_data/asm/macros_in_instructions.asm @@ -0,0 +1,31 @@ +machine MacroAsm { + reg pc[@pc]; + reg X[<=]; + reg Y[<=]; + reg A; + + constraints { + macro branch_if(condition, target) { + pc' = condition * target + (1 - condition) * (pc + 1); + }; + + col witness XInv; + col witness XIsZero; + XIsZero = 1 - X * XInv; + XIsZero * X = 0; + XIsZero * (1 - XIsZero) = 0; + } + + instr bz X, target: label { branch_if(XIsZero, target) } + instr fail { X = X + 1 } + instr assert_zero X { XIsZero = 1 } + + function main { + A <=X= 0; + bz A, is_zero; + fail; + is_zero:: + assert_zero A; + return; + } +} \ No newline at end of file diff --git a/type_check/src/lib.rs b/type_check/src/lib.rs index 2190e28c3..f4f6e3557 100644 --- a/type_check/src/lib.rs +++ b/type_check/src/lib.rs @@ -8,7 +8,13 @@ use ast::{ LinkDefinitionStatement, Machine, OperationSymbol, PilBlock, RegisterDeclarationStatement, RegisterTy, Return, SubmachineDeclaration, }, - parsed::asm::{ASMFile, FunctionStatement, LinkDeclaration, MachineStatement, RegisterFlag}, + parsed::{ + self, + asm::{ + ASMFile, FunctionStatement, InstructionBody, LinkDeclaration, MachineStatement, + RegisterFlag, + }, + }, }; use number::FieldElement; @@ -66,17 +72,14 @@ impl TypeChecker { registers.push(RegisterDeclarationStatement { start, name, ty }); } MachineStatement::InstructionDeclaration(start, name, instruction) => { - if name == "return" { - errors.push("Instruction cannot use reserved name `return`".into()); + match self.check_instruction(&name, instruction) { + Ok(instruction) => instructions.push(InstructionDefinitionStatement { + start, + name, + instruction, + }), + Err(e) => errors.extend(e), } - instructions.push(InstructionDefinitionStatement { - start, - name, - instruction: Instruction { - params: instruction.params, - body: instruction.body, - }, - }); } MachineStatement::LinkDeclaration(LinkDeclaration { start, @@ -268,4 +271,36 @@ impl TypeChecker { Ok(AnalysisASMFile { machines }) } } + + fn check_instruction( + &mut self, + name: &str, + instruction: parsed::asm::Instruction, + ) -> Result, Vec> { + if name == "return" { + return Err(vec!["Instruction cannot use reserved name `return`".into()]); + } + + if let InstructionBody::Local(statements) = &instruction.body { + let errors: Vec<_> = statements + .iter() + .filter_map(|s| match s { + ast::parsed::PilStatement::PolynomialIdentity(_, _) => None, + ast::parsed::PilStatement::PermutationIdentity(_, l, _) + | ast::parsed::PilStatement::PlookupIdentity(_, l, _) => l + .selector + .is_some() + .then_some(format!("LHS selector not yet supported in {s}.")), + _ => Some(format!("Statement not allowed in instruction body: {s}")), + }) + .collect(); + if !errors.is_empty() { + return Err(errors); + } + } + Ok(Instruction { + params: instruction.params, + body: instruction.body, + }) + } }