Compare commits

...

29 Commits

Author SHA1 Message Date
dante
67b9a5163d Update rust.yml 2025-05-26 10:04:04 -04:00
dante
55df0aa58d bump 2025-05-25 14:00:45 -04:00
dante
a37193e405 Update Cargo.toml 2025-05-24 21:07:07 -04:00
dante
281a1f82ef mimalloc default 2025-05-24 21:05:37 -04:00
dante
351595cd8b Update Cargo.toml 2025-05-24 15:37:24 -04:00
dante
7fa7832149 switch to jemalloc 2025-05-24 15:32:49 -04:00
dante
09f810bf8a Update mod.rs 2025-05-24 14:41:20 -04:00
dante
387d8ed884 rv 2025-05-23 20:20:34 -04:00
dante
f8272dc98a patch 2025-05-23 17:40:19 -04:00
dante
247b35e89d swap 2025-05-23 12:41:40 -04:00
dante
5c6c75a379 bump 2025-05-23 10:27:40 -04:00
dante
0a7ef97255 mem savings 2025-05-23 09:54:42 -04:00
dante
a9edb46433 rv 2025-05-23 09:47:58 -04:00
dante
abfbf70175 Update mod.rs 2025-05-22 23:21:16 -04:00
dante
82ed09a366 Update model.rs 2025-05-22 23:13:06 -04:00
dante
a4ff6d7331 Update val.rs 2025-05-22 23:04:28 -04:00
dante
7682149f3c fixes 2025-05-22 22:34:41 -04:00
dante
5b317b8b29 Revert "Reapply "chore: tensorview""
This reverts commit 72e6c1daa2.
2025-05-22 18:52:07 -04:00
dante
72e6c1daa2 Reapply "chore: tensorview"
This reverts commit 08bbc0b283.
2025-05-22 16:49:25 -04:00
dante
71d340f6bb Update model.rs 2025-05-22 16:43:19 -04:00
dante
065240b288 better errs 2025-05-22 16:40:00 -04:00
dante
596cb31e67 Update model.rs 2025-05-22 16:34:31 -04:00
dante
08bbc0b283 Revert "chore: tensorview"
This reverts commit ff24b8326b.
2025-05-22 16:34:26 -04:00
dante
ff24b8326b chore: tensorview 2025-05-22 14:35:25 -04:00
dante
8b57d54638 Update rust.yml 2025-05-20 18:39:05 -04:00
dante
c2354e519a fix 2025-05-20 16:17:02 -04:00
dante
0269b1e316 Update rust.yml 2025-05-20 15:52:19 -04:00
dante
0c20656238 oops 2025-05-20 15:41:20 -04:00
dante
2fd0931e5e rm lots of clones 2025-05-20 14:44:59 -04:00
46 changed files with 38862 additions and 1443 deletions

View File

@@ -198,15 +198,15 @@ jobs:
- name: Build release binary (no asm or metal)
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features mimalloc
- name: Build release binary (asm)
if: matrix.build == 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm,mimalloc
- name: Build release binary (metal)
if: matrix.build == 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal,mimalloc
- name: Strip release binary
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'

View File

@@ -245,25 +245,6 @@ jobs:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
foudry-solidity-tests:
permissions:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
with:
persist-credentials: false
submodules: recursive
- name: Install Foundry
uses: foundry-rs/foundry-toolchain@3b74dacdda3c0b763089addb99ed86bc3800e68b
- name: Run tests
run: |
cd tests/foundry
forge install https://github.com/foundry-rs/forge-std --no-git --no-commit
forge test -vvvv --fuzz-runs 64
mock-proving-tests:
permissions:
contents: read
@@ -372,9 +353,6 @@ jobs:
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
- name: Build @ezkljs/verify package
run: |
cd in-browser-evm-verifier
@@ -497,9 +475,6 @@ jobs:
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
@@ -744,24 +719,6 @@ jobs:
permissions:
contents: read
runs-on: large-self-hosted
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres
env:
POSTGRES_USER: ubuntu
POSTGRES_HOST_AUTH_METHOD: trust
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
-v /var/run/postgresql:/var/run/postgresql
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:

41
Cargo.lock generated
View File

@@ -929,7 +929,7 @@ dependencies = [
"bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools 0.11.0",
"itertools 0.12.1",
"lazy_static",
"lazycell",
"log",
@@ -1316,7 +1316,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "117725a109d387c937a1533ce01b450cbde6b88abceea8473c4d7a85853cda3c"
dependencies = [
"lazy_static",
"windows-sys 0.48.0",
"windows-sys 0.59.0",
]
[[package]]
@@ -1939,7 +1939,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "976dd42dc7e85965fe702eb8164f21f450704bdde31faefd6471dba214cb594e"
dependencies = [
"libc",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@@ -2014,6 +2014,7 @@ dependencies = [
"indicatif",
"instant",
"itertools 0.10.5",
"jemallocator",
"lazy_static",
"log",
"maybe-rayon",
@@ -3201,7 +3202,7 @@ checksum = "e04d7f318608d35d4b61ddd75cbdaee86b023ebe2bd5a66ee0915f0bf93095a9"
dependencies = [
"hermit-abi 0.5.0",
"libc",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@@ -3261,6 +3262,26 @@ version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "jemalloc-sys"
version = "0.5.4+5.3.0-patched"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac6c1946e1cea1788cbfde01c993b52a10e2da07f4bac608228d1bed20bfebf2"
dependencies = [
"cc",
"libc",
]
[[package]]
name = "jemallocator"
version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0de374a9f8e63150e6f5e8a60cc14c668226d7a347d8aee1a45766e3c4dd3bc"
dependencies = [
"jemalloc-sys",
"libc",
]
[[package]]
name = "jiff"
version = "0.2.10"
@@ -3395,7 +3416,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc2f4eb4bc735547cfed7c0a4922cbd04a4655978c09b54f1f7b228750664c34"
dependencies = [
"cfg-if",
"windows-targets 0.48.5",
"windows-targets 0.52.6",
]
[[package]]
@@ -4602,7 +4623,7 @@ dependencies = [
"once_cell",
"socket2",
"tracing",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@@ -5084,7 +5105,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.4.15",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@@ -5097,7 +5118,7 @@ dependencies = [
"errno",
"libc",
"linux-raw-sys 0.9.4",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@@ -5801,7 +5822,7 @@ dependencies = [
"getrandom 0.3.2",
"once_cell",
"rustix 1.0.5",
"windows-sys 0.52.0",
"windows-sys 0.59.0",
]
[[package]]
@@ -6930,7 +6951,7 @@ version = "0.1.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb"
dependencies = [
"windows-sys 0.48.0",
"windows-sys 0.59.0",
]
[[package]]

View File

@@ -88,9 +88,9 @@ pyo3-log = { version = "0.12.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
objc = { version = "0.2.4", optional = true }
mimalloc = { version = "0.1", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
jemallocator = { version = "0.5", optional = true }
mimalloc = { version = "0.1", optional = true }
# universal bindings
uniffi = { version = "=0.28.0", optional = true }
getrandom = { version = "0.2.8", optional = true }
@@ -243,7 +243,6 @@ ezkl = [
"dep:lazy_static",
"dep:tokio",
"dep:openssl",
"dep:mimalloc",
"dep:chrono",
"dep:sha256",
"dep:clap_complete",
@@ -270,12 +269,15 @@ no-banner = []
no-update = []
macos-metal = ["halo2_proofs/macos"]
ios-metal = ["halo2_proofs/ios"]
jemalloc = ["dep:jemallocator"]
mimalloc = ["dep:mimalloc"]
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
[profile.release]
# debug = true
rustflags = ["-C", "relocation-model=pic"]
lto = "fat"
codegen-units = 1

View File

@@ -68,7 +68,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
&[&self.image, &self.kernel, &self.bias],
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
stride: vec![1; 2],

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -59,7 +60,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -61,7 +62,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),

View File

@@ -17,6 +17,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -86,13 +87,13 @@ impl Circuit<Fr> for MyCircuit {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[output.unwrap()],
&[&output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -17,6 +17,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -87,13 +88,13 @@ impl Circuit<Fr> for MyCircuit {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[output.unwrap()],
&[&output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -59,7 +60,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.unwrap();

View File

@@ -63,7 +63,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&[self.image.clone()],
&[&self.image],
Box::new(HybridOp::SumPool {
padding: vec![(0, 0); 2],
stride: vec![1, 1],

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -57,7 +58,11 @@ impl Circuit<Fr> for MyCircuit {
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.unwrap();
Ok(())
},

View File

@@ -16,6 +16,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -58,7 +59,11 @@ impl Circuit<Fr> for MyCircuit {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Pow(4)))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Pow(4)),
)
.unwrap();
Ok(())
},

View File

@@ -70,7 +70,7 @@ impl Circuit<Fr> for NLCircuit {
config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,

View File

@@ -67,7 +67,7 @@ impl Circuit<Fr> for NLCircuit {
config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -32,7 +32,6 @@ use mnist::*;
use rand::rngs::OsRng;
use std::marker::PhantomData;
mod params;
const K: usize = 20;
@@ -216,11 +215,7 @@ where
.layer_config
.layout(
&mut region,
&[
self.input.clone(),
self.l0_params[0].clone(),
self.l0_params[1].clone(),
],
&[&self.input, &self.l0_params[0], &self.l0_params[1]],
Box::new(op),
)
.unwrap();
@@ -229,7 +224,7 @@ where
.layer_config
.layout(
&mut region,
&[x.unwrap()],
&[&x.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -241,7 +236,7 @@ where
.layer_config
.layout(
&mut region,
&[x.unwrap()],
&[&x.unwrap()],
Box::new(LookupOp::Div { denom: 32.0.into() }),
)
.unwrap()
@@ -253,7 +248,7 @@ where
.layer_config
.layout(
&mut region,
&[self.l2_params[0].clone(), x],
&[&self.l2_params[0], &x],
Box::new(PolyOp::Einsum {
equation: "ij,j->ik".to_string(),
}),
@@ -265,7 +260,7 @@ where
.layer_config
.layout(
&mut region,
&[x, self.l2_params[1].clone()],
&[&x, &self.l2_params[1]],
Box::new(PolyOp::Add),
)
.unwrap()

View File

@@ -117,10 +117,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[
self.l0_params[0].clone().try_into().unwrap(),
self.input.clone(),
],
&[&self.l0_params[0].clone().try_into().unwrap(), &self.input],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -135,7 +132,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x, self.l0_params[1].clone().try_into().unwrap()],
&[&x, &self.l0_params[1].clone().try_into().unwrap()],
Box::new(PolyOp::Add),
)
.unwrap()
@@ -147,7 +144,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x],
&[&x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
@@ -163,7 +160,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[self.l2_params[0].clone().try_into().unwrap(), x],
&[&self.l2_params[0].clone().try_into().unwrap(), &x],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -178,7 +175,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x, self.l2_params[1].clone().try_into().unwrap()],
&[&x, &self.l2_params[1].clone().try_into().unwrap()],
Box::new(PolyOp::Add),
)
.unwrap()
@@ -190,7 +187,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x],
&[&x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
@@ -203,7 +200,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x.unwrap()],
&[&x.unwrap()],
Box::new(LookupOp::Div {
denom: ezkl::circuit::utils::F32::from(128.),
}),

Binary file not shown.

37795
examples/onnx/fr_age/lol.txt Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,5 @@
// ignore file if compiling for wasm
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc::MiMalloc;
#[global_allocator]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
static GLOBAL: MiMalloc = MiMalloc;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::{CommandFactory, Parser};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]

View File

@@ -1,34 +1,34 @@
use crate::Commitments;
use crate::RunArgs;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::circuit::modules::Module;
use crate::circuit::modules::polycommit::PolyCommitChip;
use crate::circuit::modules::poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
spec::{POSEIDON_RATE, POSEIDON_WIDTH, PoseidonSpec},
};
use crate::circuit::modules::Module;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::commands::*;
use crate::fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt};
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::graph::TestDataSource;
use crate::graph::{
GraphCircuit, GraphSettings, Model, Visibility, quantize_float, scale_to_multiplier,
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
};
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::{
ProofType, TranscriptType, load_pk, load_vk, save_params, save_vk,
srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
ProofType, TranscriptType,
};
use crate::Commitments;
use crate::RunArgs;
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::{Bn256, Fq, Fr, G1, G1Affine};
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
use pyo3::exceptions::{PyIOError, PyRuntimeError};
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3_log;
use pyo3_stub_gen::{
TypeInfo, define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction,
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction, TypeInfo,
};
use snark_verifier::util::arithmetic::PrimeField;
use std::collections::HashSet;

View File

@@ -962,7 +962,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
pub fn layout(
&mut self,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, CircuitError> {
op.layout(self, region, values)

View File

@@ -1,7 +1,7 @@
use super::*;
use crate::{
circuit::{layouts, utils},
fieldutils::{IntegerRep, integer_rep_to_felt},
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::multiplier_to_scale,
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
};
@@ -109,13 +109,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
HybridOp::Greater { .. }
| HybridOp::Less { .. }
| HybridOp::Equals { .. }
| HybridOp::GreaterEqual { .. }
HybridOp::Greater
| HybridOp::Less
| HybridOp::Equals
| HybridOp::GreaterEqual
| HybridOp::Max
| HybridOp::Min
| HybridOp::LessEqual { .. } => {
| HybridOp::LessEqual => {
vec![0, 1]
}
_ => vec![],
@@ -213,7 +213,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::Rsqrt {
@@ -362,10 +362,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
HybridOp::Greater { .. }
| HybridOp::GreaterEqual { .. }
| HybridOp::Less { .. }
| HybridOp::LessEqual { .. }
HybridOp::Greater
| HybridOp::GreaterEqual
| HybridOp::Less
| HybridOp::LessEqual
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,

File diff suppressed because it is too large Load Diff

View File

@@ -186,7 +186,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(layouts::nonlinearity(
config,

View File

@@ -49,7 +49,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError>;
/// Returns the scale of the output of the operation.
@@ -209,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = values[0].clone();
if !value.all_prev_assigned() {
@@ -223,12 +223,29 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
true,
)?))
}
_ => Ok(Some(super::layouts::identity(
config,
region,
values[..].try_into()?,
self.decomp,
)?)),
_ => {
if self.decomp {
log::debug!("constraining input to be decomp");
Ok(Some(
super::layouts::decompose(
config,
region,
values[..].try_into()?,
&region.base(),
&region.legs(),
false,
)?
.1,
))
} else {
log::debug!("constraining input to be identity");
Ok(Some(super::layouts::identity(
config,
region,
values[..].try_into()?,
)?))
}
}
}
} else {
Ok(Some(value))
@@ -263,7 +280,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
&self,
_: &mut crate::circuit::BaseConfig<F>,
_: &mut RegionCtx<F>,
_: &[ValTensor<F>],
_: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Err(super::CircuitError::UnsupportedOp)
}
@@ -319,8 +336,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
}
impl<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self
@@ -333,20 +355,20 @@ impl<
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
_: &[ValTensor<F>],
_: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = if let Some(value) = &self.pre_assigned_val {
value.clone()
} else {
self.quantized_values.clone().try_into()?
};
// we gotta constrain it once if its used multiple times
Ok(Some(layouts::identity(
config,
region,
&[value],
self.decomp,
)?))
Ok(Some(if self.decomp {
log::debug!("constraining constant to be decomp");
super::layouts::decompose(config, region, &[&value], &region.base(), &region.legs(), false)?.1
} else {
log::debug!("constraining constant to be identity");
super::layouts::identity(config, region, &[&value])?
}))
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {

View File

@@ -108,8 +108,13 @@ pub enum PolyOp {
}
impl<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for PolyOp
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
@@ -203,11 +208,11 @@ impl<
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?, true)?,
PolyOp::LeakyReLU { slope, scale } => {
layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)?
}
@@ -335,9 +340,7 @@ impl<
PolyOp::Mult => {
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
}
PolyOp::Identity { .. } => {
layouts::identity(config, region, values[..].try_into()?, false)?
}
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {
@@ -416,14 +419,14 @@ impl<
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
PolyOp::Sign { .. } => 0,
PolyOp::Sign => 0,
_ => in_scales[0],
};
Ok(scale)
}
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) {
if matches!(self, PolyOp::Add | PolyOp::Sub) {
vec![0, 1]
} else if matches!(self, PolyOp::Iff) {
vec![1, 2]

View File

@@ -10,7 +10,6 @@ use halo2_proofs::{
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use maybe_rayon::iter::ParallelExtend;
use std::{
cell::RefCell,
@@ -462,15 +461,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
/// Update the max and min from inputs
pub fn update_max_min_lookup_inputs(
&mut self,
inputs: &[ValTensor<F>],
inputs: &ValTensor<F>,
) -> Result<(), CircuitError> {
let (mut min, mut max) = (0, 0);
for i in inputs {
max = max.max(i.int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.int_evals()?.into_iter().min().unwrap_or_default());
}
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
let int_eval = inputs.int_evals()?;
let max = int_eval.iter().max().unwrap_or(&0);
let min = int_eval.iter().min().unwrap_or(&0);
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(*max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(*min);
Ok(())
}
@@ -505,10 +503,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
/// add used lookup
pub fn add_used_lookup(
&mut self,
lookup: LookupOp,
inputs: &[ValTensor<F>],
lookup: &LookupOp,
inputs: &ValTensor<F>,
) -> Result<(), CircuitError> {
self.statistics.used_lookups.insert(lookup);
self.statistics.used_lookups.insert(lookup.clone());
self.update_max_min_lookup_inputs(inputs)
}
@@ -642,34 +640,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.assign_dynamic_lookup(var, values)
}
/// Assign a valtensor to a vartensor
pub fn assign_with_omissions(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
ommissions: &HashSet<usize>,
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
Ok(var.assign_with_omissions(
&mut region.borrow_mut(),
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)?)
} else {
let mut values_clone = values.clone();
let mut indices = ommissions.clone().into_iter().collect_vec();
values_clone.remove_indices(&mut indices, false)?;
let values_map = values.create_constants_map();
self.assigned_constants.par_extend(values_map);
Ok(values.clone())
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication_unconstrained(
&mut self,

View File

@@ -9,6 +9,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::Fr as F;
use halo2curves::ff::{Field, PrimeField};
use itertools::Itertools;
#[cfg(not(any(
all(target_arch = "wasm32", target_os = "unknown"),
not(feature = "ezkl")
@@ -64,7 +65,7 @@ mod matmul {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -141,7 +142,7 @@ mod matmul_col_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -215,7 +216,7 @@ mod matmul_col_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -302,7 +303,7 @@ mod matmul_col_ultra_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -380,6 +381,7 @@ mod matmul_col_ultra_overflow {
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use super::*;
@@ -422,7 +424,7 @@ mod matmul_col_ultra_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -533,7 +535,7 @@ mod dot {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -610,7 +612,7 @@ mod dot_col_overflow_triple_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -683,7 +685,7 @@ mod dot_col_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -756,7 +758,7 @@ mod sum {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -826,7 +828,7 @@ mod sum_col_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -895,7 +897,7 @@ mod sum_col_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -966,7 +968,7 @@ mod composition {
let _ = config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -975,7 +977,7 @@ mod composition {
let _ = config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -984,7 +986,7 @@ mod composition {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -1061,7 +1063,7 @@ mod conv {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1218,7 +1220,7 @@ mod conv_col_ultra_overflow {
config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone()],
&[&self.image, &self.kernel],
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1377,7 +1379,7 @@ mod conv_relu_col_ultra_overflow {
let output = config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone()],
&[&self.image, &self.kernel],
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1390,7 +1392,7 @@ mod conv_relu_col_ultra_overflow {
let _output = config
.layout(
&mut region,
&[output.unwrap().unwrap()],
&[&output.unwrap().unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -1517,7 +1519,11 @@ mod add_w_shape_casting {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -1584,7 +1590,11 @@ mod add {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -1671,8 +1681,8 @@ mod dynamic_lookup {
layouts::dynamic_lookup(
&config,
&mut region,
&self.lookups[i],
&self.tables[i],
&self.lookups[i].iter().collect_vec().try_into().unwrap(),
&self.tables[i].iter().collect_vec().try_into().unwrap(),
)
.map_err(|_| Error::Synthesis)?;
}
@@ -1767,8 +1777,8 @@ mod shuffle {
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [[ValTensor<F>; 1]; NUM_LOOP],
references: [[ValTensor<F>; 1]; NUM_LOOP],
inputs: [ValTensor<F>; NUM_LOOP],
references: [ValTensor<F>; NUM_LOOP],
_marker: PhantomData<F>,
}
@@ -1818,15 +1828,15 @@ mod shuffle {
layouts::shuffles(
&config,
&mut region,
&self.inputs[i],
&self.references[i],
&[&self.inputs[i]],
&[&self.references[i]],
layouts::SortCollisionMode::Unsorted,
)
.map_err(|_| Error::Synthesis)?;
}
assert_eq!(
region.shuffle_col_coord(),
NUM_LOOP * self.references[0][0].len()
NUM_LOOP * self.references[0].len()
);
assert_eq!(region.shuffle_index(), NUM_LOOP);
@@ -1843,17 +1853,19 @@ mod shuffle {
// parameters
let references = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
ValTensor::from(Tensor::from(
(0..LEN).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
))
})
.collect::<Vec<_>>();
let inputs = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
ValTensor::from(Tensor::from(
(0..LEN)
.rev()
.map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
))
})
.collect::<Vec<_>>();
@@ -1873,9 +1885,11 @@ mod shuffle {
} else {
loop_idx - 1
};
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * prev_idx) as u64 + 1))
})))]
ValTensor::from(Tensor::from(
(0..LEN)
.rev()
.map(|i| Value::known(F::from((i * prev_idx) as u64 + 1))),
))
})
.collect::<Vec<_>>();
@@ -1931,7 +1945,11 @@ mod add_with_overflow {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2026,7 +2044,7 @@ mod add_with_overflow_and_poseidon {
layouter.assign_region(|| "_new_module", |_| Ok(()))?;
let inputs = vec![assigned_inputs_a, assigned_inputs_b];
let inputs = vec![&assigned_inputs_a, &assigned_inputs_b];
layouter.assign_region(
|| "model",
@@ -2135,7 +2153,11 @@ mod sub {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Sub))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sub),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2202,7 +2224,11 @@ mod mult {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Mult))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Mult),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2269,7 +2295,11 @@ mod pow {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Pow(5)))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Pow(5)),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2360,13 +2390,13 @@ mod matmul_relu {
};
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[output.unwrap()],
&[&output.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -2465,7 +2495,7 @@ mod relu {
Ok(config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -2563,7 +2593,7 @@ mod lookup_ultra_overflow {
config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.map_err(|_| Error::Synthesis)

View File

@@ -1,6 +1,6 @@
use alloy::primitives::Address as H160;
use clap::{Command, Parser, Subcommand};
use clap_complete::{Generator, Shell, generate};
use clap_complete::{generate, Generator, Shell};
#[cfg(feature = "python-bindings")]
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
@@ -8,7 +8,7 @@ use std::path::PathBuf;
use std::str::FromStr;
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{Commitments, RunArgs, pfsys::ProofType};
use crate::{pfsys::ProofType, Commitments, RunArgs};
use crate::circuit::CheckMode;
use crate::graph::TestDataSource;
@@ -391,10 +391,8 @@ impl FromStr for DataField {
fn from_str(s: &str) -> Result<Self, Self::Err> {
// Check if the input starts with '@'
if s.starts_with('@') {
if let Some(file_path) = s.strip_prefix('@') {
// Extract the file path (remove the '@' prefix)
let file_path = &s[1..];
// Read the file content
let content = std::fs::read_to_string(file_path)
.map_err(|e| format!("Failed to read data file '{}': {}", file_path, e))?;

View File

@@ -1,24 +1,24 @@
use crate::graph::DataSource;
use crate::graph::GraphSettings;
use crate::graph::input::{CallToAccount, CallsToAccount, FileSourceInner, GraphData};
use crate::graph::modules::POSEIDON_INSTANCES;
use crate::pfsys::Snark;
use crate::graph::DataSource;
use crate::graph::GraphSettings;
use crate::pfsys::evm::EvmVerificationError;
use crate::pfsys::Snark;
use alloy::contract::CallBuilder;
use alloy::core::primitives::Address as H160;
use alloy::core::primitives::Bytes;
use alloy::core::primitives::U256;
use alloy::dyn_abi::abi::TokenSeq;
use alloy::dyn_abi::abi::token::{DynSeqToken, PackedSeqToken, WordToken};
use alloy::dyn_abi::abi::TokenSeq;
// use alloy::providers::Middleware;
use alloy::json_abi::JsonAbi;
use alloy::primitives::ruint::ParseError;
use alloy::primitives::{B256, I256, ParseSignedError};
use alloy::providers::ProviderBuilder;
use alloy::primitives::{ParseSignedError, B256, I256};
use alloy::providers::fillers::{
ChainIdFiller, FillProvider, GasFiller, JoinFill, NonceFiller, SignerFiller,
};
use alloy::providers::network::{Ethereum, EthereumSigner};
use alloy::providers::ProviderBuilder;
use alloy::providers::{Identity, Provider, RootProvider};
use alloy::rpc::types::eth::TransactionInput;
use alloy::rpc::types::eth::TransactionRequest;
@@ -27,9 +27,9 @@ use alloy::signers::wallet::{LocalWallet, WalletError};
use alloy::sol as abigen;
use alloy::transports::http::Http;
use alloy::transports::{RpcError, TransportErrorKind};
use foundry_compilers::Solc;
use foundry_compilers::artifacts::Settings as SolcSettings;
use foundry_compilers::error::{SolcError, SolcIoError};
use foundry_compilers::Solc;
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::{Fr, G1Affine};
use halo2curves::group::ff::PrimeField;
@@ -711,7 +711,7 @@ pub async fn verify_proof_with_data_attestation(
};
let encoded = func.encode_input(&[
Token::Address(addr_verifier.0.0.into()),
Token::Address(addr_verifier.0 .0.into()),
Token::Bytes(encoded_verifier),
])?;
@@ -757,7 +757,7 @@ pub async fn test_on_chain_data<M: 'static + Provider<Http<Client>, Ethereum>>(
let call_to_account = CallToAccount {
call_data: hex::encode(call),
decimals,
address: hex::encode(contract.address().0.0),
address: hex::encode(contract.address().0 .0),
};
info!("call_to_account: {:#?}", call_to_account);
Ok(call_to_account)
@@ -913,9 +913,9 @@ pub async fn get_contract_artifacts(
runs: usize,
) -> Result<(JsonAbi, Bytes, Bytes), EthError> {
use foundry_compilers::{
SHANGHAI_SOLC, SolcInput,
artifacts::{Optimizer, output_selection::OutputSelection},
artifacts::{output_selection::OutputSelection, Optimizer},
compilers::CompilerInput,
SolcInput, SHANGHAI_SOLC,
};
if !sol_code_path.exists() {

View File

@@ -1,6 +1,5 @@
use crate::EZKL_BUF_CAPACITY;
use crate::circuit::CheckMode;
use crate::circuit::region::RegionSettings;
use crate::circuit::CheckMode;
use crate::commands::CalibrationTarget;
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_sol};
#[allow(unused_imports)]
@@ -10,21 +9,21 @@ use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
use crate::graph::{TestDataSource, TestSources};
use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript};
use crate::pfsys::{
ProofSplitCommit, create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit,
create_keys, load_pk, load_vk, save_params, save_pk, Snark, StrategyType, TranscriptType,
};
use crate::pfsys::{
Snark, StrategyType, TranscriptType, create_keys, load_pk, load_vk, save_params, save_pk,
create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit, ProofSplitCommit,
};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::EZKL_BUF_CAPACITY;
use crate::{commands::*, EZKLError};
use crate::{Commitments, RunArgs};
use crate::{EZKLError, commands::*};
use colored::Colorize;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
use halo2_proofs::plonk::{self, Circuit};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params};
use halo2_proofs::poly::commitment::{ParamsProver, Verifier};
use halo2_proofs::poly::ipa::commitment::{IPACommitmentScheme, ParamsIPA};
@@ -37,6 +36,7 @@ use halo2_proofs::poly::kzg::strategy::AccumulatorStrategy as KZGAccumulatorStra
use halo2_proofs::poly::kzg::{
commitment::ParamsKZG, strategy::SingleStrategy as KZGSingleStrategy,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer};
use halo2_solidity_verifier;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
@@ -48,12 +48,12 @@ use itertools::Itertools;
use lazy_static::lazy_static;
use log::debug;
use log::{info, trace, warn};
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde::Serialize;
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::Config;
use snark_verifier::system::halo2::compile;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use snark_verifier::system::halo2::Config;
use std::fs::File;
use std::io::BufWriter;
use std::io::{Cursor, Write};
@@ -1163,15 +1163,9 @@ pub(crate) async fn calibrate(
// if unix get a gag
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
let _r = Gag::stdout().ok();
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
let _g = Gag::stderr().ok();
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
@@ -1329,7 +1323,7 @@ pub(crate) async fn calibrate(
.clone()
}
CalibrationTarget::Accuracy => {
let param_iterator = found_params.iter().sorted_by_key(|p| {
let mut param_iterator = found_params.iter().sorted_by_key(|p| {
(
p.run_args.input_scale,
p.run_args.param_scale,
@@ -1338,7 +1332,7 @@ pub(crate) async fn calibrate(
)
});
let last = param_iterator.last().ok_or("no params found")?;
let last = param_iterator.next_back().ok_or("no params found")?;
let max_scale = (
last.run_args.input_scale,
last.run_args.param_scale,

View File

@@ -67,8 +67,11 @@ pub enum GraphError {
#[error("invalid input types")]
InvalidInputTypes,
/// Missing results
#[error("missing results")]
MissingResults,
#[error("missing result for node {0}")]
MissingResults(usize),
/// Missing input
#[error("missing input {0}")]
MissingInputForNode(usize),
/// Tensor error
#[error("[tensor] {0}")]
TensorError(#[from] crate::tensor::TensorError),

View File

@@ -1,15 +1,15 @@
use super::errors::GraphError;
use super::quantize_float;
use crate::EZKL_BUF_CAPACITY;
use crate::circuit::InputType;
use crate::fieldutils::integer_rep_to_felt;
use crate::EZKL_BUF_CAPACITY;
use halo2curves::bn256::Fr as Fp;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::io::BufReader;
use std::io::BufWriter;
@@ -17,7 +17,7 @@ use std::io::Read;
use std::panic::UnwindSafe;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::{
tract_data::{TVec, prelude::Tensor as TractTensor},
tract_data::{prelude::Tensor as TractTensor, TVec},
value::TValue,
};
@@ -425,9 +425,7 @@ impl GraphData {
pub fn from_str(data: &str) -> Result<Self, GraphError> {
let graph_input = serde_json::from_str(data);
match graph_input {
Ok(graph_input) => {
return Ok(graph_input);
}
Ok(graph_input) => Ok(graph_input),
Err(_) => {
let path = std::path::PathBuf::from(data);
GraphData::from_path(path)

View File

@@ -33,12 +33,12 @@ use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSize
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::region::{ConstantsMap, RegionSettings};
use crate::circuit::table::{RESERVED_BLINDING_ROWS_PAD, Range, Table, num_cols_required};
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::{IntegerRep, felt_to_f64};
use crate::fieldutils::{felt_to_f64, IntegerRep};
use crate::pfsys::PrettyElements;
use crate::tensor::{Tensor, ValTensor};
use crate::{EZKL_BUF_CAPACITY, RunArgs};
use crate::{RunArgs, EZKL_BUF_CAPACITY};
use halo2_proofs::{
circuit::Layouter,
@@ -53,13 +53,13 @@ use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
pub use model::*;
pub use node::*;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDictMethods;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
@@ -571,7 +571,7 @@ impl GraphSettings {
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, &self).map_err(|e| {
error!("failed to save settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
std::io::Error::other(e)
})
}
/// load params from file
@@ -581,7 +581,7 @@ impl GraphSettings {
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
let settings: GraphSettings = serde_json::from_reader(reader).map_err(|e| {
error!("failed to load settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
std::io::Error::other(e)
})?;
crate::check_version_string_matches(&settings.version);
@@ -1232,15 +1232,9 @@ impl GraphCircuit {
let mut cs = ConstraintSystem::default();
// if unix get a gag
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
let _r = Gag::stdout().ok();
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
let _g = Gag::stderr().ok();
Self::configure_with_params(&mut cs, settings);
@@ -1576,13 +1570,13 @@ impl Circuit<Fp> for GraphCircuit {
let mut module_configs = ModuleConfigs::from_visibility(
cs,
params.module_sizes.clone(),
&params.module_sizes,
params.run_args.logrows as usize,
);
let mut vars = ModelVars::new(cs, &params);
module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone());
module_configs.configure_complex_modules(cs, &visibility, &params.module_sizes);
vars.instantiate_instance(
cs,

View File

@@ -200,7 +200,7 @@ fn number_of_iterations(mappings: &[InputMapping], dims: Vec<&[usize]>) -> usize
InputMapping::Stacked { axis, chunk } => Some(
// number of iterations given the dim size along the axis
// and the chunk size
(dims[*axis] + chunk - 1) / chunk,
dims[*axis].div_ceil(*chunk), // (dims[*axis] + chunk - 1) / chunk,
),
_ => None,
});
@@ -589,10 +589,7 @@ impl Model {
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
input_types: match self.get_input_types() {
Ok(x) => Some(x),
Err(_) => None,
},
input_types: self.get_input_types().ok(),
output_types: Some(self.get_output_types()),
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
@@ -650,10 +647,13 @@ impl Model {
let variables: std::collections::HashMap<String, usize> =
std::collections::HashMap::from_iter(variables.iter().map(|(k, v)| (k.clone(), *v)));
for (i, id) in model.clone().inputs.iter().enumerate() {
let inputs = model.inputs.clone();
let outputs = model.outputs.clone();
for (i, id) in inputs.iter().enumerate() {
let input = model.node_mut(id.node);
if input.outputs.len() == 0 {
if input.outputs.is_empty() {
return Err(GraphError::MissingOutput(id.node));
}
let mut fact: InferenceFact = input.outputs[0].fact.clone();
@@ -672,7 +672,7 @@ impl Model {
model.set_input_fact(i, fact)?;
}
for (i, _) in model.clone().outputs.iter().enumerate() {
for (i, _) in outputs.iter().enumerate() {
model.set_output_fact(i, InferenceFact::default())?;
}
@@ -1196,7 +1196,7 @@ impl Model {
.base
.layout(
&mut thread_safe_region,
&[output.clone(), comparators],
&[output, &comparators],
Box::new(HybridOp::Output {
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
@@ -1257,12 +1257,27 @@ impl Model {
node.inputs()
.iter()
.map(|(idx, outlet)| {
Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
// check node is not an output
let is_output = self.graph.outputs.iter().any(|(o_idx, _)| *idx == *o_idx);
let res = if self.graph.nodes[idx].num_uses() == 1 && !is_output {
let res = results.remove(idx);
res.ok_or(GraphError::MissingResults(*idx))?[*outlet].clone()
} else {
results.get(idx).ok_or(GraphError::MissingResults(*idx))?[*outlet]
.clone()
};
Ok(res)
})
.collect::<Result<Vec<_>, GraphError>>()?
} else {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
if self.graph.nodes[idx].num_uses() == 1 {
let res = results.remove(idx);
vec![res.ok_or(GraphError::MissingInput(*idx))?[0].clone()]
} else {
vec![results.get(idx).ok_or(GraphError::MissingInput(*idx))?[0].clone()]
}
};
trace!("output dims: {:?}", node.out_dims());
trace!(
@@ -1273,7 +1288,7 @@ impl Model {
let start = instant::Instant::now();
match &node {
NodeType::Node(n) => {
let res = if node.is_constant() && node.num_uses() == 1 {
let mut res = if node.is_constant() && node.num_uses() == 1 {
log::debug!("node {} is a constant with 1 use", n.idx);
let mut node = n.clone();
let c = node
@@ -1284,19 +1299,19 @@ impl Model {
} else {
config
.base
.layout(region, &values, n.opkind.clone_dyn())
.layout(region, &values.iter().collect_vec(), n.opkind.clone_dyn())
.map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?
};
if let Some(mut vt) = res {
if let Some(vt) = &mut res {
vt.reshape(&node.out_dims()[0])?;
// we get the max as for fused nodes this corresponds to the node output
results.insert(*idx, vec![vt.clone()]);
//only use with mock prover
debug!("------------ output node {:?}: {:?}", idx, vt.show());
// we get the max as for fused nodes this corresponds to the node output
results.insert(*idx, vec![vt.clone()]);
}
}
NodeType::SubGraph {
@@ -1340,7 +1355,7 @@ impl Model {
.inputs
.clone()
.into_iter()
.zip(values.clone().into_iter().map(|v| vec![v])),
.zip(values.iter().map(|v| vec![v.clone()])),
);
let res = model.layout_nodes(config, region, &mut subgraph_results)?;
@@ -1421,7 +1436,7 @@ impl Model {
);
let outputs = output_nodes
.map(|(idx, outlet)| {
Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
Ok(results.get(idx).ok_or(GraphError::MissingResults(*idx))?[*outlet].clone())
})
.collect::<Result<Vec<_>, GraphError>>()?;
@@ -1476,7 +1491,7 @@ impl Model {
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
&[output, &comparator],
Box::new(HybridOp::Output {
decomp: !run_args.ignore_range_check_inputs_outputs,
}),

View File

@@ -37,15 +37,15 @@ impl ModuleConfigs {
/// Create new module configs from visibility of each variable
pub fn from_visibility(
cs: &mut ConstraintSystem<Fp>,
module_size: ModuleSizes,
module_size: &ModuleSizes,
logrows: usize,
) -> Self {
let mut config = Self::default();
for size in module_size.polycommit {
for size in &module_size.polycommit {
config
.polycommit
.push(PolyCommitChip::configure(cs, (logrows, size)));
.push(PolyCommitChip::configure(cs, (logrows, *size)));
}
config
@@ -55,8 +55,8 @@ impl ModuleConfigs {
pub fn configure_complex_modules(
&mut self,
cs: &mut ConstraintSystem<Fp>,
visibility: VarVisibility,
module_size: ModuleSizes,
visibility: &VarVisibility,
module_size: &ModuleSizes,
) {
if (visibility.input.is_hashed()
|| visibility.output.is_hashed()

View File

@@ -37,6 +37,7 @@ use crate::tensor::TensorError;
// Import curve-specific field type
use halo2curves::bn256::Fr as Fp;
use itertools::Itertools;
// Import logging for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use log::trace;
@@ -118,16 +119,15 @@ impl Op<Fp> for Rescaled {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
values: &[&crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
if self.scale.len() != values.len() {
return Err(TensorError::DimMismatch("rescaled inputs".to_string()).into());
}
let res =
&crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)?
[..];
self.inner.layout(config, region, res)
crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)?;
self.inner.layout(config, region, &res.iter().collect_vec())
}
/// Create a cloned boxed copy of this operation
@@ -274,13 +274,13 @@ impl Op<Fp> for RebaseScale {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
values: &[&crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
let original_res = self
.inner
.layout(config, region, values)?
.ok_or(CircuitError::MissingLayout(self.as_string()))?;
self.rebase_op.layout(config, region, &[original_res])
self.rebase_op.layout(config, region, &[&original_res])
}
/// Create a cloned boxed copy of this operation
@@ -472,7 +472,7 @@ impl Op<Fp> for SupportedOp {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
values: &[&crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
self.as_op().layout(config, region, values)
}

View File

@@ -1,14 +1,14 @@
use super::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
use super::errors::GraphError;
use super::{Rescaled, SupportedOp, Visibility};
use crate::circuit::Op;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::hybrid::HybridOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::lookup::LookupOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::poly::PolyOp;
use crate::circuit::Op;
use crate::fieldutils::IntegerRep;
use crate::tensor::{Tensor, TensorError, TensorType};
use halo2curves::bn256::Fr as Fp;
@@ -22,7 +22,6 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::ops::{
Downsample,
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
@@ -32,6 +31,7 @@ use tract_onnx::tract_core::ops::{
einsum::EinSum,
element_wise::ElementWiseOp,
nn::{LeakyRelu, Reduce, Softmax},
Downsample,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::{

View File

@@ -29,8 +29,14 @@
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
use log::warn;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc as _;
#[global_allocator]
#[cfg(all(feature = "jemalloc", not(target_arch = "wasm32")))]
static GLOBAL: jemallocator::Jemalloc = jemallocator::Jemalloc;
#[global_allocator]
#[cfg(all(feature = "mimalloc", not(target_arch = "wasm32")))]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
/// Error type
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
@@ -97,11 +103,11 @@ impl From<String> for EZKLError {
use std::str::FromStr;
use circuit::{CheckMode, table::Range};
use circuit::{table::Range, CheckMode};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::Args;
use fieldutils::IntegerRep;
use graph::{MAX_PUBLIC_SRS, Visibility};
use graph::{Visibility, MAX_PUBLIC_SRS};
use halo2_proofs::poly::{
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
};

View File

@@ -17,16 +17,16 @@ use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
use clap::ValueEnum;
use halo2_proofs::circuit::Value;
use halo2_proofs::plonk::{
Circuit, ProvingKey, VerifyingKey, create_proof, keygen_pk, keygen_vk_custom, verify_proof,
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
use halo2curves::CurveAffine;
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use halo2curves::serde::SerdeObject;
use halo2curves::CurveAffine;
use instant::Instant;
use log::{debug, info, trace};
#[cfg(not(feature = "det-prove"))]
@@ -324,7 +324,7 @@ where
}
#[cfg(feature = "python-bindings")]
use pyo3::{PyObject, Python, ToPyObject, types::PyDict};
use pyo3::{types::PyDict, PyObject, Python, ToPyObject};
#[cfg(feature = "python-bindings")]
impl<F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> ToPyObject for Snark<F, C>
where
@@ -348,9 +348,9 @@ where
}
impl<
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
where
C::Scalar: Serialize + DeserializeOwned,
C::ScalarExt: Serialize + DeserializeOwned,

View File

@@ -27,7 +27,7 @@ pub use var::*;
use crate::{
circuit::utils,
fieldutils::{IntegerRep, integer_rep_to_felt},
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::Visibility,
};
@@ -42,11 +42,11 @@ use std::error::Error;
use std::fmt::Debug;
use std::io::Read;
use std::iter::Iterator;
use std::ops::Rem;
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
use std::{cmp::max, ops::Rem};
/// The (inner) type of tensor elements.
pub trait TensorType: Clone + Debug + 'static {
pub trait TensorType: Clone + Debug {
/// Returns the zero value.
fn zero() -> Option<Self> {
None
@@ -55,10 +55,6 @@ pub trait TensorType: Clone + Debug + 'static {
fn one() -> Option<Self> {
None
}
/// Max operator for ordering values.
fn tmax(&self, _: &Self) -> Option<Self> {
None
}
}
macro_rules! tensor_type {
@@ -70,10 +66,6 @@ macro_rules! tensor_type {
fn one() -> Option<Self> {
Some($one)
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some(max(*self, *other))
}
}
};
}
@@ -82,46 +74,12 @@ impl TensorType for f32 {
fn zero() -> Option<Self> {
Some(0.0)
}
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
// A comparison between f32s needs to handle NAN values.
fn tmax(&self, other: &Self) -> Option<Self> {
match (self.is_nan(), other.is_nan()) {
(true, true) => Some(f32::NAN),
(true, false) => Some(*other),
(false, true) => Some(*self),
(false, false) => {
if self >= other {
Some(*self)
} else {
Some(*other)
}
}
}
}
}
impl TensorType for f64 {
fn zero() -> Option<Self> {
Some(0.0)
}
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
// A comparison between f32s needs to handle NAN values.
fn tmax(&self, other: &Self) -> Option<Self> {
match (self.is_nan(), other.is_nan()) {
(true, true) => Some(f64::NAN),
(true, false) => Some(*other),
(false, true) => Some(*self),
(false, false) => {
if self >= other {
Some(*self)
} else {
Some(*other)
}
}
}
}
}
tensor_type!(bool, Bool, false, true);
@@ -147,14 +105,6 @@ impl<T: TensorType> TensorType for Value<T> {
fn one() -> Option<Self> {
Some(Value::known(T::one().unwrap()))
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some(
(self.clone())
.zip(other.clone())
.map(|(a, b)| a.tmax(&b).unwrap()),
)
}
}
impl<F: PrimeField + PartialOrd> TensorType for Assigned<F>
@@ -168,14 +118,6 @@ where
fn one() -> Option<Self> {
Some(F::ONE.into())
}
fn tmax(&self, other: &Self) -> Option<Self> {
if self.evaluate() >= other.evaluate() {
Some(*self)
} else {
Some(*other)
}
}
}
impl<F: PrimeField> TensorType for Expression<F>
@@ -189,42 +131,14 @@ where
fn one() -> Option<Self> {
Some(Expression::Constant(F::ONE))
}
fn tmax(&self, _: &Self) -> Option<Self> {
todo!()
}
}
impl TensorType for Column<Advice> {}
impl TensorType for Column<Fixed> {}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<Assigned<F>, F> {
fn tmax(&self, other: &Self) -> Option<Self> {
let mut output: Option<Self> = None;
self.value_field().zip(other.value_field()).map(|(a, b)| {
if a.evaluate() >= b.evaluate() {
output = Some(self.clone());
} else {
output = Some(other.clone());
}
});
output
}
}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<Assigned<F>, F> {}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<F, F> {
fn tmax(&self, other: &Self) -> Option<Self> {
let mut output: Option<Self> = None;
self.value().zip(other.value()).map(|(a, b)| {
if a >= b {
output = Some(self.clone());
} else {
output = Some(other.clone());
}
});
output
}
}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<F, F> {}
// specific types
impl TensorType for halo2curves::pasta::Fp {
@@ -235,10 +149,6 @@ impl TensorType for halo2curves::pasta::Fp {
fn one() -> Option<Self> {
Some(halo2curves::pasta::Fp::one())
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some((*self).max(*other))
}
}
impl TensorType for halo2curves::bn256::Fr {
@@ -249,9 +159,15 @@ impl TensorType for halo2curves::bn256::Fr {
fn one() -> Option<Self> {
Some(halo2curves::bn256::Fr::one())
}
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some((*self).max(*other))
impl<F: TensorType> TensorType for &F {
fn zero() -> Option<Self> {
None
}
fn one() -> Option<Self> {
None
}
}
@@ -374,7 +290,7 @@ impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
}
}
impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync>
impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync + 'data>
maybe_rayon::iter::IntoParallelRefMutIterator<'data> for Tensor<T>
{
type Iter = maybe_rayon::slice::IterMut<'data, T>;
@@ -423,6 +339,14 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
}
}
impl<T: Clone + TensorType> Tensor<&T> {
/// Clones the tensor values into a new tensor.
pub fn cloned(&self) -> Tensor<T> {
let inner = self.inner.clone().into_iter().cloned().collect::<Vec<T>>();
Tensor::new(Some(&inner), &self.dims).unwrap()
}
}
impl<T: Clone + TensorType> Tensor<T> {
/// Sets (copies) the tensor values to the provided ones.
pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result<Self, TensorError> {
@@ -554,7 +478,6 @@ impl<T: Clone + TensorType> Tensor<T> {
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
/// ```
@@ -631,23 +554,23 @@ impl<T: Clone + TensorType> Tensor<T> {
// Fill remaining dimensions
full_indices.extend((indices.len()..self.dims.len()).map(|i| 0..self.dims[i]));
// Pre-calculate total size and allocate result vector
let total_size: usize = full_indices
.iter()
.map(|range| range.end - range.start)
.product();
let mut res = Vec::with_capacity(total_size);
// Calculate new dimensions once
let dims: Vec<usize> = full_indices.iter().map(|e| e.end - e.start).collect();
// Use iterator directly without collecting into intermediate Vec
for coord in full_indices.iter().cloned().multi_cartesian_product() {
let index = self.get_index(&coord);
res.push(self[index].clone());
}
let mut output = Tensor::new(None, &dims)?;
Tensor::new(Some(&res), &dims)
let cartesian_coord: Vec<Vec<usize>> = full_indices
.iter()
.cloned()
.multi_cartesian_product()
.collect();
output.par_iter_mut().enumerate().for_each(|(i, e)| {
let coord = &cartesian_coord[i];
*e = self.get(coord);
});
Ok(output)
}
/// Set a slice of the Tensor.
@@ -753,7 +676,7 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
pub fn get_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
for (i, elem) in self.inner.clone().into_iter().enumerate() {
for (i, elem) in self.inner.iter().enumerate() {
if i % n == 0 {
inner.push(elem.clone());
}
@@ -776,7 +699,7 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
pub fn exclude_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
for (i, elem) in self.inner.clone().into_iter().enumerate() {
for (i, elem) in self.inner.iter().enumerate() {
if i % n != 0 {
inner.push(elem.clone());
}
@@ -812,9 +735,9 @@ impl<T: Clone + TensorType> Tensor<T> {
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
let mut offset = initial_offset;
for (i, elem) in self.inner.clone().into_iter().enumerate() {
for (i, elem) in self.inner.iter().enumerate() {
if (i + offset + 1) % n == 0 {
inner.extend(vec![elem; 1 + num_repeats]);
inner.extend(vec![elem.clone(); 1 + num_repeats]);
offset += num_repeats;
} else {
inner.push(elem.clone());
@@ -871,16 +794,16 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
/// let mut indices = vec![3, 4];
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), expected);
/// assert_eq!(a.remove_indices(&mut indices, false).unwrap(), expected);
///
///
/// let a = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
/// let b = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
/// let mut indices = vec![63, 67];
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), b);
/// assert_eq!(a.remove_indices(&mut indices, false).unwrap(), b);
/// ```
pub fn remove_indices(
&self,
@@ -927,7 +850,7 @@ impl<T: Clone + TensorType> Tensor<T> {
}
self.dims = vec![];
}
if self.dims() == &[0] && new_dims.iter().product::<usize>() == 1 {
if self.dims() == [0] && new_dims.iter().product::<usize>() == 1 {
self.dims = Vec::from(new_dims);
} else {
let product = if new_dims != [0] {
@@ -1292,33 +1215,6 @@ impl<T: Clone + TensorType> Tensor<T> {
Tensor::new(Some(&[res]), &[1])
}
/// Maps a function to tensors and enumerates in parallel
/// ```
/// use ezkl::tensor::{Tensor, TensorError};
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
/// ```
pub fn par_enum_map_mut_filtered<
F: Fn(usize) -> Result<T, E> + std::marker::Send + std::marker::Sync,
E: Error + std::marker::Send + std::marker::Sync,
>(
&mut self,
filter_indices: &std::collections::HashSet<usize>,
f: F,
) -> Result<(), E>
where
T: std::marker::Send + std::marker::Sync,
{
self.inner
.par_iter_mut()
.enumerate()
.filter(|(i, _)| filter_indices.contains(i))
.for_each(move |(i, e)| *e = f(i).unwrap());
Ok(())
}
}
impl<T: Clone + TensorType> Tensor<Tensor<T>> {
@@ -1335,9 +1231,9 @@ impl<T: Clone + TensorType> Tensor<Tensor<T>> {
pub fn combine(&self) -> Result<Tensor<T>, TensorError> {
let mut dims = 0;
let mut inner = Vec::new();
for t in self.inner.clone().into_iter() {
for t in self.inner.iter() {
dims += t.len();
inner.extend(t.inner);
inner.extend(t.inner.clone());
}
Tensor::new(Some(&inner), &[dims])
}
@@ -1808,7 +1704,7 @@ impl DataFormat {
match self {
DataFormat::NHWC => DataFormat::NCHW,
DataFormat::HWC => DataFormat::CHW,
_ => self.clone(),
_ => *self,
}
}
@@ -1908,7 +1804,7 @@ impl KernelFormat {
match self {
KernelFormat::HWIO => KernelFormat::OIHW,
KernelFormat::OHWI => KernelFormat::OIHW,
_ => self.clone(),
_ => *self,
}
}
@@ -2030,6 +1926,9 @@ mod tests {
fn tensor_slice() {
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
assert_eq!(
a.get_slice(&[0..2, 0..1]).unwrap(),
b.get_slice(&[0..2, 0..1]).unwrap()
);
}
}

View File

@@ -916,11 +916,14 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 7]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
pub fn gather_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
pub fn gather_nd<'a, T: TensorType + Send + Sync + 'a>(
input: &'a Tensor<T>,
index: &Tensor<usize>,
batch_dims: usize,
) -> Result<Tensor<T>, TensorError> {
) -> Result<Tensor<T>, TensorError>
where
&'a T: TensorType,
{
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
@@ -1108,11 +1111,14 @@ pub fn gather_nd<T: TensorType + Send + Sync>(
/// assert_eq!(result, expected);
/// ````
///
pub fn scatter_nd<T: TensorType + Send + Sync>(
pub fn scatter_nd<'a, T: TensorType + Send + Sync + 'a>(
input: &Tensor<T>,
index: &Tensor<usize>,
src: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
src: &'a Tensor<T>,
) -> Result<Tensor<T>, TensorError>
where
&'a T: TensorType,
{
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
@@ -1183,12 +1189,12 @@ pub fn abs<T: TensorType + Add<Output = T> + std::cmp::Ord + Neg<Output = T>>(
/// use ezkl::tensor::ops::intercalate_values;
///
/// let tensor = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
/// let result = intercalate_values(&tensor, 0, 2, 1).unwrap();
/// let result = intercalate_values(&tensor, &0, 2, 1).unwrap();
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 2, 3, 0, 4]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// let result = intercalate_values(&expected, 0, 2, 0).unwrap();
/// let result = intercalate_values(&expected, &0, 2, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 2, 0, 0, 0, 3, 0, 4]), &[3, 3]).unwrap();
///
/// assert_eq!(result, expected);
@@ -1196,7 +1202,7 @@ pub fn abs<T: TensorType + Add<Output = T> + std::cmp::Ord + Neg<Output = T>>(
/// ```
pub fn intercalate_values<T: TensorType>(
tensor: &Tensor<T>,
value: T,
value: &T,
stride: usize,
axis: usize,
) -> Result<Tensor<T>, TensorError> {
@@ -1494,7 +1500,7 @@ pub fn slice<T: TensorType + Send + Sync>(
}
}
t.get_slice(&slice)
Ok(t.get_slice(&slice)?)
}
// ---------------------------------------------------------------------------------------------------------
@@ -2414,20 +2420,20 @@ pub mod accumulated {
/// Some(&[25, 35]),
/// &[2],
/// ).unwrap();
/// assert_eq!(dot(&[x, y], 1).unwrap(), expected);
/// assert_eq!(dot(&x, &y, 1).unwrap(), expected);
/// ```
pub fn dot<T: TensorType + Mul<Output = T> + Add<Output = T>>(
inputs: &[Tensor<T>; 2],
a: &Tensor<T>,
b: &Tensor<T>,
chunk_size: usize,
) -> Result<Tensor<T>, TensorError> {
if inputs[0].clone().len() != inputs[1].clone().len() {
if a.len() != b.len() {
return Err(TensorError::DimMismatch("dot".to_string()));
}
let (a, b): (Tensor<T>, Tensor<T>) = (inputs[0].clone(), inputs[1].clone());
let transcript: Tensor<T> = a
.iter()
.zip(b)
.zip(b.iter())
.chunks(chunk_size)
.into_iter()
.scan(T::zero().unwrap(), |acc, chunk| {

View File

@@ -280,9 +280,17 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Vec<ValType<F>>> for ValTenso
fn from(t: Vec<ValType<F>>) -> ValTensor<F> {
ValTensor::Value {
inner: t.clone().into_iter().into(),
dims: vec![t.len()],
scale: 1,
}
}
}
impl<F: PrimeField + TensorType + PartialOrd> From<Vec<&ValType<F>>> for ValTensor<F> {
fn from(t: Vec<&ValType<F>>) -> ValTensor<F> {
ValTensor::Value {
inner: t.clone().into_iter().cloned().into(),
dims: vec![t.len()],
scale: 1,
}
}
@@ -640,6 +648,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// # Returns
/// A tensor containing the base-n decomposition of each value
pub fn decompose(&self, base: usize, n: usize) -> Result<Self, TensorError> {
let mut dims = self.dims().to_vec();
dims.push(n + 1);
let res = self
.get_inner()?
.par_iter()
@@ -665,9 +676,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
})
.collect::<Result<Vec<_>, _>>();
let mut tensor = Tensor::from(res?.into_iter().flatten().collect::<Vec<_>>().into_iter());
let mut dims = self.dims().to_vec();
dims.push(n + 1);
let mut tensor = res?.into_iter().flatten().collect::<Tensor<_>>();
tensor.reshape(&dims)?;
@@ -1043,119 +1052,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(())
}
/// Removes constant zero values from the tensor
/// Uses parallel processing for tensors larger than a threshold
pub fn remove_const_zero_values(&mut self) {
let size_threshold = 1_000_000; // Tuned using benchmarks
if self.len() < size_threshold {
// Single-threaded for small tensors
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.clone()
.into_iter()
.filter_map(|e| {
if let ValType::Constant(r) = e {
if r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if r == F::ZERO {
return None;
}
}
Some(e)
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => {}
}
} else {
// Parallel processing for large tensors
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.par_chunks_mut(chunk_size)
.flat_map(|chunk| {
chunk.par_iter_mut().filter_map(|e| {
if let ValType::Constant(r) = e {
if *r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if *r == F::ZERO {
return None;
}
}
Some(e.clone())
})
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => {}
}
}
}
/// Gets the indices of all constant zero values
/// Uses parallel processing for large tensors
///
/// # Returns
/// A vector of indices where constant zero values are located
pub fn get_const_zero_indices(&self) -> Vec<usize> {
let size_threshold = 1_000_000;
if self.len() < size_threshold {
// Single-threaded for small tensors
match &self {
ValTensor::Value { inner: v, .. } => v
.iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect(),
ValTensor::Instance { .. } => vec![],
}
} else {
// Parallel processing for large tensors
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match &self {
ValTensor::Value { inner: v, .. } => v
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
chunk
.par_iter()
.enumerate()
.filter_map(move |(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
}
_ => None,
})
})
.collect::<Vec<_>>(),
ValTensor::Instance { .. } => vec![],
}
}
}
/// Gets the indices of all constant values
/// Uses parallel processing for large tensors
///
@@ -1276,7 +1172,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Returns an error if called on an Instance tensor
pub fn intercalate_values(
&mut self,
value: ValType<F>,
value: &ValType<F>,
stride: usize,
axis: usize,
) -> Result<(), TensorError> {
@@ -1368,10 +1264,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// Concatenates two tensors along the first dimension
pub fn concat(&self, other: Self) -> Result<Self, TensorError> {
pub fn concat(&self, other: &Self) -> Result<Self, TensorError> {
let res = match (self, other) {
(ValTensor::Value { inner: v1, .. }, ValTensor::Value { inner: v2, .. }) => {
ValTensor::from(Tensor::new(Some(&[v1.clone(), v2]), &[2])?.combine()?)
ValTensor::from(Tensor::new(Some(&[v1.clone(), v2.clone()]), &[2])?.combine()?)
}
_ => {
return Err(TensorError::WrongMethod);

View File

@@ -1,8 +1,6 @@
use std::collections::HashSet;
use log::{debug, error, warn};
use crate::circuit::{CheckMode, region::ConstantsMap};
use crate::circuit::{region::ConstantsMap, CheckMode};
use super::*;
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
@@ -381,50 +379,6 @@ impl VarTensor {
}
}
/// Assigns values from a ValTensor to this tensor, excluding specified positions
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `omissions` - Set of positions to skip during assignment
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// The assigned ValTensor or an error if assignment fails
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
omissions: &HashSet<usize>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut assigned_coord = 0;
let mut res: ValTensor<F> = match values {
ValTensor::Instance { .. } => {
error!(
"assignment with omissions is not supported on instance columns. increase K if you require more rows."
);
Err(halo2_proofs::plonk::Error::Synthesis)
}
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
v.enum_map(|coord, k| {
if omissions.contains(&coord) {
return Ok::<_, halo2_proofs::plonk::Error>(k);
}
let cell =
self.assign_value(region, offset, k.clone(), assigned_coord, constants)?;
assigned_coord += 1;
Ok::<_, halo2_proofs::plonk::Error>(cell)
})?
.into(),
),
}?;
res.set_scale(values.scale());
Ok(res)
}
/// Assigns values from a ValTensor to this tensor
///
/// # Arguments

View File

@@ -3,10 +3,10 @@
mod native_tests {
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::Commitments;
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
use ezkl::pfsys::Snark;
use ezkl::Commitments;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::Bn256;
use lazy_static::lazy_static;