Compare commits

...

4 Commits

Author SHA1 Message Date
github-actions[bot]
7b955e8fd7 ci: update version string in docs 2025-05-26 14:54:28 +00:00
dante
40ce9dfde9 chore: rm lots of clones (#980) 2025-05-26 10:54:09 -04:00
dante
839030ce10 chore: rm halo2proofs patches (#976) 2025-04-29 10:58:35 -04:00
dante
cfccc5460c refactor: rm postgres (#977) 2025-04-29 08:59:14 -04:00
50 changed files with 40414 additions and 3580 deletions

View File

@@ -198,15 +198,15 @@ jobs:
- name: Build release binary (no asm or metal)
if: matrix.build != 'linux-gnu' && matrix.build != 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features mimalloc
- name: Build release binary (asm)
if: matrix.build == 'linux-gnu'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm,mimalloc
- name: Build release binary (metal)
if: matrix.build == 'macos-aarch64'
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features macos-metal,mimalloc
- name: Strip release binary
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'

View File

@@ -245,25 +245,6 @@ jobs:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
foudry-solidity-tests:
permissions:
contents: read
runs-on: non-gpu
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
with:
persist-credentials: false
submodules: recursive
- name: Install Foundry
uses: foundry-rs/foundry-toolchain@3b74dacdda3c0b763089addb99ed86bc3800e68b
- name: Run tests
run: |
cd tests/foundry
forge install https://github.com/foundry-rs/forge-std --no-git --no-commit
forge test -vvvv --fuzz-runs 64
mock-proving-tests:
permissions:
contents: read
@@ -372,9 +353,6 @@ jobs:
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --target nodejs --out-dir ./in-browser-evm-verifier/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" in-browser-evm-verifier/nodejs/ezkl.js
- name: Build @ezkljs/verify package
run: |
cd in-browser-evm-verifier
@@ -497,9 +475,6 @@ jobs:
- name: Build wasm package for nodejs target.
run: |
wasm-pack build --target nodejs --out-dir ./tests/wasm/nodejs . -- -Z build-std="panic_abort,std"
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --verbose tests::kzg_prove_and_verify_with_overflow_::w
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
@@ -744,24 +719,6 @@ jobs:
permissions:
contents: read
runs-on: large-self-hosted
services:
# Label used to access the service container
postgres:
# Docker Hub image
image: postgres
env:
POSTGRES_USER: ubuntu
POSTGRES_HOST_AUTH_METHOD: trust
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
--health-interval 10s
--health-timeout 5s
--health-retries 5
-v /var/run/postgresql:/var/run/postgresql
ports:
# Maps tcp port 5432 on service container to the host
- 5432:5432
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 #v4.2.2
with:
@@ -798,8 +755,6 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::neural_bag_of_words_ --no-capture
- name: Felt conversion
run: source .env/bin/activate; cargo nextest run py_tests::tests::felt_conversion_test_ --no-capture
- name: Postgres tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --no-capture
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: authenticate-kaggle-cli

2589
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,7 @@ cargo-features = ["profile-rustflags"]
[package]
name = "ezkl"
version = "0.0.0"
edition = "2024"
edition = "2021"
default-run = "ezkl"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -69,20 +69,18 @@ reqwest = { version = "0.12.4", default-features = false, features = [
"stream",
], optional = true }
openssl = { version = "0.10.55", features = ["vendored"], optional = true }
tokio-postgres = { version = "0.7.10", optional = true }
pg_bigdecimal = { version = "0.1.5", optional = true }
lazy_static = { version = "1.4.0", optional = true }
colored_json = { version = "3.0.1", default-features = false, optional = true }
tokio = { version = "1.35.0", default-features = false, features = [
"macros",
"rt-multi-thread",
], optional = true }
pyo3 = { version = "0.23.2", features = [
pyo3 = { version = "0.24.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default-features = false, optional = true }
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.23.0", features = [
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", version = "0.24.0", features = [
"attributes",
"tokio-runtime",
], default-features = false, optional = true }
@@ -90,9 +88,9 @@ pyo3-log = { version = "0.12.0", default-features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "37132e0397d0a73e5bd3a8615d932dabe44f6736", default-features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
objc = { version = "0.2.4", optional = true }
mimalloc = { version = "0.1", optional = true }
pyo3-stub-gen = { version = "0.6.0", optional = true }
jemallocator = { version = "0.5", optional = true }
mimalloc = { version = "0.1", optional = true }
# universal bindings
uniffi = { version = "=0.28.0", optional = true }
getrandom = { version = "0.2.8", optional = true }
@@ -242,12 +240,9 @@ ezkl = [
"dep:indicatif",
"dep:gag",
"dep:reqwest",
"dep:tokio-postgres",
"dep:pg_bigdecimal",
"dep:lazy_static",
"dep:tokio",
"dep:openssl",
"dep:mimalloc",
"dep:chrono",
"dep:sha256",
"dep:clap_complete",
@@ -274,22 +269,19 @@ no-banner = []
no-update = []
macos-metal = ["halo2_proofs/macos"]
ios-metal = ["halo2_proofs/ios"]
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
[patch.'https://github.com/zkonduit/halo2#0654e92bdf725fd44d849bfef3643870a8c7d50b']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#f441c920be45f8f05d2c06a173d82e8885a5ed4d", package = "halo2_proofs" }
jemalloc = ["dep:jemallocator"]
mimalloc = ["dep:mimalloc"]
[patch.crates-io]
uniffi_testing = { git = "https://github.com/ElusAegis/uniffi-rs", branch = "feat/testing-feature-build-fix" }
[profile.release]
# debug = true
rustflags = ["-C", "relocation-model=pic"]
lto = "fat"
codegen-units = 1
#panic = "abort"
# panic = "abort"
[profile.test-runs]

View File

@@ -68,7 +68,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
&[&self.image, &self.kernel, &self.bias],
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
stride: vec![1; 2],

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -59,7 +60,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -61,7 +62,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),

View File

@@ -17,6 +17,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -86,13 +87,13 @@ impl Circuit<Fr> for MyCircuit {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[output.unwrap()],
&[&output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -17,6 +17,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -87,13 +88,13 @@ impl Circuit<Fr> for MyCircuit {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[output.unwrap()],
&[&output.unwrap()],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -59,7 +60,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.unwrap();

View File

@@ -63,7 +63,7 @@ impl Circuit<Fr> for MyCircuit {
config
.layout(
&mut region,
&[self.image.clone()],
&[&self.image],
Box::new(HybridOp::SumPool {
padding: vec![(0, 0); 2],
stride: vec![1, 1],

View File

@@ -15,6 +15,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -57,7 +58,11 @@ impl Circuit<Fr> for MyCircuit {
|region| {
let mut region = region::RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.unwrap();
Ok(())
},

View File

@@ -16,6 +16,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use itertools::Itertools;
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
@@ -58,7 +59,11 @@ impl Circuit<Fr> for MyCircuit {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 1024, 2);
config
.layout(&mut region, &self.inputs, Box::new(PolyOp::Pow(4)))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Pow(4)),
)
.unwrap();
Ok(())
},

View File

@@ -70,7 +70,7 @@ impl Circuit<Fr> for NLCircuit {
config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,

View File

@@ -67,7 +67,7 @@ impl Circuit<Fr> for NLCircuit {
config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.unwrap();

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '0.0.0'
release = '22.0.4'
version = release

View File

@@ -32,7 +32,6 @@ use mnist::*;
use rand::rngs::OsRng;
use std::marker::PhantomData;
mod params;
const K: usize = 20;
@@ -216,11 +215,7 @@ where
.layer_config
.layout(
&mut region,
&[
self.input.clone(),
self.l0_params[0].clone(),
self.l0_params[1].clone(),
],
&[&self.input, &self.l0_params[0], &self.l0_params[1]],
Box::new(op),
)
.unwrap();
@@ -229,7 +224,7 @@ where
.layer_config
.layout(
&mut region,
&[x.unwrap()],
&[&x.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -241,7 +236,7 @@ where
.layer_config
.layout(
&mut region,
&[x.unwrap()],
&[&x.unwrap()],
Box::new(LookupOp::Div { denom: 32.0.into() }),
)
.unwrap()
@@ -253,7 +248,7 @@ where
.layer_config
.layout(
&mut region,
&[self.l2_params[0].clone(), x],
&[&self.l2_params[0], &x],
Box::new(PolyOp::Einsum {
equation: "ij,j->ik".to_string(),
}),
@@ -265,7 +260,7 @@ where
.layer_config
.layout(
&mut region,
&[x, self.l2_params[1].clone()],
&[&x, &self.l2_params[1]],
Box::new(PolyOp::Add),
)
.unwrap()

View File

@@ -117,10 +117,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[
self.l0_params[0].clone().try_into().unwrap(),
self.input.clone(),
],
&[&self.l0_params[0].clone().try_into().unwrap(), &self.input],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -135,7 +132,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x, self.l0_params[1].clone().try_into().unwrap()],
&[&x, &self.l0_params[1].clone().try_into().unwrap()],
Box::new(PolyOp::Add),
)
.unwrap()
@@ -147,7 +144,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x],
&[&x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
@@ -163,7 +160,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[self.l2_params[0].clone().try_into().unwrap(), x],
&[&self.l2_params[0].clone().try_into().unwrap(), &x],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -178,7 +175,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x, self.l2_params[1].clone().try_into().unwrap()],
&[&x, &self.l2_params[1].clone().try_into().unwrap()],
Box::new(PolyOp::Add),
)
.unwrap()
@@ -190,7 +187,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x],
&[&x],
Box::new(PolyOp::LeakyReLU {
scale: 1,
slope: 0.0.into(),
@@ -203,7 +200,7 @@ impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRe
.layer_config
.layout(
&mut region,
&[x.unwrap()],
&[&x.unwrap()],
Box::new(LookupOp::Div {
denom: ezkl::circuit::utils::F32::from(128.),
}),

View File

@@ -1,462 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Mean of ERC20 transfer amounts\n",
"\n",
"This notebook shows how to calculate the mean of ERC20 transfer amounts, pulling data in from a Postgres database. First we install and get the necessary libraries running. \n",
"The first of which is [shovel](https://indexsupply.com/shovel/docs/#getting-started), which is a library that allows us to pull data from the Ethereum blockchain into a Postgres database.\n",
"\n",
"Make sure you install postgres if needed https://indexsupply.com/shovel/docs/#getting-started. \n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import getpass\n",
"import json\n",
"import time\n",
"import subprocess\n",
"\n",
"# swap out for the relevant linux/amd64, darwin/arm64, darwin/amd64, windows/amd64\n",
"os.system(\"curl -LO https://indexsupply.net/bin/1.0/linux/amd64/shovel\")\n",
"os.system(\"chmod +x shovel\")\n",
"\n",
"\n",
"os.environ[\"PG_URL\"] = \"postgres://\" + getpass.getuser() + \":@localhost:5432/shovel\"\n",
"\n",
"# create a config.json file with the following contents\n",
"config = {\n",
" \"pg_url\": \"$PG_URL\",\n",
" \"eth_sources\": [\n",
" {\"name\": \"mainnet\", \"chain_id\": 1, \"url\": \"https://ethereum-rpc.publicnode.com\"},\n",
" {\"name\": \"base\", \"chain_id\": 8453, \"url\": \"https://base-rpc.publicnode.com\"}\n",
" ],\n",
" \"integrations\": [{\n",
" \"name\": \"usdc_transfer\",\n",
" \"enabled\": True,\n",
" \"sources\": [{\"name\": \"mainnet\"}, {\"name\": \"base\"}],\n",
" \"table\": {\n",
" \"name\": \"usdc\",\n",
" \"columns\": [\n",
" {\"name\": \"log_addr\", \"type\": \"bytea\"},\n",
" {\"name\": \"block_num\", \"type\": \"numeric\"},\n",
" {\"name\": \"f\", \"type\": \"bytea\"},\n",
" {\"name\": \"t\", \"type\": \"bytea\"},\n",
" {\"name\": \"v\", \"type\": \"numeric\"}\n",
" ]\n",
" },\n",
" \"block\": [\n",
" {\"name\": \"block_num\", \"column\": \"block_num\"},\n",
" {\n",
" \"name\": \"log_addr\",\n",
" \"column\": \"log_addr\",\n",
" \"filter_op\": \"contains\",\n",
" \"filter_arg\": [\n",
" \"a0b86991c6218b36c1d19d4a2e9eb0ce3606eb48\",\n",
" \"833589fCD6eDb6E08f4c7C32D4f71b54bdA02913\"\n",
" ]\n",
" }\n",
" ],\n",
" \"event\": {\n",
" \"name\": \"Transfer\",\n",
" \"type\": \"event\",\n",
" \"anonymous\": False,\n",
" \"inputs\": [\n",
" {\"indexed\": True, \"name\": \"from\", \"type\": \"address\", \"column\": \"f\"},\n",
" {\"indexed\": True, \"name\": \"to\", \"type\": \"address\", \"column\": \"t\"},\n",
" {\"indexed\": False, \"name\": \"value\", \"type\": \"uint256\", \"column\": \"v\"}\n",
" ]\n",
" }\n",
" }]\n",
"}\n",
"\n",
"# write the config to a file\n",
"with open(\"config.json\", \"w\") as f:\n",
" f.write(json.dumps(config))\n",
"\n",
"\n",
"# print the two env variables\n",
"os.system(\"echo $PG_URL\")\n",
"\n",
"os.system(\"createdb -h localhost -p 5432 shovel\")\n",
"\n",
"os.system(\"echo shovel is now installed. starting:\")\n",
"\n",
"command = [\"./shovel\", \"-config\", \"config.json\"]\n",
"proc = subprocess.Popen(command)\n",
"\n",
"os.system(\"echo shovel started.\")\n",
"\n",
"time.sleep(10)\n",
"\n",
"# after we've fetched some data -- kill the process\n",
"proc.terminate()\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2wIAHwqH2_mo"
},
"source": [
"**Import Dependencies**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9Byiv2Nc2MsK"
},
"outputs": [],
"source": [
"# check if notebook is in colab\n",
"try:\n",
" # install ezkl\n",
" import google.colab\n",
" import subprocess\n",
" import sys\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"ezkl\"])\n",
" subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"onnx\"])\n",
"\n",
"# rely on local installation of ezkl if the notebook is not in colab\n",
"except:\n",
" pass\n",
"\n",
"import ezkl\n",
"import torch\n",
"import datetime\n",
"import pandas as pd\n",
"import requests\n",
"import json\n",
"import os\n",
"\n",
"import logging\n",
"# # uncomment for more descriptive logging \n",
"FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'\n",
"logging.basicConfig(format=FORMAT)\n",
"logging.getLogger().setLevel(logging.DEBUG)\n",
"\n",
"print(\"ezkl version: \", ezkl.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "osjj-0Ta3E8O"
},
"source": [
"**Create Computational Graph**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "x1vl9ZXF3EEW",
"outputId": "bda21d02-fe5f-4fb2-8106-f51a8e2e67aa"
},
"outputs": [],
"source": [
"from torch import nn\n",
"import torch\n",
"\n",
"\n",
"class Model(nn.Module):\n",
" def __init__(self):\n",
" super(Model, self).__init__()\n",
"\n",
" # x is a time series \n",
" def forward(self, x):\n",
" return [torch.mean(x)]\n",
"\n",
"\n",
"\n",
"\n",
"circuit = Model()\n",
"\n",
"\n",
"\n",
"\n",
"x = 0.1*torch.rand(1,*[1,5], requires_grad=True)\n",
"\n",
"# # print(torch.__version__)\n",
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"\n",
"print(device)\n",
"\n",
"circuit.to(device)\n",
"\n",
"# Flips the neural net into inference mode\n",
"circuit.eval()\n",
"\n",
"# Export the model\n",
"torch.onnx.export(circuit, # model being run\n",
" x, # model input (or a tuple for multiple inputs)\n",
" \"lol.onnx\", # where to save the model (can be a file or file-like object)\n",
" export_params=True, # store the trained parameter weights inside the model file\n",
" opset_version=11, # the ONNX version to export the model to\n",
" do_constant_folding=True, # whether to execute constant folding for optimization\n",
" input_names = ['input'], # the model's input names\n",
" output_names = ['output'], # the model's output names\n",
" dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes\n",
" 'output' : {0 : 'batch_size'}})\n",
"\n",
"# export(circuit, input_shape=[1, 20])\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "E3qCeX-X5xqd"
},
"source": [
"**Set Data Source and Get Data**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "6RAMplxk5xPk",
"outputId": "bd2158fe-0c00-44fd-e632-6a3f70cdb7c9"
},
"outputs": [],
"source": [
"import getpass\n",
"# make an input.json file from the df above\n",
"input_filename = os.path.join('input.json')\n",
"\n",
"pg_input_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 5\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
"json_formatted_str = json.dumps(pg_input_file, indent=2)\n",
"print(json_formatted_str)\n",
"\n",
"\n",
" # Serialize data into file:\n",
"json.dump(pg_input_file, open(input_filename, 'w' ))\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# this corresponds to 4 batches\n",
"calibration_filename = os.path.join('calibration.json')\n",
"\n",
"pg_cal_file = dict(input_data = {\n",
" \"host\": \"localhost\",\n",
" # make sure you replace this with your own username\n",
" \"user\": getpass.getuser(),\n",
" \"dbname\": \"shovel\",\n",
" \"password\": \"\",\n",
" \"query\": \"SELECT v FROM usdc ORDER BY block_num DESC LIMIT 20\",\n",
" \"port\": \"5432\",\n",
"})\n",
"\n",
" # Serialize data into file:\n",
"json.dump( pg_cal_file, open(calibration_filename, 'w' ))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "eLJ7oirQ_HQR"
},
"source": [
"**EZKL Workflow**"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "rNw0C9QL6W88"
},
"outputs": [],
"source": [
"import subprocess\n",
"import os\n",
"\n",
"onnx_filename = os.path.join('lol.onnx')\n",
"compiled_filename = os.path.join('lol.compiled')\n",
"settings_filename = os.path.join('settings.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.decomp_legs = 4\n",
"\n",
"# Generate settings using ezkl\n",
"res = ezkl.gen_settings(onnx_filename, settings_filename, py_run_args=run_args)\n",
"\n",
"assert res == True\n",
"\n",
"res = await ezkl.calibrate_settings(input_filename, onnx_filename, settings_filename, \"resources\")\n",
"\n",
"assert res == True\n",
"\n",
"await ezkl.get_srs(settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"ezkl.compile_circuit(onnx_filename, compiled_filename, settings_filename)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4MmE9SX66_Il",
"outputId": "16403639-66a4-4280-ac7f-6966b75de5a3"
},
"outputs": [],
"source": [
"# generate settings\n",
"\n",
"\n",
"# show the settings.json\n",
"with open(\"settings.json\") as f:\n",
" data = json.load(f)\n",
" json_formatted_str = json.dumps(data, indent=2)\n",
"\n",
" print(json_formatted_str)\n",
"\n",
"assert os.path.exists(\"settings.json\")\n",
"assert os.path.exists(\"input.json\")\n",
"assert os.path.exists(\"lol.onnx\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fULvvnK7_CMb"
},
"outputs": [],
"source": [
"pk_path = os.path.join('test.pk')\n",
"vk_path = os.path.join('test.vk')\n",
"\n",
"\n",
"# setup the proof\n",
"res = ezkl.setup(\n",
" compiled_filename,\n",
" vk_path,\n",
" pk_path\n",
" )\n",
"\n",
"assert res == True\n",
"assert os.path.isfile(vk_path)\n",
"assert os.path.isfile(pk_path)\n",
"assert os.path.isfile(settings_filename)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"witness_path = \"witness.json\"\n",
"\n",
"# generate the witness\n",
"res = await ezkl.gen_witness(\n",
" input_filename,\n",
" compiled_filename,\n",
" witness_path\n",
" )\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Oog3j6Kd-Wed",
"outputId": "5839d0c1-5b43-476e-c2f8-6707de562260"
},
"outputs": [],
"source": [
"# prove the zk circuit\n",
"# GENERATE A PROOF\n",
"proof_path = os.path.join('test.pf')\n",
"\n",
"\n",
"proof = ezkl.prove(\n",
" witness_path,\n",
" compiled_filename,\n",
" pk_path,\n",
" proof_path,\n",
" \"single\"\n",
" )\n",
"\n",
"\n",
"print(\"proved\")\n",
"\n",
"assert os.path.isfile(proof_path)\n",
"\n"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": ".env",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.7"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

Binary file not shown.

37795
examples/onnx/fr_age/lol.txt Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,12 +1,5 @@
// ignore file if compiling for wasm
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc::MiMalloc;
#[global_allocator]
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
static GLOBAL: MiMalloc = MiMalloc;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::{CommandFactory, Parser};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]

View File

@@ -1,34 +1,34 @@
use crate::Commitments;
use crate::RunArgs;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::circuit::modules::Module;
use crate::circuit::modules::polycommit::PolyCommitChip;
use crate::circuit::modules::poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
spec::{POSEIDON_RATE, POSEIDON_WIDTH, PoseidonSpec},
};
use crate::circuit::modules::Module;
use crate::circuit::CheckMode;
use crate::circuit::InputType;
use crate::commands::*;
use crate::fieldutils::{IntegerRep, felt_to_integer_rep, integer_rep_to_felt};
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::graph::TestDataSource;
use crate::graph::{
GraphCircuit, GraphSettings, Model, Visibility, quantize_float, scale_to_multiplier,
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
};
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::{
ProofType, TranscriptType, load_pk, load_vk, save_params, save_vk,
srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
ProofType, TranscriptType,
};
use crate::Commitments;
use crate::RunArgs;
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::{Bn256, Fq, Fr, G1, G1Affine};
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
use pyo3::exceptions::{PyIOError, PyRuntimeError};
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;
use pyo3_log;
use pyo3_stub_gen::{
TypeInfo, define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction,
define_stub_info_gatherer, derive::gen_stub_pyclass, derive::gen_stub_pyclass_enum,
derive::gen_stub_pyfunction, TypeInfo,
};
use snark_verifier::util::arithmetic::PrimeField;
use std::collections::HashSet;

View File

@@ -962,7 +962,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
pub fn layout(
&mut self,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, CircuitError> {
op.layout(self, region, values)

View File

@@ -1,7 +1,7 @@
use super::*;
use crate::{
circuit::{layouts, utils},
fieldutils::{IntegerRep, integer_rep_to_felt},
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::multiplier_to_scale,
tensor::{self, DataFormat, Tensor, TensorType, ValTensor},
};
@@ -109,13 +109,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
HybridOp::Greater { .. }
| HybridOp::Less { .. }
| HybridOp::Equals { .. }
| HybridOp::GreaterEqual { .. }
HybridOp::Greater
| HybridOp::Less
| HybridOp::Equals
| HybridOp::GreaterEqual
| HybridOp::Max
| HybridOp::Min
| HybridOp::LessEqual { .. } => {
| HybridOp::LessEqual => {
vec![0, 1]
}
_ => vec![],
@@ -213,7 +213,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
HybridOp::Rsqrt {
@@ -362,10 +362,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Hybrid
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
let scale = match self {
HybridOp::Greater { .. }
| HybridOp::GreaterEqual { .. }
| HybridOp::Less { .. }
| HybridOp::LessEqual { .. }
HybridOp::Greater
| HybridOp::GreaterEqual
| HybridOp::Less
| HybridOp::LessEqual
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,

File diff suppressed because it is too large Load Diff

View File

@@ -186,7 +186,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Lookup
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(layouts::nonlinearity(
config,

View File

@@ -49,7 +49,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError>;
/// Returns the scale of the output of the operation.
@@ -209,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = values[0].clone();
if !value.all_prev_assigned() {
@@ -223,12 +223,29 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
true,
)?))
}
_ => Ok(Some(super::layouts::identity(
config,
region,
values[..].try_into()?,
self.decomp,
)?)),
_ => {
if self.decomp {
log::debug!("constraining input to be decomp");
Ok(Some(
super::layouts::decompose(
config,
region,
values[..].try_into()?,
&region.base(),
&region.legs(),
false,
)?
.1,
))
} else {
log::debug!("constraining input to be identity");
Ok(Some(super::layouts::identity(
config,
region,
values[..].try_into()?,
)?))
}
}
}
} else {
Ok(Some(value))
@@ -263,7 +280,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
&self,
_: &mut crate::circuit::BaseConfig<F>,
_: &mut RegionCtx<F>,
_: &[ValTensor<F>],
_: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Err(super::CircuitError::UnsupportedOp)
}
@@ -319,8 +336,13 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
}
impl<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self
@@ -333,20 +355,20 @@ impl<
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
_: &[ValTensor<F>],
_: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
let value = if let Some(value) = &self.pre_assigned_val {
value.clone()
} else {
self.quantized_values.clone().try_into()?
};
// we gotta constrain it once if its used multiple times
Ok(Some(layouts::identity(
config,
region,
&[value],
self.decomp,
)?))
Ok(Some(if self.decomp {
log::debug!("constraining constant to be decomp");
super::layouts::decompose(config, region, &[&value], &region.base(), &region.legs(), false)?.1
} else {
log::debug!("constraining constant to be identity");
super::layouts::identity(config, region, &[&value])?
}))
}
fn clone_dyn(&self) -> Box<dyn Op<F>> {

View File

@@ -108,8 +108,13 @@ pub enum PolyOp {
}
impl<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + Serialize + for<'de> Deserialize<'de>,
> Op<F> for PolyOp
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
@@ -203,11 +208,11 @@ impl<
&self,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
values: &[&ValTensor<F>],
) -> Result<Option<ValTensor<F>>, CircuitError> {
Ok(Some(match self {
PolyOp::Abs => layouts::abs(config, region, values[..].try_into()?)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?)?,
PolyOp::Sign => layouts::sign(config, region, values[..].try_into()?, true)?,
PolyOp::LeakyReLU { slope, scale } => {
layouts::leaky_relu(config, region, values[..].try_into()?, slope, scale)?
}
@@ -335,9 +340,7 @@ impl<
PolyOp::Mult => {
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
}
PolyOp::Identity { .. } => {
layouts::identity(config, region, values[..].try_into()?, false)?
}
PolyOp::Identity { .. } => layouts::identity(config, region, values[..].try_into()?)?,
PolyOp::Reshape(d) | PolyOp::Flatten(d) => layouts::reshape(values[..].try_into()?, d)?,
PolyOp::Pad(p) => {
if values.len() != 1 {
@@ -416,14 +419,14 @@ impl<
PolyOp::Reshape(_) | PolyOp::Flatten(_) => in_scales[0],
PolyOp::Pow(pow) => in_scales[0] * (*pow as crate::Scale),
PolyOp::Identity { out_scale } => out_scale.unwrap_or(in_scales[0]),
PolyOp::Sign { .. } => 0,
PolyOp::Sign => 0,
_ => in_scales[0],
};
Ok(scale)
}
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
if matches!(self, PolyOp::Add { .. } | PolyOp::Sub) {
if matches!(self, PolyOp::Add | PolyOp::Sub) {
vec![0, 1]
} else if matches!(self, PolyOp::Iff) {
vec![1, 2]

View File

@@ -10,7 +10,6 @@ use halo2_proofs::{
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use maybe_rayon::iter::ParallelExtend;
use std::{
cell::RefCell,
@@ -462,15 +461,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
/// Update the max and min from inputs
pub fn update_max_min_lookup_inputs(
&mut self,
inputs: &[ValTensor<F>],
inputs: &ValTensor<F>,
) -> Result<(), CircuitError> {
let (mut min, mut max) = (0, 0);
for i in inputs {
max = max.max(i.int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.int_evals()?.into_iter().min().unwrap_or_default());
}
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
let int_eval = inputs.int_evals()?;
let max = int_eval.iter().max().unwrap_or(&0);
let min = int_eval.iter().min().unwrap_or(&0);
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(*max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(*min);
Ok(())
}
@@ -505,10 +503,10 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
/// add used lookup
pub fn add_used_lookup(
&mut self,
lookup: LookupOp,
inputs: &[ValTensor<F>],
lookup: &LookupOp,
inputs: &ValTensor<F>,
) -> Result<(), CircuitError> {
self.statistics.used_lookups.insert(lookup);
self.statistics.used_lookups.insert(lookup.clone());
self.update_max_min_lookup_inputs(inputs)
}
@@ -642,34 +640,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
self.assign_dynamic_lookup(var, values)
}
/// Assign a valtensor to a vartensor
pub fn assign_with_omissions(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
ommissions: &HashSet<usize>,
) -> Result<ValTensor<F>, CircuitError> {
if let Some(region) = &self.region {
Ok(var.assign_with_omissions(
&mut region.borrow_mut(),
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)?)
} else {
let mut values_clone = values.clone();
let mut indices = ommissions.clone().into_iter().collect_vec();
values_clone.remove_indices(&mut indices, false)?;
let values_map = values.create_constants_map();
self.assigned_constants.par_extend(values_map);
Ok(values.clone())
}
}
/// Assign a valtensor to a vartensor with duplication
pub fn assign_with_duplication_unconstrained(
&mut self,

View File

@@ -9,6 +9,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::Fr as F;
use halo2curves::ff::{Field, PrimeField};
use itertools::Itertools;
#[cfg(not(any(
all(target_arch = "wasm32", target_os = "unknown"),
not(feature = "ezkl")
@@ -64,7 +65,7 @@ mod matmul {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -141,7 +142,7 @@ mod matmul_col_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -215,7 +216,7 @@ mod matmul_col_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -302,7 +303,7 @@ mod matmul_col_ultra_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -380,6 +381,7 @@ mod matmul_col_ultra_overflow {
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
};
use itertools::Itertools;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use super::*;
@@ -422,7 +424,7 @@ mod matmul_col_ultra_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "ij,jk->ik".to_string(),
}),
@@ -533,7 +535,7 @@ mod dot {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -610,7 +612,7 @@ mod dot_col_overflow_triple_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -683,7 +685,7 @@ mod dot_col_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -756,7 +758,7 @@ mod sum {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -826,7 +828,7 @@ mod sum_col_overflow_double_col {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -895,7 +897,7 @@ mod sum_col_overflow {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sum { axes: vec![0] }),
)
.map_err(|_| Error::Synthesis)
@@ -966,7 +968,7 @@ mod composition {
let _ = config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -975,7 +977,7 @@ mod composition {
let _ = config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -984,7 +986,7 @@ mod composition {
config
.layout(
&mut region,
&self.inputs.clone(),
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Einsum {
equation: "i,i->".to_string(),
}),
@@ -1061,7 +1063,7 @@ mod conv {
config
.layout(
&mut region,
&self.inputs,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1218,7 +1220,7 @@ mod conv_col_ultra_overflow {
config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone()],
&[&self.image, &self.kernel],
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1377,7 +1379,7 @@ mod conv_relu_col_ultra_overflow {
let output = config
.layout(
&mut region,
&[self.image.clone(), self.kernel.clone()],
&[&self.image, &self.kernel],
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
@@ -1390,7 +1392,7 @@ mod conv_relu_col_ultra_overflow {
let _output = config
.layout(
&mut region,
&[output.unwrap().unwrap()],
&[&output.unwrap().unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -1517,7 +1519,11 @@ mod add_w_shape_casting {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -1584,7 +1590,11 @@ mod add {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -1671,8 +1681,8 @@ mod dynamic_lookup {
layouts::dynamic_lookup(
&config,
&mut region,
&self.lookups[i],
&self.tables[i],
&self.lookups[i].iter().collect_vec().try_into().unwrap(),
&self.tables[i].iter().collect_vec().try_into().unwrap(),
)
.map_err(|_| Error::Synthesis)?;
}
@@ -1767,8 +1777,8 @@ mod shuffle {
#[derive(Clone)]
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: [[ValTensor<F>; 1]; NUM_LOOP],
references: [[ValTensor<F>; 1]; NUM_LOOP],
inputs: [ValTensor<F>; NUM_LOOP],
references: [ValTensor<F>; NUM_LOOP],
_marker: PhantomData<F>,
}
@@ -1818,15 +1828,15 @@ mod shuffle {
layouts::shuffles(
&config,
&mut region,
&self.inputs[i],
&self.references[i],
&[&self.inputs[i]],
&[&self.references[i]],
layouts::SortCollisionMode::Unsorted,
)
.map_err(|_| Error::Synthesis)?;
}
assert_eq!(
region.shuffle_col_coord(),
NUM_LOOP * self.references[0][0].len()
NUM_LOOP * self.references[0].len()
);
assert_eq!(region.shuffle_index(), NUM_LOOP);
@@ -1843,17 +1853,19 @@ mod shuffle {
// parameters
let references = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
ValTensor::from(Tensor::from(
(0..LEN).map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
))
})
.collect::<Vec<_>>();
let inputs = (0..NUM_LOOP)
.map(|loop_idx| {
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * loop_idx) as u64 + 1))
})))]
ValTensor::from(Tensor::from(
(0..LEN)
.rev()
.map(|i| Value::known(F::from((i * loop_idx) as u64 + 1))),
))
})
.collect::<Vec<_>>();
@@ -1873,9 +1885,11 @@ mod shuffle {
} else {
loop_idx - 1
};
[ValTensor::from(Tensor::from((0..LEN).rev().map(|i| {
Value::known(F::from((i * prev_idx) as u64 + 1))
})))]
ValTensor::from(Tensor::from(
(0..LEN)
.rev()
.map(|i| Value::known(F::from((i * prev_idx) as u64 + 1))),
))
})
.collect::<Vec<_>>();
@@ -1931,7 +1945,11 @@ mod add_with_overflow {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Add))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Add),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2026,7 +2044,7 @@ mod add_with_overflow_and_poseidon {
layouter.assign_region(|| "_new_module", |_| Ok(()))?;
let inputs = vec![assigned_inputs_a, assigned_inputs_b];
let inputs = vec![&assigned_inputs_a, &assigned_inputs_b];
layouter.assign_region(
|| "model",
@@ -2135,7 +2153,11 @@ mod sub {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Sub))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Sub),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2202,7 +2224,11 @@ mod mult {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Mult))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Mult),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2269,7 +2295,11 @@ mod pow {
|region| {
let mut region = RegionCtx::new(region, 0, 1, 128, 2);
config
.layout(&mut region, &self.inputs.clone(), Box::new(PolyOp::Pow(5)))
.layout(
&mut region,
&self.inputs.iter().collect_vec(),
Box::new(PolyOp::Pow(5)),
)
.map_err(|_| Error::Synthesis)
},
)
@@ -2360,13 +2390,13 @@ mod matmul_relu {
};
let output = config
.base_config
.layout(&mut region, &self.inputs, Box::new(op))
.layout(&mut region, &self.inputs.iter().collect_vec(), Box::new(op))
.unwrap();
let _output = config
.base_config
.layout(
&mut region,
&[output.unwrap()],
&[&output.unwrap()],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -2465,7 +2495,7 @@ mod relu {
Ok(config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(PolyOp::LeakyReLU {
slope: 0.0.into(),
scale: 1,
@@ -2563,7 +2593,7 @@ mod lookup_ultra_overflow {
config
.layout(
&mut region,
&[self.input.clone()],
&[&self.input],
Box::new(LookupOp::Sigmoid { scale: 1.0.into() }),
)
.map_err(|_| Error::Synthesis)

View File

@@ -1,6 +1,6 @@
use alloy::primitives::Address as H160;
use clap::{Command, Parser, Subcommand};
use clap_complete::{Generator, Shell, generate};
use clap_complete::{generate, Generator, Shell};
#[cfg(feature = "python-bindings")]
use pyo3::{conversion::FromPyObject, exceptions::PyValueError, prelude::*};
use serde::{Deserialize, Serialize};
@@ -8,7 +8,7 @@ use std::path::PathBuf;
use std::str::FromStr;
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{Commitments, RunArgs, pfsys::ProofType};
use crate::{pfsys::ProofType, Commitments, RunArgs};
use crate::circuit::CheckMode;
use crate::graph::TestDataSource;
@@ -391,10 +391,8 @@ impl FromStr for DataField {
fn from_str(s: &str) -> Result<Self, Self::Err> {
// Check if the input starts with '@'
if s.starts_with('@') {
if let Some(file_path) = s.strip_prefix('@') {
// Extract the file path (remove the '@' prefix)
let file_path = &s[1..];
// Read the file content
let content = std::fs::read_to_string(file_path)
.map_err(|e| format!("Failed to read data file '{}': {}", file_path, e))?;

View File

@@ -1,24 +1,24 @@
use crate::graph::DataSource;
use crate::graph::GraphSettings;
use crate::graph::input::{CallToAccount, CallsToAccount, FileSourceInner, GraphData};
use crate::graph::modules::POSEIDON_INSTANCES;
use crate::pfsys::Snark;
use crate::graph::DataSource;
use crate::graph::GraphSettings;
use crate::pfsys::evm::EvmVerificationError;
use crate::pfsys::Snark;
use alloy::contract::CallBuilder;
use alloy::core::primitives::Address as H160;
use alloy::core::primitives::Bytes;
use alloy::core::primitives::U256;
use alloy::dyn_abi::abi::TokenSeq;
use alloy::dyn_abi::abi::token::{DynSeqToken, PackedSeqToken, WordToken};
use alloy::dyn_abi::abi::TokenSeq;
// use alloy::providers::Middleware;
use alloy::json_abi::JsonAbi;
use alloy::primitives::ruint::ParseError;
use alloy::primitives::{B256, I256, ParseSignedError};
use alloy::providers::ProviderBuilder;
use alloy::primitives::{ParseSignedError, B256, I256};
use alloy::providers::fillers::{
ChainIdFiller, FillProvider, GasFiller, JoinFill, NonceFiller, SignerFiller,
};
use alloy::providers::network::{Ethereum, EthereumSigner};
use alloy::providers::ProviderBuilder;
use alloy::providers::{Identity, Provider, RootProvider};
use alloy::rpc::types::eth::TransactionInput;
use alloy::rpc::types::eth::TransactionRequest;
@@ -27,9 +27,9 @@ use alloy::signers::wallet::{LocalWallet, WalletError};
use alloy::sol as abigen;
use alloy::transports::http::Http;
use alloy::transports::{RpcError, TransportErrorKind};
use foundry_compilers::Solc;
use foundry_compilers::artifacts::Settings as SolcSettings;
use foundry_compilers::error::{SolcError, SolcIoError};
use foundry_compilers::Solc;
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::{Fr, G1Affine};
use halo2curves::group::ff::PrimeField;
@@ -711,7 +711,7 @@ pub async fn verify_proof_with_data_attestation(
};
let encoded = func.encode_input(&[
Token::Address(addr_verifier.0.0.into()),
Token::Address(addr_verifier.0 .0.into()),
Token::Bytes(encoded_verifier),
])?;
@@ -757,7 +757,7 @@ pub async fn test_on_chain_data<M: 'static + Provider<Http<Client>, Ethereum>>(
let call_to_account = CallToAccount {
call_data: hex::encode(call),
decimals,
address: hex::encode(contract.address().0.0),
address: hex::encode(contract.address().0 .0),
};
info!("call_to_account: {:#?}", call_to_account);
Ok(call_to_account)
@@ -913,9 +913,9 @@ pub async fn get_contract_artifacts(
runs: usize,
) -> Result<(JsonAbi, Bytes, Bytes), EthError> {
use foundry_compilers::{
SHANGHAI_SOLC, SolcInput,
artifacts::{Optimizer, output_selection::OutputSelection},
artifacts::{output_selection::OutputSelection, Optimizer},
compilers::CompilerInput,
SolcInput, SHANGHAI_SOLC,
};
if !sol_code_path.exists() {

View File

@@ -1,6 +1,5 @@
use crate::EZKL_BUF_CAPACITY;
use crate::circuit::CheckMode;
use crate::circuit::region::RegionSettings;
use crate::circuit::CheckMode;
use crate::commands::CalibrationTarget;
use crate::eth::{deploy_contract_via_solidity, deploy_da_verifier_via_solidity, fix_da_sol};
#[allow(unused_imports)]
@@ -10,21 +9,21 @@ use crate::graph::{GraphCircuit, GraphSettings, GraphWitness, Model};
use crate::graph::{TestDataSource, TestSources};
use crate::pfsys::evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript};
use crate::pfsys::{
ProofSplitCommit, create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit,
create_keys, load_pk, load_vk, save_params, save_pk, Snark, StrategyType, TranscriptType,
};
use crate::pfsys::{
Snark, StrategyType, TranscriptType, create_keys, load_pk, load_vk, save_params, save_pk,
create_proof_circuit, swap_proof_commitments_polycommit, verify_proof_circuit, ProofSplitCommit,
};
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::EZKL_BUF_CAPACITY;
use crate::{commands::*, EZKLError};
use crate::{Commitments, RunArgs};
use crate::{EZKLError, commands::*};
use colored::Colorize;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
use halo2_proofs::plonk::{self, Circuit};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params};
use halo2_proofs::poly::commitment::{ParamsProver, Verifier};
use halo2_proofs::poly::ipa::commitment::{IPACommitmentScheme, ParamsIPA};
@@ -37,6 +36,7 @@ use halo2_proofs::poly::kzg::strategy::AccumulatorStrategy as KZGAccumulatorStra
use halo2_proofs::poly::kzg::{
commitment::ParamsKZG, strategy::SingleStrategy as KZGSingleStrategy,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer};
use halo2_solidity_verifier;
use halo2curves::bn256::{Bn256, Fr, G1Affine};
@@ -48,12 +48,12 @@ use itertools::Itertools;
use lazy_static::lazy_static;
use log::debug;
use log::{info, trace, warn};
use serde::Serialize;
use serde::de::DeserializeOwned;
use serde::Serialize;
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::Config;
use snark_verifier::system::halo2::compile;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use snark_verifier::system::halo2::Config;
use std::fs::File;
use std::io::BufWriter;
use std::io::{Cursor, Write};
@@ -1163,15 +1163,9 @@ pub(crate) async fn calibrate(
// if unix get a gag
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
let _r = Gag::stdout().ok();
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
let _g = Gag::stderr().ok();
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
@@ -1329,7 +1323,7 @@ pub(crate) async fn calibrate(
.clone()
}
CalibrationTarget::Accuracy => {
let param_iterator = found_params.iter().sorted_by_key(|p| {
let mut param_iterator = found_params.iter().sorted_by_key(|p| {
(
p.run_args.input_scale,
p.run_args.param_scale,
@@ -1338,7 +1332,7 @@ pub(crate) async fn calibrate(
)
});
let last = param_iterator.last().ok_or("no params found")?;
let last = param_iterator.next_back().ok_or("no params found")?;
let max_scale = (
last.run_args.input_scale,
last.run_args.param_scale,

View File

@@ -67,8 +67,11 @@ pub enum GraphError {
#[error("invalid input types")]
InvalidInputTypes,
/// Missing results
#[error("missing results")]
MissingResults,
#[error("missing result for node {0}")]
MissingResults(usize),
/// Missing input
#[error("missing input {0}")]
MissingInputForNode(usize),
/// Tensor error
#[error("[tensor] {0}")]
TensorError(#[from] crate::tensor::TensorError),
@@ -98,8 +101,6 @@ pub enum GraphError {
feature = "ezkl",
not(all(target_arch = "wasm32", target_os = "unknown"))
))]
#[error("[tokio postgres] {0}")]
TokioPostgresError(#[from] tokio_postgres::Error),
/// Eth error
#[cfg(all(
feature = "ezkl",
@@ -141,7 +142,9 @@ pub enum GraphError {
#[error("range check {0} is too large")]
RangeCheckTooLarge(usize),
///Cannot use on-chain data source as private data
#[error("cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm.")]
#[error(
"cannot use on-chain data source as 1) output for on-chain test 2) as private data 3) as input when using wasm."
)]
OnChainDataSource,
/// Missing data source
#[error("missing data source")]

View File

@@ -1,17 +1,15 @@
use super::errors::GraphError;
use super::quantize_float;
use crate::EZKL_BUF_CAPACITY;
use crate::circuit::InputType;
use crate::fieldutils::integer_rep_to_felt;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::graph::postgres::Client;
use crate::EZKL_BUF_CAPACITY;
use halo2curves::bn256::Fr as Fp;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::io::BufReader;
use std::io::BufWriter;
@@ -19,13 +17,10 @@ use std::io::Read;
use std::panic::UnwindSafe;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::{
tract_data::{TVec, prelude::Tensor as TractTensor},
tract_data::{prelude::Tensor as TractTensor, TVec},
value::TValue,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::tract_num_traits::ToPrimitive;
type Decimals = u8;
type Call = String;
type RPCUrl = String;
@@ -260,9 +255,6 @@ pub enum DataSource {
File(FileSource),
/// Data fetched from blockchain contracts
OnChain(OnChainSource),
/// Data from a PostgreSQL database
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
DB(PostgresSource),
}
impl Default for DataSource {
@@ -323,15 +315,6 @@ impl<'de> Deserialize<'de> for DataSource {
return Ok(DataSource::OnChain(t));
}
// Try deserializing as PostgresSource if feature enabled
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let third_try: Result<PostgresSource, _> = serde_json::from_str(this_json.get());
if let Ok(t) = third_try {
return Ok(DataSource::DB(t));
}
}
Err(serde::de::Error::custom("failed to deserialize DataSource"))
}
}
@@ -442,9 +425,7 @@ impl GraphData {
pub fn from_str(data: &str) -> Result<Self, GraphError> {
let graph_input = serde_json::from_str(data);
match graph_input {
Ok(graph_input) => {
return Ok(graph_input);
}
Ok(graph_input) => Ok(graph_input),
Err(_) => {
let path = std::path::PathBuf::from(data);
GraphData::from_path(path)
@@ -517,11 +498,6 @@ impl GraphData {
"on-chain data cannot be split into batches".to_string(),
));
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
GraphData {
input_data: DataSource::DB(data),
output_data: _,
} => data.fetch_and_format_as_file().await?,
};
// Process each input tensor according to its shape
@@ -591,28 +567,6 @@ impl GraphData {
mod tests {
use super::*;
#[test]
fn test_postgres_source_new() {
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
{
let source = PostgresSource::new(
"localhost".to_string(),
"5432".to_string(),
"user".to_string(),
"SELECT * FROM table".to_string(),
"database".to_string(),
"password".to_string(),
);
assert_eq!(source.host, "localhost");
assert_eq!(source.port, "5432");
assert_eq!(source.user, "user");
assert_eq!(source.query, "SELECT * FROM table");
assert_eq!(source.dbname, "database");
assert_eq!(source.password, "password");
}
}
#[test]
fn test_data_source_serialization_round_trip() {
// Test backwards compatibility with old format
@@ -655,95 +609,6 @@ mod tests {
}
}
/// Source data from a PostgreSQL database
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
#[derive(Clone, Debug, Deserialize, Serialize, Default, PartialOrd, PartialEq)]
pub struct PostgresSource {
/// Database host address
pub host: RPCUrl,
/// Database user name
pub user: String,
/// Database password
pub password: String,
/// SQL query to execute
pub query: String,
/// Database name
pub dbname: String,
/// Database port
pub port: String,
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
impl PostgresSource {
/// Creates a new PostgreSQL data source
pub fn new(
host: RPCUrl,
port: String,
user: String,
query: String,
dbname: String,
password: String,
) -> Self {
PostgresSource {
host,
user,
password,
query,
dbname,
port,
}
}
/// Fetches data from the PostgreSQL database
pub async fn fetch(&self) -> Result<Vec<Vec<pg_bigdecimal::PgNumeric>>, GraphError> {
// Configuration string
let config = if self.password.is_empty() {
format!(
"host={} user={} dbname={} port={}",
self.host, self.user, self.dbname, self.port
)
} else {
format!(
"host={} user={} dbname={} port={} password={}",
self.host, self.user, self.dbname, self.port, self.password
)
};
let mut client = Client::connect(&config).await?;
let mut res: Vec<pg_bigdecimal::PgNumeric> = Vec::new();
// Extract rows from query
for row in client.query(&self.query, &[]).await? {
for i in 0..row.len() {
res.push(row.get(i));
}
}
Ok(vec![res])
}
/// Fetches and formats data as FileSource
pub async fn fetch_and_format_as_file(&self) -> Result<Vec<Vec<FileSourceInner>>, GraphError> {
Ok(self
.fetch()
.await?
.iter()
.map(|d| {
d.iter()
.map(|d| {
FileSourceInner::Float(
d.n.as_ref()
.unwrap()
.to_f64()
.ok_or("could not convert decimal to f64")
.unwrap(),
)
})
.collect()
})
.collect())
}
}
#[cfg(feature = "python-bindings")]
impl ToPyObject for CallToAccount {
fn to_object(&self, py: Python) -> PyObject {
@@ -767,14 +632,6 @@ impl ToPyObject for DataSource {
.unwrap();
dict.to_object(py)
}
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
DataSource::DB(source) => {
let dict = PyDict::new(py);
dict.set_item("host", &source.host).unwrap();
dict.set_item("user", &source.user).unwrap();
dict.set_item("query", &source.query).unwrap();
dict.to_object(py)
}
}
}
}

View File

@@ -6,9 +6,6 @@ pub mod model;
pub mod modules;
/// Inner elements of a computational graph that represent a single operation / constraints.
pub mod node;
/// postgres helper functions
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
pub mod postgres;
/// Helper functions
pub mod utilities;
/// Representations of a computational graph's variables.
@@ -36,12 +33,12 @@ use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSize
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::region::{ConstantsMap, RegionSettings};
use crate::circuit::table::{RESERVED_BLINDING_ROWS_PAD, Range, Table, num_cols_required};
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::{IntegerRep, felt_to_f64};
use crate::fieldutils::{felt_to_f64, IntegerRep};
use crate::pfsys::PrettyElements;
use crate::tensor::{Tensor, ValTensor};
use crate::{EZKL_BUF_CAPACITY, RunArgs};
use crate::{RunArgs, EZKL_BUF_CAPACITY};
use halo2_proofs::{
circuit::Layouter,
@@ -56,13 +53,13 @@ use maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
pub use model::*;
pub use node::*;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
#[cfg(feature = "python-bindings")]
use pyo3::prelude::*;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDict;
#[cfg(feature = "python-bindings")]
use pyo3::types::PyDictMethods;
#[cfg(feature = "python-bindings")]
use pyo3::ToPyObject;
use serde::{Deserialize, Serialize};
use std::ops::Deref;
@@ -574,7 +571,7 @@ impl GraphSettings {
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
serde_json::to_writer(writer, &self).map_err(|e| {
error!("failed to save settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
std::io::Error::other(e)
})
}
/// load params from file
@@ -584,7 +581,7 @@ impl GraphSettings {
std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::open(path)?);
let settings: GraphSettings = serde_json::from_reader(reader).map_err(|e| {
error!("failed to load settings file at {}", e);
std::io::Error::new(std::io::ErrorKind::Other, e)
std::io::Error::other(e)
})?;
crate::check_version_string_matches(&settings.version);
@@ -1011,10 +1008,6 @@ impl GraphCircuit {
DataSource::File(file_data) => {
self.load_file_data(file_data, &shapes, scales, input_types)
}
DataSource::DB(pg) => {
let data = pg.fetch_and_format_as_file().await?;
self.load_file_data(&data, &shapes, scales, input_types)
}
}
}
@@ -1239,15 +1232,9 @@ impl GraphCircuit {
let mut cs = ConstraintSystem::default();
// if unix get a gag
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _r = match Gag::stdout() {
Ok(g) => Some(g),
_ => None,
};
let _r = Gag::stdout().ok();
#[cfg(all(not(not(feature = "ezkl")), unix))]
let _g = match Gag::stderr() {
Ok(g) => Some(g),
_ => None,
};
let _g = Gag::stderr().ok();
Self::configure_with_params(&mut cs, settings);
@@ -1583,13 +1570,13 @@ impl Circuit<Fp> for GraphCircuit {
let mut module_configs = ModuleConfigs::from_visibility(
cs,
params.module_sizes.clone(),
&params.module_sizes,
params.run_args.logrows as usize,
);
let mut vars = ModelVars::new(cs, &params);
module_configs.configure_complex_modules(cs, visibility, params.module_sizes.clone());
module_configs.configure_complex_modules(cs, &visibility, &params.module_sizes);
vars.instantiate_instance(
cs,

View File

@@ -200,7 +200,7 @@ fn number_of_iterations(mappings: &[InputMapping], dims: Vec<&[usize]>) -> usize
InputMapping::Stacked { axis, chunk } => Some(
// number of iterations given the dim size along the axis
// and the chunk size
(dims[*axis] + chunk - 1) / chunk,
dims[*axis].div_ceil(*chunk), // (dims[*axis] + chunk - 1) / chunk,
),
_ => None,
});
@@ -589,10 +589,7 @@ impl Model {
required_range_checks: res.range_checks.into_iter().collect(),
model_output_scales: self.graph.get_output_scales()?,
model_input_scales: self.graph.get_input_scales(),
input_types: match self.get_input_types() {
Ok(x) => Some(x),
Err(_) => None,
},
input_types: self.get_input_types().ok(),
output_types: Some(self.get_output_types()),
num_dynamic_lookups: res.num_dynamic_lookups,
total_dynamic_col_size: res.dynamic_lookup_col_coord,
@@ -650,10 +647,13 @@ impl Model {
let variables: std::collections::HashMap<String, usize> =
std::collections::HashMap::from_iter(variables.iter().map(|(k, v)| (k.clone(), *v)));
for (i, id) in model.clone().inputs.iter().enumerate() {
let inputs = model.inputs.clone();
let outputs = model.outputs.clone();
for (i, id) in inputs.iter().enumerate() {
let input = model.node_mut(id.node);
if input.outputs.len() == 0 {
if input.outputs.is_empty() {
return Err(GraphError::MissingOutput(id.node));
}
let mut fact: InferenceFact = input.outputs[0].fact.clone();
@@ -672,7 +672,7 @@ impl Model {
model.set_input_fact(i, fact)?;
}
for (i, _) in model.clone().outputs.iter().enumerate() {
for (i, _) in outputs.iter().enumerate() {
model.set_output_fact(i, InferenceFact::default())?;
}
@@ -1196,7 +1196,7 @@ impl Model {
.base
.layout(
&mut thread_safe_region,
&[output.clone(), comparators],
&[output, &comparators],
Box::new(HybridOp::Output {
decomp: !run_args.ignore_range_check_inputs_outputs,
}),
@@ -1257,12 +1257,27 @@ impl Model {
node.inputs()
.iter()
.map(|(idx, outlet)| {
Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
// check node is not an output
let is_output = self.graph.outputs.iter().any(|(o_idx, _)| *idx == *o_idx);
let res = if self.graph.nodes[idx].num_uses() == 1 && !is_output {
let res = results.remove(idx);
res.ok_or(GraphError::MissingResults(*idx))?[*outlet].clone()
} else {
results.get(idx).ok_or(GraphError::MissingResults(*idx))?[*outlet]
.clone()
};
Ok(res)
})
.collect::<Result<Vec<_>, GraphError>>()?
} else {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
if self.graph.nodes[idx].num_uses() == 1 {
let res = results.remove(idx);
vec![res.ok_or(GraphError::MissingInput(*idx))?[0].clone()]
} else {
vec![results.get(idx).ok_or(GraphError::MissingInput(*idx))?[0].clone()]
}
};
trace!("output dims: {:?}", node.out_dims());
trace!(
@@ -1273,7 +1288,7 @@ impl Model {
let start = instant::Instant::now();
match &node {
NodeType::Node(n) => {
let res = if node.is_constant() && node.num_uses() == 1 {
let mut res = if node.is_constant() && node.num_uses() == 1 {
log::debug!("node {} is a constant with 1 use", n.idx);
let mut node = n.clone();
let c = node
@@ -1284,19 +1299,19 @@ impl Model {
} else {
config
.base
.layout(region, &values, n.opkind.clone_dyn())
.layout(region, &values.iter().collect_vec(), n.opkind.clone_dyn())
.map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?
};
if let Some(mut vt) = res {
if let Some(vt) = &mut res {
vt.reshape(&node.out_dims()[0])?;
// we get the max as for fused nodes this corresponds to the node output
results.insert(*idx, vec![vt.clone()]);
//only use with mock prover
debug!("------------ output node {:?}: {:?}", idx, vt.show());
// we get the max as for fused nodes this corresponds to the node output
results.insert(*idx, vec![vt.clone()]);
}
}
NodeType::SubGraph {
@@ -1340,7 +1355,7 @@ impl Model {
.inputs
.clone()
.into_iter()
.zip(values.clone().into_iter().map(|v| vec![v])),
.zip(values.iter().map(|v| vec![v.clone()])),
);
let res = model.layout_nodes(config, region, &mut subgraph_results)?;
@@ -1421,7 +1436,7 @@ impl Model {
);
let outputs = output_nodes
.map(|(idx, outlet)| {
Ok(results.get(idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
Ok(results.get(idx).ok_or(GraphError::MissingResults(*idx))?[*outlet].clone())
})
.collect::<Result<Vec<_>, GraphError>>()?;
@@ -1476,7 +1491,7 @@ impl Model {
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
&[output, &comparator],
Box::new(HybridOp::Output {
decomp: !run_args.ignore_range_check_inputs_outputs,
}),

View File

@@ -37,15 +37,15 @@ impl ModuleConfigs {
/// Create new module configs from visibility of each variable
pub fn from_visibility(
cs: &mut ConstraintSystem<Fp>,
module_size: ModuleSizes,
module_size: &ModuleSizes,
logrows: usize,
) -> Self {
let mut config = Self::default();
for size in module_size.polycommit {
for size in &module_size.polycommit {
config
.polycommit
.push(PolyCommitChip::configure(cs, (logrows, size)));
.push(PolyCommitChip::configure(cs, (logrows, *size)));
}
config
@@ -55,8 +55,8 @@ impl ModuleConfigs {
pub fn configure_complex_modules(
&mut self,
cs: &mut ConstraintSystem<Fp>,
visibility: VarVisibility,
module_size: ModuleSizes,
visibility: &VarVisibility,
module_size: &ModuleSizes,
) {
if (visibility.input.is_hashed()
|| visibility.output.is_hashed()

View File

@@ -37,6 +37,7 @@ use crate::tensor::TensorError;
// Import curve-specific field type
use halo2curves::bn256::Fr as Fp;
use itertools::Itertools;
// Import logging for EZKL
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use log::trace;
@@ -118,16 +119,15 @@ impl Op<Fp> for Rescaled {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
values: &[&crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
if self.scale.len() != values.len() {
return Err(TensorError::DimMismatch("rescaled inputs".to_string()).into());
}
let res =
&crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)?
[..];
self.inner.layout(config, region, res)
crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)?;
self.inner.layout(config, region, &res.iter().collect_vec())
}
/// Create a cloned boxed copy of this operation
@@ -274,13 +274,13 @@ impl Op<Fp> for RebaseScale {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
values: &[&crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
let original_res = self
.inner
.layout(config, region, values)?
.ok_or(CircuitError::MissingLayout(self.as_string()))?;
self.rebase_op.layout(config, region, &[original_res])
self.rebase_op.layout(config, region, &[&original_res])
}
/// Create a cloned boxed copy of this operation
@@ -472,7 +472,7 @@ impl Op<Fp> for SupportedOp {
&self,
config: &mut crate::circuit::BaseConfig<Fp>,
region: &mut crate::circuit::region::RegionCtx<Fp>,
values: &[crate::tensor::ValTensor<Fp>],
values: &[&crate::tensor::ValTensor<Fp>],
) -> Result<Option<crate::tensor::ValTensor<Fp>>, CircuitError> {
self.as_op().layout(config, region, values)
}

View File

@@ -1,493 +0,0 @@
use log::{debug, error, info};
use std::fmt::Debug;
use std::net::IpAddr;
#[cfg(all(not(not(feature = "ezkl")), unix))]
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use std::{fmt, pin::Pin};
use tokio::task::JoinHandle;
#[doc(inline)]
pub use tokio_postgres::config::{
ChannelBinding, Host, LoadBalanceHosts, SslMode, TargetSessionAttrs,
};
use tokio_postgres::tls::NoTlsStream;
use tokio_postgres::NoTls;
use tokio_postgres::{error::DbError, types::ToSql, Error, Row, Socket, ToStatement};
/// Connection configuration.
///
/// Configuration can be parsed from libpq-style connection strings. These strings come in two formats:
///
///
#[derive(Clone)]
pub struct Config {
config: tokio_postgres::Config,
notice_callback: Arc<dyn Fn(DbError) + Send + Sync>,
}
impl fmt::Debug for Config {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Config")
.field("config", &self.config)
.finish()
}
}
impl Default for Config {
fn default() -> Config {
Config::new()
}
}
impl Config {
/// Creates a new configuration.
pub fn new() -> Config {
tokio_postgres::Config::new().into()
}
/// Sets the user to authenticate with.
///
/// If the user is not set, then this defaults to the user executing this process.
pub fn user(&mut self, user: &str) -> &mut Config {
self.config.user(user);
self
}
/// Gets the user to authenticate with, if one has been configured with
/// the `user` method.
pub fn get_user(&self) -> Option<&str> {
self.config.get_user()
}
/// Sets the password to authenticate with.
pub fn password<T>(&mut self, password: T) -> &mut Config
where
T: AsRef<[u8]>,
{
self.config.password(password);
self
}
/// Gets the password to authenticate with, if one has been configured with
/// the `password` method.
pub fn get_password(&self) -> Option<&[u8]> {
self.config.get_password()
}
/// Sets the name of the database to connect to.
///
/// Defaults to the user.
pub fn dbname(&mut self, dbname: &str) -> &mut Config {
self.config.dbname(dbname);
self
}
/// Gets the name of the database to connect to, if one has been configured
/// with the `dbname` method.
pub fn get_dbname(&self) -> Option<&str> {
self.config.get_dbname()
}
/// Sets command line options used to configure the server.
pub fn options(&mut self, options: &str) -> &mut Config {
self.config.options(options);
self
}
/// Gets the command line options used to configure the server, if the
/// options have been set with the `options` method.
pub fn get_options(&self) -> Option<&str> {
self.config.get_options()
}
/// Sets the value of the `application_name` runtime parameter.
pub fn application_name(&mut self, application_name: &str) -> &mut Config {
self.config.application_name(application_name);
self
}
/// Gets the value of the `application_name` runtime parameter, if it has
/// been set with the `application_name` method.
pub fn get_application_name(&self) -> Option<&str> {
self.config.get_application_name()
}
/// Sets the SSL configuration.
///
/// Defaults to `prefer`.
pub fn ssl_mode(&mut self, ssl_mode: SslMode) -> &mut Config {
self.config.ssl_mode(ssl_mode);
self
}
/// Gets the SSL configuration.
pub fn get_ssl_mode(&self) -> SslMode {
self.config.get_ssl_mode()
}
/// Adds a host to the configuration.
///
/// Multiple hosts can be specified by calling this method multiple times, and each will be tried in order. On Unix
/// systems, a host starting with a `/` is interpreted as a path to a directory containing Unix domain sockets.
/// There must be either no hosts, or the same number of hosts as hostaddrs.
pub fn host(&mut self, host: &str) -> &mut Config {
self.config.host(host);
self
}
/// Gets the hosts that have been added to the configuration with `host`.
pub fn get_hosts(&self) -> &[Host] {
self.config.get_hosts()
}
/// Gets the hostaddrs that have been added to the configuration with `hostaddr`.
pub fn get_hostaddrs(&self) -> &[IpAddr] {
self.config.get_hostaddrs()
}
/// Adds a Unix socket host to the configuration.
///
/// Unlike `host`, this method allows non-UTF8 paths.
#[cfg(all(not(not(feature = "ezkl")), unix))]
pub fn host_path<T>(&mut self, host: T) -> &mut Config
where
T: AsRef<Path>,
{
self.config.host_path(host);
self
}
/// Adds a hostaddr to the configuration.
///
/// Multiple hostaddrs can be specified by calling this method multiple times, and each will be tried in order.
/// There must be either no hostaddrs, or the same number of hostaddrs as hosts.
pub fn hostaddr(&mut self, hostaddr: IpAddr) -> &mut Config {
self.config.hostaddr(hostaddr);
self
}
/// Adds a port to the configuration.
///
/// Multiple ports can be specified by calling this method multiple times. There must either be no ports, in which
/// case the default of 5432 is used, a single port, in which it is used for all hosts, or the same number of ports
/// as hosts.
pub fn port(&mut self, port: u16) -> &mut Config {
self.config.port(port);
self
}
/// Gets the ports that have been added to the configuration with `port`.
pub fn get_ports(&self) -> &[u16] {
self.config.get_ports()
}
/// Sets the timeout applied to socket-level connection attempts.
///
/// Note that hostnames can resolve to multiple IP addresses, and this timeout will apply to each address of each
/// host separately. Defaults to no limit.
pub fn connect_timeout(&mut self, connect_timeout: Duration) -> &mut Config {
self.config.connect_timeout(connect_timeout);
self
}
/// Gets the connection timeout, if one has been set with the
/// `connect_timeout` method.
pub fn get_connect_timeout(&self) -> Option<&Duration> {
self.config.get_connect_timeout()
}
/// Sets the TCP user timeout.
///
/// This is ignored for Unix domain socket connections. It is only supported on systems where
/// TCP_USER_TIMEOUT is available and will default to the system default if omitted or set to 0;
/// on other systems, it has no effect.
pub fn tcp_user_timeout(&mut self, tcp_user_timeout: Duration) -> &mut Config {
self.config.tcp_user_timeout(tcp_user_timeout);
self
}
/// Gets the TCP user timeout, if one has been set with the
/// `user_timeout` method.
pub fn get_tcp_user_timeout(&self) -> Option<&Duration> {
self.config.get_tcp_user_timeout()
}
/// Controls the use of TCP keepalive.
///
/// This is ignored for Unix domain socket connections. Defaults to `true`.
pub fn keepalives(&mut self, keepalives: bool) -> &mut Config {
self.config.keepalives(keepalives);
self
}
/// Reports whether TCP keepalives will be used.
pub fn get_keepalives(&self) -> bool {
self.config.get_keepalives()
}
/// Sets the amount of idle time before a keepalive packet is sent on the connection.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled. Defaults to 2 hours.
pub fn keepalives_idle(&mut self, keepalives_idle: Duration) -> &mut Config {
self.config.keepalives_idle(keepalives_idle);
self
}
/// Gets the configured amount of idle time before a keepalive packet will
/// be sent on the connection.
pub fn get_keepalives_idle(&self) -> Duration {
self.config.get_keepalives_idle()
}
/// Sets the time interval between TCP keepalive probes.
/// On Windows, this sets the value of the tcp_keepalive structs keepaliveinterval field.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
pub fn keepalives_interval(&mut self, keepalives_interval: Duration) -> &mut Config {
self.config.keepalives_interval(keepalives_interval);
self
}
/// Gets the time interval between TCP keepalive probes.
pub fn get_keepalives_interval(&self) -> Option<Duration> {
self.config.get_keepalives_interval()
}
/// Sets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
///
/// This is ignored for Unix domain sockets, or if the `keepalives` option is disabled.
pub fn keepalives_retries(&mut self, keepalives_retries: u32) -> &mut Config {
self.config.keepalives_retries(keepalives_retries);
self
}
/// Gets the maximum number of TCP keepalive probes that will be sent before dropping a connection.
pub fn get_keepalives_retries(&self) -> Option<u32> {
self.config.get_keepalives_retries()
}
/// Sets the requirements of the session.
///
/// This can be used to connect to the primary server in a clustered database rather than one of the read-only
/// secondary servers. Defaults to `Any`.
pub fn target_session_attrs(
&mut self,
target_session_attrs: TargetSessionAttrs,
) -> &mut Config {
self.config.target_session_attrs(target_session_attrs);
self
}
/// Gets the requirements of the session.
pub fn get_target_session_attrs(&self) -> TargetSessionAttrs {
self.config.get_target_session_attrs()
}
/// Sets the channel binding behavior.
///
/// Defaults to `prefer`.
pub fn channel_binding(&mut self, channel_binding: ChannelBinding) -> &mut Config {
self.config.channel_binding(channel_binding);
self
}
/// Gets the channel binding behavior.
pub fn get_channel_binding(&self) -> ChannelBinding {
self.config.get_channel_binding()
}
/// Sets the host load balancing behavior.
///
/// Defaults to `disable`.
pub fn load_balance_hosts(&mut self, load_balance_hosts: LoadBalanceHosts) -> &mut Config {
self.config.load_balance_hosts(load_balance_hosts);
self
}
/// Gets the host load balancing behavior.
pub fn get_load_balance_hosts(&self) -> LoadBalanceHosts {
self.config.get_load_balance_hosts()
}
/// Sets the notice callback.
///
/// This callback will be invoked with the contents of every
/// [`AsyncMessage::Notice`] that is received by the connection. Notices use
/// the same structure as errors, but they are not "errors" per-se.
///
/// Notices are distinct from notifications, which are instead accessible
/// via the [`Notifications`] API.
///
/// [`AsyncMessage::Notice`]: tokio_postgres::AsyncMessage::Notice
/// [`Notifications`]: crate::Notifications
pub fn notice_callback<F>(&mut self, f: F) -> &mut Config
where
F: Fn(DbError) + Send + Sync + 'static,
{
self.notice_callback = Arc::new(f);
self
}
/// Opens a connection to a PostgreSQL database.
pub async fn connect(&self) -> Result<Client, Error> {
let (client, connection) = self.config.connect(NoTls).await?;
let connection = Connection::new(connection);
Ok(Client::new(client, connection))
}
}
impl FromStr for Config {
type Err = Error;
fn from_str(s: &str) -> Result<Config, Error> {
s.parse::<tokio_postgres::Config>().map(Config::from)
}
}
impl From<tokio_postgres::Config> for Config {
fn from(config: tokio_postgres::Config) -> Config {
Config {
config,
notice_callback: Arc::new(|notice| {
info!("{}: {}", notice.severity(), notice.message())
}),
}
}
}
#[allow(missing_debug_implementations, dead_code)]
/// An asynchronous PostgreSQL connection. We use this to keep the connection alive / keep it pinned so that it doesn't
/// get dropped.
pub struct Connection {
/// The underlying connection stream.
connection: Pin<Box<tokio_postgres::Connection<Socket, NoTlsStream>>>,
}
impl Connection {
/// Creates a new connection.
pub fn new(connection: tokio_postgres::Connection<Socket, NoTlsStream>) -> Self {
Connection {
connection: Box::pin(connection),
}
}
/// start the connection
pub async fn start(self) {
if let Err(e) = self.connection.await {
error!("connection error: {}", e);
}
}
}
#[allow(missing_debug_implementations, dead_code)]
/// An asynchronous PostgreSQL client.
pub struct Client {
connection: JoinHandle<()>,
client: tokio_postgres::Client,
}
impl Drop for Client {
fn drop(&mut self) {
let _ = self.close_inner();
}
}
impl Client {
pub(crate) fn new(client: tokio_postgres::Client, connection: Connection) -> Client {
// The connection object performs the actual communication with the database,
// so spawn it off to run on its own.
let thread = tokio::spawn(async move {
connection.start().await;
});
Client {
client,
connection: thread,
}
}
/// A convenience function which parses a configuration string into a `Config` and then connects to the database.
///
/// See the documentation for [`Config`] for information about the connection syntax.
///
/// [`Config`]: config/struct.Config.html
pub async fn connect(params: &str) -> Result<Client, Error> {
debug!("Connecting to database with params: {}", params);
params.parse::<Config>()?.connect().await
}
/// Returns a new `Config` object which can be used to configure and connect to a database.
pub fn configure() -> Config {
Config::new()
}
/// Executes a statement, returning the number of rows modified.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// If the statement does not modify any rows (e.g. `SELECT`), 0 is returned.
///
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
pub async fn execute<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<u64, Error>
where
T: ?Sized + ToStatement + Debug,
{
debug!("Executing query: {:?}", query);
self.client.execute(query, params).await
}
/// Executes a statement, returning the resulting rows.
///
/// A statement may contain parameters, specified by `$n`, where `n` is the index of the parameter of the list
/// provided, 1-indexed.
///
/// The `query` argument can either be a `Statement`, or a raw query string. If the same statement will be
/// repeatedly executed (perhaps with different query parameters), consider preparing the statement up front
/// with the `prepare` method.
///
/// # Examples
///
pub async fn query<T>(
&mut self,
query: &T,
params: &[&(dyn ToSql + Sync)],
) -> Result<Vec<Row>, Error>
where
T: ?Sized + ToStatement + Debug,
{
debug!("Executing query: {:?}", query);
self.client.query(query, params).await
}
/// Determines if the client's connection has already closed.
///
/// If this returns `true`, the client is no longer usable.
pub fn is_closed(&self) -> bool {
self.client.is_closed()
}
/// Closes the client's connection to the server.
///
/// This is equivalent to `Client`'s `Drop` implementation, except that it returns any error encountered to the
/// caller.
pub fn close(mut self) -> Result<(), Error> {
self.close_inner()
}
fn close_inner(&mut self) -> Result<(), Error> {
self.client.__private_api_close();
Ok(())
}
}

View File

@@ -1,14 +1,14 @@
use super::errors::GraphError;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use super::VarScales;
use super::errors::GraphError;
use super::{Rescaled, SupportedOp, Visibility};
use crate::circuit::Op;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::hybrid::HybridOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::lookup::LookupOp;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use crate::circuit::poly::PolyOp;
use crate::circuit::Op;
use crate::fieldutils::IntegerRep;
use crate::tensor::{Tensor, TensorError, TensorType};
use halo2curves::bn256::Fr as Fp;
@@ -22,7 +22,6 @@ use std::sync::Arc;
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_core::ops::{
Downsample,
array::{
Gather, GatherElements, GatherNd, MultiBroadcastTo, OneHot, ScatterElements, ScatterNd,
Slice, Topk,
@@ -32,6 +31,7 @@ use tract_onnx::tract_core::ops::{
einsum::EinSum,
element_wise::ElementWiseOp,
nn::{LeakyRelu, Reduce, Softmax},
Downsample,
};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use tract_onnx::tract_hir::{

View File

@@ -29,8 +29,14 @@
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
use log::warn;
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use mimalloc as _;
#[global_allocator]
#[cfg(all(feature = "jemalloc", not(target_arch = "wasm32")))]
static GLOBAL: jemallocator::Jemalloc = jemallocator::Jemalloc;
#[global_allocator]
#[cfg(all(feature = "mimalloc", not(target_arch = "wasm32")))]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
/// Error type
// #[cfg_attr(not(feature = "ezkl"), derive(uniffi::Error))]
@@ -97,11 +103,11 @@ impl From<String> for EZKLError {
use std::str::FromStr;
use circuit::{CheckMode, table::Range};
use circuit::{table::Range, CheckMode};
#[cfg(all(feature = "ezkl", not(target_arch = "wasm32")))]
use clap::Args;
use fieldutils::IntegerRep;
use graph::{MAX_PUBLIC_SRS, Visibility};
use graph::{Visibility, MAX_PUBLIC_SRS};
use halo2_proofs::poly::{
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
};

View File

@@ -17,16 +17,16 @@ use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
use clap::ValueEnum;
use halo2_proofs::circuit::Value;
use halo2_proofs::plonk::{
Circuit, ProvingKey, VerifyingKey, create_proof, keygen_pk, keygen_vk_custom, verify_proof,
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::VerificationStrategy;
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
use halo2curves::CurveAffine;
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
use halo2curves::serde::SerdeObject;
use halo2curves::CurveAffine;
use instant::Instant;
use log::{debug, info, trace};
#[cfg(not(feature = "det-prove"))]
@@ -324,7 +324,7 @@ where
}
#[cfg(feature = "python-bindings")]
use pyo3::{PyObject, Python, ToPyObject, types::PyDict};
use pyo3::{types::PyDict, PyObject, Python, ToPyObject};
#[cfg(feature = "python-bindings")]
impl<F: PrimeField + SerdeObject + Serialize, C: CurveAffine + Serialize> ToPyObject for Snark<F, C>
where
@@ -348,9 +348,9 @@ where
}
impl<
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
F: PrimeField + SerdeObject + Serialize + FromUniformBytes<64> + DeserializeOwned,
C: CurveAffine + Serialize + DeserializeOwned,
> Snark<F, C>
where
C::Scalar: Serialize + DeserializeOwned,
C::ScalarExt: Serialize + DeserializeOwned,

View File

@@ -27,7 +27,7 @@ pub use var::*;
use crate::{
circuit::utils,
fieldutils::{IntegerRep, integer_rep_to_felt},
fieldutils::{integer_rep_to_felt, IntegerRep},
graph::Visibility,
};
@@ -42,11 +42,11 @@ use std::error::Error;
use std::fmt::Debug;
use std::io::Read;
use std::iter::Iterator;
use std::ops::Rem;
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
use std::{cmp::max, ops::Rem};
/// The (inner) type of tensor elements.
pub trait TensorType: Clone + Debug + 'static {
pub trait TensorType: Clone + Debug {
/// Returns the zero value.
fn zero() -> Option<Self> {
None
@@ -55,10 +55,6 @@ pub trait TensorType: Clone + Debug + 'static {
fn one() -> Option<Self> {
None
}
/// Max operator for ordering values.
fn tmax(&self, _: &Self) -> Option<Self> {
None
}
}
macro_rules! tensor_type {
@@ -70,10 +66,6 @@ macro_rules! tensor_type {
fn one() -> Option<Self> {
Some($one)
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some(max(*self, *other))
}
}
};
}
@@ -82,46 +74,12 @@ impl TensorType for f32 {
fn zero() -> Option<Self> {
Some(0.0)
}
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
// A comparison between f32s needs to handle NAN values.
fn tmax(&self, other: &Self) -> Option<Self> {
match (self.is_nan(), other.is_nan()) {
(true, true) => Some(f32::NAN),
(true, false) => Some(*other),
(false, true) => Some(*self),
(false, false) => {
if self >= other {
Some(*self)
} else {
Some(*other)
}
}
}
}
}
impl TensorType for f64 {
fn zero() -> Option<Self> {
Some(0.0)
}
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
// A comparison between f32s needs to handle NAN values.
fn tmax(&self, other: &Self) -> Option<Self> {
match (self.is_nan(), other.is_nan()) {
(true, true) => Some(f64::NAN),
(true, false) => Some(*other),
(false, true) => Some(*self),
(false, false) => {
if self >= other {
Some(*self)
} else {
Some(*other)
}
}
}
}
}
tensor_type!(bool, Bool, false, true);
@@ -147,14 +105,6 @@ impl<T: TensorType> TensorType for Value<T> {
fn one() -> Option<Self> {
Some(Value::known(T::one().unwrap()))
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some(
(self.clone())
.zip(other.clone())
.map(|(a, b)| a.tmax(&b).unwrap()),
)
}
}
impl<F: PrimeField + PartialOrd> TensorType for Assigned<F>
@@ -168,14 +118,6 @@ where
fn one() -> Option<Self> {
Some(F::ONE.into())
}
fn tmax(&self, other: &Self) -> Option<Self> {
if self.evaluate() >= other.evaluate() {
Some(*self)
} else {
Some(*other)
}
}
}
impl<F: PrimeField> TensorType for Expression<F>
@@ -189,42 +131,14 @@ where
fn one() -> Option<Self> {
Some(Expression::Constant(F::ONE))
}
fn tmax(&self, _: &Self) -> Option<Self> {
todo!()
}
}
impl TensorType for Column<Advice> {}
impl TensorType for Column<Fixed> {}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<Assigned<F>, F> {
fn tmax(&self, other: &Self) -> Option<Self> {
let mut output: Option<Self> = None;
self.value_field().zip(other.value_field()).map(|(a, b)| {
if a.evaluate() >= b.evaluate() {
output = Some(self.clone());
} else {
output = Some(other.clone());
}
});
output
}
}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<Assigned<F>, F> {}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<F, F> {
fn tmax(&self, other: &Self) -> Option<Self> {
let mut output: Option<Self> = None;
self.value().zip(other.value()).map(|(a, b)| {
if a >= b {
output = Some(self.clone());
} else {
output = Some(other.clone());
}
});
output
}
}
impl<F: PrimeField + PartialOrd> TensorType for AssignedCell<F, F> {}
// specific types
impl TensorType for halo2curves::pasta::Fp {
@@ -235,10 +149,6 @@ impl TensorType for halo2curves::pasta::Fp {
fn one() -> Option<Self> {
Some(halo2curves::pasta::Fp::one())
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some((*self).max(*other))
}
}
impl TensorType for halo2curves::bn256::Fr {
@@ -249,9 +159,15 @@ impl TensorType for halo2curves::bn256::Fr {
fn one() -> Option<Self> {
Some(halo2curves::bn256::Fr::one())
}
}
fn tmax(&self, other: &Self) -> Option<Self> {
Some((*self).max(*other))
impl<F: TensorType> TensorType for &F {
fn zero() -> Option<Self> {
None
}
fn one() -> Option<Self> {
None
}
}
@@ -374,7 +290,7 @@ impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
}
}
impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync>
impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync + 'data>
maybe_rayon::iter::IntoParallelRefMutIterator<'data> for Tensor<T>
{
type Iter = maybe_rayon::slice::IterMut<'data, T>;
@@ -423,6 +339,14 @@ impl<T: Clone + TensorType + PrimeField> Tensor<T> {
}
}
impl<T: Clone + TensorType> Tensor<&T> {
/// Clones the tensor values into a new tensor.
pub fn cloned(&self) -> Tensor<T> {
let inner = self.inner.clone().into_iter().cloned().collect::<Vec<T>>();
Tensor::new(Some(&inner), &self.dims).unwrap()
}
}
impl<T: Clone + TensorType> Tensor<T> {
/// Sets (copies) the tensor values to the provided ones.
pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result<Self, TensorError> {
@@ -554,7 +478,6 @@ impl<T: Clone + TensorType> Tensor<T> {
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
/// ```
@@ -631,23 +554,23 @@ impl<T: Clone + TensorType> Tensor<T> {
// Fill remaining dimensions
full_indices.extend((indices.len()..self.dims.len()).map(|i| 0..self.dims[i]));
// Pre-calculate total size and allocate result vector
let total_size: usize = full_indices
.iter()
.map(|range| range.end - range.start)
.product();
let mut res = Vec::with_capacity(total_size);
// Calculate new dimensions once
let dims: Vec<usize> = full_indices.iter().map(|e| e.end - e.start).collect();
// Use iterator directly without collecting into intermediate Vec
for coord in full_indices.iter().cloned().multi_cartesian_product() {
let index = self.get_index(&coord);
res.push(self[index].clone());
}
let mut output = Tensor::new(None, &dims)?;
Tensor::new(Some(&res), &dims)
let cartesian_coord: Vec<Vec<usize>> = full_indices
.iter()
.cloned()
.multi_cartesian_product()
.collect();
output.par_iter_mut().enumerate().for_each(|(i, e)| {
let coord = &cartesian_coord[i];
*e = self.get(coord);
});
Ok(output)
}
/// Set a slice of the Tensor.
@@ -753,7 +676,7 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
pub fn get_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
for (i, elem) in self.inner.clone().into_iter().enumerate() {
for (i, elem) in self.inner.iter().enumerate() {
if i % n == 0 {
inner.push(elem.clone());
}
@@ -776,7 +699,7 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
pub fn exclude_every_n(&self, n: usize) -> Result<Tensor<T>, TensorError> {
let mut inner: Vec<T> = vec![];
for (i, elem) in self.inner.clone().into_iter().enumerate() {
for (i, elem) in self.inner.iter().enumerate() {
if i % n != 0 {
inner.push(elem.clone());
}
@@ -812,9 +735,9 @@ impl<T: Clone + TensorType> Tensor<T> {
let mut inner: Vec<T> = Vec::with_capacity(self.inner.len());
let mut offset = initial_offset;
for (i, elem) in self.inner.clone().into_iter().enumerate() {
for (i, elem) in self.inner.iter().enumerate() {
if (i + offset + 1) % n == 0 {
inner.extend(vec![elem; 1 + num_repeats]);
inner.extend(vec![elem.clone(); 1 + num_repeats]);
offset += num_repeats;
} else {
inner.push(elem.clone());
@@ -871,16 +794,16 @@ impl<T: Clone + TensorType> Tensor<T> {
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
/// let mut indices = vec![3, 4];
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), expected);
/// assert_eq!(a.remove_indices(&mut indices, false).unwrap(), expected);
///
///
/// let a = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
/// let b = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
/// let mut indices = vec![63, 67];
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), b);
/// assert_eq!(a.remove_indices(&mut indices, false).unwrap(), b);
/// ```
pub fn remove_indices(
&self,
@@ -927,7 +850,7 @@ impl<T: Clone + TensorType> Tensor<T> {
}
self.dims = vec![];
}
if self.dims() == &[0] && new_dims.iter().product::<usize>() == 1 {
if self.dims() == [0] && new_dims.iter().product::<usize>() == 1 {
self.dims = Vec::from(new_dims);
} else {
let product = if new_dims != [0] {
@@ -1292,33 +1215,6 @@ impl<T: Clone + TensorType> Tensor<T> {
Tensor::new(Some(&[res]), &[1])
}
/// Maps a function to tensors and enumerates in parallel
/// ```
/// use ezkl::tensor::{Tensor, TensorError};
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
/// ```
pub fn par_enum_map_mut_filtered<
F: Fn(usize) -> Result<T, E> + std::marker::Send + std::marker::Sync,
E: Error + std::marker::Send + std::marker::Sync,
>(
&mut self,
filter_indices: &std::collections::HashSet<usize>,
f: F,
) -> Result<(), E>
where
T: std::marker::Send + std::marker::Sync,
{
self.inner
.par_iter_mut()
.enumerate()
.filter(|(i, _)| filter_indices.contains(i))
.for_each(move |(i, e)| *e = f(i).unwrap());
Ok(())
}
}
impl<T: Clone + TensorType> Tensor<Tensor<T>> {
@@ -1335,9 +1231,9 @@ impl<T: Clone + TensorType> Tensor<Tensor<T>> {
pub fn combine(&self) -> Result<Tensor<T>, TensorError> {
let mut dims = 0;
let mut inner = Vec::new();
for t in self.inner.clone().into_iter() {
for t in self.inner.iter() {
dims += t.len();
inner.extend(t.inner);
inner.extend(t.inner.clone());
}
Tensor::new(Some(&inner), &[dims])
}
@@ -1808,7 +1704,7 @@ impl DataFormat {
match self {
DataFormat::NHWC => DataFormat::NCHW,
DataFormat::HWC => DataFormat::CHW,
_ => self.clone(),
_ => *self,
}
}
@@ -1908,7 +1804,7 @@ impl KernelFormat {
match self {
KernelFormat::HWIO => KernelFormat::OIHW,
KernelFormat::OHWI => KernelFormat::OIHW,
_ => self.clone(),
_ => *self,
}
}
@@ -2030,6 +1926,9 @@ mod tests {
fn tensor_slice() {
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
assert_eq!(
a.get_slice(&[0..2, 0..1]).unwrap(),
b.get_slice(&[0..2, 0..1]).unwrap()
);
}
}

View File

@@ -916,11 +916,14 @@ pub fn gather_elements<T: TensorType + Send + Sync>(
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 7]), &[2]).unwrap();
/// assert_eq!(result, expected);
///
pub fn gather_nd<T: TensorType + Send + Sync>(
input: &Tensor<T>,
pub fn gather_nd<'a, T: TensorType + Send + Sync + 'a>(
input: &'a Tensor<T>,
index: &Tensor<usize>,
batch_dims: usize,
) -> Result<Tensor<T>, TensorError> {
) -> Result<Tensor<T>, TensorError>
where
&'a T: TensorType,
{
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
@@ -1108,11 +1111,14 @@ pub fn gather_nd<T: TensorType + Send + Sync>(
/// assert_eq!(result, expected);
/// ````
///
pub fn scatter_nd<T: TensorType + Send + Sync>(
pub fn scatter_nd<'a, T: TensorType + Send + Sync + 'a>(
input: &Tensor<T>,
index: &Tensor<usize>,
src: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
src: &'a Tensor<T>,
) -> Result<Tensor<T>, TensorError>
where
&'a T: TensorType,
{
// Calculate the output tensor size
let index_dims = index.dims().to_vec();
let input_dims = input.dims().to_vec();
@@ -1183,12 +1189,12 @@ pub fn abs<T: TensorType + Add<Output = T> + std::cmp::Ord + Neg<Output = T>>(
/// use ezkl::tensor::ops::intercalate_values;
///
/// let tensor = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4]), &[2, 2]).unwrap();
/// let result = intercalate_values(&tensor, 0, 2, 1).unwrap();
/// let result = intercalate_values(&tensor, &0, 2, 1).unwrap();
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 2, 3, 0, 4]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// let result = intercalate_values(&expected, 0, 2, 0).unwrap();
/// let result = intercalate_values(&expected, &0, 2, 0).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 0, 2, 0, 0, 0, 3, 0, 4]), &[3, 3]).unwrap();
///
/// assert_eq!(result, expected);
@@ -1196,7 +1202,7 @@ pub fn abs<T: TensorType + Add<Output = T> + std::cmp::Ord + Neg<Output = T>>(
/// ```
pub fn intercalate_values<T: TensorType>(
tensor: &Tensor<T>,
value: T,
value: &T,
stride: usize,
axis: usize,
) -> Result<Tensor<T>, TensorError> {
@@ -1494,7 +1500,7 @@ pub fn slice<T: TensorType + Send + Sync>(
}
}
t.get_slice(&slice)
Ok(t.get_slice(&slice)?)
}
// ---------------------------------------------------------------------------------------------------------
@@ -2414,20 +2420,20 @@ pub mod accumulated {
/// Some(&[25, 35]),
/// &[2],
/// ).unwrap();
/// assert_eq!(dot(&[x, y], 1).unwrap(), expected);
/// assert_eq!(dot(&x, &y, 1).unwrap(), expected);
/// ```
pub fn dot<T: TensorType + Mul<Output = T> + Add<Output = T>>(
inputs: &[Tensor<T>; 2],
a: &Tensor<T>,
b: &Tensor<T>,
chunk_size: usize,
) -> Result<Tensor<T>, TensorError> {
if inputs[0].clone().len() != inputs[1].clone().len() {
if a.len() != b.len() {
return Err(TensorError::DimMismatch("dot".to_string()));
}
let (a, b): (Tensor<T>, Tensor<T>) = (inputs[0].clone(), inputs[1].clone());
let transcript: Tensor<T> = a
.iter()
.zip(b)
.zip(b.iter())
.chunks(chunk_size)
.into_iter()
.scan(T::zero().unwrap(), |acc, chunk| {

View File

@@ -280,9 +280,17 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Vec<ValType<F>>> for ValTenso
fn from(t: Vec<ValType<F>>) -> ValTensor<F> {
ValTensor::Value {
inner: t.clone().into_iter().into(),
dims: vec![t.len()],
scale: 1,
}
}
}
impl<F: PrimeField + TensorType + PartialOrd> From<Vec<&ValType<F>>> for ValTensor<F> {
fn from(t: Vec<&ValType<F>>) -> ValTensor<F> {
ValTensor::Value {
inner: t.clone().into_iter().cloned().into(),
dims: vec![t.len()],
scale: 1,
}
}
@@ -640,6 +648,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// # Returns
/// A tensor containing the base-n decomposition of each value
pub fn decompose(&self, base: usize, n: usize) -> Result<Self, TensorError> {
let mut dims = self.dims().to_vec();
dims.push(n + 1);
let res = self
.get_inner()?
.par_iter()
@@ -665,9 +676,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
})
.collect::<Result<Vec<_>, _>>();
let mut tensor = Tensor::from(res?.into_iter().flatten().collect::<Vec<_>>().into_iter());
let mut dims = self.dims().to_vec();
dims.push(n + 1);
let mut tensor = res?.into_iter().flatten().collect::<Tensor<_>>();
tensor.reshape(&dims)?;
@@ -1043,119 +1052,6 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
Ok(())
}
/// Removes constant zero values from the tensor
/// Uses parallel processing for tensors larger than a threshold
pub fn remove_const_zero_values(&mut self) {
let size_threshold = 1_000_000; // Tuned using benchmarks
if self.len() < size_threshold {
// Single-threaded for small tensors
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.clone()
.into_iter()
.filter_map(|e| {
if let ValType::Constant(r) = e {
if r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if r == F::ZERO {
return None;
}
}
Some(e)
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => {}
}
} else {
// Parallel processing for large tensors
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match self {
ValTensor::Value { inner: v, dims, .. } => {
*v = v
.par_chunks_mut(chunk_size)
.flat_map(|chunk| {
chunk.par_iter_mut().filter_map(|e| {
if let ValType::Constant(r) = e {
if *r == F::ZERO {
return None;
}
} else if let ValType::AssignedConstant(_, r) = e {
if *r == F::ZERO {
return None;
}
}
Some(e.clone())
})
})
.collect();
*dims = v.dims().to_vec();
}
ValTensor::Instance { .. } => {}
}
}
}
/// Gets the indices of all constant zero values
/// Uses parallel processing for large tensors
///
/// # Returns
/// A vector of indices where constant zero values are located
pub fn get_const_zero_indices(&self) -> Vec<usize> {
let size_threshold = 1_000_000;
if self.len() < size_threshold {
// Single-threaded for small tensors
match &self {
ValTensor::Value { inner: v, .. } => v
.iter()
.enumerate()
.filter_map(|(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(i)
}
_ => None,
})
.collect(),
ValTensor::Instance { .. } => vec![],
}
} else {
// Parallel processing for large tensors
let num_cores = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
let chunk_size = (self.len() / num_cores).max(100_000);
match &self {
ValTensor::Value { inner: v, .. } => v
.par_chunks(chunk_size)
.enumerate()
.flat_map(|(chunk_idx, chunk)| {
chunk
.par_iter()
.enumerate()
.filter_map(move |(i, e)| match e {
ValType::Constant(r) | ValType::AssignedConstant(_, r) => {
(*r == F::ZERO).then_some(chunk_idx * chunk_size + i)
}
_ => None,
})
})
.collect::<Vec<_>>(),
ValTensor::Instance { .. } => vec![],
}
}
}
/// Gets the indices of all constant values
/// Uses parallel processing for large tensors
///
@@ -1276,7 +1172,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Returns an error if called on an Instance tensor
pub fn intercalate_values(
&mut self,
value: ValType<F>,
value: &ValType<F>,
stride: usize,
axis: usize,
) -> Result<(), TensorError> {
@@ -1368,10 +1264,10 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// Concatenates two tensors along the first dimension
pub fn concat(&self, other: Self) -> Result<Self, TensorError> {
pub fn concat(&self, other: &Self) -> Result<Self, TensorError> {
let res = match (self, other) {
(ValTensor::Value { inner: v1, .. }, ValTensor::Value { inner: v2, .. }) => {
ValTensor::from(Tensor::new(Some(&[v1.clone(), v2]), &[2])?.combine()?)
ValTensor::from(Tensor::new(Some(&[v1.clone(), v2.clone()]), &[2])?.combine()?)
}
_ => {
return Err(TensorError::WrongMethod);

View File

@@ -1,8 +1,6 @@
use std::collections::HashSet;
use log::{debug, error, warn};
use crate::circuit::{CheckMode, region::ConstantsMap};
use crate::circuit::{region::ConstantsMap, CheckMode};
use super::*;
/// A wrapper around Halo2's Column types that represents a tensor of variables in the circuit.
@@ -381,50 +379,6 @@ impl VarTensor {
}
}
/// Assigns values from a ValTensor to this tensor, excluding specified positions
///
/// # Arguments
/// * `region` - The region to assign values in
/// * `offset` - Base offset for assignments
/// * `values` - The ValTensor containing values to assign
/// * `omissions` - Set of positions to skip during assignment
/// * `constants` - Map for tracking constant assignments
///
/// # Returns
/// The assigned ValTensor or an error if assignment fails
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
omissions: &HashSet<usize>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut assigned_coord = 0;
let mut res: ValTensor<F> = match values {
ValTensor::Instance { .. } => {
error!(
"assignment with omissions is not supported on instance columns. increase K if you require more rows."
);
Err(halo2_proofs::plonk::Error::Synthesis)
}
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
v.enum_map(|coord, k| {
if omissions.contains(&coord) {
return Ok::<_, halo2_proofs::plonk::Error>(k);
}
let cell =
self.assign_value(region, offset, k.clone(), assigned_coord, constants)?;
assigned_coord += 1;
Ok::<_, halo2_proofs::plonk::Error>(cell)
})?
.into(),
),
}?;
res.set_scale(values.scale());
Ok(res)
}
/// Assigns values from a ValTensor to this tensor
///
/// # Arguments

View File

@@ -3,10 +3,10 @@
mod native_tests {
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::Commitments;
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
use ezkl::pfsys::Snark;
use ezkl::Commitments;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2curves::bn256::Bn256;
use lazy_static::lazy_static;

View File

@@ -272,16 +272,6 @@ mod py_tests {
anvil_child.kill().unwrap();
}
#[test]
fn postgres_notebook_() {
crate::py_tests::init_binary();
let test_dir: TempDir = TempDir::new("mean_postgres").unwrap();
let path = test_dir.path().to_str().unwrap();
crate::py_tests::mv_test_(path, "mean_postgres.ipynb");
run_notebook(path, "mean_postgres.ipynb");
test_dir.close().unwrap();
}
#[test]
fn tictactoe_autoencoder_notebook_() {
crate::py_tests::init_binary();