Compare commits

...

14 Commits

Author SHA1 Message Date
github-actions[bot]
541565cc4b ci: update version string in docs 2024-04-06 22:29:47 +00:00
dante
f78618ec59 feat: full ND conv and pool (#770) 2024-04-06 23:29:30 +01:00
Jseam
0943e534ee docs: automated sphinx documentation for python bindings (#714)
---------

Co-authored-by: dante <45801863+alexander-camuto@users.noreply.github.com>
Co-authored-by: Ethan Cemer <tylercemer@gmail.com>
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
2024-04-05 18:33:06 +01:00
dante
316a9a3b40 chore: update tract (#766) 2024-04-04 18:07:08 +01:00
dante
5389012b68 fix: patch large batch ex (#763) 2024-04-03 02:33:57 +01:00
dante
48223cca11 fix: make commitment optional for backwards compat (#762) 2024-04-03 02:26:50 +01:00
dante
32c3a5e159 fix: hold stacked outputs in a separate map 2024-04-02 21:37:20 +01:00
dante
ff563e93a7 fix: bump python version (#761) 2024-04-02 17:08:26 +01:00
dante
5639d36097 chore: verify aggr wasm unit test (#760) 2024-04-01 20:54:20 +01:00
dante
4ec8d13082 chore: verify aggr in wasm (#758) 2024-03-29 23:28:20 +00:00
dante
12735aefd4 chore: reduce softmax recip DR (#756) 2024-03-27 01:14:29 +00:00
dante
7fe179b8d4 feat: dictionary of reusable constants (#754) 2024-03-26 13:12:09 +00:00
Ethan Cemer
3be988a6a0 fix: use pnpm in build script for in-browser-evm-verifier (#752) 2024-03-25 23:23:02 +00:00
dante
3abb3aff56 feat: make selector polynomials optional (#753) 2024-03-22 09:28:28 +00:00
69 changed files with 7876 additions and 4311 deletions

View File

@@ -11,7 +11,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: nanoGPT Mock

View File

@@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -30,13 +30,13 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
export PATH=$PATH:$PWD/binaryen-version_116/bin
wasm-opt --version
set -e
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
export PATH=$PATH:$PWD/binaryen-version_116/bin
wasm-opt --version
- name: Build wasm files for both web and nodejs compilation targets
run: |
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
@@ -92,7 +92,7 @@ jobs:
const jsonObject = JSONBig.parse(string);
return jsonObject;
}
function serialize(data) { // data is an object // return a Uint8ClampedArray
// Step 1: Stringify the Object with BigInt support
if (typeof data === "object") {
@@ -100,11 +100,11 @@ jobs:
}
// Step 2: Encode the JSON String
const uint8Array = new TextEncoder().encode(data);
// Step 3: Convert to Uint8ClampedArray
return new Uint8ClampedArray(uint8Array.buffer);
}
module.exports = {
deserialize,
serialize
@@ -123,7 +123,7 @@ jobs:
const jsonObject = parse(string);
return jsonObject;
}
export function serialize(data) { // data is an object // return a Uint8ClampedArray
// Step 1: Stringify the Object with BigInt support
if (typeof data === "object") {
@@ -131,7 +131,7 @@ jobs:
}
// Step 2: Encode the JSON String
const uint8Array = new TextEncoder().encode(data);
// Step 3: Convert to Uint8ClampedArray
return new Uint8ClampedArray(uint8Array.buffer);
}

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set pyproject.toml version to match github tag

View File

@@ -25,7 +25,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -70,7 +70,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: ${{ matrix.target }}
- name: Set Cargo.toml version to match github tag
@@ -115,7 +115,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -128,6 +128,7 @@ jobs:
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" Cargo.lock.orig >Cargo.lock
- name: Install required libraries
shell: bash
run: |
@@ -139,6 +140,20 @@ jobs:
target: ${{ matrix.target }}
manylinux: auto
args: --release --out dist --features python-bindings
before-script-linux: |
# If we're running on rhel centos, install needed packages.
if command -v yum &> /dev/null; then
yum update -y && yum install -y perl-core openssl openssl-devel pkgconfig libatomic
# If we're running on i686 we need to symlink libatomic
# in order to build openssl with -latomic flag.
if [[ ! -d "/usr/lib64" ]]; then
ln -s /usr/lib/libatomic.so.1 /usr/lib/libatomic.so
fi
else
# If we're running on debian-based system.
apt update -y && apt-get install -y libssl-dev openssl pkg-config
fi
- name: Install built wheel
if: matrix.target == 'x86_64'
@@ -162,7 +177,7 @@ jobs:
# - uses: actions/checkout@v4
# - uses: actions/setup-python@v4
# with:
# python-version: 3.7
# python-version: 3.12
# - name: Install cross-compilation tools for aarch64
# if: matrix.target == 'aarch64'
@@ -214,7 +229,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
architecture: x64
- name: Set Cargo.toml version to match github tag
@@ -249,7 +264,7 @@ jobs:
apk add py3-pip
pip3 install -U pip
python3 -m venv .venv
source .venv/bin/activate
source .venv/bin/activate
pip3 install ezkl --no-index --find-links /io/dist/ --force-reinstall
python3 -c "import ezkl"
@@ -273,7 +288,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: 3.7
python-version: 3.12
- name: Set Cargo.toml version to match github tag
shell: bash
@@ -345,3 +360,17 @@ jobs:
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ./
doc-publish:
name: Trigger ReadTheDocs Build
runs-on: ubuntu-latest
needs: pypi-publish
steps:
- uses: actions/checkout@v4
- name: Trigger RTDs build
uses: dfm/rtds-action@v1
with:
webhook_url: ${{ secrets.RTDS_WEBHOOK_URL }}
webhook_token: ${{ secrets.RTDS_WEBHOOK_TOKEN }}
commit_ref: ${{ github.ref_name }}

View File

@@ -32,7 +32,7 @@ jobs:
token: ${{ secrets.RELEASE_TOKEN }}
tag_name: ${{ env.EZKL_VERSION }}
build-release-gpu:
build-release-gpu:
name: build-release-gpu
needs: ["create-release"]
runs-on: GPU
@@ -45,7 +45,7 @@ jobs:
steps:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Checkout repo
@@ -60,16 +60,15 @@ jobs:
- name: Set Cargo.toml version to match github tag
shell: bash
run: |
mv Cargo.toml Cargo.toml.orig
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.toml.orig >Cargo.toml
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.lock.orig >Cargo.lock
mv Cargo.toml Cargo.toml.orig
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.toml.orig >Cargo.toml
mv Cargo.lock Cargo.lock.orig
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.lock.orig >Cargo.lock
- name: Install dependencies
shell: bash
run: |
sudo apt-get update
sudo apt-get update
- name: Build release binary
run: cargo build --release -Z sparse-registry --features icicle
@@ -91,7 +90,6 @@ jobs:
asset_name: ${{ env.ASSET }}
asset_content_type: application/octet-stream
build-release:
name: build-release
needs: ["create-release"]

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Build
@@ -38,7 +38,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Docs
@@ -50,7 +50,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -73,7 +73,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -106,7 +106,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -139,7 +139,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -172,7 +172,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -184,12 +184,12 @@ jobs:
wasm32-tests:
runs-on: ubuntu-latest
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -199,7 +199,7 @@ jobs:
- name: Install wasm32-unknown-unknown
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- name: Run wasm verifier tests
# on mac:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
@@ -207,12 +207,12 @@ jobs:
tutorial:
runs-on: ubuntu-latest
needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -224,12 +224,12 @@ jobs:
mock-proving-tests:
runs-on: non-gpu
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -281,12 +281,12 @@ jobs:
prove-and-verify-evm-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -303,6 +303,8 @@ jobs:
with:
node-version: "18.12.1"
cache: "pnpm"
- name: "Add rust-src"
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --no-frozen-lockfile
@@ -324,7 +326,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: KZG prove and verify tests (EVM + VK rendered seperately)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
@@ -352,12 +354,12 @@ jobs:
prove-and-verify-tests:
runs-on: non-gpu
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -365,7 +367,7 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- name: Use pnpm 8
uses: pnpm/action-setup@v2
@@ -429,11 +431,11 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- uses: baptiste0928/cargo-install@v1
with:
@@ -458,12 +460,12 @@ jobs:
prove-and-verify-mock-aggr-tests:
runs-on: self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -481,7 +483,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -493,12 +495,12 @@ jobs:
prove-and-verify-aggr-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -510,12 +512,12 @@ jobs:
prove-and-verify-aggr-evm-tests:
runs-on: large-self-hosted
needs: [build, library-tests]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -525,7 +527,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: KZG prove and verify aggr tests
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
@@ -536,7 +538,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -555,18 +557,20 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Install cmake
run: sudo apt-get install -y cmake
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Run pytest
@@ -574,15 +578,15 @@ jobs:
accuracy-measurement-tests:
runs-on: ubuntu-latest-32-cores
# needs: [build, library-tests, docs]
needs: [build, library-tests, docs, python-tests, python-integration-tests]
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.7"
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -590,7 +594,7 @@ jobs:
crate: cargo-nextest
locked: true
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Div rebase
@@ -610,10 +614,10 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -623,11 +627,15 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Install pip
run: python -m ensurepip --upgrade
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt; python -m ensurepip --upgrade
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --test-threads 1
# - name: authenticate-kaggle-cli
# shell: bash
# env:
@@ -643,7 +651,5 @@ jobs:
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
- name: NBEATS tutorial
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
- name: Tictactoe tutorials
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_
# - name: Postgres tutorials
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1

View File

@@ -14,6 +14,40 @@ jobs:
- uses: actions/checkout@v4
- name: Bump version and push tag
id: tag_version
uses: mathieudutour/github-tag-action@v6.1
uses: mathieudutour/github-tag-action@v6.2
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
- name: Set Cargo.toml version to match github tag for docs
shell: bash
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
run: |
mv docs/python/src/conf.py docs/python/src/conf.py.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/src/conf.py.orig >docs/python/src/conf.py
rm docs/python/src/conf.py.orig
mv docs/python/requirements-docs.txt docs/python/requirements-docs.txt.orig
sed "s/0\\.0\\.0/${RELEASE_TAG//v}/" docs/python/requirements-docs.txt.orig >docs/python/requirements-docs.txt
rm docs/python/requirements-docs.txt.orig
- name: Commit files and create tag
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git fetch --tags
git checkout -b release-$RELEASE_TAG
git add .
git commit -m "ci: update version string in docs"
git tag -d $RELEASE_TAG
git tag $RELEASE_TAG
- name: Push changes
uses: ad-m/github-push-action@master
env:
RELEASE_TAG: ${{ steps.tag_version.outputs.new_tag }}
with:
branch: release-${{ steps.tag_version.outputs.new_tag }}
force: true
tags: true

4
.gitignore vendored
View File

@@ -48,4 +48,6 @@ node_modules
/dist
timingData.json
!tests/wasm/pk.key
!tests/wasm/vk.key
!tests/wasm/vk.key
docs/python/build
!tests/wasm/vk_aggr.key

1
.python-version Normal file
View File

@@ -0,0 +1 @@
3.12.1

26
.readthedocs.yaml Normal file
View File

@@ -0,0 +1,26 @@
# .readthedocs.yaml
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
version: 2
build:
os: ubuntu-22.04
tools:
python: "3.12"
# Build documentation in the "docs/" directory with Sphinx
sphinx:
configuration: ./docs/python/src/conf.py
# Optionally build your docs in additional formats such as PDF and ePub
# formats:
# - pdf
# - epub
# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: ./docs/python/requirements-docs.txt

1718
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -15,14 +15,14 @@ crate-type = ["cdylib", "rlib"]
[dependencies]
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "main" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch = "main" }
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "9fff22c", features = [
"derive_serde",
] }
rand = { version = "0.8", default_features = false }
itertools = { version = "0.10.3", default_features = false }
clap = { version = "4.3.3", features = ["derive"] }
clap = { version = "4.5.3", features = ["derive"] }
serde = { version = "1.0.126", features = ["derive"], optional = true }
serde_json = { version = "1.0.97", default_features = false, features = [
"float_roundtrip",
@@ -80,7 +80,7 @@ pyo3-asyncio = { version = "0.20.0", features = [
"tokio-runtime",
], default_features = false, optional = true }
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "681a096f02c9d7d363102d9fb0e446d1710ac2c8", default_features = false, optional = true }
tabled = { version = "0.12.0", optional = true }
@@ -95,10 +95,10 @@ getrandom = { version = "0.2.8", features = ["js"] }
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
wasm-bindgen-rayon = { version = "1.0", optional = true }
wasm-bindgen-test = "0.3.34"
serde-wasm-bindgen = "0.4"
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"] }
wasm-bindgen-rayon = { version = "1.2.1", optional = true }
wasm-bindgen-test = "0.3.42"
serde-wasm-bindgen = "0.6.5"
wasm-bindgen = { version = "0.2.92", features = ["serde-serialize"] }
console_error_panic_hook = "0.1.7"
wasm-bindgen-console-logger = "0.1.1"
@@ -203,5 +203,9 @@ no-banner = []
[patch.'https://github.com/ingonyama-zk/icicle']
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a", package = "halo2_proofs", branch = "ac/optional-selector-poly" }
[profile.release]
rustflags = ["-C", "relocation-model=pic"]

View File

@@ -70,8 +70,8 @@ impl Circuit<Fr> for MyCircuit {
&mut region,
&[self.image.clone(), self.kernel.clone(), self.bias.clone()],
Box::new(PolyOp::Conv {
padding: [(0, 0); 2],
stride: (1, 1),
padding: vec![(0, 0)],
stride: vec![1; 2],
}),
)
.unwrap();

View File

@@ -65,9 +65,9 @@ impl Circuit<Fr> for MyCircuit {
&mut region,
&[self.image.clone()],
Box::new(HybridOp::SumPool {
padding: [(0, 0); 2],
stride: (1, 1),
kernel_shape: (2, 2),
padding: vec![(0, 0); 2],
stride: vec![1, 1],
kernel_shape: vec![2, 2],
normalized: false,
}),
)

View File

@@ -1,3 +1,5 @@
use std::collections::HashMap;
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use ezkl::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
@@ -48,7 +50,7 @@ impl Circuit<Fr> for MyCircuit {
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L> =
PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.image.clone()], 0)?;
chip.layout(&mut layouter, &[self.image.clone()], 0, &mut HashMap::new())?;
Ok(())
}
}

2
docs/python/build.sh Executable file
View File

@@ -0,0 +1,2 @@
#!/bin/sh
sphinx-build ./src build

View File

@@ -0,0 +1,4 @@
ezkl==10.3.0
sphinx
sphinx-rtd-theme
sphinxcontrib-napoleon

29
docs/python/src/conf.py Normal file
View File

@@ -0,0 +1,29 @@
import ezkl
project = 'ezkl'
release = '10.3.0'
version = release
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.inheritance_diagram',
'sphinx.ext.autosectionlabel',
'sphinx.ext.napoleon',
'sphinx_rtd_theme',
]
autosummary_generate = True
autosummary_imported_members = True
templates_path = ['_templates']
exclude_patterns = []
# -- Options for HTML output -------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output
html_theme = 'sphinx_rtd_theme'
html_static_path = ['_static']

11
docs/python/src/index.rst Normal file
View File

@@ -0,0 +1,11 @@
.. extension documentation master file, created by
sphinx-quickstart on Mon Jun 19 15:02:05 2023.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
ezkl python bindings
================================================
.. automodule:: ezkl
:members:
:undoc-members:

View File

@@ -203,8 +203,8 @@ where
let mut region = RegionCtx::new(region, 0, NUM_INNER_COLS);
let op = PolyOp::Conv {
padding: [(PADDING, PADDING); 2],
stride: (STRIDE, STRIDE),
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
};
let x = config
.layer_config

View File

@@ -67,6 +67,7 @@
"model.add(Dense(128, activation='relu'))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(10, activation='softmax'))\n",
"model.output_names=['output']\n",
"\n",
"\n",
"# Train the model as you like here (skipped for brevity)\n",

View File

@@ -38,7 +38,7 @@
"import logging\n",
"\n",
"import tensorflow as tf\n",
"from tensorflow.keras.optimizers.legacy import Adam\n",
"from tensorflow.keras.optimizers import Adam\n",
"from tensorflow.keras.layers import *\n",
"from tensorflow.keras.models import Model\n",
"from tensorflow.keras.datasets import mnist\n",
@@ -71,9 +71,11 @@
},
"outputs": [],
"source": [
"opt = Adam()\n",
"ZDIM = 100\n",
"\n",
"opt = Adam()\n",
"\n",
"\n",
"# discriminator\n",
"# 0 if it's fake, 1 if it's real\n",
"x = in1 = Input((28,28))\n",
@@ -114,8 +116,11 @@
"\n",
"gm = Model(in1, x)\n",
"gm.compile('adam', 'mse')\n",
"gm.output_names=['output']\n",
"gm.summary()\n",
"\n",
"opt = Adam()\n",
"\n",
"# GAN\n",
"dm.trainable = False\n",
"x = dm(gm.output)\n",
@@ -415,7 +420,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

File diff suppressed because one or more lines are too long

View File

@@ -349,6 +349,8 @@
"z_log_var = Dense(ZDIM)(x)\n",
"z = Lambda(lambda x: x[0] + K.exp(0.5 * x[1]) * K.random_normal(shape=K.shape(x[0])))([z_mu, z_log_var])\n",
"dec = get_decoder()\n",
"dec.output_names=['output']\n",
"\n",
"out = dec(z)\n",
"\n",
"mse_loss = mse(Reshape((28*28,))(in1), Reshape((28*28,))(out)) * 28 * 28\n",

View File

@@ -61,11 +61,10 @@
"from sklearn.datasets import load_iris\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.ensemble import RandomForestClassifier as Rf\n",
"import sk2torch\n",
"import torch\n",
"import ezkl\n",
"import os\n",
"from torch import nn\n",
"from hummingbird.ml import convert\n",
"\n",
"\n",
"\n",
@@ -77,28 +76,12 @@
"clr.fit(X_train, y_train)\n",
"\n",
"\n",
"trees = []\n",
"for tree in clr.estimators_:\n",
" trees.append(sk2torch.wrap(tree))\n",
"\n",
"\n",
"class RandomForest(nn.Module):\n",
" def __init__(self, trees):\n",
" super(RandomForest, self).__init__()\n",
" self.trees = nn.ModuleList(trees)\n",
"\n",
" def forward(self, x):\n",
" out = self.trees[0](x)\n",
" for tree in self.trees[1:]:\n",
" out += tree(x)\n",
" return out / len(self.trees)\n",
"\n",
"\n",
"torch_rf = RandomForest(trees)\n",
"torch_rf = convert(clr, 'torch')\n",
"# assert predictions from torch are = to sklearn \n",
"diffs = []\n",
"for i in range(len(X_test)):\n",
" torch_pred = torch_rf(torch.tensor(X_test[i].reshape(1, -1)))\n",
" torch_pred = torch_rf.predict(torch.tensor(X_test[i].reshape(1, -1)))\n",
" sk_pred = clr.predict(X_test[i].reshape(1, -1))\n",
" diffs.append(torch_pred[0].round() - sk_pred[0])\n",
"\n",
@@ -134,14 +117,12 @@
"\n",
"# export to onnx format\n",
"\n",
"torch_rf.eval()\n",
"\n",
"# Input to the model\n",
"shape = X_train.shape[1:]\n",
"x = torch.rand(1, *shape, requires_grad=False)\n",
"torch_out = torch_rf(x)\n",
"torch_out = torch_rf.predict(x)\n",
"# Export the model\n",
"torch.onnx.export(torch_rf, # model being run\n",
"torch.onnx.export(torch_rf.model, # model being run\n",
" # model input (or a tuple for multiple inputs)\n",
" x,\n",
" # where to save the model (can be a file or file-like object)\n",
@@ -158,7 +139,7 @@
"\n",
"data = dict(input_shapes=[shape],\n",
" input_data=[d],\n",
" output_data=[((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])\n",
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
"\n",
"# Serialize data into file:\n",
"json.dump(data, open(\"input.json\", 'w'))\n"
@@ -321,7 +302,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -13,7 +13,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -57,7 +57,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -119,7 +119,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@@ -163,7 +163,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@@ -217,7 +217,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@@ -227,6 +227,10 @@
" self.length = self.compute_length(self.file_good)\n",
" self.data = self.load_data(self.file_good)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -749,7 +753,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -209,6 +209,11 @@
" self.length = self.compute_length(self.file_good, self.file_bad)\n",
" self.data = self.load_data(self.file_good, self.file_bad)\n",
"\n",
" def __iter__(self):\n",
" for i in range(len(self.data)):\n",
" yield self.data[i]\n",
"\n",
"\n",
" def parse_json_object(self, line):\n",
" try:\n",
" return json.loads(line)\n",
@@ -637,7 +642,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.12.2"
}
},
"nbformat": 4,

View File

@@ -0,0 +1,13 @@
{
"input_data": [
[
0.8894134163856506,
0.8894201517105103
]
],
"output_data": [
[
0.8436377
]
]
}

Binary file not shown.

View File

@@ -17,7 +17,7 @@
"clean": "rm -r dist || true",
"build:commonjs": "tsc --project tsconfig.commonjs.json && resolve-tspaths -p tsconfig.commonjs.json",
"build:esm": "tsc --project tsconfig.esm.json && resolve-tspaths -p tsconfig.esm.json",
"build": "pnpm run clean && pnpm run build:commonjs && pnpm run build:esm"
"build": "npm run clean && npm run build:commonjs && npm run build:esm"
},
"dependencies": {
"@ethereumjs/common": "^4.0.0",

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["maturin>=0.14,<0.15"]
requires = ["maturin>=1.0,<2.0"]
build-backend = "maturin"
[tool.pytest.ini_options]

View File

@@ -1,14 +1,14 @@
attrs==22.2.0
exceptiongroup==1.1.1
importlib-metadata==6.1.0
attrs==23.2.0
exceptiongroup==1.2.0
importlib-metadata==7.1.0
iniconfig==2.0.0
maturin==1.0.1
packaging==23.0
pluggy==1.0.0
pytest==7.2.2
maturin==1.5.1
packaging==24.0
pluggy==1.4.0
pytest==8.1.1
tomli==2.0.1
typing-extensions==4.5.0
zipp==3.15.0
onnx==1.14.1
onnxruntime==1.14.1
numpy==1.21.6
typing-extensions==4.10.0
zipp==3.18.1
onnx==1.15.0
onnxruntime==1.17.1
numpy==1.26.4

View File

@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2023-08-24"
channel = "nightly-2024-02-06"
components = ["rustfmt", "clippy"]

View File

@@ -15,6 +15,8 @@ pub use planner::*;
use crate::tensor::{TensorType, ValTensor};
use super::region::ConstantsMap;
/// Module trait used to extend ezkl functionality
pub trait Module<F: PrimeField + TensorType + PartialOrd> {
/// Config
@@ -39,6 +41,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
&self,
layouter: &mut impl Layouter<F>,
input: &[ValTensor<F>],
constants: &mut ConstantsMap<F>,
) -> Result<Self::InputAssignments, Error>;
/// Layout
fn layout(
@@ -46,6 +49,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
layouter: &mut impl Layouter<F>,
input: &[ValTensor<F>],
row_offset: usize,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, Error>;
/// Number of instance values the module uses every time it is applied
fn instance_increment_input(&self) -> Vec<usize>;

View File

@@ -4,6 +4,8 @@ is already implemented in halo2_gadgets, there is no wrapper chip that makes it
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
*/
use std::collections::HashMap;
// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash
use halo2_proofs::halo2curves::bn256::Fr as Fp;
use halo2_proofs::poly::commitment::{Blind, CommitmentScheme, Params};
@@ -13,6 +15,7 @@ use halo2curves::group::prime::PrimeCurveAffine;
use halo2curves::group::Curve;
use halo2curves::CurveAffine;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType, VarTensor};
use super::Module;
@@ -107,6 +110,7 @@ impl Module<Fp> for PolyCommitChip {
&self,
_: &mut impl Layouter<Fp>,
_: &[ValTensor<Fp>],
_: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
Ok(())
}
@@ -119,11 +123,24 @@ impl Module<Fp> for PolyCommitChip {
layouter: &mut impl Layouter<Fp>,
input: &[ValTensor<Fp>],
_: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
assert_eq!(input.len(), 1);
let local_constants = constants.clone();
layouter.assign_region(
|| "PolyCommit",
|mut region| self.config.inputs.assign(&mut region, 0, &input[0]),
|mut region| {
let mut local_inner_constants = local_constants.clone();
let res = self.config.inputs.assign(
&mut region,
0,
&input[0],
&mut local_inner_constants,
)?;
*constants = local_inner_constants;
Ok(res)
},
)
}
@@ -184,7 +201,12 @@ mod tests {
mut layouter: impl Layouter<Fp>,
) -> Result<(), Error> {
let polycommit_chip = PolyCommitChip::new(config);
polycommit_chip.layout(&mut layouter, &[self.message.clone()], 0);
polycommit_chip.layout(
&mut layouter,
&[self.message.clone()],
0,
&mut HashMap::new(),
);
Ok(())
}

View File

@@ -18,6 +18,7 @@ use maybe_rayon::slice::ParallelSlice;
use std::marker::PhantomData;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor, ValType};
use super::Module;
@@ -172,12 +173,15 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
&self,
layouter: &mut impl Layouter<Fp>,
message: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Self::InputAssignments, Error> {
assert_eq!(message.len(), 1);
let message = message[0].clone();
let start_time = instant::Instant::now();
let local_constants = constants.clone();
let res = layouter.assign_region(
|| "load message",
|mut region| {
@@ -199,12 +203,26 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
Ok(v.clone())
}
ValType::Constant(f) => region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
),
ValType::Constant(f) => {
if local_constants.contains_key(f) {
Ok(constants.get(f).unwrap().assigned_cell().ok_or({
log::error!("constant not previously assigned");
Error::Synthesis
})?)
} else {
let res = region.assign_advice_from_constant(
|| format!("load message_{}", i),
self.config.hash_inputs[x],
y,
*f,
)?;
constants
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
Ok(res)
}
}
e => {
log::error!(
"wrong input type {:?}, must be previously assigned",
@@ -270,8 +288,9 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
layouter: &mut impl Layouter<Fp>,
input: &[ValTensor<Fp>],
row_offset: usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<ValTensor<Fp>, Error> {
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input)?;
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
// extract the values from the input cells
let mut assigned_input: Tensor<ValType<Fp>> =
input_cells.iter().map(|e| ValType::from(e.clone())).into();
@@ -434,7 +453,7 @@ mod tests {
*,
};
use std::marker::PhantomData;
use std::{collections::HashMap, marker::PhantomData};
use halo2_gadgets::poseidon::primitives::Spec;
use halo2_proofs::{
@@ -477,7 +496,12 @@ mod tests {
mut layouter: impl Layouter<Fp>,
) -> Result<(), Error> {
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, L> = PoseidonChip::new(config);
chip.layout(&mut layouter, &[self.message.clone()], 0)?;
chip.layout(
&mut layouter,
&[self.message.clone()],
0,
&mut HashMap::new(),
)?;
Ok(())
}

View File

@@ -345,7 +345,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
Self {
@@ -956,20 +956,6 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
values: &[ValTensor<F>],
op: Box<dyn Op<F>>,
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
let res = op.layout(self, region, values)?;
if matches!(&self.check_mode, CheckMode::SAFE) && !region.is_dummy() {
if let Some(claimed_output) = &res {
// during key generation this will be unknown vals so we use this as a flag to check
let mut is_assigned = !claimed_output.any_unknowns()?;
for val in values.iter() {
is_assigned = is_assigned && !val.any_unknowns()?;
}
if is_assigned {
op.safe_mode_check(claimed_output, values)?;
}
}
};
Ok(res)
op.layout(self, region, values)
}
}

View File

@@ -1,9 +1,9 @@
use super::*;
use crate::{
circuit::{layouts, utils, Tolerance},
fieldutils::{felt_to_i128, i128_to_felt},
fieldutils::i128_to_felt,
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
tensor::{self, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
use serde::{Deserialize, Serialize};
@@ -29,15 +29,15 @@ pub enum HybridOp {
dim: usize,
},
SumPool {
padding: [(usize, usize); 2],
stride: (usize, usize),
kernel_shape: (usize, usize),
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
kernel_shape: Vec<usize>,
normalized: bool,
},
MaxPool2d {
padding: [(usize, usize); 2],
stride: (usize, usize),
pool_dims: (usize, usize),
MaxPool {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
pool_dims: Vec<usize>,
},
ReduceMin {
axes: Vec<usize>,
@@ -46,7 +46,8 @@ pub enum HybridOp {
dim: usize,
},
Softmax {
scale: utils::F32,
input_scale: utils::F32,
output_scale: utils::F32,
axes: Vec<usize>,
},
RangeCheck(Tolerance),
@@ -70,7 +71,7 @@ pub enum HybridOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
@@ -84,86 +85,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
fn as_any(&self) -> &dyn Any {
self
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let x = inputs[0].clone().map(|x| felt_to_i128(x));
let res = match &self {
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
HybridOp::Div { denom, .. } => {
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
}
HybridOp::Recip {
input_scale,
output_scale,
..
} => crate::tensor::ops::nonlinearities::recip(
&x,
input_scale.0 as f64,
output_scale.0 as f64,
),
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
HybridOp::Gather { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather(&x, idx, *dim)?
} else {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
}
}
HybridOp::OneHot { dim, num_classes } => {
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
}
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
HybridOp::MaxPool2d {
padding,
stride,
pool_dims,
..
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
HybridOp::SumPool {
padding,
stride,
kernel_shape,
normalized,
} => tensor::ops::sumpool(&x, *padding, *stride, *kernel_shape, *normalized)?,
HybridOp::Softmax { scale, axes } => {
tensor::ops::nonlinearities::softmax_axes(&x, scale.into(), axes)
}
HybridOp::RangeCheck(tol) => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
}
HybridOp::Greater => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::greater(&x, &y)?
}
HybridOp::GreaterEqual => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::greater_equal(&x, &y)?
}
HybridOp::Less => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::less(&x, &y)?
}
HybridOp::LessEqual => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::less_equal(&x, &y)?
}
HybridOp::Equals => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::equals(&x, &y)?
}
};
// convert back to felt
let output = res.map(|x| i128_to_felt(x));
Ok(ForwardResult { output })
}
fn as_string(&self) -> String {
match self {
@@ -193,18 +114,25 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
),
HybridOp::ReduceMax { axes } => format!("REDUCEMAX (axes={:?})", axes),
HybridOp::ReduceArgMax { dim } => format!("REDUCEARGMAX (dim={})", dim),
HybridOp::MaxPool2d {
HybridOp::MaxPool {
padding,
stride,
pool_dims,
} => format!(
"MAXPOOL2D (padding={:?}, stride={:?}, pool_dims={:?})",
"MaxPool (padding={:?}, stride={:?}, pool_dims={:?})",
padding, stride, pool_dims
),
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
HybridOp::Softmax { scale, axes } => {
format!("SOFTMAX (scale={}, axes={:?})", scale, axes)
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => {
format!(
"SOFTMAX (input_scale={}, output_scale={}, axes={:?})",
input_scale, output_scale, axes
)
}
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
HybridOp::Greater => "GREATER".into(),
@@ -238,9 +166,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
config,
region,
values[..].try_into()?,
*padding,
*stride,
*kernel_shape,
padding,
stride,
kernel_shape,
*normalized,
)?,
HybridOp::Recip {
@@ -300,17 +228,17 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
}
}
HybridOp::MaxPool2d {
HybridOp::MaxPool {
padding,
stride,
pool_dims,
} => layouts::max_pool2d(
} => layouts::max_pool(
config,
region,
values[..].try_into()?,
*padding,
*stride,
*pool_dims,
padding,
stride,
pool_dims,
)?,
HybridOp::ReduceMax { axes } => {
layouts::max_axes(config, region, values[..].try_into()?, axes)?
@@ -324,9 +252,18 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::ReduceArgMin { dim } => {
layouts::argmin_axes(config, region, values[..].try_into()?, *dim)?
}
HybridOp::Softmax { scale, axes } => {
layouts::softmax_axes(config, region, values[..].try_into()?, *scale, axes)?
}
HybridOp::Softmax {
input_scale,
output_scale,
axes,
} => layouts::softmax_axes(
config,
region,
values[..].try_into()?,
*input_scale,
*output_scale,
axes,
)?,
HybridOp::RangeCheck(tol) => layouts::range_check_percent(
config,
region,
@@ -359,8 +296,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
| HybridOp::ReduceArgMax { .. }
| HybridOp::OneHot { .. }
| HybridOp::ReduceArgMin { .. } => 0,
HybridOp::Softmax { .. } => 2 * in_scales[0],
HybridOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.0 as f64),
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
multiplier_to_scale(output_scale.0 as f64)
}
_ => in_scales[0],
};
Ok(scale)

File diff suppressed because it is too large Load Diff

View File

@@ -135,15 +135,12 @@ impl LookupOp {
let range = range as i128;
(-range, range)
}
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
x: &[Tensor<F>],
) -> Result<ForwardResult<F>, TensorError> {
let x = x[0].clone().map(|x| felt_to_i128(x));
let res = match &self {
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
@@ -235,6 +232,13 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
Ok(ForwardResult { output })
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
}
/// Returns the name of the operation
fn as_string(&self) -> String {

View File

@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
use crate::{
graph::quantize_tensor,
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
tensor::{self, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
@@ -27,14 +27,14 @@ pub mod region;
/// A struct representing the result of a forward pass.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
pub(crate) output: Tensor<F>,
}
/// A trait representing operations that can be represented as constraints in a circuit.
pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send + Sync + Any {
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
std::fmt::Debug + Send + Sync + Any
{
/// Returns a string representation of the operation.
fn as_string(&self) -> String;
@@ -69,36 +69,9 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any;
/// Safe mode output checl
fn safe_mode_check(
&self,
claimed_output: &ValTensor<F>,
original_values: &[ValTensor<F>],
) -> Result<(), TensorError> {
let felt_evals = original_values
.iter()
.map(|v| {
let mut evals = v.get_felt_evals().map_err(|_| TensorError::FeltError)?;
evals.reshape(v.dims())?;
Ok(evals)
})
.collect::<Result<Vec<_>, _>>()?;
let ref_op: Tensor<F> = self.f(&felt_evals)?.output;
let mut output = claimed_output
.get_felt_evals()
.map_err(|_| TensorError::FeltError)?;
output.reshape(claimed_output.dims())?;
assert_eq!(output, ref_op);
Ok(())
}
}
impl<F: PrimeField + TensorType + PartialOrd> Clone for Box<dyn Op<F>> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
fn clone(&self) -> Self {
self.clone_dyn()
}
@@ -165,7 +138,7 @@ pub struct Input {
pub datum_type: InputType,
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
Ok(self.scale)
}
@@ -174,12 +147,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
self
}
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
Ok(ForwardResult {
output: x[0].clone(),
})
}
fn as_string(&self) -> String {
"Input".into()
}
@@ -226,16 +193,13 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Unknown;
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
Ok(0)
}
fn as_any(&self) -> &dyn Any {
self
}
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
Err(TensorError::WrongMethod)
}
fn as_string(&self) -> String {
"Unknown".into()
@@ -256,7 +220,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
///
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
///
pub quantized_values: Tensor<F>,
///
@@ -266,7 +230,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
pub pre_assigned_val: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
///
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
Self {
@@ -293,17 +257,18 @@ impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
}
}
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
for Constant<F>
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self
}
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let output = self.quantized_values.clone();
Ok(ForwardResult { output })
}
fn as_string(&self) -> String {
format!("CONST (scale={})", self.quantized_values.scale().unwrap())

View File

@@ -1,6 +1,5 @@
use crate::{
circuit::layouts,
fieldutils::felt_to_i128,
tensor::{self, Tensor, TensorError},
};
@@ -32,8 +31,8 @@ pub enum PolyOp {
equation: String,
},
Conv {
padding: [(usize, usize); 2],
stride: (usize, usize),
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
},
Downsample {
axis: usize,
@@ -41,9 +40,9 @@ pub enum PolyOp {
modulo: usize,
},
DeConv {
padding: [(usize, usize); 2],
output_padding: (usize, usize),
stride: (usize, usize),
padding: Vec<(usize, usize)>,
output_padding: Vec<usize>,
stride: Vec<usize>,
},
Add,
Sub,
@@ -58,10 +57,13 @@ pub enum PolyOp {
destination: usize,
},
Flatten(Vec<usize>),
Pad([(usize, usize); 2]),
Pad(Vec<(usize, usize)>),
Sum {
axes: Vec<usize>,
},
MeanOfSquares {
axes: Vec<usize>,
},
Prod {
axes: Vec<usize>,
len_prod: usize,
@@ -89,8 +91,14 @@ pub enum PolyOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
for PolyOp
impl<
F: PrimeField
+ TensorType
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
@@ -99,10 +107,28 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::GatherND { batch_dims, .. } => format!("GATHERND (batch_dims={})", batch_dims),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::ScatterND { .. } => "SCATTERND".into(),
PolyOp::GatherElements { dim, constant_idx } => format!(
"GATHERELEMENTS (dim={}, constant_idx{})",
dim,
constant_idx.is_some()
),
PolyOp::GatherND {
batch_dims,
indices,
} => format!(
"GATHERND (batch_dims={}, constant_idx{})",
batch_dims,
indices.is_some()
),
PolyOp::MeanOfSquares { axes } => format!("MEANOFSQUARES (axes={:?})", axes),
PolyOp::ScatterElements { dim, constant_idx } => format!(
"SCATTERELEMENTS (dim={}, constant_idx{})",
dim,
constant_idx.is_some()
),
PolyOp::ScatterND { constant_idx } => {
format!("SCATTERND (constant_idx={})", constant_idx.is_some())
}
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
@@ -114,15 +140,26 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
}
PolyOp::Reshape(shape) => format!("RESHAPE (shape={:?})", shape),
PolyOp::Flatten(_) => "FLATTEN".into(),
PolyOp::Pad(_) => "PAD".into(),
PolyOp::Pad(pads) => format!("PAD (pads={:?})", pads),
PolyOp::Add => "ADD".into(),
PolyOp::Mult => "MULT".into(),
PolyOp::Sub => "SUB".into(),
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Conv { .. } => "CONV".into(),
PolyOp::DeConv { .. } => "DECONV".into(),
PolyOp::Conv { stride, padding } => {
format!("CONV (stride={:?}, padding={:?})", stride, padding)
}
PolyOp::DeConv {
stride,
padding,
output_padding,
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
stride, padding, output_padding
)
}
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
PolyOp::Slice { axis, start, end } => {
format!("SLICE (axis={}, start={}, end={})", axis, start, end)
@@ -136,146 +173,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
}
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let mut inputs = inputs.to_vec();
let res = match &self {
PolyOp::MultiBroadcastTo { shape } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch(
"multibroadcastto inputs".to_string(),
));
}
inputs[0].expand(shape)
}
PolyOp::And => tensor::ops::and(&inputs[0], &inputs[1]),
PolyOp::Or => tensor::ops::or(&inputs[0], &inputs[1]),
PolyOp::Xor => tensor::ops::xor(&inputs[0], &inputs[1]),
PolyOp::Not => tensor::ops::not(&inputs[0]),
PolyOp::Downsample {
axis,
stride,
modulo,
} => tensor::ops::downsample(&inputs[0], *axis, *stride, *modulo),
PolyOp::Resize { scale_factor } => tensor::ops::resize(&inputs[0], scale_factor),
PolyOp::Iff => tensor::ops::iff(&inputs[0], &inputs[1], &inputs[2]),
PolyOp::Einsum { equation } => tensor::ops::einsum(equation, &inputs),
PolyOp::Identity { .. } => Ok(inputs[0].clone()),
PolyOp::Reshape(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims)?;
Ok(t)
}
PolyOp::MoveAxis {
source,
destination,
} => inputs[0].move_axis(*source, *destination),
PolyOp::Flatten(new_dims) => {
let mut t = inputs[0].clone();
t.reshape(new_dims)?;
Ok(t)
}
PolyOp::Pad(p) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pad inputs".to_string()));
}
tensor::ops::pad(&inputs[0], *p)
}
PolyOp::Add => tensor::ops::add(&inputs),
PolyOp::Neg => tensor::ops::neg(&inputs[0]),
PolyOp::Sub => tensor::ops::sub(&inputs),
PolyOp::Mult => tensor::ops::mult(&inputs),
PolyOp::Conv { padding, stride } => tensor::ops::conv(&inputs, *padding, *stride),
PolyOp::DeConv {
padding,
output_padding,
stride,
} => tensor::ops::deconv(&inputs, *padding, *output_padding, *stride),
PolyOp::Pow(u) => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("pow inputs".to_string()));
}
inputs[0].pow(*u)
}
PolyOp::Sum { axes } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("sum inputs".to_string()));
}
tensor::ops::sum_axes(&inputs[0], axes)
}
PolyOp::Prod { axes, .. } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("prod inputs".to_string()));
}
tensor::ops::prod_axes(&inputs[0], axes)
}
PolyOp::Concat { axis } => {
tensor::ops::concat(&inputs.iter().collect::<Vec<_>>(), *axis)
}
PolyOp::Slice { axis, start, end } => {
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("slice inputs".to_string()));
}
tensor::ops::slice(&inputs[0], axis, start, end)
}
PolyOp::GatherElements { dim, constant_idx } => {
let x = inputs[0].clone();
let y = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_elements(&x, &y, *dim)
}
PolyOp::GatherND {
indices,
batch_dims,
} => {
let x = inputs[0].clone();
let y = if let Some(idx) = indices {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_nd(&x, &y, *batch_dims)
}
PolyOp::ScatterElements { dim, constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter(&x, &idx, &src, *dim)
}
PolyOp::ScatterND { constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if constant_idx.is_some() {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter_nd(&x, &idx, &src)
}
PolyOp::Trilu { upper, k } => tensor::ops::trilu(&inputs[0], *k, *upper),
}?;
Ok(ForwardResult { output: res })
}
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<F>,
@@ -286,6 +183,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::MultiBroadcastTo { shape } => {
layouts::expand(config, region, values[..].try_into()?, shape)?
}
PolyOp::MeanOfSquares { axes } => {
layouts::mean_of_squares_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Xor => layouts::xor(config, region, values[..].try_into()?)?,
PolyOp::Or => layouts::or(config, region, values[..].try_into()?)?,
PolyOp::And => layouts::and(config, region, values[..].try_into()?)?,
@@ -312,7 +212,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Conv { padding, stride } => {
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
}
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
@@ -364,9 +264,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
config,
region,
values[..].try_into()?,
*padding,
*output_padding,
*stride,
padding,
output_padding,
stride,
)?,
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
@@ -382,7 +282,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
)));
}
let mut input = values[0].clone();
input.pad(*p)?;
input.pad(p.clone(), 0)?;
input
}
PolyOp::Pow(exp) => layouts::pow(config, region, values[..].try_into()?, *exp)?,
@@ -398,6 +298,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn out_scale(&self, in_scales: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
let scale = match self {
PolyOp::MeanOfSquares { .. } => 2 * in_scales[0],
PolyOp::Xor | PolyOp::Or | PolyOp::And | PolyOp::Not => 0,
PolyOp::Iff => in_scales[1],
PolyOp::Einsum { .. } => {

View File

@@ -2,24 +2,28 @@ use crate::{
circuit::table::Range,
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
};
#[cfg(not(target_arch = "wasm32"))]
use colored::Colorize;
use halo2_proofs::{
circuit::Region,
plonk::{Error, Selector},
};
use halo2curves::ff::PrimeField;
use portable_atomic::AtomicI128 as AtomicInt;
use std::{
cell::RefCell,
collections::HashSet,
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
};
use portable_atomic::AtomicI128 as AtomicInt;
use super::lookup::LookupOp;
/// Constants map
pub type ConstantsMap<F> = HashMap<F, ValType<F>>;
/// Dynamic lookup index
#[derive(Clone, Debug, Default)]
pub struct DynamicLookupIndex {
@@ -120,12 +124,11 @@ impl From<Box<dyn std::error::Error>> for RegionError {
#[derive(Debug)]
/// A context for a region
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
region: Option<RefCell<Region<'a, F>>>,
row: usize,
linear_coord: usize,
num_inner_cols: usize,
total_constants: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
used_lookups: HashSet<LookupOp>,
@@ -133,13 +136,34 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
max_lookup_inputs: i128,
min_lookup_inputs: i128,
max_range_size: i128,
throw_range_check_error: bool,
witness_gen: bool,
assigned_constants: ConstantsMap<F>,
}
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a, F> {
#[cfg(not(target_arch = "wasm32"))]
///
pub fn increment_total_constants(&mut self, n: usize) {
self.total_constants += n;
pub fn debug_report(&self) {
log::debug!(
"(rows={}, coord={}, constants={}, max_lookup_inputs={}, min_lookup_inputs={}, max_range_size={}, dynamic_lookup_col_coord={}, shuffle_col_coord={})",
self.row().to_string().blue(),
self.linear_coord().to_string().yellow(),
self.total_constants().to_string().red(),
self.max_lookup_inputs().to_string().green(),
self.min_lookup_inputs().to_string().green(),
self.max_range_size().to_string().green(),
self.dynamic_lookup_col_coord().to_string().green(),
self.shuffle_col_coord().to_string().green());
}
///
pub fn assigned_constants(&self) -> &ConstantsMap<F> {
&self.assigned_constants
}
///
pub fn update_constants(&mut self, constants: ConstantsMap<F>) {
self.assigned_constants.extend(constants);
}
///
@@ -163,8 +187,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
///
pub fn throw_range_check_error(&self) -> bool {
self.throw_range_check_error
pub fn witness_gen(&self) -> bool {
self.witness_gen
}
/// Create a new region context
@@ -177,7 +201,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
row,
linear_coord,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -185,9 +208,22 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
witness_gen: true,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_with_constants(
region: Region<'a, F>,
row: usize,
num_inner_cols: usize,
constants: ConstantsMap<F>,
) -> RegionCtx<'a, F> {
let mut new_self = Self::new(region, row, num_inner_cols);
new_self.assigned_constants = constants;
new_self
}
/// Create a new region context from a wrapped region
pub fn from_wrapped_region(
region: Option<RefCell<Region<'a, F>>>,
@@ -202,7 +238,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index,
shuffle_index,
used_lookups: HashSet::new(),
@@ -210,16 +245,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error: false,
witness_gen: false,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_dummy(
row: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
) -> RegionCtx<'a, F> {
pub fn new_dummy(row: usize, num_inner_cols: usize, witness_gen: bool) -> RegionCtx<'a, F> {
let region = None;
let linear_coord = row * num_inner_cols;
@@ -228,7 +260,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants: 0,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -236,17 +267,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
witness_gen,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_dummy_with_constants(
pub fn new_dummy_with_linear_coord(
row: usize,
linear_coord: usize,
total_constants: usize,
num_inner_cols: usize,
throw_range_check_error: bool,
witness_gen: bool,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -254,7 +285,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
num_inner_cols,
linear_coord,
row,
total_constants,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
used_lookups: HashSet::new(),
@@ -262,7 +292,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
throw_range_check_error,
witness_gen,
assigned_constants: HashMap::new(),
}
}
@@ -312,29 +343,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
) -> Result<(), RegionError> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let constants = AtomicUsize::new(self.total_constants());
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
*output = output
.par_enum_map(|idx, _| {
// we kick off the loop with the current offset
let starting_offset = row.load(Ordering::SeqCst);
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
let starting_constants = constants.load(Ordering::SeqCst);
// get inner value of the locked lookups
// we need to make sure that the region is not shared between threads
let mut local_reg = Self::new_dummy_with_constants(
let mut local_reg = Self::new_dummy_with_linear_coord(
starting_offset,
starting_linear_coord,
starting_constants,
self.num_inner_cols,
self.throw_range_check_error,
self.witness_gen,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -343,10 +372,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
local_reg.linear_coord() - starting_linear_coord,
Ordering::SeqCst,
);
constants.fetch_add(
local_reg.total_constants() - starting_constants,
Ordering::SeqCst,
);
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
@@ -362,11 +387,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
// update the shuffle index
let mut shuffle_index = shuffle_index.lock().unwrap();
shuffle_index.update(&local_reg.shuffle_index);
// update the constants
let mut constants = constants.lock().unwrap();
constants.extend(local_reg.assigned_constants);
res
})
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
self.total_constants = constants.into_inner();
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
{
@@ -410,6 +437,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
})?;
self.assigned_constants = Arc::try_unwrap(constants)
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?
.into_inner()
.map_err(|e| {
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
})?;
Ok(())
}
@@ -435,7 +470,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
range: Range,
) -> Result<(), Box<dyn std::error::Error>> {
if range.0 > range.1 {
return Err("update_max_min_lookup_range: invalid range".into());
return Err(format!("update_max_min_lookup_range: invalid range {:?}", range).into());
}
let range_size = (range.1 - range.0).abs();
@@ -477,7 +512,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
/// Get the total number of constants
pub fn total_constants(&self) -> usize {
self.total_constants
self.assigned_constants.len()
}
/// Get the dynamic lookup index
@@ -525,26 +560,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.max_range_size
}
/// Assign a constant value
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
self.total_constants += 1;
if let Some(region) = &self.region {
let cell = var.assign_constant(&mut region.borrow_mut(), self.linear_coord, value)?;
Ok(cell.into())
} else {
Ok(value.into())
}
}
/// Assign a valtensor to a vartensor
pub fn assign(
&mut self,
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(&mut region.borrow_mut(), self.linear_coord, values)
var.assign(
&mut region.borrow_mut(),
self.linear_coord,
values,
&mut self.assigned_constants,
)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -560,14 +593,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
var: &VarTensor,
values: &ValTensor<F>,
) -> Result<ValTensor<F>, Error> {
self.total_constants += values.num_constants();
if let Some(region) = &self.region {
var.assign(
&mut region.borrow_mut(),
self.combined_dynamic_shuffle_coord(),
values,
&mut self.assigned_constants,
)
} else {
if !values.is_instance() {
let values_map = values.create_constants_map_iterator();
self.assigned_constants.extend(values_map);
}
Ok(values.clone())
}
}
@@ -594,13 +631,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.linear_coord,
values,
ommissions,
&mut self.assigned_constants,
)
} else {
self.total_constants += values.num_constants();
let inner_tensor = values.get_inner_tensor().unwrap();
let mut values_map = values.create_constants_map();
for o in ommissions {
self.total_constants -= inner_tensor.get_flat_index(**o).is_constant() as usize;
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
values_map.remove(&value);
}
}
self.assigned_constants.extend(values_map);
Ok(values.clone())
}
}
@@ -615,24 +659,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
) -> Result<(ValTensor<F>, usize), Error> {
if let Some(region) = &self.region {
// duplicates every nth element to adjust for column overflow
let (res, len, total_assigned_constants) = var.assign_with_duplication(
let (res, len) = var.assign_with_duplication(
&mut region.borrow_mut(),
self.row,
self.linear_coord,
values,
check_mode,
single_inner_col,
&mut self.assigned_constants,
)?;
self.total_constants += total_assigned_constants;
Ok((res, len))
} else {
let (_, len, total_assigned_constants) = var.dummy_assign_with_duplication(
let (_, len) = var.dummy_assign_with_duplication(
self.row,
self.linear_coord,
values,
single_inner_col,
&mut self.assigned_constants,
)?;
self.total_constants += total_assigned_constants;
Ok((values.clone(), len))
}
}
@@ -699,9 +743,4 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
}
Ok(())
}
/// increment constants
pub fn increment_constants(&mut self, n: usize) {
self.total_constants += n
}
}

View File

@@ -17,8 +17,6 @@ use crate::{
use crate::circuit::lookup::LookupOp;
use super::Op;
/// The range of the lookup table.
pub type Range = (i128, i128);
@@ -98,7 +96,7 @@ pub struct Table<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
@@ -113,11 +111,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
let chunk = chunk as i128;
// we index from 1 to prevent soundness issues
let first_element = i128_to_felt(chunk * (self.col_size as i128) + self.range.0);
let op_f = Op::<F>::f(
&self.nonlinearity,
&[Tensor::from(vec![first_element].into_iter())],
)
.unwrap();
let op_f = self
.nonlinearity
.f(&[Tensor::from(vec![first_element].into_iter())])
.unwrap();
(first_element, op_f.output[0])
}
@@ -138,7 +135,7 @@ pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
(range_len / (col_size as i128)) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
/// Configures the table.
pub fn configure(
cs: &mut ConstraintSystem<F>,
@@ -152,7 +149,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
// number of cols needed to store the range
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
log::debug!("table range: {:?}", range);
debug!("table range: {:?}", range);
let table_inputs = preexisting_inputs.unwrap_or_else(|| {
let mut cols = vec![];
@@ -165,7 +162,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
let num_cols = table_inputs.len();
if num_cols > 1 {
debug!("Using {} columns for non-linearity table.", num_cols);
warn!("Using {} columns for non-linearity table.", num_cols);
}
let table_outputs = table_inputs
@@ -205,8 +202,8 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
let smallest = self.range.0;
let largest = self.range.1;
let inputs = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let evals = Op::<F>::f(&self.nonlinearity, &[inputs.clone()])?;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i128_to_felt(x));
let evals = self.nonlinearity.f(&[inputs.clone()])?;
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;
@@ -275,7 +272,7 @@ pub struct RangeCheck<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> F {
let chunk = chunk as i128;
@@ -303,7 +300,7 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
}
}
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// Configures the table.
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
log::debug!("range check range: {:?}", range);

View File

@@ -1048,8 +1048,8 @@ mod conv {
&mut region,
&self.inputs,
Box::new(PolyOp::Conv {
padding: [(1, 1); 2],
stride: (2, 2),
padding: vec![(1, 1); 2],
stride: vec![2; 2],
}),
)
.map_err(|_| Error::Synthesis)
@@ -1198,8 +1198,8 @@ mod conv_col_ultra_overflow {
&mut region,
&[self.image.clone(), self.kernel.clone()],
Box::new(PolyOp::Conv {
padding: [(1, 1); 2],
stride: (2, 2),
padding: vec![(1, 1); 2],
stride: vec![2; 2],
}),
)
.map_err(|_| Error::Synthesis)
@@ -1343,8 +1343,8 @@ mod conv_relu_col_ultra_overflow {
&mut region,
&[self.image.clone(), self.kernel.clone()],
Box::new(PolyOp::Conv {
padding: [(1, 1); 2],
stride: (2, 2),
padding: vec![(1, 1); 2],
stride: vec![2; 2],
}),
)
.map_err(|_| Error::Synthesis);
@@ -1911,6 +1911,8 @@ mod add_with_overflow {
#[cfg(test)]
mod add_with_overflow_and_poseidon {
use std::collections::HashMap;
use halo2curves::bn256::Fr;
use crate::circuit::modules::{
@@ -1969,8 +1971,10 @@ mod add_with_overflow_and_poseidon {
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, WIDTH> =
PoseidonChip::new(config.poseidon.clone());
let assigned_inputs_a = poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0)?;
let assigned_inputs_b = poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1)?;
let assigned_inputs_a =
poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0, &mut HashMap::new())?;
let assigned_inputs_b =
poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1, &mut HashMap::new())?;
layouter.assign_region(|| "_new_module", |_| Ok(()))?;

View File

@@ -444,7 +444,7 @@ pub enum Commands {
disable_selector_compression: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
/// Aggregates proofs :)
Aggregate {
@@ -479,7 +479,7 @@ pub enum Commands {
split_proofs: bool,
/// commitment used
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
CompileCircuit {
@@ -726,7 +726,7 @@ pub enum Commands {
logrows: u32,
/// commitment
#[arg(long, default_value = DEFAULT_COMMITMENT)]
commitment: Commitments,
commitment: Option<Commitments>,
},
#[cfg(not(target_arch = "wasm32"))]
/// Deploys an evm verifier that is generated by ezkl

View File

@@ -24,6 +24,8 @@ use crate::pfsys::{
use crate::pfsys::{save_vk, srs::*};
use crate::tensor::TensorError;
use crate::{Commitments, RunArgs};
#[cfg(not(target_arch = "wasm32"))]
use colored::Colorize;
#[cfg(unix)]
use gag::Gag;
use halo2_proofs::dev::VerifyFailure;
@@ -337,7 +339,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
logrows,
split_proofs,
disable_selector_compression,
commitment,
commitment.into(),
),
Commands::Aggregate {
proof_path,
@@ -358,7 +360,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
logrows,
check_mode,
split_proofs,
commitment,
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
Commands::Verify {
@@ -382,7 +384,7 @@ pub async fn run(command: Commands) -> Result<String, Box<dyn Error>> {
srs_path,
logrows,
reduced_srs,
commitment,
commitment.into(),
)
.map(|e| serde_json::to_string(&e).unwrap()),
#[cfg(not(target_arch = "wasm32"))]
@@ -538,7 +540,7 @@ fn check_srs_hash(
let path = get_srs_path(logrows, srs_path, commitment);
let hash = get_file_hash(&path)?;
let predefined_hash = match { crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) } {
let predefined_hash = match crate::srs_sha::PUBLIC_SRS_SHA256_HASHES.get(&logrows) {
Some(h) => h,
None => return Err(format!("SRS (k={}) hash not found in public set", logrows).into()),
};
@@ -584,7 +586,7 @@ pub(crate) async fn get_srs_cmd(
} else if let Some(settings_p) = settings_path {
if settings_p.exists() {
let settings = GraphSettings::load(&settings_p)?;
settings.run_args.commitment
settings.run_args.commitment.into()
} else {
return Err(err_string.into());
}
@@ -664,27 +666,23 @@ pub(crate) async fn gen_witness(
// if any of the settings have kzg visibility then we need to load the srs
let commitment: Commitments = settings.run_args.commitment.into();
let start_time = Instant::now();
let witness = if settings.module_requires_polycommit() {
if get_srs_path(
settings.run_args.logrows,
srs_path.clone(),
settings.run_args.commitment,
)
.exists()
{
match settings.run_args.commitment {
if get_srs_path(settings.run_args.logrows, srs_path.clone(), commitment).exists() {
match Commitments::from(settings.run_args.commitment) {
Commitments::KZG => {
let srs: ParamsKZG<Bn256> = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path.clone(),
settings.run_args.logrows,
settings.run_args.commitment,
commitment,
)?;
circuit.forward::<KZGCommitmentScheme<_>>(
&mut input,
vk.as_ref(),
Some(&srs),
false,
true,
)?
}
Commitments::IPA => {
@@ -692,22 +690,22 @@ pub(crate) async fn gen_witness(
load_params_prover::<IPACommitmentScheme<G1Affine>>(
srs_path.clone(),
settings.run_args.logrows,
settings.run_args.commitment,
commitment,
)?;
circuit.forward::<IPACommitmentScheme<_>>(
&mut input,
vk.as_ref(),
Some(&srs),
false,
true,
)?
}
}
} else {
warn!("SRS for poly commit does not exist (will be ignored)");
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, false)?
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true)?
}
} else {
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, false)?
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true)?
};
// print each variable tuple (symbol, value) as symbol=value
@@ -819,7 +817,15 @@ impl AccuracyResults {
let error = (original.clone() - calibrated.clone())?;
let abs_error = error.map(|x| x.abs());
let squared_error = error.map(|x| x.powi(2));
let percentage_error = error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
let percentage_error = error.enum_map(|i, x| {
// if everything is 0 then we can't divide by 0 so we just return 0
let res = if original[i] == 0.0 && x == 0.0 {
0.0
} else {
x / original[i]
};
Ok::<f32, TensorError>(res)
})?;
let abs_percentage_error = percentage_error.map(|x| x.abs());
errors.extend(error);
@@ -888,6 +894,7 @@ pub(crate) fn calibrate(
only_range_check_rebase: bool,
max_logrows: Option<u32>,
) -> Result<GraphSettings, Box<dyn Error>> {
use log::error;
use std::collections::HashMap;
use tabled::Table;
@@ -900,9 +907,9 @@ pub(crate) fn calibrate(
let model = Model::from_run_args(&settings.run_args, &model_path)?;
let chunks = data.split_into_batches(model.graph.input_shapes()?)?;
info!("num of calibration batches: {}", chunks.len());
info!("num calibration batches: {}", chunks.len());
info!("running onnx predictions...");
debug!("running onnx predictions...");
let original_predictions = Model::run_onnx_predictions(
&settings.run_args,
&model_path,
@@ -970,10 +977,18 @@ pub(crate) fn calibrate(
let pb = init_bar(range_grid.len() as u64);
pb.set_message("calibrating...");
let mut num_failed = 0;
let mut num_passed = 0;
for (((input_scale, param_scale), scale_rebase_multiplier), div_rebasing) in range_grid {
pb.set_message(format!(
"input scale: {}, param scale: {}, scale rebase multiplier: {}, div rebasing: {}",
input_scale, param_scale, scale_rebase_multiplier, div_rebasing
"i-scale: {}, p-scale: {}, rebase-(x): {}, div-rebase: {}, fail: {}, pass: {}",
input_scale.to_string().blue(),
param_scale.to_string().blue(),
scale_rebase_multiplier.to_string().blue(),
div_rebasing.to_string().yellow(),
num_failed.to_string().red(),
num_passed.to_string().green()
));
let key = (
@@ -1007,7 +1022,9 @@ pub(crate) fn calibrate(
let mut circuit = match GraphCircuit::from_run_args(&local_run_args, &model_path) {
Ok(c) => c,
Err(e) => {
debug!("circuit creation from run args failed: {:?}", e);
error!("circuit creation from run args failed: {:?}", e);
pb.inc(1);
num_failed += 1;
continue;
}
};
@@ -1039,7 +1056,9 @@ pub(crate) fn calibrate(
Ok(_) => (),
// typically errors will be due to the circuit overflowing the i128 limit
Err(e) => {
debug!("forward pass failed: {:?}", e);
error!("forward pass failed: {:?}", e);
pb.inc(1);
num_failed += 1;
continue;
}
}
@@ -1104,8 +1123,10 @@ pub(crate) fn calibrate(
"found settings: \n {}",
found_settings.as_json()?.to_colored_json_auto()?
);
num_passed += 1;
} else {
debug!("calibration failed {}", res.err().unwrap());
error!("calibration failed {}", res.err().unwrap());
num_failed += 1;
}
pb.inc(1);
@@ -1278,17 +1299,19 @@ pub(crate) fn create_evm_verifier(
render_vk_seperately: bool,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
circuit_settings.run_args.commitment,
settings.run_args.logrows,
commitment,
)?;
let num_instance = circuit_settings.total_instances();
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
@@ -1322,17 +1345,18 @@ pub(crate) fn create_evm_vk(
abi_path: PathBuf,
) -> Result<String, Box<dyn Error>> {
check_solc_requirement();
let circuit_settings = GraphSettings::load(&settings_path)?;
let settings = GraphSettings::load(&settings_path)?;
let commitment: Commitments = settings.run_args.commitment.into();
let params = load_params_verifier::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
circuit_settings.run_args.commitment,
settings.run_args.logrows,
commitment,
)?;
let num_instance = circuit_settings.total_instances();
let num_instance = settings.total_instances();
let num_instance: usize = num_instance.iter().sum::<usize>();
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, circuit_settings)?;
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)?;
trace!("params computed");
let generator = halo2_solidity_verifier::SolidityGenerator::new(
@@ -1601,8 +1625,9 @@ pub(crate) fn setup(
}
let logrows = circuit.settings().run_args.logrows;
let commitment: Commitments = circuit.settings().run_args.commitment.into();
let pk = match circuit.settings().run_args.commitment {
let pk = match commitment {
Commitments::KZG => {
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path,
@@ -1711,7 +1736,8 @@ pub(crate) fn prove(
let transcript: TranscriptType = proof_type.into();
let proof_split_commits: Option<ProofSplitCommit> = data.into();
let commitment = circuit_settings.run_args.commitment;
let commitment = circuit_settings.run_args.commitment.into();
let logrows = circuit_settings.run_args.logrows;
// creates and verifies the proof
let mut snark = match commitment {
Commitments::KZG => {
@@ -1720,7 +1746,7 @@ pub(crate) fn prove(
let params = load_params_prover::<KZGCommitmentScheme<Bn256>>(
srs_path,
circuit_settings.run_args.logrows,
logrows,
Commitments::KZG,
)?;
match strategy {
@@ -1879,7 +1905,9 @@ pub(crate) fn mock_aggregate(
}
Err(_) => {
return Err(
format!("invalid sample commitment type for aggregation, must be KZG").into(),
"invalid sample commitment type for aggregation, must be KZG"
.to_string()
.into(),
);
}
}
@@ -1922,7 +1950,9 @@ pub(crate) fn setup_aggregate(
}
Err(_) => {
return Err(
format!("invalid sample commitment type for aggregation, must be KZG",).into(),
"invalid sample commitment type for aggregation, must be KZG"
.to_string()
.into(),
);
}
}
@@ -1983,7 +2013,9 @@ pub(crate) fn aggregate(
}
Err(_) => {
return Err(
format!("invalid sample commitment type for aggregation, must be KZG").into(),
"invalid sample commitment type for aggregation, must be KZG"
.to_string()
.into(),
);
}
}
@@ -2156,8 +2188,9 @@ pub(crate) fn verify(
let circuit_settings = GraphSettings::load(&settings_path)?;
let logrows = circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
match circuit_settings.run_args.commitment {
match commitment {
Commitments::KZG => {
let proof = Snark::load::<KZGCommitmentScheme<Bn256>>(&proof_path)?;
let params: ParamsKZG<Bn256> = if reduced_srs {

View File

@@ -26,6 +26,7 @@ use self::input::{FileSource, GraphData};
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::region::ConstantsMap;
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::felt_to_f64;
@@ -38,7 +39,7 @@ use halo2_proofs::{
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
};
use halo2curves::bn256::{self, Fr as Fp, G1Affine};
use halo2curves::ff::PrimeField;
use halo2curves::ff::{Field, PrimeField};
#[cfg(not(target_arch = "wasm32"))]
use lazy_static::lazy_static;
use log::{debug, error, trace, warn};
@@ -155,7 +156,7 @@ use std::cell::RefCell;
thread_local!(
/// This is a global variable that holds the settings for the graph
/// This is used to pass settings to the layouter and other parts of the circuit without needing to heavily modify the Halo2 API in a new fork
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = RefCell::new(None)
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = const { RefCell::new(None) }
);
/// Result from a forward pass
@@ -1051,12 +1052,10 @@ impl GraphCircuit {
}
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
let margin = (
(
lookup_safety_margin * min_max_lookup.0,
lookup_safety_margin * min_max_lookup.1,
);
margin
)
}
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
@@ -1240,7 +1239,7 @@ impl GraphCircuit {
inputs: &mut [Tensor<Fp>],
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&Scheme::ParamsProver>,
throw_range_check_error: bool,
witness_gen: bool,
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
let original_inputs = inputs.to_vec();
@@ -1289,7 +1288,7 @@ impl GraphCircuit {
let mut model_results =
self.model()
.forward(inputs, &self.settings().run_args, throw_range_check_error)?;
.forward(inputs, &self.settings().run_args, witness_gen)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();
@@ -1452,7 +1451,8 @@ impl GraphCircuit {
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
struct CircuitSize {
/// The configuration for the graph circuit
pub struct CircuitSize {
num_instances: usize,
num_advice_columns: usize,
num_fixed: usize,
@@ -1462,7 +1462,8 @@ struct CircuitSize {
}
impl CircuitSize {
pub fn from_cs(cs: &ConstraintSystem<Fp>, logrows: u32) -> Self {
///
pub fn from_cs<F: Field>(cs: &ConstraintSystem<F>, logrows: u32) -> Self {
CircuitSize {
num_instances: cs.num_instance_columns(),
num_advice_columns: cs.num_advice_columns(),
@@ -1604,6 +1605,8 @@ impl Circuit<Fp> for GraphCircuit {
let output_vis = &self.settings().run_args.output_visibility;
let mut graph_modules = GraphModules::new();
let mut constants = ConstantsMap::new();
let mut config = config.clone();
let mut inputs = self
@@ -1649,6 +1652,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut input_outlets,
input_visibility,
&mut instance_offset,
&mut constants,
)?;
// replace inputs with the outlets
for (i, outlet) in outlets.iter().enumerate() {
@@ -1661,6 +1665,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut inputs,
input_visibility,
&mut instance_offset,
&mut constants,
)?;
}
@@ -1697,6 +1702,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut flattened_params,
param_visibility,
&mut instance_offset,
&mut constants,
)?;
let shapes = self.model().const_shapes();
@@ -1725,6 +1731,7 @@ impl Circuit<Fp> for GraphCircuit {
&inputs,
&mut vars,
&outputs,
&mut constants,
)
.map_err(|e| {
log::error!("{}", e);
@@ -1749,6 +1756,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut output_outlets,
&self.settings().run_args.output_visibility,
&mut instance_offset,
&mut constants,
)?;
// replace outputs with the outlets
@@ -1762,6 +1770,7 @@ impl Circuit<Fp> for GraphCircuit {
&mut outputs,
&self.settings().run_args.output_visibility,
&mut instance_offset,
&mut constants,
)?;
}

View File

@@ -5,6 +5,7 @@ use super::vars::*;
use super::GraphError;
use super::GraphSettings;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::ConstantsMap;
use crate::circuit::region::RegionCtx;
use crate::circuit::table::Range;
use crate::circuit::Input;
@@ -404,7 +405,7 @@ impl ParsedNodes {
.get(input)
.ok_or(GraphError::MissingNode(*input))?;
let input_dims = node.out_dims();
let input_dim = input_dims.get(0).ok_or(GraphError::MissingNode(*input))?;
let input_dim = input_dims.first().ok_or(GraphError::MissingNode(*input))?;
inputs.push(input_dim.clone());
}
@@ -514,21 +515,24 @@ impl Model {
instance_shapes.len().to_string().blue(),
"instances".blue()
);
// this is the total number of variables we will need to allocate
// for the circuit
let default_value = if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let inputs: Vec<ValTensor<Fp>> = self
.graph
.input_shapes()?
.iter()
.map(|shape| {
let mut t: ValTensor<Fp> =
vec![default_value.clone(); shape.iter().product()].into();
let len = shape.iter().product();
let mut t: ValTensor<Fp> = (0..len)
.map(|_| {
if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::random(&mut rand::thread_rng()))
}
})
.collect::<Vec<_>>()
.into();
t.reshape(shape)?;
Ok(t)
})
@@ -577,13 +581,13 @@ impl Model {
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
throw_range_check_error: bool,
witness_gen: bool,
) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?;
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen)?;
Ok(res.into())
}
@@ -799,13 +803,18 @@ impl Model {
let input_state_idx = input_state_idx(&input_mappings);
let mut output_mappings = vec![];
for mapping in b.output_mapping.iter() {
for (i, mapping) in b.output_mapping.iter().enumerate() {
let mut mappings = vec![];
if let Some(outlet) = mapping.last_value_slot {
mappings.push(OutputMapping::Single {
outlet,
is_state: mapping.state,
});
} else if mapping.state {
mappings.push(OutputMapping::Single {
outlet: i,
is_state: mapping.state,
});
}
if let Some(last) = mapping.scan {
mappings.push(OutputMapping::Stacked {
@@ -814,6 +823,7 @@ impl Model {
is_state: false,
});
}
output_mappings.push(mappings);
}
@@ -1071,6 +1081,8 @@ impl Model {
/// * `layouter` - Halo2 Layouter.
/// * `inputs` - The values to feed into the circuit.
/// * `vars` - The variables for the circuit.
/// * `witnessed_outputs` - The values to compare against.
/// * `constants` - The constants for the circuit.
pub fn layout(
&self,
mut config: ModelConfig,
@@ -1079,6 +1091,7 @@ impl Model {
inputs: &[ValTensor<Fp>],
vars: &mut ModelVars<Fp>,
witnessed_outputs: &[ValTensor<Fp>],
constants: &mut ConstantsMap<Fp>,
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
info!("model layout...");
@@ -1104,14 +1117,12 @@ impl Model {
config.base.layout_tables(layouter)?;
config.base.layout_range_checks(layouter)?;
let mut num_rows = 0;
let mut linear_coord = 0;
let mut total_const_size = 0;
let original_constants = constants.clone();
let outputs = layouter.assign_region(
|| "model",
|region| {
let mut thread_safe_region = RegionCtx::new(region, 0, run_args.num_inner_cols);
let mut thread_safe_region = RegionCtx::new_with_constants(region, 0, run_args.num_inner_cols, original_constants.clone());
// we need to do this as this loop is called multiple times
vars.set_instance_idx(instance_idx);
@@ -1157,29 +1168,17 @@ impl Model {
error!("{}", e);
halo2_proofs::plonk::Error::Synthesis
})?;
} else if !run_args.output_visibility.is_private() {
for output in &outputs {
thread_safe_region.increment_total_constants(output.num_constants());
}
}
num_rows = thread_safe_region.row();
linear_coord = thread_safe_region.linear_coord();
total_const_size = thread_safe_region.total_constants();
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
thread_safe_region.debug_report();
*constants = thread_safe_region.assigned_constants().clone();
Ok(outputs)
},
)?;
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
num_rows.to_string().blue(),
"rows".blue(),
linear_coord.to_string().yellow(),
total_const_size.to_string().red()
);
)?;
let duration = start_time.elapsed();
trace!("model layout took: {:?}", duration);
@@ -1201,6 +1200,20 @@ impl Model {
.collect();
for (idx, node) in self.graph.nodes.iter() {
debug!("laying out {}: {}", idx, node.as_str(),);
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
region.debug_report();
debug!("input indices: {:?}", node.inputs());
debug!("output scales: {:?}", node.out_scales());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);
let mut values: Vec<ValTensor<Fp>> = if !node.is_input() {
node.inputs()
.iter()
@@ -1212,31 +1225,11 @@ impl Model {
// we re-assign inputs, always from the 0 outlet
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
};
debug!("output dims: {:?}", node.out_dims());
debug!(
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
idx,
node.as_str(),
region.row(),
region.linear_coord(),
region.total_constants(),
region.max_lookup_inputs(),
region.min_lookup_inputs()
);
debug!("dims: {:?}", node.out_dims());
debug!(
"input_dims {:?}",
"input dims {:?}",
values.iter().map(|v| v.dims()).collect_vec()
);
debug!("output scales: {:?}", node.out_scales());
debug!("input indices: {:?}", node.inputs());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);
match &node {
NodeType::Node(n) => {
@@ -1277,8 +1270,8 @@ impl Model {
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());
debug!(
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
num_iter, inputs, model.graph.inputs
"{} iteration(s) in a subgraph with inputs {:?}, sources {:?}, and outputs {:?}",
num_iter, inputs, model.graph.inputs, model.graph.outputs
);
let mut full_results: Vec<ValTensor<Fp>> = vec![];
@@ -1310,6 +1303,7 @@ impl Model {
let res = model.layout_nodes(config, region, &mut subgraph_results)?;
let mut outlets = BTreeMap::new();
let mut stacked_outlets = BTreeMap::new();
for (mappings, outlet_res) in output_mappings.iter().zip(res) {
for mapping in mappings {
@@ -1322,25 +1316,42 @@ impl Model {
let stacked_res = full_results[*outlet]
.clone()
.concat_axis(outlet_res.clone(), axis)?;
outlets.insert(outlet, stacked_res);
} else {
outlets.insert(outlet, outlet_res.clone());
stacked_outlets.insert(outlet, stacked_res);
}
outlets.insert(outlet, outlet_res.clone());
}
}
}
}
full_results = outlets.into_values().collect_vec();
// now extend with stacked elements
let mut pre_stacked_outlets = outlets.clone();
pre_stacked_outlets.extend(stacked_outlets);
let outlets = outlets.into_values().collect_vec();
full_results = pre_stacked_outlets.into_values().collect_vec();
let output_states = output_state_idx(output_mappings);
let input_states = input_state_idx(&input_mappings);
assert_eq!(input_states.len(), output_states.len());
assert_eq!(
input_states.len(),
output_states.len(),
"input and output states must be the same length, got {:?} and {:?}",
input_mappings,
output_mappings
);
for (input_idx, output_idx) in input_states.iter().zip(output_states) {
values[*input_idx] = full_results[output_idx].clone();
assert_eq!(
values[*input_idx].dims(),
outlets[output_idx].dims(),
"input and output dims must be the same, got {:?} and {:?}",
values[*input_idx].dims(),
outlets[output_idx].dims()
);
values[*input_idx] = outlets[output_idx].clone();
}
}
@@ -1380,7 +1391,7 @@ impl Model {
&self,
run_args: &RunArgs,
inputs: &[ValTensor<Fp>],
throw_range_check_error: bool,
witness_gen: bool,
) -> Result<DummyPassRes, Box<dyn Error>> {
debug!("calculating num of constraints using dummy model layout...");
@@ -1399,29 +1410,31 @@ impl Model {
vars: ModelVars::new_dummy(),
};
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error);
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen);
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
let default_value = if !self.visibility.output.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let output_scales = self.graph.get_output_scales()?;
let res = outputs
.iter()
.enumerate()
.map(|(i, output)| {
let mut comparator: ValTensor<Fp> = (0..output.len())
.map(|_| {
if !self.visibility.output.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::random(&mut rand::thread_rng()))
}
})
.collect::<Vec<_>>()
.into();
comparator.reshape(output.dims())?;
let mut tolerance = run_args.tolerance;
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
let mut comparator: ValTensor<Fp> =
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
comparator.reshape(output.dims())?;
dummy_config.layout(
&mut region,
&[output.clone(), comparator],
@@ -1432,7 +1445,7 @@ impl Model {
res?;
} else if !self.visibility.output.is_private() {
for output in &outputs {
region.increment_total_constants(output.num_constants());
region.update_constants(output.create_constants_map());
}
}
@@ -1441,14 +1454,7 @@ impl Model {
// Then number of columns in the circuits
#[cfg(not(target_arch = "wasm32"))]
debug!(
"{} {} {} (coord={}, constants={})",
"model uses".blue(),
region.row().to_string().blue(),
"rows".blue(),
region.linear_coord().to_string().yellow(),
region.total_constants().to_string().red()
);
region.debug_report();
let outputs = outputs
.iter()

View File

@@ -2,6 +2,7 @@ use crate::circuit::modules::polycommit::{PolyCommitChip, PolyCommitConfig};
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use crate::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
use crate::circuit::modules::Module;
use crate::circuit::region::ConstantsMap;
use crate::tensor::{Tensor, ValTensor};
use halo2_proofs::circuit::Layouter;
use halo2_proofs::plonk::{Column, ConstraintSystem, Error, Instance, VerifyingKey};
@@ -211,12 +212,13 @@ impl GraphModules {
layouter: &mut impl Layouter<Fp>,
x: &mut Vec<ValTensor<Fp>>,
instance_offset: &mut usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<(), Error> {
// reserve module 0 for ... modules
// hash the input and replace the constrained cells in the input
let cloned_x = (*x).clone();
x[0] = module
.layout(layouter, &cloned_x, instance_offset.to_owned())
.layout(layouter, &cloned_x, instance_offset.to_owned(), constants)
.unwrap();
for inc in module.instance_increment_input().iter() {
// increment the instance offset to make way for future module layouts
@@ -234,6 +236,7 @@ impl GraphModules {
values: &mut [ValTensor<Fp>],
element_visibility: &Visibility,
instance_offset: &mut usize,
constants: &mut ConstantsMap<Fp>,
) -> Result<(), Error> {
if element_visibility.is_polycommit() && !values.is_empty() {
// concat values and sk to get the inputs
@@ -248,7 +251,7 @@ impl GraphModules {
layouter
.assign_region(|| format!("_enter_module_{}", module_offset), |_| Ok(()))
.unwrap();
Self::layout_module(&chip, layouter, x, instance_offset).unwrap();
Self::layout_module(&chip, layouter, x, instance_offset, constants).unwrap();
// increment the current index
self.polycommit_idx += 1;
});
@@ -270,7 +273,7 @@ impl GraphModules {
let mut inputs = values.iter_mut().map(|x| vec![x.clone()]).collect_vec();
// layout the module
inputs.iter_mut().for_each(|x| {
Self::layout_module(&chip, layouter, x, instance_offset).unwrap();
Self::layout_module(&chip, layouter, x, instance_offset, constants).unwrap();
});
// replace the inputs with the outputs
values.iter_mut().enumerate().for_each(|(i, x)| {

View File

@@ -14,7 +14,6 @@ use crate::circuit::Op;
use crate::circuit::Unknown;
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::new_op_from_onnx;
use crate::tensor::Tensor;
use crate::tensor::TensorError;
use halo2curves::bn256::Fr as Fp;
#[cfg(not(target_arch = "wasm32"))]
@@ -61,20 +60,6 @@ impl Op<Fp> for Rescaled {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
if self.scale.len() != x.len() {
return Err(TensorError::DimMismatch("rescaled inputs".to_string()));
}
let mut rescaled_inputs = vec![];
let inputs = &mut x.to_vec();
for (i, ri) in inputs.iter_mut().enumerate() {
let mult_tensor = Tensor::from([Fp::from(self.scale[i].1 as u64)].into_iter());
let res = (ri.clone() * mult_tensor)?;
rescaled_inputs.push(res);
}
Op::<Fp>::f(&*self.inner, &rescaled_inputs)
}
fn as_string(&self) -> String {
format!("RESCALED INPUT ({})", self.inner.as_string())
@@ -215,13 +200,6 @@ impl Op<Fp> for RebaseScale {
fn as_any(&self) -> &dyn std::any::Any {
self
}
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
let mut res = Op::<Fp>::f(&*self.inner, x)?;
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
res.output = rebase_res.output;
Ok(res)
}
fn as_string(&self) -> String {
format!(
@@ -389,13 +367,6 @@ impl From<Box<dyn Op<Fp>>> for SupportedOp {
}
impl Op<Fp> for SupportedOp {
fn f(
&self,
inputs: &[Tensor<Fp>],
) -> Result<crate::circuit::ForwardResult<Fp>, crate::tensor::TensorError> {
self.as_op().f(inputs)
}
fn layout(
&self,
config: &mut crate::circuit::BaseConfig<Fp>,

View File

@@ -509,7 +509,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterND {
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
@@ -545,7 +545,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherND {
batch_dims,
indices: Some(c.raw_values.map(|x| x as usize)),
@@ -582,7 +582,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
@@ -734,6 +734,19 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::Sum { axes })
}
"Reduce<MeanOfSquares>" => {
if inputs.len() != 1 {
return Err(Box::new(GraphError::InvalidDims(
idx,
"mean of squares".to_string(),
)));
};
let op = load_op::<Reduce>(node.op(), idx, node.op().name().to_string())?;
let axes = op.axes.into_iter().collect();
SupportedOp::Linear(PolyOp::MeanOfSquares { axes })
}
"Max" => {
// Extract the max value
// first find the input that is a constant
@@ -1072,8 +1085,12 @@ pub fn new_op_from_onnx(
}
};
let in_scale = inputs[0].out_scales()[0];
let max_scale = std::cmp::max(scales.get_max(), in_scale);
SupportedOp::Hybrid(HybridOp::Softmax {
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
input_scale: scale_to_multiplier(in_scale).into(),
output_scale: scale_to_multiplier(max_scale).into(),
axes: softmax_op.axes.to_vec(),
})
}
@@ -1102,17 +1119,7 @@ pub fn new_op_from_onnx(
.ok_or(GraphError::MissingParams("stride".to_string()))?;
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
if b.len() == 2 && a.len() == 2 {
[(b[0], b[1]), (a[0], a[1])]
} else if b.len() == 1 && a.len() == 1 {
[(b[0], b[0]), (a[0], a[0])]
} else if b.len() == 1 && a.len() == 2 {
[(b[0], b[0]), (a[0], a[1])]
} else if b.len() == 2 && a.len() == 1 {
[(b[0], b[1]), (a[0], a[0])]
} else {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
@@ -1120,26 +1127,10 @@ pub fn new_op_from_onnx(
};
let kernel_shape = &pool_spec.kernel_shape;
let (stride_h, stride_w) = if stride.len() == 1 {
(1, stride[0])
} else if stride.len() == 2 {
(stride[0], stride[1])
} else {
return Err(Box::new(GraphError::MissingParams("stride".to_string())));
};
let (kernel_height, kernel_width) = if kernel_shape.len() == 1 {
(1, kernel_shape[0])
} else if kernel_shape.len() == 2 {
(kernel_shape[0], kernel_shape[1])
} else {
return Err(Box::new(GraphError::MissingParams("kernel".to_string())));
};
SupportedOp::Hybrid(HybridOp::MaxPool2d {
SupportedOp::Hybrid(HybridOp::MaxPool {
padding,
stride: (stride_h, stride_w),
pool_dims: (kernel_height, kernel_width),
stride: stride.to_vec(),
pool_dims: kernel_shape.to_vec(),
})
}
"Ceil" => SupportedOp::Nonlinear(LookupOp::Ceil {
@@ -1161,7 +1152,7 @@ pub fn new_op_from_onnx(
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
deleted_indices.push(1);
if c.raw_values.len() > 1 {
unimplemented!("only support scalar pow")
}
@@ -1201,15 +1192,7 @@ pub fn new_op_from_onnx(
}
let stride = match conv_node.pool_spec.strides.clone() {
Some(s) => {
if s.len() == 1 {
(s[0], s[0])
} else if s.len() == 2 {
(s[0], s[1])
} else {
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
}
}
Some(s) => s.to_vec(),
None => {
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
}
@@ -1217,17 +1200,7 @@ pub fn new_op_from_onnx(
let padding = match &conv_node.pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
if b.len() == 2 && a.len() == 2 {
[(b[0], b[1]), (a[0], a[1])]
} else if b.len() == 1 && a.len() == 1 {
[(b[0], b[0]), (a[0], a[0])]
} else if b.len() == 1 && a.len() == 2 {
[(b[0], b[0]), (a[0], a[1])]
} else if b.len() == 2 && a.len() == 1 {
[(b[0], b[1]), (a[0], a[0])]
} else {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
@@ -1282,33 +1255,20 @@ pub fn new_op_from_onnx(
}
let stride = match deconv_node.pool_spec.strides.clone() {
Some(s) => (s[0], s[1]),
Some(s) => s.to_vec(),
None => {
return Err(Box::new(GraphError::MissingParams("strides".to_string())));
}
};
let padding = match &deconv_node.pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
if b.len() == 2 && a.len() == 2 {
[(b[0], b[1]), (a[0], a[1])]
} else if b.len() == 1 && a.len() == 1 {
[(b[0], b[0]), (a[0], a[0])]
} else if b.len() == 1 && a.len() == 2 {
[(b[0], b[0]), (a[0], a[1])]
} else if b.len() == 2 && a.len() == 1 {
[(b[0], b[1]), (a[0], a[0])]
} else {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
};
let output_padding: (usize, usize) =
(deconv_node.adjustments[0], deconv_node.adjustments[1]);
// if bias exists then rescale it to the input + kernel scale
if input_scales.len() == 3 {
let bias_scale = input_scales[2];
@@ -1327,7 +1287,7 @@ pub fn new_op_from_onnx(
SupportedOp::Linear(PolyOp::DeConv {
padding,
output_padding,
output_padding: deconv_node.adjustments.to_vec(),
stride,
})
}
@@ -1428,46 +1388,17 @@ pub fn new_op_from_onnx(
.ok_or(GraphError::MissingParams("stride".to_string()))?;
let padding = match &pool_spec.padding {
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
if b.len() == 2 && a.len() == 2 {
[(b[0], b[1]), (a[0], a[1])]
} else if b.len() == 1 && a.len() == 1 {
[(b[0], b[0]), (a[0], a[0])]
} else if b.len() == 1 && a.len() == 2 {
[(b[0], b[0]), (a[0], a[1])]
} else if b.len() == 2 && a.len() == 1 {
[(b[0], b[1]), (a[0], a[0])]
} else {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
}
_ => {
return Err(Box::new(GraphError::MissingParams("padding".to_string())));
}
};
let kernel_shape = &pool_spec.kernel_shape;
let (stride_h, stride_w) = if stride.len() == 1 {
(1, stride[0])
} else if stride.len() == 2 {
(stride[0], stride[1])
} else {
return Err(Box::new(GraphError::MissingParams("stride".to_string())));
};
let (kernel_height, kernel_width) = if kernel_shape.len() == 1 {
(1, kernel_shape[0])
} else if kernel_shape.len() == 2 {
(kernel_shape[0], kernel_shape[1])
} else {
return Err(Box::new(GraphError::MissingParams(
"kernel shape".to_string(),
)));
};
SupportedOp::Hybrid(HybridOp::SumPool {
padding,
stride: (stride_h, stride_w),
kernel_shape: (kernel_height, kernel_width),
stride: stride.to_vec(),
kernel_shape: pool_spec.kernel_shape.to_vec(),
normalized: sumpool_node.normalize,
})
}
@@ -1494,29 +1425,7 @@ pub fn new_op_from_onnx(
)));
}
let padding_len = pad_node.pads.len();
// we only support symmetrical padding that affects the last 2 dims (height and width params)
for (i, pad_params) in pad_node.pads.iter().enumerate() {
if (i < padding_len - 2) && ((pad_params.0 != 0) || (pad_params.1 != 0)) {
return Err(Box::new(GraphError::MisformedParams(
"ezkl currently only supports padding height and width dimensions"
.to_string(),
)));
}
}
let padding = [
(
pad_node.pads[padding_len - 2].0,
pad_node.pads[padding_len - 1].0,
),
(
pad_node.pads[padding_len - 2].1,
pad_node.pads[padding_len - 1].1,
),
];
SupportedOp::Linear(PolyOp::Pad(padding))
SupportedOp::Linear(PolyOp::Pad(pad_node.pads.to_vec()))
}
"RmAxis" | "Reshape" | "AddAxis" => {
// Extract the slope layer hyperparams

View File

@@ -346,7 +346,7 @@ pub struct ModelVars<F: PrimeField + TensorType + PartialOrd> {
pub instance: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
/// Get instance col
pub fn get_instance_col(&self) -> Option<&Column<Instance>> {
if let Some(instance) = &self.instance {

View File

@@ -23,7 +23,6 @@
)]
// we allow this for our dynamic range based indexing scheme
#![allow(clippy::single_range_in_vec_init)]
#![feature(round_ties_even)]
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
//!
@@ -115,6 +114,12 @@ pub enum Commitments {
IPA,
}
impl From<Option<Commitments>> for Commitments {
fn from(value: Option<Commitments>) -> Self {
value.unwrap_or(Commitments::KZG)
}
}
impl FromStr for Commitments {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
@@ -215,7 +220,7 @@ pub struct RunArgs {
pub check_mode: CheckMode,
/// commitment scheme
#[arg(long, default_value = "kzg")]
pub commitment: Commitments,
pub commitment: Option<Commitments>,
}
impl Default for RunArgs {
@@ -235,7 +240,7 @@ impl Default for RunArgs {
div_rebasing: false,
rebase_frac_zero_constants: false,
check_mode: CheckMode::UNSAFE,
commitment: Commitments::KZG,
commitment: None,
}
}
}

View File

@@ -1,4 +1,8 @@
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::CircuitSize;
use crate::pfsys::{Snark, SnarkWitness};
#[cfg(not(target_arch = "wasm32"))]
use colored_json::ToColoredJson;
use halo2_proofs::circuit::AssignedCell;
use halo2_proofs::plonk::{self};
use halo2_proofs::{
@@ -16,6 +20,8 @@ use halo2_wrong_ecc::{
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
#[cfg(not(target_arch = "wasm32"))]
use log::debug;
use log::trace;
use rand::rngs::OsRng;
use snark_verifier::loader::native::NativeLoader;
@@ -193,6 +199,23 @@ impl AggregationConfig {
let main_gate_config = MainGate::<F>::configure(meta);
let range_config =
RangeChip::<F>::configure(meta, &main_gate_config, composition_bits, overflow_bits);
#[cfg(not(target_arch = "wasm32"))]
{
let circuit_size = CircuitSize::from_cs(meta, 23);
// not wasm
debug!(
"circuit size: \n {}",
circuit_size
.as_json()
.unwrap()
.to_colored_json_auto()
.unwrap()
);
}
AggregationConfig {
main_gate_config,
range_config,

View File

@@ -33,6 +33,7 @@ use tokio::runtime::Runtime;
type PyFelt = String;
/// pyclass representing an enum
#[pyclass]
#[derive(Debug, Clone)]
enum PyTestDataSource {
@@ -51,14 +52,17 @@ impl From<PyTestDataSource> for TestDataSource {
}
}
/// pyclass containing the struct used for G1
/// pyclass containing the struct used for G1, this is mostly a helper class
#[pyclass]
#[derive(Debug, Clone)]
struct PyG1 {
#[pyo3(get, set)]
/// Field Element representing x
x: PyFelt,
#[pyo3(get, set)]
/// Field Element representing y
y: PyFelt,
/// Field Element representing y
#[pyo3(get, set)]
z: PyFelt,
}
@@ -134,39 +138,59 @@ impl pyo3::ToPyObject for PyG1Affine {
}
}
/// pyclass containing the struct used for run_args
/// Python class containing the struct used for run_args
///
/// Returns
/// -------
/// PyRunArgs
///
#[pyclass]
#[derive(Clone)]
struct PyRunArgs {
#[pyo3(get, set)]
/// float: The tolerance for error on model outputs
pub tolerance: f32,
#[pyo3(get, set)]
/// int: The denominator in the fixed point representation used when quantizing inputs
pub input_scale: crate::Scale,
#[pyo3(get, set)]
/// int: The denominator in the fixed point representation used when quantizing parameters
pub param_scale: crate::Scale,
#[pyo3(get, set)]
/// int: If the scale is ever > scale_rebase_multiplier * input_scale then the scale is rebased to input_scale (this a more advanced parameter, use with caution)
pub scale_rebase_multiplier: u32,
#[pyo3(get, set)]
/// list[int]: The min and max elements in the lookup table input column
pub lookup_range: crate::circuit::table::Range,
#[pyo3(get, set)]
/// int: The log_2 number of rows
pub logrows: u32,
#[pyo3(get, set)]
/// int: The number of inner columns used for the lookup table
pub num_inner_cols: usize,
#[pyo3(get, set)]
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
pub input_visibility: Visibility,
#[pyo3(get, set)]
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
pub output_visibility: Visibility,
#[pyo3(get, set)]
/// string: accepts `public`, `private`, `fixed`, `hashed/public`, `hashed/private`, `polycommit`
pub param_visibility: Visibility,
#[pyo3(get, set)]
/// list[tuple[str, int]]: Hand-written parser for graph variables, eg. batch_size=1
pub variables: Vec<(String, usize)>,
#[pyo3(get, set)]
/// bool: Rebase the scale using lookup table for division instead of using a range check
pub div_rebasing: bool,
#[pyo3(get, set)]
/// bool: Should constants with 0.0 fraction be rebased to scale 0
pub rebase_frac_zero_constants: bool,
#[pyo3(get, set)]
/// str: check mode, accepts `safe`, `unsafe`
pub check_mode: CheckMode,
#[pyo3(get, set)]
/// str: commitment type, accepts `kzg`, `ipa`
pub commitment: PyCommitments,
}
@@ -197,7 +221,7 @@ impl From<PyRunArgs> for RunArgs {
div_rebasing: py_run_args.div_rebasing,
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
check_mode: py_run_args.check_mode,
commitment: py_run_args.commitment.into(),
commitment: Some(py_run_args.commitment.into()),
}
}
}
@@ -226,7 +250,7 @@ impl Into<PyRunArgs> for RunArgs {
#[pyclass]
#[derive(Debug, Clone)]
/// Pyclass marking the type of commitment
/// pyclass representing an enum, denoting the type of commitment
pub enum PyCommitments {
/// KZG commitment
KZG,
@@ -234,6 +258,16 @@ pub enum PyCommitments {
IPA,
}
impl From<Option<Commitments>> for PyCommitments {
fn from(commitment: Option<Commitments>) -> Self {
match commitment {
Some(Commitments::KZG) => PyCommitments::KZG,
Some(Commitments::IPA) => PyCommitments::IPA,
None => PyCommitments::KZG,
}
}
}
impl From<PyCommitments> for Commitments {
fn from(py_commitments: PyCommitments) -> Self {
match py_commitments {
@@ -263,7 +297,19 @@ impl FromStr for PyCommitments {
}
}
/// Converts a felt to big endian
/// Converts a field element hex string to big endian
///
/// Arguments
/// -------
/// felt: str
/// The field element represented as a string
///
///
/// Returns
/// -------
/// str
/// field element represented as a string
///
#[pyfunction(signature = (
felt,
))]
@@ -273,22 +319,45 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
}
/// Converts a field element hex string to an integer
///
/// Arguments
/// -------
/// felt: str
/// The field element represented as a string
///
/// Returns
/// -------
/// int
///
#[pyfunction(signature = (
array,
felt,
))]
fn felt_to_int(array: PyFelt) -> PyResult<i128> {
let felt = crate::pfsys::string_to_field::<Fr>(&array);
fn felt_to_int(felt: PyFelt) -> PyResult<i128> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_i128(felt);
Ok(int_rep)
}
/// Converts a field eleement hex string to a floating point number
/// Converts a field element hex string to a floating point number
///
/// Arguments
/// -------
/// felt: str
/// The field element represented as a string
///
/// scale: float
/// The scaling factor used to convert the field element into a floating point representation
///
/// Returns
/// -------
/// float
///
#[pyfunction(signature = (
array,
felt,
scale
))]
fn felt_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&array);
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_i128(felt);
let multiplier = scale_to_multiplier(scale);
let float_rep = int_rep as f64 / multiplier;
@@ -296,9 +365,23 @@ fn felt_to_float(array: PyFelt, scale: crate::Scale) -> PyResult<f64> {
}
/// Converts a floating point element to a field element hex string
///
/// Arguments
/// -------
/// input: float
/// The field element represented as a string
///
/// scale: float
/// The scaling factor used to quantize the float into a field element
///
/// Returns
/// -------
/// str
/// The field element represented as a string
///
#[pyfunction(signature = (
input,
scale
input,
scale
))]
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
let int_rep = quantize_float(&input, 0.0, scale)
@@ -308,9 +391,20 @@ fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
}
/// Converts a buffer to vector of field elements
///
/// Arguments
/// -------
/// buffer: list[int]
/// List of integers representing a buffer
///
/// Returns
/// -------
/// list[str]
/// List of field elements represented as strings
///
#[pyfunction(signature = (
buffer
))]
))]
fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
fn u8_array_to_u128_le(arr: [u8; 16]) -> u128 {
let mut n: u128 = 0;
@@ -369,9 +463,20 @@ fn buffer_to_felts(buffer: Vec<u8>) -> PyResult<Vec<String>> {
}
/// Generate a poseidon hash.
///
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
///
/// Returns
/// -------
/// list[str]
/// List of field elements represented as strings
///
#[pyfunction(signature = (
message,
))]
))]
fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
let message: Vec<Fr> = message
.iter()
@@ -392,12 +497,31 @@ fn poseidon_hash(message: Vec<PyFelt>) -> PyResult<Vec<PyFelt>> {
}
/// Generate a kzg commitment.
///
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
///
/// vk_path: str
/// Path to the verification key
///
/// settings_path: str
/// Path to the settings file
///
/// srs_path: str
/// Path to the Structure Reference String (SRS) file
///
/// Returns
/// -------
/// list[PyG1Affine]
///
#[pyfunction(signature = (
message,
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
))]
fn kzg_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
@@ -432,12 +556,31 @@ fn kzg_commit(
}
/// Generate an ipa commitment.
///
/// Arguments
/// -------
/// message: list[str]
/// List of field elements represnted as strings
///
/// vk_path: str
/// Path to the verification key
///
/// settings_path: str
/// Path to the settings file
///
/// srs_path: str
/// Path to the Structure Reference String (SRS) file
///
/// Returns
/// -------
/// list[PyG1Affine]
///
#[pyfunction(signature = (
message,
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
srs_path=None
))]
))]
fn ipa_commit(
message: Vec<PyFelt>,
vk_path: PathBuf,
@@ -472,10 +615,19 @@ fn ipa_commit(
}
/// Swap the commitments in a proof
///
/// Arguments
/// -------
/// proof_path: str
/// Path to the proof file
///
/// witness_path: str
/// Path to the witness file
///
#[pyfunction(signature = (
proof_path=PathBuf::from(DEFAULT_PROOF),
witness_path=PathBuf::from(DEFAULT_WITNESS),
))]
))]
fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResult<()> {
crate::execute::swap_proof_commitments_cmd(proof_path, witness_path)
.map_err(|_| PyIOError::new_err("Failed to swap commitments"))?;
@@ -484,11 +636,27 @@ fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResul
}
/// Generates a vk from a pk for a model circuit and saves it to a file
///
/// Arguments
/// -------
/// path_to_pk: str
/// Path to the proving key
///
/// circuit_settings_path: str
/// Path to the witness file
///
/// vk_output_path: str
/// Path to create the vk file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
path_to_pk=PathBuf::from(DEFAULT_PK),
circuit_settings_path=PathBuf::from(DEFAULT_SETTINGS),
vk_output_path=PathBuf::from(DEFAULT_VK),
))]
))]
fn gen_vk_from_pk_single(
path_to_pk: PathBuf,
circuit_settings_path: PathBuf,
@@ -510,10 +678,22 @@ fn gen_vk_from_pk_single(
}
/// Generates a vk from a pk for an aggregate circuit and saves it to a file
///
/// Arguments
/// -------
/// path_to_pk: str
/// Path to the proving key
///
/// vk_output_path: str
/// Path to create the vk file
///
/// Returns
/// -------
/// bool
#[pyfunction(signature = (
path_to_pk=PathBuf::from(DEFAULT_PK_AGGREGATED),
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
))]
))]
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
@@ -528,6 +708,17 @@ fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult
}
/// Displays the table as a string in python
///
/// Arguments
/// ---------
/// model: str
/// Path to the onnx file
///
/// Returns
/// ---------
/// str
/// Table of the nodes in the onnx file
///
#[pyfunction(signature = (
model = PathBuf::from(DEFAULT_MODEL),
py_run_args = None
@@ -543,7 +734,16 @@ fn table(model: PathBuf, py_run_args: Option<PyRunArgs>) -> PyResult<String> {
}
}
/// generates the srs
/// Generates the Structured Reference String (SRS), use this only for testing purposes
///
/// Arguments
/// ---------
/// srs_path: str
/// Path to the create the SRS file
///
/// logrows: int
/// The number of logrows for the SRS file
///
#[pyfunction(signature = (
srs_path,
logrows,
@@ -554,7 +754,26 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
Ok(())
}
/// gets a public srs
/// Gets a public srs
///
/// Arguments
/// ---------
/// settings_path: str
/// Path to the settings file
///
/// logrows: int
/// The number of logrows for the SRS file
///
/// srs_path: str
/// Path to the create the SRS file
///
/// commitment: str
/// Specify the commitment used ("kzg", "ipa")
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
settings_path=PathBuf::from(DEFAULT_SETTINGS),
logrows=None,
@@ -587,7 +806,23 @@ fn get_srs(
Ok(true)
}
/// generates the circuit settings
/// Generates the circuit settings
///
/// Arguments
/// ---------
/// model: str
/// Path to the onnx file
///
/// output: str
/// Path to create the settings file
///
/// py_run_args: PyRunArgs
/// PyRunArgs object to initialize the settings
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
output=PathBuf::from(DEFAULT_SETTINGS),
@@ -608,7 +843,38 @@ fn gen_settings(
Ok(true)
}
/// calibrates the circuit settings
/// Calibrates the circuit settings
///
/// Arguments
/// ---------
/// data: str
/// Path to the calibration data
///
/// model: str
/// Path to the onnx file
///
/// settings: str
/// Path to the settings file
///
/// lookup_safety_margin: int
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
///
/// scales: list[int]
/// Optional scales to specifically try for calibration
///
/// scale_rebase_multiplier: list[int]
/// Optional scale rebase multipliers to specifically try for calibration. This is the multiplier at which we divide to return to the input scale.
///
/// max_logrows: int
/// Optional max logrows to use for calibration
///
/// only_range_check_rebase: bool
/// Check ranges when rebasing
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
data = PathBuf::from(DEFAULT_CALIBRATION_FILE),
model = PathBuf::from(DEFAULT_MODEL),
@@ -650,7 +916,30 @@ fn calibrate_settings(
Ok(true)
}
/// runs the forward pass operation
/// Runs the forward pass operation to generate a witness
///
/// Arguments
/// ---------
/// data: str
/// Path to the data file
///
/// model: str
/// Path to the compiled model file
///
/// output: str
/// Path to create the witness file
///
/// vk_path: str
/// Path to the verification key
///
/// srs_path: str
/// Path to the SRS file
///
/// Returns
/// -------
/// dict
/// Python object containing the witness values
///
#[pyfunction(signature = (
data=PathBuf::from(DEFAULT_DATA),
model=PathBuf::from(DEFAULT_MODEL),
@@ -677,7 +966,20 @@ fn gen_witness(
Python::with_gil(|py| Ok(output.to_object(py)))
}
/// mocks the prover
/// Mocks the prover
///
/// Arguments
/// ---------
/// witness: str
/// Path to the witness file
///
/// model: str
/// Path to the compiled model file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
witness=PathBuf::from(DEFAULT_WITNESS),
model=PathBuf::from(DEFAULT_MODEL),
@@ -690,7 +992,23 @@ fn mock(witness: PathBuf, model: PathBuf) -> PyResult<bool> {
Ok(true)
}
/// mocks the aggregate prover
/// Mocks the aggregate prover
///
/// Arguments
/// ---------
/// aggregation_snarks: list[str]
/// List of paths to the relevant proof files
///
/// logrows: int
/// Number of logrows to use for the aggregation circuit
///
/// split_proofs: bool
/// Indicates whether the accumulated are segments of a larger proof
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
@@ -709,7 +1027,32 @@ fn mock_aggregate(
Ok(true)
}
/// runs the prover on a set of inputs
/// Runs the setup process
///
/// Arguments
/// ---------
/// model: str
/// Path to the compiled model file
///
/// vk_path: str
/// Path to create the verification key file
///
/// pk_path: str
/// Path to create the proving key file
///
/// srs_path: str
/// Path to the SRS file
///
/// witness_path: str
/// Path to the witness file
///
/// disable_selector_compression: bool
/// Whether to compress the selectors or not
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
vk_path=PathBuf::from(DEFAULT_VK),
@@ -742,7 +1085,32 @@ fn setup(
Ok(true)
}
/// runs the prover on a set of inputs
/// Runs the prover on a set of inputs
///
/// Arguments
/// ---------
/// witness: str
/// Path to the witness file
///
/// model: str
/// Path to the compiled model file
///
/// pk_path: str
/// Path to the proving key file
///
/// proof_path: str
/// Path to create the proof file
///
/// proof_type: str
/// Accepts `single`, `for-aggr`
///
/// srs_path: str
/// Path to the SRS file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
witness=PathBuf::from(DEFAULT_WITNESS),
model=PathBuf::from(DEFAULT_MODEL),
@@ -776,7 +1144,29 @@ fn prove(
Python::with_gil(|py| Ok(snark.to_object(py)))
}
/// verifies a given proof
/// Verifies a given proof
///
/// Arguments
/// ---------
/// proof_path: str
/// Path to create the proof file
///
/// settings_path: str
/// Path to the settings file
///
/// vk_path: str
/// Path to the verification key file
///
/// srs_path: str
/// Path to the SRS file
///
/// non_reduced_srs: bool
/// Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
proof_path=PathBuf::from(DEFAULT_PROOF),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
@@ -806,6 +1196,38 @@ fn verify(
Ok(true)
}
/// Runs the setup process for an aggregate setup
///
/// Arguments
/// ---------
/// sample_snarks: list[str]
/// List of paths to the various proofs
///
/// vk_path: str
/// Path to create the aggregated VK
///
/// pk_path: str
/// Path to create the aggregated PK
///
/// logrows: int
/// Number of logrows to use
///
/// split_proofs: bool
/// Whether the accumulated are segments of a larger proof
///
/// srs_path: str
/// Path to the SRS file
///
/// disable_selector_compression: bool
/// Whether to compress selectors
///
/// commitment: str
/// Accepts `kzg`, `ipa`
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
sample_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
@@ -844,6 +1266,23 @@ fn setup_aggregate(
Ok(true)
}
/// Compiles the circuit for use in other steps
///
/// Arguments
/// ---------
/// model: str
/// Path to the onnx model file
///
/// compiled_circuit: str
/// Path to output the compiled circuit
///
/// settings_path: str
/// Path to the settings files
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
model=PathBuf::from(DEFAULT_MODEL),
compiled_circuit=PathBuf::from(DEFAULT_COMPILED_CIRCUIT),
@@ -862,7 +1301,41 @@ fn compile_circuit(
Ok(true)
}
/// creates an aggregated proof
/// Creates an aggregated proof
///
/// Arguments
/// ---------
/// aggregation_snarks: list[str]
/// List of paths to the various proofs
///
/// proof_path: str
/// Path to output the aggregated proof
///
/// vk_path: str
/// Path to the VK file
///
/// transcript:
/// Proof transcript type to be used. `evm` used by default. `poseidon` is also supported
///
/// logrows:
/// Logrows used for aggregation circuit
///
/// check_mode: str
/// Run sanity checks during calculations. Accepts `safe` or `unsafe`
///
/// split-proofs: bool
/// Whether the accumulated proofs are segments of a larger circuit
///
/// srs_path: str
/// Path to the SRS used
///
/// commitment: str
/// Accepts "kzg" or "ipa"
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
aggregation_snarks=vec![PathBuf::from(DEFAULT_PROOF)],
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
@@ -905,7 +1378,32 @@ fn aggregate(
Ok(true)
}
/// verifies and aggregate proof
/// Verifies and aggregate proof
///
/// Arguments
/// ---------
/// proof_path: str
/// The path to the proof file
///
/// vk_path: str
/// The path to the verification key file
///
/// logrows: int
/// logrows used for aggregation circuit
///
/// commitment: str
/// Accepts "kzg" or "ipa"
///
/// reduced_srs: bool
/// Whether to reduce the number of SRS logrows to the number of instances rather than the number of logrows used for proofs (only works if the srs were generated in the same ceremony)
///
/// srs_path: str
/// The path to the SRS file
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
vk_path=PathBuf::from(DEFAULT_VK),
@@ -938,7 +1436,32 @@ fn verify_aggr(
Ok(true)
}
/// creates an EVM compatible verifier, you will need solc installed in your environment to run this
/// Creates an EVM compatible verifier, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// vk_path: str
/// The path to the verification key file
///
/// settings_path: str
/// The path to the settings file
///
/// sol_code_path: str
/// The path to the create the solidity verifier
///
/// abi_path: str
/// The path to create the ABI for the solidity verifier
///
/// srs_path: str
/// The path to the SRS file
///
/// render_vk_separately: bool
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
vk_path=PathBuf::from(DEFAULT_VK),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
@@ -971,7 +1494,26 @@ fn create_evm_verifier(
Ok(true)
}
// creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
/// Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// input_data: str
/// The path to the .json data file, which should contain the necessary calldata and account addresses needed to read from all the on-chain view functions that return the data that the network ingests as inputs
///
/// settings_path: str
/// The path to the settings file
///
/// sol_code_path: str
/// The path to the create the solidity verifier
///
/// abi_path: str
/// The path to create the ABI for the solidity verifier
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
input_data=PathBuf::from(DEFAULT_DATA),
settings_path=PathBuf::from(DEFAULT_SETTINGS),
@@ -993,6 +1535,32 @@ fn create_evm_data_attestation(
Ok(true)
}
/// Setup test evm witness
///
/// Arguments
/// ---------
/// data_path: str
/// The path to the .json data file, which should include both the network input (possibly private) and the network output (public input to the proof)
///
/// compiled_circuit_path: str
/// The path to the compiled model file (generated using the compile-circuit command)
///
/// test_data: str
/// For testing purposes only. The optional path to the .json data file that will be generated that contains the OnChain data storage information derived from the file information in the data .json file. Should include both the network input (possibly private) and the network output (public input to the proof)
///
/// input_sources: str
/// Where the input data comes from
///
/// output_source: str
/// Where the output data comes from
///
/// rpc_url: str
/// RPC URL for an EVM compatible node, if None, uses Anvil as a local RPC node
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
data_path,
compiled_circuit_path,
@@ -1027,6 +1595,7 @@ fn setup_test_evm_witness(
Ok(true)
}
/// deploys the solidity verifier
#[pyfunction(signature = (
addr_path,
sol_code_path=PathBuf::from(DEFAULT_SOL_CODE),
@@ -1059,6 +1628,7 @@ fn deploy_evm(
Ok(true)
}
/// deploys the solidity vk verifier
#[pyfunction(signature = (
addr_path,
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
@@ -1091,6 +1661,7 @@ fn deploy_vk_evm(
Ok(true)
}
/// deploys the solidity da verifier
#[pyfunction(signature = (
addr_path,
input_data,
@@ -1128,6 +1699,27 @@ fn deploy_da_evm(
Ok(true)
}
/// verifies an evm compatible proof, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// addr_verifier: str
/// The path to verifier contract's address
///
/// proof_path: str
/// The path to the proof file (generated using the prove command)
///
/// rpc_url: str
/// RPC URL for an Ethereum node, if None will use Anvil but WON'T persist state
///
/// addr_da: str
/// does the verifier use data attestation ?
///
/// addr_vk: str
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
addr_verifier,
proof_path=PathBuf::from(DEFAULT_PROOF),
@@ -1173,7 +1765,35 @@ fn verify_evm(
Ok(true)
}
/// creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
/// Creates an evm compatible aggregate verifier, you will need solc installed in your environment to run this
///
/// Arguments
/// ---------
/// aggregation_settings: str
/// path to the settings file
///
/// vk_path: str
/// The path to load the desired verification key file
///
/// sol_code_path: str
/// The path to the Solidity code
///
/// abi_path: str
/// The path to output the Solidity verifier ABI
///
/// logrows: int
/// Number of logrows used during aggregated setup
///
/// srs_path: str
/// The path to the SRS file
///
/// render_vk_separately: bool
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
///
/// Returns
/// -------
/// bool
///
#[pyfunction(signature = (
aggregation_settings=vec![PathBuf::from(DEFAULT_PROOF)],
vk_path=PathBuf::from(DEFAULT_VK_AGGREGATED),

View File

@@ -60,6 +60,9 @@ pub enum TensorError {
/// Unsupported operation
#[error("Unsupported operation on a tensor type")]
Unsupported,
/// Overflow
#[error("Unsigned integer overflow or underflow error in op: {0}")]
Overflow(String),
}
/// The (inner) type of tensor elements.
@@ -378,7 +381,7 @@ impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<A
{
fn from(value: Tensor<AssignedCell<Assigned<F>, F>>) -> Tensor<Value<F>> {
let mut output = Vec::new();
for (_, x) in value.iter().enumerate() {
for x in value.iter() {
output.push(x.value_field().evaluate());
}
Tensor::new(Some(&output), value.dims()).unwrap()
@@ -434,6 +437,18 @@ impl<F: PrimeField + TensorType + Clone> From<Tensor<i128>> for Tensor<Value<F>>
}
}
impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
maybe_rayon::iter::FromParallelIterator<T> for Tensor<T>
{
fn from_par_iter<I>(par_iter: I) -> Self
where
I: maybe_rayon::iter::IntoParallelIterator<Item = T>,
{
let inner: Vec<T> = par_iter.into_par_iter().collect();
Tensor::new(Some(&inner), &[inner.len()]).unwrap()
}
}
impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
maybe_rayon::iter::IntoParallelIterator for Tensor<T>
{
@@ -922,6 +937,7 @@ impl<T: Clone + TensorType> Tensor<T> {
pub fn move_axis(&mut self, source: usize, destination: usize) -> Result<Self, TensorError> {
assert!(source < self.dims.len());
assert!(destination < self.dims.len());
let mut new_dims = self.dims.clone();
new_dims.remove(source);
new_dims.insert(destination, self.dims[source]);
@@ -953,6 +969,8 @@ impl<T: Clone + TensorType> Tensor<T> {
old_coord[source - 1] = *c;
} else if (i < source && source < destination)
|| (i < destination && source > destination)
|| (i > source && source > destination)
|| (i > destination && source < destination)
{
old_coord[i] = *c;
} else if i > source && source < destination {
@@ -965,7 +983,10 @@ impl<T: Clone + TensorType> Tensor<T> {
));
}
}
output.set(&coord, self.get(&old_coord));
let value = self.get(&old_coord);
output.set(&coord, value);
}
Ok(output)

File diff suppressed because it is too large Load Diff

View File

@@ -1,8 +1,12 @@
use core::{iter::FilterMap, slice::Iter};
use crate::circuit::region::ConstantsMap;
use super::{
ops::{intercalate_values, pad, resize},
*,
};
use halo2_proofs::{arithmetic::Field, plonk::Instance};
use halo2_proofs::{arithmetic::Field, circuit::Cell, plonk::Instance};
pub(crate) fn create_constant_tensor<
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
@@ -51,6 +55,24 @@ pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::S
}
impl<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> ValType<F> {
/// Returns the inner cell of the [ValType].
pub fn cell(&self) -> Option<Cell> {
match self {
ValType::PrevAssigned(cell) => Some(cell.cell()),
ValType::AssignedConstant(cell, _) => Some(cell.cell()),
_ => None,
}
}
/// Returns the assigned cell of the [ValType].
pub fn assigned_cell(&self) -> Option<AssignedCell<F, F>> {
match self {
ValType::PrevAssigned(cell) => Some(cell.clone()),
ValType::AssignedConstant(cell, _) => Some(cell.clone()),
_ => None,
}
}
/// Returns true if the value is previously assigned.
pub fn is_prev_assigned(&self) -> bool {
matches!(
@@ -293,7 +315,13 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
}
}
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Allocate a new [ValTensor::Value] from the given [Tensor] of [i128].
pub fn from_i128_tensor(t: Tensor<i128>) -> ValTensor<F> {
let inner = t.map(|x| ValType::Value(Value::known(i128_to_felt(x))));
inner.into()
}
/// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`.
pub fn new_instance(
cs: &mut ConstraintSystem<F>,
@@ -428,10 +456,37 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
/// Returns the number of constants in the [ValTensor].
pub fn num_constants(&self) -> usize {
pub fn create_constants_map_iterator(
&self,
) -> FilterMap<Iter<'_, ValType<F>>, fn(&ValType<F>) -> Option<(F, ValType<F>)>> {
match self {
ValTensor::Value { inner, .. } => inner.iter().filter(|x| x.is_constant()).count(),
ValTensor::Instance { .. } => 0,
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
}),
ValTensor::Instance { .. } => {
unreachable!("Instance tensors do not have constants")
}
}
}
/// Returns the number of constants in the [ValTensor].
pub fn create_constants_map(&self) -> ConstantsMap<F> {
match self {
ValTensor::Value { inner, .. } => inner
.par_iter()
.filter_map(|x| {
if let ValType::Constant(v) = x {
Some((*v, x.clone()))
} else {
None
}
})
.collect(),
ValTensor::Instance { .. } => ConstantsMap::new(),
}
}
@@ -824,13 +879,13 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
};
Ok(())
}
/// Calls `pad` on the inner [Tensor].
pub fn pad(&mut self, padding: [(usize, usize); 2]) -> Result<(), TensorError> {
/// Calls `pad_spatial_dims` on the inner [Tensor].
pub fn pad(&mut self, padding: Vec<(usize, usize)>, offset: usize) -> Result<(), TensorError> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
} => {
*v = pad(v, padding)?;
*v = pad(v, padding, offset)?;
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {

View File

@@ -2,7 +2,7 @@ use std::collections::HashSet;
use log::{debug, error, warn};
use crate::circuit::CheckMode;
use crate::circuit::{region::ConstantsMap, CheckMode};
use super::*;
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
@@ -289,9 +289,10 @@ impl VarTensor {
&self,
region: &mut Region<F>,
offset: usize,
coord: usize,
constant: F,
) -> Result<AssignedCell<F, F>, halo2_proofs::plonk::Error> {
let (x, y, z) = self.cartesian_coord(offset);
let (x, y, z) = self.cartesian_coord(offset + coord);
match &self {
VarTensor::Advice { inner: advices, .. } => {
region.assign_advice_from_constant(|| "constant", advices[x][y], z, constant)
@@ -304,33 +305,28 @@ impl VarTensor {
}
/// Assigns [ValTensor] to the columns of the inner tensor.
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd>(
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
omissions: &HashSet<&usize>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut assigned_coord = 0;
let mut res: ValTensor<F> = match values {
ValTensor::Instance { .. } => {
unimplemented!("cannot assign instance to advice columns with omissions")
}
ValTensor::Value { inner: v, .. } => Ok::<_, halo2_proofs::plonk::Error>(
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
v.enum_map(|coord, k| {
if omissions.contains(&coord) {
return Ok(k);
return Ok::<_, halo2_proofs::plonk::Error>(k);
}
let cell = self.assign_value(region, offset, k.clone(), assigned_coord)?;
let cell =
self.assign_value(region, offset, k.clone(), assigned_coord, constants)?;
assigned_coord += 1;
match k {
ValType::Constant(f) => Ok::<ValType<F>, halo2_proofs::plonk::Error>(
ValType::AssignedConstant(cell, f),
),
ValType::AssignedConstant(_, f) => Ok(ValType::AssignedConstant(cell, f)),
_ => Ok(ValType::PrevAssigned(cell)),
}
Ok::<_, halo2_proofs::plonk::Error>(cell)
})?
.into(),
),
@@ -340,11 +336,12 @@ impl VarTensor {
}
/// Assigns [ValTensor] to the columns of the inner tensor.
pub fn assign<F: PrimeField + TensorType + PartialOrd>(
pub fn assign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
values: &ValTensor<F>,
constants: &mut ConstantsMap<F>,
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
let mut res: ValTensor<F> = match values {
ValTensor::Instance {
@@ -382,14 +379,7 @@ impl VarTensor {
},
ValTensor::Value { inner: v, .. } => Ok(v
.enum_map(|coord, k| {
let cell = self.assign_value(region, offset, k.clone(), coord)?;
match k {
ValType::Constant(f) => Ok::<ValType<F>, halo2_proofs::plonk::Error>(
ValType::AssignedConstant(cell, f),
),
ValType::AssignedConstant(_, f) => Ok(ValType::AssignedConstant(cell, f)),
_ => Ok(ValType::PrevAssigned(cell)),
}
self.assign_value(region, offset, k.clone(), coord, constants)
})?
.into()),
}?;
@@ -399,13 +389,16 @@ impl VarTensor {
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
pub fn dummy_assign_with_duplication<F: PrimeField + TensorType + PartialOrd>(
pub fn dummy_assign_with_duplication<
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
>(
&self,
row: usize,
offset: usize,
values: &ValTensor<F>,
single_inner_col: bool,
) -> Result<(ValTensor<F>, usize, usize), halo2_proofs::plonk::Error> {
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
match values {
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
ValTensor::Value { inner: v, dims , ..} => {
@@ -430,21 +423,24 @@ impl VarTensor {
// duplicates every nth element to adjust for column overflow
let mut res: ValTensor<F> = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap().into();
let constants_map = res.create_constants_map();
constants.extend(constants_map);
let total_used_len = res.len();
let total_constants = res.num_constants();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.reshape(dims).unwrap();
res.set_scale(values.scale());
Ok((res, total_used_len, total_constants))
Ok((res, total_used_len))
}
}
}
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
pub fn assign_with_duplication<F: PrimeField + TensorType + PartialOrd>(
pub fn assign_with_duplication<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
row: usize,
@@ -452,7 +448,8 @@ impl VarTensor {
values: &ValTensor<F>,
check_mode: &CheckMode,
single_inner_col: bool,
) -> Result<(ValTensor<F>, usize, usize), halo2_proofs::plonk::Error> {
constants: &mut ConstantsMap<F>,
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
let mut prev_cell = None;
match values {
@@ -494,7 +491,7 @@ impl VarTensor {
assert_eq!(Into::<i32>::into(k.clone()), Into::<i32>::into(v[coord - 1].clone()));
};
let cell = self.assign_value(region, offset, k.clone(), coord * step)?;
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
if single_inner_col {
if z == 0 {
@@ -502,28 +499,23 @@ impl VarTensor {
prev_cell = Some(cell.clone());
} else if coord > 0 && z == 0 && single_inner_col {
if let Some(prev_cell) = prev_cell.as_ref() {
region.constrain_equal(prev_cell.cell(),cell.cell())?;
let cell = cell.cell().ok_or({
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
let prev_cell = prev_cell.cell().ok_or({
error!("Error getting cell: {:?}", (x,y));
halo2_proofs::plonk::Error::Synthesis})?;
region.constrain_equal(prev_cell,cell)?;
} else {
error!("Error copy-constraining previous value: {:?}", (x,y));
return Err(halo2_proofs::plonk::Error::Synthesis);
}
}}
match k {
ValType::Constant(f) => {
Ok(ValType::AssignedConstant(cell, f))
},
ValType::AssignedConstant(_, f) => {
Ok(ValType::AssignedConstant(cell, f))
},
_ => {
Ok(ValType::PrevAssigned(cell))
}
}
Ok(cell)
})?.into()};
let total_used_len = res.len();
let total_constants = res.num_constants();
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
res.reshape(dims).unwrap();
@@ -542,42 +534,61 @@ impl VarTensor {
)};
}
Ok((res, total_used_len, total_constants))
Ok((res, total_used_len))
}
}
}
fn assign_value<F: PrimeField + TensorType + PartialOrd>(
fn assign_value<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
&self,
region: &mut Region<F>,
offset: usize,
k: ValType<F>,
coord: usize,
) -> Result<AssignedCell<F, F>, halo2_proofs::plonk::Error> {
constants: &mut ConstantsMap<F>,
) -> Result<ValType<F>, halo2_proofs::plonk::Error> {
let (x, y, z) = self.cartesian_coord(offset + coord);
match k {
let res = match k {
ValType::Value(v) => match &self {
VarTensor::Advice { inner: advices, .. } => {
region.assign_advice(|| "k", advices[x][y], z, || v)
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
}
_ => unimplemented!(),
},
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => match &self {
ValType::PrevAssigned(v) => match &self {
VarTensor::Advice { inner: advices, .. } => {
v.copy_advice(|| "k", region, advices[x][y], z)
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
}
_ => {
error!("PrevAssigned is only supported for advice columns");
Err(halo2_proofs::plonk::Error::Synthesis)
_ => unimplemented!(),
},
ValType::AssignedConstant(v, val) => match &self {
VarTensor::Advice { inner: advices, .. } => {
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
}
_ => unimplemented!(),
},
ValType::AssignedValue(v) => match &self {
VarTensor::Advice { inner: advices, .. } => region
.assign_advice(|| "k", advices[x][y], z, || v)
.map(|a| a.evaluate()),
VarTensor::Advice { inner: advices, .. } => ValType::PrevAssigned(
region
.assign_advice(|| "k", advices[x][y], z, || v)?
.evaluate(),
),
_ => unimplemented!(),
},
ValType::Constant(v) => self.assign_constant(region, offset + coord, v),
}
ValType::Constant(v) => {
if let std::collections::hash_map::Entry::Vacant(e) = constants.entry(v) {
let value = ValType::AssignedConstant(
self.assign_constant(region, offset, coord, v)?,
v,
);
e.insert(value.clone());
value
} else {
let cell = constants.get(&v).unwrap();
self.assign_value(region, offset, cell.clone(), coord, constants)?
}
}
};
Ok(res)
}
}

View File

@@ -8,12 +8,14 @@ use crate::graph::quantize_float;
use crate::graph::scale_to_multiplier;
use crate::graph::{GraphCircuit, GraphSettings};
use crate::pfsys::create_proof_circuit;
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
use crate::pfsys::verify_proof_circuit;
use crate::pfsys::TranscriptType;
use crate::tensor::TensorType;
use crate::CheckMode;
use crate::Commitments;
use console_error_panic_hook;
use halo2_proofs::plonk::*;
use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver};
use halo2_proofs::poly::ipa::multiopen::{ProverIPA, VerifierIPA};
@@ -33,11 +35,10 @@ use halo2curves::bn256::{Bn256, Fr, G1Affine};
use halo2curves::ff::{FromUniformBytes, PrimeField};
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::str::FromStr;
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
use console_error_panic_hook;
#[cfg(feature = "web")]
pub use wasm_bindgen_rayon::init_thread_pool;
@@ -336,8 +337,94 @@ pub fn verify(
let orig_n = 1 << circuit_settings.run_args.logrows;
let commitment = circuit_settings.run_args.commitment.into();
let mut reader = std::io::BufReader::new(&srs[..]);
let result = match circuit_settings.run_args.commitment {
let result = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = KZGSingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierSHPLONK<'_, Bn256>,
KZGCommitmentScheme<Bn256>,
KZGSingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
Commitments::IPA => {
let params: ParamsIPA<_> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
let strategy = IPASingleStrategy::new(params.verifier_params());
match proof.transcript_type {
TranscriptType::EVM => verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
EvmTranscript<G1Affine, _, _, _>,
>(&proof, &params, &vk, strategy, orig_n),
TranscriptType::Poseidon => {
verify_proof_circuit::<
VerifierIPA<_>,
IPACommitmentScheme<G1Affine>,
IPASingleStrategy<_>,
_,
PoseidonTranscript<NativeLoader, _>,
>(&proof, &params, &vk, strategy, orig_n)
}
}
}
};
match result {
Ok(_) => Ok(true),
Err(e) => Err(JsError::new(&format!("{}", e))),
}
}
#[wasm_bindgen]
#[allow(non_snake_case)]
/// Verify aggregate proof in browser using wasm
pub fn verifyAggr(
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
vk: wasm_bindgen::Clamped<Vec<u8>>,
logrows: u64,
srs: wasm_bindgen::Clamped<Vec<u8>>,
commitment: &str,
) -> Result<bool, JsError> {
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof_js[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
let mut reader = std::io::BufReader::new(&vk[..]);
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
&mut reader,
halo2_proofs::SerdeFormat::RawBytes,
(),
)
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
let commit = Commitments::from_str(commitment).map_err(|e| JsError::new(&format!("{}", e)))?;
let orig_n = 1 << logrows;
let mut reader = std::io::BufReader::new(&srs[..]);
let result = match commit {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
@@ -436,8 +523,9 @@ pub fn prove(
// read in kzg params
let mut reader = std::io::BufReader::new(&srs[..]);
let commitment = circuit.settings().run_args.commitment.into();
// creates and verifies the proof
let proof = match circuit.settings().run_args.commitment {
let proof = match commitment {
Commitments::KZG => {
let params: ParamsKZG<Bn256> =
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)

View File

@@ -122,7 +122,7 @@ mod native_tests {
let settings: GraphSettings = serde_json::from_str(&settings).unwrap();
let logrows = settings.run_args.logrows;
download_srs(logrows, settings.run_args.commitment);
download_srs(logrows, settings.run_args.commitment.into());
}
fn mv_test_(test_dir: &str, test: &str) {
@@ -200,7 +200,7 @@ mod native_tests {
"1l_tiny_div",
];
const TESTS: [&str; 91] = [
const TESTS: [&str; 92] = [
"1l_mlp", //0
"1l_slice",
"1l_concat",
@@ -296,6 +296,7 @@ mod native_tests {
"reducel1",
"reducel2", // 89
"1l_lppool",
"lstm_large", // 91
];
const WASM_TESTS: [&str; 46] = [
@@ -534,7 +535,7 @@ mod native_tests {
}
});
seq!(N in 0..=90 {
seq!(N in 0..=91 {
#(#[test_case(TESTS[N])])*
#[ignore]
@@ -622,7 +623,7 @@ mod native_tests {
#(#[test_case(TESTS[N])])*
fn mock_large_batch_public_outputs_(test: &str) {
// currently variable output rank is not supported in ONNX
if test != "gather_nd" {
if test != "gather_nd" && test != "lstm_large" {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
@@ -1970,7 +1971,7 @@ mod native_tests {
.expect("failed to parse settings file");
// get_srs for the graph_settings_num_instances
download_srs(1, graph_settings.run_args.commitment);
download_srs(1, graph_settings.run_args.commitment.into());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([

View File

@@ -56,33 +56,38 @@ mod py_tests {
// source .env/bin/activate
// pip install -r requirements.txt
// maturin develop --release --features python-bindings
// first install tf2onnx as it has protobuf conflict with onnx
let status = Command::new("pip")
.args(["install", "tf2onnx==1.16.1"])
.status()
.expect("failed to execute process");
assert!(status.success());
// now install torch, pandas, numpy, seaborn, jupyter
let status = Command::new("pip")
.args([
"install",
"torch-geometric==2.5.0",
"torch==2.0.1",
"torchvision==0.15.2",
"pandas==2.0.3",
"numpy==1.23",
"seaborn==0.12.2",
"jupyter==1.0.0",
"onnx==1.14.0",
"kaggle==1.5.15",
"py-solc-x==1.1.1",
"web3==6.5.0",
"librosa==0.10.0.post2",
"keras==2.12.0",
"tensorflow==2.12.0",
"tensorflow-datasets==4.9.3",
"tf2onnx==1.14.0",
"pytorch-lightning==2.0.6",
"torch-geometric==2.5.2",
"torch==2.2.2",
"torchvision==0.17.2",
"pandas==2.2.1",
"numpy==1.26.4",
"seaborn==0.13.2",
"notebook==7.1.2",
"nbconvert==7.16.3",
"onnx==1.16.0",
"kaggle==1.6.8",
"py-solc-x==2.0.2",
"web3==6.16.0",
"librosa==0.10.1",
"keras==3.1.1",
"tensorflow==2.16.1",
"tensorflow-datasets==4.9.4",
"pytorch-lightning==2.2.1",
"sk2torch==1.2.0",
"scikit-learn==1.3.1",
"xgboost==1.7.6",
"hummingbird-ml==0.4.9",
"lightgbm==4.0.0",
"scikit-learn==1.4.1.post1",
"xgboost==2.0.3",
"hummingbird-ml==0.4.11",
"lightgbm==4.3.0",
])
.status()
.expect("failed to execute process");

View File

@@ -11,7 +11,7 @@ mod wasm32 {
bufferToVecOfFelt, compiledCircuitValidation, encodeVerifierCalldata, feltToBigEndian,
feltToFloat, feltToInt, feltToLittleEndian, genPk, genVk, genWitness, inputValidation,
pkValidation, poseidonHash, proofValidation, prove, settingsValidation, srsValidation,
u8_array_to_u128_le, verify, vkValidation, witnessValidation,
u8_array_to_u128_le, verify, verifyAggr, vkValidation, witnessValidation,
};
use halo2_solidity_verifier::encode_calldata;
use halo2curves::bn256::{Fr, G1Affine};
@@ -27,10 +27,29 @@ mod wasm32 {
pub const NETWORK: &[u8] = include_bytes!("../tests/wasm/network.onnx");
pub const INPUT: &[u8] = include_bytes!("../tests/wasm/input.json");
pub const PROOF: &[u8] = include_bytes!("../tests/wasm/proof.json");
pub const PROOF_AGGR: &[u8] = include_bytes!("../tests/wasm/proof_aggr.json");
pub const SETTINGS: &[u8] = include_bytes!("../tests/wasm/settings.json");
pub const PK: &[u8] = include_bytes!("../tests/wasm/pk.key");
pub const VK: &[u8] = include_bytes!("../tests/wasm/vk.key");
pub const VK_AGGR: &[u8] = include_bytes!("../tests/wasm/vk_aggr.key");
pub const SRS: &[u8] = include_bytes!("../tests/wasm/kzg");
pub const SRS1: &[u8] = include_bytes!("../tests/wasm/kzg1.srs");
#[wasm_bindgen_test]
async fn can_verify_aggr() {
let value = verifyAggr(
wasm_bindgen::Clamped(PROOF_AGGR.to_vec()),
wasm_bindgen::Clamped(VK_AGGR.to_vec()),
21,
wasm_bindgen::Clamped(SRS1.to_vec()),
"kzg",
)
.map_err(|_| "failed")
.unwrap();
// should not fail
assert!(value);
}
#[wasm_bindgen_test]
async fn verify_encode_verifier_calldata() {

BIN
tests/wasm/kzg1.srs Normal file

Binary file not shown.

Binary file not shown.

3075
tests/wasm/proof_aggr.json Normal file

File diff suppressed because one or more lines are too long

View File

@@ -24,8 +24,7 @@
"param_visibility": "Private",
"div_rebasing": false,
"rebase_frac_zero_constants": false,
"check_mode": "UNSAFE",
"commitment": "KZG"
"check_mode": "UNSAFE"
},
"num_rows": 16,
"total_dynamic_col_size": 0,

BIN
tests/wasm/vk_aggr.key Normal file

Binary file not shown.