Compare commits

...

19 Commits

Author SHA1 Message Date
Yuval Shekel
919ff42f49 fix: NTT input is const 2024-03-24 16:26:10 +02:00
release-bot
a1ff989740 Bump rust crates' version
icicle-bls12-377@1.9.0
icicle-bls12-381@1.9.0
icicle-bn254@1.9.0
icicle-bw6-761@1.9.0
icicle-core@1.9.0
icicle-cuda-runtime@1.9.0
icicle-grumpkin@1.9.0

Generated by cargo-workspaces
2024-03-21 07:11:47 +00:00
Otsar
1f2144a57c Removed "machines using ICICLE" static badge (#442) 2024-03-21 09:04:19 +02:00
Jeremy Felder
db4c07dcaf Golang bindings for ECNTT (#433) 2024-03-21 09:04:00 +02:00
ChickenLover
d4f39efea3 Add Keccak hash function (#435)
This PR adds support for Keccak-256 and Keccak-512. It only adds them in
c++. There is no way of adding rust or golang wrappers rn as it requires
having an `icicle-common` create / mod
2024-03-20 22:30:19 +02:00
Yuval Shekel
7293058246 fix: (golang) MSM multi device test reset to original device after test is done 2024-03-20 16:27:11 +02:00
Yuval Shekel
03136f1074 fix: (golang) add missing NttAlgorithm field in NTTConfig 2024-03-20 16:27:11 +02:00
Yuval Shekel
3ef0d0c66e MSM scalars and points params are const
- This is required to be able to compute MSM on polynomial coefficients that are accessible by const only.
2024-03-20 16:27:11 +02:00
Stas
0dff1f9302 Use multi-threaded CUDA compilation to spped up compilation (#439)
## Describe the changes

Speed up CUDA c++ compile time using multi-threaded compilation
(--split-compile flag).
The tests on 8 core machine show ~2x acceleration.

## Linked Issues

Compiling c++ takes long time
2024-03-18 16:40:30 -04:00
ChickenLover
0d806d96ca tidy (#437) 2024-03-19 00:59:10 +07:00
release-bot
b6b5011a47 Bump rust crates' version
icicle-bls12-377@1.8.0
icicle-bls12-381@1.8.0
icicle-bn254@1.8.0
icicle-bw6-761@1.8.0
icicle-core@1.8.0
icicle-cuda-runtime@1.8.0
icicle-grumpkin@1.8.0

Generated by cargo-workspaces
2024-03-13 21:38:17 +00:00
DmytroTym
7ac463c3d9 MSM pre-computation (#427)
## Brief description

This PR adds pre-computation to the MSM, for some theory see
[this](https://youtu.be/KAWlySN7Hm8?si=XeR-htjbnK_ySbUo&t=1734) timecode
of Niall Emmart's talk.
In terms of public APIs, one method is added. It does the
pre-computation on-device leaving resulting data on-device as well. No
extra structures are added, only `precompute_factor` from `MSMConfig` is
now activated.

## Performance

While performance gains are for now often limited by our inflexibility
in choice of `c` (for example, very large MSMs get basically no speedup
from pre-compute because currently `c` cannot be larger than 16),
there's still a number of MSM sizes which get noticeable improvement:

| Pre-computation factor | bn254 size `2^20` MSM, ms. | bn254 size
`2^12` MSM, size `2^10` batch, ms. | bls12-381 size `2^20` MSM, ms. |
bls12-381 size `2^12` MSM, size `2^10` batch, ms. |
| ------------- | ------------- | ------------- | ------------- |
------------- |
| 1  | 14.1  | 82.8  | 25.5  | 136.7  |
| 2  | 11.8  | 76.6  | 20.3  | 123.8  |
| 4  | 10.9  | 73.8  | 18.1  | 117.8  |
| 8  | 10.6  | 73.7  | 17.2  | 116.0  |

Here for example pre-computation factor = 4 means that alongside each
original base point, we pre-compute and pass into the MSM 3 of its
"shifted" versions. Pre-computation factor = 1 means no pre-computation.
GPU used for benchmarks is a 3090Ti.

## TODOs and open questions

- Golang APIs are missing;
- I mentioned that to utilise pre-compute to its full potential we need
arbitrary choice of `c`. One issue with this is that pre-compute will
become dependent on `c`. For now this is not the case as `c` can only be
a power of 2 and powers of 2 can always share the same pre-computation.
So apparently we need to make `c` a parameter of the precompute function
to future-proof it from a breaking change. This is pretty unnatural and
counterintuitive as `c` is typically chosen in runtime after pre-compute
is done but I don't really see another way, pls let me know if you do.
UPD: `c` is added into pre-compute function, for now it's unused and
it's documented how it will change in the future.

Resolves https://github.com/ingonyama-zk/icicle/issues/147
Co-authored with @ChickenLover

---------

Co-authored-by: ChickenLover <romangg81@gmail.com>
Co-authored-by: nonam3e <timur@ingonyama.com>
Co-authored-by: nonam3e <71525212+nonam3e@users.noreply.github.com>
Co-authored-by: LeonHibnik <leon@ingonyama.com>
2024-03-13 23:25:16 +02:00
HadarIngonyama
287f53ff16 NTT columns batch (#424)
This PR adds the columns batch feature - enabling batch NTT computation
to be performed directly on the columns of a matrix without having to
transpose it beforehand, as requested in issue #264.

Also some small fixes to the reordering kernels were added and some
unnecessary parameters were removes from functions interfaces.

---------

Co-authored-by: DmytroTym <dmytrotym1@gmail.com>
2024-03-13 18:46:47 +02:00
Jeremy Felder
89082fb561 FEAT: MultiGPU for golang bindings (#417)
## Describe the changes

This PR adds multi gpu support in the golang bindings.

Tha main changes are to DeviceSlice which now includes a `deviceId`
attribute specifying which device the underlying data resides on and
checks for correct deviceId and current device when using DeviceSlices
in any operation.

In Go, most concurrency can be done via Goroutines (described as
lightweight threads - in reality, more of a threadpool manager),
however, there is no guarantee that a goroutine stays on a specific host
thread. Therefore, a function `RunOnDevice` was added to the
cuda_runtime package which locks a goroutine into a specific host
thread, sets a current GPU device, runs a provided function, and unlocks
the goroutine from the host thread after the provided function finishes.
While the goroutine is locked to the hsot thread, the Go runtime will
not assign other goroutines to that host thread
2024-03-13 16:19:45 +02:00
hhh_QC
08ec0b1ff6 update go install source in Dockerfile (#428) 2024-03-10 10:47:08 +02:00
Jeremy Felder
fa219d9c95 Fix release flow with deploy key and caching (#425)
## Describe the changes

This PR fixes the release flow action
2024-03-10 08:57:35 +02:00
DmytroTym
0e84fb4b76 feat: add warmup for CudaStream (#422)
## Describe the changes

Add a non-blocking `warmup` function to `CudaStream` 

> when you run the benchmark (e.g. the msm example you have) the first
instance is always slow, with a constant overhead of 200~300ms cuda
stream warmup. and I want to get rid of that in my application by
warming it up in parallel while my host do something else.
2024-03-07 19:11:34 +02:00
Alex Xiong
d8059a2a4e Merge pull request #1 from ingonyama-zk/feat/warmup
Warmup function added
2024-03-07 18:18:18 +08:00
Jeremy Felder
1abd2ef9c9 Bump rust crates' version
icicle-bls12-377@1.7.0
icicle-bls12-381@1.7.0
icicle-bn254@1.7.0
icicle-bw6-761@1.7.0
icicle-core@1.7.0
icicle-cuda-runtime@1.7.0
icicle-grumpkin@1.7.0

Generated by cargo-workspaces
2024-03-06 22:05:10 +02:00
109 changed files with 2967 additions and 511 deletions

View File

@@ -50,7 +50,7 @@ jobs:
- name: Build
working-directory: ./wrappers/golang
if: needs.check-changed-files.outputs.golang == 'true' || needs.check-changed-files.outputs.cpp_cuda == 'true'
run: ./build.sh ${{ matrix.curve }} ON # builds a single curve with G2 enabled
run: ./build.sh ${{ matrix.curve }} ON ON # builds a single curve with G2 and ECNTT enabled
- name: Upload ICICLE lib artifacts
uses: actions/upload-artifact@v4
if: needs.check-changed-files.outputs.golang == 'true' || needs.check-changed-files.outputs.cpp_cuda == 'true'

View File

@@ -20,11 +20,27 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v4
with:
ssh-key: ${{ secrets.DEPLOY_KEY }}
- name: Setup Cache
id: cache
uses: actions/cache@v4
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
key: ${{ runner.os }}-cargo-${{ hashFiles('~/.cargo/bin/cargo-workspaces') }}
- name: Install cargo-workspaces
if: steps.cache.outputs.cache-hit != 'true'
run: cargo install cargo-workspaces
- name: Bump rust crate versions, commit, and tag
working-directory: wrappers/rust
# https://github.com/pksunkara/cargo-workspaces?tab=readme-ov-file#version
run: |
cargo install cargo-workspaces
git config user.name release-bot
git config user.email release-bot@ingonyama.com
cargo workspaces version ${{ inputs.releaseType }} -y --no-individual-tags -m "Bump rust crates' version"
- name: Create draft release
env:

View File

@@ -15,7 +15,7 @@ ENV PATH="/root/.cargo/bin:${PATH}"
# Install Golang
ENV GOLANG_VERSION 1.21.1
RUN curl -L https://golang.org/dl/go${GOLANG_VERSION}.linux-amd64.tar.gz | tar -xz -C /usr/local
RUN curl -L https://go.dev/dl/go${GOLANG_VERSION}.linux-amd64.tar.gz | tar -xz -C /usr/local
ENV PATH="/usr/local/go/bin:${PATH}"
# Set the working directory in the container

View File

@@ -11,8 +11,6 @@
</a>
<a href="https://twitter.com/intent/follow?screen_name=Ingo_zk">
<img src="https://img.shields.io/twitter/follow/Ingo_zk?style=social&logo=twitter" alt="Follow us on Twitter">
</a>
<img src="https://img.shields.io/badge/Machines%20running%20ICICLE-544-lightblue" alt="Machines running ICICLE">
<a href="https://github.com/ingonyama-zk/icicle/releases">
<img src="https://img.shields.io/github/v/release/ingonyama-zk/icicle" alt="GitHub Release">
</a>

View File

@@ -2,7 +2,7 @@
[![GitHub Release](https://img.shields.io/github/v/release/ingonyama-zk/icicle)](https://github.com/ingonyama-zk/icicle/releases)
![Static Badge](https://img.shields.io/badge/Machines%20running%20ICICLE-544-blue)

View File

@@ -60,7 +60,13 @@ else()
endif()
project(icicle LANGUAGES CUDA CXX)
# Check CUDA version and, if possible, enable multi-threaded compilation
if(CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL "12.2")
message(STATUS "Using multi-threaded CUDA compilation.")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --split-compile 0")
else()
message(STATUS "Can't use multi-threaded CUDA compilation.")
endif()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr")
set(CMAKE_CUDA_FLAGS_RELEASE "")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -g -G -O0")
@@ -90,6 +96,10 @@ if (G2_DEFINED STREQUAL "ON")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DG2_DEFINED=ON")
endif ()
if (ECNTT_DEFINED STREQUAL "ON")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DECNTT_DEFINED=ON")
endif ()
option(BUILD_TESTS "Build tests" OFF)
if (NOT BUILD_TESTS)
@@ -104,6 +114,9 @@ if (NOT BUILD_TESTS)
if (NOT CURVE IN_LIST SUPPORTED_CURVES_WITHOUT_NTT)
list(APPEND ICICLE_SOURCES appUtils/ntt/ntt.cu)
list(APPEND ICICLE_SOURCES appUtils/ntt/kernel_ntt.cu)
if(ECNTT_DEFINED STREQUAL "ON")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DECNTT_DEFINED=ON")
endif()
endif()
add_library(

View File

@@ -0,0 +1,2 @@
test_keccak: test.cu keccak.cu
nvcc -o test_keccak -I. -I../.. test.cu

View File

@@ -0,0 +1,275 @@
#include "keccak.cuh"
namespace keccak {
#define ROTL64(x, y) (((x) << (y)) | ((x) >> (64 - (y))))
#define TH_ELT(t, c0, c1, c2, c3, c4, d0, d1, d2, d3, d4) \
{ \
t = ROTL64((d0 ^ d1 ^ d2 ^ d3 ^ d4), 1) ^ (c0 ^ c1 ^ c2 ^ c3 ^ c4); \
}
#define THETA( \
s00, s01, s02, s03, s04, s10, s11, s12, s13, s14, s20, s21, s22, s23, s24, s30, s31, s32, s33, s34, s40, s41, s42, \
s43, s44) \
{ \
TH_ELT(t0, s40, s41, s42, s43, s44, s10, s11, s12, s13, s14); \
TH_ELT(t1, s00, s01, s02, s03, s04, s20, s21, s22, s23, s24); \
TH_ELT(t2, s10, s11, s12, s13, s14, s30, s31, s32, s33, s34); \
TH_ELT(t3, s20, s21, s22, s23, s24, s40, s41, s42, s43, s44); \
TH_ELT(t4, s30, s31, s32, s33, s34, s00, s01, s02, s03, s04); \
s00 ^= t0; \
s01 ^= t0; \
s02 ^= t0; \
s03 ^= t0; \
s04 ^= t0; \
\
s10 ^= t1; \
s11 ^= t1; \
s12 ^= t1; \
s13 ^= t1; \
s14 ^= t1; \
\
s20 ^= t2; \
s21 ^= t2; \
s22 ^= t2; \
s23 ^= t2; \
s24 ^= t2; \
\
s30 ^= t3; \
s31 ^= t3; \
s32 ^= t3; \
s33 ^= t3; \
s34 ^= t3; \
\
s40 ^= t4; \
s41 ^= t4; \
s42 ^= t4; \
s43 ^= t4; \
s44 ^= t4; \
}
#define RHOPI( \
s00, s01, s02, s03, s04, s10, s11, s12, s13, s14, s20, s21, s22, s23, s24, s30, s31, s32, s33, s34, s40, s41, s42, \
s43, s44) \
{ \
t0 = ROTL64(s10, (uint64_t)1); \
s10 = ROTL64(s11, (uint64_t)44); \
s11 = ROTL64(s41, (uint64_t)20); \
s41 = ROTL64(s24, (uint64_t)61); \
s24 = ROTL64(s42, (uint64_t)39); \
s42 = ROTL64(s04, (uint64_t)18); \
s04 = ROTL64(s20, (uint64_t)62); \
s20 = ROTL64(s22, (uint64_t)43); \
s22 = ROTL64(s32, (uint64_t)25); \
s32 = ROTL64(s43, (uint64_t)8); \
s43 = ROTL64(s34, (uint64_t)56); \
s34 = ROTL64(s03, (uint64_t)41); \
s03 = ROTL64(s40, (uint64_t)27); \
s40 = ROTL64(s44, (uint64_t)14); \
s44 = ROTL64(s14, (uint64_t)2); \
s14 = ROTL64(s31, (uint64_t)55); \
s31 = ROTL64(s13, (uint64_t)45); \
s13 = ROTL64(s01, (uint64_t)36); \
s01 = ROTL64(s30, (uint64_t)28); \
s30 = ROTL64(s33, (uint64_t)21); \
s33 = ROTL64(s23, (uint64_t)15); \
s23 = ROTL64(s12, (uint64_t)10); \
s12 = ROTL64(s21, (uint64_t)6); \
s21 = ROTL64(s02, (uint64_t)3); \
s02 = t0; \
}
#define KHI( \
s00, s01, s02, s03, s04, s10, s11, s12, s13, s14, s20, s21, s22, s23, s24, s30, s31, s32, s33, s34, s40, s41, s42, \
s43, s44) \
{ \
t0 = s00 ^ (~s10 & s20); \
t1 = s10 ^ (~s20 & s30); \
t2 = s20 ^ (~s30 & s40); \
t3 = s30 ^ (~s40 & s00); \
t4 = s40 ^ (~s00 & s10); \
s00 = t0; \
s10 = t1; \
s20 = t2; \
s30 = t3; \
s40 = t4; \
\
t0 = s01 ^ (~s11 & s21); \
t1 = s11 ^ (~s21 & s31); \
t2 = s21 ^ (~s31 & s41); \
t3 = s31 ^ (~s41 & s01); \
t4 = s41 ^ (~s01 & s11); \
s01 = t0; \
s11 = t1; \
s21 = t2; \
s31 = t3; \
s41 = t4; \
\
t0 = s02 ^ (~s12 & s22); \
t1 = s12 ^ (~s22 & s32); \
t2 = s22 ^ (~s32 & s42); \
t3 = s32 ^ (~s42 & s02); \
t4 = s42 ^ (~s02 & s12); \
s02 = t0; \
s12 = t1; \
s22 = t2; \
s32 = t3; \
s42 = t4; \
\
t0 = s03 ^ (~s13 & s23); \
t1 = s13 ^ (~s23 & s33); \
t2 = s23 ^ (~s33 & s43); \
t3 = s33 ^ (~s43 & s03); \
t4 = s43 ^ (~s03 & s13); \
s03 = t0; \
s13 = t1; \
s23 = t2; \
s33 = t3; \
s43 = t4; \
\
t0 = s04 ^ (~s14 & s24); \
t1 = s14 ^ (~s24 & s34); \
t2 = s24 ^ (~s34 & s44); \
t3 = s34 ^ (~s44 & s04); \
t4 = s44 ^ (~s04 & s14); \
s04 = t0; \
s14 = t1; \
s24 = t2; \
s34 = t3; \
s44 = t4; \
}
#define IOTA(element, rc) \
{ \
element ^= rc; \
}
__device__ const uint64_t RC[24] = {0x0000000000000001, 0x0000000000008082, 0x800000000000808a, 0x8000000080008000,
0x000000000000808b, 0x0000000080000001, 0x8000000080008081, 0x8000000000008009,
0x000000000000008a, 0x0000000000000088, 0x0000000080008009, 0x000000008000000a,
0x000000008000808b, 0x800000000000008b, 0x8000000000008089, 0x8000000000008003,
0x8000000000008002, 0x8000000000000080, 0x000000000000800a, 0x800000008000000a,
0x8000000080008081, 0x8000000000008080, 0x0000000080000001, 0x8000000080008008};
__device__ void keccakf(uint64_t s[25])
{
uint64_t t0, t1, t2, t3, t4;
for (int i = 0; i < 24; i++) {
THETA(
s[0], s[5], s[10], s[15], s[20], s[1], s[6], s[11], s[16], s[21], s[2], s[7], s[12], s[17], s[22], s[3], s[8],
s[13], s[18], s[23], s[4], s[9], s[14], s[19], s[24]);
RHOPI(
s[0], s[5], s[10], s[15], s[20], s[1], s[6], s[11], s[16], s[21], s[2], s[7], s[12], s[17], s[22], s[3], s[8],
s[13], s[18], s[23], s[4], s[9], s[14], s[19], s[24]);
KHI(
s[0], s[5], s[10], s[15], s[20], s[1], s[6], s[11], s[16], s[21], s[2], s[7], s[12], s[17], s[22], s[3], s[8],
s[13], s[18], s[23], s[4], s[9], s[14], s[19], s[24]);
IOTA(s[0], RC[i]);
}
}
template <int C, int D>
__global__ void keccak_hash_blocks(uint8_t* input, int input_block_size, int number_of_blocks, uint8_t* output)
{
int bid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (bid >= number_of_blocks) { return; }
const int r_bits = 1600 - C;
const int r_bytes = r_bits / 8;
const int d_bytes = D / 8;
uint8_t* b_input = input + bid * input_block_size;
uint8_t* b_output = output + bid * d_bytes;
uint64_t state[25] = {}; // Initialize with zeroes
int input_len = input_block_size;
// absorb
while (input_len >= r_bytes) {
// #pragma unroll
for (int i = 0; i < r_bytes; i += 8) {
state[i / 8] ^= *(uint64_t*)(b_input + i);
}
keccakf(state);
b_input += r_bytes;
input_len -= r_bytes;
}
// last block (if any)
uint8_t last_block[r_bytes];
for (int i = 0; i < input_len; i++) {
last_block[i] = b_input[i];
}
// pad 10*1
last_block[input_len] = 1;
for (int i = 0; i < r_bytes - input_len - 1; i++) {
last_block[input_len + i + 1] = 0;
}
// last bit
last_block[r_bytes - 1] |= 0x80;
// #pragma unroll
for (int i = 0; i < r_bytes; i += 8) {
state[i / 8] ^= *(uint64_t*)(last_block + i);
}
keccakf(state);
#pragma unroll
for (int i = 0; i < d_bytes; i += 8) {
*(uint64_t*)(b_output + i) = state[i / 8];
}
}
template <int C, int D>
cudaError_t
keccak_hash(uint8_t* input, int input_block_size, int number_of_blocks, uint8_t* output, KeccakConfig config)
{
CHK_INIT_IF_RETURN();
cudaStream_t& stream = config.ctx.stream;
uint8_t* input_device;
if (config.are_inputs_on_device) {
input_device = input;
} else {
CHK_IF_RETURN(cudaMallocAsync(&input_device, number_of_blocks * input_block_size, stream));
CHK_IF_RETURN(
cudaMemcpyAsync(input_device, input, number_of_blocks * input_block_size, cudaMemcpyHostToDevice, stream));
}
uint8_t* output_device;
if (config.are_outputs_on_device) {
output_device = output;
} else {
CHK_IF_RETURN(cudaMallocAsync(&output_device, number_of_blocks * (D / 8), stream));
}
int number_of_threads = 1024;
int number_of_gpu_blocks = (number_of_blocks - 1) / number_of_threads + 1;
keccak_hash_blocks<C, D><<<number_of_gpu_blocks, number_of_threads, 0, stream>>>(
input_device, input_block_size, number_of_blocks, output_device);
if (!config.are_inputs_on_device) CHK_IF_RETURN(cudaFreeAsync(input_device, stream));
if (!config.are_outputs_on_device) {
CHK_IF_RETURN(cudaMemcpyAsync(output, output_device, number_of_blocks * (D / 8), cudaMemcpyDeviceToHost, stream));
CHK_IF_RETURN(cudaFreeAsync(output_device, stream));
}
if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(stream));
return CHK_LAST();
}
extern "C" cudaError_t
Keccak256(uint8_t* input, int input_block_size, int number_of_blocks, uint8_t* output, KeccakConfig config)
{
return keccak_hash<512, 256>(input, input_block_size, number_of_blocks, output, config);
}
extern "C" cudaError_t
Keccak512(uint8_t* input, int input_block_size, int number_of_blocks, uint8_t* output, KeccakConfig config)
{
return keccak_hash<1024, 512>(input, input_block_size, number_of_blocks, output, config);
}
} // namespace keccak

View File

@@ -0,0 +1,56 @@
#pragma once
#ifndef KECCAK_H
#define KECCAK_H
#include <cstdint>
#include "utils/device_context.cuh"
#include "utils/error_handler.cuh"
namespace keccak {
/**
* @struct KeccakConfig
* Struct that encodes various Keccak parameters.
*/
struct KeccakConfig {
device_context::DeviceContext ctx; /**< Details related to the device such as its id and stream id. */
bool are_inputs_on_device; /**< True if inputs are on device and false if they're on host. Default value: false. */
bool are_outputs_on_device; /**< If true, output is preserved on device, otherwise on host. Default value: false. */
bool is_async; /**< Whether to run the Keccak asynchronously. If set to `true`, the keccak_hash function will be
* non-blocking and you'd need to synchronize it explicitly by running
* `cudaStreamSynchronize` or `cudaDeviceSynchronize`. If set to false, keccak_hash
* function will block the current CPU thread. */
};
KeccakConfig default_keccak_config()
{
device_context::DeviceContext ctx = device_context::get_default_device_context();
KeccakConfig config = {
ctx, // ctx
false, // are_inputes_on_device
false, // are_outputs_on_device
false, // is_async
};
return config;
}
/**
* Compute the keccak hash over a sequence of preimages.
* Takes {number_of_blocks * input_block_size} u64s of input and computes {number_of_blocks} outputs, each of size {D
* / 64} u64
* @tparam C - number of bits of capacity (c = b - r = 1600 - r). Only multiples of 64 are supported.
* @tparam D - number of bits of output. Only multiples of 64 are supported.
* @param input a pointer to the input data. May be allocated on device or on host, regulated
* by the config. Must be of size [input_block_size](@ref input_block_size) * [number_of_blocks](@ref
* number_of_blocks)}.
* @param input_block_size - size of each input block in bytes. Should be divisible by 8.
* @param number_of_blocks number of input and output blocks. One GPU thread processes one block
* @param output a pointer to the output data. May be allocated on device or on host, regulated
* by the config. Must be of size [output_block_size](@ref output_block_size) * [number_of_blocks](@ref
* number_of_blocks)}
*/
template <int C, int D>
cudaError_t
keccak_hash(uint8_t* input, int input_block_size, int number_of_blocks, uint8_t* output, KeccakConfig config);
} // namespace keccak
#endif

View File

@@ -0,0 +1,67 @@
#include "utils/device_context.cuh"
#include "keccak.cu"
// #define DEBUG
#ifndef __CUDA_ARCH__
#include <cassert>
#include <chrono>
#include <fstream>
#include <iostream>
#include <iomanip>
using namespace keccak;
#define D 256
#define START_TIMER(timer) auto timer##_start = std::chrono::high_resolution_clock::now();
#define END_TIMER(timer, msg) \
printf("%s: %.0f ms\n", msg, FpMilliseconds(std::chrono::high_resolution_clock::now() - timer##_start).count());
void uint8ToHexString(const uint8_t* values, int size)
{
std::stringstream ss;
for (int i = 0; i < size; ++i) {
ss << std::hex << std::setw(2) << std::setfill('0') << (int)values[i];
}
std::string hexString = ss.str();
std::cout << hexString << std::endl;
}
int main(int argc, char* argv[])
{
using FpMilliseconds = std::chrono::duration<float, std::chrono::milliseconds::period>;
using FpMicroseconds = std::chrono::duration<float, std::chrono::microseconds::period>;
START_TIMER(allocation_timer);
// Prepare input data of [0, 1, 2 ... (number_of_blocks * input_block_size) - 1]
int number_of_blocks = argc > 1 ? 1 << atoi(argv[1]) : 1024;
int input_block_size = argc > 2 ? atoi(argv[2]) : 136;
uint8_t* in_ptr = static_cast<uint8_t*>(malloc(number_of_blocks * input_block_size));
for (uint64_t i = 0; i < number_of_blocks * input_block_size; i++) {
in_ptr[i] = (uint8_t)i;
}
END_TIMER(allocation_timer, "Allocate mem and fill input");
uint8_t* out_ptr = static_cast<uint8_t*>(malloc(number_of_blocks * (D / 8)));
START_TIMER(keccak_timer);
KeccakConfig config = default_keccak_config();
Keccak256(in_ptr, input_block_size, number_of_blocks, out_ptr, config);
END_TIMER(keccak_timer, "Keccak")
for (int i = 0; i < number_of_blocks; i++) {
#ifdef DEBUG
uint8ToHexString(out_ptr + i * (D / 8), D / 8);
#endif
}
free(in_ptr);
free(out_ptr);
}
#endif

View File

@@ -1,4 +1,4 @@
test_msm:
mkdir -p work
nvcc -o work/test_msm -std=c++17 -I. -I../.. tests/msm_test.cu
work/test_msm
work/test_msm

View File

@@ -25,9 +25,19 @@ namespace msm {
#define MAX_TH 256
// #define SIGNED_DIG //WIP
// #define BIG_TRIANGLE
// #define SSM_SUM //WIP
template <typename A, typename P>
__global__ void left_shift_kernel(A* points, const unsigned shift, const unsigned count, A* points_out)
{
const unsigned tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= count) return;
P point = P::from_affine(points[tid]);
for (unsigned i = 0; i < shift; i++)
point = P::dbl(point);
points_out[tid] = P::to_affine(point);
}
unsigned get_optimal_c(int bitsize) { return max((unsigned)ceil(log2(bitsize)) - 4, 1U); }
template <typename E>
@@ -148,47 +158,38 @@ namespace msm {
__global__ void split_scalars_kernel(
unsigned* buckets_indices,
unsigned* point_indices,
S* scalars,
const S* scalars,
unsigned nof_scalars,
unsigned points_size,
unsigned msm_size,
unsigned nof_bms,
unsigned bm_bitsize,
unsigned c)
unsigned c,
unsigned precomputed_bms_stride)
{
// constexpr unsigned sign_mask = 0x80000000;
// constexpr unsigned trash_bucket = 0x80000000;
unsigned tid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (tid >= nof_scalars) return;
unsigned bucket_index;
// unsigned bucket_index2;
unsigned current_index;
unsigned msm_index = tid / msm_size;
// unsigned borrow = 0;
S& scalar = scalars[tid];
const S& scalar = scalars[tid];
for (unsigned bm = 0; bm < nof_bms; bm++) {
const unsigned precomputed_index = bm / precomputed_bms_stride;
const unsigned target_bm = bm % precomputed_bms_stride;
bucket_index = scalar.get_scalar_digit(bm, c);
#ifdef SIGNED_DIG
bucket_index += borrow;
borrow = 0;
unsigned sign = 0;
if (bucket_index > (1 << (c - 1))) {
bucket_index = (1 << c) - bucket_index;
borrow = 1;
sign = sign_mask;
}
#endif
current_index = bm * nof_scalars + tid;
#ifdef SIGNED_DIG
point_indices[current_index] = sign | tid; // the point index is saved for later
#else
buckets_indices[current_index] =
(msm_index << (c + bm_bitsize)) | (bm << c) |
bucket_index; // the bucket module number and the msm number are appended at the msbs
if (bucket_index == 0) buckets_indices[current_index] = 0; // will be skipped
point_indices[current_index] = tid % points_size; // the point index is saved for later
#endif
if (bucket_index != 0) {
buckets_indices[current_index] =
(msm_index << (c + bm_bitsize)) | (target_bm << c) |
bucket_index; // the bucket module number and the msm number are appended at the msbs
} else {
buckets_indices[current_index] = 0; // will be skipped
}
point_indices[current_index] =
tid % points_size + points_size * precomputed_index; // the point index is saved for later
}
}
@@ -223,19 +224,11 @@ namespace msm {
const unsigned msm_idx_shift,
const unsigned c)
{
// constexpr unsigned sign_mask = 0x80000000;
unsigned tid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (tid >= nof_buckets_to_compute) return;
#ifdef SIGNED_DIG // todo - fix
const unsigned msm_index = single_bucket_indices[tid] >> msm_idx_shift;
const unsigned bm_index = (single_bucket_indices[tid] & ((1 << msm_idx_shift) - 1)) >> c;
const unsigned bucket_index =
msm_index * nof_buckets + bm_index * ((1 << (c - 1)) + 1) + (single_bucket_indices[tid] & ((1 << c) - 1));
#else
unsigned msm_index = single_bucket_indices[tid] >> msm_idx_shift;
const unsigned single_bucket_index = (single_bucket_indices[tid] & ((1 << msm_idx_shift) - 1));
unsigned bucket_index = msm_index * nof_buckets + single_bucket_index;
#endif
const unsigned bucket_offset = bucket_offsets[tid];
const unsigned bucket_size = bucket_sizes[tid];
@@ -243,14 +236,7 @@ namespace msm {
for (unsigned i = 0; i < bucket_size;
i++) { // add the relevant points starting from the relevant offset up to the bucket size
unsigned point_ind = point_indices[bucket_offset + i];
#ifdef SIGNED_DIG
unsigned sign = point_ind & sign_mask;
point_ind &= ~sign_mask;
A point = points[point_ind];
if (sign) point = A::neg(point);
#else
A point = points[point_ind];
#endif
bucket =
i ? (point == A::zero() ? bucket : bucket + point) : (point == A::zero() ? P::zero() : P::from_affine(point));
}
@@ -317,11 +303,7 @@ namespace msm {
{
unsigned tid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (tid >= nof_bms) return;
#ifdef SIGNED_DIG
unsigned buckets_in_bm = (1 << c) + 1;
#else
unsigned buckets_in_bm = (1 << c);
#endif
P line_sum = buckets[(tid + 1) * buckets_in_bm - 1];
final_sums[tid] = line_sum;
for (unsigned i = buckets_in_bm - 2; i > 0; i--) {
@@ -378,8 +360,8 @@ namespace msm {
cudaError_t bucket_method_msm(
unsigned bitsize,
unsigned c,
S* scalars,
A* points,
const S* scalars,
const A* points,
unsigned batch_size, // number of MSMs to compute
unsigned single_msm_size, // number of elements per MSM (a.k.a N)
unsigned nof_points, // number of EC points in 'points' array. Must be either (1) single_msm_size if MSMs are
@@ -392,6 +374,7 @@ namespace msm {
bool are_results_on_device,
bool is_big_triangle,
int large_bucket_factor,
int precompute_factor,
bool is_async,
cudaStream_t stream)
{
@@ -403,44 +386,59 @@ namespace msm {
THROW_ICICLE_ERR(
IcicleError_t::InvalidArgument, "bucket_method_msm: #points must be divisible by single_msm_size*batch_size");
}
if ((precompute_factor & (precompute_factor - 1)) != 0) {
THROW_ICICLE_ERR(
IcicleError_t::InvalidArgument,
"bucket_method_msm: precompute factors that are not powers of 2 currently unsupported");
}
S* d_scalars;
const S* d_scalars;
S* d_allocated_scalars = nullptr;
if (!are_scalars_on_device) {
// copy scalars to gpu
CHK_IF_RETURN(cudaMallocAsync(&d_scalars, sizeof(S) * nof_scalars, stream));
CHK_IF_RETURN(cudaMemcpyAsync(d_scalars, scalars, sizeof(S) * nof_scalars, cudaMemcpyHostToDevice, stream));
} else {
d_scalars = scalars;
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_scalars, sizeof(S) * nof_scalars, stream));
CHK_IF_RETURN(
cudaMemcpyAsync(d_allocated_scalars, scalars, sizeof(S) * nof_scalars, cudaMemcpyHostToDevice, stream));
if (are_scalars_montgomery_form) {
CHK_IF_RETURN(mont::FromMontgomery(d_allocated_scalars, nof_scalars, stream, d_allocated_scalars));
}
d_scalars = d_allocated_scalars;
} else { // already on device
if (are_scalars_montgomery_form) {
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_scalars, sizeof(S) * nof_scalars, stream));
CHK_IF_RETURN(mont::FromMontgomery(scalars, nof_scalars, stream, d_allocated_scalars));
d_scalars = d_allocated_scalars;
} else {
d_scalars = scalars;
}
}
if (are_scalars_montgomery_form) {
if (are_scalars_on_device) {
S* d_mont_scalars;
CHK_IF_RETURN(cudaMallocAsync(&d_mont_scalars, sizeof(S) * nof_scalars, stream));
CHK_IF_RETURN(mont::FromMontgomery(d_scalars, nof_scalars, stream, d_mont_scalars));
d_scalars = d_mont_scalars;
} else
CHK_IF_RETURN(mont::FromMontgomery(d_scalars, nof_scalars, stream, d_scalars));
}
unsigned total_bms_per_msm = (bitsize + c - 1) / c;
unsigned nof_bms_per_msm = (total_bms_per_msm - 1) / precompute_factor + 1;
unsigned input_indexes_count = nof_scalars * total_bms_per_msm;
unsigned bm_bitsize = (unsigned)ceil(log2(nof_bms_per_msm));
unsigned nof_bms_per_msm = (bitsize + c - 1) / c;
unsigned* bucket_indices;
unsigned* point_indices;
unsigned* sorted_bucket_indices;
unsigned* sorted_point_indices;
CHK_IF_RETURN(cudaMallocAsync(&bucket_indices, sizeof(unsigned) * nof_scalars * nof_bms_per_msm, stream));
CHK_IF_RETURN(cudaMallocAsync(&point_indices, sizeof(unsigned) * nof_scalars * nof_bms_per_msm, stream));
CHK_IF_RETURN(cudaMallocAsync(&sorted_bucket_indices, sizeof(unsigned) * nof_scalars * nof_bms_per_msm, stream));
CHK_IF_RETURN(cudaMallocAsync(&sorted_point_indices, sizeof(unsigned) * nof_scalars * nof_bms_per_msm, stream));
CHK_IF_RETURN(cudaMallocAsync(&bucket_indices, sizeof(unsigned) * input_indexes_count, stream));
CHK_IF_RETURN(cudaMallocAsync(&point_indices, sizeof(unsigned) * input_indexes_count, stream));
CHK_IF_RETURN(cudaMallocAsync(&sorted_bucket_indices, sizeof(unsigned) * input_indexes_count, stream));
CHK_IF_RETURN(cudaMallocAsync(&sorted_point_indices, sizeof(unsigned) * input_indexes_count, stream));
unsigned bm_bitsize = (unsigned)ceil(log2(nof_bms_per_msm));
// split scalars into digits
unsigned NUM_THREADS = 1 << 10;
unsigned NUM_BLOCKS = (nof_scalars + NUM_THREADS - 1) / NUM_THREADS;
split_scalars_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
bucket_indices, point_indices, d_scalars, nof_scalars, nof_points, single_msm_size, nof_bms_per_msm, bm_bitsize,
c);
split_scalars_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
bucket_indices, point_indices, d_scalars, nof_scalars, nof_points, single_msm_size, total_bms_per_msm,
bm_bitsize, c, nof_bms_per_msm);
nof_points *= precompute_factor;
// ------------------------------ Sorting routines for scalars start here ----------------------------------
// sort indices - the indices are sorted from smallest to largest in order to group together the points that
// belong to each bucket
unsigned* sort_indices_temp_storage{};
@@ -450,26 +448,22 @@ namespace msm {
// more info
CHK_IF_RETURN(cub::DeviceRadixSort::SortPairs(
sort_indices_temp_storage, sort_indices_temp_storage_bytes, bucket_indices, sorted_bucket_indices,
point_indices, sorted_point_indices, nof_scalars * nof_bms_per_msm, 0, sizeof(unsigned) * 8, stream));
point_indices, sorted_point_indices, input_indexes_count, 0, sizeof(unsigned) * 8, stream));
CHK_IF_RETURN(cudaMallocAsync(&sort_indices_temp_storage, sort_indices_temp_storage_bytes, stream));
// The second to last parameter is the default value supplied explicitly to allow passing the stream
// See https://nvlabs.github.io/cub/structcub_1_1_device_radix_sort.html#a65e82152de448c6373ed9563aaf8af7e for
// more info
CHK_IF_RETURN(cub::DeviceRadixSort::SortPairs(
sort_indices_temp_storage, sort_indices_temp_storage_bytes, bucket_indices, sorted_bucket_indices,
point_indices, sorted_point_indices, nof_scalars * nof_bms_per_msm, 0, sizeof(unsigned) * 8, stream));
point_indices, sorted_point_indices, input_indexes_count, 0, sizeof(unsigned) * 8, stream));
CHK_IF_RETURN(cudaFreeAsync(sort_indices_temp_storage, stream));
CHK_IF_RETURN(cudaFreeAsync(bucket_indices, stream));
CHK_IF_RETURN(cudaFreeAsync(point_indices, stream));
// compute number of bucket modules and number of buckets in each module
unsigned nof_bms_in_batch = nof_bms_per_msm * batch_size;
#ifdef SIGNED_DIG
const unsigned nof_buckets = nof_bms_per_msm * ((1 << (c - 1)) + 1); // signed digits
#else
// minus nof_bms_per_msm because zero bucket is not included in each bucket module
const unsigned nof_buckets = (nof_bms_per_msm << c) - nof_bms_per_msm;
#endif
const unsigned total_nof_buckets = nof_buckets * batch_size;
// find bucket_sizes
@@ -484,11 +478,11 @@ namespace msm {
size_t encode_temp_storage_bytes = 0;
CHK_IF_RETURN(cub::DeviceRunLengthEncode::Encode(
encode_temp_storage, encode_temp_storage_bytes, sorted_bucket_indices, single_bucket_indices, bucket_sizes,
nof_buckets_to_compute, nof_bms_per_msm * nof_scalars, stream));
nof_buckets_to_compute, input_indexes_count, stream));
CHK_IF_RETURN(cudaMallocAsync(&encode_temp_storage, encode_temp_storage_bytes, stream));
CHK_IF_RETURN(cub::DeviceRunLengthEncode::Encode(
encode_temp_storage, encode_temp_storage_bytes, sorted_bucket_indices, single_bucket_indices, bucket_sizes,
nof_buckets_to_compute, nof_bms_per_msm * nof_scalars, stream));
nof_buckets_to_compute, input_indexes_count, stream));
CHK_IF_RETURN(cudaFreeAsync(encode_temp_storage, stream));
CHK_IF_RETURN(cudaFreeAsync(sorted_bucket_indices, stream));
@@ -504,28 +498,33 @@ namespace msm {
offsets_temp_storage, offsets_temp_storage_bytes, bucket_sizes, bucket_offsets, total_nof_buckets + 1, stream));
CHK_IF_RETURN(cudaFreeAsync(offsets_temp_storage, stream));
A* d_points;
cudaStream_t stream_points;
// ----------- Starting to upload points (if they were on host) in parallel to scalar sorting ----------------
const A* d_points;
A* d_allocated_points = nullptr;
cudaStream_t stream_points = nullptr;
if (!are_points_on_device || are_points_montgomery_form) CHK_IF_RETURN(cudaStreamCreate(&stream_points));
if (!are_points_on_device) {
// copy points to gpu
CHK_IF_RETURN(cudaMallocAsync(&d_points, sizeof(A) * nof_points, stream_points));
CHK_IF_RETURN(cudaMemcpyAsync(d_points, points, sizeof(A) * nof_points, cudaMemcpyHostToDevice, stream_points));
} else {
d_points = points;
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_points, sizeof(A) * nof_points, stream_points));
CHK_IF_RETURN(
cudaMemcpyAsync(d_allocated_points, points, sizeof(A) * nof_points, cudaMemcpyHostToDevice, stream_points));
if (are_points_montgomery_form) {
CHK_IF_RETURN(mont::FromMontgomery(d_allocated_points, nof_points, stream_points, d_allocated_points));
}
d_points = d_allocated_points;
} else { // already on device
if (are_points_montgomery_form) {
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_points, sizeof(A) * nof_points, stream_points));
CHK_IF_RETURN(mont::FromMontgomery(points, nof_points, stream_points, d_allocated_points));
d_points = d_allocated_points;
} else {
d_points = points;
}
}
if (are_points_montgomery_form) {
if (are_points_on_device) {
A* d_mont_points;
CHK_IF_RETURN(cudaMallocAsync(&d_mont_points, sizeof(A) * nof_points, stream_points));
CHK_IF_RETURN(mont::FromMontgomery(d_points, nof_points, stream_points, d_mont_points));
d_points = d_mont_points;
} else
CHK_IF_RETURN(mont::FromMontgomery(d_points, nof_points, stream_points, d_points));
}
cudaEvent_t event_points_uploaded;
if (!are_points_on_device || are_points_montgomery_form) {
if (stream_points) {
CHK_IF_RETURN(cudaEventCreateWithFlags(&event_points_uploaded, cudaEventDisableTiming));
CHK_IF_RETURN(cudaEventRecord(event_points_uploaded, stream_points));
}
@@ -609,7 +608,7 @@ namespace msm {
cudaMemcpyAsync(&h_nof_large_buckets, nof_large_buckets, sizeof(unsigned), cudaMemcpyDeviceToHost, stream));
CHK_IF_RETURN(cudaFreeAsync(nof_large_buckets, stream));
if (!are_points_on_device || are_points_montgomery_form) {
if (stream_points) {
// by this point, points need to be already uploaded and un-Montgomeried
CHK_IF_RETURN(cudaStreamWaitEvent(stream, event_points_uploaded));
CHK_IF_RETURN(cudaEventDestroy(event_points_uploaded));
@@ -618,7 +617,7 @@ namespace msm {
cudaStream_t stream_large_buckets;
cudaEvent_t event_large_buckets_accumulated;
// this is where handling of large buckets happens (if there are any)
// ---------------- This is where handling of large buckets happens (if there are any) -------------
if (h_nof_large_buckets > 0 && bucket_th > 0) {
CHK_IF_RETURN(cudaStreamCreate(&stream_large_buckets));
CHK_IF_RETURN(cudaEventCreateWithFlags(&event_large_buckets_accumulated, cudaEventDisableTiming));
@@ -700,10 +699,11 @@ namespace msm {
CHK_IF_RETURN(cudaEventRecord(event_large_buckets_accumulated, stream_large_buckets));
}
// launch the accumulation kernel with maximum threads
// ------------------------- Accumulation of (non-large) buckets ---------------------------------
if (h_nof_buckets_to_compute > h_nof_large_buckets) {
NUM_THREADS = 1 << 8;
NUM_BLOCKS = (h_nof_buckets_to_compute - h_nof_large_buckets + NUM_THREADS - 1) / NUM_THREADS;
// launch the accumulation kernel with maximum threads
accumulate_buckets_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
buckets, sorted_bucket_offsets + h_nof_large_buckets, sorted_bucket_sizes + h_nof_large_buckets,
sorted_single_bucket_indices + h_nof_large_buckets, sorted_point_indices, d_points,
@@ -719,24 +719,11 @@ namespace msm {
CHK_IF_RETURN(cudaStreamDestroy(stream_large_buckets));
}
#ifdef SSM_SUM
// sum each bucket
NUM_THREADS = 1 << 10;
NUM_BLOCKS = (nof_buckets + NUM_THREADS - 1) / NUM_THREADS;
ssm_buckets_kernel<fake_point, fake_scalar>
<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(buckets, single_bucket_indices, nof_buckets, c);
// sum each bucket module
P* final_results;
CHK_IF_RETURN(cudaMallocAsync(&final_results, sizeof(P) * nof_bms_per_msm, stream));
NUM_THREADS = 1 << c;
NUM_BLOCKS = nof_bms_per_msm;
sum_reduction_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(buckets, final_results);
#endif
P* d_final_result;
if (!are_results_on_device) CHK_IF_RETURN(cudaMallocAsync(&d_final_result, sizeof(P) * batch_size, stream));
P* d_allocated_final_result = nullptr;
if (!are_results_on_device)
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_final_result, sizeof(P) * batch_size, stream));
// --- Reduction of buckets happens here, after this we'll get a single sum for each bucket module/window ---
unsigned nof_empty_bms_per_batch = 0; // for non-triangle accumluation this may be >0
P* final_results;
if (is_big_triangle || c == 1) {
@@ -744,15 +731,9 @@ namespace msm {
// launch the bucket module sum kernel - a thread for each bucket module
NUM_THREADS = 32;
NUM_BLOCKS = (nof_bms_in_batch + NUM_THREADS - 1) / NUM_THREADS;
#ifdef SIGNED_DIG
big_triangle_sum_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
buckets, final_results, nof_bms_in_batch, c - 1); // sighed digits
#else
big_triangle_sum_kernel<<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(buckets, final_results, nof_bms_in_batch, c);
#endif
} else {
unsigned source_bits_count = c;
// bool odd_source_c = source_bits_count % 2;
unsigned source_windows_count = nof_bms_per_msm;
unsigned source_buckets_count = nof_buckets + nof_bms_per_msm;
unsigned target_windows_count = 0;
@@ -792,12 +773,11 @@ namespace msm {
}
}
if (target_bits_count == 1) {
// Note: the reduction ends up with 'target_windows_count' windows per batch element. Some are guaranteed to
// be empty when target_windows_count>bitsize.
// for example consider bitsize=253 and c=2. The reduction ends with 254 bms but the most significant one is
// guaranteed to be zero since the scalars are 253b.
// Note: the reduction ends up with 'target_windows_count' windows per batch element. Some are guaranteed
// to be empty when target_windows_count>bitsize. for example consider bitsize=253 and c=2. The reduction
// ends with 254 bms but the most significant one is guaranteed to be zero since the scalars are 253b.
nof_bms_per_msm = target_windows_count;
nof_empty_bms_per_batch = target_windows_count - bitsize;
nof_empty_bms_per_batch = target_windows_count > bitsize ? target_windows_count - bitsize : 0;
nof_bms_in_batch = nof_bms_per_msm * batch_size;
CHK_IF_RETURN(cudaMallocAsync(&final_results, sizeof(P) * nof_bms_in_batch, stream));
@@ -819,28 +799,29 @@ namespace msm {
temp_buckets1 = nullptr;
temp_buckets2 = nullptr;
source_bits_count = target_bits_count;
// odd_source_c = source_bits_count % 2;
source_windows_count = target_windows_count;
source_buckets_count = target_buckets_count;
}
}
// launch the double and add kernel, a single thread per batch element
// ------- This is the final stage where bucket modules/window sums get added up with appropriate weights
// -------
NUM_THREADS = 32;
NUM_BLOCKS = (batch_size + NUM_THREADS - 1) / NUM_THREADS;
// launch the double and add kernel, a single thread per batch element
final_accumulation_kernel<P, S><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
final_results, are_results_on_device ? final_result : d_final_result, batch_size, nof_bms_per_msm,
final_results, are_results_on_device ? final_result : d_allocated_final_result, batch_size, nof_bms_per_msm,
nof_empty_bms_per_batch, c);
CHK_IF_RETURN(cudaFreeAsync(final_results, stream));
if (!are_results_on_device)
CHK_IF_RETURN(
cudaMemcpyAsync(final_result, d_final_result, sizeof(P) * batch_size, cudaMemcpyDeviceToHost, stream));
CHK_IF_RETURN(cudaMemcpyAsync(
final_result, d_allocated_final_result, sizeof(P) * batch_size, cudaMemcpyDeviceToHost, stream));
// free memory
if (!are_scalars_on_device || are_scalars_montgomery_form) CHK_IF_RETURN(cudaFreeAsync(d_scalars, stream));
if (!are_points_on_device || are_points_montgomery_form) CHK_IF_RETURN(cudaFreeAsync(d_points, stream));
if (!are_results_on_device) CHK_IF_RETURN(cudaFreeAsync(d_final_result, stream));
if (d_allocated_scalars) CHK_IF_RETURN(cudaFreeAsync(d_allocated_scalars, stream));
if (d_allocated_points) CHK_IF_RETURN(cudaFreeAsync(d_allocated_points, stream));
if (d_allocated_final_result) CHK_IF_RETURN(cudaFreeAsync(d_allocated_final_result, stream));
CHK_IF_RETURN(cudaFreeAsync(buckets, stream));
if (!is_async) CHK_IF_RETURN(cudaStreamSynchronize(stream));
@@ -873,7 +854,7 @@ namespace msm {
}
template <typename S, typename A, typename P>
cudaError_t MSM(S* scalars, A* points, int msm_size, MSMConfig& config, P* results)
cudaError_t MSM(const S* scalars, const A* points, int msm_size, MSMConfig& config, P* results)
{
const int bitsize = (config.bitsize == 0) ? S::NBITS : config.bitsize;
cudaStream_t& stream = config.ctx.stream;
@@ -890,7 +871,59 @@ namespace msm {
bitsize, c, scalars, points, config.batch_size, msm_size,
(config.points_size == 0) ? msm_size : config.points_size, results, config.are_scalars_on_device,
config.are_scalars_montgomery_form, config.are_points_on_device, config.are_points_montgomery_form,
config.are_results_on_device, config.is_big_triangle, config.large_bucket_factor, config.is_async, stream));
config.are_results_on_device, config.is_big_triangle, config.large_bucket_factor, config.precompute_factor,
config.is_async, stream));
}
template <typename A, typename P>
cudaError_t PrecomputeMSMBases(
A* bases,
int bases_size,
int precompute_factor,
int _c,
bool are_bases_on_device,
device_context::DeviceContext& ctx,
A* output_bases)
{
CHK_INIT_IF_RETURN();
cudaStream_t& stream = ctx.stream;
CHK_IF_RETURN(cudaMemcpyAsync(
output_bases, bases, sizeof(A) * bases_size,
are_bases_on_device ? cudaMemcpyDeviceToDevice : cudaMemcpyHostToDevice, stream));
unsigned c = 16;
unsigned total_nof_bms = (P::SCALAR_FF_NBITS - 1) / c + 1;
unsigned shift = c * ((total_nof_bms - 1) / precompute_factor + 1);
unsigned NUM_THREADS = 1 << 8;
unsigned NUM_BLOCKS = (bases_size + NUM_THREADS - 1) / NUM_THREADS;
for (int i = 1; i < precompute_factor; i++) {
left_shift_kernel<A, P><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(
&output_bases[(i - 1) * bases_size], shift, bases_size, &output_bases[i * bases_size]);
}
return CHK_LAST();
}
/**
* Extern "C" version of [PrecomputeMSMBases](@ref PrecomputeMSMBases) function with the following values of
* template parameters (where the curve is given by `-DCURVE` env variable during build):
* - `A` is the [affine representation](@ref affine_t) of curve points;
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
extern "C" cudaError_t CONCAT_EXPAND(CURVE, PrecomputeMSMBases)(
curve_config::affine_t* bases,
int bases_size,
int precompute_factor,
int _c,
bool are_bases_on_device,
device_context::DeviceContext& ctx,
curve_config::affine_t* output_bases)
{
return PrecomputeMSMBases<curve_config::affine_t, curve_config::projective_t>(
bases, bases_size, precompute_factor, _c, are_bases_on_device, ctx, output_bases);
}
/**
@@ -902,8 +935,8 @@ namespace msm {
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
extern "C" cudaError_t CONCAT_EXPAND(CURVE, MSMCuda)(
curve_config::scalar_t* scalars,
curve_config::affine_t* points,
const curve_config::scalar_t* scalars,
const curve_config::affine_t* points,
int msm_size,
MSMConfig& config,
curve_config::projective_t* out)
@@ -912,13 +945,27 @@ namespace msm {
scalars, points, msm_size, config, out);
}
/**
* Extern "C" version of [DefaultMSMConfig](@ref DefaultMSMConfig) function.
*/
extern "C" MSMConfig CONCAT_EXPAND(CURVE, DefaultMSMConfig)() { return DefaultMSMConfig<curve_config::affine_t>(); }
#if defined(G2_DEFINED)
/**
* Extern "C" version of [PrecomputeMSMBases](@ref PrecomputeMSMBases) function with the following values of
* template parameters (where the curve is given by `-DCURVE` env variable during build):
* - `A` is the [affine representation](@ref g2_affine_t) of G2 curve points;
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
extern "C" cudaError_t CONCAT_EXPAND(CURVE, G2PrecomputeMSMBases)(
curve_config::g2_affine_t* bases,
int bases_size,
int precompute_factor,
int _c,
bool are_bases_on_device,
device_context::DeviceContext& ctx,
curve_config::g2_affine_t* output_bases)
{
return PrecomputeMSMBases<curve_config::g2_affine_t, curve_config::g2_projective_t>(
bases, bases_size, precompute_factor, _c, are_bases_on_device, ctx, output_bases);
}
/**
* Extern "C" version of [MSM](@ref MSM) function with the following values of template parameters
* (where the curve is given by `-DCURVE` env variable during build):
@@ -928,8 +975,8 @@ namespace msm {
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
extern "C" cudaError_t CONCAT_EXPAND(CURVE, G2MSMCuda)(
curve_config::scalar_t* scalars,
curve_config::g2_affine_t* points,
const curve_config::scalar_t* scalars,
const curve_config::g2_affine_t* points,
int msm_size,
MSMConfig& config,
curve_config::g2_projective_t* out)
@@ -938,15 +985,6 @@ namespace msm {
scalars, points, msm_size, config, out);
}
/**
* Extern "C" version of [DefaultMSMConfig](@ref DefaultMSMConfig) function for the G2 curve
* (functionally no different than the default MSM config function for G1).
*/
extern "C" MSMConfig CONCAT_EXPAND(CURVE, G2DefaultMSMConfig)()
{
return DefaultMSMConfig<curve_config::g2_affine_t>();
}
#endif
} // namespace msm

View File

@@ -43,14 +43,18 @@ namespace msm {
* variable is set equal to the MSM size. And if every MSM uses a distinct set of
* points, it should be set to the product of MSM size and [batch_size](@ref
* batch_size). Default value: 0 (meaning it's equal to the MSM size). */
int precompute_factor; /**< The number of extra points to pre-compute for each point. Larger values decrease the
int precompute_factor; /**< The number of extra points to pre-compute for each point. See the
* [PrecomputeMSMBases](@ref PrecomputeMSMBases) function, `precompute_factor` passed
* there needs to be equal to the one used here. Larger values decrease the
* number of computations to make, on-line memory footprint, but increase the static
* memory footprint. Default value: 1 (i.e. don't pre-compute). */
int c; /**< \f$ c \f$ value, or "window bitsize" which is the main parameter of the "bucket
* method" that we use to solve the MSM problem. As a rule of thumb, larger value
* means more on-line memory footprint but also more parallelism and less computational
* complexity (up to a certain point). Default value: 0 (the optimal value of \f$ c \f$
* is chosen automatically). */
* complexity (up to a certain point). Currently pre-computation is independent of
* \f$ c \f$, however in the future value of \f$ c \f$ here and the one passed into the
* [PrecomputeMSMBases](@ref PrecomputeMSMBases) function will need to be identical.
* Default value: 0 (the optimal value of \f$ c \f$ is chosen automatically). */
int bitsize; /**< Number of bits of the largest scalar. Typically equals the bitsize of scalar field,
* but if a different (better) upper bound is known, it should be reflected in this
* variable. Default value: 0 (set to the bitsize of scalar field). */
@@ -101,12 +105,39 @@ namespace msm {
* Weierstrass](https://hyperelliptic.org/EFD/g1p/auto-shortw-projective.html) point in our codebase.
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*
* **Note:** this function is still WIP and the following [MSMConfig](@ref MSMConfig) members do not yet have any
* effect: `precompute_factor` (always equals 1) and `ctx.device_id` (0 device is always used).
* Also, it's currently better to use `batch_size=1` in most cases (except with dealing with very many MSMs).
*/
template <typename S, typename A, typename P>
cudaError_t MSM(S* scalars, A* points, int msm_size, MSMConfig& config, P* results);
cudaError_t MSM(const S* scalars, const A* points, int msm_size, MSMConfig& config, P* results);
/**
* A function that precomputes MSM bases by extending them with their shifted copies.
* e.g.:
* Original points: \f$ P_0, P_1, P_2, ... P_{size} \f$
* Extended points: \f$ P_0, P_1, P_2, ... P_{size}, 2^{l}P_0, 2^{l}P_1, ..., 2^{l}P_{size},
* 2^{2l}P_0, 2^{2l}P_1, ..., 2^{2cl}P_{size}, ... \f$
* @param bases Bases \f$ P_i \f$. In case of batch MSM, all *unique* points are concatenated.
* @param bases_size Number of bases.
* @param precompute_factor The number of total precomputed points for each base (including the base itself).
* @param _c This is currently unused, but in the future precomputation will need to be aware of
* the `c` value used in MSM (see [MSMConfig](@ref MSMConfig)). So to avoid breaking your code with this
* upcoming change, make sure to use the same value of `c` in this function and in respective MSMConfig.
* @param are_bases_on_device Whether the bases are on device.
* @param ctx Device context specifying device id and stream to use.
* @param output_bases Device-allocated buffer of size bases_size * precompute_factor for the extended bases.
* @tparam A The type of points \f$ \{P_i\} \f$ which is typically an [affine
* Weierstrass](https://hyperelliptic.org/EFD/g1p/auto-shortw.html) point.
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*
*/
template <typename A, typename P>
cudaError_t PrecomputeMSMBases(
A* bases,
int bases_size,
int precompute_factor,
int _c,
bool are_bases_on_device,
device_context::DeviceContext& ctx,
A* output_bases);
} // namespace msm

View File

@@ -30,7 +30,7 @@ public:
return os;
}
HOST_DEVICE_INLINE unsigned get_scalar_digit(unsigned digit_num, unsigned digit_width)
HOST_DEVICE_INLINE unsigned get_scalar_digit(unsigned digit_num, unsigned digit_width) const
{
return (x >> (digit_num * digit_width)) & ((1 << digit_width) - 1);
}

View File

@@ -56,7 +56,15 @@ namespace ntt {
// Note: the following reorder kernels are fused with normalization for INTT
template <typename E, typename S, uint32_t MAX_GROUP_SIZE = 80>
static __global__ void reorder_digits_inplace_and_normalize_kernel(
E* arr, uint32_t log_size, bool dit, bool fast_tw, eRevType rev_type, bool is_normalize, S inverse_N)
E* arr,
uint32_t log_size,
bool columns_batch,
uint32_t batch_size,
bool dit,
bool fast_tw,
eRevType rev_type,
bool is_normalize,
S inverse_N)
{
// launch N threads (per batch element)
// each thread starts from one index and calculates the corresponding group
@@ -65,19 +73,20 @@ namespace ntt {
const uint32_t size = 1 << log_size;
const uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t idx = tid % size;
const uint32_t batch_idx = tid / size;
const uint32_t idx = columns_batch ? tid / batch_size : tid % size;
const uint32_t batch_idx = columns_batch ? tid % batch_size : tid / size;
if (tid >= size * batch_size) return;
uint32_t next_element = idx;
uint32_t group[MAX_GROUP_SIZE];
group[0] = next_element + size * batch_idx;
group[0] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx;
uint32_t i = 1;
for (; i < MAX_GROUP_SIZE;) {
next_element = generalized_rev(next_element, log_size, dit, fast_tw, rev_type);
if (next_element < idx) return; // not handling this group
if (next_element == idx) break; // calculated whole group
group[i++] = next_element + size * batch_idx;
group[i++] = columns_batch ? next_element * batch_size + batch_idx : next_element + size * batch_idx;
}
--i;
@@ -91,9 +100,12 @@ namespace ntt {
template <typename E, typename S>
__launch_bounds__(64) __global__ void reorder_digits_and_normalize_kernel(
E* arr,
const E* arr,
E* arr_reordered,
uint32_t log_size,
bool columns_batch,
uint32_t batch_size,
uint32_t columns_batch_size,
bool dit,
bool fast_tw,
eRevType rev_type,
@@ -101,41 +113,46 @@ namespace ntt {
S inverse_N)
{
uint32_t tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= (1 << log_size) * batch_size) return;
uint32_t rd = tid;
uint32_t wr =
((tid >> log_size) << log_size) + generalized_rev(tid & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type);
arr_reordered[wr] = is_normalize ? arr[rd] * inverse_N : arr[rd];
uint32_t wr = (columns_batch ? 0 : ((tid >> log_size) << log_size)) +
generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, fast_tw, rev_type);
arr_reordered[wr * columns_batch_size + (tid % columns_batch_size)] = is_normalize ? arr[rd] * inverse_N : arr[rd];
}
template <typename E, typename S>
static __global__ void batch_elementwise_mul_with_reorder(
E* in_vec,
int n_elements,
int batch_size,
static __global__ void batch_elementwise_mul_with_reorder_kernel(
const E* in_vec,
uint32_t size,
bool columns_batch,
uint32_t batch_size,
uint32_t columns_batch_size,
S* scalar_vec,
int step,
int n_scalars,
int logn,
uint32_t log_size,
eRevType rev_type,
bool dit,
E* out_vec)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= n_elements * batch_size) return;
int64_t scalar_id = tid % n_elements;
if (rev_type != eRevType::None) scalar_id = generalized_rev(tid, logn, dit, false, rev_type);
if (tid >= size * batch_size) return;
int64_t scalar_id = (tid / columns_batch_size) % size;
if (rev_type != eRevType::None)
scalar_id = generalized_rev((tid / columns_batch_size) & ((1 << log_size) - 1), log_size, dit, false, rev_type);
out_vec[tid] = *(scalar_vec + ((scalar_id * step) % n_scalars)) * in_vec[tid];
}
template <typename E, typename S>
__launch_bounds__(64) __global__ void ntt64(
E* in,
const E* in,
E* out,
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t columns_batch_size,
uint32_t nof_ntt_blocks,
uint32_t data_stride,
uint32_t log_data_stride,
@@ -153,19 +170,27 @@ namespace ntt {
s_meta.th_stride = 8;
s_meta.ntt_block_size = 64;
s_meta.ntt_block_id = (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3));
s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 7) / 8)
: (blockIdx.x << 3) + (strided ? (threadIdx.x & 0x7) : (threadIdx.x >> 3));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 3) : (threadIdx.x & 0x7);
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
s_meta.batch_id =
columns_batch_size ? (threadIdx.x & 0x7) + ((blockIdx.x % ((columns_batch_size + 7) / 8)) << 3) : 0;
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
return;
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
if (twiddle_stride && dit) {
if (fast_tw)
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, s_meta);
else
engine.loadExternalTwiddlesGeneric64(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
@@ -189,24 +214,28 @@ namespace ntt {
if (twiddle_stride && !dit) {
if (fast_tw)
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
engine.loadExternalTwiddles64(external_twiddles, twiddle_stride, log_data_stride, s_meta);
else
engine.loadExternalTwiddlesGeneric64(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
}
template <typename E, typename S>
__launch_bounds__(64) __global__ void ntt32(
E* in,
const E* in,
E* out,
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t columns_batch_size,
uint32_t nof_ntt_blocks,
uint32_t data_stride,
uint32_t log_data_stride,
@@ -225,16 +254,25 @@ namespace ntt {
s_meta.th_stride = 4;
s_meta.ntt_block_size = 32;
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 15) / 16)
: (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
s_meta.batch_id =
columns_batch_size ? (threadIdx.x & 0xf) + ((blockIdx.x % ((columns_batch_size + 15) / 16)) << 4) : 0;
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
return;
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
if (fast_tw)
engine.loadInternalTwiddles32(internal_twiddles, strided);
else
@@ -247,24 +285,28 @@ namespace ntt {
engine.ntt4_2();
if (twiddle_stride) {
if (fast_tw)
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, s_meta);
else
engine.loadExternalTwiddlesGeneric32(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.storeGlobalData32(out, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.storeGlobalData32ColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.storeGlobalData32(out, data_stride, log_data_stride, strided, s_meta);
}
template <typename E, typename S>
__launch_bounds__(64) __global__ void ntt32dit(
E* in,
const E* in,
E* out,
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t columns_batch_size,
uint32_t nof_ntt_blocks,
uint32_t data_stride,
uint32_t log_data_stride,
@@ -283,19 +325,27 @@ namespace ntt {
s_meta.th_stride = 4;
s_meta.ntt_block_size = 32;
s_meta.ntt_block_id = (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
s_meta.ntt_block_id = columns_batch_size ? blockIdx.x / ((columns_batch_size + 15) / 16)
: (blockIdx.x << 4) + (strided ? (threadIdx.x & 0xf) : (threadIdx.x >> 2));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 4) : (threadIdx.x & 0x3);
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
s_meta.batch_id =
columns_batch_size ? (threadIdx.x & 0xf) + ((blockIdx.x % ((columns_batch_size + 15) / 16)) << 4) : 0;
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
return;
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData32(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.loadGlobalData32ColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.loadGlobalData32(in, data_stride, log_data_stride, strided, s_meta);
if (twiddle_stride) {
if (fast_tw)
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
engine.loadExternalTwiddles32(external_twiddles, twiddle_stride, log_data_stride, s_meta);
else
engine.loadExternalTwiddlesGeneric32(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
@@ -311,18 +361,22 @@ namespace ntt {
engine.SharedData32Rows8(shmem, false, false, strided); // load
engine.twiddlesInternal();
engine.ntt8win();
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
}
template <typename E, typename S>
__launch_bounds__(64) __global__ void ntt16(
E* in,
const E* in,
E* out,
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t columns_batch_size,
uint32_t nof_ntt_blocks,
uint32_t data_stride,
uint32_t log_data_stride,
@@ -341,16 +395,26 @@ namespace ntt {
s_meta.th_stride = 2;
s_meta.ntt_block_size = 16;
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
s_meta.ntt_block_id = columns_batch_size
? blockIdx.x / ((columns_batch_size + 31) / 32)
: (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
s_meta.batch_id =
columns_batch_size ? (threadIdx.x & 0x1f) + ((blockIdx.x % ((columns_batch_size + 31) / 32)) << 5) : 0;
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
return;
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.loadGlobalDataColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.loadGlobalData(in, data_stride, log_data_stride, strided, s_meta);
if (fast_tw)
engine.loadInternalTwiddles16(internal_twiddles, strided);
else
@@ -363,24 +427,28 @@ namespace ntt {
engine.ntt2_4();
if (twiddle_stride) {
if (fast_tw)
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, s_meta);
else
engine.loadExternalTwiddlesGeneric16(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
engine.twiddlesExternal();
}
engine.storeGlobalData16(out, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.storeGlobalData16ColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.storeGlobalData16(out, data_stride, log_data_stride, strided, s_meta);
}
template <typename E, typename S>
__launch_bounds__(64) __global__ void ntt16dit(
E* in,
const E* in,
E* out,
S* external_twiddles,
S* internal_twiddles,
S* basic_twiddles,
uint32_t log_size,
uint32_t tw_log_size,
uint32_t columns_batch_size,
uint32_t nof_ntt_blocks,
uint32_t data_stride,
uint32_t log_data_stride,
@@ -399,19 +467,29 @@ namespace ntt {
s_meta.th_stride = 2;
s_meta.ntt_block_size = 16;
s_meta.ntt_block_id = (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
s_meta.ntt_block_id = columns_batch_size
? blockIdx.x / ((columns_batch_size + 31) / 32)
: (blockIdx.x << 5) + (strided ? (threadIdx.x & 0x1f) : (threadIdx.x >> 1));
s_meta.ntt_inp_id = strided ? (threadIdx.x >> 5) : (threadIdx.x & 0x1);
if (s_meta.ntt_block_id >= nof_ntt_blocks) return;
s_meta.batch_id =
columns_batch_size ? (threadIdx.x & 0x1f) + ((blockIdx.x % ((columns_batch_size + 31) / 32)) << 5) : 0;
if (s_meta.ntt_block_id >= nof_ntt_blocks || (columns_batch_size > 0 && s_meta.batch_id >= columns_batch_size))
return;
if (fast_tw)
engine.loadBasicTwiddles(basic_twiddles);
else
engine.loadBasicTwiddlesGeneric(basic_twiddles, inv);
engine.loadGlobalData16(in, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.loadGlobalData16ColumnBatch(in, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.loadGlobalData16(in, data_stride, log_data_stride, strided, s_meta);
if (twiddle_stride) {
if (fast_tw)
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, strided, s_meta);
engine.loadExternalTwiddles16(external_twiddles, twiddle_stride, log_data_stride, s_meta);
else
engine.loadExternalTwiddlesGeneric16(
external_twiddles, twiddle_stride, log_data_stride, s_meta, tw_log_size, inv);
@@ -427,13 +505,17 @@ namespace ntt {
engine.SharedData16Rows8(shmem, false, false, strided); // load
engine.twiddlesInternal();
engine.ntt8win();
engine.storeGlobalData(out, data_stride, log_data_stride, log_size, strided, s_meta);
if (columns_batch_size)
engine.storeGlobalDataColumnBatch(out, data_stride, log_data_stride, s_meta, columns_batch_size);
else
engine.storeGlobalData(out, data_stride, log_data_stride, strided, s_meta);
}
template <typename E, typename S>
__global__ void normalize_kernel(E* data, S norm_factor)
__global__ void normalize_kernel(E* data, S norm_factor, uint32_t size)
{
uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid >= size) return;
data[tid] = data[tid] * norm_factor;
}
@@ -658,7 +740,7 @@ namespace ntt {
template <typename E, typename S>
cudaError_t large_ntt(
E* in,
const E* in,
E* out,
S* external_twiddles,
S* internal_twiddles,
@@ -666,6 +748,7 @@ namespace ntt {
uint32_t log_size,
uint32_t tw_log_size,
uint32_t batch_size,
bool columns_batch,
bool inv,
bool normalize,
bool dit,
@@ -679,72 +762,83 @@ namespace ntt {
}
if (log_size == 4) {
const int NOF_THREADS = min(64, 2 * batch_size);
const int NOF_BLOCKS = (2 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
const int NOF_THREADS = columns_batch ? 64 : min(64, 2 * batch_size);
const int NOF_BLOCKS =
columns_batch ? ((batch_size + 31) / 32) : (2 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
if (dit) {
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit, fast_tw);
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
} else { // dif
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit, fast_tw);
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
}
if (normalize) normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4));
if (normalize)
normalize_kernel<<<batch_size, 16, 0, cuda_stream>>>(out, S::inv_log_size(4), (1 << log_size) * batch_size);
return CHK_LAST();
}
if (log_size == 5) {
const int NOF_THREADS = min(64, 4 * batch_size);
const int NOF_BLOCKS = (4 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
const int NOF_THREADS = columns_batch ? 64 : min(64, 4 * batch_size);
const int NOF_BLOCKS =
columns_batch ? ((batch_size + 15) / 16) : (4 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
if (dit) {
ntt32dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit, fast_tw);
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
} else { // dif
ntt32<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit, fast_tw);
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
}
if (normalize) normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5));
if (normalize)
normalize_kernel<<<batch_size, 32, 0, cuda_stream>>>(out, S::inv_log_size(5), (1 << log_size) * batch_size);
return CHK_LAST();
}
if (log_size == 6) {
const int NOF_THREADS = min(64, 8 * batch_size);
const int NOF_BLOCKS = (8 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
const int NOF_THREADS = columns_batch ? 64 : min(64, 8 * batch_size);
const int NOF_BLOCKS =
columns_batch ? ((batch_size + 7) / 8) : ((8 * batch_size + NOF_THREADS - 1) / NOF_THREADS);
ntt64<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size, batch_size, 1, 0, 0,
false, 0, inv, dit, fast_tw);
if (normalize) normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6));
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
columns_batch ? batch_size : 0, columns_batch ? 1 : batch_size, 1, 0, 0, columns_batch, 0, inv, dit, fast_tw);
if (normalize)
normalize_kernel<<<batch_size, 64, 0, cuda_stream>>>(out, S::inv_log_size(6), (1 << log_size) * batch_size);
return CHK_LAST();
}
if (log_size == 8) {
const int NOF_THREADS = 64;
const int NOF_BLOCKS = (32 * batch_size + NOF_THREADS - 1) / NOF_THREADS;
const int NOF_BLOCKS =
columns_batch ? ((batch_size + 31) / 32 * 16) : ((32 * batch_size + NOF_THREADS - 1) / NOF_THREADS);
if (dit) {
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1, 0, 0,
columns_batch, 0, inv, dit, fast_tw);
ntt16dit<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 16, 4, 16, true, 1,
inv, dit, fast_tw);
} else { // dif
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 16, 4, 16, true, 1, inv, dit, fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 16, 4, 16, true, 1,
inv, dit, fast_tw);
ntt16<<<NOF_BLOCKS, NOF_THREADS, 8 * 64 * sizeof(E), cuda_stream>>>(
out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 1, 0, 0, false, 0, inv, dit, fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1, 0, 0,
columns_batch, 0, inv, dit, fast_tw);
}
if (normalize) normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8));
if (normalize)
normalize_kernel<<<batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(8), (1 << log_size) * batch_size);
return CHK_LAST();
}
// general case:
uint32_t nof_blocks = (1 << (log_size - 9)) * batch_size;
uint32_t nof_blocks = (1 << (log_size - 9)) * (columns_batch ? ((batch_size + 31) / 32) * 32 : batch_size);
if (dit) {
for (int i = 0; i < 5; i++) {
uint32_t stage_size = fast_tw ? STAGE_SIZES_HOST_FT[log_size][i] : STAGE_SIZES_HOST[log_size][i];
@@ -754,18 +848,18 @@ namespace ntt {
if (stage_size == 6)
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 6) * (columns_batch ? 1 : batch_size), 1 << stride_log,
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
else if (stage_size == 5)
ntt32dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 5) * (columns_batch ? 1 : batch_size), 1 << stride_log,
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
else if (stage_size == 4)
ntt16dit<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
i ? out : in, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1 << stride_log,
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
}
} else { // dif
bool first_run = false, prev_stage = false;
@@ -778,30 +872,31 @@ namespace ntt {
if (stage_size == 6)
ntt64<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 6) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 6) * (columns_batch ? 1 : batch_size), 1 << stride_log,
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
else if (stage_size == 5)
ntt32<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 5) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 5) * (columns_batch ? 1 : batch_size), 1 << stride_log,
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
else if (stage_size == 4)
ntt16<<<nof_blocks, 64, 8 * 64 * sizeof(E), cuda_stream>>>(
first_run ? in : out, out, external_twiddles, internal_twiddles, basic_twiddles, log_size, tw_log_size,
(1 << log_size - 4) * batch_size, 1 << stride_log, stride_log, i ? (1 << stride_log) : 0, i, i, inv, dit,
fast_tw);
columns_batch ? batch_size : 0, (1 << log_size - 4) * (columns_batch ? 1 : batch_size), 1 << stride_log,
stride_log, i ? (1 << stride_log) : 0, i || columns_batch, i, inv, dit, fast_tw);
prev_stage = stage_size;
}
}
if (normalize)
normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>(out, S::inv_log_size(log_size));
normalize_kernel<<<(1 << (log_size - 8)) * batch_size, 256, 0, cuda_stream>>>(
out, S::inv_log_size(log_size), (1 << log_size) * batch_size);
return CHK_LAST();
}
template <typename E, typename S>
cudaError_t mixed_radix_ntt(
E* d_input,
const E* d_input,
E* d_output,
S* external_twiddles,
S* internal_twiddles,
@@ -809,6 +904,7 @@ namespace ntt {
int ntt_size,
int max_logn,
int batch_size,
bool columns_batch,
bool is_inverse,
bool fast_tw,
Ordering ordering,
@@ -858,9 +954,10 @@ namespace ntt {
}
if (is_on_coset && !is_inverse) {
batch_elementwise_mul_with_reorder<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles,
arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
batch_elementwise_mul_with_reorder_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1,
arbitrary_coset ? arbitrary_coset : external_twiddles, arbitrary_coset ? 1 : coset_gen_index, n_twiddles, logn,
reverse_coset, dit, d_output);
d_input = d_output;
}
@@ -869,10 +966,11 @@ namespace ntt {
const bool is_reverse_in_place = (d_input == d_output);
if (is_reverse_in_place) {
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
d_output, logn, columns_batch, batch_size, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
} else {
reorder_digits_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_input, d_output, logn, dit, fast_tw, reverse_input, is_normalize, S::inv_log_size(logn));
d_input, d_output, logn, columns_batch, batch_size, columns_batch ? batch_size : 1, dit, fast_tw,
reverse_input, is_normalize, S::inv_log_size(logn));
}
is_normalize = false;
d_input = d_output;
@@ -880,18 +978,19 @@ namespace ntt {
// inplace ntt
CHK_IF_RETURN(large_ntt(
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size, is_inverse,
(is_normalize && reverse_output == eRevType::None), dit, fast_tw, cuda_stream));
d_input, d_output, external_twiddles, internal_twiddles, basic_twiddles, logn, max_logn, batch_size,
columns_batch, is_inverse, (is_normalize && reverse_output == eRevType::None), dit, fast_tw, cuda_stream));
if (reverse_output != eRevType::None) {
reorder_digits_inplace_and_normalize_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, logn, dit, fast_tw, reverse_output, is_normalize, S::inv_log_size(logn));
d_output, logn, columns_batch, batch_size, dit, fast_tw, reverse_output, is_normalize, S::inv_log_size(logn));
}
if (is_on_coset && is_inverse) {
batch_elementwise_mul_with_reorder<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, ntt_size, batch_size, arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles,
arbitrary_coset ? 1 : -coset_gen_index, n_twiddles, logn, reverse_coset, dit, d_output);
batch_elementwise_mul_with_reorder_kernel<<<NOF_BLOCKS, NOF_THREADS, 0, cuda_stream>>>(
d_output, ntt_size, columns_batch, batch_size, columns_batch ? batch_size : 1,
arbitrary_coset ? arbitrary_coset : external_twiddles + n_twiddles, arbitrary_coset ? 1 : -coset_gen_index,
n_twiddles, logn, reverse_coset, dit, d_output);
}
return CHK_LAST();
@@ -915,7 +1014,7 @@ namespace ntt {
cudaStream_t& stream);
template cudaError_t mixed_radix_ntt<curve_config::scalar_t, curve_config::scalar_t>(
curve_config::scalar_t* d_input,
const curve_config::scalar_t* d_input,
curve_config::scalar_t* d_output,
curve_config::scalar_t* external_twiddles,
curve_config::scalar_t* internal_twiddles,
@@ -923,6 +1022,7 @@ namespace ntt {
int ntt_size,
int max_logn,
int batch_size,
bool columns_batch,
bool is_inverse,
bool fast_tw,
Ordering ordering,

View File

@@ -21,7 +21,7 @@ namespace ntt {
const uint32_t MAX_SHARED_MEM = MAX_SHARED_MEM_ELEMENT_SIZE * MAX_NUM_THREADS;
template <typename E>
__global__ void reverse_order_kernel(E* arr, E* arr_reversed, uint32_t n, uint32_t logn, uint32_t batch_size)
__global__ void reverse_order_kernel(const E* arr, E* arr_reversed, uint32_t n, uint32_t logn, uint32_t batch_size)
{
int threadId = (blockIdx.x * blockDim.x) + threadIdx.x;
if (threadId < n * batch_size) {
@@ -46,7 +46,8 @@ namespace ntt {
* @param arr_out buffer of the same size as `arr_in` on the GPU to write the bit-permuted array into.
*/
template <typename E>
void reverse_order_batch(E* arr_in, uint32_t n, uint32_t logn, uint32_t batch_size, cudaStream_t stream, E* arr_out)
void reverse_order_batch(
const E* arr_in, uint32_t n, uint32_t logn, uint32_t batch_size, cudaStream_t stream, E* arr_out)
{
int number_of_threads = MAX_THREADS_BATCH;
int number_of_blocks = (n * batch_size + number_of_threads - 1) / number_of_threads;
@@ -63,7 +64,7 @@ namespace ntt {
* @param arr_out buffer of the same size as `arr_in` on the GPU to write the bit-permuted array into.
*/
template <typename E>
void reverse_order(E* arr_in, uint32_t n, uint32_t logn, cudaStream_t stream, E* arr_out)
void reverse_order(const E* arr_in, uint32_t n, uint32_t logn, cudaStream_t stream, E* arr_out)
{
reverse_order_batch(arr_in, n, logn, 1, stream, arr_out);
}
@@ -81,7 +82,7 @@ namespace ntt {
*/
template <typename E, typename S>
__global__ void ntt_template_kernel_shared_rev(
E* __restrict__ arr_in,
const E* __restrict__ arr_in,
int n,
const S* __restrict__ r_twiddles,
int n_twiddles,
@@ -153,7 +154,7 @@ namespace ntt {
*/
template <typename E, typename S>
__global__ void ntt_template_kernel_shared(
E* __restrict__ arr_in,
const E* __restrict__ arr_in,
int n,
const S* __restrict__ r_twiddles,
int n_twiddles,
@@ -221,7 +222,7 @@ namespace ntt {
*/
template <typename E, typename S>
__global__ void
ntt_template_kernel(E* arr_in, int n, S* twiddles, int n_twiddles, int max_task, int s, bool rev, E* arr_out)
ntt_template_kernel(const E* arr_in, int n, S* twiddles, int n_twiddles, int max_task, int s, bool rev, E* arr_out)
{
int task = blockIdx.x;
int chunks = n / (blockDim.x * 2);
@@ -273,7 +274,7 @@ namespace ntt {
*/
template <typename E, typename S>
cudaError_t ntt_inplace_batch_template(
E* d_input,
const E* d_input,
int n,
S* d_twiddles,
int n_twiddles,
@@ -391,7 +392,7 @@ namespace ntt {
cudaError_t ReleaseDomain(device_context::DeviceContext& ctx);
template <typename U, typename E>
friend cudaError_t NTT<U, E>(E* input, int size, NTTDir dir, NTTConfig<U>& config, E* output);
friend cudaError_t NTT<U, E>(const E* input, int size, NTTDir dir, NTTConfig<U>& config, E* output);
};
template <typename S>
@@ -516,12 +517,15 @@ namespace ntt {
static bool is_choose_radix2_algorithm(int logn, int batch_size, const NTTConfig<S>& config)
{
const bool is_mixed_radix_alg_supported = (logn > 3 && logn != 7);
if (!is_mixed_radix_alg_supported && config.columns_batch)
throw IcicleError(IcicleError_t::InvalidArgument, "columns batch is not supported for given NTT size");
const bool is_user_selected_radix2_alg = config.ntt_algorithm == NttAlgorithm::Radix2;
const bool is_force_radix2 = !is_mixed_radix_alg_supported || is_user_selected_radix2_alg;
if (is_force_radix2) return true;
const bool is_user_selected_mixed_radix_alg = config.ntt_algorithm == NttAlgorithm::MixedRadix;
if (is_user_selected_mixed_radix_alg) return false;
if (config.columns_batch) return false; // radix2 does not currently support columns batch mode.
// Heuristic to automatically select an algorithm
// Note that generally the decision depends on {logn, batch, ordering, inverse, coset, in-place, coeff-field} and
@@ -537,7 +541,7 @@ namespace ntt {
template <typename S, typename E>
cudaError_t radix2_ntt(
E* d_input,
const E* d_input,
E* d_output,
S* twiddles,
int ntt_size,
@@ -583,7 +587,7 @@ namespace ntt {
}
template <typename S, typename E>
cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output)
cudaError_t NTT(const E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output)
{
CHK_INIT_IF_RETURN();
@@ -610,18 +614,22 @@ namespace ntt {
bool are_inputs_on_device = config.are_inputs_on_device;
bool are_outputs_on_device = config.are_outputs_on_device;
E* d_input;
const E* d_input;
E* d_allocated_input = nullptr;
if (are_inputs_on_device) {
d_input = input;
} else {
CHK_IF_RETURN(cudaMallocAsync(&d_input, input_size_bytes, stream));
CHK_IF_RETURN(cudaMemcpyAsync(d_input, input, input_size_bytes, cudaMemcpyHostToDevice, stream));
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_input, input_size_bytes, stream));
CHK_IF_RETURN(cudaMemcpyAsync(d_allocated_input, input, input_size_bytes, cudaMemcpyHostToDevice, stream));
d_input = d_allocated_input;
}
E* d_output;
E* d_allocated_output = nullptr;
if (are_outputs_on_device) {
d_output = output;
} else {
CHK_IF_RETURN(cudaMallocAsync(&d_output, input_size_bytes, stream));
CHK_IF_RETURN(cudaMallocAsync(&d_allocated_output, input_size_bytes, stream));
d_output = d_allocated_output;
}
S* coset = nullptr;
@@ -641,37 +649,42 @@ namespace ntt {
h_coset.clear();
}
const bool is_radix2_algorithm = is_choose_radix2_algorithm(logn, batch_size, config);
const bool is_inverse = dir == NTTDir::kInverse;
if (is_radix2_algorithm) {
if constexpr (std::is_same_v<E, curve_config::projective_t>) {
CHK_IF_RETURN(ntt::radix2_ntt(
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
coset_index, stream));
} else {
const bool is_on_coset = (coset_index != 0) || coset;
const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr) && !is_on_coset;
S* twiddles = is_fast_twiddles_enabled
? (is_inverse ? domain.fast_external_twiddles_inv : domain.fast_external_twiddles)
: domain.twiddles;
S* internal_twiddles = is_fast_twiddles_enabled
? (is_inverse ? domain.fast_internal_twiddles_inv : domain.fast_internal_twiddles)
: domain.internal_twiddles;
S* basic_twiddles = is_fast_twiddles_enabled
? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles)
: domain.basic_twiddles;
CHK_IF_RETURN(ntt::mixed_radix_ntt(
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size,
is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
const bool is_radix2_algorithm = is_choose_radix2_algorithm(logn, batch_size, config);
if (is_radix2_algorithm) {
CHK_IF_RETURN(ntt::radix2_ntt(
d_input, d_output, domain.twiddles, size, domain.max_size, batch_size, is_inverse, config.ordering, coset,
coset_index, stream));
} else {
const bool is_on_coset = (coset_index != 0) || coset;
const bool is_fast_twiddles_enabled = (domain.fast_external_twiddles != nullptr) && !is_on_coset;
S* twiddles = is_fast_twiddles_enabled
? (is_inverse ? domain.fast_external_twiddles_inv : domain.fast_external_twiddles)
: domain.twiddles;
S* internal_twiddles = is_fast_twiddles_enabled
? (is_inverse ? domain.fast_internal_twiddles_inv : domain.fast_internal_twiddles)
: domain.internal_twiddles;
S* basic_twiddles = is_fast_twiddles_enabled
? (is_inverse ? domain.fast_basic_twiddles_inv : domain.fast_basic_twiddles)
: domain.basic_twiddles;
CHK_IF_RETURN(ntt::mixed_radix_ntt(
d_input, d_output, twiddles, internal_twiddles, basic_twiddles, size, domain.max_log_size, batch_size,
config.columns_batch, is_inverse, is_fast_twiddles_enabled, config.ordering, coset, coset_index, stream));
}
}
if (!are_outputs_on_device)
CHK_IF_RETURN(cudaMemcpyAsync(output, d_output, input_size_bytes, cudaMemcpyDeviceToHost, stream));
if (coset) CHK_IF_RETURN(cudaFreeAsync(coset, stream));
if (!are_inputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_input, stream));
if (!are_outputs_on_device) CHK_IF_RETURN(cudaFreeAsync(d_output, stream));
if (d_allocated_input) CHK_IF_RETURN(cudaFreeAsync(d_allocated_input, stream));
if (d_allocated_output) CHK_IF_RETURN(cudaFreeAsync(d_allocated_output, stream));
if (!config.is_async) return CHK_STICKY(cudaStreamSynchronize(stream));
return CHK_LAST();
@@ -685,6 +698,7 @@ namespace ntt {
ctx, // ctx
S::one(), // coset_gen
1, // batch_size
false, // columns_batch
Ordering::kNN, // ordering
false, // are_inputs_on_device
false, // are_outputs_on_device
@@ -712,7 +726,7 @@ namespace ntt {
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
extern "C" cudaError_t CONCAT_EXPAND(CURVE, NTTCuda)(
curve_config::scalar_t* input,
const curve_config::scalar_t* input,
int size,
NTTDir dir,
NTTConfig<curve_config::scalar_t>& config,
@@ -731,7 +745,7 @@ namespace ntt {
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
extern "C" cudaError_t CONCAT_EXPAND(CURVE, ECNTTCuda)(
curve_config::projective_t* input,
const curve_config::projective_t* input,
int size,
NTTDir dir,
NTTConfig<curve_config::scalar_t>& config,

View File

@@ -95,6 +95,8 @@ namespace ntt {
S coset_gen; /**< Coset generator. Used to perform coset (i)NTTs. Default value: `S::one()`
* (corresponding to no coset being used). */
int batch_size; /**< The number of NTTs to compute. Default value: 1. */
bool columns_batch; /**< True if the batches are the columns of an input matrix
(they are strided in memory with a stride of ntt size) Default value: false. */
Ordering ordering; /**< Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value:
* `Ordering::kNN`. */
bool are_inputs_on_device; /**< True if inputs are on device and false if they're on host. Default value: false. */
@@ -132,7 +134,7 @@ namespace ntt {
* @return `cudaSuccess` if the execution was successful and an error code otherwise.
*/
template <typename S, typename E>
cudaError_t NTT(E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output);
cudaError_t NTT(const E* input, int size, NTTDir dir, NTTConfig<S>& config, E* output);
} // namespace ntt

View File

@@ -27,7 +27,7 @@ namespace ntt {
template <typename E, typename S>
cudaError_t mixed_radix_ntt(
E* d_input,
const E* d_input,
E* d_output,
S* external_twiddles,
S* internal_twiddles,
@@ -35,6 +35,7 @@ namespace ntt {
int ntt_size,
int max_logn,
int batch_size,
bool columns_batch,
bool is_inverse,
bool fast_tw,
Ordering ordering,

View File

@@ -29,6 +29,13 @@ void incremental_values(test_scalar* res, uint32_t count)
}
}
__global__ void transpose_batch(test_scalar* in, test_scalar* out, int row_size, int column_size)
{
int tid = blockDim.x * blockIdx.x + threadIdx.x;
if (tid >= row_size * column_size) return;
out[(tid % row_size) * column_size + (tid / row_size)] = in[tid];
}
int main(int argc, char** argv)
{
cudaEvent_t icicle_start, icicle_stop, new_start, new_stop;
@@ -37,11 +44,12 @@ int main(int argc, char** argv)
int NTT_LOG_SIZE = (argc > 1) ? atoi(argv[1]) : 19;
int NTT_SIZE = 1 << NTT_LOG_SIZE;
bool INPLACE = (argc > 2) ? atoi(argv[2]) : false;
int INV = (argc > 3) ? atoi(argv[3]) : true;
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 1;
int COSET_IDX = (argc > 5) ? atoi(argv[5]) : 0;
const ntt::Ordering ordering = (argc > 6) ? ntt::Ordering(atoi(argv[6])) : ntt::Ordering::kNN;
bool FAST_TW = (argc > 7) ? atoi(argv[7]) : true;
int INV = (argc > 3) ? atoi(argv[3]) : false;
int BATCH_SIZE = (argc > 4) ? atoi(argv[4]) : 150;
bool COLUMNS_BATCH = (argc > 5) ? atoi(argv[5]) : false;
int COSET_IDX = (argc > 6) ? atoi(argv[6]) : 2;
const ntt::Ordering ordering = (argc > 7) ? ntt::Ordering(atoi(argv[7])) : ntt::Ordering::kNN;
bool FAST_TW = (argc > 8) ? atoi(argv[8]) : true;
// Note: NM, MN are not expected to be equal when comparing mixed-radix and radix-2 NTTs
const char* ordering_str = ordering == ntt::Ordering::kNN ? "NN"
@@ -52,8 +60,8 @@ int main(int argc, char** argv)
: "MN";
printf(
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, coset-idx=%d, ordering=%s, fast_tw=%d\n", NTT_LOG_SIZE,
INPLACE, INV, BATCH_SIZE, COSET_IDX, ordering_str, FAST_TW);
"running ntt 2^%d, inplace=%d, inverse=%d, batch_size=%d, columns_batch=%d coset-idx=%d, ordering=%s, fast_tw=%d\n",
NTT_LOG_SIZE, INPLACE, INV, BATCH_SIZE, COLUMNS_BATCH, COSET_IDX, ordering_str, FAST_TW);
CHK_IF_RETURN(cudaFree(nullptr)); // init GPU context (warmup)
@@ -63,6 +71,7 @@ int main(int argc, char** argv)
ntt_config.are_inputs_on_device = true;
ntt_config.are_outputs_on_device = true;
ntt_config.batch_size = BATCH_SIZE;
ntt_config.columns_batch = COLUMNS_BATCH;
CHK_IF_RETURN(cudaEventCreate(&icicle_start));
CHK_IF_RETURN(cudaEventCreate(&icicle_stop));
@@ -83,7 +92,9 @@ int main(int argc, char** argv)
// gpu allocation
test_data *GpuScalars, *GpuOutputOld, *GpuOutputNew;
test_data* GpuScalarsTransposed;
CHK_IF_RETURN(cudaMalloc(&GpuScalars, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
CHK_IF_RETURN(cudaMalloc(&GpuScalarsTransposed, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
CHK_IF_RETURN(cudaMalloc(&GpuOutputOld, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
CHK_IF_RETURN(cudaMalloc(&GpuOutputNew, sizeof(test_data) * NTT_SIZE * BATCH_SIZE));
@@ -93,10 +104,16 @@ int main(int argc, char** argv)
CHK_IF_RETURN(
cudaMemcpy(GpuScalars, CpuScalars.get(), NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyHostToDevice));
if (COLUMNS_BATCH) {
transpose_batch<<<(NTT_SIZE * BATCH_SIZE + 256 - 1) / 256, 256>>>(
GpuScalars, GpuScalarsTransposed, NTT_SIZE, BATCH_SIZE);
}
// inplace
if (INPLACE) {
CHK_IF_RETURN(
cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
CHK_IF_RETURN(cudaMemcpy(
GpuOutputNew, COLUMNS_BATCH ? GpuScalarsTransposed : GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data),
cudaMemcpyDeviceToDevice));
}
for (int coset_idx = 0; coset_idx < COSET_IDX; ++coset_idx) {
@@ -109,13 +126,14 @@ int main(int argc, char** argv)
ntt_config.ntt_algorithm = ntt::NttAlgorithm::MixedRadix;
for (size_t i = 0; i < iterations; i++) {
CHK_IF_RETURN(ntt::NTT(
INPLACE ? GpuOutputNew : GpuScalars, NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config,
GpuOutputNew));
INPLACE ? GpuOutputNew
: COLUMNS_BATCH ? GpuScalarsTransposed
: GpuScalars,
NTT_SIZE, INV ? ntt::NTTDir::kInverse : ntt::NTTDir::kForward, ntt_config, GpuOutputNew));
}
CHK_IF_RETURN(cudaEventRecord(new_stop, ntt_config.ctx.stream));
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
CHK_IF_RETURN(cudaEventElapsedTime(&new_time, new_start, new_stop));
if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); }
// OLD
CHK_IF_RETURN(cudaEventRecord(icicle_start, ntt_config.ctx.stream));
@@ -127,7 +145,6 @@ int main(int argc, char** argv)
CHK_IF_RETURN(cudaEventRecord(icicle_stop, ntt_config.ctx.stream));
CHK_IF_RETURN(cudaStreamSynchronize(ntt_config.ctx.stream));
CHK_IF_RETURN(cudaEventElapsedTime(&icicle_time, icicle_start, icicle_stop));
if (is_print) { fprintf(stderr, "cuda err %d\n", cudaGetLastError()); }
if (is_print) {
printf("Old Runtime=%0.3f MS\n", icicle_time / iterations);
@@ -140,11 +157,19 @@ int main(int argc, char** argv)
CHK_IF_RETURN(benchmark(false /*=print*/, 1)); // warmup
int count = INPLACE ? 1 : 10;
if (INPLACE) {
CHK_IF_RETURN(
cudaMemcpy(GpuOutputNew, GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
CHK_IF_RETURN(cudaMemcpy(
GpuOutputNew, COLUMNS_BATCH ? GpuScalarsTransposed : GpuScalars, NTT_SIZE * BATCH_SIZE * sizeof(test_data),
cudaMemcpyDeviceToDevice));
}
CHK_IF_RETURN(benchmark(true /*=print*/, count));
if (COLUMNS_BATCH) {
transpose_batch<<<(NTT_SIZE * BATCH_SIZE + 256 - 1) / 256, 256>>>(
GpuOutputNew, GpuScalarsTransposed, BATCH_SIZE, NTT_SIZE);
CHK_IF_RETURN(cudaMemcpy(
GpuOutputNew, GpuScalarsTransposed, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToDevice));
}
// verify
CHK_IF_RETURN(
cudaMemcpy(CpuOutputNew.get(), GpuOutputNew, NTT_SIZE * BATCH_SIZE * sizeof(test_data), cudaMemcpyDeviceToHost));
@@ -153,10 +178,11 @@ int main(int argc, char** argv)
bool success = true;
for (int i = 0; i < NTT_SIZE * BATCH_SIZE; i++) {
// if (i%64==0) printf("\n");
if (CpuOutputNew[i] != CpuOutputOld[i]) {
success = false;
// std::cout << i << " ref " << CpuOutputOld[i] << " != " << CpuOutputNew[i] << std::endl;
break;
// break;
} else {
// std::cout << i << " ref " << CpuOutputOld[i] << " == " << CpuOutputNew[i] << std::endl;
// break;

View File

@@ -9,6 +9,7 @@
struct stage_metadata {
uint32_t th_stride;
uint32_t ntt_block_size;
uint32_t batch_id;
uint32_t ntt_block_id;
uint32_t ntt_inp_id;
};
@@ -118,7 +119,7 @@ public:
}
__device__ __forceinline__ void
loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
loadExternalTwiddles64(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
{
data += tw_order * s_meta.ntt_inp_id + (s_meta.ntt_block_id & (tw_order - 1));
@@ -129,7 +130,7 @@ public:
}
__device__ __forceinline__ void
loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
loadExternalTwiddles32(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
{
data += tw_order * s_meta.ntt_inp_id * 2 + (s_meta.ntt_block_id & (tw_order - 1));
@@ -143,7 +144,7 @@ public:
}
__device__ __forceinline__ void
loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, bool strided, stage_metadata s_meta)
loadExternalTwiddles16(S* data, uint32_t tw_order, uint32_t tw_log_order, stage_metadata s_meta)
{
data += tw_order * s_meta.ntt_inp_id * 4 + (s_meta.ntt_block_id & (tw_order - 1));
@@ -195,8 +196,8 @@ public:
}
}
__device__ __forceinline__ void loadGlobalData(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
__device__ __forceinline__ void
loadGlobalData(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
@@ -211,8 +212,22 @@ public:
}
}
__device__ __forceinline__ void storeGlobalData(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
__device__ __forceinline__ void loadGlobalDataColumnBatch(
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
X[i] = data[s_meta.th_stride * i * data_stride * batch_size];
}
}
__device__ __forceinline__ void
storeGlobalData(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
@@ -227,8 +242,22 @@ public:
}
}
__device__ __forceinline__ void loadGlobalData32(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
__device__ __forceinline__ void storeGlobalDataColumnBatch(
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;
#pragma unroll
for (uint32_t i = 0; i < 8; i++) {
data[s_meta.th_stride * i * data_stride * batch_size] = X[i];
}
}
__device__ __forceinline__ void
loadGlobalData32(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
@@ -246,8 +275,25 @@ public:
}
}
__device__ __forceinline__ void storeGlobalData32(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
__device__ __forceinline__ void loadGlobalData32ColumnBatch(
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
#pragma unroll
for (uint32_t i = 0; i < 4; i++) {
X[4 * j + i] = data[(8 * i + j) * data_stride * batch_size];
}
}
}
__device__ __forceinline__ void
storeGlobalData32(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
@@ -265,8 +311,25 @@ public:
}
}
__device__ __forceinline__ void loadGlobalData16(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
__device__ __forceinline__ void storeGlobalData32ColumnBatch(
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 2 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;
#pragma unroll
for (uint32_t j = 0; j < 2; j++) {
#pragma unroll
for (uint32_t i = 0; i < 4; i++) {
data[(8 * i + j) * data_stride * batch_size] = X[4 * j + i];
}
}
}
__device__ __forceinline__ void
loadGlobalData16(const E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
@@ -284,8 +347,25 @@ public:
}
}
__device__ __forceinline__ void storeGlobalData16(
E* data, uint32_t data_stride, uint32_t log_data_stride, uint32_t log_size, bool strided, stage_metadata s_meta)
__device__ __forceinline__ void loadGlobalData16ColumnBatch(
const E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
#pragma unroll
for (uint32_t i = 0; i < 2; i++) {
X[2 * j + i] = data[(8 * i + j) * data_stride * batch_size];
}
}
}
__device__ __forceinline__ void
storeGlobalData16(E* data, uint32_t data_stride, uint32_t log_data_stride, bool strided, stage_metadata s_meta)
{
if (strided) {
data += (s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
@@ -303,6 +383,23 @@ public:
}
}
__device__ __forceinline__ void storeGlobalData16ColumnBatch(
E* data, uint32_t data_stride, uint32_t log_data_stride, stage_metadata s_meta, uint32_t batch_size)
{
data += ((s_meta.ntt_block_id & (data_stride - 1)) + data_stride * s_meta.ntt_inp_id * 4 +
(s_meta.ntt_block_id >> log_data_stride) * data_stride * s_meta.ntt_block_size) *
batch_size +
s_meta.batch_id;
#pragma unroll
for (uint32_t j = 0; j < 4; j++) {
#pragma unroll
for (uint32_t i = 0; i < 2; i++) {
data[(8 * i + j) * data_stride * batch_size] = X[2 * j + i];
}
}
}
__device__ __forceinline__ void ntt4_2()
{
#pragma unroll

View File

@@ -3,8 +3,11 @@
#define BW6_761_PARAMS_H
#include "utils/storage.cuh"
#include "bls12_377_params.cuh"
namespace bw6_761 {
typedef bls12_377::fq_config fp_config;
struct fq_config {
static constexpr unsigned limbs_count = 24;
static constexpr unsigned modulus_bit_count = 761;

View File

@@ -24,7 +24,6 @@ using namespace bls12_381;
#include "bls12_377_params.cuh"
using namespace bls12_377;
#elif CURVE_ID == BW6_761
#include "bls12_377_params.cuh"
#include "bw6_761_params.cuh"
using namespace bw6_761;
#elif CURVE_ID == GRUMPKIN
@@ -39,10 +38,6 @@ using namespace grumpkin;
* with the `-DCURVE` env variable passed during build.
*/
namespace curve_config {
#if CURVE_ID == BW6_761
typedef bls12_377::fq_config fp_config;
#endif
/**
* Scalar field of the curve. Is always a prime field.
*/

View File

@@ -33,4 +33,4 @@ public:
os << "x: " << point.x << "; y: " << point.y;
return os;
}
};
};

View File

@@ -2,7 +2,7 @@
#include "field.cuh"
#include "utils/utils.h"
#define scalar_t curve_config::scalar_t
using namespace curve_config;
extern "C" void CONCAT_EXPAND(CURVE, GenerateScalars)(scalar_t* scalars, int size)
{

View File

@@ -680,7 +680,7 @@ public:
HOST_DEVICE_INLINE uint32_t* export_limbs() { return (uint32_t*)limbs_storage.limbs; }
HOST_DEVICE_INLINE unsigned get_scalar_digit(unsigned digit_num, unsigned digit_width)
HOST_DEVICE_INLINE unsigned get_scalar_digit(unsigned digit_num, unsigned digit_width) const
{
const uint32_t limb_lsb_idx = (digit_num * digit_width) / 32;
const uint32_t shift_bits = (digit_num * digit_width) % 32;

View File

@@ -8,6 +8,9 @@ class Projective
friend Affine<FF>;
public:
static constexpr unsigned SCALAR_FF_NBITS = SCALAR_FF::NBITS;
static constexpr unsigned FF_NBITS = FF::NBITS;
FF x;
FF y;
FF z;
@@ -36,6 +39,34 @@ public:
static HOST_DEVICE_INLINE Projective neg(const Projective& point) { return {point.x, FF::neg(point.y), point.z}; }
static HOST_DEVICE_INLINE Projective dbl(const Projective& point)
{
const FF X = point.x;
const FF Y = point.y;
const FF Z = point.z;
// TODO: Change to efficient dbl once implemented for field.cuh
FF t0 = FF::sqr(Y); // 1. t0 ← Y · Y
FF Z3 = t0 + t0; // 2. Z3 ← t0 + t0
Z3 = Z3 + Z3; // 3. Z3 ← Z3 + Z3
Z3 = Z3 + Z3; // 4. Z3 ← Z3 + Z3
FF t1 = Y * Z; // 5. t1 ← Y · Z
FF t2 = FF::sqr(Z); // 6. t2 ← Z · Z
t2 = FF::template mul_unsigned<3>(FF::template mul_const<B_VALUE>(t2)); // 7. t2 ← b3 · t2
FF X3 = t2 * Z3; // 8. X3 ← t2 · Z3
FF Y3 = t0 + t2; // 9. Y3 ← t0 + t2
Z3 = t1 * Z3; // 10. Z3 ← t1 · Z3
t1 = t2 + t2; // 11. t1 ← t2 + t2
t2 = t1 + t2; // 12. t2 ← t1 + t2
t0 = t0 - t2; // 13. t0 ← t0 t2
Y3 = t0 * Y3; // 14. Y3 ← t0 · Y3
Y3 = X3 + Y3; // 15. Y3 ← X3 + Y3
t1 = X * Y; // 16. t1 ← X · Y
X3 = t0 * t1; // 17. X3 ← t0 · t1
X3 = X3 + X3; // 18. X3 ← X3 + X3
return {X3, Y3, Z3};
}
friend HOST_DEVICE_INLINE Projective operator+(Projective p1, const Projective& p2)
{
const FF X1 = p1.x; // < 2

View File

@@ -9,14 +9,14 @@ namespace mont {
#define MAX_THREADS_PER_BLOCK 256
template <typename E, bool is_into>
__global__ void MontgomeryKernel(E* input, int n, E* output)
__global__ void MontgomeryKernel(const E* input, int n, E* output)
{
int tid = blockIdx.x * blockDim.x + threadIdx.x;
if (tid < n) { output[tid] = is_into ? E::ToMontgomery(input[tid]) : E::FromMontgomery(input[tid]); }
}
template <typename E, bool is_into>
cudaError_t ConvertMontgomery(E* d_input, int n, cudaStream_t stream, E* d_output)
cudaError_t ConvertMontgomery(const E* d_input, int n, cudaStream_t stream, E* d_output)
{
// Set the grid and block dimensions
int num_threads = MAX_THREADS_PER_BLOCK;
@@ -29,13 +29,13 @@ namespace mont {
} // namespace
template <typename E>
cudaError_t ToMontgomery(E* d_input, int n, cudaStream_t stream, E* d_output)
cudaError_t ToMontgomery(const E* d_input, int n, cudaStream_t stream, E* d_output)
{
return ConvertMontgomery<E, true>(d_input, n, stream, d_output);
}
template <typename E>
cudaError_t FromMontgomery(E* d_input, int n, cudaStream_t stream, E* d_output)
cudaError_t FromMontgomery(const E* d_input, int n, cudaStream_t stream, E* d_output)
{
return ConvertMontgomery<E, false>(d_input, n, stream, d_output);
}

View File

@@ -12,7 +12,7 @@ namespace utils_internal {
template <typename E, typename S>
__global__ void BatchMulKernel(
E* in_vec,
const E* in_vec,
int n_elements,
int batch_size,
S* scalar_vec,

View File

@@ -95,7 +95,6 @@ namespace vec_ops {
*/
template <typename E>
cudaError_t Sub(E* vec_a, E* vec_b, int n, VecOpsConfig<E>& config, E* result);
} // namespace vec_ops
#endif

View File

@@ -1,12 +1,18 @@
#!/bin/bash
G2_DEFINED=OFF
ECNTT_DEFINED=OFF
if [[ $2 ]]
if [[ $2 == "ON" ]]
then
G2_DEFINED=ON
fi
if [[ $3 ]]
then
ECNTT_DEFINED=ON
fi
BUILD_DIR=$(realpath "$PWD/../../icicle/build")
SUPPORTED_CURVES=("bn254" "bls12_377" "bls12_381" "bw6_761")
@@ -22,6 +28,6 @@ mkdir -p build
for CURVE in "${BUILD_CURVES[@]}"
do
cmake -DCURVE=$CURVE -DG2_DEFINED=$G2_DEFINED -DCMAKE_BUILD_TYPE=Release -S . -B build
cmake --build build
cmake -DCURVE=$CURVE -DG2_DEFINED=$G2_DEFINED -DECNTT_DEFINED=$ECNTT_DEFINED -DCMAKE_BUILD_TYPE=Release -S . -B build
cmake --build build -j8
done

View File

@@ -76,7 +76,7 @@ func GetDefaultMSMConfig() MSMConfig {
}
func MsmCheck(scalars HostOrDeviceSlice, points HostOrDeviceSlice, cfg *MSMConfig, results HostOrDeviceSlice) {
scalarsLength, pointsLength, resultsLength := scalars.Len(), points.Len(), results.Len()
scalarsLength, pointsLength, resultsLength := scalars.Len(), points.Len()/int(cfg.PrecomputeFactor), results.Len()
if scalarsLength%pointsLength != 0 {
errorString := fmt.Sprintf(
"Number of points %d does not divide the number of scalars %d",
@@ -99,3 +99,15 @@ func MsmCheck(scalars HostOrDeviceSlice, points HostOrDeviceSlice, cfg *MSMConfi
cfg.arePointsOnDevice = points.IsOnDevice()
cfg.areResultsOnDevice = results.IsOnDevice()
}
func PrecomputeBasesCheck(points HostOrDeviceSlice, precomputeFactor int32, outputBases DeviceSlice) {
outputBasesLength, pointsLength := outputBases.Len(), points.Len()
if outputBasesLength != pointsLength*int(precomputeFactor) {
errorString := fmt.Sprintf(
"Precompute factor is probably incorrect: expected %d but got %d",
outputBasesLength/pointsLength,
precomputeFactor,
)
panic(errorString)
}
}

View File

@@ -10,18 +10,26 @@ type NTTDir int8
const (
KForward NTTDir = iota
KInverse NTTDir = 1
KInverse
)
type Ordering uint32
const (
KNN Ordering = iota
KNR Ordering = 1
KRN Ordering = 2
KRR Ordering = 3
KNM Ordering = 4
KMN Ordering = 5
KNR
KRN
KRR
KNM
KMN
)
type NttAlgorithm uint32
const (
Auto NttAlgorithm = iota
Radix2
MixedRadix
)
type NTTConfig[T any] struct {
@@ -31,13 +39,17 @@ type NTTConfig[T any] struct {
CosetGen T
/// The number of NTTs to compute. Default value: 1.
BatchSize int32
/// If true the function will compute the NTTs over the columns of the input matrix and not over the rows.
ColumnsBatch bool
/// Ordering of inputs and outputs. See [Ordering](@ref Ordering). Default value: `Ordering::kNN`.
Ordering Ordering
areInputsOnDevice bool
areOutputsOnDevice bool
/// Whether to run the NTT asynchronously. If set to `true`, the NTT function will be non-blocking and you'd need to synchronize
/// it explicitly by running `stream.synchronize()`. If set to false, the NTT function will block the current CPU thread.
IsAsync bool
IsAsync bool
NttAlgorithm NttAlgorithm /**< Explicitly select the NTT algorithm. Default value: Auto (the implementation
selects radix-2 or mixed-radix algorithm based on heuristics). */
}
func GetDefaultNTTConfig[T any](cosetGen T) NTTConfig[T] {
@@ -46,10 +58,12 @@ func GetDefaultNTTConfig[T any](cosetGen T) NTTConfig[T] {
ctx, // Ctx
cosetGen, // CosetGen
1, // BatchSize
false, // ColumnsBatch
KNN, // Ordering
false, // areInputsOnDevice
false, // areOutputsOnDevice
false, // IsAsync
Auto,
}
}

View File

@@ -19,10 +19,12 @@ func TestNTTDefaultConfig(t *testing.T) {
ctx, // Ctx
cosetGen, // CosetGen
1, // BatchSize
false, // ColumnsBatch
KNN, // Ordering
false, // areInputsOnDevice
false, // areOutputsOnDevice
false, // IsAsync
Auto, // NttAlgorithm
}
actual := GetDefaultNTTConfig(cosetGen)

View File

@@ -43,8 +43,17 @@ func (d DeviceSlice) IsOnDevice() bool {
return true
}
// TODO: change signature to be Malloc(element, numElements)
// calc size internally
func (d DeviceSlice) GetDeviceId() int {
return cr.GetDeviceFromPointer(d.inner)
}
// CheckDevice is used to ensure that the DeviceSlice about to be used resides on the currently set device
func (d DeviceSlice) CheckDevice() {
if currentDeviceId, err := cr.GetDevice(); err != cr.CudaSuccess || d.GetDeviceId() != currentDeviceId {
panic("Attempt to use DeviceSlice on a different device")
}
}
func (d *DeviceSlice) Malloc(size, sizeOfElement int) (DeviceSlice, cr.CudaError) {
dp, err := cr.Malloc(uint(size))
d.inner = dp
@@ -62,6 +71,7 @@ func (d *DeviceSlice) MallocAsync(size, sizeOfElement int, stream cr.CudaStream)
}
func (d *DeviceSlice) Free() cr.CudaError {
d.CheckDevice()
err := cr.Free(d.inner)
if err == cr.CudaSuccess {
d.length, d.capacity = 0, 0
@@ -70,6 +80,16 @@ func (d *DeviceSlice) Free() cr.CudaError {
return err
}
func (d *DeviceSlice) FreeAsync(stream cr.Stream) cr.CudaError {
d.CheckDevice()
err := cr.FreeAsync(d.inner, stream)
if err == cr.CudaSuccess {
d.length, d.capacity = 0, 0
d.inner = nil
}
return err
}
type HostSliceInterface interface {
Size() int
}
@@ -117,6 +137,7 @@ func (h HostSlice[T]) CopyToDevice(dst *DeviceSlice, shouldAllocate bool) *Devic
if shouldAllocate {
dst.Malloc(size, h.SizeOfElement())
}
dst.CheckDevice()
if size > dst.Cap() {
panic("Number of bytes to copy is too large for destination")
}
@@ -133,6 +154,7 @@ func (h HostSlice[T]) CopyToDeviceAsync(dst *DeviceSlice, stream cr.CudaStream,
if shouldAllocate {
dst.MallocAsync(size, h.SizeOfElement(), stream)
}
dst.CheckDevice()
if size > dst.Cap() {
panic("Number of bytes to copy is too large for destination")
}
@@ -144,6 +166,7 @@ func (h HostSlice[T]) CopyToDeviceAsync(dst *DeviceSlice, stream cr.CudaStream,
}
func (h HostSlice[T]) CopyFromDevice(src *DeviceSlice) {
src.CheckDevice()
if h.Len() != src.Len() {
panic("destination and source slices have different lengths")
}
@@ -152,6 +175,7 @@ func (h HostSlice[T]) CopyFromDevice(src *DeviceSlice) {
}
func (h HostSlice[T]) CopyFromDeviceAsync(src *DeviceSlice, stream cr.Stream) {
src.CheckDevice()
if h.Len() != src.Len() {
panic("destination and source slices have different lengths")
}

View File

@@ -9,6 +9,8 @@ package cuda_runtime
*/
import "C"
import (
"fmt"
"runtime"
"unsafe"
)
@@ -17,20 +19,28 @@ type DeviceContext struct {
Stream *Stream // Assuming the type is provided by a CUDA binding crate
/// Index of the currently used GPU. Default value: 0.
DeviceId uint
deviceId uint
/// Mempool to use. Default value: 0.
// TODO: use cuda_bindings.CudaMemPool as type
Mempool uint // Assuming the type is provided by a CUDA binding crate
Mempool MemPool // Assuming the type is provided by a CUDA binding crate
}
func (d DeviceContext) GetDeviceId() int {
return int(d.deviceId)
}
func GetDefaultDeviceContext() (DeviceContext, CudaError) {
device, err := GetDevice()
if err != CudaSuccess {
panic(fmt.Sprintf("Could not get current device due to %v", err))
}
var defaultStream Stream
var defaultMempool MemPool
return DeviceContext{
&defaultStream,
0,
0,
uint(device),
defaultMempool,
}, CudaSuccess
}
@@ -47,3 +57,78 @@ func GetDeviceCount() (int, CudaError) {
err := C.cudaGetDeviceCount(cCount)
return count, (CudaError)(err)
}
func GetDevice() (int, CudaError) {
var device int
cDevice := (*C.int)(unsafe.Pointer(&device))
err := C.cudaGetDevice(cDevice)
return device, (CudaError)(err)
}
func GetDeviceFromPointer(ptr unsafe.Pointer) int {
var cCudaPointerAttributes CudaPointerAttributes
err := C.cudaPointerGetAttributes(&cCudaPointerAttributes, ptr)
if (CudaError)(err) != CudaSuccess {
panic("Could not get attributes of pointer")
}
return int(cCudaPointerAttributes.device)
}
// RunOnDevice forces the provided function to run all GPU related calls within it
// on the same host thread and therefore the same GPU device.
//
// NOTE: Goroutines launched within funcToRun are not bound to the
// same host thread as funcToRun and therefore not to the same GPU device.
// If that is a requirement, RunOnDevice should be called for each with the
// same deviceId as the original call.
//
// As an example:
//
// cr.RunOnDevice(i, func(args ...any) {
// defer wg.Done()
// cfg := GetDefaultMSMConfig()
// stream, _ := cr.CreateStream()
// for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
// size := 1 << power
//
// // This will always print "Inner goroutine device: 0"
// // go func () {
// // device, _ := cr.GetDevice()
// // fmt.Println("Inner goroutine device: ", device)
// // }()
// // To force the above goroutine to same device as the wrapping function:
// // RunOnDevice(i, func(arg ...any) {
// // device, _ := cr.GetDevice()
// // fmt.Println("Inner goroutine device: ", device)
// // })
//
// scalars := GenerateScalars(size)
// points := GenerateAffinePoints(size)
//
// var p Projective
// var out core.DeviceSlice
// _, e := out.MallocAsync(p.Size(), p.Size(), stream)
// assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
// cfg.Ctx.Stream = &stream
// cfg.IsAsync = true
//
// e = Msm(scalars, points, &cfg, out)
// assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
//
// outHost := make(core.HostSlice[Projective], 1)
//
// cr.SynchronizeStream(&stream)
// outHost.CopyFromDevice(&out)
// out.Free()
// // Check with gnark-crypto
// assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
// }
// }, i)
func RunOnDevice(deviceId int, funcToRun func(args ...any), args ...any) {
go func(id int) {
defer runtime.UnlockOSThread()
runtime.LockOSThread()
SetDevice(id)
funcToRun(args...)
}(deviceId)
}

View File

@@ -12,6 +12,8 @@ import (
"unsafe"
)
type MemPool = CudaMemPool
func Malloc(size uint) (unsafe.Pointer, CudaError) {
if size == 0 {
return nil, CudaErrorMemoryAllocation

View File

@@ -17,3 +17,6 @@ type CudaEvent C.cudaEvent_t
// CudaMemPool as declared in include/driver_types.h:2928
type CudaMemPool C.cudaMemPool_t
// CudaMemPool as declared in include/driver_types.h:2928
type CudaPointerAttributes = C.struct_cudaPointerAttributes

View File

@@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud
}
func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertAffinePointsMontgomery(points, true)
}
func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertAffinePointsMontgomery(points, false)
}
@@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr
}
func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertProjectivePointsMontgomery(points, true)
}
func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertProjectivePointsMontgomery(points, false)
}

View File

@@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C
}
func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2AffinePointsMontgomery(points, true)
}
func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2AffinePointsMontgomery(points, false)
}
@@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool)
}
func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2ProjectivePointsMontgomery(points, true)
}
func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2ProjectivePointsMontgomery(points, false)
}

View File

@@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
core.MsmCheck(scalars, points, cfg, results)
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
pointsDevice := points.(core.DeviceSlice)
pointsDevice.CheckDevice()
pointsPointer = pointsDevice.AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
}
@@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0])
}
@@ -49,3 +55,28 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
err := (cr.CudaError)(__ret)
return err
}
func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
}
cPoints := (*C.g2_affine_t)(pointsPointer)
cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precomputeFactor)
cC := (C.int)(c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
outputBasesPointer := outputBases.AsPointer()
cOutputBases := (*C.g2_affine_t)(outputBasesPointer)
__ret := C.bls12_377G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}

View File

@@ -3,9 +3,12 @@
package bls12377
import (
"github.com/stretchr/testify/assert"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
@@ -82,6 +85,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor
func TestMSMG2(t *testing.T) {
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
@@ -94,12 +98,14 @@ func TestMSMG2(t *testing.T) {
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = G2Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[G2Projective], 1)
outHost.CopyFromDevice(&out)
out.Free()
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
}
@@ -136,6 +142,48 @@ func TestMSMG2Batch(t *testing.T) {
}
}
func TestPrecomputeBaseG2(t *testing.T) {
cfg := GetDefaultMSMConfig()
const precomputeFactor = 8
for _, power := range []int{10, 16} {
for _, batchSize := range []int{1, 3, 16} {
size := 1 << power
totalSize := size * batchSize
scalars := GenerateScalars(totalSize)
points := G2GenerateAffinePoints(totalSize)
var precomputeOut core.DeviceSlice
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
var p G2Projective
var out core.DeviceSlice
_, e = out.Malloc(batchSize*p.Size(), p.Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.PrecomputeFactor = precomputeFactor
e = G2Msm(scalars, precomputeOut, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[G2Projective], batchSize)
outHost.CopyFromDevice(&out)
out.Free()
precomputeOut.Free()
// Check with gnark-crypto
for i := 0; i < batchSize; i++ {
scalarsSlice := scalars[i*size : (i+1)*size]
pointsSlice := points[i*size : (i+1)*size]
out := outHost[i]
assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out))
}
}
}
}
func TestMSMG2SkewedDistribution(t *testing.T) {
cfg := GetDefaultMSMConfig()
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
@@ -165,3 +213,43 @@ func TestMSMG2SkewedDistribution(t *testing.T) {
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
}
}
func TestMSMG2MultiDevice(t *testing.T) {
numDevices, _ := cr.GetDeviceCount()
fmt.Println("There are ", numDevices, " devices available")
orig_device, _ := cr.GetDevice()
wg := sync.WaitGroup{}
for i := 0; i < numDevices; i++ {
wg.Add(1)
cr.RunOnDevice(i, func(args ...any) {
defer wg.Done()
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
scalars := GenerateScalars(size)
points := G2GenerateAffinePoints(size)
stream, _ := cr.CreateStream()
var p G2Projective
var out core.DeviceSlice
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = G2Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[G2Projective], 1)
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
}
})
}
wg.Wait()
cr.SetDevice(orig_device)
}

View File

@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>
#ifndef _BLS12_377_G2MSM_H
#define _BLS12_377_G2MSM_H
@@ -8,7 +9,8 @@
extern "C" {
#endif
cudaError_t bls12_377G2MSMCuda(scalar_t* scalars, g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
cudaError_t bls12_377G2MSMCuda(const scalar_t* scalars,const g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
cudaError_t bls12_377G2PrecomputeMSMBases(g2_affine_t* points, int count, int precompute_factor, int _c, bool bases_on_device, DeviceContext* ctx, g2_affine_t* out);
#ifdef __cplusplus
}

View File

@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>
#ifndef _BLS12_377_MSM_H
#define _BLS12_377_MSM_H
@@ -8,7 +9,8 @@
extern "C" {
#endif
cudaError_t bls12_377MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t bls12_377MSMCuda(const scalar_t* scalars, const affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t bls12_377PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);
#ifdef __cplusplus
}

View File

@@ -9,11 +9,12 @@
extern "C" {
#endif
cudaError_t bls12_377NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bls12_377NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bls12_377ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t bls12_377InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
#ifdef __cplusplus
}
#endif
#endif
#endif

View File

@@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
core.MsmCheck(scalars, points, cfg, results)
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
pointsDevice := points.(core.DeviceSlice)
pointsDevice.CheckDevice()
pointsPointer = pointsDevice.AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
}
@@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
}
@@ -47,3 +53,28 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
err := (cr.CudaError)(__ret)
return err
}
func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
}
cPoints := (*C.affine_t)(pointsPointer)
cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precomputeFactor)
cC := (C.int)(c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
outputBasesPointer := outputBases.AsPointer()
cOutputBases := (*C.affine_t)(outputBasesPointer)
__ret := C.bls12_377PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}

View File

@@ -1,9 +1,12 @@
package bls12377
import (
"github.com/stretchr/testify/assert"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
@@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core.
func TestMSM(t *testing.T) {
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
@@ -65,12 +69,14 @@ func TestMSM(t *testing.T) {
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[Projective], 1)
outHost.CopyFromDevice(&out)
out.Free()
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
}
@@ -107,6 +113,48 @@ func TestMSMBatch(t *testing.T) {
}
}
func TestPrecomputeBase(t *testing.T) {
cfg := GetDefaultMSMConfig()
const precomputeFactor = 8
for _, power := range []int{10, 16} {
for _, batchSize := range []int{1, 3, 16} {
size := 1 << power
totalSize := size * batchSize
scalars := GenerateScalars(totalSize)
points := GenerateAffinePoints(totalSize)
var precomputeOut core.DeviceSlice
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
var p Projective
var out core.DeviceSlice
_, e = out.Malloc(batchSize*p.Size(), p.Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.PrecomputeFactor = precomputeFactor
e = Msm(scalars, precomputeOut, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[Projective], batchSize)
outHost.CopyFromDevice(&out)
out.Free()
precomputeOut.Free()
// Check with gnark-crypto
for i := 0; i < batchSize; i++ {
scalarsSlice := scalars[i*size : (i+1)*size]
pointsSlice := points[i*size : (i+1)*size]
out := outHost[i]
assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out))
}
}
}
}
func TestMSMSkewedDistribution(t *testing.T) {
cfg := GetDefaultMSMConfig()
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
@@ -136,3 +184,43 @@ func TestMSMSkewedDistribution(t *testing.T) {
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
}
}
func TestMSMMultiDevice(t *testing.T) {
numDevices, _ := cr.GetDeviceCount()
fmt.Println("There are ", numDevices, " devices available")
orig_device, _ := cr.GetDevice()
wg := sync.WaitGroup{}
for i := 0; i < numDevices; i++ {
wg.Add(1)
cr.RunOnDevice(i, func(args ...any) {
defer wg.Done()
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
scalars := GenerateScalars(size)
points := GenerateAffinePoints(size)
stream, _ := cr.CreateStream()
var p Projective
var out core.DeviceSlice
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[Projective], 1)
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
}
})
}
wg.Wait()
cr.SetDevice(orig_device)
}

View File

@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
}
@@ -49,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
return core.FromCudaError(err)
}
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
core.NttCheck[T](points, cfg, results)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
}
cPoints := (*C.projective_t)(pointsPointer)
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
cDir := (C.int)(dir)
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
}
cResults := (*C.projective_t)(resultsPointer)
__ret := C.bls12_377ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))

View File

@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
}
}
func TestECNtt(t *testing.T) {
cfg := GetDefaultNttConfig()
initDomain(largestTestSize, cfg)
points := GenerateProjectivePoints(1 << largestTestSize)
for _, size := range []int{4, 5, 6, 7, 8} {
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
testSize := 1 << size
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
cfg.Ordering = v
cfg.NttAlgorithm = core.Radix2
output := make(core.HostSlice[Projective], testSize)
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
}
}
}
func TestNttDeviceAsync(t *testing.T) {
cfg := GetDefaultNttConfig()
scalars := GenerateScalars(1 << largestTestSize)

View File

@@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
}
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
scalars.CheckDevice()
return convertScalarsMontgomery(scalars, true)
}
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
scalars.CheckDevice()
return convertScalarsMontgomery(scalars, false)
}

View File

@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
var cA, cB, cOut *C.scalar_t
if a.IsOnDevice() {
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
aDevice := a.(core.DeviceSlice)
aDevice.CheckDevice()
cA = (*C.scalar_t)(aDevice.AsPointer())
} else {
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
}
if b.IsOnDevice() {
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
bDevice := b.(core.DeviceSlice)
bDevice.CheckDevice()
cB = (*C.scalar_t)(bDevice.AsPointer())
} else {
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
}
if out.IsOnDevice() {
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
outDevice := out.(core.DeviceSlice)
outDevice.CheckDevice()
cOut = (*C.scalar_t)(outDevice.AsPointer())
} else {
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
}

View File

@@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud
}
func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertAffinePointsMontgomery(points, true)
}
func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertAffinePointsMontgomery(points, false)
}
@@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr
}
func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertProjectivePointsMontgomery(points, true)
}
func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertProjectivePointsMontgomery(points, false)
}

View File

@@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C
}
func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2AffinePointsMontgomery(points, true)
}
func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2AffinePointsMontgomery(points, false)
}
@@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool)
}
func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2ProjectivePointsMontgomery(points, true)
}
func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2ProjectivePointsMontgomery(points, false)
}

View File

@@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
core.MsmCheck(scalars, points, cfg, results)
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
pointsDevice := points.(core.DeviceSlice)
pointsDevice.CheckDevice()
pointsPointer = pointsDevice.AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
}
@@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0])
}
@@ -49,3 +55,28 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
err := (cr.CudaError)(__ret)
return err
}
func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
}
cPoints := (*C.g2_affine_t)(pointsPointer)
cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precomputeFactor)
cC := (C.int)(c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
outputBasesPointer := outputBases.AsPointer()
cOutputBases := (*C.g2_affine_t)(outputBasesPointer)
__ret := C.bls12_381G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}

View File

@@ -3,9 +3,12 @@
package bls12381
import (
"github.com/stretchr/testify/assert"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
@@ -82,6 +85,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor
func TestMSMG2(t *testing.T) {
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
@@ -94,12 +98,14 @@ func TestMSMG2(t *testing.T) {
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = G2Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[G2Projective], 1)
outHost.CopyFromDevice(&out)
out.Free()
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
}
@@ -136,6 +142,48 @@ func TestMSMG2Batch(t *testing.T) {
}
}
func TestPrecomputeBaseG2(t *testing.T) {
cfg := GetDefaultMSMConfig()
const precomputeFactor = 8
for _, power := range []int{10, 16} {
for _, batchSize := range []int{1, 3, 16} {
size := 1 << power
totalSize := size * batchSize
scalars := GenerateScalars(totalSize)
points := G2GenerateAffinePoints(totalSize)
var precomputeOut core.DeviceSlice
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
var p G2Projective
var out core.DeviceSlice
_, e = out.Malloc(batchSize*p.Size(), p.Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.PrecomputeFactor = precomputeFactor
e = G2Msm(scalars, precomputeOut, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[G2Projective], batchSize)
outHost.CopyFromDevice(&out)
out.Free()
precomputeOut.Free()
// Check with gnark-crypto
for i := 0; i < batchSize; i++ {
scalarsSlice := scalars[i*size : (i+1)*size]
pointsSlice := points[i*size : (i+1)*size]
out := outHost[i]
assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out))
}
}
}
}
func TestMSMG2SkewedDistribution(t *testing.T) {
cfg := GetDefaultMSMConfig()
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
@@ -165,3 +213,43 @@ func TestMSMG2SkewedDistribution(t *testing.T) {
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
}
}
func TestMSMG2MultiDevice(t *testing.T) {
numDevices, _ := cr.GetDeviceCount()
fmt.Println("There are ", numDevices, " devices available")
orig_device, _ := cr.GetDevice()
wg := sync.WaitGroup{}
for i := 0; i < numDevices; i++ {
wg.Add(1)
cr.RunOnDevice(i, func(args ...any) {
defer wg.Done()
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
scalars := GenerateScalars(size)
points := G2GenerateAffinePoints(size)
stream, _ := cr.CreateStream()
var p G2Projective
var out core.DeviceSlice
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = G2Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[G2Projective], 1)
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
}
})
}
wg.Wait()
cr.SetDevice(orig_device)
}

View File

@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>
#ifndef _BLS12_381_G2MSM_H
#define _BLS12_381_G2MSM_H
@@ -8,7 +9,8 @@
extern "C" {
#endif
cudaError_t bls12_381G2MSMCuda(scalar_t* scalars, g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
cudaError_t bls12_381G2MSMCuda(const scalar_t* scalars,const g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
cudaError_t bls12_381G2PrecomputeMSMBases(g2_affine_t* points, int count, int precompute_factor, int _c, bool bases_on_device, DeviceContext* ctx, g2_affine_t* out);
#ifdef __cplusplus
}

View File

@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>
#ifndef _BLS12_381_MSM_H
#define _BLS12_381_MSM_H
@@ -8,7 +9,8 @@
extern "C" {
#endif
cudaError_t bls12_381MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t bls12_381MSMCuda(const scalar_t* scalars, const affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t bls12_381PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);
#ifdef __cplusplus
}

View File

@@ -9,11 +9,12 @@
extern "C" {
#endif
cudaError_t bls12_381NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bls12_381NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bls12_381ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t bls12_381InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
#ifdef __cplusplus
}
#endif
#endif
#endif

View File

@@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
core.MsmCheck(scalars, points, cfg, results)
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
pointsDevice := points.(core.DeviceSlice)
pointsDevice.CheckDevice()
pointsPointer = pointsDevice.AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
}
@@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
}
@@ -47,3 +53,28 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
err := (cr.CudaError)(__ret)
return err
}
func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
}
cPoints := (*C.affine_t)(pointsPointer)
cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precomputeFactor)
cC := (C.int)(c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
outputBasesPointer := outputBases.AsPointer()
cOutputBases := (*C.affine_t)(outputBasesPointer)
__ret := C.bls12_381PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}

View File

@@ -1,9 +1,12 @@
package bls12381
import (
"github.com/stretchr/testify/assert"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
@@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core.
func TestMSM(t *testing.T) {
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
@@ -65,12 +69,14 @@ func TestMSM(t *testing.T) {
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[Projective], 1)
outHost.CopyFromDevice(&out)
out.Free()
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
}
@@ -107,6 +113,48 @@ func TestMSMBatch(t *testing.T) {
}
}
func TestPrecomputeBase(t *testing.T) {
cfg := GetDefaultMSMConfig()
const precomputeFactor = 8
for _, power := range []int{10, 16} {
for _, batchSize := range []int{1, 3, 16} {
size := 1 << power
totalSize := size * batchSize
scalars := GenerateScalars(totalSize)
points := GenerateAffinePoints(totalSize)
var precomputeOut core.DeviceSlice
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
var p Projective
var out core.DeviceSlice
_, e = out.Malloc(batchSize*p.Size(), p.Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.PrecomputeFactor = precomputeFactor
e = Msm(scalars, precomputeOut, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[Projective], batchSize)
outHost.CopyFromDevice(&out)
out.Free()
precomputeOut.Free()
// Check with gnark-crypto
for i := 0; i < batchSize; i++ {
scalarsSlice := scalars[i*size : (i+1)*size]
pointsSlice := points[i*size : (i+1)*size]
out := outHost[i]
assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out))
}
}
}
}
func TestMSMSkewedDistribution(t *testing.T) {
cfg := GetDefaultMSMConfig()
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
@@ -136,3 +184,43 @@ func TestMSMSkewedDistribution(t *testing.T) {
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
}
}
func TestMSMMultiDevice(t *testing.T) {
numDevices, _ := cr.GetDeviceCount()
fmt.Println("There are ", numDevices, " devices available")
orig_device, _ := cr.GetDevice()
wg := sync.WaitGroup{}
for i := 0; i < numDevices; i++ {
wg.Add(1)
cr.RunOnDevice(i, func(args ...any) {
defer wg.Done()
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
scalars := GenerateScalars(size)
points := GenerateAffinePoints(size)
stream, _ := cr.CreateStream()
var p Projective
var out core.DeviceSlice
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[Projective], 1)
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
}
})
}
wg.Wait()
cr.SetDevice(orig_device)
}

View File

@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
}
@@ -49,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
return core.FromCudaError(err)
}
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
core.NttCheck[T](points, cfg, results)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
}
cPoints := (*C.projective_t)(pointsPointer)
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
cDir := (C.int)(dir)
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
}
cResults := (*C.projective_t)(resultsPointer)
__ret := C.bls12_381ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))

View File

@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
}
}
func TestECNtt(t *testing.T) {
cfg := GetDefaultNttConfig()
initDomain(largestTestSize, cfg)
points := GenerateProjectivePoints(1 << largestTestSize)
for _, size := range []int{4, 5, 6, 7, 8} {
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
testSize := 1 << size
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
cfg.Ordering = v
cfg.NttAlgorithm = core.Radix2
output := make(core.HostSlice[Projective], testSize)
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
}
}
}
func TestNttDeviceAsync(t *testing.T) {
cfg := GetDefaultNttConfig()
scalars := GenerateScalars(1 << largestTestSize)

View File

@@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
}
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
scalars.CheckDevice()
return convertScalarsMontgomery(scalars, true)
}
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
scalars.CheckDevice()
return convertScalarsMontgomery(scalars, false)
}

View File

@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
var cA, cB, cOut *C.scalar_t
if a.IsOnDevice() {
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
aDevice := a.(core.DeviceSlice)
aDevice.CheckDevice()
cA = (*C.scalar_t)(aDevice.AsPointer())
} else {
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
}
if b.IsOnDevice() {
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
bDevice := b.(core.DeviceSlice)
bDevice.CheckDevice()
cB = (*C.scalar_t)(bDevice.AsPointer())
} else {
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
}
if out.IsOnDevice() {
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
outDevice := out.(core.DeviceSlice)
outDevice.CheckDevice()
cOut = (*C.scalar_t)(outDevice.AsPointer())
} else {
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
}

View File

@@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud
}
func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertAffinePointsMontgomery(points, true)
}
func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertAffinePointsMontgomery(points, false)
}
@@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr
}
func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertProjectivePointsMontgomery(points, true)
}
func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertProjectivePointsMontgomery(points, false)
}

View File

@@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C
}
func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2AffinePointsMontgomery(points, true)
}
func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2AffinePointsMontgomery(points, false)
}
@@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool)
}
func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2ProjectivePointsMontgomery(points, true)
}
func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2ProjectivePointsMontgomery(points, false)
}

View File

@@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
core.MsmCheck(scalars, points, cfg, results)
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
pointsDevice := points.(core.DeviceSlice)
pointsDevice.CheckDevice()
pointsPointer = pointsDevice.AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
}
@@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0])
}
@@ -49,3 +55,28 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
err := (cr.CudaError)(__ret)
return err
}
func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
}
cPoints := (*C.g2_affine_t)(pointsPointer)
cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precomputeFactor)
cC := (C.int)(c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
outputBasesPointer := outputBases.AsPointer()
cOutputBases := (*C.g2_affine_t)(outputBasesPointer)
__ret := C.bn254G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}

View File

@@ -3,9 +3,12 @@
package bn254
import (
"github.com/stretchr/testify/assert"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
@@ -82,6 +85,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor
func TestMSMG2(t *testing.T) {
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
@@ -94,12 +98,14 @@ func TestMSMG2(t *testing.T) {
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = G2Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[G2Projective], 1)
outHost.CopyFromDevice(&out)
out.Free()
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
}
@@ -136,6 +142,48 @@ func TestMSMG2Batch(t *testing.T) {
}
}
func TestPrecomputeBaseG2(t *testing.T) {
cfg := GetDefaultMSMConfig()
const precomputeFactor = 8
for _, power := range []int{10, 16} {
for _, batchSize := range []int{1, 3, 16} {
size := 1 << power
totalSize := size * batchSize
scalars := GenerateScalars(totalSize)
points := G2GenerateAffinePoints(totalSize)
var precomputeOut core.DeviceSlice
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
var p G2Projective
var out core.DeviceSlice
_, e = out.Malloc(batchSize*p.Size(), p.Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.PrecomputeFactor = precomputeFactor
e = G2Msm(scalars, precomputeOut, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[G2Projective], batchSize)
outHost.CopyFromDevice(&out)
out.Free()
precomputeOut.Free()
// Check with gnark-crypto
for i := 0; i < batchSize; i++ {
scalarsSlice := scalars[i*size : (i+1)*size]
pointsSlice := points[i*size : (i+1)*size]
out := outHost[i]
assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out))
}
}
}
}
func TestMSMG2SkewedDistribution(t *testing.T) {
cfg := GetDefaultMSMConfig()
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
@@ -165,3 +213,43 @@ func TestMSMG2SkewedDistribution(t *testing.T) {
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
}
}
func TestMSMG2MultiDevice(t *testing.T) {
numDevices, _ := cr.GetDeviceCount()
fmt.Println("There are ", numDevices, " devices available")
orig_device, _ := cr.GetDevice()
wg := sync.WaitGroup{}
for i := 0; i < numDevices; i++ {
wg.Add(1)
cr.RunOnDevice(i, func(args ...any) {
defer wg.Done()
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
scalars := GenerateScalars(size)
points := G2GenerateAffinePoints(size)
stream, _ := cr.CreateStream()
var p G2Projective
var out core.DeviceSlice
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = G2Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[G2Projective], 1)
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
}
})
}
wg.Wait()
cr.SetDevice(orig_device)
}

View File

@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>
#ifndef _BN254_G2MSM_H
#define _BN254_G2MSM_H
@@ -8,7 +9,8 @@
extern "C" {
#endif
cudaError_t bn254G2MSMCuda(scalar_t* scalars, g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
cudaError_t bn254G2MSMCuda(const scalar_t* scalars,const g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
cudaError_t bn254G2PrecomputeMSMBases(g2_affine_t* points, int count, int precompute_factor, int _c, bool bases_on_device, DeviceContext* ctx, g2_affine_t* out);
#ifdef __cplusplus
}

View File

@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>
#ifndef _BN254_MSM_H
#define _BN254_MSM_H
@@ -8,7 +9,8 @@
extern "C" {
#endif
cudaError_t bn254MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t bn254MSMCuda(const scalar_t* scalars, const affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t bn254PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);
#ifdef __cplusplus
}

View File

@@ -9,11 +9,12 @@
extern "C" {
#endif
cudaError_t bn254NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bn254NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bn254ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t bn254InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
#ifdef __cplusplus
}
#endif
#endif
#endif

View File

@@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
core.MsmCheck(scalars, points, cfg, results)
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
pointsDevice := points.(core.DeviceSlice)
pointsDevice.CheckDevice()
pointsPointer = pointsDevice.AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
}
@@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
}
@@ -47,3 +53,28 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
err := (cr.CudaError)(__ret)
return err
}
func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
}
cPoints := (*C.affine_t)(pointsPointer)
cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precomputeFactor)
cC := (C.int)(c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
outputBasesPointer := outputBases.AsPointer()
cOutputBases := (*C.affine_t)(outputBasesPointer)
__ret := C.bn254PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}

View File

@@ -1,9 +1,12 @@
package bn254
import (
"github.com/stretchr/testify/assert"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
@@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core.
func TestMSM(t *testing.T) {
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
@@ -65,12 +69,14 @@ func TestMSM(t *testing.T) {
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[Projective], 1)
outHost.CopyFromDevice(&out)
out.Free()
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
}
@@ -107,6 +113,48 @@ func TestMSMBatch(t *testing.T) {
}
}
func TestPrecomputeBase(t *testing.T) {
cfg := GetDefaultMSMConfig()
const precomputeFactor = 8
for _, power := range []int{10, 16} {
for _, batchSize := range []int{1, 3, 16} {
size := 1 << power
totalSize := size * batchSize
scalars := GenerateScalars(totalSize)
points := GenerateAffinePoints(totalSize)
var precomputeOut core.DeviceSlice
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
var p Projective
var out core.DeviceSlice
_, e = out.Malloc(batchSize*p.Size(), p.Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.PrecomputeFactor = precomputeFactor
e = Msm(scalars, precomputeOut, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[Projective], batchSize)
outHost.CopyFromDevice(&out)
out.Free()
precomputeOut.Free()
// Check with gnark-crypto
for i := 0; i < batchSize; i++ {
scalarsSlice := scalars[i*size : (i+1)*size]
pointsSlice := points[i*size : (i+1)*size]
out := outHost[i]
assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out))
}
}
}
}
func TestMSMSkewedDistribution(t *testing.T) {
cfg := GetDefaultMSMConfig()
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
@@ -136,3 +184,43 @@ func TestMSMSkewedDistribution(t *testing.T) {
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
}
}
func TestMSMMultiDevice(t *testing.T) {
numDevices, _ := cr.GetDeviceCount()
fmt.Println("There are ", numDevices, " devices available")
orig_device, _ := cr.GetDevice()
wg := sync.WaitGroup{}
for i := 0; i < numDevices; i++ {
wg.Add(1)
cr.RunOnDevice(i, func(args ...any) {
defer wg.Done()
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
scalars := GenerateScalars(size)
points := GenerateAffinePoints(size)
stream, _ := cr.CreateStream()
var p Projective
var out core.DeviceSlice
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[Projective], 1)
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
}
})
}
wg.Wait()
cr.SetDevice(orig_device)
}

View File

@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
}
@@ -49,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
return core.FromCudaError(err)
}
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
core.NttCheck[T](points, cfg, results)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
}
cPoints := (*C.projective_t)(pointsPointer)
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
cDir := (C.int)(dir)
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
}
cResults := (*C.projective_t)(resultsPointer)
__ret := C.bn254ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))

View File

@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
}
}
func TestECNtt(t *testing.T) {
cfg := GetDefaultNttConfig()
initDomain(largestTestSize, cfg)
points := GenerateProjectivePoints(1 << largestTestSize)
for _, size := range []int{4, 5, 6, 7, 8} {
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
testSize := 1 << size
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
cfg.Ordering = v
cfg.NttAlgorithm = core.Radix2
output := make(core.HostSlice[Projective], testSize)
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
}
}
}
func TestNttDeviceAsync(t *testing.T) {
cfg := GetDefaultNttConfig()
scalars := GenerateScalars(1 << largestTestSize)

View File

@@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
}
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
scalars.CheckDevice()
return convertScalarsMontgomery(scalars, true)
}
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
scalars.CheckDevice()
return convertScalarsMontgomery(scalars, false)
}

View File

@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
var cA, cB, cOut *C.scalar_t
if a.IsOnDevice() {
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
aDevice := a.(core.DeviceSlice)
aDevice.CheckDevice()
cA = (*C.scalar_t)(aDevice.AsPointer())
} else {
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
}
if b.IsOnDevice() {
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
bDevice := b.(core.DeviceSlice)
bDevice.CheckDevice()
cB = (*C.scalar_t)(bDevice.AsPointer())
} else {
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
}
if out.IsOnDevice() {
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
outDevice := out.(core.DeviceSlice)
outDevice.CheckDevice()
cOut = (*C.scalar_t)(outDevice.AsPointer())
} else {
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
}

View File

@@ -146,10 +146,12 @@ func convertAffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.Cud
}
func AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertAffinePointsMontgomery(points, true)
}
func AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertAffinePointsMontgomery(points, false)
}
@@ -165,9 +167,11 @@ func convertProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool) cr
}
func ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertProjectivePointsMontgomery(points, true)
}
func ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertProjectivePointsMontgomery(points, false)
}

View File

@@ -148,10 +148,12 @@ func convertG2AffinePointsMontgomery(points *core.DeviceSlice, isInto bool) cr.C
}
func G2AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2AffinePointsMontgomery(points, true)
}
func G2AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2AffinePointsMontgomery(points, false)
}
@@ -167,9 +169,11 @@ func convertG2ProjectivePointsMontgomery(points *core.DeviceSlice, isInto bool)
}
func G2ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2ProjectivePointsMontgomery(points, true)
}
func G2ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convertG2ProjectivePointsMontgomery(points, false)
}

View File

@@ -20,7 +20,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
core.MsmCheck(scalars, points, cfg, results)
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -28,7 +30,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
pointsDevice := points.(core.DeviceSlice)
pointsDevice.CheckDevice()
pointsPointer = pointsDevice.AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
}
@@ -36,7 +40,9 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[G2Projective])[0])
}
@@ -49,3 +55,28 @@ func G2Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *c
err := (cr.CudaError)(__ret)
return err
}
func G2PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[G2Affine])[0])
}
cPoints := (*C.g2_affine_t)(pointsPointer)
cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precomputeFactor)
cC := (C.int)(c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
outputBasesPointer := outputBases.AsPointer()
cOutputBases := (*C.g2_affine_t)(outputBasesPointer)
__ret := C.bw6_761G2PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}

View File

@@ -3,9 +3,12 @@
package bw6761
import (
"github.com/stretchr/testify/assert"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
@@ -55,6 +58,7 @@ func testAgainstGnarkCryptoMsmG2(scalars core.HostSlice[ScalarField], points cor
func TestMSMG2(t *testing.T) {
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
@@ -67,12 +71,14 @@ func TestMSMG2(t *testing.T) {
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = G2Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[G2Projective], 1)
outHost.CopyFromDevice(&out)
out.Free()
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
}
@@ -109,6 +115,48 @@ func TestMSMG2Batch(t *testing.T) {
}
}
func TestPrecomputeBaseG2(t *testing.T) {
cfg := GetDefaultMSMConfig()
const precomputeFactor = 8
for _, power := range []int{10, 16} {
for _, batchSize := range []int{1, 3, 16} {
size := 1 << power
totalSize := size * batchSize
scalars := GenerateScalars(totalSize)
points := G2GenerateAffinePoints(totalSize)
var precomputeOut core.DeviceSlice
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
e = G2PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
var p G2Projective
var out core.DeviceSlice
_, e = out.Malloc(batchSize*p.Size(), p.Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.PrecomputeFactor = precomputeFactor
e = G2Msm(scalars, precomputeOut, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[G2Projective], batchSize)
outHost.CopyFromDevice(&out)
out.Free()
precomputeOut.Free()
// Check with gnark-crypto
for i := 0; i < batchSize; i++ {
scalarsSlice := scalars[i*size : (i+1)*size]
pointsSlice := points[i*size : (i+1)*size]
out := outHost[i]
assert.True(t, testAgainstGnarkCryptoMsmG2(scalarsSlice, pointsSlice, out))
}
}
}
}
func TestMSMG2SkewedDistribution(t *testing.T) {
cfg := GetDefaultMSMConfig()
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
@@ -138,3 +186,43 @@ func TestMSMG2SkewedDistribution(t *testing.T) {
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
}
}
func TestMSMG2MultiDevice(t *testing.T) {
numDevices, _ := cr.GetDeviceCount()
fmt.Println("There are ", numDevices, " devices available")
orig_device, _ := cr.GetDevice()
wg := sync.WaitGroup{}
for i := 0; i < numDevices; i++ {
wg.Add(1)
cr.RunOnDevice(i, func(args ...any) {
defer wg.Done()
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
scalars := GenerateScalars(size)
points := G2GenerateAffinePoints(size)
stream, _ := cr.CreateStream()
var p G2Projective
var out core.DeviceSlice
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = G2Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[G2Projective], 1)
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsmG2(scalars, points, outHost[0]))
}
})
}
wg.Wait()
cr.SetDevice(orig_device)
}

View File

@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>
#ifndef _BW6_761_G2MSM_H
#define _BW6_761_G2MSM_H
@@ -8,7 +9,8 @@
extern "C" {
#endif
cudaError_t bw6_761G2MSMCuda(scalar_t* scalars, g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
cudaError_t bw6_761G2MSMCuda(const scalar_t* scalars,const g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
cudaError_t bw6_761G2PrecomputeMSMBases(g2_affine_t* points, int count, int precompute_factor, int _c, bool bases_on_device, DeviceContext* ctx, g2_affine_t* out);
#ifdef __cplusplus
}

View File

@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>
#ifndef _BW6_761_MSM_H
#define _BW6_761_MSM_H
@@ -8,7 +9,8 @@
extern "C" {
#endif
cudaError_t bw6_761MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t bw6_761MSMCuda(const scalar_t* scalars, const affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t bw6_761PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);
#ifdef __cplusplus
}

View File

@@ -9,11 +9,12 @@
extern "C" {
#endif
cudaError_t bw6_761NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bw6_761NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t bw6_761ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t bw6_761InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
#ifdef __cplusplus
}
#endif
#endif
#endif

View File

@@ -18,7 +18,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
core.MsmCheck(scalars, points, cfg, results)
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -26,7 +28,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
pointsDevice := points.(core.DeviceSlice)
pointsDevice.CheckDevice()
pointsPointer = pointsDevice.AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
}
@@ -34,7 +38,9 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
}
@@ -47,3 +53,28 @@ func Msm(scalars core.HostOrDeviceSlice, points core.HostOrDeviceSlice, cfg *cor
err := (cr.CudaError)(__ret)
return err
}
func PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Affine])[0])
}
cPoints := (*C.affine_t)(pointsPointer)
cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precomputeFactor)
cC := (C.int)(c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
outputBasesPointer := outputBases.AsPointer()
cOutputBases := (*C.affine_t)(outputBasesPointer)
__ret := C.bw6_761PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}

View File

@@ -1,9 +1,12 @@
package bw6761
import (
"github.com/stretchr/testify/assert"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
@@ -53,6 +56,7 @@ func testAgainstGnarkCryptoMsm(scalars core.HostSlice[ScalarField], points core.
func TestMSM(t *testing.T) {
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
@@ -65,12 +69,14 @@ func TestMSM(t *testing.T) {
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[Projective], 1)
outHost.CopyFromDevice(&out)
out.Free()
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
}
@@ -107,6 +113,48 @@ func TestMSMBatch(t *testing.T) {
}
}
func TestPrecomputeBase(t *testing.T) {
cfg := GetDefaultMSMConfig()
const precomputeFactor = 8
for _, power := range []int{10, 16} {
for _, batchSize := range []int{1, 3, 16} {
size := 1 << power
totalSize := size * batchSize
scalars := GenerateScalars(totalSize)
points := GenerateAffinePoints(totalSize)
var precomputeOut core.DeviceSlice
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
e = PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
var p Projective
var out core.DeviceSlice
_, e = out.Malloc(batchSize*p.Size(), p.Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.PrecomputeFactor = precomputeFactor
e = Msm(scalars, precomputeOut, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[Projective], batchSize)
outHost.CopyFromDevice(&out)
out.Free()
precomputeOut.Free()
// Check with gnark-crypto
for i := 0; i < batchSize; i++ {
scalarsSlice := scalars[i*size : (i+1)*size]
pointsSlice := points[i*size : (i+1)*size]
out := outHost[i]
assert.True(t, testAgainstGnarkCryptoMsm(scalarsSlice, pointsSlice, out))
}
}
}
}
func TestMSMSkewedDistribution(t *testing.T) {
cfg := GetDefaultMSMConfig()
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
@@ -136,3 +184,43 @@ func TestMSMSkewedDistribution(t *testing.T) {
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
}
}
func TestMSMMultiDevice(t *testing.T) {
numDevices, _ := cr.GetDeviceCount()
fmt.Println("There are ", numDevices, " devices available")
orig_device, _ := cr.GetDevice()
wg := sync.WaitGroup{}
for i := 0; i < numDevices; i++ {
wg.Add(1)
cr.RunOnDevice(i, func(args ...any) {
defer wg.Done()
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
scalars := GenerateScalars(size)
points := GenerateAffinePoints(size)
stream, _ := cr.CreateStream()
var p Projective
var out core.DeviceSlice
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[Projective], 1)
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsm(scalars, points, outHost[0]))
}
})
}
wg.Wait()
cr.SetDevice(orig_device)
}

View File

@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
}
@@ -49,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
return core.FromCudaError(err)
}
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
core.NttCheck[T](points, cfg, results)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
}
cPoints := (*C.projective_t)(pointsPointer)
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
cDir := (C.int)(dir)
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
}
cResults := (*C.projective_t)(resultsPointer)
__ret := C.bw6_761ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))

View File

@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
}
}
func TestECNtt(t *testing.T) {
cfg := GetDefaultNttConfig()
initDomain(largestTestSize, cfg)
points := GenerateProjectivePoints(1 << largestTestSize)
for _, size := range []int{4, 5, 6, 7, 8} {
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
testSize := 1 << size
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
cfg.Ordering = v
cfg.NttAlgorithm = core.Radix2
output := make(core.HostSlice[Projective], testSize)
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
}
}
}
func TestNttDeviceAsync(t *testing.T) {
cfg := GetDefaultNttConfig()
scalars := GenerateScalars(1 << largestTestSize)

View File

@@ -106,9 +106,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
}
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
scalars.CheckDevice()
return convertScalarsMontgomery(scalars, true)
}
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
scalars.CheckDevice()
return convertScalarsMontgomery(scalars, false)
}

View File

@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
var cA, cB, cOut *C.scalar_t
if a.IsOnDevice() {
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
aDevice := a.(core.DeviceSlice)
aDevice.CheckDevice()
cA = (*C.scalar_t)(aDevice.AsPointer())
} else {
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
}
if b.IsOnDevice() {
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
bDevice := b.(core.DeviceSlice)
bDevice.CheckDevice()
cB = (*C.scalar_t)(bDevice.AsPointer())
} else {
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
}
if out.IsOnDevice() {
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
outDevice := out.(core.DeviceSlice)
outDevice.CheckDevice()
cOut = (*C.scalar_t)(outDevice.AsPointer())
} else {
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
}

View File

@@ -150,10 +150,12 @@ func convert{{if .IsG2}}G2{{end}}AffinePointsMontgomery(points *core.DeviceSlice
}
func {{if .IsG2}}G2{{end}}AffineToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convert{{if .IsG2}}G2{{end}}AffinePointsMontgomery(points, true)
}
func {{if .IsG2}}G2{{end}}AffineFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convert{{if .IsG2}}G2{{end}}AffinePointsMontgomery(points, false)
}
@@ -169,10 +171,12 @@ func convert{{if .IsG2}}G2{{end}}ProjectivePointsMontgomery(points *core.DeviceS
}
func {{if .IsG2}}G2{{end}}ProjectiveToMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convert{{if .IsG2}}G2{{end}}ProjectivePointsMontgomery(points, true)
}
func {{if .IsG2}}G2{{end}}ProjectiveFromMontgomery(points *core.DeviceSlice) cr.CudaError {
points.CheckDevice()
return convert{{if .IsG2}}G2{{end}}ProjectivePointsMontgomery(points, false)
}
{{end}}

View File

@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>
#ifndef _{{toUpper .Curve}}_G2MSM_H
#define _{{toUpper .Curve}}_G2MSM_H
@@ -8,7 +9,8 @@
extern "C" {
#endif
cudaError_t {{.Curve}}G2MSMCuda(scalar_t* scalars, g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
cudaError_t {{.Curve}}G2MSMCuda(const scalar_t* scalars,const g2_affine_t* points, int count, MSMConfig* config, g2_projective_t* out);
cudaError_t {{.Curve}}G2PrecomputeMSMBases(g2_affine_t* points, int count, int precompute_factor, int _c, bool bases_on_device, DeviceContext* ctx, g2_affine_t* out);
#ifdef __cplusplus
}

View File

@@ -1,5 +1,6 @@
#include <cuda_runtime.h>
#include "../../include/types.h"
#include <stdbool.h>
#ifndef _{{toUpper .Curve}}_MSM_H
#define _{{toUpper .Curve}}_MSM_H
@@ -8,7 +9,8 @@
extern "C" {
#endif
cudaError_t {{.Curve}}MSMCuda(scalar_t* scalars, affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t {{.Curve}}MSMCuda(const scalar_t* scalars, const affine_t* points, int count, MSMConfig* config, projective_t* out);
cudaError_t {{.Curve}}PrecomputeMSMBases(affine_t* points, int bases_size, int precompute_factor, int _c, bool are_bases_on_device, DeviceContext* ctx, affine_t* output_bases);
#ifdef __cplusplus
}

View File

@@ -9,11 +9,12 @@
extern "C" {
#endif
cudaError_t {{.Curve}}NTTCuda(scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t {{.Curve}}NTTCuda(const scalar_t* input, int size, int dir, NTTConfig* config, scalar_t* output);
cudaError_t {{.Curve}}ECNTTCuda(const projective_t* input, int size, int dir, NTTConfig* config, projective_t* output);
cudaError_t {{.Curve}}InitializeDomain(scalar_t* primitive_root, DeviceContext* ctx, bool fast_twiddles);
#ifdef __cplusplus
}
#endif
#endif
#endif

View File

@@ -22,7 +22,9 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr
core.MsmCheck(scalars, points, cfg, results)
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -30,7 +32,9 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
pointsDevice := points.(core.DeviceSlice)
pointsDevice.CheckDevice()
pointsPointer = pointsDevice.AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[{{if .IsG2}}G2{{end}}Affine])[0])
}
@@ -38,7 +42,9 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[{{if .IsG2}}G2{{end}}Projective])[0])
}
@@ -51,3 +57,28 @@ func {{if .IsG2}}G2{{end}}Msm(scalars core.HostOrDeviceSlice, points core.HostOr
err := (cr.CudaError)(__ret)
return err
}
func {{if .IsG2}}G2{{end}}PrecomputeBases(points core.HostOrDeviceSlice, precomputeFactor int32, c int32, ctx *cr.DeviceContext, outputBases core.DeviceSlice) cr.CudaError {
core.PrecomputeBasesCheck(points, precomputeFactor, outputBases)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[{{if .IsG2}}G2{{end}}Affine])[0])
}
cPoints := (*C.{{if .IsG2}}g2_{{end}}affine_t)(pointsPointer)
cPointsLen := (C.int)(points.Len())
cPrecomputeFactor := (C.int)(precomputeFactor)
cC := (C.int)(c)
cPointsIsOnDevice := (C._Bool)(points.IsOnDevice())
cCtx := (*C.DeviceContext)(unsafe.Pointer(ctx))
outputBasesPointer := outputBases.AsPointer()
cOutputBases := (*C.{{if .IsG2}}g2_{{end}}affine_t)(outputBasesPointer)
__ret := C.{{.Curve}}{{if .IsG2}}G2{{end}}PrecomputeMSMBases(cPoints, cPointsLen, cPrecomputeFactor, cC, cPointsIsOnDevice, cCtx, cOutputBases)
err := (cr.CudaError)(__ret)
return err
}

View File

@@ -5,9 +5,12 @@
package {{.PackageName}}
import (
"github.com/stretchr/testify/assert"
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/ingonyama-zk/icicle/wrappers/golang/core"
cr "github.com/ingonyama-zk/icicle/wrappers/golang/cuda_runtime"
@@ -102,6 +105,7 @@ func testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars core.HostSlice[Scala
func TestMSM{{if .IsG2}}G2{{end}}(t *testing.T) {
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
@@ -114,12 +118,14 @@ func TestMSM{{if .IsG2}}G2{{end}}(t *testing.T) {
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = {{if .IsG2}}G2{{end}}Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[{{if .IsG2}}G2{{end}}Projective], 1)
outHost.CopyFromDevice(&out)
out.Free()
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars, points, outHost[0]))
}
@@ -156,6 +162,48 @@ func TestMSM{{if .IsG2}}G2{{end}}Batch(t *testing.T) {
}
}
func TestPrecomputeBase{{if .IsG2}}G2{{end}}(t *testing.T) {
cfg := GetDefaultMSMConfig()
const precomputeFactor = 8
for _, power := range []int{10, 16} {
for _, batchSize := range []int{1, 3, 16} {
size := 1 << power
totalSize := size * batchSize
scalars := GenerateScalars(totalSize)
points := {{if .IsG2}}G2{{end}}GenerateAffinePoints(totalSize)
var precomputeOut core.DeviceSlice
_, e := precomputeOut.Malloc(points[0].Size()*points.Len()*int(precomputeFactor), points[0].Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for PrecomputeBases results failed")
e = {{if .IsG2}}G2{{end}}PrecomputeBases(points, precomputeFactor, 0, &cfg.Ctx, precomputeOut)
assert.Equal(t, e, cr.CudaSuccess, "PrecomputeBases failed")
var p {{if .IsG2}}G2{{end}}Projective
var out core.DeviceSlice
_, e = out.Malloc(batchSize*p.Size(), p.Size())
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.PrecomputeFactor = precomputeFactor
e = {{if .IsG2}}G2{{end}}Msm(scalars, precomputeOut, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[{{if .IsG2}}G2{{end}}Projective], batchSize)
outHost.CopyFromDevice(&out)
out.Free()
precomputeOut.Free()
// Check with gnark-crypto
for i := 0; i < batchSize; i++ {
scalarsSlice := scalars[i*size : (i+1)*size]
pointsSlice := points[i*size : (i+1)*size]
out := outHost[i]
assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalarsSlice, pointsSlice, out))
}
}
}
}
func TestMSM{{if .IsG2}}G2{{end}}SkewedDistribution(t *testing.T) {
cfg := GetDefaultMSMConfig()
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
@@ -185,3 +233,43 @@ func TestMSM{{if .IsG2}}G2{{end}}SkewedDistribution(t *testing.T) {
assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars, points, outHost[0]))
}
}
func TestMSM{{if .IsG2}}G2{{end}}MultiDevice(t *testing.T) {
numDevices, _ := cr.GetDeviceCount()
fmt.Println("There are ", numDevices, " devices available")
orig_device, _ := cr.GetDevice()
wg := sync.WaitGroup{}
for i := 0; i < numDevices; i++ {
wg.Add(1)
cr.RunOnDevice(i, func(args ...any) {
defer wg.Done()
cfg := GetDefaultMSMConfig()
cfg.IsAsync = true
for _, power := range []int{2, 3, 4, 5, 6, 7, 8, 10, 18} {
size := 1 << power
scalars := GenerateScalars(size)
points := {{if .IsG2}}G2{{end}}GenerateAffinePoints(size)
stream, _ := cr.CreateStream()
var p {{if .IsG2}}G2{{end}}Projective
var out core.DeviceSlice
_, e := out.MallocAsync(p.Size(), p.Size(), stream)
assert.Equal(t, e, cr.CudaSuccess, "Allocating bytes on device for Projective results failed")
cfg.Ctx.Stream = &stream
e = {{if .IsG2}}G2{{end}}Msm(scalars, points, &cfg, out)
assert.Equal(t, e, cr.CudaSuccess, "Msm failed")
outHost := make(core.HostSlice[{{if .IsG2}}G2{{end}}Projective], 1)
outHost.CopyFromDeviceAsync(&out, stream)
out.FreeAsync(stream)
cr.SynchronizeStream(&stream)
// Check with gnark-crypto
assert.True(t, testAgainstGnarkCryptoMsm{{if .IsG2}}G2{{end}}(scalars, points, outHost[0]))
}
})
}
wg.Wait()
cr.SetDevice(orig_device)
}

View File

@@ -27,7 +27,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
var scalarsPointer unsafe.Pointer
if scalars.IsOnDevice() {
scalarsPointer = scalars.(core.DeviceSlice).AsPointer()
scalarsDevice := scalars.(core.DeviceSlice)
scalarsDevice.CheckDevice()
scalarsPointer = scalarsDevice.AsPointer()
} else {
scalarsPointer = unsafe.Pointer(&scalars.(core.HostSlice[ScalarField])[0])
}
@@ -38,7 +40,9 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
resultsDevice := results.(core.DeviceSlice)
resultsDevice.CheckDevice()
resultsPointer = resultsDevice.AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[ScalarField])[0])
}
@@ -49,6 +53,33 @@ func Ntt[T any](scalars core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTCo
return core.FromCudaError(err)
}
func ECNtt[T any](points core.HostOrDeviceSlice, dir core.NTTDir, cfg *core.NTTConfig[T], results core.HostOrDeviceSlice) core.IcicleError {
core.NttCheck[T](points, cfg, results)
var pointsPointer unsafe.Pointer
if points.IsOnDevice() {
pointsPointer = points.(core.DeviceSlice).AsPointer()
} else {
pointsPointer = unsafe.Pointer(&points.(core.HostSlice[Projective])[0])
}
cPoints := (*C.projective_t)(pointsPointer)
cSize := (C.int)(points.Len() / int(cfg.BatchSize))
cDir := (C.int)(dir)
cCfg := (*C.NTTConfig)(unsafe.Pointer(cfg))
var resultsPointer unsafe.Pointer
if results.IsOnDevice() {
resultsPointer = results.(core.DeviceSlice).AsPointer()
} else {
resultsPointer = unsafe.Pointer(&results.(core.HostSlice[Projective])[0])
}
cResults := (*C.projective_t)(resultsPointer)
__ret := C.{{.Curve}}ECNTTCuda(cPoints, cSize, cDir, cCfg, cResults)
err := (cr.CudaError)(__ret)
return core.FromCudaError(err)
}
func InitDomain(primitiveRoot ScalarField, ctx cr.DeviceContext, fastTwiddles bool) core.IcicleError {
cPrimitiveRoot := (*C.scalar_t)(unsafe.Pointer(primitiveRoot.AsPointer()))
cCtx := (*C.DeviceContext)(unsafe.Pointer(&ctx))

View File

@@ -98,6 +98,26 @@ func TestNtt(t *testing.T) {
}
}
func TestECNtt(t *testing.T) {
cfg := GetDefaultNttConfig()
initDomain(largestTestSize, cfg)
points := GenerateProjectivePoints(1 << largestTestSize)
for _, size := range []int{4, 5, 6, 7, 8} {
for _, v := range [4]core.Ordering{core.KNN, core.KNR, core.KRN, core.KRR} {
testSize := 1 << size
pointsCopy := core.HostSliceFromElements[Projective](points[:testSize])
cfg.Ordering = v
cfg.NttAlgorithm = core.Radix2
output := make(core.HostSlice[Projective], testSize)
e := ECNtt(pointsCopy, core.KForward, &cfg, output)
assert.Equal(t, core.IcicleErrorCode(0), e.IcicleErrorCode, "ECNtt failed")
}
}
}
func TestNttDeviceAsync(t *testing.T) {
cfg := GetDefaultNttConfig()
scalars := GenerateScalars(1 << largestTestSize)

View File

@@ -33,9 +33,11 @@ func convertScalarsMontgomery(scalars *core.DeviceSlice, isInto bool) cr.CudaErr
}
func ToMontgomery(scalars *core.DeviceSlice) cr.CudaError {
scalars.CheckDevice()
return convertScalarsMontgomery(scalars, true)
}
func FromMontgomery(scalars *core.DeviceSlice) cr.CudaError {
scalars.CheckDevice()
return convertScalarsMontgomery(scalars, false)
}{{- end}}

View File

@@ -16,19 +16,25 @@ func VecOp(a, b, out core.HostOrDeviceSlice, config core.VecOpsConfig, op core.V
var cA, cB, cOut *C.scalar_t
if a.IsOnDevice() {
cA = (*C.scalar_t)(a.(core.DeviceSlice).AsPointer())
aDevice := a.(core.DeviceSlice)
aDevice.CheckDevice()
cA = (*C.scalar_t)(aDevice.AsPointer())
} else {
cA = (*C.scalar_t)(unsafe.Pointer(&a.(core.HostSlice[ScalarField])[0]))
}
if b.IsOnDevice() {
cB = (*C.scalar_t)(b.(core.DeviceSlice).AsPointer())
bDevice := b.(core.DeviceSlice)
bDevice.CheckDevice()
cB = (*C.scalar_t)(bDevice.AsPointer())
} else {
cB = (*C.scalar_t)(unsafe.Pointer(&b.(core.HostSlice[ScalarField])[0]))
}
if out.IsOnDevice() {
cOut = (*C.scalar_t)(out.(core.DeviceSlice).AsPointer())
outDevice := out.(core.DeviceSlice)
outDevice.CheckDevice()
cOut = (*C.scalar_t)(outDevice.AsPointer())
} else {
cOut = (*C.scalar_t)(unsafe.Pointer(&out.(core.HostSlice[ScalarField])[0]))
}

View File

@@ -14,7 +14,7 @@ exclude = [
]
[workspace.package]
version = "1.6.0"
version = "1.9.0"
edition = "2021"
authors = [ "Ingonyama" ]
homepage = "https://www.ingonyama.com"

Some files were not shown because too many files have changed in this diff Show More