mirror of
https://github.com/pseXperiments/icicle.git
synced 2026-01-08 23:17:54 -05:00
minor update to rust poly example
This commit is contained in:
@@ -27,7 +27,7 @@ struct Args {
|
||||
device_type: String,
|
||||
|
||||
/// Backend installation directory
|
||||
#[arg(short, long, default_value = "")]
|
||||
#[arg(short, long, default_value = "/opt/icicle/backend")]
|
||||
backend_install_dir: String,
|
||||
}
|
||||
|
||||
@@ -41,10 +41,11 @@ fn try_load_and_set_backend_device(args: &Args) {
|
||||
icicle_runtime::runtime::load_backend(&args.backend_install_dir).unwrap();
|
||||
}
|
||||
println!("Setting device {}", args.device_type);
|
||||
icicle_runtime::set_device(&icicle_runtime::Device::new(&args.device_type, 0 /* =device_id*/)).unwrap();
|
||||
let device = icicle_runtime::Device::new(&args.device_type, 0 /* =device_id*/);
|
||||
icicle_runtime::set_device(&device).unwrap();
|
||||
}
|
||||
|
||||
fn init(max_ntt_size: u64) {
|
||||
fn init_ntt_domain(max_ntt_size: u64) {
|
||||
// Initialize NTT domain for all fields. Polynomial operations rely on NTT.
|
||||
println!(
|
||||
"Initializing NTT domain for max size 2^{}",
|
||||
@@ -79,7 +80,7 @@ fn main() {
|
||||
|
||||
try_load_and_set_backend_device(&args);
|
||||
|
||||
init(1 << args.max_ntt_log_size);
|
||||
init_ntt_domain(1 << args.max_ntt_log_size);
|
||||
|
||||
let poly_size = 1 << args.poly_log_size;
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ TYPED_TEST(FieldApiTest, FieldSanityTest)
|
||||
|
||||
TYPED_TEST(FieldApiTest, vectorOps)
|
||||
{
|
||||
const uint64_t N = 1 << 15;
|
||||
const uint64_t N = 1 << 22;
|
||||
auto in_a = std::make_unique<TypeParam[]>(N);
|
||||
auto in_b = std::make_unique<TypeParam[]>(N);
|
||||
FieldApiTest<TypeParam>::random_samples(in_a.get(), N);
|
||||
@@ -167,7 +167,7 @@ TYPED_TEST(FieldApiTest, matrixAPIsAsync)
|
||||
|
||||
config.is_a_on_device = true;
|
||||
config.is_result_on_device = true;
|
||||
config.is_async = true;
|
||||
config.is_async = false;
|
||||
}
|
||||
|
||||
TypeParam* in = device_props.using_host_memory ? h_in.get() : d_in;
|
||||
@@ -310,15 +310,6 @@ TYPED_TEST(FieldApiTest, ntt)
|
||||
coset_gen = scalar_t::one();
|
||||
}
|
||||
|
||||
// TODO Yuval : remove those once the bug is fixed
|
||||
ICICLE_LOG_INFO << "NTT test: logn=" << logn;
|
||||
ICICLE_LOG_INFO << "NTT test: log_batch_size=" << log_batch_size;
|
||||
ICICLE_LOG_INFO << "NTT test: columns_batch=" << columns_batch;
|
||||
ICICLE_LOG_INFO << "NTT test: ordering=" << int(ordering);
|
||||
ICICLE_LOG_INFO << "NTT test: dir=" << (dir == NTTDir::kForward ? "forward" : "inverse");
|
||||
ICICLE_LOG_INFO << "NTT test: log_coset_stride=" << log_coset_stride;
|
||||
ICICLE_LOG_INFO << "NTT test: coset_gen=" << coset_gen;
|
||||
|
||||
const int total_size = N * batch_size;
|
||||
auto scalars = std::make_unique<TypeParam[]>(total_size);
|
||||
FieldApiTest<TypeParam>::random_samples(scalars.get(), total_size);
|
||||
@@ -375,8 +366,8 @@ TYPED_TEST(FieldApiTest, ntt)
|
||||
|
||||
run(s_main_target, out_main.get(), "ntt", false /*=measure*/, 1 /*=iters*/); // warmup
|
||||
|
||||
run(s_main_target, out_main.get(), "ntt", VERBOSE /*=measure*/, 1 /*=iters*/);
|
||||
run(s_reference_target, out_ref.get(), "ntt", VERBOSE /*=measure*/, 1 /*=iters*/);
|
||||
run(s_main_target, out_main.get(), "ntt", VERBOSE /*=measure*/, 1 /*=iters*/);
|
||||
|
||||
ASSERT_EQ(0, memcmp(out_main.get(), out_ref.get(), total_size * sizeof(scalar_t)));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user