mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-14 01:47:59 -05:00
Compare commits
19 Commits
feat/warmu
...
Otsar
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
919ff42f49 | ||
|
|
a1ff989740 | ||
|
|
1f2144a57c | ||
|
|
db4c07dcaf | ||
|
|
d4f39efea3 | ||
|
|
7293058246 | ||
|
|
03136f1074 | ||
|
|
3ef0d0c66e | ||
|
|
0dff1f9302 | ||
|
|
0d806d96ca | ||
|
|
b6b5011a47 | ||
|
|
7ac463c3d9 | ||
|
|
287f53ff16 | ||
|
|
89082fb561 | ||
|
|
08ec0b1ff6 | ||
|
|
fa219d9c95 | ||
|
|
0e84fb4b76 | ||
|
|
d8059a2a4e | ||
|
|
1abd2ef9c9 |
2
.github/workflows/golang.yml
vendored
2
.github/workflows/golang.yml
vendored
@@ -50,7 +50,7 @@ jobs:
|
||||
- name: Build
|
||||
working-directory: ./wrappers/golang
|
||||
if: needs.check-changed-files.outputs.golang == 'true' || needs.check-changed-files.outputs.cpp_cuda == 'true'
|
||||
run: ./build.sh ${{ matrix.curve }} ON # builds a single curve with G2 enabled
|
||||
run: ./build.sh ${{ matrix.curve }} ON ON # builds a single curve with G2 and ECNTT enabled
|
||||
- name: Upload ICICLE lib artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
if: needs.check-changed-files.outputs.golang == 'true' || needs.check-changed-files.outputs.cpp_cuda == 'true'
|
||||
|
||||
18
.github/workflows/release.yml
vendored
18
.github/workflows/release.yml
vendored
@@ -20,11 +20,27 @@ jobs:
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
ssh-key: ${{ secrets.DEPLOY_KEY }}
|
||||
- name: Setup Cache
|
||||
id: cache
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/bin/
|
||||
~/.cargo/registry/index/
|
||||
~/.cargo/registry/cache/
|
||||
~/.cargo/git/db/
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('~/.cargo/bin/cargo-workspaces') }}
|
||||
- name: Install cargo-workspaces
|
||||
if: steps.cache.outputs.cache-hit != 'true'
|
||||
run: cargo install cargo-workspaces
|
||||
- name: Bump rust crate versions, commit, and tag
|
||||
working-directory: wrappers/rust
|
||||
# https://github.com/pksunkara/cargo-workspaces?tab=readme-ov-file#version
|
||||
run: |
|
||||
cargo install cargo-workspaces
|
||||
git config user.name release-bot
|
||||
git config user.email release-bot@ingonyama.com
|
||||
cargo workspaces version ${{ inputs.releaseType }} -y --no-individual-tags -m "Bump rust crates' version"
|
||||
- name: Create draft release
|
||||
env:
|
||||
|
||||
@@ -15,7 +15,7 @@ ENV PATH="/root/.cargo/bin:${PATH}"
|
||||
|
||||
# Install Golang
|
||||
ENV GOLANG_VERSION 1.21.1
|
||||
RUN curl -L https://golang.org/dl/go${GOLANG_VERSION}.linux-amd64.tar.gz | tar -xz -C /usr/local
|
||||
RUN curl -L https://go.dev/dl/go${GOLANG_VERSION}.linux-amd64.tar.gz | tar -xz -C /usr/local
|
||||
ENV PATH="/usr/local/go/bin:${PATH}"
|
||||
|
||||
# Set the working directory in the container
|
||||
|
||||
@@ -11,8 +11,6 @@
|
||||
</a>
|
||||
<a href="https://twitter.com/intent/follow?screen_name=Ingo_zk">
|
||||
<img src="https://img.shields.io/twitter/follow/Ingo_zk?style=social&logo=twitter" alt="Follow us on Twitter">
|
||||
</a>
|
||||
<img src="https://img.shields.io/badge/Machines%20running%20ICICLE-544-lightblue" alt="Machines running ICICLE">
|
||||
<a href="https://github.com/ingonyama-zk/icicle/releases">
|
||||
<img src="https://img.shields.io/github/v/release/ingonyama-zk/icicle" alt="GitHub Release">
|
||||
</a>
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
[](https://github.com/ingonyama-zk/icicle/releases)
|
||||
|
||||

|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -60,7 +60,13 @@ else()
|
||||
endif()
|
||||
|
||||
project(icicle LANGUAGES CUDA CXX)
|
||||
|
||||
# Check CUDA version and, if possible, enable multi-threaded compilation
|
||||
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.2")
|
||||
message(STATUS "Using multi-threaded CUDA compilation.")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --split-compile 0")
|
||||
else()
|
||||
message(STATUS "Can't use multi-threaded CUDA compilation.")
|
||||
endif()
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
|
||||
set(CMAKE_CUDA_FLAGS_RELEASE "")
|
||||
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G -O0")
|
||||
@@ -90,6 +96,10 @@ if (G2_DEFINED STREQUAL "ON")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DG2_DEFINED=ON")
|
||||
endif ()
|
||||
|
||||
if (ECNTT_DEFINED STREQUAL "ON")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DECNTT_DEFINED=ON")
|
||||
endif ()
|
||||
|
||||
option(BUILD_TESTS "Build tests" OFF)
|
||||
|
||||
if (NOT BUILD_TESTS)
|
||||
@@ -104,6 +114,9 @@ if (NOT BUILD_TESTS)
|
||||
if (NOT CURVE IN_LIST SUPPORTED_CURVES_WITHOUT_NTT)
|
||||
list(APPEND ICICLE_SOURCES appUtils/ntt/ntt.cu)
|
||||
list(APPEND ICICLE_SOURCES appUtils/ntt/kernel_ntt.cu)
|
||||
if(ECNTT_DEFINED STREQUAL "ON")
|
||||
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DECNTT_DEFINED=ON")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
add_library(
|
||||
|
||||
2
icicle/appUtils/keccak/Makefile
Normal file
2
icicle/appUtils/keccak/Makefile
Normal file
@@ -0,0 +1,2 @@
|
||||
test_keccak: test.cu keccak.cu
|
||||
nvcc -o test_keccak -I. -I../.. test.cu
|
||||
275
icicle/appUtils/keccak/keccak.cu
Normal file
275
icicle/appUtils/keccak/keccak.cu
Normal file
@@ -0,0 +1,275 @@
|
||||
#include "keccak.cuh"
|
||||
|
||||
namespace keccak {
|
||||
#define ROTL64(x, y) (((x) << (y)) | ((x) >> (64 - (y))))
|
||||
|
||||
#define TH_ELT(t, c0, c1, c2, c3, c4, d0, d1, d2, d3, d4) \
|
||||
{ \
|
||||
t = ROTL64((d0 ^ d1 ^ d2 ^ d3 ^ d4), 1) ^ (c0 ^ c1 ^ c2 ^ c3 ^ c4); \
|
||||
}
|
||||
|
||||
#define THETA( \
|
||||
s00, s01, s02, s03, s04, s10, s11, s12, s13, s14, s20, s21, s22, s23, s24, s30, s31, s32, s33, s34, s40, s41, s42, \
|
||||
s43, s44) \
|
||||
{ \
|
||||
TH_ELT(t0, s40, s41, s42, s43, s44, s10, s11, s12, s13, s14); \
|
||||
TH_ELT(t1, s00, s01, s02, s03, s04, s20, s21, s22, s23, s24); \
|
||||
TH_ELT(t2, s10, s11, s12, s13, s14, s30, s31, s32, s33, s34); \
|
||||
TH_ELT(t3, s20, s21, s22, s23, s24, s40, s41, s42, s43, s44); \
|
||||
TH_ELT(t4, s30, s31, s32, s33, s34, s00, s01, s02, s03, s04); \
|
||||
s00 ^= t0; \
|
||||
s01 ^= t0; \
|
||||
s02 ^= t0; \
|
||||
s03 ^= t0; \
|
||||
s04 ^= t0; \
|
||||
\
|
||||
s10 ^= t1; \
|
||||
s11 ^= t1; \
|
||||
s12 ^= t1; \
|
||||
s13 ^= t1; \
|
||||
s14 ^= t1; \
|
||||
\
|
||||
s20 ^= t2; \
|
||||
s21 ^= t2; \
|
||||
s22 ^= t2; \
|
||||
s23 ^= t2; \
|
||||
s24 ^= t2; \
|
||||
\
|
||||
s30 ^= t3; \
|
||||
s31 ^= t3; \
|
||||
s32 ^= t3; \
|
||||
s33 ^= t3; \
|
||||
s34 ^= t3; \
|
||||
\
|
||||
s40 ^= t4; \
|
||||
s41 ^= t4; \
|
||||
s42 ^= t4; \
|
||||
s43 ^= t4; \
|
||||
s44 ^= t4; \
|
||||
}
|
||||
|
||||
#define RHOPI( \
|
||||
s00, s01, s02, s03, s04, s10, s11, s12, s13, s14, s20, s21, s22, s23, s24, s30, s31, s32, s33, s34, s40, s41, s42, \
|
||||
s43, s44) \
|
||||
{ \
|
||||
t0 = ROTL64(s10, (uint64_t)1); \
|
||||
s10 = ROTL64(s11, (uint64_t)44); \
|
||||
s11 = ROTL64(s41, (uint64_t)20); \
|
||||
s41 = ROTL64(s24, (uint64_t)61); \
|
||||
s24 = ROTL64(s42, (uint64_t)39); \
|
||||
s42 = ROTL64(s04, (uint64_t)18); \
|
||||
s04 = ROTL64(s20, (uint64_t)62); \
|
||||
s20 = ROTL64(s22, (uint64_t)43); \
|
||||
s22 = ROTL64(s32, (uint64_t)25); \
|
||||
s32 = ROTL64(s43, (uint64_t)8); \
|
||||
s43 = ROTL64(s34, (uint64_t)56); \
|
||||
s34 = ROTL64(s03, (uint64_t)41); \
|
||||
s03 = ROTL64(s40, (uint64_t)27); \
|
||||
s40 = ROTL64(s44, (uint64_t)14); \
|
||||
s44 = ROTL64(s14, (uint64_t)2); \
|
||||
s14 = ROTL64(s31, (uint64_t)55); \
|
||||
s31 = ROTL64(s13, (uint64_t)45); \
|
||||
s13 = ROTL64(s01, (uint64_t)36); \
|
||||
s01 = ROTL64(s30, (uint64_t)28); \
|
||||
s30 = ROTL64(s33, (uint64_t)21); \
|
||||
s33 = ROTL64(s23, (uint64_t)15); \
|
||||
s23 = ROTL64(s12, (uint64_t)10); \
|
||||
s12 = ROTL64(s21, (uint64_t)6); \
|
||||
s21 = ROTL64(s02, (uint64_t)3); \
|
||||
s02 = t0; \
|
||||
}
|
||||
|
||||
#define KHI( \
|
||||
s00, s01, s02, s03, s04, s10, s11, s12, s13, s14, s20, s21, s22, s23, s24, s30, s31, s32, s33, s34, s40, s41, s42, \
|
||||
s43, s44) \
|
||||
{ \
|
||||
t0 = s00 ^ (~s10 & s20); \
|
||||
t1 = s10 ^ (~s20 & s30); \
|
||||
t2 = s20 ^ (~s30 & s40); \
|
||||
t3 = s30 ^ (~s40 & s00); \
|
||||
t4 = s40 ^ (~s00 & s10); \
|
||||
s00 = t0; \
|
||||
s10 = t1; \
|
||||
s20 = t2; \
|
||||
s30 = t3; \
|
||||
s40 = t4; \
|
||||
\
|
||||
t0 = s01 ^ (~s11 & s21); \
|
||||
t1 = s11 ^ (~s21 & s31); \
|
||||
t2 = s21 ^ (~s31 & s41); \
|
||||
t3 = s31 ^ (~s41 & s01); \
|
||||
t4 = s41 ^ (~s01 & s11); \
|
||||
s01 = t0; \
|
||||
s11 = t1; \
|
||||
s21 = t2; \
|
||||
s31 = t3; \
|
||||
s41 = t4; \
|
||||
\
|
||||
t0 = s02 ^ (~s12 & s22); \
|
||||
t1 = s12 ^ (~s22 & s32); \
|
||||
t2 = s22 ^ (~s32 & s42); \
|
||||
t3 = s32 ^ (~s42 & s02); \
|
||||
t4 = s42 ^ (~s02 & s12); \
|
||||
s02 = t0; \
|
||||
s12 = t1; \
|
||||
s22 = t2; \
|
||||
s32 = t3; \
|
||||
s42 = t4; \
|
||||
\
|
||||
t0 = s03 ^ (~s13 & s23); \
|
||||
t1 = s13 ^ (~s23 & s33); \
|
||||
t2 = s23 ^ (~s33 & s43); \
|
||||
t3 = s33 ^ (~s43 & s03); \
|
||||
t4 = s43 ^ (~s03 & s13); \
|
||||
s03 = t0; \
|
||||
s13 = t1; \
|
||||
s23 = t2; \
|
||||
s33 = t3; \
|
||||
s43 = t4; \
|
||||
\
|
||||
t0 = s04 ^ (~s14 & s24); \
|
||||
t1 = s14 ^ (~s24 & s34); \
|
||||
t2 = s24 ^ (~s34 & s44); \
|
||||
t3 = s34 ^ (~s44 & s04); \
|
||||
t4 = s44 ^ (~s04 & s14); \
|
||||
s04 = t0; \
|
||||
s14 = t1; \
|
||||
s24 = t2; \
|
||||
s34 = t3; \
|
||||
s44 = t4; \
|
||||
}
|
||||
|
||||
#define IOTA(element, rc) \
|
||||
{ \
|
||||
element ^= rc; \
|
||||
}
|
||||
|
||||
__device__ const uint64_t RC[24] = {0x0000000000000001, 0x0000000000008082, 0x800000000000808a, 0x8000000080008000,
|
||||
0x000000000000808b, 0x0000000080000001, 0x8000000080008081, 0x8000000000008009,
|
||||
0x000000000000008a, 0x0000000000000088, 0x0000000080008009, 0x000000008000000a,
|
||||
0x000000008000808b, 0x800000000000008b, 0x8000000000008089, 0x8000000000008003,
|
||||
0x8000000000008002, 0x8000000000000080, 0x000000000000800a, 0x800000008000000a,
|
||||
0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008};
|
||||
|
||||
__device__ void keccakf(uint64_t s[25])
|
||||
{
|
||||
uint64_t t0, t1, t2, t3, t4;
|
||||
|
||||
for (int i = 0; i < 24; i++) {
|
||||
THETA(
|
||||
s[0], s[5], s[10], s[15], s[20], s[1], s[6], s[11], s[16], s[21], s[2], s[7], s[12], s[17], s[22], s[3], s[8],
|
||||
s[13], s[18], s[23], s[4], s[9], s[14], s[19], s[24]);
|
||||
RHOPI(
|
||||
s[0], s[5], s[10], s[15], s[20], s[1], s[6], s[11], s[16], s[21], s[2], s[7], s[12], s[17], s[22], s[3], s[8],
|
||||
s[13], s[18], s[23], s[4], s[9], s[14], s[19], s[24]);
|
||||
KHI(
|
||||
s[0], s[5], s[10], s[15], s[20], s[1], s[6], s[11], s[16], s[21], s[2], s[7], s[12], s[17], s[22], s[3], s[8],
|
||||
s[13], s[18], s[23], s[4], s[9], s[14], s[19], s[24]);
|
||||
IOTA(s[0], RC[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <int C, int D>
|
||||
__global__ void keccak_hash_blocks(uint8_t* input, int input_block_size, int number_of_blocks, uint8_t* output)
|
||||
{
|
||||
int bid = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
if (bid >= number_of_blocks) { return; }
|
||||
|
||||
const int r_bits = 1600 - C;
|
||||
const int r_bytes = r_bits / 8;
|
||||
const int d_bytes = D / 8;
|
||||
|
||||
uint8_t* b_input = input + bid * input_block_size;
|
||||
uint8_t* b_output = output + bid * d_bytes;
|
||||
uint64_t state[25] = {}; // Initialize with zeroes
|
||||
|
||||
int input_len = input_block_size;
|
||||
|
||||
// absorb
|
||||
while (input_len >= r_bytes) {
|
||||
// #pragma unroll
|
||||
for (int i = 0; i < r_bytes; i += 8) {
|
||||
state[i / 8] ^= *(uint64_t*)(b_input + i);
|
||||
}
|
||||
keccakf(state);
|
||||
b_input += r_bytes;
|
||||
input_len -= r_bytes;
|
||||
}
|
||||
|
||||
// last block (if any)
|
||||
uint8_t last_block[r_bytes];
|
||||
for (int i = 0; i < input_len; i++) {
|
||||
last_block[i] = b_input[i];
|
||||
}
|
||||
|
||||
// pad 10*1
|
||||
last_block[input_len] = 1;
|
||||
for (int i = 0; i < r_bytes - input_len - 1; i++) {
|
||||
last_block[input_len + i + 1] = 0;
|
||||
}
|
||||
// last bit
|
||||
last_block[r_bytes - 1] |= 0x80;
|
||||
|
||||
// #pragma unroll
|
||||
for (int i = 0; i < r_bytes; i += 8) {
|
||||
state[i / 8] ^= *(uint64_t*)(last_block + i);
|
||||
}
|
||||
keccakf(state);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < d_bytes; i += 8) {
|
||||
*(uint64_t*)(b_output + i) = state[i / 8];
|
||||
}
|
||||
}
|
||||
|
||||
template <int C, int D>
|
||||
cudaError_t
|
||||
keccak_hash(uint8_t* input, int input_block_size, int number_of_blocks, uint8_t* output, KeccakConfig config)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
cudaStream_t& stream = config.ctx.stream;
|
||||
|
||||
uint8_t* input_device;
|
||||
if (config.are_inputs_on_device) {
|
||||
input_device = input;
|
||||
} else {
|
||||
CHK_IF_RETURN(cudaMallocAsync(&input_device, number_of_blocks * input_block_size, stream));
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpyAsync(input_device, input, number_of_blocks * input_block_size, cudaMemcpyHostToDevice, stream));
|
||||
}
|
||||
|
||||
uint8_t* output_device;
|
||||
if (config.are_outputs_on_device) {
|
||||
output_device = output;
|
||||
} else {
|
||||
CHK_IF_RETURN(cudaMallocAsync(&output_device, number_of_blocks * (D / 8), stream));
|
||||
}
|
||||
|
||||
int number_of_threads = 1024;
|
||||
int number_of_gpu_blocks = (number_of_blocks - 1) / number_of_threads + 1;
|
||||
keccak_hash_blocks<C, D><<<number_of_gpu_blocks, number_of_threads, 0, stream>>>(
|
||||
input_device, input_block_size, number_of_blocks, output_device);
|
||||
|
||||
if (!config.are_inputs_on_device) CHK_IF_RETURN(cudaFreeAsync(input_device, stream));
|
||||
|
||||
if (!config.are_outputs_on_device) {
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(output, output_device, number_of_blocks * (D / 8), cudaMemcpyDeviceToHost, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(output_device, stream));
|
||||
}
|
||||
|
||||
if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(stream));
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
extern "C" cudaError_t
|
||||
Keccak256(uint8_t* input, int input_block_size, int number_of_blocks, uint8_t* output, KeccakConfig config)
|
||||
{
|
||||
return keccak_hash<512, 256>(input, input_block_size, number_of_blocks, output, config);
|
||||
}
|
||||
|
||||
extern "C" cudaError_t
|
||||
Keccak512(uint8_t* input, int input_block_size, int number_of_blocks, uint8_t* output, KeccakConfig config)
|
||||
{
|
||||
return keccak_hash<1024, 512>(input, input_block_size, number_of_blocks, output, config);
|
||||
}
|
||||
} // namespace keccak
|
||||
56
icicle/appUtils/keccak/keccak.cuh
Normal file
56
icicle/appUtils/keccak/keccak.cuh
Normal file
@@ -0,0 +1,56 @@
|
||||
#pragma once
|
||||
#ifndef KECCAK_H
|
||||
#define KECCAK_H
|
||||
|
||||
#include <cstdint>
|
||||
#include "utils/device_context.cuh"
|
||||
#include "utils/error_handler.cuh"
|
||||
|
||||
namespace keccak {
|
||||
/**
|
||||
* @struct KeccakConfig
|
||||
* Struct that encodes various Keccak parameters.
|
||||
*/
|
||||
struct KeccakConfig {
|
||||
device_context::DeviceContext ctx; /**< Details related to the device such as its id and stream id. */
|
||||
bool are_inputs_on_device; /**< True if inputs are on device and false if they're on host. Default value: false. */
|
||||
bool are_outputs_on_device; /**< If true, output is preserved on device, otherwise on host. Default value: false. */
|
||||
bool is_async; /**< Whether to run the Keccak asynchronously. If set to `true`, the keccak_hash function will be
|
||||
* non-blocking and you'd need to synchronize it explicitly by running
|
||||
* `cudaStreamSynchronize` or `cudaDeviceSynchronize`. If set to false, keccak_hash
|
||||
* function will block the current CPU thread. */
|
||||
};
|
||||
|
||||
KeccakConfig default_keccak_config()
|
||||
{
|
||||
device_context::DeviceContext ctx = device_context::get_default_device_context();
|
||||
KeccakConfig config = {
|
||||
ctx, // ctx
|
||||
false, // are_inputes_on_device
|
||||
false, // are_outputs_on_device
|
||||
false, // is_async
|
||||
};
|
||||
return config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Compute the keccak hash over a sequence of preimages.
|
||||
* Takes {number_of_blocks * input_block_size} u64s of input and computes {number_of_blocks} outputs, each of size {D
|
||||
* / 64} u64
|
||||
* @tparam C - number of bits of capacity (c = b - r = 1600 - r). Only multiples of 64 are supported.
|
||||
* @tparam D - number of bits of output. Only multiples of 64 are supported.
|
||||
* @param input a pointer to the input data. May be allocated on device or on host, regulated
|
||||
* by the config. Must be of size [input_block_size](@ref input_block_size) * [number_of_blocks](@ref
|
||||
* number_of_blocks)}.
|
||||
* @param input_block_size - size of each input block in bytes. Should be divisible by 8.
|
||||
* @param number_of_blocks number of input and output blocks. One GPU thread processes one block
|
||||
* @param output a pointer to the output data. May be allocated on device or on host, regulated
|
||||
* by the config. Must be of size [output_block_size](@ref output_block_size) * [number_of_blocks](@ref
|
||||
* number_of_blocks)}
|
||||
*/
|
||||
template <int C, int D>
|
||||
cudaError_t
|
||||
keccak_hash(uint8_t* input, int input_block_size, int number_of_blocks, uint8_t* output, KeccakConfig config);
|
||||
} // namespace keccak
|
||||
|
||||
#endif
|
||||
67
icicle/appUtils/keccak/test.cu
Normal file
67
icicle/appUtils/keccak/test.cu
Normal file
@@ -0,0 +1,67 @@
|
||||
#include "utils/device_context.cuh"
|
||||
#include "keccak.cu"
|
||||
|
||||
// #define DEBUG
|
||||
|
||||
#ifndef __CUDA_ARCH__
|
||||
#include <cassert>
|
||||
#include <chrono>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <iomanip>
|
||||
|
||||
using namespace keccak;
|
||||
|
||||
#define D 256
|
||||
|
||||
#define START_TIMER(timer) auto timer##_start = std::chrono::high_resolution_clock::now();
|
||||
#define END_TIMER(timer, msg) \
|
||||
printf("%s: %.0f ms\n", msg, FpMilliseconds(std::chrono::high_resolution_clock::now() - timer##_start).count());
|
||||
|
||||
void uint8ToHexString(const uint8_t* values, int size)
|
||||
{
|
||||
std::stringstream ss;
|
||||
|
||||
for (int i = 0; i < size; ++i) {
|
||||
ss << std::hex << std::setw(2) << std::setfill('0') << (int)values[i];
|
||||
}
|
||||
|
||||
std::string hexString = ss.str();
|
||||
std::cout << hexString << std::endl;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
using FpMilliseconds = std::chrono::duration<float, std::chrono::milliseconds::period>;
|
||||
using FpMicroseconds = std::chrono::duration<float, std::chrono::microseconds::period>;
|
||||
|
||||
START_TIMER(allocation_timer);
|
||||
// Prepare input data of [0, 1, 2 ... (number_of_blocks * input_block_size) - 1]
|
||||
int number_of_blocks = argc > 1 ? 1 << atoi(argv[1]) : 1024;
|
||||
int input_block_size = argc > 2 ? atoi(argv[2]) : 136;
|
||||
|
||||
uint8_t* in_ptr = static_cast<uint8_t*>(malloc(number_of_blocks * input_block_size));
|
||||
for (uint64_t i = 0; i < number_of_blocks * input_block_size; i++) {
|
||||
in_ptr[i] = (uint8_t)i;
|
||||
}
|
||||
|
||||
END_TIMER(allocation_timer, "Allocate mem and fill input");
|
||||
|
||||
uint8_t* out_ptr = static_cast<uint8_t*>(malloc(number_of_blocks * (D / 8)));
|
||||
|
||||
START_TIMER(keccak_timer);
|
||||
KeccakConfig config = default_keccak_config();
|
||||
Keccak256(in_ptr, input_block_size, number_of_blocks, out_ptr, config);
|
||||
END_TIMER(keccak_timer, "Keccak")
|
||||
|
||||
for (int i = 0; i < number_of_blocks; i++) {
|
||||
#ifdef DEBUG
|
||||
uint8ToHexString(out_ptr + i * (D / 8), D / 8);
|
||||
#endif
|
||||
}
|
||||
|
||||
free(in_ptr);
|
||||
free(out_ptr);
|
||||
}
|
||||
|
||||
#endif
|
||||
@@ -1,4 +1,4 @@
|
||||
test_msm:
|
||||
mkdir -p work
|
||||
nvcc -o work/test_msm -std=c++17 -I. -I../.. tests/msm_test.cu
|
||||
work/test_msm
|
||||
work/test_msm
|
||||
|
||||
@@ -25,9 +25,19 @@ namespace msm {
|
||||
#define MAX_TH 256
|
||||
|
||||
// #define SIGNED_DIG //WIP
|
||||
// #define BIG_TRIANGLE
|
||||
// #define SSM_SUM //WIP
|
||||
|
||||
template <typename A, typename P>
|
||||
__global__ void left_shift_kernel(A* points, const unsigned shift, const unsigned count, A* points_out)
|
||||
{
|
||||
const unsigned tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid >= count) return;
|
||||
P point = P::from_affine(points[tid]);
|
||||
for (unsigned i = 0; i < shift; i++)
|
||||
point = P::dbl(point);
|
||||
points_out[tid] = P::to_affine(point);
|
||||
}
|
||||
|
||||
unsigned get_optimal_c(int bitsize) { return max((unsigned)ceil(log2(bitsize)) - 4, 1U); }
|
||||
|
||||
template <typename E>
|
||||
@@ -148,47 +158,38 @@ namespace msm {
|
||||
__global__ void split_scalars_kernel(
|
||||
unsigned* buckets_indices,
|
||||
unsigned* point_indices,
|
||||
S* scalars,
|
||||
const S* scalars,
|
||||
unsigned nof_scalars,
|
||||
unsigned points_size,
|
||||
unsigned msm_size,
|
||||
unsigned nof_bms,
|
||||
unsigned bm_bitsize,
|
||||
unsigned c)
|
||||
unsigned c,
|
||||
unsigned precomputed_bms_stride)
|
||||
{
|
||||
// constexpr unsigned sign_mask = 0x80000000;
|
||||
// constexpr unsigned trash_bucket = 0x80000000;
|
||||
unsigned tid = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
if (tid >= nof_scalars) return;
|
||||
|
||||
unsigned bucket_index;
|
||||
// unsigned bucket_index2;
|
||||
unsigned current_index;
|
||||
unsigned msm_index = tid / msm_size;
|
||||
// unsigned borrow = 0;
|
||||
S& scalar = scalars[tid];
|
||||
const S& scalar = scalars[tid];
|
||||
for (unsigned bm = 0; bm < nof_bms; bm++) {
|
||||
const unsigned precomputed_index = bm / precomputed_bms_stride;
|
||||
const unsigned target_bm = bm % precomputed_bms_stride;
|
||||
|
||||
bucket_index = scalar.get_scalar_digit(bm, c);
|
||||
#ifdef SIGNED_DIG
|
||||
bucket_index += borrow;
|
||||
borrow = 0;
|
||||
unsigned sign = 0;
|
||||
if (bucket_index > (1 << (c - 1))) {
|
||||
bucket_index = (1 << c) - bucket_index;
|
||||
borrow = 1;
|
||||
sign = sign_mask;
|
||||
}
|
||||
#endif
|
||||
current_index = bm * nof_scalars + tid;
|
||||
#ifdef SIGNED_DIG
|
||||
point_indices[current_index] = sign | tid; // the point index is saved for later
|
||||
#else
|
||||
buckets_indices[current_index] =
|
||||
(msm_index << (c + bm_bitsize)) | (bm << c) |
|
||||
bucket_index; // the bucket module number and the msm number are appended at the msbs
|
||||
if (bucket_index == 0) buckets_indices[current_index] = 0; // will be skipped
|
||||
point_indices[current_index] = tid % points_size; // the point index is saved for later
|
||||
#endif
|
||||
|
||||
if (bucket_index != 0) {
|
||||
buckets_indices[current_index] =
|
||||
(msm_index << (c + bm_bitsize)) | (target_bm << c) |
|
||||
bucket_index; // the bucket module number and the msm number are appended at the msbs
|
||||
} else {
|
||||
buckets_indices[current_index] = 0; // will be skipped
|
||||
}
|
||||
point_indices[current_index] =
|
||||
tid % points_size + points_size * precomputed_index; // the point index is saved for later
|
||||
}
|
||||
}
|
||||
|
||||
@@ -223,19 +224,11 @@ namespace msm {
|
||||
const unsigned msm_idx_shift,
|
||||
const unsigned c)
|
||||
{
|
||||
// constexpr unsigned sign_mask = 0x80000000;
|
||||
unsigned tid = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
if (tid >= nof_buckets_to_compute) return;
|
||||
#ifdef SIGNED_DIG // todo - fix
|
||||
const unsigned msm_index = single_bucket_indices[tid] >> msm_idx_shift;
|
||||
const unsigned bm_index = (single_bucket_indices[tid] & ((1 << msm_idx_shift) - 1)) >> c;
|
||||
const unsigned bucket_index =
|
||||
msm_index * nof_buckets + bm_index * ((1 << (c - 1)) + 1) + (single_bucket_indices[tid] & ((1 << c) - 1));
|
||||
#else
|
||||
unsigned msm_index = single_bucket_indices[tid] >> msm_idx_shift;
|
||||
const unsigned single_bucket_index = (single_bucket_indices[tid] & ((1 << msm_idx_shift) - 1));
|
||||
unsigned bucket_index = msm_index * nof_buckets + single_bucket_index;
|
||||
#endif
|
||||
const unsigned bucket_offset = bucket_offsets[tid];
|
||||
const unsigned bucket_size = bucket_sizes[tid];
|
||||
|
||||
@@ -243,14 +236,7 @@ namespace msm {
|
||||
for (unsigned i = 0; i < bucket_size;
|
||||
i++) { // add the relevant points starting from the relevant offset up to the bucket size
|
||||
unsigned point_ind = point_indices[bucket_offset + i];
|
||||
#ifdef SIGNED_DIG
|
||||
unsigned sign = point_ind & sign_mask;
|
||||
point_ind &= ~sign_mask;
|
||||
A point = points[point_ind];
|
||||
if (sign) point = A::neg(point);
|
||||
#else
|
||||
A point = points[point_ind];
|
||||
#endif
|
||||
bucket =
|
||||
i ? (point == A::zero() ? bucket : bucket + point) : (point == A::zero() ? P::zero() : P::from_affine(point));
|
||||
}
|
||||
@@ -317,11 +303,7 @@ namespace msm {
|
||||
{
|
||||
unsigned tid = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
if (tid >= nof_bms) return;
|
||||
#ifdef SIGNED_DIG
|
||||
unsigned buckets_in_bm = (1 << c) + 1;
|
||||
#else
|
||||
unsigned buckets_in_bm = (1 << c);
|
||||
#endif
|
||||
P line_sum = buckets[(tid + 1) * buckets_in_bm - 1];
|
||||
final_sums[tid] = line_sum;
|
||||
for (unsigned i = buckets_in_bm - 2; i > 0; i--) {
|
||||
@@ -378,8 +360,8 @@ namespace msm {
|
||||
cudaError_t bucket_method_msm(
|
||||
unsigned bitsize,
|
||||
unsigned c,
|
||||
S* scalars,
|
||||
A* points,
|
||||
const S* scalars,
|
||||
const A* points,
|
||||
unsigned batch_size, // number of MSMs to compute
|
||||
unsigned single_msm_size, // number of elements per MSM (a.k.a N)
|
||||
unsigned nof_points, // number of EC points in 'points' array. Must be either (1) single_msm_size if MSMs are
|
||||
@@ -392,6 +374,7 @@ namespace msm {
|
||||
bool are_results_on_device,
|
||||
bool is_big_triangle,
|
||||
int large_bucket_factor,
|
||||
int precompute_factor,
|
||||
bool is_async,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
@@ -403,44 +386,59 @@ namespace msm {
|
||||
THROW_ICICLE_ERR(
|
||||
IcicleError_t::InvalidArgument, "bucket_method_msm: #points must be divisible by single_msm_size*batch_size");
|
||||
}
|
||||
if ((precompute_factor & (precompute_factor - 1)) != 0) {
|
||||
THROW_ICICLE_ERR(
|
||||
IcicleError_t::InvalidArgument,
|
||||
"bucket_method_msm: precompute factors that are not powers of 2 currently unsupported");
|
||||
}
|
||||
|
||||
S* d_scalars;
|
||||
const S* d_scalars;
|
||||
S* d_allocated_scalars = nullptr;
|
||||
if (!are_scalars_on_device) {
|
||||
// copy scalars to gpu
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_scalars, sizeof(S) * nof_scalars, stream));
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(d_scalars, scalars, sizeof(S) * nof_scalars, cudaMemcpyHostToDevice, stream));
|
||||
} else {
|
||||
d_scalars = scalars;
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_scalars, sizeof(S) * nof_scalars, stream));
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpyAsync(d_allocated_scalars, scalars, sizeof(S) * nof_scalars, cudaMemcpyHostToDevice, stream));
|
||||
|
||||
if (are_scalars_montgomery_form) {
|
||||
CHK_IF_RETURN(mont::FromMontgomery(d_allocated_scalars, nof_scalars, stream, d_allocated_scalars));
|
||||
}
|
||||
d_scalars = d_allocated_scalars;
|
||||
} else { // already on device
|
||||
if (are_scalars_montgomery_form) {
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_scalars, sizeof(S) * nof_scalars, stream));
|
||||
CHK_IF_RETURN(mont::FromMontgomery(scalars, nof_scalars, stream, d_allocated_scalars));
|
||||
d_scalars = d_allocated_scalars;
|
||||
} else {
|
||||
d_scalars = scalars;
|
||||
}
|
||||
}
|
||||
|
||||
if (are_scalars_montgomery_form) {
|
||||
if (are_scalars_on_device) {
|
||||
S* d_mont_scalars;
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_mont_scalars, sizeof(S) * nof_scalars, stream));
|
||||
CHK_IF_RETURN(mont::FromMontgomery(d_scalars, nof_scalars, stream, d_mont_scalars));
|
||||
d_scalars = d_mont_scalars;
|
||||
} else
|
||||
CHK_IF_RETURN(mont::FromMontgomery(d_scalars, nof_scalars, stream, d_scalars));
|
||||
}
|
||||
unsigned total_bms_per_msm = (bitsize + c - 1) / c;
|
||||
unsigned nof_bms_per_msm = (total_bms_per_msm - 1) / precompute_factor + 1;
|
||||
unsigned input_indexes_count = nof_scalars * total_bms_per_msm;
|
||||
|
||||
unsigned bm_bitsize = (unsigned)ceil(log2(nof_bms_per_msm));
|
||||
|
||||
unsigned nof_bms_per_msm = (bitsize + c - 1) / c;
|
||||
unsigned* bucket_indices;
|
||||
unsigned* point_indices;
|
||||
unsigned* sorted_bucket_indices;
|
||||
unsigned* sorted_point_indices;
|
||||
CHK_IF_RETURN(cudaMallocAsync(&bucket_indices, sizeof(unsigned) * nof_scalars * nof_bms_per_msm, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&point_indices, sizeof(unsigned) * nof_scalars * nof_bms_per_msm, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&sorted_bucket_indices, sizeof(unsigned) * nof_scalars * nof_bms_per_msm, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&sorted_point_indices, sizeof(unsigned) * nof_scalars * nof_bms_per_msm, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&bucket_indices, sizeof(unsigned) * input_indexes_count, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&point_indices, sizeof(unsigned) * input_indexes_count, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&sorted_bucket_indices, sizeof(unsigned) * input_indexes_count, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&sorted_point_indices, sizeof(unsigned) * input_indexes_count, stream));
|
||||
|
||||
unsigned bm_bitsize = (unsigned)ceil(log2(nof_bms_per_msm));
|
||||
// split scalars into digits
|
||||
unsigned NUM_THREADS = 1 << 10;
|
||||
unsigned NUM_BLOCKS = (nof_scalars + NUM_THREADS - 1) / NUM_THREADS;
|
||||
split_scalars_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
|
||||
bucket_indices, point_indices, d_scalars, nof_scalars, nof_points, single_msm_size, nof_bms_per_msm, bm_bitsize,
|
||||
c);
|
||||
|
||||
split_scalars_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
|
||||
bucket_indices, point_indices, d_scalars, nof_scalars, nof_points, single_msm_size, total_bms_per_msm,
|
||||
bm_bitsize, c, nof_bms_per_msm);
|
||||
nof_points *= precompute_factor;
|
||||
|
||||
// ------------------------------ Sorting routines for scalars start here ----------------------------------
|
||||
// sort indices - the indices are sorted from smallest to largest in order to group together the points that
|
||||
// belong to each bucket
|
||||
unsigned* sort_indices_temp_storage{};
|
||||
@@ -450,26 +448,22 @@ namespace msm {
|
||||
// more info
|
||||
CHK_IF_RETURN(cub::DeviceRadixSort::SortPairs(
|
||||
sort_indices_temp_storage, sort_indices_temp_storage_bytes, bucket_indices, sorted_bucket_indices,
|
||||
point_indices, sorted_point_indices, nof_scalars * nof_bms_per_msm, 0, sizeof(unsigned) * 8, stream));
|
||||
point_indices, sorted_point_indices, input_indexes_count, 0, sizeof(unsigned) * 8, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&sort_indices_temp_storage, sort_indices_temp_storage_bytes, stream));
|
||||
// The second to last parameter is the default value supplied explicitly to allow passing the stream
|
||||
// See https://nvlabs.github.io/cub/structcub_1_1_device_radix_sort.html#a65e82152de448c6373ed9563aaf8af7e for
|
||||
// more info
|
||||
CHK_IF_RETURN(cub::DeviceRadixSort::SortPairs(
|
||||
sort_indices_temp_storage, sort_indices_temp_storage_bytes, bucket_indices, sorted_bucket_indices,
|
||||
point_indices, sorted_point_indices, nof_scalars * nof_bms_per_msm, 0, sizeof(unsigned) * 8, stream));
|
||||
point_indices, sorted_point_indices, input_indexes_count, 0, sizeof(unsigned) * 8, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(sort_indices_temp_storage, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(bucket_indices, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(point_indices, stream));
|
||||
|
||||
// compute number of bucket modules and number of buckets in each module
|
||||
unsigned nof_bms_in_batch = nof_bms_per_msm * batch_size;
|
||||
#ifdef SIGNED_DIG
|
||||
const unsigned nof_buckets = nof_bms_per_msm * ((1 << (c - 1)) + 1); // signed digits
|
||||
#else
|
||||
// minus nof_bms_per_msm because zero bucket is not included in each bucket module
|
||||
const unsigned nof_buckets = (nof_bms_per_msm << c) - nof_bms_per_msm;
|
||||
#endif
|
||||
const unsigned total_nof_buckets = nof_buckets * batch_size;
|
||||
|
||||
// find bucket_sizes
|
||||
@@ -484,11 +478,11 @@ namespace msm {
|
||||
size_t encode_temp_storage_bytes = 0;
|
||||
CHK_IF_RETURN(cub::DeviceRunLengthEncode::Encode(
|
||||
encode_temp_storage, encode_temp_storage_bytes, sorted_bucket_indices, single_bucket_indices, bucket_sizes,
|
||||
nof_buckets_to_compute, nof_bms_per_msm * nof_scalars, stream));
|
||||
nof_buckets_to_compute, input_indexes_count, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&encode_temp_storage, encode_temp_storage_bytes, stream));
|
||||
CHK_IF_RETURN(cub::DeviceRunLengthEncode::Encode(
|
||||
encode_temp_storage, encode_temp_storage_bytes, sorted_bucket_indices, single_bucket_indices, bucket_sizes,
|
||||
nof_buckets_to_compute, nof_bms_per_msm * nof_scalars, stream));
|
||||
nof_buckets_to_compute, input_indexes_count, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(encode_temp_storage, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(sorted_bucket_indices, stream));
|
||||
|
||||
@@ -504,28 +498,33 @@ namespace msm {
|
||||
offsets_temp_storage, offsets_temp_storage_bytes, bucket_sizes, bucket_offsets, total_nof_buckets + 1, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(offsets_temp_storage, stream));
|
||||
|
||||
A* d_points;
|
||||
cudaStream_t stream_points;
|
||||
// ----------- Starting to upload points (if they were on host) in parallel to scalar sorting ----------------
|
||||
const A* d_points;
|
||||
A* d_allocated_points = nullptr;
|
||||
cudaStream_t stream_points = nullptr;
|
||||
if (!are_points_on_device || are_points_montgomery_form) CHK_IF_RETURN(cudaStreamCreate(&stream_points));
|
||||
if (!are_points_on_device) {
|
||||
// copy points to gpu
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_points, sizeof(A) * nof_points, stream_points));
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(d_points, points, sizeof(A) * nof_points, cudaMemcpyHostToDevice, stream_points));
|
||||
} else {
|
||||
d_points = points;
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_points, sizeof(A) * nof_points, stream_points));
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpyAsync(d_allocated_points, points, sizeof(A) * nof_points, cudaMemcpyHostToDevice, stream_points));
|
||||
|
||||
if (are_points_montgomery_form) {
|
||||
CHK_IF_RETURN(mont::FromMontgomery(d_allocated_points, nof_points, stream_points, d_allocated_points));
|
||||
}
|
||||
d_points = d_allocated_points;
|
||||
} else { // already on device
|
||||
if (are_points_montgomery_form) {
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_points, sizeof(A) * nof_points, stream_points));
|
||||
CHK_IF_RETURN(mont::FromMontgomery(points, nof_points, stream_points, d_allocated_points));
|
||||
d_points = d_allocated_points;
|
||||
} else {
|
||||
d_points = points;
|
||||
}
|
||||
}
|
||||
|
||||
if (are_points_montgomery_form) {
|
||||
if (are_points_on_device) {
|
||||
A* d_mont_points;
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_mont_points, sizeof(A) * nof_points, stream_points));
|
||||
CHK_IF_RETURN(mont::FromMontgomery(d_points, nof_points, stream_points, d_mont_points));
|
||||
d_points = d_mont_points;
|
||||
} else
|
||||
CHK_IF_RETURN(mont::FromMontgomery(d_points, nof_points, stream_points, d_points));
|
||||
}
|
||||
cudaEvent_t event_points_uploaded;
|
||||
if (!are_points_on_device || are_points_montgomery_form) {
|
||||
if (stream_points) {
|
||||
CHK_IF_RETURN(cudaEventCreateWithFlags(&event_points_uploaded, cudaEventDisableTiming));
|
||||
CHK_IF_RETURN(cudaEventRecord(event_points_uploaded, stream_points));
|
||||
}
|
||||
@@ -609,7 +608,7 @@ namespace msm {
|
||||
cudaMemcpyAsync(&h_nof_large_buckets, nof_large_buckets, sizeof(unsigned), cudaMemcpyDeviceToHost, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(nof_large_buckets, stream));
|
||||
|
||||
if (!are_points_on_device || are_points_montgomery_form) {
|
||||
if (stream_points) {
|
||||
// by this point, points need to be already uploaded and un-Montgomeried
|
||||
CHK_IF_RETURN(cudaStreamWaitEvent(stream, event_points_uploaded));
|
||||
CHK_IF_RETURN(cudaEventDestroy(event_points_uploaded));
|
||||
@@ -618,7 +617,7 @@ namespace msm {
|
||||
|
||||
cudaStream_t stream_large_buckets;
|
||||
cudaEvent_t event_large_buckets_accumulated;
|
||||
// this is where handling of large buckets happens (if there are any)
|
||||
// ---------------- This is where handling of large buckets happens (if there are any) -------------
|
||||
if (h_nof_large_buckets > 0 && bucket_th > 0) {
|
||||
CHK_IF_RETURN(cudaStreamCreate(&stream_large_buckets));
|
||||
CHK_IF_RETURN(cudaEventCreateWithFlags(&event_large_buckets_accumulated, cudaEventDisableTiming));
|
||||
@@ -700,10 +699,11 @@ namespace msm {
|
||||
CHK_IF_RETURN(cudaEventRecord(event_large_buckets_accumulated, stream_large_buckets));
|
||||
}
|
||||
|
||||
// launch the accumulation kernel with maximum threads
|
||||
// ------------------------- Accumulation of (non-large) buckets ---------------------------------
|
||||
if (h_nof_buckets_to_compute > h_nof_large_buckets) {
|
||||
NUM_THREADS = 1 << 8;
|
||||
NUM_BLOCKS = (h_nof_buckets_to_compute - h_nof_large_buckets + NUM_THREADS - 1) / NUM_THREADS;
|
||||
// launch the accumulation kernel with maximum threads
|
||||
accumulate_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
|
||||
buckets, sorted_bucket_offsets + h_nof_large_buckets, sorted_bucket_sizes + h_nof_large_buckets,
|
||||
sorted_single_bucket_indices + h_nof_large_buckets, sorted_point_indices, d_points,
|
||||
@@ -719,24 +719,11 @@ namespace msm {
|
||||
CHK_IF_RETURN(cudaStreamDestroy(stream_large_buckets));
|
||||
}
|
||||
|
||||
#ifdef SSM_SUM
|
||||
// sum each bucket
|
||||
NUM_THREADS = 1 << 10;
|
||||
NUM_BLOCKS = (nof_buckets + NUM_THREADS - 1) / NUM_THREADS;
|
||||
ssm_buckets_kernel<fake_point, fake_scalar>
|
||||
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(buckets, single_bucket_indices, nof_buckets, c);
|
||||
|
||||
// sum each bucket module
|
||||
P* final_results;
|
||||
CHK_IF_RETURN(cudaMallocAsync(&final_results, sizeof(P) * nof_bms_per_msm, stream));
|
||||
NUM_THREADS = 1 << c;
|
||||
NUM_BLOCKS = nof_bms_per_msm;
|
||||
sum_reduction_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(buckets, final_results);
|
||||
#endif
|
||||
|
||||
P* d_final_result;
|
||||
if (!are_results_on_device) CHK_IF_RETURN(cudaMallocAsync(&d_final_result, sizeof(P) * batch_size, stream));
|
||||
P* d_allocated_final_result = nullptr;
|
||||
if (!are_results_on_device)
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_final_result, sizeof(P) * batch_size, stream));
|
||||
|
||||
// --- Reduction of buckets happens here, after this we'll get a single sum for each bucket module/window ---
|
||||
unsigned nof_empty_bms_per_batch = 0; // for non-triangle accumluation this may be >0
|
||||
P* final_results;
|
||||
if (is_big_triangle || c == 1) {
|
||||
@@ -744,15 +731,9 @@ namespace msm {
|
||||
// launch the bucket module sum kernel - a thread for each bucket module
|
||||
NUM_THREADS = 32;
|
||||
NUM_BLOCKS = (nof_bms_in_batch + NUM_THREADS - 1) / NUM_THREADS;
|
||||
#ifdef SIGNED_DIG
|
||||
big_triangle_sum_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
|
||||
buckets, final_results, nof_bms_in_batch, c - 1); // sighed digits
|
||||
#else
|
||||
big_triangle_sum_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(buckets, final_results, nof_bms_in_batch, c);
|
||||
#endif
|
||||
} else {
|
||||
unsigned source_bits_count = c;
|
||||
// bool odd_source_c = source_bits_count % 2;
|
||||
unsigned source_windows_count = nof_bms_per_msm;
|
||||
unsigned source_buckets_count = nof_buckets + nof_bms_per_msm;
|
||||
unsigned target_windows_count = 0;
|
||||
@@ -792,12 +773,11 @@ namespace msm {
|
||||
}
|
||||
}
|
||||
if (target_bits_count == 1) {
|
||||
// Note: the reduction ends up with 'target_windows_count' windows per batch element. Some are guaranteed to
|
||||
// be empty when target_windows_count>bitsize.
|
||||
// for example consider bitsize=253 and c=2. The reduction ends with 254 bms but the most significant one is
|
||||
// guaranteed to be zero since the scalars are 253b.
|
||||
// Note: the reduction ends up with 'target_windows_count' windows per batch element. Some are guaranteed
|
||||
// to be empty when target_windows_count>bitsize. for example consider bitsize=253 and c=2. The reduction
|
||||
// ends with 254 bms but the most significant one is guaranteed to be zero since the scalars are 253b.
|
||||
nof_bms_per_msm = target_windows_count;
|
||||
nof_empty_bms_per_batch = target_windows_count - bitsize;
|
||||
nof_empty_bms_per_batch = target_windows_count > bitsize ? target_windows_count - bitsize : 0;
|
||||
nof_bms_in_batch = nof_bms_per_msm * batch_size;
|
||||
|
||||
CHK_IF_RETURN(cudaMallocAsync(&final_results, sizeof(P) * nof_bms_in_batch, stream));
|
||||
@@ -819,28 +799,29 @@ namespace msm {
|
||||
temp_buckets1 = nullptr;
|
||||
temp_buckets2 = nullptr;
|
||||
source_bits_count = target_bits_count;
|
||||
// odd_source_c = source_bits_count % 2;
|
||||
source_windows_count = target_windows_count;
|
||||
source_buckets_count = target_buckets_count;
|
||||
}
|
||||
}
|
||||
|
||||
// launch the double and add kernel, a single thread per batch element
|
||||
// ------- This is the final stage where bucket modules/window sums get added up with appropriate weights
|
||||
// -------
|
||||
NUM_THREADS = 32;
|
||||
NUM_BLOCKS = (batch_size + NUM_THREADS - 1) / NUM_THREADS;
|
||||
// launch the double and add kernel, a single thread per batch element
|
||||
final_accumulation_kernel<P, S><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
|
||||
final_results, are_results_on_device ? final_result : d_final_result, batch_size, nof_bms_per_msm,
|
||||
final_results, are_results_on_device ? final_result : d_allocated_final_result, batch_size, nof_bms_per_msm,
|
||||
nof_empty_bms_per_batch, c);
|
||||
CHK_IF_RETURN(cudaFreeAsync(final_results, stream));
|
||||
|
||||
if (!are_results_on_device)
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpyAsync(final_result, d_final_result, sizeof(P) * batch_size, cudaMemcpyDeviceToHost, stream));
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(
|
||||
final_result, d_allocated_final_result, sizeof(P) * batch_size, cudaMemcpyDeviceToHost, stream));
|
||||
|
||||
// free memory
|
||||
if (!are_scalars_on_device || are_scalars_montgomery_form) CHK_IF_RETURN(cudaFreeAsync(d_scalars, stream));
|
||||
if (!are_points_on_device || are_points_montgomery_form) CHK_IF_RETURN(cudaFreeAsync(d_points, stream));
|
||||
if (!are_results_on_device) CHK_IF_RETURN(cudaFreeAsync(d_final_result, stream));
|
||||
if (d_allocated_scalars) CHK_IF_RETURN(cudaFreeAsync(d_allocated_scalars, stream));
|
||||
if (d_allocated_points) CHK_IF_RETURN(cudaFreeAsync(d_allocated_points, stream));
|
||||
if (d_allocated_final_result) CHK_IF_RETURN(cudaFreeAsync(d_allocated_final_result, stream));
|
||||
CHK_IF_RETURN(cudaFreeAsync(buckets, stream));
|
||||
|
||||
if (!is_async) CHK_IF_RETURN(cudaStreamSynchronize(stream));
|
||||
@@ -873,7 +854,7 @@ namespace msm {
|
||||
}
|
||||
|
||||
template <typename S, typename A, typename P>
|
||||
cudaError_t MSM(S* scalars, A* points, int msm_size, MSMConfig& config, P* results)
|
||||
cudaError_t MSM(const S* scalars, const A* points, int msm_size, MSMConfig& config, P* results)
|
||||
{
|
||||
const int bitsize = (config.bitsize == 0) ? S::NBITS : config.bitsize;
|
||||
cudaStream_t& stream = config.ctx.stream;
|
||||
@@ -890,7 +871,59 @@ namespace msm {
|
||||
bitsize, c, scalars, points, config.batch_size, msm_size,
|
||||
(config.points_size == 0) ? msm_size : config.points_size, results, config.are_scalars_on_device,
|
||||
config.are_scalars_montgomery_form, config.are_points_on_device, config.are_points_montgomery_form,
|
||||
config.are_results_on_device, config.is_big_triangle, config.large_bucket_factor, config.is_async, stream));
|
||||
config.are_results_on_device, config.is_big_triangle, config.large_bucket_factor, config.precompute_factor,
|
||||
config.is_async, stream));
|
||||
}
|
||||
|
||||
template <typename A, typename P>
|
||||
cudaError_t PrecomputeMSMBases(
|
||||
A* bases,
|
||||
int bases_size,
|
||||
int precompute_factor,
|
||||
int _c,
|
||||
bool are_bases_on_device,
|
||||
device_context::DeviceContext& ctx,
|
||||
A* output_bases)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
cudaStream_t& stream = ctx.stream;
|
||||
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(
|
||||
output_bases, bases, sizeof(A) * bases_size,
|
||||
are_bases_on_device ? cudaMemcpyDeviceToDevice : cudaMemcpyHostToDevice, stream));
|
||||
|
||||
unsigned c = 16;
|
||||
unsigned total_nof_bms = (P::SCALAR_FF_NBITS - 1) / c + 1;
|
||||
unsigned shift = c * ((total_nof_bms - 1) / precompute_factor + 1);
|
||||
|
||||
unsigned NUM_THREADS = 1 << 8;
|
||||
unsigned NUM_BLOCKS = (bases_size + NUM_THREADS - 1) / NUM_THREADS;
|
||||
for (int i = 1; i < precompute_factor; i++) {
|
||||
left_shift_kernel<A, P><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
|
||||
&output_bases[(i - 1) * bases_size], shift, bases_size, &output_bases[i * bases_size]);
|
||||
}
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
/**
|
||||
* Extern "C" version of [PrecomputeMSMBases](@ref PrecomputeMSMBases) function with the following values of
|
||||
* template parameters (where the curve is given by `-DCURVE` env variable during build):
|
||||
* - `A` is the [affine representation](@ref affine_t) of curve points;
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
extern "C" cudaError_t CONCAT_EXPAND(CURVE, PrecomputeMSMBases)(
|
||||
curve_config::affine_t* bases,
|
||||
int bases_size,
|
||||
int precompute_factor,
|
||||
int _c,
|
||||
bool are_bases_on_device,
|
||||
device_context::DeviceContext& ctx,
|
||||
curve_config::affine_t* output_bases)
|
||||
{
|
||||
return PrecomputeMSMBases<curve_config::affine_t, curve_config::projective_t>(
|
||||
bases, bases_size, precompute_factor, _c, are_bases_on_device, ctx, output_bases);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -902,8 +935,8 @@ namespace msm {
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
extern "C" cudaError_t CONCAT_EXPAND(CURVE, MSMCuda)(
|
||||
curve_config::scalar_t* scalars,
|
||||
curve_config::affine_t* points,
|
||||
const curve_config::scalar_t* scalars,
|
||||
const curve_config::affine_t* points,
|
||||
int msm_size,
|
||||
MSMConfig& config,
|
||||
curve_config::projective_t* out)
|
||||
@@ -912,13 +945,27 @@ namespace msm {
|
||||
scalars, points, msm_size, config, out);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extern "C" version of [DefaultMSMConfig](@ref DefaultMSMConfig) function.
|
||||
*/
|
||||
extern "C" MSMConfig CONCAT_EXPAND(CURVE, DefaultMSMConfig)() { return DefaultMSMConfig<curve_config::affine_t>(); }
|
||||
|
||||
#if defined(G2_DEFINED)
|
||||
|
||||
/**
|
||||
* Extern "C" version of [PrecomputeMSMBases](@ref PrecomputeMSMBases) function with the following values of
|
||||
* template parameters (where the curve is given by `-DCURVE` env variable during build):
|
||||
* - `A` is the [affine representation](@ref g2_affine_t) of G2 curve points;
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
extern "C" cudaError_t CONCAT_EXPAND(CURVE, G2PrecomputeMSMBases)(
|
||||
curve_config::g2_affine_t* bases,
|
||||
int bases_size,
|
||||
int precompute_factor,
|
||||
int _c,
|
||||
bool are_bases_on_device,
|
||||
device_context::DeviceContext& ctx,
|
||||
curve_config::g2_affine_t* output_bases)
|
||||
{
|
||||
return PrecomputeMSMBases<curve_config::g2_affine_t, curve_config::g2_projective_t>(
|
||||
bases, bases_size, precompute_factor, _c, are_bases_on_device, ctx, output_bases);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extern "C" version of [MSM](@ref MSM) function with the following values of template parameters
|
||||
* (where the curve is given by `-DCURVE` env variable during build):
|
||||
@@ -928,8 +975,8 @@ namespace msm {
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
extern "C" cudaError_t CONCAT_EXPAND(CURVE, G2MSMCuda)(
|
||||
curve_config::scalar_t* scalars,
|
||||
curve_config::g2_affine_t* points,
|
||||
const curve_config::scalar_t* scalars,
|
||||
const curve_config::g2_affine_t* points,
|
||||
int msm_size,
|
||||
MSMConfig& config,
|
||||
curve_config::g2_projective_t* out)
|
||||
@@ -938,15 +985,6 @@ namespace msm {
|
||||
scalars, points, msm_size, config, out);
|
||||
}
|
||||
|
||||
/**
|
||||
* Extern "C" version of [DefaultMSMConfig](@ref DefaultMSMConfig) function for the G2 curve
|
||||
* (functionally no different than the default MSM config function for G1).
|
||||
*/
|
||||
extern "C" MSMConfig CONCAT_EXPAND(CURVE, G2DefaultMSMConfig)()
|
||||
{
|
||||
return DefaultMSMConfig<curve_config::g2_affine_t>();
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace msm
|
||||
@@ -43,14 +43,18 @@ namespace msm {
|
||||
* variable is set equal to the MSM size. And if every MSM uses a distinct set of
|
||||
* points, it should be set to the product of MSM size and [batch_size](@ref
|
||||
* batch_size). Default value: 0 (meaning it's equal to the MSM size). */
|
||||
int precompute_factor; /**< The number of extra points to pre-compute for each point. Larger values decrease the
|
||||
int precompute_factor; /**< The number of extra points to pre-compute for each point. See the
|
||||
* [PrecomputeMSMBases](@ref PrecomputeMSMBases) function, `precompute_factor` passed
|
||||
* there needs to be equal to the one used here. Larger values decrease the
|
||||
* number of computations to make, on-line memory footprint, but increase the static
|
||||
* memory footprint. Default value: 1 (i.e. don't pre-compute). */
|
||||
int c; /**< \f$ c \f$ value, or "window bitsize" which is the main parameter of the "bucket
|
||||
* method" that we use to solve the MSM problem. As a rule of thumb, larger value
|
||||
* means more on-line memory footprint but also more parallelism and less computational
|
||||
* complexity (up to a certain point). Default value: 0 (the optimal value of \f$ c \f$
|
||||
* is chosen automatically). */
|
||||
* complexity (up to a certain point). Currently pre-computation is independent of
|
||||
* \f$ c \f$, however in the future value of \f$ c \f$ here and the one passed into the
|
||||
* [PrecomputeMSMBases](@ref PrecomputeMSMBases) function will need to be identical.
|
||||
* Default value: 0 (the optimal value of \f$ c \f$ is chosen automatically). */
|
||||
int bitsize; /**< Number of bits of the largest scalar. Typically equals the bitsize of scalar field,
|
||||
* but if a different (better) upper bound is known, it should be reflected in this
|
||||
* variable. Default value: 0 (set to the bitsize of scalar field). */
|
||||
@@ -101,12 +105,39 @@ namespace msm {
|
||||
* Weierstrass](https://hyperelliptic.org/EFD/g1p/auto-shortw-projective.html) point in our codebase.
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*
|
||||
* **Note:** this function is still WIP and the following [MSMConfig](@ref MSMConfig) members do not yet have any
|
||||
* effect: `precompute_factor` (always equals 1) and `ctx.device_id` (0 device is always used).
|
||||
* Also, it's currently better to use `batch_size=1` in most cases (except with dealing with very many MSMs).
|
||||
*/
|
||||
template <typename S, typename A, typename P>
|
||||
cudaError_t MSM(S* scalars, A* points, int msm_size, MSMConfig& config, P* results);
|
||||
cudaError_t MSM(const S* scalars, const A* points, int msm_size, MSMConfig& config, P* results);
|
||||
|
||||
/**
|
||||
* A function that precomputes MSM bases by extending them with their shifted copies.
|
||||
* e.g.:
|
||||
* Original points: \f$ P_0, P_1, P_2, ... P_{size} \f$
|
||||
* Extended points: \f$ P_0, P_1, P_2, ... P_{size}, 2^{l}P_0, 2^{l}P_1, ..., 2^{l}P_{size},
|
||||
* 2^{2l}P_0, 2^{2l}P_1, ..., 2^{2cl}P_{size}, ... \f$
|
||||
* @param bases Bases \f$ P_i \f$. In case of batch MSM, all *unique* points are concatenated.
|
||||
* @param bases_size Number of bases.
|
||||
* @param precompute_factor The number of total precomputed points for each base (including the base itself).
|
||||
* @param _c This is currently unused, but in the future precomputation will need to be aware of
|
||||
* the `c` value used in MSM (see [MSMConfig](@ref MSMConfig)). So to avoid breaking your code with this
|
||||
* upcoming change, make sure to use the same value of `c` in this function and in respective MSMConfig.
|
||||
* @param are_bases_on_device Whether the bases are on device.
|
||||
* @param ctx Device context specifying device id and stream to use.
|
||||
* @param output_bases Device-allocated buffer of size bases_size * precompute_factor for the extended bases.
|
||||
* @tparam A The type of points \f$ \{P_i\} \f$ which is typically an [affine
|
||||
* Weierstrass](https://hyperelliptic.org/EFD/g1p/auto-shortw.html) point.
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*
|
||||
*/
|
||||
template <typename A, typename P>
|
||||
cudaError_t PrecomputeMSMBases(
|
||||
A* bases,
|
||||
int bases_size,
|
||||
int precompute_factor,
|
||||
int _c,
|
||||
bool are_bases_on_device,
|
||||
device_context::DeviceContext& ctx,
|
||||
A* output_bases);
|
||||
|
||||
} // namespace msm
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ public:
|
||||
return os;
|
||||
}
|
||||
|
||||
HOST_DEVICE_INLINE unsigned get_scalar_digit(unsigned digit_num, unsigned digit_width)
|
||||
HOST_DEVICE_INLINE unsigned get_scalar_digit(unsigned digit_num, unsigned digit_width) const
|
||||
{
|
||||
return (x >> (digit_num * digit_width)) & ((1 << digit_width) - 1);
|
||||
}
|
||||
|
||||
@@ -56,7 +56,15 @@ namespace ntt {
|
||||
// Note: the following reorder kernels are fused with normalization for INTT
|
||||
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
|
||||
static __global__ void reorder_digits_inplace_and_normalize_kernel(
|
||||
E* arr, uint32_t log_size, bool dit, bool fast_tw, eRevType rev_type, bool is_normalize, S inverse_N)
|
||||
E* arr,
|
||||
uint32_t log_size,
|
||||
bool columns_batch,
|
||||
uint32_t batch_size,
|
||||
bool dit,
|
||||
bool fast_tw,
|
||||
eRevType rev_type,
|
||||
bool is_normalize,
|
||||
S inverse_N)
|
||||
{
|
||||
// launch N threads (per batch element)
|
||||
// each thread starts from one index and calculates the corresponding group
|
||||
@@ -65,19 +73,20 @@ namespace ntt {
|
||||
|
||||
const uint32_t size = 1 << log_size;
|
||||
const uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
const uint32_t idx = tid % size;
|
||||
const uint32_t batch_idx = tid / size;
|
||||
const uint32_t idx = columns_batch ? tid / batch_size : tid % size;
|
||||
const uint32_t batch_idx = columns_batch ? tid % batch_size : tid / size;
|
||||
if (tid >= size * batch_size) return;
|
||||
|
||||
uint32_t next_element = idx;
|
||||
uint32_t group[MAX_GROUP_SIZE];
|
||||
group[0] = next_element + size * batch_idx;
|
||||
group[0] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx;
|
||||
|
||||
uint32_t i = 1;
|
||||
for (; i < MAX_GROUP_SIZE;) {
|
||||
next_element = generalized_rev(next_element, log_size, dit, fast_tw, rev_type);
|
||||
if (next_element < idx) return; // not handling this group
|
||||
if (next_element == idx) break; // calculated whole group
|
||||
group[i++] = next_element + size * batch_idx;
|
||||
group[i++] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx;
|
||||
}
|
||||
|
||||
--i;
|
||||
@@ -91,9 +100,12 @@ namespace ntt {
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__ void reorder_digits_and_normalize_kernel(
|
||||
E* arr,
|
||||
const E* arr,
|
||||
E* arr_reordered,
|
||||
uint32_t log_size,
|
||||
bool columns_batch,
|
||||
uint32_t batch_size,
|
||||
uint32_t columns_batch_size,
|
||||
bool dit,
|
||||
bool fast_tw,
|
||||
eRevType rev_type,
|
||||
@@ -101,41 +113,46 @@ namespace ntt {
|
||||
S inverse_N)
|
||||
{
|
||||
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= (1 << log_size) * batch_size) return;
|
||||
uint32_t rd = tid;
|
||||
uint32_t wr =
|
||||
((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type);
|
||||
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
|
||||
uint32_t wr = (columns_batch ? 0 : ((tid >> log_size) << log_size)) +
|
||||
generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type);
|
||||
arr_reordered[wr * columns_batch_size + (tid % columns_batch_size)] = is_normalize ? arr[rd] * inverse_N : arr[rd];
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
static __global__ void batch_elementwise_mul_with_reorder(
|
||||
E* in_vec,
|
||||
int n_elements,
|
||||
int batch_size,
|
||||
static __global__ void batch_elementwise_mul_with_reorder_kernel(
|
||||
const E* in_vec,
|
||||
uint32_t size,
|
||||
bool columns_batch,
|
||||
uint32_t batch_size,
|
||||
uint32_t columns_batch_size,
|
||||
S* scalar_vec,
|
||||
int step,
|
||||
int n_scalars,
|
||||
int logn,
|
||||
uint32_t log_size,
|
||||
eRevType rev_type,
|
||||
bool dit,
|
||||
E* out_vec)
|
||||
{
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= n_elements * batch_size) return;
|
||||
int64_t scalar_id = tid % n_elements;
|
||||
if (rev_type != eRevType::None) scalar_id = generalized_rev(tid, logn, dit, false, rev_type);
|
||||
if (tid >= size * batch_size) return;
|
||||
int64_t scalar_id = (tid / columns_batch_size) % size;
|
||||
if (rev_type != eRevType::None)
|
||||
scalar_id = generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, false, rev_type);
|
||||
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__ void ntt64(
|
||||
E* in,
|
||||
const E* in,
|
||||
E* out,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t columns_batch_size,
|
||||
uint32_t nof_ntt_blocks,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
@@ -153,19 +170,27 @@ namespace ntt {
|
||||
|
||||
s_meta.th_stride = 8;
|
||||
s_meta.ntt_block_size = 64;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3));
|
||||
s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 7) / 8)
|
||||
: (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 3) : (threadIdx.x & 0x7);
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
s_meta.batch_id =
|
||||
columns_batch_size ? (threadIdx.x & 0x7) + ((blockIdx.x % ((columns_batch_size + 7) / 8)) << 3) : 0;
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
|
||||
return;
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (columns_batch_size)
|
||||
engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
|
||||
|
||||
if (twiddle_stride && dit) {
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric64(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
@@ -189,24 +214,28 @@ namespace ntt {
|
||||
|
||||
if (twiddle_stride && !dit) {
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric64(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (columns_batch_size)
|
||||
engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__ void ntt32(
|
||||
E* in,
|
||||
const E* in,
|
||||
E* out,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t columns_batch_size,
|
||||
uint32_t nof_ntt_blocks,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
@@ -225,16 +254,25 @@ namespace ntt {
|
||||
|
||||
s_meta.th_stride = 4;
|
||||
s_meta.ntt_block_size = 32;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
|
||||
s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 15) / 16)
|
||||
: (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
s_meta.batch_id =
|
||||
columns_batch_size ? (threadIdx.x & 0xf) + ((blockIdx.x % ((columns_batch_size + 15) / 16)) << 4) : 0;
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
|
||||
return;
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
|
||||
if (columns_batch_size)
|
||||
engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadInternalTwiddles32(internal_twiddles, strided);
|
||||
else
|
||||
@@ -247,24 +285,28 @@ namespace ntt {
|
||||
engine.ntt4_2();
|
||||
if (twiddle_stride) {
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric32(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.storeGlobalData32(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (columns_batch_size)
|
||||
engine.storeGlobalData32ColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.storeGlobalData32(out, data_stride, log_data_stride, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__ void ntt32dit(
|
||||
E* in,
|
||||
const E* in,
|
||||
E* out,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t columns_batch_size,
|
||||
uint32_t nof_ntt_blocks,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
@@ -283,19 +325,27 @@ namespace ntt {
|
||||
|
||||
s_meta.th_stride = 4;
|
||||
s_meta.ntt_block_size = 32;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
|
||||
s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 15) / 16)
|
||||
: (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
s_meta.batch_id =
|
||||
columns_batch_size ? (threadIdx.x & 0xf) + ((blockIdx.x % ((columns_batch_size + 15) / 16)) << 4) : 0;
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
|
||||
return;
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData32(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
|
||||
if (columns_batch_size)
|
||||
engine.loadGlobalData32ColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.loadGlobalData32(in, data_stride, log_data_stride, strided, s_meta);
|
||||
if (twiddle_stride) {
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric32(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
@@ -311,18 +361,22 @@ namespace ntt {
|
||||
engine.SharedData32Rows8(shmem, false, false, strided); // load
|
||||
engine.twiddlesInternal();
|
||||
engine.ntt8win();
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (columns_batch_size)
|
||||
engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__ void ntt16(
|
||||
E* in,
|
||||
const E* in,
|
||||
E* out,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t columns_batch_size,
|
||||
uint32_t nof_ntt_blocks,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
@@ -341,16 +395,26 @@ namespace ntt {
|
||||
|
||||
s_meta.th_stride = 2;
|
||||
s_meta.ntt_block_size = 16;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
|
||||
s_meta.ntt_block_id = columns_batch_size
|
||||
? blockIdx.x / ((columns_batch_size + 31) / 32)
|
||||
: (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
s_meta.batch_id =
|
||||
columns_batch_size ? (threadIdx.x & 0x1f) + ((blockIdx.x % ((columns_batch_size + 31) / 32)) << 5) : 0;
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
|
||||
return;
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
|
||||
if (columns_batch_size)
|
||||
engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadInternalTwiddles16(internal_twiddles, strided);
|
||||
else
|
||||
@@ -363,24 +427,28 @@ namespace ntt {
|
||||
engine.ntt2_4();
|
||||
if (twiddle_stride) {
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric16(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
engine.twiddlesExternal();
|
||||
}
|
||||
engine.storeGlobalData16(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (columns_batch_size)
|
||||
engine.storeGlobalData16ColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.storeGlobalData16(out, data_stride, log_data_stride, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__launch_bounds__(64) __global__ void ntt16dit(
|
||||
E* in,
|
||||
const E* in,
|
||||
E* out,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
S* basic_twiddles,
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t columns_batch_size,
|
||||
uint32_t nof_ntt_blocks,
|
||||
uint32_t data_stride,
|
||||
uint32_t log_data_stride,
|
||||
@@ -399,19 +467,29 @@ namespace ntt {
|
||||
|
||||
s_meta.th_stride = 2;
|
||||
s_meta.ntt_block_size = 16;
|
||||
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
|
||||
s_meta.ntt_block_id = columns_batch_size
|
||||
? blockIdx.x / ((columns_batch_size + 31) / 32)
|
||||
: (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
|
||||
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
|
||||
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
|
||||
s_meta.batch_id =
|
||||
columns_batch_size ? (threadIdx.x & 0x1f) + ((blockIdx.x % ((columns_batch_size + 31) / 32)) << 5) : 0;
|
||||
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
|
||||
return;
|
||||
|
||||
if (fast_tw)
|
||||
engine.loadBasicTwiddles(basic_twiddles);
|
||||
else
|
||||
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
|
||||
engine.loadGlobalData16(in, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
|
||||
if (columns_batch_size)
|
||||
engine.loadGlobalData16ColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.loadGlobalData16(in, data_stride, log_data_stride, strided, s_meta);
|
||||
|
||||
if (twiddle_stride) {
|
||||
if (fast_tw)
|
||||
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
|
||||
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, s_meta);
|
||||
else
|
||||
engine.loadExternalTwiddlesGeneric16(
|
||||
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
|
||||
@@ -427,13 +505,17 @@ namespace ntt {
|
||||
engine.SharedData16Rows8(shmem, false, false, strided); // load
|
||||
engine.twiddlesInternal();
|
||||
engine.ntt8win();
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
|
||||
if (columns_batch_size)
|
||||
engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
|
||||
else
|
||||
engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
__global__ void normalize_kernel(E* data, S norm_factor)
|
||||
__global__ void normalize_kernel(E* data, S norm_factor, uint32_t size)
|
||||
{
|
||||
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid >= size) return;
|
||||
data[tid] = data[tid] * norm_factor;
|
||||
}
|
||||
|
||||
@@ -658,7 +740,7 @@ namespace ntt {
|
||||
|
||||
template <typename E, typename S>
|
||||
cudaError_t large_ntt(
|
||||
E* in,
|
||||
const E* in,
|
||||
E* out,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
@@ -666,6 +748,7 @@ namespace ntt {
|
||||
uint32_t log_size,
|
||||
uint32_t tw_log_size,
|
||||
uint32_t batch_size,
|
||||
bool columns_batch,
|
||||
bool inv,
|
||||
bool normalize,
|
||||
bool dit,
|
||||
@@ -679,72 +762,83 @@ namespace ntt {
|
||||
}
|
||||
|
||||
if (log_size == 4) {
|
||||
const int NOF_THREADS = min(64, 2 * batch_size);
|
||||
const int NOF_BLOCKS = (2 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
|
||||
const int NOF_THREADS = columns_batch ? 64 : min(64, 2 * batch_size);
|
||||
const int NOF_BLOCKS =
|
||||
columns_batch ? ((batch_size + 31) / 32) : (2 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
if (dit) {
|
||||
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit, fast_tw);
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
|
||||
} else { // dif
|
||||
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit, fast_tw);
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4));
|
||||
if (normalize)
|
||||
normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4), (1 << log_size) * batch_size);
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
if (log_size == 5) {
|
||||
const int NOF_THREADS = min(64, 4 * batch_size);
|
||||
const int NOF_BLOCKS = (4 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
const int NOF_THREADS = columns_batch ? 64 : min(64, 4 * batch_size);
|
||||
const int NOF_BLOCKS =
|
||||
columns_batch ? ((batch_size + 15) / 16) : (4 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
if (dit) {
|
||||
ntt32dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit, fast_tw);
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
|
||||
} else { // dif
|
||||
ntt32<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit, fast_tw);
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5));
|
||||
if (normalize)
|
||||
normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5), (1 << log_size) * batch_size);
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
if (log_size == 6) {
|
||||
const int NOF_THREADS = min(64, 8 * batch_size);
|
||||
const int NOF_BLOCKS = (8 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
const int NOF_THREADS = columns_batch ? 64 : min(64, 8 * batch_size);
|
||||
const int NOF_BLOCKS =
|
||||
columns_batch ? ((batch_size + 7) / 8) : ((8 * batch_size + NOF_THREADS - 1) / NOF_THREADS);
|
||||
ntt64<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
|
||||
false, 0, inv, dit, fast_tw);
|
||||
if (normalize) normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6));
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
|
||||
if (normalize)
|
||||
normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6), (1 << log_size) * batch_size);
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
if (log_size == 8) {
|
||||
const int NOF_THREADS = 64;
|
||||
const int NOF_BLOCKS = (32 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
|
||||
const int NOF_BLOCKS =
|
||||
columns_batch ? ((batch_size + 31) / 32 * 16) : ((32 * batch_size + NOF_THREADS - 1) / NOF_THREADS);
|
||||
if (dit) {
|
||||
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1, 0, 0,
|
||||
columns_batch, 0, inv, dit, fast_tw);
|
||||
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 16, 4, 16, true, 1,
|
||||
inv, dit, fast_tw);
|
||||
} else { // dif
|
||||
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 16, 4, 16, true, 1,
|
||||
inv, dit, fast_tw);
|
||||
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1, 0, 0,
|
||||
columns_batch, 0, inv, dit, fast_tw);
|
||||
}
|
||||
if (normalize) normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8));
|
||||
if (normalize)
|
||||
normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8), (1 << log_size) * batch_size);
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
// general case:
|
||||
uint32_t nof_blocks = (1 << (log_size - 9)) * batch_size;
|
||||
uint32_t nof_blocks = (1 << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size);
|
||||
if (dit) {
|
||||
for (int i = 0; i < 5; i++) {
|
||||
uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i];
|
||||
@@ -754,18 +848,18 @@ namespace ntt {
|
||||
if (stage_size == 6)
|
||||
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 6) * (columns_batch ? 1 : batch_size), 1 << stride_log,
|
||||
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
|
||||
else if (stage_size == 5)
|
||||
ntt32dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 5) * (columns_batch ? 1 : batch_size), 1 << stride_log,
|
||||
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
|
||||
else if (stage_size == 4)
|
||||
ntt16dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1 << stride_log,
|
||||
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
|
||||
}
|
||||
} else { // dif
|
||||
bool first_run = false, prev_stage = false;
|
||||
@@ -778,30 +872,31 @@ namespace ntt {
|
||||
if (stage_size == 6)
|
||||
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 6) * (columns_batch ? 1 : batch_size), 1 << stride_log,
|
||||
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
|
||||
else if (stage_size == 5)
|
||||
ntt32<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 5) * (columns_batch ? 1 : batch_size), 1 << stride_log,
|
||||
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
|
||||
else if (stage_size == 4)
|
||||
ntt16<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
|
||||
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
|
||||
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
|
||||
fast_tw);
|
||||
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1 << stride_log,
|
||||
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
|
||||
prev_stage = stage_size;
|
||||
}
|
||||
}
|
||||
if (normalize)
|
||||
normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(log_size));
|
||||
normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>(
|
||||
out, S::inv_log_size(log_size), (1 << log_size) * batch_size);
|
||||
|
||||
return CHK_LAST();
|
||||
}
|
||||
|
||||
template <typename E, typename S>
|
||||
cudaError_t mixed_radix_ntt(
|
||||
E* d_input,
|
||||
const E* d_input,
|
||||
E* d_output,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
@@ -809,6 +904,7 @@ namespace ntt {
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
int batch_size,
|
||||
bool columns_batch,
|
||||
bool is_inverse,
|
||||
bool fast_tw,
|
||||
Ordering ordering,
|
||||
@@ -858,9 +954,10 @@ namespace ntt {
|
||||
}
|
||||
|
||||
if (is_on_coset && !is_inverse) {
|
||||
batch_elementwise_mul_with_reorder<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_input, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles,
|
||||
arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
|
||||
batch_elementwise_mul_with_reorder_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_input, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1,
|
||||
arbitrary_coset ? arbitrary_coset : external_twiddles, arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn,
|
||||
reverse_coset, dit, d_output);
|
||||
|
||||
d_input = d_output;
|
||||
}
|
||||
@@ -869,10 +966,11 @@ namespace ntt {
|
||||
const bool is_reverse_in_place = (d_input == d_output);
|
||||
if (is_reverse_in_place) {
|
||||
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
d_output, logn, columns_batch, batch_size, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
} else {
|
||||
reorder_digits_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_input, d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
d_input, d_output, logn, columns_batch, batch_size, columns_batch ? batch_size : 1, dit, fast_tw,
|
||||
reverse_input, is_normalize, S::inv_log_size(logn));
|
||||
}
|
||||
is_normalize = false;
|
||||
d_input = d_output;
|
||||
@@ -880,18 +978,19 @@ namespace ntt {
|
||||
|
||||
// inplace ntt
|
||||
CHK_IF_RETURN(large_ntt(
|
||||
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
|
||||
(is_normalize && reverse_output == eRevType::None), dit, fast_tw, cuda_stream));
|
||||
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size,
|
||||
columns_batch, is_inverse, (is_normalize && reverse_output == eRevType::None), dit, fast_tw, cuda_stream));
|
||||
|
||||
if (reverse_output != eRevType::None) {
|
||||
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, logn, dit, fast_tw, reverse_output, is_normalize, S::inv_log_size(logn));
|
||||
d_output, logn, columns_batch, batch_size, dit, fast_tw, reverse_output, is_normalize, S::inv_log_size(logn));
|
||||
}
|
||||
|
||||
if (is_on_coset && is_inverse) {
|
||||
batch_elementwise_mul_with_reorder<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles,
|
||||
arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
|
||||
batch_elementwise_mul_with_reorder_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
|
||||
d_output, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1,
|
||||
arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles, arbitrary_coset ? 1 : -coset_gen_index,
|
||||
n_twiddles, logn, reverse_coset, dit, d_output);
|
||||
}
|
||||
|
||||
return CHK_LAST();
|
||||
@@ -915,7 +1014,7 @@ namespace ntt {
|
||||
cudaStream_t& stream);
|
||||
|
||||
template cudaError_t mixed_radix_ntt<curve_config::scalar_t, curve_config::scalar_t>(
|
||||
curve_config::scalar_t* d_input,
|
||||
const curve_config::scalar_t* d_input,
|
||||
curve_config::scalar_t* d_output,
|
||||
curve_config::scalar_t* external_twiddles,
|
||||
curve_config::scalar_t* internal_twiddles,
|
||||
@@ -923,6 +1022,7 @@ namespace ntt {
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
int batch_size,
|
||||
bool columns_batch,
|
||||
bool is_inverse,
|
||||
bool fast_tw,
|
||||
Ordering ordering,
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace ntt {
|
||||
const uint32_t MAX_SHARED_MEM = MAX_SHARED_MEM_ELEMENT_SIZE * MAX_NUM_THREADS;
|
||||
|
||||
template <typename E>
|
||||
__global__ void reverse_order_kernel(E* arr, E* arr_reversed, uint32_t n, uint32_t logn, uint32_t batch_size)
|
||||
__global__ void reverse_order_kernel(const E* arr, E* arr_reversed, uint32_t n, uint32_t logn, uint32_t batch_size)
|
||||
{
|
||||
int threadId = (blockIdx.x * blockDim.x) + threadIdx.x;
|
||||
if (threadId < n * batch_size) {
|
||||
@@ -46,7 +46,8 @@ namespace ntt {
|
||||
* @param arr_out buffer of the same size as `arr_in` on the GPU to write the bit-permuted array into.
|
||||
*/
|
||||
template <typename E>
|
||||
void reverse_order_batch(E* arr_in, uint32_t n, uint32_t logn, uint32_t batch_size, cudaStream_t stream, E* arr_out)
|
||||
void reverse_order_batch(
|
||||
const E* arr_in, uint32_t n, uint32_t logn, uint32_t batch_size, cudaStream_t stream, E* arr_out)
|
||||
{
|
||||
int number_of_threads = MAX_THREADS_BATCH;
|
||||
int number_of_blocks = (n * batch_size + number_of_threads - 1) / number_of_threads;
|
||||
@@ -63,7 +64,7 @@ namespace ntt {
|
||||
* @param arr_out buffer of the same size as `arr_in` on the GPU to write the bit-permuted array into.
|
||||
*/
|
||||
template <typename E>
|
||||
void reverse_order(E* arr_in, uint32_t n, uint32_t logn, cudaStream_t stream, E* arr_out)
|
||||
void reverse_order(const E* arr_in, uint32_t n, uint32_t logn, cudaStream_t stream, E* arr_out)
|
||||
{
|
||||
reverse_order_batch(arr_in, n, logn, 1, stream, arr_out);
|
||||
}
|
||||
@@ -81,7 +82,7 @@ namespace ntt {
|
||||
*/
|
||||
template <typename E, typename S>
|
||||
__global__ void ntt_template_kernel_shared_rev(
|
||||
E* __restrict__ arr_in,
|
||||
const E* __restrict__ arr_in,
|
||||
int n,
|
||||
const S* __restrict__ r_twiddles,
|
||||
int n_twiddles,
|
||||
@@ -153,7 +154,7 @@ namespace ntt {
|
||||
*/
|
||||
template <typename E, typename S>
|
||||
__global__ void ntt_template_kernel_shared(
|
||||
E* __restrict__ arr_in,
|
||||
const E* __restrict__ arr_in,
|
||||
int n,
|
||||
const S* __restrict__ r_twiddles,
|
||||
int n_twiddles,
|
||||
@@ -221,7 +222,7 @@ namespace ntt {
|
||||
*/
|
||||
template <typename E, typename S>
|
||||
__global__ void
|
||||
ntt_template_kernel(E* arr_in, int n, S* twiddles, int n_twiddles, int max_task, int s, bool rev, E* arr_out)
|
||||
ntt_template_kernel(const E* arr_in, int n, S* twiddles, int n_twiddles, int max_task, int s, bool rev, E* arr_out)
|
||||
{
|
||||
int task = blockIdx.x;
|
||||
int chunks = n / (blockDim.x * 2);
|
||||
@@ -273,7 +274,7 @@ namespace ntt {
|
||||
*/
|
||||
template <typename E, typename S>
|
||||
cudaError_t ntt_inplace_batch_template(
|
||||
E* d_input,
|
||||
const E* d_input,
|
||||
int n,
|
||||
S* d_twiddles,
|
||||
int n_twiddles,
|
||||
@@ -391,7 +392,7 @@ namespace ntt {
|
||||
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
|
||||
|
||||
template <typename U, typename E>
|
||||
friend cudaError_t NTT<U, E>(E* input, int size, NTTDir dir, NTTConfig<U>& config, E* output);
|
||||
friend cudaError_t NTT<U, E>(const E* input, int size, NTTDir dir, NTTConfig<U>& config, E* output);
|
||||
};
|
||||
|
||||
template <typename S>
|
||||
@@ -516,12 +517,15 @@ namespace ntt {
|
||||
static bool is_choose_radix2_algorithm(int logn, int batch_size, const NTTConfig<S>& config)
|
||||
{
|
||||
const bool is_mixed_radix_alg_supported = (logn > 3 && logn != 7);
|
||||
if (!is_mixed_radix_alg_supported && config.columns_batch)
|
||||
throw IcicleError(IcicleError_t::InvalidArgument, "columns batch is not supported for given NTT size");
|
||||
const bool is_user_selected_radix2_alg = config.ntt_algorithm == NttAlgorithm::Radix2;
|
||||
const bool is_force_radix2 = !is_mixed_radix_alg_supported || is_user_selected_radix2_alg;
|
||||
if (is_force_radix2) return true;
|
||||
|
||||
const bool is_user_selected_mixed_radix_alg = config.ntt_algorithm == NttAlgorithm::MixedRadix;
|
||||
if (is_user_selected_mixed_radix_alg) return false;
|
||||
if (config.columns_batch) return false; // radix2 does not currently support columns batch mode.
|
||||
|
||||
// Heuristic to automatically select an algorithm
|
||||
// Note that generally the decision depends on {logn, batch, ordering, inverse, coset, in-place, coeff-field} and
|
||||
@@ -537,7 +541,7 @@ namespace ntt {
|
||||
|
||||
template <typename S, typename E>
|
||||
cudaError_t radix2_ntt(
|
||||
E* d_input,
|
||||
const E* d_input,
|
||||
E* d_output,
|
||||
S* twiddles,
|
||||
int ntt_size,
|
||||
@@ -583,7 +587,7 @@ namespace ntt {
|
||||
}
|
||||
|
||||
template <typename S, typename E>
|
||||
cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output)
|
||||
cudaError_t NTT(const E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output)
|
||||
{
|
||||
CHK_INIT_IF_RETURN();
|
||||
|
||||
@@ -610,18 +614,22 @@ namespace ntt {
|
||||
bool are_inputs_on_device = config.are_inputs_on_device;
|
||||
bool are_outputs_on_device = config.are_outputs_on_device;
|
||||
|
||||
E* d_input;
|
||||
const E* d_input;
|
||||
E* d_allocated_input = nullptr;
|
||||
if (are_inputs_on_device) {
|
||||
d_input = input;
|
||||
} else {
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_input, input_size_bytes, stream));
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(d_input, input, input_size_bytes, cudaMemcpyHostToDevice, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_input, input_size_bytes, stream));
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(d_allocated_input, input, input_size_bytes, cudaMemcpyHostToDevice, stream));
|
||||
d_input = d_allocated_input;
|
||||
}
|
||||
E* d_output;
|
||||
E* d_allocated_output = nullptr;
|
||||
if (are_outputs_on_device) {
|
||||
d_output = output;
|
||||
} else {
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_output, input_size_bytes, stream));
|
||||
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_output, input_size_bytes, stream));
|
||||
d_output = d_allocated_output;
|
||||
}
|
||||
|
||||
S* coset = nullptr;
|
||||
@@ -641,37 +649,42 @@ namespace ntt {
|
||||
h_coset.clear();
|
||||
}
|
||||
|
||||
const bool is_radix2_algorithm = is_choose_radix2_algorithm(logn, batch_size, config);
|
||||
const bool is_inverse = dir == NTTDir::kInverse;
|
||||
|
||||
if (is_radix2_algorithm) {
|
||||
if constexpr (std::is_same_v<E, curve_config::projective_t>) {
|
||||
CHK_IF_RETURN(ntt::radix2_ntt(
|
||||
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
|
||||
coset_index, stream));
|
||||
} else {
|
||||
const bool is_on_coset = (coset_index != 0) || coset;
|
||||
const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr) && !is_on_coset;
|
||||
S* twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_external_twiddles_inv : domain.fast_external_twiddles)
|
||||
: domain.twiddles;
|
||||
S* internal_twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_internal_twiddles_inv : domain.fast_internal_twiddles)
|
||||
: domain.internal_twiddles;
|
||||
S* basic_twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles)
|
||||
: domain.basic_twiddles;
|
||||
|
||||
CHK_IF_RETURN(ntt::mixed_radix_ntt(
|
||||
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size,
|
||||
is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
|
||||
const bool is_radix2_algorithm = is_choose_radix2_algorithm(logn, batch_size, config);
|
||||
if (is_radix2_algorithm) {
|
||||
CHK_IF_RETURN(ntt::radix2_ntt(
|
||||
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
|
||||
coset_index, stream));
|
||||
} else {
|
||||
const bool is_on_coset = (coset_index != 0) || coset;
|
||||
const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr) && !is_on_coset;
|
||||
S* twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_external_twiddles_inv : domain.fast_external_twiddles)
|
||||
: domain.twiddles;
|
||||
S* internal_twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_internal_twiddles_inv : domain.fast_internal_twiddles)
|
||||
: domain.internal_twiddles;
|
||||
S* basic_twiddles = is_fast_twiddles_enabled
|
||||
? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles)
|
||||
: domain.basic_twiddles;
|
||||
CHK_IF_RETURN(ntt::mixed_radix_ntt(
|
||||
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size,
|
||||
config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
|
||||
}
|
||||
}
|
||||
|
||||
if (!are_outputs_on_device)
|
||||
CHK_IF_RETURN(cudaMemcpyAsync(output, d_output, input_size_bytes, cudaMemcpyDeviceToHost, stream));
|
||||
|
||||
if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
|
||||
if (!are_inputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_input, stream));
|
||||
if (!are_outputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_output, stream));
|
||||
if (d_allocated_input) CHK_IF_RETURN(cudaFreeAsync(d_allocated_input, stream));
|
||||
if (d_allocated_output) CHK_IF_RETURN(cudaFreeAsync(d_allocated_output, stream));
|
||||
if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(stream));
|
||||
|
||||
return CHK_LAST();
|
||||
@@ -685,6 +698,7 @@ namespace ntt {
|
||||
ctx, // ctx
|
||||
S::one(), // coset_gen
|
||||
1, // batch_size
|
||||
false, // columns_batch
|
||||
Ordering::kNN, // ordering
|
||||
false, // are_inputs_on_device
|
||||
false, // are_outputs_on_device
|
||||
@@ -712,7 +726,7 @@ namespace ntt {
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
extern "C" cudaError_t CONCAT_EXPAND(CURVE, NTTCuda)(
|
||||
curve_config::scalar_t* input,
|
||||
const curve_config::scalar_t* input,
|
||||
int size,
|
||||
NTTDir dir,
|
||||
NTTConfig<curve_config::scalar_t>& config,
|
||||
@@ -731,7 +745,7 @@ namespace ntt {
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
extern "C" cudaError_t CONCAT_EXPAND(CURVE, ECNTTCuda)(
|
||||
curve_config::projective_t* input,
|
||||
const curve_config::projective_t* input,
|
||||
int size,
|
||||
NTTDir dir,
|
||||
NTTConfig<curve_config::scalar_t>& config,
|
||||
|
||||
@@ -95,6 +95,8 @@ namespace ntt {
|
||||
S coset_gen; /**< Coset generator. Used to perform coset (i)NTTs. Default value: `S::one()`
|
||||
* (corresponding to no coset being used). */
|
||||
int batch_size; /**< The number of NTTs to compute. Default value: 1. */
|
||||
bool columns_batch; /**< True if the batches are the columns of an input matrix
|
||||
(they are strided in memory with a stride of ntt size) Default value: false. */
|
||||
Ordering ordering; /**< Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value:
|
||||
* `Ordering::kNN`. */
|
||||
bool are_inputs_on_device; /**< True if inputs are on device and false if they're on host. Default value: false. */
|
||||
@@ -132,7 +134,7 @@ namespace ntt {
|
||||
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
|
||||
*/
|
||||
template <typename S, typename E>
|
||||
cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output);
|
||||
cudaError_t NTT(const E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output);
|
||||
|
||||
} // namespace ntt
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ namespace ntt {
|
||||
|
||||
template <typename E, typename S>
|
||||
cudaError_t mixed_radix_ntt(
|
||||
E* d_input,
|
||||
const E* d_input,
|
||||
E* d_output,
|
||||
S* external_twiddles,
|
||||
S* internal_twiddles,
|
||||
@@ -35,6 +35,7 @@ namespace ntt {
|
||||
int ntt_size,
|
||||
int max_logn,
|
||||
int batch_size,
|
||||
bool columns_batch,
|
||||
bool is_inverse,
|
||||
bool fast_tw,
|
||||
Ordering ordering,
|
||||
|
||||
@@ -29,6 +29,13 @@ void incremental_values(test_scalar* res, uint32_t count)
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void transpose_batch(test_scalar* in, test_scalar* out, int row_size, int column_size)
|
||||
{
|
||||
int tid = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
if (tid >= row_size * column_size) return;
|
||||
out[(tid % row_size) * column_size + (tid / row_size)] = in[tid];
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
{
|
||||
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
|
||||
@@ -37,11 +44,12 @@ int main(int argc, char** argv)
|
||||
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19;
|
||||
int NTT_SIZE = 1 << NTT_LOG_SIZE;
|
||||
bool INPLACE = (argc > 2) ? atoi(argv[2]) : false;
|
||||
int INV = (argc > 3) ? atoi(argv[3]) : true;
|
||||
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 1;
|
||||
int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 0;
|
||||
const ntt::Ordering ordering = (argc > 6) ? ntt::Ordering(atoi(argv[6])) : ntt::Ordering::kNN;
|
||||
bool FAST_TW = (argc > 7) ? atoi(argv[7]) : true;
|
||||
int INV = (argc > 3) ? atoi(argv[3]) : false;
|
||||
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 150;
|
||||
bool COLUMNS_BATCH = (argc > 5) ? atoi(argv[5]) : false;
|
||||
int COSET_IDX = (argc > 6) ? atoi(argv[6]) : 2;
|
||||
const ntt::Ordering ordering = (argc > 7) ? ntt::Ordering(atoi(argv[7])) : ntt::Ordering::kNN;
|
||||
bool FAST_TW = (argc > 8) ? atoi(argv[8]) : true;
|
||||
|
||||
// Note: NM, MN are not expected to be equal when comparing mixed-radix and radix-2 NTTs
|
||||
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
|
||||
@@ -52,8 +60,8 @@ int main(int argc, char** argv)
|
||||
: "MN";
|
||||
|
||||
printf(
|
||||
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d, ordering=%s, fast_tw=%d\n", NTT_LOG_SIZE,
|
||||
INPLACE, INV, BATCH_SIZE, COSET_IDX, ordering_str, FAST_TW);
|
||||
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, columns_batch=%d coset-idx=%d, ordering=%s, fast_tw=%d\n",
|
||||
NTT_LOG_SIZE, INPLACE, INV, BATCH_SIZE, COLUMNS_BATCH, COSET_IDX, ordering_str, FAST_TW);
|
||||
|
||||
CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup)
|
||||
|
||||
@@ -63,6 +71,7 @@ int main(int argc, char** argv)
|
||||
ntt_config.are_inputs_on_device = true;
|
||||
ntt_config.are_outputs_on_device = true;
|
||||
ntt_config.batch_size = BATCH_SIZE;
|
||||
ntt_config.columns_batch = COLUMNS_BATCH;
|
||||
|
||||
CHK_IF_RETURN(cudaEventCreate(&icicle_start));
|
||||
CHK_IF_RETURN(cudaEventCreate(&icicle_stop));
|
||||
@@ -83,7 +92,9 @@ int main(int argc, char** argv)
|
||||
|
||||
// gpu allocation
|
||||
test_data *GpuScalars, *GpuOutputOld, *GpuOutputNew;
|
||||
test_data* GpuScalarsTransposed;
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuScalars, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuScalarsTransposed, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuOutputOld, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
|
||||
CHK_IF_RETURN(cudaMalloc(&GpuOutputNew, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
|
||||
|
||||
@@ -93,10 +104,16 @@ int main(int argc, char** argv)
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpy(GpuScalars, CpuScalars.get(), NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyHostToDevice));
|
||||
|
||||
if (COLUMNS_BATCH) {
|
||||
transpose_batch<<<(NTT_SIZE * BATCH_SIZE + 256 - 1) / 256, 256>>>(
|
||||
GpuScalars, GpuScalarsTransposed, NTT_SIZE, BATCH_SIZE);
|
||||
}
|
||||
|
||||
// inplace
|
||||
if (INPLACE) {
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
|
||||
CHK_IF_RETURN(cudaMemcpy(
|
||||
GpuOutputNew, COLUMNS_BATCH ? GpuScalarsTransposed : GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data),
|
||||
cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
for (int coset_idx = 0; coset_idx < COSET_IDX; ++coset_idx) {
|
||||
@@ -109,13 +126,14 @@ int main(int argc, char** argv)
|
||||
ntt_config.ntt_algorithm = ntt::NttAlgorithm::MixedRadix;
|
||||
for (size_t i = 0; i < iterations; i++) {
|
||||
CHK_IF_RETURN(ntt::NTT(
|
||||
INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config,
|
||||
GpuOutputNew));
|
||||
INPLACE ? GpuOutputNew
|
||||
: COLUMNS_BATCH ? GpuScalarsTransposed
|
||||
: GpuScalars,
|
||||
NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputNew));
|
||||
}
|
||||
CHK_IF_RETURN(cudaEventRecord(new_stop, ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaEventElapsedTime(&new_time, new_start, new_stop));
|
||||
if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); }
|
||||
|
||||
// OLD
|
||||
CHK_IF_RETURN(cudaEventRecord(icicle_start, ntt_config.ctx.stream));
|
||||
@@ -127,7 +145,6 @@ int main(int argc, char** argv)
|
||||
CHK_IF_RETURN(cudaEventRecord(icicle_stop, ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
|
||||
CHK_IF_RETURN(cudaEventElapsedTime(&icicle_time, icicle_start, icicle_stop));
|
||||
if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); }
|
||||
|
||||
if (is_print) {
|
||||
printf("Old Runtime=%0.3f MS\n", icicle_time / iterations);
|
||||
@@ -140,11 +157,19 @@ int main(int argc, char** argv)
|
||||
CHK_IF_RETURN(benchmark(false /*=print*/, 1)); // warmup
|
||||
int count = INPLACE ? 1 : 10;
|
||||
if (INPLACE) {
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
|
||||
CHK_IF_RETURN(cudaMemcpy(
|
||||
GpuOutputNew, COLUMNS_BATCH ? GpuScalarsTransposed : GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data),
|
||||
cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
CHK_IF_RETURN(benchmark(true /*=print*/, count));
|
||||
|
||||
if (COLUMNS_BATCH) {
|
||||
transpose_batch<<<(NTT_SIZE * BATCH_SIZE + 256 - 1) / 256, 256>>>(
|
||||
GpuOutputNew, GpuScalarsTransposed, BATCH_SIZE, NTT_SIZE);
|
||||
CHK_IF_RETURN(cudaMemcpy(
|
||||
GpuOutputNew, GpuScalarsTransposed, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
// verify
|
||||
CHK_IF_RETURN(
|
||||
cudaMemcpy(CpuOutputNew.get(), GpuOutputNew, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
|
||||
@@ -153,10 +178,11 @@ int main(int argc, char** argv)
|
||||
|
||||
bool success = true;
|
||||
for (int i = 0; i < NTT_SIZE * BATCH_SIZE; i++) {
|
||||
// if (i%64==0) printf("\n");
|
||||
if (CpuOutputNew[i] != CpuOutputOld[i]) {
|
||||
success = false;
|
||||
// std::cout << i << " ref " << CpuOutputOld[i] << " != " << CpuOutputNew[i] << std::endl;
|
||||
break;
|
||||
// break;
|
||||
} else {
|
||||
// std::cout << i << " ref " << CpuOutputOld[i] << " == " << CpuOutputNew[i] << std::endl;
|
||||
// break;
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
struct stage_metadata {
|
||||
uint32_t th_stride;
|
||||
uint32_t ntt_block_size;
|
||||
uint32_t batch_id;
|
||||
uint32_t ntt_block_id;
|
||||
uint32_t ntt_inp_id;
|
||||
};
|
||||
@@ -118,7 +119,7 @@ public:
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
|
||||
loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
|
||||
{
|
||||
data += tw_order * s_meta.ntt_inp_id + (s_meta.ntt_block_id & (tw_order - 1));
|
||||
|
||||
@@ -129,7 +130,7 @@ public:
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
|
||||
loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
|
||||
{
|
||||
data += tw_order * s_meta.ntt_inp_id * 2 + (s_meta.ntt_block_id & (tw_order - 1));
|
||||
|
||||
@@ -143,7 +144,7 @@ public:
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
|
||||
loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
|
||||
{
|
||||
data += tw_order * s_meta.ntt_inp_id * 4 + (s_meta.ntt_block_id & (tw_order - 1));
|
||||
|
||||
@@ -195,8 +196,8 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadGlobalData(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
__device__ __forceinline__ void
|
||||
loadGlobalData(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
|
||||
@@ -211,8 +212,22 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalData(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
__device__ __forceinline__ void loadGlobalDataColumnBatch(
|
||||
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
X[i] = data[s_meta.th_stride * i * data_stride * batch_size];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
storeGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
|
||||
@@ -227,8 +242,22 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadGlobalData32(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
__device__ __forceinline__ void storeGlobalDataColumnBatch(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 8; i++) {
|
||||
data[s_meta.th_stride * i * data_stride * batch_size] = X[i];
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadGlobalData32(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
|
||||
@@ -246,8 +275,25 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalData32(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
__device__ __forceinline__ void loadGlobalData32ColumnBatch(
|
||||
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
X[4 * j + i] = data[(8 * i + j) * data_stride * batch_size];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
storeGlobalData32(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
|
||||
@@ -265,8 +311,25 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void loadGlobalData16(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
__device__ __forceinline__ void storeGlobalData32ColumnBatch(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 2; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 4; i++) {
|
||||
data[(8 * i + j) * data_stride * batch_size] = X[4 * j + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
loadGlobalData16(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
|
||||
@@ -284,8 +347,25 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalData16(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
|
||||
__device__ __forceinline__ void loadGlobalData16ColumnBatch(
|
||||
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
X[2 * j + i] = data[(8 * i + j) * data_stride * batch_size];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void
|
||||
storeGlobalData16(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
|
||||
{
|
||||
if (strided) {
|
||||
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
|
||||
@@ -303,6 +383,23 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void storeGlobalData16ColumnBatch(
|
||||
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
|
||||
{
|
||||
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
|
||||
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
|
||||
batch_size +
|
||||
s_meta.batch_id;
|
||||
|
||||
#pragma unroll
|
||||
for (uint32_t j = 0; j < 4; j++) {
|
||||
#pragma unroll
|
||||
for (uint32_t i = 0; i < 2; i++) {
|
||||
data[(8 * i + j) * data_stride * batch_size] = X[2 * j + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void ntt4_2()
|
||||
{
|
||||
#pragma unroll
|
||||
|
||||
@@ -3,8 +3,11 @@
|
||||
#define BW6_761_PARAMS_H
|
||||
|
||||
#include "utils/storage.cuh"
|
||||
#include "bls12_377_params.cuh"
|
||||
|
||||
namespace bw6_761 {
|
||||
typedef bls12_377::fq_config fp_config;
|
||||
|
||||
struct fq_config {
|
||||
static constexpr unsigned limbs_count = 24;
|
||||
static constexpr unsigned modulus_bit_count = 761;
|
||||
|
||||
@@ -24,7 +24,6 @@ using namespace bls12_381;
|
||||
#include "bls12_377_params.cuh"
|
||||
using namespace bls12_377;
|
||||
#elif CURVE_ID == BW6_761
|
||||
#include "bls12_377_params.cuh"
|
||||
#include "bw6_761_params.cuh"
|
||||
using namespace bw6_761;
|
||||
#elif CURVE_ID == GRUMPKIN
|
||||
@@ -39,10 +38,6 @@ using namespace grumpkin;
|
||||
* with the `-DCURVE` env variable passed during build.
|
||||
*/
|
||||
namespace curve_config {
|
||||
|
||||
#if CURVE_ID == BW6_761
|
||||
typedef bls12_377::fq_config fp_config;
|
||||
#endif
|
||||
/**
|
||||
* Scalar field of the curve. Is always a prime field.
|
||||
*/
|
||||
|
||||
@@ -33,4 +33,4 @@ public:
|
||||
os << "x: " << point.x << "; y: " << point.y;
|
||||
return os;
|
||||
}
|
||||
};
|
||||
};
|
||||
@@ -2,7 +2,7 @@
|
||||
#include "field.cuh"
|
||||
#include "utils/utils.h"
|
||||
|
||||
#define scalar_t curve_config::scalar_t
|
||||
using namespace curve_config;
|
||||
|
||||
extern "C" void CONCAT_EXPAND(CURVE, GenerateScalars)(scalar_t* scalars, int size)
|
||||
{
|
||||
|
||||
@@ -680,7 +680,7 @@ public:
|
||||
|
||||
HOST_DEVICE_INLINE uint32_t* export_limbs() { return (uint32_t*)limbs_storage.limbs; }
|
||||
|
||||
HOST_DEVICE_INLINE unsigned get_scalar_digit(unsigned digit_num, unsigned digit_width)
|
||||
HOST_DEVICE_INLINE unsigned get_scalar_digit(unsigned digit_num, unsigned digit_width) const
|
||||
{
|
||||
const uint32_t limb_lsb_idx = (digit_num * digit_width) / 32;
|
||||
const uint32_t shift_bits = (digit_num * digit_width) % 32;
|
||||
|
||||
@@ -8,6 +8,9 @@ class Projective
|
||||
friend Affine<FF>;
|
||||
|
||||
public:
|
||||
static constexpr unsigned SCALAR_FF_NBITS = SCALAR_FF::NBITS;
|
||||
static constexpr unsigned FF_NBITS = FF::NBITS;
|
||||
|
||||
FF x;
|
||||
FF y;
|
||||
FF z;
|
||||
@@ -36,6 +39,34 @@ public:
|
||||
|
||||
static HOST_DEVICE_INLINE Projective neg(const Projective& point) { return {point.x, FF::neg(point.y), point.z}; }
|
||||
|
||||
static HOST_DEVICE_INLINE Projective dbl(const Projective& point)
|
||||
{
|
||||
const FF X = point.x;
|
||||
const FF Y = point.y;
|
||||
const FF Z = point.z;
|
||||
|
||||
// TODO: Change to efficient dbl once implemented for field.cuh
|
||||
FF t0 = FF::sqr(Y); // 1. t0 ← Y · Y
|
||||
FF Z3 = t0 + t0; // 2. Z3 ← t0 + t0
|
||||
Z3 = Z3 + Z3; // 3. Z3 ← Z3 + Z3
|
||||
Z3 = Z3 + Z3; // 4. Z3 ← Z3 + Z3
|
||||
FF t1 = Y * Z; // 5. t1 ← Y · Z
|
||||
FF t2 = FF::sqr(Z); // 6. t2 ← Z · Z
|
||||
t2 = FF::template mul_unsigned<3>(FF::template mul_const<B_VALUE>(t2)); // 7. t2 ← b3 · t2
|
||||
FF X3 = t2 * Z3; // 8. X3 ← t2 · Z3
|
||||
FF Y3 = t0 + t2; // 9. Y3 ← t0 + t2
|
||||
Z3 = t1 * Z3; // 10. Z3 ← t1 · Z3
|
||||
t1 = t2 + t2; // 11. t1 ← t2 + t2
|
||||
t2 = t1 + t2; // 12. t2 ← t1 + t2
|
||||
t0 = t0 - t2; // 13. t0 ← t0 − t2
|
||||
Y3 = t0 * Y3; // 14. Y3 ← t0 · Y3
|
||||
Y3 = X3 + Y3; // 15. Y3 ← X3 + Y3
|
||||
t1 = X * Y; // 16. t1 ← X · Y
|
||||
X3 = t0 * t1; // 17. X3 ← t0 · t1
|
||||
X3 = X3 + X3; // 18. X3 ← X3 + X3
|
||||
return {X3, Y3, Z3};
|
||||
}
|
||||
|
||||
friend HOST_DEVICE_INLINE Projective operator+(Projective p1, const Projective& p2)
|
||||
{
|
||||
const FF X1 = p1.x; // < 2
|
||||
|
||||
@@ -9,14 +9,14 @@ namespace mont {
|
||||
#define MAX_THREADS_PER_BLOCK 256
|
||||
|
||||
template <typename E, bool is_into>
|
||||
__global__ void MontgomeryKernel(E* input, int n, E* output)
|
||||
__global__ void MontgomeryKernel(const E* input, int n, E* output)
|
||||
{
|
||||
int tid = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (tid < n) { output[tid] = is_into ? E::ToMontgomery(input[tid]) : E::FromMontgomery(input[tid]); }
|
||||
}
|
||||
|
||||
template <typename E, bool is_into>
|
||||
cudaError_t ConvertMontgomery(E* d_input, int n, cudaStream_t stream, E* d_output)
|
||||
cudaError_t ConvertMontgomery(const E* d_input, int n, cudaStream_t stream, E* d_output)
|
||||
{
|
||||
// Set the grid and block dimensions
|
||||
int num_threads = MAX_THREADS_PER_BLOCK;
|
||||
@@ -29,13 +29,13 @@ namespace mont {
|
||||
} // namespace
|
||||
|
||||
template <typename E>
|
||||
cudaError_t ToMontgomery(E* d_input, int n, cudaStream_t stream, E* d_output)
|
||||
cudaError_t ToMontgomery(const E* d_input, int n, cudaStream_t stream, E* d_output)
|
||||
{
|
||||
return ConvertMontgomery<E, true>(d_input, n, stream, d_output);
|
||||
}
|
||||
|
||||
template <typename E>
|
||||
cudaError_t FromMontgomery(E* d_input, int n, cudaStream_t stream, E* d_output)
|
||||
cudaError_t FromMontgomery(const E* d_input, int n, cudaStream_t stream, E* d_output)
|
||||
{
|
||||
return ConvertMontgomery<E, false>(d_input, n, stream, d_output);
|
||||
}
|
||||
|
||||
@@ -12,7 +12,7 @@ namespace utils_internal {
|
||||
|
||||
template <typename E, typename S>
|
||||
__global__ void BatchMulKernel(
|
||||
E* in_vec,
|
||||
const E* in_vec,
|
||||
int n_elements,
|
||||
int batch_size,
|
||||
S* scalar_vec,
|
||||
|
||||
@@ -95,7 +95,6 @@ namespace vec_ops {
|
||||
*/
|
||||
template <typename E>
|
||||
cudaError_t Sub(E* vec_a, E* vec_b, int n, VecOpsConfig<E>& config, E* result);
|
||||
|
||||
} // namespace vec_ops
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
#!/bin/bash
|
||||
|
||||
G2_DEFINED=OFF
|
||||
ECNTT_DEFINED=OFF
|
||||
|
||||
if [[ $2 ]]
|
||||
if [[ $2 == "ON" ]]
|
||||
then
|
||||
G2_DEFINED=ON
|
||||
fi
|
||||
|
||||
if [[ $3 ]]
|
||||
then
|
||||
ECNTT_DEFINED=ON
|
||||
fi
|
||||
|
||||
BUILD_DIR=$(realpath "$PWD/../../icicle/build")
|
||||
SUPPORTED_CURVES=("bn254" "bls12_377" "bls12_381" "bw6_761")
|
||||
|
||||
@@ -22,6 +28,6 @@ mkdir -p build
|
||||
|
||||
for CURVE in "${BUILD_CURVES[@]}"
|
||||
do
|
||||
cmake -DCURVE=$CURVE -DG2_DEFINED=$G2_DEFINED -DCMAKE_BUILD_TYPE=Release -S . -B build
|
||||
cmake --build build
|
||||
cmake -DCURVE=$CURVE -DG2_DEFINED=$G2_DEFINED -DECNTT_DEFINED=$ECNTT_DEFINED -DCMAKE_BUILD_TYPE=Release -S . -B build
|
||||
cmake --build build -j8
|
||||
done
|
||||
@@ -76,7 +76,7 @@ func GetDefaultMSMConfig() MSMConfig {
|
||||
}
|
||||
|
||||
func MsmCheck(scalars HostOrDeviceSlice, points HostOrDeviceSlice, cfg *MSMConfig, results HostOrDeviceSlice) {
|
||||
scalarsLength, pointsLength, resultsLength := scalars.Len(), points.Len(), results.Len()
|
||||
scalarsLength, pointsLength, resultsLength := scalars.Len(), points.Len()/int(cfg.PrecomputeFactor), results.Len()
|
||||
if scalarsLength%pointsLength != 0 {
|
||||
errorString := fmt.Sprintf(
|
||||
"Number of points %d does not divide the number of scalars %d",
|
||||
@@ -99,3 +99,15 @@ func MsmCheck(scalars HostOrDeviceSlice, points HostOrDeviceSlice, cfg *MSMConfi
|
||||
cfg.arePointsOnDevice = points.IsOnDevice()
|
||||
cfg.areResultsOnDevice = results.IsOnDevice()
|
||||
}
|
||||
|
||||
func PrecomputeBasesCheck(points HostOrDeviceSlice, precomputeFactor int32, outputBases DeviceSlice) {
|
||||
outputBasesLength, pointsLength := outputBases.Len(), points.Len()
|
||||
if outputBasesLength != pointsLength*int(precomputeFactor) {
|
||||
errorString := fmt.Sprintf(
|
||||
"Precompute factor is probably incorrect: expected %d but got %d",
|
||||
outputBasesLength/pointsLength,
|
||||
precomputeFactor,
|
||||
)
|
||||
panic(errorString)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,18 +10,26 @@ type NTTDir int8
|
||||
|
||||
const (
|
||||
KForward NTTDir = iota
|
||||
KInverse NTTDir = 1
|
||||
KInverse
|
||||
)
|
||||
|
||||
type Ordering uint32
|
||||
|
||||
const (
|
||||
KNN Ordering = iota
|
||||
KNR Ordering = 1
|
||||
KRN Ordering = 2
|
||||
KRR Ordering = 3
|
||||
KNM Ordering = 4
|
||||
KMN Ordering = 5
|
||||
KNR
|
||||
KRN
|
||||
KRR
|
||||
KNM
|
||||
KMN
|
||||
)
|
||||
|
||||
type NttAlgorithm uint32
|
||||
|
||||
const (
|
||||
Auto NttAlgorithm = iota
|
||||
Radix2
|
||||
MixedRadix
|
||||
)
|
||||
|
||||
type NTTConfig[T any] struct {
|
||||
@@ -31,13 +39,17 @@ type NTTConfig[T any] struct {
|
||||
CosetGen T
|
||||
/// The number of NTTs to compute. Default value: 1.
|
||||
BatchSize int32
|
||||
/// If true the function will compute the NTTs over the columns of the input matrix and not over the rows.
|
||||
ColumnsBatch bool
|
||||
/// Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`.
|
||||
Ordering Ordering
|
||||
areInputsOnDevice bool
|
||||
areOutputsOnDevice bool
|
||||
/// Whether to run the NTT asynchronously. If set to `true`, the NTT function will be non-blocking and you'd need to synchronize
|
||||
/// it explicitly by running `stream.synchronize()`. If set to false, the NTT function will block the current CPU thread.
|
||||
IsAsync bool
|
||||
IsAsync bool
|
||||
NttAlgorithm NttAlgorithm /**< Explicitly select the NTT algorithm. Default value: Auto (the implementation
|
||||
selects radix-2 or mixed-radix algorithm based on heuristics). */
|
||||
}
|
||||
|
||||
func GetDefaultNTTConfig[T any](cosetGen T) NTTConfig[T] {
|
||||
@@ -46,10 +58,12 @@ func GetDefaultNTTConfig[T any](cosetGen T) NTTConfig[T] {
|
||||
ctx, // Ctx
|
||||
cosetGen, // CosetGen
|
||||
1, // BatchSize
|
||||
false, // ColumnsBatch
|
||||
KNN, // Ordering
|
||||
false, // areInputsOnDevice
|
||||
false, // areOutputsOnDevice
|
||||
false, // IsAsync
|
||||
Auto,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -19,10 +19,12 @@ func TestNTTDefaultConfig(t *testing.T) {
|
||||
ctx, // Ctx
|
||||
cosetGen, // CosetGen
|
||||
1, // BatchSize
|
||||
false, // ColumnsBatch
|
||||
KNN, // Ordering
|
||||
false, // areInputsOnDevice
|
||||
false, // areOutputsOnDevice
|
||||
false, // IsAsync
|
||||
Auto, // NttAlgorithm
|
||||
}
|
||||
|
||||
actual := GetDefaultNTTConfig(cosetGen)
|
||||
|
||||
@@ -43,8 +43,17 @@ func (d DeviceSlice) IsOnDevice() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO: change signature to be Malloc(element, numElements)
|
||||
// calc size internally
|
||||
func (d DeviceSlice) GetDeviceId() int {
|
||||
return cr.GetDeviceFromPointer(d.inner)
|
||||
}
|
||||
|
||||
// CheckDevice is used to ensure that the DeviceSlice about to be used resides on the currently set device
|
||||
func (d DeviceSlice) CheckDevice() {
|
||||
if currentDeviceId, err := cr.GetDevice(); err != cr.CudaSuccess || d.GetDeviceId() != currentDeviceId {
|
||||
panic("Attempt to use DeviceSlice on a different device")
|
||||
}
|
||||
}
|
||||
|
||||
func (d *DeviceSlice) Malloc(size, sizeOfElement int) (DeviceSlice, cr.CudaError) {
|
||||
dp, err := cr.Malloc(uint(size))
|
||||
d.inner = dp
|
||||
@@ -62,6 +71,7 @@ func (d *DeviceSlice) MallocAsync(size, sizeOfElement int, stream cr.CudaStream)
|
||||
}
|
||||
|
||||
func (d *DeviceSlice) Free() cr.CudaError {
|
||||
d.CheckDevice()
|
||||
err := cr.Free(d.inner)
|
||||
if err == cr.CudaSuccess {
|
||||
d.length, d.capacity = 0, 0
|
||||
@@ -70,6 +80,16 @@ func (d *DeviceSlice) Free() cr.CudaError {
|
||||
return err
|
||||
}
|
||||
|
||||
func (d *DeviceSlice) FreeAsync(stream cr.Stream) cr.CudaError {
|
||||
d.CheckDevice()
|
||||
err := cr.FreeAsync(d.inner, stream)
|
||||
if err == cr.CudaSuccess {
|
||||
d.length, d.capacity = 0, 0
|
||||
d.inner = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type HostSliceInterface interface {
|
||||
Size() int
|
||||
}
|
||||
@@ -117,6 +137,7 @@ func (h HostSlice[T]) CopyToDevice(dst *DeviceSlice, shouldAllocate bool) *Devic
|
||||
if shouldAllocate {
|
||||
dst.Malloc(size, h.SizeOfElement())
|
||||
}
|
||||
dst.CheckDevice()
|
||||
if size > dst.Cap() {
|
||||
panic("Number of bytes to copy is too large for destination")
|
||||
}
|
||||
@@ -133,6 +154,7 @@ func (h HostSlice[T]) CopyToDeviceAsync(dst *DeviceSlice, stream cr.CudaStream,
|
||||
if shouldAllocate {
|
||||
dst.MallocAsync(size, h.SizeOfElement(), stream)
|
||||
}
|
||||
dst.CheckDevice()
|
||||
if size > dst.Cap() {
|
||||
panic("Number of bytes to copy is too large for destination")
|
||||
}
|
||||
@@ -144,6 +166,7 @@ func (h HostSlice[T]) CopyToDeviceAsync(dst *DeviceSlice, stream cr.CudaStream,
|
||||
}
|
||||
|
||||
func (h HostSlice[T]) CopyFromDevice(src *DeviceSlice) {
|
||||
src.CheckDevice()
|
||||
if h.Len() != src.Len() {
|
||||
panic("destination and source slices have different lengths")
|
||||
}
|
||||
@@ -152,6 +175,7 @@ func (h HostSlice[T]) CopyFromDevice(src *DeviceSlice) {
|
||||
}
|
||||
|
||||
func (h HostSlice[T]) CopyFromDeviceAsync(src *DeviceSlice, stream cr.Stream) {
|
||||
src.CheckDevice()
|
||||
if h.Len() != src.Len() {
|
||||
panic("destination and source slices have different lengths")
|
||||
}
|
||||
|
||||
@@ -9,6 +9,8 @@ package cuda_runtime
|
||||
*/
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
@@ -17,20 +19,28 @@ type DeviceContext struct {
|
||||
Stream *Stream // Assuming the type is provided by a CUDA binding crate
|
||||
|
||||
/// Index of the currently used GPU. Default value: 0.
|
||||
DeviceId uint
|
||||
deviceId uint
|
||||
|
||||
/// Mempool to use. Default value: 0.
|
||||
// TODO: use cuda_bindings.CudaMemPool as type
|
||||
Mempool uint // Assuming the type is provided by a CUDA binding crate
|
||||
Mempool MemPool // Assuming the type is provided by a CUDA binding crate
|
||||
}
|
||||
|
||||
func (d DeviceContext) GetDeviceId() int {
|
||||
return int(d.deviceId)
|
||||
}
|
||||
|
||||
func GetDefaultDeviceContext() (DeviceContext, CudaError) {
|
||||
device, err := GetDevice()
|
||||
if err != CudaSuccess {
|
||||
panic(fmt.Sprintf("Could not get current device due to %v", err))
|
||||
}
|
||||
var defaultStream Stream
|
||||
var defaultMempool MemPool
|
||||
|
||||
return DeviceContext{
|
||||
&defaultStream,
|
||||
0,
|
||||
0,
|
||||
uint(device),
|
||||
defaultMempool,
|
||||
}, CudaSuccess
|
||||
}
|
||||
|
||||
@@ -47,3 +57,78 @@ func GetDeviceCount() (int, CudaError) {
|
||||
err := C.cudaGetDeviceCount(cCount)
|
||||
return count, (CudaError)(err)
|
||||
}
|
||||
|
||||
func GetDevice() (int, CudaError) {
|
||||
var device int
|
||||
cDevice := (*C.int)(unsafe.Pointer(&device))
|
||||
err := C.cudaGetDevice(cDevice)
|
||||
return device, (CudaError)(err)
|
||||
}
|
||||
|
||||
func GetDeviceFromPointer(ptr unsafe.Pointer) int {
|
||||
var cCudaPointerAttributes CudaPointerAttributes
|
||||
err := C.cudaPointerGetAttributes(&cCudaPointerAttributes, ptr)
|
||||
if (CudaError)(err) != CudaSuccess {
|
||||
panic("Could not get attributes of pointer")
|
||||
}
|
||||
return int(cCudaPointerAttributes.device)
|
||||
}
|
||||
|
||||
// RunOnDevice forces the provided function to run all GPU related calls within it
|
||||
// on the same host thread and therefore the same GPU device.
|
||||
//
|
||||
// NOTE: Goroutines launched within funcToRun are not bound to the
|
||||
// same host thread as funcToRun and therefore not to the same GPU device.
|
||||
// If that is a requirement, RunOnDevice should be called for each with the
|
||||
// same deviceId as the original call.
|
||||
//
|
||||
// As an example:
|
||||
//
|
||||
// cr.RunOnDevice(i, func(args ...any) {
|
||||
// defer wg.Done()
|
||||
// cfg := GetDefaultMSMConfig()
|
||||
// stream, _ := cr.CreateStream()
|
||||
// for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
// size := 1 << power
|
||||
//
|
||||
// // This will always print "Inner goroutine device: 0"
|
||||
// // go func () {
|
||||
// // device, _ := cr.GetDevice()
|
||||
// // fmt.Println("Inner goroutine device: ", device)
|
||||
// // }()
|
||||
// // To force the above goroutine to same device as the wrapping function:
|
||||
// // RunOnDevice(i, func(arg ...any) {
|
||||
// // device, _ := cr.GetDevice()
|
||||
// // fmt.Println("Inner goroutine device: ", device)
|
||||
// // })
|
||||
//
|
||||
// scalars := GenerateScalars(size)
|
||||
// points := GenerateAffinePoints(size)
|
||||
//
|
||||
// var p Projective
|
||||
// var out core.DeviceSlice
|
||||
// _, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
// assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
// cfg.Ctx.Stream = &stream
|
||||
// cfg.IsAsync = true
|
||||
//
|
||||
// e = Msm(scalars, points, &cfg, out)
|
||||
// assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
//
|
||||
// outHost := make(core.HostSlice[Projective], 1)
|
||||
//
|
||||
// cr.SynchronizeStream(&stream)
|
||||
// outHost.CopyFromDevice(&out)
|
||||
// out.Free()
|
||||
// // Check with gnark-crypto
|
||||
// assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
// }
|
||||
// }, i)
|
||||
func RunOnDevice(deviceId int, funcToRun func(args ...any), args ...any) {
|
||||
go func(id int) {
|
||||
defer runtime.UnlockOSThread()
|
||||
runtime.LockOSThread()
|
||||
SetDevice(id)
|
||||
funcToRun(args...)
|
||||
}(deviceId)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type MemPool = CudaMemPool
|
||||
|
||||
func Malloc(size uint) (unsafe.Pointer, CudaError) {
|
||||
if size == 0 {
|
||||
return nil, CudaErrorMemoryAllocation
|
||||
|
||||
@@ -17,3 +17,6 @@ type CudaEvent C.cudaEvent_t
|
||||
|
||||
// CudaMemPool as declared in include/driver_types.h:2928
|
||||
type CudaMemPool C.cudaMemPool_t
|
||||
|
||||
// CudaMemPool as declared in include/driver_types.h:2928
|
||||
type CudaPointerAttributes = C.struct_cudaPointerAttributes
|
||||
|
||||
@@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud
|
||||
}
|
||||
|
||||
func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr
|
||||
}
|
||||
|
||||
func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C
|
||||
}
|
||||
|
||||
func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool)
|
||||
}
|
||||
|
||||
func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
|
||||
}
|
||||
@@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0])
|
||||
}
|
||||
@@ -49,3 +55,28 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
|
||||
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
|
||||
}
|
||||
cPoints := (*C.g2_affine_t)(pointsPointer)
|
||||
|
||||
cPointsLen := (C.int)(points.Len())
|
||||
cPrecomputeFactor := (C.int)(precomputeFactor)
|
||||
cC := (C.int)(c)
|
||||
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
|
||||
|
||||
outputBasesPointer := outputBases.AsPointer()
|
||||
cOutputBases := (*C.g2_affine_t)(outputBasesPointer)
|
||||
|
||||
__ret := C.bls12_377G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,9 +3,12 @@
|
||||
package bls12377
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -82,6 +85,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor
|
||||
|
||||
func TestMSMG2(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -94,12 +98,14 @@ func TestMSMG2(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -136,6 +142,48 @@ func TestMSMG2Batch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecomputeBaseG2(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
const precomputeFactor = 8
|
||||
for _, power := range []int{10, 16} {
|
||||
for _, batchSize := range []int{1, 3, 16} {
|
||||
size := 1 << power
|
||||
totalSize := size * batchSize
|
||||
scalars := GenerateScalars(totalSize)
|
||||
points := G2GenerateAffinePoints(totalSize)
|
||||
|
||||
var precomputeOut core.DeviceSlice
|
||||
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
|
||||
|
||||
e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
|
||||
|
||||
var p G2Projective
|
||||
var out core.DeviceSlice
|
||||
_, e = out.Malloc(batchSize*p.Size(), p.Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
|
||||
cfg.PrecomputeFactor = precomputeFactor
|
||||
|
||||
e = G2Msm(scalars, precomputeOut, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], batchSize)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
precomputeOut.Free()
|
||||
|
||||
// Check with gnark-crypto
|
||||
for i := 0; i < batchSize; i++ {
|
||||
scalarsSlice := scalars[i*size : (i+1)*size]
|
||||
pointsSlice := points[i*size : (i+1)*size]
|
||||
out := outHost[i]
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMG2SkewedDistribution(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
@@ -165,3 +213,43 @@ func TestMSMG2SkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMG2MultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
orig_device, _ := cr.GetDevice()
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := G2GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p G2Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
cr.SetDevice(orig_device)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef _BLS12_377_G2MSM_H
|
||||
#define _BLS12_377_G2MSM_H
|
||||
@@ -8,7 +9,8 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bls12_377G2MSMCuda(scalar_t* scalars, g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
|
||||
cudaError_t bls12_377G2MSMCuda(const scalar_t* scalars,const g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
|
||||
cudaError_t bls12_377G2PrecomputeMSMBases(g2_affine_t* points, int count, int precompute_factor, int _c, bool bases_on_device, DeviceContext* ctx, g2_affine_t* out);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef _BLS12_377_MSM_H
|
||||
#define _BLS12_377_MSM_H
|
||||
@@ -8,7 +9,8 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bls12_377MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
|
||||
cudaError_t bls12_377MSMCuda(const scalar_t* scalars, const affine_t* points, int count, MSMConfig* config, projective_t* out);
|
||||
cudaError_t bls12_377PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -9,11 +9,12 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bls12_377NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bls12_377NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bls12_377ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t bls12_377InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
|
||||
}
|
||||
@@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
@@ -47,3 +53,28 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
|
||||
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
|
||||
}
|
||||
cPoints := (*C.affine_t)(pointsPointer)
|
||||
|
||||
cPointsLen := (C.int)(points.Len())
|
||||
cPrecomputeFactor := (C.int)(precomputeFactor)
|
||||
cC := (C.int)(c)
|
||||
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
|
||||
|
||||
outputBasesPointer := outputBases.AsPointer()
|
||||
cOutputBases := (*C.affine_t)(outputBasesPointer)
|
||||
|
||||
__ret := C.bls12_377PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package bls12377
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core.
|
||||
|
||||
func TestMSM(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -65,12 +69,14 @@ func TestMSM(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -107,6 +113,48 @@ func TestMSMBatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecomputeBase(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
const precomputeFactor = 8
|
||||
for _, power := range []int{10, 16} {
|
||||
for _, batchSize := range []int{1, 3, 16} {
|
||||
size := 1 << power
|
||||
totalSize := size * batchSize
|
||||
scalars := GenerateScalars(totalSize)
|
||||
points := GenerateAffinePoints(totalSize)
|
||||
|
||||
var precomputeOut core.DeviceSlice
|
||||
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
|
||||
|
||||
e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
|
||||
|
||||
var p Projective
|
||||
var out core.DeviceSlice
|
||||
_, e = out.Malloc(batchSize*p.Size(), p.Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
|
||||
cfg.PrecomputeFactor = precomputeFactor
|
||||
|
||||
e = Msm(scalars, precomputeOut, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], batchSize)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
precomputeOut.Free()
|
||||
|
||||
// Check with gnark-crypto
|
||||
for i := 0; i < batchSize; i++ {
|
||||
scalarsSlice := scalars[i*size : (i+1)*size]
|
||||
pointsSlice := points[i*size : (i+1)*size]
|
||||
out := outHost[i]
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMSkewedDistribution(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
@@ -136,3 +184,43 @@ func TestMSMSkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMMultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
orig_device, _ := cr.GetDevice()
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
cr.SetDevice(orig_device)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -49,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
|
||||
core.NttCheck[T](points, cfg, results)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cPoints := (*C.projective_t)(pointsPointer)
|
||||
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
|
||||
cDir := (C.int)(dir)
|
||||
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cResults := (*C.projective_t)(resultsPointer)
|
||||
|
||||
__ret := C.bls12_377ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
|
||||
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
|
||||
@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestECNtt(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
initDomain(largestTestSize, cfg)
|
||||
points := GenerateProjectivePoints(1 << largestTestSize)
|
||||
|
||||
for _, size := range []int{4, 5, 6, 7, 8} {
|
||||
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
testSize := 1 << size
|
||||
|
||||
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
|
||||
cfg.Ordering = v
|
||||
cfg.NttAlgorithm = core.Radix2
|
||||
|
||||
output := make(core.HostSlice[Projective], testSize)
|
||||
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNttDeviceAsync(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
scalars := GenerateScalars(1 << largestTestSize)
|
||||
|
||||
@@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
|
||||
}
|
||||
|
||||
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, true)
|
||||
}
|
||||
|
||||
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, false)
|
||||
}
|
||||
|
||||
@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
aDevice := a.(core.DeviceSlice)
|
||||
aDevice.CheckDevice()
|
||||
cA = (*C.scalar_t)(aDevice.AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
bDevice := b.(core.DeviceSlice)
|
||||
bDevice.CheckDevice()
|
||||
cB = (*C.scalar_t)(bDevice.AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
outDevice := out.(core.DeviceSlice)
|
||||
outDevice.CheckDevice()
|
||||
cOut = (*C.scalar_t)(outDevice.AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
@@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud
|
||||
}
|
||||
|
||||
func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr
|
||||
}
|
||||
|
||||
func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C
|
||||
}
|
||||
|
||||
func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool)
|
||||
}
|
||||
|
||||
func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
|
||||
}
|
||||
@@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0])
|
||||
}
|
||||
@@ -49,3 +55,28 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
|
||||
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
|
||||
}
|
||||
cPoints := (*C.g2_affine_t)(pointsPointer)
|
||||
|
||||
cPointsLen := (C.int)(points.Len())
|
||||
cPrecomputeFactor := (C.int)(precomputeFactor)
|
||||
cC := (C.int)(c)
|
||||
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
|
||||
|
||||
outputBasesPointer := outputBases.AsPointer()
|
||||
cOutputBases := (*C.g2_affine_t)(outputBasesPointer)
|
||||
|
||||
__ret := C.bls12_381G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,9 +3,12 @@
|
||||
package bls12381
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -82,6 +85,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor
|
||||
|
||||
func TestMSMG2(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -94,12 +98,14 @@ func TestMSMG2(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -136,6 +142,48 @@ func TestMSMG2Batch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecomputeBaseG2(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
const precomputeFactor = 8
|
||||
for _, power := range []int{10, 16} {
|
||||
for _, batchSize := range []int{1, 3, 16} {
|
||||
size := 1 << power
|
||||
totalSize := size * batchSize
|
||||
scalars := GenerateScalars(totalSize)
|
||||
points := G2GenerateAffinePoints(totalSize)
|
||||
|
||||
var precomputeOut core.DeviceSlice
|
||||
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
|
||||
|
||||
e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
|
||||
|
||||
var p G2Projective
|
||||
var out core.DeviceSlice
|
||||
_, e = out.Malloc(batchSize*p.Size(), p.Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
|
||||
cfg.PrecomputeFactor = precomputeFactor
|
||||
|
||||
e = G2Msm(scalars, precomputeOut, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], batchSize)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
precomputeOut.Free()
|
||||
|
||||
// Check with gnark-crypto
|
||||
for i := 0; i < batchSize; i++ {
|
||||
scalarsSlice := scalars[i*size : (i+1)*size]
|
||||
pointsSlice := points[i*size : (i+1)*size]
|
||||
out := outHost[i]
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMG2SkewedDistribution(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
@@ -165,3 +213,43 @@ func TestMSMG2SkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMG2MultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
orig_device, _ := cr.GetDevice()
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := G2GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p G2Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
cr.SetDevice(orig_device)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef _BLS12_381_G2MSM_H
|
||||
#define _BLS12_381_G2MSM_H
|
||||
@@ -8,7 +9,8 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bls12_381G2MSMCuda(scalar_t* scalars, g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
|
||||
cudaError_t bls12_381G2MSMCuda(const scalar_t* scalars,const g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
|
||||
cudaError_t bls12_381G2PrecomputeMSMBases(g2_affine_t* points, int count, int precompute_factor, int _c, bool bases_on_device, DeviceContext* ctx, g2_affine_t* out);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef _BLS12_381_MSM_H
|
||||
#define _BLS12_381_MSM_H
|
||||
@@ -8,7 +9,8 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bls12_381MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
|
||||
cudaError_t bls12_381MSMCuda(const scalar_t* scalars, const affine_t* points, int count, MSMConfig* config, projective_t* out);
|
||||
cudaError_t bls12_381PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -9,11 +9,12 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bls12_381NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bls12_381NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bls12_381ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t bls12_381InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
|
||||
}
|
||||
@@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
@@ -47,3 +53,28 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
|
||||
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
|
||||
}
|
||||
cPoints := (*C.affine_t)(pointsPointer)
|
||||
|
||||
cPointsLen := (C.int)(points.Len())
|
||||
cPrecomputeFactor := (C.int)(precomputeFactor)
|
||||
cC := (C.int)(c)
|
||||
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
|
||||
|
||||
outputBasesPointer := outputBases.AsPointer()
|
||||
cOutputBases := (*C.affine_t)(outputBasesPointer)
|
||||
|
||||
__ret := C.bls12_381PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package bls12381
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core.
|
||||
|
||||
func TestMSM(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -65,12 +69,14 @@ func TestMSM(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -107,6 +113,48 @@ func TestMSMBatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecomputeBase(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
const precomputeFactor = 8
|
||||
for _, power := range []int{10, 16} {
|
||||
for _, batchSize := range []int{1, 3, 16} {
|
||||
size := 1 << power
|
||||
totalSize := size * batchSize
|
||||
scalars := GenerateScalars(totalSize)
|
||||
points := GenerateAffinePoints(totalSize)
|
||||
|
||||
var precomputeOut core.DeviceSlice
|
||||
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
|
||||
|
||||
e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
|
||||
|
||||
var p Projective
|
||||
var out core.DeviceSlice
|
||||
_, e = out.Malloc(batchSize*p.Size(), p.Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
|
||||
cfg.PrecomputeFactor = precomputeFactor
|
||||
|
||||
e = Msm(scalars, precomputeOut, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], batchSize)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
precomputeOut.Free()
|
||||
|
||||
// Check with gnark-crypto
|
||||
for i := 0; i < batchSize; i++ {
|
||||
scalarsSlice := scalars[i*size : (i+1)*size]
|
||||
pointsSlice := points[i*size : (i+1)*size]
|
||||
out := outHost[i]
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMSkewedDistribution(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
@@ -136,3 +184,43 @@ func TestMSMSkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMMultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
orig_device, _ := cr.GetDevice()
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
cr.SetDevice(orig_device)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -49,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
|
||||
core.NttCheck[T](points, cfg, results)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cPoints := (*C.projective_t)(pointsPointer)
|
||||
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
|
||||
cDir := (C.int)(dir)
|
||||
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cResults := (*C.projective_t)(resultsPointer)
|
||||
|
||||
__ret := C.bls12_381ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
|
||||
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
|
||||
@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestECNtt(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
initDomain(largestTestSize, cfg)
|
||||
points := GenerateProjectivePoints(1 << largestTestSize)
|
||||
|
||||
for _, size := range []int{4, 5, 6, 7, 8} {
|
||||
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
testSize := 1 << size
|
||||
|
||||
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
|
||||
cfg.Ordering = v
|
||||
cfg.NttAlgorithm = core.Radix2
|
||||
|
||||
output := make(core.HostSlice[Projective], testSize)
|
||||
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNttDeviceAsync(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
scalars := GenerateScalars(1 << largestTestSize)
|
||||
|
||||
@@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
|
||||
}
|
||||
|
||||
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, true)
|
||||
}
|
||||
|
||||
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, false)
|
||||
}
|
||||
|
||||
@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
aDevice := a.(core.DeviceSlice)
|
||||
aDevice.CheckDevice()
|
||||
cA = (*C.scalar_t)(aDevice.AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
bDevice := b.(core.DeviceSlice)
|
||||
bDevice.CheckDevice()
|
||||
cB = (*C.scalar_t)(bDevice.AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
outDevice := out.(core.DeviceSlice)
|
||||
outDevice.CheckDevice()
|
||||
cOut = (*C.scalar_t)(outDevice.AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
@@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud
|
||||
}
|
||||
|
||||
func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr
|
||||
}
|
||||
|
||||
func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C
|
||||
}
|
||||
|
||||
func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool)
|
||||
}
|
||||
|
||||
func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
|
||||
}
|
||||
@@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0])
|
||||
}
|
||||
@@ -49,3 +55,28 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
|
||||
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
|
||||
}
|
||||
cPoints := (*C.g2_affine_t)(pointsPointer)
|
||||
|
||||
cPointsLen := (C.int)(points.Len())
|
||||
cPrecomputeFactor := (C.int)(precomputeFactor)
|
||||
cC := (C.int)(c)
|
||||
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
|
||||
|
||||
outputBasesPointer := outputBases.AsPointer()
|
||||
cOutputBases := (*C.g2_affine_t)(outputBasesPointer)
|
||||
|
||||
__ret := C.bn254G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,9 +3,12 @@
|
||||
package bn254
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -82,6 +85,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor
|
||||
|
||||
func TestMSMG2(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -94,12 +98,14 @@ func TestMSMG2(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -136,6 +142,48 @@ func TestMSMG2Batch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecomputeBaseG2(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
const precomputeFactor = 8
|
||||
for _, power := range []int{10, 16} {
|
||||
for _, batchSize := range []int{1, 3, 16} {
|
||||
size := 1 << power
|
||||
totalSize := size * batchSize
|
||||
scalars := GenerateScalars(totalSize)
|
||||
points := G2GenerateAffinePoints(totalSize)
|
||||
|
||||
var precomputeOut core.DeviceSlice
|
||||
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
|
||||
|
||||
e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
|
||||
|
||||
var p G2Projective
|
||||
var out core.DeviceSlice
|
||||
_, e = out.Malloc(batchSize*p.Size(), p.Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
|
||||
cfg.PrecomputeFactor = precomputeFactor
|
||||
|
||||
e = G2Msm(scalars, precomputeOut, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], batchSize)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
precomputeOut.Free()
|
||||
|
||||
// Check with gnark-crypto
|
||||
for i := 0; i < batchSize; i++ {
|
||||
scalarsSlice := scalars[i*size : (i+1)*size]
|
||||
pointsSlice := points[i*size : (i+1)*size]
|
||||
out := outHost[i]
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMG2SkewedDistribution(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
@@ -165,3 +213,43 @@ func TestMSMG2SkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMG2MultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
orig_device, _ := cr.GetDevice()
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := G2GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p G2Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
cr.SetDevice(orig_device)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef _BN254_G2MSM_H
|
||||
#define _BN254_G2MSM_H
|
||||
@@ -8,7 +9,8 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bn254G2MSMCuda(scalar_t* scalars, g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
|
||||
cudaError_t bn254G2MSMCuda(const scalar_t* scalars,const g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
|
||||
cudaError_t bn254G2PrecomputeMSMBases(g2_affine_t* points, int count, int precompute_factor, int _c, bool bases_on_device, DeviceContext* ctx, g2_affine_t* out);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef _BN254_MSM_H
|
||||
#define _BN254_MSM_H
|
||||
@@ -8,7 +9,8 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bn254MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
|
||||
cudaError_t bn254MSMCuda(const scalar_t* scalars, const affine_t* points, int count, MSMConfig* config, projective_t* out);
|
||||
cudaError_t bn254PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -9,11 +9,12 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bn254NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bn254NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bn254ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t bn254InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
|
||||
}
|
||||
@@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
@@ -47,3 +53,28 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
|
||||
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
|
||||
}
|
||||
cPoints := (*C.affine_t)(pointsPointer)
|
||||
|
||||
cPointsLen := (C.int)(points.Len())
|
||||
cPrecomputeFactor := (C.int)(precomputeFactor)
|
||||
cC := (C.int)(c)
|
||||
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
|
||||
|
||||
outputBasesPointer := outputBases.AsPointer()
|
||||
cOutputBases := (*C.affine_t)(outputBasesPointer)
|
||||
|
||||
__ret := C.bn254PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package bn254
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core.
|
||||
|
||||
func TestMSM(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -65,12 +69,14 @@ func TestMSM(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -107,6 +113,48 @@ func TestMSMBatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecomputeBase(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
const precomputeFactor = 8
|
||||
for _, power := range []int{10, 16} {
|
||||
for _, batchSize := range []int{1, 3, 16} {
|
||||
size := 1 << power
|
||||
totalSize := size * batchSize
|
||||
scalars := GenerateScalars(totalSize)
|
||||
points := GenerateAffinePoints(totalSize)
|
||||
|
||||
var precomputeOut core.DeviceSlice
|
||||
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
|
||||
|
||||
e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
|
||||
|
||||
var p Projective
|
||||
var out core.DeviceSlice
|
||||
_, e = out.Malloc(batchSize*p.Size(), p.Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
|
||||
cfg.PrecomputeFactor = precomputeFactor
|
||||
|
||||
e = Msm(scalars, precomputeOut, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], batchSize)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
precomputeOut.Free()
|
||||
|
||||
// Check with gnark-crypto
|
||||
for i := 0; i < batchSize; i++ {
|
||||
scalarsSlice := scalars[i*size : (i+1)*size]
|
||||
pointsSlice := points[i*size : (i+1)*size]
|
||||
out := outHost[i]
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMSkewedDistribution(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
@@ -136,3 +184,43 @@ func TestMSMSkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMMultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
orig_device, _ := cr.GetDevice()
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
cr.SetDevice(orig_device)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -49,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
|
||||
core.NttCheck[T](points, cfg, results)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cPoints := (*C.projective_t)(pointsPointer)
|
||||
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
|
||||
cDir := (C.int)(dir)
|
||||
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cResults := (*C.projective_t)(resultsPointer)
|
||||
|
||||
__ret := C.bn254ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
|
||||
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
|
||||
@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestECNtt(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
initDomain(largestTestSize, cfg)
|
||||
points := GenerateProjectivePoints(1 << largestTestSize)
|
||||
|
||||
for _, size := range []int{4, 5, 6, 7, 8} {
|
||||
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
testSize := 1 << size
|
||||
|
||||
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
|
||||
cfg.Ordering = v
|
||||
cfg.NttAlgorithm = core.Radix2
|
||||
|
||||
output := make(core.HostSlice[Projective], testSize)
|
||||
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNttDeviceAsync(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
scalars := GenerateScalars(1 << largestTestSize)
|
||||
|
||||
@@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
|
||||
}
|
||||
|
||||
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, true)
|
||||
}
|
||||
|
||||
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, false)
|
||||
}
|
||||
|
||||
@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
aDevice := a.(core.DeviceSlice)
|
||||
aDevice.CheckDevice()
|
||||
cA = (*C.scalar_t)(aDevice.AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
bDevice := b.(core.DeviceSlice)
|
||||
bDevice.CheckDevice()
|
||||
cB = (*C.scalar_t)(bDevice.AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
outDevice := out.(core.DeviceSlice)
|
||||
outDevice.CheckDevice()
|
||||
cOut = (*C.scalar_t)(outDevice.AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
@@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud
|
||||
}
|
||||
|
||||
func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertAffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr
|
||||
}
|
||||
|
||||
func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C
|
||||
}
|
||||
|
||||
func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2AffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool)
|
||||
}
|
||||
|
||||
func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convertG2ProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
|
||||
}
|
||||
@@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0])
|
||||
}
|
||||
@@ -49,3 +55,28 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
|
||||
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
|
||||
}
|
||||
cPoints := (*C.g2_affine_t)(pointsPointer)
|
||||
|
||||
cPointsLen := (C.int)(points.Len())
|
||||
cPrecomputeFactor := (C.int)(precomputeFactor)
|
||||
cC := (C.int)(c)
|
||||
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
|
||||
|
||||
outputBasesPointer := outputBases.AsPointer()
|
||||
cOutputBases := (*C.g2_affine_t)(outputBasesPointer)
|
||||
|
||||
__ret := C.bw6_761G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,9 +3,12 @@
|
||||
package bw6761
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -55,6 +58,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor
|
||||
|
||||
func TestMSMG2(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -67,12 +71,14 @@ func TestMSMG2(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -109,6 +115,48 @@ func TestMSMG2Batch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecomputeBaseG2(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
const precomputeFactor = 8
|
||||
for _, power := range []int{10, 16} {
|
||||
for _, batchSize := range []int{1, 3, 16} {
|
||||
size := 1 << power
|
||||
totalSize := size * batchSize
|
||||
scalars := GenerateScalars(totalSize)
|
||||
points := G2GenerateAffinePoints(totalSize)
|
||||
|
||||
var precomputeOut core.DeviceSlice
|
||||
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
|
||||
|
||||
e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
|
||||
|
||||
var p G2Projective
|
||||
var out core.DeviceSlice
|
||||
_, e = out.Malloc(batchSize*p.Size(), p.Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
|
||||
cfg.PrecomputeFactor = precomputeFactor
|
||||
|
||||
e = G2Msm(scalars, precomputeOut, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], batchSize)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
precomputeOut.Free()
|
||||
|
||||
// Check with gnark-crypto
|
||||
for i := 0; i < batchSize; i++ {
|
||||
scalarsSlice := scalars[i*size : (i+1)*size]
|
||||
pointsSlice := points[i*size : (i+1)*size]
|
||||
out := outHost[i]
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMG2SkewedDistribution(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
@@ -138,3 +186,43 @@ func TestMSMG2SkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMG2MultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
orig_device, _ := cr.GetDevice()
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := G2GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p G2Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = G2Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[G2Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
cr.SetDevice(orig_device)
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef _BW6_761_G2MSM_H
|
||||
#define _BW6_761_G2MSM_H
|
||||
@@ -8,7 +9,8 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bw6_761G2MSMCuda(scalar_t* scalars, g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
|
||||
cudaError_t bw6_761G2MSMCuda(const scalar_t* scalars,const g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
|
||||
cudaError_t bw6_761G2PrecomputeMSMBases(g2_affine_t* points, int count, int precompute_factor, int _c, bool bases_on_device, DeviceContext* ctx, g2_affine_t* out);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef _BW6_761_MSM_H
|
||||
#define _BW6_761_MSM_H
|
||||
@@ -8,7 +9,8 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bw6_761MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
|
||||
cudaError_t bw6_761MSMCuda(const scalar_t* scalars, const affine_t* points, int count, MSMConfig* config, projective_t* out);
|
||||
cudaError_t bw6_761PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -9,11 +9,12 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t bw6_761NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bw6_761NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t bw6_761ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t bw6_761InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
|
||||
}
|
||||
@@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
@@ -47,3 +53,28 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
|
||||
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
|
||||
}
|
||||
cPoints := (*C.affine_t)(pointsPointer)
|
||||
|
||||
cPointsLen := (C.int)(points.Len())
|
||||
cPrecomputeFactor := (C.int)(precomputeFactor)
|
||||
cC := (C.int)(c)
|
||||
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
|
||||
|
||||
outputBasesPointer := outputBases.AsPointer()
|
||||
cOutputBases := (*C.affine_t)(outputBasesPointer)
|
||||
|
||||
__ret := C.bw6_761PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package bw6761
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core.
|
||||
|
||||
func TestMSM(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -65,12 +69,14 @@ func TestMSM(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -107,6 +113,48 @@ func TestMSMBatch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecomputeBase(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
const precomputeFactor = 8
|
||||
for _, power := range []int{10, 16} {
|
||||
for _, batchSize := range []int{1, 3, 16} {
|
||||
size := 1 << power
|
||||
totalSize := size * batchSize
|
||||
scalars := GenerateScalars(totalSize)
|
||||
points := GenerateAffinePoints(totalSize)
|
||||
|
||||
var precomputeOut core.DeviceSlice
|
||||
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
|
||||
|
||||
e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
|
||||
|
||||
var p Projective
|
||||
var out core.DeviceSlice
|
||||
_, e = out.Malloc(batchSize*p.Size(), p.Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
|
||||
cfg.PrecomputeFactor = precomputeFactor
|
||||
|
||||
e = Msm(scalars, precomputeOut, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], batchSize)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
precomputeOut.Free()
|
||||
|
||||
// Check with gnark-crypto
|
||||
for i := 0; i < batchSize; i++ {
|
||||
scalarsSlice := scalars[i*size : (i+1)*size]
|
||||
pointsSlice := points[i*size : (i+1)*size]
|
||||
out := outHost[i]
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMSkewedDistribution(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
@@ -136,3 +184,43 @@ func TestMSMSkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSMMultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
orig_device, _ := cr.GetDevice()
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
cr.SetDevice(orig_device)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -49,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
|
||||
core.NttCheck[T](points, cfg, results)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cPoints := (*C.projective_t)(pointsPointer)
|
||||
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
|
||||
cDir := (C.int)(dir)
|
||||
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cResults := (*C.projective_t)(resultsPointer)
|
||||
|
||||
__ret := C.bw6_761ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
|
||||
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
|
||||
@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestECNtt(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
initDomain(largestTestSize, cfg)
|
||||
points := GenerateProjectivePoints(1 << largestTestSize)
|
||||
|
||||
for _, size := range []int{4, 5, 6, 7, 8} {
|
||||
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
testSize := 1 << size
|
||||
|
||||
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
|
||||
cfg.Ordering = v
|
||||
cfg.NttAlgorithm = core.Radix2
|
||||
|
||||
output := make(core.HostSlice[Projective], testSize)
|
||||
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNttDeviceAsync(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
scalars := GenerateScalars(1 << largestTestSize)
|
||||
|
||||
@@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
|
||||
}
|
||||
|
||||
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, true)
|
||||
}
|
||||
|
||||
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, false)
|
||||
}
|
||||
|
||||
@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
aDevice := a.(core.DeviceSlice)
|
||||
aDevice.CheckDevice()
|
||||
cA = (*C.scalar_t)(aDevice.AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
bDevice := b.(core.DeviceSlice)
|
||||
bDevice.CheckDevice()
|
||||
cB = (*C.scalar_t)(bDevice.AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
outDevice := out.(core.DeviceSlice)
|
||||
outDevice.CheckDevice()
|
||||
cOut = (*C.scalar_t)(outDevice.AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
@@ -150,10 +150,12 @@ func convert{{if .IsG2}}G2{{end}}AffinePointsMontgomery(points *core.DeviceSlice
|
||||
}
|
||||
|
||||
func {{if .IsG2}}G2{{end}}AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convert{{if .IsG2}}G2{{end}}AffinePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func {{if .IsG2}}G2{{end}}AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convert{{if .IsG2}}G2{{end}}AffinePointsMontgomery(points, false)
|
||||
}
|
||||
|
||||
@@ -169,10 +171,12 @@ func convert{{if .IsG2}}G2{{end}}ProjectivePointsMontgomery(points *core.DeviceS
|
||||
}
|
||||
|
||||
func {{if .IsG2}}G2{{end}}ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convert{{if .IsG2}}G2{{end}}ProjectivePointsMontgomery(points, true)
|
||||
}
|
||||
|
||||
func {{if .IsG2}}G2{{end}}ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
|
||||
points.CheckDevice()
|
||||
return convert{{if .IsG2}}G2{{end}}ProjectivePointsMontgomery(points, false)
|
||||
}
|
||||
{{end}}
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef _{{toUpper .Curve}}_G2MSM_H
|
||||
#define _{{toUpper .Curve}}_G2MSM_H
|
||||
@@ -8,7 +9,8 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t {{.Curve}}G2MSMCuda(scalar_t* scalars, g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
|
||||
cudaError_t {{.Curve}}G2MSMCuda(const scalar_t* scalars,const g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
|
||||
cudaError_t {{.Curve}}G2PrecomputeMSMBases(g2_affine_t* points, int count, int precompute_factor, int _c, bool bases_on_device, DeviceContext* ctx, g2_affine_t* out);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include "../../include/types.h"
|
||||
#include <stdbool.h>
|
||||
|
||||
#ifndef _{{toUpper .Curve}}_MSM_H
|
||||
#define _{{toUpper .Curve}}_MSM_H
|
||||
@@ -8,7 +9,8 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t {{.Curve}}MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
|
||||
cudaError_t {{.Curve}}MSMCuda(const scalar_t* scalars, const affine_t* points, int count, MSMConfig* config, projective_t* out);
|
||||
cudaError_t {{.Curve}}PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
||||
@@ -9,11 +9,12 @@
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
cudaError_t {{.Curve}}NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t {{.Curve}}NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
|
||||
cudaError_t {{.Curve}}ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
|
||||
cudaError_t {{.Curve}}InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
||||
@@ -22,7 +22,9 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr
|
||||
core.MsmCheck(scalars, points, cfg, results)
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -30,7 +32,9 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
pointsDevice := points.(core.DeviceSlice)
|
||||
pointsDevice.CheckDevice()
|
||||
pointsPointer = pointsDevice.AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[{{if .IsG2}}G2{{end}}Affine])[0])
|
||||
}
|
||||
@@ -38,7 +42,9 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[{{if .IsG2}}G2{{end}}Projective])[0])
|
||||
}
|
||||
@@ -51,3 +57,28 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
|
||||
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[{{if .IsG2}}G2{{end}}Affine])[0])
|
||||
}
|
||||
cPoints := (*C.{{if .IsG2}}g2_{{end}}affine_t)(pointsPointer)
|
||||
|
||||
cPointsLen := (C.int)(points.Len())
|
||||
cPrecomputeFactor := (C.int)(precomputeFactor)
|
||||
cC := (C.int)(c)
|
||||
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
|
||||
|
||||
outputBasesPointer := outputBases.AsPointer()
|
||||
cOutputBases := (*C.{{if .IsG2}}g2_{{end}}affine_t)(outputBasesPointer)
|
||||
|
||||
__ret := C.{{.Curve}}{{if .IsG2}}G2{{end}}PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -5,9 +5,12 @@
|
||||
package {{.PackageName}}
|
||||
|
||||
import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
|
||||
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
|
||||
|
||||
@@ -102,6 +105,7 @@ func testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars core.HostSlice[Scala
|
||||
|
||||
func TestMSM{{if .IsG2}}G2{{end}}(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
|
||||
@@ -114,12 +118,14 @@ func TestMSM{{if .IsG2}}G2{{end}}(t *testing.T) {
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = {{if .IsG2}}G2{{end}}Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[{{if .IsG2}}G2{{end}}Projective], 1)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars, points, outHost[0]))
|
||||
}
|
||||
@@ -156,6 +162,48 @@ func TestMSM{{if .IsG2}}G2{{end}}Batch(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrecomputeBase{{if .IsG2}}G2{{end}}(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
const precomputeFactor = 8
|
||||
for _, power := range []int{10, 16} {
|
||||
for _, batchSize := range []int{1, 3, 16} {
|
||||
size := 1 << power
|
||||
totalSize := size * batchSize
|
||||
scalars := GenerateScalars(totalSize)
|
||||
points := {{if .IsG2}}G2{{end}}GenerateAffinePoints(totalSize)
|
||||
|
||||
var precomputeOut core.DeviceSlice
|
||||
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
|
||||
|
||||
e = {{if .IsG2}}G2{{end}}PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
|
||||
|
||||
var p {{if .IsG2}}G2{{end}}Projective
|
||||
var out core.DeviceSlice
|
||||
_, e = out.Malloc(batchSize*p.Size(), p.Size())
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
|
||||
cfg.PrecomputeFactor = precomputeFactor
|
||||
|
||||
e = {{if .IsG2}}G2{{end}}Msm(scalars, precomputeOut, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[{{if .IsG2}}G2{{end}}Projective], batchSize)
|
||||
outHost.CopyFromDevice(&out)
|
||||
out.Free()
|
||||
precomputeOut.Free()
|
||||
|
||||
// Check with gnark-crypto
|
||||
for i := 0; i < batchSize; i++ {
|
||||
scalarsSlice := scalars[i*size : (i+1)*size]
|
||||
pointsSlice := points[i*size : (i+1)*size]
|
||||
out := outHost[i]
|
||||
assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalarsSlice, pointsSlice, out))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSM{{if .IsG2}}G2{{end}}SkewedDistribution(t *testing.T) {
|
||||
cfg := GetDefaultMSMConfig()
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
@@ -185,3 +233,43 @@ func TestMSM{{if .IsG2}}G2{{end}}SkewedDistribution(t *testing.T) {
|
||||
assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars, points, outHost[0]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestMSM{{if .IsG2}}G2{{end}}MultiDevice(t *testing.T) {
|
||||
numDevices, _ := cr.GetDeviceCount()
|
||||
fmt.Println("There are ", numDevices, " devices available")
|
||||
orig_device, _ := cr.GetDevice()
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
for i := 0; i < numDevices; i++ {
|
||||
wg.Add(1)
|
||||
cr.RunOnDevice(i, func(args ...any) {
|
||||
defer wg.Done()
|
||||
cfg := GetDefaultMSMConfig()
|
||||
cfg.IsAsync = true
|
||||
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
|
||||
size := 1 << power
|
||||
scalars := GenerateScalars(size)
|
||||
points := {{if .IsG2}}G2{{end}}GenerateAffinePoints(size)
|
||||
|
||||
stream, _ := cr.CreateStream()
|
||||
var p {{if .IsG2}}G2{{end}}Projective
|
||||
var out core.DeviceSlice
|
||||
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
|
||||
cfg.Ctx.Stream = &stream
|
||||
|
||||
e = {{if .IsG2}}G2{{end}}Msm(scalars, points, &cfg, out)
|
||||
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
|
||||
outHost := make(core.HostSlice[{{if .IsG2}}G2{{end}}Projective], 1)
|
||||
outHost.CopyFromDeviceAsync(&out, stream)
|
||||
out.FreeAsync(stream)
|
||||
|
||||
cr.SynchronizeStream(&stream)
|
||||
// Check with gnark-crypto
|
||||
assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars, points, outHost[0]))
|
||||
}
|
||||
})
|
||||
}
|
||||
wg.Wait()
|
||||
cr.SetDevice(orig_device)
|
||||
}
|
||||
|
||||
@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var scalarsPointer unsafe.Pointer
|
||||
if scalars.IsOnDevice() {
|
||||
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
|
||||
scalarsDevice := scalars.(core.DeviceSlice)
|
||||
scalarsDevice.CheckDevice()
|
||||
scalarsPointer = scalarsDevice.AsPointer()
|
||||
} else {
|
||||
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
resultsDevice := results.(core.DeviceSlice)
|
||||
resultsDevice.CheckDevice()
|
||||
resultsPointer = resultsDevice.AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
|
||||
}
|
||||
@@ -49,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
|
||||
core.NttCheck[T](points, cfg, results)
|
||||
|
||||
var pointsPointer unsafe.Pointer
|
||||
if points.IsOnDevice() {
|
||||
pointsPointer = points.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cPoints := (*C.projective_t)(pointsPointer)
|
||||
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
|
||||
cDir := (C.int)(dir)
|
||||
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
|
||||
|
||||
var resultsPointer unsafe.Pointer
|
||||
if results.IsOnDevice() {
|
||||
resultsPointer = results.(core.DeviceSlice).AsPointer()
|
||||
} else {
|
||||
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
|
||||
}
|
||||
cResults := (*C.projective_t)(resultsPointer)
|
||||
|
||||
__ret := C.{{.Curve}}ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
|
||||
err := (cr.CudaError)(__ret)
|
||||
return core.FromCudaError(err)
|
||||
}
|
||||
|
||||
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
|
||||
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
|
||||
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))
|
||||
|
||||
@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestECNtt(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
initDomain(largestTestSize, cfg)
|
||||
points := GenerateProjectivePoints(1 << largestTestSize)
|
||||
|
||||
for _, size := range []int{4, 5, 6, 7, 8} {
|
||||
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
|
||||
testSize := 1 << size
|
||||
|
||||
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
|
||||
cfg.Ordering = v
|
||||
cfg.NttAlgorithm = core.Radix2
|
||||
|
||||
output := make(core.HostSlice[Projective], testSize)
|
||||
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
|
||||
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNttDeviceAsync(t *testing.T) {
|
||||
cfg := GetDefaultNttConfig()
|
||||
scalars := GenerateScalars(1 << largestTestSize)
|
||||
|
||||
@@ -33,9 +33,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
|
||||
}
|
||||
|
||||
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, true)
|
||||
}
|
||||
|
||||
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
|
||||
scalars.CheckDevice()
|
||||
return convertScalarsMontgomery(scalars, false)
|
||||
}{{- end}}
|
||||
@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
|
||||
var cA, cB, cOut *C.scalar_t
|
||||
|
||||
if a.IsOnDevice() {
|
||||
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
|
||||
aDevice := a.(core.DeviceSlice)
|
||||
aDevice.CheckDevice()
|
||||
cA = (*C.scalar_t)(aDevice.AsPointer())
|
||||
} else {
|
||||
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if b.IsOnDevice() {
|
||||
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
|
||||
bDevice := b.(core.DeviceSlice)
|
||||
bDevice.CheckDevice()
|
||||
cB = (*C.scalar_t)(bDevice.AsPointer())
|
||||
} else {
|
||||
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
if out.IsOnDevice() {
|
||||
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
|
||||
outDevice := out.(core.DeviceSlice)
|
||||
outDevice.CheckDevice()
|
||||
cOut = (*C.scalar_t)(outDevice.AsPointer())
|
||||
} else {
|
||||
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
|
||||
}
|
||||
|
||||
@@ -14,7 +14,7 @@ exclude = [
|
||||
]
|
||||
|
||||
[workspace.package]
|
||||
version = "1.6.0"
|
||||
version = "1.9.0"
|
||||
edition = "2021"
|
||||
authors = [ "Ingonyama" ]
|
||||
homepage = "https://www.ingonyama.com"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user