mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-08 20:48:06 -05:00
bit reverse (#528)
This PR adds bit reverse operation support to icicle
This commit is contained in:
@@ -40,6 +40,40 @@ impl<'a> VecOpsConfig<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BitReverseConfig<'a> {
|
||||
/// Details related to the device such as its id and stream id. See [DeviceContext](@ref device_context::DeviceContext).
|
||||
pub ctx: DeviceContext<'a>,
|
||||
|
||||
/// True if inputs are on device and false if they're on host. Default value: false.
|
||||
pub is_input_on_device: bool,
|
||||
|
||||
/// If true, output is preserved on device, otherwise on host. Default value: false.
|
||||
pub is_output_on_device: bool,
|
||||
|
||||
/// Whether to run the vector operations asynchronously. If set to `true`, the functions will be non-blocking and you'd need to synchronize
|
||||
/// it explicitly by running `stream.synchronize()`. If set to false, the functions will block the current CPU thread.
|
||||
pub is_async: bool,
|
||||
}
|
||||
|
||||
impl<'a> Default for BitReverseConfig<'a> {
|
||||
fn default() -> Self {
|
||||
Self::default_for_device(DEFAULT_DEVICE_ID)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> BitReverseConfig<'a> {
|
||||
pub fn default_for_device(device_id: usize) -> Self {
|
||||
BitReverseConfig {
|
||||
ctx: DeviceContext::default_for_device(device_id),
|
||||
is_input_on_device: false,
|
||||
is_output_on_device: false,
|
||||
is_async: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub trait VecOps<F> {
|
||||
fn add(
|
||||
@@ -72,6 +106,17 @@ pub trait VecOps<F> {
|
||||
on_device: bool,
|
||||
is_async: bool,
|
||||
) -> IcicleResult<()>;
|
||||
|
||||
fn bit_reverse(
|
||||
input: &(impl HostOrDeviceSlice<F> + ?Sized),
|
||||
cfg: &BitReverseConfig,
|
||||
output: &mut (impl HostOrDeviceSlice<F> + ?Sized),
|
||||
) -> IcicleResult<()>;
|
||||
|
||||
fn bit_reverse_inplace(
|
||||
input: &mut (impl HostOrDeviceSlice<F> + ?Sized),
|
||||
cfg: &BitReverseConfig,
|
||||
) -> IcicleResult<()>;
|
||||
}
|
||||
|
||||
fn check_vec_ops_args<'a, F>(
|
||||
@@ -111,6 +156,42 @@ fn check_vec_ops_args<'a, F>(
|
||||
res_cfg.is_result_on_device = result.is_on_device();
|
||||
res_cfg
|
||||
}
|
||||
fn check_bit_reverse_args<'a, F>(
|
||||
input: &(impl HostOrDeviceSlice<F> + ?Sized),
|
||||
cfg: &BitReverseConfig<'a>,
|
||||
output: &(impl HostOrDeviceSlice<F> + ?Sized),
|
||||
) -> BitReverseConfig<'a> {
|
||||
if input.len() & (input.len() - 1) != 0 {
|
||||
panic!("input length must be a power of 2, input length: {}", input.len());
|
||||
}
|
||||
if input.len() != output.len() {
|
||||
panic!(
|
||||
"input and output lengths {}; {} do not match",
|
||||
input.len(),
|
||||
output.len()
|
||||
);
|
||||
}
|
||||
let ctx_device_id = cfg
|
||||
.ctx
|
||||
.device_id;
|
||||
if let Some(device_id) = input.device_id() {
|
||||
assert_eq!(
|
||||
device_id, ctx_device_id,
|
||||
"Device ids in input and context are different"
|
||||
);
|
||||
}
|
||||
if let Some(device_id) = output.device_id() {
|
||||
assert_eq!(
|
||||
device_id, ctx_device_id,
|
||||
"Device ids in output and context are different"
|
||||
);
|
||||
}
|
||||
check_device(ctx_device_id);
|
||||
let mut res_cfg = cfg.clone();
|
||||
res_cfg.is_input_on_device = input.is_on_device();
|
||||
res_cfg.is_output_on_device = output.is_on_device();
|
||||
res_cfg
|
||||
}
|
||||
|
||||
pub fn add_scalars<F>(
|
||||
a: &(impl HostOrDeviceSlice<F> + ?Sized),
|
||||
@@ -170,6 +251,31 @@ where
|
||||
<<F as FieldImpl>::Config as VecOps<F>>::transpose(input, row_size, column_size, output, ctx, on_device, is_async)
|
||||
}
|
||||
|
||||
pub fn bit_reverse<F>(
|
||||
input: &(impl HostOrDeviceSlice<F> + ?Sized),
|
||||
cfg: &BitReverseConfig,
|
||||
output: &mut (impl HostOrDeviceSlice<F> + ?Sized),
|
||||
) -> IcicleResult<()>
|
||||
where
|
||||
F: FieldImpl,
|
||||
<F as FieldImpl>::Config: VecOps<F>,
|
||||
{
|
||||
let cfg = check_bit_reverse_args(input, cfg, output);
|
||||
<<F as FieldImpl>::Config as VecOps<F>>::bit_reverse(input, &cfg, output)
|
||||
}
|
||||
|
||||
pub fn bit_reverse_inplace<F>(
|
||||
input: &mut (impl HostOrDeviceSlice<F> + ?Sized),
|
||||
cfg: &BitReverseConfig,
|
||||
) -> IcicleResult<()>
|
||||
where
|
||||
F: FieldImpl,
|
||||
<F as FieldImpl>::Config: VecOps<F>,
|
||||
{
|
||||
let cfg = check_bit_reverse_args(input, cfg, input);
|
||||
<<F as FieldImpl>::Config as VecOps<F>>::bit_reverse_inplace(input, &cfg)
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! impl_vec_ops_field {
|
||||
(
|
||||
@@ -180,6 +286,7 @@ macro_rules! impl_vec_ops_field {
|
||||
) => {
|
||||
mod $field_prefix_ident {
|
||||
use crate::vec_ops::{$field, CudaError, DeviceContext, HostOrDeviceSlice};
|
||||
use icicle_core::vec_ops::BitReverseConfig;
|
||||
use icicle_core::vec_ops::VecOpsConfig;
|
||||
|
||||
extern "C" {
|
||||
@@ -220,6 +327,14 @@ macro_rules! impl_vec_ops_field {
|
||||
on_device: bool,
|
||||
is_async: bool,
|
||||
) -> CudaError;
|
||||
|
||||
#[link_name = concat!($field_prefix, "_bit_reverse_cuda")]
|
||||
pub(crate) fn bit_reverse_cuda(
|
||||
input: *const $field,
|
||||
size: u64,
|
||||
config: *const BitReverseConfig,
|
||||
output: *mut $field,
|
||||
) -> CudaError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -300,6 +415,37 @@ macro_rules! impl_vec_ops_field {
|
||||
.wrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn bit_reverse(
|
||||
input: &(impl HostOrDeviceSlice<$field> + ?Sized),
|
||||
cfg: &BitReverseConfig,
|
||||
output: &mut (impl HostOrDeviceSlice<$field> + ?Sized),
|
||||
) -> IcicleResult<()> {
|
||||
unsafe {
|
||||
$field_prefix_ident::bit_reverse_cuda(
|
||||
input.as_ptr(),
|
||||
input.len() as u64,
|
||||
cfg as *const BitReverseConfig,
|
||||
output.as_mut_ptr(),
|
||||
)
|
||||
.wrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn bit_reverse_inplace(
|
||||
input: &mut (impl HostOrDeviceSlice<$field> + ?Sized),
|
||||
cfg: &BitReverseConfig,
|
||||
) -> IcicleResult<()> {
|
||||
unsafe {
|
||||
$field_prefix_ident::bit_reverse_cuda(
|
||||
input.as_ptr(),
|
||||
input.len() as u64,
|
||||
cfg as *const BitReverseConfig,
|
||||
input.as_mut_ptr(),
|
||||
)
|
||||
.wrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -313,5 +459,14 @@ macro_rules! impl_vec_add_tests {
|
||||
pub fn test_vec_add_scalars() {
|
||||
check_vec_ops_scalars::<$field>()
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn test_bit_reverse() {
|
||||
check_bit_reverse::<$field>()
|
||||
}
|
||||
#[test]
|
||||
pub fn test_bit_reverse_inplace() {
|
||||
check_bit_reverse_inplace::<$field>()
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
use crate::traits::GenerateRandom;
|
||||
use crate::vec_ops::{add_scalars, mul_scalars, sub_scalars, FieldImpl, VecOps, VecOpsConfig};
|
||||
use icicle_cuda_runtime::memory::HostSlice;
|
||||
use crate::vec_ops::{
|
||||
add_scalars, bit_reverse, bit_reverse_inplace, mul_scalars, sub_scalars, BitReverseConfig, FieldImpl, VecOps,
|
||||
VecOpsConfig,
|
||||
};
|
||||
use icicle_cuda_runtime::memory::{DeviceVec, HostSlice};
|
||||
|
||||
pub fn check_vec_ops_scalars<F: FieldImpl>()
|
||||
where
|
||||
@@ -32,3 +35,65 @@ where
|
||||
|
||||
assert_eq!(a[0], result3[0]);
|
||||
}
|
||||
|
||||
pub fn check_bit_reverse<F: FieldImpl>()
|
||||
where
|
||||
<F as FieldImpl>::Config: VecOps<F> + GenerateRandom<F>,
|
||||
{
|
||||
const LOG_SIZE: u32 = 20;
|
||||
const TEST_SIZE: usize = 1 << LOG_SIZE;
|
||||
let input_vec = F::Config::generate_random(TEST_SIZE);
|
||||
let input = HostSlice::from_slice(&input_vec);
|
||||
let mut intermediate = DeviceVec::<F>::cuda_malloc(TEST_SIZE).unwrap();
|
||||
let cfg = BitReverseConfig::default();
|
||||
bit_reverse(input, &cfg, &mut intermediate[..]).unwrap();
|
||||
|
||||
let mut intermediate_host = vec![F::one(); TEST_SIZE];
|
||||
intermediate
|
||||
.copy_to_host(HostSlice::from_mut_slice(&mut intermediate_host[..]))
|
||||
.unwrap();
|
||||
let index_reverser = |i: usize| i.reverse_bits() >> (usize::BITS - LOG_SIZE);
|
||||
intermediate_host
|
||||
.iter()
|
||||
.enumerate()
|
||||
.for_each(|(i, val)| assert_eq!(val, &input_vec[index_reverser(i)]));
|
||||
|
||||
let mut result = vec![F::one(); TEST_SIZE];
|
||||
let result = HostSlice::from_mut_slice(&mut result);
|
||||
let cfg = BitReverseConfig::default();
|
||||
bit_reverse(&intermediate[..], &cfg, result).unwrap();
|
||||
assert_eq!(input.as_slice(), result.as_slice());
|
||||
}
|
||||
|
||||
pub fn check_bit_reverse_inplace<F: FieldImpl>()
|
||||
where
|
||||
<F as FieldImpl>::Config: VecOps<F> + GenerateRandom<F>,
|
||||
{
|
||||
const LOG_SIZE: u32 = 20;
|
||||
const TEST_SIZE: usize = 1 << LOG_SIZE;
|
||||
let input_vec = F::Config::generate_random(TEST_SIZE);
|
||||
let input = HostSlice::from_slice(&input_vec);
|
||||
let mut intermediate = DeviceVec::<F>::cuda_malloc(TEST_SIZE).unwrap();
|
||||
intermediate
|
||||
.copy_from_host(&input)
|
||||
.unwrap();
|
||||
let cfg = BitReverseConfig::default();
|
||||
bit_reverse_inplace(&mut intermediate[..], &cfg).unwrap();
|
||||
|
||||
let mut intermediate_host = vec![F::one(); TEST_SIZE];
|
||||
intermediate
|
||||
.copy_to_host(HostSlice::from_mut_slice(&mut intermediate_host[..]))
|
||||
.unwrap();
|
||||
let index_reverser = |i: usize| i.reverse_bits() >> (usize::BITS - LOG_SIZE);
|
||||
intermediate_host
|
||||
.iter()
|
||||
.enumerate()
|
||||
.for_each(|(i, val)| assert_eq!(val, &input_vec[index_reverser(i)]));
|
||||
|
||||
bit_reverse_inplace(&mut intermediate[..], &cfg).unwrap();
|
||||
let mut result_host = vec![F::one(); TEST_SIZE];
|
||||
intermediate
|
||||
.copy_to_host(HostSlice::from_mut_slice(&mut result_host[..]))
|
||||
.unwrap();
|
||||
assert_eq!(input.as_slice(), result_host.as_slice());
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::curve::{ScalarCfg, ScalarField};
|
||||
use icicle_core::error::IcicleResult;
|
||||
use icicle_core::impl_vec_ops_field;
|
||||
use icicle_core::traits::IcicleResultWrap;
|
||||
use icicle_core::vec_ops::{VecOps, VecOpsConfig};
|
||||
use icicle_core::vec_ops::{BitReverseConfig, VecOps, VecOpsConfig};
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::curve::{ScalarCfg, ScalarField};
|
||||
use icicle_core::error::IcicleResult;
|
||||
use icicle_core::impl_vec_ops_field;
|
||||
use icicle_core::traits::IcicleResultWrap;
|
||||
use icicle_core::vec_ops::{VecOps, VecOpsConfig};
|
||||
use icicle_core::vec_ops::{BitReverseConfig, VecOps, VecOpsConfig};
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::curve::{ScalarCfg, ScalarField};
|
||||
use icicle_core::error::IcicleResult;
|
||||
use icicle_core::impl_vec_ops_field;
|
||||
use icicle_core::traits::IcicleResultWrap;
|
||||
use icicle_core::vec_ops::{VecOps, VecOpsConfig};
|
||||
use icicle_core::vec_ops::{BitReverseConfig, VecOps, VecOpsConfig};
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::curve::{ScalarCfg, ScalarField};
|
||||
use icicle_core::error::IcicleResult;
|
||||
use icicle_core::impl_vec_ops_field;
|
||||
use icicle_core::traits::IcicleResultWrap;
|
||||
use icicle_core::vec_ops::{VecOps, VecOpsConfig};
|
||||
use icicle_core::vec_ops::{BitReverseConfig, VecOps, VecOpsConfig};
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::field::{ExtensionCfg, ExtensionField, ScalarCfg, ScalarField};
|
||||
use icicle_core::error::IcicleResult;
|
||||
use icicle_core::impl_vec_ops_field;
|
||||
use icicle_core::traits::IcicleResultWrap;
|
||||
use icicle_core::vec_ops::{VecOps, VecOpsConfig};
|
||||
use icicle_core::vec_ops::{BitReverseConfig, VecOps, VecOpsConfig};
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
@@ -18,6 +18,7 @@ pub(crate) mod tests {
|
||||
use icicle_core::vec_ops::tests::*;
|
||||
|
||||
impl_vec_add_tests!(ScalarField);
|
||||
|
||||
mod extension {
|
||||
use super::*;
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ use crate::field::{ScalarCfg, ScalarField};
|
||||
use icicle_core::error::IcicleResult;
|
||||
use icicle_core::impl_vec_ops_field;
|
||||
use icicle_core::traits::IcicleResultWrap;
|
||||
use icicle_core::vec_ops::{VecOps, VecOpsConfig};
|
||||
use icicle_core::vec_ops::{BitReverseConfig, VecOps, VecOpsConfig};
|
||||
use icicle_cuda_runtime::device_context::DeviceContext;
|
||||
use icicle_cuda_runtime::error::CudaError;
|
||||
use icicle_cuda_runtime::memory::HostOrDeviceSlice;
|
||||
|
||||
Reference in New Issue
Block a user