Prover functions for small arith. (#2608)

Tests working for "small":
- memory
- add_sub
- rotate
- shift
- memory
- arith256

Not working:
- keccak
- poseidon
- poseidon2

---------

Co-authored-by: Georg Wiese <georgwiese@gmail.com>
This commit is contained in:
chriseth
2025-05-19 14:40:55 +02:00
committed by GitHub
parent 76bbebb53c
commit 044de121ff
7 changed files with 73 additions and 76 deletions

View File

@@ -418,7 +418,13 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bo
format!(
"{}[{}] = prover_function_{function_index}(mutable_state, input_from_channel, output_to_channel, row_offset + {row_offset}, &[{}]);",
if is_top_level { "let " } else { "" },
targets.iter().map(variable_to_string).format(", "),
targets.iter().map(|v|
if let Some(v) = v {
variable_to_string(v)
} else {
"_".to_string()
}
).format(", "),
inputs.iter().map(variable_to_string).format(", ")
)
}
@@ -1230,7 +1236,7 @@ extern \"C\" fn witgen(
let y = cell("y", 1, 0);
let z = cell("z", 2, 0);
let effects = vec![Effect::ProverFunctionCall(ProverFunctionCall {
targets: vec![x.clone()],
targets: vec![Some(x.clone())],
function_index: 0,
row_offset: 0,
inputs: vec![y.clone(), z.clone()],

View File

@@ -78,7 +78,7 @@ impl<T: FieldElement> Effect<T, Variable> {
.flat_map(|(v, known)| (!known).then_some(v)),
),
Effect::ProverFunctionCall(ProverFunctionCall { targets, .. }) => {
Box::new(targets.iter())
Box::new(targets.iter().flatten())
}
Effect::Branch(_, first, second) => {
Box::new(first.iter().chain(second).flat_map(|e| e.written_vars()))
@@ -106,7 +106,7 @@ impl<T: FieldElement, V: Hash + Eq> Effect<T, V> {
Effect::MachineCall(_, _, args) => Box::new(args.iter()),
Effect::ProverFunctionCall(ProverFunctionCall {
targets, inputs, ..
}) => Box::new(targets.iter().chain(inputs)),
}) => Box::new(targets.iter().flatten().chain(inputs)),
Effect::Branch(branch_condition, first, second) => Box::new(
branch_condition.value.referenced_symbols().chain(
[first, second]
@@ -122,8 +122,8 @@ impl<T: FieldElement, V: Hash + Eq> Effect<T, V> {
#[derive(Clone, PartialEq, Eq)]
pub struct ProverFunctionCall<V> {
/// Which variables to assign the result to.
pub targets: Vec<V>,
/// Which variables to assign the result to. If an element is None, it is ignored.
pub targets: Vec<Option<V>>,
/// The index of the prover function in the list.
pub function_index: usize,
/// The row offset to supply to the prover function.
@@ -173,7 +173,13 @@ pub fn format_code<T: FieldElement>(effects: &[Effect<T, Variable>]) -> String {
}) => {
format!(
"[{}] = prover_function_{function_index}({row_offset}, [{}]);",
targets.iter().join(", "),
targets
.iter()
.map(|v| v
.as_ref()
.map(|v| v.to_string())
.unwrap_or_else(|| "_".to_string()))
.join(", "),
inputs.iter().join(", ")
)
}

View File

@@ -152,7 +152,7 @@ enum MachineCallArgumentIdx {
/// Version of ``effect::ProverFunctionCall`` with variables replaced by their indices.
#[derive(Debug)]
struct IndexedProverFunctionCall {
pub targets: Vec<usize>,
pub targets: Vec<Option<usize>>,
pub function_index: usize,
pub row_offset: i32,
pub inputs: Vec<usize>,
@@ -273,7 +273,10 @@ impl<'a, T: FieldElement> EffectsInterpreter<'a, T> {
row_offset,
inputs,
}) => {
let targets = targets.iter().map(|v| var_mapper.map_var(v)).collect();
let targets = targets
.iter()
.map(|v| v.as_ref().map(|v| var_mapper.map_var(v)))
.collect();
let inputs = inputs.iter().map(|v| var_mapper.map_var(v)).collect();
InterpreterAction::ProverFunctionCall(IndexedProverFunctionCall {
targets,
@@ -402,7 +405,9 @@ impl<'a, T: FieldElement> EffectsInterpreter<'a, T> {
let result =
self.evaluate_prover_function(call, row_offset, inputs, fixed_data);
for (idx, val) in call.targets.iter().zip_eq(result) {
vars[*idx] = val;
if let Some(idx) = idx {
vars[*idx] = val;
}
}
}
InterpreterAction::Assertion(e1, e2, expected_equal) => {
@@ -587,7 +592,7 @@ impl<T: FieldElement> InterpreterAction<T> {
}
}),
InterpreterAction::ProverFunctionCall(call) => {
set.extend(call.targets.iter().copied());
set.extend(call.targets.iter().flatten().copied());
}
InterpreterAction::Branch(_branch_test, if_actions, else_actions) => {
set.extend(

View File

@@ -196,8 +196,8 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
.iter()
.map(|t| Variable::from_reference(t, row_offset))
.collect::<Vec<_>>();
// Only continue if none of the targets are known.
if targets.iter().any(|t| self.is_known(t)) {
// Continue if at least one of the targets is unknown.
if targets.iter().all(|t| self.is_known(t)) {
return Ok(vec![]);
}
let inputs = prover_function
@@ -218,7 +218,10 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
}
let effect = Effect::ProverFunctionCall(ProverFunctionCall {
targets,
targets: targets
.into_iter()
.map(|v| (!self.is_known(&v)).then_some(v))
.collect(),
function_index: prover_function.index,
row_offset,
inputs,
@@ -335,7 +338,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
inputs,
}) => {
let mut some_known = false;
for t in targets {
for t in targets.iter().flatten() {
if self.record_known(t.clone()) {
some_known = true;
updated_variables.push(t.clone());
@@ -344,7 +347,13 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
if some_known {
log::trace!(
"[{}] := prover_function_{function_index}({row_offset}, {})",
targets.iter().format(", "),
targets
.iter()
.map(|v| v
.as_ref()
.map(|v| v.to_string())
.unwrap_or_else(|| "_".to_string()))
.format(", "),
inputs.iter().format(", ")
);

View File

@@ -40,54 +40,6 @@ machine Arith(byte: Byte, byte2: Byte2) with
// Constrain that y2 = 0 when operation is div.
array::new(4, |i| is_division * y2[i] = 0);
// We need to provide hints for the quotient and remainder, because they are not unique under our current constraints.
// They are unique given additional main machine constraints, but it's still good to provide hints for the solver.
let quotient_hint = query |limb| match(eval(is_division)) {
1 => {
if x1_int() == 0 {
// Quotient is unconstrained, use zero.
Query::Hint(0)
} else {
let y3 = y3_int();
let x1 = x1_int();
let quotient = y3 / x1;
Query::Hint(fe(select_limb(quotient, limb)))
}
},
_ => Query::None
};
col witness y1_0(i) query quotient_hint(0);
col witness y1_1(i) query quotient_hint(1);
col witness y1_2(i) query quotient_hint(2);
col witness y1_3(i) query quotient_hint(3);
let y1: expr[] = [y1_0, y1_1, y1_2, y1_3];
let remainder_hint = query |limb| match(eval(is_division)) {
1 => {
let y3 = y3_int();
let x1 = x1_int();
if x1 == 0 {
// To satisfy x1 * y1 + x2 = y3, we need to set x2 = y3.
Query::Hint(fe(select_limb(y3, limb)))
} else {
let remainder = y3 % x1;
Query::Hint(fe(select_limb(remainder, limb)))
}
},
_ => Query::None
};
col witness x2_0(i) query remainder_hint(0);
col witness x2_1(i) query remainder_hint(1);
col witness x2_2(i) query remainder_hint(2);
col witness x2_3(i) query remainder_hint(3);
let x2: expr[] = [x2_0, x2_1, x2_2, x2_3];
pol commit x1[4], y2[4], y3[4];
// Selects the ith limb of x (little endian)
// All limbs are 8 bits
let select_limb = |x, i| if i >= 0 {
@@ -96,13 +48,32 @@ machine Arith(byte: Byte, byte2: Byte2) with
0
};
let limbs_to_int: expr[] -> int = query |limbs| array::sum(array::map_enumerated(limbs, |i, limb| int(eval(limb)) << (i * 8)));
let limbs_to_int: fe[] -> int = |limbs| array::sum(array::map_enumerated(limbs, |i, limb| int(limb) << (i * 8)));
let int_to_limbs: int -> fe[] = |x| array::new(4, |i| fe(select_limb(x, i)));
let x1_int = query || limbs_to_int(x1);
let y1_int = query || limbs_to_int(y1);
let x2_int = query || limbs_to_int(x2);
let y2_int = query || limbs_to_int(y2);
let y3_int = query || limbs_to_int(y3);
// We need to provide hints for the quotient and remainder, because they are not unique under our current constraints.
// They are unique given additional main machine constraints, but it's still good to provide hints for the solver.
query |i| std::prover::compute_from_multi_if(
is_division = 1,
y1 + x2,
i,
y3 + x1,
|values| {
let y3_value = limbs_to_int([values[0], values[1], values[2], values[3]]);
let x1_value = limbs_to_int([values[4], values[5], values[6], values[7]]);
if x1_value == 0 {
// Quotient is unconstrained, use zero for y1
// and set remainder x2 = y3.
[0, 0, 0, 0] + int_to_limbs(y3_value)
} else {
let quotient = y3_value / x1_value;
let remainder = y3_value % x1_value;
int_to_limbs(quotient) + int_to_limbs(remainder)
}
}
);
pol commit x1[4], x2[4], y1[4], y2[4], y3[4];
let combine: expr[] -> expr[] = |x| array::new(array::len(x) / 2, |i| x[2 * i + 1] * 2**8 + x[2 * i]);
// Intermediate polynomials, arrays of 16 columns, 16 bit per column.
@@ -172,6 +143,6 @@ machine Arith(byte: Byte, byte2: Byte2) with
carry * CLK8[0] = 0;
// Putting everything together
col eq0_sum = sum(8, |i| eq0(i) * CLK8[i]);
let eq0_sum = sum(8, |i| eq0(i) * CLK8[i]);
eq0_sum + carry = carry' * 2**8;
}

View File

@@ -141,7 +141,7 @@ machine Arith256 with
*
*******/
col eq0_sum = sum(64, |i| eq0(i) * CLK64[i]);
let eq0_sum = sum(64, |i| eq0(i) * CLK64[i]);
eq0_sum + carry = carry' * 2**8;
}

View File

@@ -20,9 +20,9 @@ let provide_if_unknown: expr, int, (-> fe) -> () = query |column, row, f| match
_ => (),
};
/// Returns true if all the provided columns are unknown.
let all_unknown: expr[] -> bool = query |columns|
std::array::fold(columns, true, |acc, c| acc && match try_eval(c) {
/// Returns true if at least one of the provided columns is unknown.
let some_unknown: expr[] -> bool = query |columns|
std::array::fold(columns, false, |acc, c| acc || match try_eval(c) {
Option::None => true,
Option::Some(_) => false,
});
@@ -33,7 +33,7 @@ let compute_from: expr, int, expr[], (fe[] -> fe) -> () = query |dest_col, row,
/// Computes the value of multiple columns in a row based on the values of other columns.
let compute_from_multi: expr[], int, expr[], (fe[] -> fe[]) -> () = query |dest_cols, row, input_cols, f|
if all_unknown(dest_cols) {
if some_unknown(dest_cols) {
let values = f(std::array::map(input_cols, eval));
let _ = std::array::zip(dest_cols, values, |c, v| provide_value(c, row, v));
} else {