Compare commits

...

2 Commits

Author SHA1 Message Date
Vitalii
ae5c0681ad batch mult 2023-07-12 17:42:12 +02:00
Vitalii
58a2492d60 enforce c++17 standard 2023-06-26 21:10:07 +02:00
3 changed files with 121 additions and 1 deletions

View File

@@ -20,6 +20,7 @@ fn main() {
nvcc.cuda(true);
nvcc.debug(false);
nvcc.flag(&arch);
nvcc.flag("--std=c++17");
nvcc.files([
"./icicle/appUtils/ntt/ntt.cu",
"./icicle/appUtils/msm/msm.cu",

View File

@@ -93,6 +93,83 @@ int matrix_mod_mult(E *matrix_elements, E *vector_elements, E *result, size_t di
return 0;
}
template <typename E, typename S>
__global__ void batch_vector_mult_kernel(E *element_vec, S *mult_vec, unsigned n_mult, unsigned batch_size)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid < n_mult * batch_size)
{
int mult_id = tid % n_mult;
element_vec[tid] = mult_vec[mult_id] * element_vec[tid];
}
}
template <typename E, typename S>
int batch_vector_mult_template(E *element_vec, S *mult_vec, unsigned n_mult, unsigned batch_size)
{
// Set the grid and block dimensions
int NUM_THREADS = MAX_THREADS_PER_BLOCK;
int NUM_BLOCKS = (n_mult * batch_size + NUM_THREADS - 1) / NUM_THREADS;
// Allocate memory on the device for the input vectors, the output vector, and the modulus
S *d_mult_vec;
E *d_element_vec;
size_t n_mult_size = n_mult * sizeof(S);
size_t full_size = n_mult * batch_size * sizeof(E);
cudaMalloc(&d_mult_vec, n_mult_size);
cudaMalloc(&d_element_vec, full_size);
// Copy the input vectors and the modulus from the host to the device
cudaMemcpy(d_mult_vec, mult_vec, n_mult_size, cudaMemcpyHostToDevice);
cudaMemcpy(d_element_vec, element_vec, full_size, cudaMemcpyHostToDevice);
batch_vector_mult_kernel<<<NUM_BLOCKS, NUM_THREADS>>>(d_element_vec, d_mult_vec, n_mult, batch_size);
cudaMemcpy(element_vec, d_element_vec, full_size, cudaMemcpyDeviceToHost);
cudaFree(d_mult_vec);
cudaFree(d_element_vec);
return 0;
}
extern "C" int32_t batch_vector_mult_proj_cuda(projective_t *inout,
scalar_t *scalar_vec,
size_t n_scalars,
size_t batch_size,
size_t device_id)
{
try
{
// TODO: device_id
batch_vector_mult_template(inout, scalar_vec, n_scalars, batch_size);
return CUDA_SUCCESS;
}
catch (const std::runtime_error &ex)
{
printf("error %s", ex.what()); // TODO: error code and message
return -1;
}
}
extern "C" int32_t batch_vector_mult_scalar_cuda(scalar_t *inout,
scalar_t *mult_vec,
size_t n_mult,
size_t batch_size,
size_t device_id)
{
try
{
// TODO: device_id
batch_vector_mult_template(inout, mult_vec, n_mult, batch_size);
return CUDA_SUCCESS;
}
catch (const std::runtime_error &ex)
{
printf("error %s", ex.what()); // TODO: error code and message
return -1;
}
}
extern "C" int32_t vec_mod_mult_point(projective_t *inout,
scalar_t *scalar_vec,
size_t n_elments,
@@ -114,7 +191,7 @@ extern "C" int32_t vec_mod_mult_point(projective_t *inout,
extern "C" int32_t vec_mod_mult_scalar(scalar_t *inout,
scalar_t *scalar_vec,
size_t n_elments,
size_t device_id)
size_t device_id) //TODO: unify with batch mult as batch_size=1
{
try
{

View File

@@ -50,6 +50,22 @@ extern "C" {
device_id: usize,
) -> c_int;
fn batch_vector_mult_proj_cuda(
inout: *mut Point,
scalars: *const ScalarField,
n_scalars: usize,
batch_size: usize,
device_id: usize,
) -> c_int;
fn batch_vector_mult_scalar_cuda(
inout: *mut ScalarField,
scalars: *const ScalarField,
n_mult: usize,
batch_size: usize,
device_id: usize,
) -> c_int;
fn matrix_vec_mod_mult(
matrix_flattened: *const ScalarField,
input: *const ScalarField,
@@ -223,6 +239,32 @@ pub fn mult_sc_vec(a: &mut [Scalar], b: &[Scalar], device_id: usize) {
}
}
pub fn mult_p_batch_vec(a: &mut [Point], b: &[Scalar], device_id: usize) {
assert_eq!(a.len() % b.len(), 0);
unsafe {
batch_vector_mult_proj_cuda(
a as *mut _ as *mut Point,
b as *const _ as *const ScalarField,
b.len(),
a.len() / b.len(),
device_id,
);
}
}
pub fn mult_sc_batch_vec(a: &mut [Scalar], b: &[Scalar], device_id: usize) {
assert_eq!(a.len() % b.len(), 0);
unsafe {
batch_vector_mult_scalar_cuda(
a as *mut _ as *mut ScalarField,
b as *const _ as *const ScalarField,
b.len(),
a.len() / b.len(),
device_id,
);
}
}
// Multiply a matrix by a scalar:
// `a` - flattenned matrix;
// `b` - vector to multiply `a` by;