Merge pull request #746 from powdr-labs/external-witgen-cli

Enable passing externally-generated witness via command line
This commit is contained in:
Georg Wiese
2023-11-03 15:13:22 +00:00
committed by GitHub
9 changed files with 174 additions and 54 deletions

View File

@@ -39,16 +39,24 @@ pub fn compile_pil_or_asm<T: FieldElement>(
output_dir: &Path,
force_overwrite: bool,
prove_with: Option<BackendType>,
external_witness_values: Vec<(&str, Vec<T>)>,
) -> Result<Option<CompilationResult<T>>, Vec<String>> {
if file_name.ends_with(".asm") {
compile_asm(file_name, inputs, output_dir, force_overwrite, prove_with)
compile_asm(
file_name,
inputs,
output_dir,
force_overwrite,
prove_with,
external_witness_values,
)
} else {
Ok(Some(compile_pil(
Path::new(file_name),
output_dir,
inputs_to_query_callback(inputs),
prove_with,
vec![],
external_witness_values,
)))
}
}
@@ -86,6 +94,7 @@ pub fn compile_pil_ast<T: FieldElement, Q: QueryCallback<T>>(
output_dir: &Path,
query_callback: Q,
prove_with: Option<BackendType>,
external_witness_values: Vec<(&str, Vec<T>)>,
) -> CompilationResult<T> {
// TODO exporting this to string as a hack because the parser
// is tied into the analyzer due to imports.
@@ -95,7 +104,7 @@ pub fn compile_pil_ast<T: FieldElement, Q: QueryCallback<T>>(
output_dir,
query_callback,
prove_with,
vec![],
external_witness_values,
)
}
@@ -108,6 +117,7 @@ pub fn compile_asm<T: FieldElement>(
output_dir: &Path,
force_overwrite: bool,
prove_with: Option<BackendType>,
external_witness_values: Vec<(&str, Vec<T>)>,
) -> Result<Option<CompilationResult<T>>, Vec<String>> {
let contents = fs::read_to_string(file_name).unwrap();
Ok(compile_asm_string(
@@ -117,6 +127,7 @@ pub fn compile_asm<T: FieldElement>(
output_dir,
force_overwrite,
prove_with,
external_witness_values,
)?
.1)
}
@@ -132,6 +143,7 @@ pub fn compile_asm_string<T: FieldElement>(
output_dir: &Path,
force_overwrite: bool,
prove_with: Option<BackendType>,
external_witness_values: Vec<(&str, Vec<T>)>,
) -> Result<(PathBuf, Option<CompilationResult<T>>), Vec<String>> {
let parsed = parser::parse_asm(Some(file_name), contents).unwrap_or_else(|err| {
eprintln!("Error parsing .asm file:");
@@ -179,6 +191,7 @@ pub fn compile_asm_string<T: FieldElement>(
output_dir,
inputs_to_query_callback(inputs),
prove_with,
external_witness_values,
)),
))
}

View File

@@ -13,6 +13,7 @@ pub fn verify_asm_string<T: FieldElement>(file_name: &str, contents: &str, input
&temp_dir,
true,
Some(BackendType::PilStarkCli),
vec![],
)
.unwrap();
verify(&temp_dir);

View File

@@ -25,6 +25,7 @@ fn gen_estark_proof(file_name: &str, inputs: Vec<GoldilocksField>) {
&mktemp::Temp::new_dir().unwrap(),
true,
Some(backend::BackendType::EStark),
vec![],
)
.unwrap();
}
@@ -41,6 +42,7 @@ fn gen_halo2_proof(file_name: &str, inputs: Vec<Bn254Field>) {
&mktemp::Temp::new_dir().unwrap(),
true,
Some(backend::BackendType::Halo2),
vec![],
)
.unwrap();
}

View File

@@ -45,6 +45,7 @@ fn gen_estark_proof(file_name: &str, inputs: Vec<GoldilocksField>) {
&mktemp::Temp::new_dir().unwrap(),
true,
Some(BackendType::EStark),
vec![],
)
.unwrap();
}
@@ -61,6 +62,7 @@ fn gen_halo2_proof(file_name: &str, inputs: Vec<Bn254Field>) {
&mktemp::Temp::new_dir().unwrap(),
true,
Some(BackendType::Halo2),
vec![],
)
.unwrap();
}

View File

@@ -25,6 +25,7 @@ fn gen_estark_proof(file_name: &str, inputs: Vec<GoldilocksField>) {
&mktemp::Temp::new_dir().unwrap(),
true,
Some(backend::BackendType::EStark),
vec![],
)
.unwrap();
}
@@ -41,6 +42,7 @@ fn gen_halo2_proof(file_name: &str, inputs: Vec<Bn254Field>) {
&mktemp::Temp::new_dir().unwrap(),
true,
Some(backend::BackendType::Halo2Mock),
vec![],
)
.unwrap();
}

View File

@@ -4,10 +4,13 @@ version = "0.1.0"
edition = "2021"
[dependencies]
ark-bn254 = { version = "0.4.0", default-features = false, features = ["scalar_field"] }
ark-bn254 = { version = "0.4.0", default-features = false, features = [
"scalar_field",
] }
ark-ff = "0.4.2"
num-bigint = "0.4.3"
num-traits = "0.2.15"
csv = "1.3"
[dev-dependencies]
test-log = "0.2.12"

View File

@@ -9,7 +9,9 @@ mod goldilocks;
mod serialize;
mod traits;
pub use serialize::{read_polys_file, write_polys_file};
pub use serialize::{
read_polys_csv_file, read_polys_file, write_polys_csv_file, write_polys_file, CsvRenderMode,
};
pub use bn254::Bn254Field;
pub use goldilocks::GoldilocksField;

View File

@@ -1,7 +1,82 @@
use std::io::{Read, Write};
use csv::{Reader, Writer};
use crate::{DegreeType, FieldElement};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum CsvRenderMode {
SignedBase10,
UnsignedBase10,
Hex,
}
const ROW_NAME: &str = "Row";
pub fn write_polys_csv_file<T: FieldElement>(
file: &mut impl Write,
render_mode: CsvRenderMode,
polys: &[(String, Vec<T>)],
) {
let mut writer = Writer::from_writer(file);
// Write headers, adding a "Row" column
let mut headers = vec![ROW_NAME];
headers.extend(polys.iter().map(|(name, _)| {
assert!(name != ROW_NAME);
name.as_str()
}));
writer.write_record(&headers).unwrap();
let len = polys[0].1.len();
for row_index in 0..len {
let mut row = Vec::new();
row.push(format!("{}", row_index));
for (_, values) in polys {
assert!(values.len() == len);
let value = match render_mode {
CsvRenderMode::SignedBase10 => format!("{}", values[row_index]),
CsvRenderMode::UnsignedBase10 => format!("{}", values[row_index].to_integer()),
CsvRenderMode::Hex => format!("0x{:x}", values[row_index].to_integer()),
};
row.push(value);
}
writer.write_record(&row).unwrap();
}
writer.flush().unwrap();
}
pub fn read_polys_csv_file<T: FieldElement>(file: &mut impl Read) -> Vec<(String, Vec<T>)> {
let mut reader = Reader::from_reader(file);
let headers = reader.headers().unwrap();
let mut polys = headers
.iter()
.map(|name| (name.to_string(), Vec::new()))
.collect::<Vec<_>>();
for result in reader.records() {
let record = result.unwrap();
for (idx, value) in record.iter().enumerate() {
let value = if let Some(value) = value.strip_prefix("0x") {
T::from_str_radix(value, 16).unwrap()
} else if let Some(value) = value.strip_prefix('-') {
-T::from_str(value)
} else {
T::from_str(value)
};
polys[idx].1.push(value);
}
}
// Remove "Row" column, which was added by write_polys_csv_file()
polys
.into_iter()
.filter(|(name, _)| name != ROW_NAME)
.collect()
}
fn ceil_div(num: usize, div: usize) -> usize {
(num + div - 1) / div
}
@@ -57,15 +132,19 @@ mod tests {
use super::*;
use test_log::test;
fn test_polys() -> Vec<(&'static str, Vec<Bn254Field>)> {
vec![
("a", (0..16).map(Bn254Field::from).collect()),
("b", (-16..0).map(Bn254Field::from).collect()),
]
}
#[test]
fn write_read() {
let mut buf: Vec<u8> = vec![];
let degree = 4;
let polys = vec![
("a", vec![Bn254Field::from(0); degree]),
("b", vec![Bn254Field::from(1); degree]),
];
let polys = test_polys();
let degree = polys[0].1.len();
write_polys_file(&mut buf, degree as u64, &polys);
let (read_polys, read_degree) =
@@ -74,4 +153,24 @@ mod tests {
assert_eq!(read_polys, polys);
assert_eq!(read_degree, degree as u64);
}
#[test]
fn write_read_csv() {
let polys = test_polys()
.into_iter()
.map(|(name, values)| (name.to_string(), values))
.collect::<Vec<_>>();
for render_mode in &[
CsvRenderMode::SignedBase10,
CsvRenderMode::UnsignedBase10,
CsvRenderMode::Hex,
] {
let mut buf: Vec<u8> = vec![];
write_polys_csv_file(&mut buf, *render_mode, &polys);
let read_polys = read_polys_csv_file::<Bn254Field>(&mut Cursor::new(buf));
assert_eq!(read_polys, polys);
}
}
}

View File

@@ -9,10 +9,11 @@ use compiler::{compile_asm_string, compile_pil_or_asm, write_proving_results_to_
use env_logger::fmt::Color;
use env_logger::{Builder, Target};
use log::LevelFilter;
use number::{read_polys_csv_file, write_polys_csv_file, CsvRenderMode};
use number::{Bn254Field, FieldElement, GoldilocksField};
use riscv::{compile_riscv_asm, compile_rust};
use std::io::{self, BufWriter, Read};
use std::{borrow::Cow, collections::HashSet, fs, io::Write, path::Path};
use std::io::{self, BufReader, BufWriter, Read};
use std::{borrow::Cow, fs, io::Write, path::Path};
use strum::{Display, EnumString, EnumVariantNames};
#[derive(Clone, EnumString, EnumVariantNames, Display)]
@@ -24,7 +25,7 @@ pub enum FieldArgument {
}
#[derive(Clone, EnumString, EnumVariantNames, Display)]
pub enum CsvRenderMode {
pub enum CsvRenderModeCLI {
#[strum(serialize = "i")]
SignedBase10,
#[strum(serialize = "ui")]
@@ -63,6 +64,10 @@ enum Commands {
#[arg(default_value_t = String::from("."))]
output_directory: String,
/// Path to a CSV file containing externally computed witness values.
#[arg(short, long)]
witness_values: Option<String>,
/// Comma-separated list of free inputs (numbers). Assumes queries to have the form
/// ("input", <index>).
#[arg(short, long)]
@@ -86,9 +91,9 @@ enum Commands {
/// How to render field elements in the csv file
#[arg(long)]
#[arg(default_value_t = CsvRenderMode::Hex)]
#[arg(value_parser = clap_enum_variants!(CsvRenderMode))]
csv_mode: CsvRenderMode,
#[arg(default_value_t = CsvRenderModeCLI::Hex)]
#[arg(value_parser = clap_enum_variants!(CsvRenderModeCLI))]
csv_mode: CsvRenderModeCLI,
},
/// Compiles (no-std) rust code to riscv assembly, then to powdr assembly
/// and finally to PIL and generates fixed and witness columns.
@@ -366,6 +371,7 @@ fn run_command(command: Commands) {
file,
field,
output_directory,
witness_values,
inputs,
force,
prove_with,
@@ -375,6 +381,7 @@ fn run_command(command: Commands) {
match call_with_field!(compile_with_csv_export::<field>(
file,
output_directory,
witness_values,
inputs,
force,
prove_with,
@@ -447,6 +454,7 @@ fn run_rust<F: FieldElement>(
output_dir,
force_overwrite,
prove_with,
vec![],
)?;
Ok(())
}
@@ -476,25 +484,41 @@ fn run_riscv_asm<F: FieldElement>(
output_dir,
force_overwrite,
prove_with,
vec![],
)?;
Ok(())
}
#[allow(clippy::too_many_arguments)]
fn compile_with_csv_export<T: FieldElement>(
file: String,
output_directory: String,
witness_values: Option<String>,
inputs: String,
force: bool,
prove_with: Option<BackendType>,
export_csv: bool,
csv_mode: CsvRenderMode,
csv_mode: CsvRenderModeCLI,
) -> Result<(), Vec<String>> {
let external_witness_values = witness_values
.map(|csv_path| {
let csv_file = fs::File::open(csv_path).unwrap();
let mut csv_writer = BufReader::new(&csv_file);
read_polys_csv_file::<T>(&mut csv_writer)
})
.unwrap_or(vec![]);
// Convert Vec<(String, Vec<T>)> to Vec<(&str, Vec<T>)>
let (strings, values): (Vec<_>, Vec<_>) = external_witness_values.into_iter().unzip();
let external_witness_values = strings.iter().map(AsRef::as_ref).zip(values).collect();
let result = compile_pil_or_asm::<T>(
&file,
split_inputs(&inputs),
Path::new(&output_directory),
force,
prove_with,
external_witness_values,
)?;
if export_csv {
@@ -517,7 +541,7 @@ fn export_columns_to_csv<T: FieldElement>(
fixed: Vec<(String, Vec<T>)>,
witness: Option<Vec<(String, Vec<T>)>>,
csv_path: &Path,
render_mode: CsvRenderMode,
render_mode: CsvRenderModeCLI,
) {
let columns = fixed
.into_iter()
@@ -527,42 +551,13 @@ fn export_columns_to_csv<T: FieldElement>(
let mut csv_file = fs::File::create(csv_path).unwrap();
let mut csv_writer = BufWriter::new(&mut csv_file);
// Remove prefixes (e.g. "Assembly.") if column names are still unique after
let headers = columns
.iter()
.map(|(header, _)| header.to_owned())
.collect::<Vec<_>>();
let headers_without_prefix = headers
.iter()
.map(|header| {
let suffix_start = header.rfind('.').map(|i| i + 1).unwrap_or(0);
header[suffix_start..].to_owned()
})
.collect::<Vec<_>>();
let unique_elements = headers_without_prefix.iter().collect::<HashSet<_>>();
let headers = if unique_elements.len() == headers.len() {
headers_without_prefix
} else {
headers
let render_mode = match render_mode {
CsvRenderModeCLI::SignedBase10 => CsvRenderMode::SignedBase10,
CsvRenderModeCLI::UnsignedBase10 => CsvRenderMode::UnsignedBase10,
CsvRenderModeCLI::Hex => CsvRenderMode::Hex,
};
writeln!(csv_writer, "Row,{}", headers.join(",")).unwrap();
// Write the column values
let row_count = columns[0].1.len();
for row_index in 0..row_count {
// format!("{}", values[row_index].to_integer()
let row_values: Vec<String> = columns
.iter()
.map(|(_, values)| match render_mode {
CsvRenderMode::SignedBase10 => format!("{}", values[row_index]),
CsvRenderMode::UnsignedBase10 => format!("{}", values[row_index].to_integer()),
CsvRenderMode::Hex => format!("0x{:x}", values[row_index].to_integer()),
})
.collect();
writeln!(csv_writer, "{row_index},{}", row_values.join(",")).unwrap();
}
write_polys_csv_file(&mut csv_writer, render_mode, &columns);
}
fn read_and_prove<T: FieldElement>(
@@ -613,7 +608,7 @@ fn optimize_and_output<T: FieldElement>(file: &str) {
#[cfg(test)]
mod test {
use crate::{run_command, Commands, CsvRenderMode, FieldArgument};
use crate::{run_command, Commands, CsvRenderModeCLI, FieldArgument};
use backend::BackendType;
#[test]
@@ -629,11 +624,12 @@ mod test {
file,
field: FieldArgument::Bn254,
output_directory: output_dir_str.clone(),
witness_values: None,
inputs: "3,2,1,2".into(),
force: false,
prove_with: Some(BackendType::PilStarkCli),
export_csv: true,
csv_mode: CsvRenderMode::Hex,
csv_mode: CsvRenderModeCLI::Hex,
};
run_command(pil_command);