Continuations Prototype

This commit is contained in:
Georg Wiese
2023-11-28 18:13:33 +01:00
parent aca7536821
commit 85b5ef030e
5 changed files with 288 additions and 42 deletions

View File

@@ -139,7 +139,7 @@ pub fn compile_asm<T: FieldElement>(
pub fn compile_asm_string_to_analyzed_ast<T: FieldElement>(
file_name: &str,
contents: &str,
monitor: &mut DiffMonitor,
monitor: Option<&mut DiffMonitor>,
) -> Result<AnalysisASMFile<T>, Vec<String>> {
let parsed = parser::parse_asm(Some(file_name), contents).unwrap_or_else(|err| {
eprintln!("Error parsing .asm file:");
@@ -150,6 +150,8 @@ pub fn compile_asm_string_to_analyzed_ast<T: FieldElement>(
let resolved =
importer::resolve(Some(PathBuf::from(file_name)), parsed).map_err(|e| vec![e])?;
log::debug!("Run analysis");
let mut default_monitor = DiffMonitor::default();
let monitor = monitor.unwrap_or(&mut default_monitor);
let analyzed = analyze(resolved, monitor)?;
log::debug!("Analysis done");
log::trace!("{analyzed}");
@@ -226,7 +228,7 @@ pub fn compile_asm_string<T: FieldElement>(
external_witness_values: Vec<(&str, Vec<T>)>,
) -> Result<(PathBuf, Option<CompilationResult<T>>), Vec<String>> {
let mut monitor = DiffMonitor::default();
let analyzed = compile_asm_string_to_analyzed_ast(file_name, contents, &mut monitor)?;
let analyzed = compile_asm_string_to_analyzed_ast(file_name, contents, Some(&mut monitor))?;
if let Some(hook) = analyzed_hook {
hook(&analyzed);
};
@@ -331,6 +333,10 @@ fn compile<T: FieldElement, Q: QueryCallback<T>>(
#[allow(clippy::print_stdout)]
pub fn inputs_to_query_callback<T: FieldElement>(inputs: Vec<T>) -> impl QueryCallback<T> {
// TODO: Pass bootloader inputs into this function
// Right now, accessing bootloader inputs will always fail, because it will be out of bounds
let bootloader_inputs = [];
move |query: &str| -> Result<Option<T>, String> {
// TODO In the future, when match statements need to be exhaustive,
// This function probably gets an Option as argument and it should
@@ -358,6 +364,21 @@ pub fn inputs_to_query_callback<T: FieldElement>(inputs: Vec<T>) -> impl QueryCa
))
}
}
["\"bootloader_input\"", index] => {
let index = index
.parse::<usize>()
.map_err(|e| format!("Error parsing index: {e})"))?;
let value = bootloader_inputs.get(index).cloned();
if let Some(value) = value {
log::trace!("Bootloader input query: Index {index} -> {value}");
Ok(Some(value))
} else {
Err(format!(
"Error accessing bootloader inputs: Index {index} out of bounds {}",
inputs.len()
))
}
}
["\"print_char\"", ch] => {
print!(
"{}",

View File

@@ -12,8 +12,10 @@ use log::LevelFilter;
use number::write_polys_file;
use number::{read_polys_csv_file, write_polys_csv_file, CsvRenderMode};
use number::{Bn254Field, FieldElement, GoldilocksField};
use riscv::bootloader::default_input;
use riscv::bootloader::{default_input, PC_INDEX, REGISTER_NAMES};
use riscv::{compile_riscv_asm, compile_rust};
use riscv_executor::ExecutionTrace;
use std::collections::{BTreeSet, HashMap};
use std::io::{self, BufReader, BufWriter, Read};
use std::{borrow::Cow, fs, io::Write, path::Path};
use strum::{Display, EnumString, EnumVariantNames};
@@ -101,6 +103,11 @@ enum Commands {
#[arg(short, long)]
#[arg(default_value_t = false)]
just_execute: bool,
/// Run a long execution in chunks (Experimental and not sound!)
#[arg(short, long)]
#[arg(default_value_t = false)]
continuations: bool,
},
/// Compiles (no-std) rust code to riscv assembly, then to powdr assembly
/// and finally to PIL and generates fixed and witness columns.
@@ -143,6 +150,11 @@ enum Commands {
#[arg(short, long)]
#[arg(default_value_t = false)]
just_execute: bool,
/// Run a long execution in chunks (Experimental and not sound!)
#[arg(short, long)]
#[arg(default_value_t = false)]
continuations: bool,
},
/// Compiles riscv assembly to powdr assembly and then to PIL
@@ -186,6 +198,11 @@ enum Commands {
#[arg(short, long)]
#[arg(default_value_t = false)]
just_execute: bool,
/// Run a long execution in chunks (Experimental and not sound!)
#[arg(short, long)]
#[arg(default_value_t = false)]
continuations: bool,
},
Prove {
@@ -318,6 +335,7 @@ fn run_command(command: Commands) {
prove_with,
coprocessors,
just_execute,
continuations,
} => {
let coprocessors = match coprocessors {
Some(list) => {
@@ -332,7 +350,8 @@ fn run_command(command: Commands) {
force,
prove_with,
coprocessors,
just_execute
just_execute,
continuations
)) {
eprintln!("Errors:");
for e in errors {
@@ -349,6 +368,7 @@ fn run_command(command: Commands) {
prove_with,
coprocessors,
just_execute,
continuations,
} => {
assert!(!files.is_empty());
let name = if files.len() == 1 {
@@ -371,7 +391,8 @@ fn run_command(command: Commands) {
force,
prove_with,
coprocessors,
just_execute
just_execute,
continuations
)) {
eprintln!("Errors:");
for e in errors {
@@ -400,13 +421,22 @@ fn run_command(command: Commands) {
export_csv,
csv_mode,
just_execute,
} => {
if just_execute {
// assume input is riscv asm and just execute it
let contents = fs::read_to_string(file).unwrap();
let inputs = split_inputs(&inputs);
continuations,
} => match (just_execute, continuations) {
(true, true) => {
let contents = fs::read_to_string(&file).unwrap();
let inputs = split_inputs::<GoldilocksField>(&inputs);
rust_continuations(file.as_str(), contents.as_str(), inputs);
}
(true, false) => {
let contents = fs::read_to_string(&file).unwrap();
let inputs = split_inputs::<GoldilocksField>(&inputs);
riscv_executor::execute::<GoldilocksField>(&contents, &inputs, &default_input());
} else {
}
(false, true) => {
unimplemented!("Running witgen with continuations is not supported yet.")
}
(false, false) => {
match call_with_field!(compile_with_csv_export::<field>(
file,
output_directory,
@@ -426,7 +456,7 @@ fn run_command(command: Commands) {
}
};
}
}
},
Commands::Prove {
file,
dir,
@@ -465,6 +495,7 @@ fn write_backend_to_fs<F: FieldElement>(be: &dyn Backend<F>, output_dir: &Path)
log::info!("Wrote params.bin.");
}
#[allow(clippy::too_many_arguments)]
fn run_rust<F: FieldElement>(
file_name: &str,
inputs: Vec<F>,
@@ -473,10 +504,16 @@ fn run_rust<F: FieldElement>(
prove_with: Option<BackendType>,
coprocessors: riscv::CoProcessors,
just_execute: bool,
continuations: bool,
) -> Result<(), Vec<String>> {
let (asm_file_path, asm_contents) =
compile_rust(file_name, output_dir, force_overwrite, &coprocessors, false)
.ok_or_else(|| vec!["could not compile rust".to_string()])?;
let (asm_file_path, asm_contents) = compile_rust(
file_name,
output_dir,
force_overwrite,
&coprocessors,
continuations,
)
.ok_or_else(|| vec!["could not compile rust".to_string()])?;
handle_riscv_asm(
asm_file_path.to_str().unwrap(),
@@ -486,6 +523,7 @@ fn run_rust<F: FieldElement>(
force_overwrite,
prove_with,
just_execute,
continuations,
)?;
Ok(())
}
@@ -500,6 +538,7 @@ fn run_riscv_asm<F: FieldElement>(
prove_with: Option<BackendType>,
coprocessors: riscv::CoProcessors,
just_execute: bool,
continuations: bool,
) -> Result<(), Vec<String>> {
let (asm_file_path, asm_contents) = compile_riscv_asm(
original_file_name,
@@ -519,10 +558,12 @@ fn run_riscv_asm<F: FieldElement>(
force_overwrite,
prove_with,
just_execute,
continuations,
)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn handle_riscv_asm<F: FieldElement>(
file_name: &str,
contents: &str,
@@ -531,24 +572,180 @@ fn handle_riscv_asm<F: FieldElement>(
force_overwrite: bool,
prove_with: Option<BackendType>,
just_execute: bool,
continuations: bool,
) -> Result<(), Vec<String>> {
if just_execute {
riscv_executor::execute::<F>(contents, &inputs, &default_input());
} else {
compile_asm_string(
file_name,
contents,
inputs,
None,
output_dir,
force_overwrite,
prove_with,
vec![],
)?;
match (just_execute, continuations) {
(true, true) => {
rust_continuations(file_name, contents, inputs);
}
(true, false) => {
riscv_executor::execute::<F>(contents, &inputs, &default_input());
}
(false, true) => {
unimplemented!("Running witgen with continuations is not supported yet.")
}
(false, false) => {
compile_asm_string(
file_name,
contents,
inputs,
None,
output_dir,
force_overwrite,
prove_with,
vec![],
)?;
}
}
Ok(())
}
fn transposed_trace<F: FieldElement>(trace: &ExecutionTrace) -> HashMap<String, Vec<F>> {
let mut reg_values: HashMap<&str, Vec<F>> = HashMap::with_capacity(trace.reg_map.len());
for row in trace.regs_rows() {
for (reg_name, &index) in trace.reg_map.iter() {
reg_values
.entry(reg_name)
.or_default()
.push(row[index].0.into());
}
}
reg_values
.into_iter()
.map(|(n, c)| (format!("main.{}", n), c))
.collect()
}
fn rust_continuations<F: FieldElement>(file_name: &str, contents: &str, inputs: Vec<F>) {
let mut bootloader_inputs = default_input();
let program =
compiler::compile_asm_string_to_analyzed_ast::<F>(file_name, contents, None).unwrap();
log::info!("Executing powdr-asm...");
let (full_trace, memory_accesses) = {
let trace =
riscv_executor::execute_ast::<F>(&program, &inputs, &bootloader_inputs, usize::MAX).0;
(transposed_trace::<F>(&trace), trace.mem)
};
let full_trace_length = full_trace["main.pc"].len();
log::info!("Total trace length: {}", full_trace_length);
let mut proven_trace = 0;
let mut chunk_index = 0;
let mut memory_snapshot = HashMap::new();
loop {
log::info!("\nRunning chunk {}...", chunk_index);
// Run for 2**degree - 2 steps, because the executor doesn't run the dispatcher,
// which takes 2 rows.
let degree = program
.machines
.iter()
.fold(None, |acc, (_, m)| acc.or(m.degree.clone()))
.unwrap()
.degree;
let degree = F::from(degree).to_degree();
let num_rows = degree as usize - 2;
let (chunk_trace, memory_snapshot_update) = {
let (trace, memory_snapshot_update) =
riscv_executor::execute_ast::<F>(&program, &inputs, &bootloader_inputs, num_rows);
(transposed_trace(&trace), memory_snapshot_update)
};
log::info!("{} memory slots updated.", memory_snapshot_update.len());
memory_snapshot.extend(memory_snapshot_update);
log::info!("Chunk trace length: {}", chunk_trace["main.pc"].len());
log::info!("Validating chunk...");
let (start, _) = chunk_trace["main.pc"]
.iter()
.enumerate()
.find(|(_, &pc)| pc == bootloader_inputs[PC_INDEX])
.unwrap();
let full_trace_start = match chunk_index {
// The bootloader execution in the first chunk is part of the full trace.
0 => start,
// Any other chunk starts at where we left off in the full trace.
_ => proven_trace - 1,
};
for i in 0..(chunk_trace["main.pc"].len() - start) {
for &reg in REGISTER_NAMES.iter() {
let chunk_i = i + start;
let full_i = i + full_trace_start;
if chunk_trace[reg][chunk_i] != full_trace[reg][full_i] {
log::error!("The Chunk trace differs from the full trace!");
log::error!(
"Started comparing from row {start} in the chunk to row {full_trace_start} in the full trace; the difference is at offset {i}."
);
log::error!(
"The PCs are {} and {}.",
chunk_trace["main.pc"][chunk_i],
full_trace["main.pc"][full_i]
);
log::error!(
"The first difference is in register {}: {} != {} ",
reg,
chunk_trace[reg][chunk_i],
full_trace[reg][full_i],
);
panic!();
}
}
}
if chunk_trace["main.pc"].len() < num_rows {
log::info!("Done!");
break;
}
let new_rows = match chunk_index {
0 => num_rows,
// Minus 1 because the first row was proven already.
_ => num_rows - start - 1,
};
proven_trace += new_rows;
log::info!("Proved {} rows.", new_rows);
log::info!("Building inputs for chunk {}...", chunk_index + 1);
let mut accessed_pages = BTreeSet::new();
let start_idx = memory_accesses
.binary_search_by_key(&proven_trace, |a| a.idx)
.unwrap_or_else(|v| v);
for access in &memory_accesses[start_idx..] {
// proven_trace + num_rows is an upper bound for the last row index we'll reach in the next chunk.
// In practice, we'll stop earlier, because the bootloader needs to run as well, but we don't know for
// how long as that depends on the number of pages.
if access.idx >= proven_trace + num_rows {
break;
}
accessed_pages.insert(access.address >> 10);
}
log::info!("Accessed pages: {:?}", accessed_pages);
bootloader_inputs = vec![];
for &reg in REGISTER_NAMES.iter() {
bootloader_inputs.push(*chunk_trace[reg].last().unwrap());
}
bootloader_inputs.push((accessed_pages.len() as u64).into());
for page in accessed_pages.iter() {
let start_addr = page << 10;
bootloader_inputs.push(start_addr.into());
for i in 0..256 {
let addr = start_addr + i * 4;
bootloader_inputs.push((*memory_snapshot.get(&addr).unwrap_or(&0)).into());
}
}
log::info!("Inputs length: {}", bootloader_inputs.len());
chunk_index += 1;
}
}
#[allow(clippy::too_many_arguments)]
fn compile_with_csv_export<T: FieldElement>(
file: String,
@@ -762,6 +959,7 @@ mod test {
export_csv: true,
csv_mode: CsvRenderModeCLI::Hex,
just_execute: false,
continuations: false,
};
run_command(pil_command);

View File

@@ -133,6 +133,9 @@ pub const REGISTER_NAMES: [&str; 37] = [
"main.pc",
];
/// Index of the PC in the bootloader input.
pub const PC_INDEX: usize = REGISTER_NAMES.len() - 1;
/// The bootloader input that is equivalent to not using a bootloader, i.e.:
/// - No pages are initialized
/// - All registers are set to 0

View File

@@ -0,0 +1,22 @@
#![no_std]
extern crate alloc;
use alloc::vec::Vec;
#[no_mangle]
pub fn main() {
let mut foo = Vec::new();
foo.push(1);
// Compute some fibonacci numbers
// -> Does not access memory but also does not get optimized out...
let mut a = 1;
let mut b = 1;
for _ in 0..1000000 {
let tmp = a + b;
a = b;
b = tmp;
}
// Don't optimize me away :/
assert!(a > 0);
}

View File

@@ -72,11 +72,13 @@ impl From<usize> for Elem {
pub type MemoryState = HashMap<u32, u32>;
#[derive(Debug)]
pub enum MemOperationKind {
Read,
Write,
}
#[derive(Debug)]
pub struct MemOperation {
/// Line of the register trace the memory operation happened.
pub idx: usize,
@@ -109,7 +111,7 @@ impl<'a> ExecutionTrace<'a> {
}
mod builder {
use std::collections::HashMap;
use std::{cmp, collections::HashMap};
use ast::asm_analysis::{Machine, RegisterTy};
use number::FieldElement;
@@ -288,11 +290,7 @@ mod builder {
address: addr,
});
if val != 0 {
self.mem.insert(addr, val);
} else {
self.mem.remove(&addr);
}
self.mem.insert(addr, val);
}
pub(crate) fn get_mem(&mut self, addr: u32) -> u32 {
@@ -328,11 +326,15 @@ mod builder {
self.set_reg_idx(
self.pc_idx,
if self.next_statement_line >= line_of_next_batch {
assert_eq!(self.next_statement_line, line_of_next_batch);
curr_pc + 1
} else {
curr_pc
match self.next_statement_line.cmp(&line_of_next_batch) {
cmp::Ordering::Less => curr_pc,
cmp::Ordering::Equal => curr_pc + 1,
cmp::Ordering::Greater => {
panic!(
"next_statement_line: {} > line_of_next_batch: {}",
self.next_statement_line, line_of_next_batch
);
}
}
.into(),
);
@@ -739,7 +741,7 @@ pub fn execute_ast<'a, T: FieldElement>(
loop {
let stm = statements[curr_pc as usize];
//println!("l {curr_pc}: {stm}",);
log::trace!("l {curr_pc}: {stm}",);
let is_nop = match stm {
FunctionStatement::Assignment(a) => {
@@ -761,10 +763,10 @@ pub fn execute_ast<'a, T: FieldElement>(
match &dd.directive {
DebugDirective::Loc(file, line, column) => {
let (dir, file) = debug_files[file - 1];
println!("Executed {dir}/{file}:{line}:{column}");
log::debug!("Executed {dir}/{file}:{line}:{column}");
}
DebugDirective::OriginalInstruction(insn) => {
println!(" {insn}");
log::debug!(" {insn}");
}
DebugDirective::File(_, _, _) => unreachable!(),
};