mirror of
https://github.com/zama-ai/tfhe-rs.git
synced 2026-01-09 22:57:59 -05:00
fix(gpu): fix regression on ERC20 throughput
- partially revert changes done in fd79c4f972
- transfers for the GPU case should be measured using sequential
operations (without rayon!)
This commit is contained in:
@@ -71,7 +71,7 @@ where
|
|||||||
|
|
||||||
/// This one also uses a comparison, but it leverages the 'boolean' multiplication
|
/// This one also uses a comparison, but it leverages the 'boolean' multiplication
|
||||||
/// instead of cmuxes, so it is faster
|
/// instead of cmuxes, so it is faster
|
||||||
#[cfg(not(feature = "hpu"))]
|
#[cfg(all(feature = "gpu", not(feature = "hpu")))]
|
||||||
fn transfer_no_cmux<FheType>(
|
fn transfer_no_cmux<FheType>(
|
||||||
from_amount: &FheType,
|
from_amount: &FheType,
|
||||||
to_amount: &FheType,
|
to_amount: &FheType,
|
||||||
@@ -87,6 +87,29 @@ where
|
|||||||
|
|
||||||
let amount = amount * FheType::cast_from(has_enough_funds);
|
let amount = amount * FheType::cast_from(has_enough_funds);
|
||||||
|
|
||||||
|
let new_to_amount = to_amount + &amount;
|
||||||
|
let new_from_amount = from_amount - &amount;
|
||||||
|
|
||||||
|
(new_from_amount, new_to_amount)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parallel variant of [`transfer_no_cmux`].
|
||||||
|
#[cfg(not(feature = "hpu"))]
|
||||||
|
fn par_transfer_no_cmux<FheType>(
|
||||||
|
from_amount: &FheType,
|
||||||
|
to_amount: &FheType,
|
||||||
|
amount: &FheType,
|
||||||
|
) -> (FheType, FheType)
|
||||||
|
where
|
||||||
|
FheType: Add<Output = FheType> + CastFrom<FheBool> + for<'a> FheOrd<&'a FheType> + Send + Sync,
|
||||||
|
FheBool: IfThenElse<FheType>,
|
||||||
|
for<'a> &'a FheType:
|
||||||
|
Add<Output = FheType> + Sub<Output = FheType> + Mul<FheType, Output = FheType>,
|
||||||
|
{
|
||||||
|
let has_enough_funds = (from_amount).ge(amount);
|
||||||
|
|
||||||
|
let amount = amount * FheType::cast_from(has_enough_funds);
|
||||||
|
|
||||||
let (new_to_amount, new_from_amount) =
|
let (new_to_amount, new_from_amount) =
|
||||||
rayon::join(|| to_amount + &amount, || from_amount - &amount);
|
rayon::join(|| to_amount + &amount, || from_amount - &amount);
|
||||||
|
|
||||||
@@ -95,12 +118,36 @@ where
|
|||||||
|
|
||||||
/// This one uses overflowing sub to remove the need for comparison
|
/// This one uses overflowing sub to remove the need for comparison
|
||||||
/// it also uses the 'boolean' multiplication
|
/// it also uses the 'boolean' multiplication
|
||||||
#[cfg(not(feature = "hpu"))]
|
#[cfg(all(feature = "gpu", not(feature = "hpu")))]
|
||||||
fn transfer_overflow<FheType>(
|
fn transfer_overflow<FheType>(
|
||||||
from_amount: &FheType,
|
from_amount: &FheType,
|
||||||
to_amount: &FheType,
|
to_amount: &FheType,
|
||||||
amount: &FheType,
|
amount: &FheType,
|
||||||
) -> (FheType, FheType)
|
) -> (FheType, FheType)
|
||||||
|
where
|
||||||
|
FheType: CastFrom<FheBool> + for<'a> FheOrd<&'a FheType> + Send + Sync,
|
||||||
|
FheBool: IfThenElse<FheType>,
|
||||||
|
for<'a> &'a FheType: Add<FheType, Output = FheType>
|
||||||
|
+ OverflowingSub<&'a FheType, Output = FheType>
|
||||||
|
+ Mul<FheType, Output = FheType>,
|
||||||
|
{
|
||||||
|
let (new_from, did_not_have_enough) = (from_amount).overflowing_sub(amount);
|
||||||
|
|
||||||
|
let new_from_amount = did_not_have_enough.if_then_else(from_amount, &new_from);
|
||||||
|
|
||||||
|
let had_enough_funds = !did_not_have_enough;
|
||||||
|
let new_to_amount = to_amount + (amount * FheType::cast_from(had_enough_funds));
|
||||||
|
|
||||||
|
(new_from_amount, new_to_amount)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parallel variant of [`transfer_overflow`].
|
||||||
|
#[cfg(not(feature = "hpu"))]
|
||||||
|
fn par_transfer_overflow<FheType>(
|
||||||
|
from_amount: &FheType,
|
||||||
|
to_amount: &FheType,
|
||||||
|
amount: &FheType,
|
||||||
|
) -> (FheType, FheType)
|
||||||
where
|
where
|
||||||
FheType: CastFrom<FheBool> + for<'a> FheOrd<&'a FheType> + Send + Sync,
|
FheType: CastFrom<FheBool> + for<'a> FheOrd<&'a FheType> + Send + Sync,
|
||||||
FheBool: IfThenElse<FheType>,
|
FheBool: IfThenElse<FheType>,
|
||||||
@@ -122,12 +169,36 @@ where
|
|||||||
|
|
||||||
/// This ones uses both overflowing_add/sub to check that both
|
/// This ones uses both overflowing_add/sub to check that both
|
||||||
/// the sender has enough funds, and the receiver will not overflow its balance
|
/// the sender has enough funds, and the receiver will not overflow its balance
|
||||||
#[cfg(not(feature = "hpu"))]
|
#[cfg(all(feature = "gpu", not(feature = "hpu")))]
|
||||||
fn transfer_safe<FheType>(
|
fn transfer_safe<FheType>(
|
||||||
from_amount: &FheType,
|
from_amount: &FheType,
|
||||||
to_amount: &FheType,
|
to_amount: &FheType,
|
||||||
amount: &FheType,
|
amount: &FheType,
|
||||||
) -> (FheType, FheType)
|
) -> (FheType, FheType)
|
||||||
|
where
|
||||||
|
FheType: Send + Sync,
|
||||||
|
for<'a> &'a FheType: OverflowingSub<&'a FheType, Output = FheType>
|
||||||
|
+ OverflowingAdd<&'a FheType, Output = FheType>,
|
||||||
|
FheBool: IfThenElse<FheType>,
|
||||||
|
{
|
||||||
|
let (new_from, did_not_have_enough_funds) = (from_amount).overflowing_sub(amount);
|
||||||
|
let (new_to, did_not_have_enough_space) = (to_amount).overflowing_add(amount);
|
||||||
|
|
||||||
|
let something_not_ok = did_not_have_enough_funds | did_not_have_enough_space;
|
||||||
|
|
||||||
|
let new_from_amount = something_not_ok.if_then_else(from_amount, &new_from);
|
||||||
|
let new_to_amount = something_not_ok.if_then_else(to_amount, &new_to);
|
||||||
|
|
||||||
|
(new_from_amount, new_to_amount)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parallel variant of [`transfer_safe`].
|
||||||
|
#[cfg(not(feature = "hpu"))]
|
||||||
|
fn par_transfer_safe<FheType>(
|
||||||
|
from_amount: &FheType,
|
||||||
|
to_amount: &FheType,
|
||||||
|
amount: &FheType,
|
||||||
|
) -> (FheType, FheType)
|
||||||
where
|
where
|
||||||
FheType: Send + Sync,
|
FheType: Send + Sync,
|
||||||
for<'a> &'a FheType: OverflowingSub<&'a FheType, Output = FheType>
|
for<'a> &'a FheType: OverflowingSub<&'a FheType, Output = FheType>
|
||||||
@@ -358,71 +429,69 @@ fn cuda_bench_transfer_throughput<FheType, F>(
|
|||||||
.map(|i| compressed_server_key.decompress_to_specific_gpu(GpuIndex::new(i as u32)))
|
.map(|i| compressed_server_key.decompress_to_specific_gpu(GpuIndex::new(i as u32)))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
for num_elems in [10 * num_gpus, 100 * num_gpus, 500 * num_gpus] {
|
// 200 * num_gpus seems to be enough for maximum throughput on 8xH100 SXM5
|
||||||
group.throughput(Throughput::Elements(num_elems));
|
let num_elems = 200 * num_gpus;
|
||||||
let bench_id =
|
|
||||||
format!("{bench_name}::throughput::{fn_name}::{type_name}::{num_elems}_elems");
|
|
||||||
group.bench_with_input(&bench_id, &num_elems, |b, &num_elems| {
|
|
||||||
let from_amounts = (0..num_elems)
|
|
||||||
.map(|_| FheType::encrypt(rng.gen::<u64>(), client_key))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let to_amounts = (0..num_elems)
|
|
||||||
.map(|_| FheType::encrypt(rng.gen::<u64>(), client_key))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
let amounts = (0..num_elems)
|
|
||||||
.map(|_| FheType::encrypt(rng.gen::<u64>(), client_key))
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let num_streams_per_gpu = 8; // Hard coded stream value for FheUint64
|
group.throughput(Throughput::Elements(num_elems));
|
||||||
let chunk_size = (num_elems / num_gpus) as usize;
|
let bench_id = format!("{bench_name}::throughput::{fn_name}::{type_name}::{num_elems}_elems");
|
||||||
|
group.bench_with_input(&bench_id, &num_elems, |b, &num_elems| {
|
||||||
|
let from_amounts = (0..num_elems)
|
||||||
|
.map(|_| FheType::encrypt(rng.gen::<u64>(), client_key))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let to_amounts = (0..num_elems)
|
||||||
|
.map(|_| FheType::encrypt(rng.gen::<u64>(), client_key))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let amounts = (0..num_elems)
|
||||||
|
.map(|_| FheType::encrypt(rng.gen::<u64>(), client_key))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
b.iter(|| {
|
let num_streams_per_gpu = 8; // Hard coded stream value for FheUint64
|
||||||
from_amounts
|
let chunk_size = (num_elems / num_gpus) as usize;
|
||||||
.par_chunks(chunk_size) // Split into chunks of num_gpus
|
|
||||||
.zip(
|
b.iter(|| {
|
||||||
to_amounts
|
from_amounts
|
||||||
.par_chunks(chunk_size)
|
.par_chunks(chunk_size) // Split into chunks of num_gpus
|
||||||
.zip(amounts.par_chunks(chunk_size)),
|
.zip(
|
||||||
) // Zip with the other data
|
to_amounts
|
||||||
.enumerate() // Get the index for GPU
|
.par_chunks(chunk_size)
|
||||||
.for_each(
|
.zip(amounts.par_chunks(chunk_size)),
|
||||||
|(i, (from_amount_gpu_i, (to_amount_gpu_i, amount_gpu_i)))| {
|
) // Zip with the other data
|
||||||
// Process chunks within each GPU
|
.enumerate() // Get the index for GPU
|
||||||
let stream_chunk_size = from_amount_gpu_i.len() / num_streams_per_gpu;
|
.for_each(
|
||||||
from_amount_gpu_i
|
|(i, (from_amount_gpu_i, (to_amount_gpu_i, amount_gpu_i)))| {
|
||||||
.par_chunks(stream_chunk_size)
|
// Process chunks within each GPU
|
||||||
.zip(to_amount_gpu_i.par_chunks(stream_chunk_size))
|
let stream_chunk_size = from_amount_gpu_i.len() / num_streams_per_gpu;
|
||||||
.zip(amount_gpu_i.par_chunks(stream_chunk_size))
|
from_amount_gpu_i
|
||||||
.for_each(
|
.par_chunks(stream_chunk_size)
|
||||||
|((from_amount_chunk, to_amount_chunk), amount_chunk)| {
|
.zip(to_amount_gpu_i.par_chunks(stream_chunk_size))
|
||||||
// Set the server key for the current GPU
|
.zip(amount_gpu_i.par_chunks(stream_chunk_size))
|
||||||
set_server_key(sks_vec[i].clone());
|
.for_each(|((from_amount_chunk, to_amount_chunk), amount_chunk)| {
|
||||||
// Parallel iteration over the chunks of data
|
// Set the server key for the current GPU
|
||||||
from_amount_chunk
|
set_server_key(sks_vec[i].clone());
|
||||||
.iter()
|
// Parallel iteration over the chunks of data
|
||||||
.zip(to_amount_chunk.iter().zip(amount_chunk.iter()))
|
from_amount_chunk
|
||||||
.for_each(|(from_amount, (to_amount, amount))| {
|
.iter()
|
||||||
transfer_func(from_amount, to_amount, amount);
|
.zip(to_amount_chunk.iter().zip(amount_chunk.iter()))
|
||||||
});
|
.for_each(|(from_amount, (to_amount, amount))| {
|
||||||
},
|
transfer_func(from_amount, to_amount, amount);
|
||||||
);
|
});
|
||||||
},
|
});
|
||||||
);
|
},
|
||||||
});
|
);
|
||||||
});
|
});
|
||||||
|
});
|
||||||
|
|
||||||
let params = client_key.computation_parameters();
|
let params = client_key.computation_parameters();
|
||||||
|
|
||||||
write_to_json::<u64, _>(
|
write_to_json::<u64, _>(
|
||||||
&bench_id,
|
&bench_id,
|
||||||
params,
|
params,
|
||||||
params.name(),
|
params.name(),
|
||||||
"erc20-transfer",
|
"erc20-transfer",
|
||||||
&OperatorType::Atomic,
|
&OperatorType::Atomic,
|
||||||
64,
|
64,
|
||||||
vec![],
|
vec![],
|
||||||
);
|
);
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "hpu")]
|
#[cfg(feature = "hpu")]
|
||||||
@@ -517,14 +586,19 @@ fn main() {
|
|||||||
"transfer::whitepaper",
|
"transfer::whitepaper",
|
||||||
par_transfer_whitepaper::<FheUint64>,
|
par_transfer_whitepaper::<FheUint64>,
|
||||||
);
|
);
|
||||||
print_transfer_pbs_counts(&cks, "FheUint64", "no_cmux", transfer_no_cmux::<FheUint64>);
|
print_transfer_pbs_counts(
|
||||||
|
&cks,
|
||||||
|
"FheUint64",
|
||||||
|
"no_cmux",
|
||||||
|
par_transfer_no_cmux::<FheUint64>,
|
||||||
|
);
|
||||||
print_transfer_pbs_counts(
|
print_transfer_pbs_counts(
|
||||||
&cks,
|
&cks,
|
||||||
"FheUint64",
|
"FheUint64",
|
||||||
"transfer::overflow",
|
"transfer::overflow",
|
||||||
transfer_overflow::<FheUint64>,
|
par_transfer_overflow::<FheUint64>,
|
||||||
);
|
);
|
||||||
print_transfer_pbs_counts(&cks, "FheUint64", "safe", transfer_safe::<FheUint64>);
|
print_transfer_pbs_counts(&cks, "FheUint64", "safe", par_transfer_safe::<FheUint64>);
|
||||||
}
|
}
|
||||||
|
|
||||||
// FheUint64 latency
|
// FheUint64 latency
|
||||||
@@ -544,7 +618,7 @@ fn main() {
|
|||||||
bench_name,
|
bench_name,
|
||||||
"FheUint64",
|
"FheUint64",
|
||||||
"transfer::no_cmux",
|
"transfer::no_cmux",
|
||||||
transfer_no_cmux::<FheUint64>,
|
par_transfer_no_cmux::<FheUint64>,
|
||||||
);
|
);
|
||||||
bench_transfer_latency(
|
bench_transfer_latency(
|
||||||
&mut group,
|
&mut group,
|
||||||
@@ -552,7 +626,7 @@ fn main() {
|
|||||||
bench_name,
|
bench_name,
|
||||||
"FheUint64",
|
"FheUint64",
|
||||||
"transfer::overflow",
|
"transfer::overflow",
|
||||||
transfer_overflow::<FheUint64>,
|
par_transfer_overflow::<FheUint64>,
|
||||||
);
|
);
|
||||||
bench_transfer_latency(
|
bench_transfer_latency(
|
||||||
&mut group,
|
&mut group,
|
||||||
@@ -560,7 +634,7 @@ fn main() {
|
|||||||
bench_name,
|
bench_name,
|
||||||
"FheUint64",
|
"FheUint64",
|
||||||
"transfer::safe",
|
"transfer::safe",
|
||||||
transfer_safe::<FheUint64>,
|
par_transfer_safe::<FheUint64>,
|
||||||
);
|
);
|
||||||
|
|
||||||
group.finish();
|
group.finish();
|
||||||
@@ -583,7 +657,7 @@ fn main() {
|
|||||||
bench_name,
|
bench_name,
|
||||||
"FheUint64",
|
"FheUint64",
|
||||||
"transfer::no_cmux",
|
"transfer::no_cmux",
|
||||||
transfer_no_cmux::<FheUint64>,
|
par_transfer_no_cmux::<FheUint64>,
|
||||||
);
|
);
|
||||||
bench_transfer_throughput(
|
bench_transfer_throughput(
|
||||||
&mut group,
|
&mut group,
|
||||||
@@ -591,7 +665,7 @@ fn main() {
|
|||||||
bench_name,
|
bench_name,
|
||||||
"FheUint64",
|
"FheUint64",
|
||||||
"transfer::overflow",
|
"transfer::overflow",
|
||||||
transfer_overflow::<FheUint64>,
|
par_transfer_overflow::<FheUint64>,
|
||||||
);
|
);
|
||||||
bench_transfer_throughput(
|
bench_transfer_throughput(
|
||||||
&mut group,
|
&mut group,
|
||||||
@@ -599,7 +673,7 @@ fn main() {
|
|||||||
bench_name,
|
bench_name,
|
||||||
"FheUint64",
|
"FheUint64",
|
||||||
"transfer::safe",
|
"transfer::safe",
|
||||||
transfer_safe::<FheUint64>,
|
par_transfer_safe::<FheUint64>,
|
||||||
);
|
);
|
||||||
|
|
||||||
group.finish();
|
group.finish();
|
||||||
@@ -631,14 +705,19 @@ fn main() {
|
|||||||
"transfer::whitepaper",
|
"transfer::whitepaper",
|
||||||
par_transfer_whitepaper::<FheUint64>,
|
par_transfer_whitepaper::<FheUint64>,
|
||||||
);
|
);
|
||||||
print_transfer_pbs_counts(&cks, "FheUint64", "no_cmux", transfer_no_cmux::<FheUint64>);
|
print_transfer_pbs_counts(
|
||||||
|
&cks,
|
||||||
|
"FheUint64",
|
||||||
|
"no_cmux",
|
||||||
|
par_transfer_no_cmux::<FheUint64>,
|
||||||
|
);
|
||||||
print_transfer_pbs_counts(
|
print_transfer_pbs_counts(
|
||||||
&cks,
|
&cks,
|
||||||
"FheUint64",
|
"FheUint64",
|
||||||
"transfer::overflow",
|
"transfer::overflow",
|
||||||
transfer_overflow::<FheUint64>,
|
par_transfer_overflow::<FheUint64>,
|
||||||
);
|
);
|
||||||
print_transfer_pbs_counts(&cks, "FheUint64", "safe", transfer_safe::<FheUint64>);
|
print_transfer_pbs_counts(&cks, "FheUint64", "safe", par_transfer_safe::<FheUint64>);
|
||||||
}
|
}
|
||||||
|
|
||||||
// FheUint64 latency
|
// FheUint64 latency
|
||||||
@@ -658,7 +737,7 @@ fn main() {
|
|||||||
bench_name,
|
bench_name,
|
||||||
"FheUint64",
|
"FheUint64",
|
||||||
"transfer::no_cmux",
|
"transfer::no_cmux",
|
||||||
transfer_no_cmux::<FheUint64>,
|
par_transfer_no_cmux::<FheUint64>,
|
||||||
);
|
);
|
||||||
bench_transfer_latency(
|
bench_transfer_latency(
|
||||||
&mut group,
|
&mut group,
|
||||||
@@ -666,7 +745,7 @@ fn main() {
|
|||||||
bench_name,
|
bench_name,
|
||||||
"FheUint64",
|
"FheUint64",
|
||||||
"transfer::overflow",
|
"transfer::overflow",
|
||||||
transfer_overflow::<FheUint64>,
|
par_transfer_overflow::<FheUint64>,
|
||||||
);
|
);
|
||||||
bench_transfer_latency(
|
bench_transfer_latency(
|
||||||
&mut group,
|
&mut group,
|
||||||
@@ -674,7 +753,7 @@ fn main() {
|
|||||||
bench_name,
|
bench_name,
|
||||||
"FheUint64",
|
"FheUint64",
|
||||||
"transfer::safe",
|
"transfer::safe",
|
||||||
transfer_safe::<FheUint64>,
|
par_transfer_safe::<FheUint64>,
|
||||||
);
|
);
|
||||||
|
|
||||||
group.finish();
|
group.finish();
|
||||||
@@ -689,7 +768,7 @@ fn main() {
|
|||||||
bench_name,
|
bench_name,
|
||||||
"FheUint64",
|
"FheUint64",
|
||||||
"transfer::whitepaper",
|
"transfer::whitepaper",
|
||||||
par_transfer_whitepaper::<FheUint64>,
|
transfer_whitepaper::<FheUint64>,
|
||||||
);
|
);
|
||||||
cuda_bench_transfer_throughput(
|
cuda_bench_transfer_throughput(
|
||||||
&mut group,
|
&mut group,
|
||||||
|
|||||||
Reference in New Issue
Block a user