updates example compares Radix2 and MixedRadix NTTs (#383)

## Describe the changes

Update to cover new NTT algorithms
This commit is contained in:
Stas
2024-02-20 18:05:39 -06:00
committed by GitHub
2 changed files with 34 additions and 16 deletions

View File

@@ -16,10 +16,11 @@ We recommend to run our examples in [ZK-containers](../../ZK-containers.md) to s
// Include NTT template
#include "appUtils/ntt/ntt.cu"
using namespace curve_config;
using namespace ntt;
// Configure NTT
ntt::NTTConfig<S> config=ntt::DefaultNTTConfig<S>();
NTTConfig<S> config=DefaultNTTConfig<S>();
// Call NTT
ntt::NTT<S, E>(input, ntt_size, ntt::NTTDir::kForward, config, output);
NTT<S, E>(input, ntt_size, NTTDir::kForward, config, output);
```
## Running the example
@@ -28,5 +29,10 @@ ntt::NTT<S, E>(input, ntt_size, ntt::NTTDir::kForward, config, output);
- compile with `./compile.sh`
- run with `./run.sh`
## What's in the example
1. Define the size of the example
2. Initialize input
3. Run Radix2 NTT
4. Run MixedRadix NTT
5. Validate the data output

View File

@@ -7,6 +7,7 @@
#include "appUtils/ntt/ntt.cu"
#include "appUtils/ntt/kernel_ntt.cu"
using namespace curve_config;
using namespace ntt;
// Operate on scalars
typedef scalar_t S;
@@ -58,6 +59,11 @@ int validate_output(const unsigned ntt_size, const unsigned nof_ntts, E* element
return nof_errors;
}
using FpMilliseconds = std::chrono::duration<float, std::chrono::milliseconds::period>;
#define START_TIMER(timer) auto timer##_start = std::chrono::high_resolution_clock::now();
#define END_TIMER(timer, msg) printf("%s: %.0f ms\n", msg, FpMilliseconds(std::chrono::high_resolution_clock::now() - timer##_start).count());
int main(int argc, char* argv[])
{
std::cout << "Icicle Examples: Number Theoretical Transform (NTT)" << std::endl;
@@ -78,24 +84,30 @@ int main(int argc, char* argv[])
output = (E*)malloc(sizeof(E) * batch_size);
std::cout << "Running NTT with on-host data" << std::endl;
cudaStream_t stream;
cudaStreamCreate(&stream);
// Create a device context
auto ctx = device_context::get_default_device_context();
// the next line is valid only for CURVE_ID 1 (will add support for other curves soon)
S rou = S{{0x53337857, 0x53422da9, 0xdbed349f, 0xac616632, 0x6d1e303, 0x27508aba, 0xa0ed063, 0x26125da1}};
ntt::InitDomain(rou, ctx);
const S basic_root = S::omega(log_ntt_size /*NTT_LOG_SIZE*/);
InitDomain(basic_root, ctx);
// Create an NTTConfig instance
ntt::NTTConfig<S> config = ntt::DefaultNTTConfig<S>();
NTTConfig<S> config = DefaultNTTConfig<S>();
config.ntt_algorithm = NttAlgorithm::MixedRadix;
config.batch_size = nof_ntts;
config.ctx.stream = stream;
auto begin0 = std::chrono::high_resolution_clock::now();
cudaError_t err = ntt::NTT<S, E>(input, ntt_size, ntt::NTTDir::kForward, config, output);
auto end0 = std::chrono::high_resolution_clock::now();
auto elapsed0 = std::chrono::duration_cast<std::chrono::nanoseconds>(end0 - begin0);
printf("On-device runtime: %.3f seconds\n", elapsed0.count() * 1e-9);
START_TIMER(MixedRadix);
cudaError_t err = NTT<S, E>(input, ntt_size, NTTDir::kForward, config, output);
END_TIMER(MixedRadix, "MixedRadix NTT");
std::cout << "Validating output" << std::endl;
validate_output(ntt_size, nof_ntts, output);
cudaStreamDestroy(stream);
config.ntt_algorithm = NttAlgorithm::Radix2;
START_TIMER(Radix2);
err = NTT<S, E>(input, ntt_size, NTTDir::kForward, config, output);
END_TIMER(Radix2, "Radix2 NTT");
std::cout << "Validating output" << std::endl;
validate_output(ntt_size, nof_ntts, output);
std::cout << "Cleaning-up memory" << std::endl;
free(input);
free(output);
return 0;