From 9d518bc5be697d9730365d9c6fe230392e3bb1fa Mon Sep 17 00:00:00 2001 From: Georg Wiese Date: Thu, 2 Nov 2023 17:39:01 +0000 Subject: [PATCH] Enable passing externally-generated witness via command line --- compiler/src/lib.rs | 19 ++++++- compiler/src/verify.rs | 1 + compiler/tests/asm.rs | 2 + compiler/tests/pil.rs | 2 + compiler/tests/powdr_std.rs | 2 + number/Cargo.toml | 5 +- number/src/lib.rs | 4 +- number/src/serialize.rs | 109 ++++++++++++++++++++++++++++++++++-- powdr_cli/src/main.rs | 84 +++++++++++++-------------- 9 files changed, 174 insertions(+), 54 deletions(-) diff --git a/compiler/src/lib.rs b/compiler/src/lib.rs index 5bb8b2e2d..47b2294f2 100644 --- a/compiler/src/lib.rs +++ b/compiler/src/lib.rs @@ -39,16 +39,24 @@ pub fn compile_pil_or_asm( output_dir: &Path, force_overwrite: bool, prove_with: Option, + external_witness_values: Vec<(&str, Vec)>, ) -> Result>, Vec> { if file_name.ends_with(".asm") { - compile_asm(file_name, inputs, output_dir, force_overwrite, prove_with) + compile_asm( + file_name, + inputs, + output_dir, + force_overwrite, + prove_with, + external_witness_values, + ) } else { Ok(Some(compile_pil( Path::new(file_name), output_dir, inputs_to_query_callback(inputs), prove_with, - vec![], + external_witness_values, ))) } } @@ -86,6 +94,7 @@ pub fn compile_pil_ast>( output_dir: &Path, query_callback: Q, prove_with: Option, + external_witness_values: Vec<(&str, Vec)>, ) -> CompilationResult { // TODO exporting this to string as a hack because the parser // is tied into the analyzer due to imports. @@ -95,7 +104,7 @@ pub fn compile_pil_ast>( output_dir, query_callback, prove_with, - vec![], + external_witness_values, ) } @@ -108,6 +117,7 @@ pub fn compile_asm( output_dir: &Path, force_overwrite: bool, prove_with: Option, + external_witness_values: Vec<(&str, Vec)>, ) -> Result>, Vec> { let contents = fs::read_to_string(file_name).unwrap(); Ok(compile_asm_string( @@ -117,6 +127,7 @@ pub fn compile_asm( output_dir, force_overwrite, prove_with, + external_witness_values, )? .1) } @@ -132,6 +143,7 @@ pub fn compile_asm_string( output_dir: &Path, force_overwrite: bool, prove_with: Option, + external_witness_values: Vec<(&str, Vec)>, ) -> Result<(PathBuf, Option>), Vec> { let parsed = parser::parse_asm(Some(file_name), contents).unwrap_or_else(|err| { eprintln!("Error parsing .asm file:"); @@ -179,6 +191,7 @@ pub fn compile_asm_string( output_dir, inputs_to_query_callback(inputs), prove_with, + external_witness_values, )), )) } diff --git a/compiler/src/verify.rs b/compiler/src/verify.rs index b04a96659..a3272e1c4 100644 --- a/compiler/src/verify.rs +++ b/compiler/src/verify.rs @@ -13,6 +13,7 @@ pub fn verify_asm_string(file_name: &str, contents: &str, input &temp_dir, true, Some(BackendType::PilStarkCli), + vec![], ) .unwrap(); verify(&temp_dir); diff --git a/compiler/tests/asm.rs b/compiler/tests/asm.rs index 12565e4ec..03d458274 100644 --- a/compiler/tests/asm.rs +++ b/compiler/tests/asm.rs @@ -25,6 +25,7 @@ fn gen_estark_proof(file_name: &str, inputs: Vec) { &mktemp::Temp::new_dir().unwrap(), true, Some(backend::BackendType::EStark), + vec![], ) .unwrap(); } @@ -41,6 +42,7 @@ fn gen_halo2_proof(file_name: &str, inputs: Vec) { &mktemp::Temp::new_dir().unwrap(), true, Some(backend::BackendType::Halo2), + vec![], ) .unwrap(); } diff --git a/compiler/tests/pil.rs b/compiler/tests/pil.rs index a7c5befd8..1c39d8eea 100644 --- a/compiler/tests/pil.rs +++ b/compiler/tests/pil.rs @@ -45,6 +45,7 @@ fn gen_estark_proof(file_name: &str, inputs: Vec) { &mktemp::Temp::new_dir().unwrap(), true, Some(BackendType::EStark), + vec![], ) .unwrap(); } @@ -61,6 +62,7 @@ fn gen_halo2_proof(file_name: &str, inputs: Vec) { &mktemp::Temp::new_dir().unwrap(), true, Some(BackendType::Halo2), + vec![], ) .unwrap(); } diff --git a/compiler/tests/powdr_std.rs b/compiler/tests/powdr_std.rs index 1270f5e37..dce760aae 100644 --- a/compiler/tests/powdr_std.rs +++ b/compiler/tests/powdr_std.rs @@ -25,6 +25,7 @@ fn gen_estark_proof(file_name: &str, inputs: Vec) { &mktemp::Temp::new_dir().unwrap(), true, Some(backend::BackendType::EStark), + vec![], ) .unwrap(); } @@ -41,6 +42,7 @@ fn gen_halo2_proof(file_name: &str, inputs: Vec) { &mktemp::Temp::new_dir().unwrap(), true, Some(backend::BackendType::Halo2Mock), + vec![], ) .unwrap(); } diff --git a/number/Cargo.toml b/number/Cargo.toml index 2c6dd1f48..b2c6b7dbe 100644 --- a/number/Cargo.toml +++ b/number/Cargo.toml @@ -4,10 +4,13 @@ version = "0.1.0" edition = "2021" [dependencies] -ark-bn254 = { version = "0.4.0", default-features = false, features = ["scalar_field"] } +ark-bn254 = { version = "0.4.0", default-features = false, features = [ + "scalar_field", +] } ark-ff = "0.4.2" num-bigint = "0.4.3" num-traits = "0.2.15" +csv = "1.3" [dev-dependencies] test-log = "0.2.12" diff --git a/number/src/lib.rs b/number/src/lib.rs index 060552bd8..9270a79ca 100644 --- a/number/src/lib.rs +++ b/number/src/lib.rs @@ -9,7 +9,9 @@ mod goldilocks; mod serialize; mod traits; -pub use serialize::{read_polys_file, write_polys_file}; +pub use serialize::{ + read_polys_csv_file, read_polys_file, write_polys_csv_file, write_polys_file, CsvRenderMode, +}; pub use bn254::Bn254Field; pub use goldilocks::GoldilocksField; diff --git a/number/src/serialize.rs b/number/src/serialize.rs index ab5344cc7..bd9c6e059 100644 --- a/number/src/serialize.rs +++ b/number/src/serialize.rs @@ -1,7 +1,82 @@ use std::io::{Read, Write}; +use csv::{Reader, Writer}; + use crate::{DegreeType, FieldElement}; +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum CsvRenderMode { + SignedBase10, + UnsignedBase10, + Hex, +} + +const ROW_NAME: &str = "Row"; + +pub fn write_polys_csv_file( + file: &mut impl Write, + render_mode: CsvRenderMode, + polys: &[(String, Vec)], +) { + let mut writer = Writer::from_writer(file); + + // Write headers, adding a "Row" column + let mut headers = vec![ROW_NAME]; + headers.extend(polys.iter().map(|(name, _)| { + assert!(name != ROW_NAME); + name.as_str() + })); + writer.write_record(&headers).unwrap(); + + let len = polys[0].1.len(); + for row_index in 0..len { + let mut row = Vec::new(); + row.push(format!("{}", row_index)); + for (_, values) in polys { + assert!(values.len() == len); + let value = match render_mode { + CsvRenderMode::SignedBase10 => format!("{}", values[row_index]), + CsvRenderMode::UnsignedBase10 => format!("{}", values[row_index].to_integer()), + CsvRenderMode::Hex => format!("0x{:x}", values[row_index].to_integer()), + }; + row.push(value); + } + writer.write_record(&row).unwrap(); + } + + writer.flush().unwrap(); +} + +pub fn read_polys_csv_file(file: &mut impl Read) -> Vec<(String, Vec)> { + let mut reader = Reader::from_reader(file); + let headers = reader.headers().unwrap(); + + let mut polys = headers + .iter() + .map(|name| (name.to_string(), Vec::new())) + .collect::>(); + + for result in reader.records() { + let record = result.unwrap(); + for (idx, value) in record.iter().enumerate() { + let value = if let Some(value) = value.strip_prefix("0x") { + T::from_str_radix(value, 16).unwrap() + } else if let Some(value) = value.strip_prefix('-') { + -T::from_str(value) + } else { + T::from_str(value) + }; + polys[idx].1.push(value); + } + } + + // Remove "Row" column, which was added by write_polys_csv_file() + polys + .into_iter() + .filter(|(name, _)| name != ROW_NAME) + .collect() +} + fn ceil_div(num: usize, div: usize) -> usize { (num + div - 1) / div } @@ -57,15 +132,19 @@ mod tests { use super::*; use test_log::test; + fn test_polys() -> Vec<(&'static str, Vec)> { + vec![ + ("a", (0..16).map(Bn254Field::from).collect()), + ("b", (-16..0).map(Bn254Field::from).collect()), + ] + } + #[test] fn write_read() { let mut buf: Vec = vec![]; - let degree = 4; - let polys = vec![ - ("a", vec![Bn254Field::from(0); degree]), - ("b", vec![Bn254Field::from(1); degree]), - ]; + let polys = test_polys(); + let degree = polys[0].1.len(); write_polys_file(&mut buf, degree as u64, &polys); let (read_polys, read_degree) = @@ -74,4 +153,24 @@ mod tests { assert_eq!(read_polys, polys); assert_eq!(read_degree, degree as u64); } + + #[test] + fn write_read_csv() { + let polys = test_polys() + .into_iter() + .map(|(name, values)| (name.to_string(), values)) + .collect::>(); + + for render_mode in &[ + CsvRenderMode::SignedBase10, + CsvRenderMode::UnsignedBase10, + CsvRenderMode::Hex, + ] { + let mut buf: Vec = vec![]; + write_polys_csv_file(&mut buf, *render_mode, &polys); + let read_polys = read_polys_csv_file::(&mut Cursor::new(buf)); + + assert_eq!(read_polys, polys); + } + } } diff --git a/powdr_cli/src/main.rs b/powdr_cli/src/main.rs index d2e5fb214..12181d068 100644 --- a/powdr_cli/src/main.rs +++ b/powdr_cli/src/main.rs @@ -9,10 +9,11 @@ use compiler::{compile_asm_string, compile_pil_or_asm, write_proving_results_to_ use env_logger::fmt::Color; use env_logger::{Builder, Target}; use log::LevelFilter; +use number::{read_polys_csv_file, write_polys_csv_file, CsvRenderMode}; use number::{Bn254Field, FieldElement, GoldilocksField}; use riscv::{compile_riscv_asm, compile_rust}; -use std::io::{self, BufWriter, Read}; -use std::{borrow::Cow, collections::HashSet, fs, io::Write, path::Path}; +use std::io::{self, BufReader, BufWriter, Read}; +use std::{borrow::Cow, fs, io::Write, path::Path}; use strum::{Display, EnumString, EnumVariantNames}; #[derive(Clone, EnumString, EnumVariantNames, Display)] @@ -24,7 +25,7 @@ pub enum FieldArgument { } #[derive(Clone, EnumString, EnumVariantNames, Display)] -pub enum CsvRenderMode { +pub enum CsvRenderModeCLI { #[strum(serialize = "i")] SignedBase10, #[strum(serialize = "ui")] @@ -63,6 +64,10 @@ enum Commands { #[arg(default_value_t = String::from("."))] output_directory: String, + /// Path to a CSV file containing externally computed witness values. + #[arg(short, long)] + witness_values: Option, + /// Comma-separated list of free inputs (numbers). Assumes queries to have the form /// ("input", ). #[arg(short, long)] @@ -86,9 +91,9 @@ enum Commands { /// How to render field elements in the csv file #[arg(long)] - #[arg(default_value_t = CsvRenderMode::Hex)] - #[arg(value_parser = clap_enum_variants!(CsvRenderMode))] - csv_mode: CsvRenderMode, + #[arg(default_value_t = CsvRenderModeCLI::Hex)] + #[arg(value_parser = clap_enum_variants!(CsvRenderModeCLI))] + csv_mode: CsvRenderModeCLI, }, /// Compiles (no-std) rust code to riscv assembly, then to powdr assembly /// and finally to PIL and generates fixed and witness columns. @@ -366,6 +371,7 @@ fn run_command(command: Commands) { file, field, output_directory, + witness_values, inputs, force, prove_with, @@ -375,6 +381,7 @@ fn run_command(command: Commands) { match call_with_field!(compile_with_csv_export::( file, output_directory, + witness_values, inputs, force, prove_with, @@ -447,6 +454,7 @@ fn run_rust( output_dir, force_overwrite, prove_with, + vec![], )?; Ok(()) } @@ -476,25 +484,41 @@ fn run_riscv_asm( output_dir, force_overwrite, prove_with, + vec![], )?; Ok(()) } +#[allow(clippy::too_many_arguments)] fn compile_with_csv_export( file: String, output_directory: String, + witness_values: Option, inputs: String, force: bool, prove_with: Option, export_csv: bool, - csv_mode: CsvRenderMode, + csv_mode: CsvRenderModeCLI, ) -> Result<(), Vec> { + let external_witness_values = witness_values + .map(|csv_path| { + let csv_file = fs::File::open(csv_path).unwrap(); + let mut csv_writer = BufReader::new(&csv_file); + read_polys_csv_file::(&mut csv_writer) + }) + .unwrap_or(vec![]); + + // Convert Vec<(String, Vec)> to Vec<(&str, Vec)> + let (strings, values): (Vec<_>, Vec<_>) = external_witness_values.into_iter().unzip(); + let external_witness_values = strings.iter().map(AsRef::as_ref).zip(values).collect(); + let result = compile_pil_or_asm::( &file, split_inputs(&inputs), Path::new(&output_directory), force, prove_with, + external_witness_values, )?; if export_csv { @@ -517,7 +541,7 @@ fn export_columns_to_csv( fixed: Vec<(String, Vec)>, witness: Option)>>, csv_path: &Path, - render_mode: CsvRenderMode, + render_mode: CsvRenderModeCLI, ) { let columns = fixed .into_iter() @@ -527,42 +551,13 @@ fn export_columns_to_csv( let mut csv_file = fs::File::create(csv_path).unwrap(); let mut csv_writer = BufWriter::new(&mut csv_file); - // Remove prefixes (e.g. "Assembly.") if column names are still unique after - let headers = columns - .iter() - .map(|(header, _)| header.to_owned()) - .collect::>(); - let headers_without_prefix = headers - .iter() - .map(|header| { - let suffix_start = header.rfind('.').map(|i| i + 1).unwrap_or(0); - header[suffix_start..].to_owned() - }) - .collect::>(); - - let unique_elements = headers_without_prefix.iter().collect::>(); - let headers = if unique_elements.len() == headers.len() { - headers_without_prefix - } else { - headers + let render_mode = match render_mode { + CsvRenderModeCLI::SignedBase10 => CsvRenderMode::SignedBase10, + CsvRenderModeCLI::UnsignedBase10 => CsvRenderMode::UnsignedBase10, + CsvRenderModeCLI::Hex => CsvRenderMode::Hex, }; - writeln!(csv_writer, "Row,{}", headers.join(",")).unwrap(); - - // Write the column values - let row_count = columns[0].1.len(); - for row_index in 0..row_count { - // format!("{}", values[row_index].to_integer() - let row_values: Vec = columns - .iter() - .map(|(_, values)| match render_mode { - CsvRenderMode::SignedBase10 => format!("{}", values[row_index]), - CsvRenderMode::UnsignedBase10 => format!("{}", values[row_index].to_integer()), - CsvRenderMode::Hex => format!("0x{:x}", values[row_index].to_integer()), - }) - .collect(); - writeln!(csv_writer, "{row_index},{}", row_values.join(",")).unwrap(); - } + write_polys_csv_file(&mut csv_writer, render_mode, &columns); } fn read_and_prove( @@ -613,7 +608,7 @@ fn optimize_and_output(file: &str) { #[cfg(test)] mod test { - use crate::{run_command, Commands, CsvRenderMode, FieldArgument}; + use crate::{run_command, Commands, CsvRenderModeCLI, FieldArgument}; use backend::BackendType; #[test] @@ -629,11 +624,12 @@ mod test { file, field: FieldArgument::Bn254, output_directory: output_dir_str.clone(), + witness_values: None, inputs: "3,2,1,2".into(), force: false, prove_with: Some(BackendType::PilStarkCli), export_csv: true, - csv_mode: CsvRenderMode::Hex, + csv_mode: CsvRenderModeCLI::Hex, }; run_command(pil_command);