Compare commits

...

7 Commits

Author SHA1 Message Date
github-actions[bot]
6c73ba1bee ci: update version string in docs 2024-07-18 13:59:04 +00:00
dante
5be12b7a54 fix: num groups for conv operations should be specified at load time (#828) 2024-07-18 09:58:41 -04:00
dante
2fd877c716 chore: small worm example (#568) 2024-07-15 09:20:37 -04:00
dante
8197340985 chore: const filtering optimizations (#825) 2024-07-12 12:37:02 +01:00
dante
6855ea1947 feat: parallel polynomial reads in halo2 (#826) 2024-07-12 00:33:14 +01:00
dante
2ca57bde2c chore: bump tract (#823) 2024-06-29 01:53:18 +01:00
Ethan Cemer
390de88194 feat: create_evm_vk python bindings (#818) 2024-06-23 22:21:59 -04:00
32 changed files with 817 additions and 530 deletions

View File

@@ -236,6 +236,8 @@ jobs:
with:
crate: cargo-nextest
locked: true
# - name: The Worm Mock
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
- name: public outputs and tolerance > 0
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
- name: public outputs + batch size == 10

501
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -39,7 +39,6 @@ snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch =
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
maybe-rayon = { version = "0.1.1", default_features = false }
bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = "1.6.0"
@@ -63,16 +62,13 @@ reqwest = { version = "0.12.4", default-features = false, features = [
openssl = { version = "0.10.55", features = ["vendored"] }
tokio-postgres = "0.7.10"
pg_bigdecimal = "0.1.5"
futures-util = "0.3.30"
lazy_static = "1.4.0"
colored_json = { version = "3.0.1", default_features = false, optional = true }
plotters = { version = "0.3.0", default_features = false, optional = true }
regex = { version = "1", default_features = false }
tokio = { version = "1.35", default_features = false, features = [
"macros",
"rt-multi-thread"
] }
tokio-util = { version = "0.7.9", features = ["codec"] }
pyo3 = { version = "0.21.2", features = [
"extension-module",
"abi3-py37",
@@ -83,9 +79,8 @@ pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="m
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.10.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "05ebf550aa9922b221af4635c21a67a8d2af12a9", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
objc = { version = "0.2.4", optional = true }
@@ -108,8 +103,10 @@ console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dev-dependencies]
criterion = { version = "0.5.1", features = ["html_reports"] }
[dev-dependencies]
criterion = { version = "0.3", features = ["html_reports"] }
tempfile = "3.3.0"
lazy_static = "1.4.0"
mnist = "0.5"
@@ -180,7 +177,7 @@ required-features = ["ezkl"]
[features]
web = ["wasm-bindgen-rayon"]
default = ["ezkl", "mv-lookup", "no-banner"]
default = ["ezkl", "mv-lookup", "no-banner", "parallel-poly-read"]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
ezkl = [
@@ -194,6 +191,7 @@ ezkl = [
"colored_json",
"halo2_proofs/circuit-params",
]
parallel-poly-read = ["halo2_proofs/parallel-poly-read"]
mv-lookup = [
"halo2_proofs/mv-lookup",
"snark-verifier/mv-lookup",
@@ -211,7 +209,7 @@ metal = ["dep:metal", "dep:objc"]
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a", package = "halo2_proofs", branch = "ac/optional-selector-poly" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2#8cfca221f53069a0374687654882b99e729041d7", package = "halo2_proofs" }
[profile.release]

View File

@@ -72,6 +72,7 @@ impl Circuit<Fr> for MyCircuit {
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
stride: vec![1; 2],
group: 1,
}),
)
.unwrap();

View File

@@ -1,4 +1,4 @@
ezkl==0.0.0
ezkl==11.6.3
sphinx
sphinx-rtd-theme
sphinxcontrib-napoleon

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '11.6.3'
version = release

View File

@@ -205,6 +205,7 @@ where
let op = PolyOp::Conv {
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
group: 1,
};
let x = config
.layer_config

View File

@@ -39,7 +39,7 @@
"import json\n",
"import numpy as np\n",
"from sklearn.svm import SVC\n",
"import sk2torch\n",
"from hummingbird.ml import convert\n",
"import torch\n",
"import ezkl\n",
"import os\n",
@@ -59,11 +59,11 @@
"# Train an SVM on the data and wrap it in PyTorch.\n",
"sk_model = SVC(probability=True)\n",
"sk_model.fit(xs, ys)\n",
"model = sk2torch.wrap(sk_model)\n",
"\n",
"model = convert(sk_model, \"torch\").model\n",
"\n",
"\n",
"\n",
"model\n",
"\n"
]
},
@@ -84,33 +84,6 @@
"data_path = os.path.join('input.json')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7f0ca328",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"# Create a coordinate grid to compute a vector field on.\n",
"spaced = np.linspace(-2, 2, num=25)\n",
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
"\n",
"\n",
"# Compute the gradients of the SVM output.\n",
"outputs = model.predict_proba(grid_xs)[:, 1]\n",
"(input_grads,) = torch.autograd.grad(outputs.sum(), (grid_xs,))\n",
"\n",
"\n",
"# Create a quiver plot of the vector field.\n",
"plt.quiver(\n",
" grid_xs[:, 0].detach().numpy(),\n",
" grid_xs[:, 1].detach().numpy(),\n",
" input_grads[:, 0].detach().numpy(),\n",
" input_grads[:, 1].detach().numpy(),\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -119,14 +92,14 @@
"outputs": [],
"source": [
"\n",
"\n",
"spaced = np.linspace(-2, 2, num=25)\n",
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
"# export to onnx format\n",
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
"\n",
"# Input to the model\n",
"shape = xs.shape[1:]\n",
"x = grid_xs[0:1]\n",
"torch_out = model.predict(x)\n",
"# Export the model\n",
"torch.onnx.export(model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
@@ -143,9 +116,7 @@
"\n",
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"data = dict(input_data=[d])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
@@ -167,6 +138,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "0bee4d7f",
"metadata": {},
"outputs": [],
"source": [
@@ -220,7 +192,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": null,
"id": "b1c561a8",
"metadata": {},
"outputs": [],
@@ -441,9 +413,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
}

Binary file not shown.

View File

@@ -0,0 +1 @@
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":7,"param_scale":7,"scale_rebase_multiplier":10,"lookup_range":[0,0],"logrows":13,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_constraints":5619,"total_const_size":513,"model_instance_shapes":[[1,3,10,10]],"model_output_scales":[14],"model_input_scales":[7],"module_sizes":{"kzg":[],"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}

File diff suppressed because one or more lines are too long

Binary file not shown.

View File

@@ -0,0 +1 @@
network.onnx filter=lfs diff=lfs merge=lfs -text

View File

@@ -0,0 +1,47 @@
## The worm
This is an onnx file for a [WormVAE](https://github.com/TuragaLab/wormvae?tab=readme-ov-file) model, which is a VAE / latent-space representation of the C. elegans connectome.
The model "is a large-scale latent variable model with a very high-dimensional latent space
consisting of voltage dynamics of 300 neurons over 5 minutes of time at the simulation frequency
of 160 Hz. The generative model for these latent variables is described by stochastic differential
equations modeling the nonlinear dynamics of the network activity." (see [here](https://openreview.net/pdf?id=CJzi3dRlJE-)).
In effect this is a generative model for a worm's voltage dynamics, which can be used to generate new worm-like voltage dynamics given previous connectome state.
Using ezkl you can create a zk circuit equivalent to the wormvae model, allowing you to "prove" execution of the worm model. If you're feeling particularly adventurous, you can also use the zk circuit to generate new worm-state that can be verified on chain.
To do so you'll first want to fetch the files using git-lfs (as the onnx file is too large to be stored in git).
```bash
git lfs fetch --all
```
You'll then want to use the usual ezkl loop to generate the zk circuit. We recommend using fixed visibility for the model parameters, as the model is quite large and this will prune the circuit significantly.
```bash
ezkl gen-settings --param-visibility=fixed
cp input.json calibration.json
ezkl calibrate-settings
ezkl compile-circuit
ezkl gen-witness
ezkl prove
```
You might also need to aggregate the proof to get it to fit on chain.
```bash
ezkl aggregate
```
You can then create a smart contract that verifies this aggregate proof
```bash
ezkl create-evm-verifier-aggr
```
This can then be deployed on the chain of your choice.
> Note: the model is large and thus we recommend a machine with at least 512GB of RAM to run the above commands. If you're ever compute constrained you can always use the lilith service to generate the zk circuit. Message us on discord or telegram for more details :)

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2f88c5901d3768ec21e3cf2f2840d255e84fa13c364df86b24d960cca3333769
size 82095882

View File

@@ -0,0 +1 @@
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":0,"param_scale":6,"scale_rebase_multiplier":1,"lookup_range":[-32768,32768],"logrows":17,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Fixed"},"num_constraints":367422820,"total_const_size":365577160,"model_instance_shapes":[[1,300,1200]],"model_output_scales":[6],"model_input_scales":[0,0,0],"module_sizes":{"kzg":[],"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Div":{"denom":64.0}},"ReLU",{"Ln":{"scale":64.0}},{"Exp":{"scale":64.0}}],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}

View File

@@ -250,6 +250,10 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, CircuitError> {
if values[0].len() != values[1].len() {
return Err(TensorError::DimMismatch("dot".to_string()).into());
}
region.flush()?;
// time this entire function run
let global_start = instant::Instant::now();
@@ -257,12 +261,17 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
let mut values = values.clone();
// this section has been optimized to death, don't mess with it
let mut removal_indices = values[0].get_const_zero_indices()?;
let second_zero_indices = values[1].get_const_zero_indices()?;
let mut removal_indices = values[0].get_const_zero_indices();
let second_zero_indices = values[1].get_const_zero_indices();
removal_indices.extend(second_zero_indices);
removal_indices.par_sort_unstable();
removal_indices.dedup();
// if empty return a const
if removal_indices.len() == values[0].len() {
return Ok(create_zero_tensor(1));
}
// is already sorted
values[0].remove_indices(&mut removal_indices, true)?;
values[1].remove_indices(&mut removal_indices, true)?;
@@ -270,15 +279,6 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
let elapsed = global_start.elapsed();
trace!("filtering const zero indices took: {:?}", elapsed);
if values[0].len() != values[1].len() {
return Err(TensorError::DimMismatch("dot".to_string()).into());
}
// if empty return a const
if values[0].is_empty() && values[1].is_empty() {
return Ok(create_zero_tensor(1));
}
let start = instant::Instant::now();
let mut inputs = vec![];
let block_width = config.custom_gates.output.num_inner_cols();
@@ -343,7 +343,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
.collect::<Result<Vec<_>, CircuitError>>()?;
}
let last_elem = output.get_slice(&[output.len() - 1..output.len()])?;
let last_elem = output.last()?;
region.increment(assigned_len);
@@ -1779,12 +1779,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
let mut values = values.clone();
// this section has been optimized to death, don't mess with it
let mut removal_indices = values[0].get_const_zero_indices()?;
removal_indices.par_sort_unstable();
removal_indices.dedup();
// is already sorted
values[0].remove_indices(&mut removal_indices, true)?;
values[0].remove_const_zero_values();
let elapsed = global_start.elapsed();
trace!("filtering const zero indices took: {:?}", elapsed);
@@ -1841,7 +1836,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
}
}
let last_elem = output.get_slice(&[output.len() - 1..output.len()])?;
let last_elem = output.last()?;
region.increment(assigned_len);
@@ -1884,7 +1879,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>
let global_start = instant::Instant::now();
// this section has been optimized to death, don't mess with it
let removal_indices = values[0].get_const_zero_indices()?;
let removal_indices = values[0].get_const_zero_indices();
let elapsed = global_start.elapsed();
trace!("finding const zero indices took: {:?}", elapsed);
@@ -1945,7 +1940,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>
.collect::<Result<Vec<_>, CircuitError>>()?;
}
let last_elem = output.get_slice(&[output.len() - 1..output.len()])?;
let last_elem = output.last()?;
region.increment(assigned_len);
@@ -2256,22 +2251,22 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
let orig_lhs = lhs.clone();
let orig_rhs = rhs.clone();
// get indices of zeros
let first_zero_indices = lhs.get_const_zero_indices()?;
let second_zero_indices = rhs.get_const_zero_indices()?;
let mut removal_indices = match op {
let start = instant::Instant::now();
let first_zero_indices = HashSet::from_iter(lhs.get_const_zero_indices());
let second_zero_indices = HashSet::from_iter(rhs.get_const_zero_indices());
let removal_indices = match op {
BaseOp::Add | BaseOp::Mult => {
let mut removal_indices = first_zero_indices.clone();
removal_indices.extend(second_zero_indices.clone());
removal_indices
// join the zero indices
first_zero_indices
.union(&second_zero_indices)
.cloned()
.collect()
}
BaseOp::Sub => second_zero_indices.clone(),
_ => return Err(CircuitError::UnsupportedOp),
};
removal_indices.dedup();
let removal_indices: HashSet<&usize> = HashSet::from_iter(removal_indices.iter());
let removal_indices_ptr = &removal_indices;
trace!("setting up indices took {:?}", start.elapsed());
if lhs.len() != rhs.len() {
return Err(CircuitError::DimMismatch(format!(
@@ -2280,20 +2275,19 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
)));
}
let mut inputs = vec![];
for (i, input) in [lhs.clone(), rhs.clone()].iter().enumerate() {
let inp = {
let inputs = [lhs.clone(), rhs.clone()]
.iter()
.enumerate()
.map(|(i, input)| {
let res = region.assign_with_omissions(
&config.custom_gates.inputs[i],
input,
removal_indices_ptr,
&removal_indices,
)?;
res.get_inner()?
};
inputs.push(inp);
}
Ok(res.get_inner()?)
})
.collect::<Result<Vec<_>, CircuitError>>()?;
// Now we can assign the dot product
// time the calc
@@ -2308,15 +2302,20 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
let elapsed = start.elapsed();
trace!("pairwise {} calc took {:?}", op.as_str(), start.elapsed());
let assigned_len = inputs[0].len() - removal_indices.len();
let start = instant::Instant::now();
let assigned_len = op_result.len() - removal_indices.len();
let mut output = region.assign_with_omissions(
&config.custom_gates.output,
&op_result.into(),
removal_indices_ptr,
&removal_indices,
)?;
trace!("pairwise {} calc took {:?}", op.as_str(), elapsed);
trace!(
"pairwise {} input assign took {:?}",
op.as_str(),
start.elapsed()
);
// Enable the selectors
if !region.is_dummy() {
@@ -2337,16 +2336,11 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
let a_tensor = orig_lhs.get_inner_tensor()?;
let b_tensor = orig_rhs.get_inner_tensor()?;
let first_zero_indices: HashSet<&usize> = HashSet::from_iter(first_zero_indices.iter());
let second_zero_indices: HashSet<&usize> = HashSet::from_iter(second_zero_indices.iter());
trace!("setting up indices took {:?}", start.elapsed());
// infill the zero indices with the correct values from values[0] or values[1]
if !removal_indices_ptr.is_empty() {
if !removal_indices.is_empty() {
output
.get_inner_tensor_mut()?
.par_enum_map_mut_filtered(removal_indices_ptr, |i| {
.par_enum_map_mut_filtered(&removal_indices, |i| {
let val = match op {
BaseOp::Add => {
let a_is_null = first_zero_indices.contains(&i);
@@ -2386,6 +2380,7 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
end,
region.row()
);
trace!("----------------------------");
Ok(output)
}
@@ -3028,7 +3023,7 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI
.map(|coord| {
let (b, i) = (coord[0], coord[1]);
let input = values[0].get_slice(&[b..b + 1, i..i + 1])?;
let output = conv(config, region, &[input, kernel.clone()], padding, stride)?;
let output = conv(config, region, &[input, kernel.clone()], padding, stride, 1)?;
res.push(output);
Ok(())
})
@@ -3164,7 +3159,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// &[1, 1, 2, 2],
/// ).unwrap());
///
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![1;2], &vec![2;2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![1;2], &vec![2;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[0, 32, 0, 32, 0, 6, 0, 12, 0, 4, 0, 8, 0, 4, 0, 8, 0, 0, 0, 3, 0, 0, 0, 2]), &[1, 2, 3, 4]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
@@ -3176,7 +3171,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[6, 14, 4, 2, 17, 21, 0, 1, 5]), &[1, 1, 3, 3]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
@@ -3189,7 +3184,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[17]), &[1, 1, 1, 1]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
@@ -3202,7 +3197,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[10, 4, 0, 3]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
@@ -3214,7 +3209,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[3, 1, 1, 5]),
/// &[1, 1, 2, 2],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[6, 2, 12, 4, 2, 10, 4, 20, 0, 0, 3, 1, 0, 0, 1, 5]), &[1, 1, 4, 4]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
@@ -3226,7 +3221,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[3, 2]),
/// &[1, 1, 2, 1],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(1, 1); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[0, 0]), &[1, 1, 2, 1]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
@@ -3238,7 +3233,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[3, 2]),
/// &[1, 1, 2, 1],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k], &vec![(0, 0); 2], &vec![0;2], &vec![2; 2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[6, 0, 12, 4, 0, 8, 0, 0, 3, 0, 0, 2]), &[1, 1, 4, 3]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
@@ -3249,7 +3244,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// &[1, 1, 2, 2],
/// ).unwrap());
///
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, c], &vec![(1, 1); 2], &vec![0;2], &vec![2;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[0, 32, 0, 0, 6, 0, 0, 4, 0, 0, 0, 0]), &[1, 2, 2, 3]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
/// let x = ValTensor::from_i64_tensor(Tensor::<i64>::new(
@@ -3264,7 +3259,7 @@ pub fn max_pool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Into
/// Some(&[1]),
/// &[1],
/// ).unwrap());
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2]).unwrap();
/// let result = deconv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(1, 1); 2], &vec![0;2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[55, 58, 66, 69]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
@@ -3284,6 +3279,7 @@ pub fn deconv<
padding: &[(usize, usize)],
output_padding: &[usize],
stride: &[usize],
num_groups: usize,
) -> Result<ValTensor<F>, CircuitError> {
let has_bias = inputs.len() == 3;
let (image, kernel) = (&inputs[0], &inputs[1]);
@@ -3369,6 +3365,7 @@ pub fn deconv<
&conv_input,
&vec![(0, 0); conv_dim],
&vec![1; conv_dim],
num_groups,
)?;
Ok(output)
@@ -3400,7 +3397,7 @@ pub fn deconv<
/// Some(&[0]),
/// &[1],
/// ).unwrap());
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2]).unwrap();
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[31, 16, 8, 26]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
@@ -3418,7 +3415,7 @@ pub fn deconv<
/// &[2],
/// ).unwrap());
///
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2]).unwrap();
/// let result = conv::<Fp>(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 2).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[32, 17, 9, 27, 34, 20, 13, 26]), &[1, 2, 2, 2]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
///
@@ -3436,7 +3433,7 @@ pub fn deconv<
/// &[4],
/// ).unwrap());
///
/// let result =conv(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2]).unwrap();
/// let result =conv(&dummy_config, &mut dummy_region, &[x, k, b], &vec![(0, 0); 2], &vec![1;2], 1).unwrap();
/// let expected = Tensor::<i64>::new(Some(&[65, 36, 21, 52, 73, 48, 37, 48, 65, 36, 21, 52, 73, 48, 37, 48]), &[1, 4, 2, 2]).unwrap();
/// assert_eq!(result.get_int_evals().unwrap(), expected);
/// ```
@@ -3455,6 +3452,7 @@ pub fn conv<
values: &[ValTensor<F>],
padding: &[(usize, usize)],
stride: &[usize],
num_groups: usize,
) -> Result<ValTensor<F>, CircuitError> {
let has_bias = values.len() == 3;
let (mut image, mut kernel) = (values[0].clone(), values[1].clone());
@@ -3485,6 +3483,11 @@ pub fn conv<
region.increment(*assigned_len.iter().max().unwrap());
}
// if image is 3d add a dummy batch dimension
if image.dims().len() == kernel.dims().len() - 1 {
image.reshape(&[1, image.dims()[0], image.dims()[1], image.dims()[2]])?;
}
let image_dims = image.dims();
let kernel_dims = kernel.dims();
@@ -3518,10 +3521,17 @@ pub fn conv<
log::debug!("slides: {:?}", slides);
let num_groups = input_channels / kernel_dims[1];
let input_channels_per_group = input_channels / num_groups;
let output_channels_per_group = output_channels / num_groups;
if output_channels_per_group == 0 || input_channels_per_group == 0 {
return Err(TensorError::DimMismatch(format!(
"Given groups={}, expected input channels and output channels to be divisible by groups, but got input_channels={}, output_channels={}",
num_groups, input_channels, output_channels
))
.into());
}
log::debug!(
"num_groups: {}, input_channels_per_group: {}, output_channels_per_group: {}",
num_groups,
@@ -3529,14 +3539,6 @@ pub fn conv<
output_channels_per_group
);
if output_channels_per_group == 0 {
return Err(TensorError::DimMismatch(format!(
"Given groups={}, expected kernel to be at least {} at dimension 0 but got {} instead",
num_groups, num_groups, output_channels_per_group
))
.into());
}
let num_outputs =
batch_size * num_groups * output_channels_per_group * slides.iter().product::<usize>();
@@ -3777,7 +3779,7 @@ pub(crate) fn boolean_identity<
values: &[ValTensor<F>; 1],
assign: bool,
) -> Result<ValTensor<F>, CircuitError> {
let output = if assign || !values[0].get_const_indices()?.is_empty() {
let output = if assign || !values[0].get_const_indices().is_empty() {
// get zero constants indices
let output = region.assign(&config.custom_gates.output, &values[0])?;
region.increment(output.len());
@@ -3942,11 +3944,10 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::
let x = values[0].clone();
let removal_indices = values[0].get_const_indices()?;
let removal_indices: HashSet<&usize> = HashSet::from_iter(removal_indices.iter());
let removal_indices_ptr = &removal_indices;
let removal_indices = values[0].get_const_indices();
let removal_indices: HashSet<usize> = HashSet::from_iter(removal_indices);
let w = region.assign_with_omissions(&config.static_lookups.input, &x, removal_indices_ptr)?;
let w = region.assign_with_omissions(&config.static_lookups.input, &x, &removal_indices)?;
let output = w.get_inner_tensor()?.par_enum_map(|i, e| {
Ok::<_, TensorError>(if let Some(f) = e.get_felt_eval() {
@@ -3964,7 +3965,7 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::
let mut output = region.assign_with_omissions(
&config.static_lookups.output,
&output.into(),
removal_indices_ptr,
&removal_indices,
)?;
let is_dummy = region.is_dummy();
@@ -3994,11 +3995,7 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::
})?
.into();
region.assign_with_omissions(
&config.static_lookups.index,
&table_index,
removal_indices_ptr,
)?;
region.assign_with_omissions(&config.static_lookups.index, &table_index, &removal_indices)?;
if !is_dummy {
(0..assigned_len)

View File

@@ -33,6 +33,7 @@ pub enum PolyOp {
Conv {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
group: usize,
},
Downsample {
axis: usize,
@@ -43,6 +44,7 @@ pub enum PolyOp {
padding: Vec<(usize, usize)>,
output_padding: Vec<usize>,
stride: Vec<usize>,
group: usize,
},
Add,
Sub,
@@ -148,17 +150,25 @@ impl<
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Conv { stride, padding } => {
format!("CONV (stride={:?}, padding={:?})", stride, padding)
PolyOp::Conv {
stride,
padding,
group,
} => {
format!(
"CONV (stride={:?}, padding={:?}, group={})",
stride, padding, group
)
}
PolyOp::DeConv {
stride,
padding,
output_padding,
group,
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
stride, padding, output_padding
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
stride, padding, output_padding, group
)
}
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
@@ -212,9 +222,18 @@ impl<
PolyOp::Prod { axes, .. } => {
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Conv { padding, stride } => {
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
}
PolyOp::Conv {
padding,
stride,
group,
} => layouts::conv(
config,
region,
values[..].try_into()?,
padding,
stride,
*group,
)?,
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
@@ -261,6 +280,7 @@ impl<
padding,
output_padding,
stride,
group,
} => layouts::deconv(
config,
region,
@@ -268,6 +288,7 @@ impl<
padding,
output_padding,
stride,
*group,
)?,
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,

View File

@@ -9,6 +9,8 @@ use halo2_proofs::{
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use maybe_rayon::iter::ParallelExtend;
use portable_atomic::AtomicI64 as AtomicInt;
use std::{
cell::RefCell,
@@ -515,18 +517,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
var.assign(
Ok(var.assign(
&mut region.borrow_mut(),
self.linear_coord,
values,
&mut self.assigned_constants,
)
)?)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
self.assigned_constants.par_extend(values_map);
}
Ok(values.clone())
}
@@ -542,18 +544,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
var.assign(
Ok(var.assign(
&mut region.borrow_mut(),
self.combined_dynamic_shuffle_coord(),
values,
&mut self.assigned_constants,
)
)?)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
self.assigned_constants.par_extend(values_map);
}
Ok(values.clone())
}
@@ -564,7 +566,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
) -> Result<ValTensor<F>, CircuitError> {
self.assign_dynamic_lookup(var, values)
}
@@ -573,27 +575,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
ommissions: &HashSet<&usize>,
) -> Result<ValTensor<F>, Error> {
ommissions: &HashSet<usize>,
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
var.assign_with_omissions(
Ok(var.assign_with_omissions(
&mut region.borrow_mut(),
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)
)?)
} else {
let inner_tensor = values.get_inner_tensor().unwrap();
let mut values_map = values.create_constants_map();
let mut values_clone = values.clone();
let mut indices = ommissions.clone().into_iter().collect_vec();
values_clone.remove_indices(&mut indices, false)?;
for o in ommissions {
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
values_map.remove(&value);
}
}
let values_map = values.create_constants_map();
self.assigned_constants.extend(values_map);
self.assigned_constants.par_extend(values_map);
Ok(values.clone())
}

View File

@@ -1050,6 +1050,7 @@ mod conv {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis)
@@ -1200,6 +1201,7 @@ mod conv_col_ultra_overflow {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis)
@@ -1345,6 +1347,7 @@ mod conv_relu_col_ultra_overflow {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis);

View File

@@ -1502,10 +1502,10 @@ pub(crate) async fn create_evm_vk(
#[cfg(not(target_arch = "wasm32"))]
pub(crate) async fn create_evm_data_attestation(
settings_path: PathBuf,
_sol_code_path: PathBuf,
_abi_path: PathBuf,
_input: PathBuf,
_witness: Option<PathBuf>,
sol_code_path: PathBuf,
abi_path: PathBuf,
input: PathBuf,
witness: Option<PathBuf>,
) -> Result<String, EZKLError> {
#[allow(unused_imports)]
use crate::graph::{DataSource, VarVisibility};
@@ -1517,7 +1517,7 @@ pub(crate) async fn create_evm_data_attestation(
trace!("params computed");
// if input is not provided, we just instantiate dummy input data
let data = GraphData::from_path(_input).unwrap_or(GraphData::new(DataSource::File(vec![])));
let data = GraphData::from_path(input).unwrap_or(GraphData::new(DataSource::File(vec![])));
let output_data = if let Some(DataSource::OnChain(source)) = data.output_data {
if visibility.output.is_private() {
@@ -1552,7 +1552,7 @@ pub(crate) async fn create_evm_data_attestation(
|| settings.run_args.output_visibility == Visibility::KZGCommit
|| settings.run_args.param_visibility == Visibility::KZGCommit
{
let witness = GraphWitness::from_path(_witness.unwrap_or(DEFAULT_WITNESS.into()))?;
let witness = GraphWitness::from_path(witness.unwrap_or(DEFAULT_WITNESS.into()))?;
let commitments = witness.get_polycommitments();
let proof_first_bytes = get_proof_commitments::<
KZGCommitmentScheme<Bn256>,
@@ -1566,12 +1566,12 @@ pub(crate) async fn create_evm_data_attestation(
};
let output = fix_da_sol(input_data, output_data, commitment_bytes)?;
let mut f = File::create(_sol_code_path.clone())?;
let mut f = File::create(sol_code_path.clone())?;
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(_sol_code_path, "DataAttestation", 0).await?;
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestation", 0).await?;
// save abi to file
serde_json::to_writer(std::fs::File::create(_abi_path)?, &abi)?;
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
Ok(String::new())
}

View File

@@ -16,7 +16,7 @@ pub fn i64_to_felt<F: PrimeField>(x: i64) -> F {
if x >= 0 {
F::from_u128(x as u128)
} else {
-F::from_u128((-x) as u128)
-F::from_u128(x.saturating_neg() as u128)
}
}

View File

@@ -50,7 +50,7 @@ pub enum GraphError {
/// Tract error
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
#[error("[tract] {0}")]
TractError(#[from] tract_onnx::tract_core::anyhow::Error),
TractError(#[from] tract_onnx::prelude::TractError),
/// Packing exponent is too large
#[error("largest packing exponent exceeds max. try reducing the scale")]
PackingExponent,

View File

@@ -609,9 +609,7 @@ impl Model {
reader: &mut dyn std::io::Read,
run_args: &RunArgs,
) -> Result<TractResult, GraphError> {
use tract_onnx::{
tract_core::internal::IntoArcTensor, tract_hir::internal::GenericFactoid,
};
use tract_onnx::tract_hir::internal::GenericFactoid;
let mut model = tract_onnx::onnx().model_for_read(reader)?;
@@ -648,29 +646,11 @@ impl Model {
}
// Note: do not optimize the model, as the layout will depend on underlying hardware
let mut typed_model = model
let typed_model = model
.into_typed()?
.concretize_dims(&symbol_values)?
.into_decluttered()?;
// concretize constants
for node in typed_model.eval_order()? {
let node = typed_model.node_mut(node);
if let Some(op) = node.op_as_mut::<tract_onnx::tract_core::ops::konst::Const>() {
if op.0.datum_type() == DatumType::TDim {
// get inner value to Arc<Tensor>
let mut constant = op.0.as_ref().clone();
// Generally a shape or hyperparam
constant
.as_slice_mut::<tract_onnx::prelude::TDim>()?
.iter_mut()
.for_each(|x| *x = x.eval(&symbol_values));
op.0 = constant.into_arc_tensor();
}
}
}
Ok((typed_model, symbol_values))
}

View File

@@ -85,6 +85,34 @@ pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
mult.log2().round() as crate::Scale
}
#[cfg(not(target_arch = "wasm32"))]
/// extract padding from a onnx node.
pub fn extract_padding(
pool_spec: &PoolSpec,
num_dims: usize,
) -> Result<Vec<(usize, usize)>, GraphError> {
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
PaddingSpec::Valid => vec![(0, 0); num_dims],
_ => {
return Err(GraphError::MissingParams("padding".to_string()));
}
};
Ok(padding)
}
#[cfg(not(target_arch = "wasm32"))]
/// Extracts the strides from a onnx node.
pub fn extract_strides(pool_spec: &PoolSpec) -> Result<Vec<usize>, GraphError> {
Ok(pool_spec
.strides
.clone()
.ok_or(GraphError::MissingParams("stride".to_string()))?
.to_vec())
}
/// Gets the shape of a onnx node's outlets.
#[cfg(not(target_arch = "wasm32"))]
pub fn node_output_shapes(
@@ -255,6 +283,8 @@ pub fn new_op_from_onnx(
.flat_map(|x| x.out_scales())
.collect::<Vec<_>>();
let input_dims = inputs.iter().flat_map(|x| x.out_dims()).collect::<Vec<_>>();
let mut replace_const = |scale: crate::Scale,
index: usize,
default_op: SupportedOp|
@@ -309,12 +339,9 @@ pub fn new_op_from_onnx(
}
}
"MultiBroadcastTo" => {
let op = load_op::<MultiBroadcastTo>(node.op(), idx, node.op().name().to_string())?;
let shape = op.shape.clone();
let shape = shape
.iter()
.map(|x| x.to_usize())
.collect::<Result<Vec<_>, _>>()?;
let _op = load_op::<MultiBroadcastTo>(node.op(), idx, node.op().name().to_string())?;
let shapes = node_output_shapes(&node, symbol_values)?;
let shape = shapes[0].clone();
SupportedOp::Linear(PolyOp::MultiBroadcastTo { shape })
}
@@ -1073,18 +1100,8 @@ pub fn new_op_from_onnx(
));
}
let stride = pool_spec
.strides
.clone()
.ok_or(GraphError::MissingParams("stride".to_string()))?;
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(GraphError::MissingParams("padding".to_string()));
}
};
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, input_dims[0].len())?;
let kernel_shape = &pool_spec.kernel_shape;
SupportedOp::Hybrid(HybridOp::MaxPool {
@@ -1151,21 +1168,10 @@ pub fn new_op_from_onnx(
));
}
let stride = match conv_node.pool_spec.strides.clone() {
Some(s) => s.to_vec(),
None => {
return Err(GraphError::MissingParams("strides".to_string()));
}
};
let pool_spec = &conv_node.pool_spec;
let padding = match &conv_node.pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(GraphError::MissingParams("padding".to_string()));
}
};
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, input_dims[0].len())?;
// if bias exists then rescale it to the input + kernel scale
if input_scales.len() == 3 {
@@ -1183,7 +1189,13 @@ pub fn new_op_from_onnx(
}
}
SupportedOp::Linear(PolyOp::Conv { padding, stride })
let group = conv_node.group;
SupportedOp::Linear(PolyOp::Conv {
padding,
stride,
group,
})
}
"Not" => SupportedOp::Linear(PolyOp::Not),
"And" => SupportedOp::Linear(PolyOp::And),
@@ -1214,21 +1226,10 @@ pub fn new_op_from_onnx(
));
}
let stride = match deconv_node.pool_spec.strides.clone() {
Some(s) => s.to_vec(),
None => {
return Err(GraphError::MissingParams("strides".to_string()));
}
};
let padding = match &deconv_node.pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(GraphError::MissingParams("padding".to_string()));
}
};
let pool_spec = &deconv_node.pool_spec;
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, input_dims[0].len())?;
// if bias exists then rescale it to the input + kernel scale
if input_scales.len() == 3 {
let bias_scale = input_scales[2];
@@ -1249,6 +1250,7 @@ pub fn new_op_from_onnx(
padding,
output_padding: deconv_node.adjustments.to_vec(),
stride,
group: deconv_node.group,
})
}
"Downsample" => {
@@ -1339,18 +1341,8 @@ pub fn new_op_from_onnx(
));
}
let stride = pool_spec
.strides
.clone()
.ok_or(GraphError::MissingParams("stride".to_string()))?;
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(GraphError::MissingParams("padding".to_string()));
}
};
let stride = extract_strides(pool_spec)?;
let padding = extract_padding(pool_spec, input_dims[0].len())?;
SupportedOp::Hybrid(HybridOp::SumPool {
padding,

View File

@@ -1491,7 +1491,7 @@ fn encode_evm_calldata<'a>(
/// The path to the SRS file
///
/// render_vk_separately: bool
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create_evm_vk command
///
/// Returns
/// -------
@@ -1533,6 +1533,56 @@ fn create_evm_verifier(
})
}
/// Creates an Evm verifer key. This command should be called after create_evm_verifier with the render_vk_separately arg set to true. By rendering a verification key separately you can reuse the same verifier for similar circuit setups with different verifying keys, helping to reduce the amount of state our verifiers store on the blockchain.
///
/// Arguments
/// ---------
/// vk_path: str
/// The path to the verification key file
///
/// settings_path: str
/// The path to the settings file
///
/// sol_code_path: str
/// The path to the create the solidity verifying key.
///
/// abi_path: str
/// The path to create the ABI for the solidity verifier
///
/// srs_path: str
/// The path to the SRS file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
srs_path=None
))]
fn create_evm_vk(
py: Python,
vk_path: PathBuf,
settings_path: PathBuf,
sol_code_path: PathBuf,
abi_path: PathBuf,
srs_path: Option<PathBuf>,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_asyncio::tokio::future_into_py(py, async move {
crate::execute::create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path)
.await
.map_err(|e| {
let err_str = format!("Failed to run create_evm_verifier: {}", e);
PyRuntimeError::new_err(err_str)
})?;
Ok(true)
})
}
/// Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
///
/// Arguments
@@ -1762,7 +1812,7 @@ fn deploy_da_evm(
/// Arguments
/// ---------
/// addr_verifier: str
/// The path to verifier contract's address
/// The verifier contract's address as a hex string
///
/// proof_path: str
/// The path to the proof file (generated using the prove command)
@@ -1774,7 +1824,7 @@ fn deploy_da_evm(
/// does the verifier use data attestation ?
///
/// addr_vk: str
///
/// The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
/// Returns
/// -------
/// bool
@@ -1925,6 +1975,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(compile_circuit, m)?)?;
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
m.add_function(wrap_pyfunction!(create_evm_vk, m)?)?;
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;

View File

@@ -1281,6 +1281,30 @@ impl<T: Clone + TensorType> Tensor<T> {
Ok(t)
}
/// Get last elem from Tensor
/// ```
/// use ezkl::tensor::Tensor;
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
/// let mut b = Tensor::<i32>::new(Some(&[3]), &[1]).unwrap();
///
/// assert_eq!(a.last().unwrap(), b);
/// ```
pub fn last(&self) -> Result<Tensor<T>, TensorError>
where
T: Send + Sync,
{
let res = match self.inner.last() {
Some(e) => e.clone(),
None => {
return Err(TensorError::DimError(
"Cannot get last element of empty tensor".to_string(),
))
}
};
Tensor::new(Some(&[res]), &[1])
}
/// Maps a function to tensors and enumerates in parallel
/// ```
/// use ezkl::tensor::{Tensor, TensorError};
@@ -1293,7 +1317,7 @@ impl<T: Clone + TensorType> Tensor<T> {
E: Error + std::marker::Send + std::marker::Sync,
>(
&mut self,
filter_indices: &std::collections::HashSet<&usize>,
filter_indices: &std::collections::HashSet<usize>,
f: F,
) -> Result<(), E>
where

View File

@@ -1,12 +1,12 @@
use core::{iter::FilterMap, slice::Iter};
use crate::circuit::region::ConstantsMap;
use maybe_rayon::slice::Iter;
use super::{
ops::{intercalate_values, pad, resize},
*,
};
use halo2_proofs::{arithmetic::Field, circuit::Cell, plonk::Instance};
use maybe_rayon::iter::{FilterMap, IntoParallelIterator, ParallelIterator};
pub(crate) fn create_constant_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
@@ -460,7 +460,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
&self,
) -> FilterMap<Iter<'_, ValType<F>>, fn(&ValType<F>) -> Option<(F, ValType<F>)>> {
match self {
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
ValTensor::Value { inner, .. } => inner.par_iter().filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
@@ -573,6 +573,27 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(())
}
/// Calls `get_slice` on the inner tensor.
pub fn last(&self) -> Result<ValTensor<F>, TensorError> {
let slice = match self {
ValTensor::Value {
inner: v,
dims: _,
scale,
} => {
let inner = v.last()?;
let dims = inner.dims().to_vec();
ValTensor::Value {
inner,
dims,
scale: *scale,
}
}
_ => return Err(TensorError::WrongMethod),
};
Ok(slice)
}
/// Calls `get_slice` on the inner tensor.
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, TensorError> {
if indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims() {
@@ -753,43 +774,72 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(())
}
/// gets constants
pub fn get_const_zero_indices(&self) -> Result<Vec<usize>, TensorError> {
/// remove constant zero values constants
pub fn remove_const_zero_values(&mut self) {
match self {
ValTensor::Value { inner: v, .. } => {
let mut indices = vec![];
for (i, e) in v.iter().enumerate() {
if let ValType::Constant(r) = e {
if *r == F::ZERO {
indices.push(i);
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.clone()
.into_par_iter()
.filter_map(|e| {
if let ValType::Constant(r) = e {
if r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if r == F::ZERO {
return None;
}
}
} else if let ValType::AssignedConstant(_, r) = e {
if *r == F::ZERO {
indices.push(i);
}
}
}
Ok(indices)
Some(e)
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => Ok(vec![]),
ValTensor::Instance { .. } => {}
}
}
/// gets constants
pub fn get_const_indices(&self) -> Result<Vec<usize>, TensorError> {
pub fn get_const_zero_indices(&self) -> Vec<usize> {
match self {
ValTensor::Value { inner: v, .. } => {
let mut indices = vec![];
for (i, e) in v.iter().enumerate() {
if let ValType::Constant(_) = e {
indices.push(i);
} else if let ValType::AssignedConstant(_, _) = e {
indices.push(i);
ValTensor::Value { inner: v, .. } => v
.par_iter()
.enumerate()
.filter_map(|(i, e)| {
if let ValType::Constant(r) = e {
if *r == F::ZERO {
return Some(i);
}
} else if let ValType::AssignedConstant(_, r) = e {
if *r == F::ZERO {
return Some(i);
}
}
}
Ok(indices)
}
ValTensor::Instance { .. } => Ok(vec![]),
None
})
.collect(),
ValTensor::Instance { .. } => vec![],
}
}
/// gets constants
pub fn get_const_indices(&self) -> Vec<usize> {
match self {
ValTensor::Value { inner: v, .. } => v
.par_iter()
.enumerate()
.filter_map(|(i, e)| {
if let ValType::Constant(_) = e {
Some(i)
} else if let ValType::AssignedConstant(_, _) = e {
Some(i)
} else {
None
}
})
.collect(),
ValTensor::Instance { .. } => vec![],
}
}

View File

@@ -319,7 +319,7 @@ impl VarTensor {
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
omissions: &HashSet<&usize>,
omissions: &HashSet<usize>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut assigned_coord = 0;

View File

@@ -183,12 +183,13 @@ mod native_tests {
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr_proof.json";
const LARGE_TESTS: [&str; 5] = [
const LARGE_TESTS: [&str; 6] = [
"self_attention",
"nanoGPT",
"multihead_attention",
"mobilenet",
"mnist_gan",
"smallworm",
];
const ACCURACY_CAL_TESTS: [&str; 6] = [
@@ -200,7 +201,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 93] = [
const TESTS: [&str; 94] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
@@ -298,6 +299,7 @@ mod native_tests {
"1l_lppool",
"lstm_large", // 91
"lstm_medium", // 92
"lenet_5", // 93
];
const WASM_TESTS: [&str; 46] = [
@@ -536,7 +538,7 @@ mod native_tests {
}
});
seq!(N in 0..=92 {
seq!(N in 0..=93 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -940,7 +942,7 @@ mod native_tests {
});
seq!(N in 0..=4 {
seq!(N in 0..=5 {
#(#[test_case(LARGE_TESTS[N])])*
#[ignore]

View File

@@ -423,6 +423,74 @@ async def test_create_evm_verifier():
assert res == True
assert os.path.isfile(sol_code_path)
async def test_create_evm_verifier_separate_vk():
"""
Create EVM a verifier with solidity code and separate vk
In order to run this test you will need to install solc in your environment
"""
vk_path = os.path.join(folder_path, 'test_evm.vk')
settings_path = os.path.join(folder_path, 'settings.json')
sol_code_path = os.path.join(folder_path, 'test_separate.sol')
vk_code_path = os.path.join(folder_path, 'test_vk.sol')
abi_path = os.path.join(folder_path, 'test_separate.abi')
abi_vk_path = os.path.join(folder_path, 'test_vk_separate.abi')
proof_path = os.path.join(folder_path, 'test_evm.pf')
calldata_path = os.path.join(folder_path, 'calldata.bytes')
# # res is now a vector of bytes
# res = ezkl.encode_evm_calldata(proof_path, calldata_path)
# assert os.path.isfile(calldata_path)
# assert len(res) > 0
res = await ezkl.create_evm_verifier(
vk_path,
settings_path,
sol_code_path,
abi_path,
srs_path=srs_path,
render_vk_seperately=True
)
res = await ezkl.create_evm_vk(
vk_path,
settings_path,
vk_code_path,
abi_vk_path,
srs_path=srs_path,
)
assert res == True
assert os.path.isfile(sol_code_path)
async def test_deploy_evm_separate_vk():
"""
Test deployment of the separate verifier smart contract + vk
In order to run this you will need to install solc in your environment
"""
addr_path_verifier = os.path.join(folder_path, 'address_separate.json')
addr_path_vk = os.path.join(folder_path, 'address_vk.json')
sol_code_path = os.path.join(folder_path, 'test_separate.sol')
vk_code_path = os.path.join(folder_path, 'test_vk.sol')
# TODO: without optimization there will be out of gas errors
# sol_code_path = os.path.join(folder_path, 'test.sol')
res = await ezkl.deploy_evm(
addr_path_verifier,
sol_code_path,
rpc_url=anvil_url,
)
res = await ezkl.deploy_vk_evm(
addr_path_vk,
vk_code_path,
rpc_url=anvil_url,
)
assert res == True
async def test_deploy_evm():
"""
@@ -503,6 +571,47 @@ async def test_verify_evm():
assert res == True
async def test_verify_evm_separate_vk():
"""
Verifies an evm proof
In order to run this you will need to install solc in your environment
"""
proof_path = os.path.join(folder_path, 'test_evm.pf')
addr_path_verifier = os.path.join(folder_path, 'address_separate.json')
addr_path_vk = os.path.join(folder_path, 'address_vk.json')
proof_path = os.path.join(folder_path, 'test_evm.pf')
calldata_path = os.path.join(folder_path, 'calldata_separate.bytes')
with open(addr_path_verifier, 'r') as file:
addr_verifier = file.read().rstrip()
print(addr_verifier)
with open(addr_path_vk, 'r') as file:
addr_vk = file.read().rstrip()
print(addr_vk)
# res is now a vector of bytes
res = ezkl.encode_evm_calldata(proof_path, calldata_path, addr_vk=addr_vk)
assert os.path.isfile(calldata_path)
assert len(res) > 0
# TODO: without optimization there will be out of gas errors
# sol_code_path = os.path.join(folder_path, 'test.sol')
res = await ezkl.verify_evm(
addr_verifier,
proof_path,
rpc_url=anvil_url,
addr_vk=addr_vk,
# sol_code_path
# optimizer_runs
)
assert res == True
async def test_aggregate_and_verify_aggr():
data_path = os.path.join(
@@ -761,6 +870,7 @@ def get_examples():
'accuracy',
'linear_regression',
"mnist_gan",
"smallworm",
]
examples = []
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):