Merge pull request #485 from powdr-labs/multi-return

Implement Multi-return
This commit is contained in:
Leo
2023-09-29 06:50:48 +00:00
committed by GitHub
15 changed files with 302 additions and 99 deletions

View File

@@ -1,7 +1,8 @@
//! Infer assignment registers in asm statements
use ast::asm_analysis::{
AnalysisASMFile, AssignmentStatement, Expression, FunctionStatement, Machine,
use ast::{
asm_analysis::{AnalysisASMFile, Expression, FunctionStatement, Machine},
parsed::asm::AssignmentRegister,
};
use number::FieldElement;
@@ -33,47 +34,47 @@ fn infer_machine<T: FieldElement>(mut machine: Machine<T>) -> Result<Machine<T>,
for f in machine.callable.functions_mut() {
for s in f.body.statements.iter_mut() {
if let FunctionStatement::Assignment(a) = s {
let expr_reg = match &*a.rhs {
// Map function calls to the list of assignment registers and all other expressions to a list of None.
let expr_regs = match &*a.rhs {
Expression::FunctionCall(c) => {
let def = machine
.instructions
.iter()
.find(|i| i.name == c.id)
.unwrap();
let output = {
let outputs = def.instruction.params.outputs.as_ref().unwrap();
assert!(outputs.params.len() == 1);
&outputs.params[0]
};
assert!(output.ty.is_none());
Some(output.name.clone())
let outputs = def.instruction.params.outputs.clone().unwrap_or_default();
outputs
.params
.iter()
.map(|o| {
assert!(o.ty.is_none());
AssignmentRegister::Register(o.name.clone())
})
.collect::<Vec<_>>()
}
_ => None,
_ => vec![AssignmentRegister::Wildcard; a.lhs_with_reg.len()],
};
match (&mut a.using_reg, expr_reg) {
(Some(using_reg), Some(expr_reg)) if *using_reg != expr_reg => {
errors.push(format!("Assignment register `{}` is incompatible with `{}`. Try replacing `<={}=` by `<==`.", using_reg, a.rhs, using_reg));
}
(Some(_), _) => {}
(None, Some(expr_reg)) => {
// infer the assignment register to that of the rhs
a.using_reg = Some(expr_reg);
}
(None, None) => {
let hint = AssignmentStatement {
using_reg: Some(
machine
.registers
.iter()
.find(|r| r.ty.is_assignment())
.unwrap()
.name
.clone(),
),
..a.clone()
};
errors.push(format!("Impossible to infer the assignment register for `{a}`. Try using an assignment register like `{hint}`."));
assert_eq!(expr_regs.len(), a.lhs_with_reg.len());
for ((w, reg), expr_reg) in a.lhs_with_reg.iter_mut().zip(expr_regs) {
match (&reg, expr_reg) {
(
AssignmentRegister::Register(using_reg),
AssignmentRegister::Register(expr_reg),
) if *using_reg != expr_reg => {
errors.push(format!("Assignment register `{}` is incompatible with `{}`. Try using `<==` with no explicit assignment registers.", using_reg, a.rhs));
}
(AssignmentRegister::Register(_), _) => {}
(AssignmentRegister::Wildcard, AssignmentRegister::Register(expr_reg)) => {
// infer the assignment register to that of the rhs
*reg = AssignmentRegister::Register(expr_reg);
}
(AssignmentRegister::Wildcard, AssignmentRegister::Wildcard) => {
errors.push(format!("Impossible to infer the assignment register to write to register `{w}`"));
}
}
}
}
@@ -115,8 +116,8 @@ mod tests {
let file = infer_str::<Bn254Field>(file).unwrap();
if let FunctionStatement::Assignment(AssignmentStatement { using_reg, .. }) = file.machines
[&parse_absolute_path("Machine")]
if let FunctionStatement::Assignment(AssignmentStatement { lhs_with_reg, .. }) = file
.machines[&parse_absolute_path("Machine")]
.functions()
.next()
.unwrap()
@@ -126,7 +127,10 @@ mod tests {
.next()
.unwrap()
{
assert_eq!(*using_reg, Some("X".to_string()));
assert_eq!(
lhs_with_reg[0].1,
AssignmentRegister::Register("X".to_string())
);
} else {
panic!()
};
@@ -151,8 +155,8 @@ mod tests {
let file = infer_str::<Bn254Field>(file).unwrap();
if let FunctionStatement::Assignment(AssignmentStatement { using_reg, .. }) = &file.machines
[&parse_absolute_path("Machine")]
if let FunctionStatement::Assignment(AssignmentStatement { lhs_with_reg, .. }) = &file
.machines[&parse_absolute_path("Machine")]
.functions()
.next()
.unwrap()
@@ -162,7 +166,10 @@ mod tests {
.next()
.unwrap()
{
assert_eq!(*using_reg, Some("X".to_string()));
assert_eq!(
lhs_with_reg[0].1,
AssignmentRegister::Register("X".to_string())
);
} else {
panic!()
};
@@ -185,7 +192,7 @@ mod tests {
}
"#;
assert_eq!(infer_str::<Bn254Field>(file).unwrap_err(), vec!["Assignment register `Y` is incompatible with `foo()`. Try replacing `<=Y=` by `<==`."]);
assert_eq!(infer_str::<Bn254Field>(file).unwrap_err(), vec!["Assignment register `Y` is incompatible with `foo()`. Try using `<==` with no explicit assignment registers."]);
}
#[test]
@@ -203,6 +210,11 @@ mod tests {
}
"#;
assert_eq!(infer_str::<Bn254Field>(file).unwrap_err(), vec!["Impossible to infer the assignment register for `A <== 1;`. Try using an assignment register like `A <=X= 1;`.".to_string()]);
assert_eq!(
infer_str::<Bn254Field>(file).unwrap_err(),
vec![
"Impossible to infer the assignment register to write to register `A`".to_string()
]
);
}
}

View File

@@ -28,7 +28,7 @@ pub mod utils {
InstructionStatement, LabelStatement, RegisterDeclarationStatement, RegisterTy,
},
parsed::{
asm::{InstructionBody, MachineStatement, RegisterFlag},
asm::{AssignmentRegister, InstructionBody, MachineStatement, RegisterFlag},
PilStatement,
},
};
@@ -76,11 +76,15 @@ pub mod utils {
.parse::<T>(input)
.unwrap()
{
ast::parsed::asm::FunctionStatement::Assignment(start, lhs, using_reg, rhs) => {
ast::parsed::asm::FunctionStatement::Assignment(start, lhs, reg, rhs) => {
AssignmentStatement {
start,
lhs,
using_reg,
lhs_with_reg: {
let lhs_len = lhs.len();
lhs.into_iter()
.zip(reg.unwrap_or(vec![AssignmentRegister::Wildcard; lhs_len]))
.collect()
},
rhs,
}
.into()

View File

@@ -241,15 +241,22 @@ impl<T: FieldElement> ASMPILConverter<T> {
match statement {
FunctionStatement::Assignment(AssignmentStatement {
start,
lhs,
using_reg,
lhs_with_reg,
rhs,
}) => match *rhs {
Expression::FunctionCall(c) => {
self.handle_functional_instruction(lhs, using_reg.unwrap(), c.id, c.arguments)
}) => {
let lhs_with_reg = lhs_with_reg
.into_iter()
// All assignment registers should be inferred at this point.
.map(|(lhs, reg)| (lhs, reg.unwrap()))
.collect();
match *rhs {
Expression::FunctionCall(c) => {
self.handle_functional_instruction(lhs_with_reg, c.id, c.arguments)
}
_ => self.handle_non_functional_assignment(start, lhs_with_reg, *rhs),
}
_ => self.handle_assignment(start, lhs, using_reg, *rhs),
},
}
FunctionStatement::Instruction(InstructionStatement {
instruction,
inputs,
@@ -436,22 +443,22 @@ impl<T: FieldElement> ASMPILConverter<T> {
res
}
fn handle_assignment(
fn handle_non_functional_assignment(
&mut self,
_start: usize,
write_regs: Vec<String>,
assign_reg: Option<String>,
lhs_with_reg: Vec<(String, String)>,
value: Expression<T>,
) -> CodeLine<T> {
assert!(write_regs.len() <= 1);
assert!(
assign_reg.is_some(),
"Implicit assign register not yet supported."
lhs_with_reg.len() == 1,
"Multi assignments are only implemented for function calls."
);
let assign_reg = assign_reg.unwrap();
let (write_regs, assign_reg) = lhs_with_reg.into_iter().next().unwrap();
let value = self.process_assignment_value(value);
CodeLine {
write_regs: [(assign_reg.clone(), write_regs)].into_iter().collect(),
write_regs: [(assign_reg.clone(), vec![write_regs])]
.into_iter()
.collect(),
value: [(assign_reg, value)].into(),
..Default::default()
}
@@ -459,25 +466,24 @@ impl<T: FieldElement> ASMPILConverter<T> {
fn handle_functional_instruction(
&mut self,
write_regs: Vec<String>,
assign_reg: String,
lhs_with_regs: Vec<(String, String)>,
instr_name: String,
args: Vec<Expression<T>>,
mut args: Vec<Expression<T>>,
) -> CodeLine<T> {
assert!(write_regs.len() == 1);
let instr = &self
.instructions
.get(&instr_name)
.unwrap_or_else(|| panic!("Instruction not found: {instr_name}"));
assert_eq!(instr.outputs.len(), 1);
let output = instr.outputs[0].clone();
assert!(
output == assign_reg,
"The instruction {instr_name} uses the assignment register {output}, but the caller uses {assign_reg} to further process the value.",
);
let output = instr.outputs.clone();
let mut args = args;
args.push(direct_reference(write_regs.first().unwrap().clone()));
for (o, (_, r)) in output.iter().zip(lhs_with_regs.iter()) {
assert!(
o == r,
"The instruction {instr_name} uses the output register {o}, but the caller uses {r} to further process the value.",
);
}
args.extend(lhs_with_regs.iter().map(|(lhs, _)| direct_reference(lhs)));
self.handle_instruction(instr_name, args)
}

View File

@@ -96,11 +96,8 @@ impl<T: Display> Display for AssignmentStatement<T> {
write!(
f,
"{} <={}= {};",
self.lhs.join(", "),
self.using_reg
.as_ref()
.map(ToString::to_string)
.unwrap_or_default(),
self.lhs().format(", "),
self.assignment_registers().format(", "),
self.rhs
)
}

View File

@@ -14,7 +14,9 @@ use num_bigint::BigUint;
use number::FieldElement;
use crate::parsed::{
asm::{AbsoluteSymbolPath, CallableRef, InstructionBody, OperationId, Params},
asm::{
AbsoluteSymbolPath, AssignmentRegister, CallableRef, InstructionBody, OperationId, Params,
},
PilStatement,
};
@@ -570,11 +572,20 @@ impl<T> From<Return<T>> for FunctionStatement<T> {
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct AssignmentStatement<T> {
pub start: usize,
pub lhs: Vec<String>,
pub using_reg: Option<String>,
pub lhs_with_reg: Vec<(String, AssignmentRegister)>,
pub rhs: Box<Expression<T>>,
}
impl<T> AssignmentStatement<T> {
fn lhs(&self) -> impl Iterator<Item = &String> {
self.lhs_with_reg.iter().map(|(lhs, _)| lhs)
}
fn assignment_registers(&self) -> impl Iterator<Item = &AssignmentRegister> {
self.lhs_with_reg.iter().map(|(_, reg)| reg)
}
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct InstructionStatement<T> {
pub start: usize,

View File

@@ -292,9 +292,29 @@ pub enum InstructionBody<T> {
CallableRef(CallableRef),
}
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub enum AssignmentRegister {
Register(String),
Wildcard,
}
impl AssignmentRegister {
pub fn unwrap(self) -> String {
match self {
AssignmentRegister::Register(r) => r,
AssignmentRegister::Wildcard => panic!("cannot unwrap wildcard"),
}
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
pub enum FunctionStatement<T> {
Assignment(usize, Vec<String>, Option<String>, Box<Expression<T>>),
Assignment(
usize,
Vec<String>,
Option<Vec<AssignmentRegister>>,
Box<Expression<T>>,
),
Instruction(usize, String, Vec<Expression<T>>),
Label(usize, String),
DebugDirective(usize, DebugDirective),

View File

@@ -177,6 +177,19 @@ impl<T: Display> Display for OperationId<T> {
}
}
impl Display for AssignmentRegister {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
write!(
f,
"{}",
match self {
Self::Register(r) => r.to_string(),
Self::Wildcard => "_".to_string(),
}
)
}
}
impl<T: Display> Display for FunctionStatement<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
match self {
@@ -186,7 +199,7 @@ impl<T: Display> Display for FunctionStatement<T> {
write_regs.join(", "),
assignment_reg
.as_ref()
.map(ToString::to_string)
.map(|s| s.iter().format(", ").to_string())
.unwrap_or_default(),
expression
),

View File

@@ -24,16 +24,22 @@ Labels allow referring to a location in a function by name.
### Assignments
Assignments allow setting the value of a write register to the value of an [expression](#expressions) using an assignment register.
Assignments allow setting the values of some write registers to the values of some expressions [expression](#expressions) using assignment registers.
```
{{#include ../../../test_data/asm/book/function.asm:literals}}
```
If the right-hand side of the assignment is an instruction, assignment registers can be inferred and are optional:
```
{{#include ../../../test_data/asm/book/function.asm:instruction}}
```
One important requirement is for the assignment register of the assignment to be compatible with that of the expression. This is especially relevant for instructions: the assignment register of the instruction output must match that of the assignment. In this example, we use `Y` in the assignment as the output of `square` is `Y`:
This will be inferred to be the same as `A, B <=Y, Z= square_and_double(A);` from the definition of the instruction:
```
{{#include ../../../test_data/asm/book/function.asm:square}}
{{#include ../../../test_data/asm/book/function.asm:square_and_double}}
```
### Instructions

View File

@@ -187,6 +187,31 @@ fn test_multi_assign() {
gen_estark_proof(f, slice_to_vec(&i));
}
#[test]
fn test_multi_return() {
let f = "multi_return.asm";
let i = [];
verify_asm::<GoldilocksField>(f, slice_to_vec(&i));
gen_halo2_proof(f, slice_to_vec(&i));
gen_estark_proof(f, Default::default());
}
#[test]
#[should_panic = "called `Result::unwrap()` on an `Err` value: [\"Assignment register `Z` is incompatible with `square_and_double(3)`. Try using `<==` with no explicit assignment registers.\", \"Assignment register `Y` is incompatible with `square_and_double(3)`. Try using `<==` with no explicit assignment registers.\"]"]
fn test_multi_return_wrong_assignment_registers() {
let f = "multi_return_wrong_assignment_registers.asm";
let i = [];
verify_asm::<GoldilocksField>(f, slice_to_vec(&i));
}
#[test]
#[should_panic = "Result::unwrap()` on an `Err` value: [\"Mismatched number of registers for assignment A, B <=Y= square_and_double(3);\"]"]
fn test_multi_return_wrong_assignment_register_length() {
let f = "multi_return_wrong_assignment_register_length.asm";
let i = [];
verify_asm::<GoldilocksField>(f, slice_to_vec(&i));
}
#[test]
fn test_bit_access() {
let f = "bit_access.asm";

View File

@@ -304,8 +304,19 @@ IdentifierList: Vec<String> = {
=> vec![]
}
AssignOperator: Option<String> = {
"<=" <Identifier?> "="
AssignOperator: Option<Vec<AssignmentRegister>> = {
"<==" => None,
"<=" <AssignmentRegisterList> "=" => Some(<>)
}
AssignmentRegisterList: Vec<AssignmentRegister> = {
<mut list:( <AssignmentRegister> "," )*> <end:AssignmentRegister> => { list.push(end); list },
=> vec![]
}
AssignmentRegister: AssignmentRegister = {
<Identifier> => AssignmentRegister::Register(<>),
"_" => AssignmentRegister::Wildcard,
}
ReturnStatement: FunctionStatement<T> = {

View File

@@ -7,8 +7,10 @@ machine Machine {
reg pc[@pc];
reg X[<=];
reg Y[<=];
reg Z[<=];
reg CNT;
reg A;
reg B;
// an instruction to assert that a number is zero
instr assert_zero X {
@@ -25,12 +27,13 @@ machine Machine {
pc' = XIsZero * l + (1 - XIsZero) * (pc + 1)
}
// an instruction to return the square of an input
// ANCHOR: square
instr square X -> Y {
Y = X * X
// an instruction to return the square of an input as well as its double
// ANCHOR: square_and_double
instr square_and_double X -> Y, Z {
Y = X * X,
Z = 2 * X
}
// ANCHOR_END: square
// ANCHOR_END: square_and_double
function main {
// initialise `A` to 2
@@ -48,9 +51,9 @@ machine Machine {
// ANCHOR: read_register
CNT <=X= CNT - 1;
// ANCHOR_END: read_register
// square `A`
// get the square and the double of `A`
// ANCHOR: instruction
A <== square(A);
A, B <== square_and_double(A);
// ANCHOR_END: instruction
// jump back to `start`
jmp start;
@@ -59,6 +62,8 @@ machine Machine {
// ANCHOR: instruction_statement
assert_zero A - ((2**2)**2)**2;
// ANCHOR_END: instruction_statement
// check that `B == ((2**2)**2)*2`
assert_zero B - ((2**2)**2)*2;
return;
}

View File

@@ -0,0 +1,31 @@
machine MultiAssign {
degree 16;
reg pc[@pc];
reg X[<=];
reg Y[<=];
reg Z[<=];
reg A;
reg B;
instr assert_eq X, Y { X = Y }
instr square_and_double X -> Y, Z {
Y = X * X,
Z = 2 * X
}
function main {
// Different ways of expressing the same thing...
A, B <== square_and_double(3);
A, B <=Y,Z= square_and_double(3);
A, B <=Y,_= square_and_double(3);
A, B <=_,Z= square_and_double(3);
assert_eq A, 9;
assert_eq B, 6;
return;
}
}

View File

@@ -0,0 +1,23 @@
machine MultiAssign {
degree 16;
reg pc[@pc];
reg X[<=];
reg Y[<=];
reg Z[<=];
reg A;
reg B;
instr square_and_double X -> Y, Z {
Y = X * X,
Z = 2 * X
}
function main {
// Should be using assignment registers Y, Z
A, B <=Y= square_and_double(3);
return;
}
}

View File

@@ -0,0 +1,23 @@
machine MultiAssign {
degree 16;
reg pc[@pc];
reg X[<=];
reg Y[<=];
reg Z[<=];
reg A;
reg B;
instr square_and_double X -> Y, Z {
Y = X * X,
Z = 2 * X
}
function main {
// Should be using assignment registers Y, Z
A, B <=Z,Y= square_and_double(3);
return;
}
}

View File

@@ -11,8 +11,9 @@ use ast::{
parsed::{
self,
asm::{
self, ASMModule, ASMProgram, AbsoluteSymbolPath, FunctionStatement, InstructionBody,
LinkDeclaration, MachineStatement, ModuleStatement, RegisterFlag, SymbolDefinition,
self, ASMModule, ASMProgram, AbsoluteSymbolPath, AssignmentRegister, FunctionStatement,
InstructionBody, LinkDeclaration, MachineStatement, ModuleStatement, RegisterFlag,
SymbolDefinition,
},
},
};
@@ -98,13 +99,28 @@ impl<T: FieldElement> TypeChecker<T> {
MachineStatement::FunctionDeclaration(start, name, params, statements) => {
let mut function_statements = vec![];
for s in statements {
let statement_string = s.to_string();
match s {
FunctionStatement::Assignment(start, lhs, using_reg, rhs) => {
if let Some(using_reg) = &using_reg {
if using_reg.len() != lhs.len() {
errors.push(format!(
"Mismatched number of registers for assignment {}",
statement_string
));
}
}
let using_reg = using_reg.unwrap_or_else(|| {
vec![AssignmentRegister::Wildcard; lhs.len()]
});
let lhs_with_reg = lhs
.into_iter()
.zip(using_reg.into_iter())
.collect::<Vec<_>>();
function_statements.push(
AssignmentStatement {
start,
lhs,
using_reg,
lhs_with_reg,
rhs,
}
.into(),