accumulate stwo (#535)

adds in-place vector addition and api as accumulate
This commit is contained in:
VitaliiH
2024-06-10 12:24:58 +02:00
committed by GitHub
parent 9c55d888ae
commit e19a869691
15 changed files with 127 additions and 14 deletions

View File

@@ -83,6 +83,12 @@ pub trait VecOps<F> {
cfg: &VecOpsConfig,
) -> IcicleResult<()>;
fn accumulate(
a: &mut (impl HostOrDeviceSlice<F> + ?Sized),
b: &(impl HostOrDeviceSlice<F> + ?Sized),
cfg: &VecOpsConfig,
) -> IcicleResult<()>;
fn sub(
a: &(impl HostOrDeviceSlice<F> + ?Sized),
b: &(impl HostOrDeviceSlice<F> + ?Sized),
@@ -207,6 +213,19 @@ where
<<F as FieldImpl>::Config as VecOps<F>>::add(a, b, result, &cfg)
}
pub fn accumulate_scalars<F>(
a: &mut (impl HostOrDeviceSlice<F> + ?Sized),
b: &(impl HostOrDeviceSlice<F> + ?Sized),
cfg: &VecOpsConfig,
) -> IcicleResult<()>
where
F: FieldImpl,
<F as FieldImpl>::Config: VecOps<F>,
{
let cfg = check_vec_ops_args(a, b, a, cfg);
<<F as FieldImpl>::Config as VecOps<F>>::accumulate(a, b, &cfg)
}
pub fn sub_scalars<F>(
a: &(impl HostOrDeviceSlice<F> + ?Sized),
b: &(impl HostOrDeviceSlice<F> + ?Sized),
@@ -299,6 +318,14 @@ macro_rules! impl_vec_ops_field {
result: *mut $field,
) -> CudaError;
#[link_name = concat!($field_prefix, "_accumulate_cuda")]
pub(crate) fn accumulate_scalars_cuda(
a: *mut $field,
b: *const $field,
size: u32,
cfg: *const VecOpsConfig,
) -> CudaError;
#[link_name = concat!($field_prefix, "_sub_cuda")]
pub(crate) fn sub_scalars_cuda(
a: *const $field,
@@ -357,6 +384,22 @@ macro_rules! impl_vec_ops_field {
}
}
fn accumulate(
a: &mut (impl HostOrDeviceSlice<$field> + ?Sized),
b: &(impl HostOrDeviceSlice<$field> + ?Sized),
cfg: &VecOpsConfig,
) -> IcicleResult<()> {
unsafe {
$field_prefix_ident::accumulate_scalars_cuda(
a.as_mut_ptr(),
b.as_ptr(),
a.len() as u32,
cfg as *const VecOpsConfig,
)
.wrap()
}
}
fn sub(
a: &(impl HostOrDeviceSlice<$field> + ?Sized),
b: &(impl HostOrDeviceSlice<$field> + ?Sized),
@@ -457,7 +500,7 @@ macro_rules! impl_vec_add_tests {
) => {
#[test]
pub fn test_vec_add_scalars() {
check_vec_ops_scalars::<$field>()
check_vec_ops_scalars::<$field>();
}
#[test]

View File

@@ -5,19 +5,21 @@ use crate::vec_ops::{
};
use icicle_cuda_runtime::memory::{DeviceVec, HostSlice};
use super::accumulate_scalars;
pub fn check_vec_ops_scalars<F: FieldImpl>()
where
<F as FieldImpl>::Config: VecOps<F> + GenerateRandom<F>,
{
let test_size = 1 << 14;
let a = F::Config::generate_random(test_size);
let mut a = F::Config::generate_random(test_size);
let b = F::Config::generate_random(test_size);
let ones = vec![F::one(); test_size];
let mut result = vec![F::zero(); test_size];
let mut result2 = vec![F::zero(); test_size];
let mut result3 = vec![F::zero(); test_size];
let a = HostSlice::from_slice(&a);
let a = HostSlice::from_mut_slice(&mut a);
let b = HostSlice::from_slice(&b);
let ones = HostSlice::from_slice(&ones);
let result = HostSlice::from_mut_slice(&mut result);
@@ -34,6 +36,12 @@ where
mul_scalars(a, ones, result3, &cfg).unwrap();
assert_eq!(a[0], result3[0]);
add_scalars(a, b, result, &cfg).unwrap();
accumulate_scalars(a, b, &cfg).unwrap();
assert_eq!(a[0], result[0]);
}
pub fn check_bit_reverse<F: FieldImpl>()