From 044de121ffbe540d296366a512c72fd5758afa8f Mon Sep 17 00:00:00 2001 From: chriseth Date: Mon, 19 May 2025 14:40:55 +0200 Subject: [PATCH] Prover functions for small arith. (#2608) Tests working for "small": - memory - add_sub - rotate - shift - memory - arith256 Not working: - keccak - poseidon - poseidon2 --------- Co-authored-by: Georg Wiese --- executor/src/witgen/jit/compiler.rs | 10 ++- executor/src/witgen/jit/effect.rs | 16 ++-- executor/src/witgen/jit/interpreter.rs | 13 +++- executor/src/witgen/jit/witgen_inference.rs | 19 +++-- std/machines/small_field/arith.asm | 81 +++++++-------------- std/machines/small_field/arith256.asm | 2 +- std/prover.asm | 8 +- 7 files changed, 73 insertions(+), 76 deletions(-) diff --git a/executor/src/witgen/jit/compiler.rs b/executor/src/witgen/jit/compiler.rs index ee871b71a..081cb7118 100644 --- a/executor/src/witgen/jit/compiler.rs +++ b/executor/src/witgen/jit/compiler.rs @@ -418,7 +418,13 @@ fn format_effect(effect: &Effect, is_top_level: bo format!( "{}[{}] = prover_function_{function_index}(mutable_state, input_from_channel, output_to_channel, row_offset + {row_offset}, &[{}]);", if is_top_level { "let " } else { "" }, - targets.iter().map(variable_to_string).format(", "), + targets.iter().map(|v| + if let Some(v) = v { + variable_to_string(v) + } else { + "_".to_string() + } + ).format(", "), inputs.iter().map(variable_to_string).format(", ") ) } @@ -1230,7 +1236,7 @@ extern \"C\" fn witgen( let y = cell("y", 1, 0); let z = cell("z", 2, 0); let effects = vec![Effect::ProverFunctionCall(ProverFunctionCall { - targets: vec![x.clone()], + targets: vec![Some(x.clone())], function_index: 0, row_offset: 0, inputs: vec![y.clone(), z.clone()], diff --git a/executor/src/witgen/jit/effect.rs b/executor/src/witgen/jit/effect.rs index 8228d1abf..3b27b785b 100644 --- a/executor/src/witgen/jit/effect.rs +++ b/executor/src/witgen/jit/effect.rs @@ -78,7 +78,7 @@ impl Effect { .flat_map(|(v, known)| (!known).then_some(v)), ), Effect::ProverFunctionCall(ProverFunctionCall { targets, .. }) => { - Box::new(targets.iter()) + Box::new(targets.iter().flatten()) } Effect::Branch(_, first, second) => { Box::new(first.iter().chain(second).flat_map(|e| e.written_vars())) @@ -106,7 +106,7 @@ impl Effect { Effect::MachineCall(_, _, args) => Box::new(args.iter()), Effect::ProverFunctionCall(ProverFunctionCall { targets, inputs, .. - }) => Box::new(targets.iter().chain(inputs)), + }) => Box::new(targets.iter().flatten().chain(inputs)), Effect::Branch(branch_condition, first, second) => Box::new( branch_condition.value.referenced_symbols().chain( [first, second] @@ -122,8 +122,8 @@ impl Effect { #[derive(Clone, PartialEq, Eq)] pub struct ProverFunctionCall { - /// Which variables to assign the result to. - pub targets: Vec, + /// Which variables to assign the result to. If an element is None, it is ignored. + pub targets: Vec>, /// The index of the prover function in the list. pub function_index: usize, /// The row offset to supply to the prover function. @@ -173,7 +173,13 @@ pub fn format_code(effects: &[Effect]) -> String { }) => { format!( "[{}] = prover_function_{function_index}({row_offset}, [{}]);", - targets.iter().join(", "), + targets + .iter() + .map(|v| v + .as_ref() + .map(|v| v.to_string()) + .unwrap_or_else(|| "_".to_string())) + .join(", "), inputs.iter().join(", ") ) } diff --git a/executor/src/witgen/jit/interpreter.rs b/executor/src/witgen/jit/interpreter.rs index cf6ecd207..42ade0c00 100644 --- a/executor/src/witgen/jit/interpreter.rs +++ b/executor/src/witgen/jit/interpreter.rs @@ -152,7 +152,7 @@ enum MachineCallArgumentIdx { /// Version of ``effect::ProverFunctionCall`` with variables replaced by their indices. #[derive(Debug)] struct IndexedProverFunctionCall { - pub targets: Vec, + pub targets: Vec>, pub function_index: usize, pub row_offset: i32, pub inputs: Vec, @@ -273,7 +273,10 @@ impl<'a, T: FieldElement> EffectsInterpreter<'a, T> { row_offset, inputs, }) => { - let targets = targets.iter().map(|v| var_mapper.map_var(v)).collect(); + let targets = targets + .iter() + .map(|v| v.as_ref().map(|v| var_mapper.map_var(v))) + .collect(); let inputs = inputs.iter().map(|v| var_mapper.map_var(v)).collect(); InterpreterAction::ProverFunctionCall(IndexedProverFunctionCall { targets, @@ -402,7 +405,9 @@ impl<'a, T: FieldElement> EffectsInterpreter<'a, T> { let result = self.evaluate_prover_function(call, row_offset, inputs, fixed_data); for (idx, val) in call.targets.iter().zip_eq(result) { - vars[*idx] = val; + if let Some(idx) = idx { + vars[*idx] = val; + } } } InterpreterAction::Assertion(e1, e2, expected_equal) => { @@ -587,7 +592,7 @@ impl InterpreterAction { } }), InterpreterAction::ProverFunctionCall(call) => { - set.extend(call.targets.iter().copied()); + set.extend(call.targets.iter().flatten().copied()); } InterpreterAction::Branch(_branch_test, if_actions, else_actions) => { set.extend( diff --git a/executor/src/witgen/jit/witgen_inference.rs b/executor/src/witgen/jit/witgen_inference.rs index 06ca20644..e80f83048 100644 --- a/executor/src/witgen/jit/witgen_inference.rs +++ b/executor/src/witgen/jit/witgen_inference.rs @@ -196,8 +196,8 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F .iter() .map(|t| Variable::from_reference(t, row_offset)) .collect::>(); - // Only continue if none of the targets are known. - if targets.iter().any(|t| self.is_known(t)) { + // Continue if at least one of the targets is unknown. + if targets.iter().all(|t| self.is_known(t)) { return Ok(vec![]); } let inputs = prover_function @@ -218,7 +218,10 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F } let effect = Effect::ProverFunctionCall(ProverFunctionCall { - targets, + targets: targets + .into_iter() + .map(|v| (!self.is_known(&v)).then_some(v)) + .collect(), function_index: prover_function.index, row_offset, inputs, @@ -335,7 +338,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F inputs, }) => { let mut some_known = false; - for t in targets { + for t in targets.iter().flatten() { if self.record_known(t.clone()) { some_known = true; updated_variables.push(t.clone()); @@ -344,7 +347,13 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator> WitgenInference<'a, T, F if some_known { log::trace!( "[{}] := prover_function_{function_index}({row_offset}, {})", - targets.iter().format(", "), + targets + .iter() + .map(|v| v + .as_ref() + .map(|v| v.to_string()) + .unwrap_or_else(|| "_".to_string())) + .format(", "), inputs.iter().format(", ") ); diff --git a/std/machines/small_field/arith.asm b/std/machines/small_field/arith.asm index a70d36909..fd9b9e29d 100644 --- a/std/machines/small_field/arith.asm +++ b/std/machines/small_field/arith.asm @@ -40,54 +40,6 @@ machine Arith(byte: Byte, byte2: Byte2) with // Constrain that y2 = 0 when operation is div. array::new(4, |i| is_division * y2[i] = 0); - // We need to provide hints for the quotient and remainder, because they are not unique under our current constraints. - // They are unique given additional main machine constraints, but it's still good to provide hints for the solver. - let quotient_hint = query |limb| match(eval(is_division)) { - 1 => { - if x1_int() == 0 { - // Quotient is unconstrained, use zero. - Query::Hint(0) - } else { - let y3 = y3_int(); - let x1 = x1_int(); - let quotient = y3 / x1; - Query::Hint(fe(select_limb(quotient, limb))) - } - }, - _ => Query::None - }; - - col witness y1_0(i) query quotient_hint(0); - col witness y1_1(i) query quotient_hint(1); - col witness y1_2(i) query quotient_hint(2); - col witness y1_3(i) query quotient_hint(3); - - let y1: expr[] = [y1_0, y1_1, y1_2, y1_3]; - - let remainder_hint = query |limb| match(eval(is_division)) { - 1 => { - let y3 = y3_int(); - let x1 = x1_int(); - if x1 == 0 { - // To satisfy x1 * y1 + x2 = y3, we need to set x2 = y3. - Query::Hint(fe(select_limb(y3, limb))) - } else { - let remainder = y3 % x1; - Query::Hint(fe(select_limb(remainder, limb))) - } - }, - _ => Query::None - }; - - col witness x2_0(i) query remainder_hint(0); - col witness x2_1(i) query remainder_hint(1); - col witness x2_2(i) query remainder_hint(2); - col witness x2_3(i) query remainder_hint(3); - - let x2: expr[] = [x2_0, x2_1, x2_2, x2_3]; - - pol commit x1[4], y2[4], y3[4]; - // Selects the ith limb of x (little endian) // All limbs are 8 bits let select_limb = |x, i| if i >= 0 { @@ -96,13 +48,32 @@ machine Arith(byte: Byte, byte2: Byte2) with 0 }; - let limbs_to_int: expr[] -> int = query |limbs| array::sum(array::map_enumerated(limbs, |i, limb| int(eval(limb)) << (i * 8))); + let limbs_to_int: fe[] -> int = |limbs| array::sum(array::map_enumerated(limbs, |i, limb| int(limb) << (i * 8))); + let int_to_limbs: int -> fe[] = |x| array::new(4, |i| fe(select_limb(x, i))); - let x1_int = query || limbs_to_int(x1); - let y1_int = query || limbs_to_int(y1); - let x2_int = query || limbs_to_int(x2); - let y2_int = query || limbs_to_int(y2); - let y3_int = query || limbs_to_int(y3); + // We need to provide hints for the quotient and remainder, because they are not unique under our current constraints. + // They are unique given additional main machine constraints, but it's still good to provide hints for the solver. + query |i| std::prover::compute_from_multi_if( + is_division = 1, + y1 + x2, + i, + y3 + x1, + |values| { + let y3_value = limbs_to_int([values[0], values[1], values[2], values[3]]); + let x1_value = limbs_to_int([values[4], values[5], values[6], values[7]]); + if x1_value == 0 { + // Quotient is unconstrained, use zero for y1 + // and set remainder x2 = y3. + [0, 0, 0, 0] + int_to_limbs(y3_value) + } else { + let quotient = y3_value / x1_value; + let remainder = y3_value % x1_value; + int_to_limbs(quotient) + int_to_limbs(remainder) + } + } + ); + + pol commit x1[4], x2[4], y1[4], y2[4], y3[4]; let combine: expr[] -> expr[] = |x| array::new(array::len(x) / 2, |i| x[2 * i + 1] * 2**8 + x[2 * i]); // Intermediate polynomials, arrays of 16 columns, 16 bit per column. @@ -172,6 +143,6 @@ machine Arith(byte: Byte, byte2: Byte2) with carry * CLK8[0] = 0; // Putting everything together - col eq0_sum = sum(8, |i| eq0(i) * CLK8[i]); + let eq0_sum = sum(8, |i| eq0(i) * CLK8[i]); eq0_sum + carry = carry' * 2**8; } diff --git a/std/machines/small_field/arith256.asm b/std/machines/small_field/arith256.asm index 93bcb0563..88f49eeea 100644 --- a/std/machines/small_field/arith256.asm +++ b/std/machines/small_field/arith256.asm @@ -141,7 +141,7 @@ machine Arith256 with * *******/ - col eq0_sum = sum(64, |i| eq0(i) * CLK64[i]); + let eq0_sum = sum(64, |i| eq0(i) * CLK64[i]); eq0_sum + carry = carry' * 2**8; } diff --git a/std/prover.asm b/std/prover.asm index 476bb12a8..36d2e4305 100644 --- a/std/prover.asm +++ b/std/prover.asm @@ -20,9 +20,9 @@ let provide_if_unknown: expr, int, (-> fe) -> () = query |column, row, f| match _ => (), }; -/// Returns true if all the provided columns are unknown. -let all_unknown: expr[] -> bool = query |columns| - std::array::fold(columns, true, |acc, c| acc && match try_eval(c) { +/// Returns true if at least one of the provided columns is unknown. +let some_unknown: expr[] -> bool = query |columns| + std::array::fold(columns, false, |acc, c| acc || match try_eval(c) { Option::None => true, Option::Some(_) => false, }); @@ -33,7 +33,7 @@ let compute_from: expr, int, expr[], (fe[] -> fe) -> () = query |dest_col, row, /// Computes the value of multiple columns in a row based on the values of other columns. let compute_from_multi: expr[], int, expr[], (fe[] -> fe[]) -> () = query |dest_cols, row, input_cols, f| - if all_unknown(dest_cols) { + if some_unknown(dest_cols) { let values = f(std::array::map(input_cols, eval)); let _ = std::array::zip(dest_cols, values, |c, v| provide_value(c, row, v)); } else {