feat: only allow two inputs

This commit is contained in:
Daniel Tehrani
2023-05-05 12:48:59 +02:00
parent ebf2b5f6cc
commit 4d72a30b63
3 changed files with 29 additions and 41 deletions

View File

@@ -27,14 +27,14 @@ impl<F: PrimeField> PoseidonConstants<F> {
}
pub struct Poseidon<F: PrimeField> {
pub state: Vec<F>,
pub state: [F; 3],
pub constants: PoseidonConstants<F>,
pub pos: usize,
}
impl<F: PrimeField> Poseidon<F> {
pub fn new(constants: PoseidonConstants<F>) -> Self {
let state = vec![F::zero(); 3];
let state = [F::zero(); 3];
Self {
state,
constants,
@@ -42,12 +42,10 @@ impl<F: PrimeField> Poseidon<F> {
}
}
pub fn hash(&mut self, input: Vec<F>) -> F {
// add padding
let mut input = input.clone();
let domain_tag = 3; // 2^arity - 1
input.insert(0, F::from(domain_tag));
pub fn hash(&mut self, input: &[F; 2]) -> F {
// add the domain tag
let domain_tag = F::from(3); // 2^arity - 1
let input = [domain_tag, input[0], input[1]];
self.state = input;
@@ -80,14 +78,14 @@ impl<F: PrimeField> Poseidon<F> {
// MDS matrix multiplication
fn matrix_mul(&mut self) {
let mut result = Vec::new();
let mut result = [F::zero(); 3];
for val in self.constants.mds_matrix.iter() {
for (i, val) in self.constants.mds_matrix.iter().enumerate() {
let mut tmp = F::zero();
for (j, element) in self.state.iter().enumerate() {
tmp += val[j] * element
}
result.push(tmp)
result[i] = tmp;
}
self.state = result;
@@ -124,40 +122,30 @@ impl<F: PrimeField> Poseidon<F> {
#[cfg(test)]
mod tests {
use super::*;
use k256_consts::*;
use secq256k1::field::{field_secp, BaseField};
#[test]
fn test_k256() {
type Scalar = field_secp::FieldElement;
let input = vec![
let input = [
Scalar::from_str_vartime("1234567").unwrap(),
Scalar::from_str_vartime("109987").unwrap(),
];
let round_constants: Vec<Scalar> = k256_consts::ROUND_CONSTANTS
.iter()
.map(|x| Scalar::from_str_vartime(x).unwrap())
.collect();
let mds_matrix: Vec<Vec<Scalar>> = k256_consts::MDS_MATRIX
.iter()
.map(|x| {
x.iter()
.map(|y| Scalar::from_str_vartime(y).unwrap())
.collect::<Vec<Scalar>>()
})
.collect();
let constants = PoseidonConstants::<Scalar>::new(
round_constants,
mds_matrix,
k256_consts::NUM_FULL_ROUNDS,
k256_consts::NUM_PARTIAL_ROUNDS,
let constants = PoseidonConstants::<FieldElement>::new(
ROUND_CONSTANTS.to_vec(),
vec![
MDS_MATRIX[0].to_vec(),
MDS_MATRIX[1].to_vec(),
MDS_MATRIX[2].to_vec(),
],
NUM_FULL_ROUNDS,
NUM_PARTIAL_ROUNDS,
);
let mut poseidon = Poseidon::new(constants);
let digest = poseidon.hash(input);
let digest = poseidon.hash(&input);
assert_eq!(
digest,

View File

@@ -3,7 +3,7 @@ use ff::PrimeField;
pub use secq256k1::field::field_secp::FieldElement;
#[allow(dead_code)]
pub fn hash(input: Vec<FieldElement>) -> FieldElement {
pub fn hash(input: &[FieldElement; 2]) -> FieldElement {
let round_constants: Vec<FieldElement> = k256_consts::ROUND_CONSTANTS
.iter()
.map(|x| FieldElement::from_str_vartime(x).unwrap())

View File

@@ -84,14 +84,14 @@ pub fn verify(circuit: &[u8], proof: &[u8], public_input: &[u8]) -> Result<bool,
#[wasm_bindgen]
pub fn poseidon(input_bytes: &[u8]) -> Result<Vec<u8>, JsValue> {
let mut input = Vec::new();
for i in 0..(input_bytes.len() / 32) {
let f: [u8; 32] = input_bytes[(i * 32)..(i + 1) * 32].try_into().unwrap();
let val = FieldElement::from_bytes(&f).unwrap();
input.push(FieldElement::from(val));
}
assert_eq!(input_bytes.len(), 64);
let result = hash(input);
let input = [
FieldElement::from_bytes(&input_bytes[0..32].try_into().unwrap()).unwrap(),
FieldElement::from_bytes(&input_bytes[32..64].try_into().unwrap()).unwrap(),
];
let result = hash(&input);
Ok(result.to_bytes().to_vec())
}