mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-10 06:48:01 -05:00
chore: tests with > 2000 columns (#478)
This commit is contained in:
39
.github/workflows/rust.yml
vendored
39
.github/workflows/rust.yml
vendored
@@ -61,10 +61,43 @@ jobs:
|
||||
# nextest doesn't support --doc tests
|
||||
run: cargo test --doc --verbose
|
||||
- name: Library tests
|
||||
run: cargo nextest run --lib --verbose -- --include-ignored
|
||||
run: cargo nextest run --lib --verbose
|
||||
|
||||
ultra-overflow-tests:
|
||||
runs-on: self-hosted
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- uses: mwilliamson/setup-wasmtime-action@v2
|
||||
with:
|
||||
wasmtime-version: "3.0.1"
|
||||
- name: Install wasm32-wasi
|
||||
run: rustup target add wasm32-wasi
|
||||
- name: Install cargo-wasi
|
||||
run: cargo install cargo-wasi
|
||||
# - name: Matmul overflow (wasi)
|
||||
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# - name: Conv overflow (wasi)
|
||||
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
|
||||
- name: Conv + relu overflow (wasi)
|
||||
run: cargo wasi test conv_relu_col_ultra_overflow -- --include-ignored --nocapture
|
||||
# - name: Matmul overflow
|
||||
# run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture -- --include-ignored
|
||||
# - name: Conv overflow
|
||||
# run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture -- --include-ignored
|
||||
- name: Conv + relu overflow
|
||||
run: RUST_LOG=debug cargo nextest run conv_relu_col_ultra_overflow --no-capture -- --include-ignored
|
||||
|
||||
model-serialization:
|
||||
runs-on: 256gb
|
||||
runs-on: ubuntu-latest-16-cores
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
@@ -79,7 +112,7 @@ jobs:
|
||||
- name: Model serialization
|
||||
run: cargo nextest run native_tests::tests::model_serialization_::t
|
||||
- name: Model serialization different binary ID
|
||||
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_::t --test-threads 1
|
||||
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_ --test-threads 1
|
||||
|
||||
wasm32-tests:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
11
Cargo.lock
generated
11
Cargo.lock
generated
@@ -1631,6 +1631,7 @@ dependencies = [
|
||||
"tract-onnx",
|
||||
"unzip-n",
|
||||
"wasm-bindgen",
|
||||
"wasm-bindgen-console-logger",
|
||||
"wasm-bindgen-rayon",
|
||||
"wasm-bindgen-test",
|
||||
]
|
||||
@@ -5393,6 +5394,16 @@ dependencies = [
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-console-logger"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7530a275e7faf7b5b83aabdf78244fb8d9a68a2ec4b26935a05ecc0c9b0185ed"
|
||||
dependencies = [
|
||||
"log",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-futures"
|
||||
version = "0.4.37"
|
||||
|
||||
@@ -42,8 +42,6 @@ openssl = { version = "0.10.55", features = ["vendored"] }
|
||||
postgres = "0.19.5"
|
||||
pg_bigdecimal = "0.1.5"
|
||||
lazy_static = "1.4.0"
|
||||
env_logger = { version = "0.10.0", default_features = false, optional = true}
|
||||
colored = { version = "2.0.0", default_features = false, optional = true}
|
||||
colored_json = { version = "3.0.1", default_features = false, optional = true}
|
||||
plotters = { version = "0.3.0", default_features = false, optional = true }
|
||||
regex = { version = "1", default_features = false }
|
||||
@@ -54,6 +52,10 @@ pyo3-log = { version = "0.8.1", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "2ea76c09678f092d00713ebbe6fdb046c0a9ad0f", default_features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
|
||||
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
|
||||
colored = { version = "2.0.0", default_features = false, optional = true}
|
||||
env_logger = { version = "0.10.0", default_features = false, optional = true}
|
||||
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "ac/parallel-lookup-permute", default_features = false, features = ["thread-safe-region"]}
|
||||
@@ -66,6 +68,7 @@ wasm-bindgen-test = "0.3.34"
|
||||
serde-wasm-bindgen = "0.4"
|
||||
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"]}
|
||||
console_error_panic_hook = "0.1.7"
|
||||
wasm-bindgen-console-logger = "0.1.1"
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
@@ -145,3 +148,4 @@ render = ["halo2_proofs/dev-graph", "plotters"]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
|
||||
ezkl = ["onnx", "serde", "serde_json", "log", "colored", "env_logger", "tabled/color", "colored_json", "halo2_proofs/circuit-params"]
|
||||
det-prove = []
|
||||
|
||||
@@ -168,6 +168,119 @@ mod matmul_col_overflow {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod matmul_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const LEN: usize = 20;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 2],
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for MatmulCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn matmulcircuit() {
|
||||
// get some logs fam
|
||||
crate::logger::init_logger();
|
||||
// parameters
|
||||
let mut a = Tensor::from((0..LEN * LEN).map(|i| Value::known(F::from((i + 1) as u64))));
|
||||
a.reshape(&[LEN, LEN]);
|
||||
|
||||
let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64))));
|
||||
w.reshape(&[LEN, 1]);
|
||||
|
||||
let circuit = MatmulCircuit::<F> {
|
||||
inputs: [ValTensor::from(a), ValTensor::from(w)],
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let params = crate::pfsys::srs::gen_srs::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<_>,
|
||||
>(K as u32);
|
||||
|
||||
let pk = crate::pfsys::create_keys::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
F,
|
||||
MatmulCircuit<F>,
|
||||
>(&circuit, ¶ms)
|
||||
.unwrap();
|
||||
|
||||
let prover = crate::pfsys::create_proof_circuit_kzg(
|
||||
circuit.clone(),
|
||||
¶ms,
|
||||
vec![],
|
||||
&pk,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms),
|
||||
// use safe mode to verify that the proof is correct
|
||||
CheckMode::SAFE,
|
||||
);
|
||||
|
||||
assert!(prover.is_ok());
|
||||
|
||||
let proof = prover.unwrap();
|
||||
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, strategy);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
println!("done.");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod dot {
|
||||
use ops::poly::PolyOp;
|
||||
@@ -672,6 +785,283 @@ mod conv {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod conv_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const LEN: usize = 28;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConvCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
image: ValTensor<F>,
|
||||
kernel: Tensor<F>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for ConvCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
|
||||
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.image.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
kernel: self.kernel.clone(),
|
||||
bias: None,
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn conv_circuit() {
|
||||
// parameters
|
||||
let kernel_height = 2;
|
||||
let kernel_width = 2;
|
||||
let image_height = LEN;
|
||||
let image_width = LEN;
|
||||
let in_channels = 3;
|
||||
let out_channels = 2;
|
||||
|
||||
// get some logs fam
|
||||
crate::logger::init_logger();
|
||||
let mut image =
|
||||
Tensor::from((0..in_channels * image_height * image_width).map(|i| F::from(i as u64)));
|
||||
image.reshape(&[1, in_channels, image_height, image_width]);
|
||||
image.set_visibility(crate::graph::Visibility::Private);
|
||||
|
||||
let mut kernels = Tensor::from(
|
||||
(0..{ out_channels * in_channels * kernel_height * kernel_width })
|
||||
.map(|i| F::from(i as u64)),
|
||||
);
|
||||
kernels.reshape(&[out_channels, in_channels, kernel_height, kernel_width]);
|
||||
kernels.set_visibility(crate::graph::Visibility::Private);
|
||||
|
||||
let circuit = ConvCircuit::<F> {
|
||||
image: ValTensor::from(image),
|
||||
kernel: kernels,
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let params = crate::pfsys::srs::gen_srs::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<_>,
|
||||
>(K as u32);
|
||||
|
||||
let pk = crate::pfsys::create_keys::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
F,
|
||||
ConvCircuit<F>,
|
||||
>(&circuit, ¶ms)
|
||||
.unwrap();
|
||||
|
||||
let prover = crate::pfsys::create_proof_circuit_kzg(
|
||||
circuit.clone(),
|
||||
¶ms,
|
||||
vec![],
|
||||
&pk,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms),
|
||||
// use safe mode to verify that the proof is correct
|
||||
CheckMode::SAFE,
|
||||
);
|
||||
|
||||
assert!(prover.is_ok());
|
||||
|
||||
let proof = prover.unwrap();
|
||||
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, strategy);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
println!("done.");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
// not wasm 32 unknown
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod conv_relu_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::ParamsProver;
|
||||
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const LEN: usize = 28;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConvCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
image: ValTensor<F>,
|
||||
kernel: Tensor<F>,
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for ConvCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
|
||||
let b = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
|
||||
let output = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
|
||||
let mut base_config =
|
||||
Self::Config::configure(cs, &[a, b.clone()], &output, CheckMode::SAFE);
|
||||
// sets up a new relu table
|
||||
base_config
|
||||
.configure_lookup(cs, &b, &output, K, &LookupOp::ReLU)
|
||||
.unwrap();
|
||||
base_config.clone()
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0);
|
||||
let output = config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.image.clone()],
|
||||
Box::new(PolyOp::Conv {
|
||||
kernel: self.kernel.clone(),
|
||||
bias: None,
|
||||
padding: [(1, 1); 2],
|
||||
stride: (2, 2),
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis);
|
||||
let _output = config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap().unwrap()],
|
||||
Box::new(LookupOp::ReLU),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn conv_relu_circuit() {
|
||||
// parameters
|
||||
let kernel_height = 2;
|
||||
let kernel_width = 2;
|
||||
let image_height = LEN;
|
||||
let image_width = LEN;
|
||||
let in_channels = 3;
|
||||
let out_channels = 2;
|
||||
|
||||
// get some logs fam
|
||||
crate::logger::init_logger();
|
||||
let mut image =
|
||||
Tensor::from((0..in_channels * image_height * image_width).map(|_| F::from(0)));
|
||||
image.reshape(&[1, in_channels, image_height, image_width]);
|
||||
image.set_visibility(crate::graph::Visibility::Private);
|
||||
|
||||
let mut kernels = Tensor::from(
|
||||
(0..{ out_channels * in_channels * kernel_height * kernel_width }).map(|_| F::from(0)),
|
||||
);
|
||||
kernels.reshape(&[out_channels, in_channels, kernel_height, kernel_width]);
|
||||
kernels.set_visibility(crate::graph::Visibility::Private);
|
||||
|
||||
let circuit = ConvCircuit::<F> {
|
||||
image: ValTensor::from(image),
|
||||
kernel: kernels,
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let params = crate::pfsys::srs::gen_srs::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<_>,
|
||||
>(K as u32);
|
||||
|
||||
let pk = crate::pfsys::create_keys::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
F,
|
||||
ConvCircuit<F>,
|
||||
>(&circuit, ¶ms)
|
||||
.unwrap();
|
||||
|
||||
let prover = crate::pfsys::create_proof_circuit_kzg(
|
||||
circuit.clone(),
|
||||
¶ms,
|
||||
vec![],
|
||||
&pk,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms),
|
||||
// use safe mode to verify that the proof is correct
|
||||
CheckMode::SAFE,
|
||||
);
|
||||
|
||||
assert!(prover.is_ok());
|
||||
|
||||
let proof = prover.unwrap();
|
||||
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result =
|
||||
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, strategy);
|
||||
|
||||
assert!(result.is_ok());
|
||||
|
||||
println!("done.");
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod sumpool {
|
||||
|
||||
|
||||
@@ -1205,8 +1205,6 @@ pub(crate) async fn prove(
|
||||
|
||||
trace!("params computed");
|
||||
|
||||
let now = Instant::now();
|
||||
|
||||
// creates and verifies the proof
|
||||
let snark = match strategy {
|
||||
StrategyType::Single => {
|
||||
@@ -1234,12 +1232,6 @@ pub(crate) async fn prove(
|
||||
)?
|
||||
}
|
||||
};
|
||||
let elapsed = now.elapsed();
|
||||
info!(
|
||||
"proof took {}.{}",
|
||||
elapsed.as_secs(),
|
||||
elapsed.subsec_millis()
|
||||
);
|
||||
|
||||
if let Some(proof_path) = proof_path {
|
||||
snark.save(&proof_path)?;
|
||||
|
||||
@@ -52,7 +52,7 @@ pub mod fieldutils;
|
||||
#[cfg(feature = "onnx")]
|
||||
pub mod graph;
|
||||
/// beautiful logging
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
pub mod logger;
|
||||
/// Tools for proofs and verification used by cli
|
||||
pub mod pfsys;
|
||||
|
||||
@@ -25,7 +25,10 @@ use halo2curves::serde::SerdeObject;
|
||||
use halo2curves::CurveAffine;
|
||||
use instant::Instant;
|
||||
use log::{debug, info, trace};
|
||||
#[cfg(not(feature = "det-prove"))]
|
||||
use rand::rngs::OsRng;
|
||||
#[cfg(feature = "det-prove")]
|
||||
use rand::rngs::StdRng;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snark_verifier::loader::native::NativeLoader;
|
||||
@@ -274,6 +277,9 @@ where
|
||||
Scheme::Curve: Serialize + DeserializeOwned,
|
||||
{
|
||||
let mut transcript = TranscriptWriterBuffer::<_, Scheme::Curve, _>::init(vec![]);
|
||||
#[cfg(feature = "det-prove")]
|
||||
let mut rng = <StdRng as rand::SeedableRng>::from_seed([0u8; 32]);
|
||||
#[cfg(not(feature = "det-prove"))]
|
||||
let mut rng = OsRng;
|
||||
let number_instance = instances.iter().map(|x| x.len()).collect();
|
||||
trace!("number_instance {:?}", number_instance);
|
||||
@@ -291,6 +297,8 @@ where
|
||||
trace!("instances {:?}", instances);
|
||||
|
||||
info!("proof started...");
|
||||
// not wasm32 unknown
|
||||
let now = Instant::now();
|
||||
create_proof::<Scheme, P, _, _, TW, _>(
|
||||
params,
|
||||
pk,
|
||||
@@ -314,6 +322,12 @@ where
|
||||
strategy,
|
||||
)?;
|
||||
}
|
||||
let elapsed = now.elapsed();
|
||||
info!(
|
||||
"proof took {}.{}",
|
||||
elapsed.as_secs(),
|
||||
elapsed.subsec_millis()
|
||||
);
|
||||
|
||||
Ok(checkable_pf)
|
||||
}
|
||||
@@ -524,6 +538,7 @@ pub fn create_proof_circuit_kzg<
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
/// helper function
|
||||
pub(crate) fn verify_proof_circuit_kzg<
|
||||
'params,
|
||||
|
||||
@@ -19,6 +19,7 @@ use rand::SeedableRng;
|
||||
|
||||
use crate::tensor::TensorType;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
|
||||
|
||||
use console_error_panic_hook;
|
||||
|
||||
@@ -279,6 +280,11 @@ pub fn prove(
|
||||
settings: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
srs: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<Vec<u8>, JsError> {
|
||||
log::set_logger(&DEFAULT_LOGGER).unwrap();
|
||||
#[cfg(feature = "det-prove")]
|
||||
log::set_max_level(log::LevelFilter::Debug);
|
||||
#[cfg(not(feature = "det-prove"))]
|
||||
log::set_max_level(log::LevelFilter::Info);
|
||||
// read in kzg params
|
||||
let mut reader = std::io::BufReader::new(&srs[..]);
|
||||
let params: ParamsKZG<Bn256> =
|
||||
|
||||
@@ -359,6 +359,17 @@ mod native_tests {
|
||||
use crate::native_tests::tutorial as run_tutorial;
|
||||
use tempdir::TempDir;
|
||||
|
||||
#[test]
|
||||
fn model_serialization_different_binaries_() {
|
||||
let test = "1l_mlp";
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap();
|
||||
crate::native_tests::mv_test_(path, test);
|
||||
// percent tolerance test
|
||||
model_serialization_different_binaries(path, test.to_string());
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tutorial_() {
|
||||
let test_dir = TempDir::new("tutorial").unwrap();
|
||||
@@ -380,6 +391,9 @@ mod native_tests {
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
|
||||
|
||||
seq!(N in 0..=51 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
@@ -392,15 +406,7 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn model_serialization_different_binaries_(test: &str) {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap();
|
||||
crate::native_tests::mv_test_(path, test);
|
||||
// percent tolerance test
|
||||
model_serialization_different_binaries(path, test.to_string());
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user