Compare commits

...

32 Commits

Author SHA1 Message Date
dante
9bbc89cc89 chore: bump h2 2024-08-19 17:44:46 -04:00
dante
28b65f2639 chore: bump halo2proofs 2024-08-19 17:38:22 -04:00
dante
9592d38a8f chore: add asm feature 2024-08-19 17:33:27 -04:00
dante
2cec49dfc3 chore: bump fft algos 2024-08-18 20:50:10 -04:00
dante
31a1681ca4 chore: update h2 curves 2024-08-18 13:36:00 -04:00
dante
134b54d32b Update Cargo.toml 2024-08-15 12:42:45 -04:00
dante
beb5f12376 chore: use mimalloc 2024-08-15 12:40:55 -04:00
dante
65be3c84bb Update Cargo.toml 2024-08-14 18:05:26 -04:00
dante
6f743c57d3 chore: parallelize prepare and commit 2024-08-13 15:06:47 -04:00
dante
ddb54c5a73 feat: precompute lookup cosets 2024-08-08 18:15:22 -04:00
dante
6e1f22a15b log lack of cache 2024-08-08 10:44:41 -04:00
dante
da97323bde feat: cache lookup tables 2024-08-08 09:12:40 -04:00
dante
55046feeb6 Update Cargo.toml 2024-08-07 23:42:19 -04:00
dante
d0d0596e58 chore: bump h2 2024-08-07 23:40:42 -04:00
dante
b78efdcbf4 fix: add required serde patches 2024-08-07 18:30:08 -04:00
dante
5389012b68 fix: patch large batch ex (#763) 2024-04-03 02:33:57 +01:00
dante
48223cca11 fix: make commitment optional for backwards compat (#762) 2024-04-03 02:26:50 +01:00
dante
32c3a5e159 fix: hold stacked outputs in a separate map 2024-04-02 21:37:20 +01:00
dante
ff563e93a7 fix: bump python version (#761) 2024-04-02 17:08:26 +01:00
dante
5639d36097 chore: verify aggr wasm unit test (#760) 2024-04-01 20:54:20 +01:00
dante
4ec8d13082 chore: verify aggr in wasm (#758) 2024-03-29 23:28:20 +00:00
dante
12735aefd4 chore: reduce softmax recip DR (#756) 2024-03-27 01:14:29 +00:00
dante
7fe179b8d4 feat: dictionary of reusable constants (#754) 2024-03-26 13:12:09 +00:00
Ethan Cemer
3be988a6a0 fix: use pnpm in build script for in-browser-evm-verifier (#752) 2024-03-25 23:23:02 +00:00
dante
3abb3aff56 feat: make selector polynomials optional (#753) 2024-03-22 09:28:28 +00:00
dante
338788cb8f fix: lookup safety = 1 during calibration falls OOR (#750) 2024-03-21 08:53:43 +00:00
Sung Jun Eun
feb3b1b475 fix: array element encapsulation in ezkl_demo.ipynb (#747) 2024-03-21 08:51:01 +00:00
dante
e134d86756 refactor: apply num-inner cols to constant assignments as well (#749) 2024-03-20 23:51:38 +00:00
dante
6819a3acf6 chore: more complete coverage tests (#748) 2024-03-20 18:53:47 +00:00
dante
2ccf056661 fix: logrows reset after graph creation can cause extended K overflow (#745) 2024-03-20 10:15:11 +00:00
dante
a5bf64b1a2 feat!: ipa commitments (#740)
BREAKING CHANGE: commitment is now an added flag
2024-03-16 16:31:01 +00:00
Ethan Cemer
56e2326be1 *nuke (#742) 2024-03-14 14:11:03 -05:00
124 changed files with 8663 additions and 2897 deletions

View File

@@ -11,7 +11,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: nanoGPT Mock

View File

@@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -30,13 +30,13 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
export PATH=$PATH:$PWD/binaryen-version_116/bin
wasm-opt --version
set -e
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
export PATH=$PATH:$PWD/binaryen-version_116/bin
wasm-opt --version
- name: Build wasm files for both web and nodejs compilation targets
run: |
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
@@ -92,7 +92,7 @@ jobs:
const jsonObject = JSONBig.parse(string);
return jsonObject;
}
function serialize(data) { // data is an object // return a Uint8ClampedArray
// Step 1: Stringify the Object with BigInt support
if (typeof data === "object") {
@@ -100,11 +100,11 @@ jobs:
}
// Step 2: Encode the JSON String
const uint8Array = new TextEncoder().encode(data);
// Step 3: Convert to Uint8ClampedArray
return new Uint8ClampedArray(uint8Array.buffer);
}
module.exports = {
deserialize,
serialize
@@ -123,7 +123,7 @@ jobs:
const jsonObject = parse(string);
return jsonObject;
}
export function serialize(data) { // data is an object // return a Uint8ClampedArray
// Step 1: Stringify the Object with BigInt support
if (typeof data === "object") {
@@ -131,7 +131,7 @@ jobs:
}
// Step 2: Encode the JSON String
const uint8Array = new TextEncoder().encode(data);
// Step 3: Convert to Uint8ClampedArray
return new Uint8ClampedArray(uint8Array.buffer);
}

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag

View File

@@ -25,7 +25,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -70,7 +70,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: ${{ matrix.target }}
- name: Set Cargo.toml version to match github tag
@@ -115,7 +115,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -139,6 +139,20 @@ jobs:
target: ${{ matrix.target }}
manylinux: auto
args: --release --out dist --features python-bindings
before-script-linux: |
# If we're running on rhel centos, install needed packages.
if command -v yum &> /dev/null; then
yum update -y && yum install -y perl-core openssl openssl-devel pkgconfig libatomic
# If we're running on i686 we need to symlink libatomic
# in order to build openssl with -latomic flag.
if [[ ! -d "/usr/lib64" ]]; then
ln -s /usr/lib/libatomic.so.1 /usr/lib/libatomic.so
fi
else
# If we're running on debian-based system.
apt update -y && apt-get install -y libssl-dev openssl pkg-config
fi
- name: Install built wheel
if: matrix.target == 'x86_64'
@@ -162,7 +176,7 @@ jobs:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v4
# with:
# python-version: 3.7
# python-version: 3.12
# - name: Install cross-compilation tools for aarch64
# if: matrix.target == 'aarch64'
@@ -214,7 +228,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -249,7 +263,7 @@ jobs:
apk add py3-pip
pip3 install -U pip
python3 -m venv .venv
source .venv/bin/activate
source .venv/bin/activate
pip3 install ezkl --no-index --find-links /io/dist/ --force-reinstall
python3 -c "import ezkl"
@@ -273,7 +287,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
- name: Set Cargo.toml version to match github tag
shell: bash

View File

@@ -32,7 +32,7 @@ jobs:
token: ${{ secrets.RELEASE_TOKEN }}
tag_name: ${{ env.EZKL_VERSION }}
build-release-gpu:
build-release-gpu:
name: build-release-gpu
needs: ["create-release"]
runs-on: GPU
@@ -45,7 +45,7 @@ jobs:
steps:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Checkout repo
@@ -60,16 +60,15 @@ jobs:
- name: Set Cargo.toml version to match github tag
shell: bash
run: |
mv Cargo.toml Cargo.toml.orig
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.toml.orig >Cargo.toml
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.lock.orig >Cargo.lock
mv Cargo.toml Cargo.toml.orig
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.toml.orig >Cargo.toml
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.lock.orig >Cargo.lock
- name: Install dependencies
shell: bash
run: |
sudo apt-get update
sudo apt-get update
- name: Build release binary
run: cargo build --release -Z sparse-registry --features icicle
@@ -91,7 +90,6 @@ jobs:
asset_name: ${{ env.ASSET }}
asset_content_type: application/octet-stream
build-release:
name: build-release
needs: ["create-release"]

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Build
@@ -38,7 +38,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Docs
@@ -50,7 +50,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -73,7 +73,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -106,7 +106,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -139,7 +139,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -172,7 +172,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -184,12 +184,12 @@ jobs:
wasm32-tests:
runs-on: ubuntu-latest
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -199,7 +199,7 @@ jobs:
- name: Install wasm32-unknown-unknown
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- name: Run wasm verifier tests
# on mac:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
@@ -207,12 +207,12 @@ jobs:
tutorial:
runs-on: ubuntu-latest
needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -224,12 +224,12 @@ jobs:
mock-proving-tests:
runs-on: non-gpu
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -281,12 +281,12 @@ jobs:
prove-and-verify-evm-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -303,6 +303,8 @@ jobs:
with:
node-version: "18.12.1"
cache: "pnpm"
- name: "Add rust-src"
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --no-frozen-lockfile
@@ -324,7 +326,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: KZG prove and verify tests (EVM + VK rendered seperately)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
@@ -352,12 +354,12 @@ jobs:
prove-and-verify-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -365,7 +367,7 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- name: Use pnpm 8
uses: pnpm/action-setup@v2
@@ -392,6 +394,12 @@ jobs:
- name: Replace memory definition in nodejs
run: |
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_tight_lookup_::t
- name: IPA prove and verify tests
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_::t --test-threads 1
- name: IPA prove and verify tests (ipa outputs)
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_ipa_output
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
- name: KZG prove and verify tests single inner col
@@ -408,8 +416,6 @@ jobs:
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
- name: KZG prove and verify tests (public outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t
- name: KZG prove and verify tests (public outputs + column overflow)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t
- name: KZG prove and verify tests (public inputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input
- name: KZG prove and verify tests (fixed params)
@@ -425,11 +431,11 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- uses: baptiste0928/cargo-install@v1
with:
@@ -452,44 +458,21 @@ jobs:
- name: KZG prove and verify tests (hashed outputs)
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
fuzz-tests:
runs-on: ubuntu-latest-32-cores
needs: [build, library-tests, python-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
- name: fuzz tests (EVM)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_fuzz_ --test-threads 2
# - name: fuzz tests
# run: cargo nextest run --release --verbose tests::kzg_fuzz_ --test-threads 6
prove-and-verify-mock-aggr-tests:
runs-on: self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Mock aggr tests
- name: Mock aggr tests (KZG)
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
prove-and-verify-aggr-tests-gpu:
@@ -500,7 +483,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -512,29 +495,29 @@ jobs:
prove-and-verify-aggr-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: KZG )tests
- name: KZG tests
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
prove-and-verify-aggr-evm-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -544,7 +527,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: KZG prove and verify aggr tests
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
@@ -555,7 +538,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -574,18 +557,20 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Install cmake
run: sudo apt-get install -y cmake
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Run pytest
@@ -593,15 +578,15 @@ jobs:
accuracy-measurement-tests:
runs-on: ubuntu-latest-32-cores
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -609,7 +594,7 @@ jobs:
crate: cargo-nextest
locked: true
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
@@ -629,10 +614,10 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.9"
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -642,11 +627,15 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Install pip
run: python -m ensurepip --upgrade
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: authenticate-kaggle-cli
# shell: bash
# env:
@@ -657,12 +646,10 @@ jobs:
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
- name: All notebooks
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --no-capture
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
- name: Voice tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --no-capture
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1

3
.gitignore vendored
View File

@@ -48,4 +48,5 @@ node_modules
/dist
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key
!tests/wasm/vk.key
!tests/wasm/vk_aggr.key

1933
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -15,73 +15,97 @@ crate-type = ["cdylib", "rlib"]
[dependencies]
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev="9fff22c", features=["derive_serde"] }
mimalloc = "0.1"
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
"derive_serde",
] }
rand = { version = "0.8", default_features = false }
itertools = { version = "0.10.3", default_features = false }
clap = { version = "4.3.3", features = ["derive"]}
clap = { version = "4.5.3", features = ["derive"] }
serde = { version = "1.0.126", features = ["derive"], optional = true }
serde_json = { version = "1.0.97", default_features = false, features = ["float_roundtrip", "raw_value"], optional = true }
serde_json = { version = "1.0.97", default_features = false, features = [
"float_roundtrip",
"raw_value",
], optional = true }
log = { version = "0.4.17", default_features = false, optional = true }
thiserror = { version = "1.0.38", default_features = false }
hex = { version = "0.4.3", default_features = false }
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features=["derive_serde"]}
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch= "main" }
maybe-rayon = { version = "0.1.1", default_features = false }
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
"derive_serde",
] }
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "ac/update-h2-curves" }
maybe-rayon = { version = "0.1.1", default_features = false }
bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = "1.6.0"
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
# evm related deps
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
ethers = { version = "2.0.11", default_features = false, features = ["ethers-solc"] }
indicatif = {version = "0.17.5", features = ["rayon"]}
gag = { version = "1.0.0", default_features = false}
ethers = { version = "2.0.11", default_features = false, features = [
"ethers-solc",
] }
indicatif = { version = "0.17.5", features = ["rayon"] }
gag = { version = "1.0.0", default_features = false }
instant = { version = "0.1" }
reqwest = { version = "0.11.14", default-features = false, features = ["default-tls", "multipart", "stream"] }
reqwest = { version = "0.11.14", default-features = false, features = [
"default-tls",
"multipart",
"stream",
] }
openssl = { version = "0.10.55", features = ["vendored"] }
postgres = "0.19.5"
pg_bigdecimal = "0.1.5"
lazy_static = "1.4.0"
colored_json = { version = "3.0.1", default_features = false, optional = true}
colored_json = { version = "3.0.1", default_features = false, optional = true }
plotters = { version = "0.3.0", default_features = false, optional = true }
regex = { version = "1", default_features = false }
tokio = { version = "1.26.0", default_features = false, features = ["macros", "rt"] }
tokio = { version = "1.26.0", default_features = false, features = [
"macros",
"rt",
] }
tokio-util = { version = "0.7.9", features = ["codec"] }
pyo3 = { version = "0.20.2", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
pyo3-asyncio = { version = "0.20.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
pyo3 = { version = "0.20.2", features = [
"extension-module",
"abi3-py37",
"macros",
], default_features = false, optional = true }
pyo3-asyncio = { version = "0.20.0", features = [
"attributes",
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
colored = { version = "2.0.0", default_features = false, optional = true}
env_logger = { version = "0.10.0", default_features = false, optional = true}
colored = { version = "2.0.0", default_features = false, optional = true }
env_logger = { version = "0.10.0", default_features = false, optional = true }
chrono = "0.4.31"
sha256 = "1.4.0"
[target.'cfg(target_arch = "wasm32")'.dependencies]
getrandom = { version = "0.2.8", features = ["js"] }
instant = { version = "0.1", features = [ "wasm-bindgen", "inaccurate" ] }
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
wasm-bindgen-rayon = { version = "1.0", optional=true }
wasm-bindgen-test = "0.3.34"
serde-wasm-bindgen = "0.4"
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"]}
wasm-bindgen-rayon = { version = "1.2.1", optional = true }
wasm-bindgen-test = "0.3.42"
serde-wasm-bindgen = "0.6.5"
wasm-bindgen = { version = "0.2.92", features = ["serde-serialize"] }
console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"
[dev-dependencies]
criterion = {version = "0.3", features = ["html_reports"]}
criterion = { version = "0.3", features = ["html_reports"] }
tempfile = "3.3.0"
lazy_static = "1.4.0"
mnist = "0.5"
@@ -152,12 +176,27 @@ required-features = ["ezkl"]
[features]
web = ["wasm-bindgen-rayon"]
default = ["ezkl", "mv-lookup"]
render = ["halo2_proofs/dev-graph", "plotters"]
default = ["ezkl", "mv-lookup", "precompute-coset"]
onnx = ["dep:tract-onnx"]
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
ezkl = ["onnx", "serde", "serde_json", "log", "colored", "env_logger", "tabled/color", "colored_json", "halo2_proofs/circuit-params"]
mv-lookup = ["halo2_proofs/mv-lookup", "snark-verifier/mv-lookup", "halo2_solidity_verifier/mv-lookup"]
ezkl = [
"onnx",
"serde",
"serde_json",
"log",
"colored",
"env_logger",
"tabled/color",
"colored_json",
"halo2_proofs/circuit-params",
]
mv-lookup = [
"halo2_proofs/mv-lookup",
"snark-verifier/mv-lookup",
"halo2_solidity_verifier/mv-lookup",
]
asm = ["halo2curves/asm", "halo2_proofs/asm"]
precompute-coset = ["halo2_proofs/precompute-coset"]
det-prove = []
icicle = ["halo2_proofs/icicle_gpu"]
empty-cmd = []
@@ -165,7 +204,11 @@ no-banner = []
# icicle patch to 0.1.0 if feature icicle is enabled
[patch.'https://github.com/ingonyama-zk/icicle']
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix"}
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#127938f23e7aece10b0f32d2ffc07a6c9f244d03", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
[profile.release]
rustflags = [ "-C", "relocation-model=pic" ]
rustflags = ["-C", "relocation-model=pic"]

View File

@@ -31,9 +31,9 @@ EZKL
[![Notebook](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/zkonduit/ezkl/blob/main/examples/notebooks/simple_demo_all_public.ipynb)
In the backend we use [Halo2](https://github.com/privacy-scaling-explorations/halo2) as a proof system.
In the backend we use the collaboratively-developed [Halo2](https://github.com/privacy-scaling-explorations/halo2) as a proof system.
The generated proofs can then be used on-chain to verify computation, only the Ethereum Virtual Machine (EVM) is supported at the moment.
The generated proofs can then be verified with much less computational resources, including on-chain (with the Ethereum Virtual Machine), in a browser, or on a device.
- If you have any questions, we'd love for you to open up a discussion topic in [Discussions](https://github.com/zkonduit/ezkl/discussions). Alternatively, you can join the ✨[EZKL Community Telegram Group](https://t.me/+QRzaRvTPIthlYWMx)💫.
@@ -45,6 +45,8 @@ The generated proofs can then be used on-chain to verify computation, only the E
### getting started ⚙️
The easiest way to get started is to try out a notebook.
#### Python
Install the python bindings by calling.
@@ -70,7 +72,7 @@ curl https://raw.githubusercontent.com/zkonduit/ezkl/main/install_ezkl_cli.sh |
https://user-images.githubusercontent.com/45801863/236771676-5bbbbfd1-ba6f-418a-902e-20738ce0e9f0.mp4
For more details visit the [docs](https://docs.ezkl.xyz).
For more details visit the [docs](https://docs.ezkl.xyz). The CLI is faster than Python, as it has less overhead. For even more speed and convenience, check out the [remote proving service](https://ei40vx5x6j0.typeform.com/to/sFv1oxvb), which feels like the CLI but is backed by a tuned cluster.
Build the auto-generated rust documentation and open the docs in your browser locally. `cargo doc --open`
@@ -124,17 +126,6 @@ unset ENABLE_ICICLE_GPU
**NOTE:** Even with the above environment variable set, icicle is disabled for circuits where k <= 8. To change the value of `k` where icicle is enabled, you can set the environment variable `ICICLE_SMALL_K`.
### repos
The EZKL project has several libraries and repos.
| Repo | Description |
| --- | --- |
| [@zkonduit/ezkl](https://github.com/zkonduit/ezkl) | the main ezkl repo in rust with wasm and python bindings |
| [@zkonduit/ezkljs](https://github.com/zkonduit/ezkljs) | typescript and javascript tooling to help integrate ezkl into web apps |
----------------------
### contributing 🌎
If you're interested in contributing and are unsure where to start, reach out to one of the maintainers:
@@ -151,7 +142,7 @@ More broadly:
- To report bugs or request new features [create a new issue within Issues](https://github.com/zkonduit/ezkl/issues) to inform the greater community.
Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you shall be licensed to Zkonduit Inc. under the terms and conditions specified in the [CLA](https://github.com/zkonduit/ezkl/blob/main/cla.md), which you agree to by intentionally submitting a contribution. In particular, you have the right to submit the contribution and we can distribute it under the Apache 2.0 license, among other terms and conditions.
Any contribution intentionally submitted for inclusion in the work by you shall be licensed to Zkonduit Inc. under the terms and conditions specified in the [CLA](https://github.com/zkonduit/ezkl/blob/main/cla.md), which you agree to by intentionally submitting a contribution. In particular, you have the right to submit the contribution and we can distribute it, among other terms and conditions.
### no security guarantees
@@ -159,4 +150,7 @@ Ezkl is unaudited, beta software undergoing rapid development. There may be bugs
> NOTE: Because operations are quantized when they are converted from an onnx file to a zk-circuit, outputs in python and ezkl may differ slightly.
### no warranty
Copyright (c) 2024 Zkonduit Inc. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@@ -2,11 +2,13 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_keys;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::TranscriptType;
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -15,6 +17,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
static mut KERNEL_HEIGHT: usize = 2;
static mut KERNEL_WIDTH: usize = 2;
@@ -121,28 +124,35 @@ fn runcnvrl(c: &mut Criterion) {
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(
&circuit, &params, true,
)
.unwrap();
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();

View File

@@ -1,11 +1,13 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -14,6 +16,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
static mut LEN: usize = 4;
@@ -90,25 +93,35 @@ fn rundot(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();

View File

@@ -1,11 +1,13 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -14,6 +16,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
static mut LEN: usize = 4;
@@ -94,25 +97,35 @@ fn runmatmul(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();

View File

@@ -4,17 +4,20 @@ use ezkl::circuit::*;
use ezkl::circuit::lookup::LookupOp;
use ezkl::circuit::poly::PolyOp;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
const BITS: Range = (-32768, 32768);
@@ -112,25 +115,35 @@ fn runmatmul(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::SAFE,
None,
None,
);
prover.unwrap();

View File

@@ -4,17 +4,20 @@ use ezkl::circuit::*;
use ezkl::circuit::lookup::LookupOp;
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::table::Range;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::{Bn256, Fr};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
const BITS: Range = (-8180, 8180);
@@ -115,25 +118,35 @@ fn runmatmul(c: &mut Criterion) {
group.throughput(Throughput::Elements(k as u64));
group.bench_with_input(BenchmarkId::new("pk", k), &k, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(k as u64));
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::SAFE,
None,
None,
);
prover.unwrap();

View File

@@ -1,11 +1,13 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -14,6 +16,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
static mut LEN: usize = 4;
@@ -86,25 +89,35 @@ fn runsum(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();

View File

@@ -2,11 +2,13 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::hybrid::HybridOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_keys;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::TranscriptType;
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -15,6 +17,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
static mut IMAGE_HEIGHT: usize = 2;
static mut IMAGE_WIDTH: usize = 2;
@@ -101,28 +104,35 @@ fn runsumpool(c: &mut Criterion) {
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(
&circuit, &params, true,
)
.unwrap();
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk = create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();

View File

@@ -1,11 +1,13 @@
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -14,6 +16,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
static mut LEN: usize = 4;
@@ -84,25 +87,35 @@ fn runadd(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::SAFE,
None,
None,
);
prover.unwrap();

View File

@@ -2,11 +2,13 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::poly::PolyOp;
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::*;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -15,6 +17,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::marker::PhantomData;
static mut LEN: usize = 4;
@@ -83,25 +86,35 @@ fn runpow(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::SAFE,
None,
None,
);
prover.unwrap();

View File

@@ -1,15 +1,18 @@
use std::collections::HashMap;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use ezkl::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
use ezkl::circuit::modules::Module;
use ezkl::circuit::*;
use ezkl::pfsys::create_keys;
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::srs::gen_srs;
use ezkl::pfsys::TranscriptType;
use ezkl::tensor::*;
use halo2_proofs::circuit::Value;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
arithmetic::Field,
@@ -18,6 +21,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::rngs::OsRng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
const L: usize = 10;
@@ -46,7 +50,7 @@ impl Circuit<Fr> for MyCircuit {
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L> =
PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.image.clone()], 0)?;
chip.layout(&mut layouter, &[self.image.clone()], 0, &mut HashMap::new())?;
Ok(())
}
}
@@ -62,7 +66,7 @@ fn runposeidon(c: &mut Criterion) {
let params = gen_srs::<KZGCommitmentScheme<_>>(k);
let message = (0..*size).map(|_| Fr::random(OsRng)).collect::<Vec<_>>();
let output =
let _output =
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::run(message.to_vec())
.unwrap();
@@ -76,25 +80,35 @@ fn runposeidon(c: &mut Criterion) {
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(*size as u64));
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
MyCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
Some(output[0].clone()),
&pk,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
None,
None,
);
prover.unwrap();

View File

@@ -2,11 +2,12 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::table::Range;
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
use ezkl::pfsys::create_proof_circuit_kzg;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
use ezkl::tensor::*;
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
@@ -14,6 +15,7 @@ use halo2_proofs::{
};
use halo2curves::bn256::{Bn256, Fr};
use rand::Rng;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
const BITS: Range = (-32768, 32768);
static mut LEN: usize = 4;
@@ -91,25 +93,35 @@ fn runrelu(c: &mut Criterion) {
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
b.iter(|| {
create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, &params, true)
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, &params, true)
.unwrap();
});
});
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, &params, true)
.unwrap();
let pk =
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, &params, true).unwrap();
group.throughput(Throughput::Elements(len as u64));
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
b.iter(|| {
let prover = create_proof_circuit_kzg(
let prover = create_proof_circuit::<
KZGCommitmentScheme<_>,
NLCircuit,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
CheckMode::UNSAFE,
ezkl::Commitments::KZG,
TranscriptType::EVM,
SingleStrategy::new(&params),
CheckMode::SAFE,
None,
None,
);
prover.unwrap();

View File

@@ -696,10 +696,12 @@
"for i, value in enumerate(proof[\"instances\"]):\n",
" for j, field_element in enumerate(value):\n",
" onchain_input_array.append(ezkl.felt_to_big_endian(field_element))\n",
" formatted_output += str(onchain_input_array[-1])\n",
" formatted_output += '\"' + str(onchain_input_array[-1]) + '\"'\n",
" if j != len(value) - 1:\n",
" formatted_output += \", \"\n",
" formatted_output += \"]\"\n",
" if i != len(proof[\"instances\"]) - 1:\n",
" formatted_output += \", \"\n",
"formatted_output += \"]\"\n",
"\n",
"# This will be the values you use onchain\n",
"# copy them over to remix and see if they verify\n",

View File

@@ -10,7 +10,7 @@
"\n",
"## Generalized Inverse\n",
"\n",
"We show how to use EZKL to prove that we know matrices $A$ and its generalized inverse $B$. Since these are large we deal with the KZG commitments, with $a$ the kzgcommit of $A$, $b$ the kzgcommit of $B$, and $ABA = A$.\n"
"We show how to use EZKL to prove that we know matrices $A$ and its generalized inverse $B$. Since these are large we deal with the KZG commitments, with $a$ the polycommit of $A$, $b$ the polycommit of $B$, and $ABA = A$.\n"
]
},
{
@@ -77,7 +77,7 @@
"outputs": [],
"source": [
"gip_run_args = ezkl.PyRunArgs()\n",
"gip_run_args.input_visibility = \"kzgcommit\" # matrix and generalized inverse commitments\n",
"gip_run_args.input_visibility = \"polycommit\" # matrix and generalized inverse commitments\n",
"gip_run_args.output_visibility = \"fixed\" # no parameters used\n",
"gip_run_args.param_visibility = \"fixed\" # should be Tensor(True)"
]
@@ -340,4 +340,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -161,7 +161,7 @@
"- `fixed`: known to the prover and verifier (as a commit), but not modifiable by the prover.\n",
"- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. \n",
"- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.\n",
"- `kzgcommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be inserted directly within the proof bytes. \n",
"- `polycommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be inserted directly within the proof bytes. \n",
"\n",
"\n",
"Here we create the following setup:\n",
@@ -510,4 +510,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -67,6 +67,7 @@
"model.add(Dense(128, activation='relu'))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(10, activation='softmax'))\n",
"model.output_names=['output']\n",
"\n",
"\n",
"# Train the model as you like here (skipped for brevity)\n",

View File

@@ -154,11 +154,11 @@
"- `fixed`: known to the prover and verifier (as a commit), but not modifiable by the prover.\n",
"- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. \n",
"- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.\n",
"- `kzgcommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
"- `polycommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"kzgcommit\"\n",
"- `param_visibility`: \"kzgcommit\"\n",
"- `input_visibility`: \"polycommit\"\n",
"- `param_visibility`: \"polycommit\"\n",
"- `output_visibility`: public\n",
"\n",
"We encourage you to play around with other setups :) \n",
@@ -186,8 +186,8 @@
"data_path = os.path.join('input.json')\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"kzgcommit\"\n",
"run_args.param_visibility = \"kzgcommit\"\n",
"run_args.input_visibility = \"polycommit\"\n",
"run_args.param_visibility = \"polycommit\"\n",
"run_args.output_visibility = \"public\"\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"\n",
@@ -512,4 +512,4 @@
},
"nbformat": 4,
"nbformat_minor": 2
}
}

View File

@@ -38,7 +38,7 @@
"import logging\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras.optimizers.legacy import Adam\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.layers import *\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.datasets import mnist\n",
@@ -71,9 +71,11 @@
},
"outputs": [],
"source": [
"opt = Adam()\n",
"ZDIM = 100\n",
"\n",
"opt = Adam()\n",
"\n",
"\n",
"# discriminator\n",
"# 0 if it's fake, 1 if it's real\n",
"x = in1 = Input((28,28))\n",
@@ -114,8 +116,11 @@
"\n",
"gm = Model(in1, x)\n",
"gm.compile('adam', 'mse')\n",
"gm.output_names=['output']\n",
"gm.summary()\n",
"\n",
"opt = Adam()\n",
"\n",
"# GAN\n",
"dm.trainable = False\n",
"x = dm(gm.output)\n",
@@ -415,7 +420,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

View File

@@ -349,6 +349,8 @@
"z_log_var = Dense(ZDIM)(x)\n",
"z = Lambda(lambda x: x[0] + K.exp(0.5 * x[1]) * K.random_normal(shape=K.shape(x[0])))([z_mu, z_log_var])\n",
"dec = get_decoder()\n",
"dec.output_names=['output']\n",
"\n",
"out = dec(z)\n",
"\n",
"mse_loss = mse(Reshape((28*28,))(in1), Reshape((28*28,))(out)) * 28 * 28\n",

View File

@@ -208,7 +208,7 @@
"- `private`: known only to the prover\n",
"- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. \n",
"- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.\n",
"- `kzgcommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
"- `polycommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"public\"\n",
@@ -234,7 +234,7 @@
"run_args.input_scale = 2\n",
"run_args.logrows = 8\n",
"\n",
"ezkl.get_srs(logrows=run_args.logrows)"
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)"
]
},
{
@@ -385,9 +385,9 @@
"### KZG commitment intermediate calculations\n",
"\n",
"This time the visibility parameters are:\n",
"- `input_visibility`: \"kzgcommit\"\n",
"- `input_visibility`: \"polycommit\"\n",
"- `param_visibility`: \"public\"\n",
"- `output_visibility`: kzgcommit"
"- `output_visibility`: polycommit"
]
},
{
@@ -399,9 +399,9 @@
"import ezkl\n",
"\n",
"run_args = ezkl.PyRunArgs()\n",
"run_args.input_visibility = \"kzgcommit\"\n",
"run_args.input_visibility = \"polycommit\"\n",
"run_args.param_visibility = \"fixed\"\n",
"run_args.output_visibility = \"kzgcommit\"\n",
"run_args.output_visibility = \"polycommit\"\n",
"run_args.variables = [(\"batch_size\", 1)]\n",
"run_args.input_scale = 2\n",
"run_args.logrows = 8\n"

View File

@@ -61,11 +61,10 @@
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestClassifier as Rf\n",
"import sk2torch\n",
"import torch\n",
"import ezkl\n",
"import os\n",
"from torch import nn\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"\n",
@@ -77,28 +76,12 @@
"clr.fit(X_train, y_train)\n",
"\n",
"\n",
"trees = []\n",
"for tree in clr.estimators_:\n",
" trees.append(sk2torch.wrap(tree))\n",
"\n",
"\n",
"class RandomForest(nn.Module):\n",
" def __init__(self, trees):\n",
" super(RandomForest, self).__init__()\n",
" self.trees = nn.ModuleList(trees)\n",
"\n",
" def forward(self, x):\n",
" out = self.trees[0](x)\n",
" for tree in self.trees[1:]:\n",
" out += tree(x)\n",
" return out / len(self.trees)\n",
"\n",
"\n",
"torch_rf = RandomForest(trees)\n",
"torch_rf = convert(clr, 'torch')\n",
"# assert predictions from torch are = to sklearn \n",
"diffs = []\n",
"for i in range(len(X_test)):\n",
" torch_pred = torch_rf(torch.tensor(X_test[i].reshape(1, -1)))\n",
" torch_pred = torch_rf.predict(torch.tensor(X_test[i].reshape(1, -1)))\n",
" sk_pred = clr.predict(X_test[i].reshape(1, -1))\n",
" diffs.append(torch_pred[0].round() - sk_pred[0])\n",
"\n",
@@ -134,14 +117,12 @@
"\n",
"# export to onnx format\n",
"\n",
"torch_rf.eval()\n",
"\n",
"# Input to the model\n",
"shape = X_train.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=False)\n",
"torch_out = torch_rf(x)\n",
"torch_out = torch_rf.predict(x)\n",
"# Export the model\n",
"torch.onnx.export(torch_rf, # model being run\n",
"torch.onnx.export(torch_rf.model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
@@ -158,7 +139,7 @@
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
@@ -321,7 +302,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -275,7 +275,6 @@
" proof_path,\n",
" settings_path,\n",
" vk_path,\n",
" \n",
" )\n",
"\n",
"assert res == True\n",
@@ -291,7 +290,7 @@
"source": [
"# Generate a larger SRS. This is needed for the aggregated proof\n",
"\n",
"res = ezkl.get_srs(settings_path=None, logrows=21)"
"res = ezkl.get_srs(settings_path=None, logrows=21, commitment=ezkl.PyCommitments.KZG)"
]
},
{

View File

@@ -9,7 +9,7 @@
"source": [
"## Solvency demo\n",
"\n",
"Here we create a demo of a solvency calculation in the manner of [summa-solvency](https://github.com/summa-dev/summa-solvency). The aim here is to demonstrate the use of the new kzgcommit method detailed [here](https://blog.ezkl.xyz/post/commits/). \n",
"Here we create a demo of a solvency calculation in the manner of [summa-solvency](https://github.com/summa-dev/summa-solvency). The aim here is to demonstrate the use of the new polycommit method detailed [here](https://blog.ezkl.xyz/post/commits/). \n",
"\n",
"In this setup:\n",
"- the commitments to users, respective balances, and total balance are known are publicly known to the prover and verifier. \n",
@@ -177,10 +177,10 @@
"- `private`: known only to the prover\n",
"- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. \n",
"- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.\n",
"- `kzgcommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
"- `polycommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
"\n",
"Here we create the following setup:\n",
"- `input_visibility`: \"kzgcommit\"\n",
"- `input_visibility`: \"polycommit\"\n",
"- `param_visibility`: \"public\"\n",
"- `output_visibility`: public\n",
"\n",
@@ -202,8 +202,8 @@
"outputs": [],
"source": [
"run_args = ezkl.PyRunArgs()\n",
"# \"kzgcommit\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
"run_args.input_visibility = \"kzgcommit\"\n",
"# \"polycommit\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
"run_args.input_visibility = \"polycommit\"\n",
"# the parameters are public\n",
"run_args.param_visibility = \"fixed\"\n",
"# the output is public (this is the inequality test)\n",
@@ -515,4 +515,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}

View File

@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -163,7 +163,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -217,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -227,6 +227,10 @@
" self.length = self.compute_length(self.file_good)\n",
" self.data = self.load_data(self.file_good)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -749,7 +753,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -209,6 +209,11 @@
" self.length = self.compute_length(self.file_good, self.file_bad)\n",
" self.data = self.load_data(self.file_good, self.file_bad)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -637,7 +642,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,40 @@
from torch import nn
import torch
import json
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.layer = nn.LPPool2d(2, 1, (1, 1))
def forward(self, x):
return self.layer(x)[0]
circuit = Model()
x = torch.empty(1, 3, 2, 2).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.7549541592597961, 0.990360677242279, 0.9473411440849304, 0.3951031565666199, 0.8500555753707886, 0.9352139830589294, 0.11867779493331909, 0.9493132829666138, 0.6588345766067505, 0.1933223009109497, 0.12139874696731567, 0.8547163605690002]]}

Binary file not shown.

42
examples/onnx/celu/gen.py Normal file
View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = nn.CELU()(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.35387128591537476, 0.030473172664642334, 0.08707714080810547, 0.2429301142692566, 0.45228832960128784, 0.496021032333374, 0.13245105743408203, 0.8497090339660645]]}

Binary file not shown.

41
examples/onnx/clip/gen.py Normal file
View File

@@ -0,0 +1,41 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.clamp(x, min=0.4, max=0.8)
return m
circuit = MyModel()
x = torch.empty(1, 2, 2, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.03297048807144165, 0.46362626552581787, 0.6044231057167053, 0.4949902892112732, 0.48823297023773193, 0.6798646450042725, 0.6824942231178284, 0.03491640090942383, 0.19608813524246216, 0.24129939079284668, 0.9769315123558044, 0.6306831240653992, 0.7690497636795044, 0.252221941947937, 0.9167693853378296, 0.3882681131362915, 0.9307044148445129, 0.33559417724609375, 0.7815426588058472, 0.3435332179069519, 0.7871478796005249, 0.12240773439407349, 0.5295405983924866, 0.4874419569969177, 0.08262640237808228, 0.1124718189239502, 0.5834914445877075, 0.30927878618240356, 0.48899340629577637, 0.9376634955406189, 0.21893149614334106, 0.526070773601532]]}

View File

@@ -0,0 +1,24 @@
pytorch2.2.1:±
?/Constant_output_0 /Constant"Constant*
value*JÍÌÌ> 
C/Constant_1_output_0 /Constant_1"Constant*
value*JÍÌL? 
F
input
/Constant_output_0
/Constant_1_output_0output/Clip"Clip
main_graphZ)
input


batch_size


b*
output


batch_size


B

41
examples/onnx/gru/gen.py Normal file
View File

@@ -0,0 +1,41 @@
import random
import math
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import json
model = nn.GRU(3, 3) # Input dim is 3, output dim is 3
x = torch.randn(1, 3) # make a sequence of length 5
print(x)
# Flips the neural net into inference mode
model.eval()
model.to('cpu')
# Export the model
torch.onnx.export(model, # model being run
# model input (or a tuple for multiple inputs)
x,
# where to save the model (can be a file or file-like object)
"network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
data_array = ((x).detach().numpy()).reshape([-1]).tolist()
data_json = dict(input_data=[data_array])
print(data_json)
# Serialize data into file:
json.dump(data_json, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.4145222008228302, -0.4043896496295929, 0.7545749545097351]]}

Binary file not shown.

View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.argmax(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.5505883693695068, 0.0766521692276001, 0.12006187438964844, 0.9497959017753601, 0.9100563526153564, 0.968717098236084, 0.5978299379348755, 0.9419963359832764]]}

Binary file not shown.

View File

@@ -9,7 +9,7 @@ class MyModel(nn.Module):
super(MyModel, self).__init__()
def forward(self, x):
m = nn.Logsoftmax()(x)
m = nn.Hardsigmoid()(x)
return m

View File

@@ -1 +1 @@
{"input_data": [[0.2971532940864563, 0.3465197682380676, 0.05381882190704346, 0.058654189109802246, 0.014198064804077148, 0.06088751554489136, 0.1723427176475525, 0.5115123987197876]]}
{"input_data": [[0.8326942324638367, 0.2796096205711365, 0.600328266620636, 0.3701696991920471, 0.17832040786743164, 0.6247223019599915, 0.501872718334198, 0.6961578726768494]]}

View File

@@ -1,4 +1,4 @@
pytorch2.1.0:<3A>
pytorch2.2.1:<3A>
;
inputoutput /HardSigmoid" HardSigmoid*
alpha«ª*> 

View File

@@ -0,0 +1,41 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = nn.Hardswish()(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.6996762752532959, 0.42992985248565674, 0.5102168321609497, 0.5540387630462646, 0.8489438891410828, 0.8533616065979004, 0.36736780405044556, 0.5859147310256958]]}

View File

@@ -0,0 +1,15 @@
pytorch2.2.1:{
&
inputoutput
/HardSwish" HardSwish
main_graphZ!
input


batch_size
b"
output


batch_size
B

View File

@@ -9,7 +9,7 @@ class MyModel(nn.Module):
super(MyModel, self).__init__()
def forward(self, x):
m = nn.Hardsigmoid()(x)
m = nn.LogSoftmax()(x)
return m

View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.logsumexp(x, dim=1)
return m
circuit = MyModel()
x = torch.empty(1, 2, 2, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.7973018884658813, 0.5245689153671265, 0.34149593114852905, 0.1455438733100891, 0.9482707381248474, 0.4221445322036743, 0.001363217830657959, 0.8736765384674072, 0.42954301834106445, 0.7199509739875793, 0.37641745805740356, 0.5920265316963196, 0.42270803451538086, 0.41761744022369385, 0.603948712348938, 0.7250819802284241, 0.047173500061035156, 0.5115441679954529, 0.3743387460708618, 0.16794061660766602, 0.5352339148521423, 0.037976861000061035, 0.65323406457901, 0.5585184097290039, 0.10559147596359253, 0.07827490568161011, 0.6717077493667603, 0.6480781435966492, 0.9780838489532471, 0.8353415131568909, 0.6491701006889343, 0.6573048233985901]]}

Binary file not shown.

View File

@@ -0,0 +1,13 @@
{
"input_data": [
[
0.8894134163856506,
0.8894201517105103
]
],
"output_data": [
[
0.8436377
]
]
}

Binary file not shown.

42
examples/onnx/mish/gen.py Normal file
View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = nn.Mish()(x)
return m
circuit = MyModel()
x = torch.empty(1, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.18563222885131836, 0.4843214750289917, 0.9991059899330139, 0.02534431219100952, 0.8105666041374207, 0.9658406376838684, 0.681107759475708, 0.5365872979164124]]}

View File

@@ -0,0 +1,19 @@
pytorch2.2.1:ä
0
input/Softplus_output_0 /Softplus"Softplus
1
/Softplus_output_0/Tanh_output_0/Tanh"Tanh
*
input
/Tanh_output_0output/Mul"Mul
main_graphZ!
input


batch_size
b"
output


batch_size
B

View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.norm(x, p=1, dim=1)
return m
circuit = MyModel()
x = torch.empty(1, 2, 2, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.02284395694732666, 0.7941043376922607, 0.07971876859664917, 0.8898420929908752, 0.8233054280281067, 0.11066079139709473, 0.4424799084663391, 0.4355071783065796, 0.6723723411560059, 0.6818525195121765, 0.8726171851158142, 0.17742449045181274, 0.054257750511169434, 0.5775953531265259, 0.7758923172950745, 0.8431423306465149, 0.7602444887161255, 0.29686522483825684, 0.22489851713180542, 0.0675363540649414, 0.981339693069458, 0.15771394968032837, 0.5801441669464111, 0.9044001698493958, 0.49266451597213745, 0.42621421813964844, 0.35345613956451416, 0.042848050594329834, 0.6908614039421082, 0.5422852039337158, 0.01975083351135254, 0.5772860050201416]]}

Binary file not shown.

View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.norm(x, p=2, dim=1)
return m
circuit = MyModel()
x = torch.empty(1, 2, 2, 8).uniform_(0, 1)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.8709188103675842, 0.11553549766540527, 0.27376580238342285, 0.7518517971038818, 0.7879393100738525, 0.8765475749969482, 0.14315760135650635, 0.8982420563697815, 0.7274006605148315, 0.39007169008255005, 0.729040801525116, 0.11306107044219971, 0.658822774887085, 0.666404664516449, 0.3001367449760437, 0.45343858003616333, 0.7460223436355591, 0.7423691749572754, 0.7544230818748474, 0.5674425959587097, 0.8728761672973633, 0.27062875032424927, 0.1595977544784546, 0.22975260019302368, 0.6711723208427429, 0.8265992403030396, 0.48679041862487793, 0.689740777015686, 0.330846905708313, 0.5630669593811035, 0.8058932423591614, 0.5802426338195801]]}

Binary file not shown.

42
examples/onnx/tril/gen.py Normal file
View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.triu(x)
return m
circuit = MyModel()
x = torch.empty(1, 3, 3).uniform_(0, 5)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.4870188236236572, 2.275230646133423, 3.126268148422241, 0.6412187218666077, 0.9967470169067383, 1.9814395904541016, 1.6355383396148682, 0.6397527456283569, 0.7825168967247009]]}

Binary file not shown.

42
examples/onnx/triu/gen.py Normal file
View File

@@ -0,0 +1,42 @@
from torch import nn
import torch
import json
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, x):
m = torch.tril(x)
return m
circuit = MyModel()
x = torch.empty(1, 3, 3).uniform_(0, 5)
out = circuit(x)
print(out)
torch.onnx.export(circuit, x, "network.onnx",
export_params=True, # store the trained parameter weights inside the model file
opset_version=17, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
'output': {0: 'batch_size'}})
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
data = dict(
input_data=[d1],
)
# Serialize data into file:
json.dump(data, open("input.json", 'w'))

View File

@@ -0,0 +1 @@
{"input_data": [[0.2898547053337097, 1.8070811033248901, 0.30266255140304565, 3.00581955909729, 0.5379888415336609, 1.7057424783706665, 2.415961265563965, 0.589233934879303, 0.03824889659881592]]}

Binary file not shown.

View File

@@ -17,7 +17,7 @@
"clean": "rm -r dist || true",
"build:commonjs": "tsc --project tsconfig.commonjs.json && resolve-tspaths -p tsconfig.commonjs.json",
"build:esm": "tsc --project tsconfig.esm.json && resolve-tspaths -p tsconfig.esm.json",
"build": "pnpm run clean && pnpm run build:commonjs && pnpm run build:esm"
"build": "npm run clean && npm run build:commonjs && npm run build:esm"
},
"dependencies": {
"@ethereumjs/common": "^4.0.0",

View File

@@ -1,14 +1,14 @@
attrs==22.2.0
exceptiongroup==1.1.1
importlib-metadata==6.1.0
attrs==23.2.0
exceptiongroup==1.2.0
importlib-metadata==7.1.0
iniconfig==2.0.0
maturin==1.0.1
packaging==23.0
pluggy==1.0.0
pytest==7.2.2
maturin==1.5.0
packaging==24.0
pluggy==1.4.0
pytest==8.1.1
tomli==2.0.1
typing-extensions==4.5.0
zipp==3.15.0
onnx==1.14.1
onnxruntime==1.14.1
numpy==1.21.6
typing-extensions==4.10.0
zipp==3.18.1
onnx==1.15.0
onnxruntime==1.17.1
numpy==1.26.4

View File

@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2023-08-24"
channel = "nightly-2024-02-06"
components = ["rustfmt", "clippy"]

View File

@@ -1,4 +1,6 @@
// ignore file if compiling for wasm
#[global_allocator]
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
#[cfg(not(target_arch = "wasm32"))]
use clap::Parser;

View File

@@ -2,7 +2,7 @@
pub mod poseidon;
///
pub mod kzg;
pub mod polycommit;
///
pub mod planner;
@@ -15,6 +15,8 @@ pub use planner::*;
use crate::tensor::{TensorType, ValTensor};
use super::region::ConstantsMap;
/// Module trait used to extend ezkl functionality
pub trait Module<F: PrimeField + TensorType + PartialOrd> {
/// Config
@@ -39,6 +41,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
&self,
layouter: &mut impl Layouter<F>,
input: &[ValTensor<F>],
constants: &mut ConstantsMap<F>,
) -> Result<Self::InputAssignments, Error>;
/// Layout
fn layout(
@@ -46,6 +49,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
layouter: &mut impl Layouter<F>,
input: &[ValTensor<F>],
row_offset: usize,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, Error>;
/// Number of instance values the module uses every time it is applied
fn instance_increment_input(&self) -> Vec<usize>;

View File

@@ -4,49 +4,49 @@ is already implemented in halo2_gadgets, there is no wrapper chip that makes it
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
*/
use std::collections::HashMap;
// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use halo2_proofs::poly::commitment::{Blind, Params};
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
use halo2_proofs::poly::commitment::{Blind, CommitmentScheme, Params};
use halo2_proofs::{circuit::*, plonk::*};
use halo2curves::bn256::{Bn256, G1Affine};
use halo2curves::bn256::G1Affine;
use halo2curves::group::prime::PrimeCurveAffine;
use halo2curves::group::Curve;
use halo2curves::CurveAffine;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType, VarTensor};
use super::Module;
/// The number of instance columns used by the KZG hash function
/// The number of instance columns used by the PolyCommit hash function
pub const NUM_INSTANCE_COLUMNS: usize = 0;
/// The number of advice columns used by the KZG hash function
/// The number of advice columns used by the PolyCommit hash function
pub const NUM_INNER_COLS: usize = 1;
#[derive(Debug, Clone)]
/// WIDTH, RATE and L are const generics for the struct, which represent the width, rate, and number of inputs for the Poseidon hash function, respectively.
/// This means they are values that are known at compile time and can be used to specialize the implementation of the struct.
/// The actual chip provided by halo2_gadgets is added to the parent Chip.
pub struct KZGConfig {
/// Configuration for the PolyCommit chip
pub struct PolyCommitConfig {
///
pub hash_inputs: VarTensor,
pub inputs: VarTensor,
}
type InputAssignments = ();
/// PoseidonChip is a wrapper around the Pow5Chip that adds a set of advice columns to the gadget Chip to store the inputs of the hash
///
#[derive(Debug)]
pub struct KZGChip {
config: KZGConfig,
pub struct PolyCommitChip {
config: PolyCommitConfig,
}
impl KZGChip {
impl PolyCommitChip {
/// Commit to the message using the KZG commitment scheme
pub fn commit(
message: Vec<Fp>,
pub fn commit<Scheme: CommitmentScheme<Scalar = Fp, Curve = G1Affine>>(
message: Vec<Scheme::Scalar>,
degree: u32,
num_unusable_rows: u32,
params: &ParamsKZG<Bn256>,
params: &Scheme::ParamsProver,
) -> Vec<G1Affine> {
let k = params.k();
let domain = halo2_proofs::poly::EvaluationDomain::new(degree, k);
@@ -81,14 +81,14 @@ impl KZGChip {
}
}
impl Module<Fp> for KZGChip {
type Config = KZGConfig;
impl Module<Fp> for PolyCommitChip {
type Config = PolyCommitConfig;
type InputAssignments = InputAssignments;
type RunInputs = Vec<Fp>;
type Params = (usize, usize);
fn name(&self) -> &'static str {
"KZG"
"PolyCommit"
}
fn instance_increment_input(&self) -> Vec<usize> {
@@ -102,14 +102,15 @@ impl Module<Fp> for KZGChip {
/// Configuration of the PoseidonChip
fn configure(meta: &mut ConstraintSystem<Fp>, params: Self::Params) -> Self::Config {
let hash_inputs = VarTensor::new_unblinded_advice(meta, params.0, NUM_INNER_COLS, params.1);
Self::Config { hash_inputs }
let inputs = VarTensor::new_unblinded_advice(meta, params.0, NUM_INNER_COLS, params.1);
Self::Config { inputs }
}
fn layout_inputs(
&self,
_: &mut impl Layouter<Fp>,
_: &[ValTensor<Fp>],
_: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
Ok(())
}
@@ -122,11 +123,24 @@ impl Module<Fp> for KZGChip {
layouter: &mut impl Layouter<Fp>,
input: &[ValTensor<Fp>],
_: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
assert_eq!(input.len(), 1);
let local_constants = constants.clone();
layouter.assign_region(
|| "kzg commit",
|mut region| self.config.hash_inputs.assign(&mut region, 0, &input[0]),
|| "PolyCommit",
|mut region| {
let mut local_inner_constants = local_constants.clone();
let res = self.config.inputs.assign(
&mut region,
0,
&input[0],
&mut local_inner_constants,
)?;
*constants = local_inner_constants;
Ok(res)
},
)
}
@@ -163,7 +177,7 @@ mod tests {
}
impl Circuit<Fp> for HashCircuit {
type Config = KZGConfig;
type Config = PolyCommitConfig;
type FloorPlanner = ModulePlanner;
type Params = ();
@@ -178,7 +192,7 @@ mod tests {
fn configure(meta: &mut ConstraintSystem<Fp>) -> Self::Config {
let params = (K, R);
KZGChip::configure(meta, params)
PolyCommitChip::configure(meta, params)
}
fn synthesize(
@@ -186,8 +200,13 @@ mod tests {
config: Self::Config,
mut layouter: impl Layouter<Fp>,
) -> Result<(), Error> {
let kzg_chip = KZGChip::new(config);
kzg_chip.layout(&mut layouter, &[self.message.clone()], 0);
let polycommit_chip = PolyCommitChip::new(config);
polycommit_chip.layout(
&mut layouter,
&[self.message.clone()],
0,
&mut HashMap::new(),
);
Ok(())
}
@@ -195,7 +214,7 @@ mod tests {
#[test]
#[ignore]
fn kzg_for_a_range_of_input_sizes() {
fn polycommit_chip_for_a_range_of_input_sizes() {
let rng = rand::rngs::OsRng;
#[cfg(not(target_arch = "wasm32"))]
@@ -225,7 +244,7 @@ mod tests {
#[test]
#[ignore]
fn kzg_commit_much_longer_input() {
fn polycommit_chip_much_longer_input() {
#[cfg(not(target_arch = "wasm32"))]
env_logger::init();

View File

@@ -18,6 +18,7 @@ use maybe_rayon::slice::ParallelSlice;
use std::marker::PhantomData;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType};
use super::Module;
@@ -172,12 +173,15 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
&self,
layouter: &mut impl Layouter<Fp>,
message: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
assert_eq!(message.len(), 1);
let message = message[0].clone();
let start_time = instant::Instant::now();
let local_constants = constants.clone();
let res = layouter.assign_region(
|| "load message",
|mut region| {
@@ -199,12 +203,26 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
Ok(v.clone())
}
ValType::Constant(f) => region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
),
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants.get(f).unwrap().assigned_cell().ok_or({
log::error!("constant not previously assigned");
Error::Synthesis
})?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
constants
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
Ok(res)
}
}
e => {
log::error!(
"wrong input type {:?}, must be previously assigned",
@@ -270,8 +288,9 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
layouter: &mut impl Layouter<Fp>,
input: &[ValTensor<Fp>],
row_offset: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input)?;
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
// extract the values from the input cells
let mut assigned_input: Tensor<ValType<Fp>> =
input_cells.iter().map(|e| ValType::from(e.clone())).into();
@@ -434,7 +453,7 @@ mod tests {
*,
};
use std::marker::PhantomData;
use std::{collections::HashMap, marker::PhantomData};
use halo2_gadgets::poseidon::primitives::Spec;
use halo2_proofs::{
@@ -477,7 +496,12 @@ mod tests {
mut layouter: impl Layouter<Fp>,
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, L> = PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.message.clone()], 0)?;
chip.layout(
&mut layouter,
&[self.message.clone()],
0,
&mut HashMap::new(),
)?;
Ok(())
}

View File

@@ -345,7 +345,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
Self {

View File

@@ -46,7 +46,8 @@ pub enum HybridOp {
dim: usize,
},
Softmax {
scale: utils::F32,
input_scale: utils::F32,
output_scale: utils::F32,
axes: Vec<usize>,
},
RangeCheck(Tolerance),
@@ -70,7 +71,7 @@ pub enum HybridOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
@@ -130,9 +131,16 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
kernel_shape,
normalized,
} => tensor::ops::sumpool(&x, *padding, *stride, *kernel_shape, *normalized)?,
HybridOp::Softmax { scale, axes } => {
tensor::ops::nonlinearities::softmax_axes(&x, scale.into(), axes)
}
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => tensor::ops::nonlinearities::softmax_axes(
&x,
input_scale.into(),
output_scale.into(),
axes,
),
HybridOp::RangeCheck(tol) => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
@@ -203,8 +211,15 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
),
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
HybridOp::Softmax { scale, axes } => {
format!("SOFTMAX (scale={}, axes={:?})", scale, axes)
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => {
format!(
"SOFTMAX (input_scale={}, output_scale={}, axes={:?})",
input_scale, output_scale, axes
)
}
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
HybridOp::Greater => "GREATER".into(),
@@ -324,9 +339,18 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::ReduceArgMin { dim } => {
layouts::argmin_axes(config, region, values[..].try_into()?, *dim)?
}
HybridOp::Softmax { scale, axes } => {
layouts::softmax_axes(config, region, values[..].try_into()?, *scale, axes)?
}
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => layouts::softmax_axes(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
axes,
)?,
HybridOp::RangeCheck(tol) => layouts::range_check_percent(
config,
region,
@@ -359,8 +383,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
HybridOp::Softmax { .. } => 2 * in_scales[0],
HybridOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.0 as f64),
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
multiplier_to_scale(output_scale.0 as f64)
}
_ => in_scales[0],
};
Ok(scale)

View File

@@ -9,6 +9,7 @@ use halo2curves::ff::PrimeField;
use itertools::Itertools;
use log::{error, trace};
use maybe_rayon::{
iter::IntoParallelRefIterator,
prelude::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator},
slice::ParallelSliceMut,
};
@@ -33,7 +34,7 @@ use super::*;
use crate::circuit::ops::lookup::LookupOp;
/// Same as div but splits the division into N parts
pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
@@ -68,7 +69,7 @@ pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd>(
}
/// Div accumulated layout
pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
@@ -93,9 +94,9 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.get_int_evals()?;
tensor::ops::nonlinearities::const_div(&input_evals.clone(), felt_to_i128(div) as f64)
.iter()
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
Tensor::new(
@@ -133,7 +134,7 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
}
/// recip accumulated layout
pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
value: &[ValTensor<F>; 1],
@@ -166,9 +167,9 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
felt_to_i128(input_scale) as f64,
felt_to_i128(output_scale) as f64,
)
.iter()
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
} else {
Tensor::new(
@@ -226,7 +227,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
}
/// Dot product accumulated layout
pub(crate) fn dot<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -337,7 +338,7 @@ pub(crate) fn dot<F: PrimeField + TensorType + PartialOrd>(
}
/// Einsum
pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
inputs: &[ValTensor<F>],
@@ -524,14 +525,7 @@ pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd>(
// Compute the product of all input tensors
for pair in input_pairs {
let product_across_pair = prod(
config,
region,
&[pair.try_into().map_err(|e| {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?],
)?;
let product_across_pair = prod(config, region, &[pair.into()])?;
if let Some(product) = prod_res {
prod_res = Some(
@@ -563,7 +557,7 @@ pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -574,12 +568,12 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
let is_assigned = !input.any_unknowns()?;
let sorted = if is_assigned {
input
.get_int_evals()?
.iter()
.sorted_by(|a, b| a.cmp(b))
.map(|x| Ok(Value::known(i128_to_felt(*x))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
let mut int_evals = input.get_int_evals()?;
int_evals.par_sort_unstable_by(|a, b| a.cmp(b));
int_evals
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); input.len()]),
@@ -607,7 +601,7 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
}
///
fn _select_topk<F: PrimeField + TensorType + PartialOrd>(
fn _select_topk<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -622,7 +616,7 @@ fn _select_topk<F: PrimeField + TensorType + PartialOrd>(
}
/// Select top k elements
pub(crate) fn topk_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn topk_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -642,11 +636,12 @@ pub(crate) fn topk_axes<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
fn select<F: PrimeField + TensorType + PartialOrd>(
fn select<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let start = instant::Instant::now();
let (mut input, index) = (values[0].clone(), values[1].clone());
input.flatten();
@@ -656,12 +651,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()?;
let output: ValTensor<F> = if is_assigned {
let output: ValTensor<F> = if is_assigned && region.witness_gen() {
let felt_evals = input.get_felt_evals()?;
index
.get_int_evals()?
.iter()
.map(|x| Ok(Value::known(input.get_felt_evals()?.get(&[*x as usize]))))
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
.par_iter()
.map(|x| Value::known(felt_evals.get(&[*x as usize])))
.collect::<Tensor<Value<F>>>()
} else {
Tensor::new(
Some(&vec![Value::<F>::unknown(); index.len()]),
@@ -673,10 +669,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
let (_, assigned_output) =
dynamic_lookup(config, region, &[index, output], &[dim_indices, input])?;
let end = start.elapsed();
trace!("select took: {:?}", end);
Ok(assigned_output)
}
fn one_hot<F: PrimeField + TensorType + PartialOrd>(
fn one_hot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -692,7 +691,7 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
let output: ValTensor<F> = if is_assigned {
let int_evals = input.get_int_evals()?;
let res = tensor::ops::one_hot(&int_evals, num_classes, 1)?;
res.iter()
res.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<_>>()
} else {
@@ -728,12 +727,13 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
}
/// Dynamic lookup
pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
lookups: &[ValTensor<F>; 2],
tables: &[ValTensor<F>; 2],
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
let start = instant::Instant::now();
// if not all lookups same length err
if lookups[0].len() != lookups[1].len() {
return Err("lookups must be same length".into());
@@ -753,20 +753,28 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
let _table_1 = region.assign_dynamic_lookup(&config.dynamic_lookups.tables[1], &table_1)?;
let table_len = table_0.len();
trace!("assigning tables took: {:?}", start.elapsed());
// now create a vartensor of constants for the dynamic lookup index
let table_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), table_len);
let _table_index =
region.assign_dynamic_lookup(&config.dynamic_lookups.tables[2], &table_index)?;
trace!("assigning table index took: {:?}", start.elapsed());
let lookup_0 = region.assign(&config.dynamic_lookups.inputs[0], &lookup_0)?;
let lookup_1 = region.assign(&config.dynamic_lookups.inputs[1], &lookup_1)?;
let lookup_len = lookup_0.len();
trace!("assigning lookups took: {:?}", start.elapsed());
// now set the lookup index
let lookup_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), lookup_len);
let _lookup_index = region.assign(&config.dynamic_lookups.inputs[2], &lookup_index)?;
trace!("assigning lookup index took: {:?}", start.elapsed());
if !region.is_dummy() {
(0..table_len)
.map(|i| {
@@ -802,11 +810,14 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
region.increment_dynamic_lookup_index(1);
region.increment(lookup_len);
let end = start.elapsed();
trace!("dynamic lookup took: {:?}", end);
Ok((lookup_0, lookup_1))
}
/// Shuffle arg
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
input: &[ValTensor<F>; 1],
@@ -869,7 +880,7 @@ pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd>(
}
/// One hot accumulated layout
pub(crate) fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn one_hot_axis<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -922,7 +933,7 @@ pub(crate) fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
}
/// Gather accumulated layout
pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -950,7 +961,7 @@ pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
}
/// Gather accumulated layout
pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -973,7 +984,7 @@ pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
}
/// Gather accumulated layout
pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -1024,7 +1035,7 @@ pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd>(
/// Takes a tensor representing a multi-dimensional index and returns a tensor representing the linearized index.
/// The linearized index is the index of the element in the flattened tensor.
/// FOr instance if the dims is [3,5,2], the linearized index of [2] at dim 1 is 2*5 + 3 = 13
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1032,6 +1043,7 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
dim: usize,
is_flat_index: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let start_time = instant::Instant::now();
let index = values[0].clone();
if !is_flat_index {
assert_eq!(index.dims().len(), dims.len());
@@ -1105,6 +1117,9 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
region.apply_in_loop(&mut output, inner_loop_function)?;
let elapsed = start_time.elapsed();
trace!("linearize_element_index took: {:?}", elapsed);
Ok(output.into())
}
@@ -1125,7 +1140,7 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
/// If indices_shape[-1] == r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b, where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
/// Let us think of each such r-b ranked tensor as indices_slice. Each scalar value corresponding to data[0:b-1,indices_slice] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Example 1 below)
/// If indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b. Let us think of each such tensors as indices_slice. Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Examples 2, 3, 4 and 5 below)
pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1222,7 +1237,7 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
.iter()
.map(|x| {
let slice = x.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
Ok(index_slice.get_slice(&slice)?)
index_slice.get_slice(&slice)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?
};
@@ -1232,7 +1247,6 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
}
let const_offset = create_constant_tensor(const_offset, 1);
let mut results = vec![];
@@ -1250,16 +1264,18 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
let res = sum(config, region, &[res])?;
results.push(res.get_inner_tensor()?.clone());
// assert than res is less than the product of the dims
assert!(
if region.witness_gen() {
assert!(
res.get_int_evals()?
.iter()
.all(|x| *x < dims.iter().product::<usize>() as i128),
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
dims.iter().product::<usize>(),
index_val.show(),
index_val.show(),
index_dim_multiplier.show(),
res.show()
);
}
}
let result_tensor = Tensor::from(results.into_iter());
@@ -1273,7 +1289,9 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
Ok(output.into())
}
pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn get_missing_set_elements<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -1304,7 +1322,7 @@ pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
}
fullset_evals
.iter()
.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
@@ -1337,7 +1355,7 @@ pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
}
/// Gather accumulated layout
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
@@ -1354,14 +1372,14 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
let claimed_output: ValTensor<F> = if is_assigned {
let claimed_output: ValTensor<F> = if is_assigned && region.witness_gen() {
let input_inner = input.get_int_evals()?;
let index_inner = index.get_int_evals()?.map(|x| x as usize);
let src_inner = src.get_int_evals()?;
let res = tensor::ops::scatter(&input_inner, &index_inner, &src_inner, dim)?;
res.iter()
res.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
@@ -1419,7 +1437,7 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
}
/// Scatter Nd
pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
@@ -1433,14 +1451,14 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
let claimed_output: ValTensor<F> = if is_assigned {
let claimed_output: ValTensor<F> = if is_assigned && region.witness_gen() {
let input_inner = input.get_int_evals()?;
let index_inner = index.get_int_evals()?.map(|x| x as usize);
let src_inner = src.get_int_evals()?;
let res = tensor::ops::scatter_nd(&input_inner, &index_inner, &src_inner)?;
res.iter()
res.par_iter()
.map(|x| Value::known(i128_to_felt(*x)))
.collect::<Tensor<Value<F>>>()
.into()
@@ -1457,7 +1475,6 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
region.increment(claimed_output.len());
claimed_output.reshape(input.dims())?;
// scatter elements is the inverse of gather elements
let (gather_src, linear_index) =
gather_nd(config, region, &[claimed_output.clone(), index.clone()], 0)?;
@@ -1498,7 +1515,7 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
}
/// sum accumulated layout
pub(crate) fn sum<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1581,7 +1598,7 @@ pub(crate) fn sum<F: PrimeField + TensorType + PartialOrd>(
}
/// product accumulated layout
pub(crate) fn prod<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1661,7 +1678,7 @@ pub(crate) fn prod<F: PrimeField + TensorType + PartialOrd>(
}
/// Axes wise op wrapper
fn axes_wise_op<F: PrimeField + TensorType + PartialOrd>(
fn axes_wise_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1722,7 +1739,7 @@ fn axes_wise_op<F: PrimeField + TensorType + PartialOrd>(
}
/// Sum accumulated layout
pub(crate) fn prod_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn prod_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1733,7 +1750,7 @@ pub(crate) fn prod_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Sum accumulated layout
pub(crate) fn sum_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn sum_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1744,7 +1761,7 @@ pub(crate) fn sum_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// argmax layout
pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1762,7 +1779,7 @@ pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Max accumulated layout
pub(crate) fn max_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn max_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1774,7 +1791,7 @@ pub(crate) fn max_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Argmin layout
pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1792,7 +1809,7 @@ pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Min accumulated layout
pub(crate) fn min_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn min_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1804,7 +1821,7 @@ pub(crate) fn min_axes<F: PrimeField + TensorType + PartialOrd>(
}
/// Pairwise (elementwise) op layout
pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -1959,7 +1976,7 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd>(
}
/// expand the tensor to the given shape
pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -1972,7 +1989,7 @@ pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd>(
}
///
pub(crate) fn greater<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -1995,7 +2012,7 @@ pub(crate) fn greater<F: PrimeField + TensorType + PartialOrd>(
}
///
pub(crate) fn greater_equal<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn greater_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2018,7 +2035,7 @@ pub(crate) fn greater_equal<F: PrimeField + TensorType + PartialOrd>(
}
///
pub(crate) fn less<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn less<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2028,7 +2045,7 @@ pub(crate) fn less<F: PrimeField + TensorType + PartialOrd>(
}
///
pub(crate) fn less_equal<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn less_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2038,7 +2055,7 @@ pub(crate) fn less_equal<F: PrimeField + TensorType + PartialOrd>(
}
/// And boolean operation
pub(crate) fn and<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn and<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2049,7 +2066,7 @@ pub(crate) fn and<F: PrimeField + TensorType + PartialOrd>(
}
/// Or boolean operation
pub(crate) fn or<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn or<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2065,7 +2082,7 @@ pub(crate) fn or<F: PrimeField + TensorType + PartialOrd>(
}
/// Equality boolean operation
pub(crate) fn equals<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn equals<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2075,7 +2092,7 @@ pub(crate) fn equals<F: PrimeField + TensorType + PartialOrd>(
}
/// Equality boolean operation
pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2109,7 +2126,7 @@ pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
}
/// Xor boolean operation
pub(crate) fn xor<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn xor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2135,7 +2152,7 @@ pub(crate) fn xor<F: PrimeField + TensorType + PartialOrd>(
}
/// Not boolean operation
pub(crate) fn not<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn not<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2151,7 +2168,7 @@ pub(crate) fn not<F: PrimeField + TensorType + PartialOrd>(
}
/// Iff
pub(crate) fn iff<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn iff<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 3],
@@ -2175,7 +2192,7 @@ pub(crate) fn iff<F: PrimeField + TensorType + PartialOrd>(
}
/// Negation operation accumulated layout
pub(crate) fn neg<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2185,7 +2202,7 @@ pub(crate) fn neg<F: PrimeField + TensorType + PartialOrd>(
}
/// Sumpool accumulated layout
pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
@@ -2239,7 +2256,7 @@ pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd>(
}
/// Convolution accumulated layout
pub(crate) fn max_pool2d<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn max_pool2d<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2304,7 +2321,7 @@ pub(crate) fn max_pool2d<F: PrimeField + TensorType + PartialOrd>(
/// DeConvolution accumulated layout
pub(crate) fn deconv<
F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::marker::Sync,
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync,
>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
@@ -2397,7 +2414,7 @@ pub(crate) fn deconv<
/// Convolution accumulated layout
pub(crate) fn conv<
F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::marker::Sync,
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync,
>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
@@ -2578,7 +2595,7 @@ pub(crate) fn conv<
}
/// Power accumulated layout
pub(crate) fn pow<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn pow<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2594,7 +2611,7 @@ pub(crate) fn pow<F: PrimeField + TensorType + PartialOrd>(
}
/// Rescaled op accumulated layout
pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>],
@@ -2616,7 +2633,7 @@ pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd>(
}
/// Dummy (no contraints) reshape layout
pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
new_dims: &[usize],
) -> Result<ValTensor<F>, Box<dyn Error>> {
@@ -2626,7 +2643,7 @@ pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd>(
}
/// Dummy (no contraints) move_axis layout
pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>; 1],
source: usize,
destination: usize,
@@ -2637,7 +2654,7 @@ pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd>(
}
/// resize layout
pub(crate) fn resize<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn resize<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2651,7 +2668,7 @@ pub(crate) fn resize<F: PrimeField + TensorType + PartialOrd>(
}
/// Slice layout
pub(crate) fn slice<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn slice<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2660,15 +2677,45 @@ pub(crate) fn slice<F: PrimeField + TensorType + PartialOrd>(
end: &usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// assigns the instance to the advice.
let mut output = region.assign(&config.custom_gates.output, &values[0])?;
region.increment(output.len());
let mut output = values[0].clone();
let is_assigned = output.all_prev_assigned();
if !is_assigned {
output = region.assign(&config.custom_gates.output, &values[0])?;
region.increment(output.len());
}
output.slice(axis, start, end)?;
Ok(output)
}
/// Trilu layout
pub(crate) fn trilu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
k: &i32,
upper: &bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// assigns the instance to the advice.
let mut output = values[0].clone();
let is_assigned = output.all_prev_assigned();
if !is_assigned {
output = region.assign(&config.custom_gates.inputs[0], &values[0])?;
}
let res = tensor::ops::trilu(output.get_inner_tensor()?, *k, *upper)?;
let output = region.assign(&config.custom_gates.output, &res.into())?;
region.increment(output.len());
Ok(output)
}
/// Concat layout
pub(crate) fn concat<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn concat<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
values: &[ValTensor<F>],
axis: &usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
@@ -2680,7 +2727,7 @@ pub(crate) fn concat<F: PrimeField + TensorType + PartialOrd>(
}
/// Identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2695,7 +2742,7 @@ pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd>(
}
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2731,7 +2778,7 @@ pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
}
/// Downsample layout
pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2748,7 +2795,7 @@ pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd>(
}
/// layout for enforcing two sets of cells to be equal
pub(crate) fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn enforce_equality<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],
@@ -2774,7 +2821,7 @@ pub(crate) fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
}
/// layout for range check.
pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2830,12 +2877,13 @@ pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd>(
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
if region.throw_range_check_error() {
let is_assigned = !w.any_unknowns()?;
if is_assigned && region.witness_gen() {
// assert is within range
let int_values = w.get_int_evals()?;
for v in int_values {
if v < range.0 || v > range.1 {
log::debug!("Value ({:?}) out of range: {:?}", v, range);
for v in int_values.iter() {
if v < &range.0 || v > &range.1 {
log::error!("Value ({:?}) out of range: {:?}", v, range);
return Err(Box::new(TensorError::TableLookupError));
}
}
@@ -2855,7 +2903,7 @@ pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd>(
}
/// layout for nonlinearity check.
pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2957,7 +3005,7 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
}
/// Argmax
pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -2993,7 +3041,7 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
}
/// Argmin
pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3029,7 +3077,7 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
}
/// max layout
pub(crate) fn max<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn max<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3039,7 +3087,7 @@ pub(crate) fn max<F: PrimeField + TensorType + PartialOrd>(
}
/// min layout
pub(crate) fn min<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn min<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3047,7 +3095,7 @@ pub(crate) fn min<F: PrimeField + TensorType + PartialOrd>(
_sort_ascending(config, region, values)?.get_slice(&[0..1])
}
fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd>(
fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
@@ -3150,18 +3198,19 @@ fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd>(
}
/// softmax layout
pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
scale: utils::F32,
input_scale: utils::F32,
output_scale: utils::F32,
axes: &[usize],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let soft_max_at_scale = move |config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1]|
-> Result<ValTensor<F>, Box<dyn Error>> {
softmax(config, region, values, scale)
softmax(config, region, values, input_scale, output_scale)
};
let output = multi_dim_axes_op(config, region, values, axes, soft_max_at_scale)?;
@@ -3169,33 +3218,66 @@ pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
/// softmax func
pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd>(
/// percent func
pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
scale: utils::F32,
input_scale: utils::F32,
output_scale: utils::F32,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// elementwise exponential
let ex = nonlinearity(config, region, values, &LookupOp::Exp { scale })?;
let is_assigned = values[0].all_prev_assigned();
let mut input = values[0].clone();
if !is_assigned {
input = region.assign(&config.custom_gates.inputs[0], &values[0])?;
region.increment(input.len());
};
// sum of exps
let denom = sum(config, region, &[ex.clone()])?;
// get the inverse
let felt_scale = F::from(scale.0 as u64);
let inv_denom = recip(config, region, &[denom], felt_scale, felt_scale)?;
let denom = sum(config, region, &[input.clone()])?;
let input_felt_scale = F::from(input_scale.0 as u64);
let output_felt_scale = F::from(output_scale.0 as u64);
let inv_denom = recip(
config,
region,
&[denom],
input_felt_scale,
output_felt_scale,
)?;
// product of num * (1 / denom) = 2*output_scale
let softmax = pairwise(config, region, &[ex, inv_denom], BaseOp::Mult)?;
let percent = pairwise(config, region, &[input, inv_denom], BaseOp::Mult)?;
Ok(softmax)
// rebase the percent to 2x the scale
loop_div(config, region, &[percent], input_felt_scale)
}
/// softmax func
pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
input_scale: utils::F32,
output_scale: utils::F32,
) -> Result<ValTensor<F>, Box<dyn Error>> {
// get the max then subtract it
let max_val = max(config, region, values)?;
// rebase the input to 0
let sub = pairwise(config, region, &[values[0].clone(), max_val], BaseOp::Sub)?;
// elementwise exponential
let ex = nonlinearity(
config,
region,
&[sub],
&LookupOp::Exp { scale: input_scale },
)?;
percent(config, region, &[ex.clone()], input_scale, output_scale)
}
/// Checks that the percent error between the expected public output and the actual output value
/// is within the percent error expressed by the `tol` input, where `tol == 1.0` means the percent
/// error tolerance is 1 percent.
pub(crate) fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
pub(crate) fn range_check_percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 2],

View File

@@ -123,6 +123,9 @@ pub enum LookupOp {
scale: utils::F32,
a: utils::F32,
},
HardSwish {
scale: utils::F32,
},
}
impl LookupOp {
@@ -132,9 +135,56 @@ impl LookupOp {
let range = range as i128;
(-range, range)
}
/// as path
pub fn as_path(&self) -> String {
match self {
LookupOp::Abs => "abs".into(),
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
LookupOp::Floor { scale } => format!("floor_{}", scale),
LookupOp::Round { scale } => format!("round_{}", scale),
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
LookupOp::KroneckerDelta => "kronecker_delta".into(),
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
LookupOp::Sign => "sign".into(),
LookupOp::LessThan { a } => format!("less_than_{}", a),
LookupOp::LessThanEqual { a } => format!("less_than_equal_{}", a),
LookupOp::GreaterThan { a } => format!("greater_than_{}", a),
LookupOp::GreaterThanEqual { a } => format!("greater_than_equal_{}", a),
LookupOp::Div { denom } => format!("div_{}", denom),
LookupOp::Cast { scale } => format!("cast_{}", scale),
LookupOp::Recip {
input_scale,
output_scale,
} => format!("recip_{}_{}", input_scale, output_scale),
LookupOp::ReLU => "relu".to_string(),
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
LookupOp::Erf { scale } => format!("erf_{}", scale),
LookupOp::Exp { scale } => format!("exp_{}", scale),
LookupOp::Ln { scale } => format!("ln_{}", scale),
LookupOp::Cos { scale } => format!("cos_{}", scale),
LookupOp::ACos { scale } => format!("acos_{}", scale),
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
LookupOp::ACosh { scale } => format!("acosh_{}", scale),
LookupOp::Sin { scale } => format!("sin_{}", scale),
LookupOp::ASin { scale } => format!("asin_{}", scale),
LookupOp::Sinh { scale } => format!("sinh_{}", scale),
LookupOp::ASinh { scale } => format!("asinh_{}", scale),
LookupOp::Tan { scale } => format!("tan_{}", scale),
LookupOp::ATan { scale } => format!("atan_{}", scale),
LookupOp::ATanh { scale } => format!("atanh_{}", scale),
LookupOp::Tanh { scale } => format!("tanh_{}", scale),
LookupOp::HardSwish { scale } => format!("hardswish_{}", scale),
}
}
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
@@ -223,6 +273,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
LookupOp::ATan { scale } => Ok(tensor::ops::nonlinearities::atan(&x, scale.into())),
LookupOp::ATanh { scale } => Ok(tensor::ops::nonlinearities::atanh(&x, scale.into())),
LookupOp::Tanh { scale } => Ok(tensor::ops::nonlinearities::tanh(&x, scale.into())),
LookupOp::HardSwish { scale } => {
Ok(tensor::ops::nonlinearities::hardswish(&x, scale.into()))
}
}?;
let output = res.map(|x| i128_to_felt(x));
@@ -276,6 +329,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
LookupOp::ASin { scale } => format!("ASIN(scale={})", scale),
LookupOp::Sinh { scale } => format!("SINH(scale={})", scale),
LookupOp::ASinh { scale } => format!("ASINH(scale={})", scale),
LookupOp::HardSwish { scale } => format!("HARDSWISH(scale={})", scale),
}
}

View File

@@ -27,12 +27,14 @@ pub mod region;
/// A struct representing the result of a forward pass.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
pub(crate) output: Tensor<F>,
}
/// A trait representing operations that can be represented as constraints in a circuit.
pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send + Sync + Any {
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
std::fmt::Debug + Send + Sync + Any
{
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
/// Returns a string representation of the operation.
@@ -98,7 +100,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
}
}
impl<F: PrimeField + TensorType + PartialOrd> Clone for Box<dyn Op<F>> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
fn clone(&self) -> Self {
self.clone_dyn()
}
@@ -165,7 +167,7 @@ pub struct Input {
pub datum_type: InputType,
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
Ok(self.scale)
}
@@ -226,7 +228,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Unknown;
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
Ok(0)
}
@@ -256,7 +258,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
///
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
///
pub quantized_values: Tensor<F>,
///
@@ -266,7 +268,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
pub pre_assigned_val: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
///
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
Self {
@@ -293,8 +295,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
}
}
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
for Constant<F>
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self

View File

@@ -83,10 +83,20 @@ pub enum PolyOp {
And,
Or,
Xor,
Trilu {
upper: bool,
k: i32,
},
}
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
for PolyOp
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
@@ -114,7 +124,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Add => "ADD".into(),
PolyOp::Mult => "MULT".into(),
PolyOp::Sub => "SUB".into(),
PolyOp::Sum { .. } => "SUM".into(),
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Conv { .. } => "CONV".into(),
@@ -128,6 +138,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::And => "AND".into(),
PolyOp::Or => "OR".into(),
PolyOp::Xor => "XOR".into(),
PolyOp::Trilu { upper, k } => format!("TRILU (upper={}, k={})", upper, k),
}
}
@@ -265,6 +276,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
};
tensor::ops::scatter_nd(&x, &idx, &src)
}
PolyOp::Trilu { upper, k } => tensor::ops::trilu(&inputs[0], *k, *upper),
}?;
Ok(ForwardResult { output: res })
@@ -384,6 +396,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Slice { axis, start, end } => {
layouts::slice(config, region, values[..].try_into()?, axis, start, end)?
}
PolyOp::Trilu { upper, k } => {
layouts::trilu(config, region, values[..].try_into()?, k, upper)?
}
}))
}

View File

@@ -2,24 +2,28 @@ use crate::{
circuit::table::Range,
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
};
#[cfg(not(target_arch = "wasm32"))]
use colored::Colorize;
use halo2_proofs::{
circuit::Region,
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use portable_atomic::AtomicI128 as AtomicInt;
use std::{
cell::RefCell,
collections::HashSet,
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
};
use portable_atomic::AtomicI128 as AtomicInt;
use super::lookup::LookupOp;
/// Constants map
pub type ConstantsMap<F> = HashMap<F, ValType<F>>;
/// Dynamic lookup index
#[derive(Clone, Debug, Default)]
pub struct DynamicLookupIndex {
@@ -120,12 +124,11 @@ impl From<Box<dyn std::error::Error>> for RegionError {
#[derive(Debug)]
/// A context for a region
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
region: Option<RefCell<Region<'a, F>>>,
row: usize,
linear_coord: usize,
num_inner_cols: usize,
total_constants: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
used_lookups: HashSet<LookupOp>,
@@ -133,13 +136,34 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
max_lookup_inputs: i128,
min_lookup_inputs: i128,
max_range_size: i128,
throw_range_check_error: bool,
witness_gen: bool,
assigned_constants: ConstantsMap<F>,
}
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a, F> {
#[cfg(not(target_arch = "wasm32"))]
///
pub fn increment_total_constants(&mut self, n: usize) {
self.total_constants += n;
pub fn debug_report(&self) {
log::debug!(
"(rows={}, coord={}, constants={}, max_lookup_inputs={}, min_lookup_inputs={}, max_range_size={}, dynamic_lookup_col_coord={}, shuffle_col_coord={})",
self.row().to_string().blue(),
self.linear_coord().to_string().yellow(),
self.total_constants().to_string().red(),
self.max_lookup_inputs().to_string().green(),
self.min_lookup_inputs().to_string().green(),
self.max_range_size().to_string().green(),
self.dynamic_lookup_col_coord().to_string().green(),
self.shuffle_col_coord().to_string().green());
}
///
pub fn assigned_constants(&self) -> &ConstantsMap<F> {
&self.assigned_constants
}
///
pub fn update_constants(&mut self, constants: ConstantsMap<F>) {
self.assigned_constants.extend(constants);
}
///
@@ -163,8 +187,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
///
pub fn throw_range_check_error(&self) -> bool {
self.throw_range_check_error
pub fn witness_gen(&self) -> bool {
self.witness_gen
}
/// Create a new region context
@@ -177,7 +201,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
row,
linear_coord,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -185,9 +208,22 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
witness_gen: true,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_with_constants(
region: Region<'a, F>,
row: usize,
num_inner_cols: usize,
constants: ConstantsMap<F>,
) -> RegionCtx<'a, F> {
let mut new_self = Self::new(region, row, num_inner_cols);
new_self.assigned_constants = constants;
new_self
}
/// Create a new region context from a wrapped region
pub fn from_wrapped_region(
region: Option<RefCell<Region<'a, F>>>,
@@ -202,7 +238,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index,
shuffle_index,
used_lookups: HashSet::new(),
@@ -210,16 +245,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
witness_gen: false,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_dummy(
row: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
pub fn new_dummy(row: usize, num_inner_cols: usize, witness_gen: bool) -> RegionCtx<'a, F> {
let region = None;
let linear_coord = row * num_inner_cols;
@@ -228,7 +260,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -236,17 +267,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
witness_gen,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_dummy_with_constants(
pub fn new_dummy_with_linear_coord(
row: usize,
linear_coord: usize,
total_constants: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
witness_gen: bool,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -254,7 +285,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -262,7 +292,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
witness_gen,
assigned_constants: HashMap::new(),
}
}
@@ -312,29 +343,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
) -> Result<(), RegionError> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let constants = AtomicUsize::new(self.total_constants());
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
*output = output
.par_enum_map(|idx, _| {
// we kick off the loop with the current offset
let starting_offset = row.load(Ordering::SeqCst);
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
let starting_constants = constants.load(Ordering::SeqCst);
// get inner value of the locked lookups
// we need to make sure that the region is not shared between threads
let mut local_reg = Self::new_dummy_with_constants(
let mut local_reg = Self::new_dummy_with_linear_coord(
starting_offset,
starting_linear_coord,
starting_constants,
self.num_inner_cols,
self.throw_range_check_error,
self.witness_gen,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -343,10 +372,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
local_reg.linear_coord() - starting_linear_coord,
Ordering::SeqCst,
);
constants.fetch_add(
local_reg.total_constants() - starting_constants,
Ordering::SeqCst,
);
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
@@ -362,11 +387,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants);
res
})
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
self.total_constants = constants.into_inner();
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
{
@@ -410,6 +437,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?;
self.assigned_constants = Arc::try_unwrap(constants)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?;
Ok(())
}
@@ -435,7 +470,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
range: Range,
) -> Result<(), Box<dyn std::error::Error>> {
if range.0 > range.1 {
return Err("update_max_min_lookup_range: invalid range".into());
return Err(format!("update_max_min_lookup_range: invalid range {:?}", range).into());
}
let range_size = (range.1 - range.0).abs();
@@ -477,7 +512,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
/// Get the total number of constants
pub fn total_constants(&self) -> usize {
self.total_constants
self.assigned_constants.len()
}
/// Get the dynamic lookup index
@@ -525,26 +560,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.max_range_size
}
/// Assign a constant value
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
self.total_constants += 1;
if let Some(region) = &self.region {
let cell = var.assign_constant(&mut region.borrow_mut(), self.linear_coord, value)?;
Ok(cell.into())
} else {
Ok(value.into())
}
}
/// Assign a valtensor to a vartensor
pub fn assign(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(&mut region.borrow_mut(), self.linear_coord, values)
var.assign(
&mut region.borrow_mut(),
self.linear_coord,
values,
&mut self.assigned_constants,
)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -560,14 +593,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(
&mut region.borrow_mut(),
self.combined_dynamic_shuffle_coord(),
values,
&mut self.assigned_constants,
)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -594,13 +631,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)
} else {
self.total_constants += values.num_constants();
let inner_tensor = values.get_inner_tensor().unwrap();
let mut values_map = values.create_constants_map();
for o in ommissions {
self.total_constants -= inner_tensor.get_flat_index(**o).is_constant() as usize;
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
values_map.remove(&value);
}
}
self.assigned_constants.extend(values_map);
Ok(values.clone())
}
}
@@ -615,24 +659,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len, total_assigned_constants) = var.assign_with_duplication(
let (res, len) = var.assign_with_duplication(
&mut region.borrow_mut(),
self.row,
self.linear_coord,
values,
check_mode,
single_inner_col,
&mut self.assigned_constants,
)?;
self.total_constants += total_assigned_constants;
Ok((res, len))
} else {
let (_, len, total_assigned_constants) = var.dummy_assign_with_duplication(
let (_, len) = var.dummy_assign_with_duplication(
self.row,
self.linear_coord,
values,
single_inner_col,
&mut self.assigned_constants,
)?;
self.total_constants += total_assigned_constants;
Ok((values.clone(), len))
}
}
@@ -699,9 +743,4 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
Ok(())
}
/// increment constants
pub fn increment_constants(&mut self, n: usize) {
self.total_constants += n
}
}

View File

@@ -6,7 +6,7 @@ use halo2_proofs::{
circuit::{Layouter, Value},
plonk::{ConstraintSystem, Expression, TableColumn},
};
use log::warn;
use log::{debug, warn};
use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator};
use crate::{
@@ -27,6 +27,13 @@ pub const RANGE_MULTIPLIER: i128 = 2;
/// The safety factor offset for the number of rows in the lookup table.
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
lazy_static::lazy_static! {
/// an optional directory to read and write the lookup table cache
static ref LOOKUP_CACHE: Option<std::path::PathBuf> = std::env::var("LOOKUP_CACHE")
.ok()
.map(std::path::PathBuf::from);
}
#[derive(Debug, Clone)]
///
pub struct SelectorConstructor<F: PrimeField> {
@@ -98,7 +105,7 @@ pub struct Table<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
@@ -138,7 +145,7 @@ pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
(range_len / (col_size as i128)) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// Configures the table.
pub fn configure(
cs: &mut ConstraintSystem<F>,
@@ -152,7 +159,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
// number of cols needed to store the range
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
log::debug!("table range: {:?}", range);
debug!("table range: {:?}", range);
let table_inputs = preexisting_inputs.unwrap_or_else(|| {
let mut cols = vec![];
@@ -205,8 +212,46 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
let smallest = self.range.0;
let largest = self.range.1;
let inputs = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let evals = Op::<F>::f(&self.nonlinearity, &[inputs.clone()])?;
let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
let inputs = Tensor::from(smallest..=largest)
.par_enum_map(|_, x| Ok::<_, crate::tensor::TensorError>(i128_to_felt(x)))?;
let evals = Op::<F>::f(&self.nonlinearity, &[inputs.clone()])?;
Ok((inputs, evals.output))
};
let (inputs, evals) = if let Some(cache) = &*LOOKUP_CACHE {
let cache_path = cache.join(self.nonlinearity.as_path());
let input_path = cache_path.join("inputs");
let output_path = cache_path.join("outputs");
if cache_path.exists() {
log::info!("Loading lookup table from cache: {:?}", cache_path);
let (input_cache, output_cache) =
(Tensor::load(&input_path)?, Tensor::load(&output_path)?);
(input_cache, output_cache)
} else {
log::info!(
"Generating lookup table and saving to cache: {:?}",
cache_path
);
// mkdir -p cache_path
std::fs::create_dir_all(&cache_path)?;
let (inputs, evals) = gen_table()?;
inputs.save(&input_path)?;
evals.save(&output_path)?;
(inputs, evals)
}
} else {
log::info!(
"Generating lookup table {} without cache",
self.nonlinearity.as_path()
);
gen_table()?
};
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;
@@ -238,7 +283,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
)?;
}
let output = evals.output[row_offset];
let output = evals[row_offset];
table.assign_cell(
|| format!("nl_o_col row {}", row_offset),
@@ -275,7 +320,12 @@ pub struct RangeCheck<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// as path
pub fn as_path(&self) -> String {
format!("rangecheck_{}_{}", self.range.0, self.range.1)
}
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> F {
let chunk = chunk as i128;
@@ -303,7 +353,7 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
}
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// Configures the table.
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
log::debug!("range check range: {:?}", range);
@@ -353,7 +403,31 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
let smallest = self.range.0;
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let inputs: Tensor<F> = if let Some(cache) = &*LOOKUP_CACHE {
let cache_path = cache.join(self.as_path());
let input_path = cache_path.join("inputs");
if cache_path.exists() {
log::info!("Loading range check table from cache: {:?}", cache_path);
Tensor::load(&input_path)?
} else {
log::info!(
"Generating range check table and saving to cache: {:?}",
cache_path
);
// mkdir -p cache_path
std::fs::create_dir_all(&cache_path)?;
let inputs = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
inputs.save(&input_path)?;
inputs
}
} else {
log::info!("Generating range check {} without cache", self.as_path());
Tensor::from(smallest..=largest).map(|x| i128_to_felt(x))
};
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;

View File

@@ -245,7 +245,13 @@ mod matmul_col_overflow {
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod matmul_col_ultra_overflow_double_col {
use halo2_proofs::poly::commitment::{Params, ParamsProver};
use halo2_proofs::poly::kzg::{
commitment::KZGCommitmentScheme,
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use super::*;
@@ -324,46 +330,46 @@ mod matmul_col_ultra_overflow_double_col {
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
MatmulCircuit<F>,
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
let prover = crate::pfsys::create_proof_circuit::<
KZGCommitmentScheme<_>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
crate::pfsys::TranscriptType::EVM,
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(&params),
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
None,
None,
);
assert!(prover.is_ok());
let proof = prover.unwrap();
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
}
}
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod matmul_col_ultra_overflow {
use halo2_proofs::poly::commitment::{Params, ParamsProver};
use halo2_proofs::poly::kzg::{
commitment::KZGCommitmentScheme,
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use super::*;
@@ -441,39 +447,33 @@ mod matmul_col_ultra_overflow {
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
MatmulCircuit<F>,
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
let prover = crate::pfsys::create_proof_circuit::<
KZGCommitmentScheme<_>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
crate::pfsys::TranscriptType::EVM,
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(&params),
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
None,
None,
);
assert!(prover.is_ok());
let proof = prover.unwrap();
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
}
}
@@ -1145,7 +1145,15 @@ mod conv {
#[cfg(test)]
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod conv_col_ultra_overflow {
use halo2_proofs::poly::commitment::{Params, ParamsProver};
use halo2_proofs::poly::{
kzg::strategy::SingleStrategy,
kzg::{
commitment::KZGCommitmentScheme,
multiopen::{ProverSHPLONK, VerifierSHPLONK},
},
};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use super::*;
@@ -1243,39 +1251,33 @@ mod conv_col_ultra_overflow {
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ConvCircuit<F>,
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
let prover = crate::pfsys::create_proof_circuit::<
KZGCommitmentScheme<_>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
crate::pfsys::TranscriptType::EVM,
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(&params),
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
None,
None,
);
assert!(prover.is_ok());
let proof = prover.unwrap();
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
}
}
@@ -1283,7 +1285,13 @@ mod conv_col_ultra_overflow {
// not wasm 32 unknown
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
mod conv_relu_col_ultra_overflow {
use halo2_proofs::poly::commitment::{Params, ParamsProver};
use halo2_proofs::poly::kzg::{
commitment::KZGCommitmentScheme,
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use super::*;
@@ -1396,39 +1404,33 @@ mod conv_relu_col_ultra_overflow {
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ConvCircuit<F>,
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
let prover = crate::pfsys::create_proof_circuit::<
KZGCommitmentScheme<_>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
crate::pfsys::TranscriptType::EVM,
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(&params),
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
// use safe mode to verify that the proof is correct
None,
None,
);
assert!(prover.is_ok());
let proof = prover.unwrap();
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
}
}
@@ -1909,6 +1911,8 @@ mod add_with_overflow {
#[cfg(test)]
mod add_with_overflow_and_poseidon {
use std::collections::HashMap;
use halo2curves::bn256::Fr;
use crate::circuit::modules::{
@@ -1967,8 +1971,10 @@ mod add_with_overflow_and_poseidon {
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, WIDTH> =
PoseidonChip::new(config.poseidon.clone());
let assigned_inputs_a = poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0)?;
let assigned_inputs_b = poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1)?;
let assigned_inputs_a =
poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0, &mut HashMap::new())?;
let assigned_inputs_b =
poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1, &mut HashMap::new())?;
layouter.assign_region(|| "_new_module", |_| Ok(()))?;
@@ -2417,8 +2423,13 @@ mod lookup_ultra_overflow {
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, ConstraintSystem, Error},
poly::commitment::{Params, ParamsProver},
poly::kzg::{
commitment::KZGCommitmentScheme,
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy,
},
};
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
#[derive(Clone)]
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
@@ -2497,38 +2508,32 @@ mod lookup_ultra_overflow {
let pk = crate::pfsys::create_keys::<
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
F,
ReLUCircuit<F>,
>(&circuit, &params, true)
.unwrap();
let prover = crate::pfsys::create_proof_circuit_kzg(
let prover = crate::pfsys::create_proof_circuit::<
KZGCommitmentScheme<_>,
_,
ProverSHPLONK<_>,
VerifierSHPLONK<_>,
SingleStrategy<_>,
_,
EvmTranscript<_, _, _, _>,
EvmTranscript<_, _, _, _>,
>(
circuit.clone(),
vec![],
&params,
None,
&pk,
crate::pfsys::TranscriptType::EVM,
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(&params),
// use safe mode to verify that the proof is correct
CheckMode::SAFE,
crate::Commitments::KZG,
crate::pfsys::TranscriptType::EVM,
None,
None,
);
assert!(prover.is_ok());
let proof = prover.unwrap();
let strategy =
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
let vk = pk.get_vk();
let result = crate::pfsys::verify_proof_circuit_kzg(
params.verifier_params(),
proof,
vk,
strategy,
params.n(),
);
assert!(result.is_ok());
}
}

View File

@@ -13,7 +13,7 @@ use std::path::PathBuf;
use std::{error::Error, str::FromStr};
use tosubcommand::{ToFlags, ToSubcommand};
use crate::{pfsys::ProofType, RunArgs};
use crate::{pfsys::ProofType, Commitments, RunArgs};
use crate::circuit::CheckMode;
#[cfg(not(target_arch = "wasm32"))]
@@ -90,6 +90,8 @@ pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
/// Default only check for range check rebase
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
/// Default commitment
pub const DEFAULT_COMMITMENT: &str = "kzg";
#[cfg(feature = "python-bindings")]
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
@@ -294,21 +296,6 @@ pub enum Commands {
args: RunArgs,
},
#[cfg(feature = "render")]
/// Renders the model circuit to a .png file. For an overview of how to interpret these plots, see https://zcash.github.io/halo2/user/dev-tools.html
#[command(arg_required_else_help = true)]
RenderCircuit {
/// The path to the .onnx model file
#[arg(short = 'M', long)]
model: PathBuf,
/// Path to save the .png circuit render
#[arg(short = 'O', long)]
output: PathBuf,
/// proving arguments
#[clap(flatten)]
args: RunArgs,
},
/// Generates the witness from an input file.
GenWitness {
/// The path to the .json data file
@@ -387,6 +374,9 @@ pub enum Commands {
/// number of logrows to use for srs
#[arg(long)]
logrows: usize,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
},
#[cfg(not(target_arch = "wasm32"))]
@@ -402,6 +392,9 @@ pub enum Commands {
/// Number of logrows to use for srs. Overrides settings_path if specified.
#[arg(long, default_value = None)]
logrows: Option<u32>,
/// Commitment used
#[arg(long, default_value = None)]
commitment: Option<Commitments>,
},
/// Loads model and input and runs mock prover (for testing)
Mock {
@@ -449,6 +442,9 @@ pub enum Commands {
/// compress selectors
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Option<Commitments>,
},
/// Aggregates proofs :)
Aggregate {
@@ -481,6 +477,9 @@ pub enum Commands {
/// whether the accumulated proofs are segments of a larger circuit
#[arg(long, default_value = DEFAULT_SPLIT)]
split_proofs: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Option<Commitments>,
},
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
CompileCircuit {
@@ -515,31 +514,6 @@ pub enum Commands {
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
},
#[cfg(not(target_arch = "wasm32"))]
/// Fuzzes the proof pipeline with random inputs, random parameters, and random keys
Fuzz {
/// The path to the .json witness file (generated using the gen-witness command)
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
witness: PathBuf,
/// The path to the compiled model file (generated using the compile-circuit command)
#[arg(short = 'M', long)]
compiled_circuit: PathBuf,
#[arg(
long,
require_equals = true,
num_args = 0..=1,
default_value_t = TranscriptType::default(),
value_enum
)]
transcript: TranscriptType,
/// number of fuzz iterations
#[arg(long, default_value = DEFAULT_FUZZ_RUNS)]
num_runs: usize,
/// compress selectors
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
disable_selector_compression: bool,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
#[command(arg_required_else_help = true)]
@@ -741,12 +715,18 @@ pub enum Commands {
/// The path to the verification key file (generated using the setup-aggregate command)
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
vk_path: PathBuf,
/// reduced srs
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
reduced_srs: bool,
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
#[arg(long)]
srs_path: Option<PathBuf>,
/// logrows used for aggregation circuit
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
logrows: u32,
/// commitment
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Option<Commitments>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl

View File

@@ -455,7 +455,7 @@ pub async fn verify_proof_with_data_attestation(
for val in flattened_instances.clone() {
let bytes = val.to_repr();
let u = U256::from_little_endian(bytes.as_slice());
let u = U256::from_little_endian(bytes.inner());
public_inputs.push(u);
}

File diff suppressed because it is too large Load Diff

View File

@@ -15,7 +15,7 @@ use colored_json::ToColoredJson;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::plonk::VerifyingKey;
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
use halo2_proofs::poly::commitment::CommitmentScheme;
pub use input::DataSource;
use itertools::Itertools;
use tosubcommand::ToFlags;
@@ -26,6 +26,7 @@ use self::input::{FileSource, GraphData};
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::region::ConstantsMap;
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::felt_to_f64;
@@ -37,8 +38,8 @@ use halo2_proofs::{
circuit::Layouter,
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
};
use halo2curves::bn256::{self, Bn256, Fr as Fp, G1Affine};
use halo2curves::ff::PrimeField;
use halo2curves::bn256::{self, Fr as Fp, G1Affine};
use halo2curves::ff::{Field, PrimeField};
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
use log::{debug, error, trace, warn};
@@ -126,7 +127,7 @@ pub enum GraphError {
#[error("failed to rescale inputs for {0}")]
RescalingError(String),
/// Error when attempting to load a model
#[error("failed to load model")]
#[error("failed to load")]
ModelLoad,
/// Packing exponent is too large
#[error("largest packing exponent exceeds max. try reducing the scale")]
@@ -155,7 +156,7 @@ use std::cell::RefCell;
thread_local!(
/// This is a global variable that holds the settings for the graph
/// This is used to pass settings to the layouter and other parts of the circuit without needing to heavily modify the Halo2 API in a new fork
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = RefCell::new(None)
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = const { RefCell::new(None) }
);
/// Result from a forward pass
@@ -284,20 +285,20 @@ impl GraphWitness {
}
///
pub fn get_kzg_commitments(&self) -> Vec<G1Affine> {
pub fn get_polycommitments(&self) -> Vec<G1Affine> {
let mut commitments = vec![];
if let Some(processed_inputs) = &self.processed_inputs {
if let Some(commits) = &processed_inputs.kzg_commit {
if let Some(commits) = &processed_inputs.polycommit {
commitments.extend(commits.iter().flatten());
}
}
if let Some(processed_params) = &self.processed_params {
if let Some(commits) = &processed_params.kzg_commit {
if let Some(commits) = &processed_params.polycommit {
commitments.extend(commits.iter().flatten());
}
}
if let Some(processed_outputs) = &self.processed_outputs {
if let Some(commits) = &processed_outputs.kzg_commit {
if let Some(commits) = &processed_outputs.polycommit {
commitments.extend(commits.iter().flatten());
}
}
@@ -318,7 +319,7 @@ impl GraphWitness {
/// Load the model input from a file
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
let file = std::fs::File::open(path.clone())
.map_err(|_| format!("failed to load model at {}", path.display()))?;
.map_err(|_| format!("failed to load {}", path.display()))?;
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
serde_json::from_reader(reader).map_err(|e| e.into())
@@ -387,8 +388,8 @@ impl ToPyObject for GraphWitness {
if let Some(processed_inputs_poseidon_hash) = &processed_inputs.poseidon_hash {
insert_poseidon_hash_pydict(dict_inputs, processed_inputs_poseidon_hash).unwrap();
}
if let Some(processed_inputs_kzg_commit) = &processed_inputs.kzg_commit {
insert_kzg_commit_pydict(dict_inputs, processed_inputs_kzg_commit).unwrap();
if let Some(processed_inputs_polycommit) = &processed_inputs.polycommit {
insert_polycommit_pydict(dict_inputs, processed_inputs_polycommit).unwrap();
}
dict.set_item("processed_inputs", dict_inputs).unwrap();
@@ -398,8 +399,8 @@ impl ToPyObject for GraphWitness {
if let Some(processed_params_poseidon_hash) = &processed_params.poseidon_hash {
insert_poseidon_hash_pydict(dict_params, processed_params_poseidon_hash).unwrap();
}
if let Some(processed_params_kzg_commit) = &processed_params.kzg_commit {
insert_kzg_commit_pydict(dict_inputs, processed_params_kzg_commit).unwrap();
if let Some(processed_params_polycommit) = &processed_params.polycommit {
insert_polycommit_pydict(dict_inputs, processed_params_polycommit).unwrap();
}
dict.set_item("processed_params", dict_params).unwrap();
@@ -409,8 +410,8 @@ impl ToPyObject for GraphWitness {
if let Some(processed_outputs_poseidon_hash) = &processed_outputs.poseidon_hash {
insert_poseidon_hash_pydict(dict_outputs, processed_outputs_poseidon_hash).unwrap();
}
if let Some(processed_outputs_kzg_commit) = &processed_outputs.kzg_commit {
insert_kzg_commit_pydict(dict_inputs, processed_outputs_kzg_commit).unwrap();
if let Some(processed_outputs_polycommit) = &processed_outputs.polycommit {
insert_polycommit_pydict(dict_inputs, processed_outputs_polycommit).unwrap();
}
dict.set_item("processed_outputs", dict_outputs).unwrap();
@@ -429,13 +430,13 @@ fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec<Fp>) -> Resu
}
#[cfg(feature = "python-bindings")]
fn insert_kzg_commit_pydict(pydict: &PyDict, commits: &Vec<Vec<G1Affine>>) -> Result<(), PyErr> {
fn insert_polycommit_pydict(pydict: &PyDict, commits: &Vec<Vec<G1Affine>>) -> Result<(), PyErr> {
use crate::python::PyG1Affine;
let poseidon_hash: Vec<Vec<PyG1Affine>> = commits
.iter()
.map(|c| c.iter().map(|x| PyG1Affine::from(*x)).collect())
.collect();
pydict.set_item("kzg_commit", poseidon_hash)?;
pydict.set_item("polycommit", poseidon_hash)?;
Ok(())
}
@@ -503,7 +504,9 @@ impl GraphSettings {
}
fn constants_logrows(&self) -> u32 {
(self.total_const_size as f64).log2().ceil() as u32
(self.total_const_size as f64 / self.run_args.num_inner_cols as f64)
.log2()
.ceil() as u32
}
/// calculate the total number of instances
@@ -590,7 +593,7 @@ impl GraphSettings {
|| self.run_args.param_visibility.is_hashed()
}
/// requires dynamic lookup
/// requires dynamic lookup
pub fn requires_dynamic_lookup(&self) -> bool {
self.num_dynamic_lookups > 0
}
@@ -601,10 +604,10 @@ impl GraphSettings {
}
/// any kzg visibility
pub fn module_requires_kzg(&self) -> bool {
self.run_args.input_visibility.is_kzgcommit()
|| self.run_args.output_visibility.is_kzgcommit()
|| self.run_args.param_visibility.is_kzgcommit()
pub fn module_requires_polycommit(&self) -> bool {
self.run_args.input_visibility.is_polycommit()
|| self.run_args.output_visibility.is_polycommit()
|| self.run_args.param_visibility.is_polycommit()
}
}
@@ -1049,16 +1052,10 @@ impl GraphCircuit {
}
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
let mut margin = (
(
lookup_safety_margin * min_max_lookup.0,
lookup_safety_margin * min_max_lookup.1,
);
if lookup_safety_margin == 1 {
margin.0 += 4;
margin.1 += 4;
}
margin
)
}
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
@@ -1171,17 +1168,6 @@ impl GraphCircuit {
.settings()
.clone();
// recalculate the logrows if there has been overflow on the constants
settings_mut.run_args.logrows = std::cmp::max(
settings_mut.run_args.logrows,
settings_mut.constants_logrows(),
);
// recalculate the logrows if there has been overflow for the model constraints
settings_mut.run_args.logrows = std::cmp::max(
settings_mut.run_args.logrows,
settings_mut.model_constraint_logrows(),
);
debug!(
"setting lookup_range to: {:?}, setting logrows to: {}",
self.settings().run_args.lookup_range,
@@ -1248,12 +1234,12 @@ impl GraphCircuit {
}
/// Runs the forward pass of the model / graph of computations and any associated hashing.
pub fn forward(
pub fn forward<Scheme: CommitmentScheme<Scalar = Fp, Curve = G1Affine>>(
&self,
inputs: &mut [Tensor<Fp>],
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&ParamsKZG<Bn256>>,
throw_range_check_error: bool,
srs: Option<&Scheme::ParamsProver>,
witness_gen: bool,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
let original_inputs = inputs.to_vec();
@@ -1269,7 +1255,8 @@ impl GraphCircuit {
for outlet in &module_outlets {
module_inputs.push(inputs[*outlet].clone());
}
let res = GraphModules::forward(&module_inputs, &visibility.input, vk, srs)?;
let res =
GraphModules::forward::<Scheme>(&module_inputs, &visibility.input, vk, srs)?;
processed_inputs = Some(res.clone());
let module_results = res.get_result(visibility.input.clone());
@@ -1277,7 +1264,12 @@ impl GraphCircuit {
inputs[*outlet] = Tensor::from(module_results[i].clone().into_iter());
}
} else {
processed_inputs = Some(GraphModules::forward(inputs, &visibility.input, vk, srs)?);
processed_inputs = Some(GraphModules::forward::<Scheme>(
inputs,
&visibility.input,
vk,
srs,
)?);
}
}
@@ -1285,7 +1277,7 @@ impl GraphCircuit {
let params = self.model().get_all_params();
if !params.is_empty() {
let flattened_params = Tensor::new(Some(&params), &[params.len()])?.combine()?;
processed_params = Some(GraphModules::forward(
processed_params = Some(GraphModules::forward::<Scheme>(
&[flattened_params],
&visibility.params,
vk,
@@ -1296,7 +1288,7 @@ impl GraphCircuit {
let mut model_results =
self.model()
.forward(inputs, &self.settings().run_args, throw_range_check_error)?;
.forward(inputs, &self.settings().run_args, witness_gen)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1305,7 +1297,8 @@ impl GraphCircuit {
for outlet in &module_outlets {
module_inputs.push(model_results.outputs[*outlet].clone());
}
let res = GraphModules::forward(&module_inputs, &visibility.output, vk, srs)?;
let res =
GraphModules::forward::<Scheme>(&module_inputs, &visibility.output, vk, srs)?;
processed_outputs = Some(res.clone());
let module_results = res.get_result(visibility.output.clone());
@@ -1314,7 +1307,7 @@ impl GraphCircuit {
Tensor::from(module_results[i].clone().into_iter());
}
} else {
processed_outputs = Some(GraphModules::forward(
processed_outputs = Some(GraphModules::forward::<Scheme>(
&model_results.outputs,
&visibility.output,
vk,
@@ -1458,7 +1451,8 @@ impl GraphCircuit {
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct CircuitSize {
/// The configuration for the graph circuit
pub struct CircuitSize {
num_instances: usize,
num_advice_columns: usize,
num_fixed: usize,
@@ -1468,7 +1462,8 @@ struct CircuitSize {
}
impl CircuitSize {
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
///
pub fn from_cs<F: Field>(cs: &ConstraintSystem<F>, logrows: u32) -> Self {
CircuitSize {
num_instances: cs.num_instance_columns(),
num_advice_columns: cs.num_advice_columns(),
@@ -1610,6 +1605,8 @@ impl Circuit<Fp> for GraphCircuit {
let output_vis = &self.settings().run_args.output_visibility;
let mut graph_modules = GraphModules::new();
let mut constants = ConstantsMap::new();
let mut config = config.clone();
let mut inputs = self
@@ -1655,6 +1652,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut input_outlets,
input_visibility,
&mut instance_offset,
&mut constants,
)?;
// replace inputs with the outlets
for (i, outlet) in outlets.iter().enumerate() {
@@ -1667,6 +1665,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut inputs,
input_visibility,
&mut instance_offset,
&mut constants,
)?;
}
@@ -1703,6 +1702,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut flattened_params,
param_visibility,
&mut instance_offset,
&mut constants,
)?;
let shapes = self.model().const_shapes();
@@ -1731,6 +1731,7 @@ impl Circuit<Fp> for GraphCircuit {
&inputs,
&mut vars,
&outputs,
&mut constants,
)
.map_err(|e| {
log::error!("{}", e);
@@ -1755,6 +1756,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut output_outlets,
&self.settings().run_args.output_visibility,
&mut instance_offset,
&mut constants,
)?;
// replace outputs with the outlets
@@ -1768,6 +1770,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut outputs,
&self.settings().run_args.output_visibility,
&mut instance_offset,
&mut constants,
)?;
}

View File

@@ -5,6 +5,7 @@ use super::vars::*;
use super::GraphError;
use super::GraphSettings;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::ConstantsMap;
use crate::circuit::region::RegionCtx;
use crate::circuit::table::Range;
use crate::circuit::Input;
@@ -404,7 +405,7 @@ impl ParsedNodes {
.get(input)
.ok_or(GraphError::MissingNode(*input))?;
let input_dims = node.out_dims();
let input_dim = input_dims.get(0).ok_or(GraphError::MissingNode(*input))?;
let input_dim = input_dims.first().ok_or(GraphError::MissingNode(*input))?;
inputs.push(input_dim.clone());
}
@@ -514,21 +515,24 @@ impl Model {
instance_shapes.len().to_string().blue(),
"instances".blue()
);
// this is the total number of variables we will need to allocate
// for the circuit
let default_value = if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let inputs: Vec<ValTensor<Fp>> = self
.graph
.input_shapes()?
.iter()
.map(|shape| {
let mut t: ValTensor<Fp> =
vec![default_value.clone(); shape.iter().product()].into();
let len = shape.iter().product();
let mut t: ValTensor<Fp> = (0..len)
.map(|_| {
if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::random(&mut rand::thread_rng()))
}
})
.collect::<Vec<_>>()
.into();
t.reshape(shape)?;
Ok(t)
})
@@ -577,13 +581,13 @@ impl Model {
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
throw_range_check_error: bool,
witness_gen: bool,
) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?;
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen)?;
Ok(res.into())
}
@@ -799,13 +803,18 @@ impl Model {
let input_state_idx = input_state_idx(&input_mappings);
let mut output_mappings = vec![];
for mapping in b.output_mapping.iter() {
for (i, mapping) in b.output_mapping.iter().enumerate() {
let mut mappings = vec![];
if let Some(outlet) = mapping.last_value_slot {
mappings.push(OutputMapping::Single {
outlet,
is_state: mapping.state,
});
} else if mapping.state {
mappings.push(OutputMapping::Single {
outlet: i,
is_state: mapping.state,
});
}
if let Some(last) = mapping.scan {
mappings.push(OutputMapping::Stacked {
@@ -814,6 +823,7 @@ impl Model {
is_state: false,
});
}
output_mappings.push(mappings);
}
@@ -970,7 +980,7 @@ impl Model {
let (model, _) = Model::load_onnx_using_tract(
&mut std::fs::File::open(model_path)
.map_err(|_| format!("failed to load model at {}", model_path.display()))?,
.map_err(|_| format!("failed to load {}", model_path.display()))?,
run_args,
)?;
@@ -1006,7 +1016,7 @@ impl Model {
) -> Result<Self, Box<dyn Error>> {
Model::new(
&mut std::fs::File::open(model)
.map_err(|_| format!("failed to load model at {}", model.display()))?,
.map_err(|_| format!("failed to load {}", model.display()))?,
run_args,
)
}
@@ -1071,6 +1081,8 @@ impl Model {
/// * `layouter` - Halo2 Layouter.
/// * `inputs` - The values to feed into the circuit.
/// * `vars` - The variables for the circuit.
/// * `witnessed_outputs` - The values to compare against.
/// * `constants` - The constants for the circuit.
pub fn layout(
&self,
mut config: ModelConfig,
@@ -1079,6 +1091,7 @@ impl Model {
inputs: &[ValTensor<Fp>],
vars: &mut ModelVars<Fp>,
witnessed_outputs: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
info!("model layout...");
@@ -1104,14 +1117,12 @@ impl Model {
config.base.layout_tables(layouter)?;
config.base.layout_range_checks(layouter)?;
let mut num_rows = 0;
let mut linear_coord = 0;
let mut total_const_size = 0;
let original_constants = constants.clone();
let outputs = layouter.assign_region(
|| "model",
|region| {
let mut thread_safe_region = RegionCtx::new(region, 0, run_args.num_inner_cols);
let mut thread_safe_region = RegionCtx::new_with_constants(region, 0, run_args.num_inner_cols, original_constants.clone());
// we need to do this as this loop is called multiple times
vars.set_instance_idx(instance_idx);
@@ -1157,29 +1168,17 @@ impl Model {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
} else if !run_args.output_visibility.is_private() {
for output in &outputs {
thread_safe_region.increment_total_constants(output.num_constants());
}
}
num_rows = thread_safe_region.row();
linear_coord = thread_safe_region.linear_coord();
total_const_size = thread_safe_region.total_constants();
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
thread_safe_region.debug_report();
*constants = thread_safe_region.assigned_constants().clone();
Ok(outputs)
},
)?;
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
num_rows.to_string().blue(),
"rows".blue(),
linear_coord.to_string().yellow(),
total_const_size.to_string().red()
);
)?;
let duration = start_time.elapsed();
trace!("model layout took: {:?}", duration);
@@ -1213,16 +1212,10 @@ impl Model {
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
};
debug!(
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
idx,
node.as_str(),
region.row(),
region.linear_coord(),
region.total_constants(),
region.max_lookup_inputs(),
region.min_lookup_inputs()
);
debug!("laying out {}: {}", idx, node.as_str(),);
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
region.debug_report();
debug!("dims: {:?}", node.out_dims());
debug!(
"input_dims {:?}",
@@ -1277,8 +1270,8 @@ impl Model {
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());
debug!(
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
num_iter, inputs, model.graph.inputs
"{} iteration(s) in a subgraph with inputs {:?}, sources {:?}, and outputs {:?}",
num_iter, inputs, model.graph.inputs, model.graph.outputs
);
let mut full_results: Vec<ValTensor<Fp>> = vec![];
@@ -1310,6 +1303,7 @@ impl Model {
let res = model.layout_nodes(config, region, &mut subgraph_results)?;
let mut outlets = BTreeMap::new();
let mut stacked_outlets = BTreeMap::new();
for (mappings, outlet_res) in output_mappings.iter().zip(res) {
for mapping in mappings {
@@ -1322,25 +1316,42 @@ impl Model {
let stacked_res = full_results[*outlet]
.clone()
.concat_axis(outlet_res.clone(), axis)?;
outlets.insert(outlet, stacked_res);
} else {
outlets.insert(outlet, outlet_res.clone());
stacked_outlets.insert(outlet, stacked_res);
}
outlets.insert(outlet, outlet_res.clone());
}
}
}
}
full_results = outlets.into_values().collect_vec();
// now extend with stacked elements
let mut pre_stacked_outlets = outlets.clone();
pre_stacked_outlets.extend(stacked_outlets);
let outlets = outlets.into_values().collect_vec();
full_results = pre_stacked_outlets.into_values().collect_vec();
let output_states = output_state_idx(output_mappings);
let input_states = input_state_idx(&input_mappings);
assert_eq!(input_states.len(), output_states.len());
assert_eq!(
input_states.len(),
output_states.len(),
"input and output states must be the same length, got {:?} and {:?}",
input_mappings,
output_mappings
);
for (input_idx, output_idx) in input_states.iter().zip(output_states) {
values[*input_idx] = full_results[output_idx].clone();
assert_eq!(
values[*input_idx].dims(),
outlets[output_idx].dims(),
"input and output dims must be the same, got {:?} and {:?}",
values[*input_idx].dims(),
outlets[output_idx].dims()
);
values[*input_idx] = outlets[output_idx].clone();
}
}
@@ -1380,7 +1391,7 @@ impl Model {
&self,
run_args: &RunArgs,
inputs: &[ValTensor<Fp>],
throw_range_check_error: bool,
witness_gen: bool,
) -> Result<DummyPassRes, Box<dyn Error>> {
debug!("calculating num of constraints using dummy model layout...");
@@ -1399,29 +1410,31 @@ impl Model {
vars: ModelVars::new_dummy(),
};
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error);
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen);
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
let default_value = if !self.visibility.output.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let output_scales = self.graph.get_output_scales()?;
let res = outputs
.iter()
.enumerate()
.map(|(i, output)| {
let mut comparator: ValTensor<Fp> = (0..output.len())
.map(|_| {
if !self.visibility.output.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::random(&mut rand::thread_rng()))
}
})
.collect::<Vec<_>>()
.into();
comparator.reshape(output.dims())?;
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let mut comparator: ValTensor<Fp> =
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
comparator.reshape(output.dims())?;
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
@@ -1432,7 +1445,7 @@ impl Model {
res?;
} else if !self.visibility.output.is_private() {
for output in &outputs {
region.increment_total_constants(output.num_constants());
region.update_constants(output.create_constants_map());
}
}
@@ -1441,14 +1454,7 @@ impl Model {
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
region.row().to_string().blue(),
"rows".blue(),
region.linear_coord().to_string().yellow(),
region.total_constants().to_string().red()
);
region.debug_report();
let outputs = outputs
.iter()

View File

@@ -1,12 +1,13 @@
use crate::circuit::modules::kzg::{KZGChip, KZGConfig};
use crate::circuit::modules::polycommit::{PolyCommitChip, PolyCommitConfig};
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use crate::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
use crate::circuit::modules::Module;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor};
use halo2_proofs::circuit::Layouter;
use halo2_proofs::plonk::{Column, ConstraintSystem, Error, Instance, VerifyingKey};
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
use halo2curves::bn256::{Bn256, Fr as Fp, G1Affine};
use halo2_proofs::poly::commitment::CommitmentScheme;
use halo2curves::bn256::{Fr as Fp, G1Affine};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
@@ -14,9 +15,6 @@ use super::{VarVisibility, Visibility};
/// poseidon len to hash in tree
pub const POSEIDON_LEN_GRAPH: usize = 32;
/// ElGamal number of instances
pub const ELGAMAL_INSTANCES: usize = 4;
/// Poseidon number of instancess
pub const POSEIDON_INSTANCES: usize = 1;
@@ -29,8 +27,8 @@ pub type ModulePoseidonConfig = PoseidonConfig<POSEIDON_WIDTH, POSEIDON_RATE>;
///
#[derive(Clone, Debug, Default)]
pub struct ModuleConfigs {
/// KZG
kzg: Vec<KZGConfig>,
/// PolyCommit
polycommit: Vec<PolyCommitConfig>,
/// Poseidon
poseidon: Option<ModulePoseidonConfig>,
/// Instance
@@ -46,8 +44,10 @@ impl ModuleConfigs {
) -> Self {
let mut config = Self::default();
for size in module_size.kzg {
config.kzg.push(KZGChip::configure(cs, (logrows, size)));
for size in module_size.polycommit {
config
.polycommit
.push(PolyCommitChip::configure(cs, (logrows, size)));
}
config
@@ -94,8 +94,8 @@ impl ModuleConfigs {
pub struct ModuleForwardResult {
/// The inputs of the forward pass for poseidon
pub poseidon_hash: Option<Vec<Fp>>,
/// The outputs of the forward pass for KZG
pub kzg_commit: Option<Vec<Vec<G1Affine>>>,
/// The outputs of the forward pass for PolyCommit
pub polycommit: Option<Vec<Vec<G1Affine>>>,
}
impl ModuleForwardResult {
@@ -126,7 +126,7 @@ impl ModuleForwardResult {
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
///
pub struct ModuleSizes {
kzg: Vec<usize>,
polycommit: Vec<usize>,
poseidon: (usize, Vec<usize>),
}
@@ -134,7 +134,7 @@ impl ModuleSizes {
/// Create new module sizes
pub fn new() -> Self {
ModuleSizes {
kzg: vec![],
polycommit: vec![],
poseidon: (
0,
vec![0; crate::circuit::modules::poseidon::NUM_INSTANCE_COLUMNS],
@@ -156,17 +156,17 @@ impl ModuleSizes {
/// Graph modules that can process inputs, params and outputs beyond the basic operations
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct GraphModules {
kzg_idx: usize,
polycommit_idx: usize,
}
impl GraphModules {
///
pub fn new() -> GraphModules {
GraphModules { kzg_idx: 0 }
GraphModules { polycommit_idx: 0 }
}
///
pub fn reset_index(&mut self) {
self.kzg_idx = 0;
self.polycommit_idx = 0;
}
}
@@ -179,9 +179,9 @@ impl GraphModules {
for shape in shapes {
let total_len = shape.iter().product::<usize>();
if total_len > 0 {
if visibility.is_kzgcommit() {
// 1 constraint for each kzg commitment
sizes.kzg.push(total_len);
if visibility.is_polycommit() {
// 1 constraint for each polycommit commitment
sizes.polycommit.push(total_len);
} else if visibility.is_hashed() {
sizes.poseidon.0 += ModulePoseidon::num_rows(total_len);
// 1 constraints for hash
@@ -212,12 +212,13 @@ impl GraphModules {
layouter: &mut impl Layouter<Fp>,
x: &mut Vec<ValTensor<Fp>>,
instance_offset: &mut usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<(), Error> {
// reserve module 0 for ... modules
// hash the input and replace the constrained cells in the input
let cloned_x = (*x).clone();
x[0] = module
.layout(layouter, &cloned_x, instance_offset.to_owned())
.layout(layouter, &cloned_x, instance_offset.to_owned(), constants)
.unwrap();
for inc in module.instance_increment_input().iter() {
// increment the instance offset to make way for future module layouts
@@ -235,23 +236,24 @@ impl GraphModules {
values: &mut [ValTensor<Fp>],
element_visibility: &Visibility,
instance_offset: &mut usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<(), Error> {
if element_visibility.is_kzgcommit() && !values.is_empty() {
if element_visibility.is_polycommit() && !values.is_empty() {
// concat values and sk to get the inputs
let mut inputs = values.iter_mut().map(|x| vec![x.clone()]).collect_vec();
// layout the module
inputs.iter_mut().for_each(|x| {
// create the module
let chip = KZGChip::new(configs.kzg[self.kzg_idx].clone());
// reserve module 2 onwards for kzg modules
let module_offset = 3 + self.kzg_idx;
let chip = PolyCommitChip::new(configs.polycommit[self.polycommit_idx].clone());
// reserve module 2 onwards for polycommit modules
let module_offset = 3 + self.polycommit_idx;
layouter
.assign_region(|| format!("_enter_module_{}", module_offset), |_| Ok(()))
.unwrap();
Self::layout_module(&chip, layouter, x, instance_offset).unwrap();
Self::layout_module(&chip, layouter, x, instance_offset, constants).unwrap();
// increment the current index
self.kzg_idx += 1;
self.polycommit_idx += 1;
});
// replace the inputs with the outputs
@@ -271,7 +273,7 @@ impl GraphModules {
let mut inputs = values.iter_mut().map(|x| vec![x.clone()]).collect_vec();
// layout the module
inputs.iter_mut().for_each(|x| {
Self::layout_module(&chip, layouter, x, instance_offset).unwrap();
Self::layout_module(&chip, layouter, x, instance_offset, constants).unwrap();
});
// replace the inputs with the outputs
values.iter_mut().enumerate().for_each(|(i, x)| {
@@ -288,14 +290,14 @@ impl GraphModules {
}
/// Run forward pass
pub fn forward(
inputs: &[Tensor<Fp>],
pub fn forward<Scheme: CommitmentScheme<Scalar = Fp, Curve = G1Affine>>(
inputs: &[Tensor<Scheme::Scalar>],
element_visibility: &Visibility,
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&ParamsKZG<Bn256>>,
srs: Option<&Scheme::ParamsProver>,
) -> Result<ModuleForwardResult, Box<dyn std::error::Error>> {
let mut poseidon_hash = None;
let mut kzg_commit = None;
let mut polycommit = None;
if element_visibility.is_hashed() {
let field_elements = inputs.iter().fold(vec![], |mut acc, x| {
@@ -306,11 +308,11 @@ impl GraphModules {
poseidon_hash = Some(field_elements);
}
if element_visibility.is_kzgcommit() {
if element_visibility.is_polycommit() {
if let Some(vk) = vk {
if let Some(srs) = srs {
let commitments = inputs.iter().fold(vec![], |mut acc, x| {
let res = KZGChip::commit(
let res = PolyCommitChip::commit::<Scheme>(
x.to_vec(),
vk.cs().degree() as u32,
(vk.cs().blinding_factors() + 1) as u32,
@@ -319,20 +321,20 @@ impl GraphModules {
acc.push(res);
acc
});
kzg_commit = Some(commitments);
polycommit = Some(commitments);
} else {
log::warn!("no srs provided for kzgcommit. processed value will be none");
log::warn!("no srs provided for polycommit. processed value will be none");
}
} else {
log::debug!(
"no verifying key provided for kzgcommit. processed value will be none"
"no verifying key provided for polycommit. processed value will be none"
);
}
}
Ok(ModuleForwardResult {
poseidon_hash,
kzg_commit,
polycommit,
})
}
}

View File

@@ -248,6 +248,8 @@ pub fn new_op_from_onnx(
symbol_values: &SymbolValues,
rebase_frac_zero_constants: bool,
) -> Result<(SupportedOp, Vec<usize>), Box<dyn std::error::Error>> {
use tract_onnx::tract_core::ops::array::Trilu;
use crate::circuit::InputType;
let input_scales = inputs
@@ -363,6 +365,26 @@ pub fn new_op_from_onnx(
SupportedOp::Constant(c)
}
"Trilu" => {
let op = load_op::<Trilu>(node.op(), idx, node.op().name().to_string())?;
let upper = op.upper;
// assert second input is a constant
let diagonal = if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
let raw_values = &c.raw_values;
if raw_values.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(idx, "trilu".to_string())));
}
raw_values[0] as i32
} else {
return Err("we only support constant inputs for trilu diagonal".into());
};
SupportedOp::Linear(PolyOp::Trilu { upper, k: diagonal })
}
"Gather" => {
if inputs.len() != 2 {
return Err(Box::new(GraphError::InvalidDims(idx, "gather".to_string())));
@@ -839,6 +861,9 @@ pub fn new_op_from_onnx(
}
"Abs" => SupportedOp::Nonlinear(LookupOp::Abs),
"Neg" => SupportedOp::Linear(PolyOp::Neg),
"HardSwish" => SupportedOp::Nonlinear(LookupOp::HardSwish {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
"Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
}),
@@ -1047,8 +1072,12 @@ pub fn new_op_from_onnx(
}
};
let in_scale = inputs[0].out_scales()[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
SupportedOp::Hybrid(HybridOp::Softmax {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
input_scale: scale_to_multiplier(in_scale).into(),
output_scale: scale_to_multiplier(max_scale).into(),
axes: softmax_op.axes.to_vec(),
})
}

Some files were not shown because too many files have changed in this diff Show More