use rayon as __rayon_reexport; use rayon::prelude::*; use std::io::{stdin, Read}; use std::mem::MaybeUninit; use std::{array, iter}; use tfhe::prelude::*; use tfhe::shortint::parameters::current_params::{ V1_4_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, V1_4_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, }; use tfhe::{set_server_key, ClientKey, CompressedServerKey, ConfigBuilder, Device, FheUint32}; // might improve error message on type error #[doc(hidden)] pub fn __requires_sendable_closure R + Send>(x: F) -> F { x } #[doc(hidden)] macro_rules! __join_implementation { ($len:expr; $($f:ident $r:ident $a:expr),*; $b:expr, $($c:expr,)*) => { $crate::__join_implementation!{$len + 1; $($f $r $a,)* f r $b; $($c,)* } }; ($len:expr; $($f:ident $r:ident $a:expr),* ;) => { match ($(Some($crate::__requires_sendable_closure($a)),)*) { ($(mut $f,)*) => { $(let mut $r = None;)* let array: [&mut (dyn FnMut() + Send); $len] = [ $(&mut || $r = Some((&mut $f).take().unwrap()())),* ]; $crate::__rayon_reexport::iter::ParallelIterator::for_each( $crate::__rayon_reexport::iter::IntoParallelIterator::into_par_iter(array), |f| f(), ); ($($r.unwrap(),)*) } } }; } pub(crate) use __join_implementation; macro_rules! join { ($($($a:expr),+$(,)?)?) => { $crate::__join_implementation!{0;;$($($a,)+)?} }; } // In-House implementation of array_chunk // as the one in stdlib is not stable. pub struct ArrayChunks where T: Iterator, { source: T, } impl ArrayChunks where T: Iterator, { fn new(iterator: T) -> Self { Self { source: iterator } } } impl Iterator for ArrayChunks where T: Iterator, ::Item: Sized, [T::Item; N]: Sized, { type Item = [T::Item; N]; fn next(&mut self) -> Option { // The `assume_init` is // safe because the type we are claiming to have initialized here is a // bunch of `MaybeUninit`s, which do not require initialization. let mut data: [MaybeUninit; N] = unsafe { MaybeUninit::uninit().assume_init() }; // We don't use a loop that has an early return // because we want to avoid potential memory leaks let mut i = 0; for elem in self.source.by_ref() { data[i].write(elem); i += 1; if i == N { break; } } if i == N { // This is not allowed // Some(unsafe { std::mem::transmute(data) }) // https://github.com/rust-lang/rust/issues/61956 assert_eq!( std::mem::size_of::>(), std::mem::size_of::() ); assert_eq!( std::mem::size_of::<[MaybeUninit; N]>(), std::mem::size_of::<[T::Item; N]>() ); let ptr = &mut data as *mut _ as *mut [T::Item; N]; let res = unsafe { ptr.read() }; #[allow(clippy::forget_non_drop)] core::mem::forget(data); Some(res) } else { // For each item in the array, drop if we allocated it. for elem in &mut data[0..i] { unsafe { elem.assume_init_drop(); } } None } } } #[derive(Debug)] struct Args { device: Device, parallel: bool, trivial: bool, multibit: Option, } impl Default for Args { fn default() -> Self { Self { device: Device::Cpu, parallel: false, trivial: false, multibit: None, } } } impl Args { fn from_arg_list(mut program_args: std::env::Args) -> Self { let mut args = Args::default(); let mut had_invalid = false; program_args.next().unwrap(); // This is argv[0], the program name/path while let Some(arg) = program_args.next() { if arg == "--parallel" { args.parallel = true; } else if arg == "--trivial" { args.trivial = true; } else if arg == "--device" { let Some(value) = program_args.next() else { panic!("Expected value after --device"); }; match value.to_lowercase().as_str() { "cpu" => args.device = Device::Cpu, #[cfg(feature = "gpu")] "gpu" | "cuda" => args.device = Device::CudaGpu, #[cfg(not(feature = "gpu"))] "gpu" | "cuda" => { panic!("Needs to be compiled with gpu feature to support gpu") } _ => panic!("Unsupported device {value}"), } } else if arg == "--multibit" { let Some(value) = program_args.next() else { panic!("Expected value after --multibit"); }; args.multibit = Some(value.parse().unwrap()); } else { println!("Unknown argument '{arg}'"); had_invalid = true; } } if had_invalid { panic!("Invalid argument found, aborting"); } args } } fn main() -> Result<(), std::io::Error> { let args = Args::from_arg_list(std::env::args()); println!("Args: {args:?}"); println!("key gen start"); let config = match args.multibit { None => ConfigBuilder::default(), Some(2) => ConfigBuilder::with_custom_parameters( V1_4_PARAM_MULTI_BIT_GROUP_2_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, ), Some(3) => ConfigBuilder::with_custom_parameters( V1_4_PARAM_MULTI_BIT_GROUP_3_MESSAGE_2_CARRY_2_KS_PBS_GAUSSIAN_2M64, ), Some(v) => { panic!("Invalid multibit setting {v}"); } } .build(); let client_key = ClientKey::generate(config); let csks = CompressedServerKey::new(&client_key); match (args.device, args.parallel) { (Device::Cpu, false) => { let server_key = csks.decompress(); set_server_key(server_key); } (Device::Cpu, true) => { let server_key = csks.decompress(); rayon::broadcast(|_| { set_server_key(server_key.clone()); }); set_server_key(server_key); } #[cfg(feature = "gpu")] (Device::CudaGpu, false) => { let server_key = csks.decompress_to_gpu(); set_server_key(server_key); } #[cfg(feature = "gpu")] (Device::CudaGpu, true) => { let server_key = csks.decompress_to_gpu(); rayon::broadcast(|_| { set_server_key(server_key.clone()); }); set_server_key(server_key); } #[cfg(feature = "hpu")] (Device::Hpu, _) => { println!("Hpu is not supported"); std::process::exit(1); } } println!("key gen end"); let mut buf = vec![]; stdin().read_to_end(&mut buf)?; let client_key = if args.trivial { None } else { Some(client_key) }; let encrypted_input = encrypt_data(buf, client_key.as_ref()); let encrypted_hash = if args.parallel { sha256_fhe_parallel(encrypted_input) } else { sha256_fhe(encrypted_input) }; let decrypted_hash = decrypt_hash(encrypted_hash, client_key.as_ref()); println!("{}", hex::encode(decrypted_hash)); Ok(()) } const K: [u32; 64] = [ 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, ]; const INIT: [u32; 8] = [ 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, ]; fn par_rotr(input: &FheUint32, amounts: [u32; N]) -> [FheUint32; N] { let mut result = array::from_fn(|_| input.clone()); // TODO use input.rotate_right(amounts) when tfhe-rs adds it result .par_iter_mut() .zip(amounts.into_par_iter()) .for_each(|(elem, amount)| elem.rotate_right_assign(amount)); result } fn rotr(input: &FheUint32, amounts: [u32; N]) -> [FheUint32; N] { let mut result = array::from_fn(|_| input.clone()); // TODO use input.rotate_right(amounts) when tfhe-rs adds it result .iter_mut() .zip(amounts) .for_each(|(elem, amount)| elem.rotate_right_assign(amount)); result } fn encrypt_data>(input: T, client_key: Option<&ClientKey>) -> Vec { let len = input.as_ref().len(); let remainder = (len + 9) % 64; let bytes_iter = input .as_ref() .iter() .copied() .chain(iter::once(0x80)) .chain(std::iter::repeat_n( 0x00, if remainder == 0 { 0 } else { 64 - remainder }, )) .chain(((len * 8) as u64).to_be_bytes()); ArrayChunks::<_, 4>::new(bytes_iter) .map(|bytes| { if let Some(cks) = client_key { FheUint32::encrypt(u32::from_be_bytes(bytes), cks) } else { FheUint32::encrypt_trivial(u32::from_be_bytes(bytes)) } }) .collect() } fn decrypt_hash(encrypted_hash: [FheUint32; 8], client_key: Option<&ClientKey>) -> [u8; 32] { let mut decrypted_hash = [0u8; 32]; encrypted_hash .iter() .zip(decrypted_hash.chunks_exact_mut(4)) .for_each(|(ciphertext, out_clear)| { let clear: u32 = if let Some(cks) = client_key { ciphertext.decrypt(cks) } else { ciphertext.try_decrypt_trivial().unwrap() }; out_clear.copy_from_slice(&clear.to_be_bytes()); }); decrypted_hash } fn sha256_fhe(input: Vec) -> [FheUint32; 8] { println!("len: {}", input.len()); let k = K.map(|x: u32| FheUint32::encrypt_trivial(x)); let mut hash = INIT.map(|x: u32| FheUint32::encrypt_trivial(x)); let all_ones = FheUint32::encrypt_trivial(0xffffffff_u32); let mut w: [_; 64] = array::from_fn(|_| FheUint32::encrypt_trivial(0_u32)); let len = input.len(); let total_timer = std::time::Instant::now(); println!("Starting main loop"); for (chunk_index, mut chunk) in ArrayChunks::<_, 16>::new(input.into_iter()).enumerate() { let bfr = std::time::Instant::now(); println!("Start chunk: {} / {}", chunk_index + 1, len / 16); w[0..16].swap_with_slice(&mut chunk); for i in 16..64 { let s0 = { let rotations = rotr(&w[i - 15], [7u32, 18]); &rotations[0] ^ &rotations[1] ^ (&w[i - 15] >> 3u32) }; let s1 = { let rotations = rotr(&w[i - 2], [17u32, 19]); &rotations[0] ^ &rotations[1] ^ (&w[i - 2] >> 10u32) }; w[i] = [&w[i - 16], &s0, &w[i - 7], &s1].iter().copied().sum(); } let mut a = hash[0].clone(); let mut b = hash[1].clone(); let mut c = hash[2].clone(); let mut d = hash[3].clone(); let mut e = hash[4].clone(); let mut f = hash[5].clone(); let mut g = hash[6].clone(); let mut h = hash[7].clone(); for i in 0..64 { let s1 = { let rotations = rotr(&e, [6u32, 11, 25]); &rotations[0] ^ &rotations[1] ^ &rotations[2] }; let ch = (&e & &f) ^ ((&e ^ &all_ones) & &g); // let t1 = [&h, &s1, &ch, &k[i], &w[i]].into_iter().sum::(); let t1 = FheUint32::sum([&h, &s1, &ch, &k[i], &w[i]]); let s0 = { let rotations = rotr(&a, [2u32, 13, 22]); &rotations[0] ^ &rotations[1] ^ &rotations[2] }; let maj = (&a & &b) ^ (&a & &c) ^ (&b & &c); let t2 = s0 + maj; h = g; g = f; f = e; e = d + &t1; d = c; c = b; b = a; a = t1 + t2; } hash[0] += a; hash[1] += b; hash[2] += c; hash[3] += d; hash[4] += e; hash[5] += f; hash[6] += g; hash[7] += h; println!("Processed in: {:?}", bfr.elapsed()); } println!("Total time: {:?}", total_timer.elapsed()); hash } fn sha256_fhe_parallel(input: Vec) -> [FheUint32; 8] { let k = K.map(|x: u32| FheUint32::encrypt_trivial(x)); let mut hash = INIT.map(|x: u32| FheUint32::encrypt_trivial(x)); let all_ones = FheUint32::encrypt_trivial(0xffffffff_u32); let mut w: [_; 64] = array::from_fn(|_| FheUint32::encrypt_trivial(0_u32)); let len = input.len(); let total_timer = std::time::Instant::now(); println!("Starting main loop"); for (chunk_index, mut chunk) in ArrayChunks::<_, 16>::new(input.into_iter()).enumerate() { println!("Start chunk: {} / {}", chunk_index + 1, len / 16); let bfr = std::time::Instant::now(); w[0..16].swap_with_slice(&mut chunk); for i in 16..64 { let (s0_a, s0_b, s1_a, s1_b) = join!( || par_rotr(&w[i - 15], [7u32, 18]), || &w[i - 15] >> 3u32, || par_rotr(&w[i - 2], [17u32, 19]), || &w[i - 2] >> 10u32, ); let (s0, s1) = rayon::join(|| &s0_a[0] ^ &s0_a[1] ^ s0_b, || &s1_a[0] ^ &s1_a[1] ^ s1_b); w[i] = [&w[i - 16], &s0, &w[i - 7], &s1].into_iter().sum(); } let mut a = hash[0].clone(); let mut b = hash[1].clone(); let mut c = hash[2].clone(); let mut d = hash[3].clone(); let mut e = hash[4].clone(); let mut f = hash[5].clone(); let mut g = hash[6].clone(); let mut h = hash[7].clone(); for i in 0..64 { // Please clippy let e_rotations = || { let rotations = par_rotr(&e, [6u32, 11, 25]); &rotations[0] ^ &rotations[1] ^ &rotations[2] }; let a_rotations = || { let rotations = par_rotr(&a, [2u32, 13, 22]); &rotations[0] ^ &rotations[1] ^ &rotations[2] }; let (s1, ch, s0, maj) = join!( e_rotations, || (&e & &f) ^ ((&e ^ &all_ones) & &g), a_rotations, || (&a & &b) ^ (&a & &c) ^ (&b & &c) ); let (t1, t2) = rayon::join( || [&h, &s1, &ch, &k[i], &w[i]].into_iter().sum(), || s0 + maj, ); let (d_plus_t1, t1_plus_t2) = rayon::join(|| d + &t1, || &t1 + t2); h = g; g = f; f = e; e = d_plus_t1; d = c; c = b; b = a; a = t1_plus_t2; } let hash2 = [a, b, c, d, e, f, g, h]; hash.par_iter_mut() .zip(hash2.par_iter()) .for_each(|(dest, src)| *dest += src); println!("Processed in: {:?}", bfr.elapsed()); } println!("Total time: {:?}", total_timer.elapsed()); hash }