Files
powdr/std/protocols/permutation.asm
Georg Wiese a917e4f35a Add permutation to PIL STD (#1297)
First part of #1296 

This PR adds a `permutation()` function to the standard library. The
code is inspired by the `permutation_via_challenges` test (removed now),
[this
comment](https://github.com/powdr-labs/powdr/issues/424#issuecomment-1931686047)
by @chriseth and [this Halo2
implementation](https://github.com/privacy-scaling-explorations/halo2/blob/main/halo2_proofs/examples/shuffle.rs).
2024-04-25 07:32:35 +00:00

42 lines
1.7 KiB
Rust

use std::prover::challenge;
use std::array::fold;
use std::array::len;
use std::convert::int;
use std::check::assert;
use std::field::modulus;
let is_first: col = |i| if i == 0 { 1 } else { 0 };
// Get two phase-2 challenges to use in all permutation arguments.
// Note that this assumes that globally no other challenge of these IDs is used.
let alpha: expr = challenge(0, 1);
let beta: expr = challenge(0, 2);
// Maps [x_1, x_2, ..., x_n] to alpha**(n - 1) * x_1 + alpha ** (n - 2) * x_2 + ... + x_n
let compress_expression_array = |expr_array| fold(expr_array, 0, |acc, el| alpha * acc + el);
// Adds constraints that enforce that lhs is a permutation of rhs
// Arguments:
// - acc: A phase-2 witness column to be used as the accumulator
// - lhs_selector: (assumed to be) binary selector to check which elements from the LHS to include
// - lhs: An array of expressions
// - rhs_selector: (assumed to be) binary selector to check which elements from the RHS to include
// - rhs: An array of expressions
let permutation = |acc, lhs_selector, lhs, rhs_selector, rhs| {
let _ = assert(len(lhs) == len(rhs), || "LHS and RHS should have equal length");
let _ = assert(modulus() > 2**100, || "This implementation assumes a large field");
let lhs_folded = lhs_selector * compress_expression_array(lhs);
let rhs_folded = rhs_selector * compress_expression_array(rhs);
[
// First and last z needs to be 1
// (because of wrapping, the z[0] and z[N] are the same)
is_first * (acc - 1) = 0,
// Update rule:
// acc' = acc * (beta - lhs_folded) / (beta - rhs_folded)
(beta - rhs_folded) * acc' = acc * (beta - lhs_folded)
]
};