diff --git a/tfhe-benchmark/benches/high_level_api/erc20.rs b/tfhe-benchmark/benches/high_level_api/erc20.rs index 1ee1136f3..1e8e81b3b 100644 --- a/tfhe-benchmark/benches/high_level_api/erc20.rs +++ b/tfhe-benchmark/benches/high_level_api/erc20.rs @@ -20,24 +20,24 @@ use tfhe::{set_server_key, ClientKey, CompressedServerKey, FheBool, FheUint64}; /// Transfer as written in the original FHEvm white-paper, /// it uses a comparison to check if the sender has enough, -/// and cmuxes based on the comparison result +/// and uses a cmux to compute the actual amount transferred. +/// https://docs.zama.org/protocol/zama-protocol-litepaper#creating-confidential-applications pub fn transfer_whitepaper( from_amount: &FheType, to_amount: &FheType, amount: &FheType, ) -> (FheType, FheType) where - FheType: Add + for<'a> FheOrd<&'a FheType>, + FheType: Add + for<'a> FheOrd<&'a FheType> + FheTrivialEncrypt, FheBool: IfThenElse, for<'a> &'a FheType: Add + Sub, { let has_enough_funds = (from_amount).ge(amount); + let zero_amount = FheType::encrypt_trivial(0u64); + let amount_to_transfer = has_enough_funds.select(amount, &zero_amount); - let mut new_to_amount = to_amount + amount; - new_to_amount = has_enough_funds.if_then_else(&new_to_amount, to_amount); - - let mut new_from_amount = from_amount - amount; - new_from_amount = has_enough_funds.if_then_else(&new_from_amount, from_amount); + let new_to_amount = to_amount + &amount_to_transfer; + let new_from_amount = from_amount - &amount_to_transfer; (new_from_amount, new_to_amount) } @@ -49,21 +49,18 @@ pub fn par_transfer_whitepaper( amount: &FheType, ) -> (FheType, FheType) where - FheType: Add + for<'a> FheOrd<&'a FheType> + Send + Sync, + FheType: + Add + for<'a> FheOrd<&'a FheType> + Send + Sync + FheTrivialEncrypt, FheBool: IfThenElse, for<'a> &'a FheType: Add + Sub, { let has_enough_funds = (from_amount).ge(amount); + let zero_amount = FheType::encrypt_trivial(0u64); + let amount_to_transfer = has_enough_funds.select(amount, &zero_amount); let (new_to_amount, new_from_amount) = rayon::join( - || { - let new_to_amount = to_amount + amount; - has_enough_funds.if_then_else(&new_to_amount, to_amount) - }, - || { - let new_from_amount = from_amount - amount; - has_enough_funds.if_then_else(&new_from_amount, from_amount) - }, + || to_amount + &amount_to_transfer, + || from_amount - &amount_to_transfer, ); (new_from_amount, new_to_amount)