bit reverse (#528)

This PR adds bit reverse operation support to icicle
This commit is contained in:
nonam3e
2024-06-02 16:37:58 +07:00
committed by GitHub
parent 417ca77f61
commit 8e62bde16d
29 changed files with 634 additions and 262 deletions

View File

@@ -40,6 +40,40 @@ impl<'a> VecOpsConfig<'a> {
}
}
#[repr(C)]
#[derive(Debug, Clone)]
pub struct BitReverseConfig<'a> {
/// Details related to the device such as its id and stream id. See [DeviceContext](@ref device_context::DeviceContext).
pub ctx: DeviceContext<'a>,
/// True if inputs are on device and false if they're on host. Default value: false.
pub is_input_on_device: bool,
/// If true, output is preserved on device, otherwise on host. Default value: false.
pub is_output_on_device: bool,
/// Whether to run the vector operations asynchronously. If set to `true`, the functions will be non-blocking and you'd need to synchronize
/// it explicitly by running `stream.synchronize()`. If set to false, the functions will block the current CPU thread.
pub is_async: bool,
}
impl<'a> Default for BitReverseConfig<'a> {
fn default() -> Self {
Self::default_for_device(DEFAULT_DEVICE_ID)
}
}
impl<'a> BitReverseConfig<'a> {
pub fn default_for_device(device_id: usize) -> Self {
BitReverseConfig {
ctx: DeviceContext::default_for_device(device_id),
is_input_on_device: false,
is_output_on_device: false,
is_async: false,
}
}
}
#[doc(hidden)]
pub trait VecOps<F> {
fn add(
@@ -72,6 +106,17 @@ pub trait VecOps<F> {
on_device: bool,
is_async: bool,
) -> IcicleResult<()>;
fn bit_reverse(
input: &(impl HostOrDeviceSlice<F> + ?Sized),
cfg: &BitReverseConfig,
output: &mut (impl HostOrDeviceSlice<F> + ?Sized),
) -> IcicleResult<()>;
fn bit_reverse_inplace(
input: &mut (impl HostOrDeviceSlice<F> + ?Sized),
cfg: &BitReverseConfig,
) -> IcicleResult<()>;
}
fn check_vec_ops_args<'a, F>(
@@ -111,6 +156,42 @@ fn check_vec_ops_args<'a, F>(
res_cfg.is_result_on_device = result.is_on_device();
res_cfg
}
fn check_bit_reverse_args<'a, F>(
input: &(impl HostOrDeviceSlice<F> + ?Sized),
cfg: &BitReverseConfig<'a>,
output: &(impl HostOrDeviceSlice<F> + ?Sized),
) -> BitReverseConfig<'a> {
if input.len() & (input.len() - 1) != 0 {
panic!("input length must be a power of 2, input length: {}", input.len());
}
if input.len() != output.len() {
panic!(
"input and output lengths {}; {} do not match",
input.len(),
output.len()
);
}
let ctx_device_id = cfg
.ctx
.device_id;
if let Some(device_id) = input.device_id() {
assert_eq!(
device_id, ctx_device_id,
"Device ids in input and context are different"
);
}
if let Some(device_id) = output.device_id() {
assert_eq!(
device_id, ctx_device_id,
"Device ids in output and context are different"
);
}
check_device(ctx_device_id);
let mut res_cfg = cfg.clone();
res_cfg.is_input_on_device = input.is_on_device();
res_cfg.is_output_on_device = output.is_on_device();
res_cfg
}
pub fn add_scalars<F>(
a: &(impl HostOrDeviceSlice<F> + ?Sized),
@@ -170,6 +251,31 @@ where
<<F as FieldImpl>::Config as VecOps<F>>::transpose(input, row_size, column_size, output, ctx, on_device, is_async)
}
pub fn bit_reverse<F>(
input: &(impl HostOrDeviceSlice<F> + ?Sized),
cfg: &BitReverseConfig,
output: &mut (impl HostOrDeviceSlice<F> + ?Sized),
) -> IcicleResult<()>
where
F: FieldImpl,
<F as FieldImpl>::Config: VecOps<F>,
{
let cfg = check_bit_reverse_args(input, cfg, output);
<<F as FieldImpl>::Config as VecOps<F>>::bit_reverse(input, &cfg, output)
}
pub fn bit_reverse_inplace<F>(
input: &mut (impl HostOrDeviceSlice<F> + ?Sized),
cfg: &BitReverseConfig,
) -> IcicleResult<()>
where
F: FieldImpl,
<F as FieldImpl>::Config: VecOps<F>,
{
let cfg = check_bit_reverse_args(input, cfg, input);
<<F as FieldImpl>::Config as VecOps<F>>::bit_reverse_inplace(input, &cfg)
}
#[macro_export]
macro_rules! impl_vec_ops_field {
(
@@ -180,6 +286,7 @@ macro_rules! impl_vec_ops_field {
) => {
mod $field_prefix_ident {
use crate::vec_ops::{$field, CudaError, DeviceContext, HostOrDeviceSlice};
use icicle_core::vec_ops::BitReverseConfig;
use icicle_core::vec_ops::VecOpsConfig;
extern "C" {
@@ -220,6 +327,14 @@ macro_rules! impl_vec_ops_field {
on_device: bool,
is_async: bool,
) -> CudaError;
#[link_name = concat!($field_prefix, "_bit_reverse_cuda")]
pub(crate) fn bit_reverse_cuda(
input: *const $field,
size: u64,
config: *const BitReverseConfig,
output: *mut $field,
) -> CudaError;
}
}
@@ -300,6 +415,37 @@ macro_rules! impl_vec_ops_field {
.wrap()
}
}
fn bit_reverse(
input: &(impl HostOrDeviceSlice<$field> + ?Sized),
cfg: &BitReverseConfig,
output: &mut (impl HostOrDeviceSlice<$field> + ?Sized),
) -> IcicleResult<()> {
unsafe {
$field_prefix_ident::bit_reverse_cuda(
input.as_ptr(),
input.len() as u64,
cfg as *const BitReverseConfig,
output.as_mut_ptr(),
)
.wrap()
}
}
fn bit_reverse_inplace(
input: &mut (impl HostOrDeviceSlice<$field> + ?Sized),
cfg: &BitReverseConfig,
) -> IcicleResult<()> {
unsafe {
$field_prefix_ident::bit_reverse_cuda(
input.as_ptr(),
input.len() as u64,
cfg as *const BitReverseConfig,
input.as_mut_ptr(),
)
.wrap()
}
}
}
};
}
@@ -313,5 +459,14 @@ macro_rules! impl_vec_add_tests {
pub fn test_vec_add_scalars() {
check_vec_ops_scalars::<$field>()
}
#[test]
pub fn test_bit_reverse() {
check_bit_reverse::<$field>()
}
#[test]
pub fn test_bit_reverse_inplace() {
check_bit_reverse_inplace::<$field>()
}
};
}

View File

@@ -1,6 +1,9 @@
use crate::traits::GenerateRandom;
use crate::vec_ops::{add_scalars, mul_scalars, sub_scalars, FieldImpl, VecOps, VecOpsConfig};
use icicle_cuda_runtime::memory::HostSlice;
use crate::vec_ops::{
add_scalars, bit_reverse, bit_reverse_inplace, mul_scalars, sub_scalars, BitReverseConfig, FieldImpl, VecOps,
VecOpsConfig,
};
use icicle_cuda_runtime::memory::{DeviceVec, HostSlice};
pub fn check_vec_ops_scalars<F: FieldImpl>()
where
@@ -32,3 +35,65 @@ where
assert_eq!(a[0], result3[0]);
}
pub fn check_bit_reverse<F: FieldImpl>()
where
<F as FieldImpl>::Config: VecOps<F> + GenerateRandom<F>,
{
const LOG_SIZE: u32 = 20;
const TEST_SIZE: usize = 1 << LOG_SIZE;
let input_vec = F::Config::generate_random(TEST_SIZE);
let input = HostSlice::from_slice(&input_vec);
let mut intermediate = DeviceVec::<F>::cuda_malloc(TEST_SIZE).unwrap();
let cfg = BitReverseConfig::default();
bit_reverse(input, &cfg, &mut intermediate[..]).unwrap();
let mut intermediate_host = vec![F::one(); TEST_SIZE];
intermediate
.copy_to_host(HostSlice::from_mut_slice(&mut intermediate_host[..]))
.unwrap();
let index_reverser = |i: usize| i.reverse_bits() >> (usize::BITS - LOG_SIZE);
intermediate_host
.iter()
.enumerate()
.for_each(|(i, val)| assert_eq!(val, &input_vec[index_reverser(i)]));
let mut result = vec![F::one(); TEST_SIZE];
let result = HostSlice::from_mut_slice(&mut result);
let cfg = BitReverseConfig::default();
bit_reverse(&intermediate[..], &cfg, result).unwrap();
assert_eq!(input.as_slice(), result.as_slice());
}
pub fn check_bit_reverse_inplace<F: FieldImpl>()
where
<F as FieldImpl>::Config: VecOps<F> + GenerateRandom<F>,
{
const LOG_SIZE: u32 = 20;
const TEST_SIZE: usize = 1 << LOG_SIZE;
let input_vec = F::Config::generate_random(TEST_SIZE);
let input = HostSlice::from_slice(&input_vec);
let mut intermediate = DeviceVec::<F>::cuda_malloc(TEST_SIZE).unwrap();
intermediate
.copy_from_host(&input)
.unwrap();
let cfg = BitReverseConfig::default();
bit_reverse_inplace(&mut intermediate[..], &cfg).unwrap();
let mut intermediate_host = vec![F::one(); TEST_SIZE];
intermediate
.copy_to_host(HostSlice::from_mut_slice(&mut intermediate_host[..]))
.unwrap();
let index_reverser = |i: usize| i.reverse_bits() >> (usize::BITS - LOG_SIZE);
intermediate_host
.iter()
.enumerate()
.for_each(|(i, val)| assert_eq!(val, &input_vec[index_reverser(i)]));
bit_reverse_inplace(&mut intermediate[..], &cfg).unwrap();
let mut result_host = vec![F::one(); TEST_SIZE];
intermediate
.copy_to_host(HostSlice::from_mut_slice(&mut result_host[..]))
.unwrap();
assert_eq!(input.as_slice(), result_host.as_slice());
}

View File

@@ -5,7 +5,7 @@ use crate::curve::{ScalarCfg, ScalarField};
use icicle_core::error::IcicleResult;
use icicle_core::impl_vec_ops_field;
use icicle_core::traits::IcicleResultWrap;
use icicle_core::vec_ops::{VecOps, VecOpsConfig};
use icicle_core::vec_ops::{BitReverseConfig, VecOps, VecOpsConfig};
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::error::CudaError;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;

View File

@@ -3,7 +3,7 @@ use crate::curve::{ScalarCfg, ScalarField};
use icicle_core::error::IcicleResult;
use icicle_core::impl_vec_ops_field;
use icicle_core::traits::IcicleResultWrap;
use icicle_core::vec_ops::{VecOps, VecOpsConfig};
use icicle_core::vec_ops::{BitReverseConfig, VecOps, VecOpsConfig};
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::error::CudaError;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;

View File

@@ -3,7 +3,7 @@ use crate::curve::{ScalarCfg, ScalarField};
use icicle_core::error::IcicleResult;
use icicle_core::impl_vec_ops_field;
use icicle_core::traits::IcicleResultWrap;
use icicle_core::vec_ops::{VecOps, VecOpsConfig};
use icicle_core::vec_ops::{BitReverseConfig, VecOps, VecOpsConfig};
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::error::CudaError;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;

View File

@@ -3,7 +3,7 @@ use crate::curve::{ScalarCfg, ScalarField};
use icicle_core::error::IcicleResult;
use icicle_core::impl_vec_ops_field;
use icicle_core::traits::IcicleResultWrap;
use icicle_core::vec_ops::{VecOps, VecOpsConfig};
use icicle_core::vec_ops::{BitReverseConfig, VecOps, VecOpsConfig};
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::error::CudaError;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;

View File

@@ -3,7 +3,7 @@ use crate::field::{ExtensionCfg, ExtensionField, ScalarCfg, ScalarField};
use icicle_core::error::IcicleResult;
use icicle_core::impl_vec_ops_field;
use icicle_core::traits::IcicleResultWrap;
use icicle_core::vec_ops::{VecOps, VecOpsConfig};
use icicle_core::vec_ops::{BitReverseConfig, VecOps, VecOpsConfig};
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::error::CudaError;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
@@ -18,6 +18,7 @@ pub(crate) mod tests {
use icicle_core::vec_ops::tests::*;
impl_vec_add_tests!(ScalarField);
mod extension {
use super::*;

View File

@@ -3,7 +3,7 @@ use crate::field::{ScalarCfg, ScalarField};
use icicle_core::error::IcicleResult;
use icicle_core::impl_vec_ops_field;
use icicle_core::traits::IcicleResultWrap;
use icicle_core::vec_ops::{VecOps, VecOpsConfig};
use icicle_core::vec_ops::{BitReverseConfig, VecOps, VecOpsConfig};
use icicle_cuda_runtime::device_context::DeviceContext;
use icicle_cuda_runtime::error::CudaError;
use icicle_cuda_runtime::memory::HostOrDeviceSlice;