Witness generation for assembly.

This commit is contained in:
chriseth
2023-02-27 14:15:36 +01:00
parent 2f664ddc6f
commit 40a9cef4fa
4 changed files with 139 additions and 45 deletions

View File

@@ -650,6 +650,7 @@ pol commit XInv;
pol commit XIsZero;
XIsZero = (1 - (X * XInv));
(XIsZero * X) = 0;
(XIsZero * (1 - XIsZero)) = 0;
pol commit instr_jmpz;
pol commit instr_jmpz_param_l;
pol commit instr_jmp;

View File

@@ -130,7 +130,10 @@ where
// TODO maybe better to generate a dependency graph than looping multiple times.
// TODO at least we could cache the affine expressions between loops.
let mut identity_failed;
loop {
identity_failed = false;
self.progress = false;
self.failure_reasons.clear();
@@ -152,6 +155,9 @@ where
_ => Ok(vec![]),
}
.map_err(|err| format!("No progress on {identity}:\n {err}"));
if result.is_err() {
identity_failed = true;
}
self.handle_eval_result(result);
}
if !self.progress {
@@ -169,7 +175,7 @@ where
break;
}
}
if self.next.iter().any(|v| v.is_none()) {
if identity_failed && self.next.iter().any(|v| v.is_none()) {
eprintln!(
"Error: Row {next_row}: Unable to derive values for committed polynomials: {}",
self.next
@@ -183,12 +189,32 @@ where
.collect::<Vec<String>>()
.join(", ")
);
eprintln!("Reasons: {}", self.failure_reasons.join("\n"));
eprintln!("Reasons: {}", self.failure_reasons.join("\n\n"));
eprintln!(
"Current values:\n{}",
self.next
.iter()
.enumerate()
.map(|(i, v)| format!(
"{} = {}",
self.committed_names[i],
v.as_ref()
.map(|v| format!("{v}"))
.unwrap_or("<unknown>".to_string())
))
.collect::<Vec<_>>()
.join("\n")
);
panic!();
} else {
std::mem::swap(&mut self.next, &mut self.current);
self.next = vec![None; self.current.len()];
self.current.iter().map(|v| v.clone().unwrap()).collect()
// TODO check a bit better that "None" values do not
// violate constraints.
self.current
.iter()
.map(|v| v.clone().unwrap_or_default())
.collect()
}
}
@@ -196,7 +222,7 @@ where
&mut self,
column: &&WitnessColumn,
) -> Result<Vec<(usize, AbstractNumberType)>, String> {
let query = self.interpolate_query(column.query.unwrap());
let query = self.interpolate_query(column.query.unwrap())?;
if let Some(value) = self.query_callback.as_mut().unwrap()(&query) {
Ok(vec![(column.id, value)])
} else {
@@ -207,28 +233,28 @@ where
}
}
fn interpolate_query(&self, query: &Expression) -> String {
fn interpolate_query(&self, query: &Expression) -> Result<String, String> {
if let Ok(v) = self.evaluate(query, EvaluationRow::Next) {
if v.is_constant() {
return Ok(self.format_affine_expression(v));
}
}
// TODO combine that with the constant evaluator and the commit evaluator...
match query {
Expression::Tuple(items) => items
Expression::Tuple(items) => Ok(items
.iter()
.map(|i| self.interpolate_query(i))
.collect::<Vec<_>>()
.join(", "),
.collect::<Result<Vec<_>, _>>()?
.join(", ")),
Expression::LocalVariableReference(i) => {
assert!(*i == 0);
format!("{}", self.next_row)
Ok(format!("{}", self.next_row))
}
Expression::Constant(_) => todo!(),
Expression::PolynomialReference(_) => todo!(),
Expression::PublicReference(_) => todo!(),
Expression::Number(n) => format!("{n}"),
Expression::String(s) => {
format!("\"{}\"", s.replace('\\', "\\\\").replace('"', "\\\""))
}
Expression::BinaryOperation(_, _, _) => todo!(),
Expression::UnaryOperation(_, _) => todo!(),
Expression::FunctionCall(_, _) => todo!(),
Expression::String(s) => Ok(format!(
"\"{}\"",
s.replace('\\', "\\\\").replace('"', "\\\"")
)),
_ => Err(format!("Cannot handle / evaluate {query}")),
}
}
@@ -241,12 +267,16 @@ where
EvaluationRow::Next
};
let evaluated = self.evaluate(identity, row)?;
match evaluated.solve() {
Some((id, value)) => Ok(vec![(id, value)]),
None => Err(format!(
"Could not solve expression {}",
self.format_affine_expression(evaluated)
)),
if evaluated.constant_value() == Some(0.into()) {
Ok(vec![])
} else {
match evaluated.solve() {
Some((id, value)) => Ok(vec![(id, value)]),
None => Err(format!(
"Could not solve expression {} (might be an invalid constraint)",
self.format_affine_expression(evaluated)
)),
}
}
}
@@ -345,7 +375,7 @@ where
match evaluated.solve() {
Some((id, value)) => Ok(vec![(id, value)]),
None => Err(format!(
"Could not solve expression {}",
"Could not solve expression {} (might be an invalid constraint)",
self.format_affine_expression(evaluated)
)),
}
@@ -431,11 +461,13 @@ where
self.evaluate_binary_operation(left, op, right, row)
}
Expression::UnaryOperation(op, expr) => self.evaluate_unary_operation(op, expr, row),
Expression::Tuple(_) => panic!(),
Expression::String(_) => panic!(),
Expression::LocalVariableReference(_) => panic!(),
Expression::PublicReference(_) => panic!(),
Expression::FunctionCall(_, _) => panic!(),
Expression::Tuple(_) => Err("Tuple not implemented.".to_string()),
Expression::String(_) => Err("String not implemented.".to_string()),
Expression::LocalVariableReference(_) => {
Err("Local variable references not implemented.".to_string())
}
Expression::PublicReference(_) => Err("Public references not implemented.".to_string()),
Expression::FunctionCall(_, _) => Err("Function calls not implemented.".to_string()),
}
}
@@ -555,7 +587,16 @@ where
.iter()
.enumerate()
.filter(|(_, c)| !is_zero(c))
.map(|(i, c)| format!("{} * {c}", self.committed_names[i]))
.map(|(i, c)| {
let name = self.committed_names[i];
if *c == 1.into() {
name.clone()
} else if *c == (-1).into() {
format!("-{name}")
} else {
format!("{c} * {name}")
}
})
.chain(e.constant_value().map(|v| format!("{v}")))
.collect::<Vec<_>>()
.join(" + ")

View File

@@ -1,26 +1,70 @@
use std::{path::Path, process::Command};
use std::{fs, path::Path, process::Command};
use powdr::compiler;
use powdr::number::AbstractNumberType;
fn verify(file_name: &str, query_callback: Option<fn(&str) -> Option<AbstractNumberType>>) {
fn verify_pil(file_name: &str, query_callback: Option<fn(&str) -> Option<AbstractNumberType>>) {
let input_file = Path::new(&format!("./tests/{file_name}"))
.canonicalize()
.unwrap();
let temp_dir = mktemp::Temp::new_dir().unwrap();
compiler::compile_pil(&input_file, &temp_dir, query_callback);
assert!(compiler::compile_pil(
&input_file,
&temp_dir,
query_callback
));
verify(file_name, &temp_dir);
}
fn verify_asm(file_name: &str, inputs: Vec<AbstractNumberType>) {
let contents = fs::read_to_string(format!("./tests/{file_name}")).unwrap();
let pil = powdr::asm_compiler::compile(Some(file_name), &contents).unwrap();
let pil_file_name = "asm.pil";
let temp_dir = mktemp::Temp::new_dir().unwrap();
assert!(compiler::compile_pil_ast(
&pil,
pil_file_name,
&temp_dir,
Some(|input: &str| {
let items = input.split(',').map(|s| s.trim()).collect::<Vec<_>>();
let mut it = items.iter();
let _current_step = it.next().unwrap();
let current_pc = it.next().unwrap();
while let Some(pc_check) = it.next() {
if pc_check == current_pc {
assert_eq!(*it.next().unwrap(), "\"input\"");
let index: usize = it.next().map(|s| s.parse().unwrap()).unwrap();
return Some(inputs[index].clone());
} else {
it.next();
it.next();
}
}
None
}),
));
verify(pil_file_name, &temp_dir);
}
fn verify(file_name: &str, temp_dir: &Path) {
let pilcom = std::env::var("PILCOM")
.expect("Please set the PILCOM environment variable to the path to the pilcom repository.");
let constants_file = format!("{}/constants.bin", temp_dir.to_string_lossy());
let commits_file = format!("{}/commits.bin", temp_dir.to_string_lossy());
assert!(
fs::metadata(&constants_file).unwrap().len() > 0,
"Empty constants file"
);
let verifier_output = Command::new("node")
.args([
format!("{pilcom}/src/main_pilverifier.js"),
format!("{}/commits.bin", temp_dir.as_path().to_string_lossy()),
commits_file,
"-j".to_string(),
format!("{}/{file_name}.json", temp_dir.as_path().to_string_lossy()),
format!("{}/{file_name}.json", temp_dir.to_string_lossy()),
"-c".to_string(),
format!("{}/constants.bin", temp_dir.as_path().to_string_lossy()),
constants_file,
])
.output()
.expect("failed to run pil verifier");
@@ -36,28 +80,26 @@ fn verify(file_name: &str, query_callback: Option<fn(&str) -> Option<AbstractNum
panic!("Verified did not say 'PIL OK': {output}");
}
}
drop(temp_dir);
}
#[test]
fn test_fibonacci() {
verify("fibonacci.pil", None);
verify_pil("fibonacci.pil", None);
}
#[test]
fn test_fibonacci_macro() {
verify("fib_macro.pil", None);
verify_pil("fib_macro.pil", None);
}
#[test]
fn test_global() {
verify("global.pil", None);
verify_pil("global.pil", None);
}
#[test]
fn test_sum_via_witness_query() {
verify(
verify_pil(
"sum_via_witness_query.pil",
Some(|q| {
match q {
@@ -73,7 +115,7 @@ fn test_sum_via_witness_query() {
#[test]
fn test_witness_lookup() {
verify(
verify_pil(
"witness_lookup.pil",
Some(|q| match q {
"\"input\", 0" => Some(3.into()),
@@ -83,3 +125,12 @@ fn test_witness_lookup() {
}),
);
}
#[test]
#[ignore]
fn simple_sum_asm() {
verify_asm(
"simple_sum.asm",
[16, 4, 1, 2, 8, 5].iter().map(|&x| x.into()).collect(),
);
}

View File

@@ -17,6 +17,7 @@ pil{
col witness XIsZero;
XIsZero = 1 - X * XInv;
XIsZero * X = 0;
XIsZero * (1 - XIsZero) = 0;
}
instr jmpz <=X= c, l: label { pc' = XIsZero * l + (1 - XIsZero) * (pc + 1) }