From c0e2952506a69e414742b90ce4b69aa9987060dc Mon Sep 17 00:00:00 2001 From: chriseth Date: Tue, 18 Apr 2023 19:41:34 +0200 Subject: [PATCH] Use match for prover query. --- compiler/src/lib.rs | 22 ++++++++++------------ compiler/src/verify.rs | 25 ++++++++++++------------- executor/src/analyzer/pil_analyzer.rs | 2 +- executor/src/witgen/generator.rs | 19 +++++++++++++++++++ pilgen/src/lib.rs | 27 ++++++++++++--------------- 5 files changed, 54 insertions(+), 41 deletions(-) diff --git a/compiler/src/lib.rs b/compiler/src/lib.rs index 846f098f1..c5ccf6c84 100644 --- a/compiler/src/lib.rs +++ b/compiler/src/lib.rs @@ -7,8 +7,6 @@ use std::path::Path; mod verify; pub use verify::{verify, verify_asm_string}; -use itertools::Itertools; - use executor::{analyzer, constant_evaluator, json_exporter}; use number::{DegreeType, FieldElement}; use parser::ast::PILFile; @@ -91,18 +89,18 @@ pub fn compile_asm_string( let query_callback = |query: &str| -> Option { let items = query.split(',').map(|s| s.trim()).collect::>(); - let mut it = items.iter(); - let _current_step = it.next().unwrap(); - let current_pc = it.next().unwrap(); - assert!(it.clone().len() % 3 == 0); - for (pc_check, input, index) in it.tuples() { - if pc_check == current_pc { - assert_eq!(*input, "\"input\""); - let index: usize = index.parse().unwrap(); - return inputs.get(index).cloned(); + assert_eq!(items.len(), 2); + match items[0] { + "\"input\"" => { + let index = items[1].parse::().unwrap(); + let value = inputs.get(index).cloned(); + if let Some(value) = value { + log::trace!("Input query: Index {index} -> {value}"); + } + value } + _ => None, } - None }; compile_pil_ast( &pil, diff --git a/compiler/src/verify.rs b/compiler/src/verify.rs index 77b29e89d..ee04b095e 100644 --- a/compiler/src/verify.rs +++ b/compiler/src/verify.rs @@ -1,6 +1,5 @@ use std::{fs, path::Path, process::Command}; -use itertools::Itertools; use number::FieldElement; #[allow(unused)] @@ -12,20 +11,20 @@ pub fn verify_asm_string(file_name: &str, contents: &str, inputs: Vec>(); - let mut it = items.iter(); - let _current_step = it.next().unwrap(); - let current_pc = it.next().unwrap(); - assert!(it.clone().len() % 3 == 0); - for (pc_check, input, index) in it.tuples() { - if pc_check == current_pc { - assert_eq!(*input, "\"input\""); - let index: usize = index.parse().unwrap(); - return inputs.get(index).cloned(); + Some(|query: &str| { + let items = query.split(',').map(|s| s.trim()).collect::>(); + assert_eq!(items.len(), 2); + match items[0] { + "\"input\"" => { + let index = items[1].parse::().unwrap(); + let value = inputs.get(index).cloned(); + if let Some(value) = value { + log::trace!("Input query: Index {index} -> {value}"); + } + value } + _ => None, } - None }), )); verify(pil_file_name, &temp_dir); diff --git a/executor/src/analyzer/pil_analyzer.rs b/executor/src/analyzer/pil_analyzer.rs index 39a31dc65..0a130e70f 100644 --- a/executor/src/analyzer/pil_analyzer.rs +++ b/executor/src/analyzer/pil_analyzer.rs @@ -516,7 +516,7 @@ impl PILContext { ( n.as_ref().map(|n| { self.evaluate_expression(n).unwrap_or_else(|| { - panic!("Left side of match arm must be a constant, found {}", n) + panic!("Left side of match arm must be a constant, found {n}") }) }), self.process_expression(e), diff --git a/executor/src/witgen/generator.rs b/executor/src/witgen/generator.rs index a36807501..ddb7337b5 100644 --- a/executor/src/witgen/generator.rs +++ b/executor/src/witgen/generator.rs @@ -279,10 +279,29 @@ where "\"{}\"", s.replace('\\', "\\\\").replace('"', "\\\"") )), + Expression::MatchExpression(scrutinee, arms) => self + .interpolate_match_expression_for_query(scrutinee.as_ref(), arms) + .map_err(|e| format!("Cannot handle / evaluate {query}: {e}")), _ => Err(format!("Cannot handle / evaluate {query}")), } } + fn interpolate_match_expression_for_query( + &self, + scrutinee: &Expression, + arms: &[(Option, Expression)], + ) -> Result { + let v = self + .evaluate(scrutinee, EvaluationRow::Next)? + .constant_value() + .ok_or_else(|| "Match scrutinee not constant".to_string())?; + let (_, expr) = arms + .iter() + .find(|(n, _)| n.is_none() || n.as_ref() == Some(&v)) + .ok_or_else(|| format!("Match arm not found for value {v}"))?; + self.interpolate_query(expr).map_err(|e| e.into()) + } + fn process_polynomial_identity(&self, identity: &Expression) -> EvalResult { // If there is no "next" reference in the expression, // we just evaluate it directly on the "next" row. diff --git a/pilgen/src/lib.rs b/pilgen/src/lib.rs index c2093e94c..fc13c9367 100644 --- a/pilgen/src/lib.rs +++ b/pilgen/src/lib.rs @@ -539,17 +539,9 @@ impl ASMPILConverter { .iter() .map(|n| (n, vec![FieldElement::from(0); self.code_lines.len()])) .collect::>(); - let mut free_value_queries = self + let mut free_value_query_arms = self .assignment_registers() - .map(|r| { - ( - r.clone(), - vec![ - direct_reference("i"), - direct_reference(self.pc_name.as_ref().unwrap()), - ], - ) - }) + .map(|r| (r.clone(), vec![])) .collect::>(); let label_positions = self.compute_label_positions(); @@ -584,9 +576,10 @@ impl ASMPILConverter { program_constants .get_mut(&format!("p_{assign_reg}_read_free")) .unwrap()[i] = *coeff; - free_value_queries.get_mut(assign_reg).unwrap().push( - Expression::Tuple(vec![build_number(i as u64), expr.clone()]), - ); + free_value_query_arms.get_mut(assign_reg).unwrap().push(( + Some(build_number(FieldElement::from(i as u64))), + expr.clone(), + )); } } } @@ -620,6 +613,7 @@ impl ASMPILConverter { assert!(line.instruction_literal_args.is_empty()); } } + let pc_name = self.pc_name.clone(); let free_value_pil = self .assignment_registers() .map(|reg| { @@ -629,7 +623,10 @@ impl ASMPILConverter { free_value, Some(FunctionDefinition::Query( vec!["i".to_string()], - Expression::Tuple(free_value_queries[reg].clone()), + Expression::MatchExpression( + Box::new(direct_reference(pc_name.as_ref().unwrap())), + free_value_query_arms[reg].clone(), + ), )), ) }) @@ -950,7 +947,7 @@ A' = (((first_step' * 0) + (reg_write_X_A * X)) + ((1 - (first_step' + reg_write CNT' = ((((first_step' * 0) + (reg_write_X_CNT * X)) + (instr_dec_CNT * (CNT - 1))) + ((1 - ((first_step' + reg_write_X_CNT) + instr_dec_CNT)) * CNT)); pc' = ((1 - first_step') * (((instr_jmpz * ((XIsZero * instr_jmpz_param_l) + ((1 - XIsZero) * (pc + 1)))) + (instr_jmp * instr_jmp_param_l)) + ((1 - (instr_jmpz + instr_jmp)) * (pc + 1)))); pol constant line(i) { i }; -pol commit X_free_value(i) query (i, pc, (0, ("input", 1)), (3, ("input", (CNT + 1))), (7, ("input", 0))); +pol commit X_free_value(i) query match pc { 0 => ("input", 1), 3 => ("input", (CNT + 1)), 7 => ("input", 0), }; pol constant p_X_const = [0, 0, 0, 0, 0, 0, 0, 0, 0] + [0]*; pol constant p_X_read_free = [1, 0, 0, 1, 0, 0, 0, -1, 0] + [0]*; pol constant p_instr_assert_zero = [0, 0, 0, 0, 0, 0, 0, 0, 1] + [0]*;