transpose kernel in vec_ops and rust binding (#462)

## Describe the changes

This PR adds an extern C link to the transpose kernel, now in
vec_ops.cu.
Also Rust binding, and I updated the test check_ntt_batch to use the new
transpose function.
The test passes.

## Linked Issues

Resolves #

---------

Co-authored-by: LeonHibnik <leon@ingonyama.com>
This commit is contained in:
Vlad
2024-04-09 07:47:33 +02:00
committed by GitHub
parent 4c9b3c00a5
commit 4a35eece51
4 changed files with 179 additions and 6 deletions

View File

@@ -9,6 +9,7 @@ use rayon::iter::{IntoParallelIterator, ParallelIterator};
use crate::{
ntt::{initialize_domain, initialize_domain_fast_twiddles_mode, ntt, ntt_inplace, NTTDir, NttAlgorithm, Ordering},
traits::{ArkConvertible, FieldImpl, GenerateRandom},
vec_ops::{transpose_matrix, VecOps},
};
use super::{NTTConfig, NTT};
@@ -235,6 +236,7 @@ where
pub fn check_ntt_batch<F: FieldImpl>()
where
<F as FieldImpl>::Config: NTT<F> + GenerateRandom<F>,
<F as FieldImpl>::Config: VecOps<F>,
{
let test_sizes = [1 << 4, 1 << 12];
let batch_sizes = [1, 1 << 4, 100];
@@ -278,18 +280,38 @@ where
}
}
let row_size = test_size as u32;
let column_size = batch_size as u32;
let on_device = false;
let is_async = false;
// for now, columns batching only works with MixedRadix NTT
config.batch_size = batch_size as i32;
config.columns_batch = true;
let transposed_input =
HostOrDeviceSlice::on_host(transpose_flattened_matrix(&scalars[..], batch_size));
let mut transposed_input = HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
transpose_matrix(
&scalars,
row_size,
column_size,
&mut transposed_input,
&config.ctx,
on_device,
is_async,
)
.unwrap();
let mut col_batch_ntt_result =
HostOrDeviceSlice::on_host(vec![F::zero(); batch_size * test_size]);
ntt(&transposed_input, is_inverse, &config, &mut col_batch_ntt_result).unwrap();
assert_eq!(
batch_ntt_result[..],
transpose_flattened_matrix(&col_batch_ntt_result[..], test_size)
);
transpose_matrix(
&col_batch_ntt_result,
column_size,
row_size,
&mut transposed_input,
&config.ctx,
on_device,
is_async,
)
.unwrap();
assert_eq!(batch_ntt_result[..], *transposed_input.as_slice());
config.columns_batch = false;
}
}

View File

@@ -63,6 +63,16 @@ pub trait VecOps<F> {
result: &mut HostOrDeviceSlice<F>,
cfg: &VecOpsConfig,
) -> IcicleResult<()>;
fn transpose(
input: &HostOrDeviceSlice<F>,
row_size: u32,
column_size: u32,
output: &mut HostOrDeviceSlice<F>,
ctx: &DeviceContext,
on_device: bool,
is_async: bool,
) -> IcicleResult<()>;
}
fn check_vec_ops_args<F>(a: &HostOrDeviceSlice<F>, b: &HostOrDeviceSlice<F>, result: &mut HostOrDeviceSlice<F>) {
@@ -118,6 +128,22 @@ where
<<F as FieldImpl>::Config as VecOps<F>>::mul(a, b, result, cfg)
}
pub fn transpose_matrix<F>(
input: &HostOrDeviceSlice<F>,
row_size: u32,
column_size: u32,
output: &mut HostOrDeviceSlice<F>,
ctx: &DeviceContext,
on_device: bool,
is_async: bool,
) -> IcicleResult<()>
where
F: FieldImpl,
<F as FieldImpl>::Config: VecOps<F>,
{
<<F as FieldImpl>::Config as VecOps<F>>::transpose(input, row_size, column_size, output, ctx, on_device, is_async)
}
#[macro_export]
macro_rules! impl_vec_ops_field {
(
@@ -157,6 +183,17 @@ macro_rules! impl_vec_ops_field {
cfg: *const VecOpsConfig,
result: *mut $field,
) -> CudaError;
#[link_name = concat!($field_prefix, "TransposeMatrix")]
pub(crate) fn transpose_cuda(
input: *const $field,
row_size: u32,
column_size: u32,
output: *mut $field,
ctx: *const DeviceContext,
on_device: bool,
is_async: bool,
) -> CudaError;
}
}
@@ -214,6 +251,29 @@ macro_rules! impl_vec_ops_field {
.wrap()
}
}
fn transpose(
input: &HostOrDeviceSlice<$field>,
row_size: u32,
column_size: u32,
output: &mut HostOrDeviceSlice<$field>,
ctx: &DeviceContext,
on_device: bool,
is_async: bool,
) -> IcicleResult<()> {
unsafe {
$field_prefix_ident::transpose_cuda(
input.as_ptr(),
row_size,
column_size,
output.as_mut_ptr(),
ctx as *const _ as *const DeviceContext,
on_device,
is_async,
)
.wrap()
}
}
}
};
}