mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-09 23:48:10 -05:00
transpose kernel in vec_ops and rust binding (#462)
## Describe the changes This PR adds an extern C link to the transpose kernel, now in vec_ops.cu. Also Rust binding, and I updated the test check_ntt_batch to use the new transpose function. The test passes. ## Linked Issues Resolves # --------- Co-authored-by: LeonHibnik <leon@ingonyama.com>
This commit is contained in:
@@ -9,6 +9,7 @@ use rayon::iter::{IntoParallelIterator, ParallelIterator};
|
||||
use crate::{
|
||||
ntt::{initialize_domain, initialize_domain_fast_twiddles_mode, ntt, ntt_inplace, NTTDir, NttAlgorithm, Ordering},
|
||||
traits::{ArkConvertible, FieldImpl, GenerateRandom},
|
||||
vec_ops::{transpose_matrix, VecOps},
|
||||
};
|
||||
|
||||
use super::{NTTConfig, NTT};
|
||||
@@ -235,6 +236,7 @@ where
|
||||
pub fn check_ntt_batch<F: FieldImpl>()
|
||||
where
|
||||
<F as FieldImpl>::Config: NTT<F> + GenerateRandom<F>,
|
||||
<F as FieldImpl>::Config: VecOps<F>,
|
||||
{
|
||||
let test_sizes = [1 << 4, 1 << 12];
|
||||
let batch_sizes = [1, 1 << 4, 100];
|
||||
@@ -278,18 +280,38 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
let row_size = test_size as u32;
|
||||
let column_size = batch_size as u32;
|
||||
let on_device = false;
|
||||
let is_async = false;
|
||||
// for now, columns batching only works with MixedRadix NTT
|
||||
config.batch_size = batch_size as i32;
|
||||
config.columns_batch = true;
|
||||
let transposed_input =
|
||||
HostOrDeviceSlice::on_host(transpose_flattened_matrix(&scalars[..], batch_size));
|
||||
let mut transposed_input = HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
|
||||
transpose_matrix(
|
||||
&scalars,
|
||||
row_size,
|
||||
column_size,
|
||||
&mut transposed_input,
|
||||
&config.ctx,
|
||||
on_device,
|
||||
is_async,
|
||||
)
|
||||
.unwrap();
|
||||
let mut col_batch_ntt_result =
|
||||
HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
|
||||
ntt(&transposed_input, is_inverse, &config, &mut col_batch_ntt_result).unwrap();
|
||||
assert_eq!(
|
||||
batch_ntt_result[..],
|
||||
transpose_flattened_matrix(&col_batch_ntt_result[..], test_size)
|
||||
);
|
||||
transpose_matrix(
|
||||
&col_batch_ntt_result,
|
||||
column_size,
|
||||
row_size,
|
||||
&mut transposed_input,
|
||||
&config.ctx,
|
||||
on_device,
|
||||
is_async,
|
||||
)
|
||||
.unwrap();
|
||||
assert_eq!(batch_ntt_result[..], *transposed_input.as_slice());
|
||||
config.columns_batch = false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -63,6 +63,16 @@ pub trait VecOps<F> {
|
||||
result: &mut HostOrDeviceSlice<F>,
|
||||
cfg: &VecOpsConfig,
|
||||
) -> IcicleResult<()>;
|
||||
|
||||
fn transpose(
|
||||
input: &HostOrDeviceSlice<F>,
|
||||
row_size: u32,
|
||||
column_size: u32,
|
||||
output: &mut HostOrDeviceSlice<F>,
|
||||
ctx: &DeviceContext,
|
||||
on_device: bool,
|
||||
is_async: bool,
|
||||
) -> IcicleResult<()>;
|
||||
}
|
||||
|
||||
fn check_vec_ops_args<F>(a: &HostOrDeviceSlice<F>, b: &HostOrDeviceSlice<F>, result: &mut HostOrDeviceSlice<F>) {
|
||||
@@ -118,6 +128,22 @@ where
|
||||
<<F as FieldImpl>::Config as VecOps<F>>::mul(a, b, result, cfg)
|
||||
}
|
||||
|
||||
pub fn transpose_matrix<F>(
|
||||
input: &HostOrDeviceSlice<F>,
|
||||
row_size: u32,
|
||||
column_size: u32,
|
||||
output: &mut HostOrDeviceSlice<F>,
|
||||
ctx: &DeviceContext,
|
||||
on_device: bool,
|
||||
is_async: bool,
|
||||
) -> IcicleResult<()>
|
||||
where
|
||||
F: FieldImpl,
|
||||
<F as FieldImpl>::Config: VecOps<F>,
|
||||
{
|
||||
<<F as FieldImpl>::Config as VecOps<F>>::transpose(input, row_size, column_size, output, ctx, on_device, is_async)
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! impl_vec_ops_field {
|
||||
(
|
||||
@@ -157,6 +183,17 @@ macro_rules! impl_vec_ops_field {
|
||||
cfg: *const VecOpsConfig,
|
||||
result: *mut $field,
|
||||
) -> CudaError;
|
||||
|
||||
#[link_name = concat!($field_prefix, "TransposeMatrix")]
|
||||
pub(crate) fn transpose_cuda(
|
||||
input: *const $field,
|
||||
row_size: u32,
|
||||
column_size: u32,
|
||||
output: *mut $field,
|
||||
ctx: *const DeviceContext,
|
||||
on_device: bool,
|
||||
is_async: bool,
|
||||
) -> CudaError;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -214,6 +251,29 @@ macro_rules! impl_vec_ops_field {
|
||||
.wrap()
|
||||
}
|
||||
}
|
||||
|
||||
fn transpose(
|
||||
input: &HostOrDeviceSlice<$field>,
|
||||
row_size: u32,
|
||||
column_size: u32,
|
||||
output: &mut HostOrDeviceSlice<$field>,
|
||||
ctx: &DeviceContext,
|
||||
on_device: bool,
|
||||
is_async: bool,
|
||||
) -> IcicleResult<()> {
|
||||
unsafe {
|
||||
$field_prefix_ident::transpose_cuda(
|
||||
input.as_ptr(),
|
||||
row_size,
|
||||
column_size,
|
||||
output.as_mut_ptr(),
|
||||
ctx as *const _ as *const DeviceContext,
|
||||
on_device,
|
||||
is_async,
|
||||
)
|
||||
.wrap()
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user