mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-13 09:27:58 -05:00
Compare commits
2 Commits
rust-inter
...
feat/batch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ae5c0681ad | ||
|
|
58a2492d60 |
1
build.rs
1
build.rs
@@ -20,6 +20,7 @@ fn main() {
|
||||
nvcc.cuda(true);
|
||||
nvcc.debug(false);
|
||||
nvcc.flag(&arch);
|
||||
nvcc.flag("--std=c++17");
|
||||
nvcc.files([
|
||||
"./icicle/appUtils/ntt/ntt.cu",
|
||||
"./icicle/appUtils/msm/msm.cu",
|
||||
|
||||
@@ -93,6 +93,83 @@ int matrix_mod_mult(E *matrix_elements, E *vector_elements, E *result, size_t di
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__global__ void batch_vector_mult_kernel(E *element_vec, S *mult_vec, unsigned n_mult, unsigned batch_size)
|
||||
{
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid < n_mult * batch_size)
|
||||
{
|
||||
int mult_id = tid % n_mult;
|
||||
element_vec[tid] = mult_vec[mult_id] * element_vec[tid];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
int batch_vector_mult_template(E *element_vec, S *mult_vec, unsigned n_mult, unsigned batch_size)
|
||||
{
|
||||
// Set the grid and block dimensions
|
||||
int NUM_THREADS = MAX_THREADS_PER_BLOCK;
|
||||
int NUM_BLOCKS = (n_mult * batch_size + NUM_THREADS - 1) / NUM_THREADS;
|
||||
|
||||
// Allocate memory on the device for the input vectors, the output vector, and the modulus
|
||||
S *d_mult_vec;
|
||||
E *d_element_vec;
|
||||
size_t n_mult_size = n_mult * sizeof(S);
|
||||
size_t full_size = n_mult * batch_size * sizeof(E);
|
||||
cudaMalloc(&d_mult_vec, n_mult_size);
|
||||
cudaMalloc(&d_element_vec, full_size);
|
||||
|
||||
// Copy the input vectors and the modulus from the host to the device
|
||||
cudaMemcpy(d_mult_vec, mult_vec, n_mult_size, cudaMemcpyHostToDevice);
|
||||
cudaMemcpy(d_element_vec, element_vec, full_size, cudaMemcpyHostToDevice);
|
||||
|
||||
batch_vector_mult_kernel<<<NUM_BLOCKS, NUM_THREADS>>>(d_element_vec, d_mult_vec, n_mult, batch_size);
|
||||
|
||||
cudaMemcpy(element_vec, d_element_vec, full_size, cudaMemcpyDeviceToHost);
|
||||
|
||||
cudaFree(d_mult_vec);
|
||||
cudaFree(d_element_vec);
|
||||
return 0;
|
||||
}
|
||||
|
||||
extern "C" int32_t batch_vector_mult_proj_cuda(projective_t *inout,
|
||||
scalar_t *scalar_vec,
|
||||
size_t n_scalars,
|
||||
size_t batch_size,
|
||||
size_t device_id)
|
||||
{
|
||||
try
|
||||
{
|
||||
// TODO: device_id
|
||||
batch_vector_mult_template(inout, scalar_vec, n_scalars, batch_size);
|
||||
return CUDA_SUCCESS;
|
||||
}
|
||||
catch (const std::runtime_error &ex)
|
||||
{
|
||||
printf("error %s", ex.what()); // TODO: error code and message
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" int32_t batch_vector_mult_scalar_cuda(scalar_t *inout,
|
||||
scalar_t *mult_vec,
|
||||
size_t n_mult,
|
||||
size_t batch_size,
|
||||
size_t device_id)
|
||||
{
|
||||
try
|
||||
{
|
||||
// TODO: device_id
|
||||
batch_vector_mult_template(inout, mult_vec, n_mult, batch_size);
|
||||
return CUDA_SUCCESS;
|
||||
}
|
||||
catch (const std::runtime_error &ex)
|
||||
{
|
||||
printf("error %s", ex.what()); // TODO: error code and message
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" int32_t vec_mod_mult_point(projective_t *inout,
|
||||
scalar_t *scalar_vec,
|
||||
size_t n_elments,
|
||||
@@ -114,7 +191,7 @@ extern "C" int32_t vec_mod_mult_point(projective_t *inout,
|
||||
extern "C" int32_t vec_mod_mult_scalar(scalar_t *inout,
|
||||
scalar_t *scalar_vec,
|
||||
size_t n_elments,
|
||||
size_t device_id)
|
||||
size_t device_id) //TODO: unify with batch mult as batch_size=1
|
||||
{
|
||||
try
|
||||
{
|
||||
|
||||
42
src/lib.rs
42
src/lib.rs
@@ -50,6 +50,22 @@ extern "C" {
|
||||
device_id: usize,
|
||||
) -> c_int;
|
||||
|
||||
fn batch_vector_mult_proj_cuda(
|
||||
inout: *mut Point,
|
||||
scalars: *const ScalarField,
|
||||
n_scalars: usize,
|
||||
batch_size: usize,
|
||||
device_id: usize,
|
||||
) -> c_int;
|
||||
|
||||
fn batch_vector_mult_scalar_cuda(
|
||||
inout: *mut ScalarField,
|
||||
scalars: *const ScalarField,
|
||||
n_mult: usize,
|
||||
batch_size: usize,
|
||||
device_id: usize,
|
||||
) -> c_int;
|
||||
|
||||
fn matrix_vec_mod_mult(
|
||||
matrix_flattened: *const ScalarField,
|
||||
input: *const ScalarField,
|
||||
@@ -223,6 +239,32 @@ pub fn mult_sc_vec(a: &mut [Scalar], b: &[Scalar], device_id: usize) {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mult_p_batch_vec(a: &mut [Point], b: &[Scalar], device_id: usize) {
|
||||
assert_eq!(a.len() % b.len(), 0);
|
||||
unsafe {
|
||||
batch_vector_mult_proj_cuda(
|
||||
a as *mut _ as *mut Point,
|
||||
b as *const _ as *const ScalarField,
|
||||
b.len(),
|
||||
a.len() / b.len(),
|
||||
device_id,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn mult_sc_batch_vec(a: &mut [Scalar], b: &[Scalar], device_id: usize) {
|
||||
assert_eq!(a.len() % b.len(), 0);
|
||||
unsafe {
|
||||
batch_vector_mult_scalar_cuda(
|
||||
a as *mut _ as *mut ScalarField,
|
||||
b as *const _ as *const ScalarField,
|
||||
b.len(),
|
||||
a.len() / b.len(),
|
||||
device_id,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Multiply a matrix by a scalar:
|
||||
// `a` - flattenned matrix;
|
||||
// `b` - vector to multiply `a` by;
|
||||
|
||||
Reference in New Issue
Block a user