mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 16:27:59 -05:00
Compare commits
17 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b49b0487c4 | ||
|
|
61b7a8e9b5 | ||
|
|
5dbc7d5176 | ||
|
|
ada45a3197 | ||
|
|
616b421967 | ||
|
|
f64f0ecd23 | ||
|
|
5be12b7a54 | ||
|
|
2fd877c716 | ||
|
|
8197340985 | ||
|
|
6855ea1947 | ||
|
|
2ca57bde2c | ||
|
|
390de88194 | ||
|
|
cd91f0af26 | ||
|
|
4771192823 | ||
|
|
a863ccc868 | ||
|
|
8e6ccc863d | ||
|
|
00d6873f9a |
7
.github/workflows/engine.yml
vendored
7
.github/workflows/engine.yml
vendored
@@ -22,15 +22,18 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- name: Install binaryen
|
||||
run: |
|
||||
set -e
|
||||
|
||||
2
.github/workflows/large-tests.yml
vendored
2
.github/workflows/large-tests.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: nanoGPT Mock
|
||||
|
||||
4
.github/workflows/pypi.yml
vendored
4
.github/workflows/pypi.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
@@ -86,7 +86,7 @@ jobs:
|
||||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
|
||||
21
.github/workflows/release.yml
vendored
21
.github/workflows/release.yml
vendored
@@ -45,7 +45,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Checkout repo
|
||||
@@ -106,27 +106,27 @@ jobs:
|
||||
include:
|
||||
- build: windows-msvc
|
||||
os: windows-latest
|
||||
rust: nightly-2024-02-06
|
||||
rust: nightly-2024-07-18
|
||||
target: x86_64-pc-windows-msvc
|
||||
- build: macos
|
||||
os: macos-13
|
||||
rust: nightly-2024-02-06
|
||||
rust: nightly-2024-07-18
|
||||
target: x86_64-apple-darwin
|
||||
- build: macos-aarch64
|
||||
os: macos-13
|
||||
rust: nightly-2024-02-06
|
||||
rust: nightly-2024-07-18
|
||||
target: aarch64-apple-darwin
|
||||
- build: linux-musl
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2024-02-06
|
||||
rust: nightly-2024-07-18
|
||||
target: x86_64-unknown-linux-musl
|
||||
- build: linux-gnu
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2024-02-06
|
||||
rust: nightly-2024-07-18
|
||||
target: x86_64-unknown-linux-gnu
|
||||
- build: linux-aarch64
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2024-02-06
|
||||
rust: nightly-2024-07-18
|
||||
target: aarch64-unknown-linux-gnu
|
||||
|
||||
steps:
|
||||
@@ -181,9 +181,14 @@ jobs:
|
||||
echo "target flag is: ${{ env.TARGET_FLAGS }}"
|
||||
echo "target dir is: ${{ env.TARGET_DIR }}"
|
||||
|
||||
- name: Build release binary
|
||||
- name: Build release binary (no asm)
|
||||
if: matrix.build != 'linux-gnu'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry
|
||||
|
||||
- name: Build release binary (asm)
|
||||
if: matrix.build == 'linux-gnu'
|
||||
run: ${{ env.CARGO }} build --release ${{ env.TARGET_FLAGS }} -Z sparse-registry --features asm
|
||||
|
||||
- name: Strip release binary
|
||||
if: matrix.build != 'windows-msvc' && matrix.build != 'linux-aarch64'
|
||||
run: strip "target/${{ matrix.target }}/release/ezkl"
|
||||
|
||||
150
.github/workflows/rust.yml
vendored
150
.github/workflows/rust.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Docs
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -73,7 +73,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -106,7 +106,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -172,7 +172,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -189,17 +189,20 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
- uses: nanasess/setup-chromedriver@v2
|
||||
# with:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
- name: Install wasm32-unknown-unknown
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- name: Run wasm verifier tests
|
||||
# on mac:
|
||||
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
|
||||
@@ -212,7 +215,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -229,13 +232,15 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
# - name: The Worm Mock
|
||||
# run: cargo nextest run --release --verbose tests::large_mock_::large_tests_5_expects -- --include-ignored
|
||||
- name: public outputs and tolerance > 0
|
||||
run: cargo nextest run --release --verbose tests::mock_tolerance_public_outputs_ --test-threads 32
|
||||
- name: public outputs + batch size == 10
|
||||
@@ -258,6 +263,8 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_input_::t --test-threads 32
|
||||
- name: hashed params
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_params_::t --test-threads 32
|
||||
- name: hashed params public inputs
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
|
||||
- name: hashed outputs
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_output_::t --test-threads 32
|
||||
- name: hashed inputs + params + outputs
|
||||
@@ -286,7 +293,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -304,7 +311,7 @@ jobs:
|
||||
node-version: "18.12.1"
|
||||
cache: "pnpm"
|
||||
- name: "Add rust-src"
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- name: Install dependencies for js tests and in-browser-evm-verifier package
|
||||
run: |
|
||||
pnpm install --frozen-lockfile
|
||||
@@ -323,10 +330,10 @@ jobs:
|
||||
cd in-browser-evm-verifier
|
||||
pnpm build:commonjs
|
||||
cd ..
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
# - name: Install solc
|
||||
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: KZG prove and verify tests (EVM + VK rendered seperately)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
@@ -345,6 +352,8 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_kzg_output_kzg_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain outputs & kzg inputs + params)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_output_kzg_input_kzg_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain all kzg)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_all_kzg_params_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + on chain inputs & outputs hashes)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_on_chain_input_output_hashed_prove_and_verify --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM)
|
||||
@@ -363,15 +372,18 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
with:
|
||||
# Pin to version 0.12.1
|
||||
version: 'v0.12.1'
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
@@ -429,40 +441,40 @@ jobs:
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed
|
||||
|
||||
prove-and-verify-tests-gpu:
|
||||
runs-on: GPU
|
||||
env:
|
||||
ENABLE_ICICLE_GPU: true
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: KZG prove and verify tests (kzg outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public outputs + column overflow)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (public inputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (fixed params)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
|
||||
- name: KZG prove and verify tests (hashed outputs)
|
||||
run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
# prove-and-verify-tests-gpu:
|
||||
# runs-on: GPU
|
||||
# env:
|
||||
# ENABLE_ICICLE_GPU: true
|
||||
# steps:
|
||||
# - uses: actions/checkout@v4
|
||||
# - uses: actions-rs/toolchain@v1
|
||||
# with:
|
||||
# toolchain: nightly-2024-07-18
|
||||
# override: true
|
||||
# components: rustfmt, clippy
|
||||
# - name: Add rust-src
|
||||
# run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
# - uses: actions/checkout@v3
|
||||
# - uses: baptiste0928/cargo-install@v1
|
||||
# with:
|
||||
# crate: cargo-nextest
|
||||
# locked: true
|
||||
# - name: KZG prove and verify tests (kzg outputs)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_kzg_output --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + column overflow)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_::w --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + fixed params + column overflow)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_with_overflow_fixed_params_ --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public outputs + column overflow)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_::t --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (public inputs)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_public_input --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (fixed params)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_fixed_params --features icicle --test-threads 1
|
||||
# - name: KZG prove and verify tests (hashed outputs)
|
||||
# run: cargo nextest run --release --verbose tests::kzg_prove_and_verify_hashed --features icicle --test-threads 1
|
||||
|
||||
prove-and-verify-mock-aggr-tests:
|
||||
runs-on: self-hosted
|
||||
@@ -471,7 +483,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -489,7 +501,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -506,7 +518,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -523,17 +535,17 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
# - name: Install solc
|
||||
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: KZG prove and verify aggr tests
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
@@ -544,7 +556,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -564,17 +576,17 @@ jobs:
|
||||
python-version: "3.12"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Install cmake
|
||||
run: sudo apt-get install -y cmake
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
# - name: Install solc
|
||||
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Run pytest
|
||||
@@ -590,7 +602,7 @@ jobs:
|
||||
python-version: "3.12"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -639,17 +651,17 @@ jobs:
|
||||
python-version: "3.11"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
# - name: Install solc
|
||||
# run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: Install pip
|
||||
run: python -m ensurepip --upgrade
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
|
||||
746
Cargo.lock
generated
746
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
50
Cargo.toml
50
Cargo.toml
@@ -16,35 +16,33 @@ crate-type = ["cdylib", "rlib"]
|
||||
|
||||
[dependencies]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch = "ac/optional-selector-poly" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "9fff22c", features = [
|
||||
"derive_serde",
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "b753a832e92d5c86c5c997327a9cf9de86a18851", features = [
|
||||
"derive_serde"
|
||||
] }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#8b13a0d2a7a34d8daab010dadb2c47dfa47d37d0", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
|
||||
rand = { version = "0.8", default_features = false }
|
||||
itertools = { version = "0.10.3", default_features = false }
|
||||
clap = { version = "4.5.3", features = ["derive"] }
|
||||
clap_complete = "4.5.2"
|
||||
serde = { version = "1.0.126", features = ["derive"], optional = true }
|
||||
serde_json = { version = "1.0.97", default_features = false, features = [
|
||||
"float_roundtrip",
|
||||
"raw_value",
|
||||
], optional = true }
|
||||
clap_complete = "4.5.2"
|
||||
log = { version = "0.4.17", default_features = false, optional = true }
|
||||
thiserror = { version = "1.0.38", default_features = false }
|
||||
hex = { version = "0.4.3", default_features = false }
|
||||
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/chunked-mv-lookup", package = "ecc" }
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/chunked-mv-lookup", features = [
|
||||
"derive_serde",
|
||||
] }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "ac/update-h2-curves" }
|
||||
maybe-rayon = { version = "0.1.1", default_features = false }
|
||||
bincode = { version = "1.3.3", default_features = false }
|
||||
ark-std = { version = "^0.3.0", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = "1.6.0"
|
||||
tosubcommand = { git = "https://github.com/zkonduit/enum_to_subcommand", package = "tosubcommand" }
|
||||
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
|
||||
semver = "1.0.22"
|
||||
|
||||
# evm related deps
|
||||
@@ -63,31 +61,29 @@ reqwest = { version = "0.12.4", default-features = false, features = [
|
||||
openssl = { version = "0.10.55", features = ["vendored"] }
|
||||
tokio-postgres = "0.7.10"
|
||||
pg_bigdecimal = "0.1.5"
|
||||
futures-util = "0.3.30"
|
||||
lazy_static = "1.4.0"
|
||||
colored_json = { version = "3.0.1", default_features = false, optional = true }
|
||||
plotters = { version = "0.3.0", default_features = false, optional = true }
|
||||
regex = { version = "1", default_features = false }
|
||||
tokio = { version = "1.35", default_features = false, features = [
|
||||
tokio = { version = "1.35.0", default_features = false, features = [
|
||||
"macros",
|
||||
"rt-multi-thread"
|
||||
"rt-multi-thread",
|
||||
] }
|
||||
tokio-util = { version = "0.7.9", features = ["codec"] }
|
||||
pyo3 = { version = "0.21.2", features = [
|
||||
"extension-module",
|
||||
"abi3-py37",
|
||||
"macros",
|
||||
], default_features = false, optional = true }
|
||||
pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="migration-pyo3-0.21", features = [
|
||||
"attributes",
|
||||
"attributes",
|
||||
"tokio-runtime",
|
||||
], default_features = false, optional = true }
|
||||
|
||||
pyo3-log = { version = "0.10.0", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "05ebf550aa9922b221af4635c21a67a8d2af12a9", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "40c64319291184814d9fea5fdf4fa16f5a4f7116", default_features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
|
||||
metal = { git = "https://github.com/gfx-rs/metal-rs", optional = true }
|
||||
objc = { version = "0.2.4", optional = true }
|
||||
|
||||
mimalloc = "0.1"
|
||||
|
||||
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dependencies]
|
||||
colored = { version = "2.0.0", default_features = false, optional = true }
|
||||
@@ -95,6 +91,7 @@ env_logger = { version = "0.10.0", default_features = false, optional = true }
|
||||
chrono = "0.4.31"
|
||||
sha256 = "1.4.0"
|
||||
|
||||
|
||||
[target.'cfg(target_arch = "wasm32")'.dependencies]
|
||||
getrandom = { version = "0.2.8", features = ["js"] }
|
||||
instant = { version = "0.1", features = ["wasm-bindgen", "inaccurate"] }
|
||||
@@ -108,8 +105,10 @@ console_error_panic_hook = "0.1.7"
|
||||
wasm-bindgen-console-logger = "0.1.1"
|
||||
|
||||
|
||||
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dev-dependencies]
|
||||
criterion = { version = "0.5.1", features = ["html_reports"] }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.3", features = ["html_reports"] }
|
||||
tempfile = "3.3.0"
|
||||
lazy_static = "1.4.0"
|
||||
mnist = "0.5"
|
||||
@@ -180,7 +179,7 @@ required-features = ["ezkl"]
|
||||
|
||||
[features]
|
||||
web = ["wasm-bindgen-rayon"]
|
||||
default = ["ezkl", "mv-lookup", "no-banner"]
|
||||
default = ["ezkl", "mv-lookup", "precompute-coset", "no-banner", "parallel-poly-read"]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
|
||||
ezkl = [
|
||||
@@ -194,24 +193,31 @@ ezkl = [
|
||||
"colored_json",
|
||||
"halo2_proofs/circuit-params",
|
||||
]
|
||||
parallel-poly-read = ["halo2_proofs/parallel-poly-read"]
|
||||
mv-lookup = [
|
||||
"halo2_proofs/mv-lookup",
|
||||
"snark-verifier/mv-lookup",
|
||||
"halo2_solidity_verifier/mv-lookup",
|
||||
]
|
||||
asm = ["halo2curves/asm", "halo2_proofs/asm"]
|
||||
precompute-coset = ["halo2_proofs/precompute-coset"]
|
||||
det-prove = []
|
||||
icicle = ["halo2_proofs/icicle_gpu"]
|
||||
empty-cmd = []
|
||||
no-banner = []
|
||||
metal = ["dep:metal", "dep:objc"]
|
||||
no-update = []
|
||||
|
||||
|
||||
# icicle patch to 0.1.0 if feature icicle is enabled
|
||||
[patch.'https://github.com/ingonyama-zk/icicle']
|
||||
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a", package = "halo2_proofs", branch = "ac/optional-selector-poly" }
|
||||
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/cache-lookup-commitments#8b13a0d2a7a34d8daab010dadb2c47dfa47d37d0", package = "halo2_proofs", branch = "ac/cache-lookup-commitments" }
|
||||
|
||||
[profile.release]
|
||||
rustflags = ["-C", "relocation-model=pic"]
|
||||
# lto = "fat"
|
||||
# codegen-units = 1
|
||||
# panic = "abort"
|
||||
|
||||
|
||||
@@ -72,6 +72,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(0, 0)],
|
||||
stride: vec![1; 2],
|
||||
group: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -2,6 +2,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
|
||||
use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
|
||||
use ezkl::fieldutils::IntegerRep;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
@@ -84,7 +85,7 @@ fn runrelu(c: &mut Criterion) {
|
||||
};
|
||||
|
||||
let input: Tensor<Value<Fr>> =
|
||||
Tensor::<i32>::from((0..len).map(|_| rng.gen_range(0..10))).into();
|
||||
Tensor::<IntegerRep>::from((0..len).map(|_| rng.gen_range(0..10))).into();
|
||||
|
||||
let circuit = NLCircuit {
|
||||
input: ValTensor::from(input.clone()),
|
||||
|
||||
@@ -93,9 +93,6 @@ contract LoadInstances {
|
||||
}
|
||||
}
|
||||
|
||||
// Contract that checks that the COMMITMENT_KZG bytes is equal to the first part of the proof.
|
||||
pragma solidity ^0.8.0;
|
||||
|
||||
// The kzg commitments of a given model, all aggregated into a single bytes array.
|
||||
// At solidity generation time, the commitments are hardcoded into the contract via the COMMITMENT_KZG constant.
|
||||
// It will be used to check that the proof commitments match the expected commitments.
|
||||
@@ -163,7 +160,7 @@ contract SwapProofCommitments {
|
||||
}
|
||||
|
||||
return equal; // Return true if the commitment comparison passed
|
||||
}
|
||||
} /// end checkKzgCommits
|
||||
}
|
||||
|
||||
// This contract serves as a Data Attestation Verifier for the EZKL model.
|
||||
|
||||
@@ -2,8 +2,7 @@ use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::{
|
||||
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
|
||||
};
|
||||
use ezkl::fieldutils;
|
||||
use ezkl::fieldutils::i32_to_felt;
|
||||
use ezkl::fieldutils::{self, integer_rep_to_felt, IntegerRep};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
@@ -42,8 +41,8 @@ const NUM_INNER_COLS: usize = 1;
|
||||
struct Config<
|
||||
const LEN: usize, //LEN = CHOUT x OH x OW flattened //not supported yet in rust stable
|
||||
const CLASSES: usize,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
const LOOKUP_MIN: IntegerRep,
|
||||
const LOOKUP_MAX: IntegerRep,
|
||||
// Convolution
|
||||
const KERNEL_HEIGHT: usize,
|
||||
const KERNEL_WIDTH: usize,
|
||||
@@ -66,8 +65,8 @@ struct Config<
|
||||
struct MyCircuit<
|
||||
const LEN: usize, //LEN = CHOUT x OH x OW flattened
|
||||
const CLASSES: usize,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
const LOOKUP_MIN: IntegerRep,
|
||||
const LOOKUP_MAX: IntegerRep,
|
||||
// Convolution
|
||||
const KERNEL_HEIGHT: usize,
|
||||
const KERNEL_WIDTH: usize,
|
||||
@@ -90,8 +89,8 @@ struct MyCircuit<
|
||||
impl<
|
||||
const LEN: usize,
|
||||
const CLASSES: usize,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
const LOOKUP_MIN: IntegerRep,
|
||||
const LOOKUP_MAX: IntegerRep,
|
||||
// Convolution
|
||||
const KERNEL_HEIGHT: usize,
|
||||
const KERNEL_WIDTH: usize,
|
||||
@@ -205,6 +204,7 @@ where
|
||||
let op = PolyOp::Conv {
|
||||
padding: vec![(PADDING, PADDING); 2],
|
||||
stride: vec![STRIDE; 2],
|
||||
group: 1,
|
||||
};
|
||||
let x = config
|
||||
.layer_config
|
||||
@@ -315,7 +315,11 @@ pub fn runconv() {
|
||||
.test_set_length(10_000)
|
||||
.finalize();
|
||||
|
||||
let mut train_data = Tensor::from(trn_img.iter().map(|x| i32_to_felt::<F>(*x as i32 / 16)));
|
||||
let mut train_data = Tensor::from(
|
||||
trn_img
|
||||
.iter()
|
||||
.map(|x| integer_rep_to_felt::<F>(*x as IntegerRep / 16)),
|
||||
);
|
||||
train_data.reshape(&[50_000, 28, 28]).unwrap();
|
||||
|
||||
let mut train_labels = Tensor::from(trn_lbl.iter().map(|x| *x as f32));
|
||||
@@ -343,8 +347,8 @@ pub fn runconv() {
|
||||
.map(|fl| {
|
||||
let dx = fl * 32_f32;
|
||||
let rounded = dx.round();
|
||||
let integral: i32 = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::i32_to_felt(integral)
|
||||
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::integer_rep_to_felt(integral)
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -355,7 +359,8 @@ pub fn runconv() {
|
||||
|
||||
let l0_kernels = l0_kernels.try_into().unwrap();
|
||||
|
||||
let mut l0_bias = Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::i32_to_felt(0)));
|
||||
let mut l0_bias =
|
||||
Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::integer_rep_to_felt(0)));
|
||||
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
|
||||
let l0_bias = l0_bias.try_into().unwrap();
|
||||
@@ -363,8 +368,8 @@ pub fn runconv() {
|
||||
let mut l2_biases = Tensor::<F>::from(myparams.biases.into_iter().map(|fl| {
|
||||
let dx = fl * 32_f32;
|
||||
let rounded = dx.round();
|
||||
let integral: i32 = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::i32_to_felt(integral)
|
||||
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::integer_rep_to_felt(integral)
|
||||
}));
|
||||
l2_biases.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
l2_biases.reshape(&[l2_biases.len(), 1]).unwrap();
|
||||
@@ -374,8 +379,8 @@ pub fn runconv() {
|
||||
let mut l2_weights = Tensor::<F>::from(myparams.weights.into_iter().flatten().map(|fl| {
|
||||
let dx = fl * 32_f32;
|
||||
let rounded = dx.round();
|
||||
let integral: i32 = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::i32_to_felt(integral)
|
||||
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::integer_rep_to_felt(integral)
|
||||
}));
|
||||
l2_weights.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
l2_weights.reshape(&[CLASSES, LEN]).unwrap();
|
||||
@@ -401,13 +406,13 @@ pub fn runconv() {
|
||||
l2_params: [l2_weights, l2_biases],
|
||||
};
|
||||
|
||||
let public_input: Tensor<i32> = vec![
|
||||
-25124i32, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
|
||||
let public_input: Tensor<IntegerRep> = vec![
|
||||
-25124, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
|
||||
]
|
||||
.into_iter()
|
||||
.into();
|
||||
|
||||
let pi_inner: Tensor<F> = public_input.map(i32_to_felt::<F>);
|
||||
let pi_inner: Tensor<F> = public_input.map(integer_rep_to_felt::<F>);
|
||||
|
||||
println!("MOCK PROVING");
|
||||
let now = Instant::now();
|
||||
|
||||
@@ -2,7 +2,7 @@ use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::{
|
||||
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
|
||||
};
|
||||
use ezkl::fieldutils::i32_to_felt;
|
||||
use ezkl::fieldutils::{integer_rep_to_felt, IntegerRep};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::{
|
||||
@@ -23,8 +23,8 @@ struct MyConfig {
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<
|
||||
const LEN: usize, //LEN = CHOUT x OH x OW flattened
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
const LOOKUP_MIN: IntegerRep,
|
||||
const LOOKUP_MAX: IntegerRep,
|
||||
> {
|
||||
// Given the stateless MyConfig type information, a DNN trace is determined by its input and the parameters of its layers.
|
||||
// Computing the trace still requires a forward pass. The intermediate activations are stored only by the layouter.
|
||||
@@ -34,7 +34,7 @@ struct MyCircuit<
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
|
||||
impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRep> Circuit<F>
|
||||
for MyCircuit<LEN, LOOKUP_MIN, LOOKUP_MAX>
|
||||
{
|
||||
type Config = MyConfig;
|
||||
@@ -215,33 +215,33 @@ pub fn runmlp() {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
env_logger::init();
|
||||
// parameters
|
||||
let mut l0_kernel: Tensor<F> = Tensor::<i32>::new(
|
||||
let mut l0_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
|
||||
Some(&[10, 0, 0, -1, 0, 10, 1, 0, 0, 1, 10, 0, 1, 0, 0, 10]),
|
||||
&[4, 4],
|
||||
)
|
||||
.unwrap()
|
||||
.map(i32_to_felt);
|
||||
.map(integer_rep_to_felt);
|
||||
l0_kernel.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
|
||||
let mut l0_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
|
||||
let mut l0_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
|
||||
.unwrap()
|
||||
.map(i32_to_felt);
|
||||
.map(integer_rep_to_felt);
|
||||
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
|
||||
let mut l2_kernel: Tensor<F> = Tensor::<i32>::new(
|
||||
let mut l2_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
|
||||
Some(&[0, 3, 10, -1, 0, 10, 1, 0, 0, 1, 0, 12, 1, -2, 32, 0]),
|
||||
&[4, 4],
|
||||
)
|
||||
.unwrap()
|
||||
.map(i32_to_felt);
|
||||
.map(integer_rep_to_felt);
|
||||
l2_kernel.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
// input data, with 1 padding to allow for bias
|
||||
let input: Tensor<Value<F>> = Tensor::<i32>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
|
||||
let input: Tensor<Value<F>> = Tensor::<IntegerRep>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
|
||||
.unwrap()
|
||||
.into();
|
||||
let mut l2_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
|
||||
let mut l2_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
|
||||
.unwrap()
|
||||
.map(i32_to_felt);
|
||||
.map(integer_rep_to_felt);
|
||||
l2_bias.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
|
||||
let circuit = MyCircuit::<4, -8192, 8192> {
|
||||
@@ -251,12 +251,12 @@ pub fn runmlp() {
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let public_input: Vec<i32> = unsafe {
|
||||
let public_input: Vec<IntegerRep> = unsafe {
|
||||
vec![
|
||||
(531f32 / 128f32).round().to_int_unchecked::<i32>(),
|
||||
(103f32 / 128f32).round().to_int_unchecked::<i32>(),
|
||||
(4469f32 / 128f32).round().to_int_unchecked::<i32>(),
|
||||
(2849f32 / 128f32).to_int_unchecked::<i32>(),
|
||||
(531f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
|
||||
(103f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
|
||||
(4469f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
|
||||
(2849f32 / 128f32).to_int_unchecked::<IntegerRep>(),
|
||||
]
|
||||
};
|
||||
|
||||
@@ -265,7 +265,10 @@ pub fn runmlp() {
|
||||
let prover = MockProver::run(
|
||||
K as u32,
|
||||
&circuit,
|
||||
vec![public_input.iter().map(|x| i32_to_felt::<F>(*x)).collect()],
|
||||
vec![public_input
|
||||
.iter()
|
||||
.map(|x| integer_rep_to_felt::<F>(*x))
|
||||
.collect()],
|
||||
)
|
||||
.unwrap();
|
||||
prover.assert_satisfied();
|
||||
|
||||
@@ -39,7 +39,7 @@
|
||||
"import json\n",
|
||||
"import numpy as np\n",
|
||||
"from sklearn.svm import SVC\n",
|
||||
"import sk2torch\n",
|
||||
"from hummingbird.ml import convert\n",
|
||||
"import torch\n",
|
||||
"import ezkl\n",
|
||||
"import os\n",
|
||||
@@ -59,11 +59,11 @@
|
||||
"# Train an SVM on the data and wrap it in PyTorch.\n",
|
||||
"sk_model = SVC(probability=True)\n",
|
||||
"sk_model.fit(xs, ys)\n",
|
||||
"model = sk2torch.wrap(sk_model)\n",
|
||||
"\n",
|
||||
"model = convert(sk_model, \"torch\").model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"model\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
@@ -84,33 +84,6 @@
|
||||
"data_path = os.path.join('input.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7f0ca328",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"# Create a coordinate grid to compute a vector field on.\n",
|
||||
"spaced = np.linspace(-2, 2, num=25)\n",
|
||||
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Compute the gradients of the SVM output.\n",
|
||||
"outputs = model.predict_proba(grid_xs)[:, 1]\n",
|
||||
"(input_grads,) = torch.autograd.grad(outputs.sum(), (grid_xs,))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create a quiver plot of the vector field.\n",
|
||||
"plt.quiver(\n",
|
||||
" grid_xs[:, 0].detach().numpy(),\n",
|
||||
" grid_xs[:, 1].detach().numpy(),\n",
|
||||
" input_grads[:, 0].detach().numpy(),\n",
|
||||
" input_grads[:, 1].detach().numpy(),\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -119,14 +92,14 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"spaced = np.linspace(-2, 2, num=25)\n",
|
||||
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
|
||||
"# export to onnx format\n",
|
||||
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
|
||||
"\n",
|
||||
"# Input to the model\n",
|
||||
"shape = xs.shape[1:]\n",
|
||||
"x = grid_xs[0:1]\n",
|
||||
"torch_out = model.predict(x)\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(model, # model being run\n",
|
||||
" # model input (or a tuple for multiple inputs)\n",
|
||||
@@ -143,9 +116,7 @@
|
||||
"\n",
|
||||
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_shapes=[shape],\n",
|
||||
" input_data=[d],\n",
|
||||
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
|
||||
"data = dict(input_data=[d])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n"
|
||||
@@ -167,6 +138,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0bee4d7f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -220,7 +192,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"id": "b1c561a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -441,9 +413,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
BIN
examples/onnx/1l_conv_transpose/network.compiled
Normal file
BIN
examples/onnx/1l_conv_transpose/network.compiled
Normal file
Binary file not shown.
1
examples/onnx/1l_conv_transpose/settings.json
Normal file
1
examples/onnx/1l_conv_transpose/settings.json
Normal file
@@ -0,0 +1 @@
|
||||
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":7,"param_scale":7,"scale_rebase_multiplier":10,"lookup_range":[0,0],"logrows":13,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Private"},"num_constraints":5619,"total_const_size":513,"model_instance_shapes":[[1,3,10,10]],"model_output_scales":[14],"model_input_scales":[7],"module_sizes":{"kzg":[],"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}
|
||||
1
examples/onnx/lenet_5/input.json
Normal file
1
examples/onnx/lenet_5/input.json
Normal file
File diff suppressed because one or more lines are too long
BIN
examples/onnx/lenet_5/network.onnx
Normal file
BIN
examples/onnx/lenet_5/network.onnx
Normal file
Binary file not shown.
1
examples/onnx/smallworm/.gitattributes
vendored
Normal file
1
examples/onnx/smallworm/.gitattributes
vendored
Normal file
@@ -0,0 +1 @@
|
||||
network.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
47
examples/onnx/smallworm/Readme.md
Normal file
47
examples/onnx/smallworm/Readme.md
Normal file
@@ -0,0 +1,47 @@
|
||||
## The worm
|
||||
|
||||
This is an onnx file for a [WormVAE](https://github.com/TuragaLab/wormvae?tab=readme-ov-file) model, which is a VAE / latent-space representation of the C. elegans connectome.
|
||||
|
||||
The model "is a large-scale latent variable model with a very high-dimensional latent space
|
||||
consisting of voltage dynamics of 300 neurons over 5 minutes of time at the simulation frequency
|
||||
of 160 Hz. The generative model for these latent variables is described by stochastic differential
|
||||
equations modeling the nonlinear dynamics of the network activity." (see [here](https://openreview.net/pdf?id=CJzi3dRlJE-)).
|
||||
|
||||
In effect this is a generative model for a worm's voltage dynamics, which can be used to generate new worm-like voltage dynamics given previous connectome state.
|
||||
|
||||
Using ezkl you can create a zk circuit equivalent to the wormvae model, allowing you to "prove" execution of the worm model. If you're feeling particularly adventurous, you can also use the zk circuit to generate new worm-state that can be verified on chain.
|
||||
|
||||
To do so you'll first want to fetch the files using git-lfs (as the onnx file is too large to be stored in git).
|
||||
|
||||
```bash
|
||||
git lfs fetch --all
|
||||
```
|
||||
|
||||
You'll then want to use the usual ezkl loop to generate the zk circuit. We recommend using fixed visibility for the model parameters, as the model is quite large and this will prune the circuit significantly.
|
||||
|
||||
```bash
|
||||
ezkl gen-settings --param-visibility=fixed
|
||||
cp input.json calibration.json
|
||||
ezkl calibrate-settings
|
||||
ezkl compile-circuit
|
||||
ezkl gen-witness
|
||||
ezkl prove
|
||||
```
|
||||
|
||||
You might also need to aggregate the proof to get it to fit on chain.
|
||||
|
||||
```bash
|
||||
ezkl aggregate
|
||||
```
|
||||
|
||||
You can then create a smart contract that verifies this aggregate proof
|
||||
|
||||
```bash
|
||||
ezkl create-evm-verifier-aggr
|
||||
```
|
||||
|
||||
This can then be deployed on the chain of your choice.
|
||||
|
||||
|
||||
> Note: the model is large and thus we recommend a machine with at least 512GB of RAM to run the above commands. If you're ever compute constrained you can always use the lilith service to generate the zk circuit. Message us on discord or telegram for more details :)
|
||||
|
||||
1
examples/onnx/smallworm/input.json
Normal file
1
examples/onnx/smallworm/input.json
Normal file
File diff suppressed because one or more lines are too long
3
examples/onnx/smallworm/network.onnx
Normal file
3
examples/onnx/smallworm/network.onnx
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2f88c5901d3768ec21e3cf2f2840d255e84fa13c364df86b24d960cca3333769
|
||||
size 82095882
|
||||
1
examples/onnx/smallworm/settings.json
Normal file
1
examples/onnx/smallworm/settings.json
Normal file
@@ -0,0 +1 @@
|
||||
{"run_args":{"tolerance":{"val":0.0,"scale":1.0},"input_scale":0,"param_scale":6,"scale_rebase_multiplier":1,"lookup_range":[-32768,32768],"logrows":17,"variables":[["batch_size",1]],"input_visibility":"Private","output_visibility":"Public","param_visibility":"Fixed"},"num_constraints":367422820,"total_const_size":365577160,"model_instance_shapes":[[1,300,1200]],"model_output_scales":[6],"model_input_scales":[0,0,0],"module_sizes":{"kzg":[],"poseidon":[0,[0]],"elgamal":[0,[0]]},"required_lookups":[{"Div":{"denom":64.0}},"ReLU",{"Ln":{"scale":64.0}},{"Exp":{"scale":64.0}}],"check_mode":"UNSAFE","version":"0.0.0","num_blinding_factors":null}
|
||||
@@ -9,7 +9,6 @@ import { EVM } from '@ethereumjs/evm'
|
||||
import { buildTransaction, encodeDeployment } from './utils/tx-builder'
|
||||
import { getAccountNonce, insertAccount } from './utils/account-utils'
|
||||
import { encodeVerifierCalldata } from '../nodejs/ezkl';
|
||||
import { error } from 'console'
|
||||
|
||||
async function deployContract(
|
||||
vm: VM,
|
||||
@@ -66,7 +65,7 @@ async function verify(
|
||||
vkAddress = new Uint8Array(uint8Array.buffer);
|
||||
|
||||
// convert uitn8array of length
|
||||
error('vkAddress', vkAddress)
|
||||
console.error('vkAddress', vkAddress)
|
||||
}
|
||||
const data = encodeVerifierCalldata(proof, vkAddress)
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
[toolchain]
|
||||
channel = "nightly-2024-02-06"
|
||||
channel = "nightly-2024-07-18"
|
||||
components = ["rustfmt", "clippy"]
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
// ignore file if compiling for wasm
|
||||
#[global_allocator]
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
static GLOBAL: mimalloc::MiMalloc = mimalloc::MiMalloc;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use clap::{CommandFactory, Parser};
|
||||
|
||||
@@ -22,7 +22,7 @@ use crate::{
|
||||
table::{Range, RangeCheck, Table},
|
||||
utils,
|
||||
},
|
||||
tensor::{IntoI64, Tensor, TensorType, ValTensor, VarTensor},
|
||||
tensor::{Tensor, TensorType, ValTensor, VarTensor},
|
||||
};
|
||||
use std::{collections::BTreeMap, marker::PhantomData};
|
||||
|
||||
@@ -327,7 +327,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseConfig<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
Self {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::convert::Infallible;
|
||||
|
||||
use crate::tensor::TensorError;
|
||||
use crate::{fieldutils::IntegerRep, tensor::TensorError};
|
||||
use halo2_proofs::plonk::Error as PlonkError;
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -57,7 +57,7 @@ pub enum CircuitError {
|
||||
InvalidConversion(#[from] Infallible),
|
||||
/// Invalid min/max lookup range
|
||||
#[error("invalid min/max lookup range: min: {0}, max: {1}")]
|
||||
InvalidMinMaxRange(i64, i64),
|
||||
InvalidMinMaxRange(IntegerRep, IntegerRep),
|
||||
/// Missing product in einsum
|
||||
#[error("missing product in einsum")]
|
||||
MissingEinsumProduct,
|
||||
@@ -81,7 +81,7 @@ pub enum CircuitError {
|
||||
MissingSelectors(String),
|
||||
/// Table lookup error
|
||||
#[error("value ({0}) out of range: ({1}, {2})")]
|
||||
TableOOR(i64, i64, i64),
|
||||
TableOOR(IntegerRep, IntegerRep, IntegerRep),
|
||||
/// Loookup not configured
|
||||
#[error("lookup not configured: {0}")]
|
||||
LookupNotConfigured(String),
|
||||
@@ -91,4 +91,7 @@ pub enum CircuitError {
|
||||
/// Missing layout
|
||||
#[error("missing layout for op: {0}")]
|
||||
MissingLayout(String),
|
||||
#[error("[io] {0}")]
|
||||
/// IO error
|
||||
IoError(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::i64_to_felt,
|
||||
fieldutils::integer_rep_to_felt,
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
@@ -71,7 +71,7 @@ pub enum HybridOp {
|
||||
},
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for HybridOp {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for HybridOp {
|
||||
///
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
match self {
|
||||
@@ -184,8 +184,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
i64_to_felt(input_scale.0 as i64),
|
||||
i64_to_felt(output_scale.0 as i64),
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
@@ -209,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
i64_to_felt(denom.0 as i64),
|
||||
integer_rep_to_felt(denom.0 as i128),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,9 +3,9 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
circuit::{layouts, table::Range, utils},
|
||||
fieldutils::{felt_to_i64, i64_to_felt},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, IntoI64, Tensor, TensorError, TensorType},
|
||||
tensor::{self, Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
use super::Op;
|
||||
@@ -131,16 +131,63 @@ impl LookupOp {
|
||||
/// Returns the range of values that can be represented by the table
|
||||
pub fn bit_range(max_len: usize) -> Range {
|
||||
let range = (max_len - 1) as f64 / 2_f64;
|
||||
let range = range as i64;
|
||||
let range = range as IntegerRep;
|
||||
(-range, range)
|
||||
}
|
||||
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
match self {
|
||||
LookupOp::Abs => "abs".into(),
|
||||
LookupOp::Ceil { scale } => format!("ceil_{}", scale),
|
||||
LookupOp::Floor { scale } => format!("floor_{}", scale),
|
||||
LookupOp::Round { scale } => format!("round_{}", scale),
|
||||
LookupOp::RoundHalfToEven { scale } => format!("round_half_to_even_{}", scale),
|
||||
LookupOp::Pow { scale, a } => format!("pow_{}_{}", scale, a),
|
||||
LookupOp::KroneckerDelta => "kronecker_delta".into(),
|
||||
LookupOp::Max { scale, a } => format!("max_{}_{}", scale, a),
|
||||
LookupOp::Min { scale, a } => format!("min_{}_{}", scale, a),
|
||||
LookupOp::Sign => "sign".into(),
|
||||
LookupOp::LessThan { a } => format!("less_than_{}", a),
|
||||
LookupOp::LessThanEqual { a } => format!("less_than_equal_{}", a),
|
||||
LookupOp::GreaterThan { a } => format!("greater_than_{}", a),
|
||||
LookupOp::GreaterThanEqual { a } => format!("greater_than_equal_{}", a),
|
||||
LookupOp::Div { denom } => format!("div_{}", denom),
|
||||
LookupOp::Cast { scale } => format!("cast_{}", scale),
|
||||
LookupOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
} => format!("recip_{}_{}", input_scale, output_scale),
|
||||
LookupOp::ReLU => "relu".to_string(),
|
||||
LookupOp::LeakyReLU { slope: a } => format!("leaky_relu_{}", a),
|
||||
LookupOp::Sigmoid { scale } => format!("sigmoid_{}", scale),
|
||||
LookupOp::Sqrt { scale } => format!("sqrt_{}", scale),
|
||||
LookupOp::Rsqrt { scale } => format!("rsqrt_{}", scale),
|
||||
LookupOp::Erf { scale } => format!("erf_{}", scale),
|
||||
LookupOp::Exp { scale } => format!("exp_{}", scale),
|
||||
LookupOp::Ln { scale } => format!("ln_{}", scale),
|
||||
LookupOp::Cos { scale } => format!("cos_{}", scale),
|
||||
LookupOp::ACos { scale } => format!("acos_{}", scale),
|
||||
LookupOp::Cosh { scale } => format!("cosh_{}", scale),
|
||||
LookupOp::ACosh { scale } => format!("acosh_{}", scale),
|
||||
LookupOp::Sin { scale } => format!("sin_{}", scale),
|
||||
LookupOp::ASin { scale } => format!("asin_{}", scale),
|
||||
LookupOp::Sinh { scale } => format!("sinh_{}", scale),
|
||||
LookupOp::ASinh { scale } => format!("asinh_{}", scale),
|
||||
LookupOp::Tan { scale } => format!("tan_{}", scale),
|
||||
LookupOp::ATan { scale } => format!("atan_{}", scale),
|
||||
LookupOp::ATanh { scale } => format!("atanh_{}", scale),
|
||||
LookupOp::Tanh { scale } => format!("tanh_{}", scale),
|
||||
LookupOp::HardSwish { scale } => format!("hardswish_{}", scale),
|
||||
}
|
||||
}
|
||||
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
|
||||
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
x: &[Tensor<F>],
|
||||
) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = x[0].clone().map(|x| felt_to_i64(x));
|
||||
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
|
||||
let res = match &self {
|
||||
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
|
||||
LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())),
|
||||
@@ -227,13 +274,13 @@ impl LookupOp {
|
||||
}
|
||||
}?;
|
||||
|
||||
let output = res.map(|x| i64_to_felt(x));
|
||||
let output = res.map(|x| integer_rep_to_felt(x));
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for LookupOp {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
|
||||
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
graph::quantize_tensor,
|
||||
tensor::{self, IntoI64, Tensor, TensorType, ValTensor},
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
@@ -31,12 +31,12 @@ pub use errors::CircuitError;
|
||||
|
||||
/// A struct representing the result of a forward pass.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
pub(crate) output: Tensor<F>,
|
||||
}
|
||||
|
||||
/// A trait representing operations that can be represented as constraints in a circuit.
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>:
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
std::fmt::Debug + Send + Sync + Any
|
||||
{
|
||||
/// Returns a string representation of the operation.
|
||||
@@ -75,7 +75,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Clone for Box<dyn Op<F>> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
|
||||
fn clone(&self) -> Self {
|
||||
self.clone_dyn()
|
||||
}
|
||||
@@ -142,7 +142,7 @@ pub struct Input {
|
||||
pub datum_type: InputType,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Input {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.scale)
|
||||
}
|
||||
@@ -197,7 +197,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Unknown;
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Unknown {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(0)
|
||||
}
|
||||
@@ -224,7 +224,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
|
||||
///
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
|
||||
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
///
|
||||
pub quantized_values: Tensor<F>,
|
||||
///
|
||||
@@ -234,7 +234,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash +
|
||||
pub pre_assigned_val: Option<ValTensor<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Constant<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
///
|
||||
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
|
||||
Self {
|
||||
@@ -267,8 +267,7 @@ impl<
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>
|
||||
+ IntoI64,
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
{
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
|
||||
@@ -33,6 +33,7 @@ pub enum PolyOp {
|
||||
Conv {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
group: usize,
|
||||
},
|
||||
Downsample {
|
||||
axis: usize,
|
||||
@@ -43,6 +44,7 @@ pub enum PolyOp {
|
||||
padding: Vec<(usize, usize)>,
|
||||
output_padding: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
group: usize,
|
||||
},
|
||||
Add,
|
||||
Sub,
|
||||
@@ -98,7 +100,7 @@ impl<
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>
|
||||
+ IntoI64,
|
||||
,
|
||||
> Op<F> for PolyOp
|
||||
{
|
||||
/// Returns a reference to the Any trait.
|
||||
@@ -148,17 +150,25 @@ impl<
|
||||
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
|
||||
PolyOp::Prod { .. } => "PROD".into(),
|
||||
PolyOp::Pow(_) => "POW".into(),
|
||||
PolyOp::Conv { stride, padding } => {
|
||||
format!("CONV (stride={:?}, padding={:?})", stride, padding)
|
||||
PolyOp::Conv {
|
||||
stride,
|
||||
padding,
|
||||
group,
|
||||
} => {
|
||||
format!(
|
||||
"CONV (stride={:?}, padding={:?}, group={})",
|
||||
stride, padding, group
|
||||
)
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
group,
|
||||
} => {
|
||||
format!(
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
|
||||
stride, padding, output_padding
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
|
||||
stride, padding, output_padding, group
|
||||
)
|
||||
}
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
@@ -212,9 +222,18 @@ impl<
|
||||
PolyOp::Prod { axes, .. } => {
|
||||
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
|
||||
}
|
||||
PolyOp::Conv { padding, stride } => {
|
||||
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
|
||||
}
|
||||
PolyOp::Conv {
|
||||
padding,
|
||||
stride,
|
||||
group,
|
||||
} => layouts::conv(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
padding,
|
||||
stride,
|
||||
*group,
|
||||
)?,
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
@@ -261,6 +280,7 @@ impl<
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
group,
|
||||
} => layouts::deconv(
|
||||
config,
|
||||
region,
|
||||
@@ -268,6 +288,7 @@ impl<
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
*group,
|
||||
)?,
|
||||
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
|
||||
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::{
|
||||
circuit::table::Range,
|
||||
fieldutils::IntegerRep,
|
||||
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
|
||||
};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -9,7 +10,8 @@ use halo2_proofs::{
|
||||
plonk::{Error, Selector},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use portable_atomic::AtomicI64 as AtomicInt;
|
||||
use itertools::Itertools;
|
||||
use maybe_rayon::iter::ParallelExtend;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
collections::{HashMap, HashSet},
|
||||
@@ -84,6 +86,78 @@ impl ShuffleIndex {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Some settings for a region to differentiate it across the different phases of proof generation
|
||||
pub struct RegionSettings {
|
||||
/// whether we are in witness generation mode
|
||||
pub witness_gen: bool,
|
||||
/// whether we should check range checks for validity
|
||||
pub check_range: bool,
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Sync for RegionSettings {}
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Send for RegionSettings {}
|
||||
|
||||
impl RegionSettings {
|
||||
/// Create a new region settings
|
||||
pub fn new(witness_gen: bool, check_range: bool) -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen,
|
||||
check_range,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region settings with all true
|
||||
pub fn all_true() -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen: true,
|
||||
check_range: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region settings with all false
|
||||
pub fn all_false() -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen: false,
|
||||
check_range: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
/// Region statistics
|
||||
pub struct RegionStatistics {
|
||||
/// the current maximum value of the lookup inputs
|
||||
pub max_lookup_inputs: IntegerRep,
|
||||
/// the current minimum value of the lookup inputs
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
/// the current maximum value of the range size
|
||||
pub max_range_size: IntegerRep,
|
||||
/// the current set of used lookups
|
||||
pub used_lookups: HashSet<LookupOp>,
|
||||
/// the current set of used range checks
|
||||
pub used_range_checks: HashSet<Range>,
|
||||
}
|
||||
|
||||
impl RegionStatistics {
|
||||
/// update the statistics with another set of statistics
|
||||
pub fn update(&mut self, other: &RegionStatistics) {
|
||||
self.max_lookup_inputs = self.max_lookup_inputs.max(other.max_lookup_inputs);
|
||||
self.min_lookup_inputs = self.min_lookup_inputs.min(other.min_lookup_inputs);
|
||||
self.max_range_size = self.max_range_size.max(other.max_range_size);
|
||||
self.used_lookups.extend(other.used_lookups.clone());
|
||||
self.used_range_checks
|
||||
.extend(other.used_range_checks.clone());
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Sync for RegionStatistics {}
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Send for RegionStatistics {}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// A context for a region
|
||||
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
@@ -93,13 +167,8 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
num_inner_cols: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
shuffle_index: ShuffleIndex,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
max_lookup_inputs: i64,
|
||||
min_lookup_inputs: i64,
|
||||
max_range_size: i64,
|
||||
witness_gen: bool,
|
||||
check_lookup_range: bool,
|
||||
statistics: RegionStatistics,
|
||||
settings: RegionSettings,
|
||||
assigned_constants: ConstantsMap<F>,
|
||||
}
|
||||
|
||||
@@ -151,12 +220,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
|
||||
///
|
||||
pub fn witness_gen(&self) -> bool {
|
||||
self.witness_gen
|
||||
self.settings.witness_gen
|
||||
}
|
||||
|
||||
///
|
||||
pub fn check_lookup_range(&self) -> bool {
|
||||
self.check_lookup_range
|
||||
pub fn check_range(&self) -> bool {
|
||||
self.settings.check_range
|
||||
}
|
||||
|
||||
///
|
||||
pub fn statistics(&self) -> &RegionStatistics {
|
||||
&self.statistics
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
@@ -171,13 +245,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
linear_coord,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen: true,
|
||||
check_lookup_range: true,
|
||||
statistics: RegionStatistics::default(),
|
||||
settings: RegionSettings::all_true(),
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@@ -193,39 +262,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
new_self.assigned_constants = constants;
|
||||
new_self
|
||||
}
|
||||
/// Create a new region context from a wrapped region
|
||||
pub fn from_wrapped_region(
|
||||
region: Option<RefCell<Region<'a, F>>>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
shuffle_index: ShuffleIndex,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let linear_coord = row * num_inner_cols;
|
||||
RegionCtx {
|
||||
region,
|
||||
num_inner_cols,
|
||||
linear_coord,
|
||||
row,
|
||||
dynamic_lookup_index,
|
||||
shuffle_index,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen: false,
|
||||
check_lookup_range: false,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_dummy(
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
witness_gen: bool,
|
||||
check_lookup_range: bool,
|
||||
settings: RegionSettings,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
let linear_coord = row * num_inner_cols;
|
||||
@@ -237,13 +279,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
row,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen,
|
||||
check_lookup_range,
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@@ -253,8 +290,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
row: usize,
|
||||
linear_coord: usize,
|
||||
num_inner_cols: usize,
|
||||
witness_gen: bool,
|
||||
check_lookup_range: bool,
|
||||
settings: RegionSettings,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
RegionCtx {
|
||||
@@ -264,13 +300,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
row,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen,
|
||||
check_lookup_range,
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@@ -321,12 +352,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
) -> Result<(), CircuitError> {
|
||||
let row = AtomicUsize::new(self.row());
|
||||
let linear_coord = AtomicUsize::new(self.linear_coord());
|
||||
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
|
||||
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
|
||||
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
|
||||
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
|
||||
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
|
||||
let statistics = Arc::new(Mutex::new(self.statistics.clone()));
|
||||
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
|
||||
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
|
||||
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
|
||||
|
||||
*output = output.par_enum_map(|idx, _| {
|
||||
@@ -340,8 +368,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
starting_offset,
|
||||
starting_linear_coord,
|
||||
self.num_inner_cols,
|
||||
self.witness_gen,
|
||||
self.check_lookup_range,
|
||||
self.settings.clone(),
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
@@ -351,14 +378,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
|
||||
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
|
||||
// update the lookups
|
||||
let mut lookups = lookups.lock().unwrap();
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
// update the range checks
|
||||
let mut range_checks = range_checks.lock().unwrap();
|
||||
range_checks.extend(local_reg.used_range_checks());
|
||||
let mut statistics = statistics.lock().unwrap();
|
||||
statistics.update(local_reg.statistics());
|
||||
// update the dynamic lookup index
|
||||
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
|
||||
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
|
||||
@@ -372,20 +394,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
res
|
||||
})?;
|
||||
self.linear_coord = linear_coord.into_inner();
|
||||
#[allow(trivial_numeric_casts)]
|
||||
{
|
||||
self.max_lookup_inputs = max_lookup_inputs.into_inner();
|
||||
self.min_lookup_inputs = min_lookup_inputs.into_inner();
|
||||
}
|
||||
self.row = row.into_inner();
|
||||
self.used_lookups = Arc::try_unwrap(lookups)
|
||||
self.statistics = Arc::try_unwrap(statistics)
|
||||
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?;
|
||||
self.used_range_checks = Arc::try_unwrap(range_checks)
|
||||
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?;
|
||||
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
|
||||
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
@@ -409,11 +422,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
) -> Result<(), CircuitError> {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in inputs {
|
||||
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
|
||||
min = min.min(i.get_int_evals()?.into_iter().min().unwrap_or_default());
|
||||
max = max.max(i.int_evals()?.into_iter().max().unwrap_or_default());
|
||||
min = min.min(i.int_evals()?.into_iter().min().unwrap_or_default());
|
||||
}
|
||||
self.max_lookup_inputs = self.max_lookup_inputs.max(max);
|
||||
self.min_lookup_inputs = self.min_lookup_inputs.min(min);
|
||||
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
|
||||
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -425,7 +438,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
|
||||
let range_size = (range.1 - range.0).abs();
|
||||
|
||||
self.max_range_size = self.max_range_size.max(range_size);
|
||||
self.statistics.max_range_size = self.statistics.max_range_size.max(range_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -440,13 +453,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
lookup: LookupOp,
|
||||
inputs: &[ValTensor<F>],
|
||||
) -> Result<(), CircuitError> {
|
||||
self.used_lookups.insert(lookup);
|
||||
self.statistics.used_lookups.insert(lookup);
|
||||
self.update_max_min_lookup_inputs(inputs)
|
||||
}
|
||||
|
||||
/// add used range check
|
||||
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), CircuitError> {
|
||||
self.used_range_checks.insert(range);
|
||||
self.statistics.used_range_checks.insert(range);
|
||||
self.update_max_min_lookup_range(range)
|
||||
}
|
||||
|
||||
@@ -487,27 +500,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
|
||||
/// get used lookups
|
||||
pub fn used_lookups(&self) -> HashSet<LookupOp> {
|
||||
self.used_lookups.clone()
|
||||
self.statistics.used_lookups.clone()
|
||||
}
|
||||
|
||||
/// get used range checks
|
||||
pub fn used_range_checks(&self) -> HashSet<Range> {
|
||||
self.used_range_checks.clone()
|
||||
self.statistics.used_range_checks.clone()
|
||||
}
|
||||
|
||||
/// max lookup inputs
|
||||
pub fn max_lookup_inputs(&self) -> i64 {
|
||||
self.max_lookup_inputs
|
||||
pub fn max_lookup_inputs(&self) -> IntegerRep {
|
||||
self.statistics.max_lookup_inputs
|
||||
}
|
||||
|
||||
/// min lookup inputs
|
||||
pub fn min_lookup_inputs(&self) -> i64 {
|
||||
self.min_lookup_inputs
|
||||
pub fn min_lookup_inputs(&self) -> IntegerRep {
|
||||
self.statistics.min_lookup_inputs
|
||||
}
|
||||
|
||||
/// max range check
|
||||
pub fn max_range_size(&self) -> i64 {
|
||||
self.max_range_size
|
||||
pub fn max_range_size(&self) -> IntegerRep {
|
||||
self.statistics.max_range_size
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor
|
||||
@@ -515,18 +528,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if let Some(region) = &self.region {
|
||||
var.assign(
|
||||
Ok(var.assign(
|
||||
&mut region.borrow_mut(),
|
||||
self.linear_coord,
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
)?)
|
||||
} else {
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.extend(values_map);
|
||||
self.assigned_constants.par_extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
@@ -542,18 +555,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if let Some(region) = &self.region {
|
||||
var.assign(
|
||||
Ok(var.assign(
|
||||
&mut region.borrow_mut(),
|
||||
self.combined_dynamic_shuffle_coord(),
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
)?)
|
||||
} else {
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.extend(values_map);
|
||||
self.assigned_constants.par_extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
@@ -564,7 +577,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
self.assign_dynamic_lookup(var, values)
|
||||
}
|
||||
|
||||
@@ -573,27 +586,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
ommissions: &HashSet<&usize>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
ommissions: &HashSet<usize>,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if let Some(region) = &self.region {
|
||||
var.assign_with_omissions(
|
||||
Ok(var.assign_with_omissions(
|
||||
&mut region.borrow_mut(),
|
||||
self.linear_coord,
|
||||
values,
|
||||
ommissions,
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
)?)
|
||||
} else {
|
||||
let inner_tensor = values.get_inner_tensor().unwrap();
|
||||
let mut values_map = values.create_constants_map();
|
||||
let mut values_clone = values.clone();
|
||||
let mut indices = ommissions.clone().into_iter().collect_vec();
|
||||
values_clone.remove_indices(&mut indices, false)?;
|
||||
|
||||
for o in ommissions {
|
||||
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
|
||||
values_map.remove(&value);
|
||||
}
|
||||
}
|
||||
let values_map = values.create_constants_map();
|
||||
|
||||
self.assigned_constants.extend(values_map);
|
||||
self.assigned_constants.par_extend(values_map);
|
||||
|
||||
Ok(values.clone())
|
||||
}
|
||||
|
||||
@@ -11,20 +11,33 @@ use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator};
|
||||
|
||||
use crate::{
|
||||
circuit::CircuitError,
|
||||
fieldutils::i64_to_felt,
|
||||
tensor::{IntoI64, Tensor, TensorType},
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
tensor::{Tensor, TensorType},
|
||||
};
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::execute::EZKL_REPO_PATH;
|
||||
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
|
||||
/// The range of the lookup table.
|
||||
pub type Range = (i64, i64);
|
||||
pub type Range = (IntegerRep, IntegerRep);
|
||||
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: i64 = 2;
|
||||
pub const RANGE_MULTIPLIER: IntegerRep = 2;
|
||||
/// The safety factor offset for the number of rows in the lookup table.
|
||||
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
lazy_static::lazy_static! {
|
||||
/// an optional directory to read and write the lookup table cache
|
||||
pub static ref LOOKUP_CACHE: String = format!("{}/cache", *EZKL_REPO_PATH);
|
||||
}
|
||||
|
||||
/// The lookup table cache is disabled on wasm32 target.
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
pub const LOOKUP_CACHE: &str = "";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
///
|
||||
pub struct SelectorConstructor<F: PrimeField> {
|
||||
@@ -96,21 +109,22 @@ pub struct Table<F: PrimeField> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
let chunk =
|
||||
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
|
||||
let chunk = (crate::fieldutils::felt_to_integer_rep(input) - self.range.0).abs()
|
||||
/ (self.col_size as IntegerRep);
|
||||
|
||||
i64_to_felt(chunk)
|
||||
integer_rep_to_felt(chunk)
|
||||
}
|
||||
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self, chunk: usize) -> (F, F) {
|
||||
let chunk = chunk as i64;
|
||||
let chunk = chunk as IntegerRep;
|
||||
// we index from 1 to prevent soundness issues
|
||||
let first_element = i64_to_felt(chunk * (self.col_size as i64) + self.range.0);
|
||||
let first_element =
|
||||
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0);
|
||||
let op_f = self
|
||||
.nonlinearity
|
||||
.f(&[Tensor::from(vec![first_element].into_iter())])
|
||||
@@ -130,12 +144,20 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range_len: i64, col_size: usize) -> usize {
|
||||
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
// number of cols needed to store the range
|
||||
(range_len / (col_size as i64)) as usize + 1
|
||||
(range_len / (col_size as IntegerRep)) as usize + 1
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
fn name(&self) -> String {
|
||||
format!(
|
||||
"{}_{}_{}",
|
||||
self.nonlinearity.as_path(),
|
||||
self.range.0,
|
||||
self.range.1
|
||||
)
|
||||
}
|
||||
/// Configures the table.
|
||||
pub fn configure(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -202,8 +224,51 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
|
||||
let evals = self.nonlinearity.f(&[inputs.clone()])?;
|
||||
let gen_table = || -> Result<(Tensor<F>, Tensor<F>), crate::tensor::TensorError> {
|
||||
let inputs = Tensor::from(smallest..=largest)
|
||||
.par_enum_map(|_, x| Ok::<_, crate::tensor::TensorError>(integer_rep_to_felt(x)))?;
|
||||
let evals = self.nonlinearity.f(&[inputs.clone()])?;
|
||||
Ok((inputs, evals.output))
|
||||
};
|
||||
|
||||
let (inputs, evals) = if !LOOKUP_CACHE.is_empty() {
|
||||
let cache = std::path::Path::new(&*LOOKUP_CACHE);
|
||||
let cache_path = cache.join(self.name());
|
||||
let input_path = cache_path.join("inputs");
|
||||
let output_path = cache_path.join("outputs");
|
||||
if cache_path.exists() {
|
||||
log::info!("Loading lookup table from cache: {:?}", cache_path);
|
||||
let (input_cache, output_cache) =
|
||||
(Tensor::load(&input_path)?, Tensor::load(&output_path)?);
|
||||
(input_cache, output_cache)
|
||||
} else {
|
||||
log::info!(
|
||||
"Generating lookup table and saving to cache: {:?}",
|
||||
cache_path
|
||||
);
|
||||
|
||||
// mkdir -p cache_path
|
||||
std::fs::create_dir_all(&cache_path).map_err(|e| {
|
||||
CircuitError::TensorError(crate::tensor::TensorError::FileSaveError(
|
||||
e.to_string(),
|
||||
))
|
||||
})?;
|
||||
|
||||
let (inputs, evals) = gen_table()?;
|
||||
inputs.save(&input_path)?;
|
||||
evals.save(&output_path)?;
|
||||
|
||||
(inputs, evals)
|
||||
}
|
||||
} else {
|
||||
log::info!(
|
||||
"Generating lookup table {} without cache",
|
||||
self.nonlinearity.as_path()
|
||||
);
|
||||
|
||||
gen_table()?
|
||||
};
|
||||
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
self.is_assigned = true;
|
||||
@@ -235,7 +300,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<
|
||||
)?;
|
||||
}
|
||||
|
||||
let output = evals.output[row_offset];
|
||||
let output = evals[row_offset];
|
||||
|
||||
table.assign_cell(
|
||||
|| format!("nl_o_col row {}", row_offset),
|
||||
@@ -272,12 +337,17 @@ pub struct RangeCheck<F: PrimeField> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
/// as path
|
||||
pub fn as_path(&self) -> String {
|
||||
format!("rangecheck_{}_{}", self.range.0, self.range.1)
|
||||
}
|
||||
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self, chunk: usize) -> F {
|
||||
let chunk = chunk as i64;
|
||||
let chunk = chunk as IntegerRep;
|
||||
// we index from 1 to prevent soundness issues
|
||||
i64_to_felt(chunk * (self.col_size as i64) + self.range.0)
|
||||
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
|
||||
}
|
||||
|
||||
///
|
||||
@@ -293,14 +363,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeC
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
let chunk =
|
||||
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
|
||||
let chunk = (crate::fieldutils::felt_to_integer_rep(input) - self.range.0).abs()
|
||||
/ (self.col_size as IntegerRep);
|
||||
|
||||
i64_to_felt(chunk)
|
||||
integer_rep_to_felt(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
|
||||
log::debug!("range check range: {:?}", range);
|
||||
@@ -350,7 +420,32 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeC
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
|
||||
let inputs: Tensor<F> = if !LOOKUP_CACHE.is_empty() {
|
||||
let cache = std::path::Path::new(&*LOOKUP_CACHE);
|
||||
let cache_path = cache.join(self.as_path());
|
||||
let input_path = cache_path.join("inputs");
|
||||
if cache_path.exists() {
|
||||
log::info!("Loading range check table from cache: {:?}", cache_path);
|
||||
Tensor::load(&input_path)?
|
||||
} else {
|
||||
log::info!(
|
||||
"Generating range check table and saving to cache: {:?}",
|
||||
cache_path
|
||||
);
|
||||
|
||||
// mkdir -p cache_path
|
||||
std::fs::create_dir_all(&cache_path)?;
|
||||
|
||||
let inputs = Tensor::from(smallest..=largest).map(|x| integer_rep_to_felt(x));
|
||||
inputs.save(&input_path)?;
|
||||
inputs
|
||||
}
|
||||
} else {
|
||||
log::info!("Generating range check {} without cache", self.as_path());
|
||||
|
||||
Tensor::from(smallest..=largest).map(|x| integer_rep_to_felt(x))
|
||||
};
|
||||
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
self.is_assigned = true;
|
||||
|
||||
@@ -256,7 +256,7 @@ mod matmul_col_ultra_overflow_double_col {
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const LEN: usize = 20;
|
||||
const LEN: usize = 10;
|
||||
const NUM_INNER_COLS: usize = 2;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -374,7 +374,7 @@ mod matmul_col_ultra_overflow {
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const LEN: usize = 20;
|
||||
const LEN: usize = 10;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MatmulCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
@@ -1050,6 +1050,7 @@ mod conv {
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1158,7 +1159,7 @@ mod conv_col_ultra_overflow {
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const LEN: usize = 28;
|
||||
const LEN: usize = 10;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConvCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
@@ -1200,6 +1201,7 @@ mod conv_col_ultra_overflow {
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1296,7 +1298,7 @@ mod conv_relu_col_ultra_overflow {
|
||||
use super::*;
|
||||
|
||||
const K: usize = 4;
|
||||
const LEN: usize = 28;
|
||||
const LEN: usize = 15;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct ConvCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
@@ -1345,6 +1347,7 @@ mod conv_relu_col_ultra_overflow {
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis);
|
||||
|
||||
@@ -379,9 +379,9 @@ pub enum Commands {
|
||||
#[arg(long = "target", default_value = DEFAULT_CALIBRATION_TARGET, value_hint = clap::ValueHint::Other)]
|
||||
/// Target for calibration. Set to "resources" to optimize for computational resource. Otherwise, set to "accuracy" to optimize for accuracy.
|
||||
target: CalibrationTarget,
|
||||
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
|
||||
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be ceil(2^k * lookup_safety_margin). larger = safer but slower
|
||||
#[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN, value_hint = clap::ValueHint::Other)]
|
||||
lookup_safety_margin: i64,
|
||||
lookup_safety_margin: f64,
|
||||
/// Optional scales to specifically try for calibration. Example, --scales 0,4
|
||||
#[arg(long, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::Other)]
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
@@ -868,6 +868,7 @@ pub enum Commands {
|
||||
#[arg(long, value_hint = clap::ValueHint::Other)]
|
||||
addr_vk: Option<H160Flag>,
|
||||
},
|
||||
#[cfg(not(feature = "no-update"))]
|
||||
/// Updates ezkl binary to version specified (or latest if not specified)
|
||||
Update {
|
||||
/// The version to update to
|
||||
|
||||
99
src/eth.rs
99
src/eth.rs
@@ -327,11 +327,7 @@ pub async fn setup_eth_backend(
|
||||
ProviderBuilder::new()
|
||||
.with_recommended_fillers()
|
||||
.signer(EthereumSigner::from(wallet))
|
||||
.on_http(
|
||||
endpoint
|
||||
.parse()
|
||||
.map_err(|_| EthError::UrlParse(endpoint.clone()))?,
|
||||
),
|
||||
.on_http(endpoint.parse().map_err(|_| EthError::UrlParse(endpoint))?),
|
||||
);
|
||||
|
||||
let chain_id = client.get_chain_id().await?;
|
||||
@@ -354,8 +350,7 @@ pub async fn deploy_contract_via_solidity(
|
||||
let (abi, bytecode, runtime_bytecode) =
|
||||
get_contract_artifacts(sol_code_path, contract_name, runs).await?;
|
||||
|
||||
let factory =
|
||||
get_sol_contract_factory(abi, bytecode, runtime_bytecode, client.clone(), None::<()>)?;
|
||||
let factory = get_sol_contract_factory(abi, bytecode, runtime_bytecode, client, None::<()>)?;
|
||||
let contract = factory.deploy().await?;
|
||||
|
||||
Ok(contract)
|
||||
@@ -452,20 +447,30 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
}
|
||||
}
|
||||
|
||||
let (abi, bytecode, runtime_bytecode) =
|
||||
get_contract_artifacts(sol_code_path, "DataAttestation", runs).await?;
|
||||
|
||||
let (contract_addresses, call_data, decimals) = if !calls_to_accounts.is_empty() {
|
||||
parse_calls_to_accounts(calls_to_accounts)?
|
||||
} else {
|
||||
return Err(EthError::OnChainDataSource);
|
||||
// if calls to accounts is empty then we know need to check that atleast there kzg visibility in the settings file
|
||||
let kzg_visibility = settings.run_args.input_visibility.is_polycommit()
|
||||
|| settings.run_args.output_visibility.is_polycommit()
|
||||
|| settings.run_args.param_visibility.is_polycommit();
|
||||
if !kzg_visibility {
|
||||
return Err(EthError::OnChainDataSource);
|
||||
}
|
||||
let factory =
|
||||
get_sol_contract_factory::<_, ()>(abi, bytecode, runtime_bytecode, client, None)?;
|
||||
let contract = factory.deploy().await?;
|
||||
return Ok(contract);
|
||||
};
|
||||
|
||||
let (abi, bytecode, runtime_bytecode) =
|
||||
get_contract_artifacts(sol_code_path, "DataAttestation", runs).await?;
|
||||
|
||||
let factory = get_sol_contract_factory(
|
||||
abi,
|
||||
bytecode,
|
||||
runtime_bytecode,
|
||||
client.clone(),
|
||||
client,
|
||||
Some((
|
||||
// address[] memory _contractAddresses,
|
||||
DynSeqToken(
|
||||
@@ -506,7 +511,7 @@ pub async fn deploy_da_verifier_via_solidity(
|
||||
),
|
||||
// uint8 _instanceOffset,
|
||||
WordToken(U256::from(contract_instance_offset as u32).into()),
|
||||
//address _admin
|
||||
// address _admin
|
||||
WordToken(client_address.into_word()),
|
||||
)),
|
||||
)?;
|
||||
@@ -529,7 +534,7 @@ fn parse_calls_to_accounts(
|
||||
let mut call_data = vec![];
|
||||
let mut decimals: Vec<Vec<U256>> = vec![];
|
||||
for (i, val) in calls_to_accounts.iter().enumerate() {
|
||||
let contract_address_bytes = hex::decode(val.address.clone())?;
|
||||
let contract_address_bytes = hex::decode(&val.address)?;
|
||||
let contract_address = H160::from_slice(&contract_address_bytes);
|
||||
contract_addresses.push(contract_address);
|
||||
call_data.push(vec![]);
|
||||
@@ -573,7 +578,7 @@ pub async fn update_account_calls(
|
||||
|
||||
let (client, client_address) = setup_eth_backend(rpc_url, None).await?;
|
||||
|
||||
let contract = DataAttestation::new(addr, client.clone());
|
||||
let contract = DataAttestation::new(addr, &client);
|
||||
|
||||
info!("contract_addresses: {:#?}", contract_addresses);
|
||||
|
||||
@@ -726,7 +731,7 @@ pub async fn verify_proof_with_data_attestation(
|
||||
|
||||
for val in flattened_instances.clone() {
|
||||
let bytes = val.to_repr();
|
||||
let u = U256::from_le_slice(bytes.as_slice());
|
||||
let u = U256::from_le_slice(bytes.inner().as_slice());
|
||||
public_inputs.push(u);
|
||||
}
|
||||
|
||||
@@ -804,7 +809,7 @@ pub async fn test_on_chain_data<M: 'static + Provider<Http<Client>, Ethereum>>(
|
||||
client: Arc<M>,
|
||||
data: &[Vec<FileSourceInner>],
|
||||
) -> Result<Vec<CallsToAccount>, EthError> {
|
||||
let (contract, decimals) = setup_test_contract(client.clone(), data).await?;
|
||||
let (contract, decimals) = setup_test_contract(client, data).await?;
|
||||
|
||||
// Get the encoded call data for each input
|
||||
let mut calldata = vec![];
|
||||
@@ -836,10 +841,10 @@ pub async fn read_on_chain_inputs<M: 'static + Provider<Http<Client>, Ethereum>>
|
||||
let mut decimals = vec![];
|
||||
for on_chain_data in data {
|
||||
// Construct the address
|
||||
let contract_address_bytes = hex::decode(on_chain_data.address.clone())?;
|
||||
let contract_address_bytes = hex::decode(&on_chain_data.address)?;
|
||||
let contract_address = H160::from_slice(&contract_address_bytes);
|
||||
for (call_data, decimal) in &on_chain_data.call_data {
|
||||
let call_data_bytes = hex::decode(call_data.clone())?;
|
||||
let call_data_bytes = hex::decode(call_data)?;
|
||||
let input: TransactionInput = call_data_bytes.into();
|
||||
|
||||
let tx = TransactionRequest::default()
|
||||
@@ -866,8 +871,8 @@ pub async fn evm_quantize<M: 'static + Provider<Http<Client>, Ethereum>>(
|
||||
) -> Result<Vec<Fr>, EthError> {
|
||||
let contract = QuantizeData::deploy(&client).await?;
|
||||
|
||||
let fetched_inputs = data.0.clone();
|
||||
let decimals = data.1.clone();
|
||||
let fetched_inputs = &data.0;
|
||||
let decimals = &data.1;
|
||||
|
||||
let fetched_inputs = fetched_inputs
|
||||
.iter()
|
||||
@@ -943,7 +948,7 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
|
||||
(None, false) => {
|
||||
return Err(EthError::NoConstructor);
|
||||
}
|
||||
(None, true) => bytecode.clone(),
|
||||
(None, true) => bytecode,
|
||||
(Some(_), _) => {
|
||||
let mut data = bytecode.to_vec();
|
||||
|
||||
@@ -955,7 +960,7 @@ fn get_sol_contract_factory<'a, M: 'static + Provider<Http<Client>, Ethereum>, T
|
||||
}
|
||||
};
|
||||
|
||||
Ok(CallBuilder::new_raw_deploy(client.clone(), data))
|
||||
Ok(CallBuilder::new_raw_deploy(client, data))
|
||||
}
|
||||
|
||||
/// Compiles a solidity verifier contract and returns the abi, bytecode, and runtime bytecode
|
||||
@@ -1030,7 +1035,7 @@ pub fn fix_da_sol(
|
||||
|
||||
// fill in the quantization params and total calls
|
||||
// as constants to the contract to save on gas
|
||||
if let Some(input_data) = input_data {
|
||||
if let Some(input_data) = &input_data {
|
||||
let input_calls: usize = input_data.iter().map(|v| v.call_data.len()).sum();
|
||||
accounts_len = input_data.len();
|
||||
contract = contract.replace(
|
||||
@@ -1038,7 +1043,7 @@ pub fn fix_da_sol(
|
||||
&format!("uint256 constant INPUT_CALLS = {};", input_calls),
|
||||
);
|
||||
}
|
||||
if let Some(output_data) = output_data {
|
||||
if let Some(output_data) = &output_data {
|
||||
let output_calls: usize = output_data.iter().map(|v| v.call_data.len()).sum();
|
||||
accounts_len += output_data.len();
|
||||
contract = contract.replace(
|
||||
@@ -1048,8 +1053,9 @@ pub fn fix_da_sol(
|
||||
}
|
||||
contract = contract.replace("AccountCall[]", &format!("AccountCall[{}]", accounts_len));
|
||||
|
||||
if commitment_bytes.clone().is_some() && !commitment_bytes.clone().unwrap().is_empty() {
|
||||
let commitment_bytes = commitment_bytes.unwrap();
|
||||
// The case where a combination of on-chain data source + kzg commit is provided.
|
||||
if commitment_bytes.is_some() && !commitment_bytes.as_ref().unwrap().is_empty() {
|
||||
let commitment_bytes = commitment_bytes.as_ref().unwrap();
|
||||
let hex_string = hex::encode(commitment_bytes);
|
||||
contract = contract.replace(
|
||||
"bytes constant COMMITMENT_KZG = hex\"\";",
|
||||
@@ -1064,5 +1070,44 @@ pub fn fix_da_sol(
|
||||
);
|
||||
}
|
||||
|
||||
// if both input and output data is none then we will only deploy the DataAttest contract, adding in the verifyWithDataAttestation function
|
||||
if input_data.is_none()
|
||||
&& output_data.is_none()
|
||||
&& commitment_bytes.as_ref().is_some()
|
||||
&& !commitment_bytes.as_ref().unwrap().is_empty()
|
||||
{
|
||||
contract = contract.replace(
|
||||
"contract SwapProofCommitments {",
|
||||
"contract DataAttestation {",
|
||||
);
|
||||
|
||||
// Remove everything past the end of the checkKzgCommits function
|
||||
if let Some(pos) = contract.find(" } /// end checkKzgCommits") {
|
||||
contract.truncate(pos);
|
||||
contract.push('}');
|
||||
}
|
||||
|
||||
// Add the Solidity function below checkKzgCommits
|
||||
contract.push_str(
|
||||
r#"
|
||||
function verifyWithDataAttestation(
|
||||
address verifier,
|
||||
bytes calldata encoded
|
||||
) public view returns (bool) {
|
||||
require(verifier.code.length > 0, "Address: call to non-contract");
|
||||
require(checkKzgCommits(encoded), "Invalid KZG commitments");
|
||||
// static call the verifier contract to verify the proof
|
||||
(bool success, bytes memory returndata) = verifier.staticcall(encoded);
|
||||
|
||||
if (success) {
|
||||
return abi.decode(returndata, (bool));
|
||||
} else {
|
||||
revert("low-level call to verifier failed");
|
||||
}
|
||||
}
|
||||
}"#,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(contract)
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::circuit::region::RegionSettings;
|
||||
use crate::circuit::CheckMode;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::commands::CalibrationTarget;
|
||||
@@ -506,10 +507,12 @@ pub async fn run(command: Commands) -> Result<String, EZKLError> {
|
||||
)
|
||||
.await
|
||||
}
|
||||
#[cfg(not(feature = "no-update"))]
|
||||
Commands::Update { version } => update_ezkl_binary(&version).map(|e| e.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "no-update"))]
|
||||
/// Assert that the version is valid
|
||||
fn assert_version_is_valid(version: &str) -> Result<(), EZKLError> {
|
||||
let err_string = "Invalid version string. Must be in the format v0.0.0";
|
||||
@@ -527,13 +530,28 @@ fn assert_version_is_valid(version: &str) -> Result<(), EZKLError> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "no-update"))]
|
||||
const INSTALL_BYTES: &[u8] = include_bytes!("../install_ezkl_cli.sh");
|
||||
|
||||
#[cfg(not(feature = "no-update"))]
|
||||
fn update_ezkl_binary(version: &Option<String>) -> Result<String, EZKLError> {
|
||||
// run the install script with the version
|
||||
let install_script = std::str::from_utf8(INSTALL_BYTES)?;
|
||||
// now run as sh script with the version as an argument
|
||||
let mut command = std::process::Command::new("sh");
|
||||
|
||||
// check if bash is installed
|
||||
let command = if std::process::Command::new("bash")
|
||||
.arg("--version")
|
||||
.status()
|
||||
.is_err()
|
||||
{
|
||||
log::warn!("bash is not installed on this system, trying to run the install script with sh (may fail)");
|
||||
"sh"
|
||||
} else {
|
||||
"bash"
|
||||
};
|
||||
|
||||
let mut command = std::process::Command::new(command);
|
||||
let mut command = command.arg("-c").arg(install_script);
|
||||
|
||||
if let Some(version) = version {
|
||||
@@ -768,6 +786,8 @@ pub(crate) async fn gen_witness(
|
||||
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
|
||||
let region_settings = RegionSettings::all_true();
|
||||
|
||||
let start_time = Instant::now();
|
||||
let witness = if settings.module_requires_polycommit() {
|
||||
if get_srs_path(settings.run_args.logrows, srs_path.clone(), commitment).exists() {
|
||||
@@ -782,8 +802,7 @@ pub(crate) async fn gen_witness(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
Some(&srs),
|
||||
true,
|
||||
true,
|
||||
region_settings,
|
||||
)?
|
||||
}
|
||||
Commitments::IPA => {
|
||||
@@ -797,8 +816,7 @@ pub(crate) async fn gen_witness(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
Some(&srs),
|
||||
true,
|
||||
true,
|
||||
region_settings,
|
||||
)?
|
||||
}
|
||||
}
|
||||
@@ -808,12 +826,16 @@ pub(crate) async fn gen_witness(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
None,
|
||||
true,
|
||||
true,
|
||||
region_settings,
|
||||
)?
|
||||
}
|
||||
} else {
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true, true)?
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
None,
|
||||
region_settings,
|
||||
)?
|
||||
};
|
||||
|
||||
// print each variable tuple (symbol, value) as symbol=value
|
||||
@@ -996,7 +1018,7 @@ pub(crate) async fn calibrate(
|
||||
data: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
target: CalibrationTarget,
|
||||
lookup_safety_margin: i64,
|
||||
lookup_safety_margin: f64,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
@@ -1006,6 +1028,8 @@ pub(crate) async fn calibrate(
|
||||
use std::collections::HashMap;
|
||||
use tabled::Table;
|
||||
|
||||
use crate::fieldutils::IntegerRep;
|
||||
|
||||
let data = GraphData::from_path(data)?;
|
||||
// load the pre-generated settings
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
@@ -1114,7 +1138,7 @@ pub(crate) async fn calibrate(
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
lookup_range: (i64::MIN, i64::MAX),
|
||||
lookup_range: (IntegerRep::MIN, IntegerRep::MAX),
|
||||
..settings.run_args.clone()
|
||||
};
|
||||
|
||||
@@ -1154,8 +1178,7 @@ pub(crate) async fn calibrate(
|
||||
&mut data.clone(),
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
false,
|
||||
RegionSettings::all_true(),
|
||||
)
|
||||
.map_err(|e| format!("failed to forward: {}", e))?;
|
||||
|
||||
@@ -1485,10 +1508,10 @@ pub(crate) async fn create_evm_vk(
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn create_evm_data_attestation(
|
||||
settings_path: PathBuf,
|
||||
_sol_code_path: PathBuf,
|
||||
_abi_path: PathBuf,
|
||||
_input: PathBuf,
|
||||
_witness: Option<PathBuf>,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
input: PathBuf,
|
||||
witness: Option<PathBuf>,
|
||||
) -> Result<String, EZKLError> {
|
||||
#[allow(unused_imports)]
|
||||
use crate::graph::{DataSource, VarVisibility};
|
||||
@@ -1500,7 +1523,7 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
trace!("params computed");
|
||||
|
||||
// if input is not provided, we just instantiate dummy input data
|
||||
let data = GraphData::from_path(_input).unwrap_or(GraphData::new(DataSource::File(vec![])));
|
||||
let data = GraphData::from_path(input).unwrap_or(GraphData::new(DataSource::File(vec![])));
|
||||
|
||||
let output_data = if let Some(DataSource::OnChain(source)) = data.output_data {
|
||||
if visibility.output.is_private() {
|
||||
@@ -1535,7 +1558,7 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
|| settings.run_args.output_visibility == Visibility::KZGCommit
|
||||
|| settings.run_args.param_visibility == Visibility::KZGCommit
|
||||
{
|
||||
let witness = GraphWitness::from_path(_witness.unwrap_or(DEFAULT_WITNESS.into()))?;
|
||||
let witness = GraphWitness::from_path(witness.unwrap_or(DEFAULT_WITNESS.into()))?;
|
||||
let commitments = witness.get_polycommitments();
|
||||
let proof_first_bytes = get_proof_commitments::<
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
@@ -1549,12 +1572,12 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
};
|
||||
|
||||
let output = fix_da_sol(input_data, output_data, commitment_bytes)?;
|
||||
let mut f = File::create(_sol_code_path.clone())?;
|
||||
let mut f = File::create(sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(_sol_code_path, "DataAttestation", 0).await?;
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestation", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(_abi_path)?, &abi)?;
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
@@ -2,42 +2,21 @@ use halo2_proofs::arithmetic::Field;
|
||||
/// Utilities for converting from Halo2 PrimeField types to integers (and vice-versa).
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
/// Converts an i32 to a PrimeField element.
|
||||
pub fn i32_to_felt<F: PrimeField>(x: i32) -> F {
|
||||
if x >= 0 {
|
||||
F::from(x as u64)
|
||||
} else {
|
||||
-F::from(x.unsigned_abs() as u64)
|
||||
}
|
||||
}
|
||||
/// Integer representation of a PrimeField element.
|
||||
pub type IntegerRep = i128;
|
||||
|
||||
/// Converts an i64 to a PrimeField element.
|
||||
pub fn i64_to_felt<F: PrimeField>(x: i64) -> F {
|
||||
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
if x >= 0 {
|
||||
F::from_u128(x as u128)
|
||||
} else {
|
||||
-F::from_u128((-x) as u128)
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a PrimeField element to an i32.
|
||||
pub fn felt_to_i32<F: PrimeField + PartialOrd + Field>(x: F) -> i32 {
|
||||
if x > F::from(i32::MAX as u64) {
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_32 = u32::from_le_bytes(negtmp[..4].try_into().unwrap());
|
||||
-(lower_32 as i32)
|
||||
} else {
|
||||
let rep = (x).to_repr();
|
||||
let tmp: &[u8] = rep.as_ref();
|
||||
let lower_32 = u32::from_le_bytes(tmp[..4].try_into().unwrap());
|
||||
lower_32 as i32
|
||||
-F::from_u128(x.saturating_neg() as u128)
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a PrimeField element to an f64.
|
||||
pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
if x > F::from_u128(i64::MAX as u128) {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
@@ -51,17 +30,17 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
}
|
||||
|
||||
/// Converts a PrimeField element to an i64.
|
||||
pub fn felt_to_i64<F: PrimeField + PartialOrd + Field>(x: F) -> i64 {
|
||||
if x > F::from_u128(i64::MAX as u128) {
|
||||
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
-(lower_128 as i64)
|
||||
-(lower_128 as IntegerRep)
|
||||
} else {
|
||||
let rep = (x).to_repr();
|
||||
let tmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(tmp[..16].try_into().unwrap());
|
||||
lower_128 as i64
|
||||
lower_128 as IntegerRep
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,33 +52,24 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_conv() {
|
||||
let res: F = i32_to_felt(-15i32);
|
||||
let res: F = integer_rep_to_felt(-15);
|
||||
assert_eq!(res, -F::from(15));
|
||||
|
||||
let res: F = i32_to_felt(2_i32.pow(17));
|
||||
let res: F = integer_rep_to_felt(2_i128.pow(17));
|
||||
assert_eq!(res, F::from(131072));
|
||||
|
||||
let res: F = i64_to_felt(-15i64);
|
||||
let res: F = integer_rep_to_felt(-15);
|
||||
assert_eq!(res, -F::from(15));
|
||||
|
||||
let res: F = i64_to_felt(2_i64.pow(17));
|
||||
let res: F = integer_rep_to_felt(2_i128.pow(17));
|
||||
assert_eq!(res, F::from(131072));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttoi32() {
|
||||
for x in -(2i32.pow(16))..(2i32.pow(16)) {
|
||||
let fieldx: F = i32_to_felt::<F>(x);
|
||||
let xf: i32 = felt_to_i32::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttoi64() {
|
||||
for x in -(2i64.pow(20))..(2i64.pow(20)) {
|
||||
let fieldx: F = i64_to_felt::<F>(x);
|
||||
let xf: i64 = felt_to_i64::<F>(fieldx);
|
||||
fn felttointegerrep() {
|
||||
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,16 +41,16 @@ pub enum GraphError {
|
||||
/// Error when attempting to rescale an operation
|
||||
#[error("failed to rescale inputs for {0}")]
|
||||
RescalingError(String),
|
||||
/// Error when attempting to load a model from a file
|
||||
#[error("failed to load model")]
|
||||
ModelLoad(#[from] std::io::Error),
|
||||
/// Reading a file failed
|
||||
#[error("[io] ({0}) {1}")]
|
||||
ReadWriteFileError(String, String),
|
||||
/// Model serialization error
|
||||
#[error("failed to ser/deser model: {0}")]
|
||||
ModelSerialize(#[from] bincode::Error),
|
||||
/// Tract error
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[tract] {0}")]
|
||||
TractError(#[from] tract_onnx::tract_core::anyhow::Error),
|
||||
TractError(#[from] tract_onnx::prelude::TractError),
|
||||
/// Packing exponent is too large
|
||||
#[error("largest packing exponent exceeds max. try reducing the scale")]
|
||||
PackingExponent,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::errors::GraphError;
|
||||
use super::quantize_float;
|
||||
use crate::circuit::InputType;
|
||||
use crate::fieldutils::i64_to_felt;
|
||||
use crate::fieldutils::integer_rep_to_felt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::postgres::Client;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -128,7 +128,7 @@ impl FileSourceInner {
|
||||
/// Convert to a field element
|
||||
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => i64_to_felt(quantize_float(f, 0.0, scale).unwrap()),
|
||||
FileSourceInner::Float(f) => integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap()),
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
Fp::one()
|
||||
@@ -150,7 +150,7 @@ impl FileSourceInner {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_i64(*f) as f64,
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -485,18 +485,25 @@ impl GraphData {
|
||||
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
||||
let reader = std::fs::File::open(path)?;
|
||||
let reader = std::fs::File::open(&path).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
let mut reader = BufReader::with_capacity(*EZKL_BUF_CAPACITY, reader);
|
||||
let mut buf = String::new();
|
||||
reader.read_to_string(&mut buf)?;
|
||||
reader.read_to_string(&mut buf).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
let graph_input = serde_json::from_str(&buf)?;
|
||||
Ok(graph_input)
|
||||
}
|
||||
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
||||
let file = std::fs::File::create(path.clone()).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
// buf writer
|
||||
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
let writer = BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::to_writer(writer, self)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -34,10 +34,10 @@ use self::input::{FileSource, GraphData};
|
||||
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::circuit::region::{ConstantsMap, RegionSettings};
|
||||
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::{CheckMode, InputType};
|
||||
use crate::fieldutils::felt_to_f64;
|
||||
use crate::fieldutils::{felt_to_f64, IntegerRep};
|
||||
use crate::pfsys::PrettyElements;
|
||||
use crate::tensor::{Tensor, ValTensor};
|
||||
use crate::{RunArgs, EZKL_BUF_CAPACITY};
|
||||
@@ -69,13 +69,14 @@ pub use vars::*;
|
||||
use crate::pfsys::field_to_string;
|
||||
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: i64 = 2;
|
||||
pub const RANGE_MULTIPLIER: IntegerRep = 2;
|
||||
|
||||
/// The maximum number of columns in a lookup table.
|
||||
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
|
||||
|
||||
/// Max representation of a lookup table input
|
||||
pub const MAX_LOOKUP_ABS: i64 = (MAX_NUM_LOOKUP_COLS as i64) * 2_i64.pow(MAX_PUBLIC_SRS);
|
||||
pub const MAX_LOOKUP_ABS: IntegerRep =
|
||||
(MAX_NUM_LOOKUP_COLS as IntegerRep) * 2_i128.pow(MAX_PUBLIC_SRS);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
lazy_static! {
|
||||
@@ -126,11 +127,11 @@ pub struct GraphWitness {
|
||||
/// Any hashes of outputs generated during the forward pass
|
||||
pub processed_outputs: Option<ModuleForwardResult>,
|
||||
/// max lookup input
|
||||
pub max_lookup_inputs: i64,
|
||||
pub max_lookup_inputs: IntegerRep,
|
||||
/// max lookup input
|
||||
pub min_lookup_inputs: i64,
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
/// max range check size
|
||||
pub max_range_size: i64,
|
||||
pub max_range_size: IntegerRep,
|
||||
}
|
||||
|
||||
impl GraphWitness {
|
||||
@@ -267,7 +268,9 @@ impl GraphWitness {
|
||||
|
||||
/// Load the model input from a file
|
||||
pub fn from_path(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
||||
let file = std::fs::File::open(path.clone())?;
|
||||
let file = std::fs::File::open(path.clone()).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
serde_json::from_reader(reader).map_err(|e| e.into())
|
||||
@@ -275,9 +278,11 @@ impl GraphWitness {
|
||||
|
||||
/// Save the model input to a file
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
||||
let file = std::fs::File::create(path.clone()).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
// use buf writer
|
||||
let writer =
|
||||
std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, std::fs::File::create(path)?);
|
||||
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, file);
|
||||
|
||||
serde_json::to_writer(writer, &self).map_err(|e| e.into())
|
||||
}
|
||||
@@ -640,7 +645,9 @@ impl GraphCircuit {
|
||||
}
|
||||
///
|
||||
pub fn save(&self, path: std::path::PathBuf) -> Result<(), GraphError> {
|
||||
let f = std::fs::File::create(path)?;
|
||||
let f = std::fs::File::create(&path).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
let writer = std::io::BufWriter::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
bincode::serialize_into(writer, &self)?;
|
||||
Ok(())
|
||||
@@ -649,7 +656,9 @@ impl GraphCircuit {
|
||||
///
|
||||
pub fn load(path: std::path::PathBuf) -> Result<Self, GraphError> {
|
||||
// read bytes from file
|
||||
let f = std::fs::File::open(path)?;
|
||||
let f = std::fs::File::open(&path).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
let reader = std::io::BufReader::with_capacity(*EZKL_BUF_CAPACITY, f);
|
||||
let result: GraphCircuit = bincode::deserialize_from(reader)?;
|
||||
|
||||
@@ -796,18 +805,26 @@ impl GraphCircuit {
|
||||
// the ordering here is important, we want the inputs to come before the outputs
|
||||
// as they are configured in that order as Column<Instances>
|
||||
let mut public_inputs: Vec<Fp> = vec![];
|
||||
if self.settings().run_args.input_visibility.is_public() {
|
||||
public_inputs.extend(self.graph_witness.inputs.clone().into_iter().flatten())
|
||||
} else if let Some(processed_inputs) = &data.processed_inputs {
|
||||
|
||||
// we first process the inputs
|
||||
if let Some(processed_inputs) = &data.processed_inputs {
|
||||
public_inputs.extend(processed_inputs.get_instances().into_iter().flatten());
|
||||
}
|
||||
|
||||
// we then process the params
|
||||
if let Some(processed_params) = &data.processed_params {
|
||||
public_inputs.extend(processed_params.get_instances().into_iter().flatten());
|
||||
}
|
||||
|
||||
// if the inputs are public, we add them to the public inputs AFTER the processed params as they are configured in that order as Column<Instances>
|
||||
if self.settings().run_args.input_visibility.is_public() {
|
||||
public_inputs.extend(self.graph_witness.inputs.clone().into_iter().flatten())
|
||||
}
|
||||
|
||||
// if the outputs are public, we add them to the public inputs
|
||||
if self.settings().run_args.output_visibility.is_public() {
|
||||
public_inputs.extend(self.graph_witness.outputs.clone().into_iter().flatten());
|
||||
// if the outputs are processed, we add the processed outputs to the public inputs
|
||||
} else if let Some(processed_outputs) = &data.processed_outputs {
|
||||
public_inputs.extend(processed_outputs.get_instances().into_iter().flatten());
|
||||
}
|
||||
@@ -1026,14 +1043,14 @@ impl GraphCircuit {
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i64) -> Range {
|
||||
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: f64) -> Range {
|
||||
(
|
||||
lookup_safety_margin * min_max_lookup.0,
|
||||
lookup_safety_margin * min_max_lookup.1,
|
||||
(lookup_safety_margin * min_max_lookup.0 as f64).floor() as IntegerRep,
|
||||
(lookup_safety_margin * min_max_lookup.1 as f64).ceil() as IntegerRep,
|
||||
)
|
||||
}
|
||||
|
||||
fn calc_num_cols(range_len: i64, max_logrows: u32) -> usize {
|
||||
fn calc_num_cols(range_len: IntegerRep, max_logrows: u32) -> usize {
|
||||
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
|
||||
num_cols_required(range_len, max_col_size)
|
||||
}
|
||||
@@ -1041,7 +1058,7 @@ impl GraphCircuit {
|
||||
fn table_size_logrows(
|
||||
&self,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: i64,
|
||||
max_range_size: IntegerRep,
|
||||
) -> Result<u32, GraphError> {
|
||||
// pick the range with the largest absolute size safe_lookup_range or max_range_size
|
||||
let safe_range = std::cmp::max(
|
||||
@@ -1060,9 +1077,9 @@ impl GraphCircuit {
|
||||
pub fn calc_min_logrows(
|
||||
&mut self,
|
||||
min_max_lookup: Range,
|
||||
max_range_size: i64,
|
||||
max_range_size: IntegerRep,
|
||||
max_logrows: Option<u32>,
|
||||
lookup_safety_margin: i64,
|
||||
lookup_safety_margin: f64,
|
||||
) -> Result<(), GraphError> {
|
||||
// load the max logrows
|
||||
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
|
||||
@@ -1072,9 +1089,13 @@ impl GraphCircuit {
|
||||
|
||||
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
|
||||
|
||||
let lookup_size = (safe_lookup_range.1 - safe_lookup_range.0).abs();
|
||||
// check if subtraction overflows
|
||||
|
||||
let lookup_size =
|
||||
(safe_lookup_range.1.saturating_sub(safe_lookup_range.0)).saturating_abs();
|
||||
// check if has overflowed max lookup input
|
||||
if lookup_size > MAX_LOOKUP_ABS / lookup_safety_margin {
|
||||
|
||||
if lookup_size > (MAX_LOOKUP_ABS as f64 / lookup_safety_margin).floor() as IntegerRep {
|
||||
return Err(GraphError::LookupRangeTooLarge(
|
||||
lookup_size.unsigned_abs() as usize
|
||||
));
|
||||
@@ -1154,7 +1175,7 @@ impl GraphCircuit {
|
||||
&self,
|
||||
k: u32,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: i64,
|
||||
max_range_size: IntegerRep,
|
||||
) -> bool {
|
||||
// if num cols is too large then the extended k is too large
|
||||
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS
|
||||
@@ -1212,8 +1233,7 @@ impl GraphCircuit {
|
||||
inputs: &mut [Tensor<Fp>],
|
||||
vk: Option<&VerifyingKey<G1Affine>>,
|
||||
srs: Option<&Scheme::ParamsProver>,
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
region_settings: RegionSettings,
|
||||
) -> Result<GraphWitness, GraphError> {
|
||||
let original_inputs = inputs.to_vec();
|
||||
|
||||
@@ -1262,7 +1282,7 @@ impl GraphCircuit {
|
||||
|
||||
let mut model_results =
|
||||
self.model()
|
||||
.forward(inputs, &self.settings().run_args, witness_gen, check_lookup)?;
|
||||
.forward(inputs, &self.settings().run_args, region_settings)?;
|
||||
|
||||
if visibility.output.requires_processing() {
|
||||
let module_outlets = visibility.output.overwrites_inputs();
|
||||
|
||||
@@ -7,10 +7,12 @@ use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
use crate::circuit::region::RegionSettings;
|
||||
use crate::circuit::table::Range;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::Unknown;
|
||||
use crate::fieldutils::IntegerRep;
|
||||
use crate::tensor::ValType;
|
||||
use crate::{
|
||||
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
|
||||
@@ -64,11 +66,11 @@ pub struct ForwardResult {
|
||||
/// The outputs of the forward pass.
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
/// The maximum value of any input to a lookup operation.
|
||||
pub max_lookup_inputs: i64,
|
||||
pub max_lookup_inputs: IntegerRep,
|
||||
/// The minimum value of any input to a lookup operation.
|
||||
pub min_lookup_inputs: i64,
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
/// The max range check size
|
||||
pub max_range_size: i64,
|
||||
pub max_range_size: IntegerRep,
|
||||
}
|
||||
|
||||
impl From<DummyPassRes> for ForwardResult {
|
||||
@@ -116,11 +118,11 @@ pub struct DummyPassRes {
|
||||
/// range checks
|
||||
pub range_checks: HashSet<Range>,
|
||||
/// max lookup inputs
|
||||
pub max_lookup_inputs: i64,
|
||||
pub max_lookup_inputs: IntegerRep,
|
||||
/// min lookup inputs
|
||||
pub min_lookup_inputs: i64,
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
/// min range check
|
||||
pub max_range_size: i64,
|
||||
pub max_range_size: IntegerRep,
|
||||
/// outputs
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
}
|
||||
@@ -483,7 +485,9 @@ impl Model {
|
||||
|
||||
///
|
||||
pub fn save(&self, path: PathBuf) -> Result<(), GraphError> {
|
||||
let f = std::fs::File::create(path)?;
|
||||
let f = std::fs::File::create(&path).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
let writer = std::io::BufWriter::new(f);
|
||||
bincode::serialize_into(writer, &self)?;
|
||||
Ok(())
|
||||
@@ -492,10 +496,16 @@ impl Model {
|
||||
///
|
||||
pub fn load(path: PathBuf) -> Result<Self, GraphError> {
|
||||
// read bytes from file
|
||||
let mut f = std::fs::File::open(&path)?;
|
||||
let metadata = fs::metadata(&path)?;
|
||||
let mut f = std::fs::File::open(&path).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
let metadata = fs::metadata(&path).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
let mut buffer = vec![0; metadata.len() as usize];
|
||||
f.read_exact(&mut buffer)?;
|
||||
f.read_exact(&mut buffer).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
let result = bincode::deserialize(&buffer)?;
|
||||
Ok(result)
|
||||
}
|
||||
@@ -537,7 +547,7 @@ impl Model {
|
||||
})
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
let res = self.dummy_layout(run_args, &inputs, false, false)?;
|
||||
let res = self.dummy_layout(run_args, &inputs, RegionSettings::all_false())?;
|
||||
|
||||
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
|
||||
|
||||
@@ -580,14 +590,13 @@ impl Model {
|
||||
&self,
|
||||
model_inputs: &[Tensor<Fp>],
|
||||
run_args: &RunArgs,
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
region_settings: RegionSettings,
|
||||
) -> Result<ForwardResult, GraphError> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
.collect();
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen, check_lookup)?;
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, region_settings)?;
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
@@ -601,9 +610,7 @@ impl Model {
|
||||
reader: &mut dyn std::io::Read,
|
||||
run_args: &RunArgs,
|
||||
) -> Result<TractResult, GraphError> {
|
||||
use tract_onnx::{
|
||||
tract_core::internal::IntoArcTensor, tract_hir::internal::GenericFactoid,
|
||||
};
|
||||
use tract_onnx::tract_hir::internal::GenericFactoid;
|
||||
|
||||
let mut model = tract_onnx::onnx().model_for_read(reader)?;
|
||||
|
||||
@@ -640,29 +647,11 @@ impl Model {
|
||||
}
|
||||
|
||||
// Note: do not optimize the model, as the layout will depend on underlying hardware
|
||||
let mut typed_model = model
|
||||
let typed_model = model
|
||||
.into_typed()?
|
||||
.concretize_dims(&symbol_values)?
|
||||
.into_decluttered()?;
|
||||
|
||||
// concretize constants
|
||||
for node in typed_model.eval_order()? {
|
||||
let node = typed_model.node_mut(node);
|
||||
if let Some(op) = node.op_as_mut::<tract_onnx::tract_core::ops::konst::Const>() {
|
||||
if op.0.datum_type() == DatumType::TDim {
|
||||
// get inner value to Arc<Tensor>
|
||||
let mut constant = op.0.as_ref().clone();
|
||||
// Generally a shape or hyperparam
|
||||
constant
|
||||
.as_slice_mut::<tract_onnx::prelude::TDim>()?
|
||||
.iter_mut()
|
||||
.for_each(|x| *x = x.eval(&symbol_values));
|
||||
|
||||
op.0 = constant.into_arc_tensor();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok((typed_model, symbol_values))
|
||||
}
|
||||
|
||||
@@ -975,8 +964,11 @@ impl Model {
|
||||
) -> Result<Vec<Vec<Tensor<f32>>>, GraphError> {
|
||||
use tract_onnx::tract_core::internal::IntoArcTensor;
|
||||
|
||||
let (model, _) =
|
||||
Model::load_onnx_using_tract(&mut std::fs::File::open(model_path)?, run_args)?;
|
||||
let mut file = std::fs::File::open(model_path).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(model_path.display().to_string(), e.to_string())
|
||||
})?;
|
||||
|
||||
let (model, _) = Model::load_onnx_using_tract(&mut file, run_args)?;
|
||||
|
||||
let datum_types: Vec<DatumType> = model
|
||||
.input_outlets()?
|
||||
@@ -1005,7 +997,10 @@ impl Model {
|
||||
/// * `params` - A [GraphSettings] struct holding parsed CLI arguments.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn from_run_args(run_args: &RunArgs, model: &std::path::Path) -> Result<Self, GraphError> {
|
||||
Model::new(&mut std::fs::File::open(model)?, run_args)
|
||||
let mut file = std::fs::File::open(model).map_err(|e| {
|
||||
GraphError::ReadWriteFileError(model.display().to_string(), e.to_string())
|
||||
})?;
|
||||
Model::new(&mut file, run_args)
|
||||
}
|
||||
|
||||
/// Configures a model for the circuit
|
||||
@@ -1396,8 +1391,7 @@ impl Model {
|
||||
&self,
|
||||
run_args: &RunArgs,
|
||||
inputs: &[ValTensor<Fp>],
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
region_settings: RegionSettings,
|
||||
) -> Result<DummyPassRes, GraphError> {
|
||||
debug!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
@@ -1416,8 +1410,7 @@ impl Model {
|
||||
vars: ModelVars::new_dummy(),
|
||||
};
|
||||
|
||||
let mut region =
|
||||
RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen, check_lookup);
|
||||
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, region_settings);
|
||||
|
||||
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::circuit::lookup::LookupOp;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::Op;
|
||||
use crate::fieldutils::IntegerRep;
|
||||
use crate::tensor::{Tensor, TensorError, TensorType};
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
use halo2curves::ff::PrimeField;
|
||||
@@ -50,16 +51,20 @@ use tract_onnx::tract_hir::{
|
||||
/// * `dims` - the dimensionality of the resulting [Tensor].
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i64, TensorError> {
|
||||
pub fn quantize_float(
|
||||
elem: &f64,
|
||||
shift: f64,
|
||||
scale: crate::Scale,
|
||||
) -> Result<IntegerRep, TensorError> {
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let max_value = ((i64::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
|
||||
if *elem > max_value {
|
||||
return Err(TensorError::SigBitTruncationError);
|
||||
}
|
||||
|
||||
// we parallelize the quantization process as it seems to be quite slow at times
|
||||
let scaled = (mult * *elem + shift).round() as i64;
|
||||
let scaled = (mult * *elem + shift).round() as IntegerRep;
|
||||
|
||||
Ok(scaled)
|
||||
}
|
||||
@@ -70,7 +75,7 @@ pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i64
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
pub fn dequantize(felt: Fp, scale: crate::Scale, shift: f64) -> f64 {
|
||||
let int_rep = crate::fieldutils::felt_to_i64(felt);
|
||||
let int_rep = crate::fieldutils::felt_to_integer_rep(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
int_rep as f64 / multiplier - shift
|
||||
}
|
||||
@@ -85,6 +90,34 @@ pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
|
||||
mult.log2().round() as crate::Scale
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// extract padding from a onnx node.
|
||||
pub fn extract_padding(
|
||||
pool_spec: &PoolSpec,
|
||||
num_dims: usize,
|
||||
) -> Result<Vec<(usize, usize)>, GraphError> {
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
PaddingSpec::Valid => vec![(0, 0); num_dims],
|
||||
_ => {
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
Ok(padding)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Extracts the strides from a onnx node.
|
||||
pub fn extract_strides(pool_spec: &PoolSpec) -> Result<Vec<usize>, GraphError> {
|
||||
Ok(pool_spec
|
||||
.strides
|
||||
.clone()
|
||||
.ok_or(GraphError::MissingParams("stride".to_string()))?
|
||||
.to_vec())
|
||||
}
|
||||
|
||||
/// Gets the shape of a onnx node's outlets.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn node_output_shapes(
|
||||
@@ -255,6 +288,8 @@ pub fn new_op_from_onnx(
|
||||
.flat_map(|x| x.out_scales())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let input_dims = inputs.iter().flat_map(|x| x.out_dims()).collect::<Vec<_>>();
|
||||
|
||||
let mut replace_const = |scale: crate::Scale,
|
||||
index: usize,
|
||||
default_op: SupportedOp|
|
||||
@@ -309,12 +344,9 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
"MultiBroadcastTo" => {
|
||||
let op = load_op::<MultiBroadcastTo>(node.op(), idx, node.op().name().to_string())?;
|
||||
let shape = op.shape.clone();
|
||||
let shape = shape
|
||||
.iter()
|
||||
.map(|x| x.to_usize())
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
let _op = load_op::<MultiBroadcastTo>(node.op(), idx, node.op().name().to_string())?;
|
||||
let shapes = node_output_shapes(&node, symbol_values)?;
|
||||
let shape = shapes[0].clone();
|
||||
SupportedOp::Linear(PolyOp::MultiBroadcastTo { shape })
|
||||
}
|
||||
|
||||
@@ -1073,18 +1105,8 @@ pub fn new_op_from_onnx(
|
||||
));
|
||||
}
|
||||
|
||||
let stride = pool_spec
|
||||
.strides
|
||||
.clone()
|
||||
.ok_or(GraphError::MissingParams("stride".to_string()))?;
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, input_dims[0].len())?;
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::MaxPool {
|
||||
@@ -1151,21 +1173,10 @@ pub fn new_op_from_onnx(
|
||||
));
|
||||
}
|
||||
|
||||
let stride = match conv_node.pool_spec.strides.clone() {
|
||||
Some(s) => s.to_vec(),
|
||||
None => {
|
||||
return Err(GraphError::MissingParams("strides".to_string()));
|
||||
}
|
||||
};
|
||||
let pool_spec = &conv_node.pool_spec;
|
||||
|
||||
let padding = match &conv_node.pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, input_dims[0].len())?;
|
||||
|
||||
// if bias exists then rescale it to the input + kernel scale
|
||||
if input_scales.len() == 3 {
|
||||
@@ -1183,7 +1194,13 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
SupportedOp::Linear(PolyOp::Conv { padding, stride })
|
||||
let group = conv_node.group;
|
||||
|
||||
SupportedOp::Linear(PolyOp::Conv {
|
||||
padding,
|
||||
stride,
|
||||
group,
|
||||
})
|
||||
}
|
||||
"Not" => SupportedOp::Linear(PolyOp::Not),
|
||||
"And" => SupportedOp::Linear(PolyOp::And),
|
||||
@@ -1214,21 +1231,10 @@ pub fn new_op_from_onnx(
|
||||
));
|
||||
}
|
||||
|
||||
let stride = match deconv_node.pool_spec.strides.clone() {
|
||||
Some(s) => s.to_vec(),
|
||||
None => {
|
||||
return Err(GraphError::MissingParams("strides".to_string()));
|
||||
}
|
||||
};
|
||||
let padding = match &deconv_node.pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
let pool_spec = &deconv_node.pool_spec;
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, input_dims[0].len())?;
|
||||
// if bias exists then rescale it to the input + kernel scale
|
||||
if input_scales.len() == 3 {
|
||||
let bias_scale = input_scales[2];
|
||||
@@ -1249,6 +1255,7 @@ pub fn new_op_from_onnx(
|
||||
padding,
|
||||
output_padding: deconv_node.adjustments.to_vec(),
|
||||
stride,
|
||||
group: deconv_node.group,
|
||||
})
|
||||
}
|
||||
"Downsample" => {
|
||||
@@ -1339,18 +1346,8 @@ pub fn new_op_from_onnx(
|
||||
));
|
||||
}
|
||||
|
||||
let stride = pool_spec
|
||||
.strides
|
||||
.clone()
|
||||
.ok_or(GraphError::MissingParams("stride".to_string()))?;
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, input_dims[0].len())?;
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::SumPool {
|
||||
padding,
|
||||
@@ -1432,7 +1429,7 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
|
||||
visibility: &Visibility,
|
||||
) -> Result<Tensor<F>, TensorError> {
|
||||
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
|
||||
Ok::<_, TensorError>(crate::fieldutils::i64_to_felt::<F>(quantize_float(
|
||||
Ok::<_, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
|
||||
&(x).into(),
|
||||
0.0,
|
||||
scale,
|
||||
|
||||
@@ -23,6 +23,7 @@
|
||||
)]
|
||||
// we allow this for our dynamic range based indexing scheme
|
||||
#![allow(clippy::single_range_in_vec_init)]
|
||||
#![feature(buf_read_has_data_left)]
|
||||
#![feature(stmt_expr_attributes)]
|
||||
|
||||
//! A library for turning computational graphs, such as neural networks, into ZK-circuits.
|
||||
@@ -85,6 +86,7 @@ use std::str::FromStr;
|
||||
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use clap::Args;
|
||||
use fieldutils::IntegerRep;
|
||||
use graph::Visibility;
|
||||
use halo2_proofs::poly::{
|
||||
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
|
||||
@@ -243,7 +245,7 @@ pub struct RunArgs {
|
||||
#[arg(long, default_value = "1", value_hint = clap::ValueHint::Other)]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// The min and max elements in the lookup table input column
|
||||
#[arg(short = 'B', long, value_parser = parse_key_val::<i64, i64>, default_value = "-32768->32768")]
|
||||
#[arg(short = 'B', long, value_parser = parse_key_val::<IntegerRep, IntegerRep>, default_value = "-32768->32768")]
|
||||
pub lookup_range: Range,
|
||||
/// The log_2 number of rows
|
||||
#[arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other)]
|
||||
|
||||
@@ -76,7 +76,7 @@ pub fn init_logger() {
|
||||
prefix_token(&record.level()),
|
||||
// pretty print UTC time
|
||||
chrono::Utc::now()
|
||||
.format("%Y-%m-%d %H:%M:%S")
|
||||
.format("%Y-%m-%d %H:%M:%S:%3f")
|
||||
.to_string()
|
||||
.bright_magenta(),
|
||||
record.metadata().target(),
|
||||
|
||||
@@ -558,7 +558,8 @@ where
|
||||
+ PrimeField
|
||||
+ FromUniformBytes<64>
|
||||
+ WithSmallOrderMulGroup<3>,
|
||||
Scheme::Curve: Serialize + DeserializeOwned,
|
||||
Scheme::Curve: Serialize + DeserializeOwned + SerdeObject,
|
||||
Scheme::ParamsProver: Send + Sync,
|
||||
{
|
||||
let strategy = Strategy::new(params.verifier_params());
|
||||
let mut transcript = TranscriptWriterBuffer::<_, Scheme::Curve, _>::init(vec![]);
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::circuit::modules::poseidon::{
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::{CheckMode, Tolerance};
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{felt_to_i64, i64_to_felt};
|
||||
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
use crate::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use crate::graph::TestDataSource;
|
||||
use crate::graph::{
|
||||
@@ -331,9 +331,9 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
|
||||
#[pyfunction(signature = (
|
||||
felt,
|
||||
))]
|
||||
fn felt_to_int(felt: PyFelt) -> PyResult<i64> {
|
||||
fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_i64(felt);
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
Ok(int_rep)
|
||||
}
|
||||
|
||||
@@ -357,7 +357,7 @@ fn felt_to_int(felt: PyFelt) -> PyResult<i64> {
|
||||
))]
|
||||
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_i64(felt);
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
let float_rep = int_rep as f64 / multiplier;
|
||||
Ok(float_rep)
|
||||
@@ -385,7 +385,7 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
let int_rep = quantize_float(&input, 0.0, scale)
|
||||
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
|
||||
let felt = i64_to_felt(int_rep);
|
||||
let felt = integer_rep_to_felt(int_rep);
|
||||
Ok(crate::pfsys::field_to_string::<Fr>(&felt))
|
||||
}
|
||||
|
||||
@@ -887,7 +887,7 @@ fn calibrate_settings(
|
||||
model: PathBuf,
|
||||
settings: PathBuf,
|
||||
target: CalibrationTarget,
|
||||
lookup_safety_margin: i64,
|
||||
lookup_safety_margin: f64,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
@@ -1491,7 +1491,7 @@ fn encode_evm_calldata<'a>(
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// render_vk_separately: bool
|
||||
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
|
||||
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create_evm_vk command
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
@@ -1533,6 +1533,56 @@ fn create_evm_verifier(
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates an Evm verifer key. This command should be called after create_evm_verifier with the render_vk_separately arg set to true. By rendering a verification key separately you can reuse the same verifier for similar circuit setups with different verifying keys, helping to reduce the amount of state our verifiers store on the blockchain.
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// vk_path: str
|
||||
/// The path to the verification key file
|
||||
///
|
||||
/// settings_path: str
|
||||
/// The path to the settings file
|
||||
///
|
||||
/// sol_code_path: str
|
||||
/// The path to the create the solidity verifying key.
|
||||
///
|
||||
/// abi_path: str
|
||||
/// The path to create the ABI for the solidity verifier
|
||||
///
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
srs_path=None
|
||||
))]
|
||||
fn create_evm_vk(
|
||||
py: Python,
|
||||
vk_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
|
||||
///
|
||||
/// Arguments
|
||||
@@ -1762,7 +1812,7 @@ fn deploy_da_evm(
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// addr_verifier: str
|
||||
/// The path to verifier contract's address
|
||||
/// The verifier contract's address as a hex string
|
||||
///
|
||||
/// proof_path: str
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
@@ -1774,7 +1824,7 @@ fn deploy_da_evm(
|
||||
/// does the verifier use data attestation ?
|
||||
///
|
||||
/// addr_vk: str
|
||||
///
|
||||
/// The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
@@ -1925,6 +1975,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(compile_circuit, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_vk, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
|
||||
|
||||
@@ -27,4 +27,10 @@ pub enum TensorError {
|
||||
/// Unset visibility
|
||||
#[error("unset visibility")]
|
||||
UnsetVisibility,
|
||||
/// File save error
|
||||
#[error("save error: {0}")]
|
||||
FileSaveError(String),
|
||||
/// File load error
|
||||
#[error("load error: {0}")]
|
||||
FileLoadError(String),
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ pub mod var;
|
||||
|
||||
pub use errors::TensorError;
|
||||
|
||||
use halo2curves::{bn256::Fr, ff::PrimeField};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use maybe_rayon::{
|
||||
prelude::{
|
||||
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator,
|
||||
@@ -18,6 +18,9 @@ use maybe_rayon::{
|
||||
slice::ParallelSliceMut,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::io::BufRead;
|
||||
use std::io::Write;
|
||||
use std::path::PathBuf;
|
||||
pub use val::*;
|
||||
pub use var::*;
|
||||
|
||||
@@ -26,7 +29,7 @@ use instant::Instant;
|
||||
|
||||
use crate::{
|
||||
circuit::utils,
|
||||
fieldutils::{felt_to_i32, felt_to_i64, i32_to_felt, i64_to_felt},
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
graph::Visibility,
|
||||
};
|
||||
|
||||
@@ -41,6 +44,7 @@ use itertools::Itertools;
|
||||
use metal::{Device, MTLResourceOptions, MTLSize};
|
||||
use std::error::Error;
|
||||
use std::fmt::Debug;
|
||||
use std::io::Read;
|
||||
use std::iter::Iterator;
|
||||
use std::ops::{Add, Deref, DerefMut, Div, Mul, Neg, Range, Sub};
|
||||
use std::{cmp::max, ops::Rem};
|
||||
@@ -108,7 +112,7 @@ impl TensorType for f32 {
|
||||
Some(0.0)
|
||||
}
|
||||
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for i32, usize.
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
|
||||
// A comparison between f32s needs to handle NAN values.
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
match (self.is_nan(), other.is_nan()) {
|
||||
@@ -131,7 +135,7 @@ impl TensorType for f64 {
|
||||
Some(0.0)
|
||||
}
|
||||
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for i32, usize.
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
|
||||
// A comparison between f32s needs to handle NAN values.
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
match (self.is_nan(), other.is_nan()) {
|
||||
@@ -150,8 +154,7 @@ impl TensorType for f64 {
|
||||
}
|
||||
|
||||
tensor_type!(bool, Bool, false, true);
|
||||
tensor_type!(i64, Int64, 0, 1);
|
||||
tensor_type!(i32, Int32, 0, 1);
|
||||
tensor_type!(IntegerRep, IntegerRep, 0, 1);
|
||||
tensor_type!(usize, USize, 0, 1);
|
||||
tensor_type!((), Empty, (), ());
|
||||
tensor_type!(utils::F32, F32, utils::F32(0.0), utils::F32(1.0));
|
||||
@@ -316,92 +319,6 @@ impl<T: TensorType> DerefMut for Tensor<T> {
|
||||
self.inner.deref_mut()
|
||||
}
|
||||
}
|
||||
/// Convert to i64 trait
|
||||
pub trait IntoI64 {
|
||||
/// Convert to i64
|
||||
fn into_i64(self) -> i64;
|
||||
|
||||
/// From i64
|
||||
fn from_i64(i: i64) -> Self;
|
||||
}
|
||||
|
||||
impl IntoI64 for i64 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self
|
||||
}
|
||||
fn from_i64(i: i64) -> i64 {
|
||||
i
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for i32 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as i32
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for usize {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for f32 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as f32
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for f64 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as f64
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for () {
|
||||
fn into_i64(self) -> i64 {
|
||||
0
|
||||
}
|
||||
fn from_i64(_: i64) -> Self {}
|
||||
}
|
||||
|
||||
impl IntoI64 for Fr {
|
||||
fn into_i64(self) -> i64 {
|
||||
felt_to_i64(self)
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i64_to_felt::<Fr>(i)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + IntoI64> IntoI64 for Value<F> {
|
||||
fn into_i64(self) -> i64 {
|
||||
let mut res = vec![];
|
||||
self.map(|x| res.push(x.into_i64()));
|
||||
|
||||
if res.is_empty() {
|
||||
0
|
||||
} else {
|
||||
res[0]
|
||||
}
|
||||
}
|
||||
|
||||
fn from_i64(i: i64) -> Self {
|
||||
Value::known(F::from_i64(i))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialEq + TensorType> PartialEq for Tensor<T> {
|
||||
fn eq(&self, other: &Tensor<T>) -> bool {
|
||||
@@ -431,42 +348,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<Assigned<F>, F>>>
|
||||
for Tensor<i32>
|
||||
{
|
||||
fn from(value: Tensor<AssignedCell<Assigned<F>, F>>) -> Tensor<i32> {
|
||||
let mut output = Vec::new();
|
||||
value.map(|x| {
|
||||
x.evaluate().value().map(|y| {
|
||||
let e = felt_to_i32(*y);
|
||||
output.push(e);
|
||||
e
|
||||
})
|
||||
});
|
||||
Tensor::new(Some(&output), value.dims()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>>
|
||||
for Tensor<i32>
|
||||
{
|
||||
fn from(value: Tensor<AssignedCell<F, F>>) -> Tensor<i32> {
|
||||
let mut output = Vec::new();
|
||||
value.map(|x| {
|
||||
let mut i = 0;
|
||||
x.value().map(|y| {
|
||||
let e = felt_to_i32(*y);
|
||||
output.push(e);
|
||||
i += 1;
|
||||
});
|
||||
if i == 0 {
|
||||
output.push(0);
|
||||
}
|
||||
});
|
||||
Tensor::new(Some(&output), value.dims()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<Assigned<F>, F>>>
|
||||
for Tensor<Value<F>>
|
||||
{
|
||||
@@ -479,24 +360,6 @@ impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<A
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + Clone + PartialOrd> From<Tensor<Value<F>>> for Tensor<i32> {
|
||||
fn from(t: Tensor<Value<F>>) -> Tensor<i32> {
|
||||
let mut output = Vec::new();
|
||||
t.map(|x| {
|
||||
let mut i = 0;
|
||||
x.map(|y| {
|
||||
let e = felt_to_i32(y);
|
||||
output.push(e);
|
||||
i += 1;
|
||||
});
|
||||
if i == 0 {
|
||||
output.push(0);
|
||||
}
|
||||
});
|
||||
Tensor::new(Some(&output), t.dims()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + Clone + PartialOrd> From<Tensor<Value<F>>>
|
||||
for Tensor<Value<Assigned<F>>>
|
||||
{
|
||||
@@ -508,20 +371,10 @@ impl<F: PrimeField + TensorType + Clone + PartialOrd> From<Tensor<Value<F>>>
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + Clone> From<Tensor<i32>> for Tensor<Value<F>> {
|
||||
fn from(t: Tensor<i32>) -> Tensor<Value<F>> {
|
||||
impl<F: PrimeField + TensorType + Clone> From<Tensor<IntegerRep>> for Tensor<Value<F>> {
|
||||
fn from(t: Tensor<IntegerRep>) -> Tensor<Value<F>> {
|
||||
let mut ta: Tensor<Value<F>> =
|
||||
Tensor::from((0..t.len()).map(|i| Value::known(i32_to_felt::<F>(t[i]))));
|
||||
// safe to unwrap as we know the dims are correct
|
||||
ta.reshape(t.dims()).unwrap();
|
||||
ta
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + Clone> From<Tensor<i64>> for Tensor<Value<F>> {
|
||||
fn from(t: Tensor<i64>) -> Tensor<Value<F>> {
|
||||
let mut ta: Tensor<Value<F>> =
|
||||
Tensor::from((0..t.len()).map(|i| Value::known(i64_to_felt::<F>(t[i]))));
|
||||
Tensor::from((0..t.len()).map(|i| Value::known(integer_rep_to_felt::<F>(t[i]))));
|
||||
// safe to unwrap as we know the dims are correct
|
||||
ta.reshape(t.dims()).unwrap();
|
||||
ta
|
||||
@@ -560,6 +413,45 @@ impl<'data, T: Clone + TensorType + std::marker::Send + std::marker::Sync>
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType + PrimeField> Tensor<T> {
|
||||
/// save to a file
|
||||
pub fn save(&self, path: &PathBuf) -> Result<(), TensorError> {
|
||||
let writer =
|
||||
std::fs::File::create(path).map_err(|e| TensorError::FileSaveError(e.to_string()))?;
|
||||
let mut buf_writer = std::io::BufWriter::new(writer);
|
||||
|
||||
self.inner.iter().map(|x| x.clone()).for_each(|x| {
|
||||
let x = x.to_repr();
|
||||
buf_writer.write_all(x.as_ref()).unwrap();
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// load from a file
|
||||
pub fn load(path: &PathBuf) -> Result<Self, TensorError> {
|
||||
let reader =
|
||||
std::fs::File::open(path).map_err(|e| TensorError::FileLoadError(e.to_string()))?;
|
||||
let mut buf_reader = std::io::BufReader::new(reader);
|
||||
|
||||
let mut inner = Vec::new();
|
||||
while let Ok(true) = buf_reader.has_data_left() {
|
||||
let mut repr = T::Repr::default();
|
||||
match buf_reader.read_exact(repr.as_mut()) {
|
||||
Ok(_) => {
|
||||
inner.push(T::from_repr(repr).unwrap());
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(TensorError::FileLoadError(
|
||||
"Failed to read tensor".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(Tensor::new(Some(&inner), &[inner.len()]).unwrap())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Sets (copies) the tensor values to the provided ones.
|
||||
pub fn new(values: Option<&[T]>, dims: &[usize]) -> Result<Self, TensorError> {
|
||||
@@ -633,7 +525,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(None, &[3, 3, 3]).unwrap();
|
||||
///
|
||||
/// a.set(&[0, 0, 1], 10);
|
||||
/// assert_eq!(a[0 + 0 + 1], 10);
|
||||
@@ -650,7 +543,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(None, &[2, 3, 5]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(None, &[2, 3, 5]).unwrap();
|
||||
///
|
||||
/// a[1*15 + 1*5 + 1] = 5;
|
||||
/// assert_eq!(a.get(&[1, 1, 1]), 5);
|
||||
@@ -664,7 +558,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(None, &[2, 3, 5]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(None, &[2, 3, 5]).unwrap();
|
||||
///
|
||||
/// a[1*15 + 1*5 + 1] = 5;
|
||||
/// assert_eq!(a.get(&[1, 1, 1]), 5);
|
||||
@@ -684,11 +579,12 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Pad to a length that is divisible by n
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
|
||||
///
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
|
||||
/// ```
|
||||
pub fn pad_to_zero_rem(&self, n: usize, pad: T) -> Result<Tensor<T>, TensorError> {
|
||||
@@ -704,7 +600,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(None, &[2, 3, 5]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(None, &[2, 3, 5]).unwrap();
|
||||
///
|
||||
/// let flat_index = 1*15 + 1*5 + 1;
|
||||
/// a[1*15 + 1*5 + 1] = 5;
|
||||
@@ -731,8 +628,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Get a slice from the Tensor.
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<i32>::new(Some(&[1, 2]), &[2]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<IntegerRep>::new(Some(&[1, 2]), &[2]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.get_slice(&[0..2]).unwrap(), b);
|
||||
/// ```
|
||||
@@ -782,9 +680,10 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Set a slice of the Tensor.
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let b = Tensor::<i32>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
|
||||
/// a.set_slice(&[1..2], &Tensor::<i32>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let b = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
|
||||
/// a.set_slice(&[1..2], &Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
|
||||
/// assert_eq!(a, b);
|
||||
/// ```
|
||||
pub fn set_slice(
|
||||
@@ -845,6 +744,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.get_index(&[2, 2, 2]), 26);
|
||||
@@ -868,12 +768,13 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 3, 4, 5, 5, 6]), &[8]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 3, 4, 5, 5, 6]), &[8]).unwrap();
|
||||
/// assert_eq!(a.duplicate_every_n(3, 1, 0).unwrap(), expected);
|
||||
/// assert_eq!(a.duplicate_every_n(7, 1, 0).unwrap(), a);
|
||||
///
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 1, 2, 3, 3, 4, 5, 5, 6]), &[9]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 2, 3, 3, 4, 5, 5, 6]), &[9]).unwrap();
|
||||
/// assert_eq!(a.duplicate_every_n(3, 1, 2).unwrap(), expected);
|
||||
///
|
||||
/// ```
|
||||
@@ -900,8 +801,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let a = Tensor::<i32>::new(Some(&[1, 2, 3, 3, 4, 5, 6, 6]), &[8]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 3, 5, 6, 6]), &[7]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 3, 4, 5, 6, 6]), &[8]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 3, 5, 6, 6]), &[7]).unwrap();
|
||||
/// assert_eq!(a.remove_every_n(4, 1, 0).unwrap(), expected);
|
||||
///
|
||||
///
|
||||
@@ -935,14 +837,15 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// WARN: assumes indices are in ascending order for speed
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
|
||||
/// let mut indices = vec![3, 4];
|
||||
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), expected);
|
||||
///
|
||||
///
|
||||
/// let a = Tensor::<i32>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
|
||||
/// let b = Tensor::<i32>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
|
||||
/// let b = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
|
||||
/// let mut indices = vec![63, 67];
|
||||
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), b);
|
||||
/// ```
|
||||
@@ -972,6 +875,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///Reshape the tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// a.reshape(&[9, 3]);
|
||||
/// assert_eq!(a.dims(), &[9, 3]);
|
||||
@@ -1006,22 +910,23 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Move axis of the tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// let b = a.move_axis(0, 2).unwrap();
|
||||
/// assert_eq!(b.dims(), &[3, 3, 3]);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6]), &[1, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6]), &[1, 2, 3]).unwrap();
|
||||
/// let b = a.move_axis(0, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let b = a.move_axis(1, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let b = a.move_axis(2, 1).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
/// ```
|
||||
@@ -1086,22 +991,23 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Swap axes of the tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// let b = a.swap_axes(0, 2).unwrap();
|
||||
/// assert_eq!(b.dims(), &[3, 3, 3]);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6]), &[2, 1, 3]).unwrap();
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6]), &[2, 1, 3]).unwrap();
|
||||
/// let b = a.swap_axes(0, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let b = a.swap_axes(1, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let b = a.swap_axes(2, 1).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
/// ```
|
||||
@@ -1148,9 +1054,10 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Broadcasts the tensor to a given shape
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
|
||||
///
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 1, 1, 2, 2, 2, 3, 3, 3]), &[3, 3]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 1, 2, 2, 2, 3, 3, 3]), &[3, 3]).unwrap();
|
||||
/// assert_eq!(a.expand(&[3, 3]).unwrap(), expected);
|
||||
///
|
||||
/// ```
|
||||
@@ -1204,6 +1111,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///Flatten the tensor shape
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// a.flatten();
|
||||
/// assert_eq!(a.dims(), &[27]);
|
||||
@@ -1217,8 +1125,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Maps a function to tensors
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.map(|x| i32::pow(x,2));
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.map(|x| IntegerRep::pow(x,2));
|
||||
/// assert_eq!(c, Tensor::from([1, 16].into_iter()))
|
||||
/// ```
|
||||
pub fn map<F: FnMut(T) -> G, G: TensorType>(&self, mut f: F) -> Tensor<G> {
|
||||
@@ -1231,8 +1140,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Maps a function to tensors and enumerates
|
||||
/// ```
|
||||
/// use ezkl::tensor::{Tensor, TensorError};
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
|
||||
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
|
||||
/// ```
|
||||
pub fn enum_map<F: FnMut(usize, T) -> Result<G, E>, G: TensorType, E: Error>(
|
||||
@@ -1254,8 +1164,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Maps a function to tensors and enumerates in parallel
|
||||
/// ```
|
||||
/// use ezkl::tensor::{Tensor, TensorError};
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
|
||||
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
|
||||
/// ```
|
||||
pub fn par_enum_map<
|
||||
@@ -1281,11 +1192,37 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
/// Get last elem from Tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<IntegerRep>::new(Some(&[3]), &[1]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.last().unwrap(), b);
|
||||
/// ```
|
||||
pub fn last(&self) -> Result<Tensor<T>, TensorError>
|
||||
where
|
||||
T: Send + Sync,
|
||||
{
|
||||
let res = match self.inner.last() {
|
||||
Some(e) => e.clone(),
|
||||
None => {
|
||||
return Err(TensorError::DimError(
|
||||
"Cannot get last element of empty tensor".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Tensor::new(Some(&[res]), &[1])
|
||||
}
|
||||
|
||||
/// Maps a function to tensors and enumerates in parallel
|
||||
/// ```
|
||||
/// use ezkl::tensor::{Tensor, TensorError};
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
|
||||
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
|
||||
/// ```
|
||||
pub fn par_enum_map_mut_filtered<
|
||||
@@ -1293,7 +1230,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
E: Error + std::marker::Send + std::marker::Sync,
|
||||
>(
|
||||
&mut self,
|
||||
filter_indices: &std::collections::HashSet<&usize>,
|
||||
filter_indices: &std::collections::HashSet<usize>,
|
||||
f: F,
|
||||
) -> Result<(), E>
|
||||
where
|
||||
@@ -1308,103 +1245,13 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
#[allow(unsafe_code)]
|
||||
/// Perform a tensor operation on the GPU using Metal.
|
||||
pub fn metal_tensor_op<T: Clone + TensorType + IntoI64 + Send + Sync>(
|
||||
v: &Tensor<T>,
|
||||
w: &Tensor<T>,
|
||||
op: &str,
|
||||
) -> Tensor<T> {
|
||||
assert_eq!(v.dims(), w.dims());
|
||||
|
||||
log::trace!("------------------------------------------------");
|
||||
|
||||
let start = Instant::now();
|
||||
let v = v
|
||||
.par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64()))
|
||||
.unwrap();
|
||||
let w = w
|
||||
.par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64()))
|
||||
.unwrap();
|
||||
log::trace!("Time to map tensors: {:?}", start.elapsed());
|
||||
|
||||
objc::rc::autoreleasepool(|| {
|
||||
// create function pipeline.
|
||||
// this compiles the function, so a pipline can't be created in performance sensitive code.
|
||||
|
||||
let pipeline = &PIPELINES[op];
|
||||
|
||||
let length = v.len() as u64;
|
||||
let size = length * core::mem::size_of::<i64>() as u64;
|
||||
assert_eq!(v.len(), w.len());
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
let buffer_a = DEVICE.new_buffer_with_data(
|
||||
unsafe { std::mem::transmute(v.as_ptr()) },
|
||||
size,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
let buffer_b = DEVICE.new_buffer_with_data(
|
||||
unsafe { std::mem::transmute(w.as_ptr()) },
|
||||
size,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
let buffer_result = DEVICE.new_buffer(
|
||||
size, // the operation will return an array with the same size.
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
log::trace!("Time to load buffers: {:?}", start.elapsed());
|
||||
|
||||
// for sending commands, a command buffer is needed.
|
||||
let start = Instant::now();
|
||||
let command_buffer = QUEUE.new_command_buffer();
|
||||
log::trace!("Time to load command buffer: {:?}", start.elapsed());
|
||||
// to write commands into a buffer an encoder is needed, in our case a compute encoder.
|
||||
let start = Instant::now();
|
||||
let compute_encoder = command_buffer.new_compute_command_encoder();
|
||||
compute_encoder.set_compute_pipeline_state(&pipeline);
|
||||
compute_encoder.set_buffers(
|
||||
0,
|
||||
&[Some(&buffer_a), Some(&buffer_b), Some(&buffer_result)],
|
||||
&[0; 3],
|
||||
);
|
||||
log::trace!("Time to load compute encoder: {:?}", start.elapsed());
|
||||
|
||||
// specify thread count and organization
|
||||
let start = Instant::now();
|
||||
let grid_size = MTLSize::new(length, 1, 1);
|
||||
let threadgroup_size = MTLSize::new(length, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_size, threadgroup_size);
|
||||
log::trace!("Time to dispatch threads: {:?}", start.elapsed());
|
||||
|
||||
// end encoding and execute commands
|
||||
let start = Instant::now();
|
||||
compute_encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
|
||||
command_buffer.wait_until_completed();
|
||||
log::trace!("Time to commit: {:?}", start.elapsed());
|
||||
|
||||
let start = Instant::now();
|
||||
let ptr = buffer_result.contents() as *const i64;
|
||||
let len = buffer_result.length() as usize / std::mem::size_of::<i64>();
|
||||
let slice = unsafe { core::slice::from_raw_parts(ptr, len) };
|
||||
let res = Tensor::new(Some(&slice.to_vec()), &v.dims()).unwrap();
|
||||
log::trace!("Time to get result: {:?}", start.elapsed());
|
||||
|
||||
res.map(|x| T::from_i64(x))
|
||||
})
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType> Tensor<Tensor<T>> {
|
||||
/// Flattens a tensor of tensors
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let mut b = Tensor::<i32>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let mut b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
/// let mut c = Tensor::new(Some(&[a,b]), &[2]).unwrap();
|
||||
/// let mut d = c.combine().unwrap();
|
||||
/// assert_eq!(d.dims(), &[8]);
|
||||
@@ -1420,9 +1267,7 @@ impl<T: Clone + TensorType> Tensor<Tensor<T>> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Add
|
||||
for Tensor<T>
|
||||
{
|
||||
impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Add for Tensor<T> {
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
/// Adds tensors.
|
||||
/// # Arguments
|
||||
@@ -1432,42 +1277,43 @@ impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync + I
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Add;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.add(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 1D casting
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2]),
|
||||
/// &[1]).unwrap();
|
||||
/// let result = x.add(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
///
|
||||
/// // Now test 2D casting
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3]),
|
||||
/// &[2]).unwrap();
|
||||
/// let result = x.add(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 4, 4, 4]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 3, 4, 4, 4, 4]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
@@ -1501,13 +1347,14 @@ impl<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sync> Ne
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Neg;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.neg();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[-2, -1, -2, -1, -1, -1]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[-2, -1, -2, -1, -1, -1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn neg(self) -> Self {
|
||||
@@ -1520,9 +1367,7 @@ impl<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sync> Ne
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Sub
|
||||
for Tensor<T>
|
||||
{
|
||||
impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Sub for Tensor<T> {
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
/// Subtracts tensors.
|
||||
/// # Arguments
|
||||
@@ -1532,43 +1377,44 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync + I
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Sub;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.sub(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 1D sub
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2]),
|
||||
/// &[1],
|
||||
/// ).unwrap();
|
||||
/// let result = x.sub(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 2D sub
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3]),
|
||||
/// &[2],
|
||||
/// ).unwrap();
|
||||
/// let result = x.sub(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, -1, 0, -2, -2, -2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, -1, 0, -2, -2, -2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
@@ -1594,9 +1440,7 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync + I
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Mul
|
||||
for Tensor<T>
|
||||
{
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mul for Tensor<T> {
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
/// Elementwise multiplies tensors.
|
||||
/// # Arguments
|
||||
@@ -1606,41 +1450,42 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + I
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Mul;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.mul(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 1D mult
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2]),
|
||||
/// &[1]).unwrap();
|
||||
/// let result = x.mul(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 2D mult
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 2]),
|
||||
/// &[2]).unwrap();
|
||||
/// let result = x.mul(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
@@ -1666,7 +1511,7 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + I
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Tensor<T> {
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Tensor<T> {
|
||||
/// Elementwise raise a tensor to the nth power.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -1675,13 +1520,14 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + I
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Mul;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.pow(3).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[8, 3375, 8, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[8, 3375, 8, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn pow(&self, mut exp: u32) -> Result<Self, TensorError> {
|
||||
@@ -1715,30 +1561,31 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Div;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 1, 4, 1, 1, 4]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let y = Tensor::<i32>::new(
|
||||
/// let y = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.div(y).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // test 1D casting
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 1, 4, 1, 1, 4]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let y = Tensor::<i32>::new(
|
||||
/// let y = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2]),
|
||||
/// &[1],
|
||||
/// ).unwrap();
|
||||
/// let result = x.div(y).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[2, 0, 2, 0, 0, 2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 0, 2, 0, 0, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
@@ -1765,17 +1612,18 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Rem;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 1, 4, 1, 1, 4]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let y = Tensor::<i32>::new(
|
||||
/// let y = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.rem(y).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, 0, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn rem(self, rhs: Self) -> Self::Output {
|
||||
@@ -1859,25 +1707,25 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn tensor_clone() {
|
||||
let x = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
let x = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
assert_eq!(x, x.clone());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_eq() {
|
||||
let a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
let mut b = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
|
||||
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
let mut b = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
|
||||
b.reshape(&[3]).unwrap();
|
||||
let c = Tensor::<i32>::new(Some(&[1, 2, 4]), &[3]).unwrap();
|
||||
let d = Tensor::<i32>::new(Some(&[1, 2, 4]), &[3, 1]).unwrap();
|
||||
let c = Tensor::<IntegerRep>::new(Some(&[1, 2, 4]), &[3]).unwrap();
|
||||
let d = Tensor::<IntegerRep>::new(Some(&[1, 2, 4]), &[3, 1]).unwrap();
|
||||
assert_eq!(a, b);
|
||||
assert_ne!(a, c);
|
||||
assert_ne!(a, d);
|
||||
}
|
||||
#[test]
|
||||
fn tensor_slice() {
|
||||
let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
let b = Tensor::<i32>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,12 +1,12 @@
|
||||
use core::{iter::FilterMap, slice::Iter};
|
||||
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::{circuit::region::ConstantsMap, fieldutils::felt_to_integer_rep};
|
||||
use maybe_rayon::slice::Iter;
|
||||
|
||||
use super::{
|
||||
ops::{intercalate_values, pad, resize},
|
||||
*,
|
||||
};
|
||||
use halo2_proofs::{arithmetic::Field, circuit::Cell, plonk::Instance};
|
||||
use maybe_rayon::iter::{FilterMap, IntoParallelIterator, ParallelIterator};
|
||||
|
||||
pub(crate) fn create_constant_tensor<
|
||||
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
|
||||
@@ -54,6 +54,44 @@ pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::S
|
||||
AssignedConstant(AssignedCell<F, F>, F),
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> From<ValType<F>> for IntegerRep {
|
||||
fn from(val: ValType<F>) -> Self {
|
||||
match val {
|
||||
ValType::Value(v) => {
|
||||
let mut output = 0;
|
||||
let mut i = 0;
|
||||
v.map(|y| {
|
||||
let e = felt_to_integer_rep(y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::AssignedValue(v) => {
|
||||
let mut output = 0;
|
||||
let mut i = 0;
|
||||
v.evaluate().map(|y| {
|
||||
let e = felt_to_integer_rep(y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
let mut output = 0;
|
||||
let mut i = 0;
|
||||
v.value().map(|y| {
|
||||
let e = felt_to_integer_rep(*y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::Constant(v) => felt_to_integer_rep(v),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> ValType<F> {
|
||||
/// Returns the inner cell of the [ValType].
|
||||
pub fn cell(&self) -> Option<Cell> {
|
||||
@@ -121,44 +159,6 @@ impl<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + Partia
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> From<ValType<F>> for i32 {
|
||||
fn from(val: ValType<F>) -> Self {
|
||||
match val {
|
||||
ValType::Value(v) => {
|
||||
let mut output = 0_i32;
|
||||
let mut i = 0;
|
||||
v.map(|y| {
|
||||
let e = felt_to_i32(y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::AssignedValue(v) => {
|
||||
let mut output = 0_i32;
|
||||
let mut i = 0;
|
||||
v.evaluate().map(|y| {
|
||||
let e = felt_to_i32(y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
let mut output = 0_i32;
|
||||
let mut i = 0;
|
||||
v.value().map(|y| {
|
||||
let e = felt_to_i32(*y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::Constant(v) => felt_to_i32(v),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> From<F> for ValType<F> {
|
||||
fn from(t: F) -> ValType<F> {
|
||||
ValType::Constant(t)
|
||||
@@ -317,8 +317,8 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
/// Allocate a new [ValTensor::Value] from the given [Tensor] of [i64].
|
||||
pub fn from_i64_tensor(t: Tensor<i64>) -> ValTensor<F> {
|
||||
let inner = t.map(|x| ValType::Value(Value::known(i64_to_felt(x))));
|
||||
pub fn from_integer_rep_tensor(t: Tensor<IntegerRep>) -> ValTensor<F> {
|
||||
let inner = t.map(|x| ValType::Value(Value::known(integer_rep_to_felt(x))));
|
||||
inner.into()
|
||||
}
|
||||
|
||||
@@ -460,7 +460,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
&self,
|
||||
) -> FilterMap<Iter<'_, ValType<F>>, fn(&ValType<F>) -> Option<(F, ValType<F>)>> {
|
||||
match self {
|
||||
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
|
||||
ValTensor::Value { inner, .. } => inner.par_iter().filter_map(|x| {
|
||||
if let ValType::Constant(v) = x {
|
||||
Some((*v, x.clone()))
|
||||
} else {
|
||||
@@ -521,9 +521,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Calls `int_evals` on the inner tensor.
|
||||
pub fn get_int_evals(&self) -> Result<Tensor<i64>, TensorError> {
|
||||
pub fn int_evals(&self) -> Result<Tensor<IntegerRep>, TensorError> {
|
||||
// finally convert to vector of integers
|
||||
let mut integer_evals: Vec<i64> = vec![];
|
||||
let mut integer_evals: Vec<IntegerRep> = vec![];
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: _, ..
|
||||
@@ -531,25 +531,26 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
// we have to push to an externally created vector or else vaf.map() returns an evaluation wrapped in Value<> (which we don't want)
|
||||
let _ = v.map(|vaf| match vaf {
|
||||
ValType::Value(v) => v.map(|f| {
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(f));
|
||||
integer_evals.push(crate::fieldutils::felt_to_integer_rep(f));
|
||||
}),
|
||||
ValType::AssignedValue(v) => v.map(|f| {
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate()));
|
||||
integer_evals.push(crate::fieldutils::felt_to_integer_rep(f.evaluate()));
|
||||
}),
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
v.value_field().map(|f| {
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate()));
|
||||
integer_evals
|
||||
.push(crate::fieldutils::felt_to_integer_rep(f.evaluate()));
|
||||
})
|
||||
}
|
||||
ValType::Constant(v) => {
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(v));
|
||||
integer_evals.push(crate::fieldutils::felt_to_integer_rep(v));
|
||||
Value::unknown()
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
let mut tensor: Tensor<i64> = integer_evals.into_iter().into();
|
||||
let mut tensor: Tensor<IntegerRep> = integer_evals.into_iter().into();
|
||||
match tensor.reshape(self.dims()) {
|
||||
_ => {}
|
||||
};
|
||||
@@ -573,6 +574,27 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Calls `get_slice` on the inner tensor.
|
||||
pub fn last(&self) -> Result<ValTensor<F>, TensorError> {
|
||||
let slice = match self {
|
||||
ValTensor::Value {
|
||||
inner: v,
|
||||
dims: _,
|
||||
scale,
|
||||
} => {
|
||||
let inner = v.last()?;
|
||||
let dims = inner.dims().to_vec();
|
||||
ValTensor::Value {
|
||||
inner,
|
||||
dims,
|
||||
scale: *scale,
|
||||
}
|
||||
}
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
Ok(slice)
|
||||
}
|
||||
|
||||
/// Calls `get_slice` on the inner tensor.
|
||||
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, TensorError> {
|
||||
if indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims() {
|
||||
@@ -753,43 +775,72 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// gets constants
|
||||
pub fn get_const_zero_indices(&self) -> Result<Vec<usize>, TensorError> {
|
||||
/// remove constant zero values constants
|
||||
pub fn remove_const_zero_values(&mut self) {
|
||||
match self {
|
||||
ValTensor::Value { inner: v, .. } => {
|
||||
let mut indices = vec![];
|
||||
for (i, e) in v.iter().enumerate() {
|
||||
if let ValType::Constant(r) = e {
|
||||
if *r == F::ZERO {
|
||||
indices.push(i);
|
||||
ValTensor::Value { inner: v, dims, .. } => {
|
||||
*v = v
|
||||
.clone()
|
||||
.into_par_iter()
|
||||
.filter_map(|e| {
|
||||
if let ValType::Constant(r) = e {
|
||||
if r == F::ZERO {
|
||||
return None;
|
||||
}
|
||||
} else if let ValType::AssignedConstant(_, r) = e {
|
||||
if r == F::ZERO {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
} else if let ValType::AssignedConstant(_, r) = e {
|
||||
if *r == F::ZERO {
|
||||
indices.push(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(indices)
|
||||
Some(e)
|
||||
})
|
||||
.collect();
|
||||
*dims = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => Ok(vec![]),
|
||||
ValTensor::Instance { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// gets constants
|
||||
pub fn get_const_indices(&self) -> Result<Vec<usize>, TensorError> {
|
||||
pub fn get_const_zero_indices(&self) -> Vec<usize> {
|
||||
match self {
|
||||
ValTensor::Value { inner: v, .. } => {
|
||||
let mut indices = vec![];
|
||||
for (i, e) in v.iter().enumerate() {
|
||||
if let ValType::Constant(_) = e {
|
||||
indices.push(i);
|
||||
} else if let ValType::AssignedConstant(_, _) = e {
|
||||
indices.push(i);
|
||||
ValTensor::Value { inner: v, .. } => v
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, e)| {
|
||||
if let ValType::Constant(r) = e {
|
||||
if *r == F::ZERO {
|
||||
return Some(i);
|
||||
}
|
||||
} else if let ValType::AssignedConstant(_, r) = e {
|
||||
if *r == F::ZERO {
|
||||
return Some(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(indices)
|
||||
}
|
||||
ValTensor::Instance { .. } => Ok(vec![]),
|
||||
None
|
||||
})
|
||||
.collect(),
|
||||
ValTensor::Instance { .. } => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// gets constants
|
||||
pub fn get_const_indices(&self) -> Vec<usize> {
|
||||
match self {
|
||||
ValTensor::Value { inner: v, .. } => v
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, e)| {
|
||||
if let ValType::Constant(_) = e {
|
||||
Some(i)
|
||||
} else if let ValType::AssignedConstant(_, _) = e {
|
||||
Some(i)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
ValTensor::Instance { .. } => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
@@ -952,25 +1003,22 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
/// A [String] representation of the [ValTensor] for display, for example in showing intermediate values in a computational graph.
|
||||
pub fn show(&self) -> String {
|
||||
match self.clone() {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: _, ..
|
||||
} => {
|
||||
let r: Tensor<i32> = v.map(|x| x.into());
|
||||
if r.len() > 10 {
|
||||
let start = r[..5].to_vec();
|
||||
let end = r[r.len() - 5..].to_vec();
|
||||
// print the two split by ... in the middle
|
||||
format!(
|
||||
"[{} ... {}]",
|
||||
start.iter().map(|x| format!("{}", x)).join(", "),
|
||||
end.iter().map(|x| format!("{}", x)).join(", ")
|
||||
)
|
||||
} else {
|
||||
format!("{:?}", r)
|
||||
}
|
||||
}
|
||||
_ => "ValTensor not PrevAssigned".into(),
|
||||
let r = match self.int_evals() {
|
||||
Ok(v) => v,
|
||||
Err(_) => return "ValTensor not PrevAssigned".into(),
|
||||
};
|
||||
|
||||
if r.len() > 10 {
|
||||
let start = r[..5].to_vec();
|
||||
let end = r[r.len() - 5..].to_vec();
|
||||
// print the two split by ... in the middle
|
||||
format!(
|
||||
"[{} ... {}]",
|
||||
start.iter().map(|x| format!("{}", x)).join(", "),
|
||||
end.iter().map(|x| format!("{}", x)).join(", ")
|
||||
)
|
||||
} else {
|
||||
format!("{:?}", r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -319,7 +319,7 @@ impl VarTensor {
|
||||
region: &mut Region<F>,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
omissions: &HashSet<&usize>,
|
||||
omissions: &HashSet<usize>,
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
|
||||
let mut assigned_coord = 0;
|
||||
@@ -368,7 +368,7 @@ impl VarTensor {
|
||||
.sum::<usize>();
|
||||
let dims = &dims[*idx];
|
||||
// this should never ever fail
|
||||
let t: Tensor<i32> = Tensor::new(None, dims).unwrap();
|
||||
let t: Tensor<IntegerRep> = Tensor::new(None, dims).unwrap();
|
||||
Ok(t.enum_map(|coord, _| {
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord);
|
||||
region.assign_advice_from_instance(
|
||||
@@ -497,7 +497,7 @@ impl VarTensor {
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord * step);
|
||||
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
|
||||
// assert that duplication occurred correctly
|
||||
assert_eq!(Into::<i32>::into(k.clone()), Into::<i32>::into(v[coord - 1].clone()));
|
||||
assert_eq!(Into::<IntegerRep>::into(k.clone()), Into::<IntegerRep>::into(v[coord - 1].clone()));
|
||||
};
|
||||
|
||||
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
|
||||
@@ -533,13 +533,14 @@ impl VarTensor {
|
||||
if matches!(check_mode, CheckMode::SAFE) {
|
||||
// during key generation this will be 0 so we use this as a flag to check
|
||||
// TODO: this isn't very safe and would be better to get the phase directly
|
||||
let is_assigned = !Into::<Tensor<i32>>::into(res.clone().get_inner().unwrap())
|
||||
let res_evals = res.int_evals().unwrap();
|
||||
let is_assigned = res_evals
|
||||
.iter()
|
||||
.all(|&x| x == 0);
|
||||
if is_assigned {
|
||||
if !is_assigned {
|
||||
assert_eq!(
|
||||
Into::<Tensor<i32>>::into(values.get_inner().unwrap()),
|
||||
Into::<Tensor<i32>>::into(res.get_inner().unwrap())
|
||||
values.int_evals().unwrap(),
|
||||
res_evals
|
||||
)};
|
||||
}
|
||||
|
||||
|
||||
89
src/wasm.rs
89
src/wasm.rs
@@ -1,41 +1,52 @@
|
||||
use crate::circuit::modules::polycommit::PolyCommitChip;
|
||||
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
use crate::circuit::modules::poseidon::PoseidonChip;
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::fieldutils::felt_to_i64;
|
||||
use crate::fieldutils::i64_to_felt;
|
||||
use crate::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use crate::graph::quantize_float;
|
||||
use crate::graph::scale_to_multiplier;
|
||||
use crate::graph::{GraphCircuit, GraphSettings};
|
||||
use crate::pfsys::create_proof_circuit;
|
||||
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
|
||||
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
|
||||
use crate::pfsys::verify_proof_circuit;
|
||||
use crate::pfsys::TranscriptType;
|
||||
use crate::tensor::TensorType;
|
||||
use crate::CheckMode;
|
||||
use crate::Commitments;
|
||||
use crate::{
|
||||
circuit::{
|
||||
modules::{
|
||||
polycommit::PolyCommitChip,
|
||||
poseidon::{
|
||||
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
|
||||
PoseidonChip,
|
||||
},
|
||||
Module,
|
||||
},
|
||||
region::RegionSettings,
|
||||
},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
|
||||
graph::{
|
||||
modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit,
|
||||
GraphSettings,
|
||||
},
|
||||
pfsys::{
|
||||
create_proof_circuit,
|
||||
evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript},
|
||||
verify_proof_circuit, TranscriptType,
|
||||
},
|
||||
tensor::TensorType,
|
||||
CheckMode, Commitments,
|
||||
};
|
||||
use console_error_panic_hook;
|
||||
use halo2_proofs::plonk::*;
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver};
|
||||
use halo2_proofs::poly::ipa::multiopen::{ProverIPA, VerifierIPA};
|
||||
use halo2_proofs::poly::ipa::{
|
||||
commitment::{IPACommitmentScheme, ParamsIPA},
|
||||
strategy::SingleStrategy as IPASingleStrategy,
|
||||
use halo2_proofs::{
|
||||
plonk::*,
|
||||
poly::{
|
||||
commitment::{CommitmentScheme, ParamsProver},
|
||||
ipa::{
|
||||
commitment::{IPACommitmentScheme, ParamsIPA},
|
||||
multiopen::{ProverIPA, VerifierIPA},
|
||||
strategy::SingleStrategy as IPASingleStrategy,
|
||||
},
|
||||
kzg::{
|
||||
commitment::{KZGCommitmentScheme, ParamsKZG},
|
||||
multiopen::{ProverSHPLONK, VerifierSHPLONK},
|
||||
strategy::SingleStrategy as KZGSingleStrategy,
|
||||
},
|
||||
VerificationStrategy,
|
||||
},
|
||||
};
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::{
|
||||
commitment::{KZGCommitmentScheme, ParamsKZG},
|
||||
strategy::SingleStrategy as KZGSingleStrategy,
|
||||
};
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
use halo2curves::ff::{FromUniformBytes, PrimeField};
|
||||
use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use halo2curves::{
|
||||
bn256::{Bn256, Fr, G1Affine},
|
||||
ff::{FromUniformBytes, PrimeField},
|
||||
};
|
||||
use snark_verifier::{loader::native::NativeLoader, system::halo2::transcript::evm::EvmTranscript};
|
||||
use std::str::FromStr;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
|
||||
@@ -113,7 +124,7 @@ pub fn feltToInt(
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
Ok(wasm_bindgen::Clamped(
|
||||
serde_json::to_vec(&felt_to_i64(felt))
|
||||
serde_json::to_vec(&felt_to_integer_rep(felt))
|
||||
.map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?,
|
||||
))
|
||||
}
|
||||
@@ -127,7 +138,7 @@ pub fn feltToFloat(
|
||||
) -> Result<f64, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
let int_rep = felt_to_i64(felt);
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
Ok(int_rep as f64 / multiplier)
|
||||
}
|
||||
@@ -141,7 +152,7 @@ pub fn floatToFelt(
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
let int_rep =
|
||||
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
let felt = i64_to_felt(int_rep);
|
||||
let felt = integer_rep_to_felt(int_rep);
|
||||
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
|
||||
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|
||||
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
|
||||
@@ -275,7 +286,7 @@ pub fn genWitness(
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
let witness = circuit
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, false, false)
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, RegionSettings::all_false())
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
serde_json::to_vec(&witness)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
mod native_tests {
|
||||
|
||||
use ezkl::circuit::Tolerance;
|
||||
use ezkl::fieldutils::{felt_to_i64, i64_to_felt};
|
||||
use ezkl::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
|
||||
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
|
||||
@@ -183,12 +183,13 @@ mod native_tests {
|
||||
|
||||
const PF_FAILURE_AGGR: &str = "examples/test_failure_aggr_proof.json";
|
||||
|
||||
const LARGE_TESTS: [&str; 5] = [
|
||||
const LARGE_TESTS: [&str; 6] = [
|
||||
"self_attention",
|
||||
"nanoGPT",
|
||||
"multihead_attention",
|
||||
"mobilenet",
|
||||
"mnist_gan",
|
||||
"smallworm",
|
||||
];
|
||||
|
||||
const ACCURACY_CAL_TESTS: [&str; 6] = [
|
||||
@@ -200,7 +201,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 93] = [
|
||||
const TESTS: [&str; 94] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
@@ -298,6 +299,7 @@ mod native_tests {
|
||||
"1l_lppool",
|
||||
"lstm_large", // 91
|
||||
"lstm_medium", // 92
|
||||
"lenet_5", // 93
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
@@ -536,7 +538,7 @@ mod native_tests {
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=92 {
|
||||
seq!(N in 0..=93 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -644,6 +646,15 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_hashed_params_public_inputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_fixed_inputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
@@ -940,7 +951,7 @@ mod native_tests {
|
||||
|
||||
});
|
||||
|
||||
seq!(N in 0..=4 {
|
||||
seq!(N in 0..=5 {
|
||||
|
||||
#(#[test_case(LARGE_TESTS[N])])*
|
||||
#[ignore]
|
||||
@@ -1066,6 +1077,15 @@ mod native_tests {
|
||||
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "file", "on-chain", "polycommit", "public", "polycommit");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
#(#[test_case(TESTS_ON_CHAIN_INPUT[N])])*
|
||||
fn kzg_evm_on_chain_all_kzg_params_prove_and_verify_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
let _anvil_child = crate::native_tests::start_anvil(true, Hardfork::Latest);
|
||||
kzg_evm_on_chain_input_prove_and_verify(path, test.to_string(), "file", "file", "polycommit", "polycommit", "polycommit");
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -1402,10 +1422,10 @@ mod native_tests {
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::zero()
|
||||
} else {
|
||||
i64_to_felt(
|
||||
(felt_to_i64(*v) as f32
|
||||
integer_rep_to_felt(
|
||||
(felt_to_integer_rep(*v) as f32
|
||||
* (rand::thread_rng().gen_range(-0.01..0.01) * tolerance))
|
||||
as i64,
|
||||
as IntegerRep,
|
||||
)
|
||||
};
|
||||
|
||||
@@ -1425,10 +1445,10 @@ mod native_tests {
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::from(2)
|
||||
} else {
|
||||
i64_to_felt(
|
||||
(felt_to_i64(*v) as f32
|
||||
integer_rep_to_felt(
|
||||
(felt_to_integer_rep(*v) as f32
|
||||
* (rand::thread_rng().gen_range(0.02..0.1) * tolerance))
|
||||
as i64,
|
||||
as IntegerRep,
|
||||
)
|
||||
};
|
||||
*v + perturbation
|
||||
@@ -2330,7 +2350,6 @@ mod native_tests {
|
||||
|
||||
let model_path = format!("{}/{}/network.compiled", test_dir, example_name);
|
||||
let settings_path = format!("{}/{}/settings.json", test_dir, example_name);
|
||||
|
||||
init_params(settings_path.clone().into());
|
||||
|
||||
let data_path = format!("{}/{}/input.json", test_dir, example_name);
|
||||
@@ -2342,62 +2361,6 @@ mod native_tests {
|
||||
let test_input_source = format!("--input-source={}", input_source);
|
||||
let test_output_source = format!("--output-source={}", output_source);
|
||||
|
||||
// load witness
|
||||
let witness: GraphWitness = GraphWitness::from_path(witness_path.clone().into()).unwrap();
|
||||
let mut input: GraphData = GraphData::from_path(data_path.clone().into()).unwrap();
|
||||
|
||||
if input_visibility == "hashed" {
|
||||
let hashes = witness.processed_inputs.unwrap().poseidon_hash.unwrap();
|
||||
input.input_data = DataSource::File(
|
||||
hashes
|
||||
.iter()
|
||||
.map(|h| vec![FileSourceInner::Field(*h)])
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
if output_visibility == "hashed" {
|
||||
let hashes = witness.processed_outputs.unwrap().poseidon_hash.unwrap();
|
||||
input.output_data = Some(DataSource::File(
|
||||
hashes
|
||||
.iter()
|
||||
.map(|h| vec![FileSourceInner::Field(*h)])
|
||||
.collect(),
|
||||
));
|
||||
} else {
|
||||
input.output_data = Some(DataSource::File(
|
||||
witness
|
||||
.pretty_elements
|
||||
.unwrap()
|
||||
.rescaled_outputs
|
||||
.iter()
|
||||
.map(|o| {
|
||||
o.iter()
|
||||
.map(|f| FileSourceInner::Float(f.parse().unwrap()))
|
||||
.collect()
|
||||
})
|
||||
.collect(),
|
||||
));
|
||||
}
|
||||
|
||||
input.save(data_path.clone().into()).unwrap();
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"setup-test-evm-data",
|
||||
"-D",
|
||||
data_path.as_str(),
|
||||
"-M",
|
||||
&model_path,
|
||||
"--test-data",
|
||||
test_on_chain_data_path.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
test_input_source.as_str(),
|
||||
test_output_source.as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"setup",
|
||||
@@ -2412,6 +2375,82 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// generate the witness, passing the vk path to generate the necessary kzg commits
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"gen-witness",
|
||||
"-D",
|
||||
&data_path,
|
||||
"-M",
|
||||
&model_path,
|
||||
"-O",
|
||||
&witness_path,
|
||||
"--vk-path",
|
||||
&format!("{}/{}/key.vk", test_dir, example_name),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// load witness
|
||||
let witness: GraphWitness = GraphWitness::from_path(witness_path.clone().into()).unwrap();
|
||||
// print out the witness
|
||||
println!("WITNESS: {:?}", witness);
|
||||
let mut input: GraphData = GraphData::from_path(data_path.clone().into()).unwrap();
|
||||
if input_source != "file" || output_source != "file" {
|
||||
println!("on chain input");
|
||||
if input_visibility == "hashed" {
|
||||
let hashes = witness.processed_inputs.unwrap().poseidon_hash.unwrap();
|
||||
input.input_data = DataSource::File(
|
||||
hashes
|
||||
.iter()
|
||||
.map(|h| vec![FileSourceInner::Field(*h)])
|
||||
.collect(),
|
||||
);
|
||||
}
|
||||
if output_visibility == "hashed" {
|
||||
let hashes = witness.processed_outputs.unwrap().poseidon_hash.unwrap();
|
||||
input.output_data = Some(DataSource::File(
|
||||
hashes
|
||||
.iter()
|
||||
.map(|h| vec![FileSourceInner::Field(*h)])
|
||||
.collect(),
|
||||
));
|
||||
} else {
|
||||
input.output_data = Some(DataSource::File(
|
||||
witness
|
||||
.pretty_elements
|
||||
.unwrap()
|
||||
.rescaled_outputs
|
||||
.iter()
|
||||
.map(|o| {
|
||||
o.iter()
|
||||
.map(|f| FileSourceInner::Float(f.parse().unwrap()))
|
||||
.collect()
|
||||
})
|
||||
.collect(),
|
||||
));
|
||||
}
|
||||
input.save(data_path.clone().into()).unwrap();
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"setup-test-evm-data",
|
||||
"-D",
|
||||
data_path.as_str(),
|
||||
"-M",
|
||||
&model_path,
|
||||
"--test-data",
|
||||
test_on_chain_data_path.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
test_input_source.as_str(),
|
||||
test_output_source.as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
}
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"prove",
|
||||
@@ -2502,13 +2541,19 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let deploy_evm_data_path = if input_source != "file" || output_source != "file" {
|
||||
test_on_chain_data_path.clone()
|
||||
} else {
|
||||
data_path.clone()
|
||||
};
|
||||
|
||||
let addr_path_da_arg = format!("--addr-path={}/{}/addr_da.txt", test_dir, example_name);
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"deploy-evm-da",
|
||||
format!("--settings-path={}", settings_path).as_str(),
|
||||
"-D",
|
||||
test_on_chain_data_path.as_str(),
|
||||
deploy_evm_data_path.as_str(),
|
||||
"--sol-code-path",
|
||||
sol_arg.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
@@ -2546,40 +2591,42 @@ mod native_tests {
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
// Create a new set of test on chain data
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"setup-test-evm-data",
|
||||
// Create a new set of test on chain data only for the on-chain input source
|
||||
if input_source != "file" || output_source != "file" {
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"setup-test-evm-data",
|
||||
"-D",
|
||||
data_path.as_str(),
|
||||
"-M",
|
||||
&model_path,
|
||||
"--test-data",
|
||||
test_on_chain_data_path.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
test_input_source.as_str(),
|
||||
test_output_source.as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
assert!(status.success());
|
||||
|
||||
let deployed_addr_arg = format!("--addr={}", addr_da);
|
||||
|
||||
let args: Vec<&str> = vec![
|
||||
"test-update-account-calls",
|
||||
deployed_addr_arg.as_str(),
|
||||
"-D",
|
||||
data_path.as_str(),
|
||||
"-M",
|
||||
&model_path,
|
||||
"--test-data",
|
||||
test_on_chain_data_path.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
test_input_source.as_str(),
|
||||
test_output_source.as_str(),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
];
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
assert!(status.success());
|
||||
|
||||
let deployed_addr_arg = format!("--addr={}", addr_da);
|
||||
|
||||
let args = vec![
|
||||
"test-update-account-calls",
|
||||
deployed_addr_arg.as_str(),
|
||||
"-D",
|
||||
test_on_chain_data_path.as_str(),
|
||||
rpc_arg.as_str(),
|
||||
];
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args(&args)
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
|
||||
assert!(status.success());
|
||||
assert!(status.success());
|
||||
}
|
||||
// As sanity check, add example that should fail.
|
||||
let args = vec![
|
||||
"verify-evm",
|
||||
|
||||
@@ -423,6 +423,74 @@ async def test_create_evm_verifier():
|
||||
assert res == True
|
||||
assert os.path.isfile(sol_code_path)
|
||||
|
||||
async def test_create_evm_verifier_separate_vk():
|
||||
"""
|
||||
Create EVM a verifier with solidity code and separate vk
|
||||
In order to run this test you will need to install solc in your environment
|
||||
"""
|
||||
vk_path = os.path.join(folder_path, 'test_evm.vk')
|
||||
settings_path = os.path.join(folder_path, 'settings.json')
|
||||
sol_code_path = os.path.join(folder_path, 'test_separate.sol')
|
||||
vk_code_path = os.path.join(folder_path, 'test_vk.sol')
|
||||
abi_path = os.path.join(folder_path, 'test_separate.abi')
|
||||
abi_vk_path = os.path.join(folder_path, 'test_vk_separate.abi')
|
||||
proof_path = os.path.join(folder_path, 'test_evm.pf')
|
||||
calldata_path = os.path.join(folder_path, 'calldata.bytes')
|
||||
|
||||
# # res is now a vector of bytes
|
||||
# res = ezkl.encode_evm_calldata(proof_path, calldata_path)
|
||||
|
||||
# assert os.path.isfile(calldata_path)
|
||||
# assert len(res) > 0
|
||||
|
||||
|
||||
res = await ezkl.create_evm_verifier(
|
||||
vk_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
srs_path=srs_path,
|
||||
render_vk_seperately=True
|
||||
)
|
||||
|
||||
res = await ezkl.create_evm_vk(
|
||||
vk_path,
|
||||
settings_path,
|
||||
vk_code_path,
|
||||
abi_vk_path,
|
||||
srs_path=srs_path,
|
||||
)
|
||||
|
||||
assert res == True
|
||||
assert os.path.isfile(sol_code_path)
|
||||
|
||||
|
||||
async def test_deploy_evm_separate_vk():
|
||||
"""
|
||||
Test deployment of the separate verifier smart contract + vk
|
||||
In order to run this you will need to install solc in your environment
|
||||
"""
|
||||
addr_path_verifier = os.path.join(folder_path, 'address_separate.json')
|
||||
addr_path_vk = os.path.join(folder_path, 'address_vk.json')
|
||||
sol_code_path = os.path.join(folder_path, 'test_separate.sol')
|
||||
vk_code_path = os.path.join(folder_path, 'test_vk.sol')
|
||||
|
||||
# TODO: without optimization there will be out of gas errors
|
||||
# sol_code_path = os.path.join(folder_path, 'test.sol')
|
||||
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path_verifier,
|
||||
sol_code_path,
|
||||
rpc_url=anvil_url,
|
||||
)
|
||||
|
||||
res = await ezkl.deploy_vk_evm(
|
||||
addr_path_vk,
|
||||
vk_code_path,
|
||||
rpc_url=anvil_url,
|
||||
)
|
||||
|
||||
assert res == True
|
||||
|
||||
async def test_deploy_evm():
|
||||
"""
|
||||
@@ -503,6 +571,47 @@ async def test_verify_evm():
|
||||
|
||||
assert res == True
|
||||
|
||||
async def test_verify_evm_separate_vk():
|
||||
"""
|
||||
Verifies an evm proof
|
||||
In order to run this you will need to install solc in your environment
|
||||
"""
|
||||
proof_path = os.path.join(folder_path, 'test_evm.pf')
|
||||
addr_path_verifier = os.path.join(folder_path, 'address_separate.json')
|
||||
addr_path_vk = os.path.join(folder_path, 'address_vk.json')
|
||||
proof_path = os.path.join(folder_path, 'test_evm.pf')
|
||||
calldata_path = os.path.join(folder_path, 'calldata_separate.bytes')
|
||||
|
||||
with open(addr_path_verifier, 'r') as file:
|
||||
addr_verifier = file.read().rstrip()
|
||||
|
||||
print(addr_verifier)
|
||||
|
||||
with open(addr_path_vk, 'r') as file:
|
||||
addr_vk = file.read().rstrip()
|
||||
|
||||
print(addr_vk)
|
||||
|
||||
# res is now a vector of bytes
|
||||
res = ezkl.encode_evm_calldata(proof_path, calldata_path, addr_vk=addr_vk)
|
||||
|
||||
assert os.path.isfile(calldata_path)
|
||||
assert len(res) > 0
|
||||
|
||||
# TODO: without optimization there will be out of gas errors
|
||||
# sol_code_path = os.path.join(folder_path, 'test.sol')
|
||||
|
||||
res = await ezkl.verify_evm(
|
||||
addr_verifier,
|
||||
proof_path,
|
||||
rpc_url=anvil_url,
|
||||
addr_vk=addr_vk,
|
||||
# sol_code_path
|
||||
# optimizer_runs
|
||||
)
|
||||
|
||||
assert res == True
|
||||
|
||||
|
||||
async def test_aggregate_and_verify_aggr():
|
||||
data_path = os.path.join(
|
||||
@@ -761,6 +870,7 @@ def get_examples():
|
||||
'accuracy',
|
||||
'linear_regression',
|
||||
"mnist_gan",
|
||||
"smallworm",
|
||||
]
|
||||
examples = []
|
||||
for subdir, _, _ in os.walk(os.path.join(examples_path, "onnx")):
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user