minor update to rust poly example

This commit is contained in:
Yuval Shekel
2024-07-30 17:28:55 +03:00
parent a45746fc3b
commit 5332f4c8f8
2 changed files with 8 additions and 16 deletions

View File

@@ -27,7 +27,7 @@ struct Args {
device_type: String,
/// Backend installation directory
#[arg(short, long, default_value = "")]
#[arg(short, long, default_value = "/opt/icicle/backend")]
backend_install_dir: String,
}
@@ -41,10 +41,11 @@ fn try_load_and_set_backend_device(args: &Args) {
icicle_runtime::runtime::load_backend(&args.backend_install_dir).unwrap();
}
println!("Setting device {}", args.device_type);
icicle_runtime::set_device(&icicle_runtime::Device::new(&args.device_type, 0 /* =device_id*/)).unwrap();
let device = icicle_runtime::Device::new(&args.device_type, 0 /* =device_id*/);
icicle_runtime::set_device(&device).unwrap();
}
fn init(max_ntt_size: u64) {
fn init_ntt_domain(max_ntt_size: u64) {
// Initialize NTT domain for all fields. Polynomial operations rely on NTT.
println!(
"Initializing NTT domain for max size 2^{}",
@@ -79,7 +80,7 @@ fn main() {
try_load_and_set_backend_device(&args);
init(1 << args.max_ntt_log_size);
init_ntt_domain(1 << args.max_ntt_log_size);
let poly_size = 1 << args.poly_log_size;

View File

@@ -91,7 +91,7 @@ TYPED_TEST(FieldApiTest, FieldSanityTest)
TYPED_TEST(FieldApiTest, vectorOps)
{
const uint64_t N = 1 << 15;
const uint64_t N = 1 << 22;
auto in_a = std::make_unique<TypeParam[]>(N);
auto in_b = std::make_unique<TypeParam[]>(N);
FieldApiTest<TypeParam>::random_samples(in_a.get(), N);
@@ -167,7 +167,7 @@ TYPED_TEST(FieldApiTest, matrixAPIsAsync)
config.is_a_on_device = true;
config.is_result_on_device = true;
config.is_async = true;
config.is_async = false;
}
TypeParam* in = device_props.using_host_memory ? h_in.get() : d_in;
@@ -310,15 +310,6 @@ TYPED_TEST(FieldApiTest, ntt)
coset_gen = scalar_t::one();
}
// TODO Yuval : remove those once the bug is fixed
ICICLE_LOG_INFO << "NTT test: logn=" << logn;
ICICLE_LOG_INFO << "NTT test: log_batch_size=" << log_batch_size;
ICICLE_LOG_INFO << "NTT test: columns_batch=" << columns_batch;
ICICLE_LOG_INFO << "NTT test: ordering=" << int(ordering);
ICICLE_LOG_INFO << "NTT test: dir=" << (dir == NTTDir::kForward ? "forward" : "inverse");
ICICLE_LOG_INFO << "NTT test: log_coset_stride=" << log_coset_stride;
ICICLE_LOG_INFO << "NTT test: coset_gen=" << coset_gen;
const int total_size = N * batch_size;
auto scalars = std::make_unique<TypeParam[]>(total_size);
FieldApiTest<TypeParam>::random_samples(scalars.get(), total_size);
@@ -375,8 +366,8 @@ TYPED_TEST(FieldApiTest, ntt)
run(s_main_target, out_main.get(), "ntt", false /*=measure*/, 1 /*=iters*/); // warmup
run(s_main_target, out_main.get(), "ntt", VERBOSE /*=measure*/, 1 /*=iters*/);
run(s_reference_target, out_ref.get(), "ntt", VERBOSE /*=measure*/, 1 /*=iters*/);
run(s_main_target, out_main.get(), "ntt", VERBOSE /*=measure*/, 1 /*=iters*/);
ASSERT_EQ(0, memcmp(out_main.get(), out_ref.get(), total_size * sizeof(scalar_t)));
}