mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-05-13 03:00:26 -04:00
Merge pull request #175 from chriseth/use_match_for_prover_query
Use match for prover query.
This commit is contained in:
@@ -7,8 +7,6 @@ use std::path::Path;
|
||||
mod verify;
|
||||
pub use verify::{verify, verify_asm_string};
|
||||
|
||||
use itertools::Itertools;
|
||||
|
||||
use executor::{analyzer, constant_evaluator, json_exporter};
|
||||
use number::{DegreeType, FieldElement};
|
||||
use parser::ast::PILFile;
|
||||
@@ -91,18 +89,18 @@ pub fn compile_asm_string(
|
||||
|
||||
let query_callback = |query: &str| -> Option<FieldElement> {
|
||||
let items = query.split(',').map(|s| s.trim()).collect::<Vec<_>>();
|
||||
let mut it = items.iter();
|
||||
let _current_step = it.next().unwrap();
|
||||
let current_pc = it.next().unwrap();
|
||||
assert!(it.clone().len() % 3 == 0);
|
||||
for (pc_check, input, index) in it.tuples() {
|
||||
if pc_check == current_pc {
|
||||
assert_eq!(*input, "\"input\"");
|
||||
let index: usize = index.parse().unwrap();
|
||||
return inputs.get(index).cloned();
|
||||
assert_eq!(items.len(), 2);
|
||||
match items[0] {
|
||||
"\"input\"" => {
|
||||
let index = items[1].parse::<usize>().unwrap();
|
||||
let value = inputs.get(index).cloned();
|
||||
if let Some(value) = value {
|
||||
log::trace!("Input query: Index {index} -> {value}");
|
||||
}
|
||||
value
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
None
|
||||
};
|
||||
compile_pil_ast(
|
||||
&pil,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::{fs, path::Path, process::Command};
|
||||
|
||||
use itertools::Itertools;
|
||||
use number::FieldElement;
|
||||
|
||||
#[allow(unused)]
|
||||
@@ -12,20 +11,20 @@ pub fn verify_asm_string(file_name: &str, contents: &str, inputs: Vec<FieldEleme
|
||||
&pil,
|
||||
pil_file_name,
|
||||
&temp_dir,
|
||||
Some(|input: &str| {
|
||||
let items = input.split(',').map(|s| s.trim()).collect::<Vec<_>>();
|
||||
let mut it = items.iter();
|
||||
let _current_step = it.next().unwrap();
|
||||
let current_pc = it.next().unwrap();
|
||||
assert!(it.clone().len() % 3 == 0);
|
||||
for (pc_check, input, index) in it.tuples() {
|
||||
if pc_check == current_pc {
|
||||
assert_eq!(*input, "\"input\"");
|
||||
let index: usize = index.parse().unwrap();
|
||||
return inputs.get(index).cloned();
|
||||
Some(|query: &str| {
|
||||
let items = query.split(',').map(|s| s.trim()).collect::<Vec<_>>();
|
||||
assert_eq!(items.len(), 2);
|
||||
match items[0] {
|
||||
"\"input\"" => {
|
||||
let index = items[1].parse::<usize>().unwrap();
|
||||
let value = inputs.get(index).cloned();
|
||||
if let Some(value) = value {
|
||||
log::trace!("Input query: Index {index} -> {value}");
|
||||
}
|
||||
value
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
None
|
||||
}),
|
||||
));
|
||||
verify(pil_file_name, &temp_dir);
|
||||
|
||||
@@ -516,7 +516,7 @@ impl PILContext {
|
||||
(
|
||||
n.as_ref().map(|n| {
|
||||
self.evaluate_expression(n).unwrap_or_else(|| {
|
||||
panic!("Left side of match arm must be a constant, found {}", n)
|
||||
panic!("Left side of match arm must be a constant, found {n}")
|
||||
})
|
||||
}),
|
||||
self.process_expression(e),
|
||||
|
||||
@@ -279,10 +279,29 @@ where
|
||||
"\"{}\"",
|
||||
s.replace('\\', "\\\\").replace('"', "\\\"")
|
||||
)),
|
||||
Expression::MatchExpression(scrutinee, arms) => self
|
||||
.interpolate_match_expression_for_query(scrutinee.as_ref(), arms)
|
||||
.map_err(|e| format!("Cannot handle / evaluate {query}: {e}")),
|
||||
_ => Err(format!("Cannot handle / evaluate {query}")),
|
||||
}
|
||||
}
|
||||
|
||||
fn interpolate_match_expression_for_query(
|
||||
&self,
|
||||
scrutinee: &Expression,
|
||||
arms: &[(Option<FieldElement>, Expression)],
|
||||
) -> Result<String, EvalError> {
|
||||
let v = self
|
||||
.evaluate(scrutinee, EvaluationRow::Next)?
|
||||
.constant_value()
|
||||
.ok_or_else(|| "Match scrutinee not constant".to_string())?;
|
||||
let (_, expr) = arms
|
||||
.iter()
|
||||
.find(|(n, _)| n.is_none() || n.as_ref() == Some(&v))
|
||||
.ok_or_else(|| format!("Match arm not found for value {v}"))?;
|
||||
self.interpolate_query(expr).map_err(|e| e.into())
|
||||
}
|
||||
|
||||
fn process_polynomial_identity(&self, identity: &Expression) -> EvalResult {
|
||||
// If there is no "next" reference in the expression,
|
||||
// we just evaluate it directly on the "next" row.
|
||||
|
||||
@@ -539,17 +539,9 @@ impl ASMPILConverter {
|
||||
.iter()
|
||||
.map(|n| (n, vec![FieldElement::from(0); self.code_lines.len()]))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
let mut free_value_queries = self
|
||||
let mut free_value_query_arms = self
|
||||
.assignment_registers()
|
||||
.map(|r| {
|
||||
(
|
||||
r.clone(),
|
||||
vec![
|
||||
direct_reference("i"),
|
||||
direct_reference(self.pc_name.as_ref().unwrap()),
|
||||
],
|
||||
)
|
||||
})
|
||||
.map(|r| (r.clone(), vec![]))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
|
||||
let label_positions = self.compute_label_positions();
|
||||
@@ -584,9 +576,10 @@ impl ASMPILConverter {
|
||||
program_constants
|
||||
.get_mut(&format!("p_{assign_reg}_read_free"))
|
||||
.unwrap()[i] = *coeff;
|
||||
free_value_queries.get_mut(assign_reg).unwrap().push(
|
||||
Expression::Tuple(vec![build_number(i as u64), expr.clone()]),
|
||||
);
|
||||
free_value_query_arms.get_mut(assign_reg).unwrap().push((
|
||||
Some(build_number(FieldElement::from(i as u64))),
|
||||
expr.clone(),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -620,6 +613,7 @@ impl ASMPILConverter {
|
||||
assert!(line.instruction_literal_args.is_empty());
|
||||
}
|
||||
}
|
||||
let pc_name = self.pc_name.clone();
|
||||
let free_value_pil = self
|
||||
.assignment_registers()
|
||||
.map(|reg| {
|
||||
@@ -629,7 +623,10 @@ impl ASMPILConverter {
|
||||
free_value,
|
||||
Some(FunctionDefinition::Query(
|
||||
vec!["i".to_string()],
|
||||
Expression::Tuple(free_value_queries[reg].clone()),
|
||||
Expression::MatchExpression(
|
||||
Box::new(direct_reference(pc_name.as_ref().unwrap())),
|
||||
free_value_query_arms[reg].clone(),
|
||||
),
|
||||
)),
|
||||
)
|
||||
})
|
||||
@@ -950,7 +947,7 @@ A' = (((first_step' * 0) + (reg_write_X_A * X)) + ((1 - (first_step' + reg_write
|
||||
CNT' = ((((first_step' * 0) + (reg_write_X_CNT * X)) + (instr_dec_CNT * (CNT - 1))) + ((1 - ((first_step' + reg_write_X_CNT) + instr_dec_CNT)) * CNT));
|
||||
pc' = ((1 - first_step') * (((instr_jmpz * ((XIsZero * instr_jmpz_param_l) + ((1 - XIsZero) * (pc + 1)))) + (instr_jmp * instr_jmp_param_l)) + ((1 - (instr_jmpz + instr_jmp)) * (pc + 1))));
|
||||
pol constant line(i) { i };
|
||||
pol commit X_free_value(i) query (i, pc, (0, ("input", 1)), (3, ("input", (CNT + 1))), (7, ("input", 0)));
|
||||
pol commit X_free_value(i) query match pc { 0 => ("input", 1), 3 => ("input", (CNT + 1)), 7 => ("input", 0), };
|
||||
pol constant p_X_const = [0, 0, 0, 0, 0, 0, 0, 0, 0] + [0]*;
|
||||
pol constant p_X_read_free = [1, 0, 0, 1, 0, 0, 0, -1, 0] + [0]*;
|
||||
pol constant p_instr_assert_zero = [0, 0, 0, 0, 0, 0, 0, 0, 1] + [0]*;
|
||||
|
||||
Reference in New Issue
Block a user