mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-09 15:37:58 -05:00
updates example compares Radix2 and MixedRadix NTTs (#383)
## Describe the changes Update to cover new NTT algorithms
This commit is contained in:
@@ -16,10 +16,11 @@ We recommend to run our examples in [ZK-containers](../../ZK-containers.md) to s
|
||||
// Include NTT template
|
||||
#include "appUtils/ntt/ntt.cu"
|
||||
using namespace curve_config;
|
||||
using namespace ntt;
|
||||
// Configure NTT
|
||||
ntt::NTTConfig<S> config=ntt::DefaultNTTConfig<S>();
|
||||
NTTConfig<S> config=DefaultNTTConfig<S>();
|
||||
// Call NTT
|
||||
ntt::NTT<S, E>(input, ntt_size, ntt::NTTDir::kForward, config, output);
|
||||
NTT<S, E>(input, ntt_size, NTTDir::kForward, config, output);
|
||||
```
|
||||
|
||||
## Running the example
|
||||
@@ -28,5 +29,10 @@ ntt::NTT<S, E>(input, ntt_size, ntt::NTTDir::kForward, config, output);
|
||||
- compile with `./compile.sh`
|
||||
- run with `./run.sh`
|
||||
|
||||
## What's in the example
|
||||
|
||||
|
||||
1. Define the size of the example
|
||||
2. Initialize input
|
||||
3. Run Radix2 NTT
|
||||
4. Run MixedRadix NTT
|
||||
5. Validate the data output
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "appUtils/ntt/ntt.cu"
|
||||
#include "appUtils/ntt/kernel_ntt.cu"
|
||||
using namespace curve_config;
|
||||
using namespace ntt;
|
||||
|
||||
// Operate on scalars
|
||||
typedef scalar_t S;
|
||||
@@ -58,6 +59,11 @@ int validate_output(const unsigned ntt_size, const unsigned nof_ntts, E* element
|
||||
return nof_errors;
|
||||
}
|
||||
|
||||
using FpMilliseconds = std::chrono::duration<float, std::chrono::milliseconds::period>;
|
||||
#define START_TIMER(timer) auto timer##_start = std::chrono::high_resolution_clock::now();
|
||||
#define END_TIMER(timer, msg) printf("%s: %.0f ms\n", msg, FpMilliseconds(std::chrono::high_resolution_clock::now() - timer##_start).count());
|
||||
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
std::cout << "Icicle Examples: Number Theoretical Transform (NTT)" << std::endl;
|
||||
@@ -78,24 +84,30 @@ int main(int argc, char* argv[])
|
||||
output = (E*)malloc(sizeof(E) * batch_size);
|
||||
|
||||
std::cout << "Running NTT with on-host data" << std::endl;
|
||||
cudaStream_t stream;
|
||||
cudaStreamCreate(&stream);
|
||||
// Create a device context
|
||||
auto ctx = device_context::get_default_device_context();
|
||||
// the next line is valid only for CURVE_ID 1 (will add support for other curves soon)
|
||||
S rou = S{{0x53337857, 0x53422da9, 0xdbed349f, 0xac616632, 0x6d1e303, 0x27508aba, 0xa0ed063, 0x26125da1}};
|
||||
ntt::InitDomain(rou, ctx);
|
||||
const S basic_root = S::omega(log_ntt_size /*NTT_LOG_SIZE*/);
|
||||
InitDomain(basic_root, ctx);
|
||||
// Create an NTTConfig instance
|
||||
ntt::NTTConfig<S> config = ntt::DefaultNTTConfig<S>();
|
||||
NTTConfig<S> config = DefaultNTTConfig<S>();
|
||||
config.ntt_algorithm = NttAlgorithm::MixedRadix;
|
||||
config.batch_size = nof_ntts;
|
||||
config.ctx.stream = stream;
|
||||
auto begin0 = std::chrono::high_resolution_clock::now();
|
||||
cudaError_t err = ntt::NTT<S, E>(input, ntt_size, ntt::NTTDir::kForward, config, output);
|
||||
auto end0 = std::chrono::high_resolution_clock::now();
|
||||
auto elapsed0 = std::chrono::duration_cast<std::chrono::nanoseconds>(end0 - begin0);
|
||||
printf("On-device runtime: %.3f seconds\n", elapsed0.count() * 1e-9);
|
||||
START_TIMER(MixedRadix);
|
||||
cudaError_t err = NTT<S, E>(input, ntt_size, NTTDir::kForward, config, output);
|
||||
END_TIMER(MixedRadix, "MixedRadix NTT");
|
||||
|
||||
std::cout << "Validating output" << std::endl;
|
||||
validate_output(ntt_size, nof_ntts, output);
|
||||
cudaStreamDestroy(stream);
|
||||
|
||||
config.ntt_algorithm = NttAlgorithm::Radix2;
|
||||
START_TIMER(Radix2);
|
||||
err = NTT<S, E>(input, ntt_size, NTTDir::kForward, config, output);
|
||||
END_TIMER(Radix2, "Radix2 NTT");
|
||||
|
||||
std::cout << "Validating output" << std::endl;
|
||||
validate_output(ntt_size, nof_ntts, output);
|
||||
|
||||
std::cout << "Cleaning-up memory" << std::endl;
|
||||
free(input);
|
||||
free(output);
|
||||
return 0;
|
||||
|
||||
Reference in New Issue
Block a user