mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ec8d13082 | ||
|
|
12735aefd4 | ||
|
|
7fe179b8d4 | ||
|
|
3be988a6a0 | ||
|
|
3abb3aff56 | ||
|
|
338788cb8f | ||
|
|
feb3b1b475 | ||
|
|
e134d86756 | ||
|
|
6819a3acf6 | ||
|
|
2ccf056661 | ||
|
|
a5bf64b1a2 | ||
|
|
56e2326be1 |
2
.github/workflows/large-tests.yml
vendored
2
.github/workflows/large-tests.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: nanoGPT Mock
|
||||
|
||||
22
.github/workflows/npm.yml
vendored
22
.github/workflows/npm.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
@@ -30,13 +30,13 @@ jobs:
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- name: Install binaryen
|
||||
run: |
|
||||
set -e
|
||||
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
|
||||
export PATH=$PATH:$PWD/binaryen-version_116/bin
|
||||
wasm-opt --version
|
||||
set -e
|
||||
curl -L https://github.com/WebAssembly/binaryen/releases/download/version_116/binaryen-version_116-x86_64-linux.tar.gz | tar xzf -
|
||||
export PATH=$PATH:$PWD/binaryen-version_116/bin
|
||||
wasm-opt --version
|
||||
- name: Build wasm files for both web and nodejs compilation targets
|
||||
run: |
|
||||
wasm-pack build --release --target nodejs --out-dir ./pkg/nodejs . -- -Z build-std="panic_abort,std"
|
||||
@@ -92,7 +92,7 @@ jobs:
|
||||
const jsonObject = JSONBig.parse(string);
|
||||
return jsonObject;
|
||||
}
|
||||
|
||||
|
||||
function serialize(data) { // data is an object // return a Uint8ClampedArray
|
||||
// Step 1: Stringify the Object with BigInt support
|
||||
if (typeof data === "object") {
|
||||
@@ -100,11 +100,11 @@ jobs:
|
||||
}
|
||||
// Step 2: Encode the JSON String
|
||||
const uint8Array = new TextEncoder().encode(data);
|
||||
|
||||
|
||||
// Step 3: Convert to Uint8ClampedArray
|
||||
return new Uint8ClampedArray(uint8Array.buffer);
|
||||
}
|
||||
|
||||
|
||||
module.exports = {
|
||||
deserialize,
|
||||
serialize
|
||||
@@ -123,7 +123,7 @@ jobs:
|
||||
const jsonObject = parse(string);
|
||||
return jsonObject;
|
||||
}
|
||||
|
||||
|
||||
export function serialize(data) { // data is an object // return a Uint8ClampedArray
|
||||
// Step 1: Stringify the Object with BigInt support
|
||||
if (typeof data === "object") {
|
||||
@@ -131,7 +131,7 @@ jobs:
|
||||
}
|
||||
// Step 2: Encode the JSON String
|
||||
const uint8Array = new TextEncoder().encode(data);
|
||||
|
||||
|
||||
// Step 3: Convert to Uint8ClampedArray
|
||||
return new Uint8ClampedArray(uint8Array.buffer);
|
||||
}
|
||||
|
||||
14
.github/workflows/pypi.yml
vendored
14
.github/workflows/pypi.yml
vendored
@@ -139,6 +139,20 @@ jobs:
|
||||
target: ${{ matrix.target }}
|
||||
manylinux: auto
|
||||
args: --release --out dist --features python-bindings
|
||||
before-script-linux: |
|
||||
# If we're running on rhel centos, install needed packages.
|
||||
if command -v yum &> /dev/null; then
|
||||
yum update -y && yum install -y perl-core openssl openssl-devel pkgconfig libatomic
|
||||
|
||||
# If we're running on i686 we need to symlink libatomic
|
||||
# in order to build openssl with -latomic flag.
|
||||
if [[ ! -d "/usr/lib64" ]]; then
|
||||
ln -s /usr/lib/libatomic.so.1 /usr/lib/libatomic.so
|
||||
fi
|
||||
else
|
||||
# If we're running on debian-based system.
|
||||
apt update -y && apt-get install -y libssl-dev openssl pkg-config
|
||||
fi
|
||||
|
||||
- name: Install built wheel
|
||||
if: matrix.target == 'x86_64'
|
||||
|
||||
16
.github/workflows/release.yml
vendored
16
.github/workflows/release.yml
vendored
@@ -32,7 +32,7 @@ jobs:
|
||||
token: ${{ secrets.RELEASE_TOKEN }}
|
||||
tag_name: ${{ env.EZKL_VERSION }}
|
||||
|
||||
build-release-gpu:
|
||||
build-release-gpu:
|
||||
name: build-release-gpu
|
||||
needs: ["create-release"]
|
||||
runs-on: GPU
|
||||
@@ -45,7 +45,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Checkout repo
|
||||
@@ -60,16 +60,15 @@ jobs:
|
||||
- name: Set Cargo.toml version to match github tag
|
||||
shell: bash
|
||||
run: |
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.toml.orig >Cargo.toml
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
mv Cargo.toml Cargo.toml.orig
|
||||
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.toml.orig >Cargo.toml
|
||||
mv Cargo.lock Cargo.lock.orig
|
||||
sed "s/0\\.0\\.0/${EZKL_VERSION//v}/" Cargo.lock.orig >Cargo.lock
|
||||
|
||||
- name: Install dependencies
|
||||
shell: bash
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get update
|
||||
|
||||
- name: Build release binary
|
||||
run: cargo build --release -Z sparse-registry --features icicle
|
||||
@@ -91,7 +90,6 @@ jobs:
|
||||
asset_name: ${{ env.ASSET }}
|
||||
asset_content_type: application/octet-stream
|
||||
|
||||
|
||||
build-release:
|
||||
name: build-release
|
||||
needs: ["create-release"]
|
||||
|
||||
99
.github/workflows/rust.yml
vendored
99
.github/workflows/rust.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Docs
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -73,7 +73,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -106,7 +106,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -172,7 +172,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -189,7 +189,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
@@ -199,7 +199,7 @@ jobs:
|
||||
- name: Install wasm32-unknown-unknown
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- name: Run wasm verifier tests
|
||||
# on mac:
|
||||
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
|
||||
@@ -212,7 +212,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -229,7 +229,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -286,7 +286,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -303,6 +303,8 @@ jobs:
|
||||
with:
|
||||
node-version: "18.12.1"
|
||||
cache: "pnpm"
|
||||
- name: "Add rust-src"
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- name: Install dependencies for js tests and in-browser-evm-verifier package
|
||||
run: |
|
||||
pnpm install --no-frozen-lockfile
|
||||
@@ -324,7 +326,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
- name: KZG prove and verify tests (EVM + VK rendered seperately)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
@@ -357,7 +359,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
@@ -365,7 +367,7 @@ jobs:
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
@@ -392,6 +394,12 @@ jobs:
|
||||
- name: Replace memory definition in nodejs
|
||||
run: |
|
||||
sed -i "3s|.*|imports['env'] = {memory: new WebAssembly.Memory({initial:20,maximum:65536,shared:true})}|" tests/wasm/nodejs/ezkl.js
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_tight_lookup_::t
|
||||
- name: IPA prove and verify tests
|
||||
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_::t --test-threads 1
|
||||
- name: IPA prove and verify tests (ipa outputs)
|
||||
run: cargo nextest run --release --verbose tests::ipa_prove_and_verify_ipa_output
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w
|
||||
- name: KZG prove and verify tests single inner col
|
||||
@@ -408,8 +416,6 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t
|
||||
- name: KZG prove and verify tests (public inputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input
|
||||
- name: KZG prove and verify tests (fixed params)
|
||||
@@ -425,11 +431,11 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
@@ -452,29 +458,6 @@ jobs:
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
|
||||
fuzz-tests:
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
needs: [build, library-tests, python-tests]
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
- name: fuzz tests (EVM)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_fuzz_ --test-threads 2
|
||||
# - name: fuzz tests
|
||||
# run: cargo nextest run --release --verbose tests::kzg_fuzz_ --test-threads 6
|
||||
|
||||
prove-and-verify-mock-aggr-tests:
|
||||
runs-on: self-hosted
|
||||
needs: [build, library-tests]
|
||||
@@ -482,14 +465,14 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Mock aggr tests
|
||||
- name: Mock aggr tests (KZG)
|
||||
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_mock_prove_and_verify_ --test-threads 8
|
||||
|
||||
prove-and-verify-aggr-tests-gpu:
|
||||
@@ -500,7 +483,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -517,14 +500,14 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: KZG )tests
|
||||
- name: KZG tests
|
||||
run: cargo nextest run --release --verbose tests_aggr::kzg_aggr_prove_and_verify_ --test-threads 4 -- --include-ignored
|
||||
|
||||
prove-and-verify-aggr-evm-tests:
|
||||
@@ -534,7 +517,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -544,7 +527,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
- name: KZG prove and verify aggr tests
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
@@ -555,7 +538,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -577,7 +560,7 @@ jobs:
|
||||
python-version: "3.7"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Install solc
|
||||
@@ -585,7 +568,7 @@ jobs:
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Run pytest
|
||||
@@ -601,7 +584,7 @@ jobs:
|
||||
python-version: "3.7"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -629,10 +612,10 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: "3.9"
|
||||
python-version: "3.10"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -642,7 +625,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
@@ -657,12 +640,12 @@ jobs:
|
||||
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
|
||||
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
|
||||
- name: All notebooks
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --no-capture
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_ --test-threads 1
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: NBEATS tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::nbeats_
|
||||
- name: Tictactoe tutorials
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_ --no-capture
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::tictactoe_
|
||||
# - name: Postgres tutorials
|
||||
# run: source .env/bin/activate; cargo nextest run py_tests::tests::postgres_ --test-threads 1
|
||||
|
||||
1750
Cargo.lock
generated
1750
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
98
Cargo.toml
98
Cargo.toml
@@ -15,73 +15,96 @@ crate-type = ["cdylib", "rlib"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "main" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "main" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev="9fff22c", features=["derive_serde"] }
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "9fff22c", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
rand = { version = "0.8", default_features = false }
|
||||
itertools = { version = "0.10.3", default_features = false }
|
||||
clap = { version = "4.3.3", features = ["derive"]}
|
||||
clap = { version = "4.5.3", features = ["derive"] }
|
||||
serde = { version = "1.0.126", features = ["derive"], optional = true }
|
||||
serde_json = { version = "1.0.97", default_features = false, features = ["float_roundtrip", "raw_value"], optional = true }
|
||||
serde_json = { version = "1.0.97", default_features = false, features = [
|
||||
"float_roundtrip",
|
||||
"raw_value",
|
||||
], optional = true }
|
||||
log = { version = "0.4.17", default_features = false, optional = true }
|
||||
thiserror = { version = "1.0.38", default_features = false }
|
||||
hex = { version = "0.4.3", default_features = false }
|
||||
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features=["derive_serde"]}
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch= "main" }
|
||||
maybe-rayon = { version = "0.1.1", default_features = false }
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
|
||||
maybe-rayon = { version = "0.1.1", default_features = false }
|
||||
bincode = { version = "1.3.3", default_features = false }
|
||||
ark-std = { version = "^0.3.0", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = "1.6.0"
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
|
||||
|
||||
|
||||
# evm related deps
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
ethers = { version = "2.0.11", default_features = false, features = ["ethers-solc"] }
|
||||
indicatif = {version = "0.17.5", features = ["rayon"]}
|
||||
gag = { version = "1.0.0", default_features = false}
|
||||
ethers = { version = "2.0.11", default_features = false, features = [
|
||||
"ethers-solc",
|
||||
] }
|
||||
indicatif = { version = "0.17.5", features = ["rayon"] }
|
||||
gag = { version = "1.0.0", default_features = false }
|
||||
instant = { version = "0.1" }
|
||||
reqwest = { version = "0.11.14", default-features = false, features = ["default-tls", "multipart", "stream"] }
|
||||
reqwest = { version = "0.11.14", default-features = false, features = [
|
||||
"default-tls",
|
||||
"multipart",
|
||||
"stream",
|
||||
] }
|
||||
openssl = { version = "0.10.55", features = ["vendored"] }
|
||||
postgres = "0.19.5"
|
||||
pg_bigdecimal = "0.1.5"
|
||||
lazy_static = "1.4.0"
|
||||
colored_json = { version = "3.0.1", default_features = false, optional = true}
|
||||
colored_json = { version = "3.0.1", default_features = false, optional = true }
|
||||
plotters = { version = "0.3.0", default_features = false, optional = true }
|
||||
regex = { version = "1", default_features = false }
|
||||
tokio = { version = "1.26.0", default_features = false, features = ["macros", "rt"] }
|
||||
tokio = { version = "1.26.0", default_features = false, features = [
|
||||
"macros",
|
||||
"rt",
|
||||
] }
|
||||
tokio-util = { version = "0.7.9", features = ["codec"] }
|
||||
pyo3 = { version = "0.20.2", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
|
||||
pyo3-asyncio = { version = "0.20.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
|
||||
pyo3 = { version = "0.20.2", features = [
|
||||
"extension-module",
|
||||
"abi3-py37",
|
||||
"macros",
|
||||
], default_features = false, optional = true }
|
||||
pyo3-asyncio = { version = "0.20.0", features = [
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
], default_features = false, optional = true }
|
||||
pyo3-log = { version = "0.9.0", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "7b1aa33b2f7d1f19b80e270c83320f0f94daff69", default_features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
|
||||
|
||||
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
|
||||
colored = { version = "2.0.0", default_features = false, optional = true}
|
||||
env_logger = { version = "0.10.0", default_features = false, optional = true}
|
||||
colored = { version = "2.0.0", default_features = false, optional = true }
|
||||
env_logger = { version = "0.10.0", default_features = false, optional = true }
|
||||
chrono = "0.4.31"
|
||||
sha256 = "1.4.0"
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
getrandom = { version = "0.2.8", features = ["js"] }
|
||||
instant = { version = "0.1", features = [ "wasm-bindgen", "inaccurate" ] }
|
||||
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
|
||||
|
||||
[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))'.dependencies]
|
||||
wasm-bindgen-rayon = { version = "1.0", optional=true }
|
||||
wasm-bindgen-rayon = { version = "1.0", optional = true }
|
||||
wasm-bindgen-test = "0.3.34"
|
||||
serde-wasm-bindgen = "0.4"
|
||||
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"]}
|
||||
wasm-bindgen = { version = "0.2.81", features = ["serde-serialize"] }
|
||||
console_error_panic_hook = "0.1.7"
|
||||
wasm-bindgen-console-logger = "0.1.1"
|
||||
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = {version = "0.3", features = ["html_reports"]}
|
||||
criterion = { version = "0.3", features = ["html_reports"] }
|
||||
tempfile = "3.3.0"
|
||||
lazy_static = "1.4.0"
|
||||
mnist = "0.5"
|
||||
@@ -153,11 +176,24 @@ required-features = ["ezkl"]
|
||||
[features]
|
||||
web = ["wasm-bindgen-rayon"]
|
||||
default = ["ezkl", "mv-lookup"]
|
||||
render = ["halo2_proofs/dev-graph", "plotters"]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
|
||||
ezkl = ["onnx", "serde", "serde_json", "log", "colored", "env_logger", "tabled/color", "colored_json", "halo2_proofs/circuit-params"]
|
||||
mv-lookup = ["halo2_proofs/mv-lookup", "snark-verifier/mv-lookup", "halo2_solidity_verifier/mv-lookup"]
|
||||
ezkl = [
|
||||
"onnx",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"log",
|
||||
"colored",
|
||||
"env_logger",
|
||||
"tabled/color",
|
||||
"colored_json",
|
||||
"halo2_proofs/circuit-params",
|
||||
]
|
||||
mv-lookup = [
|
||||
"halo2_proofs/mv-lookup",
|
||||
"snark-verifier/mv-lookup",
|
||||
"halo2_solidity_verifier/mv-lookup",
|
||||
]
|
||||
det-prove = []
|
||||
icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
@@ -165,7 +201,11 @@ no-banner = []
|
||||
|
||||
# icicle patch to 0.1.0 if feature icicle is enabled
|
||||
[patch.'https://github.com/ingonyama-zk/icicle']
|
||||
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix"}
|
||||
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a", package = "halo2_proofs", branch = "ac/optional-selector-poly" }
|
||||
|
||||
|
||||
[profile.release]
|
||||
rustflags = [ "-C", "relocation-model=pic" ]
|
||||
rustflags = ["-C", "relocation-model=pic"]
|
||||
|
||||
24
README.md
24
README.md
@@ -31,9 +31,9 @@ EZKL
|
||||
|
||||
[](https://colab.research.google.com/github/zkonduit/ezkl/blob/main/examples/notebooks/simple_demo_all_public.ipynb)
|
||||
|
||||
In the backend we use [Halo2](https://github.com/privacy-scaling-explorations/halo2) as a proof system.
|
||||
In the backend we use the collaboratively-developed [Halo2](https://github.com/privacy-scaling-explorations/halo2) as a proof system.
|
||||
|
||||
The generated proofs can then be used on-chain to verify computation, only the Ethereum Virtual Machine (EVM) is supported at the moment.
|
||||
The generated proofs can then be verified with much less computational resources, including on-chain (with the Ethereum Virtual Machine), in a browser, or on a device.
|
||||
|
||||
- If you have any questions, we'd love for you to open up a discussion topic in [Discussions](https://github.com/zkonduit/ezkl/discussions). Alternatively, you can join the ✨[EZKL Community Telegram Group](https://t.me/+QRzaRvTPIthlYWMx)💫.
|
||||
|
||||
@@ -45,6 +45,8 @@ The generated proofs can then be used on-chain to verify computation, only the E
|
||||
|
||||
### getting started ⚙️
|
||||
|
||||
The easiest way to get started is to try out a notebook.
|
||||
|
||||
#### Python
|
||||
Install the python bindings by calling.
|
||||
|
||||
@@ -70,7 +72,7 @@ curl https://raw.githubusercontent.com/zkonduit/ezkl/main/install_ezkl_cli.sh |
|
||||
|
||||
https://user-images.githubusercontent.com/45801863/236771676-5bbbbfd1-ba6f-418a-902e-20738ce0e9f0.mp4
|
||||
|
||||
For more details visit the [docs](https://docs.ezkl.xyz).
|
||||
For more details visit the [docs](https://docs.ezkl.xyz). The CLI is faster than Python, as it has less overhead. For even more speed and convenience, check out the [remote proving service](https://ei40vx5x6j0.typeform.com/to/sFv1oxvb), which feels like the CLI but is backed by a tuned cluster.
|
||||
|
||||
Build the auto-generated rust documentation and open the docs in your browser locally. `cargo doc --open`
|
||||
|
||||
@@ -124,17 +126,6 @@ unset ENABLE_ICICLE_GPU
|
||||
|
||||
**NOTE:** Even with the above environment variable set, icicle is disabled for circuits where k <= 8. To change the value of `k` where icicle is enabled, you can set the environment variable `ICICLE_SMALL_K`.
|
||||
|
||||
### repos
|
||||
|
||||
The EZKL project has several libraries and repos.
|
||||
|
||||
| Repo | Description |
|
||||
| --- | --- |
|
||||
| [@zkonduit/ezkl](https://github.com/zkonduit/ezkl) | the main ezkl repo in rust with wasm and python bindings |
|
||||
| [@zkonduit/ezkljs](https://github.com/zkonduit/ezkljs) | typescript and javascript tooling to help integrate ezkl into web apps |
|
||||
|
||||
----------------------
|
||||
|
||||
### contributing 🌎
|
||||
|
||||
If you're interested in contributing and are unsure where to start, reach out to one of the maintainers:
|
||||
@@ -151,7 +142,7 @@ More broadly:
|
||||
- To report bugs or request new features [create a new issue within Issues](https://github.com/zkonduit/ezkl/issues) to inform the greater community.
|
||||
|
||||
|
||||
Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion in the work by you shall be licensed to Zkonduit Inc. under the terms and conditions specified in the [CLA](https://github.com/zkonduit/ezkl/blob/main/cla.md), which you agree to by intentionally submitting a contribution. In particular, you have the right to submit the contribution and we can distribute it under the Apache 2.0 license, among other terms and conditions.
|
||||
Any contribution intentionally submitted for inclusion in the work by you shall be licensed to Zkonduit Inc. under the terms and conditions specified in the [CLA](https://github.com/zkonduit/ezkl/blob/main/cla.md), which you agree to by intentionally submitting a contribution. In particular, you have the right to submit the contribution and we can distribute it, among other terms and conditions.
|
||||
|
||||
### no security guarantees
|
||||
|
||||
@@ -159,4 +150,7 @@ Ezkl is unaudited, beta software undergoing rapid development. There may be bugs
|
||||
|
||||
> NOTE: Because operations are quantized when they are converted from an onnx file to a zk-circuit, outputs in python and ezkl may differ slightly.
|
||||
|
||||
### no warranty
|
||||
|
||||
Copyright (c) 2024 Zkonduit Inc. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
||||
@@ -2,11 +2,13 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_keys;
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::srs::gen_srs;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
@@ -15,6 +17,7 @@ use halo2_proofs::{
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
static mut KERNEL_HEIGHT: usize = 2;
|
||||
static mut KERNEL_WIDTH: usize = 2;
|
||||
@@ -121,28 +124,35 @@ fn runcnvrl(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(
|
||||
&circuit, ¶ms, true,
|
||||
)
|
||||
.unwrap();
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit_kzg(
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
TranscriptType::EVM,
|
||||
SingleStrategy::new(¶ms),
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
@@ -14,6 +16,7 @@ use halo2_proofs::{
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
@@ -90,25 +93,35 @@ fn rundot(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit_kzg(
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
TranscriptType::EVM,
|
||||
SingleStrategy::new(¶ms),
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
@@ -14,6 +16,7 @@ use halo2_proofs::{
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
@@ -94,25 +97,35 @@ fn runmatmul(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit_kzg(
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
TranscriptType::EVM,
|
||||
SingleStrategy::new(¶ms),
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
|
||||
@@ -4,17 +4,20 @@ use ezkl::circuit::*;
|
||||
|
||||
use ezkl::circuit::lookup::LookupOp;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const BITS: Range = (-32768, 32768);
|
||||
@@ -112,25 +115,35 @@ fn runmatmul(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit_kzg(
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
SingleStrategy::new(¶ms),
|
||||
CheckMode::SAFE,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
|
||||
@@ -4,17 +4,20 @@ use ezkl::circuit::*;
|
||||
use ezkl::circuit::lookup::LookupOp;
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
const BITS: Range = (-8180, 8180);
|
||||
@@ -115,25 +118,35 @@ fn runmatmul(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(k as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", k), &k, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(k as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", k), &k, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit_kzg(
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
SingleStrategy::new(¶ms),
|
||||
CheckMode::SAFE,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
@@ -14,6 +16,7 @@ use halo2_proofs::{
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
@@ -86,25 +89,35 @@ fn runsum(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit_kzg(
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
TranscriptType::EVM,
|
||||
SingleStrategy::new(¶ms),
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
|
||||
@@ -2,11 +2,13 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
|
||||
use ezkl::circuit::hybrid::HybridOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_keys;
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::srs::gen_srs;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
@@ -15,6 +17,7 @@ use halo2_proofs::{
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
static mut IMAGE_HEIGHT: usize = 2;
|
||||
static mut IMAGE_WIDTH: usize = 2;
|
||||
@@ -101,28 +104,35 @@ fn runsumpool(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(
|
||||
&circuit, ¶ms, true,
|
||||
)
|
||||
.unwrap();
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit_kzg(
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
TranscriptType::EVM,
|
||||
SingleStrategy::new(¶ms),
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
@@ -14,6 +16,7 @@ use halo2_proofs::{
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
@@ -84,25 +87,35 @@ fn runadd(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit_kzg(
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
SingleStrategy::new(¶ms),
|
||||
CheckMode::SAFE,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
|
||||
@@ -2,11 +2,13 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
|
||||
use ezkl::circuit::poly::PolyOp;
|
||||
use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
@@ -15,6 +17,7 @@ use halo2_proofs::{
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
static mut LEN: usize = 4;
|
||||
@@ -83,25 +86,35 @@ fn runpow(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit_kzg(
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
SingleStrategy::new(¶ms),
|
||||
CheckMode::SAFE,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Throughput};
|
||||
use ezkl::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
use ezkl::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
|
||||
use ezkl::circuit::modules::Module;
|
||||
use ezkl::circuit::*;
|
||||
use ezkl::pfsys::create_keys;
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::srs::gen_srs;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
arithmetic::Field,
|
||||
@@ -18,6 +21,7 @@ use halo2_proofs::{
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::rngs::OsRng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
const L: usize = 10;
|
||||
|
||||
@@ -46,7 +50,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
) -> Result<(), Error> {
|
||||
let chip: PoseidonChip<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L> =
|
||||
PoseidonChip::new(config);
|
||||
chip.layout(&mut layouter, &[self.image.clone()], 0)?;
|
||||
chip.layout(&mut layouter, &[self.image.clone()], 0, &mut HashMap::new())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@@ -62,7 +66,7 @@ fn runposeidon(c: &mut Criterion) {
|
||||
let params = gen_srs::<KZGCommitmentScheme<_>>(k);
|
||||
|
||||
let message = (0..*size).map(|_| Fr::random(OsRng)).collect::<Vec<_>>();
|
||||
let output =
|
||||
let _output =
|
||||
PoseidonChip::<PoseidonSpec, POSEIDON_WIDTH, POSEIDON_RATE, L>::run(message.to_vec())
|
||||
.unwrap();
|
||||
|
||||
@@ -76,25 +80,35 @@ fn runposeidon(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", size), &size, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, MyCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, MyCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(*size as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", size), &size, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit_kzg(
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
MyCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
Some(output[0].clone()),
|
||||
&pk,
|
||||
TranscriptType::EVM,
|
||||
SingleStrategy::new(¶ms),
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
|
||||
@@ -2,11 +2,12 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
|
||||
use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
|
||||
use ezkl::pfsys::create_proof_circuit_kzg;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
|
||||
use halo2_proofs::poly::kzg::strategy::SingleStrategy;
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
@@ -14,6 +15,7 @@ use halo2_proofs::{
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, Fr};
|
||||
use rand::Rng;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
const BITS: Range = (-32768, 32768);
|
||||
static mut LEN: usize = 4;
|
||||
@@ -91,25 +93,35 @@ fn runrelu(c: &mut Criterion) {
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("pk", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, ¶ms, true)
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
});
|
||||
});
|
||||
|
||||
let pk = create_keys::<KZGCommitmentScheme<Bn256>, Fr, NLCircuit>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
let pk =
|
||||
create_keys::<KZGCommitmentScheme<Bn256>, NLCircuit>(&circuit, ¶ms, true).unwrap();
|
||||
|
||||
group.throughput(Throughput::Elements(len as u64));
|
||||
group.bench_with_input(BenchmarkId::new("prove", len), &len, |b, &_| {
|
||||
b.iter(|| {
|
||||
let prover = create_proof_circuit_kzg(
|
||||
let prover = create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
NLCircuit,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
ezkl::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
SingleStrategy::new(¶ms),
|
||||
CheckMode::SAFE,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
prover.unwrap();
|
||||
|
||||
@@ -696,10 +696,12 @@
|
||||
"for i, value in enumerate(proof[\"instances\"]):\n",
|
||||
" for j, field_element in enumerate(value):\n",
|
||||
" onchain_input_array.append(ezkl.felt_to_big_endian(field_element))\n",
|
||||
" formatted_output += str(onchain_input_array[-1])\n",
|
||||
" formatted_output += '\"' + str(onchain_input_array[-1]) + '\"'\n",
|
||||
" if j != len(value) - 1:\n",
|
||||
" formatted_output += \", \"\n",
|
||||
" formatted_output += \"]\"\n",
|
||||
" if i != len(proof[\"instances\"]) - 1:\n",
|
||||
" formatted_output += \", \"\n",
|
||||
"formatted_output += \"]\"\n",
|
||||
"\n",
|
||||
"# This will be the values you use onchain\n",
|
||||
"# copy them over to remix and see if they verify\n",
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
"\n",
|
||||
"## Generalized Inverse\n",
|
||||
"\n",
|
||||
"We show how to use EZKL to prove that we know matrices $A$ and its generalized inverse $B$. Since these are large we deal with the KZG commitments, with $a$ the kzgcommit of $A$, $b$ the kzgcommit of $B$, and $ABA = A$.\n"
|
||||
"We show how to use EZKL to prove that we know matrices $A$ and its generalized inverse $B$. Since these are large we deal with the KZG commitments, with $a$ the polycommit of $A$, $b$ the polycommit of $B$, and $ABA = A$.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -77,7 +77,7 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"gip_run_args = ezkl.PyRunArgs()\n",
|
||||
"gip_run_args.input_visibility = \"kzgcommit\" # matrix and generalized inverse commitments\n",
|
||||
"gip_run_args.input_visibility = \"polycommit\" # matrix and generalized inverse commitments\n",
|
||||
"gip_run_args.output_visibility = \"fixed\" # no parameters used\n",
|
||||
"gip_run_args.param_visibility = \"fixed\" # should be Tensor(True)"
|
||||
]
|
||||
@@ -340,4 +340,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
@@ -161,7 +161,7 @@
|
||||
"- `fixed`: known to the prover and verifier (as a commit), but not modifiable by the prover.\n",
|
||||
"- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. \n",
|
||||
"- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.\n",
|
||||
"- `kzgcommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be inserted directly within the proof bytes. \n",
|
||||
"- `polycommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be inserted directly within the proof bytes. \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Here we create the following setup:\n",
|
||||
@@ -510,4 +510,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
@@ -154,11 +154,11 @@
|
||||
"- `fixed`: known to the prover and verifier (as a commit), but not modifiable by the prover.\n",
|
||||
"- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. \n",
|
||||
"- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.\n",
|
||||
"- `kzgcommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
|
||||
"- `polycommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
|
||||
"\n",
|
||||
"Here we create the following setup:\n",
|
||||
"- `input_visibility`: \"kzgcommit\"\n",
|
||||
"- `param_visibility`: \"kzgcommit\"\n",
|
||||
"- `input_visibility`: \"polycommit\"\n",
|
||||
"- `param_visibility`: \"polycommit\"\n",
|
||||
"- `output_visibility`: public\n",
|
||||
"\n",
|
||||
"We encourage you to play around with other setups :) \n",
|
||||
@@ -186,8 +186,8 @@
|
||||
"data_path = os.path.join('input.json')\n",
|
||||
"\n",
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"run_args.input_visibility = \"kzgcommit\"\n",
|
||||
"run_args.param_visibility = \"kzgcommit\"\n",
|
||||
"run_args.input_visibility = \"polycommit\"\n",
|
||||
"run_args.param_visibility = \"polycommit\"\n",
|
||||
"run_args.output_visibility = \"public\"\n",
|
||||
"run_args.variables = [(\"batch_size\", 1)]\n",
|
||||
"\n",
|
||||
@@ -512,4 +512,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
}
|
||||
@@ -264,9 +264,9 @@
|
||||
"### KZG commitment intermediate calculations\n",
|
||||
"\n",
|
||||
"the visibility parameters are:\n",
|
||||
"- `input_visibility`: \"kzgcommit\"\n",
|
||||
"- `input_visibility`: \"polycommit\"\n",
|
||||
"- `param_visibility`: \"public\"\n",
|
||||
"- `output_visibility`: kzgcommit"
|
||||
"- `output_visibility`: polycommit"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -280,15 +280,15 @@
|
||||
"srs_path = os.path.join('kzg.srs')\n",
|
||||
"\n",
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"run_args.input_visibility = \"kzgcommit\"\n",
|
||||
"run_args.input_visibility = \"polycommit\"\n",
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"run_args.output_visibility = \"kzgcommit\"\n",
|
||||
"run_args.output_visibility = \"polycommit\"\n",
|
||||
"run_args.variables = [(\"batch_size\", 1)]\n",
|
||||
"run_args.input_scale = 0\n",
|
||||
"run_args.param_scale = 0\n",
|
||||
"run_args.logrows = 18\n",
|
||||
"\n",
|
||||
"ezkl.get_srs(logrows=run_args.logrows)\n"
|
||||
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -208,7 +208,7 @@
|
||||
"- `private`: known only to the prover\n",
|
||||
"- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. \n",
|
||||
"- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.\n",
|
||||
"- `kzgcommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
|
||||
"- `polycommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
|
||||
"\n",
|
||||
"Here we create the following setup:\n",
|
||||
"- `input_visibility`: \"public\"\n",
|
||||
@@ -234,7 +234,7 @@
|
||||
"run_args.input_scale = 2\n",
|
||||
"run_args.logrows = 8\n",
|
||||
"\n",
|
||||
"ezkl.get_srs(logrows=run_args.logrows)"
|
||||
"ezkl.get_srs(logrows=run_args.logrows, commitment=ezkl.PyCommitments.KZG)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -385,9 +385,9 @@
|
||||
"### KZG commitment intermediate calculations\n",
|
||||
"\n",
|
||||
"This time the visibility parameters are:\n",
|
||||
"- `input_visibility`: \"kzgcommit\"\n",
|
||||
"- `input_visibility`: \"polycommit\"\n",
|
||||
"- `param_visibility`: \"public\"\n",
|
||||
"- `output_visibility`: kzgcommit"
|
||||
"- `output_visibility`: polycommit"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -399,9 +399,9 @@
|
||||
"import ezkl\n",
|
||||
"\n",
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"run_args.input_visibility = \"kzgcommit\"\n",
|
||||
"run_args.input_visibility = \"polycommit\"\n",
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"run_args.output_visibility = \"kzgcommit\"\n",
|
||||
"run_args.output_visibility = \"polycommit\"\n",
|
||||
"run_args.variables = [(\"batch_size\", 1)]\n",
|
||||
"run_args.input_scale = 2\n",
|
||||
"run_args.logrows = 8\n"
|
||||
|
||||
@@ -275,7 +275,6 @@
|
||||
" proof_path,\n",
|
||||
" settings_path,\n",
|
||||
" vk_path,\n",
|
||||
" \n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"assert res == True\n",
|
||||
@@ -291,7 +290,7 @@
|
||||
"source": [
|
||||
"# Generate a larger SRS. This is needed for the aggregated proof\n",
|
||||
"\n",
|
||||
"res = ezkl.get_srs(settings_path=None, logrows=21)"
|
||||
"res = ezkl.get_srs(settings_path=None, logrows=21, commitment=ezkl.PyCommitments.KZG)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
"source": [
|
||||
"## Solvency demo\n",
|
||||
"\n",
|
||||
"Here we create a demo of a solvency calculation in the manner of [summa-solvency](https://github.com/summa-dev/summa-solvency). The aim here is to demonstrate the use of the new kzgcommit method detailed [here](https://blog.ezkl.xyz/post/commits/). \n",
|
||||
"Here we create a demo of a solvency calculation in the manner of [summa-solvency](https://github.com/summa-dev/summa-solvency). The aim here is to demonstrate the use of the new polycommit method detailed [here](https://blog.ezkl.xyz/post/commits/). \n",
|
||||
"\n",
|
||||
"In this setup:\n",
|
||||
"- the commitments to users, respective balances, and total balance are known are publicly known to the prover and verifier. \n",
|
||||
@@ -177,10 +177,10 @@
|
||||
"- `private`: known only to the prover\n",
|
||||
"- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. \n",
|
||||
"- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.\n",
|
||||
"- `kzgcommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
|
||||
"- `polycommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes. \n",
|
||||
"\n",
|
||||
"Here we create the following setup:\n",
|
||||
"- `input_visibility`: \"kzgcommit\"\n",
|
||||
"- `input_visibility`: \"polycommit\"\n",
|
||||
"- `param_visibility`: \"public\"\n",
|
||||
"- `output_visibility`: public\n",
|
||||
"\n",
|
||||
@@ -202,8 +202,8 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"run_args = ezkl.PyRunArgs()\n",
|
||||
"# \"kzgcommit\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
|
||||
"run_args.input_visibility = \"kzgcommit\"\n",
|
||||
"# \"polycommit\" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph\n",
|
||||
"run_args.input_visibility = \"polycommit\"\n",
|
||||
"# the parameters are public\n",
|
||||
"run_args.param_visibility = \"fixed\"\n",
|
||||
"# the output is public (this is the inequality test)\n",
|
||||
@@ -515,4 +515,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
40
examples/onnx/1l_lppool/gen.py
Normal file
40
examples/onnx/1l_lppool/gen.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
self.layer = nn.LPPool2d(2, 1, (1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return self.layer(x)[0]
|
||||
|
||||
|
||||
circuit = Model()
|
||||
|
||||
x = torch.empty(1, 3, 2, 2).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
|
||||
1
examples/onnx/1l_lppool/input.json
Normal file
1
examples/onnx/1l_lppool/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.7549541592597961, 0.990360677242279, 0.9473411440849304, 0.3951031565666199, 0.8500555753707886, 0.9352139830589294, 0.11867779493331909, 0.9493132829666138, 0.6588345766067505, 0.1933223009109497, 0.12139874696731567, 0.8547163605690002]]}
|
||||
BIN
examples/onnx/1l_lppool/network.onnx
Normal file
BIN
examples/onnx/1l_lppool/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/celu/gen.py
Normal file
42
examples/onnx/celu/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = nn.CELU()(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/celu/input.json
Normal file
1
examples/onnx/celu/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.35387128591537476, 0.030473172664642334, 0.08707714080810547, 0.2429301142692566, 0.45228832960128784, 0.496021032333374, 0.13245105743408203, 0.8497090339660645]]}
|
||||
BIN
examples/onnx/celu/network.onnx
Normal file
BIN
examples/onnx/celu/network.onnx
Normal file
Binary file not shown.
41
examples/onnx/clip/gen.py
Normal file
41
examples/onnx/clip/gen.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.clamp(x, min=0.4, max=0.8)
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 2, 2, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/clip/input.json
Normal file
1
examples/onnx/clip/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.03297048807144165, 0.46362626552581787, 0.6044231057167053, 0.4949902892112732, 0.48823297023773193, 0.6798646450042725, 0.6824942231178284, 0.03491640090942383, 0.19608813524246216, 0.24129939079284668, 0.9769315123558044, 0.6306831240653992, 0.7690497636795044, 0.252221941947937, 0.9167693853378296, 0.3882681131362915, 0.9307044148445129, 0.33559417724609375, 0.7815426588058472, 0.3435332179069519, 0.7871478796005249, 0.12240773439407349, 0.5295405983924866, 0.4874419569969177, 0.08262640237808228, 0.1124718189239502, 0.5834914445877075, 0.30927878618240356, 0.48899340629577637, 0.9376634955406189, 0.21893149614334106, 0.526070773601532]]}
|
||||
24
examples/onnx/clip/network.onnx
Normal file
24
examples/onnx/clip/network.onnx
Normal file
@@ -0,0 +1,24 @@
|
||||
pytorch2.2.1:±
|
||||
?/Constant_output_0 /Constant"Constant*
|
||||
value*JÍÌÌ>
|
||||
C/Constant_1_output_0/Constant_1"Constant*
|
||||
value*JÍÌL?
|
||||
F
|
||||
input
|
||||
/Constant_output_0
|
||||
/Constant_1_output_0output/Clip"Clip
|
||||
main_graphZ)
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
|
||||
|
||||
b*
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
|
||||
|
||||
B
|
||||
41
examples/onnx/gru/gen.py
Normal file
41
examples/onnx/gru/gen.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import random
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import json
|
||||
|
||||
|
||||
model = nn.GRU(3, 3) # Input dim is 3, output dim is 3
|
||||
x = torch.randn(1, 3) # make a sequence of length 5
|
||||
|
||||
print(x)
|
||||
|
||||
# Flips the neural net into inference mode
|
||||
model.eval()
|
||||
model.to('cpu')
|
||||
|
||||
# Export the model
|
||||
torch.onnx.export(model, # model being run
|
||||
# model input (or a tuple for multiple inputs)
|
||||
x,
|
||||
# where to save the model (can be a file or file-like object)
|
||||
"network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=10, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
data_array = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data_json = dict(input_data=[data_array])
|
||||
|
||||
print(data_json)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data_json, open("input.json", 'w'))
|
||||
1
examples/onnx/gru/input.json
Normal file
1
examples/onnx/gru/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.4145222008228302, -0.4043896496295929, 0.7545749545097351]]}
|
||||
BIN
examples/onnx/gru/network.onnx
Normal file
BIN
examples/onnx/gru/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/hard_max/gen.py
Normal file
42
examples/onnx/hard_max/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.argmax(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/hard_max/input.json
Normal file
1
examples/onnx/hard_max/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.5505883693695068, 0.0766521692276001, 0.12006187438964844, 0.9497959017753601, 0.9100563526153564, 0.968717098236084, 0.5978299379348755, 0.9419963359832764]]}
|
||||
BIN
examples/onnx/hard_max/network.onnx
Normal file
BIN
examples/onnx/hard_max/network.onnx
Normal file
Binary file not shown.
@@ -9,7 +9,7 @@ class MyModel(nn.Module):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = nn.Logsoftmax()(x)
|
||||
m = nn.Hardsigmoid()(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
@@ -1 +1 @@
|
||||
{"input_data": [[0.2971532940864563, 0.3465197682380676, 0.05381882190704346, 0.058654189109802246, 0.014198064804077148, 0.06088751554489136, 0.1723427176475525, 0.5115123987197876]]}
|
||||
{"input_data": [[0.8326942324638367, 0.2796096205711365, 0.600328266620636, 0.3701696991920471, 0.17832040786743164, 0.6247223019599915, 0.501872718334198, 0.6961578726768494]]}
|
||||
@@ -1,4 +1,4 @@
|
||||
pytorch2.1.0:<3A>
|
||||
pytorch2.2.1:<3A>
|
||||
;
|
||||
inputoutput/HardSigmoid"HardSigmoid*
|
||||
alphaǻ*>
|
||||
|
||||
41
examples/onnx/hard_swish/gen.py
Normal file
41
examples/onnx/hard_swish/gen.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = nn.Hardswish()(x)
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/hard_swish/input.json
Normal file
1
examples/onnx/hard_swish/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.6996762752532959, 0.42992985248565674, 0.5102168321609497, 0.5540387630462646, 0.8489438891410828, 0.8533616065979004, 0.36736780405044556, 0.5859147310256958]]}
|
||||
15
examples/onnx/hard_swish/network.onnx
Normal file
15
examples/onnx/hard_swish/network.onnx
Normal file
@@ -0,0 +1,15 @@
|
||||
pytorch2.2.1:{
|
||||
&
|
||||
inputoutput
|
||||
/HardSwish" HardSwish
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
@@ -9,7 +9,7 @@ class MyModel(nn.Module):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = nn.Hardsigmoid()(x)
|
||||
m = nn.LogSoftmax()(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
42
examples/onnx/logsumexp/gen.py
Normal file
42
examples/onnx/logsumexp/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.logsumexp(x, dim=1)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 2, 2, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/logsumexp/input.json
Normal file
1
examples/onnx/logsumexp/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.7973018884658813, 0.5245689153671265, 0.34149593114852905, 0.1455438733100891, 0.9482707381248474, 0.4221445322036743, 0.001363217830657959, 0.8736765384674072, 0.42954301834106445, 0.7199509739875793, 0.37641745805740356, 0.5920265316963196, 0.42270803451538086, 0.41761744022369385, 0.603948712348938, 0.7250819802284241, 0.047173500061035156, 0.5115441679954529, 0.3743387460708618, 0.16794061660766602, 0.5352339148521423, 0.037976861000061035, 0.65323406457901, 0.5585184097290039, 0.10559147596359253, 0.07827490568161011, 0.6717077493667603, 0.6480781435966492, 0.9780838489532471, 0.8353415131568909, 0.6491701006889343, 0.6573048233985901]]}
|
||||
BIN
examples/onnx/logsumexp/network.onnx
Normal file
BIN
examples/onnx/logsumexp/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/mish/gen.py
Normal file
42
examples/onnx/mish/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = nn.Mish()(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/mish/input.json
Normal file
1
examples/onnx/mish/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.18563222885131836, 0.4843214750289917, 0.9991059899330139, 0.02534431219100952, 0.8105666041374207, 0.9658406376838684, 0.681107759475708, 0.5365872979164124]]}
|
||||
19
examples/onnx/mish/network.onnx
Normal file
19
examples/onnx/mish/network.onnx
Normal file
@@ -0,0 +1,19 @@
|
||||
pytorch2.2.1:ä
|
||||
0
|
||||
input/Softplus_output_0 /Softplus"Softplus
|
||||
1
|
||||
/Softplus_output_0/Tanh_output_0/Tanh"Tanh
|
||||
*
|
||||
input
|
||||
/Tanh_output_0output/Mul"Mul
|
||||
main_graphZ!
|
||||
input
|
||||
|
||||
|
||||
batch_size
|
||||
b"
|
||||
output
|
||||
|
||||
|
||||
batch_size
|
||||
B
|
||||
42
examples/onnx/reducel1/gen.py
Normal file
42
examples/onnx/reducel1/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.norm(x, p=1, dim=1)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 2, 2, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/reducel1/input.json
Normal file
1
examples/onnx/reducel1/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.02284395694732666, 0.7941043376922607, 0.07971876859664917, 0.8898420929908752, 0.8233054280281067, 0.11066079139709473, 0.4424799084663391, 0.4355071783065796, 0.6723723411560059, 0.6818525195121765, 0.8726171851158142, 0.17742449045181274, 0.054257750511169434, 0.5775953531265259, 0.7758923172950745, 0.8431423306465149, 0.7602444887161255, 0.29686522483825684, 0.22489851713180542, 0.0675363540649414, 0.981339693069458, 0.15771394968032837, 0.5801441669464111, 0.9044001698493958, 0.49266451597213745, 0.42621421813964844, 0.35345613956451416, 0.042848050594329834, 0.6908614039421082, 0.5422852039337158, 0.01975083351135254, 0.5772860050201416]]}
|
||||
BIN
examples/onnx/reducel1/network.onnx
Normal file
BIN
examples/onnx/reducel1/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/reducel2/gen.py
Normal file
42
examples/onnx/reducel2/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.norm(x, p=2, dim=1)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 2, 2, 8).uniform_(0, 1)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/reducel2/input.json
Normal file
1
examples/onnx/reducel2/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.8709188103675842, 0.11553549766540527, 0.27376580238342285, 0.7518517971038818, 0.7879393100738525, 0.8765475749969482, 0.14315760135650635, 0.8982420563697815, 0.7274006605148315, 0.39007169008255005, 0.729040801525116, 0.11306107044219971, 0.658822774887085, 0.666404664516449, 0.3001367449760437, 0.45343858003616333, 0.7460223436355591, 0.7423691749572754, 0.7544230818748474, 0.5674425959587097, 0.8728761672973633, 0.27062875032424927, 0.1595977544784546, 0.22975260019302368, 0.6711723208427429, 0.8265992403030396, 0.48679041862487793, 0.689740777015686, 0.330846905708313, 0.5630669593811035, 0.8058932423591614, 0.5802426338195801]]}
|
||||
BIN
examples/onnx/reducel2/network.onnx
Normal file
BIN
examples/onnx/reducel2/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/tril/gen.py
Normal file
42
examples/onnx/tril/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.triu(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 3, 3).uniform_(0, 5)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/tril/input.json
Normal file
1
examples/onnx/tril/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.4870188236236572, 2.275230646133423, 3.126268148422241, 0.6412187218666077, 0.9967470169067383, 1.9814395904541016, 1.6355383396148682, 0.6397527456283569, 0.7825168967247009]]}
|
||||
BIN
examples/onnx/tril/network.onnx
Normal file
BIN
examples/onnx/tril/network.onnx
Normal file
Binary file not shown.
42
examples/onnx/triu/gen.py
Normal file
42
examples/onnx/triu/gen.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from torch import nn
|
||||
import torch
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
m = torch.tril(x)
|
||||
|
||||
return m
|
||||
|
||||
|
||||
circuit = MyModel()
|
||||
|
||||
x = torch.empty(1, 3, 3).uniform_(0, 5)
|
||||
|
||||
out = circuit(x)
|
||||
|
||||
print(out)
|
||||
|
||||
torch.onnx.export(circuit, x, "network.onnx",
|
||||
export_params=True, # store the trained parameter weights inside the model file
|
||||
opset_version=17, # the ONNX version to export the model to
|
||||
do_constant_folding=True, # whether to execute constant folding for optimization
|
||||
input_names=['input'], # the model's input names
|
||||
output_names=['output'], # the model's output names
|
||||
dynamic_axes={'input': {0: 'batch_size'}, # variable length axes
|
||||
'output': {0: 'batch_size'}})
|
||||
|
||||
|
||||
d1 = ((x).detach().numpy()).reshape([-1]).tolist()
|
||||
|
||||
data = dict(
|
||||
input_data=[d1],
|
||||
)
|
||||
|
||||
# Serialize data into file:
|
||||
json.dump(data, open("input.json", 'w'))
|
||||
1
examples/onnx/triu/input.json
Normal file
1
examples/onnx/triu/input.json
Normal file
@@ -0,0 +1 @@
|
||||
{"input_data": [[0.2898547053337097, 1.8070811033248901, 0.30266255140304565, 3.00581955909729, 0.5379888415336609, 1.7057424783706665, 2.415961265563965, 0.589233934879303, 0.03824889659881592]]}
|
||||
BIN
examples/onnx/triu/network.onnx
Normal file
BIN
examples/onnx/triu/network.onnx
Normal file
Binary file not shown.
@@ -17,7 +17,7 @@
|
||||
"clean": "rm -r dist || true",
|
||||
"build:commonjs": "tsc --project tsconfig.commonjs.json && resolve-tspaths -p tsconfig.commonjs.json",
|
||||
"build:esm": "tsc --project tsconfig.esm.json && resolve-tspaths -p tsconfig.esm.json",
|
||||
"build": "pnpm run clean && pnpm run build:commonjs && pnpm run build:esm"
|
||||
"build": "npm run clean && npm run build:commonjs && npm run build:esm"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ethereumjs/common": "^4.0.0",
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
[toolchain]
|
||||
channel = "nightly-2023-08-24"
|
||||
channel = "nightly-2024-02-06"
|
||||
components = ["rustfmt", "clippy"]
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
pub mod poseidon;
|
||||
|
||||
///
|
||||
pub mod kzg;
|
||||
pub mod polycommit;
|
||||
|
||||
///
|
||||
pub mod planner;
|
||||
@@ -15,6 +15,8 @@ pub use planner::*;
|
||||
|
||||
use crate::tensor::{TensorType, ValTensor};
|
||||
|
||||
use super::region::ConstantsMap;
|
||||
|
||||
/// Module trait used to extend ezkl functionality
|
||||
pub trait Module<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// Config
|
||||
@@ -39,6 +41,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
|
||||
&self,
|
||||
layouter: &mut impl Layouter<F>,
|
||||
input: &[ValTensor<F>],
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<Self::InputAssignments, Error>;
|
||||
/// Layout
|
||||
fn layout(
|
||||
@@ -46,6 +49,7 @@ pub trait Module<F: PrimeField + TensorType + PartialOrd> {
|
||||
layouter: &mut impl Layouter<F>,
|
||||
input: &[ValTensor<F>],
|
||||
row_offset: usize,
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<ValTensor<F>, Error>;
|
||||
/// Number of instance values the module uses every time it is applied
|
||||
fn instance_increment_input(&self) -> Vec<usize>;
|
||||
|
||||
@@ -4,49 +4,49 @@ is already implemented in halo2_gadgets, there is no wrapper chip that makes it
|
||||
Thanks to https://github.com/summa-dev/summa-solvency/blob/master/src/chips/poseidon/hash.rs for the inspiration (and also helping us understand how to use this).
|
||||
*/
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
// This chip adds a set of advice columns to the gadget Chip to store the inputs of the hash
|
||||
use halo2_proofs::halo2curves::bn256::Fr as Fp;
|
||||
use halo2_proofs::poly::commitment::{Blind, Params};
|
||||
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
|
||||
use halo2_proofs::poly::commitment::{Blind, CommitmentScheme, Params};
|
||||
use halo2_proofs::{circuit::*, plonk::*};
|
||||
use halo2curves::bn256::{Bn256, G1Affine};
|
||||
use halo2curves::bn256::G1Affine;
|
||||
use halo2curves::group::prime::PrimeCurveAffine;
|
||||
use halo2curves::group::Curve;
|
||||
use halo2curves::CurveAffine;
|
||||
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::tensor::{Tensor, ValTensor, ValType, VarTensor};
|
||||
|
||||
use super::Module;
|
||||
|
||||
/// The number of instance columns used by the KZG hash function
|
||||
/// The number of instance columns used by the PolyCommit hash function
|
||||
pub const NUM_INSTANCE_COLUMNS: usize = 0;
|
||||
/// The number of advice columns used by the KZG hash function
|
||||
/// The number of advice columns used by the PolyCommit hash function
|
||||
pub const NUM_INNER_COLS: usize = 1;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// WIDTH, RATE and L are const generics for the struct, which represent the width, rate, and number of inputs for the Poseidon hash function, respectively.
|
||||
/// This means they are values that are known at compile time and can be used to specialize the implementation of the struct.
|
||||
/// The actual chip provided by halo2_gadgets is added to the parent Chip.
|
||||
pub struct KZGConfig {
|
||||
/// Configuration for the PolyCommit chip
|
||||
pub struct PolyCommitConfig {
|
||||
///
|
||||
pub hash_inputs: VarTensor,
|
||||
pub inputs: VarTensor,
|
||||
}
|
||||
|
||||
type InputAssignments = ();
|
||||
|
||||
/// PoseidonChip is a wrapper around the Pow5Chip that adds a set of advice columns to the gadget Chip to store the inputs of the hash
|
||||
///
|
||||
#[derive(Debug)]
|
||||
pub struct KZGChip {
|
||||
config: KZGConfig,
|
||||
pub struct PolyCommitChip {
|
||||
config: PolyCommitConfig,
|
||||
}
|
||||
|
||||
impl KZGChip {
|
||||
impl PolyCommitChip {
|
||||
/// Commit to the message using the KZG commitment scheme
|
||||
pub fn commit(
|
||||
message: Vec<Fp>,
|
||||
pub fn commit<Scheme: CommitmentScheme<Scalar = Fp, Curve = G1Affine>>(
|
||||
message: Vec<Scheme::Scalar>,
|
||||
degree: u32,
|
||||
num_unusable_rows: u32,
|
||||
params: &ParamsKZG<Bn256>,
|
||||
params: &Scheme::ParamsProver,
|
||||
) -> Vec<G1Affine> {
|
||||
let k = params.k();
|
||||
let domain = halo2_proofs::poly::EvaluationDomain::new(degree, k);
|
||||
@@ -81,14 +81,14 @@ impl KZGChip {
|
||||
}
|
||||
}
|
||||
|
||||
impl Module<Fp> for KZGChip {
|
||||
type Config = KZGConfig;
|
||||
impl Module<Fp> for PolyCommitChip {
|
||||
type Config = PolyCommitConfig;
|
||||
type InputAssignments = InputAssignments;
|
||||
type RunInputs = Vec<Fp>;
|
||||
type Params = (usize, usize);
|
||||
|
||||
fn name(&self) -> &'static str {
|
||||
"KZG"
|
||||
"PolyCommit"
|
||||
}
|
||||
|
||||
fn instance_increment_input(&self) -> Vec<usize> {
|
||||
@@ -102,14 +102,15 @@ impl Module<Fp> for KZGChip {
|
||||
|
||||
/// Configuration of the PoseidonChip
|
||||
fn configure(meta: &mut ConstraintSystem<Fp>, params: Self::Params) -> Self::Config {
|
||||
let hash_inputs = VarTensor::new_unblinded_advice(meta, params.0, NUM_INNER_COLS, params.1);
|
||||
Self::Config { hash_inputs }
|
||||
let inputs = VarTensor::new_unblinded_advice(meta, params.0, NUM_INNER_COLS, params.1);
|
||||
Self::Config { inputs }
|
||||
}
|
||||
|
||||
fn layout_inputs(
|
||||
&self,
|
||||
_: &mut impl Layouter<Fp>,
|
||||
_: &[ValTensor<Fp>],
|
||||
_: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Self::InputAssignments, Error> {
|
||||
Ok(())
|
||||
}
|
||||
@@ -122,11 +123,24 @@ impl Module<Fp> for KZGChip {
|
||||
layouter: &mut impl Layouter<Fp>,
|
||||
input: &[ValTensor<Fp>],
|
||||
_: usize,
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, Error> {
|
||||
assert_eq!(input.len(), 1);
|
||||
|
||||
let local_constants = constants.clone();
|
||||
layouter.assign_region(
|
||||
|| "kzg commit",
|
||||
|mut region| self.config.hash_inputs.assign(&mut region, 0, &input[0]),
|
||||
|| "PolyCommit",
|
||||
|mut region| {
|
||||
let mut local_inner_constants = local_constants.clone();
|
||||
let res = self.config.inputs.assign(
|
||||
&mut region,
|
||||
0,
|
||||
&input[0],
|
||||
&mut local_inner_constants,
|
||||
)?;
|
||||
*constants = local_inner_constants;
|
||||
Ok(res)
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
@@ -163,7 +177,7 @@ mod tests {
|
||||
}
|
||||
|
||||
impl Circuit<Fp> for HashCircuit {
|
||||
type Config = KZGConfig;
|
||||
type Config = PolyCommitConfig;
|
||||
type FloorPlanner = ModulePlanner;
|
||||
type Params = ();
|
||||
|
||||
@@ -178,7 +192,7 @@ mod tests {
|
||||
|
||||
fn configure(meta: &mut ConstraintSystem<Fp>) -> Self::Config {
|
||||
let params = (K, R);
|
||||
KZGChip::configure(meta, params)
|
||||
PolyCommitChip::configure(meta, params)
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
@@ -186,8 +200,13 @@ mod tests {
|
||||
config: Self::Config,
|
||||
mut layouter: impl Layouter<Fp>,
|
||||
) -> Result<(), Error> {
|
||||
let kzg_chip = KZGChip::new(config);
|
||||
kzg_chip.layout(&mut layouter, &[self.message.clone()], 0);
|
||||
let polycommit_chip = PolyCommitChip::new(config);
|
||||
polycommit_chip.layout(
|
||||
&mut layouter,
|
||||
&[self.message.clone()],
|
||||
0,
|
||||
&mut HashMap::new(),
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -195,7 +214,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn kzg_for_a_range_of_input_sizes() {
|
||||
fn polycommit_chip_for_a_range_of_input_sizes() {
|
||||
let rng = rand::rngs::OsRng;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -225,7 +244,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
#[ignore]
|
||||
fn kzg_commit_much_longer_input() {
|
||||
fn polycommit_chip_much_longer_input() {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
env_logger::init();
|
||||
|
||||
@@ -18,6 +18,7 @@ use maybe_rayon::slice::ParallelSlice;
|
||||
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::tensor::{Tensor, ValTensor, ValType};
|
||||
|
||||
use super::Module;
|
||||
@@ -172,12 +173,15 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
&self,
|
||||
layouter: &mut impl Layouter<Fp>,
|
||||
message: &[ValTensor<Fp>],
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Self::InputAssignments, Error> {
|
||||
assert_eq!(message.len(), 1);
|
||||
let message = message[0].clone();
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let local_constants = constants.clone();
|
||||
|
||||
let res = layouter.assign_region(
|
||||
|| "load message",
|
||||
|mut region| {
|
||||
@@ -199,12 +203,26 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
Ok(v.clone())
|
||||
}
|
||||
ValType::Constant(f) => region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
),
|
||||
ValType::Constant(f) => {
|
||||
if local_constants.contains_key(f) {
|
||||
Ok(constants.get(f).unwrap().assigned_cell().ok_or({
|
||||
log::error!("constant not previously assigned");
|
||||
Error::Synthesis
|
||||
})?)
|
||||
} else {
|
||||
let res = region.assign_advice_from_constant(
|
||||
|| format!("load message_{}", i),
|
||||
self.config.hash_inputs[x],
|
||||
y,
|
||||
*f,
|
||||
)?;
|
||||
|
||||
constants
|
||||
.insert(*f, ValType::AssignedConstant(res.clone(), *f));
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
e => {
|
||||
log::error!(
|
||||
"wrong input type {:?}, must be previously assigned",
|
||||
@@ -270,8 +288,9 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
|
||||
layouter: &mut impl Layouter<Fp>,
|
||||
input: &[ValTensor<Fp>],
|
||||
row_offset: usize,
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<ValTensor<Fp>, Error> {
|
||||
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input)?;
|
||||
let (mut input_cells, zero_val) = self.layout_inputs(layouter, input, constants)?;
|
||||
// extract the values from the input cells
|
||||
let mut assigned_input: Tensor<ValType<Fp>> =
|
||||
input_cells.iter().map(|e| ValType::from(e.clone())).into();
|
||||
@@ -434,7 +453,7 @@ mod tests {
|
||||
*,
|
||||
};
|
||||
|
||||
use std::marker::PhantomData;
|
||||
use std::{collections::HashMap, marker::PhantomData};
|
||||
|
||||
use halo2_gadgets::poseidon::primitives::Spec;
|
||||
use halo2_proofs::{
|
||||
@@ -477,7 +496,12 @@ mod tests {
|
||||
mut layouter: impl Layouter<Fp>,
|
||||
) -> Result<(), Error> {
|
||||
let chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, L> = PoseidonChip::new(config);
|
||||
chip.layout(&mut layouter, &[self.message.clone()], 0)?;
|
||||
chip.layout(
|
||||
&mut layouter,
|
||||
&[self.message.clone()],
|
||||
0,
|
||||
&mut HashMap::new(),
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -345,7 +345,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
Self {
|
||||
|
||||
@@ -46,7 +46,8 @@ pub enum HybridOp {
|
||||
dim: usize,
|
||||
},
|
||||
Softmax {
|
||||
scale: utils::F32,
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
axes: Vec<usize>,
|
||||
},
|
||||
RangeCheck(Tolerance),
|
||||
@@ -70,7 +71,7 @@ pub enum HybridOp {
|
||||
},
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for HybridOp {
|
||||
///
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
match self {
|
||||
@@ -130,9 +131,16 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
kernel_shape,
|
||||
normalized,
|
||||
} => tensor::ops::sumpool(&x, *padding, *stride, *kernel_shape, *normalized)?,
|
||||
HybridOp::Softmax { scale, axes } => {
|
||||
tensor::ops::nonlinearities::softmax_axes(&x, scale.into(), axes)
|
||||
}
|
||||
HybridOp::Softmax {
|
||||
input_scale,
|
||||
output_scale,
|
||||
axes,
|
||||
} => tensor::ops::nonlinearities::softmax_axes(
|
||||
&x,
|
||||
input_scale.into(),
|
||||
output_scale.into(),
|
||||
axes,
|
||||
),
|
||||
HybridOp::RangeCheck(tol) => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
|
||||
@@ -203,8 +211,15 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
),
|
||||
HybridOp::ReduceMin { axes } => format!("REDUCEMIN (axes={:?})", axes),
|
||||
HybridOp::ReduceArgMin { dim } => format!("REDUCEARGMIN (dim={})", dim),
|
||||
HybridOp::Softmax { scale, axes } => {
|
||||
format!("SOFTMAX (scale={}, axes={:?})", scale, axes)
|
||||
HybridOp::Softmax {
|
||||
input_scale,
|
||||
output_scale,
|
||||
axes,
|
||||
} => {
|
||||
format!(
|
||||
"SOFTMAX (input_scale={}, output_scale={}, axes={:?})",
|
||||
input_scale, output_scale, axes
|
||||
)
|
||||
}
|
||||
HybridOp::RangeCheck(p) => format!("RANGECHECK (tol={:?})", p),
|
||||
HybridOp::Greater => "GREATER".into(),
|
||||
@@ -324,9 +339,18 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
HybridOp::ReduceArgMin { dim } => {
|
||||
layouts::argmin_axes(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
HybridOp::Softmax { scale, axes } => {
|
||||
layouts::softmax_axes(config, region, values[..].try_into()?, *scale, axes)?
|
||||
}
|
||||
HybridOp::Softmax {
|
||||
input_scale,
|
||||
output_scale,
|
||||
axes,
|
||||
} => layouts::softmax_axes(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
*input_scale,
|
||||
*output_scale,
|
||||
axes,
|
||||
)?,
|
||||
HybridOp::RangeCheck(tol) => layouts::range_check_percent(
|
||||
config,
|
||||
region,
|
||||
@@ -359,8 +383,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
| HybridOp::ReduceArgMax { .. }
|
||||
| HybridOp::OneHot { .. }
|
||||
| HybridOp::ReduceArgMin { .. } => 0,
|
||||
HybridOp::Softmax { .. } => 2 * in_scales[0],
|
||||
HybridOp::Recip { output_scale, .. } => multiplier_to_scale(output_scale.0 as f64),
|
||||
HybridOp::Softmax { output_scale, .. } | HybridOp::Recip { output_scale, .. } => {
|
||||
multiplier_to_scale(output_scale.0 as f64)
|
||||
}
|
||||
_ => in_scales[0],
|
||||
};
|
||||
Ok(scale)
|
||||
|
||||
@@ -9,6 +9,7 @@ use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use log::{error, trace};
|
||||
use maybe_rayon::{
|
||||
iter::IntoParallelRefIterator,
|
||||
prelude::{IndexedParallelIterator, IntoParallelIterator, ParallelIterator},
|
||||
slice::ParallelSliceMut,
|
||||
};
|
||||
@@ -33,7 +34,7 @@ use super::*;
|
||||
use crate::circuit::ops::lookup::LookupOp;
|
||||
|
||||
/// Same as div but splits the division into N parts
|
||||
pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
value: &[ValTensor<F>; 1],
|
||||
@@ -68,7 +69,7 @@ pub(crate) fn loop_div<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Div accumulated layout
|
||||
pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn div<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
value: &[ValTensor<F>; 1],
|
||||
@@ -93,9 +94,9 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.get_int_evals()?;
|
||||
tensor::ops::nonlinearities::const_div(&input_evals.clone(), felt_to_i128(div) as f64)
|
||||
.iter()
|
||||
.map(|x| Ok(Value::known(i128_to_felt(*x))))
|
||||
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
|
||||
.par_iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
@@ -133,7 +134,7 @@ pub(crate) fn div<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// recip accumulated layout
|
||||
pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
value: &[ValTensor<F>; 1],
|
||||
@@ -166,9 +167,9 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
felt_to_i128(input_scale) as f64,
|
||||
felt_to_i128(output_scale) as f64,
|
||||
)
|
||||
.iter()
|
||||
.map(|x| Ok(Value::known(i128_to_felt(*x))))
|
||||
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
|
||||
.par_iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
} else {
|
||||
Tensor::new(
|
||||
@@ -226,7 +227,7 @@ pub(crate) fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Dot product accumulated layout
|
||||
pub(crate) fn dot<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -337,7 +338,7 @@ pub(crate) fn dot<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Einsum
|
||||
pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
inputs: &[ValTensor<F>],
|
||||
@@ -524,14 +525,7 @@ pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
// Compute the product of all input tensors
|
||||
for pair in input_pairs {
|
||||
let product_across_pair = prod(
|
||||
config,
|
||||
region,
|
||||
&[pair.try_into().map_err(|e| {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?],
|
||||
)?;
|
||||
let product_across_pair = prod(config, region, &[pair.into()])?;
|
||||
|
||||
if let Some(product) = prod_res {
|
||||
prod_res = Some(
|
||||
@@ -563,7 +557,7 @@ pub(crate) fn einsum<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
|
||||
fn _sort_ascending<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -574,12 +568,12 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
|
||||
let sorted = if is_assigned {
|
||||
input
|
||||
.get_int_evals()?
|
||||
.iter()
|
||||
.sorted_by(|a, b| a.cmp(b))
|
||||
.map(|x| Ok(Value::known(i128_to_felt(*x))))
|
||||
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
|
||||
let mut int_evals = input.get_int_evals()?;
|
||||
int_evals.par_sort_unstable_by(|a, b| a.cmp(b));
|
||||
int_evals
|
||||
.par_iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); input.len()]),
|
||||
@@ -607,7 +601,7 @@ fn _sort_ascending<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
///
|
||||
fn _select_topk<F: PrimeField + TensorType + PartialOrd>(
|
||||
fn _select_topk<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -622,7 +616,7 @@ fn _select_topk<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Select top k elements
|
||||
pub(crate) fn topk_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn topk_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -642,11 +636,12 @@ pub(crate) fn topk_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn select<F: PrimeField + TensorType + PartialOrd>(
|
||||
fn select<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let start = instant::Instant::now();
|
||||
let (mut input, index) = (values[0].clone(), values[1].clone());
|
||||
input.flatten();
|
||||
|
||||
@@ -656,12 +651,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()?;
|
||||
|
||||
let output: ValTensor<F> = if is_assigned {
|
||||
let output: ValTensor<F> = if is_assigned && region.witness_gen() {
|
||||
let felt_evals = input.get_felt_evals()?;
|
||||
index
|
||||
.get_int_evals()?
|
||||
.iter()
|
||||
.map(|x| Ok(Value::known(input.get_felt_evals()?.get(&[*x as usize]))))
|
||||
.collect::<Result<Tensor<Value<F>>, Box<dyn Error>>>()?
|
||||
.par_iter()
|
||||
.map(|x| Value::known(felt_evals.get(&[*x as usize])))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
} else {
|
||||
Tensor::new(
|
||||
Some(&vec![Value::<F>::unknown(); index.len()]),
|
||||
@@ -673,10 +669,13 @@ fn select<F: PrimeField + TensorType + PartialOrd>(
|
||||
let (_, assigned_output) =
|
||||
dynamic_lookup(config, region, &[index, output], &[dim_indices, input])?;
|
||||
|
||||
let end = start.elapsed();
|
||||
trace!("select took: {:?}", end);
|
||||
|
||||
Ok(assigned_output)
|
||||
}
|
||||
|
||||
fn one_hot<F: PrimeField + TensorType + PartialOrd>(
|
||||
fn one_hot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -692,7 +691,7 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
|
||||
let output: ValTensor<F> = if is_assigned {
|
||||
let int_evals = input.get_int_evals()?;
|
||||
let res = tensor::ops::one_hot(&int_evals, num_classes, 1)?;
|
||||
res.iter()
|
||||
res.par_iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<_>>()
|
||||
} else {
|
||||
@@ -728,12 +727,13 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Dynamic lookup
|
||||
pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
lookups: &[ValTensor<F>; 2],
|
||||
tables: &[ValTensor<F>; 2],
|
||||
) -> Result<(ValTensor<F>, ValTensor<F>), Box<dyn Error>> {
|
||||
let start = instant::Instant::now();
|
||||
// if not all lookups same length err
|
||||
if lookups[0].len() != lookups[1].len() {
|
||||
return Err("lookups must be same length".into());
|
||||
@@ -753,20 +753,28 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
|
||||
let _table_1 = region.assign_dynamic_lookup(&config.dynamic_lookups.tables[1], &table_1)?;
|
||||
let table_len = table_0.len();
|
||||
|
||||
trace!("assigning tables took: {:?}", start.elapsed());
|
||||
|
||||
// now create a vartensor of constants for the dynamic lookup index
|
||||
let table_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), table_len);
|
||||
let _table_index =
|
||||
region.assign_dynamic_lookup(&config.dynamic_lookups.tables[2], &table_index)?;
|
||||
|
||||
trace!("assigning table index took: {:?}", start.elapsed());
|
||||
|
||||
let lookup_0 = region.assign(&config.dynamic_lookups.inputs[0], &lookup_0)?;
|
||||
let lookup_1 = region.assign(&config.dynamic_lookups.inputs[1], &lookup_1)?;
|
||||
let lookup_len = lookup_0.len();
|
||||
|
||||
trace!("assigning lookups took: {:?}", start.elapsed());
|
||||
|
||||
// now set the lookup index
|
||||
let lookup_index = create_constant_tensor(F::from(dynamic_lookup_index as u64), lookup_len);
|
||||
|
||||
let _lookup_index = region.assign(&config.dynamic_lookups.inputs[2], &lookup_index)?;
|
||||
|
||||
trace!("assigning lookup index took: {:?}", start.elapsed());
|
||||
|
||||
if !region.is_dummy() {
|
||||
(0..table_len)
|
||||
.map(|i| {
|
||||
@@ -802,11 +810,14 @@ pub(crate) fn dynamic_lookup<F: PrimeField + TensorType + PartialOrd>(
|
||||
region.increment_dynamic_lookup_index(1);
|
||||
region.increment(lookup_len);
|
||||
|
||||
let end = start.elapsed();
|
||||
trace!("dynamic lookup took: {:?}", end);
|
||||
|
||||
Ok((lookup_0, lookup_1))
|
||||
}
|
||||
|
||||
/// Shuffle arg
|
||||
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
input: &[ValTensor<F>; 1],
|
||||
@@ -869,7 +880,7 @@ pub(crate) fn shuffles<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// One hot accumulated layout
|
||||
pub(crate) fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn one_hot_axis<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -922,7 +933,7 @@ pub(crate) fn one_hot_axis<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Gather accumulated layout
|
||||
pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -950,7 +961,7 @@ pub(crate) fn gather<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Gather accumulated layout
|
||||
pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -973,7 +984,7 @@ pub(crate) fn gather_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Gather accumulated layout
|
||||
pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -1024,7 +1035,7 @@ pub(crate) fn gather_nd<F: PrimeField + TensorType + PartialOrd>(
|
||||
/// Takes a tensor representing a multi-dimensional index and returns a tensor representing the linearized index.
|
||||
/// The linearized index is the index of the element in the flattened tensor.
|
||||
/// FOr instance if the dims is [3,5,2], the linearized index of [2] at dim 1 is 2*5 + 3 = 13
|
||||
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -1032,6 +1043,7 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
|
||||
dim: usize,
|
||||
is_flat_index: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let start_time = instant::Instant::now();
|
||||
let index = values[0].clone();
|
||||
if !is_flat_index {
|
||||
assert_eq!(index.dims().len(), dims.len());
|
||||
@@ -1105,6 +1117,9 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
region.apply_in_loop(&mut output, inner_loop_function)?;
|
||||
|
||||
let elapsed = start_time.elapsed();
|
||||
trace!("linearize_element_index took: {:?}", elapsed);
|
||||
|
||||
Ok(output.into())
|
||||
}
|
||||
|
||||
@@ -1125,7 +1140,7 @@ pub(crate) fn linearize_element_index<F: PrimeField + TensorType + PartialOrd>(
|
||||
/// If indices_shape[-1] == r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensors containing 1-D tensors of dimension r-b, where N is an integer equals to the product of 1 and all the elements in the batch dimensions of the indices_shape.
|
||||
/// Let us think of each such r-b ranked tensor as indices_slice. Each scalar value corresponding to data[0:b-1,indices_slice] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Example 1 below)
|
||||
/// If indices_shape[-1] < r-b, since the rank of indices is q, indices can be thought of as N (q-b-1)-dimensional tensor containing 1-D tensors of dimension < r-b. Let us think of each such tensors as indices_slice. Each tensor slice corresponding to data[0:b-1, indices_slice , :] is filled into the corresponding location of the (q-b-1)-dimensional tensor to form the output tensor (Examples 2, 3, 4 and 5 below)
|
||||
pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -1222,7 +1237,7 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
|
||||
.iter()
|
||||
.map(|x| {
|
||||
let slice = x.iter().map(|x| *x..*x + 1).collect::<Vec<_>>();
|
||||
Ok(index_slice.get_slice(&slice)?)
|
||||
index_slice.get_slice(&slice)
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?
|
||||
};
|
||||
@@ -1232,7 +1247,6 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
|
||||
const_offset += F::from(coord[i] as u64) * dim_multiplier[i];
|
||||
}
|
||||
|
||||
|
||||
let const_offset = create_constant_tensor(const_offset, 1);
|
||||
|
||||
let mut results = vec![];
|
||||
@@ -1250,16 +1264,18 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
|
||||
let res = sum(config, region, &[res])?;
|
||||
results.push(res.get_inner_tensor()?.clone());
|
||||
// assert than res is less than the product of the dims
|
||||
assert!(
|
||||
if region.witness_gen() {
|
||||
assert!(
|
||||
res.get_int_evals()?
|
||||
.iter()
|
||||
.all(|x| *x < dims.iter().product::<usize>() as i128),
|
||||
"res is greater than the product of the dims {} (coord={}, index_dim_multiplier={}, res={})",
|
||||
dims.iter().product::<usize>(),
|
||||
index_val.show(),
|
||||
index_val.show(),
|
||||
index_dim_multiplier.show(),
|
||||
res.show()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let result_tensor = Tensor::from(results.into_iter());
|
||||
@@ -1273,7 +1289,9 @@ pub(crate) fn linearize_nd_index<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output.into())
|
||||
}
|
||||
|
||||
pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn get_missing_set_elements<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -1304,7 +1322,7 @@ pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
fullset_evals
|
||||
.iter()
|
||||
.par_iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
@@ -1337,7 +1355,7 @@ pub(crate) fn get_missing_set_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Gather accumulated layout
|
||||
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 3],
|
||||
@@ -1354,14 +1372,14 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
|
||||
|
||||
let claimed_output: ValTensor<F> = if is_assigned {
|
||||
let claimed_output: ValTensor<F> = if is_assigned && region.witness_gen() {
|
||||
let input_inner = input.get_int_evals()?;
|
||||
let index_inner = index.get_int_evals()?.map(|x| x as usize);
|
||||
let src_inner = src.get_int_evals()?;
|
||||
|
||||
let res = tensor::ops::scatter(&input_inner, &index_inner, &src_inner, dim)?;
|
||||
|
||||
res.iter()
|
||||
res.par_iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
@@ -1419,7 +1437,7 @@ pub(crate) fn scatter_elements<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Scatter Nd
|
||||
pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 3],
|
||||
@@ -1433,14 +1451,14 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !index.any_unknowns()? && !src.any_unknowns()?;
|
||||
|
||||
let claimed_output: ValTensor<F> = if is_assigned {
|
||||
let claimed_output: ValTensor<F> = if is_assigned && region.witness_gen() {
|
||||
let input_inner = input.get_int_evals()?;
|
||||
let index_inner = index.get_int_evals()?.map(|x| x as usize);
|
||||
let src_inner = src.get_int_evals()?;
|
||||
|
||||
let res = tensor::ops::scatter_nd(&input_inner, &index_inner, &src_inner)?;
|
||||
|
||||
res.iter()
|
||||
res.par_iter()
|
||||
.map(|x| Value::known(i128_to_felt(*x)))
|
||||
.collect::<Tensor<Value<F>>>()
|
||||
.into()
|
||||
@@ -1457,7 +1475,6 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
|
||||
region.increment(claimed_output.len());
|
||||
claimed_output.reshape(input.dims())?;
|
||||
|
||||
|
||||
// scatter elements is the inverse of gather elements
|
||||
let (gather_src, linear_index) =
|
||||
gather_nd(config, region, &[claimed_output.clone(), index.clone()], 0)?;
|
||||
@@ -1498,7 +1515,7 @@ pub(crate) fn scatter_nd<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// sum accumulated layout
|
||||
pub(crate) fn sum<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -1581,7 +1598,7 @@ pub(crate) fn sum<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// product accumulated layout
|
||||
pub(crate) fn prod<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -1661,7 +1678,7 @@ pub(crate) fn prod<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Axes wise op wrapper
|
||||
fn axes_wise_op<F: PrimeField + TensorType + PartialOrd>(
|
||||
fn axes_wise_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -1722,7 +1739,7 @@ fn axes_wise_op<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Sum accumulated layout
|
||||
pub(crate) fn prod_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn prod_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -1733,7 +1750,7 @@ pub(crate) fn prod_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Sum accumulated layout
|
||||
pub(crate) fn sum_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn sum_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -1744,7 +1761,7 @@ pub(crate) fn sum_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// argmax layout
|
||||
pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -1762,7 +1779,7 @@ pub(crate) fn argmax_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Max accumulated layout
|
||||
pub(crate) fn max_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn max_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -1774,7 +1791,7 @@ pub(crate) fn max_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Argmin layout
|
||||
pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -1792,7 +1809,7 @@ pub(crate) fn argmin_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Min accumulated layout
|
||||
pub(crate) fn min_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn min_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -1804,7 +1821,7 @@ pub(crate) fn min_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Pairwise (elementwise) op layout
|
||||
pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -1959,7 +1976,7 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// expand the tensor to the given shape
|
||||
pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -1972,7 +1989,7 @@ pub(crate) fn expand<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
///
|
||||
pub(crate) fn greater<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn greater<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -1995,7 +2012,7 @@ pub(crate) fn greater<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
///
|
||||
pub(crate) fn greater_equal<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn greater_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -2018,7 +2035,7 @@ pub(crate) fn greater_equal<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
///
|
||||
pub(crate) fn less<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn less<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -2028,7 +2045,7 @@ pub(crate) fn less<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
///
|
||||
pub(crate) fn less_equal<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn less_equal<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -2038,7 +2055,7 @@ pub(crate) fn less_equal<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// And boolean operation
|
||||
pub(crate) fn and<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn and<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -2049,7 +2066,7 @@ pub(crate) fn and<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Or boolean operation
|
||||
pub(crate) fn or<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn or<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -2065,7 +2082,7 @@ pub(crate) fn or<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Equality boolean operation
|
||||
pub(crate) fn equals<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn equals<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -2075,7 +2092,7 @@ pub(crate) fn equals<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Equality boolean operation
|
||||
pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -2109,7 +2126,7 @@ pub(crate) fn equals_zero<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Xor boolean operation
|
||||
pub(crate) fn xor<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn xor<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -2135,7 +2152,7 @@ pub(crate) fn xor<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Not boolean operation
|
||||
pub(crate) fn not<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn not<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -2151,7 +2168,7 @@ pub(crate) fn not<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Iff
|
||||
pub(crate) fn iff<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn iff<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 3],
|
||||
@@ -2175,7 +2192,7 @@ pub(crate) fn iff<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Negation operation accumulated layout
|
||||
pub(crate) fn neg<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn neg<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -2185,7 +2202,7 @@ pub(crate) fn neg<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Sumpool accumulated layout
|
||||
pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
@@ -2239,7 +2256,7 @@ pub(crate) fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Convolution accumulated layout
|
||||
pub(crate) fn max_pool2d<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn max_pool2d<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -2304,7 +2321,7 @@ pub(crate) fn max_pool2d<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
/// DeConvolution accumulated layout
|
||||
pub(crate) fn deconv<
|
||||
F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::marker::Sync,
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync,
|
||||
>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
@@ -2397,7 +2414,7 @@ pub(crate) fn deconv<
|
||||
|
||||
/// Convolution accumulated layout
|
||||
pub(crate) fn conv<
|
||||
F: PrimeField + TensorType + PartialOrd + std::marker::Send + std::marker::Sync,
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash + std::marker::Send + std::marker::Sync,
|
||||
>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
@@ -2578,7 +2595,7 @@ pub(crate) fn conv<
|
||||
}
|
||||
|
||||
/// Power accumulated layout
|
||||
pub(crate) fn pow<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn pow<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -2594,7 +2611,7 @@ pub(crate) fn pow<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Rescaled op accumulated layout
|
||||
pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
@@ -2616,7 +2633,7 @@ pub(crate) fn rescale<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Dummy (no contraints) reshape layout
|
||||
pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
new_dims: &[usize],
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
@@ -2626,7 +2643,7 @@ pub(crate) fn reshape<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Dummy (no contraints) move_axis layout
|
||||
pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
source: usize,
|
||||
destination: usize,
|
||||
@@ -2637,7 +2654,7 @@ pub(crate) fn move_axis<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// resize layout
|
||||
pub(crate) fn resize<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn resize<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -2651,7 +2668,7 @@ pub(crate) fn resize<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Slice layout
|
||||
pub(crate) fn slice<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn slice<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -2660,15 +2677,45 @@ pub(crate) fn slice<F: PrimeField + TensorType + PartialOrd>(
|
||||
end: &usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// assigns the instance to the advice.
|
||||
let mut output = region.assign(&config.custom_gates.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
let mut output = values[0].clone();
|
||||
|
||||
let is_assigned = output.all_prev_assigned();
|
||||
if !is_assigned {
|
||||
output = region.assign(&config.custom_gates.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
}
|
||||
|
||||
output.slice(axis, start, end)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Trilu layout
|
||||
pub(crate) fn trilu<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
k: &i32,
|
||||
upper: &bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// assigns the instance to the advice.
|
||||
let mut output = values[0].clone();
|
||||
|
||||
let is_assigned = output.all_prev_assigned();
|
||||
if !is_assigned {
|
||||
output = region.assign(&config.custom_gates.inputs[0], &values[0])?;
|
||||
}
|
||||
|
||||
let res = tensor::ops::trilu(output.get_inner_tensor()?, *k, *upper)?;
|
||||
|
||||
let output = region.assign(&config.custom_gates.output, &res.into())?;
|
||||
region.increment(output.len());
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Concat layout
|
||||
pub(crate) fn concat<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn concat<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
values: &[ValTensor<F>],
|
||||
axis: &usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
@@ -2680,7 +2727,7 @@ pub(crate) fn concat<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
|
||||
pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -2695,7 +2742,7 @@ pub(crate) fn identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
|
||||
pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -2731,7 +2778,7 @@ pub(crate) fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Downsample layout
|
||||
pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -2748,7 +2795,7 @@ pub(crate) fn downsample<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// layout for enforcing two sets of cells to be equal
|
||||
pub(crate) fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn enforce_equality<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
@@ -2774,7 +2821,7 @@ pub(crate) fn enforce_equality<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// layout for range check.
|
||||
pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -2830,12 +2877,13 @@ pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
if region.throw_range_check_error() {
|
||||
let is_assigned = !w.any_unknowns()?;
|
||||
if is_assigned && region.witness_gen() {
|
||||
// assert is within range
|
||||
let int_values = w.get_int_evals()?;
|
||||
for v in int_values {
|
||||
if v < range.0 || v > range.1 {
|
||||
log::debug!("Value ({:?}) out of range: {:?}", v, range);
|
||||
for v in int_values.iter() {
|
||||
if v < &range.0 || v > &range.1 {
|
||||
log::error!("Value ({:?}) out of range: {:?}", v, range);
|
||||
return Err(Box::new(TensorError::TableLookupError));
|
||||
}
|
||||
}
|
||||
@@ -2855,7 +2903,7 @@ pub(crate) fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// layout for nonlinearity check.
|
||||
pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -2957,7 +3005,7 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Argmax
|
||||
pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -2993,7 +3041,7 @@ pub(crate) fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// Argmin
|
||||
pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -3029,7 +3077,7 @@ pub(crate) fn argmin<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// max layout
|
||||
pub(crate) fn max<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn max<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -3039,7 +3087,7 @@ pub(crate) fn max<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// min layout
|
||||
pub(crate) fn min<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn min<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -3047,7 +3095,7 @@ pub(crate) fn min<F: PrimeField + TensorType + PartialOrd>(
|
||||
_sort_ascending(config, region, values)?.get_slice(&[0..1])
|
||||
}
|
||||
|
||||
fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd>(
|
||||
fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
@@ -3150,18 +3198,19 @@ fn multi_dim_axes_op<F: PrimeField + TensorType + PartialOrd>(
|
||||
}
|
||||
|
||||
/// softmax layout
|
||||
pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: utils::F32,
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
axes: &[usize],
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let soft_max_at_scale = move |config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1]|
|
||||
-> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
softmax(config, region, values, scale)
|
||||
softmax(config, region, values, input_scale, output_scale)
|
||||
};
|
||||
|
||||
let output = multi_dim_axes_op(config, region, values, axes, soft_max_at_scale)?;
|
||||
@@ -3169,33 +3218,66 @@ pub(crate) fn softmax_axes<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// softmax func
|
||||
pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
/// percent func
|
||||
pub(crate) fn percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: utils::F32,
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// elementwise exponential
|
||||
let ex = nonlinearity(config, region, values, &LookupOp::Exp { scale })?;
|
||||
|
||||
let is_assigned = values[0].all_prev_assigned();
|
||||
let mut input = values[0].clone();
|
||||
if !is_assigned {
|
||||
input = region.assign(&config.custom_gates.inputs[0], &values[0])?;
|
||||
region.increment(input.len());
|
||||
};
|
||||
// sum of exps
|
||||
let denom = sum(config, region, &[ex.clone()])?;
|
||||
// get the inverse
|
||||
|
||||
let felt_scale = F::from(scale.0 as u64);
|
||||
let inv_denom = recip(config, region, &[denom], felt_scale, felt_scale)?;
|
||||
let denom = sum(config, region, &[input.clone()])?;
|
||||
|
||||
let input_felt_scale = F::from(input_scale.0 as u64);
|
||||
let output_felt_scale = F::from(output_scale.0 as u64);
|
||||
let inv_denom = recip(
|
||||
config,
|
||||
region,
|
||||
&[denom],
|
||||
input_felt_scale,
|
||||
output_felt_scale,
|
||||
)?;
|
||||
// product of num * (1 / denom) = 2*output_scale
|
||||
let softmax = pairwise(config, region, &[ex, inv_denom], BaseOp::Mult)?;
|
||||
let percent = pairwise(config, region, &[input, inv_denom], BaseOp::Mult)?;
|
||||
|
||||
Ok(softmax)
|
||||
// rebase the percent to 2x the scale
|
||||
loop_div(config, region, &[percent], input_felt_scale)
|
||||
}
|
||||
|
||||
/// softmax func
|
||||
pub(crate) fn softmax<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
input_scale: utils::F32,
|
||||
output_scale: utils::F32,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
// get the max then subtract it
|
||||
let max_val = max(config, region, values)?;
|
||||
// rebase the input to 0
|
||||
let sub = pairwise(config, region, &[values[0].clone(), max_val], BaseOp::Sub)?;
|
||||
// elementwise exponential
|
||||
let ex = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[sub],
|
||||
&LookupOp::Exp { scale: input_scale },
|
||||
)?;
|
||||
|
||||
percent(config, region, &[ex.clone()], input_scale, output_scale)
|
||||
}
|
||||
|
||||
/// Checks that the percent error between the expected public output and the actual output value
|
||||
/// is within the percent error expressed by the `tol` input, where `tol == 1.0` means the percent
|
||||
/// error tolerance is 1 percent.
|
||||
pub(crate) fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub(crate) fn range_check_percent<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
|
||||
@@ -123,6 +123,9 @@ pub enum LookupOp {
|
||||
scale: utils::F32,
|
||||
a: utils::F32,
|
||||
},
|
||||
HardSwish {
|
||||
scale: utils::F32,
|
||||
},
|
||||
}
|
||||
|
||||
impl LookupOp {
|
||||
@@ -134,7 +137,7 @@ impl LookupOp {
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
@@ -223,6 +226,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
LookupOp::ATan { scale } => Ok(tensor::ops::nonlinearities::atan(&x, scale.into())),
|
||||
LookupOp::ATanh { scale } => Ok(tensor::ops::nonlinearities::atanh(&x, scale.into())),
|
||||
LookupOp::Tanh { scale } => Ok(tensor::ops::nonlinearities::tanh(&x, scale.into())),
|
||||
LookupOp::HardSwish { scale } => {
|
||||
Ok(tensor::ops::nonlinearities::hardswish(&x, scale.into()))
|
||||
}
|
||||
}?;
|
||||
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
@@ -276,6 +282,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
LookupOp::ASin { scale } => format!("ASIN(scale={})", scale),
|
||||
LookupOp::Sinh { scale } => format!("SINH(scale={})", scale),
|
||||
LookupOp::ASinh { scale } => format!("ASINH(scale={})", scale),
|
||||
LookupOp::HardSwish { scale } => format!("HARDSWISH(scale={})", scale),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -27,12 +27,14 @@ pub mod region;
|
||||
|
||||
/// A struct representing the result of a forward pass.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
pub(crate) output: Tensor<F>,
|
||||
}
|
||||
|
||||
/// A trait representing operations that can be represented as constraints in a circuit.
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send + Sync + Any {
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
std::fmt::Debug + Send + Sync + Any
|
||||
{
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
|
||||
/// Returns a string representation of the operation.
|
||||
@@ -98,7 +100,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send +
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Clone for Box<dyn Op<F>> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
|
||||
fn clone(&self) -> Self {
|
||||
self.clone_dyn()
|
||||
}
|
||||
@@ -165,7 +167,7 @@ pub struct Input {
|
||||
pub datum_type: InputType,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
Ok(self.scale)
|
||||
}
|
||||
@@ -226,7 +228,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Unknown;
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, Box<dyn Error>> {
|
||||
Ok(0)
|
||||
}
|
||||
@@ -256,7 +258,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
|
||||
|
||||
///
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
///
|
||||
pub quantized_values: Tensor<F>,
|
||||
///
|
||||
@@ -266,7 +268,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub pre_assigned_val: Option<ValTensor<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
///
|
||||
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
|
||||
Self {
|
||||
@@ -293,8 +295,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
|
||||
for Constant<F>
|
||||
impl<
|
||||
F: PrimeField
|
||||
+ TensorType
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
{
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
|
||||
@@ -83,10 +83,20 @@ pub enum PolyOp {
|
||||
And,
|
||||
Or,
|
||||
Xor,
|
||||
Trilu {
|
||||
upper: bool,
|
||||
k: i32,
|
||||
},
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
|
||||
for PolyOp
|
||||
impl<
|
||||
F: PrimeField
|
||||
+ TensorType
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for PolyOp
|
||||
{
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
@@ -114,7 +124,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Add => "ADD".into(),
|
||||
PolyOp::Mult => "MULT".into(),
|
||||
PolyOp::Sub => "SUB".into(),
|
||||
PolyOp::Sum { .. } => "SUM".into(),
|
||||
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
|
||||
PolyOp::Prod { .. } => "PROD".into(),
|
||||
PolyOp::Pow(_) => "POW".into(),
|
||||
PolyOp::Conv { .. } => "CONV".into(),
|
||||
@@ -128,6 +138,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::And => "AND".into(),
|
||||
PolyOp::Or => "OR".into(),
|
||||
PolyOp::Xor => "XOR".into(),
|
||||
PolyOp::Trilu { upper, k } => format!("TRILU (upper={}, k={})", upper, k),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,6 +276,7 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
};
|
||||
tensor::ops::scatter_nd(&x, &idx, &src)
|
||||
}
|
||||
PolyOp::Trilu { upper, k } => tensor::ops::trilu(&inputs[0], *k, *upper),
|
||||
}?;
|
||||
|
||||
Ok(ForwardResult { output: res })
|
||||
@@ -384,6 +396,9 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Slice { axis, start, end } => {
|
||||
layouts::slice(config, region, values[..].try_into()?, axis, start, end)?
|
||||
}
|
||||
PolyOp::Trilu { upper, k } => {
|
||||
layouts::trilu(config, region, values[..].try_into()?, k, upper)?
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -2,24 +2,28 @@ use crate::{
|
||||
circuit::table::Range,
|
||||
tensor::{Tensor, TensorError, TensorType, ValTensor, ValType, VarTensor},
|
||||
};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored::Colorize;
|
||||
use halo2_proofs::{
|
||||
circuit::Region,
|
||||
plonk::{Error, Selector},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use portable_atomic::AtomicI128 as AtomicInt;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
collections::HashSet,
|
||||
collections::{HashMap, HashSet},
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
};
|
||||
|
||||
use portable_atomic::AtomicI128 as AtomicInt;
|
||||
|
||||
use super::lookup::LookupOp;
|
||||
|
||||
/// Constants map
|
||||
pub type ConstantsMap<F> = HashMap<F, ValType<F>>;
|
||||
|
||||
/// Dynamic lookup index
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct DynamicLookupIndex {
|
||||
@@ -120,12 +124,11 @@ impl From<Box<dyn std::error::Error>> for RegionError {
|
||||
|
||||
#[derive(Debug)]
|
||||
/// A context for a region
|
||||
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
|
||||
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
region: Option<RefCell<Region<'a, F>>>,
|
||||
row: usize,
|
||||
linear_coord: usize,
|
||||
num_inner_cols: usize,
|
||||
total_constants: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
shuffle_index: ShuffleIndex,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
@@ -133,13 +136,34 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
|
||||
max_lookup_inputs: i128,
|
||||
min_lookup_inputs: i128,
|
||||
max_range_size: i128,
|
||||
throw_range_check_error: bool,
|
||||
witness_gen: bool,
|
||||
assigned_constants: ConstantsMap<F>,
|
||||
}
|
||||
|
||||
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a, F> {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
///
|
||||
pub fn increment_total_constants(&mut self, n: usize) {
|
||||
self.total_constants += n;
|
||||
pub fn debug_report(&self) {
|
||||
log::debug!(
|
||||
"(rows={}, coord={}, constants={}, max_lookup_inputs={}, min_lookup_inputs={}, max_range_size={}, dynamic_lookup_col_coord={}, shuffle_col_coord={})",
|
||||
self.row().to_string().blue(),
|
||||
self.linear_coord().to_string().yellow(),
|
||||
self.total_constants().to_string().red(),
|
||||
self.max_lookup_inputs().to_string().green(),
|
||||
self.min_lookup_inputs().to_string().green(),
|
||||
self.max_range_size().to_string().green(),
|
||||
self.dynamic_lookup_col_coord().to_string().green(),
|
||||
self.shuffle_col_coord().to_string().green());
|
||||
}
|
||||
|
||||
///
|
||||
pub fn assigned_constants(&self) -> &ConstantsMap<F> {
|
||||
&self.assigned_constants
|
||||
}
|
||||
|
||||
///
|
||||
pub fn update_constants(&mut self, constants: ConstantsMap<F>) {
|
||||
self.assigned_constants.extend(constants);
|
||||
}
|
||||
|
||||
///
|
||||
@@ -163,8 +187,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn throw_range_check_error(&self) -> bool {
|
||||
self.throw_range_check_error
|
||||
pub fn witness_gen(&self) -> bool {
|
||||
self.witness_gen
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
@@ -177,7 +201,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
num_inner_cols,
|
||||
row,
|
||||
linear_coord,
|
||||
total_constants: 0,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
@@ -185,9 +208,22 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error: false,
|
||||
witness_gen: true,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_with_constants(
|
||||
region: Region<'a, F>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
constants: ConstantsMap<F>,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let mut new_self = Self::new(region, row, num_inner_cols);
|
||||
new_self.assigned_constants = constants;
|
||||
new_self
|
||||
}
|
||||
/// Create a new region context from a wrapped region
|
||||
pub fn from_wrapped_region(
|
||||
region: Option<RefCell<Region<'a, F>>>,
|
||||
@@ -202,7 +238,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
num_inner_cols,
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: 0,
|
||||
dynamic_lookup_index,
|
||||
shuffle_index,
|
||||
used_lookups: HashSet::new(),
|
||||
@@ -210,16 +245,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error: false,
|
||||
witness_gen: false,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_dummy(
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
throw_range_check_error: bool,
|
||||
) -> RegionCtx<'a, F> {
|
||||
pub fn new_dummy(row: usize, num_inner_cols: usize, witness_gen: bool) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
let linear_coord = row * num_inner_cols;
|
||||
|
||||
@@ -228,7 +260,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
num_inner_cols,
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants: 0,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
@@ -236,17 +267,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error,
|
||||
witness_gen,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_dummy_with_constants(
|
||||
pub fn new_dummy_with_linear_coord(
|
||||
row: usize,
|
||||
linear_coord: usize,
|
||||
total_constants: usize,
|
||||
num_inner_cols: usize,
|
||||
throw_range_check_error: bool,
|
||||
witness_gen: bool,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
RegionCtx {
|
||||
@@ -254,7 +285,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
num_inner_cols,
|
||||
linear_coord,
|
||||
row,
|
||||
total_constants,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
@@ -262,7 +292,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
throw_range_check_error,
|
||||
witness_gen,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -312,29 +343,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
) -> Result<(), RegionError> {
|
||||
let row = AtomicUsize::new(self.row());
|
||||
let linear_coord = AtomicUsize::new(self.linear_coord());
|
||||
let constants = AtomicUsize::new(self.total_constants());
|
||||
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
|
||||
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
|
||||
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
|
||||
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
|
||||
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
|
||||
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
|
||||
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
|
||||
|
||||
*output = output
|
||||
.par_enum_map(|idx, _| {
|
||||
// we kick off the loop with the current offset
|
||||
let starting_offset = row.load(Ordering::SeqCst);
|
||||
let starting_linear_coord = linear_coord.load(Ordering::SeqCst);
|
||||
let starting_constants = constants.load(Ordering::SeqCst);
|
||||
// get inner value of the locked lookups
|
||||
|
||||
// we need to make sure that the region is not shared between threads
|
||||
let mut local_reg = Self::new_dummy_with_constants(
|
||||
let mut local_reg = Self::new_dummy_with_linear_coord(
|
||||
starting_offset,
|
||||
starting_linear_coord,
|
||||
starting_constants,
|
||||
self.num_inner_cols,
|
||||
self.throw_range_check_error,
|
||||
self.witness_gen,
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
@@ -343,10 +372,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
local_reg.linear_coord() - starting_linear_coord,
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
constants.fetch_add(
|
||||
local_reg.total_constants() - starting_constants,
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
|
||||
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
|
||||
@@ -362,11 +387,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
// update the shuffle index
|
||||
let mut shuffle_index = shuffle_index.lock().unwrap();
|
||||
shuffle_index.update(&local_reg.shuffle_index);
|
||||
// update the constants
|
||||
let mut constants = constants.lock().unwrap();
|
||||
constants.extend(local_reg.assigned_constants);
|
||||
|
||||
res
|
||||
})
|
||||
.map_err(|e| RegionError::from(format!("dummy_loop: {:?}", e)))?;
|
||||
self.total_constants = constants.into_inner();
|
||||
self.linear_coord = linear_coord.into_inner();
|
||||
#[allow(trivial_numeric_casts)]
|
||||
{
|
||||
@@ -410,6 +437,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get shuffle index: {:?}", e))
|
||||
})?;
|
||||
self.assigned_constants = Arc::try_unwrap(constants)
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
|
||||
})?
|
||||
.into_inner()
|
||||
.map_err(|e| {
|
||||
RegionError::from(format!("dummy_loop: failed to get constants: {:?}", e))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -435,7 +470,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
range: Range,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if range.0 > range.1 {
|
||||
return Err("update_max_min_lookup_range: invalid range".into());
|
||||
return Err(format!("update_max_min_lookup_range: invalid range {:?}", range).into());
|
||||
}
|
||||
|
||||
let range_size = (range.1 - range.0).abs();
|
||||
@@ -477,7 +512,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
|
||||
/// Get the total number of constants
|
||||
pub fn total_constants(&self) -> usize {
|
||||
self.total_constants
|
||||
self.assigned_constants.len()
|
||||
}
|
||||
|
||||
/// Get the dynamic lookup index
|
||||
@@ -525,26 +560,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.max_range_size
|
||||
}
|
||||
|
||||
/// Assign a constant value
|
||||
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
|
||||
self.total_constants += 1;
|
||||
if let Some(region) = &self.region {
|
||||
let cell = var.assign_constant(&mut region.borrow_mut(), self.linear_coord, value)?;
|
||||
Ok(cell.into())
|
||||
} else {
|
||||
Ok(value.into())
|
||||
}
|
||||
}
|
||||
/// Assign a valtensor to a vartensor
|
||||
pub fn assign(
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
self.total_constants += values.num_constants();
|
||||
if let Some(region) = &self.region {
|
||||
var.assign(&mut region.borrow_mut(), self.linear_coord, values)
|
||||
var.assign(
|
||||
&mut region.borrow_mut(),
|
||||
self.linear_coord,
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
} else {
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
@@ -560,14 +593,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
self.total_constants += values.num_constants();
|
||||
if let Some(region) = &self.region {
|
||||
var.assign(
|
||||
&mut region.borrow_mut(),
|
||||
self.combined_dynamic_shuffle_coord(),
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
} else {
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
@@ -594,13 +631,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.linear_coord,
|
||||
values,
|
||||
ommissions,
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
} else {
|
||||
self.total_constants += values.num_constants();
|
||||
let inner_tensor = values.get_inner_tensor().unwrap();
|
||||
let mut values_map = values.create_constants_map();
|
||||
|
||||
for o in ommissions {
|
||||
self.total_constants -= inner_tensor.get_flat_index(**o).is_constant() as usize;
|
||||
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
|
||||
values_map.remove(&value);
|
||||
}
|
||||
}
|
||||
|
||||
self.assigned_constants.extend(values_map);
|
||||
|
||||
Ok(values.clone())
|
||||
}
|
||||
}
|
||||
@@ -615,24 +659,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
) -> Result<(ValTensor<F>, usize), Error> {
|
||||
if let Some(region) = &self.region {
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let (res, len, total_assigned_constants) = var.assign_with_duplication(
|
||||
let (res, len) = var.assign_with_duplication(
|
||||
&mut region.borrow_mut(),
|
||||
self.row,
|
||||
self.linear_coord,
|
||||
values,
|
||||
check_mode,
|
||||
single_inner_col,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
self.total_constants += total_assigned_constants;
|
||||
Ok((res, len))
|
||||
} else {
|
||||
let (_, len, total_assigned_constants) = var.dummy_assign_with_duplication(
|
||||
let (_, len) = var.dummy_assign_with_duplication(
|
||||
self.row,
|
||||
self.linear_coord,
|
||||
values,
|
||||
single_inner_col,
|
||||
&mut self.assigned_constants,
|
||||
)?;
|
||||
self.total_constants += total_assigned_constants;
|
||||
Ok((values.clone(), len))
|
||||
}
|
||||
}
|
||||
@@ -699,9 +743,4 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// increment constants
|
||||
pub fn increment_constants(&mut self, n: usize) {
|
||||
self.total_constants += n
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ use halo2_proofs::{
|
||||
circuit::{Layouter, Value},
|
||||
plonk::{ConstraintSystem, Expression, TableColumn},
|
||||
};
|
||||
use log::warn;
|
||||
use log::{debug, warn};
|
||||
use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator};
|
||||
|
||||
use crate::{
|
||||
@@ -98,7 +98,7 @@ pub struct Table<F: PrimeField> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
@@ -138,7 +138,7 @@ pub fn num_cols_required(range_len: i128, col_size: usize) -> usize {
|
||||
(range_len / (col_size as i128)) as usize + 1
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -152,7 +152,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Table<F> {
|
||||
// number of cols needed to store the range
|
||||
let num_cols = num_cols_required((range.1 - range.0).abs(), col_size);
|
||||
|
||||
log::debug!("table range: {:?}", range);
|
||||
debug!("table range: {:?}", range);
|
||||
|
||||
let table_inputs = preexisting_inputs.unwrap_or_else(|| {
|
||||
let mut cols = vec![];
|
||||
@@ -275,7 +275,7 @@ pub struct RangeCheck<F: PrimeField> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self, chunk: usize) -> F {
|
||||
let chunk = chunk as i128;
|
||||
@@ -303,7 +303,7 @@ impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> RangeCheck<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
|
||||
log::debug!("range check range: {:?}", range);
|
||||
|
||||
@@ -245,7 +245,13 @@ mod matmul_col_overflow {
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod matmul_col_ultra_overflow_double_col {
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
|
||||
use halo2_proofs::poly::kzg::{
|
||||
commitment::KZGCommitmentScheme,
|
||||
multiopen::{ProverSHPLONK, VerifierSHPLONK},
|
||||
strategy::SingleStrategy,
|
||||
};
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -324,46 +330,46 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
|
||||
let pk = crate::pfsys::create_keys::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
F,
|
||||
MatmulCircuit<F>,
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
let prover = crate::pfsys::create_proof_circuit_kzg(
|
||||
let prover = crate::pfsys::create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
_,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms),
|
||||
// use safe mode to verify that the proof is correct
|
||||
CheckMode::SAFE,
|
||||
crate::Commitments::KZG,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(prover.is_ok());
|
||||
|
||||
let proof = prover.unwrap();
|
||||
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod matmul_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
|
||||
use halo2_proofs::poly::kzg::{
|
||||
commitment::KZGCommitmentScheme,
|
||||
multiopen::{ProverSHPLONK, VerifierSHPLONK},
|
||||
strategy::SingleStrategy,
|
||||
};
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -441,39 +447,33 @@ mod matmul_col_ultra_overflow {
|
||||
|
||||
let pk = crate::pfsys::create_keys::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
F,
|
||||
MatmulCircuit<F>,
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
let prover = crate::pfsys::create_proof_circuit_kzg(
|
||||
let prover = crate::pfsys::create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
_,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms),
|
||||
// use safe mode to verify that the proof is correct
|
||||
CheckMode::SAFE,
|
||||
crate::Commitments::KZG,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(prover.is_ok());
|
||||
|
||||
let proof = prover.unwrap();
|
||||
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1145,7 +1145,15 @@ mod conv {
|
||||
#[cfg(test)]
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod conv_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
|
||||
use halo2_proofs::poly::{
|
||||
kzg::strategy::SingleStrategy,
|
||||
kzg::{
|
||||
commitment::KZGCommitmentScheme,
|
||||
multiopen::{ProverSHPLONK, VerifierSHPLONK},
|
||||
},
|
||||
};
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -1243,39 +1251,33 @@ mod conv_col_ultra_overflow {
|
||||
|
||||
let pk = crate::pfsys::create_keys::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
F,
|
||||
ConvCircuit<F>,
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
let prover = crate::pfsys::create_proof_circuit_kzg(
|
||||
let prover = crate::pfsys::create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
_,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms),
|
||||
// use safe mode to verify that the proof is correct
|
||||
CheckMode::SAFE,
|
||||
crate::Commitments::KZG,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(prover.is_ok());
|
||||
|
||||
let proof = prover.unwrap();
|
||||
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1283,7 +1285,13 @@ mod conv_col_ultra_overflow {
|
||||
// not wasm 32 unknown
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
mod conv_relu_col_ultra_overflow {
|
||||
use halo2_proofs::poly::commitment::{Params, ParamsProver};
|
||||
|
||||
use halo2_proofs::poly::kzg::{
|
||||
commitment::KZGCommitmentScheme,
|
||||
multiopen::{ProverSHPLONK, VerifierSHPLONK},
|
||||
strategy::SingleStrategy,
|
||||
};
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
use super::*;
|
||||
|
||||
@@ -1396,39 +1404,33 @@ mod conv_relu_col_ultra_overflow {
|
||||
|
||||
let pk = crate::pfsys::create_keys::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
F,
|
||||
ConvCircuit<F>,
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
let prover = crate::pfsys::create_proof_circuit_kzg(
|
||||
let prover = crate::pfsys::create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
_,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms),
|
||||
// use safe mode to verify that the proof is correct
|
||||
CheckMode::SAFE,
|
||||
crate::Commitments::KZG,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
// use safe mode to verify that the proof is correct
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(prover.is_ok());
|
||||
|
||||
let proof = prover.unwrap();
|
||||
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1909,6 +1911,8 @@ mod add_with_overflow {
|
||||
|
||||
#[cfg(test)]
|
||||
mod add_with_overflow_and_poseidon {
|
||||
use std::collections::HashMap;
|
||||
|
||||
use halo2curves::bn256::Fr;
|
||||
|
||||
use crate::circuit::modules::{
|
||||
@@ -1967,8 +1971,10 @@ mod add_with_overflow_and_poseidon {
|
||||
let poseidon_chip: PoseidonChip<PoseidonSpec, WIDTH, RATE, WIDTH> =
|
||||
PoseidonChip::new(config.poseidon.clone());
|
||||
|
||||
let assigned_inputs_a = poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0)?;
|
||||
let assigned_inputs_b = poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1)?;
|
||||
let assigned_inputs_a =
|
||||
poseidon_chip.layout(&mut layouter, &self.inputs[0..1], 0, &mut HashMap::new())?;
|
||||
let assigned_inputs_b =
|
||||
poseidon_chip.layout(&mut layouter, &self.inputs[1..2], 1, &mut HashMap::new())?;
|
||||
|
||||
layouter.assign_region(|| "_new_module", |_| Ok(()))?;
|
||||
|
||||
@@ -2417,8 +2423,13 @@ mod lookup_ultra_overflow {
|
||||
use halo2_proofs::{
|
||||
circuit::{Layouter, SimpleFloorPlanner, Value},
|
||||
plonk::{Circuit, ConstraintSystem, Error},
|
||||
poly::commitment::{Params, ParamsProver},
|
||||
poly::kzg::{
|
||||
commitment::KZGCommitmentScheme,
|
||||
multiopen::{ProverSHPLONK, VerifierSHPLONK},
|
||||
strategy::SingleStrategy,
|
||||
},
|
||||
};
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
@@ -2497,38 +2508,32 @@ mod lookup_ultra_overflow {
|
||||
|
||||
let pk = crate::pfsys::create_keys::<
|
||||
halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme<halo2curves::bn256::Bn256>,
|
||||
F,
|
||||
ReLUCircuit<F>,
|
||||
>(&circuit, ¶ms, true)
|
||||
.unwrap();
|
||||
|
||||
let prover = crate::pfsys::create_proof_circuit_kzg(
|
||||
let prover = crate::pfsys::create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
_,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
SingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit.clone(),
|
||||
vec![],
|
||||
¶ms,
|
||||
None,
|
||||
&pk,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(¶ms),
|
||||
// use safe mode to verify that the proof is correct
|
||||
CheckMode::SAFE,
|
||||
crate::Commitments::KZG,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
None,
|
||||
None,
|
||||
);
|
||||
|
||||
assert!(prover.is_ok());
|
||||
|
||||
let proof = prover.unwrap();
|
||||
|
||||
let strategy =
|
||||
halo2_proofs::poly::kzg::strategy::SingleStrategy::new(params.verifier_params());
|
||||
let vk = pk.get_vk();
|
||||
let result = crate::pfsys::verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
proof,
|
||||
vk,
|
||||
strategy,
|
||||
params.n(),
|
||||
);
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ use std::path::PathBuf;
|
||||
use std::{error::Error, str::FromStr};
|
||||
use tosubcommand::{ToFlags, ToSubcommand};
|
||||
|
||||
use crate::{pfsys::ProofType, RunArgs};
|
||||
use crate::{pfsys::ProofType, Commitments, RunArgs};
|
||||
|
||||
use crate::circuit::CheckMode;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -90,6 +90,8 @@ pub const DEFAULT_SCALE_REBASE_MULTIPLIERS: &str = "1,2,10";
|
||||
pub const DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION: &str = "false";
|
||||
/// Default only check for range check rebase
|
||||
pub const DEFAULT_ONLY_RANGE_CHECK_REBASE: &str = "false";
|
||||
/// Default commitment
|
||||
pub const DEFAULT_COMMITMENT: &str = "kzg";
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
/// Converts TranscriptType into a PyObject (Required for TranscriptType to be compatible with Python)
|
||||
@@ -294,21 +296,6 @@ pub enum Commands {
|
||||
args: RunArgs,
|
||||
},
|
||||
|
||||
#[cfg(feature = "render")]
|
||||
/// Renders the model circuit to a .png file. For an overview of how to interpret these plots, see https://zcash.github.io/halo2/user/dev-tools.html
|
||||
#[command(arg_required_else_help = true)]
|
||||
RenderCircuit {
|
||||
/// The path to the .onnx model file
|
||||
#[arg(short = 'M', long)]
|
||||
model: PathBuf,
|
||||
/// Path to save the .png circuit render
|
||||
#[arg(short = 'O', long)]
|
||||
output: PathBuf,
|
||||
/// proving arguments
|
||||
#[clap(flatten)]
|
||||
args: RunArgs,
|
||||
},
|
||||
|
||||
/// Generates the witness from an input file.
|
||||
GenWitness {
|
||||
/// The path to the .json data file
|
||||
@@ -387,6 +374,9 @@ pub enum Commands {
|
||||
/// number of logrows to use for srs
|
||||
#[arg(long)]
|
||||
logrows: usize,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
commitment: Commitments,
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -402,6 +392,9 @@ pub enum Commands {
|
||||
/// Number of logrows to use for srs. Overrides settings_path if specified.
|
||||
#[arg(long, default_value = None)]
|
||||
logrows: Option<u32>,
|
||||
/// Commitment used
|
||||
#[arg(long, default_value = None)]
|
||||
commitment: Option<Commitments>,
|
||||
},
|
||||
/// Loads model and input and runs mock prover (for testing)
|
||||
Mock {
|
||||
@@ -449,6 +442,9 @@ pub enum Commands {
|
||||
/// compress selectors
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
|
||||
disable_selector_compression: bool,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
commitment: Commitments,
|
||||
},
|
||||
/// Aggregates proofs :)
|
||||
Aggregate {
|
||||
@@ -481,6 +477,9 @@ pub enum Commands {
|
||||
/// whether the accumulated proofs are segments of a larger circuit
|
||||
#[arg(long, default_value = DEFAULT_SPLIT)]
|
||||
split_proofs: bool,
|
||||
/// commitment used
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
commitment: Commitments,
|
||||
},
|
||||
/// Compiles a circuit from onnx to a simplified graph (einsum + other ops) and parameters as sets of field elements
|
||||
CompileCircuit {
|
||||
@@ -515,31 +514,6 @@ pub enum Commands {
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
|
||||
disable_selector_compression: bool,
|
||||
},
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Fuzzes the proof pipeline with random inputs, random parameters, and random keys
|
||||
Fuzz {
|
||||
/// The path to the .json witness file (generated using the gen-witness command)
|
||||
#[arg(short = 'W', long, default_value = DEFAULT_WITNESS)]
|
||||
witness: PathBuf,
|
||||
/// The path to the compiled model file (generated using the compile-circuit command)
|
||||
#[arg(short = 'M', long)]
|
||||
compiled_circuit: PathBuf,
|
||||
#[arg(
|
||||
long,
|
||||
require_equals = true,
|
||||
num_args = 0..=1,
|
||||
default_value_t = TranscriptType::default(),
|
||||
value_enum
|
||||
)]
|
||||
transcript: TranscriptType,
|
||||
/// number of fuzz iterations
|
||||
#[arg(long, default_value = DEFAULT_FUZZ_RUNS)]
|
||||
num_runs: usize,
|
||||
/// compress selectors
|
||||
#[arg(long, default_value = DEFAULT_DISABLE_SELECTOR_COMPRESSION)]
|
||||
disable_selector_compression: bool,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys a test contact that the data attester reads from and creates a data attestation formatted input.json file that contains call data information
|
||||
#[command(arg_required_else_help = true)]
|
||||
@@ -741,12 +715,18 @@ pub enum Commands {
|
||||
/// The path to the verification key file (generated using the setup-aggregate command)
|
||||
#[arg(long, default_value = DEFAULT_VK_AGGREGATED)]
|
||||
vk_path: PathBuf,
|
||||
/// reduced srs
|
||||
#[arg(long, default_value = DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION)]
|
||||
reduced_srs: bool,
|
||||
/// The path to SRS, if None will use $EZKL_REPO_PATH/srs/kzg{logrows}.srs
|
||||
#[arg(long)]
|
||||
srs_path: Option<PathBuf>,
|
||||
/// logrows used for aggregation circuit
|
||||
#[arg(long, default_value = DEFAULT_AGGREGATED_LOGROWS)]
|
||||
logrows: u32,
|
||||
/// commitment
|
||||
#[arg(long, default_value = DEFAULT_COMMITMENT)]
|
||||
commitment: Commitments,
|
||||
},
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Deploys an evm verifier that is generated by ezkl
|
||||
|
||||
1271
src/execute.rs
1271
src/execute.rs
File diff suppressed because it is too large
Load Diff
103
src/graph/mod.rs
103
src/graph/mod.rs
@@ -15,7 +15,7 @@ use colored_json::ToColoredJson;
|
||||
#[cfg(unix)]
|
||||
use gag::Gag;
|
||||
use halo2_proofs::plonk::VerifyingKey;
|
||||
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
|
||||
use halo2_proofs::poly::commitment::CommitmentScheme;
|
||||
pub use input::DataSource;
|
||||
use itertools::Itertools;
|
||||
use tosubcommand::ToFlags;
|
||||
@@ -26,6 +26,7 @@ use self::input::{FileSource, GraphData};
|
||||
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::{CheckMode, InputType};
|
||||
use crate::fieldutils::felt_to_f64;
|
||||
@@ -37,7 +38,7 @@ use halo2_proofs::{
|
||||
circuit::Layouter,
|
||||
plonk::{Circuit, ConstraintSystem, Error as PlonkError},
|
||||
};
|
||||
use halo2curves::bn256::{self, Bn256, Fr as Fp, G1Affine};
|
||||
use halo2curves::bn256::{self, Fr as Fp, G1Affine};
|
||||
use halo2curves::ff::PrimeField;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use lazy_static::lazy_static;
|
||||
@@ -126,7 +127,7 @@ pub enum GraphError {
|
||||
#[error("failed to rescale inputs for {0}")]
|
||||
RescalingError(String),
|
||||
/// Error when attempting to load a model
|
||||
#[error("failed to load model")]
|
||||
#[error("failed to load")]
|
||||
ModelLoad,
|
||||
/// Packing exponent is too large
|
||||
#[error("largest packing exponent exceeds max. try reducing the scale")]
|
||||
@@ -155,7 +156,7 @@ use std::cell::RefCell;
|
||||
thread_local!(
|
||||
/// This is a global variable that holds the settings for the graph
|
||||
/// This is used to pass settings to the layouter and other parts of the circuit without needing to heavily modify the Halo2 API in a new fork
|
||||
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = RefCell::new(None)
|
||||
pub static GLOBAL_SETTINGS: RefCell<Option<GraphSettings>> = const { RefCell::new(None) }
|
||||
);
|
||||
|
||||
/// Result from a forward pass
|
||||
@@ -284,20 +285,20 @@ impl GraphWitness {
|
||||
}
|
||||
|
||||
///
|
||||
pub fn get_kzg_commitments(&self) -> Vec<G1Affine> {
|
||||
pub fn get_polycommitments(&self) -> Vec<G1Affine> {
|
||||
let mut commitments = vec![];
|
||||
if let Some(processed_inputs) = &self.processed_inputs {
|
||||
if let Some(commits) = &processed_inputs.kzg_commit {
|
||||
if let Some(commits) = &processed_inputs.polycommit {
|
||||
commitments.extend(commits.iter().flatten());
|
||||
}
|
||||
}
|
||||
if let Some(processed_params) = &self.processed_params {
|
||||
if let Some(commits) = &processed_params.kzg_commit {
|
||||
if let Some(commits) = &processed_params.polycommit {
|
||||
commitments.extend(commits.iter().flatten());
|
||||
}
|
||||
}
|
||||
if let Some(processed_outputs) = &self.processed_outputs {
|
||||
if let Some(commits) = &processed_outputs.kzg_commit {
|
||||
if let Some(commits) = &processed_outputs.polycommit {
|
||||
commitments.extend(commits.iter().flatten());
|
||||
}
|
||||
}
|
||||
@@ -318,7 +319,7 @@ impl GraphWitness {
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, Box<dyn std::error::Error>> {
|
||||
let file = std::fs::File::open(path.clone())
|
||||
.map_err(|_| format!("failed to load model at {}", path.display()))?;
|
||||
.map_err(|_| format!("failed to load {}", path.display()))?;
|
||||
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::from_reader(reader).map_err(|e| e.into())
|
||||
@@ -387,8 +388,8 @@ impl ToPyObject for GraphWitness {
|
||||
if let Some(processed_inputs_poseidon_hash) = &processed_inputs.poseidon_hash {
|
||||
insert_poseidon_hash_pydict(dict_inputs, processed_inputs_poseidon_hash).unwrap();
|
||||
}
|
||||
if let Some(processed_inputs_kzg_commit) = &processed_inputs.kzg_commit {
|
||||
insert_kzg_commit_pydict(dict_inputs, processed_inputs_kzg_commit).unwrap();
|
||||
if let Some(processed_inputs_polycommit) = &processed_inputs.polycommit {
|
||||
insert_polycommit_pydict(dict_inputs, processed_inputs_polycommit).unwrap();
|
||||
}
|
||||
|
||||
dict.set_item("processed_inputs", dict_inputs).unwrap();
|
||||
@@ -398,8 +399,8 @@ impl ToPyObject for GraphWitness {
|
||||
if let Some(processed_params_poseidon_hash) = &processed_params.poseidon_hash {
|
||||
insert_poseidon_hash_pydict(dict_params, processed_params_poseidon_hash).unwrap();
|
||||
}
|
||||
if let Some(processed_params_kzg_commit) = &processed_params.kzg_commit {
|
||||
insert_kzg_commit_pydict(dict_inputs, processed_params_kzg_commit).unwrap();
|
||||
if let Some(processed_params_polycommit) = &processed_params.polycommit {
|
||||
insert_polycommit_pydict(dict_inputs, processed_params_polycommit).unwrap();
|
||||
}
|
||||
|
||||
dict.set_item("processed_params", dict_params).unwrap();
|
||||
@@ -409,8 +410,8 @@ impl ToPyObject for GraphWitness {
|
||||
if let Some(processed_outputs_poseidon_hash) = &processed_outputs.poseidon_hash {
|
||||
insert_poseidon_hash_pydict(dict_outputs, processed_outputs_poseidon_hash).unwrap();
|
||||
}
|
||||
if let Some(processed_outputs_kzg_commit) = &processed_outputs.kzg_commit {
|
||||
insert_kzg_commit_pydict(dict_inputs, processed_outputs_kzg_commit).unwrap();
|
||||
if let Some(processed_outputs_polycommit) = &processed_outputs.polycommit {
|
||||
insert_polycommit_pydict(dict_inputs, processed_outputs_polycommit).unwrap();
|
||||
}
|
||||
|
||||
dict.set_item("processed_outputs", dict_outputs).unwrap();
|
||||
@@ -429,13 +430,13 @@ fn insert_poseidon_hash_pydict(pydict: &PyDict, poseidon_hash: &Vec<Fp>) -> Resu
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
fn insert_kzg_commit_pydict(pydict: &PyDict, commits: &Vec<Vec<G1Affine>>) -> Result<(), PyErr> {
|
||||
fn insert_polycommit_pydict(pydict: &PyDict, commits: &Vec<Vec<G1Affine>>) -> Result<(), PyErr> {
|
||||
use crate::python::PyG1Affine;
|
||||
let poseidon_hash: Vec<Vec<PyG1Affine>> = commits
|
||||
.iter()
|
||||
.map(|c| c.iter().map(|x| PyG1Affine::from(*x)).collect())
|
||||
.collect();
|
||||
pydict.set_item("kzg_commit", poseidon_hash)?;
|
||||
pydict.set_item("polycommit", poseidon_hash)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -503,7 +504,9 @@ impl GraphSettings {
|
||||
}
|
||||
|
||||
fn constants_logrows(&self) -> u32 {
|
||||
(self.total_const_size as f64).log2().ceil() as u32
|
||||
(self.total_const_size as f64 / self.run_args.num_inner_cols as f64)
|
||||
.log2()
|
||||
.ceil() as u32
|
||||
}
|
||||
|
||||
/// calculate the total number of instances
|
||||
@@ -590,7 +593,7 @@ impl GraphSettings {
|
||||
|| self.run_args.param_visibility.is_hashed()
|
||||
}
|
||||
|
||||
/// requires dynamic lookup
|
||||
/// requires dynamic lookup
|
||||
pub fn requires_dynamic_lookup(&self) -> bool {
|
||||
self.num_dynamic_lookups > 0
|
||||
}
|
||||
@@ -601,10 +604,10 @@ impl GraphSettings {
|
||||
}
|
||||
|
||||
/// any kzg visibility
|
||||
pub fn module_requires_kzg(&self) -> bool {
|
||||
self.run_args.input_visibility.is_kzgcommit()
|
||||
|| self.run_args.output_visibility.is_kzgcommit()
|
||||
|| self.run_args.param_visibility.is_kzgcommit()
|
||||
pub fn module_requires_polycommit(&self) -> bool {
|
||||
self.run_args.input_visibility.is_polycommit()
|
||||
|| self.run_args.output_visibility.is_polycommit()
|
||||
|| self.run_args.param_visibility.is_polycommit()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1049,16 +1052,10 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i128) -> Range {
|
||||
let mut margin = (
|
||||
(
|
||||
lookup_safety_margin * min_max_lookup.0,
|
||||
lookup_safety_margin * min_max_lookup.1,
|
||||
);
|
||||
if lookup_safety_margin == 1 {
|
||||
margin.0 += 4;
|
||||
margin.1 += 4;
|
||||
}
|
||||
|
||||
margin
|
||||
)
|
||||
}
|
||||
|
||||
fn calc_num_cols(range_len: i128, max_logrows: u32) -> usize {
|
||||
@@ -1171,17 +1168,6 @@ impl GraphCircuit {
|
||||
.settings()
|
||||
.clone();
|
||||
|
||||
// recalculate the logrows if there has been overflow on the constants
|
||||
settings_mut.run_args.logrows = std::cmp::max(
|
||||
settings_mut.run_args.logrows,
|
||||
settings_mut.constants_logrows(),
|
||||
);
|
||||
// recalculate the logrows if there has been overflow for the model constraints
|
||||
settings_mut.run_args.logrows = std::cmp::max(
|
||||
settings_mut.run_args.logrows,
|
||||
settings_mut.model_constraint_logrows(),
|
||||
);
|
||||
|
||||
debug!(
|
||||
"setting lookup_range to: {:?}, setting logrows to: {}",
|
||||
self.settings().run_args.lookup_range,
|
||||
@@ -1248,12 +1234,12 @@ impl GraphCircuit {
|
||||
}
|
||||
|
||||
/// Runs the forward pass of the model / graph of computations and any associated hashing.
|
||||
pub fn forward(
|
||||
pub fn forward<Scheme: CommitmentScheme<Scalar = Fp, Curve = G1Affine>>(
|
||||
&self,
|
||||
inputs: &mut [Tensor<Fp>],
|
||||
vk: Option<&VerifyingKey<G1Affine>>,
|
||||
srs: Option<&ParamsKZG<Bn256>>,
|
||||
throw_range_check_error: bool,
|
||||
srs: Option<&Scheme::ParamsProver>,
|
||||
witness_gen: bool,
|
||||
) -> Result<GraphWitness, Box<dyn std::error::Error>> {
|
||||
let original_inputs = inputs.to_vec();
|
||||
|
||||
@@ -1269,7 +1255,8 @@ impl GraphCircuit {
|
||||
for outlet in &module_outlets {
|
||||
module_inputs.push(inputs[*outlet].clone());
|
||||
}
|
||||
let res = GraphModules::forward(&module_inputs, &visibility.input, vk, srs)?;
|
||||
let res =
|
||||
GraphModules::forward::<Scheme>(&module_inputs, &visibility.input, vk, srs)?;
|
||||
processed_inputs = Some(res.clone());
|
||||
let module_results = res.get_result(visibility.input.clone());
|
||||
|
||||
@@ -1277,7 +1264,12 @@ impl GraphCircuit {
|
||||
inputs[*outlet] = Tensor::from(module_results[i].clone().into_iter());
|
||||
}
|
||||
} else {
|
||||
processed_inputs = Some(GraphModules::forward(inputs, &visibility.input, vk, srs)?);
|
||||
processed_inputs = Some(GraphModules::forward::<Scheme>(
|
||||
inputs,
|
||||
&visibility.input,
|
||||
vk,
|
||||
srs,
|
||||
)?);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1285,7 +1277,7 @@ impl GraphCircuit {
|
||||
let params = self.model().get_all_params();
|
||||
if !params.is_empty() {
|
||||
let flattened_params = Tensor::new(Some(¶ms), &[params.len()])?.combine()?;
|
||||
processed_params = Some(GraphModules::forward(
|
||||
processed_params = Some(GraphModules::forward::<Scheme>(
|
||||
&[flattened_params],
|
||||
&visibility.params,
|
||||
vk,
|
||||
@@ -1296,7 +1288,7 @@ impl GraphCircuit {
|
||||
|
||||
let mut model_results =
|
||||
self.model()
|
||||
.forward(inputs, &self.settings().run_args, throw_range_check_error)?;
|
||||
.forward(inputs, &self.settings().run_args, witness_gen)?;
|
||||
|
||||
if visibility.output.requires_processing() {
|
||||
let module_outlets = visibility.output.overwrites_inputs();
|
||||
@@ -1305,7 +1297,8 @@ impl GraphCircuit {
|
||||
for outlet in &module_outlets {
|
||||
module_inputs.push(model_results.outputs[*outlet].clone());
|
||||
}
|
||||
let res = GraphModules::forward(&module_inputs, &visibility.output, vk, srs)?;
|
||||
let res =
|
||||
GraphModules::forward::<Scheme>(&module_inputs, &visibility.output, vk, srs)?;
|
||||
processed_outputs = Some(res.clone());
|
||||
let module_results = res.get_result(visibility.output.clone());
|
||||
|
||||
@@ -1314,7 +1307,7 @@ impl GraphCircuit {
|
||||
Tensor::from(module_results[i].clone().into_iter());
|
||||
}
|
||||
} else {
|
||||
processed_outputs = Some(GraphModules::forward(
|
||||
processed_outputs = Some(GraphModules::forward::<Scheme>(
|
||||
&model_results.outputs,
|
||||
&visibility.output,
|
||||
vk,
|
||||
@@ -1610,6 +1603,8 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
let output_vis = &self.settings().run_args.output_visibility;
|
||||
let mut graph_modules = GraphModules::new();
|
||||
|
||||
let mut constants = ConstantsMap::new();
|
||||
|
||||
let mut config = config.clone();
|
||||
|
||||
let mut inputs = self
|
||||
@@ -1655,6 +1650,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
&mut input_outlets,
|
||||
input_visibility,
|
||||
&mut instance_offset,
|
||||
&mut constants,
|
||||
)?;
|
||||
// replace inputs with the outlets
|
||||
for (i, outlet) in outlets.iter().enumerate() {
|
||||
@@ -1667,6 +1663,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
&mut inputs,
|
||||
input_visibility,
|
||||
&mut instance_offset,
|
||||
&mut constants,
|
||||
)?;
|
||||
}
|
||||
|
||||
@@ -1703,6 +1700,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
&mut flattened_params,
|
||||
param_visibility,
|
||||
&mut instance_offset,
|
||||
&mut constants,
|
||||
)?;
|
||||
|
||||
let shapes = self.model().const_shapes();
|
||||
@@ -1731,6 +1729,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
&inputs,
|
||||
&mut vars,
|
||||
&outputs,
|
||||
&mut constants,
|
||||
)
|
||||
.map_err(|e| {
|
||||
log::error!("{}", e);
|
||||
@@ -1755,6 +1754,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
&mut output_outlets,
|
||||
&self.settings().run_args.output_visibility,
|
||||
&mut instance_offset,
|
||||
&mut constants,
|
||||
)?;
|
||||
|
||||
// replace outputs with the outlets
|
||||
@@ -1768,6 +1768,7 @@ impl Circuit<Fp> for GraphCircuit {
|
||||
&mut outputs,
|
||||
&self.settings().run_args.output_visibility,
|
||||
&mut instance_offset,
|
||||
&mut constants,
|
||||
)?;
|
||||
}
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ use super::vars::*;
|
||||
use super::GraphError;
|
||||
use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
use crate::circuit::table::Range;
|
||||
use crate::circuit::Input;
|
||||
@@ -404,7 +405,7 @@ impl ParsedNodes {
|
||||
.get(input)
|
||||
.ok_or(GraphError::MissingNode(*input))?;
|
||||
let input_dims = node.out_dims();
|
||||
let input_dim = input_dims.get(0).ok_or(GraphError::MissingNode(*input))?;
|
||||
let input_dim = input_dims.first().ok_or(GraphError::MissingNode(*input))?;
|
||||
inputs.push(input_dim.clone());
|
||||
}
|
||||
|
||||
@@ -514,21 +515,24 @@ impl Model {
|
||||
instance_shapes.len().to_string().blue(),
|
||||
"instances".blue()
|
||||
);
|
||||
// this is the total number of variables we will need to allocate
|
||||
// for the circuit
|
||||
let default_value = if !self.visibility.input.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::ONE)
|
||||
};
|
||||
|
||||
let inputs: Vec<ValTensor<Fp>> = self
|
||||
.graph
|
||||
.input_shapes()?
|
||||
.iter()
|
||||
.map(|shape| {
|
||||
let mut t: ValTensor<Fp> =
|
||||
vec![default_value.clone(); shape.iter().product()].into();
|
||||
let len = shape.iter().product();
|
||||
let mut t: ValTensor<Fp> = (0..len)
|
||||
.map(|_| {
|
||||
if !self.visibility.input.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::random(&mut rand::thread_rng()))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.into();
|
||||
|
||||
t.reshape(shape)?;
|
||||
Ok(t)
|
||||
})
|
||||
@@ -577,13 +581,13 @@ impl Model {
|
||||
&self,
|
||||
model_inputs: &[Tensor<Fp>],
|
||||
run_args: &RunArgs,
|
||||
throw_range_check_error: bool,
|
||||
witness_gen: bool,
|
||||
) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
.collect();
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, throw_range_check_error)?;
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen)?;
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
@@ -970,7 +974,7 @@ impl Model {
|
||||
|
||||
let (model, _) = Model::load_onnx_using_tract(
|
||||
&mut std::fs::File::open(model_path)
|
||||
.map_err(|_| format!("failed to load model at {}", model_path.display()))?,
|
||||
.map_err(|_| format!("failed to load {}", model_path.display()))?,
|
||||
run_args,
|
||||
)?;
|
||||
|
||||
@@ -1006,7 +1010,7 @@ impl Model {
|
||||
) -> Result<Self, Box<dyn Error>> {
|
||||
Model::new(
|
||||
&mut std::fs::File::open(model)
|
||||
.map_err(|_| format!("failed to load model at {}", model.display()))?,
|
||||
.map_err(|_| format!("failed to load {}", model.display()))?,
|
||||
run_args,
|
||||
)
|
||||
}
|
||||
@@ -1071,6 +1075,8 @@ impl Model {
|
||||
/// * `layouter` - Halo2 Layouter.
|
||||
/// * `inputs` - The values to feed into the circuit.
|
||||
/// * `vars` - The variables for the circuit.
|
||||
/// * `witnessed_outputs` - The values to compare against.
|
||||
/// * `constants` - The constants for the circuit.
|
||||
pub fn layout(
|
||||
&self,
|
||||
mut config: ModelConfig,
|
||||
@@ -1079,6 +1085,7 @@ impl Model {
|
||||
inputs: &[ValTensor<Fp>],
|
||||
vars: &mut ModelVars<Fp>,
|
||||
witnessed_outputs: &[ValTensor<Fp>],
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<Vec<ValTensor<Fp>>, Box<dyn Error>> {
|
||||
info!("model layout...");
|
||||
|
||||
@@ -1104,14 +1111,12 @@ impl Model {
|
||||
config.base.layout_tables(layouter)?;
|
||||
config.base.layout_range_checks(layouter)?;
|
||||
|
||||
let mut num_rows = 0;
|
||||
let mut linear_coord = 0;
|
||||
let mut total_const_size = 0;
|
||||
let original_constants = constants.clone();
|
||||
|
||||
let outputs = layouter.assign_region(
|
||||
|| "model",
|
||||
|region| {
|
||||
let mut thread_safe_region = RegionCtx::new(region, 0, run_args.num_inner_cols);
|
||||
let mut thread_safe_region = RegionCtx::new_with_constants(region, 0, run_args.num_inner_cols, original_constants.clone());
|
||||
// we need to do this as this loop is called multiple times
|
||||
vars.set_instance_idx(instance_idx);
|
||||
|
||||
@@ -1157,29 +1162,17 @@ impl Model {
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
} else if !run_args.output_visibility.is_private() {
|
||||
for output in &outputs {
|
||||
thread_safe_region.increment_total_constants(output.num_constants());
|
||||
}
|
||||
}
|
||||
num_rows = thread_safe_region.row();
|
||||
linear_coord = thread_safe_region.linear_coord();
|
||||
total_const_size = thread_safe_region.total_constants();
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
thread_safe_region.debug_report();
|
||||
|
||||
*constants = thread_safe_region.assigned_constants().clone();
|
||||
|
||||
Ok(outputs)
|
||||
},
|
||||
)?;
|
||||
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
debug!(
|
||||
"{} {} {} (coord={}, constants={})",
|
||||
"model uses".blue(),
|
||||
num_rows.to_string().blue(),
|
||||
"rows".blue(),
|
||||
linear_coord.to_string().yellow(),
|
||||
total_const_size.to_string().red()
|
||||
);
|
||||
)?;
|
||||
|
||||
let duration = start_time.elapsed();
|
||||
trace!("model layout took: {:?}", duration);
|
||||
@@ -1213,16 +1206,10 @@ impl Model {
|
||||
vec![results.get(idx).ok_or(GraphError::MissingResults)?[0].clone()]
|
||||
};
|
||||
|
||||
debug!(
|
||||
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
|
||||
idx,
|
||||
node.as_str(),
|
||||
region.row(),
|
||||
region.linear_coord(),
|
||||
region.total_constants(),
|
||||
region.max_lookup_inputs(),
|
||||
region.min_lookup_inputs()
|
||||
);
|
||||
debug!("laying out {}: {}", idx, node.as_str(),);
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
region.debug_report();
|
||||
debug!("dims: {:?}", node.out_dims());
|
||||
debug!(
|
||||
"input_dims {:?}",
|
||||
@@ -1380,7 +1367,7 @@ impl Model {
|
||||
&self,
|
||||
run_args: &RunArgs,
|
||||
inputs: &[ValTensor<Fp>],
|
||||
throw_range_check_error: bool,
|
||||
witness_gen: bool,
|
||||
) -> Result<DummyPassRes, Box<dyn Error>> {
|
||||
debug!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
@@ -1399,29 +1386,31 @@ impl Model {
|
||||
vars: ModelVars::new_dummy(),
|
||||
};
|
||||
|
||||
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, throw_range_check_error);
|
||||
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen);
|
||||
|
||||
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
|
||||
|
||||
if self.visibility.output.is_public() || self.visibility.output.is_fixed() {
|
||||
let default_value = if !self.visibility.output.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::ONE)
|
||||
};
|
||||
|
||||
let output_scales = self.graph.get_output_scales()?;
|
||||
let res = outputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, output)| {
|
||||
let mut comparator: ValTensor<Fp> = (0..output.len())
|
||||
.map(|_| {
|
||||
if !self.visibility.output.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::random(&mut rand::thread_rng()))
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.into();
|
||||
comparator.reshape(output.dims())?;
|
||||
|
||||
let mut tolerance = run_args.tolerance;
|
||||
tolerance.scale = scale_to_multiplier(output_scales[i]).into();
|
||||
|
||||
let mut comparator: ValTensor<Fp> =
|
||||
vec![default_value.clone(); output.dims().iter().product::<usize>()].into();
|
||||
comparator.reshape(output.dims())?;
|
||||
|
||||
dummy_config.layout(
|
||||
&mut region,
|
||||
&[output.clone(), comparator],
|
||||
@@ -1432,7 +1421,7 @@ impl Model {
|
||||
res?;
|
||||
} else if !self.visibility.output.is_private() {
|
||||
for output in &outputs {
|
||||
region.increment_total_constants(output.num_constants());
|
||||
region.update_constants(output.create_constants_map());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1441,14 +1430,7 @@ impl Model {
|
||||
|
||||
// Then number of columns in the circuits
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
debug!(
|
||||
"{} {} {} (coord={}, constants={})",
|
||||
"model uses".blue(),
|
||||
region.row().to_string().blue(),
|
||||
"rows".blue(),
|
||||
region.linear_coord().to_string().yellow(),
|
||||
region.total_constants().to_string().red()
|
||||
);
|
||||
region.debug_report();
|
||||
|
||||
let outputs = outputs
|
||||
.iter()
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
use crate::circuit::modules::kzg::{KZGChip, KZGConfig};
|
||||
use crate::circuit::modules::polycommit::{PolyCommitChip, PolyCommitConfig};
|
||||
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
use crate::circuit::modules::poseidon::{PoseidonChip, PoseidonConfig};
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::tensor::{Tensor, ValTensor};
|
||||
use halo2_proofs::circuit::Layouter;
|
||||
use halo2_proofs::plonk::{Column, ConstraintSystem, Error, Instance, VerifyingKey};
|
||||
use halo2_proofs::poly::kzg::commitment::ParamsKZG;
|
||||
use halo2curves::bn256::{Bn256, Fr as Fp, G1Affine};
|
||||
use halo2_proofs::poly::commitment::CommitmentScheme;
|
||||
use halo2curves::bn256::{Fr as Fp, G1Affine};
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
@@ -14,9 +15,6 @@ use super::{VarVisibility, Visibility};
|
||||
|
||||
/// poseidon len to hash in tree
|
||||
pub const POSEIDON_LEN_GRAPH: usize = 32;
|
||||
|
||||
/// ElGamal number of instances
|
||||
pub const ELGAMAL_INSTANCES: usize = 4;
|
||||
/// Poseidon number of instancess
|
||||
pub const POSEIDON_INSTANCES: usize = 1;
|
||||
|
||||
@@ -29,8 +27,8 @@ pub type ModulePoseidonConfig = PoseidonConfig<POSEIDON_WIDTH, POSEIDON_RATE>;
|
||||
///
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ModuleConfigs {
|
||||
/// KZG
|
||||
kzg: Vec<KZGConfig>,
|
||||
/// PolyCommit
|
||||
polycommit: Vec<PolyCommitConfig>,
|
||||
/// Poseidon
|
||||
poseidon: Option<ModulePoseidonConfig>,
|
||||
/// Instance
|
||||
@@ -46,8 +44,10 @@ impl ModuleConfigs {
|
||||
) -> Self {
|
||||
let mut config = Self::default();
|
||||
|
||||
for size in module_size.kzg {
|
||||
config.kzg.push(KZGChip::configure(cs, (logrows, size)));
|
||||
for size in module_size.polycommit {
|
||||
config
|
||||
.polycommit
|
||||
.push(PolyCommitChip::configure(cs, (logrows, size)));
|
||||
}
|
||||
|
||||
config
|
||||
@@ -94,8 +94,8 @@ impl ModuleConfigs {
|
||||
pub struct ModuleForwardResult {
|
||||
/// The inputs of the forward pass for poseidon
|
||||
pub poseidon_hash: Option<Vec<Fp>>,
|
||||
/// The outputs of the forward pass for KZG
|
||||
pub kzg_commit: Option<Vec<Vec<G1Affine>>>,
|
||||
/// The outputs of the forward pass for PolyCommit
|
||||
pub polycommit: Option<Vec<Vec<G1Affine>>>,
|
||||
}
|
||||
|
||||
impl ModuleForwardResult {
|
||||
@@ -126,7 +126,7 @@ impl ModuleForwardResult {
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
|
||||
///
|
||||
pub struct ModuleSizes {
|
||||
kzg: Vec<usize>,
|
||||
polycommit: Vec<usize>,
|
||||
poseidon: (usize, Vec<usize>),
|
||||
}
|
||||
|
||||
@@ -134,7 +134,7 @@ impl ModuleSizes {
|
||||
/// Create new module sizes
|
||||
pub fn new() -> Self {
|
||||
ModuleSizes {
|
||||
kzg: vec![],
|
||||
polycommit: vec![],
|
||||
poseidon: (
|
||||
0,
|
||||
vec![0; crate::circuit::modules::poseidon::NUM_INSTANCE_COLUMNS],
|
||||
@@ -156,17 +156,17 @@ impl ModuleSizes {
|
||||
/// Graph modules that can process inputs, params and outputs beyond the basic operations
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct GraphModules {
|
||||
kzg_idx: usize,
|
||||
polycommit_idx: usize,
|
||||
}
|
||||
impl GraphModules {
|
||||
///
|
||||
pub fn new() -> GraphModules {
|
||||
GraphModules { kzg_idx: 0 }
|
||||
GraphModules { polycommit_idx: 0 }
|
||||
}
|
||||
|
||||
///
|
||||
pub fn reset_index(&mut self) {
|
||||
self.kzg_idx = 0;
|
||||
self.polycommit_idx = 0;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -179,9 +179,9 @@ impl GraphModules {
|
||||
for shape in shapes {
|
||||
let total_len = shape.iter().product::<usize>();
|
||||
if total_len > 0 {
|
||||
if visibility.is_kzgcommit() {
|
||||
// 1 constraint for each kzg commitment
|
||||
sizes.kzg.push(total_len);
|
||||
if visibility.is_polycommit() {
|
||||
// 1 constraint for each polycommit commitment
|
||||
sizes.polycommit.push(total_len);
|
||||
} else if visibility.is_hashed() {
|
||||
sizes.poseidon.0 += ModulePoseidon::num_rows(total_len);
|
||||
// 1 constraints for hash
|
||||
@@ -212,12 +212,13 @@ impl GraphModules {
|
||||
layouter: &mut impl Layouter<Fp>,
|
||||
x: &mut Vec<ValTensor<Fp>>,
|
||||
instance_offset: &mut usize,
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<(), Error> {
|
||||
// reserve module 0 for ... modules
|
||||
// hash the input and replace the constrained cells in the input
|
||||
let cloned_x = (*x).clone();
|
||||
x[0] = module
|
||||
.layout(layouter, &cloned_x, instance_offset.to_owned())
|
||||
.layout(layouter, &cloned_x, instance_offset.to_owned(), constants)
|
||||
.unwrap();
|
||||
for inc in module.instance_increment_input().iter() {
|
||||
// increment the instance offset to make way for future module layouts
|
||||
@@ -235,23 +236,24 @@ impl GraphModules {
|
||||
values: &mut [ValTensor<Fp>],
|
||||
element_visibility: &Visibility,
|
||||
instance_offset: &mut usize,
|
||||
constants: &mut ConstantsMap<Fp>,
|
||||
) -> Result<(), Error> {
|
||||
if element_visibility.is_kzgcommit() && !values.is_empty() {
|
||||
if element_visibility.is_polycommit() && !values.is_empty() {
|
||||
// concat values and sk to get the inputs
|
||||
let mut inputs = values.iter_mut().map(|x| vec![x.clone()]).collect_vec();
|
||||
|
||||
// layout the module
|
||||
inputs.iter_mut().for_each(|x| {
|
||||
// create the module
|
||||
let chip = KZGChip::new(configs.kzg[self.kzg_idx].clone());
|
||||
// reserve module 2 onwards for kzg modules
|
||||
let module_offset = 3 + self.kzg_idx;
|
||||
let chip = PolyCommitChip::new(configs.polycommit[self.polycommit_idx].clone());
|
||||
// reserve module 2 onwards for polycommit modules
|
||||
let module_offset = 3 + self.polycommit_idx;
|
||||
layouter
|
||||
.assign_region(|| format!("_enter_module_{}", module_offset), |_| Ok(()))
|
||||
.unwrap();
|
||||
Self::layout_module(&chip, layouter, x, instance_offset).unwrap();
|
||||
Self::layout_module(&chip, layouter, x, instance_offset, constants).unwrap();
|
||||
// increment the current index
|
||||
self.kzg_idx += 1;
|
||||
self.polycommit_idx += 1;
|
||||
});
|
||||
|
||||
// replace the inputs with the outputs
|
||||
@@ -271,7 +273,7 @@ impl GraphModules {
|
||||
let mut inputs = values.iter_mut().map(|x| vec![x.clone()]).collect_vec();
|
||||
// layout the module
|
||||
inputs.iter_mut().for_each(|x| {
|
||||
Self::layout_module(&chip, layouter, x, instance_offset).unwrap();
|
||||
Self::layout_module(&chip, layouter, x, instance_offset, constants).unwrap();
|
||||
});
|
||||
// replace the inputs with the outputs
|
||||
values.iter_mut().enumerate().for_each(|(i, x)| {
|
||||
@@ -288,14 +290,14 @@ impl GraphModules {
|
||||
}
|
||||
|
||||
/// Run forward pass
|
||||
pub fn forward(
|
||||
inputs: &[Tensor<Fp>],
|
||||
pub fn forward<Scheme: CommitmentScheme<Scalar = Fp, Curve = G1Affine>>(
|
||||
inputs: &[Tensor<Scheme::Scalar>],
|
||||
element_visibility: &Visibility,
|
||||
vk: Option<&VerifyingKey<G1Affine>>,
|
||||
srs: Option<&ParamsKZG<Bn256>>,
|
||||
srs: Option<&Scheme::ParamsProver>,
|
||||
) -> Result<ModuleForwardResult, Box<dyn std::error::Error>> {
|
||||
let mut poseidon_hash = None;
|
||||
let mut kzg_commit = None;
|
||||
let mut polycommit = None;
|
||||
|
||||
if element_visibility.is_hashed() {
|
||||
let field_elements = inputs.iter().fold(vec![], |mut acc, x| {
|
||||
@@ -306,11 +308,11 @@ impl GraphModules {
|
||||
poseidon_hash = Some(field_elements);
|
||||
}
|
||||
|
||||
if element_visibility.is_kzgcommit() {
|
||||
if element_visibility.is_polycommit() {
|
||||
if let Some(vk) = vk {
|
||||
if let Some(srs) = srs {
|
||||
let commitments = inputs.iter().fold(vec![], |mut acc, x| {
|
||||
let res = KZGChip::commit(
|
||||
let res = PolyCommitChip::commit::<Scheme>(
|
||||
x.to_vec(),
|
||||
vk.cs().degree() as u32,
|
||||
(vk.cs().blinding_factors() + 1) as u32,
|
||||
@@ -319,20 +321,20 @@ impl GraphModules {
|
||||
acc.push(res);
|
||||
acc
|
||||
});
|
||||
kzg_commit = Some(commitments);
|
||||
polycommit = Some(commitments);
|
||||
} else {
|
||||
log::warn!("no srs provided for kzgcommit. processed value will be none");
|
||||
log::warn!("no srs provided for polycommit. processed value will be none");
|
||||
}
|
||||
} else {
|
||||
log::debug!(
|
||||
"no verifying key provided for kzgcommit. processed value will be none"
|
||||
"no verifying key provided for polycommit. processed value will be none"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(ModuleForwardResult {
|
||||
poseidon_hash,
|
||||
kzg_commit,
|
||||
polycommit,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -248,6 +248,8 @@ pub fn new_op_from_onnx(
|
||||
symbol_values: &SymbolValues,
|
||||
rebase_frac_zero_constants: bool,
|
||||
) -> Result<(SupportedOp, Vec<usize>), Box<dyn std::error::Error>> {
|
||||
use tract_onnx::tract_core::ops::array::Trilu;
|
||||
|
||||
use crate::circuit::InputType;
|
||||
|
||||
let input_scales = inputs
|
||||
@@ -363,6 +365,26 @@ pub fn new_op_from_onnx(
|
||||
SupportedOp::Constant(c)
|
||||
}
|
||||
|
||||
"Trilu" => {
|
||||
let op = load_op::<Trilu>(node.op(), idx, node.op().name().to_string())?;
|
||||
let upper = op.upper;
|
||||
|
||||
// assert second input is a constant
|
||||
let diagonal = if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
let raw_values = &c.raw_values;
|
||||
if raw_values.len() != 1 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "trilu".to_string())));
|
||||
}
|
||||
raw_values[0] as i32
|
||||
} else {
|
||||
return Err("we only support constant inputs for trilu diagonal".into());
|
||||
};
|
||||
|
||||
SupportedOp::Linear(PolyOp::Trilu { upper, k: diagonal })
|
||||
}
|
||||
|
||||
"Gather" => {
|
||||
if inputs.len() != 2 {
|
||||
return Err(Box::new(GraphError::InvalidDims(idx, "gather".to_string())));
|
||||
@@ -839,6 +861,9 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Abs" => SupportedOp::Nonlinear(LookupOp::Abs),
|
||||
"Neg" => SupportedOp::Linear(PolyOp::Neg),
|
||||
"HardSwish" => SupportedOp::Nonlinear(LookupOp::HardSwish {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
}),
|
||||
"Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
}),
|
||||
@@ -1047,8 +1072,12 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
};
|
||||
|
||||
let in_scale = inputs[0].out_scales()[0];
|
||||
let max_scale = std::cmp::max(scales.get_max(), in_scale);
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::Softmax {
|
||||
scale: scale_to_multiplier(inputs[0].out_scales()[0]).into(),
|
||||
input_scale: scale_to_multiplier(in_scale).into(),
|
||||
output_scale: scale_to_multiplier(max_scale).into(),
|
||||
axes: softmax_op.axes.to_vec(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -45,7 +45,7 @@ pub enum Visibility {
|
||||
impl Display for Visibility {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
match self {
|
||||
Visibility::KZGCommit => write!(f, "kzgcommit"),
|
||||
Visibility::KZGCommit => write!(f, "polycommit"),
|
||||
Visibility::Private => write!(f, "private"),
|
||||
Visibility::Public => write!(f, "public"),
|
||||
Visibility::Fixed => write!(f, "fixed"),
|
||||
@@ -88,7 +88,7 @@ impl<'a> From<&'a str> for Visibility {
|
||||
match s {
|
||||
"private" => Visibility::Private,
|
||||
"public" => Visibility::Public,
|
||||
"kzgcommit" => Visibility::KZGCommit,
|
||||
"polycommit" => Visibility::KZGCommit,
|
||||
"fixed" => Visibility::Fixed,
|
||||
"hashed" | "hashed/public" => Visibility::Hashed {
|
||||
hash_is_public: true,
|
||||
@@ -111,7 +111,7 @@ impl IntoPy<PyObject> for Visibility {
|
||||
Visibility::Private => "private".to_object(py),
|
||||
Visibility::Public => "public".to_object(py),
|
||||
Visibility::Fixed => "fixed".to_object(py),
|
||||
Visibility::KZGCommit => "kzgcommit".to_object(py),
|
||||
Visibility::KZGCommit => "polycommit".to_object(py),
|
||||
Visibility::Hashed {
|
||||
hash_is_public,
|
||||
outlets,
|
||||
@@ -158,7 +158,7 @@ impl<'source> FromPyObject<'source> for Visibility {
|
||||
match strval.to_lowercase().as_str() {
|
||||
"private" => Ok(Visibility::Private),
|
||||
"public" => Ok(Visibility::Public),
|
||||
"kzgcommit" => Ok(Visibility::KZGCommit),
|
||||
"polycommit" => Ok(Visibility::KZGCommit),
|
||||
"hashed" => Ok(Visibility::Hashed {
|
||||
hash_is_public: true,
|
||||
outlets: vec![],
|
||||
@@ -192,7 +192,7 @@ impl Visibility {
|
||||
matches!(&self, Visibility::Hashed { .. })
|
||||
}
|
||||
#[allow(missing_docs)]
|
||||
pub fn is_kzgcommit(&self) -> bool {
|
||||
pub fn is_polycommit(&self) -> bool {
|
||||
matches!(&self, Visibility::KZGCommit)
|
||||
}
|
||||
|
||||
@@ -323,9 +323,9 @@ impl VarVisibility {
|
||||
& !output_vis.is_hashed()
|
||||
& !params_vis.is_hashed()
|
||||
& !input_vis.is_hashed()
|
||||
& !output_vis.is_kzgcommit()
|
||||
& !params_vis.is_kzgcommit()
|
||||
& !input_vis.is_kzgcommit()
|
||||
& !output_vis.is_polycommit()
|
||||
& !params_vis.is_polycommit()
|
||||
& !input_vis.is_polycommit()
|
||||
{
|
||||
return Err(Box::new(GraphError::Visibility));
|
||||
}
|
||||
@@ -346,7 +346,7 @@ pub struct ModelVars<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub instance: Option<ValTensor<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> ModelVars<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ModelVars<F> {
|
||||
/// Get instance col
|
||||
pub fn get_instance_col(&self) -> Option<&Column<Instance>> {
|
||||
if let Some(instance) = &self.instance {
|
||||
|
||||
76
src/lib.rs
76
src/lib.rs
@@ -23,14 +23,19 @@
|
||||
)]
|
||||
// we allow this for our dynamic range based indexing scheme
|
||||
#![allow(clippy::single_range_in_vec_init)]
|
||||
#![feature(round_ties_even)]
|
||||
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
//!
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use clap::Args;
|
||||
use graph::Visibility;
|
||||
use halo2_proofs::poly::{
|
||||
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
|
||||
};
|
||||
use halo2curves::bn256::{Bn256, G1Affine};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tosubcommand::ToFlags;
|
||||
|
||||
@@ -97,6 +102,71 @@ const EZKL_KEY_FORMAT: &str = "raw-bytes";
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
const EZKL_BUF_CAPACITY: &usize = &8000;
|
||||
|
||||
#[derive(
|
||||
Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize, Default, Copy,
|
||||
)]
|
||||
/// Commitment scheme
|
||||
pub enum Commitments {
|
||||
#[default]
|
||||
/// KZG
|
||||
KZG,
|
||||
/// IPA
|
||||
IPA,
|
||||
}
|
||||
|
||||
impl FromStr for Commitments {
|
||||
type Err = String;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"kzg" => Ok(Commitments::KZG),
|
||||
"ipa" => Ok(Commitments::IPA),
|
||||
_ => Err("Invalid value for Commitments".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<KZGCommitmentScheme<Bn256>> for Commitments {
|
||||
fn from(_value: KZGCommitmentScheme<Bn256>) -> Self {
|
||||
Commitments::KZG
|
||||
}
|
||||
}
|
||||
|
||||
impl From<IPACommitmentScheme<G1Affine>> for Commitments {
|
||||
fn from(_value: IPACommitmentScheme<G1Affine>) -> Self {
|
||||
Commitments::IPA
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Commitments {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Commitments::KZG => write!(f, "kzg"),
|
||||
Commitments::IPA => write!(f, "ipa"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ToFlags for Commitments {
|
||||
/// Convert the struct to a subcommand string
|
||||
fn to_flags(&self) -> Vec<String> {
|
||||
vec![format!("{}", self)]
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for Commitments {
|
||||
fn from(value: String) -> Self {
|
||||
match value.to_lowercase().as_str() {
|
||||
"kzg" => Commitments::KZG,
|
||||
"ipa" => Commitments::IPA,
|
||||
_ => {
|
||||
log::error!("Invalid value for Commitments");
|
||||
log::warn!("defaulting to KZG");
|
||||
Commitments::KZG
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parameters specific to a proving run
|
||||
#[derive(Debug, Args, Deserialize, Serialize, Clone, PartialEq, PartialOrd, ToFlags)]
|
||||
pub struct RunArgs {
|
||||
@@ -142,6 +212,9 @@ pub struct RunArgs {
|
||||
/// check mode (safe, unsafe, etc)
|
||||
#[arg(long, default_value = "unsafe")]
|
||||
pub check_mode: CheckMode,
|
||||
/// commitment scheme
|
||||
#[arg(long, default_value = "kzg")]
|
||||
pub commitment: Commitments,
|
||||
}
|
||||
|
||||
impl Default for RunArgs {
|
||||
@@ -161,6 +234,7 @@ impl Default for RunArgs {
|
||||
div_rebasing: false,
|
||||
rebase_frac_zero_constants: false,
|
||||
check_mode: CheckMode::UNSAFE,
|
||||
commitment: Commitments::KZG,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use thiserror::Error;
|
||||
|
||||
/// Aggregate proof generation for EVM
|
||||
pub mod aggregation;
|
||||
/// Aggregate proof generation for EVM using KZG
|
||||
pub mod aggregation_kzg;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
/// Errors related to evm verification
|
||||
|
||||
268
src/pfsys/mod.rs
268
src/pfsys/mod.rs
@@ -6,17 +6,16 @@ pub mod srs;
|
||||
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::graph::GraphWitness;
|
||||
use crate::pfsys::evm::aggregation::PoseidonTranscript;
|
||||
use crate::tensor::TensorType;
|
||||
use crate::{EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
|
||||
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
|
||||
use crate::{Commitments, EZKL_BUF_CAPACITY, EZKL_KEY_FORMAT};
|
||||
use clap::ValueEnum;
|
||||
use halo2_proofs::circuit::Value;
|
||||
use halo2_proofs::plonk::{
|
||||
create_proof, keygen_pk, keygen_vk_custom, verify_proof, Circuit, ProvingKey, VerifyingKey,
|
||||
};
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, Params, ParamsProver, Prover, Verifier};
|
||||
use halo2_proofs::poly::kzg::commitment::{KZGCommitmentScheme, ParamsKZG};
|
||||
use halo2_proofs::poly::kzg::multiopen::{ProverSHPLONK, VerifierSHPLONK};
|
||||
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_proofs::transcript::{EncodedChallenge, TranscriptReadBuffer, TranscriptWriterBuffer};
|
||||
use halo2curves::ff::{FromUniformBytes, PrimeField, WithSmallOrderMulGroup};
|
||||
@@ -32,7 +31,6 @@ use serde::de::DeserializeOwned;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use snark_verifier::system::halo2::{compile, Config};
|
||||
use snark_verifier::verifier::plonk::PlonkProtocol;
|
||||
use std::error::Error;
|
||||
use std::fs::File;
|
||||
@@ -293,6 +291,8 @@ where
|
||||
pub pretty_public_inputs: Option<PrettyElements>,
|
||||
/// timestamp
|
||||
pub timestamp: Option<u128>,
|
||||
/// commitment
|
||||
pub commitment: Option<Commitments>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "python-bindings")]
|
||||
@@ -336,6 +336,7 @@ where
|
||||
transcript_type: TranscriptType,
|
||||
split: Option<ProofSplitCommit>,
|
||||
pretty_public_inputs: Option<PrettyElements>,
|
||||
commitment: Option<Commitments>,
|
||||
) -> Self {
|
||||
Self {
|
||||
protocol,
|
||||
@@ -352,6 +353,7 @@ where
|
||||
.unwrap()
|
||||
.as_millis(),
|
||||
),
|
||||
commitment,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -398,27 +400,36 @@ impl From<GraphWitness> for Option<ProofSplitCommit> {
|
||||
let mut elem_offset = 0;
|
||||
|
||||
if let Some(input) = witness.processed_inputs {
|
||||
if let Some(kzg) = input.kzg_commit {
|
||||
if let Some(polycommit) = input.polycommit {
|
||||
// flatten and count number of elements
|
||||
let num_elements = kzg.iter().map(|kzg| kzg.len()).sum::<usize>();
|
||||
let num_elements = polycommit
|
||||
.iter()
|
||||
.map(|polycommit| polycommit.len())
|
||||
.sum::<usize>();
|
||||
|
||||
elem_offset += num_elements;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(params) = witness.processed_params {
|
||||
if let Some(kzg) = params.kzg_commit {
|
||||
if let Some(polycommit) = params.polycommit {
|
||||
// flatten and count number of elements
|
||||
let num_elements = kzg.iter().map(|kzg| kzg.len()).sum::<usize>();
|
||||
let num_elements = polycommit
|
||||
.iter()
|
||||
.map(|polycommit| polycommit.len())
|
||||
.sum::<usize>();
|
||||
|
||||
elem_offset += num_elements;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(output) = witness.processed_outputs {
|
||||
if let Some(kzg) = output.kzg_commit {
|
||||
if let Some(polycommit) = output.polycommit {
|
||||
// flatten and count number of elements
|
||||
let num_elements = kzg.iter().map(|kzg| kzg.len()).sum::<usize>();
|
||||
let num_elements = polycommit
|
||||
.iter()
|
||||
.map(|polycommit| polycommit.len())
|
||||
.sum::<usize>();
|
||||
|
||||
Some(ProofSplitCommit {
|
||||
start: elem_offset,
|
||||
@@ -481,7 +492,7 @@ where
|
||||
}
|
||||
|
||||
/// Creates a [VerifyingKey] and [ProvingKey] for a [crate::graph::GraphCircuit] (`circuit`) with specific [CommitmentScheme] parameters (`params`).
|
||||
pub fn create_keys<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
|
||||
pub fn create_keys<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
circuit: &C,
|
||||
params: &'_ Scheme::ParamsProver,
|
||||
disable_selector_compression: bool,
|
||||
@@ -491,7 +502,7 @@ where
|
||||
<Scheme as CommitmentScheme>::Scalar: FromUniformBytes<64>,
|
||||
{
|
||||
// Real proof
|
||||
let empty_circuit = <C as Circuit<F>>::without_witnesses(circuit);
|
||||
let empty_circuit = <C as Circuit<Scheme::Scalar>>::without_witnesses(circuit);
|
||||
|
||||
// Initialize verifying key
|
||||
let now = Instant::now();
|
||||
@@ -513,8 +524,7 @@ where
|
||||
pub fn create_proof_circuit<
|
||||
'params,
|
||||
Scheme: CommitmentScheme,
|
||||
F: PrimeField + TensorType,
|
||||
C: Circuit<F>,
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
P: Prover<'params, Scheme>,
|
||||
V: Verifier<'params, Scheme>,
|
||||
Strategy: VerificationStrategy<'params, Scheme, V>,
|
||||
@@ -526,40 +536,28 @@ pub fn create_proof_circuit<
|
||||
instances: Vec<Vec<Scheme::Scalar>>,
|
||||
params: &'params Scheme::ParamsProver,
|
||||
pk: &ProvingKey<Scheme::Curve>,
|
||||
strategy: Strategy,
|
||||
check_mode: CheckMode,
|
||||
commitment: Commitments,
|
||||
transcript_type: TranscriptType,
|
||||
split: Option<ProofSplitCommit>,
|
||||
set_protocol: bool,
|
||||
protocol: Option<PlonkProtocol<Scheme::Curve>>,
|
||||
) -> Result<Snark<Scheme::Scalar, Scheme::Curve>, Box<dyn Error>>
|
||||
where
|
||||
C: Circuit<Scheme::Scalar>,
|
||||
Scheme::ParamsVerifier: 'params,
|
||||
Scheme::Scalar: Serialize
|
||||
+ DeserializeOwned
|
||||
+ SerdeObject
|
||||
+ PrimeField
|
||||
+ FromUniformBytes<64>
|
||||
+ WithSmallOrderMulGroup<3>
|
||||
+ Ord,
|
||||
+ WithSmallOrderMulGroup<3>,
|
||||
Scheme::Curve: Serialize + DeserializeOwned,
|
||||
{
|
||||
let strategy = Strategy::new(params.verifier_params());
|
||||
let mut transcript = TranscriptWriterBuffer::<_, Scheme::Curve, _>::init(vec![]);
|
||||
#[cfg(feature = "det-prove")]
|
||||
let mut rng = <StdRng as rand::SeedableRng>::from_seed([0u8; 32]);
|
||||
#[cfg(not(feature = "det-prove"))]
|
||||
let mut rng = OsRng;
|
||||
let number_instance = instances.iter().map(|x| x.len()).collect();
|
||||
trace!("number_instance {:?}", number_instance);
|
||||
let mut protocol = None;
|
||||
|
||||
if set_protocol {
|
||||
protocol = Some(compile(
|
||||
params,
|
||||
pk.get_vk(),
|
||||
Config::kzg().with_num_instance(number_instance),
|
||||
))
|
||||
}
|
||||
|
||||
let pi_inner = instances
|
||||
.iter()
|
||||
@@ -595,13 +593,14 @@ where
|
||||
transcript_type,
|
||||
split,
|
||||
None,
|
||||
Some(commitment),
|
||||
);
|
||||
|
||||
// sanity check that the generated proof is valid
|
||||
if check_mode == CheckMode::SAFE {
|
||||
debug!("verifying generated proof");
|
||||
let verifier_params = params.verifier_params();
|
||||
verify_proof_circuit::<F, V, Scheme, Strategy, E, TR>(
|
||||
verify_proof_circuit::<V, Scheme, Strategy, E, TR>(
|
||||
&checkable_pf,
|
||||
verifier_params,
|
||||
pk.get_vk(),
|
||||
@@ -621,7 +620,6 @@ where
|
||||
|
||||
/// Swaps the proof commitments to a new set in the proof
|
||||
pub fn swap_proof_commitments<
|
||||
F: PrimeField,
|
||||
Scheme: CommitmentScheme,
|
||||
E: EncodedChallenge<Scheme::Curve>,
|
||||
TW: TranscriptWriterBuffer<Vec<u8>, Scheme::Curve, E>,
|
||||
@@ -641,7 +639,7 @@ where
|
||||
{
|
||||
let mut transcript_new: TW = TranscriptWriterBuffer::<_, Scheme::Curve, _>::init(vec![]);
|
||||
|
||||
// kzg commitments are the first set of points in the proof, this we'll always be the first set of advice
|
||||
// polycommit commitments are the first set of points in the proof, this we'll always be the first set of advice
|
||||
for commit in commitments {
|
||||
transcript_new
|
||||
.write_point(*commit)
|
||||
@@ -659,31 +657,46 @@ where
|
||||
}
|
||||
|
||||
/// Swap the proof commitments to a new set in the proof for KZG
|
||||
pub fn swap_proof_commitments_kzg(
|
||||
pub fn swap_proof_commitments_polycommit(
|
||||
snark: &Snark<Fr, G1Affine>,
|
||||
commitments: &[G1Affine],
|
||||
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
|
||||
let proof = match snark.transcript_type {
|
||||
TranscriptType::EVM => swap_proof_commitments::<
|
||||
Fr,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(snark, commitments)?,
|
||||
TranscriptType::Poseidon => swap_proof_commitments::<
|
||||
Fr,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(snark, commitments)?,
|
||||
let proof = match snark.commitment {
|
||||
Some(Commitments::KZG) => match snark.transcript_type {
|
||||
TranscriptType::EVM => swap_proof_commitments::<
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(snark, commitments)?,
|
||||
TranscriptType::Poseidon => swap_proof_commitments::<
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(snark, commitments)?,
|
||||
},
|
||||
Some(Commitments::IPA) => match snark.transcript_type {
|
||||
TranscriptType::EVM => swap_proof_commitments::<
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(snark, commitments)?,
|
||||
TranscriptType::Poseidon => swap_proof_commitments::<
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(snark, commitments)?,
|
||||
},
|
||||
None => {
|
||||
return Err("commitment scheme not found".into());
|
||||
}
|
||||
};
|
||||
|
||||
Ok(proof)
|
||||
}
|
||||
|
||||
/// A wrapper around halo2's verify_proof
|
||||
pub fn verify_proof_circuit<
|
||||
'params,
|
||||
F: PrimeField,
|
||||
V: Verifier<'params, Scheme>,
|
||||
Scheme: CommitmentScheme,
|
||||
Strategy: VerificationStrategy<'params, Scheme, V>,
|
||||
@@ -701,7 +714,6 @@ where
|
||||
+ PrimeField
|
||||
+ FromUniformBytes<64>
|
||||
+ WithSmallOrderMulGroup<3>
|
||||
+ Ord
|
||||
+ Serialize
|
||||
+ DeserializeOwned,
|
||||
Scheme::Curve: Serialize + DeserializeOwned,
|
||||
@@ -719,7 +731,7 @@ where
|
||||
}
|
||||
|
||||
/// Loads a [VerifyingKey] at `path`.
|
||||
pub fn load_vk<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
|
||||
pub fn load_vk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
path: PathBuf,
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<VerifyingKey<Scheme::Curve>, Box<dyn Error>>
|
||||
@@ -742,7 +754,7 @@ where
|
||||
}
|
||||
|
||||
/// Loads a [ProvingKey] at `path`.
|
||||
pub fn load_pk<Scheme: CommitmentScheme, F: PrimeField + TensorType, C: Circuit<F>>(
|
||||
pub fn load_pk<Scheme: CommitmentScheme, C: Circuit<Scheme::Scalar>>(
|
||||
path: PathBuf,
|
||||
params: <C as Circuit<Scheme::Scalar>>::Params,
|
||||
) -> Result<ProvingKey<Scheme::Curve>, Box<dyn Error>>
|
||||
@@ -765,13 +777,12 @@ where
|
||||
}
|
||||
|
||||
/// Saves a [ProvingKey] to `path`.
|
||||
pub fn save_pk<Scheme: CommitmentScheme>(
|
||||
pub fn save_pk<C: SerdeObject + CurveAffine>(
|
||||
path: &PathBuf,
|
||||
pk: &ProvingKey<Scheme::Curve>,
|
||||
pk: &ProvingKey<C>,
|
||||
) -> Result<(), io::Error>
|
||||
where
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
C::ScalarExt: FromUniformBytes<64> + SerdeObject,
|
||||
{
|
||||
info!("saving proving key 💾");
|
||||
let f = File::create(path)?;
|
||||
@@ -783,13 +794,12 @@ where
|
||||
}
|
||||
|
||||
/// Saves a [VerifyingKey] to `path`.
|
||||
pub fn save_vk<Scheme: CommitmentScheme>(
|
||||
pub fn save_vk<C: CurveAffine + SerdeObject>(
|
||||
path: &PathBuf,
|
||||
vk: &VerifyingKey<Scheme::Curve>,
|
||||
vk: &VerifyingKey<C>,
|
||||
) -> Result<(), io::Error>
|
||||
where
|
||||
Scheme::Curve: SerdeObject + CurveAffine,
|
||||
Scheme::Scalar: PrimeField + SerdeObject + FromUniformBytes<64>,
|
||||
C::ScalarExt: FromUniformBytes<64> + SerdeObject,
|
||||
{
|
||||
info!("saving verification key 💾");
|
||||
let f = File::create(path)?;
|
||||
@@ -813,156 +823,25 @@ pub fn save_params<Scheme: CommitmentScheme>(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// helper function
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn create_proof_circuit_kzg<
|
||||
'params,
|
||||
C: Circuit<Fr>,
|
||||
Strategy: VerificationStrategy<'params, KZGCommitmentScheme<Bn256>, VerifierSHPLONK<'params, Bn256>>,
|
||||
>(
|
||||
circuit: C,
|
||||
params: &'params ParamsKZG<Bn256>,
|
||||
public_inputs: Option<Vec<Fr>>,
|
||||
pk: &ProvingKey<G1Affine>,
|
||||
transcript: TranscriptType,
|
||||
strategy: Strategy,
|
||||
check_mode: CheckMode,
|
||||
split: Option<ProofSplitCommit>,
|
||||
) -> Result<Snark<Fr, G1Affine>, Box<dyn Error>> {
|
||||
let public_inputs = if let Some(public_inputs) = public_inputs {
|
||||
if !public_inputs.is_empty() {
|
||||
vec![public_inputs]
|
||||
} else {
|
||||
vec![vec![]]
|
||||
}
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
match transcript {
|
||||
TranscriptType::EVM => create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
Fr,
|
||||
_,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
_,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(
|
||||
circuit,
|
||||
public_inputs,
|
||||
params,
|
||||
pk,
|
||||
strategy,
|
||||
check_mode,
|
||||
transcript,
|
||||
split,
|
||||
false,
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from),
|
||||
TranscriptType::Poseidon => create_proof_circuit::<
|
||||
KZGCommitmentScheme<_>,
|
||||
Fr,
|
||||
_,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
_,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(
|
||||
circuit,
|
||||
public_inputs,
|
||||
params,
|
||||
pk,
|
||||
strategy,
|
||||
check_mode,
|
||||
transcript,
|
||||
split,
|
||||
true,
|
||||
)
|
||||
.map_err(Box::<dyn Error>::from),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
/// helper function
|
||||
pub(crate) fn verify_proof_circuit_kzg<
|
||||
'params,
|
||||
Strategy: VerificationStrategy<'params, KZGCommitmentScheme<Bn256>, VerifierSHPLONK<'params, Bn256>>,
|
||||
>(
|
||||
params: &'params ParamsKZG<Bn256>,
|
||||
proof: Snark<Fr, G1Affine>,
|
||||
vk: &VerifyingKey<G1Affine>,
|
||||
strategy: Strategy,
|
||||
orig_n: u64,
|
||||
) -> Result<Strategy::Output, halo2_proofs::plonk::Error> {
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
Fr,
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, params, vk, strategy, orig_n),
|
||||
TranscriptType::Poseidon => verify_proof_circuit::<
|
||||
Fr,
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
_,
|
||||
_,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, params, vk, strategy, orig_n),
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////
|
||||
|
||||
#[cfg(test)]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
mod tests {
|
||||
use std::io::copy;
|
||||
|
||||
use super::*;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
use tempfile::Builder;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_can_load_pre_generated_srs() {
|
||||
let tmp_dir = Builder::new().prefix("example").tempdir().unwrap();
|
||||
// lets hope this link never rots
|
||||
let target = "https://trusted-setup-halo2kzg.s3.eu-central-1.amazonaws.com/hermez-raw-1";
|
||||
let response = reqwest::get(target).await.unwrap();
|
||||
|
||||
let fname = response
|
||||
.url()
|
||||
.path_segments()
|
||||
.and_then(|segments| segments.last())
|
||||
.and_then(|name| if name.is_empty() { None } else { Some(name) })
|
||||
.unwrap_or("tmp.bin");
|
||||
|
||||
info!("file to download: '{}'", fname);
|
||||
let fname = tmp_dir.path().join(fname);
|
||||
info!("will be located under: '{:?}'", fname);
|
||||
let mut dest = File::create(fname.clone()).unwrap();
|
||||
let content = response.bytes().await.unwrap();
|
||||
copy(&mut &content[..], &mut dest).unwrap();
|
||||
let res = srs::load_srs::<KZGCommitmentScheme<Bn256>>(fname);
|
||||
assert!(res.is_ok())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_can_load_saved_srs() {
|
||||
let tmp_dir = Builder::new().prefix("example").tempdir().unwrap();
|
||||
let fname = tmp_dir.path().join("kzg.params");
|
||||
let fname = tmp_dir.path().join("polycommit.params");
|
||||
let srs = srs::gen_srs::<KZGCommitmentScheme<Bn256>>(1);
|
||||
let res = save_params::<KZGCommitmentScheme<Bn256>>(&fname, &srs);
|
||||
assert!(res.is_ok());
|
||||
let res = srs::load_srs::<KZGCommitmentScheme<Bn256>>(fname);
|
||||
let res = srs::load_srs_prover::<KZGCommitmentScheme<Bn256>>(fname);
|
||||
assert!(res.is_ok())
|
||||
}
|
||||
|
||||
@@ -977,6 +856,7 @@ mod tests {
|
||||
split: None,
|
||||
pretty_public_inputs: None,
|
||||
timestamp: None,
|
||||
commitment: None,
|
||||
};
|
||||
|
||||
snark
|
||||
|
||||
@@ -17,7 +17,7 @@ pub fn gen_srs<Scheme: CommitmentScheme>(k: u32) -> Scheme::ParamsProver {
|
||||
}
|
||||
|
||||
/// Loads the [CommitmentScheme::ParamsVerifier] at `path`.
|
||||
pub fn load_srs<Scheme: CommitmentScheme>(
|
||||
pub fn load_srs_verifier<Scheme: CommitmentScheme>(
|
||||
path: PathBuf,
|
||||
) -> Result<Scheme::ParamsVerifier, Box<dyn Error>> {
|
||||
info!("loading srs from {:?}", path);
|
||||
@@ -26,3 +26,14 @@ pub fn load_srs<Scheme: CommitmentScheme>(
|
||||
let mut reader = BufReader::new(f);
|
||||
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(Box::<dyn Error>::from)
|
||||
}
|
||||
|
||||
/// Loads the [CommitmentScheme::ParamsVerifier] at `path`.
|
||||
pub fn load_srs_prover<Scheme: CommitmentScheme>(
|
||||
path: PathBuf,
|
||||
) -> Result<Scheme::ParamsProver, Box<dyn Error>> {
|
||||
info!("loading srs from {:?}", path);
|
||||
let f = File::open(path.clone())
|
||||
.map_err(|_| format!("failed to load srs at {}", path.display()))?;
|
||||
let mut reader = BufReader::new(f);
|
||||
Params::<'_, Scheme::Curve>::read(&mut reader).map_err(Box::<dyn Error>::from)
|
||||
}
|
||||
|
||||
145
src/python.rs
145
src/python.rs
@@ -1,4 +1,4 @@
|
||||
use crate::circuit::modules::kzg::KZGChip;
|
||||
use crate::circuit::modules::polycommit::PolyCommitChip;
|
||||
use crate::circuit::modules::poseidon::{
|
||||
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
|
||||
PoseidonChip,
|
||||
@@ -12,12 +12,14 @@ use crate::graph::TestDataSource;
|
||||
use crate::graph::{
|
||||
quantize_float, scale_to_multiplier, GraphCircuit, GraphSettings, Model, Visibility,
|
||||
};
|
||||
use crate::pfsys::evm::aggregation::AggregationCircuit;
|
||||
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
|
||||
use crate::pfsys::{
|
||||
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs, ProofType,
|
||||
TranscriptType,
|
||||
load_pk, load_vk, save_params, save_vk, srs::gen_srs as ezkl_gen_srs, srs::load_srs_prover,
|
||||
ProofType, TranscriptType,
|
||||
};
|
||||
use crate::Commitments;
|
||||
use crate::RunArgs;
|
||||
use halo2_proofs::poly::ipa::commitment::IPACommitmentScheme;
|
||||
use halo2_proofs::poly::kzg::commitment::KZGCommitmentScheme;
|
||||
use halo2curves::bn256::{Bn256, Fq, Fr, G1Affine, G1};
|
||||
use pyo3::exceptions::{PyIOError, PyRuntimeError};
|
||||
@@ -25,6 +27,7 @@ use pyo3::prelude::*;
|
||||
use pyo3::wrap_pyfunction;
|
||||
use pyo3_log;
|
||||
use snark_verifier::util::arithmetic::PrimeField;
|
||||
use std::str::FromStr;
|
||||
use std::{fs::File, path::PathBuf};
|
||||
use tokio::runtime::Runtime;
|
||||
|
||||
@@ -163,6 +166,8 @@ struct PyRunArgs {
|
||||
pub rebase_frac_zero_constants: bool,
|
||||
#[pyo3(get, set)]
|
||||
pub check_mode: CheckMode,
|
||||
#[pyo3(get, set)]
|
||||
pub commitment: PyCommitments,
|
||||
}
|
||||
|
||||
/// default instantiation of PyRunArgs
|
||||
@@ -192,6 +197,7 @@ impl From<PyRunArgs> for RunArgs {
|
||||
div_rebasing: py_run_args.div_rebasing,
|
||||
rebase_frac_zero_constants: py_run_args.rebase_frac_zero_constants,
|
||||
check_mode: py_run_args.check_mode,
|
||||
commitment: py_run_args.commitment.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -213,6 +219,46 @@ impl Into<PyRunArgs> for RunArgs {
|
||||
div_rebasing: self.div_rebasing,
|
||||
rebase_frac_zero_constants: self.rebase_frac_zero_constants,
|
||||
check_mode: self.check_mode,
|
||||
commitment: self.commitment.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
#[derive(Debug, Clone)]
|
||||
/// Pyclass marking the type of commitment
|
||||
pub enum PyCommitments {
|
||||
/// KZG commitment
|
||||
KZG,
|
||||
/// IPA commitment
|
||||
IPA,
|
||||
}
|
||||
|
||||
impl From<PyCommitments> for Commitments {
|
||||
fn from(py_commitments: PyCommitments) -> Self {
|
||||
match py_commitments {
|
||||
PyCommitments::KZG => Commitments::KZG,
|
||||
PyCommitments::IPA => Commitments::IPA,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Into<PyCommitments> for Commitments {
|
||||
fn into(self) -> PyCommitments {
|
||||
match self {
|
||||
Commitments::KZG => PyCommitments::KZG,
|
||||
Commitments::IPA => PyCommitments::IPA,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for PyCommitments {
|
||||
type Err = String;
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"kzg" => Ok(PyCommitments::KZG),
|
||||
"ipa" => Ok(PyCommitments::IPA),
|
||||
_ => Err("Invalid value for Commitments".to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -366,15 +412,56 @@ fn kzg_commit(
|
||||
let settings = GraphSettings::load(&settings_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?;
|
||||
|
||||
let srs_path = crate::execute::get_srs_path(settings.run_args.logrows, srs_path);
|
||||
let srs_path =
|
||||
crate::execute::get_srs_path(settings.run_args.logrows, srs_path, Commitments::KZG);
|
||||
|
||||
let srs = load_srs::<KZGCommitmentScheme<Bn256>>(srs_path)
|
||||
let srs = load_srs_prover::<KZGCommitmentScheme<Bn256>>(srs_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load srs"))?;
|
||||
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(vk_path, settings)
|
||||
let vk = load_vk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(vk_path, settings)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load vk"))?;
|
||||
|
||||
let output = KZGChip::commit(
|
||||
let output = PolyCommitChip::commit::<KZGCommitmentScheme<Bn256>>(
|
||||
message,
|
||||
vk.cs().degree() as u32,
|
||||
(vk.cs().blinding_factors() + 1) as u32,
|
||||
&srs,
|
||||
);
|
||||
|
||||
Ok(output.iter().map(|x| (*x).into()).collect::<Vec<_>>())
|
||||
}
|
||||
|
||||
/// Generate an ipa commitment.
|
||||
#[pyfunction(signature = (
|
||||
message,
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
srs_path=None
|
||||
))]
|
||||
fn ipa_commit(
|
||||
message: Vec<PyFelt>,
|
||||
vk_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> PyResult<Vec<PyG1Affine>> {
|
||||
let message: Vec<Fr> = message
|
||||
.iter()
|
||||
.map(crate::pfsys::string_to_field::<Fr>)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let settings = GraphSettings::load(&settings_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?;
|
||||
|
||||
let srs_path =
|
||||
crate::execute::get_srs_path(settings.run_args.logrows, srs_path, Commitments::KZG);
|
||||
|
||||
let srs = load_srs_prover::<IPACommitmentScheme<G1Affine>>(srs_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load srs"))?;
|
||||
|
||||
let vk = load_vk::<IPACommitmentScheme<G1Affine>, GraphCircuit>(vk_path, settings)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load vk"))?;
|
||||
|
||||
let output = PolyCommitChip::commit::<IPACommitmentScheme<G1Affine>>(
|
||||
message,
|
||||
vk.cs().degree() as u32,
|
||||
(vk.cs().blinding_factors() + 1) as u32,
|
||||
@@ -390,7 +477,7 @@ fn kzg_commit(
|
||||
witness_path=PathBuf::from(DEFAULT_WITNESS),
|
||||
))]
|
||||
fn swap_proof_commitments(proof_path: PathBuf, witness_path: PathBuf) -> PyResult<()> {
|
||||
crate::execute::swap_proof_commitments(proof_path, witness_path)
|
||||
crate::execute::swap_proof_commitments_cmd(proof_path, witness_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to swap commitments"))?;
|
||||
|
||||
Ok(())
|
||||
@@ -410,13 +497,13 @@ fn gen_vk_from_pk_single(
|
||||
let settings = GraphSettings::load(&circuit_settings_path)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load circuit settings"))?;
|
||||
|
||||
let pk = load_pk::<KZGCommitmentScheme<Bn256>, Fr, GraphCircuit>(path_to_pk, settings)
|
||||
let pk = load_pk::<KZGCommitmentScheme<Bn256>, GraphCircuit>(path_to_pk, settings)
|
||||
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
|
||||
|
||||
let vk = pk.get_vk();
|
||||
|
||||
// now save
|
||||
save_vk::<KZGCommitmentScheme<Bn256>>(&vk_output_path, vk)
|
||||
save_vk::<G1Affine>(&vk_output_path, vk)
|
||||
.map_err(|_| PyIOError::new_err("Failed to save vk"))?;
|
||||
|
||||
Ok(true)
|
||||
@@ -428,13 +515,13 @@ fn gen_vk_from_pk_single(
|
||||
vk_output_path=PathBuf::from(DEFAULT_VK_AGGREGATED),
|
||||
))]
|
||||
fn gen_vk_from_pk_aggr(path_to_pk: PathBuf, vk_output_path: PathBuf) -> PyResult<bool> {
|
||||
let pk = load_pk::<KZGCommitmentScheme<Bn256>, Fr, AggregationCircuit>(path_to_pk, ())
|
||||
let pk = load_pk::<KZGCommitmentScheme<Bn256>, AggregationCircuit>(path_to_pk, ())
|
||||
.map_err(|_| PyIOError::new_err("Failed to load pk"))?;
|
||||
|
||||
let vk = pk.get_vk();
|
||||
|
||||
// now save
|
||||
save_vk::<KZGCommitmentScheme<Bn256>>(&vk_output_path, vk)
|
||||
save_vk::<G1Affine>(&vk_output_path, vk)
|
||||
.map_err(|_| PyIOError::new_err("Failed to save vk"))?;
|
||||
|
||||
Ok(true)
|
||||
@@ -471,19 +558,27 @@ fn gen_srs(srs_path: PathBuf, logrows: usize) -> PyResult<()> {
|
||||
#[pyfunction(signature = (
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
logrows=None,
|
||||
srs_path=None
|
||||
srs_path=None,
|
||||
commitment=None,
|
||||
))]
|
||||
fn get_srs(
|
||||
settings_path: Option<PathBuf>,
|
||||
logrows: Option<u32>,
|
||||
srs_path: Option<PathBuf>,
|
||||
commitment: Option<PyCommitments>,
|
||||
) -> PyResult<bool> {
|
||||
let commitment: Option<Commitments> = match commitment {
|
||||
Some(c) => Some(c.into()),
|
||||
None => None,
|
||||
};
|
||||
|
||||
Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(crate::execute::get_srs_cmd(
|
||||
srs_path,
|
||||
settings_path,
|
||||
logrows,
|
||||
commitment,
|
||||
))
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to get srs: {}", e);
|
||||
@@ -719,6 +814,7 @@ fn verify(
|
||||
split_proofs = false,
|
||||
srs_path = None,
|
||||
disable_selector_compression=DEFAULT_DISABLE_SELECTOR_COMPRESSION.parse().unwrap(),
|
||||
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
|
||||
))]
|
||||
fn setup_aggregate(
|
||||
sample_snarks: Vec<PathBuf>,
|
||||
@@ -728,6 +824,7 @@ fn setup_aggregate(
|
||||
split_proofs: bool,
|
||||
srs_path: Option<PathBuf>,
|
||||
disable_selector_compression: bool,
|
||||
commitment: PyCommitments,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::setup_aggregate(
|
||||
sample_snarks,
|
||||
@@ -737,6 +834,7 @@ fn setup_aggregate(
|
||||
logrows,
|
||||
split_proofs,
|
||||
disable_selector_compression,
|
||||
commitment.into(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to setup aggregate: {}", e);
|
||||
@@ -774,6 +872,7 @@ fn compile_circuit(
|
||||
check_mode=CheckMode::UNSAFE,
|
||||
split_proofs = false,
|
||||
srs_path=None,
|
||||
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
|
||||
))]
|
||||
fn aggregate(
|
||||
aggregation_snarks: Vec<PathBuf>,
|
||||
@@ -784,6 +883,7 @@ fn aggregate(
|
||||
check_mode: CheckMode,
|
||||
split_proofs: bool,
|
||||
srs_path: Option<PathBuf>,
|
||||
commitment: PyCommitments,
|
||||
) -> Result<bool, PyErr> {
|
||||
// the K used for the aggregation circuit
|
||||
crate::execute::aggregate(
|
||||
@@ -795,6 +895,7 @@ fn aggregate(
|
||||
logrows,
|
||||
check_mode,
|
||||
split_proofs,
|
||||
commitment.into(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run aggregate: {}", e);
|
||||
@@ -809,15 +910,27 @@ fn aggregate(
|
||||
proof_path=PathBuf::from(DEFAULT_PROOF_AGGREGATED),
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
logrows=DEFAULT_AGGREGATED_LOGROWS.parse().unwrap(),
|
||||
commitment=DEFAULT_COMMITMENT.parse().unwrap(),
|
||||
reduced_srs=DEFAULT_USE_REDUCED_SRS_FOR_VERIFICATION.parse().unwrap(),
|
||||
srs_path=None,
|
||||
))]
|
||||
fn verify_aggr(
|
||||
proof_path: PathBuf,
|
||||
vk_path: PathBuf,
|
||||
logrows: u32,
|
||||
commitment: PyCommitments,
|
||||
reduced_srs: bool,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> Result<bool, PyErr> {
|
||||
crate::execute::verify_aggr(proof_path, vk_path, srs_path, logrows).map_err(|e| {
|
||||
crate::execute::verify_aggr(
|
||||
proof_path,
|
||||
vk_path,
|
||||
srs_path,
|
||||
logrows,
|
||||
reduced_srs,
|
||||
commitment.into(),
|
||||
)
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run verify_aggr: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
@@ -1103,11 +1216,13 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_class::<PyG1Affine>()?;
|
||||
m.add_class::<PyG1>()?;
|
||||
m.add_class::<PyTestDataSource>()?;
|
||||
m.add_class::<PyCommitments>()?;
|
||||
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_big_endian, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_int, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(felt_to_float, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(kzg_commit, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(ipa_commit, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(swap_proof_commitments, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(poseidon_hash, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(float_to_felt, m)?)?;
|
||||
|
||||
@@ -378,7 +378,7 @@ impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<A
|
||||
{
|
||||
fn from(value: Tensor<AssignedCell<Assigned<F>, F>>) -> Tensor<Value<F>> {
|
||||
let mut output = Vec::new();
|
||||
for (_, x) in value.iter().enumerate() {
|
||||
for x in value.iter() {
|
||||
output.push(x.value_field().evaluate());
|
||||
}
|
||||
Tensor::new(Some(&output), value.dims()).unwrap()
|
||||
@@ -434,6 +434,18 @@ impl<F: PrimeField + TensorType + Clone> From<Tensor<i128>> for Tensor<Value<F>>
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
|
||||
maybe_rayon::iter::FromParallelIterator<T> for Tensor<T>
|
||||
{
|
||||
fn from_par_iter<I>(par_iter: I) -> Self
|
||||
where
|
||||
I: maybe_rayon::iter::IntoParallelIterator<Item = T>,
|
||||
{
|
||||
let inner: Vec<T> = par_iter.into_par_iter().collect();
|
||||
Tensor::new(Some(&inner), &[inner.len()]).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType + std::marker::Send + std::marker::Sync>
|
||||
maybe_rayon::iter::IntoParallelIterator for Tensor<T>
|
||||
{
|
||||
|
||||
@@ -8,6 +8,142 @@ use maybe_rayon::{
|
||||
use std::collections::{HashMap, HashSet};
|
||||
pub use std::ops::{Add, Div, Mul, Neg, Sub};
|
||||
|
||||
/// Trilu operation.
|
||||
/// # Arguments
|
||||
/// * `a` - Tensor
|
||||
/// * `k` - i32
|
||||
/// * `upper` - Boolean
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::tensor::ops::trilu;
|
||||
/// let a = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[1, 3, 2],
|
||||
/// ).unwrap();
|
||||
/// let result = trilu(&a, 1, true).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 2, 0, 0, 0, 0]), &[1, 3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, 1, false).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 2, 3, 4, 5, 6]), &[1, 3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, 0, true).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 2, 0, 4, 0, 0]), &[1, 3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, 0, false).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 0, 3, 4, 5, 6]), &[1, 3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, -1, true).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 2, 3, 4, 0, 6]), &[1, 3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, -1, false).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 0, 3, 0, 5, 6]), &[1, 3, 2]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let a = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6]),
|
||||
/// &[1, 2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = trilu(&a, 1, true).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 2, 3, 0, 0, 6]), &[1, 2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, 1, false).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 2, 0, 4, 5, 6]), &[1, 2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, 0, true).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 2, 3, 0, 5, 6]), &[1, 2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, 0, false).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 0, 0, 4, 5, 0]), &[1, 2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, -1, true).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 2, 3, 4, 5, 6]), &[1, 2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, -1, false).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 0, 0, 4, 0, 0]), &[1, 2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let a = Tensor::<i128>::new(
|
||||
/// Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9]),
|
||||
/// &[1, 3, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = trilu(&a, 1, true).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 2, 3, 0, 0, 6, 0, 0, 0]), &[1, 3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, 1, false).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 2, 0, 4, 5, 6, 7, 8, 9]), &[1, 3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, 0, true).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 2, 3, 0, 5, 6, 0, 0, 9]), &[1, 3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, 0, false).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 0, 0, 4, 5, 0, 7, 8, 9]), &[1, 3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, -1, true).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 8, 9]), &[1, 3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// let result = trilu(&a, -1, false).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 0, 0, 4, 0, 0, 7, 8, 0]), &[1, 3, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn trilu<T: TensorType + std::marker::Send + std::marker::Sync>(
|
||||
a: &Tensor<T>,
|
||||
k: i32,
|
||||
upper: bool,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
let mut output = a.clone();
|
||||
|
||||
// Given a 2-D matrix or batches of 2-D matrices, returns the upper or lower triangular part of the tensor(s).
|
||||
// The attribute “upper” determines whether the upper or lower part is retained.
|
||||
// If set to true, the upper triangular matrix is retained. Lower triangular matrix is retained otherwise.
|
||||
// Default value for the “upper” attribute is true. Trilu takes one input tensor of shape [*, N, M], where * is zero or more batch dimensions.
|
||||
// The upper triangular part consists of the elements on and above the given diagonal (k).
|
||||
// The lower triangular part consists of elements on and below the diagonal. All other elements in the matrix are set to zero.
|
||||
|
||||
let batch_dims = &a.dims()[0..a.dims().len() - 2];
|
||||
let batch_cartiesian = batch_dims.iter().map(|d| 0..*d).multi_cartesian_product();
|
||||
|
||||
for b in batch_cartiesian {
|
||||
for i in 0..a.dims()[1] {
|
||||
for j in 0..a.dims()[2] {
|
||||
let mut coord = b.clone();
|
||||
coord.push(i);
|
||||
coord.push(j);
|
||||
// If k = 0, the triangular part on and above/below the main diagonal is retained.
|
||||
|
||||
if upper {
|
||||
// If upper is set to true, a positive k retains the upper triangular matrix excluding the main diagonal and (k-1) diagonals above it.
|
||||
if (j as i32) < (i as i32) + k {
|
||||
output.set(&coord, T::zero().ok_or(TensorError::Unsupported)?);
|
||||
}
|
||||
} else {
|
||||
// If upper is set to false, a positive k retains the lower triangular matrix including the main diagonal and k diagonals above it.
|
||||
if (j as i32) > (i as i32) + k {
|
||||
output.set(&coord, T::zero().ok_or(TensorError::Unsupported)?);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// IFF operation.
|
||||
/// # Arguments
|
||||
/// * `mask` - Tensor of 0s and 1s
|
||||
@@ -1739,11 +1875,7 @@ pub fn topk<T: TensorType + PartialOrd>(
|
||||
let mut indexed_a = a.clone();
|
||||
indexed_a.flatten();
|
||||
|
||||
let mut indexed_a = a
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, x)| (i, x))
|
||||
.collect::<Vec<_>>();
|
||||
let mut indexed_a = a.iter().enumerate().collect::<Vec<_>>();
|
||||
|
||||
if largest {
|
||||
indexed_a.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
|
||||
@@ -3256,6 +3388,47 @@ pub mod nonlinearities {
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies hardswish to a tensor of integers.
|
||||
/// Hardswish is defined as:
|
||||
// Hardswish(x)={0if x≤−3,xif x≥+3,x⋅(x+3)/6otherwise
|
||||
// Hardswish(x)=⎩
|
||||
// ⎨
|
||||
// ⎧0xx⋅(x+3)/6if x≤−3,if x≥+3,otherwise
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `a` - Tensor
|
||||
/// * `scale_input` - Single value
|
||||
/// * `scale_output` - Single value
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::tensor::ops::nonlinearities::hardswish;
|
||||
/// let x = Tensor::<i128>::new(
|
||||
/// Some(&[-12, -3, 2, 1, 1, 15]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = hardswish(&x, 1.0);
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 0, 2, 1, 1, 15]), &[2, 3]).unwrap();
|
||||
///
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// ```
|
||||
pub fn hardswish(a: &Tensor<i128>, scale_input: f64) -> Tensor<i128> {
|
||||
a.par_enum_map(|_, a_i| {
|
||||
let kix = (a_i as f64) / scale_input;
|
||||
let res = if kix <= -3.0 {
|
||||
0.0
|
||||
} else if kix >= 3.0 {
|
||||
kix
|
||||
} else {
|
||||
kix * (kix + 3.0) / 6.0
|
||||
};
|
||||
let rounded = (res * scale_input).round();
|
||||
Ok::<_, TensorError>(rounded as i128)
|
||||
})
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Elementwise applies exponential to a tensor of integers.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -3355,12 +3528,17 @@ pub mod nonlinearities {
|
||||
}
|
||||
|
||||
/// softmax layout
|
||||
pub fn softmax_axes(a: &Tensor<i128>, scale: f64, axes: &[usize]) -> Tensor<i128> {
|
||||
pub fn softmax_axes(
|
||||
a: &Tensor<i128>,
|
||||
input_scale: f64,
|
||||
output_scale: f64,
|
||||
axes: &[usize],
|
||||
) -> Tensor<i128> {
|
||||
// we want this to be as small as possible so we set the output scale to 1
|
||||
let dims = a.dims();
|
||||
|
||||
if dims.len() == 1 {
|
||||
return softmax(a, scale);
|
||||
return softmax(a, input_scale, output_scale);
|
||||
}
|
||||
|
||||
let cartesian_coord = dims[..dims.len() - 1]
|
||||
@@ -3383,7 +3561,7 @@ pub mod nonlinearities {
|
||||
|
||||
let softmax_input = a.get_slice(&sum_dims).unwrap();
|
||||
|
||||
let res = softmax(&softmax_input, scale);
|
||||
let res = softmax(&softmax_input, input_scale, output_scale);
|
||||
|
||||
outputs.push(res);
|
||||
}
|
||||
@@ -3410,20 +3588,25 @@ pub mod nonlinearities {
|
||||
/// Some(&[2, 2, 3, 2, 2, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = softmax(&x, 128.0);
|
||||
/// let result = softmax(&x, 128.0, 128.0 * 128.0);
|
||||
/// // doubles the scale of the input
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2730, 2730, 2751, 2730, 2730, 2688]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2734, 2734, 2755, 2734, 2734, 2692]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn softmax(a: &Tensor<i128>, scale: f64) -> Tensor<i128> {
|
||||
pub fn softmax(a: &Tensor<i128>, input_scale: f64, output_scale: f64) -> Tensor<i128> {
|
||||
// the more accurate calculation is commented out and we implement as below so it matches the steps in layout
|
||||
|
||||
let exp = exp(a, scale);
|
||||
let exp = exp(a, input_scale);
|
||||
|
||||
let sum = sum(&exp).unwrap();
|
||||
let inv_denom = recip(&sum, scale, scale);
|
||||
|
||||
(exp * inv_denom).unwrap()
|
||||
let inv_denom = recip(&sum, input_scale, output_scale);
|
||||
let mut res = (exp * inv_denom).unwrap();
|
||||
res = res
|
||||
.iter()
|
||||
.map(|x| ((*x as f64) / input_scale).round() as i128)
|
||||
.collect();
|
||||
res.reshape(a.dims()).unwrap();
|
||||
res
|
||||
}
|
||||
|
||||
/// Applies range_check_percent
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
|
||||
use super::{
|
||||
ops::{intercalate_values, pad, resize},
|
||||
*,
|
||||
};
|
||||
use halo2_proofs::{arithmetic::Field, plonk::Instance};
|
||||
use halo2_proofs::{arithmetic::Field, circuit::Cell, plonk::Instance};
|
||||
|
||||
pub(crate) fn create_constant_tensor<
|
||||
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
|
||||
@@ -51,6 +53,24 @@ pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::S
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> ValType<F> {
|
||||
/// Returns the inner cell of the [ValType].
|
||||
pub fn cell(&self) -> Option<Cell> {
|
||||
match self {
|
||||
ValType::PrevAssigned(cell) => Some(cell.cell()),
|
||||
ValType::AssignedConstant(cell, _) => Some(cell.cell()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the assigned cell of the [ValType].
|
||||
pub fn assigned_cell(&self) -> Option<AssignedCell<F, F>> {
|
||||
match self {
|
||||
ValType::PrevAssigned(cell) => Some(cell.clone()),
|
||||
ValType::AssignedConstant(cell, _) => Some(cell.clone()),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns true if the value is previously assigned.
|
||||
pub fn is_prev_assigned(&self) -> bool {
|
||||
matches!(
|
||||
@@ -293,7 +313,7 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
/// Allocate a new [ValTensor::Instance] from the ConstraintSystem with the given tensor `dims`, optionally enabling `equality`.
|
||||
pub fn new_instance(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -428,10 +448,40 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Returns the number of constants in the [ValTensor].
|
||||
pub fn num_constants(&self) -> usize {
|
||||
pub fn create_constants_map_iterator(
|
||||
&self,
|
||||
) -> core::iter::FilterMap<
|
||||
core::slice::Iter<'_, ValType<F>>,
|
||||
fn(&ValType<F>) -> Option<(F, ValType<F>)>,
|
||||
> {
|
||||
match self {
|
||||
ValTensor::Value { inner, .. } => inner.iter().filter(|x| x.is_constant()).count(),
|
||||
ValTensor::Instance { .. } => 0,
|
||||
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
|
||||
if let ValType::Constant(v) = x {
|
||||
Some((*v, x.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}),
|
||||
ValTensor::Instance { .. } => {
|
||||
unreachable!("Instance tensors do not have constants")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the number of constants in the [ValTensor].
|
||||
pub fn create_constants_map(&self) -> ConstantsMap<F> {
|
||||
match self {
|
||||
ValTensor::Value { inner, .. } => inner
|
||||
.par_iter()
|
||||
.filter_map(|x| {
|
||||
if let ValType::Constant(v) = x {
|
||||
Some((*v, x.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
ValTensor::Instance { .. } => ConstantsMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use log::{error, warn};
|
||||
use log::{debug, error, warn};
|
||||
|
||||
use crate::circuit::CheckMode;
|
||||
use crate::circuit::{region::ConstantsMap, CheckMode};
|
||||
|
||||
use super::*;
|
||||
/// A wrapper around Halo2's `Column<Fixed>` or `Column<Advice>`.
|
||||
@@ -104,7 +104,7 @@ impl VarTensor {
|
||||
let mut advices = vec![];
|
||||
|
||||
if modulo > 1 {
|
||||
warn!("using column duplication for {} advice blocks", modulo - 1);
|
||||
debug!("using column duplication for {} advice blocks", modulo - 1);
|
||||
}
|
||||
|
||||
for _ in 0..modulo {
|
||||
@@ -150,7 +150,7 @@ impl VarTensor {
|
||||
modulo = (num_constants + modulo) / max_rows + 1;
|
||||
|
||||
if modulo > 1 {
|
||||
warn!("using column duplication for {} fixed columns", modulo - 1);
|
||||
debug!("using column duplication for {} fixed columns", modulo - 1);
|
||||
}
|
||||
|
||||
for _ in 0..modulo {
|
||||
@@ -289,9 +289,10 @@ impl VarTensor {
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
offset: usize,
|
||||
coord: usize,
|
||||
constant: F,
|
||||
) -> Result<AssignedCell<F, F>, halo2_proofs::plonk::Error> {
|
||||
let (x, y, z) = self.cartesian_coord(offset);
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord);
|
||||
match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
region.assign_advice_from_constant(|| "constant", advices[x][y], z, constant)
|
||||
@@ -304,33 +305,28 @@ impl VarTensor {
|
||||
}
|
||||
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor.
|
||||
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub fn assign_with_omissions<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
omissions: &HashSet<&usize>,
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
|
||||
let mut assigned_coord = 0;
|
||||
let mut res: ValTensor<F> = match values {
|
||||
ValTensor::Instance { .. } => {
|
||||
unimplemented!("cannot assign instance to advice columns with omissions")
|
||||
}
|
||||
ValTensor::Value { inner: v, .. } => Ok::<_, halo2_proofs::plonk::Error>(
|
||||
ValTensor::Value { inner: v, .. } => Ok::<ValTensor<F>, halo2_proofs::plonk::Error>(
|
||||
v.enum_map(|coord, k| {
|
||||
if omissions.contains(&coord) {
|
||||
return Ok(k);
|
||||
return Ok::<_, halo2_proofs::plonk::Error>(k);
|
||||
}
|
||||
let cell = self.assign_value(region, offset, k.clone(), assigned_coord)?;
|
||||
let cell =
|
||||
self.assign_value(region, offset, k.clone(), assigned_coord, constants)?;
|
||||
assigned_coord += 1;
|
||||
|
||||
match k {
|
||||
ValType::Constant(f) => Ok::<ValType<F>, halo2_proofs::plonk::Error>(
|
||||
ValType::AssignedConstant(cell, f),
|
||||
),
|
||||
ValType::AssignedConstant(_, f) => Ok(ValType::AssignedConstant(cell, f)),
|
||||
_ => Ok(ValType::PrevAssigned(cell)),
|
||||
}
|
||||
Ok::<_, halo2_proofs::plonk::Error>(cell)
|
||||
})?
|
||||
.into(),
|
||||
),
|
||||
@@ -340,11 +336,12 @@ impl VarTensor {
|
||||
}
|
||||
|
||||
/// Assigns [ValTensor] to the columns of the inner tensor.
|
||||
pub fn assign<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub fn assign<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
|
||||
let mut res: ValTensor<F> = match values {
|
||||
ValTensor::Instance {
|
||||
@@ -382,14 +379,7 @@ impl VarTensor {
|
||||
},
|
||||
ValTensor::Value { inner: v, .. } => Ok(v
|
||||
.enum_map(|coord, k| {
|
||||
let cell = self.assign_value(region, offset, k.clone(), coord)?;
|
||||
match k {
|
||||
ValType::Constant(f) => Ok::<ValType<F>, halo2_proofs::plonk::Error>(
|
||||
ValType::AssignedConstant(cell, f),
|
||||
),
|
||||
ValType::AssignedConstant(_, f) => Ok(ValType::AssignedConstant(cell, f)),
|
||||
_ => Ok(ValType::PrevAssigned(cell)),
|
||||
}
|
||||
self.assign_value(region, offset, k.clone(), coord, constants)
|
||||
})?
|
||||
.into()),
|
||||
}?;
|
||||
@@ -399,13 +389,16 @@ impl VarTensor {
|
||||
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
|
||||
pub fn dummy_assign_with_duplication<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub fn dummy_assign_with_duplication<
|
||||
F: PrimeField + TensorType + PartialOrd + std::hash::Hash,
|
||||
>(
|
||||
&self,
|
||||
row: usize,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
single_inner_col: bool,
|
||||
) -> Result<(ValTensor<F>, usize, usize), halo2_proofs::plonk::Error> {
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
|
||||
match values {
|
||||
ValTensor::Instance { .. } => unimplemented!("duplication is not supported on instance columns. increase K if you require more rows."),
|
||||
ValTensor::Value { inner: v, dims , ..} => {
|
||||
@@ -430,21 +423,24 @@ impl VarTensor {
|
||||
|
||||
// duplicates every nth element to adjust for column overflow
|
||||
let mut res: ValTensor<F> = v.duplicate_every_n(duplication_freq, num_repeats, duplication_offset).unwrap().into();
|
||||
|
||||
let constants_map = res.create_constants_map();
|
||||
constants.extend(constants_map);
|
||||
|
||||
let total_used_len = res.len();
|
||||
let total_constants = res.num_constants();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
|
||||
res.reshape(dims).unwrap();
|
||||
res.set_scale(values.scale());
|
||||
|
||||
Ok((res, total_used_len, total_constants))
|
||||
Ok((res, total_used_len))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Assigns specific values (`ValTensor`) to the columns of the inner tensor but allows for column wrapping for accumulated operations.
|
||||
/// Duplication occurs by copying the last cell of the column to the first cell next column and creating a copy constraint between the two.
|
||||
pub fn assign_with_duplication<F: PrimeField + TensorType + PartialOrd>(
|
||||
pub fn assign_with_duplication<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
row: usize,
|
||||
@@ -452,7 +448,8 @@ impl VarTensor {
|
||||
values: &ValTensor<F>,
|
||||
check_mode: &CheckMode,
|
||||
single_inner_col: bool,
|
||||
) -> Result<(ValTensor<F>, usize, usize), halo2_proofs::plonk::Error> {
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<(ValTensor<F>, usize), halo2_proofs::plonk::Error> {
|
||||
let mut prev_cell = None;
|
||||
|
||||
match values {
|
||||
@@ -494,7 +491,7 @@ impl VarTensor {
|
||||
assert_eq!(Into::<i32>::into(k.clone()), Into::<i32>::into(v[coord - 1].clone()));
|
||||
};
|
||||
|
||||
let cell = self.assign_value(region, offset, k.clone(), coord * step)?;
|
||||
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
|
||||
|
||||
if single_inner_col {
|
||||
if z == 0 {
|
||||
@@ -502,28 +499,23 @@ impl VarTensor {
|
||||
prev_cell = Some(cell.clone());
|
||||
} else if coord > 0 && z == 0 && single_inner_col {
|
||||
if let Some(prev_cell) = prev_cell.as_ref() {
|
||||
region.constrain_equal(prev_cell.cell(),cell.cell())?;
|
||||
let cell = cell.cell().ok_or({
|
||||
error!("Error getting cell: {:?}", (x,y));
|
||||
halo2_proofs::plonk::Error::Synthesis})?;
|
||||
let prev_cell = prev_cell.cell().ok_or({
|
||||
error!("Error getting cell: {:?}", (x,y));
|
||||
halo2_proofs::plonk::Error::Synthesis})?;
|
||||
region.constrain_equal(prev_cell,cell)?;
|
||||
} else {
|
||||
error!("Error copy-constraining previous value: {:?}", (x,y));
|
||||
return Err(halo2_proofs::plonk::Error::Synthesis);
|
||||
}
|
||||
}}
|
||||
|
||||
match k {
|
||||
ValType::Constant(f) => {
|
||||
Ok(ValType::AssignedConstant(cell, f))
|
||||
},
|
||||
ValType::AssignedConstant(_, f) => {
|
||||
Ok(ValType::AssignedConstant(cell, f))
|
||||
},
|
||||
_ => {
|
||||
Ok(ValType::PrevAssigned(cell))
|
||||
}
|
||||
}
|
||||
Ok(cell)
|
||||
|
||||
})?.into()};
|
||||
let total_used_len = res.len();
|
||||
let total_constants = res.num_constants();
|
||||
res.remove_every_n(duplication_freq, num_repeats, duplication_offset).unwrap();
|
||||
|
||||
res.reshape(dims).unwrap();
|
||||
@@ -542,42 +534,61 @@ impl VarTensor {
|
||||
)};
|
||||
}
|
||||
|
||||
Ok((res, total_used_len, total_constants))
|
||||
Ok((res, total_used_len))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn assign_value<F: PrimeField + TensorType + PartialOrd>(
|
||||
fn assign_value<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
region: &mut Region<F>,
|
||||
offset: usize,
|
||||
k: ValType<F>,
|
||||
coord: usize,
|
||||
) -> Result<AssignedCell<F, F>, halo2_proofs::plonk::Error> {
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<ValType<F>, halo2_proofs::plonk::Error> {
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord);
|
||||
match k {
|
||||
let res = match k {
|
||||
ValType::Value(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
region.assign_advice(|| "k", advices[x][y], z, || v)
|
||||
ValType::PrevAssigned(region.assign_advice(|| "k", advices[x][y], z, || v)?)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => match &self {
|
||||
ValType::PrevAssigned(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
v.copy_advice(|| "k", region, advices[x][y], z)
|
||||
ValType::PrevAssigned(v.copy_advice(|| "k", region, advices[x][y], z)?)
|
||||
}
|
||||
_ => {
|
||||
error!("PrevAssigned is only supported for advice columns");
|
||||
Err(halo2_proofs::plonk::Error::Synthesis)
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
ValType::AssignedConstant(v, val) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => {
|
||||
ValType::AssignedConstant(v.copy_advice(|| "k", region, advices[x][y], z)?, val)
|
||||
}
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
ValType::AssignedValue(v) => match &self {
|
||||
VarTensor::Advice { inner: advices, .. } => region
|
||||
.assign_advice(|| "k", advices[x][y], z, || v)
|
||||
.map(|a| a.evaluate()),
|
||||
VarTensor::Advice { inner: advices, .. } => ValType::PrevAssigned(
|
||||
region
|
||||
.assign_advice(|| "k", advices[x][y], z, || v)?
|
||||
.evaluate(),
|
||||
),
|
||||
_ => unimplemented!(),
|
||||
},
|
||||
ValType::Constant(v) => self.assign_constant(region, offset + coord, v),
|
||||
}
|
||||
ValType::Constant(v) => {
|
||||
if let std::collections::hash_map::Entry::Vacant(e) = constants.entry(v) {
|
||||
let value = ValType::AssignedConstant(
|
||||
self.assign_constant(region, offset, coord, v)?,
|
||||
v,
|
||||
);
|
||||
e.insert(value.clone());
|
||||
value
|
||||
} else {
|
||||
let cell = constants.get(&v).unwrap();
|
||||
self.assign_value(region, offset, cell.clone(), coord, constants)?
|
||||
}
|
||||
}
|
||||
};
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
253
src/wasm.rs
253
src/wasm.rs
@@ -6,22 +6,39 @@ use crate::fieldutils::i128_to_felt;
|
||||
use crate::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use crate::graph::quantize_float;
|
||||
use crate::graph::scale_to_multiplier;
|
||||
use crate::graph::{GraphCircuit, GraphSettings};
|
||||
use crate::pfsys::create_proof_circuit;
|
||||
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
|
||||
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
|
||||
use crate::pfsys::verify_proof_circuit;
|
||||
use crate::pfsys::TranscriptType;
|
||||
use crate::tensor::TensorType;
|
||||
use crate::CheckMode;
|
||||
use crate::Commitments;
|
||||
use console_error_panic_hook;
|
||||
use halo2_proofs::plonk::*;
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver};
|
||||
use halo2_proofs::poly::ipa::multiopen::{ProverIPA, VerifierIPA};
|
||||
use halo2_proofs::poly::ipa::{
|
||||
commitment::{IPACommitmentScheme, ParamsIPA},
|
||||
strategy::SingleStrategy as IPASingleStrategy,
|
||||
};
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::{
|
||||
commitment::{KZGCommitmentScheme, ParamsKZG},
|
||||
strategy::SingleStrategy as KZGSingleStrategy,
|
||||
};
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
use halo2curves::ff::{FromUniformBytes, PrimeField};
|
||||
|
||||
use crate::tensor::TensorType;
|
||||
use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use std::str::FromStr;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
|
||||
|
||||
use console_error_panic_hook;
|
||||
|
||||
#[cfg(feature = "web")]
|
||||
pub use wasm_bindgen_rayon::init_thread_pool;
|
||||
|
||||
@@ -37,9 +54,6 @@ pub fn init_panic_hook() {
|
||||
console_error_panic_hook::set_once();
|
||||
}
|
||||
|
||||
use crate::graph::{GraphCircuit, GraphSettings};
|
||||
use crate::pfsys::{create_proof_circuit_kzg, verify_proof_circuit_kzg};
|
||||
|
||||
/// Wrapper around the halo2 encode call data method
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
@@ -222,7 +236,7 @@ pub fn genWitness(
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
let witness = circuit
|
||||
.forward(&mut input, None, None, false)
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, false)
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
serde_json::to_vec(&witness)
|
||||
@@ -307,15 +321,10 @@ pub fn verify(
|
||||
settings: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
srs: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
) -> Result<bool, JsError> {
|
||||
let mut reader = std::io::BufReader::new(&srs[..]);
|
||||
let params: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
|
||||
|
||||
let circuit_settings: GraphSettings = serde_json::from_slice(&settings[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize settings: {}", e)))?;
|
||||
|
||||
let snark: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof_js[..])
|
||||
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof_js[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
|
||||
|
||||
let mut reader = std::io::BufReader::new(&vk[..]);
|
||||
@@ -326,15 +335,142 @@ pub fn verify(
|
||||
)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
|
||||
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
let orig_n = 1 << circuit_settings.run_args.logrows;
|
||||
|
||||
let result = verify_proof_circuit_kzg(
|
||||
params.verifier_params(),
|
||||
snark,
|
||||
&vk,
|
||||
strategy,
|
||||
1 << circuit_settings.run_args.logrows,
|
||||
);
|
||||
let mut reader = std::io::BufReader::new(&srs[..]);
|
||||
let result = match circuit_settings.run_args.commitment {
|
||||
Commitments::KZG => {
|
||||
let params: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n),
|
||||
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n)
|
||||
}
|
||||
}
|
||||
}
|
||||
Commitments::IPA => {
|
||||
let params: ParamsIPA<_> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
|
||||
let strategy = IPASingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n),
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, orig_n)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(_) => Ok(true),
|
||||
Err(e) => Err(JsError::new(&format!("{}", e))),
|
||||
}
|
||||
}
|
||||
|
||||
#[wasm_bindgen]
|
||||
#[allow(non_snake_case)]
|
||||
/// Verify aggregate proof in browser using wasm
|
||||
pub fn verifyAggr(
|
||||
proof_js: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
vk: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
logrows: u64,
|
||||
srs: wasm_bindgen::Clamped<Vec<u8>>,
|
||||
commitment: &str,
|
||||
) -> Result<bool, JsError> {
|
||||
let proof: crate::pfsys::Snark<Fr, G1Affine> = serde_json::from_slice(&proof_js[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize proof: {}", e)))?;
|
||||
|
||||
let mut reader = std::io::BufReader::new(&vk[..]);
|
||||
let vk = VerifyingKey::<G1Affine>::read::<_, AggregationCircuit>(
|
||||
&mut reader,
|
||||
halo2_proofs::SerdeFormat::RawBytes,
|
||||
(),
|
||||
)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize vk: {}", e)))?;
|
||||
|
||||
let commit = Commitments::from_str(commitment).map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
let mut reader = std::io::BufReader::new(&srs[..]);
|
||||
let result = match commit {
|
||||
Commitments::KZG => {
|
||||
let params: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
|
||||
let strategy = KZGSingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, logrows),
|
||||
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierSHPLONK<'_, Bn256>,
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, logrows)
|
||||
}
|
||||
}
|
||||
}
|
||||
Commitments::IPA => {
|
||||
let params: ParamsIPA<_> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize params: {}", e)))?;
|
||||
let strategy = IPASingleStrategy::new(params.verifier_params());
|
||||
match proof.transcript_type {
|
||||
TranscriptType::EVM => verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<G1Affine, _, _, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, logrows),
|
||||
TranscriptType::Poseidon => {
|
||||
verify_proof_circuit::<
|
||||
VerifierIPA<_>,
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
PoseidonTranscript<NativeLoader, _>,
|
||||
>(&proof, ¶ms, &vk, strategy, logrows)
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(_) => Ok(true),
|
||||
@@ -354,11 +490,6 @@ pub fn prove(
|
||||
log::set_max_level(log::LevelFilter::Debug);
|
||||
#[cfg(not(feature = "det-prove"))]
|
||||
log::set_max_level(log::LevelFilter::Info);
|
||||
// read in kzg params
|
||||
let mut reader = std::io::BufReader::new(&srs[..]);
|
||||
let params: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize srs: {}", e)))?;
|
||||
|
||||
// read in circuit
|
||||
let mut circuit: crate::graph::GraphCircuit = bincode::deserialize(&compiled_circuit[..])
|
||||
@@ -386,17 +517,63 @@ pub fn prove(
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
let proof_split_commits: Option<crate::pfsys::ProofSplitCommit> = data.into();
|
||||
|
||||
let strategy = KZGSingleStrategy::new(¶ms);
|
||||
let proof = create_proof_circuit_kzg(
|
||||
circuit,
|
||||
¶ms,
|
||||
Some(public_inputs),
|
||||
&pk,
|
||||
crate::pfsys::TranscriptType::EVM,
|
||||
strategy,
|
||||
crate::circuit::CheckMode::UNSAFE,
|
||||
proof_split_commits,
|
||||
)
|
||||
// read in kzg params
|
||||
let mut reader = std::io::BufReader::new(&srs[..]);
|
||||
// creates and verifies the proof
|
||||
let proof = match circuit.settings().run_args.commitment {
|
||||
Commitments::KZG => {
|
||||
let params: ParamsKZG<Bn256> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize srs: {}", e)))?;
|
||||
|
||||
create_proof_circuit::<
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
_,
|
||||
ProverSHPLONK<_>,
|
||||
VerifierSHPLONK<_>,
|
||||
KZGSingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit,
|
||||
vec![public_inputs],
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
crate::Commitments::KZG,
|
||||
TranscriptType::EVM,
|
||||
proof_split_commits,
|
||||
None,
|
||||
)
|
||||
}
|
||||
Commitments::IPA => {
|
||||
let params: ParamsIPA<_> =
|
||||
halo2_proofs::poly::commitment::Params::<'_, G1Affine>::read(&mut reader)
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize srs: {}", e)))?;
|
||||
|
||||
create_proof_circuit::<
|
||||
IPACommitmentScheme<G1Affine>,
|
||||
_,
|
||||
ProverIPA<_>,
|
||||
VerifierIPA<_>,
|
||||
IPASingleStrategy<_>,
|
||||
_,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
EvmTranscript<_, _, _, _>,
|
||||
>(
|
||||
circuit,
|
||||
vec![public_inputs],
|
||||
¶ms,
|
||||
&pk,
|
||||
CheckMode::UNSAFE,
|
||||
crate::Commitments::IPA,
|
||||
TranscriptType::EVM,
|
||||
proof_split_commits,
|
||||
None,
|
||||
)
|
||||
}
|
||||
}
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
Ok(serde_json::to_string(&proof)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
{"protocol":null,"instances":[[[12436184717236109307,3962172157175319849,7381016538464732718,1011752739694698287]],[[6425625360762666998,7924344314350639699,14762033076929465436,2023505479389396574]]],"proof":[1,2,3,4,5,6,7,8],"transcript_type":"EVM"}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user