mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
5 Commits
ac/f32-eps
...
v22.0.4
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7b955e8fd7 | ||
|
|
40ce9dfde9 | ||
|
|
839030ce10 | ||
|
|
cfccc5460c | ||
|
|
0de0682bfa |
6
.github/workflows/release.yml
vendored
6
.github/workflows/release.yml
vendored
@@ -198,15 +198,15 @@ jobs:
|
||||
|
||||
- name: Build release binary (no asm or metal)
|
||||
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features mimalloc
|
||||
|
||||
- name: Build release binary (asm)
|
||||
if: matrix.build == 'linux-gnu'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm,mimalloc
|
||||
|
||||
- name: Build release binary (metal)
|
||||
if: matrix.build == 'macos-aarch64'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal,mimalloc
|
||||
|
||||
- name: Strip release binary
|
||||
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
|
||||
|
||||
45
.github/workflows/rust.yml
vendored
45
.github/workflows/rust.yml
vendored
@@ -245,25 +245,6 @@ jobs:
|
||||
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
|
||||
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
|
||||
|
||||
foudry-solidity-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: non-gpu
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
|
||||
with:
|
||||
persist-credentials: false
|
||||
submodules: recursive
|
||||
|
||||
- name: Install Foundry
|
||||
uses: foundry-rs/foundry-toolchain@3b74dacdda3c0b763089addb99ed86bc3800e68b
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
cd tests/foundry
|
||||
forge install https://github.com/foundry-rs/forge-std --no-git --no-commit
|
||||
forge test -vvvv --fuzz-runs 64
|
||||
|
||||
mock-proving-tests:
|
||||
permissions:
|
||||
contents: read
|
||||
@@ -372,9 +353,6 @@ jobs:
|
||||
- name: Build wasm package for nodejs target.
|
||||
run: |
|
||||
wasm-pack build --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
|
||||
- name: Build @ezkljs/verify package
|
||||
run: |
|
||||
cd in-browser-evm-verifier
|
||||
@@ -497,9 +475,6 @@ jobs:
|
||||
- name: Build wasm package for nodejs target.
|
||||
run: |
|
||||
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
@@ -744,24 +719,6 @@ jobs:
|
||||
permissions:
|
||||
contents: read
|
||||
runs-on: large-self-hosted
|
||||
services:
|
||||
# Label used to access the service container
|
||||
postgres:
|
||||
# Docker Hub image
|
||||
image: postgres
|
||||
env:
|
||||
POSTGRES_USER: ubuntu
|
||||
POSTGRES_HOST_AUTH_METHOD: trust
|
||||
# Set health checks to wait until postgres has started
|
||||
options: >-
|
||||
--health-cmd pg_isready
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
-v /var/run/postgresql:/var/run/postgresql
|
||||
ports:
|
||||
# Maps tcp port 5432 on service container to the host
|
||||
- 5432:5432
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
|
||||
with:
|
||||
@@ -798,8 +755,6 @@ jobs:
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
|
||||
- name: Felt conversion
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::felt_conversion_test_ --no-capture
|
||||
- name: Postgres tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
|
||||
- name: Tictactoe tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
|
||||
# - name: authenticate-kaggle-cli
|
||||
|
||||
2589
Cargo.lock
generated
2589
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
26
Cargo.toml
26
Cargo.toml
@@ -3,7 +3,7 @@ cargo-features = ["profile-rustflags"]
|
||||
[package]
|
||||
name = "ezkl"
|
||||
version = "0.0.0"
|
||||
edition = "2024"
|
||||
edition = "2021"
|
||||
default-run = "ezkl"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
@@ -69,20 +69,18 @@ reqwest = { version = "0.12.4", default-features = false, features = [
|
||||
"stream",
|
||||
], optional = true }
|
||||
openssl = { version = "0.10.55", features = ["vendored"], optional = true }
|
||||
tokio-postgres = { version = "0.7.10", optional = true }
|
||||
pg_bigdecimal = { version = "0.1.5", optional = true }
|
||||
lazy_static = { version = "1.4.0", optional = true }
|
||||
colored_json = { version = "3.0.1", default-features = false, optional = true }
|
||||
tokio = { version = "1.35.0", default-features = false, features = [
|
||||
"macros",
|
||||
"rt-multi-thread",
|
||||
], optional = true }
|
||||
pyo3 = { version = "0.23.2", features = [
|
||||
pyo3 = { version = "0.24.2", features = [
|
||||
"extension-module",
|
||||
"abi3-py37",
|
||||
"macros",
|
||||
], default-features = false, optional = true }
|
||||
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.23.0", features = [
|
||||
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.24.0", features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
], default-features = false, optional = true }
|
||||
@@ -90,9 +88,9 @@ pyo3-log = { version = "0.12.0", default-features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
objc = { version = "0.2.4", optional = true }
|
||||
mimalloc = { version = "0.1", optional = true }
|
||||
pyo3-stub-gen = { version = "0.6.0", optional = true }
|
||||
|
||||
jemallocator = { version = "0.5", optional = true }
|
||||
mimalloc = { version = "0.1", optional = true }
|
||||
# universal bindings
|
||||
uniffi = { version = "=0.28.0", optional = true }
|
||||
getrandom = { version = "0.2.8", optional = true }
|
||||
@@ -242,12 +240,9 @@ ezkl = [
|
||||
"dep:indicatif",
|
||||
"dep:gag",
|
||||
"dep:reqwest",
|
||||
"dep:tokio-postgres",
|
||||
"dep:pg_bigdecimal",
|
||||
"dep:lazy_static",
|
||||
"dep:tokio",
|
||||
"dep:openssl",
|
||||
"dep:mimalloc",
|
||||
"dep:chrono",
|
||||
"dep:sha256",
|
||||
"dep:clap_complete",
|
||||
@@ -274,22 +269,19 @@ no-banner = []
|
||||
no-update = []
|
||||
macos-metal = ["halo2_proofs/macos"]
|
||||
ios-metal = ["halo2_proofs/ios"]
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
|
||||
jemalloc = ["dep:jemallocator"]
|
||||
mimalloc = ["dep:mimalloc"]
|
||||
|
||||
|
||||
[patch.crates-io]
|
||||
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
|
||||
|
||||
[profile.release]
|
||||
# debug = true
|
||||
rustflags = ["-C", "relocation-model=pic"]
|
||||
lto = "fat"
|
||||
codegen-units = 1
|
||||
#panic = "abort"
|
||||
# panic = "abort"
|
||||
|
||||
|
||||
[profile.test-runs]
|
||||
|
||||
@@ -68,7 +68,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
|
||||
&[&self.image, &self.kernel, &self.bias],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(0, 0)],
|
||||
stride: vec![1; 2],
|
||||
|
||||
@@ -15,6 +15,7 @@ use halo2_proofs::{
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
@@ -59,7 +60,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "i,i->".to_string(),
|
||||
}),
|
||||
|
||||
@@ -15,6 +15,7 @@ use halo2_proofs::{
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
@@ -61,7 +62,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ab,bc->ac".to_string(),
|
||||
}),
|
||||
|
||||
@@ -17,6 +17,7 @@ use halo2_proofs::{
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use itertools::Itertools;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@@ -86,13 +87,13 @@ impl Circuit<Fr> for MyCircuit {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
let output = config
|
||||
.base_config
|
||||
.layout(&mut region, &self.inputs, Box::new(op))
|
||||
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
|
||||
.unwrap();
|
||||
let _output = config
|
||||
.base_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
&[&output.unwrap()],
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -17,6 +17,7 @@ use halo2_proofs::{
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use itertools::Itertools;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
@@ -87,13 +88,13 @@ impl Circuit<Fr> for MyCircuit {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
let output = config
|
||||
.base_config
|
||||
.layout(&mut region, &self.inputs, Box::new(op))
|
||||
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
|
||||
.unwrap();
|
||||
let _output = config
|
||||
.base_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
&[&output.unwrap()],
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -15,6 +15,7 @@ use halo2_proofs::{
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
@@ -59,7 +60,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Sum { axes: vec![0] }),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -63,7 +63,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.image.clone()],
|
||||
&[&self.image],
|
||||
Box::new(HybridOp::SumPool {
|
||||
padding: vec![(0, 0); 2],
|
||||
stride: vec![1, 1],
|
||||
|
||||
@@ -15,6 +15,7 @@ use halo2_proofs::{
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
@@ -57,7 +58,11 @@ impl Circuit<Fr> for MyCircuit {
|
||||
|region| {
|
||||
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs, Box::new(PolyOp::Add))
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Add),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
|
||||
@@ -16,6 +16,7 @@ use halo2_proofs::{
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use itertools::Itertools;
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
@@ -58,7 +59,11 @@ impl Circuit<Fr> for MyCircuit {
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs, Box::new(PolyOp::Pow(4)))
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Pow(4)),
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
},
|
||||
|
||||
@@ -70,7 +70,7 @@ impl Circuit<Fr> for NLCircuit {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
&[&self.input],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
|
||||
@@ -67,7 +67,7 @@ impl Circuit<Fr> for NLCircuit {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
&[&self.input],
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '22.0.4'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ use mnist::*;
|
||||
use rand::rngs::OsRng;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
|
||||
mod params;
|
||||
|
||||
const K: usize = 20;
|
||||
@@ -216,11 +215,7 @@ where
|
||||
.layer_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[
|
||||
self.input.clone(),
|
||||
self.l0_params[0].clone(),
|
||||
self.l0_params[1].clone(),
|
||||
],
|
||||
&[&self.input, &self.l0_params[0], &self.l0_params[1]],
|
||||
Box::new(op),
|
||||
)
|
||||
.unwrap();
|
||||
@@ -229,7 +224,7 @@ where
|
||||
.layer_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x.unwrap()],
|
||||
&[&x.unwrap()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
@@ -241,7 +236,7 @@ where
|
||||
.layer_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x.unwrap()],
|
||||
&[&x.unwrap()],
|
||||
Box::new(LookupOp::Div { denom: 32.0.into() }),
|
||||
)
|
||||
.unwrap()
|
||||
@@ -253,7 +248,7 @@ where
|
||||
.layer_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.l2_params[0].clone(), x],
|
||||
&[&self.l2_params[0], &x],
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ij,j->ik".to_string(),
|
||||
}),
|
||||
@@ -265,7 +260,7 @@ where
|
||||
.layer_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x, self.l2_params[1].clone()],
|
||||
&[&x, &self.l2_params[1]],
|
||||
Box::new(PolyOp::Add),
|
||||
)
|
||||
.unwrap()
|
||||
|
||||
@@ -117,10 +117,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.layer_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[
|
||||
self.l0_params[0].clone().try_into().unwrap(),
|
||||
self.input.clone(),
|
||||
],
|
||||
&[&self.l0_params[0].clone().try_into().unwrap(), &self.input],
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ab,bc->ac".to_string(),
|
||||
}),
|
||||
@@ -135,7 +132,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.layer_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x, self.l0_params[1].clone().try_into().unwrap()],
|
||||
&[&x, &self.l0_params[1].clone().try_into().unwrap()],
|
||||
Box::new(PolyOp::Add),
|
||||
)
|
||||
.unwrap()
|
||||
@@ -147,7 +144,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.layer_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x],
|
||||
&[&x],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
scale: 1,
|
||||
slope: 0.0.into(),
|
||||
@@ -163,7 +160,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.layer_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.l2_params[0].clone().try_into().unwrap(), x],
|
||||
&[&self.l2_params[0].clone().try_into().unwrap(), &x],
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ab,bc->ac".to_string(),
|
||||
}),
|
||||
@@ -178,7 +175,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.layer_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x, self.l2_params[1].clone().try_into().unwrap()],
|
||||
&[&x, &self.l2_params[1].clone().try_into().unwrap()],
|
||||
Box::new(PolyOp::Add),
|
||||
)
|
||||
.unwrap()
|
||||
@@ -190,7 +187,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.layer_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x],
|
||||
&[&x],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
scale: 1,
|
||||
slope: 0.0.into(),
|
||||
@@ -203,7 +200,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
|
||||
.layer_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[x.unwrap()],
|
||||
&[&x.unwrap()],
|
||||
Box::new(LookupOp::Div {
|
||||
denom: ezkl::circuit::utils::F32::from(128.),
|
||||
}),
|
||||
|
||||
@@ -1,462 +0,0 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Mean of ERC20 transfer amounts\n",
|
||||
"\n",
|
||||
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
|
||||
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
|
||||
"\n",
|
||||
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"import getpass\n",
|
||||
"import json\n",
|
||||
"import time\n",
|
||||
"import subprocess\n",
|
||||
"\n",
|
||||
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
|
||||
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
|
||||
"os.system(\"chmod +x shovel\")\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
|
||||
"\n",
|
||||
"# create a config.json file with the following contents\n",
|
||||
"config = {\n",
|
||||
" \"pg_url\": \"$PG_URL\",\n",
|
||||
" \"eth_sources\": [\n",
|
||||
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
|
||||
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
|
||||
" ],\n",
|
||||
" \"integrations\": [{\n",
|
||||
" \"name\": \"usdc_transfer\",\n",
|
||||
" \"enabled\": True,\n",
|
||||
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
|
||||
" \"table\": {\n",
|
||||
" \"name\": \"usdc\",\n",
|
||||
" \"columns\": [\n",
|
||||
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
|
||||
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
|
||||
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
|
||||
" ]\n",
|
||||
" },\n",
|
||||
" \"block\": [\n",
|
||||
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
|
||||
" {\n",
|
||||
" \"name\": \"log_addr\",\n",
|
||||
" \"column\": \"log_addr\",\n",
|
||||
" \"filter_op\": \"contains\",\n",
|
||||
" \"filter_arg\": [\n",
|
||||
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
|
||||
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" ],\n",
|
||||
" \"event\": {\n",
|
||||
" \"name\": \"Transfer\",\n",
|
||||
" \"type\": \"event\",\n",
|
||||
" \"anonymous\": False,\n",
|
||||
" \"inputs\": [\n",
|
||||
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
|
||||
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
|
||||
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
|
||||
" ]\n",
|
||||
" }\n",
|
||||
" }]\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"# write the config to a file\n",
|
||||
"with open(\"config.json\", \"w\") as f:\n",
|
||||
" f.write(json.dumps(config))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# print the two env variables\n",
|
||||
"os.system(\"echo $PG_URL\")\n",
|
||||
"\n",
|
||||
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel is now installed. starting:\")\n",
|
||||
"\n",
|
||||
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
|
||||
"proc = subprocess.Popen(command)\n",
|
||||
"\n",
|
||||
"os.system(\"echo shovel started.\")\n",
|
||||
"\n",
|
||||
"time.sleep(10)\n",
|
||||
"\n",
|
||||
"# after we've fetched some data -- kill the process\n",
|
||||
"proc.terminate()\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "2wIAHwqH2_mo"
|
||||
},
|
||||
"source": [
|
||||
"**Import Dependencies**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "9Byiv2Nc2MsK"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# check if notebook is in colab\n",
|
||||
"try:\n",
|
||||
" # install ezkl\n",
|
||||
" import google.colab\n",
|
||||
" import subprocess\n",
|
||||
" import sys\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
|
||||
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
|
||||
"\n",
|
||||
"# rely on local installation of ezkl if the notebook is not in colab\n",
|
||||
"except:\n",
|
||||
" pass\n",
|
||||
"\n",
|
||||
"import ezkl\n",
|
||||
"import torch\n",
|
||||
"import datetime\n",
|
||||
"import pandas as pd\n",
|
||||
"import requests\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"import logging\n",
|
||||
"# # uncomment for more descriptive logging \n",
|
||||
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
|
||||
"logging.basicConfig(format=FORMAT)\n",
|
||||
"logging.getLogger().setLevel(logging.DEBUG)\n",
|
||||
"\n",
|
||||
"print(\"ezkl version: \", ezkl.__version__)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "osjj-0Ta3E8O"
|
||||
},
|
||||
"source": [
|
||||
"**Create Computational Graph**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "x1vl9ZXF3EEW",
|
||||
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from torch import nn\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"class Model(nn.Module):\n",
|
||||
" def __init__(self):\n",
|
||||
" super(Model, self).__init__()\n",
|
||||
"\n",
|
||||
" # x is a time series \n",
|
||||
" def forward(self, x):\n",
|
||||
" return [torch.mean(x)]\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"circuit = Model()\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
|
||||
"\n",
|
||||
"# # print(torch.__version__)\n",
|
||||
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
|
||||
"\n",
|
||||
"print(device)\n",
|
||||
"\n",
|
||||
"circuit.to(device)\n",
|
||||
"\n",
|
||||
"# Flips the neural net into inference mode\n",
|
||||
"circuit.eval()\n",
|
||||
"\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(circuit, # model being run\n",
|
||||
" x, # model input (or a tuple for multiple inputs)\n",
|
||||
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
|
||||
" export_params=True, # store the trained parameter weights inside the model file\n",
|
||||
" opset_version=11, # the ONNX version to export the model to\n",
|
||||
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
|
||||
" input_names = ['input'], # the model's input names\n",
|
||||
" output_names = ['output'], # the model's output names\n",
|
||||
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
|
||||
" 'output' : {0 : 'batch_size'}})\n",
|
||||
"\n",
|
||||
"# export(circuit, input_shape=[1, 20])\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "E3qCeX-X5xqd"
|
||||
},
|
||||
"source": [
|
||||
"**Set Data Source and Get Data**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "6RAMplxk5xPk",
|
||||
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"# make an input.json file from the df above\n",
|
||||
"input_filename = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
"pg_input_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
|
||||
"print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# this corresponds to 4 batches\n",
|
||||
"calibration_filename = os.path.join('calibration.json')\n",
|
||||
"\n",
|
||||
"pg_cal_file = dict(input_data = {\n",
|
||||
" \"host\": \"localhost\",\n",
|
||||
" # make sure you replace this with your own username\n",
|
||||
" \"user\": getpass.getuser(),\n",
|
||||
" \"dbname\": \"shovel\",\n",
|
||||
" \"password\": \"\",\n",
|
||||
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
|
||||
" \"port\": \"5432\",\n",
|
||||
"})\n",
|
||||
"\n",
|
||||
" # Serialize data into file:\n",
|
||||
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "eLJ7oirQ_HQR"
|
||||
},
|
||||
"source": [
|
||||
"**EZKL Workflow**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "rNw0C9QL6W88"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import subprocess\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"onnx_filename = os.path.join('lol.onnx')\n",
|
||||
"compiled_filename = os.path.join('lol.compiled')\n",
|
||||
"settings_filename = os.path.join('settings.json')\n",
|
||||
"\n",
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"run_args.decomp_legs = 4\n",
|
||||
"\n",
|
||||
"# Generate settings using ezkl\n",
|
||||
"res = ezkl.gen_settings(onnx_filename, settings_filename, py_run_args=run_args)\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"await ezkl.get_srs(settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "4MmE9SX66_Il",
|
||||
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# generate settings\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# show the settings.json\n",
|
||||
"with open(\"settings.json\") as f:\n",
|
||||
" data = json.load(f)\n",
|
||||
" json_formatted_str = json.dumps(data, indent=2)\n",
|
||||
"\n",
|
||||
" print(json_formatted_str)\n",
|
||||
"\n",
|
||||
"assert os.path.exists(\"settings.json\")\n",
|
||||
"assert os.path.exists(\"input.json\")\n",
|
||||
"assert os.path.exists(\"lol.onnx\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "fULvvnK7_CMb"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"pk_path = os.path.join('test.pk')\n",
|
||||
"vk_path = os.path.join('test.vk')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# setup the proof\n",
|
||||
"res = ezkl.setup(\n",
|
||||
" compiled_filename,\n",
|
||||
" vk_path,\n",
|
||||
" pk_path\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
"assert os.path.isfile(vk_path)\n",
|
||||
"assert os.path.isfile(pk_path)\n",
|
||||
"assert os.path.isfile(settings_filename)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"witness_path = \"witness.json\"\n",
|
||||
"\n",
|
||||
"# generate the witness\n",
|
||||
"res = await ezkl.gen_witness(\n",
|
||||
" input_filename,\n",
|
||||
" compiled_filename,\n",
|
||||
" witness_path\n",
|
||||
" )\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"base_uri": "https://localhost:8080/"
|
||||
},
|
||||
"id": "Oog3j6Kd-Wed",
|
||||
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# prove the zk circuit\n",
|
||||
"# GENERATE A PROOF\n",
|
||||
"proof_path = os.path.join('test.pf')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"proof = ezkl.prove(\n",
|
||||
" witness_path,\n",
|
||||
" compiled_filename,\n",
|
||||
" pk_path,\n",
|
||||
" proof_path,\n",
|
||||
" \"single\"\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"print(\"proved\")\n",
|
||||
"\n",
|
||||
"assert os.path.isfile(proof_path)\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": ".env",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.12.7"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
BIN
examples/onnx/fr_age/heaptrack.ezkl.3356626.gz
Normal file
BIN
examples/onnx/fr_age/heaptrack.ezkl.3356626.gz
Normal file
Binary file not shown.
37795
examples/onnx/fr_age/lol.txt
Normal file
37795
examples/onnx/fr_age/lol.txt
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,5 @@
|
||||
// ignore file if compiling for wasm
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use mimalloc::MiMalloc;
|
||||
|
||||
#[global_allocator]
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
static GLOBAL: MiMalloc = MiMalloc;
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use clap::{CommandFactory, Parser};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
|
||||
@@ -1,34 +1,34 @@
|
||||
use crate::Commitments;
|
||||
use crate::RunArgs;
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::modules::polycommit::PolyCommitChip;
|
||||
use crate::circuit::modules::poseidon::{
|
||||
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
|
||||
PoseidonChip,
|
||||
spec::{POSEIDON_RATE, POSEIDON_WIDTH, PoseidonSpec},
|
||||
};
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::circuit::InputType;
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt};
|
||||
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
use crate::graph::TestDataSource;
|
||||
use crate::graph::{
|
||||
GraphCircuit, GraphSettings, Model, Visibility, quantize_float, scale_to_multiplier,
|
||||
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
|
||||
};
|
||||
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
|
||||
use crate::pfsys::{
|
||||
ProofType, TranscriptType, load_pk, load_vk, save_params, save_vk,
|
||||
srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
|
||||
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
|
||||
ProofType, TranscriptType,
|
||||
};
|
||||
use crate::Commitments;
|
||||
use crate::RunArgs;
|
||||
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2curves::bn256::{Bn256, Fq, Fr, G1, G1Affine};
|
||||
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
|
||||
use pyo3::exceptions::{PyIOError, PyRuntimeError};
|
||||
use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
use pyo3_log;
|
||||
use pyo3_stub_gen::{
|
||||
TypeInfo, define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
|
||||
derive::gen_stub_pyfunction,
|
||||
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
|
||||
derive::gen_stub_pyfunction, TypeInfo,
|
||||
};
|
||||
use snark_verifier::util::arithmetic::PrimeField;
|
||||
use std::collections::HashSet;
|
||||
@@ -206,6 +206,9 @@ struct PyRunArgs {
|
||||
/// bool: Should the circuit use range checks for inputs and outputs (set to false if the input is a felt)
|
||||
#[pyo3(get, set)]
|
||||
pub ignore_range_check_inputs_outputs: bool,
|
||||
/// float: epsilon used for arguments that use division
|
||||
#[pyo3(get, set)]
|
||||
pub epsilon: f64,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -238,12 +241,14 @@ impl From<PyRunArgs> for RunArgs {
|
||||
decomp_base: py_run_args.decomp_base,
|
||||
decomp_legs: py_run_args.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: py_run_args.ignore_range_check_inputs_outputs,
|
||||
epsilon: Some(py_run_args.epsilon),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<PyRunArgs> for RunArgs {
|
||||
fn into(self) -> PyRunArgs {
|
||||
let eps = self.get_epsilon();
|
||||
PyRunArgs {
|
||||
bounded_log_lookup: self.bounded_log_lookup,
|
||||
input_scale: self.input_scale,
|
||||
@@ -262,6 +267,7 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
decomp_base: self.decomp_base,
|
||||
decomp_legs: self.decomp_legs,
|
||||
ignore_range_check_inputs_outputs: self.ignore_range_check_inputs_outputs,
|
||||
epsilon: eps,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -962,7 +962,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
pub fn layout(
|
||||
&mut self,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
values: &[&ValTensor<F>],
|
||||
op: Box<dyn Op<F>>,
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
op.layout(self, region, values)
|
||||
|
||||
@@ -15,10 +15,12 @@ use serde::{Deserialize, Serialize};
|
||||
pub enum HybridOp {
|
||||
Ln {
|
||||
scale: utils::F32,
|
||||
eps: f64,
|
||||
},
|
||||
Rsqrt {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
eps: f64,
|
||||
},
|
||||
Sqrt {
|
||||
scale: utils::F32,
|
||||
@@ -42,6 +44,7 @@ pub enum HybridOp {
|
||||
Recip {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
eps: f64,
|
||||
},
|
||||
Div {
|
||||
denom: utils::F32,
|
||||
@@ -77,6 +80,7 @@ pub enum HybridOp {
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
axes: Vec<usize>,
|
||||
eps: f64,
|
||||
},
|
||||
Output {
|
||||
decomp: bool,
|
||||
@@ -105,13 +109,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
///
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
match self {
|
||||
HybridOp::Greater { .. }
|
||||
| HybridOp::Less { .. }
|
||||
| HybridOp::Equals { .. }
|
||||
| HybridOp::GreaterEqual { .. }
|
||||
HybridOp::Greater
|
||||
| HybridOp::Less
|
||||
| HybridOp::Equals
|
||||
| HybridOp::GreaterEqual
|
||||
| HybridOp::Max
|
||||
| HybridOp::Min
|
||||
| HybridOp::LessEqual { .. } => {
|
||||
| HybridOp::LessEqual => {
|
||||
vec![0, 1]
|
||||
}
|
||||
_ => vec![],
|
||||
@@ -128,12 +132,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
HybridOp::Rsqrt {
|
||||
input_scale,
|
||||
output_scale,
|
||||
eps,
|
||||
} => format!(
|
||||
"RSQRT (input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
"RSQRT (input_scale={}, output_scale={}, eps={})",
|
||||
input_scale, output_scale, eps
|
||||
),
|
||||
HybridOp::Sqrt { scale } => format!("SQRT(scale={})", scale),
|
||||
HybridOp::Ln { scale } => format!("LN(scale={})", scale),
|
||||
HybridOp::Ln { scale, eps } => format!("LN(scale={}, eps={})", scale, eps),
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
format!("ROUND_HALF_TO_EVEN(scale={}, legs={})", scale, legs)
|
||||
}
|
||||
@@ -146,16 +151,18 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
eps,
|
||||
} => format!(
|
||||
"RECIP (input_scale={}, output_scale={})",
|
||||
input_scale, output_scale
|
||||
"RECIP (input_scale={}, output_scale={}, eps={})",
|
||||
input_scale, output_scale, eps
|
||||
),
|
||||
HybridOp::Div { denom } => format!("DIV (denom={})", denom),
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
kernel_shape,
|
||||
normalized, data_format
|
||||
normalized,
|
||||
data_format,
|
||||
} => format!(
|
||||
"SUMPOOL (padding={:?}, stride={:?}, kernel_shape={:?}, normalized={}, data_format={:?})",
|
||||
padding, stride, kernel_shape, normalized, data_format
|
||||
@@ -177,10 +184,11 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
input_scale,
|
||||
output_scale,
|
||||
axes,
|
||||
eps,
|
||||
} => {
|
||||
format!(
|
||||
"SOFTMAX (input_scale={}, output_scale={}, axes={:?})",
|
||||
input_scale, output_scale, axes
|
||||
"SOFTMAX (input_scale={}, output_scale={}, axes={:?}, eps={})",
|
||||
input_scale, output_scale, axes, eps
|
||||
)
|
||||
}
|
||||
HybridOp::Output { decomp } => {
|
||||
@@ -205,23 +213,27 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
values: &[&ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
HybridOp::Rsqrt {
|
||||
input_scale,
|
||||
output_scale,
|
||||
eps,
|
||||
} => layouts::rsqrt(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*input_scale,
|
||||
*output_scale,
|
||||
*eps,
|
||||
)?,
|
||||
HybridOp::Sqrt { scale } => {
|
||||
layouts::sqrt(config, region, values[..].try_into()?, *scale)?
|
||||
}
|
||||
HybridOp::Ln { scale } => layouts::ln(config, region, values[..].try_into()?, *scale)?,
|
||||
HybridOp::Ln { scale, eps } => {
|
||||
layouts::ln(config, region, values[..].try_into()?, *scale, *eps)?
|
||||
}
|
||||
HybridOp::RoundHalfToEven { scale, legs } => {
|
||||
layouts::round_half_to_even(config, region, values[..].try_into()?, *scale, *legs)?
|
||||
}
|
||||
@@ -255,12 +267,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
eps,
|
||||
} => layouts::recip(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
integer_rep_to_felt(input_scale.0 as IntegerRep),
|
||||
integer_rep_to_felt(output_scale.0 as IntegerRep),
|
||||
*eps,
|
||||
)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
if denom.0.fract() == 0.0 {
|
||||
@@ -317,6 +331,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
input_scale,
|
||||
output_scale,
|
||||
axes,
|
||||
eps,
|
||||
} => layouts::softmax_axes(
|
||||
config,
|
||||
region,
|
||||
@@ -324,6 +339,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
*input_scale,
|
||||
*output_scale,
|
||||
axes,
|
||||
*eps,
|
||||
)?,
|
||||
HybridOp::Output { decomp } => {
|
||||
layouts::output(config, region, values[..].try_into()?, *decomp)?
|
||||
@@ -346,10 +362,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
let scale = match self {
|
||||
HybridOp::Greater { .. }
|
||||
| HybridOp::GreaterEqual { .. }
|
||||
| HybridOp::Less { .. }
|
||||
| HybridOp::LessEqual { .. }
|
||||
HybridOp::Greater
|
||||
| HybridOp::GreaterEqual
|
||||
| HybridOp::Less
|
||||
| HybridOp::LessEqual
|
||||
| HybridOp::ReduceArgMax { .. }
|
||||
| HybridOp::OneHot { .. }
|
||||
| HybridOp::ReduceArgMin { .. } => 0,
|
||||
@@ -364,6 +380,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
|
||||
} => multiplier_to_scale((output_scale.0 * input_scale.0) as f64),
|
||||
HybridOp::Ln {
|
||||
scale: output_scale,
|
||||
eps: _,
|
||||
} => 4 * multiplier_to_scale(output_scale.0 as f64),
|
||||
_ => in_scales[0],
|
||||
};
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -186,7 +186,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
values: &[&ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(layouts::nonlinearity(
|
||||
config,
|
||||
|
||||
@@ -49,7 +49,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
values: &[&ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError>;
|
||||
|
||||
/// Returns the scale of the output of the operation.
|
||||
@@ -209,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
values: &[&ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
let value = values[0].clone();
|
||||
if !value.all_prev_assigned() {
|
||||
@@ -223,12 +223,29 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
|
||||
true,
|
||||
)?))
|
||||
}
|
||||
_ => Ok(Some(super::layouts::identity(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
self.decomp,
|
||||
)?)),
|
||||
_ => {
|
||||
if self.decomp {
|
||||
log::debug!("constraining input to be decomp");
|
||||
Ok(Some(
|
||||
super::layouts::decompose(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
®ion.base(),
|
||||
®ion.legs(),
|
||||
false,
|
||||
)?
|
||||
.1,
|
||||
))
|
||||
} else {
|
||||
log::debug!("constraining input to be identity");
|
||||
Ok(Some(super::layouts::identity(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
)?))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok(Some(value))
|
||||
@@ -263,7 +280,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
|
||||
&self,
|
||||
_: &mut crate::circuit::BaseConfig<F>,
|
||||
_: &mut RegionCtx<F>,
|
||||
_: &[ValTensor<F>],
|
||||
_: &[&ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Err(super::CircuitError::UnsupportedOp)
|
||||
}
|
||||
@@ -319,8 +336,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
}
|
||||
|
||||
impl<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
F: PrimeField
|
||||
+ TensorType
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
{
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
@@ -333,20 +355,20 @@ impl<
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
_: &[ValTensor<F>],
|
||||
_: &[&ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
let value = if let Some(value) = &self.pre_assigned_val {
|
||||
value.clone()
|
||||
} else {
|
||||
self.quantized_values.clone().try_into()?
|
||||
};
|
||||
// we gotta constrain it once if its used multiple times
|
||||
Ok(Some(layouts::identity(
|
||||
config,
|
||||
region,
|
||||
&[value],
|
||||
self.decomp,
|
||||
)?))
|
||||
Ok(Some(if self.decomp {
|
||||
log::debug!("constraining constant to be decomp");
|
||||
super::layouts::decompose(config, region, &[&value], ®ion.base(), ®ion.legs(), false)?.1
|
||||
} else {
|
||||
log::debug!("constraining constant to be identity");
|
||||
super::layouts::identity(config, region, &[&value])?
|
||||
}))
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<F>> {
|
||||
|
||||
@@ -108,8 +108,13 @@ pub enum PolyOp {
|
||||
}
|
||||
|
||||
impl<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
|
||||
> Op<F> for PolyOp
|
||||
F: PrimeField
|
||||
+ TensorType
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for PolyOp
|
||||
{
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
@@ -203,11 +208,11 @@ impl<
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
values: &[&ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, CircuitError> {
|
||||
Ok(Some(match self {
|
||||
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?, true)?,
|
||||
PolyOp::LeakyReLU { slope, scale } => {
|
||||
layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)?
|
||||
}
|
||||
@@ -335,9 +340,7 @@ impl<
|
||||
PolyOp::Mult => {
|
||||
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
|
||||
}
|
||||
PolyOp::Identity { .. } => {
|
||||
layouts::identity(config, region, values[..].try_into()?, false)?
|
||||
}
|
||||
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
|
||||
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
|
||||
PolyOp::Pad(p) => {
|
||||
if values.len() != 1 {
|
||||
@@ -416,14 +419,14 @@ impl<
|
||||
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
|
||||
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
|
||||
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
|
||||
PolyOp::Sign { .. } => 0,
|
||||
PolyOp::Sign => 0,
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
}
|
||||
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) {
|
||||
if matches!(self, PolyOp::Add | PolyOp::Sub) {
|
||||
vec![0, 1]
|
||||
} else if matches!(self, PolyOp::Iff) {
|
||||
vec![1, 2]
|
||||
|
||||
@@ -10,7 +10,6 @@ use halo2_proofs::{
|
||||
plonk::{Error, Selector},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use maybe_rayon::iter::ParallelExtend;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
@@ -462,15 +461,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_inputs(
|
||||
&mut self,
|
||||
inputs: &[ValTensor<F>],
|
||||
inputs: &ValTensor<F>,
|
||||
) -> Result<(), CircuitError> {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in inputs {
|
||||
max = max.max(i.int_evals()?.into_iter().max().unwrap_or_default());
|
||||
min = min.min(i.int_evals()?.into_iter().min().unwrap_or_default());
|
||||
}
|
||||
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
|
||||
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
|
||||
let int_eval = inputs.int_evals()?;
|
||||
let max = int_eval.iter().max().unwrap_or(&0);
|
||||
let min = int_eval.iter().min().unwrap_or(&0);
|
||||
|
||||
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(*max);
|
||||
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(*min);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -505,10 +503,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
/// add used lookup
|
||||
pub fn add_used_lookup(
|
||||
&mut self,
|
||||
lookup: LookupOp,
|
||||
inputs: &[ValTensor<F>],
|
||||
lookup: &LookupOp,
|
||||
inputs: &ValTensor<F>,
|
||||
) -> Result<(), CircuitError> {
|
||||
self.statistics.used_lookups.insert(lookup);
|
||||
self.statistics.used_lookups.insert(lookup.clone());
|
||||
self.update_max_min_lookup_inputs(inputs)
|
||||
}
|
||||
|
||||
@@ -642,34 +640,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
self.assign_dynamic_lookup(var, values)
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor
|
||||
pub fn assign_with_omissions(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
ommissions: &HashSet<usize>,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if let Some(region) = &self.region {
|
||||
Ok(var.assign_with_omissions(
|
||||
&mut region.borrow_mut(),
|
||||
self.linear_coord,
|
||||
values,
|
||||
ommissions,
|
||||
&mut self.assigned_constants,
|
||||
)?)
|
||||
} else {
|
||||
let mut values_clone = values.clone();
|
||||
let mut indices = ommissions.clone().into_iter().collect_vec();
|
||||
values_clone.remove_indices(&mut indices, false)?;
|
||||
|
||||
let values_map = values.create_constants_map();
|
||||
|
||||
self.assigned_constants.par_extend(values_map);
|
||||
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor with duplication
|
||||
pub fn assign_with_duplication_unconstrained(
|
||||
&mut self,
|
||||
|
||||
@@ -9,6 +9,7 @@ use halo2_proofs::{
|
||||
};
|
||||
use halo2curves::bn256::Fr as F;
|
||||
use halo2curves::ff::{Field, PrimeField};
|
||||
use itertools::Itertools;
|
||||
#[cfg(not(any(
|
||||
all(target_arch = "wasm32", target_os = "unknown"),
|
||||
not(feature = "ezkl")
|
||||
@@ -64,7 +65,7 @@ mod matmul {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
}),
|
||||
@@ -141,7 +142,7 @@ mod matmul_col_overflow_double_col {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
}),
|
||||
@@ -215,7 +216,7 @@ mod matmul_col_overflow {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
}),
|
||||
@@ -302,7 +303,7 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
}),
|
||||
@@ -380,6 +381,7 @@ mod matmul_col_ultra_overflow {
|
||||
multiopen::{ProverSHPLONK, VerifierSHPLONK},
|
||||
strategy::SingleStrategy,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
use super::*;
|
||||
@@ -422,7 +424,7 @@ mod matmul_col_ultra_overflow {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "ij,jk->ik".to_string(),
|
||||
}),
|
||||
@@ -533,7 +535,7 @@ mod dot {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "i,i->".to_string(),
|
||||
}),
|
||||
@@ -610,7 +612,7 @@ mod dot_col_overflow_triple_col {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "i,i->".to_string(),
|
||||
}),
|
||||
@@ -683,7 +685,7 @@ mod dot_col_overflow {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "i,i->".to_string(),
|
||||
}),
|
||||
@@ -756,7 +758,7 @@ mod sum {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Sum { axes: vec![0] }),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -826,7 +828,7 @@ mod sum_col_overflow_double_col {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Sum { axes: vec![0] }),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -895,7 +897,7 @@ mod sum_col_overflow {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Sum { axes: vec![0] }),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -966,7 +968,7 @@ mod composition {
|
||||
let _ = config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "i,i->".to_string(),
|
||||
}),
|
||||
@@ -975,7 +977,7 @@ mod composition {
|
||||
let _ = config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "i,i->".to_string(),
|
||||
}),
|
||||
@@ -984,7 +986,7 @@ mod composition {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Einsum {
|
||||
equation: "i,i->".to_string(),
|
||||
}),
|
||||
@@ -1061,7 +1063,7 @@ mod conv {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
@@ -1218,7 +1220,7 @@ mod conv_col_ultra_overflow {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone()],
|
||||
&[&self.image, &self.kernel],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
@@ -1377,7 +1379,7 @@ mod conv_relu_col_ultra_overflow {
|
||||
let output = config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.image.clone(), self.kernel.clone()],
|
||||
&[&self.image, &self.kernel],
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
@@ -1390,7 +1392,7 @@ mod conv_relu_col_ultra_overflow {
|
||||
let _output = config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap().unwrap()],
|
||||
&[&output.unwrap().unwrap()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
@@ -1517,7 +1519,11 @@ mod add_w_shape_casting {
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Add),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
@@ -1584,7 +1590,11 @@ mod add {
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Add),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
@@ -1671,8 +1681,8 @@ mod dynamic_lookup {
|
||||
layouts::dynamic_lookup(
|
||||
&config,
|
||||
&mut region,
|
||||
&self.lookups[i],
|
||||
&self.tables[i],
|
||||
&self.lookups[i].iter().collect_vec().try_into().unwrap(),
|
||||
&self.tables[i].iter().collect_vec().try_into().unwrap(),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)?;
|
||||
}
|
||||
@@ -1767,8 +1777,8 @@ mod shuffle {
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [[ValTensor<F>; 1]; NUM_LOOP],
|
||||
references: [[ValTensor<F>; 1]; NUM_LOOP],
|
||||
inputs: [ValTensor<F>; NUM_LOOP],
|
||||
references: [ValTensor<F>; NUM_LOOP],
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
@@ -1818,15 +1828,15 @@ mod shuffle {
|
||||
layouts::shuffles(
|
||||
&config,
|
||||
&mut region,
|
||||
&self.inputs[i],
|
||||
&self.references[i],
|
||||
&[&self.inputs[i]],
|
||||
&[&self.references[i]],
|
||||
layouts::SortCollisionMode::Unsorted,
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)?;
|
||||
}
|
||||
assert_eq!(
|
||||
region.shuffle_col_coord(),
|
||||
NUM_LOOP * self.references[0][0].len()
|
||||
NUM_LOOP * self.references[0].len()
|
||||
);
|
||||
assert_eq!(region.shuffle_index(), NUM_LOOP);
|
||||
|
||||
@@ -1843,17 +1853,19 @@ mod shuffle {
|
||||
// parameters
|
||||
let references = (0..NUM_LOOP)
|
||||
.map(|loop_idx| {
|
||||
[ValTensor::from(Tensor::from((0..LEN).map(|i| {
|
||||
Value::known(F::from((i * loop_idx) as u64 + 1))
|
||||
})))]
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..LEN).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let inputs = (0..NUM_LOOP)
|
||||
.map(|loop_idx| {
|
||||
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
|
||||
Value::known(F::from((i * loop_idx) as u64 + 1))
|
||||
})))]
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..LEN)
|
||||
.rev()
|
||||
.map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
@@ -1873,9 +1885,11 @@ mod shuffle {
|
||||
} else {
|
||||
loop_idx - 1
|
||||
};
|
||||
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
|
||||
Value::known(F::from((i * prev_idx) as u64 + 1))
|
||||
})))]
|
||||
ValTensor::from(Tensor::from(
|
||||
(0..LEN)
|
||||
.rev()
|
||||
.map(|i| Value::known(F::from((i * prev_idx) as u64 + 1))),
|
||||
))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
@@ -1931,7 +1945,11 @@ mod add_with_overflow {
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Add),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
@@ -2026,7 +2044,7 @@ mod add_with_overflow_and_poseidon {
|
||||
|
||||
layouter.assign_region(|| "_new_module", |_| Ok(()))?;
|
||||
|
||||
let inputs = vec![assigned_inputs_a, assigned_inputs_b];
|
||||
let inputs = vec![&assigned_inputs_a, &assigned_inputs_b];
|
||||
|
||||
layouter.assign_region(
|
||||
|| "model",
|
||||
@@ -2135,7 +2153,11 @@ mod sub {
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Sub))
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Sub),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
@@ -2202,7 +2224,11 @@ mod mult {
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Mult))
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Mult),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
@@ -2269,7 +2295,11 @@ mod pow {
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
|
||||
config
|
||||
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Pow(5)))
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.iter().collect_vec(),
|
||||
Box::new(PolyOp::Pow(5)),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
@@ -2360,13 +2390,13 @@ mod matmul_relu {
|
||||
};
|
||||
let output = config
|
||||
.base_config
|
||||
.layout(&mut region, &self.inputs, Box::new(op))
|
||||
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
|
||||
.unwrap();
|
||||
let _output = config
|
||||
.base_config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[output.unwrap()],
|
||||
&[&output.unwrap()],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
@@ -2465,7 +2495,7 @@ mod relu {
|
||||
Ok(config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
&[&self.input],
|
||||
Box::new(PolyOp::LeakyReLU {
|
||||
slope: 0.0.into(),
|
||||
scale: 1,
|
||||
@@ -2563,7 +2593,7 @@ mod lookup_ultra_overflow {
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&[self.input.clone()],
|
||||
&[&self.input],
|
||||
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use alloy::primitives::Address as H160;
|
||||
use clap::{Command, Parser, Subcommand};
|
||||
use clap_complete::{Generator, Shell, generate};
|
||||
use clap_complete::{generate, Generator, Shell};
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*};
|
||||
use serde::{Deserialize, Serialize};
|
||||
@@ -8,7 +8,7 @@ use std::path::PathBuf;
|
||||
use std::str::FromStr;
|
||||
use tosubcommand::{ToFlags, ToSubcommand};
|
||||
|
||||
use crate::{Commitments, RunArgs, pfsys::ProofType};
|
||||
use crate::{pfsys::ProofType, Commitments, RunArgs};
|
||||
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::graph::TestDataSource;
|
||||
@@ -391,10 +391,8 @@ impl FromStr for DataField {
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
// Check if the input starts with '@'
|
||||
if s.starts_with('@') {
|
||||
if let Some(file_path) = s.strip_prefix('@') {
|
||||
// Extract the file path (remove the '@' prefix)
|
||||
let file_path = &s[1..];
|
||||
|
||||
// Read the file content
|
||||
let content = std::fs::read_to_string(file_path)
|
||||
.map_err(|e| format!("Failed to read data file '{}': {}", file_path, e))?;
|
||||
|
||||
22
src/eth.rs
22
src/eth.rs
@@ -1,24 +1,24 @@
|
||||
use crate::graph::DataSource;
|
||||
use crate::graph::GraphSettings;
|
||||
use crate::graph::input::{CallToAccount, CallsToAccount, FileSourceInner, GraphData};
|
||||
use crate::graph::modules::POSEIDON_INSTANCES;
|
||||
use crate::pfsys::Snark;
|
||||
use crate::graph::DataSource;
|
||||
use crate::graph::GraphSettings;
|
||||
use crate::pfsys::evm::EvmVerificationError;
|
||||
use crate::pfsys::Snark;
|
||||
use alloy::contract::CallBuilder;
|
||||
use alloy::core::primitives::Address as H160;
|
||||
use alloy::core::primitives::Bytes;
|
||||
use alloy::core::primitives::U256;
|
||||
use alloy::dyn_abi::abi::TokenSeq;
|
||||
use alloy::dyn_abi::abi::token::{DynSeqToken, PackedSeqToken, WordToken};
|
||||
use alloy::dyn_abi::abi::TokenSeq;
|
||||
// use alloy::providers::Middleware;
|
||||
use alloy::json_abi::JsonAbi;
|
||||
use alloy::primitives::ruint::ParseError;
|
||||
use alloy::primitives::{B256, I256, ParseSignedError};
|
||||
use alloy::providers::ProviderBuilder;
|
||||
use alloy::primitives::{ParseSignedError, B256, I256};
|
||||
use alloy::providers::fillers::{
|
||||
ChainIdFiller, FillProvider, GasFiller, JoinFill, NonceFiller, SignerFiller,
|
||||
};
|
||||
use alloy::providers::network::{Ethereum, EthereumSigner};
|
||||
use alloy::providers::ProviderBuilder;
|
||||
use alloy::providers::{Identity, Provider, RootProvider};
|
||||
use alloy::rpc::types::eth::TransactionInput;
|
||||
use alloy::rpc::types::eth::TransactionRequest;
|
||||
@@ -27,9 +27,9 @@ use alloy::signers::wallet::{LocalWallet, WalletError};
|
||||
use alloy::sol as abigen;
|
||||
use alloy::transports::http::Http;
|
||||
use alloy::transports::{RpcError, TransportErrorKind};
|
||||
use foundry_compilers::Solc;
|
||||
use foundry_compilers::artifacts::Settings as SolcSettings;
|
||||
use foundry_compilers::error::{SolcError, SolcIoError};
|
||||
use foundry_compilers::Solc;
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
use halo2curves::bn256::{Fr, G1Affine};
|
||||
use halo2curves::group::ff::PrimeField;
|
||||
@@ -711,7 +711,7 @@ pub async fn verify_proof_with_data_attestation(
|
||||
};
|
||||
|
||||
let encoded = func.encode_input(&[
|
||||
Token::Address(addr_verifier.0.0.into()),
|
||||
Token::Address(addr_verifier.0 .0.into()),
|
||||
Token::Bytes(encoded_verifier),
|
||||
])?;
|
||||
|
||||
@@ -757,7 +757,7 @@ pub async fn test_on_chain_data<M: 'static + Provider<Http<Client>, Ethereum>>(
|
||||
let call_to_account = CallToAccount {
|
||||
call_data: hex::encode(call),
|
||||
decimals,
|
||||
address: hex::encode(contract.address().0.0),
|
||||
address: hex::encode(contract.address().0 .0),
|
||||
};
|
||||
info!("call_to_account: {:#?}", call_to_account);
|
||||
Ok(call_to_account)
|
||||
@@ -913,9 +913,9 @@ pub async fn get_contract_artifacts(
|
||||
runs: usize,
|
||||
) -> Result<(JsonAbi, Bytes, Bytes), EthError> {
|
||||
use foundry_compilers::{
|
||||
SHANGHAI_SOLC, SolcInput,
|
||||
artifacts::{Optimizer, output_selection::OutputSelection},
|
||||
artifacts::{output_selection::OutputSelection, Optimizer},
|
||||
compilers::CompilerInput,
|
||||
SolcInput, SHANGHAI_SOLC,
|
||||
};
|
||||
|
||||
if !sol_code_path.exists() {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::circuit::region::RegionSettings;
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::commands::CalibrationTarget;
|
||||
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_sol};
|
||||
#[allow(unused_imports)]
|
||||
@@ -10,21 +9,21 @@ use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
|
||||
use crate::graph::{TestDataSource, TestSources};
|
||||
use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript};
|
||||
use crate::pfsys::{
|
||||
ProofSplitCommit, create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit,
|
||||
create_keys, load_pk, load_vk, save_params, save_pk, Snark, StrategyType, TranscriptType,
|
||||
};
|
||||
use crate::pfsys::{
|
||||
Snark, StrategyType, TranscriptType, create_keys, load_pk, load_vk, save_params, save_pk,
|
||||
create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit, ProofSplitCommit,
|
||||
};
|
||||
use crate::pfsys::{save_vk, srs::*};
|
||||
use crate::tensor::TensorError;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use crate::{commands::*, EZKLError};
|
||||
use crate::{Commitments, RunArgs};
|
||||
use crate::{EZKLError, commands::*};
|
||||
use colored::Colorize;
|
||||
#[cfg(unix)]
|
||||
use gag::Gag;
|
||||
use halo2_proofs::dev::VerifyFailure;
|
||||
use halo2_proofs::plonk::{self, Circuit};
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, Params};
|
||||
use halo2_proofs::poly::commitment::{ParamsProver, Verifier};
|
||||
use halo2_proofs::poly::ipa::commitment::{IPACommitmentScheme, ParamsIPA};
|
||||
@@ -37,6 +36,7 @@ use halo2_proofs::poly::kzg::strategy::AccumulatorStrategy as KZGAccumulatorStra
|
||||
use halo2_proofs::poly::kzg::{
|
||||
commitment::ParamsKZG, strategy::SingleStrategy as KZGSingleStrategy,
|
||||
};
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer};
|
||||
use halo2_solidity_verifier;
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
@@ -48,12 +48,12 @@ use itertools::Itertools;
|
||||
use lazy_static::lazy_static;
|
||||
use log::debug;
|
||||
use log::{info, trace, warn};
|
||||
use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde::Serialize;
|
||||
use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::system::halo2::Config;
|
||||
use snark_verifier::system::halo2::compile;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use snark_verifier::system::halo2::Config;
|
||||
use std::fs::File;
|
||||
use std::io::BufWriter;
|
||||
use std::io::{Cursor, Write};
|
||||
@@ -1163,15 +1163,9 @@ pub(crate) async fn calibrate(
|
||||
|
||||
// if unix get a gag
|
||||
#[cfg(all(not(not(feature = "ezkl")), unix))]
|
||||
let _r = match Gag::stdout() {
|
||||
Ok(g) => Some(g),
|
||||
_ => None,
|
||||
};
|
||||
let _r = Gag::stdout().ok();
|
||||
#[cfg(all(not(not(feature = "ezkl")), unix))]
|
||||
let _g = match Gag::stderr() {
|
||||
Ok(g) => Some(g),
|
||||
_ => None,
|
||||
};
|
||||
let _g = Gag::stderr().ok();
|
||||
|
||||
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
|
||||
Ok(c) => c,
|
||||
@@ -1329,7 +1323,7 @@ pub(crate) async fn calibrate(
|
||||
.clone()
|
||||
}
|
||||
CalibrationTarget::Accuracy => {
|
||||
let param_iterator = found_params.iter().sorted_by_key(|p| {
|
||||
let mut param_iterator = found_params.iter().sorted_by_key(|p| {
|
||||
(
|
||||
p.run_args.input_scale,
|
||||
p.run_args.param_scale,
|
||||
@@ -1338,7 +1332,7 @@ pub(crate) async fn calibrate(
|
||||
)
|
||||
});
|
||||
|
||||
let last = param_iterator.last().ok_or("no params found")?;
|
||||
let last = param_iterator.next_back().ok_or("no params found")?;
|
||||
let max_scale = (
|
||||
last.run_args.input_scale,
|
||||
last.run_args.param_scale,
|
||||
|
||||
@@ -67,8 +67,11 @@ pub enum GraphError {
|
||||
#[error("invalid input types")]
|
||||
InvalidInputTypes,
|
||||
/// Missing results
|
||||
#[error("missing results")]
|
||||
MissingResults,
|
||||
#[error("missing result for node {0}")]
|
||||
MissingResults(usize),
|
||||
/// Missing input
|
||||
#[error("missing input {0}")]
|
||||
MissingInputForNode(usize),
|
||||
/// Tensor error
|
||||
#[error("[tensor] {0}")]
|
||||
TensorError(#[from] crate::tensor::TensorError),
|
||||
@@ -98,8 +101,6 @@ pub enum GraphError {
|
||||
feature = "ezkl",
|
||||
not(all(target_arch = "wasm32", target_os = "unknown"))
|
||||
))]
|
||||
#[error("[tokio postgres] {0}")]
|
||||
TokioPostgresError(#[from] tokio_postgres::Error),
|
||||
/// Eth error
|
||||
#[cfg(all(
|
||||
feature = "ezkl",
|
||||
@@ -141,7 +142,9 @@ pub enum GraphError {
|
||||
#[error("range check {0} is too large")]
|
||||
RangeCheckTooLarge(usize),
|
||||
///Cannot use on-chain data source as private data
|
||||
#[error("cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm.")]
|
||||
#[error(
|
||||
"cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm."
|
||||
)]
|
||||
OnChainDataSource,
|
||||
/// Missing data source
|
||||
#[error("missing data source")]
|
||||
|
||||
@@ -1,17 +1,15 @@
|
||||
use super::errors::GraphError;
|
||||
use super::quantize_float;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use crate::circuit::InputType;
|
||||
use crate::fieldutils::integer_rep_to_felt;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::graph::postgres::Client;
|
||||
use crate::EZKL_BUF_CAPACITY;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::prelude::*;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use std::io::BufReader;
|
||||
use std::io::BufWriter;
|
||||
@@ -19,13 +17,10 @@ use std::io::Read;
|
||||
use std::panic::UnwindSafe;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_core::{
|
||||
tract_data::{TVec, prelude::Tensor as TractTensor},
|
||||
tract_data::{prelude::Tensor as TractTensor, TVec},
|
||||
value::TValue,
|
||||
};
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
|
||||
|
||||
type Decimals = u8;
|
||||
type Call = String;
|
||||
type RPCUrl = String;
|
||||
@@ -260,9 +255,6 @@ pub enum DataSource {
|
||||
File(FileSource),
|
||||
/// Data fetched from blockchain contracts
|
||||
OnChain(OnChainSource),
|
||||
/// Data from a PostgreSQL database
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
DB(PostgresSource),
|
||||
}
|
||||
|
||||
impl Default for DataSource {
|
||||
@@ -323,15 +315,6 @@ impl<'de> Deserialize<'de> for DataSource {
|
||||
return Ok(DataSource::OnChain(t));
|
||||
}
|
||||
|
||||
// Try deserializing as PostgresSource if feature enabled
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
|
||||
if let Ok(t) = third_try {
|
||||
return Ok(DataSource::DB(t));
|
||||
}
|
||||
}
|
||||
|
||||
Err(serde::de::Error::custom("failed to deserialize DataSource"))
|
||||
}
|
||||
}
|
||||
@@ -442,9 +425,7 @@ impl GraphData {
|
||||
pub fn from_str(data: &str) -> Result<Self, GraphError> {
|
||||
let graph_input = serde_json::from_str(data);
|
||||
match graph_input {
|
||||
Ok(graph_input) => {
|
||||
return Ok(graph_input);
|
||||
}
|
||||
Ok(graph_input) => Ok(graph_input),
|
||||
Err(_) => {
|
||||
let path = std::path::PathBuf::from(data);
|
||||
GraphData::from_path(path)
|
||||
@@ -517,11 +498,6 @@ impl GraphData {
|
||||
"on-chain data cannot be split into batches".to_string(),
|
||||
));
|
||||
}
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
GraphData {
|
||||
input_data: DataSource::DB(data),
|
||||
output_data: _,
|
||||
} => data.fetch_and_format_as_file().await?,
|
||||
};
|
||||
|
||||
// Process each input tensor according to its shape
|
||||
@@ -591,28 +567,6 @@ impl GraphData {
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_postgres_source_new() {
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
{
|
||||
let source = PostgresSource::new(
|
||||
"localhost".to_string(),
|
||||
"5432".to_string(),
|
||||
"user".to_string(),
|
||||
"SELECT * FROM table".to_string(),
|
||||
"database".to_string(),
|
||||
"password".to_string(),
|
||||
);
|
||||
|
||||
assert_eq!(source.host, "localhost");
|
||||
assert_eq!(source.port, "5432");
|
||||
assert_eq!(source.user, "user");
|
||||
assert_eq!(source.query, "SELECT * FROM table");
|
||||
assert_eq!(source.dbname, "database");
|
||||
assert_eq!(source.password, "password");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_data_source_serialization_round_trip() {
|
||||
// Test backwards compatibility with old format
|
||||
@@ -655,95 +609,6 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
/// Source data from a PostgreSQL database
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
|
||||
pub struct PostgresSource {
|
||||
/// Database host address
|
||||
pub host: RPCUrl,
|
||||
/// Database user name
|
||||
pub user: String,
|
||||
/// Database password
|
||||
pub password: String,
|
||||
/// SQL query to execute
|
||||
pub query: String,
|
||||
/// Database name
|
||||
pub dbname: String,
|
||||
/// Database port
|
||||
pub port: String,
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
impl PostgresSource {
|
||||
/// Creates a new PostgreSQL data source
|
||||
pub fn new(
|
||||
host: RPCUrl,
|
||||
port: String,
|
||||
user: String,
|
||||
query: String,
|
||||
dbname: String,
|
||||
password: String,
|
||||
) -> Self {
|
||||
PostgresSource {
|
||||
host,
|
||||
user,
|
||||
password,
|
||||
query,
|
||||
dbname,
|
||||
port,
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetches data from the PostgreSQL database
|
||||
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
|
||||
// Configuration string
|
||||
let config = if self.password.is_empty() {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={}",
|
||||
self.host, self.user, self.dbname, self.port
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"host={} user={} dbname={} port={} password={}",
|
||||
self.host, self.user, self.dbname, self.port, self.password
|
||||
)
|
||||
};
|
||||
|
||||
let mut client = Client::connect(&config).await?;
|
||||
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
|
||||
|
||||
// Extract rows from query
|
||||
for row in client.query(&self.query, &[]).await? {
|
||||
for i in 0..row.len() {
|
||||
res.push(row.get(i));
|
||||
}
|
||||
}
|
||||
Ok(vec![res])
|
||||
}
|
||||
|
||||
/// Fetches and formats data as FileSource
|
||||
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
|
||||
Ok(self
|
||||
.fetch()
|
||||
.await?
|
||||
.iter()
|
||||
.map(|d| {
|
||||
d.iter()
|
||||
.map(|d| {
|
||||
FileSourceInner::Float(
|
||||
d.n.as_ref()
|
||||
.unwrap()
|
||||
.to_f64()
|
||||
.ok_or("could not convert decimal to f64")
|
||||
.unwrap(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl ToPyObject for CallToAccount {
|
||||
fn to_object(&self, py: Python) -> PyObject {
|
||||
@@ -767,14 +632,6 @@ impl ToPyObject for DataSource {
|
||||
.unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
DataSource::DB(source) => {
|
||||
let dict = PyDict::new(py);
|
||||
dict.set_item("host", &source.host).unwrap();
|
||||
dict.set_item("user", &source.user).unwrap();
|
||||
dict.set_item("query", &source.query).unwrap();
|
||||
dict.to_object(py)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,9 +6,6 @@ pub mod model;
|
||||
pub mod modules;
|
||||
/// Inner elements of a computational graph that represent a single operation / constraints.
|
||||
pub mod node;
|
||||
/// postgres helper functions
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
pub mod postgres;
|
||||
/// Helper functions
|
||||
pub mod utilities;
|
||||
/// Representations of a computational graph's variables.
|
||||
@@ -36,12 +33,12 @@ use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSize
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
use crate::circuit::region::{ConstantsMap, RegionSettings};
|
||||
use crate::circuit::table::{RESERVED_BLINDING_ROWS_PAD, Range, Table, num_cols_required};
|
||||
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::{CheckMode, InputType};
|
||||
use crate::fieldutils::{IntegerRep, felt_to_f64};
|
||||
use crate::fieldutils::{felt_to_f64, IntegerRep};
|
||||
use crate::pfsys::PrettyElements;
|
||||
use crate::tensor::{Tensor, ValTensor};
|
||||
use crate::{EZKL_BUF_CAPACITY, RunArgs};
|
||||
use crate::{RunArgs, EZKL_BUF_CAPACITY};
|
||||
|
||||
use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
@@ -56,13 +53,13 @@ use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
|
||||
pub use model::*;
|
||||
pub use node::*;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::prelude::*;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDict;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::types::PyDictMethods;
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::ToPyObject;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::ops::Deref;
|
||||
@@ -574,7 +571,7 @@ impl GraphSettings {
|
||||
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
serde_json::to_writer(writer, &self).map_err(|e| {
|
||||
error!("failed to save settings file at {}", e);
|
||||
std::io::Error::new(std::io::ErrorKind::Other, e)
|
||||
std::io::Error::other(e)
|
||||
})
|
||||
}
|
||||
/// load params from file
|
||||
@@ -584,7 +581,7 @@ impl GraphSettings {
|
||||
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
|
||||
let settings: GraphSettings = serde_json::from_reader(reader).map_err(|e| {
|
||||
error!("failed to load settings file at {}", e);
|
||||
std::io::Error::new(std::io::ErrorKind::Other, e)
|
||||
std::io::Error::other(e)
|
||||
})?;
|
||||
|
||||
crate::check_version_string_matches(&settings.version);
|
||||
@@ -1011,10 +1008,6 @@ impl GraphCircuit {
|
||||
DataSource::File(file_data) => {
|
||||
self.load_file_data(file_data, &shapes, scales, input_types)
|
||||
}
|
||||
DataSource::DB(pg) => {
|
||||
let data = pg.fetch_and_format_as_file().await?;
|
||||
self.load_file_data(&data, &shapes, scales, input_types)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1239,15 +1232,9 @@ impl GraphCircuit {
|
||||
let mut cs = ConstraintSystem::default();
|
||||
// if unix get a gag
|
||||
#[cfg(all(not(not(feature = "ezkl")), unix))]
|
||||
let _r = match Gag::stdout() {
|
||||
Ok(g) => Some(g),
|
||||
_ => None,
|
||||
};
|
||||
let _r = Gag::stdout().ok();
|
||||
#[cfg(all(not(not(feature = "ezkl")), unix))]
|
||||
let _g = match Gag::stderr() {
|
||||
Ok(g) => Some(g),
|
||||
_ => None,
|
||||
};
|
||||
let _g = Gag::stderr().ok();
|
||||
|
||||
Self::configure_with_params(&mut cs, settings);
|
||||
|
||||
@@ -1583,13 +1570,13 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
|
||||
let mut module_configs = ModuleConfigs::from_visibility(
|
||||
cs,
|
||||
params.module_sizes.clone(),
|
||||
¶ms.module_sizes,
|
||||
params.run_args.logrows as usize,
|
||||
);
|
||||
|
||||
let mut vars = ModelVars::new(cs, ¶ms);
|
||||
|
||||
module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone());
|
||||
module_configs.configure_complex_modules(cs, &visibility, ¶ms.module_sizes);
|
||||
|
||||
vars.instantiate_instance(
|
||||
cs,
|
||||
|
||||
@@ -200,7 +200,7 @@ fn number_of_iterations(mappings: &[InputMapping], dims: Vec<&[usize]>) -> usize
|
||||
InputMapping::Stacked { axis, chunk } => Some(
|
||||
// number of iterations given the dim size along the axis
|
||||
// and the chunk size
|
||||
(dims[*axis] + chunk - 1) / chunk,
|
||||
dims[*axis].div_ceil(*chunk), // (dims[*axis] + chunk - 1) / chunk,
|
||||
),
|
||||
_ => None,
|
||||
});
|
||||
@@ -589,10 +589,7 @@ impl Model {
|
||||
required_range_checks: res.range_checks.into_iter().collect(),
|
||||
model_output_scales: self.graph.get_output_scales()?,
|
||||
model_input_scales: self.graph.get_input_scales(),
|
||||
input_types: match self.get_input_types() {
|
||||
Ok(x) => Some(x),
|
||||
Err(_) => None,
|
||||
},
|
||||
input_types: self.get_input_types().ok(),
|
||||
output_types: Some(self.get_output_types()),
|
||||
num_dynamic_lookups: res.num_dynamic_lookups,
|
||||
total_dynamic_col_size: res.dynamic_lookup_col_coord,
|
||||
@@ -650,10 +647,13 @@ impl Model {
|
||||
let variables: std::collections::HashMap<String, usize> =
|
||||
std::collections::HashMap::from_iter(variables.iter().map(|(k, v)| (k.clone(), *v)));
|
||||
|
||||
for (i, id) in model.clone().inputs.iter().enumerate() {
|
||||
let inputs = model.inputs.clone();
|
||||
let outputs = model.outputs.clone();
|
||||
|
||||
for (i, id) in inputs.iter().enumerate() {
|
||||
let input = model.node_mut(id.node);
|
||||
|
||||
if input.outputs.len() == 0 {
|
||||
if input.outputs.is_empty() {
|
||||
return Err(GraphError::MissingOutput(id.node));
|
||||
}
|
||||
let mut fact: InferenceFact = input.outputs[0].fact.clone();
|
||||
@@ -672,7 +672,7 @@ impl Model {
|
||||
model.set_input_fact(i, fact)?;
|
||||
}
|
||||
|
||||
for (i, _) in model.clone().outputs.iter().enumerate() {
|
||||
for (i, _) in outputs.iter().enumerate() {
|
||||
model.set_output_fact(i, InferenceFact::default())?;
|
||||
}
|
||||
|
||||
@@ -1196,7 +1196,7 @@ impl Model {
|
||||
.base
|
||||
.layout(
|
||||
&mut thread_safe_region,
|
||||
&[output.clone(), comparators],
|
||||
&[output, &comparators],
|
||||
Box::new(HybridOp::Output {
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
@@ -1257,12 +1257,27 @@ impl Model {
|
||||
node.inputs()
|
||||
.iter()
|
||||
.map(|(idx, outlet)| {
|
||||
Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
|
||||
// check node is not an output
|
||||
let is_output = self.graph.outputs.iter().any(|(o_idx, _)| *idx == *o_idx);
|
||||
|
||||
let res = if self.graph.nodes[idx].num_uses() == 1 && !is_output {
|
||||
let res = results.remove(idx);
|
||||
res.ok_or(GraphError::MissingResults(*idx))?[*outlet].clone()
|
||||
} else {
|
||||
results.get(idx).ok_or(GraphError::MissingResults(*idx))?[*outlet]
|
||||
.clone()
|
||||
};
|
||||
Ok(res)
|
||||
})
|
||||
.collect::<Result<Vec<_>, GraphError>>()?
|
||||
} else {
|
||||
// we re-assign inputs, always from the 0 outlet
|
||||
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
|
||||
if self.graph.nodes[idx].num_uses() == 1 {
|
||||
let res = results.remove(idx);
|
||||
vec![res.ok_or(GraphError::MissingInput(*idx))?[0].clone()]
|
||||
} else {
|
||||
vec![results.get(idx).ok_or(GraphError::MissingInput(*idx))?[0].clone()]
|
||||
}
|
||||
};
|
||||
trace!("output dims: {:?}", node.out_dims());
|
||||
trace!(
|
||||
@@ -1273,7 +1288,7 @@ impl Model {
|
||||
let start = instant::Instant::now();
|
||||
match &node {
|
||||
NodeType::Node(n) => {
|
||||
let res = if node.is_constant() && node.num_uses() == 1 {
|
||||
let mut res = if node.is_constant() && node.num_uses() == 1 {
|
||||
log::debug!("node {} is a constant with 1 use", n.idx);
|
||||
let mut node = n.clone();
|
||||
let c = node
|
||||
@@ -1284,19 +1299,19 @@ impl Model {
|
||||
} else {
|
||||
config
|
||||
.base
|
||||
.layout(region, &values, n.opkind.clone_dyn())
|
||||
.layout(region, &values.iter().collect_vec(), n.opkind.clone_dyn())
|
||||
.map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?
|
||||
};
|
||||
|
||||
if let Some(mut vt) = res {
|
||||
if let Some(vt) = &mut res {
|
||||
vt.reshape(&node.out_dims()[0])?;
|
||||
// we get the max as for fused nodes this corresponds to the node output
|
||||
results.insert(*idx, vec![vt.clone()]);
|
||||
//only use with mock prover
|
||||
debug!("------------ output node {:?}: {:?}", idx, vt.show());
|
||||
// we get the max as for fused nodes this corresponds to the node output
|
||||
results.insert(*idx, vec![vt.clone()]);
|
||||
}
|
||||
}
|
||||
NodeType::SubGraph {
|
||||
@@ -1340,7 +1355,7 @@ impl Model {
|
||||
.inputs
|
||||
.clone()
|
||||
.into_iter()
|
||||
.zip(values.clone().into_iter().map(|v| vec![v])),
|
||||
.zip(values.iter().map(|v| vec![v.clone()])),
|
||||
);
|
||||
|
||||
let res = model.layout_nodes(config, region, &mut subgraph_results)?;
|
||||
@@ -1421,7 +1436,7 @@ impl Model {
|
||||
);
|
||||
let outputs = output_nodes
|
||||
.map(|(idx, outlet)| {
|
||||
Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
|
||||
Ok(results.get(idx).ok_or(GraphError::MissingResults(*idx))?[*outlet].clone())
|
||||
})
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
@@ -1476,7 +1491,7 @@ impl Model {
|
||||
|
||||
dummy_config.layout(
|
||||
&mut region,
|
||||
&[output.clone(), comparator],
|
||||
&[output, &comparator],
|
||||
Box::new(HybridOp::Output {
|
||||
decomp: !run_args.ignore_range_check_inputs_outputs,
|
||||
}),
|
||||
|
||||
@@ -37,15 +37,15 @@ impl ModuleConfigs {
|
||||
/// Create new module configs from visibility of each variable
|
||||
pub fn from_visibility(
|
||||
cs: &mut ConstraintSystem<Fp>,
|
||||
module_size: ModuleSizes,
|
||||
module_size: &ModuleSizes,
|
||||
logrows: usize,
|
||||
) -> Self {
|
||||
let mut config = Self::default();
|
||||
|
||||
for size in module_size.polycommit {
|
||||
for size in &module_size.polycommit {
|
||||
config
|
||||
.polycommit
|
||||
.push(PolyCommitChip::configure(cs, (logrows, size)));
|
||||
.push(PolyCommitChip::configure(cs, (logrows, *size)));
|
||||
}
|
||||
|
||||
config
|
||||
@@ -55,8 +55,8 @@ impl ModuleConfigs {
|
||||
pub fn configure_complex_modules(
|
||||
&mut self,
|
||||
cs: &mut ConstraintSystem<Fp>,
|
||||
visibility: VarVisibility,
|
||||
module_size: ModuleSizes,
|
||||
visibility: &VarVisibility,
|
||||
module_size: &ModuleSizes,
|
||||
) {
|
||||
if (visibility.input.is_hashed()
|
||||
|| visibility.output.is_hashed()
|
||||
|
||||
@@ -37,6 +37,7 @@ use crate::tensor::TensorError;
|
||||
// Import curve-specific field type
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
|
||||
use itertools::Itertools;
|
||||
// Import logging for EZKL
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use log::trace;
|
||||
@@ -118,16 +119,15 @@ impl Op<Fp> for Rescaled {
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
region: &mut crate::circuit::region::RegionCtx<Fp>,
|
||||
values: &[crate::tensor::ValTensor<Fp>],
|
||||
values: &[&crate::tensor::ValTensor<Fp>],
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
|
||||
if self.scale.len() != values.len() {
|
||||
return Err(TensorError::DimMismatch("rescaled inputs".to_string()).into());
|
||||
}
|
||||
|
||||
let res =
|
||||
&crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)?
|
||||
[..];
|
||||
self.inner.layout(config, region, res)
|
||||
crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)?;
|
||||
self.inner.layout(config, region, &res.iter().collect_vec())
|
||||
}
|
||||
|
||||
/// Create a cloned boxed copy of this operation
|
||||
@@ -274,13 +274,13 @@ impl Op<Fp> for RebaseScale {
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
region: &mut crate::circuit::region::RegionCtx<Fp>,
|
||||
values: &[crate::tensor::ValTensor<Fp>],
|
||||
values: &[&crate::tensor::ValTensor<Fp>],
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
|
||||
let original_res = self
|
||||
.inner
|
||||
.layout(config, region, values)?
|
||||
.ok_or(CircuitError::MissingLayout(self.as_string()))?;
|
||||
self.rebase_op.layout(config, region, &[original_res])
|
||||
self.rebase_op.layout(config, region, &[&original_res])
|
||||
}
|
||||
|
||||
/// Create a cloned boxed copy of this operation
|
||||
@@ -472,7 +472,7 @@ impl Op<Fp> for SupportedOp {
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
region: &mut crate::circuit::region::RegionCtx<Fp>,
|
||||
values: &[crate::tensor::ValTensor<Fp>],
|
||||
values: &[&crate::tensor::ValTensor<Fp>],
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
|
||||
self.as_op().layout(config, region, values)
|
||||
}
|
||||
|
||||
@@ -1,493 +0,0 @@
|
||||
use log::{debug, error, info};
|
||||
use std::fmt::Debug;
|
||||
use std::net::IpAddr;
|
||||
#[cfg(all(not(not(feature = "ezkl")), unix))]
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::{fmt, pin::Pin};
|
||||
use tokio::task::JoinHandle;
|
||||
#[doc(inline)]
|
||||
pub use tokio_postgres::config::{
|
||||
ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs,
|
||||
};
|
||||
use tokio_postgres::tls::NoTlsStream;
|
||||
use tokio_postgres::NoTls;
|
||||
use tokio_postgres::{error::DbError, types::ToSql, Error, Row, Socket, ToStatement};
|
||||
|
||||
/// Connection configuration.
|
||||
///
|
||||
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
|
||||
///
|
||||
///
|
||||
#[derive(Clone)]
|
||||
pub struct Config {
|
||||
config: tokio_postgres::Config,
|
||||
notice_callback: Arc<dyn Fn(DbError) + Send + Sync>,
|
||||
}
|
||||
|
||||
impl fmt::Debug for Config {
|
||||
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fmt.debug_struct("Config")
|
||||
.field("config", &self.config)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Config {
|
||||
Config::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
/// Creates a new configuration.
|
||||
pub fn new() -> Config {
|
||||
tokio_postgres::Config::new().into()
|
||||
}
|
||||
|
||||
/// Sets the user to authenticate with.
|
||||
///
|
||||
/// If the user is not set, then this defaults to the user executing this process.
|
||||
pub fn user(&mut self, user: &str) -> &mut Config {
|
||||
self.config.user(user);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the user to authenticate with, if one has been configured with
|
||||
/// the `user` method.
|
||||
pub fn get_user(&self) -> Option<&str> {
|
||||
self.config.get_user()
|
||||
}
|
||||
|
||||
/// Sets the password to authenticate with.
|
||||
pub fn password<T>(&mut self, password: T) -> &mut Config
|
||||
where
|
||||
T: AsRef<[u8]>,
|
||||
{
|
||||
self.config.password(password);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the password to authenticate with, if one has been configured with
|
||||
/// the `password` method.
|
||||
pub fn get_password(&self) -> Option<&[u8]> {
|
||||
self.config.get_password()
|
||||
}
|
||||
|
||||
/// Sets the name of the database to connect to.
|
||||
///
|
||||
/// Defaults to the user.
|
||||
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
|
||||
self.config.dbname(dbname);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the name of the database to connect to, if one has been configured
|
||||
/// with the `dbname` method.
|
||||
pub fn get_dbname(&self) -> Option<&str> {
|
||||
self.config.get_dbname()
|
||||
}
|
||||
|
||||
/// Sets command line options used to configure the server.
|
||||
pub fn options(&mut self, options: &str) -> &mut Config {
|
||||
self.config.options(options);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the command line options used to configure the server, if the
|
||||
/// options have been set with the `options` method.
|
||||
pub fn get_options(&self) -> Option<&str> {
|
||||
self.config.get_options()
|
||||
}
|
||||
|
||||
/// Sets the value of the `application_name` runtime parameter.
|
||||
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
|
||||
self.config.application_name(application_name);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the value of the `application_name` runtime parameter, if it has
|
||||
/// been set with the `application_name` method.
|
||||
pub fn get_application_name(&self) -> Option<&str> {
|
||||
self.config.get_application_name()
|
||||
}
|
||||
|
||||
/// Sets the SSL configuration.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
|
||||
self.config.ssl_mode(ssl_mode);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the SSL configuration.
|
||||
pub fn get_ssl_mode(&self) -> SslMode {
|
||||
self.config.get_ssl_mode()
|
||||
}
|
||||
|
||||
/// Adds a host to the configuration.
|
||||
///
|
||||
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
|
||||
/// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets.
|
||||
/// There must be either no hosts, or the same number of hosts as hostaddrs.
|
||||
pub fn host(&mut self, host: &str) -> &mut Config {
|
||||
self.config.host(host);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the hosts that have been added to the configuration with `host`.
|
||||
pub fn get_hosts(&self) -> &[Host] {
|
||||
self.config.get_hosts()
|
||||
}
|
||||
|
||||
/// Gets the hostaddrs that have been added to the configuration with `hostaddr`.
|
||||
pub fn get_hostaddrs(&self) -> &[IpAddr] {
|
||||
self.config.get_hostaddrs()
|
||||
}
|
||||
|
||||
/// Adds a Unix socket host to the configuration.
|
||||
///
|
||||
/// Unlike `host`, this method allows non-UTF8 paths.
|
||||
#[cfg(all(not(not(feature = "ezkl")), unix))]
|
||||
pub fn host_path<T>(&mut self, host: T) -> &mut Config
|
||||
where
|
||||
T: AsRef<Path>,
|
||||
{
|
||||
self.config.host_path(host);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a hostaddr to the configuration.
|
||||
///
|
||||
/// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order.
|
||||
/// There must be either no hostaddrs, or the same number of hostaddrs as hosts.
|
||||
pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
|
||||
self.config.hostaddr(hostaddr);
|
||||
self
|
||||
}
|
||||
|
||||
/// Adds a port to the configuration.
|
||||
///
|
||||
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
|
||||
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
|
||||
/// as hosts.
|
||||
pub fn port(&mut self, port: u16) -> &mut Config {
|
||||
self.config.port(port);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the ports that have been added to the configuration with `port`.
|
||||
pub fn get_ports(&self) -> &[u16] {
|
||||
self.config.get_ports()
|
||||
}
|
||||
|
||||
/// Sets the timeout applied to socket-level connection attempts.
|
||||
///
|
||||
/// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
|
||||
/// host separately. Defaults to no limit.
|
||||
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
|
||||
self.config.connect_timeout(connect_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the connection timeout, if one has been set with the
|
||||
/// `connect_timeout` method.
|
||||
pub fn get_connect_timeout(&self) -> Option<&Duration> {
|
||||
self.config.get_connect_timeout()
|
||||
}
|
||||
|
||||
/// Sets the TCP user timeout.
|
||||
///
|
||||
/// This is ignored for Unix domain socket connections. It is only supported on systems where
|
||||
/// TCP_USER_TIMEOUT is available and will default to the system default if omitted or set to 0;
|
||||
/// on other systems, it has no effect.
|
||||
pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
|
||||
self.config.tcp_user_timeout(tcp_user_timeout);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the TCP user timeout, if one has been set with the
|
||||
/// `user_timeout` method.
|
||||
pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
|
||||
self.config.get_tcp_user_timeout()
|
||||
}
|
||||
|
||||
/// Controls the use of TCP keepalive.
|
||||
///
|
||||
/// This is ignored for Unix domain socket connections. Defaults to `true`.
|
||||
pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
|
||||
self.config.keepalives(keepalives);
|
||||
self
|
||||
}
|
||||
|
||||
/// Reports whether TCP keepalives will be used.
|
||||
pub fn get_keepalives(&self) -> bool {
|
||||
self.config.get_keepalives()
|
||||
}
|
||||
|
||||
/// Sets the amount of idle time before a keepalive packet is sent on the connection.
|
||||
///
|
||||
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours.
|
||||
pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
|
||||
self.config.keepalives_idle(keepalives_idle);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the configured amount of idle time before a keepalive packet will
|
||||
/// be sent on the connection.
|
||||
pub fn get_keepalives_idle(&self) -> Duration {
|
||||
self.config.get_keepalives_idle()
|
||||
}
|
||||
|
||||
/// Sets the time interval between TCP keepalive probes.
|
||||
/// On Windows, this sets the value of the tcp_keepalive struct’s keepaliveinterval field.
|
||||
///
|
||||
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
|
||||
pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
|
||||
self.config.keepalives_interval(keepalives_interval);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the time interval between TCP keepalive probes.
|
||||
pub fn get_keepalives_interval(&self) -> Option<Duration> {
|
||||
self.config.get_keepalives_interval()
|
||||
}
|
||||
|
||||
/// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
|
||||
///
|
||||
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
|
||||
pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
|
||||
self.config.keepalives_retries(keepalives_retries);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
|
||||
pub fn get_keepalives_retries(&self) -> Option<u32> {
|
||||
self.config.get_keepalives_retries()
|
||||
}
|
||||
|
||||
/// Sets the requirements of the session.
|
||||
///
|
||||
/// This can be used to connect to the primary server in a clustered database rather than one of the read-only
|
||||
/// secondary servers. Defaults to `Any`.
|
||||
pub fn target_session_attrs(
|
||||
&mut self,
|
||||
target_session_attrs: TargetSessionAttrs,
|
||||
) -> &mut Config {
|
||||
self.config.target_session_attrs(target_session_attrs);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the requirements of the session.
|
||||
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
|
||||
self.config.get_target_session_attrs()
|
||||
}
|
||||
|
||||
/// Sets the channel binding behavior.
|
||||
///
|
||||
/// Defaults to `prefer`.
|
||||
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
|
||||
self.config.channel_binding(channel_binding);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the channel binding behavior.
|
||||
pub fn get_channel_binding(&self) -> ChannelBinding {
|
||||
self.config.get_channel_binding()
|
||||
}
|
||||
|
||||
/// Sets the host load balancing behavior.
|
||||
///
|
||||
/// Defaults to `disable`.
|
||||
pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
|
||||
self.config.load_balance_hosts(load_balance_hosts);
|
||||
self
|
||||
}
|
||||
|
||||
/// Gets the host load balancing behavior.
|
||||
pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
|
||||
self.config.get_load_balance_hosts()
|
||||
}
|
||||
|
||||
/// Sets the notice callback.
|
||||
///
|
||||
/// This callback will be invoked with the contents of every
|
||||
/// [`AsyncMessage::Notice`] that is received by the connection. Notices use
|
||||
/// the same structure as errors, but they are not "errors" per-se.
|
||||
///
|
||||
/// Notices are distinct from notifications, which are instead accessible
|
||||
/// via the [`Notifications`] API.
|
||||
///
|
||||
/// [`AsyncMessage::Notice`]: tokio_postgres::AsyncMessage::Notice
|
||||
/// [`Notifications`]: crate::Notifications
|
||||
pub fn notice_callback<F>(&mut self, f: F) -> &mut Config
|
||||
where
|
||||
F: Fn(DbError) + Send + Sync + 'static,
|
||||
{
|
||||
self.notice_callback = Arc::new(f);
|
||||
self
|
||||
}
|
||||
|
||||
/// Opens a connection to a PostgreSQL database.
|
||||
pub async fn connect(&self) -> Result<Client, Error> {
|
||||
let (client, connection) = self.config.connect(NoTls).await?;
|
||||
|
||||
let connection = Connection::new(connection);
|
||||
|
||||
Ok(Client::new(client, connection))
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Config {
|
||||
type Err = Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Config, Error> {
|
||||
s.parse::<tokio_postgres::Config>().map(Config::from)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<tokio_postgres::Config> for Config {
|
||||
fn from(config: tokio_postgres::Config) -> Config {
|
||||
Config {
|
||||
config,
|
||||
notice_callback: Arc::new(|notice| {
|
||||
info!("{}: {}", notice.severity(), notice.message())
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_debug_implementations, dead_code)]
|
||||
/// An asynchronous PostgreSQL connection. We use this to keep the connection alive / keep it pinned so that it doesn't
|
||||
/// get dropped.
|
||||
pub struct Connection {
|
||||
/// The underlying connection stream.
|
||||
connection: Pin<Box<tokio_postgres::Connection<Socket, NoTlsStream>>>,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
/// Creates a new connection.
|
||||
pub fn new(connection: tokio_postgres::Connection<Socket, NoTlsStream>) -> Self {
|
||||
Connection {
|
||||
connection: Box::pin(connection),
|
||||
}
|
||||
}
|
||||
|
||||
/// start the connection
|
||||
pub async fn start(self) {
|
||||
if let Err(e) = self.connection.await {
|
||||
error!("connection error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_debug_implementations, dead_code)]
|
||||
/// An asynchronous PostgreSQL client.
|
||||
pub struct Client {
|
||||
connection: JoinHandle<()>,
|
||||
client: tokio_postgres::Client,
|
||||
}
|
||||
|
||||
impl Drop for Client {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.close_inner();
|
||||
}
|
||||
}
|
||||
|
||||
impl Client {
|
||||
pub(crate) fn new(client: tokio_postgres::Client, connection: Connection) -> Client {
|
||||
// The connection object performs the actual communication with the database,
|
||||
// so spawn it off to run on its own.
|
||||
let thread = tokio::spawn(async move {
|
||||
connection.start().await;
|
||||
});
|
||||
|
||||
Client {
|
||||
client,
|
||||
connection: thread,
|
||||
}
|
||||
}
|
||||
|
||||
/// A convenience function which parses a configuration string into a `Config` and then connects to the database.
|
||||
///
|
||||
/// See the documentation for [`Config`] for information about the connection syntax.
|
||||
///
|
||||
/// [`Config`]: config/struct.Config.html
|
||||
pub async fn connect(params: &str) -> Result<Client, Error> {
|
||||
debug!("Connecting to database with params: {}", params);
|
||||
params.parse::<Config>()?.connect().await
|
||||
}
|
||||
|
||||
/// Returns a new `Config` object which can be used to configure and connect to a database.
|
||||
pub fn configure() -> Config {
|
||||
Config::new()
|
||||
}
|
||||
|
||||
/// Executes a statement, returning the number of rows modified.
|
||||
///
|
||||
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
|
||||
/// provided, 1-indexed.
|
||||
///
|
||||
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
|
||||
///
|
||||
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
|
||||
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
|
||||
/// with the `prepare` method.
|
||||
///
|
||||
pub async fn execute<T>(
|
||||
&mut self,
|
||||
query: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
) -> Result<u64, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement + Debug,
|
||||
{
|
||||
debug!("Executing query: {:?}", query);
|
||||
self.client.execute(query, params).await
|
||||
}
|
||||
|
||||
/// Executes a statement, returning the resulting rows.
|
||||
///
|
||||
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
|
||||
/// provided, 1-indexed.
|
||||
///
|
||||
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
|
||||
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
|
||||
/// with the `prepare` method.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
pub async fn query<T>(
|
||||
&mut self,
|
||||
query: &T,
|
||||
params: &[&(dyn ToSql + Sync)],
|
||||
) -> Result<Vec<Row>, Error>
|
||||
where
|
||||
T: ?Sized + ToStatement + Debug,
|
||||
{
|
||||
debug!("Executing query: {:?}", query);
|
||||
self.client.query(query, params).await
|
||||
}
|
||||
|
||||
/// Determines if the client's connection has already closed.
|
||||
///
|
||||
/// If this returns `true`, the client is no longer usable.
|
||||
pub fn is_closed(&self) -> bool {
|
||||
self.client.is_closed()
|
||||
}
|
||||
|
||||
/// Closes the client's connection to the server.
|
||||
///
|
||||
/// This is equivalent to `Client`'s `Drop` implementation, except that it returns any error encountered to the
|
||||
/// caller.
|
||||
pub fn close(mut self) -> Result<(), Error> {
|
||||
self.close_inner()
|
||||
}
|
||||
|
||||
fn close_inner(&mut self) -> Result<(), Error> {
|
||||
self.client.__private_api_close();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -1,14 +1,14 @@
|
||||
use super::errors::GraphError;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use super::VarScales;
|
||||
use super::errors::GraphError;
|
||||
use super::{Rescaled, SupportedOp, Visibility};
|
||||
use crate::circuit::Op;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::Op;
|
||||
use crate::fieldutils::IntegerRep;
|
||||
use crate::tensor::{Tensor, TensorError, TensorType};
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
@@ -22,7 +22,6 @@ use std::sync::Arc;
|
||||
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_core::ops::{
|
||||
Downsample,
|
||||
array::{
|
||||
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
|
||||
Slice, Topk,
|
||||
@@ -32,6 +31,7 @@ use tract_onnx::tract_core::ops::{
|
||||
einsum::EinSum,
|
||||
element_wise::ElementWiseOp,
|
||||
nn::{LeakyRelu, Reduce, Softmax},
|
||||
Downsample,
|
||||
};
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use tract_onnx::tract_hir::{
|
||||
@@ -858,6 +858,7 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Hybrid(HybridOp::Recip {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
eps: run_args.get_epsilon(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -903,6 +904,7 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Hybrid(HybridOp::Rsqrt {
|
||||
input_scale: (scale_to_multiplier(in_scale) as f32).into(),
|
||||
output_scale: (scale_to_multiplier(max_scale) as f32).into(),
|
||||
eps: run_args.get_epsilon(),
|
||||
})
|
||||
}
|
||||
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp {
|
||||
@@ -913,6 +915,7 @@ pub fn new_op_from_onnx(
|
||||
if run_args.bounded_log_lookup {
|
||||
SupportedOp::Hybrid(HybridOp::Ln {
|
||||
scale: scale_to_multiplier(input_scales[0]).into(),
|
||||
eps: run_args.get_epsilon(),
|
||||
})
|
||||
} else {
|
||||
SupportedOp::Nonlinear(LookupOp::Ln {
|
||||
@@ -1131,6 +1134,7 @@ pub fn new_op_from_onnx(
|
||||
input_scale: scale_to_multiplier(in_scale).into(),
|
||||
output_scale: scale_to_multiplier(max_scale).into(),
|
||||
axes: softmax_op.axes.to_vec(),
|
||||
eps: run_args.get_epsilon(),
|
||||
})
|
||||
}
|
||||
"MaxPool" => {
|
||||
|
||||
21
src/lib.rs
21
src/lib.rs
@@ -29,8 +29,14 @@
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
use log::warn;
|
||||
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
|
||||
use mimalloc as _;
|
||||
|
||||
#[global_allocator]
|
||||
#[cfg(all(feature = "jemalloc", not(target_arch = "wasm32")))]
|
||||
static GLOBAL: jemallocator::Jemalloc = jemallocator::Jemalloc;
|
||||
|
||||
#[global_allocator]
|
||||
#[cfg(all(feature = "mimalloc", not(target_arch = "wasm32")))]
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
|
||||
/// Error type
|
||||
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
|
||||
@@ -350,6 +356,16 @@ pub struct RunArgs {
|
||||
arg(long, default_value = "false")
|
||||
)]
|
||||
pub ignore_range_check_inputs_outputs: bool,
|
||||
/// Optional override for epsilon value
|
||||
#[cfg_attr(all(feature = "ezkl", not(target_arch = "wasm32")), arg(long))]
|
||||
pub epsilon: Option<f64>,
|
||||
}
|
||||
|
||||
impl RunArgs {
|
||||
/// Returns the epsilon value
|
||||
pub fn get_epsilon(&self) -> f64 {
|
||||
self.epsilon.unwrap_or(f64::EPSILON)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
@@ -376,6 +392,7 @@ impl Default for RunArgs {
|
||||
decomp_base: 16384,
|
||||
decomp_legs: 2,
|
||||
ignore_range_check_inputs_outputs: false,
|
||||
epsilon: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,16 +17,16 @@ use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
|
||||
use clap::ValueEnum;
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::plonk::{
|
||||
Circuit, ProvingKey, VerifyingKey, create_proof, keygen_pk, keygen_vk_custom, verify_proof,
|
||||
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
|
||||
};
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
|
||||
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
|
||||
use halo2curves::CurveAffine;
|
||||
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
|
||||
use halo2curves::serde::SerdeObject;
|
||||
use halo2curves::CurveAffine;
|
||||
use instant::Instant;
|
||||
use log::{debug, info, trace};
|
||||
#[cfg(not(feature = "det-prove"))]
|
||||
@@ -324,7 +324,7 @@ where
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
use pyo3::{PyObject, Python, ToPyObject, types::PyDict};
|
||||
use pyo3::{types::PyDict, PyObject, Python, ToPyObject};
|
||||
#[cfg(feature = "python-bindings")]
|
||||
impl<F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> ToPyObject for Snark<F, C>
|
||||
where
|
||||
@@ -348,9 +348,9 @@ where
|
||||
}
|
||||
|
||||
impl<
|
||||
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
|
||||
C: CurveAffine + Serialize + DeserializeOwned,
|
||||
> Snark<F, C>
|
||||
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
|
||||
C: CurveAffine + Serialize + DeserializeOwned,
|
||||
> Snark<F, C>
|
||||
where
|
||||
C::Scalar: Serialize + DeserializeOwned,
|
||||
C::ScalarExt: Serialize + DeserializeOwned,
|
||||
|
||||
@@ -27,7 +27,7 @@ pub use var::*;
|
||||
|
||||
use crate::{
|
||||
circuit::utils,
|
||||
fieldutils::{IntegerRep, integer_rep_to_felt},
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
graph::Visibility,
|
||||
};
|
||||
|
||||
@@ -42,11 +42,11 @@ use std::error::Error;
|
||||
use std::fmt::Debug;
|
||||
use std::io::Read;
|
||||
use std::iter::Iterator;
|
||||
use std::ops::Rem;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
|
||||
/// The (inner) type of tensor elements.
|
||||
pub trait TensorType: Clone + Debug + 'static {
|
||||
pub trait TensorType: Clone + Debug {
|
||||
/// Returns the zero value.
|
||||
fn zero() -> Option<Self> {
|
||||
None
|
||||
@@ -55,10 +55,6 @@ pub trait TensorType: Clone + Debug + 'static {
|
||||
fn one() -> Option<Self> {
|
||||
None
|
||||
}
|
||||
/// Max operator for ordering values.
|
||||
fn tmax(&self, _: &Self) -> Option<Self> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! tensor_type {
|
||||
@@ -70,10 +66,6 @@ macro_rules! tensor_type {
|
||||
fn one() -> Option<Self> {
|
||||
Some($one)
|
||||
}
|
||||
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
Some(max(*self, *other))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
@@ -82,46 +74,12 @@ impl TensorType for f32 {
|
||||
fn zero() -> Option<Self> {
|
||||
Some(0.0)
|
||||
}
|
||||
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
|
||||
// A comparison between f32s needs to handle NAN values.
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
match (self.is_nan(), other.is_nan()) {
|
||||
(true, true) => Some(f32::NAN),
|
||||
(true, false) => Some(*other),
|
||||
(false, true) => Some(*self),
|
||||
(false, false) => {
|
||||
if self >= other {
|
||||
Some(*self)
|
||||
} else {
|
||||
Some(*other)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorType for f64 {
|
||||
fn zero() -> Option<Self> {
|
||||
Some(0.0)
|
||||
}
|
||||
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
|
||||
// A comparison between f32s needs to handle NAN values.
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
match (self.is_nan(), other.is_nan()) {
|
||||
(true, true) => Some(f64::NAN),
|
||||
(true, false) => Some(*other),
|
||||
(false, true) => Some(*self),
|
||||
(false, false) => {
|
||||
if self >= other {
|
||||
Some(*self)
|
||||
} else {
|
||||
Some(*other)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor_type!(bool, Bool, false, true);
|
||||
@@ -147,14 +105,6 @@ impl<T: TensorType> TensorType for Value<T> {
|
||||
fn one() -> Option<Self> {
|
||||
Some(Value::known(T::one().unwrap()))
|
||||
}
|
||||
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
Some(
|
||||
(self.clone())
|
||||
.zip(other.clone())
|
||||
.map(|(a, b)| a.tmax(&b).unwrap()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + PartialOrd> TensorType for Assigned<F>
|
||||
@@ -168,14 +118,6 @@ where
|
||||
fn one() -> Option<Self> {
|
||||
Some(F::ONE.into())
|
||||
}
|
||||
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
if self.evaluate() >= other.evaluate() {
|
||||
Some(*self)
|
||||
} else {
|
||||
Some(*other)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField> TensorType for Expression<F>
|
||||
@@ -189,42 +131,14 @@ where
|
||||
fn one() -> Option<Self> {
|
||||
Some(Expression::Constant(F::ONE))
|
||||
}
|
||||
|
||||
fn tmax(&self, _: &Self) -> Option<Self> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorType for Column<Advice> {}
|
||||
impl TensorType for Column<Fixed> {}
|
||||
|
||||
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<Assigned<F>, F> {
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
let mut output: Option<Self> = None;
|
||||
self.value_field().zip(other.value_field()).map(|(a, b)| {
|
||||
if a.evaluate() >= b.evaluate() {
|
||||
output = Some(self.clone());
|
||||
} else {
|
||||
output = Some(other.clone());
|
||||
}
|
||||
});
|
||||
output
|
||||
}
|
||||
}
|
||||
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<Assigned<F>, F> {}
|
||||
|
||||
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<F, F> {
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
let mut output: Option<Self> = None;
|
||||
self.value().zip(other.value()).map(|(a, b)| {
|
||||
if a >= b {
|
||||
output = Some(self.clone());
|
||||
} else {
|
||||
output = Some(other.clone());
|
||||
}
|
||||
});
|
||||
output
|
||||
}
|
||||
}
|
||||
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<F, F> {}
|
||||
|
||||
// specific types
|
||||
impl TensorType for halo2curves::pasta::Fp {
|
||||
@@ -235,10 +149,6 @@ impl TensorType for halo2curves::pasta::Fp {
|
||||
fn one() -> Option<Self> {
|
||||
Some(halo2curves::pasta::Fp::one())
|
||||
}
|
||||
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
Some((*self).max(*other))
|
||||
}
|
||||
}
|
||||
|
||||
impl TensorType for halo2curves::bn256::Fr {
|
||||
@@ -249,9 +159,15 @@ impl TensorType for halo2curves::bn256::Fr {
|
||||
fn one() -> Option<Self> {
|
||||
Some(halo2curves::bn256::Fr::one())
|
||||
}
|
||||
}
|
||||
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
Some((*self).max(*other))
|
||||
impl<F: TensorType> TensorType for &F {
|
||||
fn zero() -> Option<Self> {
|
||||
None
|
||||
}
|
||||
|
||||
fn one() -> Option<Self> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -374,7 +290,7 @@ impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
|
||||
}
|
||||
}
|
||||
|
||||
impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync>
|
||||
impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync + 'data>
|
||||
maybe_rayon::iter::IntoParallelRefMutIterator<'data> for Tensor<T>
|
||||
{
|
||||
type Iter = maybe_rayon::slice::IterMut<'data, T>;
|
||||
@@ -423,6 +339,14 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType> Tensor<&T> {
|
||||
/// Clones the tensor values into a new tensor.
|
||||
pub fn cloned(&self) -> Tensor<T> {
|
||||
let inner = self.inner.clone().into_iter().cloned().collect::<Vec<T>>();
|
||||
Tensor::new(Some(&inner), &self.dims).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Sets (copies) the tensor values to the provided ones.
|
||||
pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result<Self, TensorError> {
|
||||
@@ -554,7 +478,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
|
||||
///
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
|
||||
/// ```
|
||||
@@ -631,23 +554,23 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
// Fill remaining dimensions
|
||||
full_indices.extend((indices.len()..self.dims.len()).map(|i| 0..self.dims[i]));
|
||||
|
||||
// Pre-calculate total size and allocate result vector
|
||||
let total_size: usize = full_indices
|
||||
.iter()
|
||||
.map(|range| range.end - range.start)
|
||||
.product();
|
||||
let mut res = Vec::with_capacity(total_size);
|
||||
|
||||
// Calculate new dimensions once
|
||||
let dims: Vec<usize> = full_indices.iter().map(|e| e.end - e.start).collect();
|
||||
|
||||
// Use iterator directly without collecting into intermediate Vec
|
||||
for coord in full_indices.iter().cloned().multi_cartesian_product() {
|
||||
let index = self.get_index(&coord);
|
||||
res.push(self[index].clone());
|
||||
}
|
||||
let mut output = Tensor::new(None, &dims)?;
|
||||
|
||||
Tensor::new(Some(&res), &dims)
|
||||
let cartesian_coord: Vec<Vec<usize>> = full_indices
|
||||
.iter()
|
||||
.cloned()
|
||||
.multi_cartesian_product()
|
||||
.collect();
|
||||
|
||||
output.par_iter_mut().enumerate().for_each(|(i, e)| {
|
||||
let coord = &cartesian_coord[i];
|
||||
*e = self.get(coord);
|
||||
});
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Set a slice of the Tensor.
|
||||
@@ -753,7 +676,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// ```
|
||||
pub fn get_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
|
||||
let mut inner: Vec<T> = vec![];
|
||||
for (i, elem) in self.inner.clone().into_iter().enumerate() {
|
||||
for (i, elem) in self.inner.iter().enumerate() {
|
||||
if i % n == 0 {
|
||||
inner.push(elem.clone());
|
||||
}
|
||||
@@ -776,7 +699,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// ```
|
||||
pub fn exclude_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
|
||||
let mut inner: Vec<T> = vec![];
|
||||
for (i, elem) in self.inner.clone().into_iter().enumerate() {
|
||||
for (i, elem) in self.inner.iter().enumerate() {
|
||||
if i % n != 0 {
|
||||
inner.push(elem.clone());
|
||||
}
|
||||
@@ -812,9 +735,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
|
||||
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
|
||||
let mut offset = initial_offset;
|
||||
for (i, elem) in self.inner.clone().into_iter().enumerate() {
|
||||
for (i, elem) in self.inner.iter().enumerate() {
|
||||
if (i + offset + 1) % n == 0 {
|
||||
inner.extend(vec![elem; 1 + num_repeats]);
|
||||
inner.extend(vec![elem.clone(); 1 + num_repeats]);
|
||||
offset += num_repeats;
|
||||
} else {
|
||||
inner.push(elem.clone());
|
||||
@@ -871,16 +794,16 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
|
||||
/// let mut indices = vec![3, 4];
|
||||
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), expected);
|
||||
/// assert_eq!(a.remove_indices(&mut indices, false).unwrap(), expected);
|
||||
///
|
||||
///
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
|
||||
/// let b = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
|
||||
/// let mut indices = vec![63, 67];
|
||||
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), b);
|
||||
/// assert_eq!(a.remove_indices(&mut indices, false).unwrap(), b);
|
||||
/// ```
|
||||
pub fn remove_indices(
|
||||
&self,
|
||||
@@ -927,7 +850,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
self.dims = vec![];
|
||||
}
|
||||
if self.dims() == &[0] && new_dims.iter().product::<usize>() == 1 {
|
||||
if self.dims() == [0] && new_dims.iter().product::<usize>() == 1 {
|
||||
self.dims = Vec::from(new_dims);
|
||||
} else {
|
||||
let product = if new_dims != [0] {
|
||||
@@ -1292,33 +1215,6 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
|
||||
Tensor::new(Some(&[res]), &[1])
|
||||
}
|
||||
|
||||
/// Maps a function to tensors and enumerates in parallel
|
||||
/// ```
|
||||
/// use ezkl::tensor::{Tensor, TensorError};
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
|
||||
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
|
||||
/// ```
|
||||
pub fn par_enum_map_mut_filtered<
|
||||
F: Fn(usize) -> Result<T, E> + std::marker::Send + std::marker::Sync,
|
||||
E: Error + std::marker::Send + std::marker::Sync,
|
||||
>(
|
||||
&mut self,
|
||||
filter_indices: &std::collections::HashSet<usize>,
|
||||
f: F,
|
||||
) -> Result<(), E>
|
||||
where
|
||||
T: std::marker::Send + std::marker::Sync,
|
||||
{
|
||||
self.inner
|
||||
.par_iter_mut()
|
||||
.enumerate()
|
||||
.filter(|(i, _)| filter_indices.contains(i))
|
||||
.for_each(move |(i, e)| *e = f(i).unwrap());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType> Tensor<Tensor<T>> {
|
||||
@@ -1335,9 +1231,9 @@ impl<T: Clone + TensorType> Tensor<Tensor<T>> {
|
||||
pub fn combine(&self) -> Result<Tensor<T>, TensorError> {
|
||||
let mut dims = 0;
|
||||
let mut inner = Vec::new();
|
||||
for t in self.inner.clone().into_iter() {
|
||||
for t in self.inner.iter() {
|
||||
dims += t.len();
|
||||
inner.extend(t.inner);
|
||||
inner.extend(t.inner.clone());
|
||||
}
|
||||
Tensor::new(Some(&inner), &[dims])
|
||||
}
|
||||
@@ -1808,7 +1704,7 @@ impl DataFormat {
|
||||
match self {
|
||||
DataFormat::NHWC => DataFormat::NCHW,
|
||||
DataFormat::HWC => DataFormat::CHW,
|
||||
_ => self.clone(),
|
||||
_ => *self,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1908,7 +1804,7 @@ impl KernelFormat {
|
||||
match self {
|
||||
KernelFormat::HWIO => KernelFormat::OIHW,
|
||||
KernelFormat::OHWI => KernelFormat::OIHW,
|
||||
_ => self.clone(),
|
||||
_ => *self,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2030,6 +1926,9 @@ mod tests {
|
||||
fn tensor_slice() {
|
||||
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
|
||||
assert_eq!(
|
||||
a.get_slice(&[0..2, 0..1]).unwrap(),
|
||||
b.get_slice(&[0..2, 0..1]).unwrap()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,7 +160,7 @@ pub fn decompose(
|
||||
///
|
||||
/// let result = trilu(&a, 0, false).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 3, 4, 5, 6]), &[1, 3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, -1, true).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 0, 6]), &[1, 3, 2]).unwrap();
|
||||
@@ -168,7 +168,7 @@ pub fn decompose(
|
||||
///
|
||||
/// let result = trilu(&a, -1, false).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 3, 0, 5, 6]), &[1, 3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let a = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
@@ -188,7 +188,7 @@ pub fn decompose(
|
||||
///
|
||||
/// let result = trilu(&a, 0, false).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 4, 5, 0]), &[1, 2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, -1, true).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[1, 2, 3]).unwrap();
|
||||
@@ -196,7 +196,7 @@ pub fn decompose(
|
||||
///
|
||||
/// let result = trilu(&a, -1, false).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 4, 0, 0]), &[1, 2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let a = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9]),
|
||||
@@ -216,7 +216,7 @@ pub fn decompose(
|
||||
///
|
||||
/// let result = trilu(&a, 0, false).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 4, 5, 0, 7, 8, 9]), &[1, 3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, -1, true).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 8, 9]), &[1, 3, 3]).unwrap();
|
||||
@@ -224,7 +224,7 @@ pub fn decompose(
|
||||
///
|
||||
/// let result = trilu(&a, -1, false).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 4, 0, 0, 7, 8, 0]), &[1, 3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn trilu<T: TensorType + std::marker::Send + std::marker::Sync>(
|
||||
a: &Tensor<T>,
|
||||
@@ -916,11 +916,14 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 7]), &[2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
pub fn gather_nd<T: TensorType + Send + Sync>(
|
||||
input: &Tensor<T>,
|
||||
pub fn gather_nd<'a, T: TensorType + Send + Sync + 'a>(
|
||||
input: &'a Tensor<T>,
|
||||
index: &Tensor<usize>,
|
||||
batch_dims: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
) -> Result<Tensor<T>, TensorError>
|
||||
where
|
||||
&'a T: TensorType,
|
||||
{
|
||||
// Calculate the output tensor size
|
||||
let index_dims = index.dims().to_vec();
|
||||
let input_dims = input.dims().to_vec();
|
||||
@@ -1108,11 +1111,14 @@ pub fn gather_nd<T: TensorType + Send + Sync>(
|
||||
/// assert_eq!(result, expected);
|
||||
/// ````
|
||||
///
|
||||
pub fn scatter_nd<T: TensorType + Send + Sync>(
|
||||
pub fn scatter_nd<'a, T: TensorType + Send + Sync + 'a>(
|
||||
input: &Tensor<T>,
|
||||
index: &Tensor<usize>,
|
||||
src: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
src: &'a Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError>
|
||||
where
|
||||
&'a T: TensorType,
|
||||
{
|
||||
// Calculate the output tensor size
|
||||
let index_dims = index.dims().to_vec();
|
||||
let input_dims = input.dims().to_vec();
|
||||
@@ -1183,12 +1189,12 @@ pub fn abs<T: TensorType + Add<Output = T> + std::cmp::Ord + Neg<Output = T>>(
|
||||
/// use ezkl::tensor::ops::intercalate_values;
|
||||
///
|
||||
/// let tensor = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
|
||||
/// let result = intercalate_values(&tensor, 0, 2, 1).unwrap();
|
||||
/// let result = intercalate_values(&tensor, &0, 2, 1).unwrap();
|
||||
///
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 2, 3, 0, 4]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = intercalate_values(&expected, 0, 2, 0).unwrap();
|
||||
/// let result = intercalate_values(&expected, &0, 2, 0).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 2, 0, 0, 0, 3, 0, 4]), &[3, 3]).unwrap();
|
||||
///
|
||||
/// assert_eq!(result, expected);
|
||||
@@ -1196,7 +1202,7 @@ pub fn abs<T: TensorType + Add<Output = T> + std::cmp::Ord + Neg<Output = T>>(
|
||||
/// ```
|
||||
pub fn intercalate_values<T: TensorType>(
|
||||
tensor: &Tensor<T>,
|
||||
value: T,
|
||||
value: &T,
|
||||
stride: usize,
|
||||
axis: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
@@ -1494,7 +1500,7 @@ pub fn slice<T: TensorType + Send + Sync>(
|
||||
}
|
||||
}
|
||||
|
||||
t.get_slice(&slice)
|
||||
Ok(t.get_slice(&slice)?)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------------------------------------
|
||||
@@ -1859,14 +1865,14 @@ pub mod nonlinearities {
|
||||
/// Some(&[4, 25, 8, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = rsqrt(&x, 1.0);
|
||||
/// let result = rsqrt(&x, 1.0, f64::EPSILON);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 0, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn rsqrt(a: &Tensor<IntegerRep>, scale_input: f64) -> Tensor<IntegerRep> {
|
||||
pub fn rsqrt(a: &Tensor<IntegerRep>, scale_input: f64, eps: f64) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale_input;
|
||||
let fout = scale_input / (kix.sqrt() + f64::EPSILON);
|
||||
let fout = scale_input / (kix.sqrt() + eps);
|
||||
let rounded = fout.round();
|
||||
Ok::<_, TensorError>(rounded as IntegerRep)
|
||||
})
|
||||
@@ -2339,15 +2345,20 @@ pub mod nonlinearities {
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = 2_f64;
|
||||
/// let result = recip(&x, 1.0, k);
|
||||
/// let result = recip(&x, 1.0, k, f64::EPSILON);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 1, 0, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn recip(a: &Tensor<IntegerRep>, input_scale: f64, out_scale: f64) -> Tensor<IntegerRep> {
|
||||
pub fn recip(
|
||||
a: &Tensor<IntegerRep>,
|
||||
input_scale: f64,
|
||||
out_scale: f64,
|
||||
eps: f64,
|
||||
) -> Tensor<IntegerRep> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rescaled = (a_i as f64) / input_scale;
|
||||
let denom = if rescaled == 0_f64 {
|
||||
(1_f64) / (rescaled + f64::EPSILON)
|
||||
(1_f64) / (rescaled + eps)
|
||||
} else {
|
||||
(1_f64) / (rescaled)
|
||||
};
|
||||
@@ -2366,16 +2377,16 @@ pub mod nonlinearities {
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use ezkl::tensor::ops::nonlinearities::zero_recip;
|
||||
/// let k = 2_f64;
|
||||
/// let result = zero_recip(1.0);
|
||||
/// let result = zero_recip(1.0, f64::EPSILON);
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4503599627370496]), &[1]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn zero_recip(out_scale: f64) -> Tensor<IntegerRep> {
|
||||
pub fn zero_recip(out_scale: f64, eps: f64) -> Tensor<IntegerRep> {
|
||||
let a = Tensor::<IntegerRep>::new(Some(&[0]), &[1]).unwrap();
|
||||
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let rescaled = a_i as f64;
|
||||
let denom = (1_f64) / (rescaled + f64::EPSILON);
|
||||
let denom = (1_f64) / (rescaled + eps);
|
||||
let d_inv_x = out_scale * denom;
|
||||
Ok::<_, TensorError>(d_inv_x.round() as IntegerRep)
|
||||
})
|
||||
@@ -2409,20 +2420,20 @@ pub mod accumulated {
|
||||
/// Some(&[25, 35]),
|
||||
/// &[2],
|
||||
/// ).unwrap();
|
||||
/// assert_eq!(dot(&[x, y], 1).unwrap(), expected);
|
||||
/// assert_eq!(dot(&x, &y, 1).unwrap(), expected);
|
||||
/// ```
|
||||
pub fn dot<T: TensorType + Mul<Output = T> + Add<Output = T>>(
|
||||
inputs: &[Tensor<T>; 2],
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
chunk_size: usize,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
if inputs[0].clone().len() != inputs[1].clone().len() {
|
||||
if a.len() != b.len() {
|
||||
return Err(TensorError::DimMismatch("dot".to_string()));
|
||||
}
|
||||
let (a, b): (Tensor<T>, Tensor<T>) = (inputs[0].clone(), inputs[1].clone());
|
||||
|
||||
let transcript: Tensor<T> = a
|
||||
.iter()
|
||||
.zip(b)
|
||||
.zip(b.iter())
|
||||
.chunks(chunk_size)
|
||||
.into_iter()
|
||||
.scan(T::zero().unwrap(), |acc, chunk| {
|
||||
|
||||
@@ -280,9 +280,17 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Vec<ValType<F>>> for ValTenso
|
||||
fn from(t: Vec<ValType<F>>) -> ValTensor<F> {
|
||||
ValTensor::Value {
|
||||
inner: t.clone().into_iter().into(),
|
||||
|
||||
dims: vec![t.len()],
|
||||
scale: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> From<Vec<&ValType<F>>> for ValTensor<F> {
|
||||
fn from(t: Vec<&ValType<F>>) -> ValTensor<F> {
|
||||
ValTensor::Value {
|
||||
inner: t.clone().into_iter().cloned().into(),
|
||||
dims: vec![t.len()],
|
||||
scale: 1,
|
||||
}
|
||||
}
|
||||
@@ -640,6 +648,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
/// # Returns
|
||||
/// A tensor containing the base-n decomposition of each value
|
||||
pub fn decompose(&self, base: usize, n: usize) -> Result<Self, TensorError> {
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims.push(n + 1);
|
||||
|
||||
let res = self
|
||||
.get_inner()?
|
||||
.par_iter()
|
||||
@@ -665,9 +676,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>();
|
||||
|
||||
let mut tensor = Tensor::from(res?.into_iter().flatten().collect::<Vec<_>>().into_iter());
|
||||
let mut dims = self.dims().to_vec();
|
||||
dims.push(n + 1);
|
||||
let mut tensor = res?.into_iter().flatten().collect::<Tensor<_>>();
|
||||
|
||||
tensor.reshape(&dims)?;
|
||||
|
||||
@@ -1043,119 +1052,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes constant zero values from the tensor
|
||||
/// Uses parallel processing for tensors larger than a threshold
|
||||
pub fn remove_const_zero_values(&mut self) {
|
||||
let size_threshold = 1_000_000; // Tuned using benchmarks
|
||||
|
||||
if self.len() < size_threshold {
|
||||
// Single-threaded for small tensors
|
||||
match self {
|
||||
ValTensor::Value { inner: v, dims, .. } => {
|
||||
*v = v
|
||||
.clone()
|
||||
.into_iter()
|
||||
.filter_map(|e| {
|
||||
if let ValType::Constant(r) = e {
|
||||
if r == F::ZERO {
|
||||
return None;
|
||||
}
|
||||
} else if let ValType::AssignedConstant(_, r) = e {
|
||||
if r == F::ZERO {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
Some(e)
|
||||
})
|
||||
.collect();
|
||||
*dims = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {}
|
||||
}
|
||||
} else {
|
||||
// Parallel processing for large tensors
|
||||
let num_cores = std::thread::available_parallelism()
|
||||
.map(|n| n.get())
|
||||
.unwrap_or(1);
|
||||
let chunk_size = (self.len() / num_cores).max(100_000);
|
||||
|
||||
match self {
|
||||
ValTensor::Value { inner: v, dims, .. } => {
|
||||
*v = v
|
||||
.par_chunks_mut(chunk_size)
|
||||
.flat_map(|chunk| {
|
||||
chunk.par_iter_mut().filter_map(|e| {
|
||||
if let ValType::Constant(r) = e {
|
||||
if *r == F::ZERO {
|
||||
return None;
|
||||
}
|
||||
} else if let ValType::AssignedConstant(_, r) = e {
|
||||
if *r == F::ZERO {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
Some(e.clone())
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
*dims = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the indices of all constant zero values
|
||||
/// Uses parallel processing for large tensors
|
||||
///
|
||||
/// # Returns
|
||||
/// A vector of indices where constant zero values are located
|
||||
pub fn get_const_zero_indices(&self) -> Vec<usize> {
|
||||
let size_threshold = 1_000_000;
|
||||
|
||||
if self.len() < size_threshold {
|
||||
// Single-threaded for small tensors
|
||||
match &self {
|
||||
ValTensor::Value { inner: v, .. } => v
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, e)| match e {
|
||||
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
|
||||
(*r == F::ZERO).then_some(i)
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
.collect(),
|
||||
ValTensor::Instance { .. } => vec![],
|
||||
}
|
||||
} else {
|
||||
// Parallel processing for large tensors
|
||||
let num_cores = std::thread::available_parallelism()
|
||||
.map(|n| n.get())
|
||||
.unwrap_or(1);
|
||||
let chunk_size = (self.len() / num_cores).max(100_000);
|
||||
|
||||
match &self {
|
||||
ValTensor::Value { inner: v, .. } => v
|
||||
.par_chunks(chunk_size)
|
||||
.enumerate()
|
||||
.flat_map(|(chunk_idx, chunk)| {
|
||||
chunk
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.filter_map(move |(i, e)| match e {
|
||||
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
|
||||
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
|
||||
}
|
||||
_ => None,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>(),
|
||||
ValTensor::Instance { .. } => vec![],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Gets the indices of all constant values
|
||||
/// Uses parallel processing for large tensors
|
||||
///
|
||||
@@ -1276,7 +1172,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
/// Returns an error if called on an Instance tensor
|
||||
pub fn intercalate_values(
|
||||
&mut self,
|
||||
value: ValType<F>,
|
||||
value: &ValType<F>,
|
||||
stride: usize,
|
||||
axis: usize,
|
||||
) -> Result<(), TensorError> {
|
||||
@@ -1368,10 +1264,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Concatenates two tensors along the first dimension
|
||||
pub fn concat(&self, other: Self) -> Result<Self, TensorError> {
|
||||
pub fn concat(&self, other: &Self) -> Result<Self, TensorError> {
|
||||
let res = match (self, other) {
|
||||
(ValTensor::Value { inner: v1, .. }, ValTensor::Value { inner: v2, .. }) => {
|
||||
ValTensor::from(Tensor::new(Some(&[v1.clone(), v2]), &[2])?.combine()?)
|
||||
ValTensor::from(Tensor::new(Some(&[v1.clone(), v2.clone()]), &[2])?.combine()?)
|
||||
}
|
||||
_ => {
|
||||
return Err(TensorError::WrongMethod);
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use log::{debug, error, warn};
|
||||
|
||||
use crate::circuit::{CheckMode, region::ConstantsMap};
|
||||
use crate::circuit::{region::ConstantsMap, CheckMode};
|
||||
|
||||
use super::*;
|
||||
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
|
||||
@@ -381,50 +379,6 @@ impl VarTensor {
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns values from a ValTensor to this tensor, excluding specified positions
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `region` - The region to assign values in
|
||||
/// * `offset` - Base offset for assignments
|
||||
/// * `values` - The ValTensor containing values to assign
|
||||
/// * `omissions` - Set of positions to skip during assignment
|
||||
/// * `constants` - Map for tracking constant assignments
|
||||
///
|
||||
/// # Returns
|
||||
/// The assigned ValTensor or an error if assignment fails
|
||||
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
omissions: &HashSet<usize>,
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
|
||||
let mut assigned_coord = 0;
|
||||
let mut res: ValTensor<F> = match values {
|
||||
ValTensor::Instance { .. } => {
|
||||
error!(
|
||||
"assignment with omissions is not supported on instance columns. increase K if you require more rows."
|
||||
);
|
||||
Err(halo2_proofs::plonk::Error::Synthesis)
|
||||
}
|
||||
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
|
||||
v.enum_map(|coord, k| {
|
||||
if omissions.contains(&coord) {
|
||||
return Ok::<_, halo2_proofs::plonk::Error>(k);
|
||||
}
|
||||
let cell =
|
||||
self.assign_value(region, offset, k.clone(), assigned_coord, constants)?;
|
||||
assigned_coord += 1;
|
||||
Ok::<_, halo2_proofs::plonk::Error>(cell)
|
||||
})?
|
||||
.into(),
|
||||
),
|
||||
}?;
|
||||
res.set_scale(values.scale());
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Assigns values from a ValTensor to this tensor
|
||||
///
|
||||
/// # Arguments
|
||||
|
||||
Binary file not shown.
@@ -3,10 +3,10 @@
|
||||
mod native_tests {
|
||||
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::Commitments;
|
||||
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
|
||||
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
|
||||
use ezkl::pfsys::Snark;
|
||||
use ezkl::Commitments;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2curves::bn256::Bn256;
|
||||
use lazy_static::lazy_static;
|
||||
|
||||
@@ -272,16 +272,6 @@ mod py_tests {
|
||||
anvil_child.kill().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn postgres_notebook_() {
|
||||
crate::py_tests::init_binary();
|
||||
let test_dir: TempDir = TempDir::new("mean_postgres").unwrap();
|
||||
let path = test_dir.path().to_str().unwrap();
|
||||
crate::py_tests::mv_test_(path, "mean_postgres.ipynb");
|
||||
run_notebook(path, "mean_postgres.ipynb");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tictactoe_autoencoder_notebook_() {
|
||||
crate::py_tests::init_binary();
|
||||
|
||||
Reference in New Issue
Block a user