mirror of
https://github.com/powdr-labs/powdr.git
synced 2026-01-10 01:28:12 -05:00
Prover functions for small arith. (#2608)
Tests working for "small": - memory - add_sub - rotate - shift - memory - arith256 Not working: - keccak - poseidon - poseidon2 --------- Co-authored-by: Georg Wiese <georgwiese@gmail.com>
This commit is contained in:
@@ -418,7 +418,13 @@ fn format_effect<T: FieldElement>(effect: &Effect<T, Variable>, is_top_level: bo
|
||||
format!(
|
||||
"{}[{}] = prover_function_{function_index}(mutable_state, input_from_channel, output_to_channel, row_offset + {row_offset}, &[{}]);",
|
||||
if is_top_level { "let " } else { "" },
|
||||
targets.iter().map(variable_to_string).format(", "),
|
||||
targets.iter().map(|v|
|
||||
if let Some(v) = v {
|
||||
variable_to_string(v)
|
||||
} else {
|
||||
"_".to_string()
|
||||
}
|
||||
).format(", "),
|
||||
inputs.iter().map(variable_to_string).format(", ")
|
||||
)
|
||||
}
|
||||
@@ -1230,7 +1236,7 @@ extern \"C\" fn witgen(
|
||||
let y = cell("y", 1, 0);
|
||||
let z = cell("z", 2, 0);
|
||||
let effects = vec![Effect::ProverFunctionCall(ProverFunctionCall {
|
||||
targets: vec![x.clone()],
|
||||
targets: vec![Some(x.clone())],
|
||||
function_index: 0,
|
||||
row_offset: 0,
|
||||
inputs: vec![y.clone(), z.clone()],
|
||||
|
||||
@@ -78,7 +78,7 @@ impl<T: FieldElement> Effect<T, Variable> {
|
||||
.flat_map(|(v, known)| (!known).then_some(v)),
|
||||
),
|
||||
Effect::ProverFunctionCall(ProverFunctionCall { targets, .. }) => {
|
||||
Box::new(targets.iter())
|
||||
Box::new(targets.iter().flatten())
|
||||
}
|
||||
Effect::Branch(_, first, second) => {
|
||||
Box::new(first.iter().chain(second).flat_map(|e| e.written_vars()))
|
||||
@@ -106,7 +106,7 @@ impl<T: FieldElement, V: Hash + Eq> Effect<T, V> {
|
||||
Effect::MachineCall(_, _, args) => Box::new(args.iter()),
|
||||
Effect::ProverFunctionCall(ProverFunctionCall {
|
||||
targets, inputs, ..
|
||||
}) => Box::new(targets.iter().chain(inputs)),
|
||||
}) => Box::new(targets.iter().flatten().chain(inputs)),
|
||||
Effect::Branch(branch_condition, first, second) => Box::new(
|
||||
branch_condition.value.referenced_symbols().chain(
|
||||
[first, second]
|
||||
@@ -122,8 +122,8 @@ impl<T: FieldElement, V: Hash + Eq> Effect<T, V> {
|
||||
|
||||
#[derive(Clone, PartialEq, Eq)]
|
||||
pub struct ProverFunctionCall<V> {
|
||||
/// Which variables to assign the result to.
|
||||
pub targets: Vec<V>,
|
||||
/// Which variables to assign the result to. If an element is None, it is ignored.
|
||||
pub targets: Vec<Option<V>>,
|
||||
/// The index of the prover function in the list.
|
||||
pub function_index: usize,
|
||||
/// The row offset to supply to the prover function.
|
||||
@@ -173,7 +173,13 @@ pub fn format_code<T: FieldElement>(effects: &[Effect<T, Variable>]) -> String {
|
||||
}) => {
|
||||
format!(
|
||||
"[{}] = prover_function_{function_index}({row_offset}, [{}]);",
|
||||
targets.iter().join(", "),
|
||||
targets
|
||||
.iter()
|
||||
.map(|v| v
|
||||
.as_ref()
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "_".to_string()))
|
||||
.join(", "),
|
||||
inputs.iter().join(", ")
|
||||
)
|
||||
}
|
||||
|
||||
@@ -152,7 +152,7 @@ enum MachineCallArgumentIdx {
|
||||
/// Version of ``effect::ProverFunctionCall`` with variables replaced by their indices.
|
||||
#[derive(Debug)]
|
||||
struct IndexedProverFunctionCall {
|
||||
pub targets: Vec<usize>,
|
||||
pub targets: Vec<Option<usize>>,
|
||||
pub function_index: usize,
|
||||
pub row_offset: i32,
|
||||
pub inputs: Vec<usize>,
|
||||
@@ -273,7 +273,10 @@ impl<'a, T: FieldElement> EffectsInterpreter<'a, T> {
|
||||
row_offset,
|
||||
inputs,
|
||||
}) => {
|
||||
let targets = targets.iter().map(|v| var_mapper.map_var(v)).collect();
|
||||
let targets = targets
|
||||
.iter()
|
||||
.map(|v| v.as_ref().map(|v| var_mapper.map_var(v)))
|
||||
.collect();
|
||||
let inputs = inputs.iter().map(|v| var_mapper.map_var(v)).collect();
|
||||
InterpreterAction::ProverFunctionCall(IndexedProverFunctionCall {
|
||||
targets,
|
||||
@@ -402,7 +405,9 @@ impl<'a, T: FieldElement> EffectsInterpreter<'a, T> {
|
||||
let result =
|
||||
self.evaluate_prover_function(call, row_offset, inputs, fixed_data);
|
||||
for (idx, val) in call.targets.iter().zip_eq(result) {
|
||||
vars[*idx] = val;
|
||||
if let Some(idx) = idx {
|
||||
vars[*idx] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
InterpreterAction::Assertion(e1, e2, expected_equal) => {
|
||||
@@ -587,7 +592,7 @@ impl<T: FieldElement> InterpreterAction<T> {
|
||||
}
|
||||
}),
|
||||
InterpreterAction::ProverFunctionCall(call) => {
|
||||
set.extend(call.targets.iter().copied());
|
||||
set.extend(call.targets.iter().flatten().copied());
|
||||
}
|
||||
InterpreterAction::Branch(_branch_test, if_actions, else_actions) => {
|
||||
set.extend(
|
||||
|
||||
@@ -196,8 +196,8 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
|
||||
.iter()
|
||||
.map(|t| Variable::from_reference(t, row_offset))
|
||||
.collect::<Vec<_>>();
|
||||
// Only continue if none of the targets are known.
|
||||
if targets.iter().any(|t| self.is_known(t)) {
|
||||
// Continue if at least one of the targets is unknown.
|
||||
if targets.iter().all(|t| self.is_known(t)) {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
let inputs = prover_function
|
||||
@@ -218,7 +218,10 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
|
||||
}
|
||||
|
||||
let effect = Effect::ProverFunctionCall(ProverFunctionCall {
|
||||
targets,
|
||||
targets: targets
|
||||
.into_iter()
|
||||
.map(|v| (!self.is_known(&v)).then_some(v))
|
||||
.collect(),
|
||||
function_index: prover_function.index,
|
||||
row_offset,
|
||||
inputs,
|
||||
@@ -335,7 +338,7 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
|
||||
inputs,
|
||||
}) => {
|
||||
let mut some_known = false;
|
||||
for t in targets {
|
||||
for t in targets.iter().flatten() {
|
||||
if self.record_known(t.clone()) {
|
||||
some_known = true;
|
||||
updated_variables.push(t.clone());
|
||||
@@ -344,7 +347,13 @@ impl<'a, T: FieldElement, FixedEval: FixedEvaluator<T>> WitgenInference<'a, T, F
|
||||
if some_known {
|
||||
log::trace!(
|
||||
"[{}] := prover_function_{function_index}({row_offset}, {})",
|
||||
targets.iter().format(", "),
|
||||
targets
|
||||
.iter()
|
||||
.map(|v| v
|
||||
.as_ref()
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or_else(|| "_".to_string()))
|
||||
.format(", "),
|
||||
inputs.iter().format(", ")
|
||||
);
|
||||
|
||||
|
||||
@@ -40,54 +40,6 @@ machine Arith(byte: Byte, byte2: Byte2) with
|
||||
// Constrain that y2 = 0 when operation is div.
|
||||
array::new(4, |i| is_division * y2[i] = 0);
|
||||
|
||||
// We need to provide hints for the quotient and remainder, because they are not unique under our current constraints.
|
||||
// They are unique given additional main machine constraints, but it's still good to provide hints for the solver.
|
||||
let quotient_hint = query |limb| match(eval(is_division)) {
|
||||
1 => {
|
||||
if x1_int() == 0 {
|
||||
// Quotient is unconstrained, use zero.
|
||||
Query::Hint(0)
|
||||
} else {
|
||||
let y3 = y3_int();
|
||||
let x1 = x1_int();
|
||||
let quotient = y3 / x1;
|
||||
Query::Hint(fe(select_limb(quotient, limb)))
|
||||
}
|
||||
},
|
||||
_ => Query::None
|
||||
};
|
||||
|
||||
col witness y1_0(i) query quotient_hint(0);
|
||||
col witness y1_1(i) query quotient_hint(1);
|
||||
col witness y1_2(i) query quotient_hint(2);
|
||||
col witness y1_3(i) query quotient_hint(3);
|
||||
|
||||
let y1: expr[] = [y1_0, y1_1, y1_2, y1_3];
|
||||
|
||||
let remainder_hint = query |limb| match(eval(is_division)) {
|
||||
1 => {
|
||||
let y3 = y3_int();
|
||||
let x1 = x1_int();
|
||||
if x1 == 0 {
|
||||
// To satisfy x1 * y1 + x2 = y3, we need to set x2 = y3.
|
||||
Query::Hint(fe(select_limb(y3, limb)))
|
||||
} else {
|
||||
let remainder = y3 % x1;
|
||||
Query::Hint(fe(select_limb(remainder, limb)))
|
||||
}
|
||||
},
|
||||
_ => Query::None
|
||||
};
|
||||
|
||||
col witness x2_0(i) query remainder_hint(0);
|
||||
col witness x2_1(i) query remainder_hint(1);
|
||||
col witness x2_2(i) query remainder_hint(2);
|
||||
col witness x2_3(i) query remainder_hint(3);
|
||||
|
||||
let x2: expr[] = [x2_0, x2_1, x2_2, x2_3];
|
||||
|
||||
pol commit x1[4], y2[4], y3[4];
|
||||
|
||||
// Selects the ith limb of x (little endian)
|
||||
// All limbs are 8 bits
|
||||
let select_limb = |x, i| if i >= 0 {
|
||||
@@ -96,13 +48,32 @@ machine Arith(byte: Byte, byte2: Byte2) with
|
||||
0
|
||||
};
|
||||
|
||||
let limbs_to_int: expr[] -> int = query |limbs| array::sum(array::map_enumerated(limbs, |i, limb| int(eval(limb)) << (i * 8)));
|
||||
let limbs_to_int: fe[] -> int = |limbs| array::sum(array::map_enumerated(limbs, |i, limb| int(limb) << (i * 8)));
|
||||
let int_to_limbs: int -> fe[] = |x| array::new(4, |i| fe(select_limb(x, i)));
|
||||
|
||||
let x1_int = query || limbs_to_int(x1);
|
||||
let y1_int = query || limbs_to_int(y1);
|
||||
let x2_int = query || limbs_to_int(x2);
|
||||
let y2_int = query || limbs_to_int(y2);
|
||||
let y3_int = query || limbs_to_int(y3);
|
||||
// We need to provide hints for the quotient and remainder, because they are not unique under our current constraints.
|
||||
// They are unique given additional main machine constraints, but it's still good to provide hints for the solver.
|
||||
query |i| std::prover::compute_from_multi_if(
|
||||
is_division = 1,
|
||||
y1 + x2,
|
||||
i,
|
||||
y3 + x1,
|
||||
|values| {
|
||||
let y3_value = limbs_to_int([values[0], values[1], values[2], values[3]]);
|
||||
let x1_value = limbs_to_int([values[4], values[5], values[6], values[7]]);
|
||||
if x1_value == 0 {
|
||||
// Quotient is unconstrained, use zero for y1
|
||||
// and set remainder x2 = y3.
|
||||
[0, 0, 0, 0] + int_to_limbs(y3_value)
|
||||
} else {
|
||||
let quotient = y3_value / x1_value;
|
||||
let remainder = y3_value % x1_value;
|
||||
int_to_limbs(quotient) + int_to_limbs(remainder)
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
pol commit x1[4], x2[4], y1[4], y2[4], y3[4];
|
||||
|
||||
let combine: expr[] -> expr[] = |x| array::new(array::len(x) / 2, |i| x[2 * i + 1] * 2**8 + x[2 * i]);
|
||||
// Intermediate polynomials, arrays of 16 columns, 16 bit per column.
|
||||
@@ -172,6 +143,6 @@ machine Arith(byte: Byte, byte2: Byte2) with
|
||||
carry * CLK8[0] = 0;
|
||||
|
||||
// Putting everything together
|
||||
col eq0_sum = sum(8, |i| eq0(i) * CLK8[i]);
|
||||
let eq0_sum = sum(8, |i| eq0(i) * CLK8[i]);
|
||||
eq0_sum + carry = carry' * 2**8;
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ machine Arith256 with
|
||||
*
|
||||
*******/
|
||||
|
||||
col eq0_sum = sum(64, |i| eq0(i) * CLK64[i]);
|
||||
let eq0_sum = sum(64, |i| eq0(i) * CLK64[i]);
|
||||
|
||||
eq0_sum + carry = carry' * 2**8;
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ let provide_if_unknown: expr, int, (-> fe) -> () = query |column, row, f| match
|
||||
_ => (),
|
||||
};
|
||||
|
||||
/// Returns true if all the provided columns are unknown.
|
||||
let all_unknown: expr[] -> bool = query |columns|
|
||||
std::array::fold(columns, true, |acc, c| acc && match try_eval(c) {
|
||||
/// Returns true if at least one of the provided columns is unknown.
|
||||
let some_unknown: expr[] -> bool = query |columns|
|
||||
std::array::fold(columns, false, |acc, c| acc || match try_eval(c) {
|
||||
Option::None => true,
|
||||
Option::Some(_) => false,
|
||||
});
|
||||
@@ -33,7 +33,7 @@ let compute_from: expr, int, expr[], (fe[] -> fe) -> () = query |dest_col, row,
|
||||
|
||||
/// Computes the value of multiple columns in a row based on the values of other columns.
|
||||
let compute_from_multi: expr[], int, expr[], (fe[] -> fe[]) -> () = query |dest_cols, row, input_cols, f|
|
||||
if all_unknown(dest_cols) {
|
||||
if some_unknown(dest_cols) {
|
||||
let values = f(std::array::map(input_cols, eval));
|
||||
let _ = std::array::zip(dest_cols, values, |c, v| provide_value(c, row, v));
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user