Pass degree to submachines mvp (#1983)

When passing arguments to a machine which has n submachines, the first n
arguments are interpreted as the submachines, then two more optional
arguments are allowed, which are interpreted as the min and max degree
for that submachine.

---------

Co-authored-by: Leo Alt <leo@powdrlabs.com>
This commit is contained in:
Thibaut Schaeffer
2024-10-30 14:28:37 +01:00
committed by GitHub
parent cf099fc109
commit 2a461140a7
9 changed files with 143 additions and 76 deletions

View File

@@ -5,7 +5,7 @@
use std::collections::BTreeMap;
use powdr_ast::{
asm_analysis::{self, combine_flags, AnalysisASMFile, LinkDefinition},
asm_analysis::{self, combine_flags, AnalysisASMFile, LinkDefinition, MachineDegree},
object::{Link, LinkFrom, LinkTo, Location, Machine, Object, Operation, PILGraph},
parsed::{
asm::{parse_absolute_path, AbsoluteSymbolPath, CallableRef, MachineParams},
@@ -59,29 +59,71 @@ pub fn compile(input: AnalysisASMFile) -> PILGraph {
};
// get a list of all machines to instantiate and their arguments. The order does not matter.
let mut queue = vec![(main_location.clone(), main_ty.clone(), vec![])];
let mut queue = vec![(
main_location.clone(),
Instance {
machine_ty: main_ty.clone(),
submachine_locations: vec![],
min_degree: None,
max_degree: None,
},
)];
// map instance location to (type, arguments)
let mut instances = BTreeMap::default();
while let Some((location, ty, args)) = queue.pop() {
let machine = &input.get_machine(&ty).unwrap();
while let Some((
location,
Instance {
machine_ty,
submachine_locations,
min_degree,
max_degree,
},
)) = queue.pop()
{
let machine = &input.get_machine(&machine_ty).unwrap();
queue.extend(machine.submachines.iter().map(|def| {
let called_machine = &input.get_machine(&def.ty).unwrap();
// we need to pass at least as many arguments as we have submachines
assert!(def.args.len() >= called_machine.params.0.len());
// and at most as many as submachines plus a min degree and a max degree
assert!(def.args.len() <= called_machine.params.0.len() + 2);
let mut def_args = def.args.clone().into_iter();
let submachine_args: Vec<_> = (&mut def_args)
.take(called_machine.params.0.len())
.collect();
let min_degree = def_args.next();
let max_degree = def_args.next();
(
// get the absolute name for this submachine
location.clone().join(def.name.clone()),
// submachine type
def.ty.clone(),
// resolve each given machine arg to a proper instance location
def.args
.iter()
.map(|a| resolve_submachine_arg(&location, machine, &args, a))
.collect(),
Instance {
machine_ty: def.ty.clone(),
// resolve each given machine arg to a proper instance location
submachine_locations: submachine_args
.iter()
.map(|a| {
resolve_submachine_arg(&location, machine, &submachine_locations, a)
})
.collect(),
min_degree,
max_degree,
},
)
}));
instances.insert(location, (ty, args));
instances.insert(
location,
Instance {
machine_ty,
submachine_locations,
min_degree,
max_degree,
},
);
}
// count incoming permutations for each machine.
@@ -215,9 +257,16 @@ struct SubmachineRef {
pub ty: AbsoluteSymbolPath,
}
struct Instance {
machine_ty: AbsoluteSymbolPath,
submachine_locations: Vec<Location>,
min_degree: Option<Expression>,
max_degree: Option<Expression>,
}
struct ASMPILConverter<'a> {
/// Map of all machine instances to their type and passed arguments
instances: &'a BTreeMap<Location, (AbsoluteSymbolPath, Vec<Location>)>,
instances: &'a BTreeMap<Location, Instance>,
/// Current machine instance
location: &'a Location,
/// Input definitions and machines.
@@ -232,7 +281,7 @@ struct ASMPILConverter<'a> {
impl<'a> ASMPILConverter<'a> {
fn new(
instances: &'a BTreeMap<Location, (AbsoluteSymbolPath, Vec<Location>)>,
instances: &'a BTreeMap<Location, Instance>,
location: &'a Location,
input: &'a AnalysisASMFile,
incoming_permutations: &'a mut BTreeMap<Location, u64>,
@@ -252,7 +301,7 @@ impl<'a> ASMPILConverter<'a> {
}
fn convert_machine(
instances: &'a BTreeMap<Location, (AbsoluteSymbolPath, Vec<Location>)>,
instances: &'a BTreeMap<Location, Instance>,
location: &'a Location,
input: &'a AnalysisASMFile,
incoming_permutations: &'a mut BTreeMap<Location, u64>,
@@ -261,11 +310,19 @@ impl<'a> ASMPILConverter<'a> {
}
fn convert_machine_inner(mut self) -> Object {
let (ty, args) = self.instances.get(self.location).as_ref().unwrap();
// TODO: This clone doubles the current memory usage
let input = self.input.get_machine(ty).unwrap().clone();
let instance = self.instances.get(self.location).unwrap();
let degree = input.degree;
// TODO: This clone doubles the current memory usage
let input = self
.input
.get_machine(&instance.machine_ty)
.unwrap()
.clone();
// the passed degrees have priority over the ones defined in the machine type
let min = instance.min_degree.clone().or(input.degree.min);
let max = instance.max_degree.clone().or(input.degree.max);
let degree = MachineDegree { min, max };
self.submachines = input
.submachines
@@ -283,7 +340,7 @@ impl<'a> ASMPILConverter<'a> {
assert!(input.callable.is_only_operations());
// process machine parameters
self.handle_parameters(input.params, args);
self.handle_parameters(input.params, &instance.submachine_locations);
for block in input.pil {
self.handle_pil_statement(block);
@@ -341,7 +398,7 @@ impl<'a> ASMPILConverter<'a> {
.iter()
.find(|s| s.name == instance)
.unwrap_or_else(|| {
let (ty, _) = self.instances.get(self.location).unwrap();
let ty = &self.instances.get(self.location).unwrap().machine_ty;
panic!("could not find submachine named `{instance}` in machine `{ty}`");
});
// get the machine type from the machine map

View File

@@ -84,12 +84,6 @@ impl TypeChecker {
pil.push(statement);
}
MachineStatement::Submachine(_, ty, name, args) => {
args.iter().for_each(|arg| {
if arg.try_to_identifier().is_none() {
errors
.push(format!("submachine argument not a machine instance: {arg}"))
}
});
submachines.push(SubmachineDeclaration {
name,
ty: AbsoluteSymbolPath::default().join(ty),

View File

@@ -194,7 +194,14 @@ fn riscv_machine(
format!(
r#"
{}
machine Main with min_degree: {}, max_degree: {} {{
let MIN_DEGREE_LOG: int = {};
let MIN_DEGREE: int = 2**MIN_DEGREE_LOG;
let MAX_DEGREE_LOG: int = {};
let MAIN_MAX_DEGREE: int = 2**MAX_DEGREE_LOG;
let LARGE_SUBMACHINES_MAX_DEGREE: int = 2**(MAX_DEGREE_LOG + 2);
machine Main with min_degree: MIN_DEGREE, max_degree: {} {{
{}
{}
@@ -209,15 +216,12 @@ let initial_memory: (fe, fe)[] = [
}}
"#,
runtime.submachines_import(),
1 << (options
.min_degree_log
.unwrap_or(powdr_linker::MIN_DEGREE_LOG as u8)),
// We expect some machines (e.g. register memory) to use up to 4x the number
// of rows as main. By setting the max degree of main to be smaller by a factor
// of 4, we ensure that we don't run out of rows in those machines.
1 << options
.max_degree_log
.unwrap_or(*powdr_linker::MAX_DEGREE_LOG as u8 - 2),
options.min_degree_log,
options.max_degree_log,
// We're passing this as well because continuations requires
// Main's max_degree to be a constant.
// We should fix that in the continuations code and remove this.
1 << options.max_degree_log,
runtime.submachines_declare(),
preamble,
initial_memory
@@ -274,7 +278,7 @@ fn preamble(field: KnownField, runtime: &Runtime, with_bootloader: bool) -> Stri
+ &memory(with_bootloader)
+ r#"
// =============== Register memory =======================
"# + "std::machines::large_field::memory::Memory regs(byte2);"
"# + "std::machines::large_field::memory::Memory regs(byte2, MIN_DEGREE, LARGE_SUBMACHINES_MAX_DEGREE);"
+ r#"
// Get the value in register Y.
instr get_reg Y -> X link ~> X = regs.mload(Y, STEP);
@@ -595,7 +599,7 @@ fn mul_instruction(field: KnownField, runtime: &Runtime) -> &'static str {
fn memory(with_bootloader: bool) -> String {
let memory_machine = if with_bootloader {
r#"
std::machines::large_field::memory_with_bootloader_write::MemoryWithBootloaderWrite memory(byte2);
std::machines::large_field::memory_with_bootloader_write::MemoryWithBootloaderWrite memory(byte2, MIN_DEGREE, MAIN_MAX_DEGREE);
// Stores val(W) at address (V = val(X) - val(Z) + Y) % 2**32.
// V can be between 0 and 2**33.
@@ -610,7 +614,7 @@ fn memory(with_bootloader: bool) -> String {
"#
} else {
r#"
std::machines::large_field::memory::Memory memory(byte2);
std::machines::large_field::memory::Memory memory(byte2, MIN_DEGREE, MAIN_MAX_DEGREE);
"#
};

View File

@@ -47,7 +47,7 @@ impl Runtime {
"std::machines::large_field::binary::Binary",
None,
"binary",
vec!["byte_binary"],
vec!["byte_binary", "MIN_DEGREE", "LARGE_SUBMACHINES_MAX_DEGREE"],
[
r#"instr and X, Y, Z, W
link ~> tmp1_col = regs.mload(X, STEP)
@@ -73,7 +73,7 @@ impl Runtime {
"std::machines::large_field::shift::Shift",
None,
"shift",
vec!["byte_shift"],
vec!["byte_shift", "MIN_DEGREE", "LARGE_SUBMACHINES_MAX_DEGREE"],
[
r#"instr shl X, Y, Z, W
link ~> tmp1_col = regs.mload(X, STEP)
@@ -94,7 +94,7 @@ impl Runtime {
"std::machines::split::split_gl::SplitGL",
None,
"split_gl",
vec!["byte_compare"],
vec!["byte_compare", "MIN_DEGREE", "MAIN_MAX_DEGREE"],
[r#"instr split_gl X, Z, W
link ~> tmp1_col = regs.mload(X, STEP)
link ~> (tmp3_col, tmp4_col) = split_gl.split(tmp1_col)

View File

@@ -63,8 +63,8 @@ pub struct CompilerOptions {
pub field: KnownField,
pub libs: RuntimeLibs,
pub continuations: bool,
pub min_degree_log: Option<u8>,
pub max_degree_log: Option<u8>,
pub min_degree_log: u8,
pub max_degree_log: u8,
}
impl CompilerOptions {
@@ -73,8 +73,8 @@ impl CompilerOptions {
field,
libs,
continuations,
min_degree_log: None,
max_degree_log: None,
min_degree_log: 5,
max_degree_log: 20,
}
}
@@ -83,8 +83,8 @@ impl CompilerOptions {
field: KnownField::BabyBearField,
libs: RuntimeLibs::new(),
continuations: false,
min_degree_log: None,
max_degree_log: None,
min_degree_log: 5,
max_degree_log: 20,
}
}
@@ -93,21 +93,21 @@ impl CompilerOptions {
field: KnownField::GoldilocksField,
libs: RuntimeLibs::new(),
continuations: false,
min_degree_log: None,
max_degree_log: None,
min_degree_log: 5,
max_degree_log: 20,
}
}
pub fn with_min_degree_log(self, log_min_degree: u8) -> Self {
pub fn with_min_degree_log(self, min_degree_log: u8) -> Self {
Self {
min_degree_log: Some(log_min_degree),
min_degree_log,
..self
}
}
pub fn with_max_degree_log(self, log_max_degree: u8) -> Self {
pub fn with_max_degree_log(self, max_degree_log: u8) -> Self {
Self {
max_degree_log: Some(log_max_degree),
max_degree_log,
..self
}
}

View File

@@ -28,6 +28,7 @@ pub fn translate_program(program: impl RiscVProgram, options: CompilerOptions) -
translate_program_impl(program, options.field, &runtime, options.continuations);
riscv_machine(
options,
&runtime,
&preamble(options.field, &runtime, options.continuations),
initial_mem,
@@ -195,6 +196,7 @@ fn translate_program_impl(
}
fn riscv_machine(
options: CompilerOptions,
runtime: &Runtime,
preamble: &str,
initial_memory: Vec<String>,
@@ -205,9 +207,17 @@ fn riscv_machine(
{}
use std::machines::small_field::add_sub::AddSub;
use std::machines::small_field::arith::Arith;
machine Main with min_degree: {}, max_degree: {} {{
AddSub add_sub(byte2);
Arith arith_mul(byte, byte2);
let MIN_DEGREE_LOG: int = {};
let MIN_DEGREE: int = 2**MIN_DEGREE_LOG;
let MAX_DEGREE_LOG: int = {};
let MAIN_MAX_DEGREE: int = 2**MAX_DEGREE_LOG;
let LARGE_SUBMACHINES_MAX_DEGREE: int = 2**(MAX_DEGREE_LOG + 2);
machine Main with min_degree: MIN_DEGREE, max_degree: {} {{
AddSub add_sub(byte2, MIN_DEGREE, LARGE_SUBMACHINES_MAX_DEGREE);
Arith arith_mul(byte, byte2, MIN_DEGREE, MAIN_MAX_DEGREE);
{}
{}
@@ -219,14 +229,15 @@ let initial_memory: (fe, fe)[] = [
function main {{
{}
}}
}}
}}
"#,
runtime.submachines_import(),
1 << powdr_linker::MIN_DEGREE_LOG,
// We expect some machines (e.g. register memory) to use up to 4x the number
// of rows as main. By setting the max degree of main to be smaller by a factor
// of 4, we ensure that we don't run out of rows in those machines.
1 << (*powdr_linker::MAX_DEGREE_LOG - 2),
options.min_degree_log,
options.max_degree_log,
// We're passing this as well because continuations requires
// Main's max_degree to be a constant.
// We should fix that in the continuations code and remove this.
1 << options.max_degree_log,
runtime.submachines_declare(),
preamble,
initial_memory
@@ -288,7 +299,7 @@ fn preamble(field: KnownField, runtime: &Runtime, with_bootloader: bool) -> Stri
+ &memory(with_bootloader)
+ r#"
// =============== Register memory =======================
"# + "std::machines::small_field::memory::Memory regs(bit12, byte2);"
"# + "std::machines::small_field::memory::Memory regs(bit12, byte2, MIN_DEGREE, LARGE_SUBMACHINES_MAX_DEGREE);"
+ r#"
// Get the value in register YL.
instr get_reg YL -> XH, XL link ~> (XH, XL) = regs.mload(0, YL, STEP);
@@ -664,7 +675,7 @@ fn memory(with_bootloader: bool) -> String {
todo!()
} else {
r#"
std::machines::small_field::memory::Memory memory(bit12, byte2);
std::machines::small_field::memory::Memory memory(bit12, byte2, MIN_DEGREE, MAIN_MAX_DEGREE);
"#
};

View File

@@ -48,7 +48,7 @@ impl Runtime {
"std::machines::small_field::binary::Binary",
None,
"binary",
vec!["byte_binary"],
vec!["byte_binary", "MIN_DEGREE", "LARGE_SUBMACHINES_MAX_DEGREE"],
[
r#"instr and XL, YL, ZH, ZL, WL
link ~> (tmp1_h, tmp1_l) = regs.mload(0, XL, STEP)
@@ -77,7 +77,7 @@ impl Runtime {
"std::machines::small_field::shift::Shift",
None,
"shift",
vec!["byte_shift"],
vec!["byte_shift", "MIN_DEGREE", "LARGE_SUBMACHINES_MAX_DEGREE"],
[
r#"instr shl XL, YL, ZH, ZL, WL
link ~> (tmp1_h, tmp1_l) = regs.mload(0, XL, STEP)

View File

@@ -1,6 +1,7 @@
let N: int = 8;
// calls a constrained machine from a constrained machine
machine Arith with
degree: 8,
latch: latch,
operation_id: operation_id
{
@@ -15,11 +16,11 @@ machine Arith with
}
machine Main with
degree: 8,
degree: N,
latch: latch,
operation_id: operation_id
{
Arith arith;
Arith arith(N, N);
// return `3*x + 3*y`, adding twice locally and twice externally
operation main<0>;

View File

@@ -14,8 +14,8 @@ machine Main with degree: N {
col fixed STEP(i) { i };
Byte2 byte2;
Memory memory(byte2);
Child sub(memory);
Memory memory(byte2, 2 * N, 2 * N);
Child sub(memory, N, N);
instr mload X -> Y link ~> Y = memory.mload(X, STEP);
instr mstore X, Y -> link ~> memory.mstore(X, STEP, Y);
@@ -73,7 +73,7 @@ machine Main with degree: N {
}
}
machine Child(mem: Memory) with degree: N {
machine Child(mem: Memory) {
reg pc[@pc];
reg X[<=];
reg Y[<=];
@@ -81,7 +81,7 @@ machine Child(mem: Memory) with degree: N {
reg A;
reg B;
GrandChild sub(mem);
GrandChild sub(mem, N, N);
instr mload X, Y -> Z link ~> Z = mem.mload(X, Y);
instr mstore X, Y, Z -> link ~> mem.mstore(X, Y, Z);
@@ -110,7 +110,7 @@ machine Child(mem: Memory) with degree: N {
}
}
machine GrandChild(mem: Memory) with degree: N {
machine GrandChild(mem: Memory) {
reg pc[@pc];
reg X[<=];
reg Y[<=];