Make instruction bodies contain pil statements. (#531)

This commit is contained in:
chriseth
2023-09-05 19:59:22 +02:00
committed by GitHub
parent 67faa7f6c2
commit f129f80840
8 changed files with 103 additions and 122 deletions

View File

@@ -4,9 +4,9 @@ use std::{
};
use ast::parsed::{
asm::{ASMFile, Instruction, InstructionBody, InstructionBodyElement, MachineStatement},
asm::{ASMFile, Instruction, InstructionBody, MachineStatement},
postvisit_expression_in_statement_mut, postvisit_expression_mut, Expression,
FunctionDefinition, PilStatement, SelectedExpressions,
FunctionDefinition, PilStatement,
};
use number::FieldElement;
@@ -45,21 +45,7 @@ where
MachineStatement::InstructionDeclaration(_, _, Instruction { body, .. }) => {
match body {
InstructionBody::Local(body) => {
body.iter_mut().for_each(|e| match e {
InstructionBodyElement::PolynomialIdentity(left, right) => {
self.process_expression(left);
self.process_expression(right);
}
InstructionBodyElement::PlookupIdentity(left, _, right) => {
self.process_selected_expressions(left);
self.process_selected_expressions(right);
}
InstructionBodyElement::FunctionCall(c) => {
c.arguments.iter_mut().for_each(|i| {
self.process_expression(i);
});
}
});
*body = expander.expand_macros(std::mem::take(body))
}
InstructionBody::CallableRef(..) => {}
}
@@ -184,21 +170,4 @@ where
ControlFlow::<()>::Continue(())
}
fn process_expressions(&mut self, exprs: &mut [Expression<T>]) -> ControlFlow<()> {
for e in exprs.iter_mut() {
self.process_expression(e)?;
}
ControlFlow::Continue(())
}
fn process_selected_expressions(
&mut self,
exprs: &mut SelectedExpressions<T>,
) -> ControlFlow<()> {
if let Some(e) = &mut exprs.selector {
self.process_expression(e)?;
};
self.process_expressions(&mut exprs.expressions)
}
}

View File

@@ -9,7 +9,7 @@ use ast::{
LinkDefinitionStatement, Machine, PilBlock, RegisterDeclarationStatement, RegisterTy, Rom,
},
parsed::{
asm::{InstructionBody, InstructionBodyElement, PlookupOperator},
asm::InstructionBody,
build::{
build_add, build_binary_expr, build_mul, build_number, build_sub, direct_reference,
next_reference,
@@ -342,33 +342,7 @@ impl<T: FieldElement> ASMPILConverter<T> {
// First transform into PIL so that we can apply macro expansion.
let res = match s.instruction.body {
InstructionBody::Local(body) => {
let mut statements = body
.into_iter()
.map(|el| match el {
InstructionBodyElement::PolynomialIdentity(left, right) => {
PilStatement::PolynomialIdentity(s.start, build_sub(left, right))
}
InstructionBodyElement::PlookupIdentity(left, op, right) => {
assert!(
left.selector.is_none(),
"LHS selector not supported, could and-combine with instruction flag later."
);
match op {
PlookupOperator::In => {
PilStatement::PlookupIdentity(s.start, left, right)
}
PlookupOperator::Is => {
PilStatement::PermutationIdentity(s.start, left, right)
}
}
}
InstructionBodyElement::FunctionCall(c) => {
PilStatement::FunctionCall(s.start, c.id, c.arguments)
}
})
.collect::<Vec<_>>();
InstructionBody::Local(mut body) => {
// Substitute parameter references by the column names
let substitutions = instruction
.literal_arg_names()
@@ -378,7 +352,7 @@ impl<T: FieldElement> ASMPILConverter<T> {
(arg_name.clone(), param_col_name)
})
.collect::<HashMap<_, _>>();
statements.iter_mut().for_each(|s| {
body.iter_mut().for_each(|s| {
postvisit_expression_in_statement_mut(s, &mut |e| {
if let Expression::PolynomialReference(r) = e {
if let Some(sub) = substitutions.get(r.name()) {
@@ -389,8 +363,7 @@ impl<T: FieldElement> ASMPILConverter<T> {
});
});
// Expand macros and analyze resulting statements.
for mut statement in statements {
for mut statement in body {
if let PilStatement::PolynomialIdentity(_start, expr) = statement {
match extract_update(expr) {
(Some(var), expr) => {
@@ -412,9 +385,9 @@ impl<T: FieldElement> ASMPILConverter<T> {
PilStatement::PermutationIdentity(_, left, _)
| PilStatement::PlookupIdentity(_, left, _) => {
assert!(
left.selector.is_none(),
"LHS selector not supported, could and-combine with instruction flag later."
);
left.selector.is_none(),
"LHS selector not supported, could and-combine with instruction flag later."
);
left.selector = Some(direct_reference(&instruction_flag));
self.pil.push(statement)
}

View File

@@ -1,6 +1,6 @@
use number::AbstractNumberType;
use super::{Expression, PilStatement, SelectedExpressions};
use super::{Expression, PilStatement};
#[derive(Debug, PartialEq, Eq)]
pub struct ASMFile<T> {
@@ -101,7 +101,7 @@ pub struct CallableRef {
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub enum InstructionBody<T> {
Local(Vec<InstructionBodyElement<T>>),
Local(Vec<PilStatement<T>>),
CallableRef(CallableRef),
}
@@ -133,17 +133,6 @@ pub struct Param {
pub ty: Option<String>,
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub enum InstructionBodyElement<T> {
PolynomialIdentity(Expression<T>, Expression<T>),
PlookupIdentity(
SelectedExpressions<T>,
PlookupOperator,
SelectedExpressions<T>,
),
FunctionCall(FunctionCall<T>),
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub struct FunctionCall<T> {
pub id: String,

View File

@@ -253,22 +253,6 @@ impl<T: Display> Display for FunctionCall<T> {
}
}
impl<T: Display> Display for InstructionBodyElement<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
match self {
InstructionBodyElement::PolynomialIdentity(left, right) => {
write!(f, "{left} = {right}")
}
InstructionBodyElement::PlookupIdentity(left, operator, right) => {
write!(f, "{left} {operator} {right}")
}
InstructionBodyElement::FunctionCall(c) => {
write!(f, "{c}")
}
}
}
}
impl Display for PlookupOperator {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
match self {

View File

@@ -173,3 +173,10 @@ fn hello_world_asm_fail() {
let i = [1];
verify_asm::<GoldilocksField>(f, slice_to_vec(&i));
}
#[test]
fn test_macros_in_instructions() {
let f = "macros_in_instructions.asm";
verify_asm::<GoldilocksField>(f, Default::default());
gen_halo2_proof(f, Default::default());
}

View File

@@ -212,24 +212,17 @@ pub CallableRef: CallableRef = {
<instance:Identifier> "." <callable:Identifier> => CallableRef { instance, callable }
}
InstructionBodyElements: Vec<InstructionBodyElement<T>> = {
InstructionBodyElements: Vec<PilStatement<T>> = {
<mut list:( <InstructionBodyElement> "," )*> <end:InstructionBodyElement> => { list.push(end); list },
=> vec![]
}
InstructionBodyElement: InstructionBodyElement<T> = {
<l:Expression> "=" <r:Expression> => InstructionBodyElement::PolynomialIdentity(l, r),
<SelectedExpressions> <PlookupOperator> <SelectedExpressions> => InstructionBodyElement::PlookupIdentity(<>),
<id:Identifier> "(" <arguments:ExpressionList> ")" => InstructionBodyElement::FunctionCall(FunctionCall {<>})
}
// This is only valid in instructions, not in PIL in general.
// "connect" is not supported because it does not support selectors
// and we need that for the instruction.
PlookupOperator: PlookupOperator = {
"in" => PlookupOperator::In,
"is" => PlookupOperator::Is,
InstructionBodyElement: PilStatement<T> = {
PolynomialIdentity,
PlookupIdentity,
PermutationIdentity,
// We could use FunctionCallStatement here, but it makes lalrpop fail to build
<@L> <Identifier> "(" <ExpressionList> ")" => PilStatement::FunctionCall(<>)
}
Params: Params = {

View File

@@ -0,0 +1,31 @@
machine MacroAsm {
reg pc[@pc];
reg X[<=];
reg Y[<=];
reg A;
constraints {
macro branch_if(condition, target) {
pc' = condition * target + (1 - condition) * (pc + 1);
};
col witness XInv;
col witness XIsZero;
XIsZero = 1 - X * XInv;
XIsZero * X = 0;
XIsZero * (1 - XIsZero) = 0;
}
instr bz X, target: label { branch_if(XIsZero, target) }
instr fail { X = X + 1 }
instr assert_zero X { XIsZero = 1 }
function main {
A <=X= 0;
bz A, is_zero;
fail;
is_zero::
assert_zero A;
return;
}
}

View File

@@ -8,7 +8,13 @@ use ast::{
LinkDefinitionStatement, Machine, OperationSymbol, PilBlock, RegisterDeclarationStatement,
RegisterTy, Return, SubmachineDeclaration,
},
parsed::asm::{ASMFile, FunctionStatement, LinkDeclaration, MachineStatement, RegisterFlag},
parsed::{
self,
asm::{
ASMFile, FunctionStatement, InstructionBody, LinkDeclaration, MachineStatement,
RegisterFlag,
},
},
};
use number::FieldElement;
@@ -66,17 +72,14 @@ impl<T: FieldElement> TypeChecker<T> {
registers.push(RegisterDeclarationStatement { start, name, ty });
}
MachineStatement::InstructionDeclaration(start, name, instruction) => {
if name == "return" {
errors.push("Instruction cannot use reserved name `return`".into());
match self.check_instruction(&name, instruction) {
Ok(instruction) => instructions.push(InstructionDefinitionStatement {
start,
name,
instruction,
}),
Err(e) => errors.extend(e),
}
instructions.push(InstructionDefinitionStatement {
start,
name,
instruction: Instruction {
params: instruction.params,
body: instruction.body,
},
});
}
MachineStatement::LinkDeclaration(LinkDeclaration {
start,
@@ -268,4 +271,36 @@ impl<T: FieldElement> TypeChecker<T> {
Ok(AnalysisASMFile { machines })
}
}
fn check_instruction(
&mut self,
name: &str,
instruction: parsed::asm::Instruction<T>,
) -> Result<Instruction<T>, Vec<String>> {
if name == "return" {
return Err(vec!["Instruction cannot use reserved name `return`".into()]);
}
if let InstructionBody::Local(statements) = &instruction.body {
let errors: Vec<_> = statements
.iter()
.filter_map(|s| match s {
ast::parsed::PilStatement::PolynomialIdentity(_, _) => None,
ast::parsed::PilStatement::PermutationIdentity(_, l, _)
| ast::parsed::PilStatement::PlookupIdentity(_, l, _) => l
.selector
.is_some()
.then_some(format!("LHS selector not yet supported in {s}.")),
_ => Some(format!("Statement not allowed in instruction body: {s}")),
})
.collect();
if !errors.is_empty() {
return Err(errors);
}
}
Ok(Instruction {
params: instruction.params,
body: instruction.body,
})
}
}