From 40a9cef4fa243c6e16049baebd837fee9b583c5d Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 27 Feb 2023 14:15:36 +0100 Subject: [PATCH] Witness generation for assembly. --- src/asm_compiler/mod.rs | 1 + src/commit_evaluator/mod.rs | 105 +++++++++++++++++++++++++----------- tests/integration.rs | 77 +++++++++++++++++++++----- tests/simple_sum.asm | 1 + 4 files changed, 139 insertions(+), 45 deletions(-) diff --git a/src/asm_compiler/mod.rs b/src/asm_compiler/mod.rs index 1759859b0..187f608ac 100644 --- a/src/asm_compiler/mod.rs +++ b/src/asm_compiler/mod.rs @@ -650,6 +650,7 @@ pol commit XInv; pol commit XIsZero; XIsZero = (1 - (X * XInv)); (XIsZero * X) = 0; +(XIsZero * (1 - XIsZero)) = 0; pol commit instr_jmpz; pol commit instr_jmpz_param_l; pol commit instr_jmp; diff --git a/src/commit_evaluator/mod.rs b/src/commit_evaluator/mod.rs index e9b1148f1..100647df8 100644 --- a/src/commit_evaluator/mod.rs +++ b/src/commit_evaluator/mod.rs @@ -130,7 +130,10 @@ where // TODO maybe better to generate a dependency graph than looping multiple times. // TODO at least we could cache the affine expressions between loops. + + let mut identity_failed; loop { + identity_failed = false; self.progress = false; self.failure_reasons.clear(); @@ -152,6 +155,9 @@ where _ => Ok(vec![]), } .map_err(|err| format!("No progress on {identity}:\n {err}")); + if result.is_err() { + identity_failed = true; + } self.handle_eval_result(result); } if !self.progress { @@ -169,7 +175,7 @@ where break; } } - if self.next.iter().any(|v| v.is_none()) { + if identity_failed && self.next.iter().any(|v| v.is_none()) { eprintln!( "Error: Row {next_row}: Unable to derive values for committed polynomials: {}", self.next @@ -183,12 +189,32 @@ where .collect::>() .join(", ") ); - eprintln!("Reasons: {}", self.failure_reasons.join("\n")); + eprintln!("Reasons: {}", self.failure_reasons.join("\n\n")); + eprintln!( + "Current values:\n{}", + self.next + .iter() + .enumerate() + .map(|(i, v)| format!( + "{} = {}", + self.committed_names[i], + v.as_ref() + .map(|v| format!("{v}")) + .unwrap_or("".to_string()) + )) + .collect::>() + .join("\n") + ); panic!(); } else { std::mem::swap(&mut self.next, &mut self.current); self.next = vec![None; self.current.len()]; - self.current.iter().map(|v| v.clone().unwrap()).collect() + // TODO check a bit better that "None" values do not + // violate constraints. + self.current + .iter() + .map(|v| v.clone().unwrap_or_default()) + .collect() } } @@ -196,7 +222,7 @@ where &mut self, column: &&WitnessColumn, ) -> Result, String> { - let query = self.interpolate_query(column.query.unwrap()); + let query = self.interpolate_query(column.query.unwrap())?; if let Some(value) = self.query_callback.as_mut().unwrap()(&query) { Ok(vec![(column.id, value)]) } else { @@ -207,28 +233,28 @@ where } } - fn interpolate_query(&self, query: &Expression) -> String { + fn interpolate_query(&self, query: &Expression) -> Result { + if let Ok(v) = self.evaluate(query, EvaluationRow::Next) { + if v.is_constant() { + return Ok(self.format_affine_expression(v)); + } + } // TODO combine that with the constant evaluator and the commit evaluator... match query { - Expression::Tuple(items) => items + Expression::Tuple(items) => Ok(items .iter() .map(|i| self.interpolate_query(i)) - .collect::>() - .join(", "), + .collect::, _>>()? + .join(", ")), Expression::LocalVariableReference(i) => { assert!(*i == 0); - format!("{}", self.next_row) + Ok(format!("{}", self.next_row)) } - Expression::Constant(_) => todo!(), - Expression::PolynomialReference(_) => todo!(), - Expression::PublicReference(_) => todo!(), - Expression::Number(n) => format!("{n}"), - Expression::String(s) => { - format!("\"{}\"", s.replace('\\', "\\\\").replace('"', "\\\"")) - } - Expression::BinaryOperation(_, _, _) => todo!(), - Expression::UnaryOperation(_, _) => todo!(), - Expression::FunctionCall(_, _) => todo!(), + Expression::String(s) => Ok(format!( + "\"{}\"", + s.replace('\\', "\\\\").replace('"', "\\\"") + )), + _ => Err(format!("Cannot handle / evaluate {query}")), } } @@ -241,12 +267,16 @@ where EvaluationRow::Next }; let evaluated = self.evaluate(identity, row)?; - match evaluated.solve() { - Some((id, value)) => Ok(vec![(id, value)]), - None => Err(format!( - "Could not solve expression {}", - self.format_affine_expression(evaluated) - )), + if evaluated.constant_value() == Some(0.into()) { + Ok(vec![]) + } else { + match evaluated.solve() { + Some((id, value)) => Ok(vec![(id, value)]), + None => Err(format!( + "Could not solve expression {} (might be an invalid constraint)", + self.format_affine_expression(evaluated) + )), + } } } @@ -345,7 +375,7 @@ where match evaluated.solve() { Some((id, value)) => Ok(vec![(id, value)]), None => Err(format!( - "Could not solve expression {}", + "Could not solve expression {} (might be an invalid constraint)", self.format_affine_expression(evaluated) )), } @@ -431,11 +461,13 @@ where self.evaluate_binary_operation(left, op, right, row) } Expression::UnaryOperation(op, expr) => self.evaluate_unary_operation(op, expr, row), - Expression::Tuple(_) => panic!(), - Expression::String(_) => panic!(), - Expression::LocalVariableReference(_) => panic!(), - Expression::PublicReference(_) => panic!(), - Expression::FunctionCall(_, _) => panic!(), + Expression::Tuple(_) => Err("Tuple not implemented.".to_string()), + Expression::String(_) => Err("String not implemented.".to_string()), + Expression::LocalVariableReference(_) => { + Err("Local variable references not implemented.".to_string()) + } + Expression::PublicReference(_) => Err("Public references not implemented.".to_string()), + Expression::FunctionCall(_, _) => Err("Function calls not implemented.".to_string()), } } @@ -555,7 +587,16 @@ where .iter() .enumerate() .filter(|(_, c)| !is_zero(c)) - .map(|(i, c)| format!("{} * {c}", self.committed_names[i])) + .map(|(i, c)| { + let name = self.committed_names[i]; + if *c == 1.into() { + name.clone() + } else if *c == (-1).into() { + format!("-{name}") + } else { + format!("{c} * {name}") + } + }) .chain(e.constant_value().map(|v| format!("{v}"))) .collect::>() .join(" + ") diff --git a/tests/integration.rs b/tests/integration.rs index bfcb46036..7a481f454 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,26 +1,70 @@ -use std::{path::Path, process::Command}; +use std::{fs, path::Path, process::Command}; use powdr::compiler; use powdr::number::AbstractNumberType; -fn verify(file_name: &str, query_callback: Option Option>) { +fn verify_pil(file_name: &str, query_callback: Option Option>) { let input_file = Path::new(&format!("./tests/{file_name}")) .canonicalize() .unwrap(); let temp_dir = mktemp::Temp::new_dir().unwrap(); - compiler::compile_pil(&input_file, &temp_dir, query_callback); + assert!(compiler::compile_pil( + &input_file, + &temp_dir, + query_callback + )); + verify(file_name, &temp_dir); +} +fn verify_asm(file_name: &str, inputs: Vec) { + let contents = fs::read_to_string(format!("./tests/{file_name}")).unwrap(); + let pil = powdr::asm_compiler::compile(Some(file_name), &contents).unwrap(); + let pil_file_name = "asm.pil"; + let temp_dir = mktemp::Temp::new_dir().unwrap(); + assert!(compiler::compile_pil_ast( + &pil, + pil_file_name, + &temp_dir, + Some(|input: &str| { + let items = input.split(',').map(|s| s.trim()).collect::>(); + let mut it = items.iter(); + let _current_step = it.next().unwrap(); + let current_pc = it.next().unwrap(); + while let Some(pc_check) = it.next() { + if pc_check == current_pc { + assert_eq!(*it.next().unwrap(), "\"input\""); + let index: usize = it.next().map(|s| s.parse().unwrap()).unwrap(); + return Some(inputs[index].clone()); + } else { + it.next(); + it.next(); + } + } + None + }), + )); + verify(pil_file_name, &temp_dir); +} + +fn verify(file_name: &str, temp_dir: &Path) { let pilcom = std::env::var("PILCOM") .expect("Please set the PILCOM environment variable to the path to the pilcom repository."); + let constants_file = format!("{}/constants.bin", temp_dir.to_string_lossy()); + let commits_file = format!("{}/commits.bin", temp_dir.to_string_lossy()); + assert!( + fs::metadata(&constants_file).unwrap().len() > 0, + "Empty constants file" + ); + let verifier_output = Command::new("node") .args([ format!("{pilcom}/src/main_pilverifier.js"), - format!("{}/commits.bin", temp_dir.as_path().to_string_lossy()), + commits_file, "-j".to_string(), - format!("{}/{file_name}.json", temp_dir.as_path().to_string_lossy()), + format!("{}/{file_name}.json", temp_dir.to_string_lossy()), "-c".to_string(), - format!("{}/constants.bin", temp_dir.as_path().to_string_lossy()), + constants_file, ]) .output() .expect("failed to run pil verifier"); @@ -36,28 +80,26 @@ fn verify(file_name: &str, query_callback: Option Option Some(3.into()), @@ -83,3 +125,12 @@ fn test_witness_lookup() { }), ); } + +#[test] +#[ignore] +fn simple_sum_asm() { + verify_asm( + "simple_sum.asm", + [16, 4, 1, 2, 8, 5].iter().map(|&x| x.into()).collect(), + ); +} diff --git a/tests/simple_sum.asm b/tests/simple_sum.asm index 0e0781adf..85d4ea599 100644 --- a/tests/simple_sum.asm +++ b/tests/simple_sum.asm @@ -17,6 +17,7 @@ pil{ col witness XIsZero; XIsZero = 1 - X * XInv; XIsZero * X = 0; + XIsZero * (1 - XIsZero) = 0; } instr jmpz <=X= c, l: label { pc' = XIsZero * l + (1 - XIsZero) * (pc + 1) }