From 699bd3992bfd388f6f76e6eb45a8886661dedbbb Mon Sep 17 00:00:00 2001 From: Jseam Date: Wed, 3 May 2023 10:24:34 +0800 Subject: [PATCH] feat: Add python bindings for mock and prove (#197) --- .gitignore | 1 + Cargo.lock | 20 +- Cargo.toml | 2 +- requirements.txt | 2 +- src/circuit/mod.rs | 33 +++ src/commands.rs | 67 ++++++ src/execute.rs | 3 +- src/python.rs | 370 +++++++++++++++++++++++----------- tests/python/binding_tests.py | 63 +++++- 9 files changed, 431 insertions(+), 130 deletions(-) diff --git a/.gitignore b/.gitignore index 218c9968..7699670c 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,4 @@ var/ *.egg-info/ .installed.cfg *.egg +.vscode/ \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 06d23481..2d49477a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3862,9 +3862,9 @@ checksum = "fe7765e19fb2ba6fd4373b8d90399f5321683ea7c11b598c6bbaa3a72e9c83b8" [[package]] name = "pyo3" -version = "0.18.2" +version = "0.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfb848f80438f926a9ebddf0a539ed6065434fd7aae03a89312a9821f81b8501" +checksum = "e3b1ac5b3731ba34fdaa9785f8d74d17448cd18f30cf19e0c7e7b1fdb5272109" dependencies = [ "cfg-if 1.0.0", "indoc", @@ -3879,9 +3879,9 @@ dependencies = [ [[package]] name = "pyo3-build-config" -version = "0.18.2" +version = "0.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98a42e7f42e917ce6664c832d5eee481ad514c98250c49e0b03b20593e2c7ed0" +checksum = "9cb946f5ac61bb61a5014924910d936ebd2b23b705f7a4a3c40b05c720b079a3" dependencies = [ "once_cell", "target-lexicon", @@ -3889,9 +3889,9 @@ dependencies = [ [[package]] name = "pyo3-ffi" -version = "0.18.2" +version = "0.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a0707f0ab26826fe4ccd59b69106e9df5e12d097457c7b8f9c0fd1d2743eec4d" +checksum = "fd4d7c5337821916ea2a1d21d1092e8443cf34879e53a0ac653fbb98f44ff65c" dependencies = [ "libc", "pyo3-build-config", @@ -3910,9 +3910,9 @@ dependencies = [ [[package]] name = "pyo3-macros" -version = "0.18.2" +version = "0.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "978d18e61465ecd389e1f235ff5a467146dc4e3c3968b90d274fe73a5dd4a438" +checksum = "a9d39c55dab3fc5a4b25bbd1ac10a2da452c4aca13bb450f22818a002e29648d" dependencies = [ "proc-macro2", "pyo3-macros-backend", @@ -3922,9 +3922,9 @@ dependencies = [ [[package]] name = "pyo3-macros-backend" -version = "0.18.2" +version = "0.18.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e0e1128f85ce3fca66e435e08aa2089a2689c1c48ce97803e13f63124058462" +checksum = "97daff08a4c48320587b5224cc98d609e3c27b6d437315bd40b605c98eeb5918" dependencies = [ "proc-macro2", "quote", diff --git a/Cargo.toml b/Cargo.toml index bba64c1b..9b037595 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,7 @@ serde_traitobject = "0.2.8" bincode = "*" # python binding related deps -pyo3 = { version = "0.18.2", features = ["extension-module", "abi3-py37"], optional = true } +pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37"], optional = true } pyo3-log = { version = "0.8.1", optional = true } # Omit for the time being as ndarrays are not being used # numpy = { version = "0.18.0", optional = true } diff --git a/requirements.txt b/requirements.txt index 78d5bc69..523b6cfa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ attrs==22.2.0 exceptiongroup==1.1.1 importlib-metadata==6.1.0 iniconfig==2.0.0 -maturin==0.14.16 +maturin==0.14.17 packaging==23.0 pluggy==1.0.0 pytest==7.2.2 diff --git a/src/circuit/mod.rs b/src/circuit/mod.rs index 77e4cd4b..5970d8cf 100644 --- a/src/circuit/mod.rs +++ b/src/circuit/mod.rs @@ -20,6 +20,13 @@ use halo2_proofs::{ poly::Rotation, }; use log::warn; +#[cfg(feature = "python-bindings")] +use pyo3::{ + exceptions::PyValueError, + prelude::*, + types::PyString, + conversion::{FromPyObject, PyTryFrom} +}; use serde::{Deserialize, Serialize}; use crate::{ @@ -70,6 +77,32 @@ impl From for CheckMode { } } +#[cfg(feature = "python-bindings")] +/// Converts CheckMode into a PyObject (Required for CheckMode to be compatible with Python) +impl IntoPy for CheckMode { + fn into_py(self, py: Python) -> PyObject { + match self { + CheckMode::SAFE => "safe".to_object(py), + CheckMode::UNSAFE => "unsafe".to_object(py), + } + } +} + +#[cfg(feature = "python-bindings")] +/// Obtains CheckMode from PyObject (Required for CheckMode to be compatible with Python) +impl<'source> FromPyObject<'source> for CheckMode { + fn extract(ob: &'source PyAny) -> PyResult { + let trystr = ::try_from(ob)?; + let strval = trystr.to_string(); + match strval.to_lowercase().as_str() { + "safe" => Ok(CheckMode::SAFE), + "unsafe" => Ok(CheckMode::UNSAFE), + _ => Err(PyValueError::new_err("Invalid value for CheckMode")) + } + + } +} + /// Configuration for an accumulated arg. #[derive(Clone, Debug, Default)] pub struct BaseConfig { diff --git a/src/commands.rs b/src/commands.rs index 5051f00e..16e8f07c 100644 --- a/src/commands.rs +++ b/src/commands.rs @@ -9,8 +9,16 @@ use std::error::Error; use std::fs::File; use std::io::{stdin, stdout, Read, Write}; use std::path::PathBuf; +#[cfg(feature = "python-bindings")] +use pyo3::{ + exceptions::PyValueError, + prelude::*, + types::PyString, + conversion::{FromPyObject, PyTryFrom} +}; use crate::circuit::CheckMode; +use crate::graph::{VarVisibility, Visibility}; #[allow(missing_docs)] #[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] @@ -27,6 +35,31 @@ impl std::fmt::Display for TranscriptType { .fmt(f) } } +#[cfg(feature = "python-bindings")] +/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python) +impl IntoPy for TranscriptType { + fn into_py(self, py: Python) -> PyObject { + match self { + TranscriptType::Blake => "blake".to_object(py), + TranscriptType::Poseidon => "poseidon".to_object(py), + TranscriptType::EVM => "evm".to_object(py), + } + } +} +#[cfg(feature = "python-bindings")] +/// Obtains TranscriptType from PyObject (Required for TranscriptType to be compatible with Python) +impl<'source> FromPyObject<'source> for TranscriptType { + fn extract(ob: &'source PyAny) -> PyResult { + let trystr = ::try_from(ob)?; + let strval = trystr.to_string(); + match strval.to_lowercase().as_str() { + "blake" => Ok(TranscriptType::Blake), + "poseidon" => Ok(TranscriptType::Poseidon), + "evm" => Ok(TranscriptType::EVM), + _ => Err(PyValueError::new_err("Invalid value for TranscriptType")) + } + } +} #[allow(missing_docs)] #[derive(ValueEnum, Copy, Clone, Debug, PartialEq, Eq, Deserialize, Serialize)] @@ -42,6 +75,29 @@ impl std::fmt::Display for StrategyType { .fmt(f) } } +#[cfg(feature = "python-bindings")] +/// Converts StrategyType into a PyObject (Required for StrategyType to be compatible with Python) +impl IntoPy for StrategyType { + fn into_py(self, py: Python) -> PyObject { + match self { + StrategyType::Single => "single".to_object(py), + StrategyType::Accum => "accum".to_object(py), + } + } +} +#[cfg(feature = "python-bindings")] +/// Obtains StrategyType from PyObject (Required for StrategyType to be compatible with Python) +impl<'source> FromPyObject<'source> for StrategyType { + fn extract(ob: &'source PyAny) -> PyResult { + let trystr = ::try_from(ob)?; + let strval = trystr.to_string(); + match strval.to_lowercase().as_str() { + "single" => Ok(StrategyType::Single), + "accum" => Ok(StrategyType::Accum), + _ => Err(PyValueError::new_err("Invalid value for StrategyType")) + } + } +} /// Parameters specific to a proving run #[derive(Debug, Args, Deserialize, Serialize, Clone, Default)] @@ -79,6 +135,17 @@ pub struct RunArgs { pub check_mode: CheckMode, } +#[allow(missing_docs)] +impl RunArgs { + pub fn to_var_visibility(&self) -> VarVisibility { + VarVisibility { + input: if self.public_inputs { Visibility::Public } else { Visibility::Private }, + params: if self.public_params { Visibility::Public } else { Visibility::Private }, + output: if self.public_outputs { Visibility::Public } else { Visibility::Private }, + } + } +} + const EZKLCONF: &str = "EZKLCONF"; const RUNARGS: &str = "RUNARGS"; diff --git a/src/execute.rs b/src/execute.rs index 7985ddac..82ed9653 100644 --- a/src/execute.rs +++ b/src/execute.rs @@ -639,7 +639,8 @@ fn verify_aggr( Ok(()) } -fn load_params_cmd(params_path: PathBuf, logrows: u32) -> Result, Box> { +/// helper function for load_params +pub fn load_params_cmd(params_path: PathBuf, logrows: u32) -> Result, Box> { let mut params: ParamsKZG = load_params::>(params_path)?; info!("downsizing params to {} logrows", logrows); if logrows < params.k() { diff --git a/src/python.rs b/src/python.rs index e55c55bd..f786834d 100644 --- a/src/python.rs +++ b/src/python.rs @@ -1,16 +1,21 @@ use crate::circuit::CheckMode; -use crate::commands::RunArgs; -use crate::graph::{quantize_float, Mode, Model, VarVisibility, Visibility}; -use crate::pfsys::{gen_srs as ezkl_gen_srs, prepare_data, save_params}; -use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme; +use crate::commands::{RunArgs, StrategyType, TranscriptType}; +use crate::execute::{create_proof_circuit_kzg, load_params_cmd}; +use crate::graph::{quantize_float, Mode, Model, ModelCircuit, VarVisibility}; +use crate::pfsys::{gen_srs as ezkl_gen_srs, create_keys, prepare_data, save_params, save_vk}; +use halo2_proofs::poly::kzg::{ + commitment::KZGCommitmentScheme, + strategy::{AccumulatorStrategy, SingleStrategy as KZGSingleStrategy}, +}; +use halo2_proofs::dev::MockProver; use halo2curves::bn256::{Bn256, Fr}; use log::trace; -use pyo3::exceptions::{PyIOError, PyRuntimeError, PyValueError}; +use pyo3::exceptions::{PyIOError, PyRuntimeError}; use pyo3::prelude::*; use pyo3::wrap_pyfunction; use pyo3_log; -use std::fs::File; -use std::path::PathBuf; +use rayon::prelude::{IntoParallelRefIterator, ParallelIterator}; +use std::{fs::File, path::PathBuf, sync::Arc}; use tabled::Table; // See commands.rs and execute.rs @@ -30,31 +35,82 @@ use tabled::Table; // .render(args.logrows, &circuit, &root)?; // } -// Table +/// Environment variable for EZKLCONF +// const EZKLCONF: &str = "EZKLCONF"; -#[pyfunction] -fn table(model: String) -> Result { - // use default values to initialize model - let run_args = RunArgs { - tolerance: 0, - scale: 7, - bits: 16, - logrows: 17, - public_inputs: true, - public_outputs: true, - public_params: false, - pack_base: 1, - check_mode: CheckMode::SAFE, - allocated_constraints: None, - }; +/// pyclass containing the struct used for run_args +#[pyclass] +#[derive(Clone)] +struct PyRunArgs { + #[pyo3(get, set)] + pub tolerance: usize, + #[pyo3(get, set)] + pub scale: u32, + #[pyo3(get, set)] + pub bits: usize, + #[pyo3(get, set)] + pub logrows: u32, + #[pyo3(get, set)] + pub public_inputs: bool, + #[pyo3(get, set)] + pub public_outputs: bool, + #[pyo3(get, set)] + pub public_params: bool, + #[pyo3(get, set)] + pub pack_base: u32, + #[pyo3(get, set)] + pub allocated_constraints: Option, + #[pyo3(get, set)] + pub check_mode: CheckMode, +} - // use default values to initialize model - let visibility = VarVisibility { - input: Visibility::Public, - params: Visibility::Private, - output: Visibility::Public, - }; +/// default instantiation of PyRunArgs +#[pymethods] +impl PyRunArgs { + #[new] + fn new() -> Self { + PyRunArgs { + tolerance: 0, + scale: 7, + bits: 16, + logrows: 17, + public_inputs: true, + public_outputs: true, + public_params: false, + pack_base: 1, + allocated_constraints: None, + check_mode: CheckMode::SAFE, + } + } +} + +/// Conversion between PyRunArgs and RunArgs +impl From for RunArgs { + fn from(py_run_args: PyRunArgs) -> Self { + RunArgs { + tolerance: py_run_args.tolerance, + scale: py_run_args.scale, + bits: py_run_args.bits, + logrows: py_run_args.logrows, + public_inputs: py_run_args.public_inputs, + public_outputs: py_run_args.public_outputs, + public_params: py_run_args.public_params, + pack_base: py_run_args.pack_base, + allocated_constraints: py_run_args.allocated_constraints, + check_mode: py_run_args.check_mode, + } + } +} + +/// Displays the table as a string in python +#[pyfunction(signature = ( + model, + py_run_args = None +))] +fn table(model: String, py_run_args: Option) -> Result { + let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into(); + let visibility: VarVisibility = run_args.to_var_visibility(); let result = Model::::new(model, run_args, Mode::Mock, visibility); match result { @@ -63,114 +119,196 @@ fn table(model: String) -> Result { } } -#[pyfunction] -fn gen_srs(params_path: PathBuf, logrows: u32) -> PyResult<()> { - let run_args = RunArgs { - tolerance: 0, - scale: 7, - bits: 16, - logrows: logrows, - public_inputs: true, - public_outputs: true, - public_params: false, - pack_base: 1, - check_mode: CheckMode::SAFE, - allocated_constraints: None, - }; +/// generates the srs +#[pyfunction(signature = ( + params_path, + py_run_args = None +))] +fn gen_srs(params_path: PathBuf, py_run_args: Option) -> PyResult<()> { + let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into(); let params = ezkl_gen_srs::>(run_args.logrows); save_params::>(¶ms_path, ¶ms)?; Ok(()) } +/// runs the forward pass operation #[pyfunction(signature = ( data, model, output, - tolerance=0, - scale=7, - bits=16, - logrows=17, - public_inputs=true, - public_outputs=true, - public_params=false, - pack_base=1, - check_mode="safe" + py_run_args = None ))] fn forward( data: String, model: String, output: String, - tolerance: usize, - scale: u32, - bits: usize, - logrows: u32, - public_inputs: bool, - public_outputs: bool, - public_params: bool, - pack_base: u32, - check_mode: &str, + py_run_args: Option ) -> PyResult<()> { - let data = prepare_data(data); + let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into(); + let mut data = prepare_data(data).map_err(|_| PyIOError::new_err("Failed to import data"))?; - match data { - Ok(m) => { - let run_args = RunArgs { - tolerance: tolerance, - scale: scale, - bits: bits, - logrows: logrows, - public_inputs: public_inputs, - public_outputs: public_outputs, - public_params: public_params, - pack_base: pack_base, - check_mode: CheckMode::from(check_mode.to_string()), - allocated_constraints: None, - }; - let mut new_data = m; - let mut model_inputs = vec![]; - // quantize the supplied data using the provided scale. - for v in new_data.input_data.iter() { - let t: Result, _> = v - .iter() - .map(|x| quantize_float(x, 0.0, run_args.scale)) - .collect(); - match t { - Ok(t) => model_inputs.push(t.into_iter().into()), - Err(_) => return Err(PyValueError::new_err("Failed to quantize vector")), - } - } - let res = Model::::forward(model, &model_inputs, run_args); + let mut model_inputs = vec![]; + // quantize the supplied data using the provided scale. + // for v in new_data.input_data.iter() { + // match vector_to_quantized(v, &Vec::from([v.len()]), 0.0, run_args.scale) { + // Ok(t) => model_inputs.push(t), + // Err(_) => return Err(PyValueError::new_err("Failed to quantize vector")), + // } + // } + for v in data.input_data.iter() { + let t: Vec = v + .par_iter() + .map(|x| quantize_float(x, 0.0, run_args.scale).unwrap()) + .collect(); + model_inputs.push(t.into_iter().into()); + } + let res = Model::::forward(model, &model_inputs, run_args) + .map_err(|_| PyRuntimeError::new_err("Failed to compute forward pass"))?; - match res { - Ok(r) => { - let float_res: Vec> = r.iter().map(|t| t.to_vec()).collect(); - trace!("forward pass output: {:?}", float_res); - new_data.output_data = float_res; + let float_res: Vec> = res.iter().map(|t| t.to_vec()).collect(); + trace!("forward pass output: {:?}", float_res); + data.output_data = float_res; - match serde_json::to_writer(&File::create(output)?, &new_data) { - Ok(_) => { - // TODO output a dictionary - // obtain gil - // TODO: Convert to Python::with_gil() when it stabilizes - // let gil = Python::acquire_gil(); - // obtain python instance - // let py = gil.python(); - // return Ok(new_data.to_object(py)) - Ok(()) - } - Err(_) => return Err(PyIOError::new_err("Failed to create output file")), - } - } - Err(_) => Err(PyRuntimeError::new_err("Failed to compute forward pass")), - } + match serde_json::to_writer(&File::create(output)?, &data) { + Ok(_) => { + // TODO output a dictionary + // obtain gil + // TODO: Convert to Python::with_gil() when it stabilizes + // let gil = Python::acquire_gil(); + // obtain python instance + // let py = gil.python(); + // return Ok(new_data.to_object(py)) + Ok(()) } - Err(_) => Err(PyIOError::new_err("Failed to import files")), + Err(_) => return Err(PyIOError::new_err("Failed to create output file")), + } +} + +/// mocks the prover +#[pyfunction(signature = ( + data, + model, + py_run_args = None +))] +fn mock( + data: String, + model: String, + py_run_args: Option +) -> Result { + let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into(); + let logrows = run_args.logrows; + let data = prepare_data(data).map_err(|_| PyIOError::new_err("Failed to import data"))?; + let visibility = run_args.to_var_visibility(); + + let procmodel = Model::::new(model, run_args, Mode::Mock, visibility) + .map_err(|_| PyIOError::new_err("Failed to process model"))?; + + let arcmodel: Arc> = Arc::new(procmodel); + let circuit = ModelCircuit::::new(&data, arcmodel) + .map_err(|_| PyRuntimeError::new_err("Failed to create circuit"))?; + + let public_inputs = circuit.prepare_public_inputs(&data) + .map_err(|_| PyRuntimeError::new_err("Failed to prepare public inputs"))?; + let prover = MockProver::run(logrows, &circuit, public_inputs) + .map_err(|_| PyRuntimeError::new_err("Failed to run prover"))?; + + prover.assert_satisfied(); + + let res = prover.verify(); + match res { + Ok(_) => return Ok(true), + Err(_) => return Ok(false), + } +} + +/// runs the prover on a set of inputs +#[pyfunction(signature = ( + data, + model, + vk_path, + proof_path, + params_path, + transcript, + strategy, + py_run_args = None +))] +fn prove( + data: String, + model: String, + vk_path: PathBuf, + proof_path: PathBuf, + params_path: PathBuf, + transcript: TranscriptType, + strategy: StrategyType, + py_run_args: Option +) -> Result { + let run_args: RunArgs = py_run_args.unwrap_or_else(PyRunArgs::new).into(); + let logrows = run_args.logrows; + let check_mode = run_args.check_mode; + let data = prepare_data(data).map_err(|_| PyIOError::new_err("Failed to import data"))?; + let visibility = run_args.to_var_visibility(); + + let procmodel = Model::::new(model, run_args, Mode::Prove, visibility) + .map_err(|_| PyIOError::new_err("Failed to process model"))?; + + let arcmodel: Arc> = Arc::new(procmodel); + let circuit = ModelCircuit::::new(&data, arcmodel) + .map_err(|_| PyRuntimeError::new_err("Failed to create circuit"))?; + + let public_inputs = circuit.prepare_public_inputs(&data) + .map_err(|_| PyRuntimeError::new_err("Failed to prepare public inputs"))?; + + let params = load_params_cmd(params_path, logrows) + .map_err(|_| PyIOError::new_err("Failed to load params"))?; + + let proving_key = create_keys::, Fr, ModelCircuit>(&circuit, ¶ms) + .map_err(|_| PyRuntimeError::new_err("Failed to create proving key"))?; + + let snark = match strategy { + StrategyType::Single => { + let strategy = KZGSingleStrategy::new(¶ms); + match create_proof_circuit_kzg( + circuit, + ¶ms, + public_inputs, + &proving_key, + transcript, + strategy, + check_mode + ) { + Ok(snark) => Ok(snark), + Err(_) => Err(PyRuntimeError::new_err("Failed to create proof circuit single strategy")), + } + } + StrategyType::Accum => { + let strategy = AccumulatorStrategy::new(¶ms); + match create_proof_circuit_kzg( + circuit, + ¶ms, + public_inputs, + &proving_key, + transcript, + strategy, + check_mode + ) { + Ok(snark) => Ok(snark), + Err(_) => Err(PyRuntimeError::new_err("Failed to create proof circuit using accumulator strategy")), + } + } + }; + + match snark?.save(&proof_path) { + Ok(_) => { + match save_vk::>(&vk_path, proving_key.get_vk()) { + Ok(_) => Ok(true), + Err(_) => Err(PyIOError::new_err("Failed to save vk to vk_path")) + } + } + Err(_) => Err(PyIOError::new_err("Failed to save to proof path")) } } -// TODO: Mock // TODO: Aggregate -// TODO: Prove // TODO: CreateEVMVerifier // TODO: CreateEVMVerifierAggr // TODO: DeployVerifierEVM @@ -184,8 +322,12 @@ fn forward( #[pymodule] fn ezkl_lib(_py: Python<'_>, m: &PyModule) -> PyResult<()> { pyo3_log::init(); + m.add_class::()?; m.add_function(wrap_pyfunction!(table, m)?)?; m.add_function(wrap_pyfunction!(gen_srs, m)?)?; m.add_function(wrap_pyfunction!(forward, m)?)?; + m.add_function(wrap_pyfunction!(mock, m)?)?; + m.add_function(wrap_pyfunction!(prove, m)?)?; + Ok(()) } diff --git a/tests/python/binding_tests.py b/tests/python/binding_tests.py index d676acf9..351add05 100644 --- a/tests/python/binding_tests.py +++ b/tests/python/binding_tests.py @@ -49,7 +49,7 @@ def test_gen_srs(): Test for gen_srs() with 17 logrows. You may want to comment this test as it takes a long time to run """ - ezkl_lib.gen_srs(params_path, 17) + ezkl_lib.gen_srs(params_path) assert os.path.isfile(params_path) @@ -80,7 +80,64 @@ def test_forward(): with open(output_path, "r") as f: data = json.load(f) - assert data == {"input_data": [[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]], "input_shapes": [ - [1, 5, 5]], "output_data": [[0.9140625, 0.9140625, 0.9140625, 0.9140625, 0.9140625, 0.9140625, 0.9140625, 0.9140625, 0.9140625]]} + assert data == {"input_data":[[0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1]],"input_shapes":[[1,5,5]],"output_data":[[0.9140625,0.9140625,0.9140625,0.9140625,0.9140625,0.9140625,0.9140625,0.9140625,0.9140625]]} os.remove(output_path) + +def test_mock(): + """ + Test for mock + """ + + data_path = os.path.join( + examples_path, + 'onnx', + '1l_average', + 'input.json' + ) + + model_path = os.path.join( + examples_path, + 'onnx', + '1l_average', + 'network.onnx' + ) + + res = ezkl_lib.mock(data_path, model_path) + assert res == True + + +def test_prove(): + """ + Test for prove + """ + + data_path = os.path.join( + examples_path, + 'onnx', + '1l_average', + 'input.json' + ) + + model_path = os.path.join( + examples_path, + 'onnx', + '1l_average', + 'network.onnx' + ) + + vk_path = os.path.join(folder_path, 'test.vk') + proof_path = os.path.join(folder_path, 'test.pf') + + res = ezkl_lib.prove( + data_path, + model_path, + vk_path, + proof_path, + params_path, + "poseidon", + "single", + ) + assert res == True + assert os.path.isfile(vk_path) + assert os.path.isfile(proof_path)