mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-08 23:17:54 -05:00
accumulate stwo (#535)
adds in-place vector addition and api as accumulate
This commit is contained in:
@@ -56,6 +56,9 @@ extern "C" cudaError_t babybear_mul_cuda(
|
|||||||
extern "C" cudaError_t babybear_add_cuda(
|
extern "C" cudaError_t babybear_add_cuda(
|
||||||
babybear::scalar_t* vec_a, babybear::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, babybear::scalar_t* result);
|
babybear::scalar_t* vec_a, babybear::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, babybear::scalar_t* result);
|
||||||
|
|
||||||
|
extern "C" cudaError_t babybear_accumulate_cuda(
|
||||||
|
babybear::scalar_t* vec_a, babybear::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config);
|
||||||
|
|
||||||
extern "C" cudaError_t babybear_sub_cuda(
|
extern "C" cudaError_t babybear_sub_cuda(
|
||||||
babybear::scalar_t* vec_a, babybear::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, babybear::scalar_t* result);
|
babybear::scalar_t* vec_a, babybear::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, babybear::scalar_t* result);
|
||||||
|
|
||||||
|
|||||||
@@ -104,6 +104,9 @@ extern "C" cudaError_t bls12_377_mul_cuda(
|
|||||||
extern "C" cudaError_t bls12_377_add_cuda(
|
extern "C" cudaError_t bls12_377_add_cuda(
|
||||||
bls12_377::scalar_t* vec_a, bls12_377::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bls12_377::scalar_t* result);
|
bls12_377::scalar_t* vec_a, bls12_377::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bls12_377::scalar_t* result);
|
||||||
|
|
||||||
|
extern "C" cudaError_t bls12_377_accumulate_cuda(
|
||||||
|
bls12_377::scalar_t* vec_a, bls12_377::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config);
|
||||||
|
|
||||||
extern "C" cudaError_t bls12_377_sub_cuda(
|
extern "C" cudaError_t bls12_377_sub_cuda(
|
||||||
bls12_377::scalar_t* vec_a, bls12_377::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bls12_377::scalar_t* result);
|
bls12_377::scalar_t* vec_a, bls12_377::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bls12_377::scalar_t* result);
|
||||||
|
|
||||||
|
|||||||
@@ -104,6 +104,9 @@ extern "C" cudaError_t bls12_381_mul_cuda(
|
|||||||
extern "C" cudaError_t bls12_381_add_cuda(
|
extern "C" cudaError_t bls12_381_add_cuda(
|
||||||
bls12_381::scalar_t* vec_a, bls12_381::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bls12_381::scalar_t* result);
|
bls12_381::scalar_t* vec_a, bls12_381::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bls12_381::scalar_t* result);
|
||||||
|
|
||||||
|
extern "C" cudaError_t bls12_381_accumulate_cuda(
|
||||||
|
bls12_381::scalar_t* vec_a, bls12_381::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config);
|
||||||
|
|
||||||
extern "C" cudaError_t bls12_381_sub_cuda(
|
extern "C" cudaError_t bls12_381_sub_cuda(
|
||||||
bls12_381::scalar_t* vec_a, bls12_381::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bls12_381::scalar_t* result);
|
bls12_381::scalar_t* vec_a, bls12_381::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bls12_381::scalar_t* result);
|
||||||
|
|
||||||
|
|||||||
@@ -136,6 +136,9 @@ extern "C" cudaError_t bn254_mul_cuda(
|
|||||||
extern "C" cudaError_t bn254_add_cuda(
|
extern "C" cudaError_t bn254_add_cuda(
|
||||||
bn254::scalar_t* vec_a, bn254::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bn254::scalar_t* result);
|
bn254::scalar_t* vec_a, bn254::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bn254::scalar_t* result);
|
||||||
|
|
||||||
|
extern "C" cudaError_t bn254_accumulate_cuda(
|
||||||
|
bn254::scalar_t* vec_a, bn254::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config);
|
||||||
|
|
||||||
extern "C" cudaError_t bn254_sub_cuda(
|
extern "C" cudaError_t bn254_sub_cuda(
|
||||||
bn254::scalar_t* vec_a, bn254::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bn254::scalar_t* result);
|
bn254::scalar_t* vec_a, bn254::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bn254::scalar_t* result);
|
||||||
|
|
||||||
|
|||||||
@@ -104,6 +104,9 @@ extern "C" cudaError_t bw6_761_mul_cuda(
|
|||||||
extern "C" cudaError_t bw6_761_add_cuda(
|
extern "C" cudaError_t bw6_761_add_cuda(
|
||||||
bw6_761::scalar_t* vec_a, bw6_761::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bw6_761::scalar_t* result);
|
bw6_761::scalar_t* vec_a, bw6_761::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bw6_761::scalar_t* result);
|
||||||
|
|
||||||
|
extern "C" cudaError_t bw6_761_accumulate_cuda(
|
||||||
|
bw6_761::scalar_t* vec_a, bw6_761::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config);
|
||||||
|
|
||||||
extern "C" cudaError_t bw6_761_sub_cuda(
|
extern "C" cudaError_t bw6_761_sub_cuda(
|
||||||
bw6_761::scalar_t* vec_a, bw6_761::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bw6_761::scalar_t* result);
|
bw6_761::scalar_t* vec_a, bw6_761::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, bw6_761::scalar_t* result);
|
||||||
|
|
||||||
|
|||||||
@@ -74,6 +74,9 @@ extern "C" cudaError_t grumpkin_mul_cuda(
|
|||||||
extern "C" cudaError_t grumpkin_add_cuda(
|
extern "C" cudaError_t grumpkin_add_cuda(
|
||||||
grumpkin::scalar_t* vec_a, grumpkin::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, grumpkin::scalar_t* result);
|
grumpkin::scalar_t* vec_a, grumpkin::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, grumpkin::scalar_t* result);
|
||||||
|
|
||||||
|
extern "C" cudaError_t grumpkin_accumulate_cuda(
|
||||||
|
grumpkin::scalar_t* vec_a, grumpkin::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config);
|
||||||
|
|
||||||
extern "C" cudaError_t grumpkin_sub_cuda(
|
extern "C" cudaError_t grumpkin_sub_cuda(
|
||||||
grumpkin::scalar_t* vec_a, grumpkin::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, grumpkin::scalar_t* result);
|
grumpkin::scalar_t* vec_a, grumpkin::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, grumpkin::scalar_t* result);
|
||||||
|
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ extern "C" cudaError_t stark252_mul_cuda(
|
|||||||
extern "C" cudaError_t stark252_add_cuda(
|
extern "C" cudaError_t stark252_add_cuda(
|
||||||
stark252::scalar_t* vec_a, stark252::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, stark252::scalar_t* result);
|
stark252::scalar_t* vec_a, stark252::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, stark252::scalar_t* result);
|
||||||
|
|
||||||
|
extern "C" cudaError_t stark252_accumulate_cuda(
|
||||||
|
stark252::scalar_t* vec_a, stark252::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config);
|
||||||
|
|
||||||
extern "C" cudaError_t stark252_sub_cuda(
|
extern "C" cudaError_t stark252_sub_cuda(
|
||||||
stark252::scalar_t* vec_a, stark252::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, stark252::scalar_t* result);
|
stark252::scalar_t* vec_a, stark252::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, stark252::scalar_t* result);
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ extern "C" cudaError_t ${FIELD}_mul_cuda(
|
|||||||
extern "C" cudaError_t ${FIELD}_add_cuda(
|
extern "C" cudaError_t ${FIELD}_add_cuda(
|
||||||
${FIELD}::scalar_t* vec_a, ${FIELD}::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, ${FIELD}::scalar_t* result);
|
${FIELD}::scalar_t* vec_a, ${FIELD}::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, ${FIELD}::scalar_t* result);
|
||||||
|
|
||||||
|
extern "C" cudaError_t ${FIELD}_accumulate_cuda(
|
||||||
|
${FIELD}::scalar_t* vec_a, ${FIELD}::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config);
|
||||||
|
|
||||||
extern "C" cudaError_t ${FIELD}_sub_cuda(
|
extern "C" cudaError_t ${FIELD}_sub_cuda(
|
||||||
${FIELD}::scalar_t* vec_a, ${FIELD}::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, ${FIELD}::scalar_t* result);
|
${FIELD}::scalar_t* vec_a, ${FIELD}::scalar_t* vec_b, int n, vec_ops::VecOpsConfig& config, ${FIELD}::scalar_t* result);
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ extern "C" cudaError_t ${FIELD}_extension_mul_cuda(
|
|||||||
extern "C" cudaError_t ${FIELD}_extension_add_cuda(
|
extern "C" cudaError_t ${FIELD}_extension_add_cuda(
|
||||||
${FIELD}::extension_t* vec_a, ${FIELD}::extension_t* vec_b, int n, vec_ops::VecOpsConfig& config, ${FIELD}::extension_t* result);
|
${FIELD}::extension_t* vec_a, ${FIELD}::extension_t* vec_b, int n, vec_ops::VecOpsConfig& config, ${FIELD}::extension_t* result);
|
||||||
|
|
||||||
|
extern "C" cudaError_t ${FIELD}_extension_accumulate_cuda(
|
||||||
|
${FIELD}::extension_t* vec_a, ${FIELD}::extension_t* vec_b, int n, vec_ops::VecOpsConfig& config);
|
||||||
|
|
||||||
extern "C" cudaError_t ${FIELD}_extension_sub_cuda(
|
extern "C" cudaError_t ${FIELD}_extension_sub_cuda(
|
||||||
${FIELD}::extension_t* vec_a, ${FIELD}::extension_t* vec_b, int n, vec_ops::VecOpsConfig& config, ${FIELD}::extension_t* result);
|
${FIELD}::extension_t* vec_a, ${FIELD}::extension_t* vec_b, int n, vec_ops::VecOpsConfig& config, ${FIELD}::extension_t* result);
|
||||||
|
|
||||||
|
|||||||
@@ -30,6 +30,18 @@ namespace vec_ops {
|
|||||||
return add<scalar_t>(vec_a, vec_b, n, config, result);
|
return add<scalar_t>(vec_a, vec_b, n, config, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Accumulate (as vec_a[i] += vec_b[i]) function with the template parameter
|
||||||
|
* `E` being the [field](@ref scalar_t) (either scalar field of the curve given by `-DCURVE`
|
||||||
|
* or standalone "STARK field" given by `-DFIELD`).
|
||||||
|
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||||
|
*/
|
||||||
|
extern "C" cudaError_t
|
||||||
|
CONCAT_EXPAND(FIELD, accumulate_cuda)(scalar_t* vec_a, scalar_t* vec_b, int n, VecOpsConfig& config)
|
||||||
|
{
|
||||||
|
return add<scalar_t>(vec_a, vec_b, n, config, vec_a);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extern version of [Sub](@ref Sub) function with the template parameter
|
* Extern version of [Sub](@ref Sub) function with the template parameter
|
||||||
* `E` being the [field](@ref scalar_t) (either scalar field of the curve given by `-DCURVE`
|
* `E` being the [field](@ref scalar_t) (either scalar field of the curve given by `-DCURVE`
|
||||||
|
|||||||
@@ -29,6 +29,17 @@ namespace vec_ops {
|
|||||||
return add<extension_t>(vec_a, vec_b, n, config, result);
|
return add<extension_t>(vec_a, vec_b, n, config, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Accumulate (as vec_a[i] += vec_b[i]) function with the template parameter
|
||||||
|
* `E` being the [extension field](@ref extension_t) of the base field given by `-DFIELD` env variable during build.
|
||||||
|
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||||
|
*/
|
||||||
|
extern "C" cudaError_t
|
||||||
|
CONCAT_EXPAND(FIELD, extension_accumulate_cuda)(extension_t* vec_a, extension_t* vec_b, int n, VecOpsConfig& config)
|
||||||
|
{
|
||||||
|
return add<extension_t>(vec_a, vec_b, n, config, vec_a);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Extern version of [Sub](@ref Sub) function with the template parameter
|
* Extern version of [Sub](@ref Sub) function with the template parameter
|
||||||
* `E` being the [extension field](@ref extension_t) of the base field given by `-DFIELD` env variable during build.
|
* `E` being the [extension field](@ref extension_t) of the base field given by `-DFIELD` env variable during build.
|
||||||
|
|||||||
@@ -82,16 +82,19 @@ namespace vec_ops {
|
|||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
template <typename E, void (*Kernel)(const E*, const E*, int, E*)>
|
template <typename E, void (*Kernel)(const E*, const E*, int, E*)>
|
||||||
cudaError_t vec_op(const E* vec_a, const E* vec_b, int n, VecOpsConfig& config, E* result)
|
cudaError_t vec_op(E* vec_a, const E* vec_b, int n, VecOpsConfig& config, E* result)
|
||||||
{
|
{
|
||||||
CHK_INIT_IF_RETURN();
|
CHK_INIT_IF_RETURN();
|
||||||
|
|
||||||
|
bool is_in_place = vec_a == result;
|
||||||
|
|
||||||
// Set the grid and block dimensions
|
// Set the grid and block dimensions
|
||||||
int num_threads = MAX_THREADS_PER_BLOCK;
|
int num_threads = MAX_THREADS_PER_BLOCK;
|
||||||
int num_blocks = (n + num_threads - 1) / num_threads;
|
int num_blocks = (n + num_threads - 1) / num_threads;
|
||||||
|
|
||||||
E *d_result, *d_alloc_vec_a, *d_alloc_vec_b;
|
E *d_result, *d_alloc_vec_a, *d_alloc_vec_b;
|
||||||
const E *d_vec_a, *d_vec_b;
|
E* d_vec_a;
|
||||||
|
const E* d_vec_b;
|
||||||
if (!config.is_a_on_device) {
|
if (!config.is_a_on_device) {
|
||||||
CHK_IF_RETURN(cudaMallocAsync(&d_alloc_vec_a, n * sizeof(E), config.ctx.stream));
|
CHK_IF_RETURN(cudaMallocAsync(&d_alloc_vec_a, n * sizeof(E), config.ctx.stream));
|
||||||
CHK_IF_RETURN(cudaMemcpyAsync(d_alloc_vec_a, vec_a, n * sizeof(E), cudaMemcpyHostToDevice, config.ctx.stream));
|
CHK_IF_RETURN(cudaMemcpyAsync(d_alloc_vec_a, vec_a, n * sizeof(E), cudaMemcpyHostToDevice, config.ctx.stream));
|
||||||
@@ -109,41 +112,49 @@ namespace vec_ops {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (!config.is_result_on_device) {
|
if (!config.is_result_on_device) {
|
||||||
CHK_IF_RETURN(cudaMallocAsync(&d_result, n * sizeof(E), config.ctx.stream));
|
if (!is_in_place) {
|
||||||
|
CHK_IF_RETURN(cudaMallocAsync(&d_result, n * sizeof(E), config.ctx.stream));
|
||||||
|
} else {
|
||||||
|
d_result = d_vec_a;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
d_result = result;
|
if (!is_in_place) {
|
||||||
|
d_result = result;
|
||||||
|
} else {
|
||||||
|
d_result = result = d_vec_a;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call the kernel to perform element-wise operation
|
// Call the kernel to perform element-wise operation
|
||||||
Kernel<<<num_blocks, num_threads, 0, config.ctx.stream>>>(d_vec_a, d_vec_b, n, d_result);
|
Kernel<<<num_blocks, num_threads, 0, config.ctx.stream>>>(d_vec_a, d_vec_b, n, d_result);
|
||||||
|
|
||||||
if (!config.is_a_on_device) { CHK_IF_RETURN(cudaFreeAsync(d_alloc_vec_a, config.ctx.stream)); }
|
|
||||||
if (!config.is_b_on_device) { CHK_IF_RETURN(cudaFreeAsync(d_alloc_vec_b, config.ctx.stream)); }
|
|
||||||
|
|
||||||
if (!config.is_result_on_device) {
|
if (!config.is_result_on_device) {
|
||||||
CHK_IF_RETURN(cudaMemcpyAsync(result, d_result, n * sizeof(E), cudaMemcpyDeviceToHost, config.ctx.stream));
|
CHK_IF_RETURN(cudaMemcpyAsync(result, d_result, n * sizeof(E), cudaMemcpyDeviceToHost, config.ctx.stream));
|
||||||
CHK_IF_RETURN(cudaFreeAsync(d_result, config.ctx.stream));
|
CHK_IF_RETURN(cudaFreeAsync(d_result, config.ctx.stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!config.is_a_on_device && !is_in_place) { CHK_IF_RETURN(cudaFreeAsync(d_alloc_vec_a, config.ctx.stream)); }
|
||||||
|
if (!config.is_b_on_device) { CHK_IF_RETURN(cudaFreeAsync(d_alloc_vec_b, config.ctx.stream)); }
|
||||||
|
|
||||||
if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(config.ctx.stream));
|
if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(config.ctx.stream));
|
||||||
|
|
||||||
return CHK_LAST();
|
return CHK_LAST();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename E>
|
template <typename E>
|
||||||
cudaError_t mul(const E* vec_a, const E* vec_b, int n, VecOpsConfig& config, E* result)
|
cudaError_t mul(E* vec_a, const E* vec_b, int n, VecOpsConfig& config, E* result)
|
||||||
{
|
{
|
||||||
return vec_op<E, mul_kernel>(vec_a, vec_b, n, config, result);
|
return vec_op<E, mul_kernel>(vec_a, vec_b, n, config, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename E>
|
template <typename E>
|
||||||
cudaError_t add(const E* vec_a, const E* vec_b, int n, VecOpsConfig& config, E* result)
|
cudaError_t add(E* vec_a, const E* vec_b, int n, VecOpsConfig& config, E* result)
|
||||||
{
|
{
|
||||||
return vec_op<E, add_kernel>(vec_a, vec_b, n, config, result);
|
return vec_op<E, add_kernel>(vec_a, vec_b, n, config, result);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename E>
|
template <typename E>
|
||||||
cudaError_t sub(const E* vec_a, const E* vec_b, int n, VecOpsConfig& config, E* result)
|
cudaError_t sub(E* vec_a, const E* vec_b, int n, VecOpsConfig& config, E* result)
|
||||||
{
|
{
|
||||||
return vec_op<E, sub_kernel>(vec_a, vec_b, n, config, result);
|
return vec_op<E, sub_kernel>(vec_a, vec_b, n, config, result);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,12 +3,13 @@
|
|||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
// include list of test files
|
// include list of test files
|
||||||
// Ensure the device_error_test.cu is last to prevent aborting mid-test run
|
|
||||||
#include "field_test.cu"
|
#include "field_test.cu"
|
||||||
#ifdef CURVE_ID
|
#ifdef CURVE_ID
|
||||||
#include "curve_test.cu"
|
#include "curve_test.cu"
|
||||||
#endif
|
#endif
|
||||||
#include "error_handler_test.cu"
|
#include "error_handler_test.cu"
|
||||||
|
|
||||||
|
// Ensure the device_error_test.cu is last to prevent aborting mid-test run
|
||||||
#include "device_error_test.cu"
|
#include "device_error_test.cu"
|
||||||
|
|
||||||
int main(int argc, char** argv)
|
int main(int argc, char** argv)
|
||||||
|
|||||||
@@ -83,6 +83,12 @@ pub trait VecOps<F> {
|
|||||||
cfg: &VecOpsConfig,
|
cfg: &VecOpsConfig,
|
||||||
) -> IcicleResult<()>;
|
) -> IcicleResult<()>;
|
||||||
|
|
||||||
|
fn accumulate(
|
||||||
|
a: &mut (impl HostOrDeviceSlice<F> + ?Sized),
|
||||||
|
b: &(impl HostOrDeviceSlice<F> + ?Sized),
|
||||||
|
cfg: &VecOpsConfig,
|
||||||
|
) -> IcicleResult<()>;
|
||||||
|
|
||||||
fn sub(
|
fn sub(
|
||||||
a: &(impl HostOrDeviceSlice<F> + ?Sized),
|
a: &(impl HostOrDeviceSlice<F> + ?Sized),
|
||||||
b: &(impl HostOrDeviceSlice<F> + ?Sized),
|
b: &(impl HostOrDeviceSlice<F> + ?Sized),
|
||||||
@@ -207,6 +213,19 @@ where
|
|||||||
<<F as FieldImpl>::Config as VecOps<F>>::add(a, b, result, &cfg)
|
<<F as FieldImpl>::Config as VecOps<F>>::add(a, b, result, &cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn accumulate_scalars<F>(
|
||||||
|
a: &mut (impl HostOrDeviceSlice<F> + ?Sized),
|
||||||
|
b: &(impl HostOrDeviceSlice<F> + ?Sized),
|
||||||
|
cfg: &VecOpsConfig,
|
||||||
|
) -> IcicleResult<()>
|
||||||
|
where
|
||||||
|
F: FieldImpl,
|
||||||
|
<F as FieldImpl>::Config: VecOps<F>,
|
||||||
|
{
|
||||||
|
let cfg = check_vec_ops_args(a, b, a, cfg);
|
||||||
|
<<F as FieldImpl>::Config as VecOps<F>>::accumulate(a, b, &cfg)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn sub_scalars<F>(
|
pub fn sub_scalars<F>(
|
||||||
a: &(impl HostOrDeviceSlice<F> + ?Sized),
|
a: &(impl HostOrDeviceSlice<F> + ?Sized),
|
||||||
b: &(impl HostOrDeviceSlice<F> + ?Sized),
|
b: &(impl HostOrDeviceSlice<F> + ?Sized),
|
||||||
@@ -299,6 +318,14 @@ macro_rules! impl_vec_ops_field {
|
|||||||
result: *mut $field,
|
result: *mut $field,
|
||||||
) -> CudaError;
|
) -> CudaError;
|
||||||
|
|
||||||
|
#[link_name = concat!($field_prefix, "_accumulate_cuda")]
|
||||||
|
pub(crate) fn accumulate_scalars_cuda(
|
||||||
|
a: *mut $field,
|
||||||
|
b: *const $field,
|
||||||
|
size: u32,
|
||||||
|
cfg: *const VecOpsConfig,
|
||||||
|
) -> CudaError;
|
||||||
|
|
||||||
#[link_name = concat!($field_prefix, "_sub_cuda")]
|
#[link_name = concat!($field_prefix, "_sub_cuda")]
|
||||||
pub(crate) fn sub_scalars_cuda(
|
pub(crate) fn sub_scalars_cuda(
|
||||||
a: *const $field,
|
a: *const $field,
|
||||||
@@ -357,6 +384,22 @@ macro_rules! impl_vec_ops_field {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn accumulate(
|
||||||
|
a: &mut (impl HostOrDeviceSlice<$field> + ?Sized),
|
||||||
|
b: &(impl HostOrDeviceSlice<$field> + ?Sized),
|
||||||
|
cfg: &VecOpsConfig,
|
||||||
|
) -> IcicleResult<()> {
|
||||||
|
unsafe {
|
||||||
|
$field_prefix_ident::accumulate_scalars_cuda(
|
||||||
|
a.as_mut_ptr(),
|
||||||
|
b.as_ptr(),
|
||||||
|
a.len() as u32,
|
||||||
|
cfg as *const VecOpsConfig,
|
||||||
|
)
|
||||||
|
.wrap()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn sub(
|
fn sub(
|
||||||
a: &(impl HostOrDeviceSlice<$field> + ?Sized),
|
a: &(impl HostOrDeviceSlice<$field> + ?Sized),
|
||||||
b: &(impl HostOrDeviceSlice<$field> + ?Sized),
|
b: &(impl HostOrDeviceSlice<$field> + ?Sized),
|
||||||
@@ -457,7 +500,7 @@ macro_rules! impl_vec_add_tests {
|
|||||||
) => {
|
) => {
|
||||||
#[test]
|
#[test]
|
||||||
pub fn test_vec_add_scalars() {
|
pub fn test_vec_add_scalars() {
|
||||||
check_vec_ops_scalars::<$field>()
|
check_vec_ops_scalars::<$field>();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|||||||
@@ -5,19 +5,21 @@ use crate::vec_ops::{
|
|||||||
};
|
};
|
||||||
use icicle_cuda_runtime::memory::{DeviceVec, HostSlice};
|
use icicle_cuda_runtime::memory::{DeviceVec, HostSlice};
|
||||||
|
|
||||||
|
use super::accumulate_scalars;
|
||||||
|
|
||||||
pub fn check_vec_ops_scalars<F: FieldImpl>()
|
pub fn check_vec_ops_scalars<F: FieldImpl>()
|
||||||
where
|
where
|
||||||
<F as FieldImpl>::Config: VecOps<F> + GenerateRandom<F>,
|
<F as FieldImpl>::Config: VecOps<F> + GenerateRandom<F>,
|
||||||
{
|
{
|
||||||
let test_size = 1 << 14;
|
let test_size = 1 << 14;
|
||||||
|
|
||||||
let a = F::Config::generate_random(test_size);
|
let mut a = F::Config::generate_random(test_size);
|
||||||
let b = F::Config::generate_random(test_size);
|
let b = F::Config::generate_random(test_size);
|
||||||
let ones = vec![F::one(); test_size];
|
let ones = vec![F::one(); test_size];
|
||||||
let mut result = vec![F::zero(); test_size];
|
let mut result = vec![F::zero(); test_size];
|
||||||
let mut result2 = vec![F::zero(); test_size];
|
let mut result2 = vec![F::zero(); test_size];
|
||||||
let mut result3 = vec![F::zero(); test_size];
|
let mut result3 = vec![F::zero(); test_size];
|
||||||
let a = HostSlice::from_slice(&a);
|
let a = HostSlice::from_mut_slice(&mut a);
|
||||||
let b = HostSlice::from_slice(&b);
|
let b = HostSlice::from_slice(&b);
|
||||||
let ones = HostSlice::from_slice(&ones);
|
let ones = HostSlice::from_slice(&ones);
|
||||||
let result = HostSlice::from_mut_slice(&mut result);
|
let result = HostSlice::from_mut_slice(&mut result);
|
||||||
@@ -34,6 +36,12 @@ where
|
|||||||
mul_scalars(a, ones, result3, &cfg).unwrap();
|
mul_scalars(a, ones, result3, &cfg).unwrap();
|
||||||
|
|
||||||
assert_eq!(a[0], result3[0]);
|
assert_eq!(a[0], result3[0]);
|
||||||
|
|
||||||
|
add_scalars(a, b, result, &cfg).unwrap();
|
||||||
|
|
||||||
|
accumulate_scalars(a, b, &cfg).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(a[0], result[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check_bit_reverse<F: FieldImpl>()
|
pub fn check_bit_reverse<F: FieldImpl>()
|
||||||
|
|||||||
Reference in New Issue
Block a user