Compare commits

...

2 Commits

Author SHA1 Message Date
dante
395bf4e9a9 chore: fft tests 2025-03-22 16:59:56 +00:00
dante
a1b531e93b feat: (optional) fft-conv 2025-03-22 16:51:40 +00:00
24 changed files with 1460 additions and 191 deletions

View File

@@ -1,10 +1,10 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::create_keys;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::TranscriptType;
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
@@ -64,13 +64,13 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2, false);
config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
padding: vec![(0, 0), (0, 0)],
stride: vec![1; 2],
group: 1,
data_format: DataFormat::NCHW,
@@ -90,7 +90,7 @@ fn runcnvrl(c: &mut Criterion) {
let params = gen_srs::<KZGCommitmentScheme<_>>(K as u32);
for size in [1, 2, 4].iter() {
for size in [1, 2, 4, 32, 128].iter() {
unsafe {
KERNEL_HEIGHT = size * 2;
KERNEL_WIDTH = size * 2;

176
benches/accum_conv_ntt.rs Normal file
View File

@@ -0,0 +1,176 @@
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::create_keys;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::srs::gen_srs;
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
static mut KERNEL_HEIGHT: usize = 2;
static mut KERNEL_WIDTH: usize = 2;
static mut OUT_CHANNELS: usize = 1;
static mut IMAGE_HEIGHT: usize = 2;
static mut IMAGE_WIDTH: usize = 2;
static mut IN_CHANNELS: usize = 1;
const K: usize = 17;
#[derive(Clone, Debug)]
struct MyCircuit {
image: ValTensor<Fr>,
kernel: ValTensor<Fr>,
bias: ValTensor<Fr>,
}
impl Circuit<Fr> for MyCircuit {
type Config = BaseConfig<Fr>;
type FloorPlanner = SimpleFloorPlanner;
type Params = ();
fn without_witnesses(&self) -> Self {
self.clone()
}
fn configure(cs: &mut ConstraintSystem<Fr>) -> Self::Config {
let len = 10;
let a = VarTensor::new_advice(cs, K, 1, len * len);
let b = VarTensor::new_advice(cs, K, 1, len * len);
let output = VarTensor::new_advice(cs, K, 1, (len + 1) * len);
let _constant = VarTensor::constant_cols(cs, K, len * len, false);
Self::Config::configure(cs, &[a, b], &output, CheckMode::UNSAFE)
}
fn synthesize(
&self,
mut config: Self::Config,
mut layouter: impl Layouter<Fr>,
) -> Result<(), Error> {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2, true);
config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
Box::new(PolyOp::Conv {
padding: vec![(0, 0), (0, 0)],
stride: vec![1; 2],
group: 1,
data_format: DataFormat::NCHW,
kernel_format: KernelFormat::OIHW,
}),
)
.unwrap();
Ok(())
},
)?;
Ok(())
}
}
fn runcnvrl(c: &mut Criterion) {
let mut group = c.benchmark_group("accum_conv");
let params = gen_srs::<KZGCommitmentScheme<_>>(K as u32);
for size in [1, 2, 4].iter() {
unsafe {
KERNEL_HEIGHT = size * 2;
KERNEL_WIDTH = size * 2;
IMAGE_HEIGHT = size * 4;
IMAGE_WIDTH = size * 4;
IN_CHANNELS = 1;
OUT_CHANNELS = 1;
let mut image = Tensor::from(
(0..IN_CHANNELS * IMAGE_HEIGHT * IMAGE_WIDTH)
.map(|_| Value::known(Fr::random(OsRng))),
);
image
.reshape(&[1, IN_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH])
.unwrap();
let mut kernel = Tensor::from(
(0..{ OUT_CHANNELS * IN_CHANNELS * KERNEL_HEIGHT * KERNEL_WIDTH })
.map(|_| Fr::random(OsRng)),
);
kernel
.reshape(&[OUT_CHANNELS, IN_CHANNELS, KERNEL_HEIGHT, KERNEL_WIDTH])
.unwrap();
kernel.set_visibility(&ezkl::graph::Visibility::Private);
let mut bias = Tensor::from((0..{ OUT_CHANNELS }).map(|_| Fr::random(OsRng)));
bias.set_visibility(&ezkl::graph::Visibility::Private);
let circuit = MyCircuit {
image: ValTensor::from(image),
kernel: ValTensor::try_from(kernel).unwrap(),
bias: ValTensor::try_from(bias).unwrap(),
};
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();
});
});
}
}
group.finish();
}
criterion_group! {
name = benches;
config = Criterion::default().with_plots();
targets = runcnvrl
}
criterion_main!(benches);

View File

@@ -1,8 +1,8 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2, false);
config
.layout(
&mut region,

View File

@@ -1,8 +1,8 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -57,7 +57,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2, false);
config
.layout(
&mut region,

View File

@@ -1,11 +1,11 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use ezkl::circuit::table::Range;
use ezkl::circuit::*;
use ezkl::circuit::lookup::LookupOp;
use ezkl::circuit::poly::PolyOp;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -83,7 +83,7 @@ impl Circuit<Fr> for MyCircuit {
let op = PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
};
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2, false);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))

View File

@@ -1,11 +1,11 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use ezkl::circuit::*;
use ezkl::circuit::lookup::LookupOp;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::table::Range;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -84,7 +84,7 @@ impl Circuit<Fr> for MyCircuit {
let op = PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
};
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2, false);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))

View File

@@ -1,8 +1,8 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2, false);
config
.layout(
&mut region,

View File

@@ -1,10 +1,10 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use ezkl::circuit::hybrid::HybridOp;
use ezkl::circuit::*;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::create_keys;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::TranscriptType;
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
@@ -59,7 +59,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2, false);
config
.layout(
&mut region,

View File

@@ -1,8 +1,8 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -55,7 +55,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2, false);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Add))
.unwrap();

View File

@@ -1,9 +1,9 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -56,7 +56,7 @@ impl Circuit<Fr> for MyCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2, false);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Pow(4)))
.unwrap();

View File

@@ -1,10 +1,10 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{BaseConfig as Config, CheckMode};
use ezkl::fieldutils::IntegerRep;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -66,7 +66,7 @@ impl Circuit<Fr> for NLCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2, false);
config
.layout(
&mut region,

View File

@@ -1,10 +1,10 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use criterion::{BenchmarkId, Criterion, Throughput, criterion_group, criterion_main};
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::table::Range;
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
use ezkl::circuit::{BaseConfig as Config, CheckMode, ops::lookup::LookupOp};
use ezkl::fieldutils::IntegerRep;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
@@ -63,7 +63,7 @@ impl Circuit<Fr> for NLCircuit {
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2, false);
config
.layout(
&mut region,

View File

@@ -1,8 +1,8 @@
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
BaseConfig as PolyConfig, CheckMode, ops::lookup::LookupOp, ops::poly::PolyOp,
};
use ezkl::fieldutils::{self, integer_rep_to_felt, IntegerRep};
use ezkl::fieldutils::{self, IntegerRep, integer_rep_to_felt};
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::poly::commitment::Params;
@@ -10,8 +10,8 @@ use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{
create_proof, keygen_pk, keygen_vk, verify_proof, Circuit, Column, ConstraintSystem, Error,
Instance,
Circuit, Column, ConstraintSystem, Error, Instance, create_proof, keygen_pk, keygen_vk,
verify_proof,
},
poly::{
commitment::ParamsProver,
@@ -32,7 +32,6 @@ use mnist::*;
use rand::rngs::OsRng;
use std::marker::PhantomData;
mod params;
const K: usize = 20;
@@ -88,20 +87,20 @@ struct MyCircuit<
}
impl<
const LEN: usize,
const CLASSES: usize,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
const OUT_CHANNELS: usize,
const STRIDE: usize,
const IMAGE_HEIGHT: usize,
const IMAGE_WIDTH: usize,
const IN_CHANNELS: usize,
const PADDING: usize,
> Circuit<F>
const LEN: usize,
const CLASSES: usize,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
const OUT_CHANNELS: usize,
const STRIDE: usize,
const IMAGE_HEIGHT: usize,
const IMAGE_WIDTH: usize,
const IN_CHANNELS: usize,
const PADDING: usize,
> Circuit<F>
for MyCircuit<
LEN,
CLASSES,
@@ -203,7 +202,7 @@ where
.assign_region(
|| "mlp_4d",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 1024, 2);
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 1024, 2, false);
let op = PolyOp::Conv {
padding: vec![(PADDING, PADDING); 2],

View File

@@ -1,8 +1,8 @@
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
BaseConfig as PolyConfig, CheckMode, ops::lookup::LookupOp, ops::poly::PolyOp,
};
use ezkl::fieldutils::{integer_rep_to_felt, IntegerRep};
use ezkl::fieldutils::{IntegerRep, integer_rep_to_felt};
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
@@ -112,7 +112,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.assign_region(
|| "mlp_4d",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2, false);
let x = config
.layer_config
.layout(
@@ -283,10 +283,12 @@ pub fn runmlp() {
let prover = MockProver::run(
K as u32,
&circuit,
vec![public_input
.iter()
.map(|x| integer_rep_to_felt::<F>(*x))
.collect()],
vec![
public_input
.iter()
.map(|x| integer_rep_to_felt::<F>(*x))
.collect(),
],
)
.unwrap();
prover.assert_satisfied();

View File

@@ -1,34 +1,34 @@
use crate::circuit::modules::polycommit::PolyCommitChip;
use crate::circuit::modules::poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
};
use crate::circuit::modules::Module;
use crate::Commitments;
use crate::RunArgs;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::circuit::modules::Module;
use crate::circuit::modules::polycommit::PolyCommitChip;
use crate::circuit::modules::poseidon::{
PoseidonChip,
spec::{POSEIDON_RATE, POSEIDON_WIDTH, PoseidonSpec},
};
use crate::commands::*;
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt};
use crate::graph::TestDataSource;
use crate::graph::{
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
GraphCircuit, GraphSettings, Model, Visibility, quantize_float, scale_to_multiplier,
};
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::{
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
ProofType, TranscriptType,
ProofType, TranscriptType, load_pk, load_vk, save_params, save_vk,
srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
};
use crate::Commitments;
use crate::RunArgs;
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
use halo2curves::bn256::{Bn256, Fq, Fr, G1, G1Affine};
use pyo3::exceptions::{PyIOError, PyRuntimeError};
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3_log;
use pyo3_stub_gen::{
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction, TypeInfo,
TypeInfo, define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction,
};
use snark_verifier::util::arithmetic::PrimeField;
use std::collections::HashSet;
@@ -203,6 +203,9 @@ struct PyRunArgs {
/// bool: Should the circuit use unbounded lookups for log
#[pyo3(get, set)]
pub bounded_log_lookup: bool,
/// bool: Should the circuit use fft for conv
#[pyo3(get, set)]
pub use_fft_for_conv: bool,
/// bool: Should the circuit use range checks for inputs and outputs (set to false if the input is a felt)
#[pyo3(get, set)]
pub ignore_range_check_inputs_outputs: bool,
@@ -238,6 +241,7 @@ impl From<PyRunArgs> for RunArgs {
decomp_base: py_run_args.decomp_base,
decomp_legs: py_run_args.decomp_legs,
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
use_fft_for_conv: py_run_args.use_fft_for_conv,
}
}
}
@@ -262,6 +266,7 @@ impl Into<PyRunArgs> for RunArgs {
decomp_base: self.decomp_base,
decomp_legs: self.decomp_legs,
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
use_fft_for_conv: self.use_fft_for_conv,
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -16,12 +16,12 @@ use std::{
cell::RefCell,
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
atomic::{AtomicUsize, Ordering},
},
};
use super::{lookup::LookupOp, CircuitError};
use super::{CircuitError, lookup::LookupOp};
/// Constants map
pub type ConstantsMap<F> = HashMap<F, ValType<F>>;
@@ -97,6 +97,8 @@ pub struct RegionSettings {
pub base: usize,
/// number of legs for decompositions
pub legs: usize,
/// whether we should use fft conv or naive conv
pub use_fft: bool,
}
#[allow(unsafe_code)]
@@ -106,32 +108,41 @@ unsafe impl Send for RegionSettings {}
impl RegionSettings {
/// Create a new region settings
pub fn new(witness_gen: bool, check_range: bool, base: usize, legs: usize) -> RegionSettings {
pub fn new(
witness_gen: bool,
check_range: bool,
base: usize,
legs: usize,
use_fft: bool,
) -> RegionSettings {
RegionSettings {
witness_gen,
check_range,
base,
legs,
use_fft,
}
}
/// Create a new region settings with all true
pub fn all_true(base: usize, legs: usize) -> RegionSettings {
pub fn all_true(base: usize, legs: usize, use_fft: bool) -> RegionSettings {
RegionSettings {
witness_gen: true,
check_range: true,
base,
legs,
use_fft,
}
}
/// Create a new region settings with all false
pub fn all_false(base: usize, legs: usize) -> RegionSettings {
pub fn all_false(base: usize, legs: usize, use_fft: bool) -> RegionSettings {
RegionSettings {
witness_gen: false,
check_range: false,
base,
legs,
use_fft,
}
}
}
@@ -194,6 +205,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.settings.legs
}
/// use fft for conv
pub fn use_fft(&self) -> bool {
self.settings.use_fft
}
/// get the max dynamic input len
pub fn max_dynamic_input_len(&self) -> usize {
self.max_dynamic_input_len
@@ -273,6 +289,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
num_inner_cols: usize,
decomp_base: usize,
decomp_legs: usize,
use_fft: bool,
) -> RegionCtx<'a, F> {
let region = Some(RefCell::new(region));
let linear_coord = row * num_inner_cols;
@@ -285,7 +302,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
statistics: RegionStatistics::default(),
settings: RegionSettings::all_true(decomp_base, decomp_legs),
settings: RegionSettings::all_true(decomp_base, decomp_legs, use_fft),
assigned_constants: HashMap::new(),
max_dynamic_input_len: 0,
}
@@ -298,13 +315,37 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
num_inner_cols: usize,
decomp_base: usize,
decomp_legs: usize,
use_fft: bool,
constants: ConstantsMap<F>,
) -> RegionCtx<'a, F> {
let mut new_self = Self::new(region, row, num_inner_cols, decomp_base, decomp_legs);
let mut new_self = Self::new(
region,
row,
num_inner_cols,
decomp_base,
decomp_legs,
use_fft,
);
new_self.assigned_constants = constants;
new_self
}
/// convert into dummy
pub fn into_dummy(self) -> RegionCtx<'a, F> {
RegionCtx {
region: None,
num_inner_cols: self.num_inner_cols,
row: self.row,
linear_coord: self.linear_coord,
dynamic_lookup_index: self.dynamic_lookup_index,
shuffle_index: self.shuffle_index,
statistics: self.statistics,
settings: self.settings,
assigned_constants: self.assigned_constants,
max_dynamic_input_len: self.max_dynamic_input_len,
}
}
/// Create a new region context
pub fn new_dummy(
row: usize,
@@ -355,8 +396,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
output: &mut Tensor<T>,
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>
+ Send
+ Sync,
+ Send
+ Sync,
) -> Result<(), CircuitError> {
if self.is_dummy() {
self.dummy_loop(output, inner_loop_function)?;
@@ -391,8 +432,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
output: &mut Tensor<T>,
inner_loop_function: impl Fn(usize, &mut RegionCtx<'a, F>) -> Result<T, CircuitError>
+ Send
+ Sync,
+ Send
+ Sync,
) -> Result<(), CircuitError> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());

View File

@@ -60,7 +60,7 @@ mod matmul {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(
&mut region,
@@ -137,7 +137,7 @@ mod matmul_col_overflow_double_col {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2, false);
config
.layout(
&mut region,
@@ -211,7 +211,7 @@ mod matmul_col_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(
&mut region,
@@ -298,7 +298,7 @@ mod matmul_col_ultra_overflow_double_col {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2, false);
config
.layout(
&mut region,
@@ -418,7 +418,7 @@ mod matmul_col_ultra_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(
&mut region,
@@ -529,7 +529,7 @@ mod dot {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(
&mut region,
@@ -606,7 +606,7 @@ mod dot_col_overflow_triple_col {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 3, 128, 2);
let mut region = RegionCtx::new(region, 0, 3, 128, 2, false);
config
.layout(
&mut region,
@@ -679,7 +679,7 @@ mod dot_col_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(
&mut region,
@@ -752,7 +752,7 @@ mod sum {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(
&mut region,
@@ -822,7 +822,7 @@ mod sum_col_overflow_double_col {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2);
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS, 128, 2, false);
config
.layout(
&mut region,
@@ -891,7 +891,7 @@ mod sum_col_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(
&mut region,
@@ -962,7 +962,7 @@ mod composition {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
let _ = config
.layout(
&mut region,
@@ -1057,7 +1057,7 @@ mod conv {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, true);
config
.layout(
&mut region,
@@ -1070,7 +1070,10 @@ mod conv {
kernel_format: KernelFormat::default(),
}),
)
.map_err(|_| Error::Synthesis)
.map_err(|e| {
println!("Error in synthesizer: {:?}", e);
Error::Synthesis
})
},
)
.unwrap();
@@ -1214,7 +1217,7 @@ mod conv_col_ultra_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(
&mut region,
@@ -1373,7 +1376,7 @@ mod conv_relu_col_ultra_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
let mut region = RegionCtx::new(region, 0, 1, 2, 2, true);
let output = config
.layout(
&mut region,
@@ -1515,7 +1518,7 @@ mod add_w_shape_casting {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.map_err(|_| Error::Synthesis)
@@ -1582,7 +1585,7 @@ mod add {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.map_err(|_| Error::Synthesis)
@@ -1666,7 +1669,7 @@ mod dynamic_lookup {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
for i in 0..NUM_LOOP {
layouts::dynamic_lookup(
&config,
@@ -1813,7 +1816,7 @@ mod shuffle {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
for i in 0..NUM_LOOP {
layouts::shuffles(
&config,
@@ -1929,7 +1932,7 @@ mod add_with_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.map_err(|_| Error::Synthesis)
@@ -1964,11 +1967,11 @@ mod add_with_overflow_and_poseidon {
use halo2curves::bn256::Fr;
use crate::circuit::modules::{
poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip, PoseidonConfig,
},
Module, ModulePlanner,
poseidon::{
PoseidonChip, PoseidonConfig,
spec::{POSEIDON_RATE, POSEIDON_WIDTH, PoseidonSpec},
},
};
use super::*;
@@ -2031,7 +2034,7 @@ mod add_with_overflow_and_poseidon {
layouter.assign_region(
|| "model",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.base
.layout(&mut region, &inputs, Box::new(PolyOp::Add))
@@ -2133,7 +2136,7 @@ mod sub {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Sub))
.map_err(|_| Error::Synthesis)
@@ -2200,7 +2203,7 @@ mod mult {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Mult))
.map_err(|_| Error::Synthesis)
@@ -2267,7 +2270,7 @@ mod pow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Pow(5)))
.map_err(|_| Error::Synthesis)
@@ -2354,7 +2357,7 @@ mod matmul_relu {
layouter.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
let mut region = RegionCtx::new(region, 0, 1, 1024, 2, false);
let op = PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
};
@@ -2461,7 +2464,7 @@ mod relu {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 2, 2);
let mut region = RegionCtx::new(region, 0, 1, 2, 2, false);
Ok(config
.layout(
&mut region,
@@ -2559,7 +2562,7 @@ mod lookup_ultra_overflow {
.assign_region(
|| "",
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
let mut region = RegionCtx::new(region, 0, 1, 128, 2, false);
config
.layout(
&mut region,

View File

@@ -20,8 +20,8 @@ use crate::tensor::TensorError;
use crate::{Commitments, RunArgs};
use crate::{EZKLError, commands::*};
use colored::Colorize;
#[cfg(unix)]
use gag::Gag;
// #[cfg(unix)]
// use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
use halo2_proofs::plonk::{self, Circuit};
use halo2_proofs::poly::VerificationStrategy;
@@ -747,8 +747,11 @@ pub(crate) async fn gen_witness(
let commitment: Commitments = settings.run_args.commitment.into();
let region_settings =
RegionSettings::all_true(settings.run_args.decomp_base, settings.run_args.decomp_legs);
let region_settings = RegionSettings::all_true(
settings.run_args.decomp_base,
settings.run_args.decomp_legs,
settings.run_args.use_fft_for_conv,
);
let start_time = Instant::now();
let witness = if settings.module_requires_polycommit() {
@@ -1145,17 +1148,17 @@ pub(crate) async fn calibrate(
..settings.run_args.clone()
};
// if unix get a gag
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
// // if unix get a gag
// #[cfg(all(not(not(feature = "ezkl")), unix))]
// let _r = match Gag::stdout() {
// Ok(g) => Some(g),
// _ => None,
// };
// #[cfg(all(not(not(feature = "ezkl")), unix))]
// let _g = match Gag::stderr() {
// Ok(g) => Some(g),
// _ => None,
// };
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
@@ -1184,6 +1187,7 @@ pub(crate) async fn calibrate(
RegionSettings::all_true(
settings.run_args.decomp_base,
settings.run_args.decomp_legs,
settings.run_args.use_fft_for_conv,
),
)
.map_err(|e| format!("failed to forward: {}", e))?;
@@ -1209,11 +1213,11 @@ pub(crate) async fn calibrate(
}
}
// drop the gag
#[cfg(all(not(not(feature = "ezkl")), unix))]
drop(_r);
#[cfg(all(not(not(feature = "ezkl")), unix))]
drop(_g);
// // drop the gag
// #[cfg(all(not(not(feature = "ezkl")), unix))]
// drop(_r);
// #[cfg(all(not(not(feature = "ezkl")), unix))]
// drop(_g);
let result = forward_pass_res.get(&key).ok_or("key not found")?;

View File

@@ -1,22 +1,22 @@
use super::GraphSettings;
use super::errors::GraphError;
use super::extract_const_quantized_values;
use super::node::*;
use super::vars::*;
use super::GraphSettings;
use crate::circuit::Input;
use crate::circuit::InputType;
use crate::circuit::Unknown;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::ConstantsMap;
use crate::circuit::region::RegionCtx;
use crate::circuit::region::RegionSettings;
use crate::circuit::table::Range;
use crate::circuit::Input;
use crate::circuit::InputType;
use crate::circuit::Unknown;
use crate::fieldutils::IntegerRep;
use crate::tensor::ValType;
use crate::{
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
tensor::{Tensor, ValTensor},
RunArgs,
circuit::{BaseConfig as PolyConfig, CheckMode, Op, lookup::LookupOp},
tensor::{Tensor, ValTensor},
};
use halo2curves::bn256::Fr as Fp;
@@ -573,7 +573,11 @@ impl Model {
let res = self.dummy_layout(
run_args,
&inputs,
RegionSettings::all_false(run_args.decomp_base, run_args.decomp_legs),
RegionSettings::all_false(
run_args.decomp_base,
run_args.decomp_legs,
run_args.use_fft_for_conv,
),
)?;
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
@@ -1143,6 +1147,8 @@ impl Model {
}
}
println!("results: {:?}", results);
let instance_idx = vars.get_instance_idx();
config.base.layout_tables(layouter)?;
@@ -1159,6 +1165,7 @@ impl Model {
run_args.num_inner_cols,
run_args.decomp_base,
run_args.decomp_legs,
run_args.use_fft_for_conv,
original_constants.clone(),
);
// we need to do this as this loop is called multiple times

View File

@@ -97,11 +97,11 @@ impl From<String> for EZKLError {
use std::str::FromStr;
use circuit::{table::Range, CheckMode};
use circuit::{CheckMode, table::Range};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::Args;
use fieldutils::IntegerRep;
use graph::{Visibility, MAX_PUBLIC_SRS};
use graph::{MAX_PUBLIC_SRS, Visibility};
use halo2_proofs::poly::{
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
};
@@ -344,6 +344,12 @@ pub struct RunArgs {
arg(long, default_value = "false")
)]
pub bounded_log_lookup: bool,
/// Whether to use fft to compute conv
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
arg(long, default_value = "false")
)]
pub use_fft_for_conv: bool,
/// Range check inputs and outputs (turn off if the inputs are felts)
#[cfg_attr(
all(feature = "ezkl", not(target_arch = "wasm32")),
@@ -376,6 +382,7 @@ impl Default for RunArgs {
decomp_base: 16384,
decomp_legs: 2,
ignore_range_check_inputs_outputs: false,
use_fft_for_conv: false,
}
}
}

View File

@@ -737,6 +737,20 @@ impl<T: Clone + TensorType> Tensor<T> {
index
}
/// Flip the order of the inner values of the tensor.
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
/// let b = Tensor::<IntegerRep>::new(Some(&[6, 5, 4, 3, 2, 1]), &[2, 3]).unwrap();
/// assert_eq!(a.flip().unwrap(), b);
/// ```
pub fn flip(&self) -> Result<Tensor<T>, TensorError> {
let mut inner = self.inner.clone();
inner.reverse();
Tensor::new(Some(&inner), &self.dims)
}
/// Fetches every nth element
///
/// ```

View File

@@ -715,6 +715,21 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(tensor)
}
/// Flips the inner tensor's order
pub fn flip(&mut self) -> Result<(), TensorError> {
match self {
ValTensor::Value {
inner: v, dims: _, ..
} => {
*v = v.flip()?;
}
ValTensor::Instance { .. } => {
return Err(TensorError::WrongMethod);
}
};
Ok(())
}
/// Pads the tensor with zeros until its size is divisible by n
///
/// # Arguments

View File

@@ -888,6 +888,39 @@ def get_examples():
return examples
def get_conv_examples():
EXAMPLES_OMIT = [
# these are too large
'mobilenet_large',
'mobilenet',
'doodles',
'nanoGPT',
"self_attention",
'multihead_attention',
'large_op_graph',
'1l_instance_norm',
'variable_cnn',
'accuracy',
'linear_regression',
"mnist_gan",
"smallworm",
"fr_age",
"1d_conv",
]
examples = []
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):
name = subdir.split('/')[-1]
if name in EXAMPLES_OMIT or name == "onnx" or "conv" not in name:
continue
else:
examples.append((
os.path.join(subdir, "network.onnx"),
os.path.join(subdir, "input.json"),
))
return examples
@pytest.mark.parametrize("model_file, input_file", get_examples())
async def test_all_examples(model_file, input_file):
"""Tests all examples in the examples folder"""
@@ -963,3 +996,81 @@ async def test_all_examples(model_file, input_file):
)
assert res == True
@pytest.mark.parametrize("model_file, input_file", get_conv_examples())
async def test_fft_examples(model_file, input_file):
"""Tests all examples in the examples folder"""
# gen settings
settings_path = os.path.join(folder_path, "settings.json")
compiled_model_path = os.path.join(folder_path, 'network.ezkl')
pk_path = os.path.join(folder_path, 'test.pk')
vk_path = os.path.join(folder_path, 'test.vk')
witness_path = os.path.join(folder_path, 'witness.json')
proof_path = os.path.join(folder_path, 'proof.json')
print("Testing example: ", model_file)
run_args = ezkl.PyRunArgs()
run_args.variables = [("batch_size", 1), ("sequence_length", 100), ("<Sym1>", 1)]
run_args.logrows = 22
run_args.use_fft_for_conv = True
res = ezkl.gen_settings(model_file, settings_path, py_run_args=run_args)
assert res
res = await ezkl.calibrate_settings(
input_file, model_file, settings_path, "resources")
assert res
print("Compiling example: ", model_file)
res = ezkl.compile_circuit(model_file, compiled_model_path, settings_path)
assert res
with open(settings_path, 'r') as f:
data = json.load(f)
logrows = data["run_args"]["logrows"]
srs_path = os.path.join(folder_path, f"srs_{logrows}")
# generate the srs file if the path does not exist
if not os.path.exists(srs_path):
print("Generating srs file: ", srs_path)
ezkl.gen_srs(os.path.join(folder_path, srs_path), logrows)
print("Setting up example: ", model_file)
res = ezkl.setup(
compiled_model_path,
vk_path,
pk_path,
srs_path
)
assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
print("Generating witness for example: ", model_file)
res = await ezkl.gen_witness(input_file, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)
print("Proving example: ", model_file)
ezkl.prove(
witness_path,
compiled_model_path,
pk_path,
proof_path,
"single",
srs_path=srs_path,
)
assert os.path.isfile(proof_path)
print("Verifying example: ", model_file)
res = ezkl.verify(
proof_path,
settings_path,
vk_path,
srs_path=srs_path,
)
assert res == True