From bfff61516e8ece01e4d69dc0032bb8a2903a85ae Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Thu, 7 Dec 2023 17:25:25 +0100 Subject: [PATCH] Move continuations into riscv crate --- powdr_cli/src/main.rs | 176 +------------------------------------ riscv/Cargo.toml | 4 +- riscv/src/continuations.rs | 173 ++++++++++++++++++++++++++++++++++++ riscv/src/lib.rs | 1 + 4 files changed, 180 insertions(+), 174 deletions(-) create mode 100644 riscv/src/continuations.rs diff --git a/powdr_cli/src/main.rs b/powdr_cli/src/main.rs index 6fc084f52..ec3b3789f 100644 --- a/powdr_cli/src/main.rs +++ b/powdr_cli/src/main.rs @@ -12,12 +12,9 @@ use log::LevelFilter; use number::write_polys_file; use number::{read_polys_csv_file, write_polys_csv_file, CsvRenderMode}; use number::{Bn254Field, FieldElement, GoldilocksField}; -use riscv::bootloader::{ - default_input, BYTES_PER_WORD, PAGE_SIZE_BYTES_LOG, PC_INDEX, REGISTER_NAMES, -}; +use riscv::continuations::rust_continuations; use riscv::{compile_riscv_asm, compile_rust}; -use riscv_executor::ExecutionTrace; -use std::collections::{BTreeSet, HashMap}; +use std::collections::HashMap; use std::io::{self, BufReader, BufWriter, Read}; use std::{borrow::Cow, fs, io::Write, path::Path}; use strum::{Display, EnumString, EnumVariantNames}; @@ -441,7 +438,7 @@ fn run_command(command: Commands) { riscv_executor::execute::( &contents, &inputs, - &default_input(), + &[], riscv_executor::ExecMode::Fast, ); } @@ -596,7 +593,7 @@ fn handle_riscv_asm( riscv_executor::execute::( contents, &inputs_hash, - &default_input(), + &[], riscv_executor::ExecMode::Fast, ); } @@ -619,171 +616,6 @@ fn handle_riscv_asm( Ok(()) } -fn transposed_trace(trace: &ExecutionTrace) -> HashMap> { - let mut reg_values: HashMap<&str, Vec> = HashMap::with_capacity(trace.reg_map.len()); - - for row in trace.regs_rows() { - for (reg_name, &index) in trace.reg_map.iter() { - reg_values - .entry(reg_name) - .or_default() - .push(row[index].0.into()); - } - } - - reg_values - .into_iter() - .map(|(n, c)| (format!("main.{}", n), c)) - .collect() -} - -fn rust_continuations(file_name: &str, contents: &str, inputs: Vec) { - let mut bootloader_inputs = default_input(); - - let program = - compiler::compile_asm_string_to_analyzed_ast::(file_name, contents, None).unwrap(); - - let inputs: HashMap> = vec![(F::from(0), inputs)].into_iter().collect(); - - log::info!("Executing powdr-asm..."); - let (full_trace, memory_accesses) = { - let trace = riscv_executor::execute_ast::( - &program, - &inputs, - &bootloader_inputs, - usize::MAX, - riscv_executor::ExecMode::Trace, - ) - .0; - (transposed_trace::(&trace), trace.mem) - }; - - let full_trace_length = full_trace["main.pc"].len(); - log::info!("Total trace length: {}", full_trace_length); - - let mut proven_trace = 0; - let mut chunk_index = 0; - let mut memory_snapshot = HashMap::new(); - - loop { - log::info!("\nRunning chunk {}...", chunk_index); - // Run for 2**degree - 2 steps, because the executor doesn't run the dispatcher, - // which takes 2 rows. - let degree = program - .machines - .iter() - .fold(None, |acc, (_, m)| acc.or(m.degree.clone())) - .unwrap() - .degree; - let degree = F::from(degree).to_degree(); - let num_rows = degree as usize - 2; - let (chunk_trace, memory_snapshot_update) = { - let (trace, memory_snapshot_update) = riscv_executor::execute_ast::( - &program, - &inputs, - &bootloader_inputs, - num_rows, - riscv_executor::ExecMode::Trace, - ); - (transposed_trace(&trace), memory_snapshot_update) - }; - log::info!("{} memory slots updated.", memory_snapshot_update.len()); - memory_snapshot.extend(memory_snapshot_update); - log::info!("Chunk trace length: {}", chunk_trace["main.pc"].len()); - - log::info!("Validating chunk..."); - let (start, _) = chunk_trace["main.pc"] - .iter() - .enumerate() - .find(|(_, &pc)| pc == bootloader_inputs[PC_INDEX]) - .unwrap(); - let full_trace_start = match chunk_index { - // The bootloader execution in the first chunk is part of the full trace. - 0 => start, - // Any other chunk starts at where we left off in the full trace. - _ => proven_trace - 1, - }; - log::info!("Bootloader used {} rows.", start); - for i in 0..(chunk_trace["main.pc"].len() - start) { - for ® in REGISTER_NAMES.iter() { - let chunk_i = i + start; - let full_i = i + full_trace_start; - if chunk_trace[reg][chunk_i] != full_trace[reg][full_i] { - log::error!("The Chunk trace differs from the full trace!"); - log::error!( - "Started comparing from row {start} in the chunk to row {full_trace_start} in the full trace; the difference is at offset {i}." - ); - log::error!( - "The PCs are {} and {}.", - chunk_trace["main.pc"][chunk_i], - full_trace["main.pc"][full_i] - ); - log::error!( - "The first difference is in register {}: {} != {} ", - reg, - chunk_trace[reg][chunk_i], - full_trace[reg][full_i], - ); - panic!(); - } - } - } - - if chunk_trace["main.pc"].len() < num_rows { - log::info!("Done!"); - break; - } - - let new_rows = match chunk_index { - 0 => num_rows, - // Minus 1 because the first row was proven already. - _ => num_rows - start - 1, - }; - proven_trace += new_rows; - log::info!("Proved {} rows.", new_rows); - - log::info!("Building inputs for chunk {}...", chunk_index + 1); - let mut accessed_pages = BTreeSet::new(); - let start_idx = memory_accesses - .binary_search_by_key(&proven_trace, |a| a.idx) - .unwrap_or_else(|v| v); - - for access in &memory_accesses[start_idx..] { - // proven_trace + num_rows is an upper bound for the last row index we'll reach in the next chunk. - // In practice, we'll stop earlier, because the bootloader needs to run as well, but we don't know for - // how long as that depends on the number of pages. - if access.idx >= proven_trace + num_rows { - break; - } - accessed_pages.insert(access.address >> PAGE_SIZE_BYTES_LOG); - } - log::info!( - "{} accessed pages: {:?}", - accessed_pages.len(), - accessed_pages - ); - - bootloader_inputs = vec![]; - for ® in REGISTER_NAMES.iter() { - bootloader_inputs.push(*chunk_trace[reg].last().unwrap()); - } - bootloader_inputs.push((accessed_pages.len() as u64).into()); - for &page in accessed_pages.iter() { - let start_addr = page << PAGE_SIZE_BYTES_LOG; - bootloader_inputs.push(page.into()); - let words_per_page = (1 << (PAGE_SIZE_BYTES_LOG)) / BYTES_PER_WORD; - for i in 0..words_per_page { - let addr = start_addr + (i * BYTES_PER_WORD) as u32; - bootloader_inputs.push((*memory_snapshot.get(&addr).unwrap_or(&0)).into()); - } - } - - log::info!("Inputs length: {}", bootloader_inputs.len()); - - chunk_index += 1; - } -} - #[allow(clippy::too_many_arguments)] fn compile_with_csv_export( file: String, diff --git a/riscv/Cargo.toml b/riscv/Cargo.toml index 7ebb371eb..462049e43 100644 --- a/riscv/Cargo.toml +++ b/riscv/Cargo.toml @@ -18,6 +18,8 @@ serde_json = "1.0" regex-syntax = { version = "0.6", default_features = false, features = [ "unicode", ] } +riscv_executor = { path = "../riscv_executor" } +compiler = { path = "../compiler" } [build-dependencies] lalrpop = "^0.19" @@ -27,5 +29,3 @@ test-log = "0.2.12" env_logger = "0.10.0" hex = "0.4.3" number = { path = "../number" } -compiler = { path = "../compiler" } -riscv_executor = { path = "../riscv_executor" } diff --git a/riscv/src/continuations.rs b/riscv/src/continuations.rs new file mode 100644 index 000000000..23038f0b0 --- /dev/null +++ b/riscv/src/continuations.rs @@ -0,0 +1,173 @@ +use std::collections::{BTreeSet, HashMap}; + +use number::FieldElement; +use riscv_executor::ExecutionTrace; + +use crate::bootloader::{ + default_input, BYTES_PER_WORD, PAGE_SIZE_BYTES_LOG, PC_INDEX, REGISTER_NAMES, +}; + +fn transposed_trace(trace: &ExecutionTrace) -> HashMap> { + let mut reg_values: HashMap<&str, Vec> = HashMap::with_capacity(trace.reg_map.len()); + + for row in trace.regs_rows() { + for (reg_name, &index) in trace.reg_map.iter() { + reg_values + .entry(reg_name) + .or_default() + .push(row[index].0.into()); + } + } + + reg_values + .into_iter() + .map(|(n, c)| (format!("main.{}", n), c)) + .collect() +} + +pub fn rust_continuations(file_name: &str, contents: &str, inputs: Vec) { + let mut bootloader_inputs = default_input(); + + let program = + compiler::compile_asm_string_to_analyzed_ast::(file_name, contents, None).unwrap(); + + let inputs: HashMap> = vec![(F::from(0), inputs)].into_iter().collect(); + + log::info!("Executing powdr-asm..."); + let (full_trace, memory_accesses) = { + let trace = riscv_executor::execute_ast::( + &program, + &inputs, + &bootloader_inputs, + usize::MAX, + riscv_executor::ExecMode::Trace, + ) + .0; + (transposed_trace::(&trace), trace.mem) + }; + + let full_trace_length = full_trace["main.pc"].len(); + log::info!("Total trace length: {}", full_trace_length); + + let mut proven_trace = 0; + let mut chunk_index = 0; + let mut memory_snapshot = HashMap::new(); + + loop { + log::info!("\nRunning chunk {}...", chunk_index); + // Run for 2**degree - 2 steps, because the executor doesn't run the dispatcher, + // which takes 2 rows. + let degree = program + .machines + .iter() + .fold(None, |acc, (_, m)| acc.or(m.degree.clone())) + .unwrap() + .degree; + let degree = F::from(degree).to_degree(); + let num_rows = degree as usize - 2; + let (chunk_trace, memory_snapshot_update) = { + let (trace, memory_snapshot_update) = riscv_executor::execute_ast::( + &program, + &inputs, + &bootloader_inputs, + num_rows, + riscv_executor::ExecMode::Trace, + ); + (transposed_trace(&trace), memory_snapshot_update) + }; + log::info!("{} memory slots updated.", memory_snapshot_update.len()); + memory_snapshot.extend(memory_snapshot_update); + log::info!("Chunk trace length: {}", chunk_trace["main.pc"].len()); + + log::info!("Validating chunk..."); + let (start, _) = chunk_trace["main.pc"] + .iter() + .enumerate() + .find(|(_, &pc)| pc == bootloader_inputs[PC_INDEX]) + .unwrap(); + let full_trace_start = match chunk_index { + // The bootloader execution in the first chunk is part of the full trace. + 0 => start, + // Any other chunk starts at where we left off in the full trace. + _ => proven_trace - 1, + }; + log::info!("Bootloader used {} rows.", start); + for i in 0..(chunk_trace["main.pc"].len() - start) { + for ® in REGISTER_NAMES.iter() { + let chunk_i = i + start; + let full_i = i + full_trace_start; + if chunk_trace[reg][chunk_i] != full_trace[reg][full_i] { + log::error!("The Chunk trace differs from the full trace!"); + log::error!( + "Started comparing from row {start} in the chunk to row {full_trace_start} in the full trace; the difference is at offset {i}." + ); + log::error!( + "The PCs are {} and {}.", + chunk_trace["main.pc"][chunk_i], + full_trace["main.pc"][full_i] + ); + log::error!( + "The first difference is in register {}: {} != {} ", + reg, + chunk_trace[reg][chunk_i], + full_trace[reg][full_i], + ); + panic!(); + } + } + } + + if chunk_trace["main.pc"].len() < num_rows { + log::info!("Done!"); + break; + } + + let new_rows = match chunk_index { + 0 => num_rows, + // Minus 1 because the first row was proven already. + _ => num_rows - start - 1, + }; + proven_trace += new_rows; + log::info!("Proved {} rows.", new_rows); + + log::info!("Building inputs for chunk {}...", chunk_index + 1); + let mut accessed_pages = BTreeSet::new(); + let start_idx = memory_accesses + .binary_search_by_key(&proven_trace, |a| a.idx) + .unwrap_or_else(|v| v); + + for access in &memory_accesses[start_idx..] { + // proven_trace + num_rows is an upper bound for the last row index we'll reach in the next chunk. + // In practice, we'll stop earlier, because the bootloader needs to run as well, but we don't know for + // how long as that depends on the number of pages. + if access.idx >= proven_trace + num_rows { + break; + } + accessed_pages.insert(access.address >> PAGE_SIZE_BYTES_LOG); + } + log::info!( + "{} accessed pages: {:?}", + accessed_pages.len(), + accessed_pages + ); + + bootloader_inputs = vec![]; + for ® in REGISTER_NAMES.iter() { + bootloader_inputs.push(*chunk_trace[reg].last().unwrap()); + } + bootloader_inputs.push((accessed_pages.len() as u64).into()); + for &page in accessed_pages.iter() { + let start_addr = page << PAGE_SIZE_BYTES_LOG; + bootloader_inputs.push(page.into()); + let words_per_page = (1 << (PAGE_SIZE_BYTES_LOG)) / BYTES_PER_WORD; + for i in 0..words_per_page { + let addr = start_addr + (i * BYTES_PER_WORD) as u32; + bootloader_inputs.push((*memory_snapshot.get(&addr).unwrap_or(&0)).into()); + } + } + + log::info!("Inputs length: {}", bootloader_inputs.len()); + + chunk_index += 1; + } +} diff --git a/riscv/src/lib.rs b/riscv/src/lib.rs index 102590d5c..2cce61363 100644 --- a/riscv/src/lib.rs +++ b/riscv/src/lib.rs @@ -17,6 +17,7 @@ pub use crate::coprocessors::CoProcessors; pub mod bootloader; pub mod compiler; +pub mod continuations; mod coprocessors; mod disambiguator; pub mod parser;