From 376019c7ebb00a8941f510b73bbc0f3cb36c57f8 Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Mon, 11 Dec 2023 16:45:29 +0100 Subject: [PATCH] Refactor --- compiler/src/pipeline.rs | 8 + powdr_cli/src/main.rs | 317 +++++++++++++++++++------------------ riscv/src/continuations.rs | 5 +- 3 files changed, 173 insertions(+), 157 deletions(-) diff --git a/compiler/src/pipeline.rs b/compiler/src/pipeline.rs index 4be95b15f..cd2ec8a98 100644 --- a/compiler/src/pipeline.rs +++ b/compiler/src/pipeline.rs @@ -641,6 +641,14 @@ impl Pipeline { Ok(()) } + pub fn asm_string(mut self) -> Result> { + self.advance_to(Stage::AsmString)?; + match self.artifact.unwrap() { + Artifact::AsmString(_, asm_string) => Ok(asm_string), + _ => panic!(), + } + } + pub fn analyzed_asm(mut self) -> Result, Vec> { self.advance_to(Stage::AnalyzedAsm)?; let Artifact::AnalyzedAsm(analyzed_asm) = self.artifact.unwrap() else { diff --git a/powdr_cli/src/main.rs b/powdr_cli/src/main.rs index d4d21bbf1..b4f4c9672 100644 --- a/powdr_cli/src/main.rs +++ b/powdr_cli/src/main.rs @@ -18,6 +18,35 @@ use std::path::PathBuf; use std::{borrow::Cow, fs, io::Write, path::Path}; use strum::{Display, EnumString, EnumVariantNames}; +fn add_external_witness_values( + pipeline: Pipeline, + witness_values: Option, +) -> Pipeline { + let external_witness_values = witness_values + .map(|csv_path| { + let csv_file = fs::File::open(csv_path).unwrap(); + let mut csv_writer = BufReader::new(&csv_file); + read_polys_csv_file::(&mut csv_writer) + }) + .unwrap_or(vec![]); + + pipeline.with_external_witness_values(external_witness_values) +} + +fn add_csv_settings( + pipeline: Pipeline, + export_csv: bool, + csv_mode: CsvRenderModeCLI, +) -> Pipeline { + let csv_mode = match csv_mode { + CsvRenderModeCLI::SignedBase10 => CsvRenderMode::SignedBase10, + CsvRenderModeCLI::UnsignedBase10 => CsvRenderMode::UnsignedBase10, + CsvRenderModeCLI::Hex => CsvRenderMode::Hex, + }; + + pipeline.with_witness_csv_settings(export_csv, csv_mode) +} + #[derive(Clone, EnumString, EnumVariantNames, Display)] pub enum FieldArgument { #[strum(serialize = "gl")] @@ -26,7 +55,7 @@ pub enum FieldArgument { Bn254, } -#[derive(Clone, EnumString, EnumVariantNames, Display)] +#[derive(Clone, Copy, EnumString, EnumVariantNames, Display)] pub enum CsvRenderModeCLI { #[strum(serialize = "i")] SignedBase10, @@ -420,69 +449,20 @@ fn run_command(command: Commands) { csv_mode, just_execute, continuations, - } => match (just_execute, continuations) { - (true, true) => { - assert!(matches!(field, FieldArgument::Gl)); - let inputs = split_inputs::(&inputs); - rust_continuations_dry_run( - Pipeline::default().from_asm_file(PathBuf::from(file)), - inputs, - ); - } - (true, false) => { - let contents = fs::read_to_string(&file).unwrap(); - let inputs = split_inputs::(&inputs); - let inputs: HashMap> = - vec![(GoldilocksField::from(0), inputs)] - .into_iter() - .collect(); - riscv_executor::execute::( - &contents, - &inputs, - &[], - riscv_executor::ExecMode::Fast, - ); - } - (false, true) => { - assert!(matches!(field, FieldArgument::Gl)); - let inputs = split_inputs::(&inputs); - let pipeline_factory = || { - Pipeline::default() - .from_asm_file(PathBuf::from(&file)) - .with_prover_inputs(vec![]) - }; - let pipeline_callback = - |mut pipeline: Pipeline| -> Result<(), Vec> { - pipeline.advance_to(Stage::GeneratedWitness)?; - if let Some(backend) = prove_with { - pipeline.with_backend(backend).proof()?; - } - Ok(()) - }; - - rust_continuations(pipeline_factory, pipeline_callback, inputs.clone()).unwrap(); - } - (false, false) => { - match call_with_field!(compile_with_csv_export::( - file, - output_directory, - witness_values, - inputs, - force, - prove_with, - export_csv, - csv_mode - )) { - Ok(()) => {} - Err(errors) => { - eprintln!("Errors:"); - for e in errors { - eprintln!("{e}"); - } - } - }; - } - }, + } => { + call_with_field!(run_pil::( + file, + output_directory, + witness_values, + inputs, + force, + prove_with, + export_csv, + csv_mode, + just_execute, + continuations + )); + } Commands::Prove { file, dir, @@ -541,12 +521,25 @@ fn run_rust( ) .ok_or_else(|| vec!["could not compile rust".to_string()])?; - handle_riscv_asm( - asm_file_path.to_str().unwrap(), - &asm_contents, - inputs, - output_dir, + let pipeline_factory = || { + Pipeline::::default().from_asm_string( + asm_contents.clone(), + Some(PathBuf::from(asm_file_path.to_str().unwrap())), + ) + }; + + let pipeline_factory = make_pipeline_factory( + pipeline_factory, + inputs.clone(), + output_dir.to_path_buf(), force_overwrite, + None, // witness_values, + false, // export_csv, + CsvRenderModeCLI::Hex, // csv_mode, + ); + run( + pipeline_factory, + inputs, prove_with, just_execute, continuations, @@ -576,12 +569,25 @@ fn run_riscv_asm( ) .ok_or_else(|| vec!["could not compile RISC-V assembly".to_string()])?; - handle_riscv_asm( - asm_file_path.to_str().unwrap(), - &asm_contents, - inputs, - output_dir, + let pipeline_factory = || { + Pipeline::::default().from_asm_string( + asm_contents.clone(), + Some(PathBuf::from(asm_file_path.to_str().unwrap())), + ) + }; + + let pipeline_factory = make_pipeline_factory( + pipeline_factory, + inputs.clone(), + output_dir.to_path_buf(), force_overwrite, + None, // witness_values, + false, // export_csv, + CsvRenderModeCLI::Hex, // csv_mode, + ); + run( + pipeline_factory, + inputs, prove_with, just_execute, continuations, @@ -590,69 +596,27 @@ fn run_riscv_asm( } #[allow(clippy::too_many_arguments)] -fn handle_riscv_asm( - file_name: &str, - contents: &str, +fn make_pipeline_factory( + pipeline_factory: impl Fn() -> Pipeline, inputs: Vec, - output_dir: &Path, + output_dir: PathBuf, force_overwrite: bool, - prove_with: Option, - just_execute: bool, - continuations: bool, -) -> Result<(), Vec> { - match (just_execute, continuations) { - (true, true) => { - rust_continuations_dry_run( - Pipeline::default() - .from_asm_string(contents.to_string(), Some(PathBuf::from(file_name))), - inputs, - ); - } - (true, false) => { - let mut inputs_hash: HashMap> = HashMap::default(); - inputs_hash.insert(0u32.into(), inputs); - riscv_executor::execute::( - contents, - &inputs_hash, - &[], - riscv_executor::ExecMode::Fast, - ); - } - (false, true) => { - let pipeline_factory = || { - Pipeline::default() - .with_output(output_dir.to_path_buf(), force_overwrite) - .from_asm_string(contents.to_string(), Some(PathBuf::from(file_name))) - .with_prover_inputs(inputs.clone()) - }; - let pipeline_callback = |mut pipeline: Pipeline| -> Result<(), Vec> { - pipeline.advance_to(Stage::GeneratedWitness)?; - if let Some(backend) = prove_with { - pipeline.with_backend(backend).proof()?; - } - Ok(()) - }; + witness_values: Option, + export_csv: bool, + csv_mode: CsvRenderModeCLI, +) -> impl Fn() -> Pipeline { + move || { + let pipeline = pipeline_factory() + .with_output(output_dir.clone(), force_overwrite) + .with_prover_inputs(inputs.clone()); - rust_continuations(pipeline_factory, pipeline_callback, inputs.clone())?; - } - (false, false) => { - let mut pipeline = Pipeline::default() - .with_output(output_dir.to_path_buf(), force_overwrite) - .from_asm_string(contents.to_string(), Some(PathBuf::from(file_name))) - .with_prover_inputs(inputs) - .with_backend(BackendType::PilStarkCli); - pipeline.advance_to(Stage::GeneratedWitness).unwrap(); - if let Some(backend) = prove_with { - pipeline = pipeline.with_backend(backend); - pipeline.proof().unwrap(); - } - } + let pipeline = add_external_witness_values(pipeline, witness_values.clone()); + add_csv_settings(pipeline, export_csv, csv_mode) } - Ok(()) } #[allow(clippy::too_many_arguments)] -fn compile_with_csv_export( +fn run_pil( file: String, output_directory: String, witness_values: Option, @@ -661,33 +625,80 @@ fn compile_with_csv_export( prove_with: Option, export_csv: bool, csv_mode: CsvRenderModeCLI, + just_execute: bool, + continuations: bool, +) { + let pipeline_factory = || Pipeline::::default().from_asm_file(PathBuf::from(&file)); + let inputs = split_inputs::(&inputs); + + let pipeline_factory = make_pipeline_factory( + pipeline_factory, + inputs.clone(), + PathBuf::from(output_directory), + force, + witness_values, + export_csv, + csv_mode, + ); + if let Err(errors) = run( + pipeline_factory, + inputs, + prove_with, + just_execute, + continuations, + ) { + eprintln!("Errors:"); + for e in errors { + eprintln!("{e}"); + } + }; +} + +#[allow(clippy::too_many_arguments)] +fn run( + pipeline_factory: impl Fn() -> Pipeline, + inputs: Vec, + prove_with: Option, + just_execute: bool, + continuations: bool, ) -> Result<(), Vec> { - let external_witness_values = witness_values - .map(|csv_path| { - let csv_file = fs::File::open(csv_path).unwrap(); - let mut csv_writer = BufReader::new(&csv_file); - read_polys_csv_file::(&mut csv_writer) - }) - .unwrap_or(vec![]); - - let output_dir = Path::new(&output_directory); - - let csv_mode = match csv_mode { - CsvRenderModeCLI::SignedBase10 => CsvRenderMode::SignedBase10, - CsvRenderModeCLI::UnsignedBase10 => CsvRenderMode::UnsignedBase10, - CsvRenderModeCLI::Hex => CsvRenderMode::Hex, + let bootloader_inputs = if continuations { + rust_continuations_dry_run(pipeline_factory(), inputs.clone()) + } else { + vec![] }; - let mut pipeline = Pipeline::default() - .with_output(output_dir.to_path_buf(), force) - .from_file(PathBuf::from(file)) - .with_external_witness_values(external_witness_values) - .with_witness_csv_settings(export_csv, csv_mode) - .with_prover_inputs(split_inputs(&inputs)); - - pipeline.advance_to(Stage::GeneratedWitness).unwrap(); - prove_with.map(|backend| pipeline.with_backend(backend).proof().unwrap()); + match (just_execute, continuations) { + (true, true) => { + // Nothing to do + } + (true, false) => { + let mut inputs_hash: HashMap> = HashMap::default(); + inputs_hash.insert(0u32.into(), inputs); + riscv_executor::execute::( + &pipeline_factory().asm_string().unwrap(), + &inputs_hash, + &[], + riscv_executor::ExecMode::Fast, + ); + } + (false, true) => { + let pipeline_callback = |mut pipeline: Pipeline| -> Result<(), Vec> { + pipeline.advance_to(Stage::GeneratedWitness)?; + if let Some(backend) = prove_with { + pipeline.with_backend(backend).proof()?; + } + Ok(()) + }; + rust_continuations(pipeline_factory, pipeline_callback, bootloader_inputs)?; + } + (false, false) => { + let mut pipeline = pipeline_factory(); + pipeline.advance_to(Stage::GeneratedWitness).unwrap(); + prove_with.map(|backend| pipeline.with_backend(backend).proof().unwrap()); + } + } Ok(()) } diff --git a/riscv/src/continuations.rs b/riscv/src/continuations.rs index 08d46fdae..52176a50e 100644 --- a/riscv/src/continuations.rs +++ b/riscv/src/continuations.rs @@ -49,15 +49,12 @@ fn transposed_trace(trace: &ExecutionTrace) -> HashMap( pipeline_factory: PipelineFactory, pipeline_callback: PipelineCallback, - inputs: Vec, + bootloader_inputs: Vec>, ) -> Result<(), E> where PipelineFactory: Fn() -> Pipeline, PipelineCallback: Fn(Pipeline) -> Result<(), E>, { - log::info!("Dry running execution to collect bootloader inputs..."); - let pipeline = pipeline_factory(); - let bootloader_inputs = rust_continuations_dry_run(pipeline, inputs.clone()); let num_chunks = bootloader_inputs.len(); bootloader_inputs