Compare commits

...

8 Commits

Author SHA1 Message Date
github-actions[bot]
1066b1ebcf ci: update version string in docs 2025-06-21 17:24:45 +00:00
dante
e81d93a73a chore: display calibration fail reasons (#983) 2025-06-21 19:24:30 +02:00
dante
40ce9dfde9 chore: rm lots of clones (#980) 2025-05-26 10:54:09 -04:00
dante
839030ce10 chore: rm halo2proofs patches (#976) 2025-04-29 10:58:35 -04:00
dante
cfccc5460c refactor: rm postgres (#977) 2025-04-29 08:59:14 -04:00
dante
0de0682bfa refactor: configurable div epsilon (#968) 2025-04-23 09:12:24 +01:00
dante
bf9cf14ab7 refactor!: rpc url should be required (#965)
BREAKING CHANGE: in python the order of arguments for evm related functions has changed
2025-04-22 12:45:36 +01:00
dante
6818962ac2 chore: pass in raw data for gen-witness from file (#964) 2025-04-06 14:08:11 -04:00
65 changed files with 40659 additions and 3709 deletions

View File

@@ -258,7 +258,7 @@ jobs:
- name: Install built wheel
if: matrix.target == 'x86_64-unknown-linux-musl'
uses: addnab/docker-run-action@v3
uses: addnab/docker-run-action@3e77f186b7a929ef010f183a9e24c0f9955ea609
with:
image: alpine:latest
options: -v ${{ github.workspace }}:/io -w /io
@@ -380,7 +380,7 @@ jobs:
with:
persist-credentials: false
- name: Trigger RTDs build
uses: dfm/rtds-action@v1
uses: dfm/rtds-action@618148c547f4b56cdf4fa4dcf3a94c91ce025f2d
with:
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}

View File

@@ -198,15 +198,15 @@ jobs:
- name: Build release binary (no asm or metal)
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features mimalloc
- name: Build release binary (asm)
if: matrix.build == 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm,mimalloc
- name: Build release binary (metal)
if: matrix.build == 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal,mimalloc
- name: Strip release binary
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'

View File

@@ -33,7 +33,7 @@ jobs:
toolchain: nightly-2025-02-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
- uses: baptiste0928/cargo-install@91c5da15570085bcde6f4d7aed98cb82d6769fd3
with:
crate: cargo-nextest
locked: true
@@ -233,7 +233,7 @@ jobs:
with:
# Pin to version 0.12.1
version: "v0.12.1"
- uses: nanasess/setup-chromedriver@e93e57b843c0c92788f22483f1a31af8ee48db25 #v2.3.0
- uses: nanasess/setup-chromedriver@affb1ea8848cbb080be372c1e8d7a5c173e9298f #v2.3.0
# with:
# chromedriver-version: "115.0.5790.102"
- name: Install wasm32-unknown-unknown
@@ -245,25 +245,6 @@ jobs:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
foudry-solidity-tests:
permissions:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
with:
persist-credentials: false
submodules: recursive
- name: Install Foundry
uses: foundry-rs/foundry-toolchain@v1
- name: Run tests
run: |
cd tests/foundry
forge install https://github.com/foundry-rs/forge-std --no-git --no-commit
forge test -vvvv --fuzz-runs 64
mock-proving-tests:
permissions:
contents: read
@@ -372,9 +353,6 @@ jobs:
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
- name: Build @ezkljs/verify package
run: |
cd in-browser-evm-verifier
@@ -497,9 +475,6 @@ jobs:
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
@@ -744,24 +719,6 @@ jobs:
permissions:
contents: read
runs-on: large-self-hosted
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres
env:
POSTGRES_USER: ubuntu
POSTGRES_HOST_AUTH_METHOD: trust
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
-v /var/run/postgresql:/var/run/postgresql
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
@@ -798,8 +755,6 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
- name: Felt conversion
run: source .env/bin/activate; cargo nextest run py_tests::tests::felt_conversion_test_ --no-capture
- name: Postgres tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: authenticate-kaggle-cli

2589
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@ cargo-features = ["profile-rustflags"]
[package]
name = "ezkl"
version = "0.0.0"
edition = "2024"
edition = "2021"
default-run = "ezkl"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -69,20 +69,18 @@ reqwest = { version = "0.12.4", default-features = false, features = [
"stream",
], optional = true }
openssl = { version = "0.10.55", features = ["vendored"], optional = true }
tokio-postgres = { version = "0.7.10", optional = true }
pg_bigdecimal = { version = "0.1.5", optional = true }
lazy_static = { version = "1.4.0", optional = true }
colored_json = { version = "3.0.1", default-features = false, optional = true }
tokio = { version = "1.35.0", default-features = false, features = [
"macros",
"rt-multi-thread",
], optional = true }
pyo3 = { version = "0.23.2", features = [
pyo3 = { version = "0.24.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default-features = false, optional = true }
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.23.0", features = [
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.24.0", features = [
"attributes",
"tokio-runtime",
], default-features = false, optional = true }
@@ -90,9 +88,9 @@ pyo3-log = { version = "0.12.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
objc = { version = "0.2.4", optional = true }
mimalloc = { version = "0.1", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
jemallocator = { version = "0.5", optional = true }
mimalloc = { version = "0.1", optional = true }
# universal bindings
uniffi = { version = "=0.28.0", optional = true }
getrandom = { version = "0.2.8", optional = true }
@@ -242,12 +240,9 @@ ezkl = [
"dep:indicatif",
"dep:gag",
"dep:reqwest",
"dep:tokio-postgres",
"dep:pg_bigdecimal",
"dep:lazy_static",
"dep:tokio",
"dep:openssl",
"dep:mimalloc",
"dep:chrono",
"dep:sha256",
"dep:clap_complete",
@@ -274,22 +269,19 @@ no-banner = []
no-update = []
macos-metal = ["halo2_proofs/macos"]
ios-metal = ["halo2_proofs/ios"]
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
jemalloc = ["dep:jemallocator"]
mimalloc = ["dep:mimalloc"]
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
[profile.release]
# debug = true
rustflags = ["-C", "relocation-model=pic"]
lto = "fat"
codegen-units = 1
#panic = "abort"
# panic = "abort"
[profile.test-runs]

View File

@@ -68,7 +68,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
&[&self.image, &self.kernel, &self.bias],
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
stride: vec![1; 2],

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -59,7 +60,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -61,7 +62,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),

View File

@@ -17,6 +17,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -86,13 +87,13 @@ impl Circuit<Fr> for MyCircuit {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[output.unwrap()],
&[&output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -17,6 +17,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -87,13 +88,13 @@ impl Circuit<Fr> for MyCircuit {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[output.unwrap()],
&[&output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -59,7 +60,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.unwrap();

View File

@@ -63,7 +63,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&[self.image.clone()],
&[&self.image],
Box::new(HybridOp::SumPool {
padding: vec![(0, 0); 2],
stride: vec![1, 1],

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -57,7 +58,11 @@ impl Circuit<Fr> for MyCircuit {
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.unwrap();
Ok(())
},

View File

@@ -16,6 +16,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -58,7 +59,11 @@ impl Circuit<Fr> for MyCircuit {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Pow(4)))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Pow(4)),
)
.unwrap();
Ok(())
},

View File

@@ -70,7 +70,7 @@ impl Circuit<Fr> for NLCircuit {
config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,

View File

@@ -67,7 +67,7 @@ impl Circuit<Fr> for NLCircuit {
config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '22.0.5'
version = release

View File

@@ -32,7 +32,6 @@ use mnist::*;
use rand::rngs::OsRng;
use std::marker::PhantomData;
mod params;
const K: usize = 20;
@@ -216,11 +215,7 @@ where
.layer_config
.layout(
&mut region,
&[
self.input.clone(),
self.l0_params[0].clone(),
self.l0_params[1].clone(),
],
&[&self.input, &self.l0_params[0], &self.l0_params[1]],
Box::new(op),
)
.unwrap();
@@ -229,7 +224,7 @@ where
.layer_config
.layout(
&mut region,
&[x.unwrap()],
&[&x.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -241,7 +236,7 @@ where
.layer_config
.layout(
&mut region,
&[x.unwrap()],
&[&x.unwrap()],
Box::new(LookupOp::Div { denom: 32.0.into() }),
)
.unwrap()
@@ -253,7 +248,7 @@ where
.layer_config
.layout(
&mut region,
&[self.l2_params[0].clone(), x],
&[&self.l2_params[0], &x],
Box::new(PolyOp::Einsum {
equation: "ij,j->ik".to_string(),
}),
@@ -265,7 +260,7 @@ where
.layer_config
.layout(
&mut region,
&[x, self.l2_params[1].clone()],
&[&x, &self.l2_params[1]],
Box::new(PolyOp::Add),
)
.unwrap()

View File

@@ -117,10 +117,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[
self.l0_params[0].clone().try_into().unwrap(),
self.input.clone(),
],
&[&self.l0_params[0].clone().try_into().unwrap(), &self.input],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -135,7 +132,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x, self.l0_params[1].clone().try_into().unwrap()],
&[&x, &self.l0_params[1].clone().try_into().unwrap()],
Box::new(PolyOp::Add),
)
.unwrap()
@@ -147,7 +144,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x],
&[&x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
@@ -163,7 +160,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[self.l2_params[0].clone().try_into().unwrap(), x],
&[&self.l2_params[0].clone().try_into().unwrap(), &x],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -178,7 +175,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x, self.l2_params[1].clone().try_into().unwrap()],
&[&x, &self.l2_params[1].clone().try_into().unwrap()],
Box::new(PolyOp::Add),
)
.unwrap()
@@ -190,7 +187,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x],
&[&x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
@@ -203,7 +200,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x.unwrap()],
&[&x.unwrap()],
Box::new(LookupOp::Div {
denom: ezkl::circuit::utils::F32::from(128.),
}),

View File

@@ -1088,7 +1088,7 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" rpc_url='http://127.0.0.1:3030'\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",

View File

@@ -472,8 +472,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -526,9 +526,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
@@ -557,8 +557,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]
@@ -566,7 +566,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "ezkl",
"display_name": ".env",
"language": "python",
"name": "python3"
},
@@ -580,7 +580,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.12.9"
},
"orig_nbformat": 4
},

View File

@@ -543,8 +543,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -597,9 +597,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
@@ -628,8 +628,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]
@@ -651,7 +651,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.9"
},
"orig_nbformat": 4
},

View File

@@ -474,8 +474,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -529,9 +529,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
@@ -560,8 +560,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]

View File

@@ -453,8 +453,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -474,8 +474,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
@@ -510,4 +510,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -462,8 +462,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -483,8 +483,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
@@ -512,4 +512,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -1,462 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mean of ERC20 transfer amounts\n",
"\n",
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"\n",
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"import json\n",
"import time\n",
"import subprocess\n",
"\n",
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
"os.system(\"chmod +x shovel\")\n",
"\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
"\n",
"# create a config.json file with the following contents\n",
"config = {\n",
" \"pg_url\": \"$PG_URL\",\n",
" \"eth_sources\": [\n",
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
" ],\n",
" \"integrations\": [{\n",
" \"name\": \"usdc_transfer\",\n",
" \"enabled\": True,\n",
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
" \"table\": {\n",
" \"name\": \"usdc\",\n",
" \"columns\": [\n",
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
" ]\n",
" },\n",
" \"block\": [\n",
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
" {\n",
" \"name\": \"log_addr\",\n",
" \"column\": \"log_addr\",\n",
" \"filter_op\": \"contains\",\n",
" \"filter_arg\": [\n",
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
" ]\n",
" }\n",
" ],\n",
" \"event\": {\n",
" \"name\": \"Transfer\",\n",
" \"type\": \"event\",\n",
" \"anonymous\": False,\n",
" \"inputs\": [\n",
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
" ]\n",
" }\n",
" }]\n",
"}\n",
"\n",
"# write the config to a file\n",
"with open(\"config.json\", \"w\") as f:\n",
" f.write(json.dumps(config))\n",
"\n",
"\n",
"# print the two env variables\n",
"os.system(\"echo $PG_URL\")\n",
"\n",
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
"\n",
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"proc = subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(10)\n",
"\n",
"# after we've fetched some data -- kill the process\n",
"proc.terminate()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2wIAHwqH2_mo"
},
"source": [
"**Import Dependencies**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9Byiv2Nc2MsK"
},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import ezkl\n",
"import torch\n",
"import datetime\n",
"import pandas as pd\n",
"import requests\n",
"import json\n",
"import os\n",
"\n",
"import logging\n",
"# # uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"print(\"ezkl version: \", ezkl.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "osjj-0Ta3E8O"
},
"source": [
"**Create Computational Graph**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "x1vl9ZXF3EEW",
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
},
"outputs": [],
"source": [
"from torch import nn\n",
"import torch\n",
"\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
"\n",
" # x is a time series \n",
" def forward(self, x):\n",
" return [torch.mean(x)]\n",
"\n",
"\n",
"\n",
"\n",
"circuit = Model()\n",
"\n",
"\n",
"\n",
"\n",
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
"\n",
"# # print(torch.__version__)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"print(device)\n",
"\n",
"circuit.to(device)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=11, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"# export(circuit, input_shape=[1, 20])\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E3qCeX-X5xqd"
},
"source": [
"**Set Data Source and Get Data**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6RAMplxk5xPk",
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
},
"outputs": [],
"source": [
"import getpass\n",
"# make an input.json file from the df above\n",
"input_filename = os.path.join('input.json')\n",
"\n",
"pg_input_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
"print(json_formatted_str)\n",
"\n",
"\n",
" # Serialize data into file:\n",
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# this corresponds to 4 batches\n",
"calibration_filename = os.path.join('calibration.json')\n",
"\n",
"pg_cal_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
" # Serialize data into file:\n",
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eLJ7oirQ_HQR"
},
"source": [
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rNw0C9QL6W88"
},
"outputs": [],
"source": [
"import subprocess\n",
"import os\n",
"\n",
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.decomp_legs = 4\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename, py_run_args=run_args)\n",
"\n",
"assert res == True\n",
"\n",
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True\n",
"\n",
"await ezkl.get_srs(settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4MmE9SX66_Il",
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
},
"outputs": [],
"source": [
"# generate settings\n",
"\n",
"\n",
"# show the settings.json\n",
"with open(\"settings.json\") as f:\n",
" data = json.load(f)\n",
" json_formatted_str = json.dumps(data, indent=2)\n",
"\n",
" print(json_formatted_str)\n",
"\n",
"assert os.path.exists(\"settings.json\")\n",
"assert os.path.exists(\"input.json\")\n",
"assert os.path.exists(\"lol.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fULvvnK7_CMb"
},
"outputs": [],
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"\n",
"\n",
"# setup the proof\n",
"res = ezkl.setup(\n",
" compiled_filename,\n",
" vk_path,\n",
" pk_path\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_filename)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"witness_path = \"witness.json\"\n",
"\n",
"# generate the witness\n",
"res = await ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Oog3j6Kd-Wed",
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
},
"outputs": [],
"source": [
"# prove the zk circuit\n",
"# GENERATE A PROOF\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"\n",
"proof = ezkl.prove(\n",
" witness_path,\n",
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\"\n",
" )\n",
"\n",
"\n",
"print(\"proved\")\n",
"\n",
"assert os.path.isfile(proof_path)\n",
"\n"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@@ -504,8 +504,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -527,8 +527,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path\n",
")\n",
"assert res == True"
]
@@ -558,4 +558,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
}
}

View File

@@ -261,8 +261,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" \"verifier/reusable\"\n",
")\n",
"\n",
@@ -288,7 +288,7 @@
"for name in names:\n",
" addr_path_vk = \"addr_vk.txt\"\n",
" sol_key_code_path = os.path.join(name, 'test_key.sol')\n",
" res = await ezkl.deploy_evm(addr_path_vk, sol_key_code_path, 'http://127.0.0.1:3030', \"vka\")\n",
" res = await ezkl.deploy_evm(addr_path_vk, 'http://127.0.0.1:3030', sol_key_code_path, \"vka\")\n",
" assert res == True\n",
"\n",
" with open(addr_path_vk, 'r') as file:\n",
@@ -298,8 +298,8 @@
" sol_code_path = os.path.join(name, 'vk.sol')\n",
" res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" addr_vk = addr_vk\n",
" )\n",
" assert res == True"

View File

@@ -562,8 +562,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -616,9 +616,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )\n"
]
},
@@ -653,8 +653,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]

View File

@@ -666,7 +666,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -689,8 +689,8 @@
"# await\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -701,7 +701,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": null,
"metadata": {},
"outputs": [
{
@@ -722,8 +722,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
")\n",
"assert res == True"
]
@@ -743,7 +743,8 @@
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
@@ -756,7 +757,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.2"
"version": "3.12.9"
}
},
"nbformat": 4,

View File

@@ -849,8 +849,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" address_path,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True\n",
@@ -870,8 +870,8 @@
"\n",
"res = await ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" \"http://127.0.0.1:3030\"\n",
" \"http://127.0.0.1:3030\",\n",
" proof_path\n",
")\n",
"assert res == True"
]
@@ -905,4 +905,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -358,8 +358,8 @@
"\n",
"res = await ezkl.deploy_evm(\n",
" addr_path_verifier,\n",
" 'http://127.0.0.1:3030',\n",
" sol_code_path,\n",
" 'http://127.0.0.1:3030'\n",
")\n",
"\n",
"assert res == True"
@@ -405,9 +405,9 @@
"res = await ezkl.deploy_da_evm(\n",
" addr_path_da,\n",
" input_path,\n",
" RPC_URL,\n",
" settings_path,\n",
" sol_code_path,\n",
" RPC_URL,\n",
" )"
]
},
@@ -470,8 +470,8 @@
"\n",
"res = ezkl.verify_evm(\n",
" addr,\n",
" proof_path,\n",
" RPC_URL,\n",
" proof_path,\n",
" addr_da,\n",
")"
]
@@ -531,7 +531,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.12.9"
}
},
"nbformat": 4,

Binary file not shown.

37795
examples/onnx/fr_age/lol.txt Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,5 @@
// ignore file if compiling for wasm
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc::MiMalloc;
#[global_allocator]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
static GLOBAL: MiMalloc = MiMalloc;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::{CommandFactory, Parser};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]

View File

@@ -1,34 +1,34 @@
use crate::Commitments;
use crate::RunArgs;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::circuit::modules::Module;
use crate::circuit::modules::polycommit::PolyCommitChip;
use crate::circuit::modules::poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
spec::{POSEIDON_RATE, POSEIDON_WIDTH, PoseidonSpec},
};
use crate::circuit::modules::Module;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::commands::*;
use crate::fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt};
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::graph::TestDataSource;
use crate::graph::{
GraphCircuit, GraphSettings, Model, Visibility, quantize_float, scale_to_multiplier,
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
};
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::{
ProofType, TranscriptType, load_pk, load_vk, save_params, save_vk,
srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
ProofType, TranscriptType,
};
use crate::Commitments;
use crate::RunArgs;
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::{Bn256, Fq, Fr, G1, G1Affine};
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
use pyo3::exceptions::{PyIOError, PyRuntimeError};
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3_log;
use pyo3_stub_gen::{
TypeInfo, define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction,
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction, TypeInfo,
};
use snark_verifier::util::arithmetic::PrimeField;
use std::collections::HashSet;
@@ -206,6 +206,9 @@ struct PyRunArgs {
/// bool: Should the circuit use range checks for inputs and outputs (set to false if the input is a felt)
#[pyo3(get, set)]
pub ignore_range_check_inputs_outputs: bool,
/// float: epsilon used for arguments that use division
#[pyo3(get, set)]
pub epsilon: f64,
}
/// default instantiation of PyRunArgs
@@ -238,12 +241,14 @@ impl From<PyRunArgs> for RunArgs {
decomp_base: py_run_args.decomp_base,
decomp_legs: py_run_args.decomp_legs,
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
epsilon: Some(py_run_args.epsilon),
}
}
}
impl Into<PyRunArgs> for RunArgs {
fn into(self) -> PyRunArgs {
let eps = self.get_epsilon();
PyRunArgs {
bounded_log_lookup: self.bounded_log_lookup,
input_scale: self.input_scale,
@@ -262,6 +267,7 @@ impl Into<PyRunArgs> for RunArgs {
decomp_base: self.decomp_base,
decomp_legs: self.decomp_legs,
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
epsilon: eps,
}
}
}
@@ -962,7 +968,7 @@ fn gen_settings(
output=PathBuf::from(DEFAULT_SETTINGS),
variables=Vec::from([("batch_size".to_string(), 1)]),
seed=DEFAULT_SEED.parse().unwrap(),
min=None,
min=None,
max=None
))]
#[gen_stub_pyfunction]
@@ -1823,7 +1829,7 @@ fn create_evm_data_attestation(
test_data,
input_source,
output_source,
rpc_url=None
rpc_url,
))]
#[gen_stub_pyfunction]
fn setup_test_evm_data(
@@ -1833,7 +1839,7 @@ fn setup_test_evm_data(
test_data: PathBuf,
input_source: PyTestDataSource,
output_source: PyTestDataSource,
rpc_url: Option<String>,
rpc_url: String,
) -> PyResult<Bound<'_, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
crate::execute::setup_test_evm_data(
@@ -1857,8 +1863,8 @@ fn setup_test_evm_data(
/// deploys the solidity verifier
#[pyfunction(signature = (
addr_path,
rpc_url,
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
rpc_url=None,
contract_type=ContractType::default(),
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None,
@@ -1867,8 +1873,8 @@ fn setup_test_evm_data(
fn deploy_evm(
py: Python,
addr_path: PathBuf,
rpc_url: String,
sol_code_path: PathBuf,
rpc_url: Option<String>,
contract_type: ContractType,
optimizer_runs: usize,
private_key: Option<String>,
@@ -1896,9 +1902,9 @@ fn deploy_evm(
#[pyfunction(signature = (
addr_path,
input_data,
rpc_url,
settings_path=PathBuf::from(DEFAULT_SETTINGS),
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE_DA),
rpc_url=None,
optimizer_runs=DEFAULT_OPTIMIZER_RUNS.parse().unwrap(),
private_key=None
))]
@@ -1907,9 +1913,9 @@ fn deploy_da_evm(
py: Python,
addr_path: PathBuf,
input_data: String,
rpc_url: String,
settings_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
optimizer_runs: usize,
private_key: Option<String>,
) -> PyResult<Bound<'_, PyAny>> {
@@ -1956,8 +1962,8 @@ fn deploy_da_evm(
///
#[pyfunction(signature = (
addr_verifier,
rpc_url,
proof_path=PathBuf::from(DEFAULT_PROOF),
rpc_url=None,
addr_da = None,
addr_vk = None,
))]
@@ -1965,8 +1971,8 @@ fn deploy_da_evm(
fn verify_evm<'a>(
py: Python<'a>,
addr_verifier: &'a str,
rpc_url: String,
proof_path: PathBuf,
rpc_url: Option<String>,
addr_da: Option<&'a str>,
addr_vk: Option<&'a str>,
) -> PyResult<Bound<'a, PyAny>> {

View File

@@ -962,7 +962,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
pub fn layout(
&mut self,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, CircuitError> {
op.layout(self, region, values)

View File

@@ -15,10 +15,12 @@ use serde::{Deserialize, Serialize};
pub enum HybridOp {
Ln {
scale: utils::F32,
eps: f64,
},
Rsqrt {
input_scale: utils::F32,
output_scale: utils::F32,
eps: f64,
},
Sqrt {
scale: utils::F32,
@@ -42,6 +44,7 @@ pub enum HybridOp {
Recip {
input_scale: utils::F32,
output_scale: utils::F32,
eps: f64,
},
Div {
denom: utils::F32,
@@ -77,6 +80,7 @@ pub enum HybridOp {
input_scale: utils::F32,
output_scale: utils::F32,
axes: Vec<usize>,
eps: f64,
},
Output {
decomp: bool,
@@ -105,13 +109,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
HybridOp::Greater { .. }
| HybridOp::Less { .. }
| HybridOp::Equals { .. }
| HybridOp::GreaterEqual { .. }
HybridOp::Greater
| HybridOp::Less
| HybridOp::Equals
| HybridOp::GreaterEqual
| HybridOp::Max
| HybridOp::Min
| HybridOp::LessEqual { .. } => {
| HybridOp::LessEqual => {
vec![0, 1]
}
_ => vec![],
@@ -128,12 +132,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Rsqrt {
input_scale,
output_scale,
eps,
} => format!(
"RSQRT (input_scale={}, output_scale={})",
input_scale, output_scale
"RSQRT (input_scale={}, output_scale={}, eps={})",
input_scale, output_scale, eps
),
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
HybridOp::Ln { scale, eps } => format!("LN(scale={}, eps={})", scale, eps),
HybridOp::RoundHalfToEven { scale, legs } => {
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
}
@@ -146,16 +151,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Recip {
input_scale,
output_scale,
eps,
} => format!(
"RECIP (input_scale={}, output_scale={})",
input_scale, output_scale
"RECIP (input_scale={}, output_scale={}, eps={})",
input_scale, output_scale, eps
),
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
HybridOp::SumPool {
padding,
stride,
kernel_shape,
normalized, data_format
normalized,
data_format,
} => format!(
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
padding, stride, kernel_shape, normalized, data_format
@@ -177,10 +184,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
input_scale,
output_scale,
axes,
eps,
} => {
format!(
"SOFTMAX (input_scale={}, output_scale={}, axes={:?})",
input_scale, output_scale, axes
"SOFTMAX (input_scale={}, output_scale={}, axes={:?}, eps={})",
input_scale, output_scale, axes, eps
)
}
HybridOp::Output { decomp } => {
@@ -205,23 +213,27 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::Rsqrt {
input_scale,
output_scale,
eps,
} => layouts::rsqrt(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
*eps,
)?,
HybridOp::Sqrt { scale } => {
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
}
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
HybridOp::Ln { scale, eps } => {
layouts::ln(config, region, values[..].try_into()?, *scale, *eps)?
}
HybridOp::RoundHalfToEven { scale, legs } => {
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
}
@@ -255,12 +267,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
HybridOp::Recip {
input_scale,
output_scale,
eps,
} => layouts::recip(
config,
region,
values[..].try_into()?,
integer_rep_to_felt(input_scale.0 as IntegerRep),
integer_rep_to_felt(output_scale.0 as IntegerRep),
*eps,
)?,
HybridOp::Div { denom, .. } => {
if denom.0.fract() == 0.0 {
@@ -317,6 +331,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
input_scale,
output_scale,
axes,
eps,
} => layouts::softmax_axes(
config,
region,
@@ -324,6 +339,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
*input_scale,
*output_scale,
axes,
*eps,
)?,
HybridOp::Output { decomp } => {
layouts::output(config, region, values[..].try_into()?, *decomp)?
@@ -346,10 +362,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
HybridOp::Greater { .. }
| HybridOp::GreaterEqual { .. }
| HybridOp::Less { .. }
| HybridOp::LessEqual { .. }
HybridOp::Greater
| HybridOp::GreaterEqual
| HybridOp::Less
| HybridOp::LessEqual
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
@@ -364,6 +380,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
HybridOp::Ln {
scale: output_scale,
eps: _,
} => 4 * multiplier_to_scale(output_scale.0 as f64),
_ => in_scales[0],
};

File diff suppressed because it is too large Load Diff

View File

@@ -186,7 +186,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(layouts::nonlinearity(
config,

View File

@@ -49,7 +49,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError>;
/// Returns the scale of the output of the operation.
@@ -209,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = values[0].clone();
if !value.all_prev_assigned() {
@@ -223,12 +223,29 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
true,
)?))
}
_ => Ok(Some(super::layouts::identity(
config,
region,
values[..].try_into()?,
self.decomp,
)?)),
_ => {
if self.decomp {
log::debug!("constraining input to be decomp");
Ok(Some(
super::layouts::decompose(
config,
region,
values[..].try_into()?,
&region.base(),
&region.legs(),
false,
)?
.1,
))
} else {
log::debug!("constraining input to be identity");
Ok(Some(super::layouts::identity(
config,
region,
values[..].try_into()?,
)?))
}
}
}
} else {
Ok(Some(value))
@@ -263,7 +280,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
&self,
_: &mut crate::circuit::BaseConfig<F>,
_: &mut RegionCtx<F>,
_: &[ValTensor<F>],
_: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Err(super::CircuitError::UnsupportedOp)
}
@@ -319,8 +336,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
}
impl<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self
@@ -333,20 +355,20 @@ impl<
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
_: &[ValTensor<F>],
_: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = if let Some(value) = &self.pre_assigned_val {
value.clone()
} else {
self.quantized_values.clone().try_into()?
};
// we gotta constrain it once if its used multiple times
Ok(Some(layouts::identity(
config,
region,
&[value],
self.decomp,
)?))
Ok(Some(if self.decomp {
log::debug!("constraining constant to be decomp");
super::layouts::decompose(config, region, &[&value], &region.base(), &region.legs(), false)?.1
} else {
log::debug!("constraining constant to be identity");
super::layouts::identity(config, region, &[&value])?
}))
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {

View File

@@ -108,8 +108,13 @@ pub enum PolyOp {
}
impl<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for PolyOp
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
@@ -203,11 +208,11 @@ impl<
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?, true)?,
PolyOp::LeakyReLU { slope, scale } => {
layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)?
}
@@ -335,9 +340,7 @@ impl<
PolyOp::Mult => {
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
}
PolyOp::Identity { .. } => {
layouts::identity(config, region, values[..].try_into()?, false)?
}
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {
@@ -416,14 +419,14 @@ impl<
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
PolyOp::Sign { .. } => 0,
PolyOp::Sign => 0,
_ => in_scales[0],
};
Ok(scale)
}
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) {
if matches!(self, PolyOp::Add | PolyOp::Sub) {
vec![0, 1]
} else if matches!(self, PolyOp::Iff) {
vec![1, 2]

View File

@@ -10,7 +10,6 @@ use halo2_proofs::{
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use maybe_rayon::iter::ParallelExtend;
use std::{
cell::RefCell,
@@ -462,15 +461,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
/// Update the max and min from inputs
pub fn update_max_min_lookup_inputs(
&mut self,
inputs: &[ValTensor<F>],
inputs: &ValTensor<F>,
) -> Result<(), CircuitError> {
let (mut min, mut max) = (0, 0);
for i in inputs {
max = max.max(i.int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.int_evals()?.into_iter().min().unwrap_or_default());
}
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
let int_eval = inputs.int_evals()?;
let max = int_eval.iter().max().unwrap_or(&0);
let min = int_eval.iter().min().unwrap_or(&0);
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(*max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(*min);
Ok(())
}
@@ -505,10 +503,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
/// add used lookup
pub fn add_used_lookup(
&mut self,
lookup: LookupOp,
inputs: &[ValTensor<F>],
lookup: &LookupOp,
inputs: &ValTensor<F>,
) -> Result<(), CircuitError> {
self.statistics.used_lookups.insert(lookup);
self.statistics.used_lookups.insert(lookup.clone());
self.update_max_min_lookup_inputs(inputs)
}
@@ -642,34 +640,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.assign_dynamic_lookup(var, values)
}
/// Assign a valtensor to a vartensor
pub fn assign_with_omissions(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
ommissions: &HashSet<usize>,
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
Ok(var.assign_with_omissions(
&mut region.borrow_mut(),
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)?)
} else {
let mut values_clone = values.clone();
let mut indices = ommissions.clone().into_iter().collect_vec();
values_clone.remove_indices(&mut indices, false)?;
let values_map = values.create_constants_map();
self.assigned_constants.par_extend(values_map);
Ok(values.clone())
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication_unconstrained(
&mut self,

View File

@@ -9,6 +9,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::Fr as F;
use halo2curves::ff::{Field, PrimeField};
use itertools::Itertools;
#[cfg(not(any(
all(target_arch = "wasm32", target_os = "unknown"),
not(feature = "ezkl")
@@ -64,7 +65,7 @@ mod matmul {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -141,7 +142,7 @@ mod matmul_col_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -215,7 +216,7 @@ mod matmul_col_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -302,7 +303,7 @@ mod matmul_col_ultra_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -380,6 +381,7 @@ mod matmul_col_ultra_overflow {
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use super::*;
@@ -422,7 +424,7 @@ mod matmul_col_ultra_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -533,7 +535,7 @@ mod dot {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -610,7 +612,7 @@ mod dot_col_overflow_triple_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -683,7 +685,7 @@ mod dot_col_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -756,7 +758,7 @@ mod sum {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -826,7 +828,7 @@ mod sum_col_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -895,7 +897,7 @@ mod sum_col_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -966,7 +968,7 @@ mod composition {
let _ = config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -975,7 +977,7 @@ mod composition {
let _ = config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -984,7 +986,7 @@ mod composition {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -1061,7 +1063,7 @@ mod conv {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1218,7 +1220,7 @@ mod conv_col_ultra_overflow {
config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone()],
&[&self.image, &self.kernel],
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1377,7 +1379,7 @@ mod conv_relu_col_ultra_overflow {
let output = config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone()],
&[&self.image, &self.kernel],
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1390,7 +1392,7 @@ mod conv_relu_col_ultra_overflow {
let _output = config
.layout(
&mut region,
&[output.unwrap().unwrap()],
&[&output.unwrap().unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -1517,7 +1519,11 @@ mod add_w_shape_casting {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -1584,7 +1590,11 @@ mod add {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -1671,8 +1681,8 @@ mod dynamic_lookup {
layouts::dynamic_lookup(
&config,
&mut region,
&self.lookups[i],
&self.tables[i],
&self.lookups[i].iter().collect_vec().try_into().unwrap(),
&self.tables[i].iter().collect_vec().try_into().unwrap(),
)
.map_err(|_| Error::Synthesis)?;
}
@@ -1767,8 +1777,8 @@ mod shuffle {
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [[ValTensor<F>; 1]; NUM_LOOP],
references: [[ValTensor<F>; 1]; NUM_LOOP],
inputs: [ValTensor<F>; NUM_LOOP],
references: [ValTensor<F>; NUM_LOOP],
_marker: PhantomData<F>,
}
@@ -1818,15 +1828,15 @@ mod shuffle {
layouts::shuffles(
&config,
&mut region,
&self.inputs[i],
&self.references[i],
&[&self.inputs[i]],
&[&self.references[i]],
layouts::SortCollisionMode::Unsorted,
)
.map_err(|_| Error::Synthesis)?;
}
assert_eq!(
region.shuffle_col_coord(),
NUM_LOOP * self.references[0][0].len()
NUM_LOOP * self.references[0].len()
);
assert_eq!(region.shuffle_index(), NUM_LOOP);
@@ -1843,17 +1853,19 @@ mod shuffle {
// parameters
let references = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
ValTensor::from(Tensor::from(
(0..LEN).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
))
})
.collect::<Vec<_>>();
let inputs = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
ValTensor::from(Tensor::from(
(0..LEN)
.rev()
.map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
))
})
.collect::<Vec<_>>();
@@ -1873,9 +1885,11 @@ mod shuffle {
} else {
loop_idx - 1
};
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * prev_idx) as u64 + 1))
})))]
ValTensor::from(Tensor::from(
(0..LEN)
.rev()
.map(|i| Value::known(F::from((i * prev_idx) as u64 + 1))),
))
})
.collect::<Vec<_>>();
@@ -1931,7 +1945,11 @@ mod add_with_overflow {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2026,7 +2044,7 @@ mod add_with_overflow_and_poseidon {
layouter.assign_region(|| "_new_module", |_| Ok(()))?;
let inputs = vec![assigned_inputs_a, assigned_inputs_b];
let inputs = vec![&assigned_inputs_a, &assigned_inputs_b];
layouter.assign_region(
|| "model",
@@ -2135,7 +2153,11 @@ mod sub {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Sub))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sub),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2202,7 +2224,11 @@ mod mult {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Mult))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Mult),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2269,7 +2295,11 @@ mod pow {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Pow(5)))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Pow(5)),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2360,13 +2390,13 @@ mod matmul_relu {
};
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[output.unwrap()],
&[&output.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -2465,7 +2495,7 @@ mod relu {
Ok(config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -2563,7 +2593,7 @@ mod lookup_ultra_overflow {
config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.map_err(|_| Error::Synthesis)

View File

@@ -1,6 +1,6 @@
use alloy::primitives::Address as H160;
use clap::{Command, Parser, Subcommand};
use clap_complete::{Generator, Shell, generate};
use clap_complete::{generate, Generator, Shell};
#[cfg(feature = "python-bindings")]
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
@@ -8,7 +8,7 @@ use std::path::PathBuf;
use std::str::FromStr;
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{Commitments, RunArgs, pfsys::ProofType};
use crate::{pfsys::ProofType, Commitments, RunArgs};
use crate::circuit::CheckMode;
use crate::graph::TestDataSource;
@@ -382,6 +382,42 @@ pub struct Cli {
pub command: Option<Commands>,
}
/// Custom parser for data field that handles both direct JSON strings and file paths with '@' prefix
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, PartialOrd)]
pub struct DataField(pub String);
impl FromStr for DataField {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// Check if the input starts with '@'
if let Some(file_path) = s.strip_prefix('@') {
// Extract the file path (remove the '@' prefix)
// Read the file content
let content = std::fs::read_to_string(file_path)
.map_err(|e| format!("Failed to read data file '{}': {}", file_path, e))?;
// Return the file content as the data field value
Ok(DataField(content))
} else {
// Use the input string directly
Ok(DataField(s.to_string()))
}
}
}
impl ToFlags for DataField {
fn to_flags(&self) -> Vec<String> {
vec![self.0.clone()]
}
}
impl std::fmt::Display for DataField {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[allow(missing_docs)]
#[derive(Debug, Subcommand, Clone, Deserialize, Serialize, PartialEq, PartialOrd, ToSubcommand)]
pub enum Commands {
@@ -400,9 +436,9 @@ pub enum Commands {
/// Generates the witness from an input file.
GenWitness {
/// The path to the .json data file
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_hint = clap::ValueHint::FilePath)]
data: Option<String>,
/// The path to the .json data file (with @ prefix) or a raw data string of the form '{"input_data": [[1, 2, 3]]}'
#[arg(short = 'D', long, default_value = DEFAULT_DATA, value_parser = DataField::from_str)]
data: Option<DataField>,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long, default_value = DEFAULT_COMPILED_CIRCUIT, value_hint = clap::ValueHint::FilePath)]
compiled_circuit: Option<PathBuf>,
@@ -647,9 +683,9 @@ pub enum Commands {
/// Should include both the network input (possibly private) and the network output (public input to the proof)
#[arg(short = 'T', long, value_hint = clap::ValueHint::FilePath)]
test_data: PathBuf,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
rpc_url: String,
/// where the input data come from
#[arg(long, default_value = "on-chain", value_hint = clap::ValueHint::Other)]
input_source: TestDataSource,
@@ -844,9 +880,9 @@ pub enum Commands {
/// The path to the Solidity code (generated using the create-evm-verifier command)
#[arg(long, default_value = DEFAULT_SOL_CODE, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Url)]
rpc_url: String,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
/// The path to output the contract address
addr_path: Option<PathBuf>,
@@ -872,9 +908,9 @@ pub enum Commands {
/// The path to the Solidity code
#[arg(long, default_value = DEFAULT_SOL_CODE_DA, value_hint = clap::ValueHint::FilePath)]
sol_code_path: Option<PathBuf>,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
rpc_url: String,
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS_DA, value_hint = clap::ValueHint::FilePath)]
/// The path to output the contract address
addr_path: Option<PathBuf>,
@@ -894,9 +930,9 @@ pub enum Commands {
/// The path to verifier contract's address
#[arg(long, default_value = DEFAULT_CONTRACT_ADDRESS, value_hint = clap::ValueHint::Other)]
addr_verifier: H160Flag,
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
/// RPC URL for an Ethereum node
#[arg(short = 'U', long, value_hint = clap::ValueHint::Url)]
rpc_url: Option<String>,
rpc_url: String,
/// does the verifier use data attestation ?
#[arg(long, value_hint = clap::ValueHint::Other)]
addr_da: Option<H160Flag>,

View File

@@ -1,25 +1,24 @@
use crate::graph::DataSource;
use crate::graph::GraphSettings;
use crate::graph::input::{CallToAccount, CallsToAccount, FileSourceInner, GraphData};
use crate::graph::modules::POSEIDON_INSTANCES;
use crate::pfsys::Snark;
use crate::graph::DataSource;
use crate::graph::GraphSettings;
use crate::pfsys::evm::EvmVerificationError;
use crate::pfsys::Snark;
use alloy::contract::CallBuilder;
use alloy::core::primitives::Address as H160;
use alloy::core::primitives::Bytes;
use alloy::core::primitives::U256;
use alloy::dyn_abi::abi::TokenSeq;
use alloy::dyn_abi::abi::token::{DynSeqToken, PackedSeqToken, WordToken};
use alloy::dyn_abi::abi::TokenSeq;
// use alloy::providers::Middleware;
use alloy::json_abi::JsonAbi;
use alloy::node_bindings::Anvil;
use alloy::primitives::ruint::ParseError;
use alloy::primitives::{B256, I256, ParseSignedError};
use alloy::providers::ProviderBuilder;
use alloy::primitives::{ParseSignedError, B256, I256};
use alloy::providers::fillers::{
ChainIdFiller, FillProvider, GasFiller, JoinFill, NonceFiller, SignerFiller,
};
use alloy::providers::network::{Ethereum, EthereumSigner};
use alloy::providers::ProviderBuilder;
use alloy::providers::{Identity, Provider, RootProvider};
use alloy::rpc::types::eth::TransactionInput;
use alloy::rpc::types::eth::TransactionRequest;
@@ -28,9 +27,9 @@ use alloy::signers::wallet::{LocalWallet, WalletError};
use alloy::sol as abigen;
use alloy::transports::http::Http;
use alloy::transports::{RpcError, TransportErrorKind};
use foundry_compilers::Solc;
use foundry_compilers::artifacts::Settings as SolcSettings;
use foundry_compilers::error::{SolcError, SolcIoError};
use foundry_compilers::Solc;
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::{Fr, G1Affine};
use halo2curves::group::ff::PrimeField;
@@ -313,25 +312,12 @@ pub type ContractFactory<M> = CallBuilder<Http<Client>, Arc<M>, ()>;
/// Return an instance of Anvil and a client for the given RPC URL. If none is provided, a local client is used.
pub async fn setup_eth_backend(
rpc_url: Option<&str>,
rpc_url: &str,
private_key: Option<&str>,
) -> Result<(EthersClient, alloy::primitives::Address), EthError> {
// Launch anvil
let endpoint: String;
if let Some(rpc_url) = rpc_url {
endpoint = rpc_url.to_string();
} else {
let anvil = Anvil::new()
.args([
"--code-size-limit=41943040",
"--disable-block-gas-limit",
"-p",
"8545",
])
.spawn();
endpoint = anvil.endpoint();
}
let endpoint = rpc_url.to_string();
// Instantiate the wallet
let wallet: LocalWallet;
@@ -365,7 +351,7 @@ pub async fn setup_eth_backend(
///
pub async fn deploy_contract_via_solidity(
sol_code_path: PathBuf,
rpc_url: Option<&str>,
rpc_url: &str,
runs: usize,
private_key: Option<&str>,
contract_name: &str,
@@ -387,7 +373,7 @@ pub async fn deploy_da_verifier_via_solidity(
settings_path: PathBuf,
input: String,
sol_code_path: PathBuf,
rpc_url: Option<&str>,
rpc_url: &str,
runs: usize,
private_key: Option<&str>,
) -> Result<H160, EthError> {
@@ -574,7 +560,7 @@ pub async fn verify_proof_via_solidity(
proof: Snark<Fr, G1Affine>,
addr: H160,
addr_vk: Option<H160>,
rpc_url: Option<&str>,
rpc_url: &str,
) -> Result<bool, EthError> {
let flattened_instances = proof.instances.into_iter().flatten();
@@ -676,7 +662,7 @@ pub async fn verify_proof_with_data_attestation(
addr_verifier: H160,
addr_da: H160,
addr_vk: Option<H160>,
rpc_url: Option<&str>,
rpc_url: &str,
) -> Result<bool, EthError> {
use ethabi::{Function, Param, ParamType, StateMutability, Token};
@@ -725,7 +711,7 @@ pub async fn verify_proof_with_data_attestation(
};
let encoded = func.encode_input(&[
Token::Address(addr_verifier.0.0.into()),
Token::Address(addr_verifier.0 .0.into()),
Token::Bytes(encoded_verifier),
])?;
@@ -771,7 +757,7 @@ pub async fn test_on_chain_data<M: 'static + Provider<Http<Client>, Ethereum>>(
let call_to_account = CallToAccount {
call_data: hex::encode(call),
decimals,
address: hex::encode(contract.address().0.0),
address: hex::encode(contract.address().0 .0),
};
info!("call_to_account: {:#?}", call_to_account);
Ok(call_to_account)
@@ -927,9 +913,9 @@ pub async fn get_contract_artifacts(
runs: usize,
) -> Result<(JsonAbi, Bytes, Bytes), EthError> {
use foundry_compilers::{
SHANGHAI_SOLC, SolcInput,
artifacts::{Optimizer, output_selection::OutputSelection},
artifacts::{output_selection::OutputSelection, Optimizer},
compilers::CompilerInput,
SolcInput, SHANGHAI_SOLC,
};
if !sol_code_path.exists() {
@@ -1015,7 +1001,7 @@ pub fn fix_da_sol(commitment_bytes: Option<Vec<u8>>, only_kzg: bool) -> Result<S
require(checkKzgCommits(encoded), "Invalid KZG commitments");
// static call the verifier contract to verify the proof
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
if (success) {
return abi.decode(returndata, (bool));
} else {

View File

@@ -1,6 +1,5 @@
use crate::EZKL_BUF_CAPACITY;
use crate::circuit::CheckMode;
use crate::circuit::region::RegionSettings;
use crate::circuit::CheckMode;
use crate::commands::CalibrationTarget;
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_sol};
#[allow(unused_imports)]
@@ -10,21 +9,21 @@ use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
use crate::graph::{TestDataSource, TestSources};
use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript};
use crate::pfsys::{
ProofSplitCommit, create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit,
create_keys, load_pk, load_vk, save_params, save_pk, Snark, StrategyType, TranscriptType,
};
use crate::pfsys::{
Snark, StrategyType, TranscriptType, create_keys, load_pk, load_vk, save_params, save_pk,
create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit, ProofSplitCommit,
};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::EZKL_BUF_CAPACITY;
use crate::{commands::*, EZKLError};
use crate::{Commitments, RunArgs};
use crate::{EZKLError, commands::*};
use colored::Colorize;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
use halo2_proofs::plonk::{self, Circuit};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params};
use halo2_proofs::poly::commitment::{ParamsProver, Verifier};
use halo2_proofs::poly::ipa::commitment::{IPACommitmentScheme, ParamsIPA};
@@ -37,6 +36,7 @@ use halo2_proofs::poly::kzg::strategy::AccumulatorStrategy as KZGAccumulatorStra
use halo2_proofs::poly::kzg::{
commitment::ParamsKZG, strategy::SingleStrategy as KZGSingleStrategy,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer};
use halo2_solidity_verifier;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
@@ -48,12 +48,12 @@ use itertools::Itertools;
use lazy_static::lazy_static;
use log::debug;
use log::{info, trace, warn};
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde::Serialize;
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::Config;
use snark_verifier::system::halo2::compile;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use snark_verifier::system::halo2::Config;
use std::fs::File;
use std::io::BufWriter;
use std::io::{Cursor, Write};
@@ -176,7 +176,7 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
srs_path,
} => gen_witness(
compiled_circuit.unwrap_or(DEFAULT_COMPILED_CIRCUIT.into()),
data.unwrap_or(DEFAULT_DATA.into()),
data.unwrap_or(DataField(DEFAULT_DATA.into())).to_string(),
Some(output.unwrap_or(DEFAULT_WITNESS.into())),
vk_path,
srs_path,
@@ -1139,6 +1139,7 @@ pub(crate) async fn calibrate(
let mut num_failed = 0;
let mut num_passed = 0;
let mut failure_reasons = vec![];
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
pb.set_message(format!(
@@ -1163,20 +1164,21 @@ pub(crate) async fn calibrate(
// if unix get a gag
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
let _r = Gag::stdout().ok();
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
let _g = Gag::stderr().ok();
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(e) => {
error!("circuit creation from run args failed: {:?}", e);
failure_reasons.push(format!(
"i-scale: {}, p-scale: {}, rebase-(x): {}, reason: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().yellow(),
e
));
pb.inc(1);
num_failed += 1;
continue;
@@ -1221,6 +1223,13 @@ pub(crate) async fn calibrate(
error!("forward pass failed: {:?}", e);
pb.inc(1);
num_failed += 1;
failure_reasons.push(format!(
"i-scale: {}, p-scale: {}, rebase-(x): {}, reason: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().yellow(),
e
));
continue;
}
}
@@ -1286,7 +1295,14 @@ pub(crate) async fn calibrate(
found_settings.as_json()?.to_colored_json_auto()?
);
num_passed += 1;
} else {
} else if let Err(res) = res {
failure_reasons.push(format!(
"i-scale: {}, p-scale: {}, rebase-(x): {}, reason: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().yellow(),
res.to_string().red()
));
num_failed += 1;
}
@@ -1296,6 +1312,13 @@ pub(crate) async fn calibrate(
pb.finish_with_message("Calibration Done.");
if found_params.is_empty() {
if !failure_reasons.is_empty() {
error!("Calibration failed for the following reasons:");
for reason in failure_reasons {
error!("{}", reason);
}
}
return Err("calibration failed, could not find any suitable parameters given the calibration dataset".into());
}
@@ -1329,7 +1352,7 @@ pub(crate) async fn calibrate(
.clone()
}
CalibrationTarget::Accuracy => {
let param_iterator = found_params.iter().sorted_by_key(|p| {
let mut param_iterator = found_params.iter().sorted_by_key(|p| {
(
p.run_args.input_scale,
p.run_args.param_scale,
@@ -1338,7 +1361,7 @@ pub(crate) async fn calibrate(
)
});
let last = param_iterator.last().ok_or("no params found")?;
let last = param_iterator.next_back().ok_or("no params found")?;
let max_scale = (
last.run_args.input_scale,
last.run_args.param_scale,
@@ -1608,7 +1631,7 @@ pub(crate) async fn deploy_da_evm(
data: String,
settings_path: PathBuf,
sol_code_path: PathBuf,
rpc_url: Option<String>,
rpc_url: String,
addr_path: PathBuf,
runs: usize,
private_key: Option<String>,
@@ -1617,7 +1640,7 @@ pub(crate) async fn deploy_da_evm(
settings_path,
data,
sol_code_path,
rpc_url.as_deref(),
&rpc_url,
runs,
private_key.as_deref(),
)
@@ -1632,7 +1655,7 @@ pub(crate) async fn deploy_da_evm(
pub(crate) async fn deploy_evm(
sol_code_path: PathBuf,
rpc_url: Option<String>,
rpc_url: String,
addr_path: PathBuf,
runs: usize,
private_key: Option<String>,
@@ -1645,7 +1668,7 @@ pub(crate) async fn deploy_evm(
};
let contract_address = deploy_contract_via_solidity(
sol_code_path,
rpc_url.as_deref(),
&rpc_url,
runs,
private_key.as_deref(),
contract_name,
@@ -1688,7 +1711,7 @@ pub(crate) fn encode_evm_calldata(
pub(crate) async fn verify_evm(
proof_path: PathBuf,
addr_verifier: H160Flag,
rpc_url: Option<String>,
rpc_url: String,
addr_da: Option<H160Flag>,
addr_vk: Option<H160Flag>,
) -> Result<String, EZKLError> {
@@ -1702,7 +1725,7 @@ pub(crate) async fn verify_evm(
addr_verifier.into(),
addr_da.into(),
addr_vk.map(|s| s.into()),
rpc_url.as_deref(),
&rpc_url,
)
.await?
} else {
@@ -1710,7 +1733,7 @@ pub(crate) async fn verify_evm(
proof.clone(),
addr_verifier.into(),
addr_vk.map(|s| s.into()),
rpc_url.as_deref(),
&rpc_url,
)
.await?
};
@@ -1851,7 +1874,7 @@ pub(crate) async fn setup_test_evm_data(
data_path: String,
compiled_circuit_path: PathBuf,
test_data: PathBuf,
rpc_url: Option<String>,
rpc_url: String,
input_source: TestDataSource,
output_source: TestDataSource,
) -> Result<String, EZKLError> {

View File

@@ -67,8 +67,11 @@ pub enum GraphError {
#[error("invalid input types")]
InvalidInputTypes,
/// Missing results
#[error("missing results")]
MissingResults,
#[error("missing result for node {0}")]
MissingResults(usize),
/// Missing input
#[error("missing input {0}")]
MissingInputForNode(usize),
/// Tensor error
#[error("[tensor] {0}")]
TensorError(#[from] crate::tensor::TensorError),
@@ -98,8 +101,6 @@ pub enum GraphError {
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
#[error("[tokio postgres] {0}")]
TokioPostgresError(#[from] tokio_postgres::Error),
/// Eth error
#[cfg(all(
feature = "ezkl",
@@ -141,7 +142,9 @@ pub enum GraphError {
#[error("range check {0} is too large")]
RangeCheckTooLarge(usize),
///Cannot use on-chain data source as private data
#[error("cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm.")]
#[error(
"cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm."
)]
OnChainDataSource,
/// Missing data source
#[error("missing data source")]

View File

@@ -2,8 +2,6 @@ use super::errors::GraphError;
use super::quantize_float;
use crate::circuit::InputType;
use crate::fieldutils::integer_rep_to_felt;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::postgres::Client;
use crate::EZKL_BUF_CAPACITY;
use halo2curves::bn256::Fr as Fp;
#[cfg(feature = "python-bindings")]
@@ -23,9 +21,6 @@ use tract_onnx::tract_core::{
value::TValue,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
type Decimals = u8;
type Call = String;
type RPCUrl = String;
@@ -201,9 +196,9 @@ impl OnChainSource {
data: &FileSource,
scales: Vec<crate::Scale>,
mut shapes: Vec<Vec<usize>>,
rpc: Option<&str>,
rpc: &str,
) -> Result<Self, GraphError> {
use crate::eth::{read_on_chain_inputs, test_on_chain_data, DEFAULT_ANVIL_ENDPOINT};
use crate::eth::{read_on_chain_inputs, test_on_chain_data};
use log::debug;
// Set up local anvil instance for reading on-chain data
@@ -217,7 +212,7 @@ impl OnChainSource {
shapes[idx] = vec![i.len()];
}
}
let used_rpc = rpc.unwrap_or(DEFAULT_ANVIL_ENDPOINT).to_string();
let used_rpc = rpc.to_string();
let call_to_account = test_on_chain_data(client.clone(), data).await?;
debug!("Call to account: {:?}", call_to_account);
@@ -260,9 +255,6 @@ pub enum DataSource {
File(FileSource),
/// Data fetched from blockchain contracts
OnChain(OnChainSource),
/// Data from a PostgreSQL database
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
DB(PostgresSource),
}
impl Default for DataSource {
@@ -323,15 +315,6 @@ impl<'de> Deserialize<'de> for DataSource {
return Ok(DataSource::OnChain(t));
}
// Try deserializing as PostgresSource if feature enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
if let Ok(t) = third_try {
return Ok(DataSource::DB(t));
}
}
Err(serde::de::Error::custom("failed to deserialize DataSource"))
}
}
@@ -381,7 +364,7 @@ impl GraphData {
return Err(GraphError::InvalidDims(
0,
"non file data cannot be split into batches".to_string(),
))
));
}
}
Ok(inputs)
@@ -434,17 +417,15 @@ impl GraphData {
/// Loads graph input data from a string, first seeing if it is a file path or JSON data
/// If it is a file path, it will load the data from the file
/// Otherwise, it will attempt to parse the string as JSON data
///
///
/// # Arguments
/// * `data` - String containing the input data
/// # Returns
/// A new GraphData instance containing the loaded data
pub fn from_str(data: &str) -> Result<Self, GraphError> {
let graph_input = serde_json::from_str(data);
let graph_input = serde_json::from_str(data);
match graph_input {
Ok(graph_input) => {
return Ok(graph_input);
}
Ok(graph_input) => Ok(graph_input),
Err(_) => {
let path = std::path::PathBuf::from(data);
GraphData::from_path(path)
@@ -515,13 +496,8 @@ impl GraphData {
return Err(GraphError::InvalidDims(
0,
"on-chain data cannot be split into batches".to_string(),
))
));
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
GraphData {
input_data: DataSource::DB(data),
output_data: _,
} => data.fetch_and_format_as_file().await?,
};
// Process each input tensor according to its shape
@@ -538,7 +514,6 @@ impl GraphData {
input.len(),
input_size
),
));
}
@@ -592,28 +567,6 @@ impl GraphData {
mod tests {
use super::*;
#[test]
fn test_postgres_source_new() {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let source = PostgresSource::new(
"localhost".to_string(),
"5432".to_string(),
"user".to_string(),
"SELECT * FROM table".to_string(),
"database".to_string(),
"password".to_string(),
);
assert_eq!(source.host, "localhost");
assert_eq!(source.port, "5432");
assert_eq!(source.user, "user");
assert_eq!(source.query, "SELECT * FROM table");
assert_eq!(source.dbname, "database");
assert_eq!(source.password, "password");
}
}
#[test]
fn test_data_source_serialization_round_trip() {
// Test backwards compatibility with old format
@@ -656,95 +609,6 @@ mod tests {
}
}
/// Source data from a PostgreSQL database
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct PostgresSource {
/// Database host address
pub host: RPCUrl,
/// Database user name
pub user: String,
/// Database password
pub password: String,
/// SQL query to execute
pub query: String,
/// Database name
pub dbname: String,
/// Database port
pub port: String,
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl PostgresSource {
/// Creates a new PostgreSQL data source
pub fn new(
host: RPCUrl,
port: String,
user: String,
query: String,
dbname: String,
password: String,
) -> Self {
PostgresSource {
host,
user,
password,
query,
dbname,
port,
}
}
/// Fetches data from the PostgreSQL database
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
// Configuration string
let config = if self.password.is_empty() {
format!(
"host={} user={} dbname={} port={}",
self.host, self.user, self.dbname, self.port
)
} else {
format!(
"host={} user={} dbname={} port={} password={}",
self.host, self.user, self.dbname, self.port, self.password
)
};
let mut client = Client::connect(&config).await?;
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// Extract rows from query
for row in client.query(&self.query, &[]).await? {
for i in 0..row.len() {
res.push(row.get(i));
}
}
Ok(vec![res])
}
/// Fetches and formats data as FileSource
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
Ok(self
.fetch()
.await?
.iter()
.map(|d| {
d.iter()
.map(|d| {
FileSourceInner::Float(
d.n.as_ref()
.unwrap()
.to_f64()
.ok_or("could not convert decimal to f64")
.unwrap(),
)
})
.collect()
})
.collect())
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for CallToAccount {
fn to_object(&self, py: Python) -> PyObject {
@@ -768,14 +632,6 @@ impl ToPyObject for DataSource {
.unwrap();
dict.to_object(py)
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
DataSource::DB(source) => {
let dict = PyDict::new(py);
dict.set_item("host", &source.host).unwrap();
dict.set_item("user", &source.user).unwrap();
dict.set_item("query", &source.query).unwrap();
dict.to_object(py)
}
}
}
}

View File

@@ -6,9 +6,6 @@ pub mod model;
pub mod modules;
/// Inner elements of a computational graph that represent a single operation / constraints.
pub mod node;
/// postgres helper functions
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub mod postgres;
/// Helper functions
pub mod utilities;
/// Representations of a computational graph's variables.
@@ -574,7 +571,7 @@ impl GraphSettings {
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, &self).map_err(|e| {
error!("failed to save settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
std::io::Error::other(e)
})
}
/// load params from file
@@ -584,7 +581,7 @@ impl GraphSettings {
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
let settings: GraphSettings = serde_json::from_reader(reader).map_err(|e| {
error!("failed to load settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
std::io::Error::other(e)
})?;
crate::check_version_string_matches(&settings.version);
@@ -764,7 +761,7 @@ pub struct TestOnChainData {
/// The path to the test witness
pub data: std::path::PathBuf,
/// rpc endpoint
pub rpc: Option<String>,
pub rpc: String,
/// data sources for the on chain data
pub data_sources: TestSources,
}
@@ -1011,10 +1008,6 @@ impl GraphCircuit {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
DataSource::DB(pg) => {
let data = pg.fetch_and_format_as_file().await?;
self.load_file_data(&data, &shapes, scales, input_types)
}
}
}
@@ -1027,7 +1020,7 @@ impl GraphCircuit {
scales: Vec<crate::Scale>,
) -> Result<Vec<Tensor<Fp>>, GraphError> {
use crate::eth::{evm_quantize, read_on_chain_inputs, setup_eth_backend};
let (client, client_address) = setup_eth_backend(Some(&source.rpc), None).await?;
let (client, client_address) = setup_eth_backend(&source.rpc, None).await?;
let input = read_on_chain_inputs(client.clone(), client_address, &source.call).await?;
let quantized_evm_inputs =
evm_quantize(client, scales, &input, &source.call.decimals).await?;
@@ -1239,15 +1232,9 @@ impl GraphCircuit {
let mut cs = ConstraintSystem::default();
// if unix get a gag
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
let _r = Gag::stdout().ok();
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
let _g = Gag::stderr().ok();
Self::configure_with_params(&mut cs, settings);
@@ -1481,13 +1468,9 @@ impl GraphCircuit {
// print file data
debug!("file data: {:?}", file_data);
let on_chain_data: OnChainSource = OnChainSource::test_from_file_data(
&file_data,
scales,
shapes,
test_on_chain_data.rpc.as_deref(),
)
.await?;
let on_chain_data: OnChainSource =
OnChainSource::test_from_file_data(&file_data, scales, shapes, &test_on_chain_data.rpc)
.await?;
// Here we update the GraphData struct with the on-chain data
if input_data.is_some() {
data.input_data = on_chain_data.clone().into();
@@ -1587,13 +1570,13 @@ impl Circuit<Fp> for GraphCircuit {
let mut module_configs = ModuleConfigs::from_visibility(
cs,
params.module_sizes.clone(),
&params.module_sizes,
params.run_args.logrows as usize,
);
let mut vars = ModelVars::new(cs, &params);
module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone());
module_configs.configure_complex_modules(cs, &visibility, &params.module_sizes);
vars.instantiate_instance(
cs,

View File

@@ -200,7 +200,7 @@ fn number_of_iterations(mappings: &[InputMapping], dims: Vec<&[usize]>) -> usize
InputMapping::Stacked { axis, chunk } => Some(
// number of iterations given the dim size along the axis
// and the chunk size
(dims[*axis] + chunk - 1) / chunk,
dims[*axis].div_ceil(*chunk), // (dims[*axis] + chunk - 1) / chunk,
),
_ => None,
});
@@ -589,10 +589,7 @@ impl Model {
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
input_types: match self.get_input_types() {
Ok(x) => Some(x),
Err(_) => None,
},
input_types: self.get_input_types().ok(),
output_types: Some(self.get_output_types()),
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
@@ -650,10 +647,13 @@ impl Model {
let variables: std::collections::HashMap<String, usize> =
std::collections::HashMap::from_iter(variables.iter().map(|(k, v)| (k.clone(), *v)));
for (i, id) in model.clone().inputs.iter().enumerate() {
let inputs = model.inputs.clone();
let outputs = model.outputs.clone();
for (i, id) in inputs.iter().enumerate() {
let input = model.node_mut(id.node);
if input.outputs.len() == 0 {
if input.outputs.is_empty() {
return Err(GraphError::MissingOutput(id.node));
}
let mut fact: InferenceFact = input.outputs[0].fact.clone();
@@ -672,7 +672,7 @@ impl Model {
model.set_input_fact(i, fact)?;
}
for (i, _) in model.clone().outputs.iter().enumerate() {
for (i, _) in outputs.iter().enumerate() {
model.set_output_fact(i, InferenceFact::default())?;
}
@@ -1196,7 +1196,7 @@ impl Model {
.base
.layout(
&mut thread_safe_region,
&[output.clone(), comparators],
&[output, &comparators],
Box::new(HybridOp::Output {
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
@@ -1257,12 +1257,27 @@ impl Model {
node.inputs()
.iter()
.map(|(idx, outlet)| {
Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
// check node is not an output
let is_output = self.graph.outputs.iter().any(|(o_idx, _)| *idx == *o_idx);
let res = if self.graph.nodes[idx].num_uses() == 1 && !is_output {
let res = results.remove(idx);
res.ok_or(GraphError::MissingResults(*idx))?[*outlet].clone()
} else {
results.get(idx).ok_or(GraphError::MissingResults(*idx))?[*outlet]
.clone()
};
Ok(res)
})
.collect::<Result<Vec<_>, GraphError>>()?
} else {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
if self.graph.nodes[idx].num_uses() == 1 {
let res = results.remove(idx);
vec![res.ok_or(GraphError::MissingInput(*idx))?[0].clone()]
} else {
vec![results.get(idx).ok_or(GraphError::MissingInput(*idx))?[0].clone()]
}
};
trace!("output dims: {:?}", node.out_dims());
trace!(
@@ -1273,7 +1288,7 @@ impl Model {
let start = instant::Instant::now();
match &node {
NodeType::Node(n) => {
let res = if node.is_constant() && node.num_uses() == 1 {
let mut res = if node.is_constant() && node.num_uses() == 1 {
log::debug!("node {} is a constant with 1 use", n.idx);
let mut node = n.clone();
let c = node
@@ -1284,19 +1299,19 @@ impl Model {
} else {
config
.base
.layout(region, &values, n.opkind.clone_dyn())
.layout(region, &values.iter().collect_vec(), n.opkind.clone_dyn())
.map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?
};
if let Some(mut vt) = res {
if let Some(vt) = &mut res {
vt.reshape(&node.out_dims()[0])?;
// we get the max as for fused nodes this corresponds to the node output
results.insert(*idx, vec![vt.clone()]);
//only use with mock prover
debug!("------------ output node {:?}: {:?}", idx, vt.show());
// we get the max as for fused nodes this corresponds to the node output
results.insert(*idx, vec![vt.clone()]);
}
}
NodeType::SubGraph {
@@ -1340,7 +1355,7 @@ impl Model {
.inputs
.clone()
.into_iter()
.zip(values.clone().into_iter().map(|v| vec![v])),
.zip(values.iter().map(|v| vec![v.clone()])),
);
let res = model.layout_nodes(config, region, &mut subgraph_results)?;
@@ -1421,7 +1436,7 @@ impl Model {
);
let outputs = output_nodes
.map(|(idx, outlet)| {
Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
Ok(results.get(idx).ok_or(GraphError::MissingResults(*idx))?[*outlet].clone())
})
.collect::<Result<Vec<_>, GraphError>>()?;
@@ -1476,7 +1491,7 @@ impl Model {
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
&[output, &comparator],
Box::new(HybridOp::Output {
decomp: !run_args.ignore_range_check_inputs_outputs,
}),

View File

@@ -37,15 +37,15 @@ impl ModuleConfigs {
/// Create new module configs from visibility of each variable
pub fn from_visibility(
cs: &mut ConstraintSystem<Fp>,
module_size: ModuleSizes,
module_size: &ModuleSizes,
logrows: usize,
) -> Self {
let mut config = Self::default();
for size in module_size.polycommit {
for size in &module_size.polycommit {
config
.polycommit
.push(PolyCommitChip::configure(cs, (logrows, size)));
.push(PolyCommitChip::configure(cs, (logrows, *size)));
}
config
@@ -55,8 +55,8 @@ impl ModuleConfigs {
pub fn configure_complex_modules(
&mut self,
cs: &mut ConstraintSystem<Fp>,
visibility: VarVisibility,
module_size: ModuleSizes,
visibility: &VarVisibility,
module_size: &ModuleSizes,
) {
if (visibility.input.is_hashed()
|| visibility.output.is_hashed()

View File

@@ -37,6 +37,7 @@ use crate::tensor::TensorError;
// Import curve-specific field type
use halo2curves::bn256::Fr as Fp;
use itertools::Itertools;
// Import logging for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use log::trace;
@@ -118,16 +119,15 @@ impl Op<Fp> for Rescaled {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
values: &[&crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
if self.scale.len() != values.len() {
return Err(TensorError::DimMismatch("rescaled inputs".to_string()).into());
}
let res =
&crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)?
[..];
self.inner.layout(config, region, res)
crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)?;
self.inner.layout(config, region, &res.iter().collect_vec())
}
/// Create a cloned boxed copy of this operation
@@ -274,13 +274,13 @@ impl Op<Fp> for RebaseScale {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
values: &[&crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
let original_res = self
.inner
.layout(config, region, values)?
.ok_or(CircuitError::MissingLayout(self.as_string()))?;
self.rebase_op.layout(config, region, &[original_res])
self.rebase_op.layout(config, region, &[&original_res])
}
/// Create a cloned boxed copy of this operation
@@ -472,7 +472,7 @@ impl Op<Fp> for SupportedOp {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
values: &[&crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
self.as_op().layout(config, region, values)
}

View File

@@ -1,493 +0,0 @@
use log::{debug, error, info};
use std::fmt::Debug;
use std::net::IpAddr;
#[cfg(all(not(not(feature = "ezkl")), unix))]
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, pin::Pin};
use tokio::task::JoinHandle;
#[doc(inline)]
pub use tokio_postgres::config::{
ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs,
};
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::NoTls;
use tokio_postgres::{error::DbError, types::ToSql, Error, Row, Socket, ToStatement};
/// Connection configuration.
///
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
///
///
#[derive(Clone)]
pub struct Config {
config: tokio_postgres::Config,
notice_callback: Arc<dyn Fn(DbError) + Send + Sync>,
}
impl fmt::Debug for Config {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Config")
.field("config", &self.config)
.finish()
}
}
impl Default for Config {
fn default() -> Config {
Config::new()
}
}
impl Config {
/// Creates a new configuration.
pub fn new() -> Config {
tokio_postgres::Config::new().into()
}
/// Sets the user to authenticate with.
///
/// If the user is not set, then this defaults to the user executing this process.
pub fn user(&mut self, user: &str) -> &mut Config {
self.config.user(user);
self
}
/// Gets the user to authenticate with, if one has been configured with
/// the `user` method.
pub fn get_user(&self) -> Option<&str> {
self.config.get_user()
}
/// Sets the password to authenticate with.
pub fn password<T>(&mut self, password: T) -> &mut Config
where
T: AsRef<[u8]>,
{
self.config.password(password);
self
}
/// Gets the password to authenticate with, if one has been configured with
/// the `password` method.
pub fn get_password(&self) -> Option<&[u8]> {
self.config.get_password()
}
/// Sets the name of the database to connect to.
///
/// Defaults to the user.
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
self.config.dbname(dbname);
self
}
/// Gets the name of the database to connect to, if one has been configured
/// with the `dbname` method.
pub fn get_dbname(&self) -> Option<&str> {
self.config.get_dbname()
}
/// Sets command line options used to configure the server.
pub fn options(&mut self, options: &str) -> &mut Config {
self.config.options(options);
self
}
/// Gets the command line options used to configure the server, if the
/// options have been set with the `options` method.
pub fn get_options(&self) -> Option<&str> {
self.config.get_options()
}
/// Sets the value of the `application_name` runtime parameter.
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
self.config.application_name(application_name);
self
}
/// Gets the value of the `application_name` runtime parameter, if it has
/// been set with the `application_name` method.
pub fn get_application_name(&self) -> Option<&str> {
self.config.get_application_name()
}
/// Sets the SSL configuration.
///
/// Defaults to `prefer`.
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
self.config.ssl_mode(ssl_mode);
self
}
/// Gets the SSL configuration.
pub fn get_ssl_mode(&self) -> SslMode {
self.config.get_ssl_mode()
}
/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
/// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets.
/// There must be either no hosts, or the same number of hosts as hostaddrs.
pub fn host(&mut self, host: &str) -> &mut Config {
self.config.host(host);
self
}
/// Gets the hosts that have been added to the configuration with `host`.
pub fn get_hosts(&self) -> &[Host] {
self.config.get_hosts()
}
/// Gets the hostaddrs that have been added to the configuration with `hostaddr`.
pub fn get_hostaddrs(&self) -> &[IpAddr] {
self.config.get_hostaddrs()
}
/// Adds a Unix socket host to the configuration.
///
/// Unlike `host`, this method allows non-UTF8 paths.
#[cfg(all(not(not(feature = "ezkl")), unix))]
pub fn host_path<T>(&mut self, host: T) -> &mut Config
where
T: AsRef<Path>,
{
self.config.host_path(host);
self
}
/// Adds a hostaddr to the configuration.
///
/// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order.
/// There must be either no hostaddrs, or the same number of hostaddrs as hosts.
pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
self.config.hostaddr(hostaddr);
self
}
/// Adds a port to the configuration.
///
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
/// as hosts.
pub fn port(&mut self, port: u16) -> &mut Config {
self.config.port(port);
self
}
/// Gets the ports that have been added to the configuration with `port`.
pub fn get_ports(&self) -> &[u16] {
self.config.get_ports()
}
/// Sets the timeout applied to socket-level connection attempts.
///
/// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
/// host separately. Defaults to no limit.
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
self.config.connect_timeout(connect_timeout);
self
}
/// Gets the connection timeout, if one has been set with the
/// `connect_timeout` method.
pub fn get_connect_timeout(&self) -> Option<&Duration> {
self.config.get_connect_timeout()
}
/// Sets the TCP user timeout.
///
/// This is ignored for Unix domain socket connections. It is only supported on systems where
/// TCP_USER_TIMEOUT is available and will default to the system default if omitted or set to 0;
/// on other systems, it has no effect.
pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
self.config.tcp_user_timeout(tcp_user_timeout);
self
}
/// Gets the TCP user timeout, if one has been set with the
/// `user_timeout` method.
pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
self.config.get_tcp_user_timeout()
}
/// Controls the use of TCP keepalive.
///
/// This is ignored for Unix domain socket connections. Defaults to `true`.
pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
self.config.keepalives(keepalives);
self
}
/// Reports whether TCP keepalives will be used.
pub fn get_keepalives(&self) -> bool {
self.config.get_keepalives()
}
/// Sets the amount of idle time before a keepalive packet is sent on the connection.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours.
pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
self.config.keepalives_idle(keepalives_idle);
self
}
/// Gets the configured amount of idle time before a keepalive packet will
/// be sent on the connection.
pub fn get_keepalives_idle(&self) -> Duration {
self.config.get_keepalives_idle()
}
/// Sets the time interval between TCP keepalive probes.
/// On Windows, this sets the value of the tcp_keepalive structs keepaliveinterval field.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
self.config.keepalives_interval(keepalives_interval);
self
}
/// Gets the time interval between TCP keepalive probes.
pub fn get_keepalives_interval(&self) -> Option<Duration> {
self.config.get_keepalives_interval()
}
/// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
self.config.keepalives_retries(keepalives_retries);
self
}
/// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
pub fn get_keepalives_retries(&self) -> Option<u32> {
self.config.get_keepalives_retries()
}
/// Sets the requirements of the session.
///
/// This can be used to connect to the primary server in a clustered database rather than one of the read-only
/// secondary servers. Defaults to `Any`.
pub fn target_session_attrs(
&mut self,
target_session_attrs: TargetSessionAttrs,
) -> &mut Config {
self.config.target_session_attrs(target_session_attrs);
self
}
/// Gets the requirements of the session.
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
self.config.get_target_session_attrs()
}
/// Sets the channel binding behavior.
///
/// Defaults to `prefer`.
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
self.config.channel_binding(channel_binding);
self
}
/// Gets the channel binding behavior.
pub fn get_channel_binding(&self) -> ChannelBinding {
self.config.get_channel_binding()
}
/// Sets the host load balancing behavior.
///
/// Defaults to `disable`.
pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
self.config.load_balance_hosts(load_balance_hosts);
self
}
/// Gets the host load balancing behavior.
pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
self.config.get_load_balance_hosts()
}
/// Sets the notice callback.
///
/// This callback will be invoked with the contents of every
/// [`AsyncMessage::Notice`] that is received by the connection. Notices use
/// the same structure as errors, but they are not "errors" per-se.
///
/// Notices are distinct from notifications, which are instead accessible
/// via the [`Notifications`] API.
///
/// [`AsyncMessage::Notice`]: tokio_postgres::AsyncMessage::Notice
/// [`Notifications`]: crate::Notifications
pub fn notice_callback<F>(&mut self, f: F) -> &mut Config
where
F: Fn(DbError) + Send + Sync + 'static,
{
self.notice_callback = Arc::new(f);
self
}
/// Opens a connection to a PostgreSQL database.
pub async fn connect(&self) -> Result<Client, Error> {
let (client, connection) = self.config.connect(NoTls).await?;
let connection = Connection::new(connection);
Ok(Client::new(client, connection))
}
}
impl FromStr for Config {
type Err = Error;
fn from_str(s: &str) -> Result<Config, Error> {
s.parse::<tokio_postgres::Config>().map(Config::from)
}
}
impl From<tokio_postgres::Config> for Config {
fn from(config: tokio_postgres::Config) -> Config {
Config {
config,
notice_callback: Arc::new(|notice| {
info!("{}: {}", notice.severity(), notice.message())
}),
}
}
}
#[allow(missing_debug_implementations, dead_code)]
/// An asynchronous PostgreSQL connection. We use this to keep the connection alive / keep it pinned so that it doesn't
/// get dropped.
pub struct Connection {
/// The underlying connection stream.
connection: Pin<Box<tokio_postgres::Connection<Socket, NoTlsStream>>>,
}
impl Connection {
/// Creates a new connection.
pub fn new(connection: tokio_postgres::Connection<Socket, NoTlsStream>) -> Self {
Connection {
connection: Box::pin(connection),
}
}
/// start the connection
pub async fn start(self) {
if let Err(e) = self.connection.await {
error!("connection error: {}", e);
}
}
}
#[allow(missing_debug_implementations, dead_code)]
/// An asynchronous PostgreSQL client.
pub struct Client {
connection: JoinHandle<()>,
client: tokio_postgres::Client,
}
impl Drop for Client {
fn drop(&mut self) {
let _ = self.close_inner();
}
}
impl Client {
pub(crate) fn new(client: tokio_postgres::Client, connection: Connection) -> Client {
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
let thread = tokio::spawn(async move {
connection.start().await;
});
Client {
client,
connection: thread,
}
}
/// A convenience function which parses a configuration string into a `Config` and then connects to the database.
///
/// See the documentation for [`Config`] for information about the connection syntax.
///
/// [`Config`]: config/struct.Config.html
pub async fn connect(params: &str) -> Result<Client, Error> {
debug!("Connecting to database with params: {}", params);
params.parse::<Config>()?.connect().await
}
/// Returns a new `Config` object which can be used to configure and connect to a database.
pub fn configure() -> Config {
Config::new()
}
/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
///
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
pub async fn execute<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error>
where
T: ?Sized + ToStatement + Debug,
{
debug!("Executing query: {:?}", query);
self.client.execute(query, params).await
}
/// Executes a statement, returning the resulting rows.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Examples
///
pub async fn query<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement + Debug,
{
debug!("Executing query: {:?}", query);
self.client.query(query, params).await
}
/// Determines if the client's connection has already closed.
///
/// If this returns `true`, the client is no longer usable.
pub fn is_closed(&self) -> bool {
self.client.is_closed()
}
/// Closes the client's connection to the server.
///
/// This is equivalent to `Client`'s `Drop` implementation, except that it returns any error encountered to the
/// caller.
pub fn close(mut self) -> Result<(), Error> {
self.close_inner()
}
fn close_inner(&mut self) -> Result<(), Error> {
self.client.__private_api_close();
Ok(())
}
}

View File

@@ -1,14 +1,14 @@
use super::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
use super::errors::GraphError;
use super::{Rescaled, SupportedOp, Visibility};
use crate::circuit::Op;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::hybrid::HybridOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::lookup::LookupOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::poly::PolyOp;
use crate::circuit::Op;
use crate::fieldutils::IntegerRep;
use crate::tensor::{Tensor, TensorError, TensorType};
use halo2curves::bn256::Fr as Fp;
@@ -22,7 +22,6 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::ops::{
Downsample,
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
@@ -32,6 +31,7 @@ use tract_onnx::tract_core::ops::{
einsum::EinSum,
element_wise::ElementWiseOp,
nn::{LeakyRelu, Reduce, Softmax},
Downsample,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::{
@@ -858,6 +858,7 @@ pub fn new_op_from_onnx(
SupportedOp::Hybrid(HybridOp::Recip {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
eps: run_args.get_epsilon(),
})
}
@@ -903,6 +904,7 @@ pub fn new_op_from_onnx(
SupportedOp::Hybrid(HybridOp::Rsqrt {
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
eps: run_args.get_epsilon(),
})
}
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
@@ -913,6 +915,7 @@ pub fn new_op_from_onnx(
if run_args.bounded_log_lookup {
SupportedOp::Hybrid(HybridOp::Ln {
scale: scale_to_multiplier(input_scales[0]).into(),
eps: run_args.get_epsilon(),
})
} else {
SupportedOp::Nonlinear(LookupOp::Ln {
@@ -1131,6 +1134,7 @@ pub fn new_op_from_onnx(
input_scale: scale_to_multiplier(in_scale).into(),
output_scale: scale_to_multiplier(max_scale).into(),
axes: softmax_op.axes.to_vec(),
eps: run_args.get_epsilon(),
})
}
"MaxPool" => {

View File

@@ -29,8 +29,14 @@
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
use log::warn;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc as _;
#[global_allocator]
#[cfg(all(feature = "jemalloc", not(target_arch = "wasm32")))]
static GLOBAL: jemallocator::Jemalloc = jemallocator::Jemalloc;
#[global_allocator]
#[cfg(all(feature = "mimalloc", not(target_arch = "wasm32")))]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
/// Error type
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
@@ -350,6 +356,16 @@ pub struct RunArgs {
arg(long, default_value = "false")
)]
pub ignore_range_check_inputs_outputs: bool,
/// Optional override for epsilon value
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
pub epsilon: Option<f64>,
}
impl RunArgs {
/// Returns the epsilon value
pub fn get_epsilon(&self) -> f64 {
self.epsilon.unwrap_or(f64::EPSILON)
}
}
impl Default for RunArgs {
@@ -376,6 +392,7 @@ impl Default for RunArgs {
decomp_base: 16384,
decomp_legs: 2,
ignore_range_check_inputs_outputs: false,
epsilon: None,
}
}
}

View File

@@ -17,16 +17,16 @@ use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
use clap::ValueEnum;
use halo2_proofs::circuit::Value;
use halo2_proofs::plonk::{
Circuit, ProvingKey, VerifyingKey, create_proof, keygen_pk, keygen_vk_custom, verify_proof,
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
use halo2curves::CurveAffine;
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use halo2curves::serde::SerdeObject;
use halo2curves::CurveAffine;
use instant::Instant;
use log::{debug, info, trace};
#[cfg(not(feature = "det-prove"))]
@@ -324,7 +324,7 @@ where
}
#[cfg(feature = "python-bindings")]
use pyo3::{PyObject, Python, ToPyObject, types::PyDict};
use pyo3::{types::PyDict, PyObject, Python, ToPyObject};
#[cfg(feature = "python-bindings")]
impl<F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> ToPyObject for Snark<F, C>
where
@@ -348,9 +348,9 @@ where
}
impl<
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
where
C::Scalar: Serialize + DeserializeOwned,
C::ScalarExt: Serialize + DeserializeOwned,

View File

@@ -27,7 +27,7 @@ pub use var::*;
use crate::{
circuit::utils,
fieldutils::{IntegerRep, integer_rep_to_felt},
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::Visibility,
};
@@ -42,11 +42,11 @@ use std::error::Error;
use std::fmt::Debug;
use std::io::Read;
use std::iter::Iterator;
use std::ops::Rem;
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
use std::{cmp::max, ops::Rem};
/// The (inner) type of tensor elements.
pub trait TensorType: Clone + Debug + 'static {
pub trait TensorType: Clone + Debug {
/// Returns the zero value.
fn zero() -> Option<Self> {
None
@@ -55,10 +55,6 @@ pub trait TensorType: Clone + Debug + 'static {
fn one() -> Option<Self> {
None
}
/// Max operator for ordering values.
fn tmax(&self, _: &Self) -> Option<Self> {
None
}
}
macro_rules! tensor_type {
@@ -70,10 +66,6 @@ macro_rules! tensor_type {
fn one() -> Option<Self> {
Some($one)
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some(max(*self, *other))
}
}
};
}
@@ -82,46 +74,12 @@ impl TensorType for f32 {
fn zero() -> Option<Self> {
Some(0.0)
}
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
// A comparison between f32s needs to handle NAN values.
fn tmax(&self, other: &Self) -> Option<Self> {
match (self.is_nan(), other.is_nan()) {
(true, true) => Some(f32::NAN),
(true, false) => Some(*other),
(false, true) => Some(*self),
(false, false) => {
if self >= other {
Some(*self)
} else {
Some(*other)
}
}
}
}
}
impl TensorType for f64 {
fn zero() -> Option<Self> {
Some(0.0)
}
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
// A comparison between f32s needs to handle NAN values.
fn tmax(&self, other: &Self) -> Option<Self> {
match (self.is_nan(), other.is_nan()) {
(true, true) => Some(f64::NAN),
(true, false) => Some(*other),
(false, true) => Some(*self),
(false, false) => {
if self >= other {
Some(*self)
} else {
Some(*other)
}
}
}
}
}
tensor_type!(bool, Bool, false, true);
@@ -147,14 +105,6 @@ impl<T: TensorType> TensorType for Value<T> {
fn one() -> Option<Self> {
Some(Value::known(T::one().unwrap()))
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some(
(self.clone())
.zip(other.clone())
.map(|(a, b)| a.tmax(&b).unwrap()),
)
}
}
impl<F: PrimeField + PartialOrd> TensorType for Assigned<F>
@@ -168,14 +118,6 @@ where
fn one() -> Option<Self> {
Some(F::ONE.into())
}
fn tmax(&self, other: &Self) -> Option<Self> {
if self.evaluate() >= other.evaluate() {
Some(*self)
} else {
Some(*other)
}
}
}
impl<F: PrimeField> TensorType for Expression<F>
@@ -189,42 +131,14 @@ where
fn one() -> Option<Self> {
Some(Expression::Constant(F::ONE))
}
fn tmax(&self, _: &Self) -> Option<Self> {
todo!()
}
}
impl TensorType for Column<Advice> {}
impl TensorType for Column<Fixed> {}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<Assigned<F>, F> {
fn tmax(&self, other: &Self) -> Option<Self> {
let mut output: Option<Self> = None;
self.value_field().zip(other.value_field()).map(|(a, b)| {
if a.evaluate() >= b.evaluate() {
output = Some(self.clone());
} else {
output = Some(other.clone());
}
});
output
}
}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<Assigned<F>, F> {}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<F, F> {
fn tmax(&self, other: &Self) -> Option<Self> {
let mut output: Option<Self> = None;
self.value().zip(other.value()).map(|(a, b)| {
if a >= b {
output = Some(self.clone());
} else {
output = Some(other.clone());
}
});
output
}
}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<F, F> {}
// specific types
impl TensorType for halo2curves::pasta::Fp {
@@ -235,10 +149,6 @@ impl TensorType for halo2curves::pasta::Fp {
fn one() -> Option<Self> {
Some(halo2curves::pasta::Fp::one())
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some((*self).max(*other))
}
}
impl TensorType for halo2curves::bn256::Fr {
@@ -249,9 +159,15 @@ impl TensorType for halo2curves::bn256::Fr {
fn one() -> Option<Self> {
Some(halo2curves::bn256::Fr::one())
}
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some((*self).max(*other))
impl<F: TensorType> TensorType for &F {
fn zero() -> Option<Self> {
None
}
fn one() -> Option<Self> {
None
}
}
@@ -374,7 +290,7 @@ impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
}
}
impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync>
impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync + 'data>
maybe_rayon::iter::IntoParallelRefMutIterator<'data> for Tensor<T>
{
type Iter = maybe_rayon::slice::IterMut<'data, T>;
@@ -423,6 +339,14 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
}
}
impl<T: Clone + TensorType> Tensor<&T> {
/// Clones the tensor values into a new tensor.
pub fn cloned(&self) -> Tensor<T> {
let inner = self.inner.clone().into_iter().cloned().collect::<Vec<T>>();
Tensor::new(Some(&inner), &self.dims).unwrap()
}
}
impl<T: Clone + TensorType> Tensor<T> {
/// Sets (copies) the tensor values to the provided ones.
pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result<Self, TensorError> {
@@ -554,7 +478,6 @@ impl<T: Clone + TensorType> Tensor<T> {
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
/// ```
@@ -631,23 +554,23 @@ impl<T: Clone + TensorType> Tensor<T> {
// Fill remaining dimensions
full_indices.extend((indices.len()..self.dims.len()).map(|i| 0..self.dims[i]));
// Pre-calculate total size and allocate result vector
let total_size: usize = full_indices
.iter()
.map(|range| range.end - range.start)
.product();
let mut res = Vec::with_capacity(total_size);
// Calculate new dimensions once
let dims: Vec<usize> = full_indices.iter().map(|e| e.end - e.start).collect();
// Use iterator directly without collecting into intermediate Vec
for coord in full_indices.iter().cloned().multi_cartesian_product() {
let index = self.get_index(&coord);
res.push(self[index].clone());
}
let mut output = Tensor::new(None, &dims)?;
Tensor::new(Some(&res), &dims)
let cartesian_coord: Vec<Vec<usize>> = full_indices
.iter()
.cloned()
.multi_cartesian_product()
.collect();
output.par_iter_mut().enumerate().for_each(|(i, e)| {
let coord = &cartesian_coord[i];
*e = self.get(coord);
});
Ok(output)
}
/// Set a slice of the Tensor.
@@ -753,7 +676,7 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
pub fn get_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
for (i, elem) in self.inner.clone().into_iter().enumerate() {
for (i, elem) in self.inner.iter().enumerate() {
if i % n == 0 {
inner.push(elem.clone());
}
@@ -776,7 +699,7 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
pub fn exclude_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
for (i, elem) in self.inner.clone().into_iter().enumerate() {
for (i, elem) in self.inner.iter().enumerate() {
if i % n != 0 {
inner.push(elem.clone());
}
@@ -812,9 +735,9 @@ impl<T: Clone + TensorType> Tensor<T> {
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
let mut offset = initial_offset;
for (i, elem) in self.inner.clone().into_iter().enumerate() {
for (i, elem) in self.inner.iter().enumerate() {
if (i + offset + 1) % n == 0 {
inner.extend(vec![elem; 1 + num_repeats]);
inner.extend(vec![elem.clone(); 1 + num_repeats]);
offset += num_repeats;
} else {
inner.push(elem.clone());
@@ -871,16 +794,16 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
/// let mut indices = vec![3, 4];
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), expected);
/// assert_eq!(a.remove_indices(&mut indices, false).unwrap(), expected);
///
///
/// let a = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
/// let b = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
/// let mut indices = vec![63, 67];
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), b);
/// assert_eq!(a.remove_indices(&mut indices, false).unwrap(), b);
/// ```
pub fn remove_indices(
&self,
@@ -927,7 +850,7 @@ impl<T: Clone + TensorType> Tensor<T> {
}
self.dims = vec![];
}
if self.dims() == &[0] && new_dims.iter().product::<usize>() == 1 {
if self.dims() == [0] && new_dims.iter().product::<usize>() == 1 {
self.dims = Vec::from(new_dims);
} else {
let product = if new_dims != [0] {
@@ -1292,33 +1215,6 @@ impl<T: Clone + TensorType> Tensor<T> {
Tensor::new(Some(&[res]), &[1])
}
/// Maps a function to tensors and enumerates in parallel
/// ```
/// use ezkl::tensor::{Tensor, TensorError};
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
/// ```
pub fn par_enum_map_mut_filtered<
F: Fn(usize) -> Result<T, E> + std::marker::Send + std::marker::Sync,
E: Error + std::marker::Send + std::marker::Sync,
>(
&mut self,
filter_indices: &std::collections::HashSet<usize>,
f: F,
) -> Result<(), E>
where
T: std::marker::Send + std::marker::Sync,
{
self.inner
.par_iter_mut()
.enumerate()
.filter(|(i, _)| filter_indices.contains(i))
.for_each(move |(i, e)| *e = f(i).unwrap());
Ok(())
}
}
impl<T: Clone + TensorType> Tensor<Tensor<T>> {
@@ -1335,9 +1231,9 @@ impl<T: Clone + TensorType> Tensor<Tensor<T>> {
pub fn combine(&self) -> Result<Tensor<T>, TensorError> {
let mut dims = 0;
let mut inner = Vec::new();
for t in self.inner.clone().into_iter() {
for t in self.inner.iter() {
dims += t.len();
inner.extend(t.inner);
inner.extend(t.inner.clone());
}
Tensor::new(Some(&inner), &[dims])
}
@@ -1808,7 +1704,7 @@ impl DataFormat {
match self {
DataFormat::NHWC => DataFormat::NCHW,
DataFormat::HWC => DataFormat::CHW,
_ => self.clone(),
_ => *self,
}
}
@@ -1908,7 +1804,7 @@ impl KernelFormat {
match self {
KernelFormat::HWIO => KernelFormat::OIHW,
KernelFormat::OHWI => KernelFormat::OIHW,
_ => self.clone(),
_ => *self,
}
}
@@ -2030,6 +1926,9 @@ mod tests {
fn tensor_slice() {
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
assert_eq!(
a.get_slice(&[0..2, 0..1]).unwrap(),
b.get_slice(&[0..2, 0..1]).unwrap()
);
}
}

View File

@@ -160,7 +160,7 @@ pub fn decompose(
///
/// let result = trilu(&a, 0, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 3, 4, 5, 6]), &[1, 3, 2]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let result = trilu(&a, -1, true).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 0, 6]), &[1, 3, 2]).unwrap();
@@ -168,7 +168,7 @@ pub fn decompose(
///
/// let result = trilu(&a, -1, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 3, 0, 5, 6]), &[1, 3, 2]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let a = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6]),
@@ -188,7 +188,7 @@ pub fn decompose(
///
/// let result = trilu(&a, 0, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 4, 5, 0]), &[1, 2, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let result = trilu(&a, -1, true).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[1, 2, 3]).unwrap();
@@ -196,7 +196,7 @@ pub fn decompose(
///
/// let result = trilu(&a, -1, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 4, 0, 0]), &[1, 2, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let a = Tensor::<IntegerRep>::new(
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9]),
@@ -216,7 +216,7 @@ pub fn decompose(
///
/// let result = trilu(&a, 0, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 4, 5, 0, 7, 8, 9]), &[1, 3, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
///
/// let result = trilu(&a, -1, true).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 8, 9]), &[1, 3, 3]).unwrap();
@@ -224,7 +224,7 @@ pub fn decompose(
///
/// let result = trilu(&a, -1, false).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 4, 0, 0, 7, 8, 0]), &[1, 3, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result, expected);
/// ```
pub fn trilu<T: TensorType + std::marker::Send + std::marker::Sync>(
a: &Tensor<T>,
@@ -916,11 +916,14 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 7]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
pub fn gather_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
pub fn gather_nd<'a, T: TensorType + Send + Sync + 'a>(
input: &'a Tensor<T>,
index: &Tensor<usize>,
batch_dims: usize,
) -> Result<Tensor<T>, TensorError> {
) -> Result<Tensor<T>, TensorError>
where
&'a T: TensorType,
{
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
@@ -1108,11 +1111,14 @@ pub fn gather_nd<T: TensorType + Send + Sync>(
/// assert_eq!(result, expected);
/// ````
///
pub fn scatter_nd<T: TensorType + Send + Sync>(
pub fn scatter_nd<'a, T: TensorType + Send + Sync + 'a>(
input: &Tensor<T>,
index: &Tensor<usize>,
src: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
src: &'a Tensor<T>,
) -> Result<Tensor<T>, TensorError>
where
&'a T: TensorType,
{
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
@@ -1183,12 +1189,12 @@ pub fn abs<T: TensorType + Add<Output = T> + std::cmp::Ord + Neg<Output = T>>(
/// use ezkl::tensor::ops::intercalate_values;
///
/// let tensor = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
/// let result = intercalate_values(&tensor, 0, 2, 1).unwrap();
/// let result = intercalate_values(&tensor, &0, 2, 1).unwrap();
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 2, 3, 0, 4]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// let result = intercalate_values(&expected, 0, 2, 0).unwrap();
/// let result = intercalate_values(&expected, &0, 2, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 2, 0, 0, 0, 3, 0, 4]), &[3, 3]).unwrap();
///
/// assert_eq!(result, expected);
@@ -1196,7 +1202,7 @@ pub fn abs<T: TensorType + Add<Output = T> + std::cmp::Ord + Neg<Output = T>>(
/// ```
pub fn intercalate_values<T: TensorType>(
tensor: &Tensor<T>,
value: T,
value: &T,
stride: usize,
axis: usize,
) -> Result<Tensor<T>, TensorError> {
@@ -1494,7 +1500,7 @@ pub fn slice<T: TensorType + Send + Sync>(
}
}
t.get_slice(&slice)
Ok(t.get_slice(&slice)?)
}
// ---------------------------------------------------------------------------------------------------------
@@ -1859,14 +1865,14 @@ pub mod nonlinearities {
/// Some(&[4, 25, 8, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = rsqrt(&x, 1.0);
/// let result = rsqrt(&x, 1.0, f64::EPSILON);
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 1, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn rsqrt(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
pub fn rsqrt(a: &Tensor<IntegerRep>, scale_input: f64, eps: f64) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let kix = (a_i as f64) / scale_input;
let fout = scale_input / (kix.sqrt() + f64::EPSILON);
let fout = scale_input / (kix.sqrt() + eps);
let rounded = fout.round();
Ok::<_, TensorError>(rounded as IntegerRep)
})
@@ -2339,15 +2345,20 @@ pub mod nonlinearities {
/// &[2, 3],
/// ).unwrap();
/// let k = 2_f64;
/// let result = recip(&x, 1.0, k);
/// let result = recip(&x, 1.0, k, f64::EPSILON);
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 1, 0, 2, 2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn recip(a: &Tensor<IntegerRep>, input_scale: f64, out_scale: f64) -> Tensor<IntegerRep> {
pub fn recip(
a: &Tensor<IntegerRep>,
input_scale: f64,
out_scale: f64,
eps: f64,
) -> Tensor<IntegerRep> {
a.par_enum_map(|_, a_i| {
let rescaled = (a_i as f64) / input_scale;
let denom = if rescaled == 0_f64 {
(1_f64) / (rescaled + f64::EPSILON)
(1_f64) / (rescaled + eps)
} else {
(1_f64) / (rescaled)
};
@@ -2366,16 +2377,16 @@ pub mod nonlinearities {
/// use ezkl::fieldutils::IntegerRep;
/// use ezkl::tensor::ops::nonlinearities::zero_recip;
/// let k = 2_f64;
/// let result = zero_recip(1.0);
/// let result = zero_recip(1.0, f64::EPSILON);
/// let expected = Tensor::<IntegerRep>::new(Some(&[4503599627370496]), &[1]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn zero_recip(out_scale: f64) -> Tensor<IntegerRep> {
pub fn zero_recip(out_scale: f64, eps: f64) -> Tensor<IntegerRep> {
let a = Tensor::<IntegerRep>::new(Some(&[0]), &[1]).unwrap();
a.par_enum_map(|_, a_i| {
let rescaled = a_i as f64;
let denom = (1_f64) / (rescaled + f64::EPSILON);
let denom = (1_f64) / (rescaled + eps);
let d_inv_x = out_scale * denom;
Ok::<_, TensorError>(d_inv_x.round() as IntegerRep)
})
@@ -2409,20 +2420,20 @@ pub mod accumulated {
/// Some(&[25, 35]),
/// &[2],
/// ).unwrap();
/// assert_eq!(dot(&[x, y], 1).unwrap(), expected);
/// assert_eq!(dot(&x, &y, 1).unwrap(), expected);
/// ```
pub fn dot<T: TensorType + Mul<Output = T> + Add<Output = T>>(
inputs: &[Tensor<T>; 2],
a: &Tensor<T>,
b: &Tensor<T>,
chunk_size: usize,
) -> Result<Tensor<T>, TensorError> {
if inputs[0].clone().len() != inputs[1].clone().len() {
if a.len() != b.len() {
return Err(TensorError::DimMismatch("dot".to_string()));
}
let (a, b): (Tensor<T>, Tensor<T>) = (inputs[0].clone(), inputs[1].clone());
let transcript: Tensor<T> = a
.iter()
.zip(b)
.zip(b.iter())
.chunks(chunk_size)
.into_iter()
.scan(T::zero().unwrap(), |acc, chunk| {

View File

@@ -280,9 +280,17 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Vec<ValType<F>>> for ValTenso
fn from(t: Vec<ValType<F>>) -> ValTensor<F> {
ValTensor::Value {
inner: t.clone().into_iter().into(),
dims: vec![t.len()],
scale: 1,
}
}
}
impl<F: PrimeField + TensorType + PartialOrd> From<Vec<&ValType<F>>> for ValTensor<F> {
fn from(t: Vec<&ValType<F>>) -> ValTensor<F> {
ValTensor::Value {
inner: t.clone().into_iter().cloned().into(),
dims: vec![t.len()],
scale: 1,
}
}
@@ -640,6 +648,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// # Returns
/// A tensor containing the base-n decomposition of each value
pub fn decompose(&self, base: usize, n: usize) -> Result<Self, TensorError> {
let mut dims = self.dims().to_vec();
dims.push(n + 1);
let res = self
.get_inner()?
.par_iter()
@@ -665,9 +676,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
})
.collect::<Result<Vec<_>, _>>();
let mut tensor = Tensor::from(res?.into_iter().flatten().collect::<Vec<_>>().into_iter());
let mut dims = self.dims().to_vec();
dims.push(n + 1);
let mut tensor = res?.into_iter().flatten().collect::<Tensor<_>>();
tensor.reshape(&dims)?;
@@ -1043,119 +1052,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(())
}
/// Removes constant zero values from the tensor
/// Uses parallel processing for tensors larger than a threshold
pub fn remove_const_zero_values(&mut self) {
let size_threshold = 1_000_000; // Tuned using benchmarks
if self.len() < size_threshold {
// Single-threaded for small tensors
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.clone()
.into_iter()
.filter_map(|e| {
if let ValType::Constant(r) = e {
if r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if r == F::ZERO {
return None;
}
}
Some(e)
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => {}
}
} else {
// Parallel processing for large tensors
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.par_chunks_mut(chunk_size)
.flat_map(|chunk| {
chunk.par_iter_mut().filter_map(|e| {
if let ValType::Constant(r) = e {
if *r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if *r == F::ZERO {
return None;
}
}
Some(e.clone())
})
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => {}
}
}
}
/// Gets the indices of all constant zero values
/// Uses parallel processing for large tensors
///
/// # Returns
/// A vector of indices where constant zero values are located
pub fn get_const_zero_indices(&self) -> Vec<usize> {
let size_threshold = 1_000_000;
if self.len() < size_threshold {
// Single-threaded for small tensors
match &self {
ValTensor::Value { inner: v, .. } => v
.iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect(),
ValTensor::Instance { .. } => vec![],
}
} else {
// Parallel processing for large tensors
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match &self {
ValTensor::Value { inner: v, .. } => v
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
chunk
.par_iter()
.enumerate()
.filter_map(move |(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
}
_ => None,
})
})
.collect::<Vec<_>>(),
ValTensor::Instance { .. } => vec![],
}
}
}
/// Gets the indices of all constant values
/// Uses parallel processing for large tensors
///
@@ -1276,7 +1172,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Returns an error if called on an Instance tensor
pub fn intercalate_values(
&mut self,
value: ValType<F>,
value: &ValType<F>,
stride: usize,
axis: usize,
) -> Result<(), TensorError> {
@@ -1368,10 +1264,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// Concatenates two tensors along the first dimension
pub fn concat(&self, other: Self) -> Result<Self, TensorError> {
pub fn concat(&self, other: &Self) -> Result<Self, TensorError> {
let res = match (self, other) {
(ValTensor::Value { inner: v1, .. }, ValTensor::Value { inner: v2, .. }) => {
ValTensor::from(Tensor::new(Some(&[v1.clone(), v2]), &[2])?.combine()?)
ValTensor::from(Tensor::new(Some(&[v1.clone(), v2.clone()]), &[2])?.combine()?)
}
_ => {
return Err(TensorError::WrongMethod);

View File

@@ -1,8 +1,6 @@
use std::collections::HashSet;
use log::{debug, error, warn};
use crate::circuit::{CheckMode, region::ConstantsMap};
use crate::circuit::{region::ConstantsMap, CheckMode};
use super::*;
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
@@ -381,50 +379,6 @@ impl VarTensor {
}
}
/// Assigns values from a ValTensor to this tensor, excluding specified positions
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `omissions` - Set of positions to skip during assignment
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// The assigned ValTensor or an error if assignment fails
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
omissions: &HashSet<usize>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut assigned_coord = 0;
let mut res: ValTensor<F> = match values {
ValTensor::Instance { .. } => {
error!(
"assignment with omissions is not supported on instance columns. increase K if you require more rows."
);
Err(halo2_proofs::plonk::Error::Synthesis)
}
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
v.enum_map(|coord, k| {
if omissions.contains(&coord) {
return Ok::<_, halo2_proofs::plonk::Error>(k);
}
let cell =
self.assign_value(region, offset, k.clone(), assigned_coord, constants)?;
assigned_coord += 1;
Ok::<_, halo2_proofs::plonk::Error>(cell)
})?
.into(),
),
}?;
res.set_scale(values.scale());
Ok(res)
}
/// Assigns values from a ValTensor to this tensor
///
/// # Arguments

Binary file not shown.

View File

@@ -3,10 +3,10 @@
mod native_tests {
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::Commitments;
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
use ezkl::pfsys::Snark;
use ezkl::Commitments;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::Bn256;
use lazy_static::lazy_static;

View File

@@ -272,16 +272,6 @@ mod py_tests {
anvil_child.kill().unwrap();
}
#[test]
fn postgres_notebook_() {
crate::py_tests::init_binary();
let test_dir: TempDir = TempDir::new("mean_postgres").unwrap();
let path = test_dir.path().to_str().unwrap();
crate::py_tests::mv_test_(path, "mean_postgres.ipynb");
run_notebook(path, "mean_postgres.ipynb");
test_dir.close().unwrap();
}
#[test]
fn tictactoe_autoencoder_notebook_() {
crate::py_tests::init_binary();

View File

@@ -479,15 +479,15 @@ async def test_deploy_evm_reusable_and_vka():
res = await ezkl.deploy_evm(
addr_path_verifier,
sol_code_path,
anvil_url,
sol_code_path,
"verifier/reusable",
)
res = await ezkl.deploy_evm(
addr_path_vk,
vk_code_path,
anvil_url,
vk_code_path,
"vka",
)
@@ -506,8 +506,8 @@ async def test_deploy_evm():
res = await ezkl.deploy_evm(
addr_path,
sol_code_path,
anvil_url,
sol_code_path,
)
assert res == True
@@ -528,8 +528,8 @@ async def test_deploy_evm_with_private_key():
res = await ezkl.deploy_evm(
addr_path,
anvil_url,
sol_code_path,
rpc_url=anvil_url,
private_key=anvil_default_private_key
)
@@ -540,8 +540,8 @@ async def test_deploy_evm_with_private_key():
with pytest.raises(RuntimeError, match="Failed to run deploy_evm"):
res = await ezkl.deploy_evm(
addr_path,
anvil_url,
sol_code_path,
rpc_url=anvil_url,
private_key=custom_zero_balance_private_key
)
@@ -564,8 +564,8 @@ async def test_verify_evm():
res = await ezkl.verify_evm(
addr,
anvil_url,
proof_path,
rpc_url=anvil_url,
# sol_code_path
# optimizer_runs
)
@@ -604,8 +604,8 @@ async def test_verify_evm_separate_vk():
res = await ezkl.verify_evm(
addr_verifier,
anvil_url,
proof_path,
rpc_url=anvil_url,
addr_vk=addr_vk,
# sol_code_path
# optimizer_runs
@@ -831,8 +831,8 @@ async def test_evm_aggregate_and_verify_aggr():
res = await ezkl.deploy_evm(
addr_path,
anvil_url,
sol_code_path,
rpc_url=anvil_url,
)
# as a sanity check