diff --git a/src/ed25519_sha512.rs b/src/ed25519_sha512.rs index 3e6a459..f0dc1c9 100644 --- a/src/ed25519_sha512.rs +++ b/src/ed25519_sha512.rs @@ -4,6 +4,7 @@ use crate::sha512::Sha512; use crate::field::{Field}; use num_bigint::BigUint; use core::ops::Sub; +use num_traits::Zero; // implementation based on https://ed25519.cr.yp.to/ed25519-20110926.pdf @@ -34,35 +35,36 @@ pub fn gen_priv_key(k: &[u8; 32]) -> KeyPair { a[31] &= 0b1111_1000; // clear least significant 3 bits // q = 2^255 - 19 - let q = (BigUint::from(2u8).pow(255u32)).sub(19u8); + let q = BigUint::from(2u8).pow(255u32).sub(19u8); let F_q = Field::new(&q); - let q = F_q.elem(&q); - println!("Q={:?}", q); // base point is (x, 4/5) w/ positive x - let bp_y = F_q.elem(&5u8) * &4u8; -println!("1"); + let bp_y = F_q.elem(&4u8) / &5u8; + println!("Base point y={:?}", bp_y.n); + // d = -121665 / 121666 let d = -F_q.elem(&121665u32) / &121666u32; -println!("2"); + println!("d={}", d.n); // xx = x^2 = (y^2 - 1) / (1 + d*y^2) - let xx = (&bp_y * &bp_y - &1u8) / &1u8 + &(&d * &bp_y.sq()); -println!("3"); + let xx = (&bp_y.sq() - &1u8) / &(&(&d * &bp_y.sq()) + &1u8); + println!("xx={}", xx.n); - //let I = F_q.elem(&2u8).pow(&(&q - &1u8)) / &4u8; - let I = F_q.elem(&2u8); //.pow(&(&q - &1u8)); // / &4u8; -println!("3.1"); - let i3 = &q - &1u8; -println!("3.2"); - let i2 = I.pow(&i3); + // calculate the square root of xx assuming a^((p-1)/4) = 1 mod q + let mut bp_x = (&xx).pow(&((&q + &3u8) / &8u8)); -println!("4 = {:?}", i2); - let mut x = &xx.pow(&(&q + &3u8)) / &1u8; - if ((&x * &x) - &xx).n != BigUint::from(0u8) { // if x is not the solution, multiply I - x = x * &I; + // if that that's match, calculate the square root of xx again assuming a^((p-1)/4) = -1 mod q + if &bp_x.sq().n != &xx.n { + let I = F_q.elem(&2u8).pow(&((&q - &1u8) / &4u8)); + bp_x = &bp_x * &I; } - println!("Base point x={:?}", x); + // if bp_x is odd number, it's representing the negative x coordinate. + // in such a case, since base point x is positive, the value needs to be negated + if !(&bp_x.n % 2u8).is_zero() { + bp_x = -&bp_x; + } + println!("Base point x={:?}", bp_x.n); + // x should be positive // if least significant bit of x is 1, convert it to positive by // x = q - x