Compare commits

..

1 Commits

Author SHA1 Message Date
github-actions[bot]
3640c9aa6d ci: update version string in docs 2024-06-10 16:14:16 +00:00
33 changed files with 638 additions and 1007 deletions

View File

@@ -236,8 +236,6 @@ jobs:
with:
crate: cargo-nextest
locked: true
# - name: The Worm Mock
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: public outputs and tolerance > 0
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10
@@ -347,8 +345,6 @@ jobs:
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain outputs & kzg inputs + params)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain all kzg)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM + on chain inputs & outputs hashes)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
- name: KZG prove and verify tests (EVM)

501
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -39,6 +39,7 @@ snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch =
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
maybe-rayon = { version = "0.1.1", default_features = false }
bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = "1.6.0"
@@ -62,13 +63,16 @@ reqwest = { version = "0.12.4", default-features = false, features = [
openssl = { version = "0.10.55", features = ["vendored"] }
tokio-postgres = "0.7.10"
pg_bigdecimal = "0.1.5"
futures-util = "0.3.30"
lazy_static = "1.4.0"
colored_json = { version = "3.0.1", default_features = false, optional = true }
plotters = { version = "0.3.0", default_features = false, optional = true }
regex = { version = "1", default_features = false }
tokio = { version = "1.35", default_features = false, features = [
"macros",
"rt-multi-thread"
] }
tokio-util = { version = "0.7.9", features = ["codec"] }
pyo3 = { version = "0.21.2", features = [
"extension-module",
"abi3-py37",
@@ -79,8 +83,9 @@ pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="m
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.10.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "05ebf550aa9922b221af4635c21a67a8d2af12a9", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
objc = { version = "0.2.4", optional = true }
@@ -103,10 +108,8 @@ console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
[dev-dependencies]
criterion = { version = "0.3", features = ["html_reports"] }
tempfile = "3.3.0"
lazy_static = "1.4.0"
mnist = "0.5"
@@ -177,7 +180,7 @@ required-features = ["ezkl"]
[features]
web = ["wasm-bindgen-rayon"]
default = ["ezkl", "mv-lookup", "no-banner", "parallel-poly-read"]
default = ["ezkl", "mv-lookup", "no-banner"]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
ezkl = [
@@ -191,7 +194,6 @@ ezkl = [
"colored_json",
"halo2_proofs/circuit-params",
]
parallel-poly-read = ["halo2_proofs/parallel-poly-read"]
mv-lookup = [
"halo2_proofs/mv-lookup",
"snark-verifier/mv-lookup",
@@ -201,7 +203,6 @@ det-prove = []
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
no-banner = []
no-update = []
metal = ["dep:metal", "dep:objc"]
# icicle patch to 0.1.0 if feature icicle is enabled
@@ -209,7 +210,7 @@ metal = ["dep:metal", "dep:objc"]
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#8cfca221f53069a0374687654882b99e729041d7", package = "halo2_proofs" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a", package = "halo2_proofs", branch = "ac/optional-selector-poly" }
[profile.release]

View File

@@ -93,6 +93,9 @@ contract LoadInstances {
}
}
// Contract that checks that the COMMITMENT_KZG bytes is equal to the first part of the proof.
pragma solidity ^0.8.0;
// The kzg commitments of a given model, all aggregated into a single bytes array.
// At solidity generation time, the commitments are hardcoded into the contract via the COMMITMENT_KZG constant.
// It will be used to check that the proof commitments match the expected commitments.
@@ -160,7 +163,7 @@ contract SwapProofCommitments {
}
return equal; // Return true if the commitment comparison passed
} /// end checkKzgCommits
}
}
// This contract serves as a Data Attestation Verifier for the EZKL model.

View File

@@ -1,4 +1,4 @@
ezkl==0.0.0
ezkl==11.3.3
sphinx
sphinx-rtd-theme
sphinxcontrib-napoleon

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '11.3.3'
version = release

View File

@@ -39,7 +39,7 @@
"import json\n",
"import numpy as np\n",
"from sklearn.svm import SVC\n",
"from hummingbird.ml import convert\n",
"import sk2torch\n",
"import torch\n",
"import ezkl\n",
"import os\n",
@@ -59,11 +59,11 @@
"# Train an SVM on the data and wrap it in PyTorch.\n",
"sk_model = SVC(probability=True)\n",
"sk_model.fit(xs, ys)\n",
"model = convert(sk_model, \"torch\").model\n",
"model = sk2torch.wrap(sk_model)\n",
"\n",
"\n",
"\n",
"\n",
"model\n",
"\n"
]
},
@@ -84,6 +84,33 @@
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f0ca328",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"# Create a coordinate grid to compute a vector field on.\n",
"spaced = np.linspace(-2, 2, num=25)\n",
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
"\n",
"\n",
"# Compute the gradients of the SVM output.\n",
"outputs = model.predict_proba(grid_xs)[:, 1]\n",
"(input_grads,) = torch.autograd.grad(outputs.sum(), (grid_xs,))\n",
"\n",
"\n",
"# Create a quiver plot of the vector field.\n",
"plt.quiver(\n",
" grid_xs[:, 0].detach().numpy(),\n",
" grid_xs[:, 1].detach().numpy(),\n",
" input_grads[:, 0].detach().numpy(),\n",
" input_grads[:, 1].detach().numpy(),\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -92,14 +119,14 @@
"outputs": [],
"source": [
"\n",
"spaced = np.linspace(-2, 2, num=25)\n",
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
"\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = xs.shape[1:]\n",
"x = grid_xs[0:1]\n",
"torch_out = model.predict(x)\n",
"# Export the model\n",
"torch.onnx.export(model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
@@ -116,7 +143,9 @@
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_data=[d])\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
@@ -138,7 +167,6 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0bee4d7f",
"metadata": {},
"outputs": [],
"source": [
@@ -192,7 +220,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 10,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
@@ -413,9 +441,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.9.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -1 +0,0 @@
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":7,"param_scale":7,"scale_rebase_multiplier":10,"lookup_range":[0,0],"logrows":13,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_constraints":5619,"total_const_size":513,"model_instance_shapes":[[1,3,10,10]],"model_output_scales":[14],"model_input_scales":[7],"module_sizes":{"kzg":[],"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -1 +0,0 @@
network.onnx filter=lfs diff=lfs merge=lfs -text

View File

@@ -1,47 +0,0 @@
## The worm
This is an onnx file for a [WormVAE](https://github.com/TuragaLab/wormvae?tab=readme-ov-file) model, which is a VAE / latent-space representation of the C. elegans connectome.
The model "is a large-scale latent variable model with a very high-dimensional latent space
consisting of voltage dynamics of 300 neurons over 5 minutes of time at the simulation frequency
of 160 Hz. The generative model for these latent variables is described by stochastic differential
equations modeling the nonlinear dynamics of the network activity." (see [here](https://openreview.net/pdf?id=CJzi3dRlJE-)).
In effect this is a generative model for a worm's voltage dynamics, which can be used to generate new worm-like voltage dynamics given previous connectome state.
Using ezkl you can create a zk circuit equivalent to the wormvae model, allowing you to "prove" execution of the worm model. If you're feeling particularly adventurous, you can also use the zk circuit to generate new worm-state that can be verified on chain.
To do so you'll first want to fetch the files using git-lfs (as the onnx file is too large to be stored in git).
```bash
git lfs fetch --all
```
You'll then want to use the usual ezkl loop to generate the zk circuit. We recommend using fixed visibility for the model parameters, as the model is quite large and this will prune the circuit significantly.
```bash
ezkl gen-settings --param-visibility=fixed
cp input.json calibration.json
ezkl calibrate-settings
ezkl compile-circuit
ezkl gen-witness
ezkl prove
```
You might also need to aggregate the proof to get it to fit on chain.
```bash
ezkl aggregate
```
You can then create a smart contract that verifies this aggregate proof
```bash
ezkl create-evm-verifier-aggr
```
This can then be deployed on the chain of your choice.
> Note: the model is large and thus we recommend a machine with at least 512GB of RAM to run the above commands. If you're ever compute constrained you can always use the lilith service to generate the zk circuit. Message us on discord or telegram for more details :)

File diff suppressed because one or more lines are too long

View File

@@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2f88c5901d3768ec21e3cf2f2840d255e84fa13c364df86b24d960cca3333769
size 82095882

View File

@@ -1 +0,0 @@
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":0,"param_scale":6,"scale_rebase_multiplier":1,"lookup_range":[-32768,32768],"logrows":17,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Fixed"},"num_constraints":367422820,"total_const_size":365577160,"model_instance_shapes":[[1,300,1200]],"model_output_scales":[6],"model_input_scales":[0,0,0],"module_sizes":{"kzg":[],"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Div":{"denom":64.0}},"ReLU",{"Ln":{"scale":64.0}},{"Exp":{"scale":64.0}}],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}

View File

@@ -250,10 +250,6 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, CircuitError> {
if values[0].len() != values[1].len() {
return Err(TensorError::DimMismatch("dot".to_string()).into());
}
region.flush()?;
// time this entire function run
let global_start = instant::Instant::now();
@@ -261,17 +257,12 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
let mut values = values.clone();
// this section has been optimized to death, don't mess with it
let mut removal_indices = values[0].get_const_zero_indices();
let second_zero_indices = values[1].get_const_zero_indices();
let mut removal_indices = values[0].get_const_zero_indices()?;
let second_zero_indices = values[1].get_const_zero_indices()?;
removal_indices.extend(second_zero_indices);
removal_indices.par_sort_unstable();
removal_indices.dedup();
// if empty return a const
if removal_indices.len() == values[0].len() {
return Ok(create_zero_tensor(1));
}
// is already sorted
values[0].remove_indices(&mut removal_indices, true)?;
values[1].remove_indices(&mut removal_indices, true)?;
@@ -279,6 +270,15 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
let elapsed = global_start.elapsed();
trace!("filtering const zero indices took: {:?}", elapsed);
if values[0].len() != values[1].len() {
return Err(TensorError::DimMismatch("dot".to_string()).into());
}
// if empty return a const
if values[0].is_empty() && values[1].is_empty() {
return Ok(create_zero_tensor(1));
}
let start = instant::Instant::now();
let mut inputs = vec![];
let block_width = config.custom_gates.output.num_inner_cols();
@@ -343,7 +343,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
.collect::<Result<Vec<_>, CircuitError>>()?;
}
let last_elem = output.last()?;
let last_elem = output.get_slice(&[output.len() - 1..output.len()])?;
region.increment(assigned_len);
@@ -1779,7 +1779,12 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
let mut values = values.clone();
// this section has been optimized to death, don't mess with it
values[0].remove_const_zero_values();
let mut removal_indices = values[0].get_const_zero_indices()?;
removal_indices.par_sort_unstable();
removal_indices.dedup();
// is already sorted
values[0].remove_indices(&mut removal_indices, true)?;
let elapsed = global_start.elapsed();
trace!("filtering const zero indices took: {:?}", elapsed);
@@ -1836,7 +1841,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
}
}
let last_elem = output.last()?;
let last_elem = output.get_slice(&[output.len() - 1..output.len()])?;
region.increment(assigned_len);
@@ -1879,7 +1884,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>
let global_start = instant::Instant::now();
// this section has been optimized to death, don't mess with it
let removal_indices = values[0].get_const_zero_indices();
let removal_indices = values[0].get_const_zero_indices()?;
let elapsed = global_start.elapsed();
trace!("finding const zero indices took: {:?}", elapsed);
@@ -1940,7 +1945,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>
.collect::<Result<Vec<_>, CircuitError>>()?;
}
let last_elem = output.last()?;
let last_elem = output.get_slice(&[output.len() - 1..output.len()])?;
region.increment(assigned_len);
@@ -2251,22 +2256,22 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
let orig_lhs = lhs.clone();
let orig_rhs = rhs.clone();
let start = instant::Instant::now();
let first_zero_indices = HashSet::from_iter(lhs.get_const_zero_indices());
let second_zero_indices = HashSet::from_iter(rhs.get_const_zero_indices());
let removal_indices = match op {
// get indices of zeros
let first_zero_indices = lhs.get_const_zero_indices()?;
let second_zero_indices = rhs.get_const_zero_indices()?;
let mut removal_indices = match op {
BaseOp::Add | BaseOp::Mult => {
// join the zero indices
first_zero_indices
.union(&second_zero_indices)
.cloned()
.collect()
let mut removal_indices = first_zero_indices.clone();
removal_indices.extend(second_zero_indices.clone());
removal_indices
}
BaseOp::Sub => second_zero_indices.clone(),
_ => return Err(CircuitError::UnsupportedOp),
};
trace!("setting up indices took {:?}", start.elapsed());
removal_indices.dedup();
let removal_indices: HashSet<&usize> = HashSet::from_iter(removal_indices.iter());
let removal_indices_ptr = &removal_indices;
if lhs.len() != rhs.len() {
return Err(CircuitError::DimMismatch(format!(
@@ -2275,19 +2280,20 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
)));
}
let inputs = [lhs.clone(), rhs.clone()]
.iter()
.enumerate()
.map(|(i, input)| {
let mut inputs = vec![];
for (i, input) in [lhs.clone(), rhs.clone()].iter().enumerate() {
let inp = {
let res = region.assign_with_omissions(
&config.custom_gates.inputs[i],
input,
&removal_indices,
removal_indices_ptr,
)?;
Ok(res.get_inner()?)
})
.collect::<Result<Vec<_>, CircuitError>>()?;
res.get_inner()?
};
inputs.push(inp);
}
// Now we can assign the dot product
// time the calc
@@ -2302,20 +2308,15 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
trace!("pairwise {} calc took {:?}", op.as_str(), start.elapsed());
let elapsed = start.elapsed();
let start = instant::Instant::now();
let assigned_len = op_result.len() - removal_indices.len();
let assigned_len = inputs[0].len() - removal_indices.len();
let mut output = region.assign_with_omissions(
&config.custom_gates.output,
&op_result.into(),
&removal_indices,
removal_indices_ptr,
)?;
trace!(
"pairwise {} input assign took {:?}",
op.as_str(),
start.elapsed()
);
trace!("pairwise {} calc took {:?}", op.as_str(), elapsed);
// Enable the selectors
if !region.is_dummy() {
@@ -2336,11 +2337,16 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
let a_tensor = orig_lhs.get_inner_tensor()?;
let b_tensor = orig_rhs.get_inner_tensor()?;
let first_zero_indices: HashSet<&usize> = HashSet::from_iter(first_zero_indices.iter());
let second_zero_indices: HashSet<&usize> = HashSet::from_iter(second_zero_indices.iter());
trace!("setting up indices took {:?}", start.elapsed());
// infill the zero indices with the correct values from values[0] or values[1]
if !removal_indices.is_empty() {
if !removal_indices_ptr.is_empty() {
output
.get_inner_tensor_mut()?
.par_enum_map_mut_filtered(&removal_indices, |i| {
.par_enum_map_mut_filtered(removal_indices_ptr, |i| {
let val = match op {
BaseOp::Add => {
let a_is_null = first_zero_indices.contains(&i);
@@ -2380,7 +2386,6 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
end,
region.row()
);
trace!("----------------------------");
Ok(output)
}
@@ -3772,7 +3777,7 @@ pub(crate) fn boolean_identity<
values: &[ValTensor<F>; 1],
assign: bool,
) -> Result<ValTensor<F>, CircuitError> {
let output = if assign || !values[0].get_const_indices().is_empty() {
let output = if assign || !values[0].get_const_indices()?.is_empty() {
// get zero constants indices
let output = region.assign(&config.custom_gates.output, &values[0])?;
region.increment(output.len());
@@ -3937,10 +3942,11 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::
let x = values[0].clone();
let removal_indices = values[0].get_const_indices();
let removal_indices: HashSet<usize> = HashSet::from_iter(removal_indices);
let removal_indices = values[0].get_const_indices()?;
let removal_indices: HashSet<&usize> = HashSet::from_iter(removal_indices.iter());
let removal_indices_ptr = &removal_indices;
let w = region.assign_with_omissions(&config.static_lookups.input, &x, &removal_indices)?;
let w = region.assign_with_omissions(&config.static_lookups.input, &x, removal_indices_ptr)?;
let output = w.get_inner_tensor()?.par_enum_map(|i, e| {
Ok::<_, TensorError>(if let Some(f) = e.get_felt_eval() {
@@ -3958,7 +3964,7 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::
let mut output = region.assign_with_omissions(
&config.static_lookups.output,
&output.into(),
&removal_indices,
removal_indices_ptr,
)?;
let is_dummy = region.is_dummy();
@@ -3988,7 +3994,11 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::
})?
.into();
region.assign_with_omissions(&config.static_lookups.index, &table_index, &removal_indices)?;
region.assign_with_omissions(
&config.static_lookups.index,
&table_index,
removal_indices_ptr,
)?;
if !is_dummy {
(0..assigned_len)

View File

@@ -9,8 +9,6 @@ use halo2_proofs::{
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use maybe_rayon::iter::ParallelExtend;
use portable_atomic::AtomicI64 as AtomicInt;
use std::{
cell::RefCell,
@@ -517,18 +515,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, CircuitError> {
) -> Result<ValTensor<F>, Error> {
if let Some(region) = &self.region {
Ok(var.assign(
var.assign(
&mut region.borrow_mut(),
self.linear_coord,
values,
&mut self.assigned_constants,
)?)
)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.par_extend(values_map);
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
@@ -544,18 +542,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, CircuitError> {
) -> Result<ValTensor<F>, Error> {
if let Some(region) = &self.region {
Ok(var.assign(
var.assign(
&mut region.borrow_mut(),
self.combined_dynamic_shuffle_coord(),
values,
&mut self.assigned_constants,
)?)
)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.par_extend(values_map);
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
@@ -566,7 +564,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, CircuitError> {
) -> Result<ValTensor<F>, Error> {
self.assign_dynamic_lookup(var, values)
}
@@ -575,24 +573,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
ommissions: &HashSet<usize>,
) -> Result<ValTensor<F>, CircuitError> {
ommissions: &HashSet<&usize>,
) -> Result<ValTensor<F>, Error> {
if let Some(region) = &self.region {
Ok(var.assign_with_omissions(
var.assign_with_omissions(
&mut region.borrow_mut(),
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)?)
)
} else {
let mut values_clone = values.clone();
let mut indices = ommissions.clone().into_iter().collect_vec();
values_clone.remove_indices(&mut indices, false)?;
let inner_tensor = values.get_inner_tensor().unwrap();
let mut values_map = values.create_constants_map();
let values_map = values.create_constants_map();
for o in ommissions {
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
values_map.remove(&value);
}
}
self.assigned_constants.par_extend(values_map);
self.assigned_constants.extend(values_map);
Ok(values.clone())
}

View File

@@ -379,9 +379,9 @@ pub enum Commands {
#[arg(long = "target", default_value = DEFAULT_CALIBRATION_TARGET, value_hint = clap::ValueHint::Other)]
/// Target for calibration. Set to "resources" to optimize for computational resource. Otherwise, set to "accuracy" to optimize for accuracy.
target: CalibrationTarget,
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be ceil(2^k * lookup_safety_margin). larger = safer but slower
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
#[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN, value_hint = clap::ValueHint::Other)]
lookup_safety_margin: f64,
lookup_safety_margin: i64,
/// Optional scales to specifically try for calibration. Example, --scales 0,4
#[arg(long, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::Other)]
scales: Option<Vec<crate::Scale>>,
@@ -868,7 +868,6 @@ pub enum Commands {
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_vk: Option<H160Flag>,
},
#[cfg(not(feature = "no-update"))]
/// Updates ezkl binary to version specified (or latest if not specified)
Update {
/// The version to update to

View File

@@ -327,7 +327,11 @@ pub async fn setup_eth_backend(
ProviderBuilder::new()
.with_recommended_fillers()
.signer(EthereumSigner::from(wallet))
.on_http(endpoint.parse().map_err(|_| EthError::UrlParse(endpoint))?),
.on_http(
endpoint
.parse()
.map_err(|_| EthError::UrlParse(endpoint.clone()))?,
),
);
let chain_id = client.get_chain_id().await?;
@@ -350,7 +354,8 @@ pub async fn deploy_contract_via_solidity(
let (abi, bytecode, runtime_bytecode) =
get_contract_artifacts(sol_code_path, contract_name, runs).await?;
let factory = get_sol_contract_factory(abi, bytecode, runtime_bytecode, client, None::<()>)?;
let factory =
get_sol_contract_factory(abi, bytecode, runtime_bytecode, client.clone(), None::<()>)?;
let contract = factory.deploy().await?;
Ok(contract)
@@ -447,30 +452,20 @@ pub async fn deploy_da_verifier_via_solidity(
}
}
let (abi, bytecode, runtime_bytecode) =
get_contract_artifacts(sol_code_path, "DataAttestation", runs).await?;
let (contract_addresses, call_data, decimals) = if !calls_to_accounts.is_empty() {
parse_calls_to_accounts(calls_to_accounts)?
} else {
// if calls to accounts is empty then we know need to check that atleast there kzg visibility in the settings file
let kzg_visibility = settings.run_args.input_visibility.is_polycommit()
|| settings.run_args.output_visibility.is_polycommit()
|| settings.run_args.param_visibility.is_polycommit();
if !kzg_visibility {
return Err(EthError::OnChainDataSource);
}
let factory =
get_sol_contract_factory::<_, ()>(abi, bytecode, runtime_bytecode, client, None)?;
let contract = factory.deploy().await?;
return Ok(contract);
return Err(EthError::OnChainDataSource);
};
let (abi, bytecode, runtime_bytecode) =
get_contract_artifacts(sol_code_path, "DataAttestation", runs).await?;
let factory = get_sol_contract_factory(
abi,
bytecode,
runtime_bytecode,
client,
client.clone(),
Some((
// address[] memory _contractAddresses,
DynSeqToken(
@@ -511,7 +506,7 @@ pub async fn deploy_da_verifier_via_solidity(
),
// uint8 _instanceOffset,
WordToken(U256::from(contract_instance_offset as u32).into()),
// address _admin
//address _admin
WordToken(client_address.into_word()),
)),
)?;
@@ -534,7 +529,7 @@ fn parse_calls_to_accounts(
let mut call_data = vec![];
let mut decimals: Vec<Vec<U256>> = vec![];
for (i, val) in calls_to_accounts.iter().enumerate() {
let contract_address_bytes = hex::decode(&val.address)?;
let contract_address_bytes = hex::decode(val.address.clone())?;
let contract_address = H160::from_slice(&contract_address_bytes);
contract_addresses.push(contract_address);
call_data.push(vec![]);
@@ -578,7 +573,7 @@ pub async fn update_account_calls(
let (client, client_address) = setup_eth_backend(rpc_url, None).await?;
let contract = DataAttestation::new(addr, &client);
let contract = DataAttestation::new(addr, client.clone());
info!("contract_addresses: {:#?}", contract_addresses);
@@ -809,7 +804,7 @@ pub async fn test_on_chain_data<M: 'static + Provider<Http<Client>, Ethereum>>(
client: Arc<M>,
data: &[Vec<FileSourceInner>],
) -> Result<Vec<CallsToAccount>, EthError> {
let (contract, decimals) = setup_test_contract(client, data).await?;
let (contract, decimals) = setup_test_contract(client.clone(), data).await?;
// Get the encoded call data for each input
let mut calldata = vec![];
@@ -841,10 +836,10 @@ pub async fn read_on_chain_inputs<M: 'static + Provider<Http<Client>, Ethereum>>
let mut decimals = vec![];
for on_chain_data in data {
// Construct the address
let contract_address_bytes = hex::decode(&on_chain_data.address)?;
let contract_address_bytes = hex::decode(on_chain_data.address.clone())?;
let contract_address = H160::from_slice(&contract_address_bytes);
for (call_data, decimal) in &on_chain_data.call_data {
let call_data_bytes = hex::decode(call_data)?;
let call_data_bytes = hex::decode(call_data.clone())?;
let input: TransactionInput = call_data_bytes.into();
let tx = TransactionRequest::default()
@@ -871,8 +866,8 @@ pub async fn evm_quantize<M: 'static + Provider<Http<Client>, Ethereum>>(
) -> Result<Vec<Fr>, EthError> {
let contract = QuantizeData::deploy(&client).await?;
let fetched_inputs = &data.0;
let decimals = &data.1;
let fetched_inputs = data.0.clone();
let decimals = data.1.clone();
let fetched_inputs = fetched_inputs
.iter()
@@ -948,7 +943,7 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
(None, false) => {
return Err(EthError::NoConstructor);
}
(None, true) => bytecode,
(None, true) => bytecode.clone(),
(Some(_), _) => {
let mut data = bytecode.to_vec();
@@ -960,7 +955,7 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
}
};
Ok(CallBuilder::new_raw_deploy(client, data))
Ok(CallBuilder::new_raw_deploy(client.clone(), data))
}
/// Compiles a solidity verifier contract and returns the abi, bytecode, and runtime bytecode
@@ -1035,7 +1030,7 @@ pub fn fix_da_sol(
// fill in the quantization params and total calls
// as constants to the contract to save on gas
if let Some(input_data) = &input_data {
if let Some(input_data) = input_data {
let input_calls: usize = input_data.iter().map(|v| v.call_data.len()).sum();
accounts_len = input_data.len();
contract = contract.replace(
@@ -1043,7 +1038,7 @@ pub fn fix_da_sol(
&format!("uint256 constant INPUT_CALLS = {};", input_calls),
);
}
if let Some(output_data) = &output_data {
if let Some(output_data) = output_data {
let output_calls: usize = output_data.iter().map(|v| v.call_data.len()).sum();
accounts_len += output_data.len();
contract = contract.replace(
@@ -1053,9 +1048,8 @@ pub fn fix_da_sol(
}
contract = contract.replace("AccountCall[]", &format!("AccountCall[{}]", accounts_len));
// The case where a combination of on-chain data source + kzg commit is provided.
if commitment_bytes.is_some() && !commitment_bytes.as_ref().unwrap().is_empty() {
let commitment_bytes = commitment_bytes.as_ref().unwrap();
if commitment_bytes.clone().is_some() && !commitment_bytes.clone().unwrap().is_empty() {
let commitment_bytes = commitment_bytes.unwrap();
let hex_string = hex::encode(commitment_bytes);
contract = contract.replace(
"bytes constant COMMITMENT_KZG = hex\"\";",
@@ -1070,44 +1064,5 @@ pub fn fix_da_sol(
);
}
// if both input and output data is none then we will only deploy the DataAttest contract, adding in the verifyWithDataAttestation function
if input_data.is_none()
&& output_data.is_none()
&& commitment_bytes.as_ref().is_some()
&& !commitment_bytes.as_ref().unwrap().is_empty()
{
contract = contract.replace(
"contract SwapProofCommitments {",
"contract DataAttestation {",
);
// Remove everything past the end of the checkKzgCommits function
if let Some(pos) = contract.find(" } /// end checkKzgCommits") {
contract.truncate(pos);
contract.push('}');
}
// Add the Solidity function below checkKzgCommits
contract.push_str(
r#"
function verifyWithDataAttestation(
address verifier,
bytes calldata encoded
) public view returns (bool) {
require(verifier.code.length > 0, "Address: call to non-contract");
require(checkKzgCommits(encoded), "Invalid KZG commitments");
// static call the verifier contract to verify the proof
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
if (success) {
return abi.decode(returndata, (bool));
} else {
revert("low-level call to verifier failed");
}
}
}"#,
);
}
Ok(contract)
}

View File

@@ -506,12 +506,10 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
)
.await
}
#[cfg(not(feature = "no-update"))]
Commands::Update { version } => update_ezkl_binary(&version).map(|e| e.to_string()),
}
}
#[cfg(not(feature = "no-update"))]
/// Assert that the version is valid
fn assert_version_is_valid(version: &str) -> Result<(), EZKLError> {
let err_string = "Invalid version string. Must be in the format v0.0.0";
@@ -529,10 +527,8 @@ fn assert_version_is_valid(version: &str) -> Result<(), EZKLError> {
Ok(())
}
#[cfg(not(feature = "no-update"))]
const INSTALL_BYTES: &[u8] = include_bytes!("../install_ezkl_cli.sh");
#[cfg(not(feature = "no-update"))]
fn update_ezkl_binary(version: &Option<String>) -> Result<String, EZKLError> {
// run the install script with the version
let install_script = std::str::from_utf8(INSTALL_BYTES)?;
@@ -1013,7 +1009,7 @@ pub(crate) async fn calibrate(
data: PathBuf,
settings_path: PathBuf,
target: CalibrationTarget,
lookup_safety_margin: f64,
lookup_safety_margin: i64,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
only_range_check_rebase: bool,
@@ -1502,10 +1498,10 @@ pub(crate) async fn create_evm_vk(
#[cfg(not(target_arch = "wasm32"))]
pub(crate) async fn create_evm_data_attestation(
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
input: PathBuf,
witness: Option<PathBuf>,
_sol_code_path: PathBuf,
_abi_path: PathBuf,
_input: PathBuf,
_witness: Option<PathBuf>,
) -> Result<String, EZKLError> {
#[allow(unused_imports)]
use crate::graph::{DataSource, VarVisibility};
@@ -1517,7 +1513,7 @@ pub(crate) async fn create_evm_data_attestation(
trace!("params computed");
// if input is not provided, we just instantiate dummy input data
let data = GraphData::from_path(input).unwrap_or(GraphData::new(DataSource::File(vec![])));
let data = GraphData::from_path(_input).unwrap_or(GraphData::new(DataSource::File(vec![])));
let output_data = if let Some(DataSource::OnChain(source)) = data.output_data {
if visibility.output.is_private() {
@@ -1552,7 +1548,7 @@ pub(crate) async fn create_evm_data_attestation(
|| settings.run_args.output_visibility == Visibility::KZGCommit
|| settings.run_args.param_visibility == Visibility::KZGCommit
{
let witness = GraphWitness::from_path(witness.unwrap_or(DEFAULT_WITNESS.into()))?;
let witness = GraphWitness::from_path(_witness.unwrap_or(DEFAULT_WITNESS.into()))?;
let commitments = witness.get_polycommitments();
let proof_first_bytes = get_proof_commitments::<
KZGCommitmentScheme<Bn256>,
@@ -1566,12 +1562,12 @@ pub(crate) async fn create_evm_data_attestation(
};
let output = fix_da_sol(input_data, output_data, commitment_bytes)?;
let mut f = File::create(sol_code_path.clone())?;
let mut f = File::create(_sol_code_path.clone())?;
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestation", 0).await?;
let (abi, _, _) = get_contract_artifacts(_sol_code_path, "DataAttestation", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
serde_json::to_writer(std::fs::File::create(_abi_path)?, &abi)?;
Ok(String::new())
}

View File

@@ -16,7 +16,7 @@ pub fn i64_to_felt<F: PrimeField>(x: i64) -> F {
if x >= 0 {
F::from_u128(x as u128)
} else {
-F::from_u128(x.saturating_neg() as u128)
-F::from_u128((-x) as u128)
}
}

View File

@@ -41,16 +41,16 @@ pub enum GraphError {
/// Error when attempting to rescale an operation
#[error("failed to rescale inputs for {0}")]
RescalingError(String),
/// Reading a file failed
#[error("[io] ({0}) {1}")]
ReadWriteFileError(String, String),
/// Error when attempting to load a model from a file
#[error("failed to load model")]
ModelLoad(#[from] std::io::Error),
/// Model serialization error
#[error("failed to ser/deser model: {0}")]
ModelSerialize(#[from] bincode::Error),
/// Tract error
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[error("[tract] {0}")]
TractError(#[from] tract_onnx::prelude::TractError),
TractError(#[from] tract_onnx::tract_core::anyhow::Error),
/// Packing exponent is too large
#[error("largest packing exponent exceeds max. try reducing the scale")]
PackingExponent,

View File

@@ -485,25 +485,18 @@ impl GraphData {
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
let reader = std::fs::File::open(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let reader = std::fs::File::open(path)?;
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
let mut buf = String::new();
reader.read_to_string(&mut buf).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
reader.read_to_string(&mut buf)?;
let graph_input = serde_json::from_str(&buf)?;
Ok(graph_input)
}
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let file = std::fs::File::create(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
// buf writer
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, self)?;
Ok(())
}

View File

@@ -267,9 +267,7 @@ impl GraphWitness {
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
let file = std::fs::File::open(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let file = std::fs::File::open(path.clone())?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::from_reader(reader).map_err(|e| e.into())
@@ -277,11 +275,9 @@ impl GraphWitness {
/// Save the model input to a file
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let file = std::fs::File::create(path.clone()).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
// use buf writer
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
let writer =
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, &self).map_err(|e| e.into())
}
@@ -644,9 +640,7 @@ impl GraphCircuit {
}
///
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
let f = std::fs::File::create(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let f = std::fs::File::create(path)?;
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
bincode::serialize_into(writer, &self)?;
Ok(())
@@ -655,9 +649,7 @@ impl GraphCircuit {
///
pub fn load(path: std::path::PathBuf) -> Result<Self, GraphError> {
// read bytes from file
let f = std::fs::File::open(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let f = std::fs::File::open(path)?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
let result: GraphCircuit = bincode::deserialize_from(reader)?;
@@ -1034,10 +1026,10 @@ impl GraphCircuit {
Ok(data)
}
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: f64) -> Range {
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i64) -> Range {
(
(lookup_safety_margin * min_max_lookup.0 as f64).floor() as i64,
(lookup_safety_margin * min_max_lookup.1 as f64).ceil() as i64,
lookup_safety_margin * min_max_lookup.0,
lookup_safety_margin * min_max_lookup.1,
)
}
@@ -1070,7 +1062,7 @@ impl GraphCircuit {
min_max_lookup: Range,
max_range_size: i64,
max_logrows: Option<u32>,
lookup_safety_margin: f64,
lookup_safety_margin: i64,
) -> Result<(), GraphError> {
// load the max logrows
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
@@ -1080,13 +1072,9 @@ impl GraphCircuit {
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
// check if subtraction overflows
let lookup_size =
(safe_lookup_range.1.saturating_sub(safe_lookup_range.0)).saturating_abs();
let lookup_size = (safe_lookup_range.1 - safe_lookup_range.0).abs();
// check if has overflowed max lookup input
if lookup_size > (MAX_LOOKUP_ABS as f64 / lookup_safety_margin).floor() as i64 {
if lookup_size > MAX_LOOKUP_ABS / lookup_safety_margin {
return Err(GraphError::LookupRangeTooLarge(
lookup_size.unsigned_abs() as usize
));

View File

@@ -483,9 +483,7 @@ impl Model {
///
pub fn save(&self, path: PathBuf) -> Result<(), GraphError> {
let f = std::fs::File::create(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let f = std::fs::File::create(path)?;
let writer = std::io::BufWriter::new(f);
bincode::serialize_into(writer, &self)?;
Ok(())
@@ -494,16 +492,10 @@ impl Model {
///
pub fn load(path: PathBuf) -> Result<Self, GraphError> {
// read bytes from file
let mut f = std::fs::File::open(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let metadata = fs::metadata(&path).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
let mut f = std::fs::File::open(&path)?;
let metadata = fs::metadata(&path)?;
let mut buffer = vec![0; metadata.len() as usize];
f.read_exact(&mut buffer).map_err(|e| {
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
})?;
f.read_exact(&mut buffer)?;
let result = bincode::deserialize(&buffer)?;
Ok(result)
}
@@ -609,7 +601,9 @@ impl Model {
reader: &mut dyn std::io::Read,
run_args: &RunArgs,
) -> Result<TractResult, GraphError> {
use tract_onnx::tract_hir::internal::GenericFactoid;
use tract_onnx::{
tract_core::internal::IntoArcTensor, tract_hir::internal::GenericFactoid,
};
let mut model = tract_onnx::onnx().model_for_read(reader)?;
@@ -646,11 +640,29 @@ impl Model {
}
// Note: do not optimize the model, as the layout will depend on underlying hardware
let typed_model = model
let mut typed_model = model
.into_typed()?
.concretize_dims(&symbol_values)?
.into_decluttered()?;
// concretize constants
for node in typed_model.eval_order()? {
let node = typed_model.node_mut(node);
if let Some(op) = node.op_as_mut::<tract_onnx::tract_core::ops::konst::Const>() {
if op.0.datum_type() == DatumType::TDim {
// get inner value to Arc<Tensor>
let mut constant = op.0.as_ref().clone();
// Generally a shape or hyperparam
constant
.as_slice_mut::<tract_onnx::prelude::TDim>()?
.iter_mut()
.for_each(|x| *x = x.eval(&symbol_values));
op.0 = constant.into_arc_tensor();
}
}
}
Ok((typed_model, symbol_values))
}
@@ -963,11 +975,8 @@ impl Model {
) -> Result<Vec<Vec<Tensor<f32>>>, GraphError> {
use tract_onnx::tract_core::internal::IntoArcTensor;
let mut file = std::fs::File::open(model_path).map_err(|e| {
GraphError::ReadWriteFileError(model_path.display().to_string(), e.to_string())
})?;
let (model, _) = Model::load_onnx_using_tract(&mut file, run_args)?;
let (model, _) =
Model::load_onnx_using_tract(&mut std::fs::File::open(model_path)?, run_args)?;
let datum_types: Vec<DatumType> = model
.input_outlets()?
@@ -996,10 +1005,7 @@ impl Model {
/// * `params` - A [GraphSettings] struct holding parsed CLI arguments.
#[cfg(not(target_arch = "wasm32"))]
pub fn from_run_args(run_args: &RunArgs, model: &std::path::Path) -> Result<Self, GraphError> {
let mut file = std::fs::File::open(model).map_err(|e| {
GraphError::ReadWriteFileError(model.display().to_string(), e.to_string())
})?;
Model::new(&mut file, run_args)
Model::new(&mut std::fs::File::open(model)?, run_args)
}
/// Configures a model for the circuit

View File

@@ -85,34 +85,6 @@ pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
mult.log2().round() as crate::Scale
}
#[cfg(not(target_arch = "wasm32"))]
/// extract padding from a onnx node.
pub fn extract_padding(
pool_spec: &PoolSpec,
num_dims: usize,
) -> Result<Vec<(usize, usize)>, GraphError> {
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
PaddingSpec::Valid => vec![(0, 0); num_dims],
_ => {
return Err(GraphError::MissingParams("padding".to_string()));
}
};
Ok(padding)
}
#[cfg(not(target_arch = "wasm32"))]
/// Extracts the strides from a onnx node.
pub fn extract_strides(pool_spec: &PoolSpec) -> Result<Vec<usize>, GraphError> {
Ok(pool_spec
.strides
.clone()
.ok_or(GraphError::MissingParams("stride".to_string()))?
.to_vec())
}
/// Gets the shape of a onnx node's outlets.
#[cfg(not(target_arch = "wasm32"))]
pub fn node_output_shapes(
@@ -283,11 +255,6 @@ pub fn new_op_from_onnx(
.flat_map(|x| x.out_scales())
.collect::<Vec<_>>();
let input_dims = inputs
.iter()
.flat_map(|x| x.out_dims())
.collect::<Vec<_>>();
let mut replace_const = |scale: crate::Scale,
index: usize,
default_op: SupportedOp|
@@ -342,9 +309,12 @@ pub fn new_op_from_onnx(
}
}
"MultiBroadcastTo" => {
let _op = load_op::<MultiBroadcastTo>(node.op(), idx, node.op().name().to_string())?;
let shapes = node_output_shapes(&node, symbol_values)?;
let shape = shapes[0].clone();
let op = load_op::<MultiBroadcastTo>(node.op(), idx, node.op().name().to_string())?;
let shape = op.shape.clone();
let shape = shape
.iter()
.map(|x| x.to_usize())
.collect::<Result<Vec<_>, _>>()?;
SupportedOp::Linear(PolyOp::MultiBroadcastTo { shape })
}
@@ -1103,8 +1073,18 @@ pub fn new_op_from_onnx(
));
}
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, input_dims[0].len())?;
let stride = pool_spec
.strides
.clone()
.ok_or(GraphError::MissingParams("stride".to_string()))?;
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(GraphError::MissingParams("padding".to_string()));
}
};
let kernel_shape = &pool_spec.kernel_shape;
SupportedOp::Hybrid(HybridOp::MaxPool {
@@ -1171,10 +1151,21 @@ pub fn new_op_from_onnx(
));
}
let pool_spec = &conv_node.pool_spec;
let stride = match conv_node.pool_spec.strides.clone() {
Some(s) => s.to_vec(),
None => {
return Err(GraphError::MissingParams("strides".to_string()));
}
};
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, input_dims[0].len())?;
let padding = match &conv_node.pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(GraphError::MissingParams("padding".to_string()));
}
};
// if bias exists then rescale it to the input + kernel scale
if input_scales.len() == 3 {
@@ -1223,10 +1214,21 @@ pub fn new_op_from_onnx(
));
}
let pool_spec = &deconv_node.pool_spec;
let stride = match deconv_node.pool_spec.strides.clone() {
Some(s) => s.to_vec(),
None => {
return Err(GraphError::MissingParams("strides".to_string()));
}
};
let padding = match &deconv_node.pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(GraphError::MissingParams("padding".to_string()));
}
};
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, input_dims[0].len())?;
// if bias exists then rescale it to the input + kernel scale
if input_scales.len() == 3 {
let bias_scale = input_scales[2];
@@ -1337,8 +1339,18 @@ pub fn new_op_from_onnx(
));
}
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, input_dims[0].len())?;
let stride = pool_spec
.strides
.clone()
.ok_or(GraphError::MissingParams("stride".to_string()))?;
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(GraphError::MissingParams("padding".to_string()));
}
};
SupportedOp::Hybrid(HybridOp::SumPool {
padding,

View File

@@ -887,7 +887,7 @@ fn calibrate_settings(
model: PathBuf,
settings: PathBuf,
target: CalibrationTarget,
lookup_safety_margin: f64,
lookup_safety_margin: i64,
scales: Option<Vec<crate::Scale>>,
scale_rebase_multiplier: Vec<u32>,
max_logrows: Option<u32>,
@@ -1491,7 +1491,7 @@ fn encode_evm_calldata<'a>(
/// The path to the SRS file
///
/// render_vk_separately: bool
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create_evm_vk command
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
///
/// Returns
/// -------
@@ -1533,56 +1533,6 @@ fn create_evm_verifier(
})
}
/// Creates an Evm verifer key. This command should be called after create_evm_verifier with the render_vk_separately arg set to true. By rendering a verification key separately you can reuse the same verifier for similar circuit setups with different verifying keys, helping to reduce the amount of state our verifiers store on the blockchain.
///
/// Arguments
/// ---------
/// vk_path: str
/// The path to the verification key file
///
/// settings_path: str
/// The path to the settings file
///
/// sol_code_path: str
/// The path to the create the solidity verifying key.
///
/// abi_path: str
/// The path to create the ABI for the solidity verifier
///
/// srs_path: str
/// The path to the SRS file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
srs_path=None
))]
fn create_evm_vk(
py: Python,
vk_path: PathBuf,
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
srs_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path)
.await
.map_err(|e| {
let err_str = format!("Failed to run create_evm_verifier: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
})
}
/// Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
///
/// Arguments
@@ -1812,7 +1762,7 @@ fn deploy_da_evm(
/// Arguments
/// ---------
/// addr_verifier: str
/// The verifier contract's address as a hex string
/// The path to verifier contract's address
///
/// proof_path: str
/// The path to the proof file (generated using the prove command)
@@ -1824,7 +1774,7 @@ fn deploy_da_evm(
/// does the verifier use data attestation ?
///
/// addr_vk: str
/// The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
///
/// Returns
/// -------
/// bool
@@ -1975,7 +1925,6 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(compile_circuit, m)?)?;
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_vk, m)?)?;
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;

View File

@@ -1281,30 +1281,6 @@ impl<T: Clone + TensorType> Tensor<T> {
Ok(t)
}
/// Get last elem from Tensor
/// ```
/// use ezkl::tensor::Tensor;
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
/// let mut b = Tensor::<i32>::new(Some(&[3]), &[1]).unwrap();
///
/// assert_eq!(a.last().unwrap(), b);
/// ```
pub fn last(&self) -> Result<Tensor<T>, TensorError>
where
T: Send + Sync,
{
let res = match self.inner.last() {
Some(e) => e.clone(),
None => {
return Err(TensorError::DimError(
"Cannot get last element of empty tensor".to_string(),
))
}
};
Tensor::new(Some(&[res]), &[1])
}
/// Maps a function to tensors and enumerates in parallel
/// ```
/// use ezkl::tensor::{Tensor, TensorError};
@@ -1317,7 +1293,7 @@ impl<T: Clone + TensorType> Tensor<T> {
E: Error + std::marker::Send + std::marker::Sync,
>(
&mut self,
filter_indices: &std::collections::HashSet<usize>,
filter_indices: &std::collections::HashSet<&usize>,
f: F,
) -> Result<(), E>
where

View File

@@ -1,12 +1,12 @@
use core::{iter::FilterMap, slice::Iter};
use crate::circuit::region::ConstantsMap;
use maybe_rayon::slice::Iter;
use super::{
ops::{intercalate_values, pad, resize},
*,
};
use halo2_proofs::{arithmetic::Field, circuit::Cell, plonk::Instance};
use maybe_rayon::iter::{FilterMap, IntoParallelIterator, ParallelIterator};
pub(crate) fn create_constant_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
@@ -460,7 +460,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
&self,
) -> FilterMap<Iter<'_, ValType<F>>, fn(&ValType<F>) -> Option<(F, ValType<F>)>> {
match self {
ValTensor::Value { inner, .. } => inner.par_iter().filter_map(|x| {
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
@@ -573,27 +573,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(())
}
/// Calls `get_slice` on the inner tensor.
pub fn last(&self) -> Result<ValTensor<F>, TensorError> {
let slice = match self {
ValTensor::Value {
inner: v,
dims: _,
scale,
} => {
let inner = v.last()?;
let dims = inner.dims().to_vec();
ValTensor::Value {
inner,
dims,
scale: *scale,
}
}
_ => return Err(TensorError::WrongMethod),
};
Ok(slice)
}
/// Calls `get_slice` on the inner tensor.
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, TensorError> {
if indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims() {
@@ -774,72 +753,43 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(())
}
/// remove constant zero values constants
pub fn remove_const_zero_values(&mut self) {
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.clone()
.into_par_iter()
.filter_map(|e| {
if let ValType::Constant(r) = e {
if r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if r == F::ZERO {
return None;
}
}
Some(e)
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => {}
}
}
/// gets constants
pub fn get_const_zero_indices(&self) -> Vec<usize> {
pub fn get_const_zero_indices(&self) -> Result<Vec<usize>, TensorError> {
match self {
ValTensor::Value { inner: v, .. } => v
.par_iter()
.enumerate()
.filter_map(|(i, e)| {
ValTensor::Value { inner: v, .. } => {
let mut indices = vec![];
for (i, e) in v.iter().enumerate() {
if let ValType::Constant(r) = e {
if *r == F::ZERO {
return Some(i);
indices.push(i);
}
} else if let ValType::AssignedConstant(_, r) = e {
if *r == F::ZERO {
return Some(i);
indices.push(i);
}
}
None
})
.collect(),
ValTensor::Instance { .. } => vec![],
}
Ok(indices)
}
ValTensor::Instance { .. } => Ok(vec![]),
}
}
/// gets constants
pub fn get_const_indices(&self) -> Vec<usize> {
pub fn get_const_indices(&self) -> Result<Vec<usize>, TensorError> {
match self {
ValTensor::Value { inner: v, .. } => v
.par_iter()
.enumerate()
.filter_map(|(i, e)| {
ValTensor::Value { inner: v, .. } => {
let mut indices = vec![];
for (i, e) in v.iter().enumerate() {
if let ValType::Constant(_) = e {
Some(i)
indices.push(i);
} else if let ValType::AssignedConstant(_, _) = e {
Some(i)
} else {
None
indices.push(i);
}
})
.collect(),
ValTensor::Instance { .. } => vec![],
}
Ok(indices)
}
ValTensor::Instance { .. } => Ok(vec![]),
}
}

View File

@@ -319,7 +319,7 @@ impl VarTensor {
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
omissions: &HashSet<usize>,
omissions: &HashSet<&usize>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut assigned_coord = 0;

View File

@@ -183,13 +183,12 @@ mod native_tests {
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr_proof.json";
const LARGE_TESTS: [&str; 6] = [
const LARGE_TESTS: [&str; 5] = [
"self_attention",
"nanoGPT",
"multihead_attention",
"mobilenet",
"mnist_gan",
"smallworm",
];
const ACCURACY_CAL_TESTS: [&str; 6] = [
@@ -201,7 +200,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 94] = [
const TESTS: [&str; 93] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
@@ -299,7 +298,6 @@ mod native_tests {
"1l_lppool",
"lstm_large", // 91
"lstm_medium", // 92
"lenet_5", // 93
];
const WASM_TESTS: [&str; 46] = [
@@ -538,7 +536,7 @@ mod native_tests {
}
});
seq!(N in 0..=93 {
seq!(N in 0..=92 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -942,7 +940,7 @@ mod native_tests {
});
seq!(N in 0..=5 {
seq!(N in 0..=4 {
#(#[test_case(LARGE_TESTS[N])])*
#[ignore]
@@ -1068,15 +1066,6 @@ mod native_tests {
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "file", "on-chain", "polycommit", "public", "polycommit");
test_dir.close().unwrap();
}
#(#[test_case(TESTS_ON_CHAIN_INPUT[N])])*
fn kzg_evm_on_chain_all_kzg_params_prove_and_verify_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
let _anvil_child = crate::native_tests::start_anvil(true, Hardfork::Latest);
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "file", "file", "polycommit", "polycommit", "polycommit");
test_dir.close().unwrap();
}
});
@@ -2341,6 +2330,7 @@ mod native_tests {
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
init_params(settings_path.clone().into());
let data_path = format!("{}/{}/input.json", test_dir, example_name);
@@ -2352,6 +2342,62 @@ mod native_tests {
let test_input_source = format!("--input-source={}", input_source);
let test_output_source = format!("--output-source={}", output_source);
// load witness
let witness: GraphWitness = GraphWitness::from_path(witness_path.clone().into()).unwrap();
let mut input: GraphData = GraphData::from_path(data_path.clone().into()).unwrap();
if input_visibility == "hashed" {
let hashes = witness.processed_inputs.unwrap().poseidon_hash.unwrap();
input.input_data = DataSource::File(
hashes
.iter()
.map(|h| vec![FileSourceInner::Field(*h)])
.collect(),
);
}
if output_visibility == "hashed" {
let hashes = witness.processed_outputs.unwrap().poseidon_hash.unwrap();
input.output_data = Some(DataSource::File(
hashes
.iter()
.map(|h| vec![FileSourceInner::Field(*h)])
.collect(),
));
} else {
input.output_data = Some(DataSource::File(
witness
.pretty_elements
.unwrap()
.rescaled_outputs
.iter()
.map(|o| {
o.iter()
.map(|f| FileSourceInner::Float(f.parse().unwrap()))
.collect()
})
.collect(),
));
}
input.save(data_path.clone().into()).unwrap();
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"setup-test-evm-data",
"-D",
data_path.as_str(),
"-M",
&model_path,
"--test-data",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
test_input_source.as_str(),
test_output_source.as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"setup",
@@ -2366,82 +2412,6 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
// generate the witness, passing the vk path to generate the necessary kzg commits
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"gen-witness",
"-D",
&data_path,
"-M",
&model_path,
"-O",
&witness_path,
"--vk-path",
&format!("{}/{}/key.vk", test_dir, example_name),
])
.status()
.expect("failed to execute process");
assert!(status.success());
// load witness
let witness: GraphWitness = GraphWitness::from_path(witness_path.clone().into()).unwrap();
// print out the witness
println!("WITNESS: {:?}", witness);
let mut input: GraphData = GraphData::from_path(data_path.clone().into()).unwrap();
if input_source != "file" || output_source != "file" {
println!("on chain input");
if input_visibility == "hashed" {
let hashes = witness.processed_inputs.unwrap().poseidon_hash.unwrap();
input.input_data = DataSource::File(
hashes
.iter()
.map(|h| vec![FileSourceInner::Field(*h)])
.collect(),
);
}
if output_visibility == "hashed" {
let hashes = witness.processed_outputs.unwrap().poseidon_hash.unwrap();
input.output_data = Some(DataSource::File(
hashes
.iter()
.map(|h| vec![FileSourceInner::Field(*h)])
.collect(),
));
} else {
input.output_data = Some(DataSource::File(
witness
.pretty_elements
.unwrap()
.rescaled_outputs
.iter()
.map(|o| {
o.iter()
.map(|f| FileSourceInner::Float(f.parse().unwrap()))
.collect()
})
.collect(),
));
}
input.save(data_path.clone().into()).unwrap();
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"setup-test-evm-data",
"-D",
data_path.as_str(),
"-M",
&model_path,
"--test-data",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
test_input_source.as_str(),
test_output_source.as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"prove",
@@ -2532,19 +2502,13 @@ mod native_tests {
.expect("failed to execute process");
assert!(status.success());
let deploy_evm_data_path = if input_source != "file" || output_source != "file" {
test_on_chain_data_path.clone()
} else {
data_path.clone()
};
let addr_path_da_arg = format!("--addr-path={}/{}/addr_da.txt", test_dir, example_name);
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"deploy-evm-da",
format!("--settings-path={}", settings_path).as_str(),
"-D",
deploy_evm_data_path.as_str(),
test_on_chain_data_path.as_str(),
"--sol-code-path",
sol_arg.as_str(),
rpc_arg.as_str(),
@@ -2582,42 +2546,40 @@ mod native_tests {
.status()
.expect("failed to execute process");
assert!(status.success());
// Create a new set of test on chain data only for the on-chain input source
if input_source != "file" || output_source != "file" {
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"setup-test-evm-data",
"-D",
data_path.as_str(),
"-M",
&model_path,
"--test-data",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
test_input_source.as_str(),
test_output_source.as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let deployed_addr_arg = format!("--addr={}", addr_da);
let args: Vec<&str> = vec![
"test-update-account-calls",
deployed_addr_arg.as_str(),
// Create a new set of test on chain data
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"setup-test-evm-data",
"-D",
data_path.as_str(),
"-M",
&model_path,
"--test-data",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args(&args)
.status()
.expect("failed to execute process");
test_input_source.as_str(),
test_output_source.as_str(),
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
assert!(status.success());
let deployed_addr_arg = format!("--addr={}", addr_da);
let args = vec![
"test-update-account-calls",
deployed_addr_arg.as_str(),
"-D",
test_on_chain_data_path.as_str(),
rpc_arg.as_str(),
];
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args(&args)
.status()
.expect("failed to execute process");
assert!(status.success());
// As sanity check, add example that should fail.
let args = vec![
"verify-evm",

View File

@@ -423,74 +423,6 @@ async def test_create_evm_verifier():
assert res == True
assert os.path.isfile(sol_code_path)
async def test_create_evm_verifier_separate_vk():
"""
Create EVM a verifier with solidity code and separate vk
In order to run this test you will need to install solc in your environment
"""
vk_path = os.path.join(folder_path, 'test_evm.vk')
settings_path = os.path.join(folder_path, 'settings.json')
sol_code_path = os.path.join(folder_path, 'test_separate.sol')
vk_code_path = os.path.join(folder_path, 'test_vk.sol')
abi_path = os.path.join(folder_path, 'test_separate.abi')
abi_vk_path = os.path.join(folder_path, 'test_vk_separate.abi')
proof_path = os.path.join(folder_path, 'test_evm.pf')
calldata_path = os.path.join(folder_path, 'calldata.bytes')
# # res is now a vector of bytes
# res = ezkl.encode_evm_calldata(proof_path, calldata_path)
# assert os.path.isfile(calldata_path)
# assert len(res) > 0
res = await ezkl.create_evm_verifier(
vk_path,
settings_path,
sol_code_path,
abi_path,
srs_path=srs_path,
render_vk_seperately=True
)
res = await ezkl.create_evm_vk(
vk_path,
settings_path,
vk_code_path,
abi_vk_path,
srs_path=srs_path,
)
assert res == True
assert os.path.isfile(sol_code_path)
async def test_deploy_evm_separate_vk():
"""
Test deployment of the separate verifier smart contract + vk
In order to run this you will need to install solc in your environment
"""
addr_path_verifier = os.path.join(folder_path, 'address_separate.json')
addr_path_vk = os.path.join(folder_path, 'address_vk.json')
sol_code_path = os.path.join(folder_path, 'test_separate.sol')
vk_code_path = os.path.join(folder_path, 'test_vk.sol')
# TODO: without optimization there will be out of gas errors
# sol_code_path = os.path.join(folder_path, 'test.sol')
res = await ezkl.deploy_evm(
addr_path_verifier,
sol_code_path,
rpc_url=anvil_url,
)
res = await ezkl.deploy_vk_evm(
addr_path_vk,
vk_code_path,
rpc_url=anvil_url,
)
assert res == True
async def test_deploy_evm():
"""
@@ -571,47 +503,6 @@ async def test_verify_evm():
assert res == True
async def test_verify_evm_separate_vk():
"""
Verifies an evm proof
In order to run this you will need to install solc in your environment
"""
proof_path = os.path.join(folder_path, 'test_evm.pf')
addr_path_verifier = os.path.join(folder_path, 'address_separate.json')
addr_path_vk = os.path.join(folder_path, 'address_vk.json')
proof_path = os.path.join(folder_path, 'test_evm.pf')
calldata_path = os.path.join(folder_path, 'calldata_separate.bytes')
with open(addr_path_verifier, 'r') as file:
addr_verifier = file.read().rstrip()
print(addr_verifier)
with open(addr_path_vk, 'r') as file:
addr_vk = file.read().rstrip()
print(addr_vk)
# res is now a vector of bytes
res = ezkl.encode_evm_calldata(proof_path, calldata_path, addr_vk=addr_vk)
assert os.path.isfile(calldata_path)
assert len(res) > 0
# TODO: without optimization there will be out of gas errors
# sol_code_path = os.path.join(folder_path, 'test.sol')
res = await ezkl.verify_evm(
addr_verifier,
proof_path,
rpc_url=anvil_url,
addr_vk=addr_vk,
# sol_code_path
# optimizer_runs
)
assert res == True
async def test_aggregate_and_verify_aggr():
data_path = os.path.join(
@@ -870,7 +761,6 @@ def get_examples():
'accuracy',
'linear_regression',
"mnist_gan",
"smallworm",
]
examples = []
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):