Fix CSV export #380 (#380)

This commit is contained in:
Georg Wiese
2023-07-20 10:58:17 +02:00
committed by GitHub
parent 6780ccb218
commit 4fdea44740
5 changed files with 151 additions and 56 deletions

View File

@@ -30,23 +30,23 @@ pub fn no_callback<T>() -> Option<fn(&str) -> Option<T>> {
/// Compiles a .pil or .asm file and runs witness generation.
/// If the file ends in .asm, converts it to .pil first.
/// Returns the compilation result if any compilation took place.
pub fn compile_pil_or_asm<T: FieldElement>(
file_name: &str,
inputs: Vec<T>,
output_dir: &Path,
force_overwrite: bool,
prove_with: Option<Backend>,
) -> PathBuf {
) -> Option<CompilationResult<T>> {
if file_name.ends_with(".asm") {
compile_asm(file_name, inputs, output_dir, force_overwrite, prove_with)
} else {
compile_pil(
Some(compile_pil(
Path::new(file_name),
output_dir,
Some(inputs_to_query_callback(inputs)),
prove_with,
);
PathBuf::from(file_name)
))
}
}
@@ -56,14 +56,14 @@ pub fn analyze_pil<T: FieldElement>(pil_file: &Path) -> Analyzed<T> {
/// Compiles a .pil file to its json form and also tries to generate
/// constants and committed polynomials.
/// @returns true if all committed/witness and constant/fixed polynomials
/// could be generated.
/// @returns a compilation result, containing witness and fixed columns
/// if they could be successfully generated.
pub fn compile_pil<T: FieldElement, QueryCallback>(
pil_file: &Path,
output_dir: &Path,
query_callback: Option<QueryCallback>,
prove_with: Option<Backend>,
) -> bool
) -> CompilationResult<T>
where
QueryCallback: FnMut(&str) -> Option<T> + Sync + Send,
{
@@ -76,13 +76,15 @@ where
)
}
/// Compiles a given PIL and tries to generate fixed and witness columns.
/// @returns a compilation result, containing witness and fixed columns
pub fn compile_pil_ast<T: FieldElement, QueryCallback>(
pil: &PILFile<T>,
file_name: &OsStr,
output_dir: &Path,
query_callback: Option<QueryCallback>,
prove_with: Option<Backend>,
) -> bool
) -> CompilationResult<T>
where
QueryCallback: FnMut(&str) -> Option<T> + Sync + Send,
{
@@ -99,13 +101,14 @@ where
/// Compiles a .asm file, outputs the PIL on stdout and tries to generate
/// fixed and witness columns.
/// @returns a compilation result if any compilation was done.
pub fn compile_asm<T: FieldElement>(
file_name: &str,
inputs: Vec<T>,
output_dir: &Path,
force_overwrite: bool,
prove_with: Option<Backend>,
) -> PathBuf {
) -> Option<CompilationResult<T>> {
let contents = fs::read_to_string(file_name).unwrap();
compile_asm_string(
file_name,
@@ -115,12 +118,13 @@ pub fn compile_asm<T: FieldElement>(
force_overwrite,
prove_with,
)
.1
}
/// Compiles the contents of a .asm file, outputs the PIL on stdout and tries to generate
/// fixed and witness columns.
///
/// Returns the relative pil file name.
/// Returns the relative pil file name and the compilation result if any compilation was done.
pub fn compile_asm_string<T: FieldElement>(
file_name: &str,
contents: &str,
@@ -128,7 +132,7 @@ pub fn compile_asm_string<T: FieldElement>(
output_dir: &Path,
force_overwrite: bool,
prove_with: Option<Backend>,
) -> PathBuf {
) -> (PathBuf, Option<CompilationResult<T>>) {
let parsed = parser::parse_asm(Some(file_name), contents).unwrap_or_else(|err| {
eprintln!("Error parsing .asm file:");
err.output_to_stderr();
@@ -149,28 +153,41 @@ pub fn compile_asm_string<T: FieldElement>(
"Target file {} already exists. Not overwriting.",
pil_file_path.to_str().unwrap()
);
return pil_file_path;
return (pil_file_path, None);
}
fs::write(pil_file_path.clone(), format!("{pil}")).unwrap();
compile_pil_ast(
&pil,
pil_file_path.file_name().unwrap(),
output_dir,
Some(inputs_to_query_callback(inputs)),
prove_with,
);
pil_file_path
let pil_file_name = pil_file_path.file_name().unwrap();
(
pil_file_path.clone(),
Some(compile_pil_ast(
&pil,
pil_file_name,
output_dir,
Some(inputs_to_query_callback(inputs)),
prove_with,
)),
)
}
pub struct CompilationResult<T: FieldElement> {
/// Whether all committed/witness and constant/fixed polynomials could be generated.
pub success: bool,
/// Constant columns, potentially incomplete (if success is false)
pub constants: Vec<(String, Vec<T>)>,
/// Witness columns, potentially None (if success is false)
pub witness: Option<Vec<(String, Vec<T>)>>,
}
/// Optimizes a given pil and tries to generate constants and committed polynomials.
/// @returns a compilation result, containing witness and fixed columns, if successful.
fn compile<T: FieldElement, QueryCallback>(
analyzed: Analyzed<T>,
file_name: &OsStr,
output_dir: &Path,
query_callback: Option<QueryCallback>,
prove_with: Option<Backend>,
) -> bool
) -> CompilationResult<T>
where
QueryCallback: FnMut(&str) -> Option<T> + Send + Sync,
{
@@ -187,6 +204,8 @@ where
log::info!("Evaluating fixed columns...");
let (constants, degree) = constant_evaluator::generate(&analyzed);
log::info!("Took {}", start.elapsed().as_secs_f32());
let mut witness = None;
if analyzed.constant_count() == constants.len() {
write_constants_to_fs(&constants, output_dir, degree);
log::info!("Generated constants.");
@@ -200,10 +219,17 @@ where
if let Some(Backend::Halo2) = prove_with {
let degree = usize::BITS - degree.leading_zeros() + 1;
let params = halo2::kzg_params(degree as usize);
let proof = halo2::prove_ast(&analyzed, constants, commits, params);
let proof = halo2::prove_ast(&analyzed, constants.clone(), commits.clone(), params);
write_proof_to_fs(&proof, output_dir);
log::info!("Generated proof.");
}
witness = Some(
commits
.into_iter()
.map(|(name, c)| (name.to_owned(), c))
.collect(),
);
} else {
log::warn!("Not writing constants.bin because not all declared constants are defined (or there are none).");
success = false;
@@ -212,7 +238,16 @@ where
write_compiled_json_to_fs(&json_out, file_name, output_dir);
log::info!("Compiled PIL source code.");
success
let constants = constants
.into_iter()
.map(|(name, c)| (name.to_owned(), c))
.collect();
CompilationResult {
success,
constants,
witness,
}
}
pub fn inputs_to_query_callback<T: FieldElement>(inputs: Vec<T>) -> impl Fn(&str) -> Option<T> {

View File

@@ -5,7 +5,7 @@ use crate::compile_asm_string;
pub fn verify_asm_string<T: FieldElement>(file_name: &str, contents: &str, inputs: Vec<T>) {
let temp_dir = mktemp::Temp::new_dir().unwrap();
let pil_file_path = compile_asm_string(file_name, contents, inputs, &temp_dir, true, None);
let pil_file_path = compile_asm_string(file_name, contents, inputs, &temp_dir, true, None).0;
let pil_file_name = pil_file_path.file_name().unwrap().to_string_lossy();
verify(&pil_file_name, &temp_dir);
}

View File

@@ -10,12 +10,7 @@ pub fn verify_pil(file_name: &str, query_callback: Option<fn(&str) -> Option<Gol
.unwrap();
let temp_dir = mktemp::Temp::new_dir().unwrap();
assert!(compiler::compile_pil(
&input_file,
&temp_dir,
query_callback,
None,
));
assert!(compiler::compile_pil(&input_file, &temp_dir, query_callback, None,).success);
compiler::verify(file_name, &temp_dir);
}

View File

@@ -16,6 +16,9 @@ backend = { path = "../backend" }
pilopt = { path = "../pilopt" }
strum = { version = "0.24.1", features = ["derive"] }
[dev-dependencies]
tempfile = "3.6"
[[bin]]
name = "powdr"
path = "src/main.rs"

View File

@@ -242,6 +242,10 @@ fn main() {
.init();
let command = Cli::parse().command;
run_command(command);
}
fn run_command(command: Commands) {
match command {
Commands::Rust {
file,
@@ -301,22 +305,15 @@ fn main() {
export_csv,
csv_mode,
} => {
let pil_filename = call_with_field!(compile_pil_or_asm::<field>(
&file,
split_inputs(&inputs),
Path::new(&output_directory),
call_with_field!(compile_with_csv_export::<field>(
file,
output_directory,
inputs,
force,
prove_with
prove_with,
export_csv,
csv_mode
));
if export_csv {
let pil = Path::new(&pil_filename);
let dir = Path::new(&output_directory);
let csv_path = dir.join("columns.csv");
call_with_field!(export_columns_to_csv::<field>(
pil, dir, &csv_path, csv_mode
));
}
}
Commands::Prove {
file,
@@ -374,23 +371,47 @@ fn write_params_to_fs(params: &[u8], output_dir: &Path) {
log::info!("Wrote params.bin.");
}
fn compile_with_csv_export<T: FieldElement>(
file: String,
output_directory: String,
inputs: String,
force: bool,
prove_with: Option<Backend>,
export_csv: bool,
csv_mode: CsvRenderMode,
) {
let result = compile_pil_or_asm::<T>(
&file,
split_inputs(&inputs),
Path::new(&output_directory),
force,
prove_with,
);
if export_csv {
// Compilation result is None if the ASM file has not been compiled
// (e.g. it has been compiled before and the force flag is not set)
if let Some(compilation_result) = result {
let csv_path = Path::new(&output_directory).join("columns.csv");
export_columns_to_csv::<T>(
compilation_result.constants,
compilation_result.witness,
&csv_path,
csv_mode,
);
}
}
}
fn export_columns_to_csv<T: FieldElement>(
file: &Path,
dir: &Path,
fixed: Vec<(String, Vec<T>)>,
witness: Option<Vec<(String, Vec<T>)>>,
csv_path: &Path,
render_mode: CsvRenderMode,
) {
let pil = compiler::analyze_pil::<T>(file);
let fixed = compiler::util::read_fixed(&pil, dir);
let witness = compiler::util::read_witness(&pil, dir);
assert_eq!(fixed.1, witness.1);
let columns = fixed
.0
.into_iter()
.chain(witness.0.into_iter())
.map(|(name, values)| (name.to_owned(), values))
.chain(witness.unwrap_or(vec![]).into_iter())
.collect::<Vec<_>>();
let mut csv_file = fs::File::create(csv_path).unwrap();
@@ -480,3 +501,44 @@ fn optimize_and_output<T: FieldElement>(file: &str) {
pilopt::optimize(compiler::analyze_pil::<T>(Path::new(file)))
);
}
#[cfg(test)]
mod test {
use backend::Backend;
use tempfile;
use crate::{run_command, Commands, CsvRenderMode, FieldArgument};
#[test]
fn test_simple_sum() {
let output_dir = tempfile::tempdir().unwrap();
let output_dir_str = output_dir.path().to_string_lossy().to_string();
let pil_command = Commands::Pil {
file: "../test_data/asm/simple_sum.asm".into(),
field: FieldArgument::Bn254,
output_directory: output_dir_str.clone(),
inputs: "3,2,1,2".into(),
force: false,
prove_with: None,
export_csv: true,
csv_mode: CsvRenderMode::Hex,
};
run_command(pil_command);
let file = output_dir
.path()
.join("simple_sum_opt.pil")
.to_string_lossy()
.to_string();
let prove_command = Commands::Prove {
file,
dir: output_dir_str,
field: FieldArgument::Bn254,
backend: Backend::Halo2Mock,
proof: None,
params: None,
};
run_command(prove_command);
}
}