Compare commits

..

2 Commits

Author SHA1 Message Date
dante
fce07b4e57 Update rust.yml 2025-04-28 17:13:04 -04:00
dante
84a5693b0e refactor: rm postgres 2025-04-28 16:26:19 -04:00
47 changed files with 2315 additions and 40438 deletions

View File

@@ -198,15 +198,15 @@ jobs:
- name: Build release binary (no asm or metal)
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features mimalloc
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
- name: Build release binary (asm)
if: matrix.build == 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm,mimalloc
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
- name: Build release binary (metal)
if: matrix.build == 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal,mimalloc
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
- name: Strip release binary
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'

View File

@@ -245,6 +245,25 @@ jobs:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
foudry-solidity-tests:
permissions:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
with:
persist-credentials: false
submodules: recursive
- name: Install Foundry
uses: foundry-rs/foundry-toolchain@3b74dacdda3c0b763089addb99ed86bc3800e68b
- name: Run tests
run: |
cd tests/foundry
forge install https://github.com/foundry-rs/forge-std --no-git --no-commit
forge test -vvvv --fuzz-runs 64
mock-proving-tests:
permissions:
contents: read
@@ -353,6 +372,9 @@ jobs:
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
- name: Build @ezkljs/verify package
run: |
cd in-browser-evm-verifier
@@ -475,6 +497,9 @@ jobs:
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
@@ -719,6 +744,24 @@ jobs:
permissions:
contents: read
runs-on: large-self-hosted
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres
env:
POSTGRES_USER: ubuntu
POSTGRES_HOST_AUTH_METHOD: trust
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
-v /var/run/postgresql:/var/run/postgresql
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:

2434
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@ cargo-features = ["profile-rustflags"]
[package]
name = "ezkl"
version = "0.0.0"
edition = "2021"
edition = "2024"
default-run = "ezkl"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -75,12 +75,12 @@ tokio = { version = "1.35.0", default-features = false, features = [
"macros",
"rt-multi-thread",
], optional = true }
pyo3 = { version = "0.24.2", features = [
pyo3 = { version = "0.23.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default-features = false, optional = true }
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.24.0", features = [
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.23.0", features = [
"attributes",
"tokio-runtime",
], default-features = false, optional = true }
@@ -88,9 +88,9 @@ pyo3-log = { version = "0.12.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
objc = { version = "0.2.4", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
jemallocator = { version = "0.5", optional = true }
mimalloc = { version = "0.1", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
# universal bindings
uniffi = { version = "=0.28.0", optional = true }
getrandom = { version = "0.2.8", optional = true }
@@ -243,6 +243,7 @@ ezkl = [
"dep:lazy_static",
"dep:tokio",
"dep:openssl",
"dep:mimalloc",
"dep:chrono",
"dep:sha256",
"dep:clap_complete",
@@ -269,19 +270,22 @@ no-banner = []
no-update = []
macos-metal = ["halo2_proofs/macos"]
ios-metal = ["halo2_proofs/ios"]
jemalloc = ["dep:jemallocator"]
mimalloc = ["dep:mimalloc"]
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
[profile.release]
# debug = true
rustflags = ["-C", "relocation-model=pic"]
lto = "fat"
codegen-units = 1
# panic = "abort"
#panic = "abort"
[profile.test-runs]

View File

@@ -68,7 +68,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&[&self.image, &self.kernel, &self.bias],
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
stride: vec![1; 2],

View File

@@ -15,7 +15,6 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -60,7 +59,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs,
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),

View File

@@ -15,7 +15,6 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -62,7 +61,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs,
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),

View File

@@ -17,7 +17,6 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -87,13 +86,13 @@ impl Circuit<Fr> for MyCircuit {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.layout(&mut region, &self.inputs, Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[&output.unwrap()],
&[output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -17,7 +17,6 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -88,13 +87,13 @@ impl Circuit<Fr> for MyCircuit {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.layout(&mut region, &self.inputs, Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[&output.unwrap()],
&[output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -15,7 +15,6 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -60,7 +59,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs,
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.unwrap();

View File

@@ -63,7 +63,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&[&self.image],
&[self.image.clone()],
Box::new(HybridOp::SumPool {
padding: vec![(0, 0); 2],
stride: vec![1, 1],

View File

@@ -15,7 +15,6 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -58,11 +57,7 @@ impl Circuit<Fr> for MyCircuit {
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.layout(&mut region, &self.inputs, Box::new(PolyOp::Add))
.unwrap();
Ok(())
},

View File

@@ -16,7 +16,6 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -59,11 +58,7 @@ impl Circuit<Fr> for MyCircuit {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Pow(4)),
)
.layout(&mut region, &self.inputs, Box::new(PolyOp::Pow(4)))
.unwrap();
Ok(())
},

View File

@@ -70,7 +70,7 @@ impl Circuit<Fr> for NLCircuit {
config
.layout(
&mut region,
&[&self.input],
&[self.input.clone()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,

View File

@@ -67,7 +67,7 @@ impl Circuit<Fr> for NLCircuit {
config
.layout(
&mut region,
&[&self.input],
&[self.input.clone()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '22.0.5'
release = '0.0.0'
version = release

View File

@@ -32,6 +32,7 @@ use mnist::*;
use rand::rngs::OsRng;
use std::marker::PhantomData;
mod params;
const K: usize = 20;
@@ -215,7 +216,11 @@ where
.layer_config
.layout(
&mut region,
&[&self.input, &self.l0_params[0], &self.l0_params[1]],
&[
self.input.clone(),
self.l0_params[0].clone(),
self.l0_params[1].clone(),
],
Box::new(op),
)
.unwrap();
@@ -224,7 +229,7 @@ where
.layer_config
.layout(
&mut region,
&[&x.unwrap()],
&[x.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -236,7 +241,7 @@ where
.layer_config
.layout(
&mut region,
&[&x.unwrap()],
&[x.unwrap()],
Box::new(LookupOp::Div { denom: 32.0.into() }),
)
.unwrap()
@@ -248,7 +253,7 @@ where
.layer_config
.layout(
&mut region,
&[&self.l2_params[0], &x],
&[self.l2_params[0].clone(), x],
Box::new(PolyOp::Einsum {
equation: "ij,j->ik".to_string(),
}),
@@ -260,7 +265,7 @@ where
.layer_config
.layout(
&mut region,
&[&x, &self.l2_params[1]],
&[x, self.l2_params[1].clone()],
Box::new(PolyOp::Add),
)
.unwrap()

View File

@@ -117,7 +117,10 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&self.l0_params[0].clone().try_into().unwrap(), &self.input],
&[
self.l0_params[0].clone().try_into().unwrap(),
self.input.clone(),
],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -132,7 +135,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&x, &self.l0_params[1].clone().try_into().unwrap()],
&[x, self.l0_params[1].clone().try_into().unwrap()],
Box::new(PolyOp::Add),
)
.unwrap()
@@ -144,7 +147,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&x],
&[x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
@@ -160,7 +163,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&self.l2_params[0].clone().try_into().unwrap(), &x],
&[self.l2_params[0].clone().try_into().unwrap(), x],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -175,7 +178,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&x, &self.l2_params[1].clone().try_into().unwrap()],
&[x, self.l2_params[1].clone().try_into().unwrap()],
Box::new(PolyOp::Add),
)
.unwrap()
@@ -187,7 +190,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&x],
&[x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
@@ -200,7 +203,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[&x.unwrap()],
&[x.unwrap()],
Box::new(LookupOp::Div {
denom: ezkl::circuit::utils::F32::from(128.),
}),

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,12 @@
// ignore file if compiling for wasm
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc::MiMalloc;
#[global_allocator]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
static GLOBAL: MiMalloc = MiMalloc;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::{CommandFactory, Parser};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]

View File

@@ -1,34 +1,34 @@
use crate::circuit::modules::polycommit::PolyCommitChip;
use crate::circuit::modules::poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
};
use crate::circuit::modules::Module;
use crate::Commitments;
use crate::RunArgs;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::circuit::modules::Module;
use crate::circuit::modules::polycommit::PolyCommitChip;
use crate::circuit::modules::poseidon::{
PoseidonChip,
spec::{POSEIDON_RATE, POSEIDON_WIDTH, PoseidonSpec},
};
use crate::commands::*;
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt};
use crate::graph::TestDataSource;
use crate::graph::{
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
GraphCircuit, GraphSettings, Model, Visibility, quantize_float, scale_to_multiplier,
};
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::{
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
ProofType, TranscriptType,
ProofType, TranscriptType, load_pk, load_vk, save_params, save_vk,
srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
};
use crate::Commitments;
use crate::RunArgs;
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
use halo2curves::bn256::{Bn256, Fq, Fr, G1, G1Affine};
use pyo3::exceptions::{PyIOError, PyRuntimeError};
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3_log;
use pyo3_stub_gen::{
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction, TypeInfo,
TypeInfo, define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction,
};
use snark_verifier::util::arithmetic::PrimeField;
use std::collections::HashSet;

View File

@@ -962,7 +962,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
pub fn layout(
&mut self,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>],
values: &[ValTensor<F>],
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, CircuitError> {
op.layout(self, region, values)

View File

@@ -1,7 +1,7 @@
use super::*;
use crate::{
circuit::{layouts, utils},
fieldutils::{integer_rep_to_felt, IntegerRep},
fieldutils::{IntegerRep, integer_rep_to_felt},
graph::multiplier_to_scale,
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
};
@@ -109,13 +109,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
HybridOp::Greater
| HybridOp::Less
| HybridOp::Equals
| HybridOp::GreaterEqual
HybridOp::Greater { .. }
| HybridOp::Less { .. }
| HybridOp::Equals { .. }
| HybridOp::GreaterEqual { .. }
| HybridOp::Max
| HybridOp::Min
| HybridOp::LessEqual => {
| HybridOp::LessEqual { .. } => {
vec![0, 1]
}
_ => vec![],
@@ -213,7 +213,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>],
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::Rsqrt {
@@ -362,10 +362,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
HybridOp::Greater
| HybridOp::GreaterEqual
| HybridOp::Less
| HybridOp::LessEqual
HybridOp::Greater { .. }
| HybridOp::GreaterEqual { .. }
| HybridOp::Less { .. }
| HybridOp::LessEqual { .. }
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,

File diff suppressed because it is too large Load Diff

View File

@@ -186,7 +186,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>],
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(layouts::nonlinearity(
config,

View File

@@ -49,7 +49,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>],
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError>;
/// Returns the scale of the output of the operation.
@@ -209,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>],
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = values[0].clone();
if !value.all_prev_assigned() {
@@ -223,29 +223,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
true,
)?))
}
_ => {
if self.decomp {
log::debug!("constraining input to be decomp");
Ok(Some(
super::layouts::decompose(
config,
region,
values[..].try_into()?,
&region.base(),
&region.legs(),
false,
)?
.1,
))
} else {
log::debug!("constraining input to be identity");
Ok(Some(super::layouts::identity(
config,
region,
values[..].try_into()?,
)?))
}
}
_ => Ok(Some(super::layouts::identity(
config,
region,
values[..].try_into()?,
self.decomp,
)?)),
}
} else {
Ok(Some(value))
@@ -280,7 +263,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
&self,
_: &mut crate::circuit::BaseConfig<F>,
_: &mut RegionCtx<F>,
_: &[&ValTensor<F>],
_: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Err(super::CircuitError::UnsupportedOp)
}
@@ -336,13 +319,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
}
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self
@@ -355,20 +333,20 @@ impl<
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
_: &[&ValTensor<F>],
_: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = if let Some(value) = &self.pre_assigned_val {
value.clone()
} else {
self.quantized_values.clone().try_into()?
};
Ok(Some(if self.decomp {
log::debug!("constraining constant to be decomp");
super::layouts::decompose(config, region, &[&value], &region.base(), &region.legs(), false)?.1
} else {
log::debug!("constraining constant to be identity");
super::layouts::identity(config, region, &[&value])?
}))
// we gotta constrain it once if its used multiple times
Ok(Some(layouts::identity(
config,
region,
&[value],
self.decomp,
)?))
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {

View File

@@ -108,13 +108,8 @@ pub enum PolyOp {
}
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for PolyOp
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
@@ -208,11 +203,11 @@ impl<
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[&ValTensor<F>],
values: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?, true)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
PolyOp::LeakyReLU { slope, scale } => {
layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)?
}
@@ -340,7 +335,9 @@ impl<
PolyOp::Mult => {
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
}
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Identity { .. } => {
layouts::identity(config, region, values[..].try_into()?, false)?
}
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {
@@ -419,14 +416,14 @@ impl<
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
PolyOp::Sign => 0,
PolyOp::Sign { .. } => 0,
_ => in_scales[0],
};
Ok(scale)
}
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
if matches!(self, PolyOp::Add | PolyOp::Sub) {
if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) {
vec![0, 1]
} else if matches!(self, PolyOp::Iff) {
vec![1, 2]

View File

@@ -10,6 +10,7 @@ use halo2_proofs::{
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use maybe_rayon::iter::ParallelExtend;
use std::{
cell::RefCell,
@@ -461,14 +462,15 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
/// Update the max and min from inputs
pub fn update_max_min_lookup_inputs(
&mut self,
inputs: &ValTensor<F>,
inputs: &[ValTensor<F>],
) -> Result<(), CircuitError> {
let int_eval = inputs.int_evals()?;
let max = int_eval.iter().max().unwrap_or(&0);
let min = int_eval.iter().min().unwrap_or(&0);
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(*max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(*min);
let (mut min, mut max) = (0, 0);
for i in inputs {
max = max.max(i.int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.int_evals()?.into_iter().min().unwrap_or_default());
}
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
Ok(())
}
@@ -503,10 +505,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
/// add used lookup
pub fn add_used_lookup(
&mut self,
lookup: &LookupOp,
inputs: &ValTensor<F>,
lookup: LookupOp,
inputs: &[ValTensor<F>],
) -> Result<(), CircuitError> {
self.statistics.used_lookups.insert(lookup.clone());
self.statistics.used_lookups.insert(lookup);
self.update_max_min_lookup_inputs(inputs)
}
@@ -640,6 +642,34 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.assign_dynamic_lookup(var, values)
}
/// Assign a valtensor to a vartensor
pub fn assign_with_omissions(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
ommissions: &HashSet<usize>,
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
Ok(var.assign_with_omissions(
&mut region.borrow_mut(),
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)?)
} else {
let mut values_clone = values.clone();
let mut indices = ommissions.clone().into_iter().collect_vec();
values_clone.remove_indices(&mut indices, false)?;
let values_map = values.create_constants_map();
self.assigned_constants.par_extend(values_map);
Ok(values.clone())
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication_unconstrained(
&mut self,

View File

@@ -9,7 +9,6 @@ use halo2_proofs::{
};
use halo2curves::bn256::Fr as F;
use halo2curves::ff::{Field, PrimeField};
use itertools::Itertools;
#[cfg(not(any(
all(target_arch = "wasm32", target_os = "unknown"),
not(feature = "ezkl")
@@ -65,7 +64,7 @@ mod matmul {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -142,7 +141,7 @@ mod matmul_col_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -216,7 +215,7 @@ mod matmul_col_overflow {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -303,7 +302,7 @@ mod matmul_col_ultra_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -381,7 +380,6 @@ mod matmul_col_ultra_overflow {
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use super::*;
@@ -424,7 +422,7 @@ mod matmul_col_ultra_overflow {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -535,7 +533,7 @@ mod dot {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -612,7 +610,7 @@ mod dot_col_overflow_triple_col {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -685,7 +683,7 @@ mod dot_col_overflow {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -758,7 +756,7 @@ mod sum {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -828,7 +826,7 @@ mod sum_col_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -897,7 +895,7 @@ mod sum_col_overflow {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -968,7 +966,7 @@ mod composition {
let _ = config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -977,7 +975,7 @@ mod composition {
let _ = config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -986,7 +984,7 @@ mod composition {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs.clone(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -1063,7 +1061,7 @@ mod conv {
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
&self.inputs,
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1220,7 +1218,7 @@ mod conv_col_ultra_overflow {
config
.layout(
&mut region,
&[&self.image, &self.kernel],
&[self.image.clone(), self.kernel.clone()],
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1379,7 +1377,7 @@ mod conv_relu_col_ultra_overflow {
let output = config
.layout(
&mut region,
&[&self.image, &self.kernel],
&[self.image.clone(), self.kernel.clone()],
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1392,7 +1390,7 @@ mod conv_relu_col_ultra_overflow {
let _output = config
.layout(
&mut region,
&[&output.unwrap().unwrap()],
&[output.unwrap().unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -1519,11 +1517,7 @@ mod add_w_shape_casting {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.map_err(|_| Error::Synthesis)
},
)
@@ -1590,11 +1584,7 @@ mod add {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.map_err(|_| Error::Synthesis)
},
)
@@ -1681,8 +1671,8 @@ mod dynamic_lookup {
layouts::dynamic_lookup(
&config,
&mut region,
&self.lookups[i].iter().collect_vec().try_into().unwrap(),
&self.tables[i].iter().collect_vec().try_into().unwrap(),
&self.lookups[i],
&self.tables[i],
)
.map_err(|_| Error::Synthesis)?;
}
@@ -1777,8 +1767,8 @@ mod shuffle {
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [ValTensor<F>; NUM_LOOP],
references: [ValTensor<F>; NUM_LOOP],
inputs: [[ValTensor<F>; 1]; NUM_LOOP],
references: [[ValTensor<F>; 1]; NUM_LOOP],
_marker: PhantomData<F>,
}
@@ -1828,15 +1818,15 @@ mod shuffle {
layouts::shuffles(
&config,
&mut region,
&[&self.inputs[i]],
&[&self.references[i]],
&self.inputs[i],
&self.references[i],
layouts::SortCollisionMode::Unsorted,
)
.map_err(|_| Error::Synthesis)?;
}
assert_eq!(
region.shuffle_col_coord(),
NUM_LOOP * self.references[0].len()
NUM_LOOP * self.references[0][0].len()
);
assert_eq!(region.shuffle_index(), NUM_LOOP);
@@ -1853,19 +1843,17 @@ mod shuffle {
// parameters
let references = (0..NUM_LOOP)
.map(|loop_idx| {
ValTensor::from(Tensor::from(
(0..LEN).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
))
[ValTensor::from(Tensor::from((0..LEN).map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
})
.collect::<Vec<_>>();
let inputs = (0..NUM_LOOP)
.map(|loop_idx| {
ValTensor::from(Tensor::from(
(0..LEN)
.rev()
.map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
))
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
})
.collect::<Vec<_>>();
@@ -1885,11 +1873,9 @@ mod shuffle {
} else {
loop_idx - 1
};
ValTensor::from(Tensor::from(
(0..LEN)
.rev()
.map(|i| Value::known(F::from((i * prev_idx) as u64 + 1))),
))
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * prev_idx) as u64 + 1))
})))]
})
.collect::<Vec<_>>();
@@ -1945,11 +1931,7 @@ mod add_with_overflow {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.map_err(|_| Error::Synthesis)
},
)
@@ -2044,7 +2026,7 @@ mod add_with_overflow_and_poseidon {
layouter.assign_region(|| "_new_module", |_| Ok(()))?;
let inputs = vec![&assigned_inputs_a, &assigned_inputs_b];
let inputs = vec![assigned_inputs_a, assigned_inputs_b];
layouter.assign_region(
|| "model",
@@ -2153,11 +2135,7 @@ mod sub {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sub),
)
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Sub))
.map_err(|_| Error::Synthesis)
},
)
@@ -2224,11 +2202,7 @@ mod mult {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Mult),
)
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Mult))
.map_err(|_| Error::Synthesis)
},
)
@@ -2295,11 +2269,7 @@ mod pow {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Pow(5)),
)
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Pow(5)))
.map_err(|_| Error::Synthesis)
},
)
@@ -2390,13 +2360,13 @@ mod matmul_relu {
};
let output = config
.base_config
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.layout(&mut region, &self.inputs, Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[&output.unwrap()],
&[output.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -2495,7 +2465,7 @@ mod relu {
Ok(config
.layout(
&mut region,
&[&self.input],
&[self.input.clone()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -2593,7 +2563,7 @@ mod lookup_ultra_overflow {
config
.layout(
&mut region,
&[&self.input],
&[self.input.clone()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.map_err(|_| Error::Synthesis)

View File

@@ -1,6 +1,6 @@
use alloy::primitives::Address as H160;
use clap::{Command, Parser, Subcommand};
use clap_complete::{generate, Generator, Shell};
use clap_complete::{Generator, Shell, generate};
#[cfg(feature = "python-bindings")]
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
@@ -8,7 +8,7 @@ use std::path::PathBuf;
use std::str::FromStr;
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{pfsys::ProofType, Commitments, RunArgs};
use crate::{Commitments, RunArgs, pfsys::ProofType};
use crate::circuit::CheckMode;
use crate::graph::TestDataSource;
@@ -391,8 +391,10 @@ impl FromStr for DataField {
fn from_str(s: &str) -> Result<Self, Self::Err> {
// Check if the input starts with '@'
if let Some(file_path) = s.strip_prefix('@') {
if s.starts_with('@') {
// Extract the file path (remove the '@' prefix)
let file_path = &s[1..];
// Read the file content
let content = std::fs::read_to_string(file_path)
.map_err(|e| format!("Failed to read data file '{}': {}", file_path, e))?;

View File

@@ -1,24 +1,24 @@
use crate::graph::input::{CallToAccount, CallsToAccount, FileSourceInner, GraphData};
use crate::graph::modules::POSEIDON_INSTANCES;
use crate::graph::DataSource;
use crate::graph::GraphSettings;
use crate::pfsys::evm::EvmVerificationError;
use crate::graph::input::{CallToAccount, CallsToAccount, FileSourceInner, GraphData};
use crate::graph::modules::POSEIDON_INSTANCES;
use crate::pfsys::Snark;
use crate::pfsys::evm::EvmVerificationError;
use alloy::contract::CallBuilder;
use alloy::core::primitives::Address as H160;
use alloy::core::primitives::Bytes;
use alloy::core::primitives::U256;
use alloy::dyn_abi::abi::token::{DynSeqToken, PackedSeqToken, WordToken};
use alloy::dyn_abi::abi::TokenSeq;
use alloy::dyn_abi::abi::token::{DynSeqToken, PackedSeqToken, WordToken};
// use alloy::providers::Middleware;
use alloy::json_abi::JsonAbi;
use alloy::primitives::ruint::ParseError;
use alloy::primitives::{ParseSignedError, B256, I256};
use alloy::primitives::{B256, I256, ParseSignedError};
use alloy::providers::ProviderBuilder;
use alloy::providers::fillers::{
ChainIdFiller, FillProvider, GasFiller, JoinFill, NonceFiller, SignerFiller,
};
use alloy::providers::network::{Ethereum, EthereumSigner};
use alloy::providers::ProviderBuilder;
use alloy::providers::{Identity, Provider, RootProvider};
use alloy::rpc::types::eth::TransactionInput;
use alloy::rpc::types::eth::TransactionRequest;
@@ -27,9 +27,9 @@ use alloy::signers::wallet::{LocalWallet, WalletError};
use alloy::sol as abigen;
use alloy::transports::http::Http;
use alloy::transports::{RpcError, TransportErrorKind};
use foundry_compilers::Solc;
use foundry_compilers::artifacts::Settings as SolcSettings;
use foundry_compilers::error::{SolcError, SolcIoError};
use foundry_compilers::Solc;
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::{Fr, G1Affine};
use halo2curves::group::ff::PrimeField;
@@ -711,7 +711,7 @@ pub async fn verify_proof_with_data_attestation(
};
let encoded = func.encode_input(&[
Token::Address(addr_verifier.0 .0.into()),
Token::Address(addr_verifier.0.0.into()),
Token::Bytes(encoded_verifier),
])?;
@@ -757,7 +757,7 @@ pub async fn test_on_chain_data<M: 'static + Provider<Http<Client>, Ethereum>>(
let call_to_account = CallToAccount {
call_data: hex::encode(call),
decimals,
address: hex::encode(contract.address().0 .0),
address: hex::encode(contract.address().0.0),
};
info!("call_to_account: {:#?}", call_to_account);
Ok(call_to_account)
@@ -913,9 +913,9 @@ pub async fn get_contract_artifacts(
runs: usize,
) -> Result<(JsonAbi, Bytes, Bytes), EthError> {
use foundry_compilers::{
artifacts::{output_selection::OutputSelection, Optimizer},
SHANGHAI_SOLC, SolcInput,
artifacts::{Optimizer, output_selection::OutputSelection},
compilers::CompilerInput,
SolcInput, SHANGHAI_SOLC,
};
if !sol_code_path.exists() {

View File

@@ -1,5 +1,6 @@
use crate::circuit::region::RegionSettings;
use crate::EZKL_BUF_CAPACITY;
use crate::circuit::CheckMode;
use crate::circuit::region::RegionSettings;
use crate::commands::CalibrationTarget;
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_sol};
#[allow(unused_imports)]
@@ -9,21 +10,21 @@ use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
use crate::graph::{TestDataSource, TestSources};
use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript};
use crate::pfsys::{
create_keys, load_pk, load_vk, save_params, save_pk, Snark, StrategyType, TranscriptType,
ProofSplitCommit, create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit,
};
use crate::pfsys::{
create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit, ProofSplitCommit,
Snark, StrategyType, TranscriptType, create_keys, load_pk, load_vk, save_params, save_pk,
};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::EZKL_BUF_CAPACITY;
use crate::{commands::*, EZKLError};
use crate::{Commitments, RunArgs};
use crate::{EZKLError, commands::*};
use colored::Colorize;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
use halo2_proofs::plonk::{self, Circuit};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params};
use halo2_proofs::poly::commitment::{ParamsProver, Verifier};
use halo2_proofs::poly::ipa::commitment::{IPACommitmentScheme, ParamsIPA};
@@ -36,7 +37,6 @@ use halo2_proofs::poly::kzg::strategy::AccumulatorStrategy as KZGAccumulatorStra
use halo2_proofs::poly::kzg::{
commitment::ParamsKZG, strategy::SingleStrategy as KZGSingleStrategy,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer};
use halo2_solidity_verifier;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
@@ -48,12 +48,12 @@ use itertools::Itertools;
use lazy_static::lazy_static;
use log::debug;
use log::{info, trace, warn};
use serde::de::DeserializeOwned;
use serde::Serialize;
use serde::de::DeserializeOwned;
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::Config;
use snark_verifier::system::halo2::compile;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use snark_verifier::system::halo2::Config;
use std::fs::File;
use std::io::BufWriter;
use std::io::{Cursor, Write};
@@ -1139,7 +1139,6 @@ pub(crate) async fn calibrate(
let mut num_failed = 0;
let mut num_passed = 0;
let mut failure_reasons = vec![];
for ((input_scale, param_scale), scale_rebase_multiplier) in range_grid {
pb.set_message(format!(
@@ -1164,21 +1163,20 @@ pub(crate) async fn calibrate(
// if unix get a gag
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _r = Gag::stdout().ok();
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _g = Gag::stderr().ok();
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(e) => {
error!("circuit creation from run args failed: {:?}", e);
failure_reasons.push(format!(
"i-scale: {}, p-scale: {}, rebase-(x): {}, reason: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().yellow(),
e
));
pb.inc(1);
num_failed += 1;
continue;
@@ -1223,13 +1221,6 @@ pub(crate) async fn calibrate(
error!("forward pass failed: {:?}", e);
pb.inc(1);
num_failed += 1;
failure_reasons.push(format!(
"i-scale: {}, p-scale: {}, rebase-(x): {}, reason: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().yellow(),
e
));
continue;
}
}
@@ -1295,14 +1286,7 @@ pub(crate) async fn calibrate(
found_settings.as_json()?.to_colored_json_auto()?
);
num_passed += 1;
} else if let Err(res) = res {
failure_reasons.push(format!(
"i-scale: {}, p-scale: {}, rebase-(x): {}, reason: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().yellow(),
res.to_string().red()
));
} else {
num_failed += 1;
}
@@ -1312,13 +1296,6 @@ pub(crate) async fn calibrate(
pb.finish_with_message("Calibration Done.");
if found_params.is_empty() {
if !failure_reasons.is_empty() {
error!("Calibration failed for the following reasons:");
for reason in failure_reasons {
error!("{}", reason);
}
}
return Err("calibration failed, could not find any suitable parameters given the calibration dataset".into());
}
@@ -1352,7 +1329,7 @@ pub(crate) async fn calibrate(
.clone()
}
CalibrationTarget::Accuracy => {
let mut param_iterator = found_params.iter().sorted_by_key(|p| {
let param_iterator = found_params.iter().sorted_by_key(|p| {
(
p.run_args.input_scale,
p.run_args.param_scale,
@@ -1361,7 +1338,7 @@ pub(crate) async fn calibrate(
)
});
let last = param_iterator.next_back().ok_or("no params found")?;
let last = param_iterator.last().ok_or("no params found")?;
let max_scale = (
last.run_args.input_scale,
last.run_args.param_scale,

View File

@@ -67,11 +67,8 @@ pub enum GraphError {
#[error("invalid input types")]
InvalidInputTypes,
/// Missing results
#[error("missing result for node {0}")]
MissingResults(usize),
/// Missing input
#[error("missing input {0}")]
MissingInputForNode(usize),
#[error("missing results")]
MissingResults,
/// Tensor error
#[error("[tensor] {0}")]
TensorError(#[from] crate::tensor::TensorError),

View File

@@ -1,15 +1,15 @@
use super::errors::GraphError;
use super::quantize_float;
use crate::EZKL_BUF_CAPACITY;
use crate::circuit::InputType;
use crate::fieldutils::integer_rep_to_felt;
use crate::EZKL_BUF_CAPACITY;
use halo2curves::bn256::Fr as Fp;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::io::BufReader;
use std::io::BufWriter;
@@ -17,7 +17,7 @@ use std::io::Read;
use std::panic::UnwindSafe;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::{
tract_data::{prelude::Tensor as TractTensor, TVec},
tract_data::{TVec, prelude::Tensor as TractTensor},
value::TValue,
};
@@ -425,7 +425,9 @@ impl GraphData {
pub fn from_str(data: &str) -> Result<Self, GraphError> {
let graph_input = serde_json::from_str(data);
match graph_input {
Ok(graph_input) => Ok(graph_input),
Ok(graph_input) => {
return Ok(graph_input);
}
Err(_) => {
let path = std::path::PathBuf::from(data);
GraphData::from_path(path)

View File

@@ -33,12 +33,12 @@ use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSize
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::region::{ConstantsMap, RegionSettings};
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::table::{RESERVED_BLINDING_ROWS_PAD, Range, Table, num_cols_required};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::{felt_to_f64, IntegerRep};
use crate::fieldutils::{IntegerRep, felt_to_f64};
use crate::pfsys::PrettyElements;
use crate::tensor::{Tensor, ValTensor};
use crate::{RunArgs, EZKL_BUF_CAPACITY};
use crate::{EZKL_BUF_CAPACITY, RunArgs};
use halo2_proofs::{
circuit::Layouter,
@@ -53,13 +53,13 @@ use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
pub use model::*;
pub use node::*;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDictMethods;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
@@ -571,7 +571,7 @@ impl GraphSettings {
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, &self).map_err(|e| {
error!("failed to save settings file at {}", e);
std::io::Error::other(e)
std::io::Error::new(std::io::ErrorKind::Other, e)
})
}
/// load params from file
@@ -581,7 +581,7 @@ impl GraphSettings {
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
let settings: GraphSettings = serde_json::from_reader(reader).map_err(|e| {
error!("failed to load settings file at {}", e);
std::io::Error::other(e)
std::io::Error::new(std::io::ErrorKind::Other, e)
})?;
crate::check_version_string_matches(&settings.version);
@@ -1232,9 +1232,15 @@ impl GraphCircuit {
let mut cs = ConstraintSystem::default();
// if unix get a gag
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _r = Gag::stdout().ok();
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _g = Gag::stderr().ok();
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
Self::configure_with_params(&mut cs, settings);
@@ -1570,13 +1576,13 @@ impl Circuit<Fp> for GraphCircuit {
let mut module_configs = ModuleConfigs::from_visibility(
cs,
&params.module_sizes,
params.module_sizes.clone(),
params.run_args.logrows as usize,
);
let mut vars = ModelVars::new(cs, &params);
module_configs.configure_complex_modules(cs, &visibility, &params.module_sizes);
module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone());
vars.instantiate_instance(
cs,

View File

@@ -200,7 +200,7 @@ fn number_of_iterations(mappings: &[InputMapping], dims: Vec<&[usize]>) -> usize
InputMapping::Stacked { axis, chunk } => Some(
// number of iterations given the dim size along the axis
// and the chunk size
dims[*axis].div_ceil(*chunk), // (dims[*axis] + chunk - 1) / chunk,
(dims[*axis] + chunk - 1) / chunk,
),
_ => None,
});
@@ -589,7 +589,10 @@ impl Model {
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
input_types: self.get_input_types().ok(),
input_types: match self.get_input_types() {
Ok(x) => Some(x),
Err(_) => None,
},
output_types: Some(self.get_output_types()),
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
@@ -647,13 +650,10 @@ impl Model {
let variables: std::collections::HashMap<String, usize> =
std::collections::HashMap::from_iter(variables.iter().map(|(k, v)| (k.clone(), *v)));
let inputs = model.inputs.clone();
let outputs = model.outputs.clone();
for (i, id) in inputs.iter().enumerate() {
for (i, id) in model.clone().inputs.iter().enumerate() {
let input = model.node_mut(id.node);
if input.outputs.is_empty() {
if input.outputs.len() == 0 {
return Err(GraphError::MissingOutput(id.node));
}
let mut fact: InferenceFact = input.outputs[0].fact.clone();
@@ -672,7 +672,7 @@ impl Model {
model.set_input_fact(i, fact)?;
}
for (i, _) in outputs.iter().enumerate() {
for (i, _) in model.clone().outputs.iter().enumerate() {
model.set_output_fact(i, InferenceFact::default())?;
}
@@ -1196,7 +1196,7 @@ impl Model {
.base
.layout(
&mut thread_safe_region,
&[output, &comparators],
&[output.clone(), comparators],
Box::new(HybridOp::Output {
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
@@ -1257,27 +1257,12 @@ impl Model {
node.inputs()
.iter()
.map(|(idx, outlet)| {
// check node is not an output
let is_output = self.graph.outputs.iter().any(|(o_idx, _)| *idx == *o_idx);
let res = if self.graph.nodes[idx].num_uses() == 1 && !is_output {
let res = results.remove(idx);
res.ok_or(GraphError::MissingResults(*idx))?[*outlet].clone()
} else {
results.get(idx).ok_or(GraphError::MissingResults(*idx))?[*outlet]
.clone()
};
Ok(res)
Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
})
.collect::<Result<Vec<_>, GraphError>>()?
} else {
// we re-assign inputs, always from the 0 outlet
if self.graph.nodes[idx].num_uses() == 1 {
let res = results.remove(idx);
vec![res.ok_or(GraphError::MissingInput(*idx))?[0].clone()]
} else {
vec![results.get(idx).ok_or(GraphError::MissingInput(*idx))?[0].clone()]
}
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
};
trace!("output dims: {:?}", node.out_dims());
trace!(
@@ -1288,7 +1273,7 @@ impl Model {
let start = instant::Instant::now();
match &node {
NodeType::Node(n) => {
let mut res = if node.is_constant() && node.num_uses() == 1 {
let res = if node.is_constant() && node.num_uses() == 1 {
log::debug!("node {} is a constant with 1 use", n.idx);
let mut node = n.clone();
let c = node
@@ -1299,19 +1284,19 @@ impl Model {
} else {
config
.base
.layout(region, &values.iter().collect_vec(), n.opkind.clone_dyn())
.layout(region, &values, n.opkind.clone_dyn())
.map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?
};
if let Some(vt) = &mut res {
if let Some(mut vt) = res {
vt.reshape(&node.out_dims()[0])?;
//only use with mock prover
debug!("------------ output node {:?}: {:?}", idx, vt.show());
// we get the max as for fused nodes this corresponds to the node output
results.insert(*idx, vec![vt.clone()]);
//only use with mock prover
debug!("------------ output node {:?}: {:?}", idx, vt.show());
}
}
NodeType::SubGraph {
@@ -1355,7 +1340,7 @@ impl Model {
.inputs
.clone()
.into_iter()
.zip(values.iter().map(|v| vec![v.clone()])),
.zip(values.clone().into_iter().map(|v| vec![v])),
);
let res = model.layout_nodes(config, region, &mut subgraph_results)?;
@@ -1436,7 +1421,7 @@ impl Model {
);
let outputs = output_nodes
.map(|(idx, outlet)| {
Ok(results.get(idx).ok_or(GraphError::MissingResults(*idx))?[*outlet].clone())
Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
})
.collect::<Result<Vec<_>, GraphError>>()?;
@@ -1491,7 +1476,7 @@ impl Model {
dummy_config.layout(
&mut region,
&[output, &comparator],
&[output.clone(), comparator],
Box::new(HybridOp::Output {
decomp: !run_args.ignore_range_check_inputs_outputs,
}),

View File

@@ -37,15 +37,15 @@ impl ModuleConfigs {
/// Create new module configs from visibility of each variable
pub fn from_visibility(
cs: &mut ConstraintSystem<Fp>,
module_size: &ModuleSizes,
module_size: ModuleSizes,
logrows: usize,
) -> Self {
let mut config = Self::default();
for size in &module_size.polycommit {
for size in module_size.polycommit {
config
.polycommit
.push(PolyCommitChip::configure(cs, (logrows, *size)));
.push(PolyCommitChip::configure(cs, (logrows, size)));
}
config
@@ -55,8 +55,8 @@ impl ModuleConfigs {
pub fn configure_complex_modules(
&mut self,
cs: &mut ConstraintSystem<Fp>,
visibility: &VarVisibility,
module_size: &ModuleSizes,
visibility: VarVisibility,
module_size: ModuleSizes,
) {
if (visibility.input.is_hashed()
|| visibility.output.is_hashed()

View File

@@ -37,7 +37,6 @@ use crate::tensor::TensorError;
// Import curve-specific field type
use halo2curves::bn256::Fr as Fp;
use itertools::Itertools;
// Import logging for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use log::trace;
@@ -119,15 +118,16 @@ impl Op<Fp> for Rescaled {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[&crate::tensor::ValTensor<Fp>],
values: &[crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
if self.scale.len() != values.len() {
return Err(TensorError::DimMismatch("rescaled inputs".to_string()).into());
}
let res =
crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)?;
self.inner.layout(config, region, &res.iter().collect_vec())
&crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)?
[..];
self.inner.layout(config, region, res)
}
/// Create a cloned boxed copy of this operation
@@ -274,13 +274,13 @@ impl Op<Fp> for RebaseScale {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[&crate::tensor::ValTensor<Fp>],
values: &[crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
let original_res = self
.inner
.layout(config, region, values)?
.ok_or(CircuitError::MissingLayout(self.as_string()))?;
self.rebase_op.layout(config, region, &[&original_res])
self.rebase_op.layout(config, region, &[original_res])
}
/// Create a cloned boxed copy of this operation
@@ -472,7 +472,7 @@ impl Op<Fp> for SupportedOp {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[&crate::tensor::ValTensor<Fp>],
values: &[crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
self.as_op().layout(config, region, values)
}

View File

@@ -1,14 +1,14 @@
use super::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
use super::errors::GraphError;
use super::{Rescaled, SupportedOp, Visibility};
use crate::circuit::Op;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::hybrid::HybridOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::lookup::LookupOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::poly::PolyOp;
use crate::circuit::Op;
use crate::fieldutils::IntegerRep;
use crate::tensor::{Tensor, TensorError, TensorType};
use halo2curves::bn256::Fr as Fp;
@@ -22,6 +22,7 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::ops::{
Downsample,
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
@@ -31,7 +32,6 @@ use tract_onnx::tract_core::ops::{
einsum::EinSum,
element_wise::ElementWiseOp,
nn::{LeakyRelu, Reduce, Softmax},
Downsample,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::{

View File

@@ -29,14 +29,8 @@
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
use log::warn;
#[global_allocator]
#[cfg(all(feature = "jemalloc", not(target_arch = "wasm32")))]
static GLOBAL: jemallocator::Jemalloc = jemallocator::Jemalloc;
#[global_allocator]
#[cfg(all(feature = "mimalloc", not(target_arch = "wasm32")))]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc as _;
/// Error type
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
@@ -103,11 +97,11 @@ impl From<String> for EZKLError {
use std::str::FromStr;
use circuit::{table::Range, CheckMode};
use circuit::{CheckMode, table::Range};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::Args;
use fieldutils::IntegerRep;
use graph::{Visibility, MAX_PUBLIC_SRS};
use graph::{MAX_PUBLIC_SRS, Visibility};
use halo2_proofs::poly::{
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
};

View File

@@ -17,16 +17,16 @@ use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
use clap::ValueEnum;
use halo2_proofs::circuit::Value;
use halo2_proofs::plonk::{
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
Circuit, ProvingKey, VerifyingKey, create_proof, keygen_pk, keygen_vk_custom, verify_proof,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
use halo2curves::CurveAffine;
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use halo2curves::serde::SerdeObject;
use halo2curves::CurveAffine;
use instant::Instant;
use log::{debug, info, trace};
#[cfg(not(feature = "det-prove"))]
@@ -324,7 +324,7 @@ where
}
#[cfg(feature = "python-bindings")]
use pyo3::{types::PyDict, PyObject, Python, ToPyObject};
use pyo3::{PyObject, Python, ToPyObject, types::PyDict};
#[cfg(feature = "python-bindings")]
impl<F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> ToPyObject for Snark<F, C>
where
@@ -348,9 +348,9 @@ where
}
impl<
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
where
C::Scalar: Serialize + DeserializeOwned,
C::ScalarExt: Serialize + DeserializeOwned,

View File

@@ -27,7 +27,7 @@ pub use var::*;
use crate::{
circuit::utils,
fieldutils::{integer_rep_to_felt, IntegerRep},
fieldutils::{IntegerRep, integer_rep_to_felt},
graph::Visibility,
};
@@ -42,11 +42,11 @@ use std::error::Error;
use std::fmt::Debug;
use std::io::Read;
use std::iter::Iterator;
use std::ops::Rem;
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
use std::{cmp::max, ops::Rem};
/// The (inner) type of tensor elements.
pub trait TensorType: Clone + Debug {
pub trait TensorType: Clone + Debug + 'static {
/// Returns the zero value.
fn zero() -> Option<Self> {
None
@@ -55,6 +55,10 @@ pub trait TensorType: Clone + Debug {
fn one() -> Option<Self> {
None
}
/// Max operator for ordering values.
fn tmax(&self, _: &Self) -> Option<Self> {
None
}
}
macro_rules! tensor_type {
@@ -66,6 +70,10 @@ macro_rules! tensor_type {
fn one() -> Option<Self> {
Some($one)
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some(max(*self, *other))
}
}
};
}
@@ -74,12 +82,46 @@ impl TensorType for f32 {
fn zero() -> Option<Self> {
Some(0.0)
}
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
// A comparison between f32s needs to handle NAN values.
fn tmax(&self, other: &Self) -> Option<Self> {
match (self.is_nan(), other.is_nan()) {
(true, true) => Some(f32::NAN),
(true, false) => Some(*other),
(false, true) => Some(*self),
(false, false) => {
if self >= other {
Some(*self)
} else {
Some(*other)
}
}
}
}
}
impl TensorType for f64 {
fn zero() -> Option<Self> {
Some(0.0)
}
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
// A comparison between f32s needs to handle NAN values.
fn tmax(&self, other: &Self) -> Option<Self> {
match (self.is_nan(), other.is_nan()) {
(true, true) => Some(f64::NAN),
(true, false) => Some(*other),
(false, true) => Some(*self),
(false, false) => {
if self >= other {
Some(*self)
} else {
Some(*other)
}
}
}
}
}
tensor_type!(bool, Bool, false, true);
@@ -105,6 +147,14 @@ impl<T: TensorType> TensorType for Value<T> {
fn one() -> Option<Self> {
Some(Value::known(T::one().unwrap()))
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some(
(self.clone())
.zip(other.clone())
.map(|(a, b)| a.tmax(&b).unwrap()),
)
}
}
impl<F: PrimeField + PartialOrd> TensorType for Assigned<F>
@@ -118,6 +168,14 @@ where
fn one() -> Option<Self> {
Some(F::ONE.into())
}
fn tmax(&self, other: &Self) -> Option<Self> {
if self.evaluate() >= other.evaluate() {
Some(*self)
} else {
Some(*other)
}
}
}
impl<F: PrimeField> TensorType for Expression<F>
@@ -131,14 +189,42 @@ where
fn one() -> Option<Self> {
Some(Expression::Constant(F::ONE))
}
fn tmax(&self, _: &Self) -> Option<Self> {
todo!()
}
}
impl TensorType for Column<Advice> {}
impl TensorType for Column<Fixed> {}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<Assigned<F>, F> {}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<Assigned<F>, F> {
fn tmax(&self, other: &Self) -> Option<Self> {
let mut output: Option<Self> = None;
self.value_field().zip(other.value_field()).map(|(a, b)| {
if a.evaluate() >= b.evaluate() {
output = Some(self.clone());
} else {
output = Some(other.clone());
}
});
output
}
}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<F, F> {}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<F, F> {
fn tmax(&self, other: &Self) -> Option<Self> {
let mut output: Option<Self> = None;
self.value().zip(other.value()).map(|(a, b)| {
if a >= b {
output = Some(self.clone());
} else {
output = Some(other.clone());
}
});
output
}
}
// specific types
impl TensorType for halo2curves::pasta::Fp {
@@ -149,6 +235,10 @@ impl TensorType for halo2curves::pasta::Fp {
fn one() -> Option<Self> {
Some(halo2curves::pasta::Fp::one())
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some((*self).max(*other))
}
}
impl TensorType for halo2curves::bn256::Fr {
@@ -159,15 +249,9 @@ impl TensorType for halo2curves::bn256::Fr {
fn one() -> Option<Self> {
Some(halo2curves::bn256::Fr::one())
}
}
impl<F: TensorType> TensorType for &F {
fn zero() -> Option<Self> {
None
}
fn one() -> Option<Self> {
None
fn tmax(&self, other: &Self) -> Option<Self> {
Some((*self).max(*other))
}
}
@@ -290,7 +374,7 @@ impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
}
}
impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync + 'data>
impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync>
maybe_rayon::iter::IntoParallelRefMutIterator<'data> for Tensor<T>
{
type Iter = maybe_rayon::slice::IterMut<'data, T>;
@@ -339,14 +423,6 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
}
}
impl<T: Clone + TensorType> Tensor<&T> {
/// Clones the tensor values into a new tensor.
pub fn cloned(&self) -> Tensor<T> {
let inner = self.inner.clone().into_iter().cloned().collect::<Vec<T>>();
Tensor::new(Some(&inner), &self.dims).unwrap()
}
}
impl<T: Clone + TensorType> Tensor<T> {
/// Sets (copies) the tensor values to the provided ones.
pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result<Self, TensorError> {
@@ -478,6 +554,7 @@ impl<T: Clone + TensorType> Tensor<T> {
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
/// ```
@@ -554,23 +631,23 @@ impl<T: Clone + TensorType> Tensor<T> {
// Fill remaining dimensions
full_indices.extend((indices.len()..self.dims.len()).map(|i| 0..self.dims[i]));
// Pre-calculate total size and allocate result vector
let total_size: usize = full_indices
.iter()
.map(|range| range.end - range.start)
.product();
let mut res = Vec::with_capacity(total_size);
// Calculate new dimensions once
let dims: Vec<usize> = full_indices.iter().map(|e| e.end - e.start).collect();
let mut output = Tensor::new(None, &dims)?;
// Use iterator directly without collecting into intermediate Vec
for coord in full_indices.iter().cloned().multi_cartesian_product() {
let index = self.get_index(&coord);
res.push(self[index].clone());
}
let cartesian_coord: Vec<Vec<usize>> = full_indices
.iter()
.cloned()
.multi_cartesian_product()
.collect();
output.par_iter_mut().enumerate().for_each(|(i, e)| {
let coord = &cartesian_coord[i];
*e = self.get(coord);
});
Ok(output)
Tensor::new(Some(&res), &dims)
}
/// Set a slice of the Tensor.
@@ -676,7 +753,7 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
pub fn get_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
for (i, elem) in self.inner.iter().enumerate() {
for (i, elem) in self.inner.clone().into_iter().enumerate() {
if i % n == 0 {
inner.push(elem.clone());
}
@@ -699,7 +776,7 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
pub fn exclude_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
for (i, elem) in self.inner.iter().enumerate() {
for (i, elem) in self.inner.clone().into_iter().enumerate() {
if i % n != 0 {
inner.push(elem.clone());
}
@@ -735,9 +812,9 @@ impl<T: Clone + TensorType> Tensor<T> {
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
let mut offset = initial_offset;
for (i, elem) in self.inner.iter().enumerate() {
for (i, elem) in self.inner.clone().into_iter().enumerate() {
if (i + offset + 1) % n == 0 {
inner.extend(vec![elem.clone(); 1 + num_repeats]);
inner.extend(vec![elem; 1 + num_repeats]);
offset += num_repeats;
} else {
inner.push(elem.clone());
@@ -794,16 +871,16 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
/// let mut indices = vec![3, 4];
/// assert_eq!(a.remove_indices(&mut indices, false).unwrap(), expected);
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), expected);
///
///
/// let a = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
/// let b = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
/// let mut indices = vec![63, 67];
/// assert_eq!(a.remove_indices(&mut indices, false).unwrap(), b);
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), b);
/// ```
pub fn remove_indices(
&self,
@@ -850,7 +927,7 @@ impl<T: Clone + TensorType> Tensor<T> {
}
self.dims = vec![];
}
if self.dims() == [0] && new_dims.iter().product::<usize>() == 1 {
if self.dims() == &[0] && new_dims.iter().product::<usize>() == 1 {
self.dims = Vec::from(new_dims);
} else {
let product = if new_dims != [0] {
@@ -1215,6 +1292,33 @@ impl<T: Clone + TensorType> Tensor<T> {
Tensor::new(Some(&[res]), &[1])
}
/// Maps a function to tensors and enumerates in parallel
/// ```
/// use ezkl::tensor::{Tensor, TensorError};
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
/// ```
pub fn par_enum_map_mut_filtered<
F: Fn(usize) -> Result<T, E> + std::marker::Send + std::marker::Sync,
E: Error + std::marker::Send + std::marker::Sync,
>(
&mut self,
filter_indices: &std::collections::HashSet<usize>,
f: F,
) -> Result<(), E>
where
T: std::marker::Send + std::marker::Sync,
{
self.inner
.par_iter_mut()
.enumerate()
.filter(|(i, _)| filter_indices.contains(i))
.for_each(move |(i, e)| *e = f(i).unwrap());
Ok(())
}
}
impl<T: Clone + TensorType> Tensor<Tensor<T>> {
@@ -1231,9 +1335,9 @@ impl<T: Clone + TensorType> Tensor<Tensor<T>> {
pub fn combine(&self) -> Result<Tensor<T>, TensorError> {
let mut dims = 0;
let mut inner = Vec::new();
for t in self.inner.iter() {
for t in self.inner.clone().into_iter() {
dims += t.len();
inner.extend(t.inner.clone());
inner.extend(t.inner);
}
Tensor::new(Some(&inner), &[dims])
}
@@ -1704,7 +1808,7 @@ impl DataFormat {
match self {
DataFormat::NHWC => DataFormat::NCHW,
DataFormat::HWC => DataFormat::CHW,
_ => *self,
_ => self.clone(),
}
}
@@ -1804,7 +1908,7 @@ impl KernelFormat {
match self {
KernelFormat::HWIO => KernelFormat::OIHW,
KernelFormat::OHWI => KernelFormat::OIHW,
_ => *self,
_ => self.clone(),
}
}
@@ -1926,9 +2030,6 @@ mod tests {
fn tensor_slice() {
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
assert_eq!(
a.get_slice(&[0..2, 0..1]).unwrap(),
b.get_slice(&[0..2, 0..1]).unwrap()
);
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
}
}

View File

@@ -916,14 +916,11 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 7]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
pub fn gather_nd<'a, T: TensorType + Send + Sync + 'a>(
input: &'a Tensor<T>,
pub fn gather_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
batch_dims: usize,
) -> Result<Tensor<T>, TensorError>
where
&'a T: TensorType,
{
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
@@ -1111,14 +1108,11 @@ where
/// assert_eq!(result, expected);
/// ````
///
pub fn scatter_nd<'a, T: TensorType + Send + Sync + 'a>(
pub fn scatter_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
index: &Tensor<usize>,
src: &'a Tensor<T>,
) -> Result<Tensor<T>, TensorError>
where
&'a T: TensorType,
{
src: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
@@ -1189,12 +1183,12 @@ pub fn abs<T: TensorType + Add<Output = T> + std::cmp::Ord + Neg<Output = T>>(
/// use ezkl::tensor::ops::intercalate_values;
///
/// let tensor = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
/// let result = intercalate_values(&tensor, &0, 2, 1).unwrap();
/// let result = intercalate_values(&tensor, 0, 2, 1).unwrap();
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 2, 3, 0, 4]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// let result = intercalate_values(&expected, &0, 2, 0).unwrap();
/// let result = intercalate_values(&expected, 0, 2, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 2, 0, 0, 0, 3, 0, 4]), &[3, 3]).unwrap();
///
/// assert_eq!(result, expected);
@@ -1202,7 +1196,7 @@ pub fn abs<T: TensorType + Add<Output = T> + std::cmp::Ord + Neg<Output = T>>(
/// ```
pub fn intercalate_values<T: TensorType>(
tensor: &Tensor<T>,
value: &T,
value: T,
stride: usize,
axis: usize,
) -> Result<Tensor<T>, TensorError> {
@@ -1500,7 +1494,7 @@ pub fn slice<T: TensorType + Send + Sync>(
}
}
Ok(t.get_slice(&slice)?)
t.get_slice(&slice)
}
// ---------------------------------------------------------------------------------------------------------
@@ -2420,20 +2414,20 @@ pub mod accumulated {
/// Some(&[25, 35]),
/// &[2],
/// ).unwrap();
/// assert_eq!(dot(&x, &y, 1).unwrap(), expected);
/// assert_eq!(dot(&[x, y], 1).unwrap(), expected);
/// ```
pub fn dot<T: TensorType + Mul<Output = T> + Add<Output = T>>(
a: &Tensor<T>,
b: &Tensor<T>,
inputs: &[Tensor<T>; 2],
chunk_size: usize,
) -> Result<Tensor<T>, TensorError> {
if a.len() != b.len() {
if inputs[0].clone().len() != inputs[1].clone().len() {
return Err(TensorError::DimMismatch("dot".to_string()));
}
let (a, b): (Tensor<T>, Tensor<T>) = (inputs[0].clone(), inputs[1].clone());
let transcript: Tensor<T> = a
.iter()
.zip(b.iter())
.zip(b)
.chunks(chunk_size)
.into_iter()
.scan(T::zero().unwrap(), |acc, chunk| {

View File

@@ -280,17 +280,9 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Vec<ValType<F>>> for ValTenso
fn from(t: Vec<ValType<F>>) -> ValTensor<F> {
ValTensor::Value {
inner: t.clone().into_iter().into(),
dims: vec![t.len()],
scale: 1,
}
}
}
impl<F: PrimeField + TensorType + PartialOrd> From<Vec<&ValType<F>>> for ValTensor<F> {
fn from(t: Vec<&ValType<F>>) -> ValTensor<F> {
ValTensor::Value {
inner: t.clone().into_iter().cloned().into(),
dims: vec![t.len()],
scale: 1,
}
}
@@ -648,9 +640,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// # Returns
/// A tensor containing the base-n decomposition of each value
pub fn decompose(&self, base: usize, n: usize) -> Result<Self, TensorError> {
let mut dims = self.dims().to_vec();
dims.push(n + 1);
let res = self
.get_inner()?
.par_iter()
@@ -676,7 +665,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
})
.collect::<Result<Vec<_>, _>>();
let mut tensor = res?.into_iter().flatten().collect::<Tensor<_>>();
let mut tensor = Tensor::from(res?.into_iter().flatten().collect::<Vec<_>>().into_iter());
let mut dims = self.dims().to_vec();
dims.push(n + 1);
tensor.reshape(&dims)?;
@@ -1052,6 +1043,119 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(())
}
/// Removes constant zero values from the tensor
/// Uses parallel processing for tensors larger than a threshold
pub fn remove_const_zero_values(&mut self) {
let size_threshold = 1_000_000; // Tuned using benchmarks
if self.len() < size_threshold {
// Single-threaded for small tensors
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.clone()
.into_iter()
.filter_map(|e| {
if let ValType::Constant(r) = e {
if r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if r == F::ZERO {
return None;
}
}
Some(e)
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => {}
}
} else {
// Parallel processing for large tensors
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.par_chunks_mut(chunk_size)
.flat_map(|chunk| {
chunk.par_iter_mut().filter_map(|e| {
if let ValType::Constant(r) = e {
if *r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if *r == F::ZERO {
return None;
}
}
Some(e.clone())
})
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => {}
}
}
}
/// Gets the indices of all constant zero values
/// Uses parallel processing for large tensors
///
/// # Returns
/// A vector of indices where constant zero values are located
pub fn get_const_zero_indices(&self) -> Vec<usize> {
let size_threshold = 1_000_000;
if self.len() < size_threshold {
// Single-threaded for small tensors
match &self {
ValTensor::Value { inner: v, .. } => v
.iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect(),
ValTensor::Instance { .. } => vec![],
}
} else {
// Parallel processing for large tensors
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match &self {
ValTensor::Value { inner: v, .. } => v
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
chunk
.par_iter()
.enumerate()
.filter_map(move |(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
}
_ => None,
})
})
.collect::<Vec<_>>(),
ValTensor::Instance { .. } => vec![],
}
}
}
/// Gets the indices of all constant values
/// Uses parallel processing for large tensors
///
@@ -1172,7 +1276,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Returns an error if called on an Instance tensor
pub fn intercalate_values(
&mut self,
value: &ValType<F>,
value: ValType<F>,
stride: usize,
axis: usize,
) -> Result<(), TensorError> {
@@ -1264,10 +1368,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// Concatenates two tensors along the first dimension
pub fn concat(&self, other: &Self) -> Result<Self, TensorError> {
pub fn concat(&self, other: Self) -> Result<Self, TensorError> {
let res = match (self, other) {
(ValTensor::Value { inner: v1, .. }, ValTensor::Value { inner: v2, .. }) => {
ValTensor::from(Tensor::new(Some(&[v1.clone(), v2.clone()]), &[2])?.combine()?)
ValTensor::from(Tensor::new(Some(&[v1.clone(), v2]), &[2])?.combine()?)
}
_ => {
return Err(TensorError::WrongMethod);

View File

@@ -1,6 +1,8 @@
use std::collections::HashSet;
use log::{debug, error, warn};
use crate::circuit::{region::ConstantsMap, CheckMode};
use crate::circuit::{CheckMode, region::ConstantsMap};
use super::*;
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
@@ -379,6 +381,50 @@ impl VarTensor {
}
}
/// Assigns values from a ValTensor to this tensor, excluding specified positions
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `omissions` - Set of positions to skip during assignment
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// The assigned ValTensor or an error if assignment fails
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
omissions: &HashSet<usize>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut assigned_coord = 0;
let mut res: ValTensor<F> = match values {
ValTensor::Instance { .. } => {
error!(
"assignment with omissions is not supported on instance columns. increase K if you require more rows."
);
Err(halo2_proofs::plonk::Error::Synthesis)
}
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
v.enum_map(|coord, k| {
if omissions.contains(&coord) {
return Ok::<_, halo2_proofs::plonk::Error>(k);
}
let cell =
self.assign_value(region, offset, k.clone(), assigned_coord, constants)?;
assigned_coord += 1;
Ok::<_, halo2_proofs::plonk::Error>(cell)
})?
.into(),
),
}?;
res.set_scale(values.scale());
Ok(res)
}
/// Assigns values from a ValTensor to this tensor
///
/// # Arguments

View File

@@ -3,10 +3,10 @@
mod native_tests {
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::Commitments;
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
use ezkl::pfsys::Snark;
use ezkl::Commitments;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::Bn256;
use lazy_static::lazy_static;