chore: tests with > 2000 columns (#478)

This commit is contained in:
dante
2023-09-12 22:06:28 +01:00
committed by GitHub
parent fd0c1250a9
commit 7c8c2db9c9
9 changed files with 480 additions and 23 deletions

View File

@@ -61,10 +61,43 @@ jobs:
# nextest doesn't support --doc tests
run: cargo test --doc --verbose
- name: Library tests
run: cargo nextest run --lib --verbose -- --include-ignored
run: cargo nextest run --lib --verbose
ultra-overflow-tests:
runs-on: self-hosted
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- uses: mwilliamson/setup-wasmtime-action@v2
with:
wasmtime-version: "3.0.1"
- name: Install wasm32-wasi
run: rustup target add wasm32-wasi
- name: Install cargo-wasi
run: cargo install cargo-wasi
# - name: Matmul overflow (wasi)
# run: cargo wasi test matmul_col_ultra_overflow -- --include-ignored --nocapture
# - name: Conv overflow (wasi)
# run: cargo wasi test conv_col_ultra_overflow -- --include-ignored --nocapture
- name: Conv + relu overflow (wasi)
run: cargo wasi test conv_relu_col_ultra_overflow -- --include-ignored --nocapture
# - name: Matmul overflow
# run: RUST_LOG=debug cargo nextest run matmul_col_ultra_overflow --no-capture -- --include-ignored
# - name: Conv overflow
# run: RUST_LOG=debug cargo nextest run conv_col_ultra_overflow --no-capture -- --include-ignored
- name: Conv + relu overflow
run: RUST_LOG=debug cargo nextest run conv_relu_col_ultra_overflow --no-capture -- --include-ignored
model-serialization:
runs-on: 256gb
runs-on: ubuntu-latest-16-cores
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
@@ -79,7 +112,7 @@ jobs:
- name: Model serialization
run: cargo nextest run native_tests::tests::model_serialization_::t
- name: Model serialization different binary ID
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_::t --test-threads 1
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_ --test-threads 1
wasm32-tests:
runs-on: ubuntu-latest

11
Cargo.lock generated
View File

@@ -1631,6 +1631,7 @@ dependencies = [
"tract-onnx",
"unzip-n",
"wasm-bindgen",
"wasm-bindgen-console-logger",
"wasm-bindgen-rayon",
"wasm-bindgen-test",
]
@@ -5393,6 +5394,16 @@ dependencies = [
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-console-logger"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7530a275e7faf7b5b83aabdf78244fb8d9a68a2ec4b26935a05ecc0c9b0185ed"
dependencies = [
"log",
"wasm-bindgen",
]
[[package]]
name = "wasm-bindgen-futures"
version = "0.4.37"

View File

@@ -42,8 +42,6 @@ openssl = { version = "0.10.55", features = ["vendored"] }
postgres = "0.19.5"
pg_bigdecimal = "0.1.5"
lazy_static = "1.4.0"
env_logger = { version = "0.10.0", default_features = false, optional = true}
colored = { version = "2.0.0", default_features = false, optional = true}
colored_json = { version = "3.0.1", default_features = false, optional = true}
plotters = { version = "0.3.0", default_features = false, optional = true }
regex = { version = "1", default_features = false }
@@ -54,6 +52,10 @@ pyo3-log = { version = "0.8.1", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "2ea76c09678f092d00713ebbe6fdb046c0a9ad0f", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
colored = { version = "2.0.0", default_features = false, optional = true}
env_logger = { version = "0.10.0", default_features = false, optional = true}
[target.'cfg(target_arch = "wasm32")'.dependencies]
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "ac/parallel-lookup-permute", default_features = false, features = ["thread-safe-region"]}
@@ -66,6 +68,7 @@ wasm-bindgen-test = "0.3.34"
serde-wasm-bindgen = "0.4"
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"]}
console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"
[dev-dependencies]
@@ -145,3 +148,4 @@ render = ["halo2_proofs/dev-graph", "plotters"]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
ezkl = ["onnx", "serde", "serde_json", "log", "colored", "env_logger", "tabled/color", "colored_json", "halo2_proofs/circuit-params"]
det-prove = []

View File

@@ -168,6 +168,119 @@ mod matmul_col_overflow {
}
}
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod matmul_col_ultra_overflow {
use halo2_proofs::poly::commitment::ParamsProver;
use super::*;
const K: usize = 4;
const LEN: usize = 20;
#[derive(Clone)]
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; 2],
_marker: PhantomData<F>,
}
impl Circuit<F> for MatmulCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
let output = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0);
config
.layout(
&mut region,
&self.inputs.clone(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
)
.map_err(|_| Error::Synthesis)
},
)
.unwrap();
Ok(())
}
}
#[test]
#[ignore]
fn matmulcircuit() {
// get some logs fam
crate::logger::init_logger();
// parameters
let mut a = Tensor::from((0..LEN * LEN).map(|i| Value::known(F::from((i + 1) as u64))));
a.reshape(&[LEN, LEN]);
let mut w = Tensor::from((0..LEN).map(|i| Value::known(F::from((i + 1) as u64))));
w.reshape(&[LEN, 1]);
let circuit = MatmulCircuit::<F> {
inputs: [ValTensor::from(a), ValTensor::from(w)],
_marker: PhantomData,
};
let params = crate::pfsys::srs::gen_srs::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<_>,
>(K as u32);
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
MatmulCircuit<F>,
>(&circuit, &params)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
circuit.clone(),
&params,
vec![],
&pk,
crate::pfsys::TranscriptType::EVM,
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(&params),
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
);
assert!(prover.is_ok());
let proof = prover.unwrap();
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, strategy);
assert!(result.is_ok());
println!("done.");
}
}
#[cfg(test)]
mod dot {
use ops::poly::PolyOp;
@@ -672,6 +785,283 @@ mod conv {
}
}
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod conv_col_ultra_overflow {
use halo2_proofs::poly::commitment::ParamsProver;
use super::*;
const K: usize = 4;
const LEN: usize = 28;
#[derive(Clone)]
struct ConvCircuit<F: PrimeField + TensorType + PartialOrd> {
image: ValTensor<F>,
kernel: Tensor<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for ConvCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
let output = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
Self::Config::configure(cs, &[a, b], &output, CheckMode::SAFE)
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0);
config
.layout(
&mut region,
&[self.image.clone()],
Box::new(PolyOp::Conv {
kernel: self.kernel.clone(),
bias: None,
padding: [(1, 1); 2],
stride: (2, 2),
}),
)
.map_err(|_| Error::Synthesis)
},
)
.unwrap();
Ok(())
}
}
#[test]
#[ignore]
fn conv_circuit() {
// parameters
let kernel_height = 2;
let kernel_width = 2;
let image_height = LEN;
let image_width = LEN;
let in_channels = 3;
let out_channels = 2;
// get some logs fam
crate::logger::init_logger();
let mut image =
Tensor::from((0..in_channels * image_height * image_width).map(|i| F::from(i as u64)));
image.reshape(&[1, in_channels, image_height, image_width]);
image.set_visibility(crate::graph::Visibility::Private);
let mut kernels = Tensor::from(
(0..{ out_channels * in_channels * kernel_height * kernel_width })
.map(|i| F::from(i as u64)),
);
kernels.reshape(&[out_channels, in_channels, kernel_height, kernel_width]);
kernels.set_visibility(crate::graph::Visibility::Private);
let circuit = ConvCircuit::<F> {
image: ValTensor::from(image),
kernel: kernels,
_marker: PhantomData,
};
let params = crate::pfsys::srs::gen_srs::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<_>,
>(K as u32);
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ConvCircuit<F>,
>(&circuit, &params)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
circuit.clone(),
&params,
vec![],
&pk,
crate::pfsys::TranscriptType::EVM,
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(&params),
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
);
assert!(prover.is_ok());
let proof = prover.unwrap();
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, strategy);
assert!(result.is_ok());
println!("done.");
}
}
#[cfg(test)]
// not wasm 32 unknown
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod conv_relu_col_ultra_overflow {
use halo2_proofs::poly::commitment::ParamsProver;
use super::*;
const K: usize = 4;
const LEN: usize = 28;
#[derive(Clone)]
struct ConvCircuit<F: PrimeField + TensorType + PartialOrd> {
image: ValTensor<F>,
kernel: Tensor<F>,
_marker: PhantomData<F>,
}
impl Circuit<F> for ConvCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
let a = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
let b = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
let output = VarTensor::new_advice(cs, K, LEN * LEN * LEN);
let mut base_config =
Self::Config::configure(cs, &[a, b.clone()], &output, CheckMode::SAFE);
// sets up a new relu table
base_config
.configure_lookup(cs, &b, &output, K, &LookupOp::ReLU)
.unwrap();
base_config.clone()
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<F>,
) -> Result<(), Error> {
layouter
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0);
let output = config
.layout(
&mut region,
&[self.image.clone()],
Box::new(PolyOp::Conv {
kernel: self.kernel.clone(),
bias: None,
padding: [(1, 1); 2],
stride: (2, 2),
}),
)
.map_err(|_| Error::Synthesis);
let _output = config
.layout(
&mut region,
&[output.unwrap().unwrap()],
Box::new(LookupOp::ReLU),
)
.unwrap();
Ok(())
},
)
.unwrap();
Ok(())
}
}
#[test]
#[ignore]
fn conv_relu_circuit() {
// parameters
let kernel_height = 2;
let kernel_width = 2;
let image_height = LEN;
let image_width = LEN;
let in_channels = 3;
let out_channels = 2;
// get some logs fam
crate::logger::init_logger();
let mut image =
Tensor::from((0..in_channels * image_height * image_width).map(|_| F::from(0)));
image.reshape(&[1, in_channels, image_height, image_width]);
image.set_visibility(crate::graph::Visibility::Private);
let mut kernels = Tensor::from(
(0..{ out_channels * in_channels * kernel_height * kernel_width }).map(|_| F::from(0)),
);
kernels.reshape(&[out_channels, in_channels, kernel_height, kernel_width]);
kernels.set_visibility(crate::graph::Visibility::Private);
let circuit = ConvCircuit::<F> {
image: ValTensor::from(image),
kernel: kernels,
_marker: PhantomData,
};
let params = crate::pfsys::srs::gen_srs::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<_>,
>(K as u32);
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ConvCircuit<F>,
>(&circuit, &params)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
circuit.clone(),
&params,
vec![],
&pk,
crate::pfsys::TranscriptType::EVM,
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(&params),
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
);
assert!(prover.is_ok());
let proof = prover.unwrap();
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result =
crate::pfsys::verify_proof_circuit_kzg(params.verifier_params(), proof, &vk, strategy);
assert!(result.is_ok());
println!("done.");
}
}
#[cfg(test)]
mod sumpool {

View File

@@ -1205,8 +1205,6 @@ pub(crate) async fn prove(
trace!("params computed");
let now = Instant::now();
// creates and verifies the proof
let snark = match strategy {
StrategyType::Single => {
@@ -1234,12 +1232,6 @@ pub(crate) async fn prove(
)?
}
};
let elapsed = now.elapsed();
info!(
"proof took {}.{}",
elapsed.as_secs(),
elapsed.subsec_millis()
);
if let Some(proof_path) = proof_path {
snark.save(&proof_path)?;

View File

@@ -52,7 +52,7 @@ pub mod fieldutils;
#[cfg(feature = "onnx")]
pub mod graph;
/// beautiful logging
#[cfg(not(target_arch = "wasm32"))]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
pub mod logger;
/// Tools for proofs and verification used by cli
pub mod pfsys;

View File

@@ -25,7 +25,10 @@ use halo2curves::serde::SerdeObject;
use halo2curves::CurveAffine;
use instant::Instant;
use log::{debug, info, trace};
#[cfg(not(feature = "det-prove"))]
use rand::rngs::OsRng;
#[cfg(feature = "det-prove")]
use rand::rngs::StdRng;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use snark_verifier::loader::native::NativeLoader;
@@ -274,6 +277,9 @@ where
Scheme::Curve: Serialize + DeserializeOwned,
{
let mut transcript = TranscriptWriterBuffer::<_, Scheme::Curve, _>::init(vec![]);
#[cfg(feature = "det-prove")]
let mut rng = <StdRng as rand::SeedableRng>::from_seed([0u8; 32]);
#[cfg(not(feature = "det-prove"))]
let mut rng = OsRng;
let number_instance = instances.iter().map(|x| x.len()).collect();
trace!("number_instance {:?}", number_instance);
@@ -291,6 +297,8 @@ where
trace!("instances {:?}", instances);
info!("proof started...");
// not wasm32 unknown
let now = Instant::now();
create_proof::<Scheme, P, _, _, TW, _>(
params,
pk,
@@ -314,6 +322,12 @@ where
strategy,
)?;
}
let elapsed = now.elapsed();
info!(
"proof took {}.{}",
elapsed.as_secs(),
elapsed.subsec_millis()
);
Ok(checkable_pf)
}
@@ -524,6 +538,7 @@ pub fn create_proof_circuit_kzg<
}
}
#[allow(unused)]
/// helper function
pub(crate) fn verify_proof_circuit_kzg<
'params,

View File

@@ -19,6 +19,7 @@ use rand::SeedableRng;
use crate::tensor::TensorType;
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
use console_error_panic_hook;
@@ -279,6 +280,11 @@ pub fn prove(
settings: wasm_bindgen::Clamped<Vec<u8>>,
srs: wasm_bindgen::Clamped<Vec<u8>>,
) -> Result<Vec<u8>, JsError> {
log::set_logger(&DEFAULT_LOGGER).unwrap();
#[cfg(feature = "det-prove")]
log::set_max_level(log::LevelFilter::Debug);
#[cfg(not(feature = "det-prove"))]
log::set_max_level(log::LevelFilter::Info);
// read in kzg params
let mut reader = std::io::BufReader::new(&srs[..]);
let params: ParamsKZG<Bn256> =

View File

@@ -359,6 +359,17 @@ mod native_tests {
use crate::native_tests::tutorial as run_tutorial;
use tempdir::TempDir;
#[test]
fn model_serialization_different_binaries_() {
let test = "1l_mlp";
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap();
crate::native_tests::mv_test_(path, test);
// percent tolerance test
model_serialization_different_binaries(path, test.to_string());
test_dir.close().unwrap();
}
#[test]
fn tutorial_() {
let test_dir = TempDir::new("tutorial").unwrap();
@@ -380,6 +391,9 @@ mod native_tests {
}
});
seq!(N in 0..=51 {
#(#[test_case(TESTS[N])])*
@@ -392,15 +406,7 @@ mod native_tests {
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn model_serialization_different_binaries_(test: &str) {
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap();
crate::native_tests::mv_test_(path, test);
// percent tolerance test
model_serialization_different_binaries(path, test.to_string());
test_dir.close().unwrap();
}