mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1da7cac21a | ||
|
|
ada45a3197 | ||
|
|
616b421967 | ||
|
|
f64f0ecd23 | ||
|
|
5be12b7a54 |
4
.github/workflows/engine.yml
vendored
4
.github/workflows/engine.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- name: Install binaryen
|
||||
run: |
|
||||
set -e
|
||||
|
||||
2
.github/workflows/large-tests.yml
vendored
2
.github/workflows/large-tests.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: nanoGPT Mock
|
||||
|
||||
4
.github/workflows/pypi.yml
vendored
4
.github/workflows/pypi.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
@@ -86,7 +86,7 @@ jobs:
|
||||
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
|
||||
|
||||
14
.github/workflows/release.yml
vendored
14
.github/workflows/release.yml
vendored
@@ -45,7 +45,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Checkout repo
|
||||
@@ -106,27 +106,27 @@ jobs:
|
||||
include:
|
||||
- build: windows-msvc
|
||||
os: windows-latest
|
||||
rust: nightly-2024-02-06
|
||||
rust: nightly-2024-07-18
|
||||
target: x86_64-pc-windows-msvc
|
||||
- build: macos
|
||||
os: macos-13
|
||||
rust: nightly-2024-02-06
|
||||
rust: nightly-2024-07-18
|
||||
target: x86_64-apple-darwin
|
||||
- build: macos-aarch64
|
||||
os: macos-13
|
||||
rust: nightly-2024-02-06
|
||||
rust: nightly-2024-07-18
|
||||
target: aarch64-apple-darwin
|
||||
- build: linux-musl
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2024-02-06
|
||||
rust: nightly-2024-07-18
|
||||
target: x86_64-unknown-linux-musl
|
||||
- build: linux-gnu
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2024-02-06
|
||||
rust: nightly-2024-07-18
|
||||
target: x86_64-unknown-linux-gnu
|
||||
- build: linux-aarch64
|
||||
os: ubuntu-22.04
|
||||
rust: nightly-2024-02-06
|
||||
rust: nightly-2024-07-18
|
||||
target: aarch64-unknown-linux-gnu
|
||||
|
||||
steps:
|
||||
|
||||
60
.github/workflows/rust.yml
vendored
60
.github/workflows/rust.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Docs
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -73,7 +73,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -106,7 +106,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -172,7 +172,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -189,7 +189,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
@@ -199,7 +199,7 @@ jobs:
|
||||
- name: Install wasm32-unknown-unknown
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- name: Run wasm verifier tests
|
||||
# on mac:
|
||||
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
|
||||
@@ -212,7 +212,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -229,7 +229,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -260,6 +260,8 @@ jobs:
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_input_::t --test-threads 32
|
||||
- name: hashed params
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_params_::t --test-threads 32
|
||||
- name: hashed params public inputs
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
|
||||
- name: hashed outputs
|
||||
run: cargo nextest run --release --verbose tests::mock_hashed_output_::t --test-threads 32
|
||||
- name: hashed inputs + params + outputs
|
||||
@@ -288,7 +290,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -306,7 +308,7 @@ jobs:
|
||||
node-version: "18.12.1"
|
||||
cache: "pnpm"
|
||||
- name: "Add rust-src"
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- name: Install dependencies for js tests and in-browser-evm-verifier package
|
||||
run: |
|
||||
pnpm install --frozen-lockfile
|
||||
@@ -328,7 +330,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: KZG prove and verify tests (EVM + VK rendered seperately)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
@@ -367,7 +369,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
@@ -375,7 +377,7 @@ jobs:
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
@@ -441,11 +443,11 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
@@ -475,7 +477,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -493,7 +495,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -510,7 +512,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -527,7 +529,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -537,7 +539,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: KZG prove and verify aggr tests
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
@@ -548,7 +550,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -568,7 +570,7 @@ jobs:
|
||||
python-version: "3.12"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Install cmake
|
||||
@@ -578,7 +580,7 @@ jobs:
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
|
||||
- name: Run pytest
|
||||
@@ -594,7 +596,7 @@ jobs:
|
||||
python-version: "3.12"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -643,7 +645,7 @@ jobs:
|
||||
python-version: "3.11"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-02-06
|
||||
toolchain: nightly-2024-07-18
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -653,7 +655,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
|
||||
- name: Install pip
|
||||
run: python -m ensurepip --upgrade
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
|
||||
50
Cargo.lock
generated
50
Cargo.lock
generated
@@ -1536,6 +1536,15 @@ dependencies = [
|
||||
"zeroize",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "deranged"
|
||||
version = "0.3.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4"
|
||||
dependencies = [
|
||||
"powerfmt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "derivative"
|
||||
version = "2.2.0"
|
||||
@@ -2297,7 +2306,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2#8cfca221f53069a0374687654882b99e729041d7#8cfca221f53069a0374687654882b99e729041d7"
|
||||
source = "git+https://github.com/zkonduit/halo2#098ac0ef3b29255e0e2524ecbb4e7e325ae6e7fd#098ac0ef3b29255e0e2524ecbb4e7e325ae6e7fd"
|
||||
dependencies = [
|
||||
"blake2b_simd",
|
||||
"env_logger",
|
||||
@@ -2310,6 +2319,7 @@ dependencies = [
|
||||
"rand_chacha",
|
||||
"rand_core 0.6.4",
|
||||
"rustacuda",
|
||||
"rustc-hash 2.0.0",
|
||||
"sha3 0.9.1",
|
||||
"tracing",
|
||||
]
|
||||
@@ -3262,6 +3272,12 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-conv"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.46"
|
||||
@@ -3775,6 +3791,12 @@ dependencies = [
|
||||
"postgres-protocol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "powerfmt"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.17"
|
||||
@@ -4007,7 +4029,7 @@ dependencies = [
|
||||
"pin-project-lite",
|
||||
"quinn-proto",
|
||||
"quinn-udp",
|
||||
"rustc-hash",
|
||||
"rustc-hash 1.1.0",
|
||||
"rustls",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
@@ -4023,7 +4045,7 @@ dependencies = [
|
||||
"bytes",
|
||||
"rand 0.8.5",
|
||||
"ring",
|
||||
"rustc-hash",
|
||||
"rustc-hash 1.1.0",
|
||||
"rustls",
|
||||
"slab",
|
||||
"thiserror",
|
||||
@@ -4447,6 +4469,12 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152"
|
||||
|
||||
[[package]]
|
||||
name = "rustc-hex"
|
||||
version = "2.1.0"
|
||||
@@ -5217,11 +5245,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "time"
|
||||
version = "0.3.23"
|
||||
version = "0.3.36"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446"
|
||||
checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885"
|
||||
dependencies = [
|
||||
"deranged",
|
||||
"itoa",
|
||||
"num-conv",
|
||||
"powerfmt",
|
||||
"serde",
|
||||
"time-core",
|
||||
"time-macros",
|
||||
@@ -5229,16 +5260,17 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "time-core"
|
||||
version = "0.1.1"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb"
|
||||
checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
|
||||
|
||||
[[package]]
|
||||
name = "time-macros"
|
||||
version = "0.2.10"
|
||||
version = "0.2.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4"
|
||||
checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf"
|
||||
dependencies = [
|
||||
"num-conv",
|
||||
"time-core",
|
||||
]
|
||||
|
||||
|
||||
@@ -209,7 +209,7 @@ metal = ["dep:metal", "dep:objc"]
|
||||
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#8cfca221f53069a0374687654882b99e729041d7", package = "halo2_proofs" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#098ac0ef3b29255e0e2524ecbb4e7e325ae6e7fd", package = "halo2_proofs" }
|
||||
|
||||
|
||||
[profile.release]
|
||||
|
||||
@@ -72,6 +72,7 @@ impl Circuit<Fr> for MyCircuit {
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(0, 0)],
|
||||
stride: vec![1; 2],
|
||||
group: 1,
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
@@ -2,6 +2,7 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
|
||||
use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::table::Range;
|
||||
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
|
||||
use ezkl::fieldutils::IntegerRep;
|
||||
use ezkl::pfsys::create_proof_circuit;
|
||||
use ezkl::pfsys::TranscriptType;
|
||||
use ezkl::pfsys::{create_keys, srs::gen_srs};
|
||||
@@ -84,7 +85,7 @@ fn runrelu(c: &mut Criterion) {
|
||||
};
|
||||
|
||||
let input: Tensor<Value<Fr>> =
|
||||
Tensor::<i32>::from((0..len).map(|_| rng.gen_range(0..10))).into();
|
||||
Tensor::<IntegerRep>::from((0..len).map(|_| rng.gen_range(0..10))).into();
|
||||
|
||||
let circuit = NLCircuit {
|
||||
input: ValTensor::from(input.clone()),
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ezkl==0.0.0
|
||||
ezkl==12.0.2
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
sphinxcontrib-napoleon
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '12.0.2'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -2,8 +2,7 @@ use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::{
|
||||
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
|
||||
};
|
||||
use ezkl::fieldutils;
|
||||
use ezkl::fieldutils::i32_to_felt;
|
||||
use ezkl::fieldutils::{self, integer_rep_to_felt, IntegerRep};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::poly::commitment::Params;
|
||||
@@ -42,8 +41,8 @@ const NUM_INNER_COLS: usize = 1;
|
||||
struct Config<
|
||||
const LEN: usize, //LEN = CHOUT x OH x OW flattened //not supported yet in rust stable
|
||||
const CLASSES: usize,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
const LOOKUP_MIN: IntegerRep,
|
||||
const LOOKUP_MAX: IntegerRep,
|
||||
// Convolution
|
||||
const KERNEL_HEIGHT: usize,
|
||||
const KERNEL_WIDTH: usize,
|
||||
@@ -66,8 +65,8 @@ struct Config<
|
||||
struct MyCircuit<
|
||||
const LEN: usize, //LEN = CHOUT x OH x OW flattened
|
||||
const CLASSES: usize,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
const LOOKUP_MIN: IntegerRep,
|
||||
const LOOKUP_MAX: IntegerRep,
|
||||
// Convolution
|
||||
const KERNEL_HEIGHT: usize,
|
||||
const KERNEL_WIDTH: usize,
|
||||
@@ -90,8 +89,8 @@ struct MyCircuit<
|
||||
impl<
|
||||
const LEN: usize,
|
||||
const CLASSES: usize,
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
const LOOKUP_MIN: IntegerRep,
|
||||
const LOOKUP_MAX: IntegerRep,
|
||||
// Convolution
|
||||
const KERNEL_HEIGHT: usize,
|
||||
const KERNEL_WIDTH: usize,
|
||||
@@ -205,6 +204,7 @@ where
|
||||
let op = PolyOp::Conv {
|
||||
padding: vec![(PADDING, PADDING); 2],
|
||||
stride: vec![STRIDE; 2],
|
||||
group: 1,
|
||||
};
|
||||
let x = config
|
||||
.layer_config
|
||||
@@ -315,7 +315,11 @@ pub fn runconv() {
|
||||
.test_set_length(10_000)
|
||||
.finalize();
|
||||
|
||||
let mut train_data = Tensor::from(trn_img.iter().map(|x| i32_to_felt::<F>(*x as i32 / 16)));
|
||||
let mut train_data = Tensor::from(
|
||||
trn_img
|
||||
.iter()
|
||||
.map(|x| integer_rep_to_felt::<F>(*x as IntegerRep / 16)),
|
||||
);
|
||||
train_data.reshape(&[50_000, 28, 28]).unwrap();
|
||||
|
||||
let mut train_labels = Tensor::from(trn_lbl.iter().map(|x| *x as f32));
|
||||
@@ -343,8 +347,8 @@ pub fn runconv() {
|
||||
.map(|fl| {
|
||||
let dx = fl * 32_f32;
|
||||
let rounded = dx.round();
|
||||
let integral: i32 = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::i32_to_felt(integral)
|
||||
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::integer_rep_to_felt(integral)
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -355,7 +359,8 @@ pub fn runconv() {
|
||||
|
||||
let l0_kernels = l0_kernels.try_into().unwrap();
|
||||
|
||||
let mut l0_bias = Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::i32_to_felt(0)));
|
||||
let mut l0_bias =
|
||||
Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::integer_rep_to_felt(0)));
|
||||
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
|
||||
let l0_bias = l0_bias.try_into().unwrap();
|
||||
@@ -363,8 +368,8 @@ pub fn runconv() {
|
||||
let mut l2_biases = Tensor::<F>::from(myparams.biases.into_iter().map(|fl| {
|
||||
let dx = fl * 32_f32;
|
||||
let rounded = dx.round();
|
||||
let integral: i32 = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::i32_to_felt(integral)
|
||||
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::integer_rep_to_felt(integral)
|
||||
}));
|
||||
l2_biases.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
l2_biases.reshape(&[l2_biases.len(), 1]).unwrap();
|
||||
@@ -374,8 +379,8 @@ pub fn runconv() {
|
||||
let mut l2_weights = Tensor::<F>::from(myparams.weights.into_iter().flatten().map(|fl| {
|
||||
let dx = fl * 32_f32;
|
||||
let rounded = dx.round();
|
||||
let integral: i32 = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::i32_to_felt(integral)
|
||||
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
|
||||
fieldutils::integer_rep_to_felt(integral)
|
||||
}));
|
||||
l2_weights.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
l2_weights.reshape(&[CLASSES, LEN]).unwrap();
|
||||
@@ -401,13 +406,13 @@ pub fn runconv() {
|
||||
l2_params: [l2_weights, l2_biases],
|
||||
};
|
||||
|
||||
let public_input: Tensor<i32> = vec![
|
||||
-25124i32, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
|
||||
let public_input: Tensor<IntegerRep> = vec![
|
||||
-25124, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
|
||||
]
|
||||
.into_iter()
|
||||
.into();
|
||||
|
||||
let pi_inner: Tensor<F> = public_input.map(i32_to_felt::<F>);
|
||||
let pi_inner: Tensor<F> = public_input.map(integer_rep_to_felt::<F>);
|
||||
|
||||
println!("MOCK PROVING");
|
||||
let now = Instant::now();
|
||||
|
||||
@@ -2,7 +2,7 @@ use ezkl::circuit::region::RegionCtx;
|
||||
use ezkl::circuit::{
|
||||
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
|
||||
};
|
||||
use ezkl::fieldutils::i32_to_felt;
|
||||
use ezkl::fieldutils::{integer_rep_to_felt, IntegerRep};
|
||||
use ezkl::tensor::*;
|
||||
use halo2_proofs::dev::MockProver;
|
||||
use halo2_proofs::{
|
||||
@@ -23,8 +23,8 @@ struct MyConfig {
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<
|
||||
const LEN: usize, //LEN = CHOUT x OH x OW flattened
|
||||
const LOOKUP_MIN: i64,
|
||||
const LOOKUP_MAX: i64,
|
||||
const LOOKUP_MIN: IntegerRep,
|
||||
const LOOKUP_MAX: IntegerRep,
|
||||
> {
|
||||
// Given the stateless MyConfig type information, a DNN trace is determined by its input and the parameters of its layers.
|
||||
// Computing the trace still requires a forward pass. The intermediate activations are stored only by the layouter.
|
||||
@@ -34,7 +34,7 @@ struct MyCircuit<
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
|
||||
impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRep> Circuit<F>
|
||||
for MyCircuit<LEN, LOOKUP_MIN, LOOKUP_MAX>
|
||||
{
|
||||
type Config = MyConfig;
|
||||
@@ -215,33 +215,33 @@ pub fn runmlp() {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
env_logger::init();
|
||||
// parameters
|
||||
let mut l0_kernel: Tensor<F> = Tensor::<i32>::new(
|
||||
let mut l0_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
|
||||
Some(&[10, 0, 0, -1, 0, 10, 1, 0, 0, 1, 10, 0, 1, 0, 0, 10]),
|
||||
&[4, 4],
|
||||
)
|
||||
.unwrap()
|
||||
.map(i32_to_felt);
|
||||
.map(integer_rep_to_felt);
|
||||
l0_kernel.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
|
||||
let mut l0_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
|
||||
let mut l0_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
|
||||
.unwrap()
|
||||
.map(i32_to_felt);
|
||||
.map(integer_rep_to_felt);
|
||||
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
|
||||
let mut l2_kernel: Tensor<F> = Tensor::<i32>::new(
|
||||
let mut l2_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
|
||||
Some(&[0, 3, 10, -1, 0, 10, 1, 0, 0, 1, 0, 12, 1, -2, 32, 0]),
|
||||
&[4, 4],
|
||||
)
|
||||
.unwrap()
|
||||
.map(i32_to_felt);
|
||||
.map(integer_rep_to_felt);
|
||||
l2_kernel.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
// input data, with 1 padding to allow for bias
|
||||
let input: Tensor<Value<F>> = Tensor::<i32>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
|
||||
let input: Tensor<Value<F>> = Tensor::<IntegerRep>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
|
||||
.unwrap()
|
||||
.into();
|
||||
let mut l2_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
|
||||
let mut l2_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
|
||||
.unwrap()
|
||||
.map(i32_to_felt);
|
||||
.map(integer_rep_to_felt);
|
||||
l2_bias.set_visibility(&ezkl::graph::Visibility::Private);
|
||||
|
||||
let circuit = MyCircuit::<4, -8192, 8192> {
|
||||
@@ -251,12 +251,12 @@ pub fn runmlp() {
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let public_input: Vec<i32> = unsafe {
|
||||
let public_input: Vec<IntegerRep> = unsafe {
|
||||
vec![
|
||||
(531f32 / 128f32).round().to_int_unchecked::<i32>(),
|
||||
(103f32 / 128f32).round().to_int_unchecked::<i32>(),
|
||||
(4469f32 / 128f32).round().to_int_unchecked::<i32>(),
|
||||
(2849f32 / 128f32).to_int_unchecked::<i32>(),
|
||||
(531f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
|
||||
(103f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
|
||||
(4469f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
|
||||
(2849f32 / 128f32).to_int_unchecked::<IntegerRep>(),
|
||||
]
|
||||
};
|
||||
|
||||
@@ -265,7 +265,10 @@ pub fn runmlp() {
|
||||
let prover = MockProver::run(
|
||||
K as u32,
|
||||
&circuit,
|
||||
vec![public_input.iter().map(|x| i32_to_felt::<F>(*x)).collect()],
|
||||
vec![public_input
|
||||
.iter()
|
||||
.map(|x| integer_rep_to_felt::<F>(*x))
|
||||
.collect()],
|
||||
)
|
||||
.unwrap();
|
||||
prover.assert_satisfied();
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
[toolchain]
|
||||
channel = "nightly-2024-02-06"
|
||||
channel = "nightly-2024-07-18"
|
||||
components = ["rustfmt", "clippy"]
|
||||
|
||||
@@ -22,7 +22,7 @@ use crate::{
|
||||
table::{Range, RangeCheck, Table},
|
||||
utils,
|
||||
},
|
||||
tensor::{IntoI64, Tensor, TensorType, ValTensor, VarTensor},
|
||||
tensor::{Tensor, TensorType, ValTensor, VarTensor},
|
||||
};
|
||||
use std::{collections::BTreeMap, marker::PhantomData};
|
||||
|
||||
@@ -327,7 +327,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseConfig<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
|
||||
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
|
||||
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
|
||||
Self {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::convert::Infallible;
|
||||
|
||||
use crate::tensor::TensorError;
|
||||
use crate::{fieldutils::IntegerRep, tensor::TensorError};
|
||||
use halo2_proofs::plonk::Error as PlonkError;
|
||||
use thiserror::Error;
|
||||
|
||||
@@ -57,7 +57,7 @@ pub enum CircuitError {
|
||||
InvalidConversion(#[from] Infallible),
|
||||
/// Invalid min/max lookup range
|
||||
#[error("invalid min/max lookup range: min: {0}, max: {1}")]
|
||||
InvalidMinMaxRange(i64, i64),
|
||||
InvalidMinMaxRange(IntegerRep, IntegerRep),
|
||||
/// Missing product in einsum
|
||||
#[error("missing product in einsum")]
|
||||
MissingEinsumProduct,
|
||||
@@ -81,7 +81,7 @@ pub enum CircuitError {
|
||||
MissingSelectors(String),
|
||||
/// Table lookup error
|
||||
#[error("value ({0}) out of range: ({1}, {2})")]
|
||||
TableOOR(i64, i64, i64),
|
||||
TableOOR(IntegerRep, IntegerRep, IntegerRep),
|
||||
/// Loookup not configured
|
||||
#[error("lookup not configured: {0}")]
|
||||
LookupNotConfigured(String),
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
use crate::{
|
||||
circuit::{layouts, utils, Tolerance},
|
||||
fieldutils::i64_to_felt,
|
||||
fieldutils::integer_rep_to_felt,
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
@@ -71,7 +71,7 @@ pub enum HybridOp {
|
||||
},
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for HybridOp {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for HybridOp {
|
||||
///
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
match self {
|
||||
@@ -184,8 +184,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
i64_to_felt(input_scale.0 as i64),
|
||||
i64_to_felt(output_scale.0 as i64),
|
||||
integer_rep_to_felt(input_scale.0 as i128),
|
||||
integer_rep_to_felt(output_scale.0 as i128),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
@@ -209,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
i64_to_felt(denom.0 as i64),
|
||||
integer_rep_to_felt(denom.0 as i128),
|
||||
)?
|
||||
} else {
|
||||
layouts::nonlinearity(
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -3,9 +3,9 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
circuit::{layouts, table::Range, utils},
|
||||
fieldutils::{felt_to_i64, i64_to_felt},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
|
||||
graph::multiplier_to_scale,
|
||||
tensor::{self, IntoI64, Tensor, TensorError, TensorType},
|
||||
tensor::{self, Tensor, TensorError, TensorType},
|
||||
};
|
||||
|
||||
use super::Op;
|
||||
@@ -131,16 +131,16 @@ impl LookupOp {
|
||||
/// Returns the range of values that can be represented by the table
|
||||
pub fn bit_range(max_len: usize) -> Range {
|
||||
let range = (max_len - 1) as f64 / 2_f64;
|
||||
let range = range as i64;
|
||||
let range = range as IntegerRep;
|
||||
(-range, range)
|
||||
}
|
||||
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
|
||||
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
|
||||
&self,
|
||||
x: &[Tensor<F>],
|
||||
) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = x[0].clone().map(|x| felt_to_i64(x));
|
||||
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
|
||||
let res = match &self {
|
||||
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
|
||||
LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())),
|
||||
@@ -227,13 +227,13 @@ impl LookupOp {
|
||||
}
|
||||
}?;
|
||||
|
||||
let output = res.map(|x| i64_to_felt(x));
|
||||
let output = res.map(|x| integer_rep_to_felt(x));
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for LookupOp {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
|
||||
/// Returns a reference to the Any trait.
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
|
||||
@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
graph::quantize_tensor,
|
||||
tensor::{self, IntoI64, Tensor, TensorType, ValTensor},
|
||||
tensor::{self, Tensor, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
@@ -31,12 +31,12 @@ pub use errors::CircuitError;
|
||||
|
||||
/// A struct representing the result of a forward pass.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
pub(crate) output: Tensor<F>,
|
||||
}
|
||||
|
||||
/// A trait representing operations that can be represented as constraints in a circuit.
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>:
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
|
||||
std::fmt::Debug + Send + Sync + Any
|
||||
{
|
||||
/// Returns a string representation of the operation.
|
||||
@@ -75,7 +75,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64
|
||||
fn as_any(&self) -> &dyn Any;
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Clone for Box<dyn Op<F>> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
|
||||
fn clone(&self) -> Self {
|
||||
self.clone_dyn()
|
||||
}
|
||||
@@ -142,7 +142,7 @@ pub struct Input {
|
||||
pub datum_type: InputType,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Input {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(self.scale)
|
||||
}
|
||||
@@ -197,7 +197,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Unknown;
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Unknown {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
|
||||
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
|
||||
Ok(0)
|
||||
}
|
||||
@@ -224,7 +224,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F>
|
||||
|
||||
///
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
|
||||
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
///
|
||||
pub quantized_values: Tensor<F>,
|
||||
///
|
||||
@@ -234,7 +234,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash +
|
||||
pub pre_assigned_val: Option<ValTensor<F>>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Constant<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
|
||||
///
|
||||
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
|
||||
Self {
|
||||
@@ -267,8 +267,7 @@ impl<
|
||||
+ PartialOrd
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>
|
||||
+ IntoI64,
|
||||
+ for<'de> Deserialize<'de>,
|
||||
> Op<F> for Constant<F>
|
||||
{
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
|
||||
@@ -33,6 +33,7 @@ pub enum PolyOp {
|
||||
Conv {
|
||||
padding: Vec<(usize, usize)>,
|
||||
stride: Vec<usize>,
|
||||
group: usize,
|
||||
},
|
||||
Downsample {
|
||||
axis: usize,
|
||||
@@ -43,6 +44,7 @@ pub enum PolyOp {
|
||||
padding: Vec<(usize, usize)>,
|
||||
output_padding: Vec<usize>,
|
||||
stride: Vec<usize>,
|
||||
group: usize,
|
||||
},
|
||||
Add,
|
||||
Sub,
|
||||
@@ -98,7 +100,7 @@ impl<
|
||||
+ std::hash::Hash
|
||||
+ Serialize
|
||||
+ for<'de> Deserialize<'de>
|
||||
+ IntoI64,
|
||||
,
|
||||
> Op<F> for PolyOp
|
||||
{
|
||||
/// Returns a reference to the Any trait.
|
||||
@@ -148,17 +150,25 @@ impl<
|
||||
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
|
||||
PolyOp::Prod { .. } => "PROD".into(),
|
||||
PolyOp::Pow(_) => "POW".into(),
|
||||
PolyOp::Conv { stride, padding } => {
|
||||
format!("CONV (stride={:?}, padding={:?})", stride, padding)
|
||||
PolyOp::Conv {
|
||||
stride,
|
||||
padding,
|
||||
group,
|
||||
} => {
|
||||
format!(
|
||||
"CONV (stride={:?}, padding={:?}, group={})",
|
||||
stride, padding, group
|
||||
)
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
group,
|
||||
} => {
|
||||
format!(
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
|
||||
stride, padding, output_padding
|
||||
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
|
||||
stride, padding, output_padding, group
|
||||
)
|
||||
}
|
||||
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
|
||||
@@ -212,9 +222,18 @@ impl<
|
||||
PolyOp::Prod { axes, .. } => {
|
||||
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
|
||||
}
|
||||
PolyOp::Conv { padding, stride } => {
|
||||
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
|
||||
}
|
||||
PolyOp::Conv {
|
||||
padding,
|
||||
stride,
|
||||
group,
|
||||
} => layouts::conv(
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
padding,
|
||||
stride,
|
||||
*group,
|
||||
)?,
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
@@ -261,6 +280,7 @@ impl<
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
group,
|
||||
} => layouts::deconv(
|
||||
config,
|
||||
region,
|
||||
@@ -268,6 +288,7 @@ impl<
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
*group,
|
||||
)?,
|
||||
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
|
||||
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use crate::{
|
||||
circuit::table::Range,
|
||||
fieldutils::IntegerRep,
|
||||
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
|
||||
};
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -11,7 +12,6 @@ use halo2_proofs::{
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use maybe_rayon::iter::ParallelExtend;
|
||||
use portable_atomic::AtomicI64 as AtomicInt;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
collections::{HashMap, HashSet},
|
||||
@@ -86,6 +86,78 @@ impl ShuffleIndex {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Some settings for a region to differentiate it across the different phases of proof generation
|
||||
pub struct RegionSettings {
|
||||
/// whether we are in witness generation mode
|
||||
pub witness_gen: bool,
|
||||
/// whether we should check range checks for validity
|
||||
pub check_range: bool,
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Sync for RegionSettings {}
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Send for RegionSettings {}
|
||||
|
||||
impl RegionSettings {
|
||||
/// Create a new region settings
|
||||
pub fn new(witness_gen: bool, check_range: bool) -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen,
|
||||
check_range,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region settings with all true
|
||||
pub fn all_true() -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen: true,
|
||||
check_range: true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region settings with all false
|
||||
pub fn all_false() -> RegionSettings {
|
||||
RegionSettings {
|
||||
witness_gen: false,
|
||||
check_range: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
/// Region statistics
|
||||
pub struct RegionStatistics {
|
||||
/// the current maximum value of the lookup inputs
|
||||
pub max_lookup_inputs: IntegerRep,
|
||||
/// the current minimum value of the lookup inputs
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
/// the current maximum value of the range size
|
||||
pub max_range_size: IntegerRep,
|
||||
/// the current set of used lookups
|
||||
pub used_lookups: HashSet<LookupOp>,
|
||||
/// the current set of used range checks
|
||||
pub used_range_checks: HashSet<Range>,
|
||||
}
|
||||
|
||||
impl RegionStatistics {
|
||||
/// update the statistics with another set of statistics
|
||||
pub fn update(&mut self, other: &RegionStatistics) {
|
||||
self.max_lookup_inputs = self.max_lookup_inputs.max(other.max_lookup_inputs);
|
||||
self.min_lookup_inputs = self.min_lookup_inputs.min(other.min_lookup_inputs);
|
||||
self.max_range_size = self.max_range_size.max(other.max_range_size);
|
||||
self.used_lookups.extend(other.used_lookups.clone());
|
||||
self.used_range_checks
|
||||
.extend(other.used_range_checks.clone());
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Sync for RegionStatistics {}
|
||||
#[allow(unsafe_code)]
|
||||
unsafe impl Send for RegionStatistics {}
|
||||
|
||||
#[derive(Debug)]
|
||||
/// A context for a region
|
||||
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
|
||||
@@ -95,13 +167,8 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
|
||||
num_inner_cols: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
shuffle_index: ShuffleIndex,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
max_lookup_inputs: i64,
|
||||
min_lookup_inputs: i64,
|
||||
max_range_size: i64,
|
||||
witness_gen: bool,
|
||||
check_lookup_range: bool,
|
||||
statistics: RegionStatistics,
|
||||
settings: RegionSettings,
|
||||
assigned_constants: ConstantsMap<F>,
|
||||
}
|
||||
|
||||
@@ -153,12 +220,17 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
|
||||
///
|
||||
pub fn witness_gen(&self) -> bool {
|
||||
self.witness_gen
|
||||
self.settings.witness_gen
|
||||
}
|
||||
|
||||
///
|
||||
pub fn check_lookup_range(&self) -> bool {
|
||||
self.check_lookup_range
|
||||
pub fn check_range(&self) -> bool {
|
||||
self.settings.check_range
|
||||
}
|
||||
|
||||
///
|
||||
pub fn statistics(&self) -> &RegionStatistics {
|
||||
&self.statistics
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
@@ -173,13 +245,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
linear_coord,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen: true,
|
||||
check_lookup_range: true,
|
||||
statistics: RegionStatistics::default(),
|
||||
settings: RegionSettings::all_true(),
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@@ -195,39 +262,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
new_self.assigned_constants = constants;
|
||||
new_self
|
||||
}
|
||||
/// Create a new region context from a wrapped region
|
||||
pub fn from_wrapped_region(
|
||||
region: Option<RefCell<Region<'a, F>>>,
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
dynamic_lookup_index: DynamicLookupIndex,
|
||||
shuffle_index: ShuffleIndex,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let linear_coord = row * num_inner_cols;
|
||||
RegionCtx {
|
||||
region,
|
||||
num_inner_cols,
|
||||
linear_coord,
|
||||
row,
|
||||
dynamic_lookup_index,
|
||||
shuffle_index,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen: false,
|
||||
check_lookup_range: false,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new region context
|
||||
pub fn new_dummy(
|
||||
row: usize,
|
||||
num_inner_cols: usize,
|
||||
witness_gen: bool,
|
||||
check_lookup_range: bool,
|
||||
settings: RegionSettings,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
let linear_coord = row * num_inner_cols;
|
||||
@@ -239,13 +279,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
row,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen,
|
||||
check_lookup_range,
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@@ -255,8 +290,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
row: usize,
|
||||
linear_coord: usize,
|
||||
num_inner_cols: usize,
|
||||
witness_gen: bool,
|
||||
check_lookup_range: bool,
|
||||
settings: RegionSettings,
|
||||
) -> RegionCtx<'a, F> {
|
||||
let region = None;
|
||||
RegionCtx {
|
||||
@@ -266,13 +300,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
row,
|
||||
dynamic_lookup_index: DynamicLookupIndex::default(),
|
||||
shuffle_index: ShuffleIndex::default(),
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
max_range_size: 0,
|
||||
witness_gen,
|
||||
check_lookup_range,
|
||||
statistics: RegionStatistics::default(),
|
||||
settings,
|
||||
assigned_constants: HashMap::new(),
|
||||
}
|
||||
}
|
||||
@@ -323,12 +352,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
) -> Result<(), CircuitError> {
|
||||
let row = AtomicUsize::new(self.row());
|
||||
let linear_coord = AtomicUsize::new(self.linear_coord());
|
||||
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
|
||||
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
|
||||
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
|
||||
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
|
||||
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
|
||||
let statistics = Arc::new(Mutex::new(self.statistics.clone()));
|
||||
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
|
||||
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
|
||||
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
|
||||
|
||||
*output = output.par_enum_map(|idx, _| {
|
||||
@@ -342,8 +368,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
starting_offset,
|
||||
starting_linear_coord,
|
||||
self.num_inner_cols,
|
||||
self.witness_gen,
|
||||
self.check_lookup_range,
|
||||
self.settings.clone(),
|
||||
);
|
||||
let res = inner_loop_function(idx, &mut local_reg);
|
||||
// we update the offset and constants
|
||||
@@ -353,14 +378,9 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
|
||||
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
|
||||
// update the lookups
|
||||
let mut lookups = lookups.lock().unwrap();
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
// update the range checks
|
||||
let mut range_checks = range_checks.lock().unwrap();
|
||||
range_checks.extend(local_reg.used_range_checks());
|
||||
let mut statistics = statistics.lock().unwrap();
|
||||
statistics.update(local_reg.statistics());
|
||||
// update the dynamic lookup index
|
||||
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
|
||||
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
|
||||
@@ -374,20 +394,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
res
|
||||
})?;
|
||||
self.linear_coord = linear_coord.into_inner();
|
||||
#[allow(trivial_numeric_casts)]
|
||||
{
|
||||
self.max_lookup_inputs = max_lookup_inputs.into_inner();
|
||||
self.min_lookup_inputs = min_lookup_inputs.into_inner();
|
||||
}
|
||||
self.row = row.into_inner();
|
||||
self.used_lookups = Arc::try_unwrap(lookups)
|
||||
self.statistics = Arc::try_unwrap(statistics)
|
||||
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?;
|
||||
self.used_range_checks = Arc::try_unwrap(range_checks)
|
||||
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?;
|
||||
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
|
||||
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?
|
||||
.into_inner()
|
||||
@@ -411,11 +422,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
) -> Result<(), CircuitError> {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in inputs {
|
||||
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
|
||||
min = min.min(i.get_int_evals()?.into_iter().min().unwrap_or_default());
|
||||
max = max.max(i.int_evals()?.into_iter().max().unwrap_or_default());
|
||||
min = min.min(i.int_evals()?.into_iter().min().unwrap_or_default());
|
||||
}
|
||||
self.max_lookup_inputs = self.max_lookup_inputs.max(max);
|
||||
self.min_lookup_inputs = self.min_lookup_inputs.min(min);
|
||||
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
|
||||
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -427,7 +438,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
|
||||
let range_size = (range.1 - range.0).abs();
|
||||
|
||||
self.max_range_size = self.max_range_size.max(range_size);
|
||||
self.statistics.max_range_size = self.statistics.max_range_size.max(range_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -442,13 +453,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
lookup: LookupOp,
|
||||
inputs: &[ValTensor<F>],
|
||||
) -> Result<(), CircuitError> {
|
||||
self.used_lookups.insert(lookup);
|
||||
self.statistics.used_lookups.insert(lookup);
|
||||
self.update_max_min_lookup_inputs(inputs)
|
||||
}
|
||||
|
||||
/// add used range check
|
||||
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), CircuitError> {
|
||||
self.used_range_checks.insert(range);
|
||||
self.statistics.used_range_checks.insert(range);
|
||||
self.update_max_min_lookup_range(range)
|
||||
}
|
||||
|
||||
@@ -489,27 +500,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
|
||||
/// get used lookups
|
||||
pub fn used_lookups(&self) -> HashSet<LookupOp> {
|
||||
self.used_lookups.clone()
|
||||
self.statistics.used_lookups.clone()
|
||||
}
|
||||
|
||||
/// get used range checks
|
||||
pub fn used_range_checks(&self) -> HashSet<Range> {
|
||||
self.used_range_checks.clone()
|
||||
self.statistics.used_range_checks.clone()
|
||||
}
|
||||
|
||||
/// max lookup inputs
|
||||
pub fn max_lookup_inputs(&self) -> i64 {
|
||||
self.max_lookup_inputs
|
||||
pub fn max_lookup_inputs(&self) -> IntegerRep {
|
||||
self.statistics.max_lookup_inputs
|
||||
}
|
||||
|
||||
/// min lookup inputs
|
||||
pub fn min_lookup_inputs(&self) -> i64 {
|
||||
self.min_lookup_inputs
|
||||
pub fn min_lookup_inputs(&self) -> IntegerRep {
|
||||
self.statistics.min_lookup_inputs
|
||||
}
|
||||
|
||||
/// max range check
|
||||
pub fn max_range_size(&self) -> i64 {
|
||||
self.max_range_size
|
||||
pub fn max_range_size(&self) -> IntegerRep {
|
||||
self.statistics.max_range_size
|
||||
}
|
||||
|
||||
/// Assign a valtensor to a vartensor
|
||||
|
||||
@@ -11,17 +11,17 @@ use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator};
|
||||
|
||||
use crate::{
|
||||
circuit::CircuitError,
|
||||
fieldutils::i64_to_felt,
|
||||
tensor::{IntoI64, Tensor, TensorType},
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
tensor::{Tensor, TensorType},
|
||||
};
|
||||
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
|
||||
/// The range of the lookup table.
|
||||
pub type Range = (i64, i64);
|
||||
pub type Range = (IntegerRep, IntegerRep);
|
||||
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: i64 = 2;
|
||||
pub const RANGE_MULTIPLIER: IntegerRep = 2;
|
||||
/// The safety factor offset for the number of rows in the lookup table.
|
||||
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
|
||||
|
||||
@@ -96,21 +96,22 @@ pub struct Table<F: PrimeField> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
let chunk =
|
||||
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
|
||||
let chunk = (crate::fieldutils::felt_to_integer_rep(input) - self.range.0).abs()
|
||||
/ (self.col_size as IntegerRep);
|
||||
|
||||
i64_to_felt(chunk)
|
||||
integer_rep_to_felt(chunk)
|
||||
}
|
||||
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self, chunk: usize) -> (F, F) {
|
||||
let chunk = chunk as i64;
|
||||
let chunk = chunk as IntegerRep;
|
||||
// we index from 1 to prevent soundness issues
|
||||
let first_element = i64_to_felt(chunk * (self.col_size as i64) + self.range.0);
|
||||
let first_element =
|
||||
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0);
|
||||
let op_f = self
|
||||
.nonlinearity
|
||||
.f(&[Tensor::from(vec![first_element].into_iter())])
|
||||
@@ -130,12 +131,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<
|
||||
}
|
||||
|
||||
///
|
||||
pub fn num_cols_required(range_len: i64, col_size: usize) -> usize {
|
||||
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
|
||||
// number of cols needed to store the range
|
||||
(range_len / (col_size as i64)) as usize + 1
|
||||
(range_len / (col_size as IntegerRep)) as usize + 1
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(
|
||||
cs: &mut ConstraintSystem<F>,
|
||||
@@ -202,7 +203,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| integer_rep_to_felt(x));
|
||||
let evals = self.nonlinearity.f(&[inputs.clone()])?;
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
@@ -272,12 +273,12 @@ pub struct RangeCheck<F: PrimeField> {
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
/// get first_element of column
|
||||
pub fn get_first_element(&self, chunk: usize) -> F {
|
||||
let chunk = chunk as i64;
|
||||
let chunk = chunk as IntegerRep;
|
||||
// we index from 1 to prevent soundness issues
|
||||
i64_to_felt(chunk * (self.col_size as i64) + self.range.0)
|
||||
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
|
||||
}
|
||||
|
||||
///
|
||||
@@ -293,14 +294,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeC
|
||||
/// get column index given input
|
||||
pub fn get_col_index(&self, input: F) -> F {
|
||||
// range is split up into chunks of size col_size, find the chunk that input is in
|
||||
let chunk =
|
||||
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
|
||||
let chunk = (crate::fieldutils::felt_to_integer_rep(input) - self.range.0).abs()
|
||||
/ (self.col_size as IntegerRep);
|
||||
|
||||
i64_to_felt(chunk)
|
||||
integer_rep_to_felt(chunk)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
|
||||
/// Configures the table.
|
||||
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
|
||||
log::debug!("range check range: {:?}", range);
|
||||
@@ -350,7 +351,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeC
|
||||
let smallest = self.range.0;
|
||||
let largest = self.range.1;
|
||||
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
|
||||
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| integer_rep_to_felt(x));
|
||||
let chunked_inputs = inputs.chunks(self.col_size);
|
||||
|
||||
self.is_assigned = true;
|
||||
|
||||
@@ -1050,6 +1050,7 @@ mod conv {
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1200,6 +1201,7 @@ mod conv_col_ultra_overflow {
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
@@ -1345,6 +1347,7 @@ mod conv_relu_col_ultra_overflow {
|
||||
Box::new(PolyOp::Conv {
|
||||
padding: vec![(1, 1); 2],
|
||||
stride: vec![2; 2],
|
||||
group: 1,
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis);
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::circuit::region::RegionSettings;
|
||||
use crate::circuit::CheckMode;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::commands::CalibrationTarget;
|
||||
@@ -785,6 +786,8 @@ pub(crate) async fn gen_witness(
|
||||
|
||||
let commitment: Commitments = settings.run_args.commitment.into();
|
||||
|
||||
let region_settings = RegionSettings::all_true();
|
||||
|
||||
let start_time = Instant::now();
|
||||
let witness = if settings.module_requires_polycommit() {
|
||||
if get_srs_path(settings.run_args.logrows, srs_path.clone(), commitment).exists() {
|
||||
@@ -799,8 +802,7 @@ pub(crate) async fn gen_witness(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
Some(&srs),
|
||||
true,
|
||||
true,
|
||||
region_settings,
|
||||
)?
|
||||
}
|
||||
Commitments::IPA => {
|
||||
@@ -814,8 +816,7 @@ pub(crate) async fn gen_witness(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
Some(&srs),
|
||||
true,
|
||||
true,
|
||||
region_settings,
|
||||
)?
|
||||
}
|
||||
}
|
||||
@@ -825,12 +826,16 @@ pub(crate) async fn gen_witness(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
None,
|
||||
true,
|
||||
true,
|
||||
region_settings,
|
||||
)?
|
||||
}
|
||||
} else {
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true, true)?
|
||||
circuit.forward::<KZGCommitmentScheme<Bn256>>(
|
||||
&mut input,
|
||||
vk.as_ref(),
|
||||
None,
|
||||
region_settings,
|
||||
)?
|
||||
};
|
||||
|
||||
// print each variable tuple (symbol, value) as symbol=value
|
||||
@@ -1023,6 +1028,8 @@ pub(crate) async fn calibrate(
|
||||
use std::collections::HashMap;
|
||||
use tabled::Table;
|
||||
|
||||
use crate::fieldutils::IntegerRep;
|
||||
|
||||
let data = GraphData::from_path(data)?;
|
||||
// load the pre-generated settings
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
@@ -1131,7 +1138,7 @@ pub(crate) async fn calibrate(
|
||||
param_scale,
|
||||
scale_rebase_multiplier,
|
||||
div_rebasing,
|
||||
lookup_range: (i64::MIN, i64::MAX),
|
||||
lookup_range: (IntegerRep::MIN, IntegerRep::MAX),
|
||||
..settings.run_args.clone()
|
||||
};
|
||||
|
||||
@@ -1171,8 +1178,7 @@ pub(crate) async fn calibrate(
|
||||
&mut data.clone(),
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
false,
|
||||
RegionSettings::all_true(),
|
||||
)
|
||||
.map_err(|e| format!("failed to forward: {}", e))?;
|
||||
|
||||
|
||||
@@ -2,17 +2,11 @@ use halo2_proofs::arithmetic::Field;
|
||||
/// Utilities for converting from Halo2 PrimeField types to integers (and vice-versa).
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
/// Converts an i32 to a PrimeField element.
|
||||
pub fn i32_to_felt<F: PrimeField>(x: i32) -> F {
|
||||
if x >= 0 {
|
||||
F::from(x as u64)
|
||||
} else {
|
||||
-F::from(x.unsigned_abs() as u64)
|
||||
}
|
||||
}
|
||||
/// Integer representation of a PrimeField element.
|
||||
pub type IntegerRep = i128;
|
||||
|
||||
/// Converts an i64 to a PrimeField element.
|
||||
pub fn i64_to_felt<F: PrimeField>(x: i64) -> F {
|
||||
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
|
||||
if x >= 0 {
|
||||
F::from_u128(x as u128)
|
||||
} else {
|
||||
@@ -20,24 +14,9 @@ pub fn i64_to_felt<F: PrimeField>(x: i64) -> F {
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a PrimeField element to an i32.
|
||||
pub fn felt_to_i32<F: PrimeField + PartialOrd + Field>(x: F) -> i32 {
|
||||
if x > F::from(i32::MAX as u64) {
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_32 = u32::from_le_bytes(negtmp[..4].try_into().unwrap());
|
||||
-(lower_32 as i32)
|
||||
} else {
|
||||
let rep = (x).to_repr();
|
||||
let tmp: &[u8] = rep.as_ref();
|
||||
let lower_32 = u32::from_le_bytes(tmp[..4].try_into().unwrap());
|
||||
lower_32 as i32
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts a PrimeField element to an f64.
|
||||
pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
if x > F::from_u128(i64::MAX as u128) {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
@@ -51,17 +30,17 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
|
||||
}
|
||||
|
||||
/// Converts a PrimeField element to an i64.
|
||||
pub fn felt_to_i64<F: PrimeField + PartialOrd + Field>(x: F) -> i64 {
|
||||
if x > F::from_u128(i64::MAX as u128) {
|
||||
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
|
||||
if x > F::from_u128(IntegerRep::MAX as u128) {
|
||||
let rep = (-x).to_repr();
|
||||
let negtmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
|
||||
-(lower_128 as i64)
|
||||
-(lower_128 as IntegerRep)
|
||||
} else {
|
||||
let rep = (x).to_repr();
|
||||
let tmp: &[u8] = rep.as_ref();
|
||||
let lower_128: u128 = u128::from_le_bytes(tmp[..16].try_into().unwrap());
|
||||
lower_128 as i64
|
||||
lower_128 as IntegerRep
|
||||
}
|
||||
}
|
||||
|
||||
@@ -73,33 +52,24 @@ mod test {
|
||||
|
||||
#[test]
|
||||
fn test_conv() {
|
||||
let res: F = i32_to_felt(-15i32);
|
||||
let res: F = integer_rep_to_felt(-15);
|
||||
assert_eq!(res, -F::from(15));
|
||||
|
||||
let res: F = i32_to_felt(2_i32.pow(17));
|
||||
let res: F = integer_rep_to_felt(2_i128.pow(17));
|
||||
assert_eq!(res, F::from(131072));
|
||||
|
||||
let res: F = i64_to_felt(-15i64);
|
||||
let res: F = integer_rep_to_felt(-15);
|
||||
assert_eq!(res, -F::from(15));
|
||||
|
||||
let res: F = i64_to_felt(2_i64.pow(17));
|
||||
let res: F = integer_rep_to_felt(2_i128.pow(17));
|
||||
assert_eq!(res, F::from(131072));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttoi32() {
|
||||
for x in -(2i32.pow(16))..(2i32.pow(16)) {
|
||||
let fieldx: F = i32_to_felt::<F>(x);
|
||||
let xf: i32 = felt_to_i32::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn felttoi64() {
|
||||
for x in -(2i64.pow(20))..(2i64.pow(20)) {
|
||||
let fieldx: F = i64_to_felt::<F>(x);
|
||||
let xf: i64 = felt_to_i64::<F>(fieldx);
|
||||
fn felttointegerrep() {
|
||||
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
|
||||
let fieldx: F = integer_rep_to_felt::<F>(x);
|
||||
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
|
||||
assert_eq!(x, xf);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use super::errors::GraphError;
|
||||
use super::quantize_float;
|
||||
use crate::circuit::InputType;
|
||||
use crate::fieldutils::i64_to_felt;
|
||||
use crate::fieldutils::integer_rep_to_felt;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::graph::postgres::Client;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
@@ -128,7 +128,7 @@ impl FileSourceInner {
|
||||
/// Convert to a field element
|
||||
pub fn to_field(&self, scale: crate::Scale) -> Fp {
|
||||
match self {
|
||||
FileSourceInner::Float(f) => i64_to_felt(quantize_float(f, 0.0, scale).unwrap()),
|
||||
FileSourceInner::Float(f) => integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap()),
|
||||
FileSourceInner::Bool(f) => {
|
||||
if *f {
|
||||
Fp::one()
|
||||
@@ -150,7 +150,7 @@ impl FileSourceInner {
|
||||
0.0
|
||||
}
|
||||
}
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_i64(*f) as f64,
|
||||
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,10 +34,10 @@ use self::input::{FileSource, GraphData};
|
||||
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::modules::ModulePlanner;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::circuit::region::{ConstantsMap, RegionSettings};
|
||||
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
|
||||
use crate::circuit::{CheckMode, InputType};
|
||||
use crate::fieldutils::felt_to_f64;
|
||||
use crate::fieldutils::{felt_to_f64, IntegerRep};
|
||||
use crate::pfsys::PrettyElements;
|
||||
use crate::tensor::{Tensor, ValTensor};
|
||||
use crate::{RunArgs, EZKL_BUF_CAPACITY};
|
||||
@@ -69,13 +69,14 @@ pub use vars::*;
|
||||
use crate::pfsys::field_to_string;
|
||||
|
||||
/// The safety factor for the range of the lookup table.
|
||||
pub const RANGE_MULTIPLIER: i64 = 2;
|
||||
pub const RANGE_MULTIPLIER: IntegerRep = 2;
|
||||
|
||||
/// The maximum number of columns in a lookup table.
|
||||
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
|
||||
|
||||
/// Max representation of a lookup table input
|
||||
pub const MAX_LOOKUP_ABS: i64 = (MAX_NUM_LOOKUP_COLS as i64) * 2_i64.pow(MAX_PUBLIC_SRS);
|
||||
pub const MAX_LOOKUP_ABS: IntegerRep =
|
||||
(MAX_NUM_LOOKUP_COLS as IntegerRep) * 2_i128.pow(MAX_PUBLIC_SRS);
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
lazy_static! {
|
||||
@@ -126,11 +127,11 @@ pub struct GraphWitness {
|
||||
/// Any hashes of outputs generated during the forward pass
|
||||
pub processed_outputs: Option<ModuleForwardResult>,
|
||||
/// max lookup input
|
||||
pub max_lookup_inputs: i64,
|
||||
pub max_lookup_inputs: IntegerRep,
|
||||
/// max lookup input
|
||||
pub min_lookup_inputs: i64,
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
/// max range check size
|
||||
pub max_range_size: i64,
|
||||
pub max_range_size: IntegerRep,
|
||||
}
|
||||
|
||||
impl GraphWitness {
|
||||
@@ -804,18 +805,26 @@ impl GraphCircuit {
|
||||
// the ordering here is important, we want the inputs to come before the outputs
|
||||
// as they are configured in that order as Column<Instances>
|
||||
let mut public_inputs: Vec<Fp> = vec![];
|
||||
if self.settings().run_args.input_visibility.is_public() {
|
||||
public_inputs.extend(self.graph_witness.inputs.clone().into_iter().flatten())
|
||||
} else if let Some(processed_inputs) = &data.processed_inputs {
|
||||
|
||||
// we first process the inputs
|
||||
if let Some(processed_inputs) = &data.processed_inputs {
|
||||
public_inputs.extend(processed_inputs.get_instances().into_iter().flatten());
|
||||
}
|
||||
|
||||
// we then process the params
|
||||
if let Some(processed_params) = &data.processed_params {
|
||||
public_inputs.extend(processed_params.get_instances().into_iter().flatten());
|
||||
}
|
||||
|
||||
// if the inputs are public, we add them to the public inputs AFTER the processed params as they are configured in that order as Column<Instances>
|
||||
if self.settings().run_args.input_visibility.is_public() {
|
||||
public_inputs.extend(self.graph_witness.inputs.clone().into_iter().flatten())
|
||||
}
|
||||
|
||||
// if the outputs are public, we add them to the public inputs
|
||||
if self.settings().run_args.output_visibility.is_public() {
|
||||
public_inputs.extend(self.graph_witness.outputs.clone().into_iter().flatten());
|
||||
// if the outputs are processed, we add the processed outputs to the public inputs
|
||||
} else if let Some(processed_outputs) = &data.processed_outputs {
|
||||
public_inputs.extend(processed_outputs.get_instances().into_iter().flatten());
|
||||
}
|
||||
@@ -1036,12 +1045,12 @@ impl GraphCircuit {
|
||||
|
||||
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: f64) -> Range {
|
||||
(
|
||||
(lookup_safety_margin * min_max_lookup.0 as f64).floor() as i64,
|
||||
(lookup_safety_margin * min_max_lookup.1 as f64).ceil() as i64,
|
||||
(lookup_safety_margin * min_max_lookup.0 as f64).floor() as IntegerRep,
|
||||
(lookup_safety_margin * min_max_lookup.1 as f64).ceil() as IntegerRep,
|
||||
)
|
||||
}
|
||||
|
||||
fn calc_num_cols(range_len: i64, max_logrows: u32) -> usize {
|
||||
fn calc_num_cols(range_len: IntegerRep, max_logrows: u32) -> usize {
|
||||
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
|
||||
num_cols_required(range_len, max_col_size)
|
||||
}
|
||||
@@ -1049,7 +1058,7 @@ impl GraphCircuit {
|
||||
fn table_size_logrows(
|
||||
&self,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: i64,
|
||||
max_range_size: IntegerRep,
|
||||
) -> Result<u32, GraphError> {
|
||||
// pick the range with the largest absolute size safe_lookup_range or max_range_size
|
||||
let safe_range = std::cmp::max(
|
||||
@@ -1068,7 +1077,7 @@ impl GraphCircuit {
|
||||
pub fn calc_min_logrows(
|
||||
&mut self,
|
||||
min_max_lookup: Range,
|
||||
max_range_size: i64,
|
||||
max_range_size: IntegerRep,
|
||||
max_logrows: Option<u32>,
|
||||
lookup_safety_margin: f64,
|
||||
) -> Result<(), GraphError> {
|
||||
@@ -1086,7 +1095,7 @@ impl GraphCircuit {
|
||||
(safe_lookup_range.1.saturating_sub(safe_lookup_range.0)).saturating_abs();
|
||||
// check if has overflowed max lookup input
|
||||
|
||||
if lookup_size > (MAX_LOOKUP_ABS as f64 / lookup_safety_margin).floor() as i64 {
|
||||
if lookup_size > (MAX_LOOKUP_ABS as f64 / lookup_safety_margin).floor() as IntegerRep {
|
||||
return Err(GraphError::LookupRangeTooLarge(
|
||||
lookup_size.unsigned_abs() as usize
|
||||
));
|
||||
@@ -1166,7 +1175,7 @@ impl GraphCircuit {
|
||||
&self,
|
||||
k: u32,
|
||||
safe_lookup_range: Range,
|
||||
max_range_size: i64,
|
||||
max_range_size: IntegerRep,
|
||||
) -> bool {
|
||||
// if num cols is too large then the extended k is too large
|
||||
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS
|
||||
@@ -1224,8 +1233,7 @@ impl GraphCircuit {
|
||||
inputs: &mut [Tensor<Fp>],
|
||||
vk: Option<&VerifyingKey<G1Affine>>,
|
||||
srs: Option<&Scheme::ParamsProver>,
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
region_settings: RegionSettings,
|
||||
) -> Result<GraphWitness, GraphError> {
|
||||
let original_inputs = inputs.to_vec();
|
||||
|
||||
@@ -1274,7 +1282,7 @@ impl GraphCircuit {
|
||||
|
||||
let mut model_results =
|
||||
self.model()
|
||||
.forward(inputs, &self.settings().run_args, witness_gen, check_lookup)?;
|
||||
.forward(inputs, &self.settings().run_args, region_settings)?;
|
||||
|
||||
if visibility.output.requires_processing() {
|
||||
let module_outlets = visibility.output.overwrites_inputs();
|
||||
|
||||
@@ -7,10 +7,12 @@ use super::GraphSettings;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::circuit::region::RegionCtx;
|
||||
use crate::circuit::region::RegionSettings;
|
||||
use crate::circuit::table::Range;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::Unknown;
|
||||
use crate::fieldutils::IntegerRep;
|
||||
use crate::tensor::ValType;
|
||||
use crate::{
|
||||
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
|
||||
@@ -64,11 +66,11 @@ pub struct ForwardResult {
|
||||
/// The outputs of the forward pass.
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
/// The maximum value of any input to a lookup operation.
|
||||
pub max_lookup_inputs: i64,
|
||||
pub max_lookup_inputs: IntegerRep,
|
||||
/// The minimum value of any input to a lookup operation.
|
||||
pub min_lookup_inputs: i64,
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
/// The max range check size
|
||||
pub max_range_size: i64,
|
||||
pub max_range_size: IntegerRep,
|
||||
}
|
||||
|
||||
impl From<DummyPassRes> for ForwardResult {
|
||||
@@ -116,11 +118,11 @@ pub struct DummyPassRes {
|
||||
/// range checks
|
||||
pub range_checks: HashSet<Range>,
|
||||
/// max lookup inputs
|
||||
pub max_lookup_inputs: i64,
|
||||
pub max_lookup_inputs: IntegerRep,
|
||||
/// min lookup inputs
|
||||
pub min_lookup_inputs: i64,
|
||||
pub min_lookup_inputs: IntegerRep,
|
||||
/// min range check
|
||||
pub max_range_size: i64,
|
||||
pub max_range_size: IntegerRep,
|
||||
/// outputs
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
}
|
||||
@@ -545,7 +547,7 @@ impl Model {
|
||||
})
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
let res = self.dummy_layout(run_args, &inputs, false, false)?;
|
||||
let res = self.dummy_layout(run_args, &inputs, RegionSettings::all_false())?;
|
||||
|
||||
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
|
||||
|
||||
@@ -588,14 +590,13 @@ impl Model {
|
||||
&self,
|
||||
model_inputs: &[Tensor<Fp>],
|
||||
run_args: &RunArgs,
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
region_settings: RegionSettings,
|
||||
) -> Result<ForwardResult, GraphError> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
.collect();
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen, check_lookup)?;
|
||||
let res = self.dummy_layout(run_args, &valtensor_inputs, region_settings)?;
|
||||
Ok(res.into())
|
||||
}
|
||||
|
||||
@@ -1390,8 +1391,7 @@ impl Model {
|
||||
&self,
|
||||
run_args: &RunArgs,
|
||||
inputs: &[ValTensor<Fp>],
|
||||
witness_gen: bool,
|
||||
check_lookup: bool,
|
||||
region_settings: RegionSettings,
|
||||
) -> Result<DummyPassRes, GraphError> {
|
||||
debug!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
@@ -1410,8 +1410,7 @@ impl Model {
|
||||
vars: ModelVars::new_dummy(),
|
||||
};
|
||||
|
||||
let mut region =
|
||||
RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen, check_lookup);
|
||||
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, region_settings);
|
||||
|
||||
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::circuit::lookup::LookupOp;
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::Op;
|
||||
use crate::fieldutils::IntegerRep;
|
||||
use crate::tensor::{Tensor, TensorError, TensorType};
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
use halo2curves::ff::PrimeField;
|
||||
@@ -50,16 +51,20 @@ use tract_onnx::tract_hir::{
|
||||
/// * `dims` - the dimensionality of the resulting [Tensor].
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i64, TensorError> {
|
||||
pub fn quantize_float(
|
||||
elem: &f64,
|
||||
shift: f64,
|
||||
scale: crate::Scale,
|
||||
) -> Result<IntegerRep, TensorError> {
|
||||
let mult = scale_to_multiplier(scale);
|
||||
let max_value = ((i64::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
|
||||
|
||||
if *elem > max_value {
|
||||
return Err(TensorError::SigBitTruncationError);
|
||||
}
|
||||
|
||||
// we parallelize the quantization process as it seems to be quite slow at times
|
||||
let scaled = (mult * *elem + shift).round() as i64;
|
||||
let scaled = (mult * *elem + shift).round() as IntegerRep;
|
||||
|
||||
Ok(scaled)
|
||||
}
|
||||
@@ -70,7 +75,7 @@ pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i64
|
||||
/// * `scale` - `2^scale` used in the fixed point representation.
|
||||
/// * `shift` - offset used in the fixed point representation.
|
||||
pub fn dequantize(felt: Fp, scale: crate::Scale, shift: f64) -> f64 {
|
||||
let int_rep = crate::fieldutils::felt_to_i64(felt);
|
||||
let int_rep = crate::fieldutils::felt_to_integer_rep(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
int_rep as f64 / multiplier - shift
|
||||
}
|
||||
@@ -283,10 +288,7 @@ pub fn new_op_from_onnx(
|
||||
.flat_map(|x| x.out_scales())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let input_dims = inputs
|
||||
.iter()
|
||||
.flat_map(|x| x.out_dims())
|
||||
.collect::<Vec<_>>();
|
||||
let input_dims = inputs.iter().flat_map(|x| x.out_dims()).collect::<Vec<_>>();
|
||||
|
||||
let mut replace_const = |scale: crate::Scale,
|
||||
index: usize,
|
||||
@@ -1192,7 +1194,13 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
}
|
||||
|
||||
SupportedOp::Linear(PolyOp::Conv { padding, stride })
|
||||
let group = conv_node.group;
|
||||
|
||||
SupportedOp::Linear(PolyOp::Conv {
|
||||
padding,
|
||||
stride,
|
||||
group,
|
||||
})
|
||||
}
|
||||
"Not" => SupportedOp::Linear(PolyOp::Not),
|
||||
"And" => SupportedOp::Linear(PolyOp::And),
|
||||
@@ -1247,6 +1255,7 @@ pub fn new_op_from_onnx(
|
||||
padding,
|
||||
output_padding: deconv_node.adjustments.to_vec(),
|
||||
stride,
|
||||
group: deconv_node.group,
|
||||
})
|
||||
}
|
||||
"Downsample" => {
|
||||
@@ -1420,7 +1429,7 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
|
||||
visibility: &Visibility,
|
||||
) -> Result<Tensor<F>, TensorError> {
|
||||
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
|
||||
Ok::<_, TensorError>(crate::fieldutils::i64_to_felt::<F>(quantize_float(
|
||||
Ok::<_, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
|
||||
&(x).into(),
|
||||
0.0,
|
||||
scale,
|
||||
|
||||
@@ -85,6 +85,7 @@ use std::str::FromStr;
|
||||
|
||||
use circuit::{table::Range, CheckMode, Tolerance};
|
||||
use clap::Args;
|
||||
use fieldutils::IntegerRep;
|
||||
use graph::Visibility;
|
||||
use halo2_proofs::poly::{
|
||||
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
|
||||
@@ -243,7 +244,7 @@ pub struct RunArgs {
|
||||
#[arg(long, default_value = "1", value_hint = clap::ValueHint::Other)]
|
||||
pub scale_rebase_multiplier: u32,
|
||||
/// The min and max elements in the lookup table input column
|
||||
#[arg(short = 'B', long, value_parser = parse_key_val::<i64, i64>, default_value = "-32768->32768")]
|
||||
#[arg(short = 'B', long, value_parser = parse_key_val::<IntegerRep, IntegerRep>, default_value = "-32768->32768")]
|
||||
pub lookup_range: Range,
|
||||
/// The log_2 number of rows
|
||||
#[arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other)]
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::circuit::modules::poseidon::{
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::circuit::{CheckMode, Tolerance};
|
||||
use crate::commands::*;
|
||||
use crate::fieldutils::{felt_to_i64, i64_to_felt};
|
||||
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
use crate::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use crate::graph::TestDataSource;
|
||||
use crate::graph::{
|
||||
@@ -331,9 +331,9 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
|
||||
#[pyfunction(signature = (
|
||||
felt,
|
||||
))]
|
||||
fn felt_to_int(felt: PyFelt) -> PyResult<i64> {
|
||||
fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_i64(felt);
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
Ok(int_rep)
|
||||
}
|
||||
|
||||
@@ -357,7 +357,7 @@ fn felt_to_int(felt: PyFelt) -> PyResult<i64> {
|
||||
))]
|
||||
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
|
||||
let int_rep = felt_to_i64(felt);
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
let float_rep = int_rep as f64 / multiplier;
|
||||
Ok(float_rep)
|
||||
@@ -385,7 +385,7 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
|
||||
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
|
||||
let int_rep = quantize_float(&input, 0.0, scale)
|
||||
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
|
||||
let felt = i64_to_felt(int_rep);
|
||||
let felt = integer_rep_to_felt(int_rep);
|
||||
Ok(crate::pfsys::field_to_string::<Fr>(&felt))
|
||||
}
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ pub mod var;
|
||||
|
||||
pub use errors::TensorError;
|
||||
|
||||
use halo2curves::{bn256::Fr, ff::PrimeField};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use maybe_rayon::{
|
||||
prelude::{
|
||||
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator,
|
||||
@@ -26,7 +26,7 @@ use instant::Instant;
|
||||
|
||||
use crate::{
|
||||
circuit::utils,
|
||||
fieldutils::{felt_to_i32, felt_to_i64, i32_to_felt, i64_to_felt},
|
||||
fieldutils::{integer_rep_to_felt, IntegerRep},
|
||||
graph::Visibility,
|
||||
};
|
||||
|
||||
@@ -108,7 +108,7 @@ impl TensorType for f32 {
|
||||
Some(0.0)
|
||||
}
|
||||
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for i32, usize.
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
|
||||
// A comparison between f32s needs to handle NAN values.
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
match (self.is_nan(), other.is_nan()) {
|
||||
@@ -131,7 +131,7 @@ impl TensorType for f64 {
|
||||
Some(0.0)
|
||||
}
|
||||
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for i32, usize.
|
||||
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
|
||||
// A comparison between f32s needs to handle NAN values.
|
||||
fn tmax(&self, other: &Self) -> Option<Self> {
|
||||
match (self.is_nan(), other.is_nan()) {
|
||||
@@ -150,8 +150,7 @@ impl TensorType for f64 {
|
||||
}
|
||||
|
||||
tensor_type!(bool, Bool, false, true);
|
||||
tensor_type!(i64, Int64, 0, 1);
|
||||
tensor_type!(i32, Int32, 0, 1);
|
||||
tensor_type!(IntegerRep, IntegerRep, 0, 1);
|
||||
tensor_type!(usize, USize, 0, 1);
|
||||
tensor_type!((), Empty, (), ());
|
||||
tensor_type!(utils::F32, F32, utils::F32(0.0), utils::F32(1.0));
|
||||
@@ -316,92 +315,6 @@ impl<T: TensorType> DerefMut for Tensor<T> {
|
||||
self.inner.deref_mut()
|
||||
}
|
||||
}
|
||||
/// Convert to i64 trait
|
||||
pub trait IntoI64 {
|
||||
/// Convert to i64
|
||||
fn into_i64(self) -> i64;
|
||||
|
||||
/// From i64
|
||||
fn from_i64(i: i64) -> Self;
|
||||
}
|
||||
|
||||
impl IntoI64 for i64 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self
|
||||
}
|
||||
fn from_i64(i: i64) -> i64 {
|
||||
i
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for i32 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as i32
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for usize {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for f32 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as f32
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for f64 {
|
||||
fn into_i64(self) -> i64 {
|
||||
self as i64
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i as f64
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoI64 for () {
|
||||
fn into_i64(self) -> i64 {
|
||||
0
|
||||
}
|
||||
fn from_i64(_: i64) -> Self {}
|
||||
}
|
||||
|
||||
impl IntoI64 for Fr {
|
||||
fn into_i64(self) -> i64 {
|
||||
felt_to_i64(self)
|
||||
}
|
||||
fn from_i64(i: i64) -> Self {
|
||||
i64_to_felt::<Fr>(i)
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + IntoI64> IntoI64 for Value<F> {
|
||||
fn into_i64(self) -> i64 {
|
||||
let mut res = vec![];
|
||||
self.map(|x| res.push(x.into_i64()));
|
||||
|
||||
if res.is_empty() {
|
||||
0
|
||||
} else {
|
||||
res[0]
|
||||
}
|
||||
}
|
||||
|
||||
fn from_i64(i: i64) -> Self {
|
||||
Value::known(F::from_i64(i))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: PartialEq + TensorType> PartialEq for Tensor<T> {
|
||||
fn eq(&self, other: &Tensor<T>) -> bool {
|
||||
@@ -431,42 +344,6 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<Assigned<F>, F>>>
|
||||
for Tensor<i32>
|
||||
{
|
||||
fn from(value: Tensor<AssignedCell<Assigned<F>, F>>) -> Tensor<i32> {
|
||||
let mut output = Vec::new();
|
||||
value.map(|x| {
|
||||
x.evaluate().value().map(|y| {
|
||||
let e = felt_to_i32(*y);
|
||||
output.push(e);
|
||||
e
|
||||
})
|
||||
});
|
||||
Tensor::new(Some(&output), value.dims()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>>
|
||||
for Tensor<i32>
|
||||
{
|
||||
fn from(value: Tensor<AssignedCell<F, F>>) -> Tensor<i32> {
|
||||
let mut output = Vec::new();
|
||||
value.map(|x| {
|
||||
let mut i = 0;
|
||||
x.value().map(|y| {
|
||||
let e = felt_to_i32(*y);
|
||||
output.push(e);
|
||||
i += 1;
|
||||
});
|
||||
if i == 0 {
|
||||
output.push(0);
|
||||
}
|
||||
});
|
||||
Tensor::new(Some(&output), value.dims()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<Assigned<F>, F>>>
|
||||
for Tensor<Value<F>>
|
||||
{
|
||||
@@ -479,24 +356,6 @@ impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<A
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + Clone + PartialOrd> From<Tensor<Value<F>>> for Tensor<i32> {
|
||||
fn from(t: Tensor<Value<F>>) -> Tensor<i32> {
|
||||
let mut output = Vec::new();
|
||||
t.map(|x| {
|
||||
let mut i = 0;
|
||||
x.map(|y| {
|
||||
let e = felt_to_i32(y);
|
||||
output.push(e);
|
||||
i += 1;
|
||||
});
|
||||
if i == 0 {
|
||||
output.push(0);
|
||||
}
|
||||
});
|
||||
Tensor::new(Some(&output), t.dims()).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + Clone + PartialOrd> From<Tensor<Value<F>>>
|
||||
for Tensor<Value<Assigned<F>>>
|
||||
{
|
||||
@@ -508,20 +367,10 @@ impl<F: PrimeField + TensorType + Clone + PartialOrd> From<Tensor<Value<F>>>
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + Clone> From<Tensor<i32>> for Tensor<Value<F>> {
|
||||
fn from(t: Tensor<i32>) -> Tensor<Value<F>> {
|
||||
impl<F: PrimeField + TensorType + Clone> From<Tensor<IntegerRep>> for Tensor<Value<F>> {
|
||||
fn from(t: Tensor<IntegerRep>) -> Tensor<Value<F>> {
|
||||
let mut ta: Tensor<Value<F>> =
|
||||
Tensor::from((0..t.len()).map(|i| Value::known(i32_to_felt::<F>(t[i]))));
|
||||
// safe to unwrap as we know the dims are correct
|
||||
ta.reshape(t.dims()).unwrap();
|
||||
ta
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + Clone> From<Tensor<i64>> for Tensor<Value<F>> {
|
||||
fn from(t: Tensor<i64>) -> Tensor<Value<F>> {
|
||||
let mut ta: Tensor<Value<F>> =
|
||||
Tensor::from((0..t.len()).map(|i| Value::known(i64_to_felt::<F>(t[i]))));
|
||||
Tensor::from((0..t.len()).map(|i| Value::known(integer_rep_to_felt::<F>(t[i]))));
|
||||
// safe to unwrap as we know the dims are correct
|
||||
ta.reshape(t.dims()).unwrap();
|
||||
ta
|
||||
@@ -633,7 +482,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(None, &[3, 3, 3]).unwrap();
|
||||
///
|
||||
/// a.set(&[0, 0, 1], 10);
|
||||
/// assert_eq!(a[0 + 0 + 1], 10);
|
||||
@@ -650,7 +500,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(None, &[2, 3, 5]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(None, &[2, 3, 5]).unwrap();
|
||||
///
|
||||
/// a[1*15 + 1*5 + 1] = 5;
|
||||
/// assert_eq!(a.get(&[1, 1, 1]), 5);
|
||||
@@ -664,7 +515,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(None, &[2, 3, 5]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(None, &[2, 3, 5]).unwrap();
|
||||
///
|
||||
/// a[1*15 + 1*5 + 1] = 5;
|
||||
/// assert_eq!(a.get(&[1, 1, 1]), 5);
|
||||
@@ -684,11 +536,12 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Pad to a length that is divisible by n
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
|
||||
///
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
|
||||
/// ```
|
||||
pub fn pad_to_zero_rem(&self, n: usize, pad: T) -> Result<Tensor<T>, TensorError> {
|
||||
@@ -704,7 +557,8 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(None, &[2, 3, 5]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(None, &[2, 3, 5]).unwrap();
|
||||
///
|
||||
/// let flat_index = 1*15 + 1*5 + 1;
|
||||
/// a[1*15 + 1*5 + 1] = 5;
|
||||
@@ -731,8 +585,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Get a slice from the Tensor.
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<i32>::new(Some(&[1, 2]), &[2]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<IntegerRep>::new(Some(&[1, 2]), &[2]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.get_slice(&[0..2]).unwrap(), b);
|
||||
/// ```
|
||||
@@ -782,9 +637,10 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Set a slice of the Tensor.
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let b = Tensor::<i32>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
|
||||
/// a.set_slice(&[1..2], &Tensor::<i32>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let b = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
|
||||
/// a.set_slice(&[1..2], &Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
|
||||
/// assert_eq!(a, b);
|
||||
/// ```
|
||||
pub fn set_slice(
|
||||
@@ -845,6 +701,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.get_index(&[2, 2, 2]), 26);
|
||||
@@ -868,12 +725,13 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 3, 4, 5, 5, 6]), &[8]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 3, 4, 5, 5, 6]), &[8]).unwrap();
|
||||
/// assert_eq!(a.duplicate_every_n(3, 1, 0).unwrap(), expected);
|
||||
/// assert_eq!(a.duplicate_every_n(7, 1, 0).unwrap(), a);
|
||||
///
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 1, 2, 3, 3, 4, 5, 5, 6]), &[9]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 2, 3, 3, 4, 5, 5, 6]), &[9]).unwrap();
|
||||
/// assert_eq!(a.duplicate_every_n(3, 1, 2).unwrap(), expected);
|
||||
///
|
||||
/// ```
|
||||
@@ -900,8 +758,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let a = Tensor::<i32>::new(Some(&[1, 2, 3, 3, 4, 5, 6, 6]), &[8]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 3, 5, 6, 6]), &[7]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 3, 4, 5, 6, 6]), &[8]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 3, 5, 6, 6]), &[7]).unwrap();
|
||||
/// assert_eq!(a.remove_every_n(4, 1, 0).unwrap(), expected);
|
||||
///
|
||||
///
|
||||
@@ -935,14 +794,15 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// WARN: assumes indices are in ascending order for speed
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
|
||||
/// let mut indices = vec![3, 4];
|
||||
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), expected);
|
||||
///
|
||||
///
|
||||
/// let a = Tensor::<i32>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
|
||||
/// let b = Tensor::<i32>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
|
||||
/// let a = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
|
||||
/// let b = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
|
||||
/// let mut indices = vec![63, 67];
|
||||
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), b);
|
||||
/// ```
|
||||
@@ -972,6 +832,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///Reshape the tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// a.reshape(&[9, 3]);
|
||||
/// assert_eq!(a.dims(), &[9, 3]);
|
||||
@@ -1006,22 +867,23 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Move axis of the tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// let b = a.move_axis(0, 2).unwrap();
|
||||
/// assert_eq!(b.dims(), &[3, 3, 3]);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6]), &[1, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6]), &[1, 2, 3]).unwrap();
|
||||
/// let b = a.move_axis(0, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let b = a.move_axis(1, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let b = a.move_axis(2, 1).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
/// ```
|
||||
@@ -1086,22 +948,23 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Swap axes of the tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// let b = a.swap_axes(0, 2).unwrap();
|
||||
/// assert_eq!(b.dims(), &[3, 3, 3]);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6]), &[2, 1, 3]).unwrap();
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6]), &[2, 1, 3]).unwrap();
|
||||
/// let b = a.swap_axes(0, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let b = a.swap_axes(1, 2).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
///
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
|
||||
/// let b = a.swap_axes(2, 1).unwrap();
|
||||
/// assert_eq!(b, expected);
|
||||
/// ```
|
||||
@@ -1148,9 +1011,10 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Broadcasts the tensor to a given shape
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
|
||||
///
|
||||
/// let mut expected = Tensor::<i32>::new(Some(&[1, 1, 1, 2, 2, 2, 3, 3, 3]), &[3, 3]).unwrap();
|
||||
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 1, 2, 2, 2, 3, 3, 3]), &[3, 3]).unwrap();
|
||||
/// assert_eq!(a.expand(&[3, 3]).unwrap(), expected);
|
||||
///
|
||||
/// ```
|
||||
@@ -1204,6 +1068,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
///Flatten the tensor shape
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
|
||||
/// a.flatten();
|
||||
/// assert_eq!(a.dims(), &[27]);
|
||||
@@ -1217,8 +1082,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Maps a function to tensors
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.map(|x| i32::pow(x,2));
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.map(|x| IntegerRep::pow(x,2));
|
||||
/// assert_eq!(c, Tensor::from([1, 16].into_iter()))
|
||||
/// ```
|
||||
pub fn map<F: FnMut(T) -> G, G: TensorType>(&self, mut f: F) -> Tensor<G> {
|
||||
@@ -1231,8 +1097,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Maps a function to tensors and enumerates
|
||||
/// ```
|
||||
/// use ezkl::tensor::{Tensor, TensorError};
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
|
||||
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
|
||||
/// ```
|
||||
pub fn enum_map<F: FnMut(usize, T) -> Result<G, E>, G: TensorType, E: Error>(
|
||||
@@ -1254,8 +1121,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Maps a function to tensors and enumerates in parallel
|
||||
/// ```
|
||||
/// use ezkl::tensor::{Tensor, TensorError};
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
|
||||
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
|
||||
/// ```
|
||||
pub fn par_enum_map<
|
||||
@@ -1284,8 +1152,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Get last elem from Tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<i32>::new(Some(&[3]), &[1]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<IntegerRep>::new(Some(&[3]), &[1]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.last().unwrap(), b);
|
||||
/// ```
|
||||
@@ -1308,8 +1177,9 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// Maps a function to tensors and enumerates in parallel
|
||||
/// ```
|
||||
/// use ezkl::tensor::{Tensor, TensorError};
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
|
||||
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
|
||||
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
|
||||
/// ```
|
||||
pub fn par_enum_map_mut_filtered<
|
||||
@@ -1332,103 +1202,13 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
#[allow(unsafe_code)]
|
||||
/// Perform a tensor operation on the GPU using Metal.
|
||||
pub fn metal_tensor_op<T: Clone + TensorType + IntoI64 + Send + Sync>(
|
||||
v: &Tensor<T>,
|
||||
w: &Tensor<T>,
|
||||
op: &str,
|
||||
) -> Tensor<T> {
|
||||
assert_eq!(v.dims(), w.dims());
|
||||
|
||||
log::trace!("------------------------------------------------");
|
||||
|
||||
let start = Instant::now();
|
||||
let v = v
|
||||
.par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64()))
|
||||
.unwrap();
|
||||
let w = w
|
||||
.par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64()))
|
||||
.unwrap();
|
||||
log::trace!("Time to map tensors: {:?}", start.elapsed());
|
||||
|
||||
objc::rc::autoreleasepool(|| {
|
||||
// create function pipeline.
|
||||
// this compiles the function, so a pipline can't be created in performance sensitive code.
|
||||
|
||||
let pipeline = &PIPELINES[op];
|
||||
|
||||
let length = v.len() as u64;
|
||||
let size = length * core::mem::size_of::<i64>() as u64;
|
||||
assert_eq!(v.len(), w.len());
|
||||
|
||||
let start = Instant::now();
|
||||
|
||||
let buffer_a = DEVICE.new_buffer_with_data(
|
||||
unsafe { std::mem::transmute(v.as_ptr()) },
|
||||
size,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
let buffer_b = DEVICE.new_buffer_with_data(
|
||||
unsafe { std::mem::transmute(w.as_ptr()) },
|
||||
size,
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
let buffer_result = DEVICE.new_buffer(
|
||||
size, // the operation will return an array with the same size.
|
||||
MTLResourceOptions::StorageModeShared,
|
||||
);
|
||||
|
||||
log::trace!("Time to load buffers: {:?}", start.elapsed());
|
||||
|
||||
// for sending commands, a command buffer is needed.
|
||||
let start = Instant::now();
|
||||
let command_buffer = QUEUE.new_command_buffer();
|
||||
log::trace!("Time to load command buffer: {:?}", start.elapsed());
|
||||
// to write commands into a buffer an encoder is needed, in our case a compute encoder.
|
||||
let start = Instant::now();
|
||||
let compute_encoder = command_buffer.new_compute_command_encoder();
|
||||
compute_encoder.set_compute_pipeline_state(&pipeline);
|
||||
compute_encoder.set_buffers(
|
||||
0,
|
||||
&[Some(&buffer_a), Some(&buffer_b), Some(&buffer_result)],
|
||||
&[0; 3],
|
||||
);
|
||||
log::trace!("Time to load compute encoder: {:?}", start.elapsed());
|
||||
|
||||
// specify thread count and organization
|
||||
let start = Instant::now();
|
||||
let grid_size = MTLSize::new(length, 1, 1);
|
||||
let threadgroup_size = MTLSize::new(length, 1, 1);
|
||||
compute_encoder.dispatch_threads(grid_size, threadgroup_size);
|
||||
log::trace!("Time to dispatch threads: {:?}", start.elapsed());
|
||||
|
||||
// end encoding and execute commands
|
||||
let start = Instant::now();
|
||||
compute_encoder.end_encoding();
|
||||
command_buffer.commit();
|
||||
|
||||
command_buffer.wait_until_completed();
|
||||
log::trace!("Time to commit: {:?}", start.elapsed());
|
||||
|
||||
let start = Instant::now();
|
||||
let ptr = buffer_result.contents() as *const i64;
|
||||
let len = buffer_result.length() as usize / std::mem::size_of::<i64>();
|
||||
let slice = unsafe { core::slice::from_raw_parts(ptr, len) };
|
||||
let res = Tensor::new(Some(&slice.to_vec()), &v.dims()).unwrap();
|
||||
log::trace!("Time to get result: {:?}", start.elapsed());
|
||||
|
||||
res.map(|x| T::from_i64(x))
|
||||
})
|
||||
}
|
||||
|
||||
impl<T: Clone + TensorType> Tensor<Tensor<T>> {
|
||||
/// Flattens a tensor of tensors
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let mut b = Tensor::<i32>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
/// let mut b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
/// let mut c = Tensor::new(Some(&[a,b]), &[2]).unwrap();
|
||||
/// let mut d = c.combine().unwrap();
|
||||
/// assert_eq!(d.dims(), &[8]);
|
||||
@@ -1444,9 +1224,7 @@ impl<T: Clone + TensorType> Tensor<Tensor<T>> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Add
|
||||
for Tensor<T>
|
||||
{
|
||||
impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Add for Tensor<T> {
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
/// Adds tensors.
|
||||
/// # Arguments
|
||||
@@ -1456,42 +1234,43 @@ impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync + I
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Add;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.add(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 1D casting
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2]),
|
||||
/// &[1]).unwrap();
|
||||
/// let result = x.add(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
///
|
||||
/// // Now test 2D casting
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3]),
|
||||
/// &[2]).unwrap();
|
||||
/// let result = x.add(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 4, 4, 4]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 3, 4, 4, 4, 4]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
@@ -1525,13 +1304,14 @@ impl<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sync> Ne
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Neg;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.neg();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[-2, -1, -2, -1, -1, -1]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[-2, -1, -2, -1, -1, -1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn neg(self) -> Self {
|
||||
@@ -1544,9 +1324,7 @@ impl<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sync> Ne
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Sub
|
||||
for Tensor<T>
|
||||
{
|
||||
impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Sub for Tensor<T> {
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
/// Subtracts tensors.
|
||||
/// # Arguments
|
||||
@@ -1556,43 +1334,44 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync + I
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Sub;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.sub(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 1D sub
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2]),
|
||||
/// &[1],
|
||||
/// ).unwrap();
|
||||
/// let result = x.sub(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 2D sub
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3]),
|
||||
/// &[2],
|
||||
/// ).unwrap();
|
||||
/// let result = x.sub(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, -1, 0, -2, -2, -2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, -1, 0, -2, -2, -2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
@@ -1618,9 +1397,7 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync + I
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Mul
|
||||
for Tensor<T>
|
||||
{
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mul for Tensor<T> {
|
||||
type Output = Result<Tensor<T>, TensorError>;
|
||||
/// Elementwise multiplies tensors.
|
||||
/// # Arguments
|
||||
@@ -1630,41 +1407,42 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + I
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Mul;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 3, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.mul(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 1D mult
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2]),
|
||||
/// &[1]).unwrap();
|
||||
/// let result = x.mul(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // Now test 2D mult
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let k = Tensor::<i32>::new(
|
||||
/// let k = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 2]),
|
||||
/// &[2]).unwrap();
|
||||
/// let result = x.mul(k).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn mul(self, rhs: Self) -> Self::Output {
|
||||
@@ -1690,7 +1468,7 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + I
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Tensor<T> {
|
||||
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Tensor<T> {
|
||||
/// Elementwise raise a tensor to the nth power.
|
||||
/// # Arguments
|
||||
///
|
||||
@@ -1699,13 +1477,14 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + I
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Mul;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 15, 2, 1, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.pow(3).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[8, 3375, 8, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[8, 3375, 8, 1, 1, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn pow(&self, mut exp: u32) -> Result<Self, TensorError> {
|
||||
@@ -1739,30 +1518,31 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Div;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 1, 4, 1, 1, 4]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let y = Tensor::<i32>::new(
|
||||
/// let y = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.div(y).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
///
|
||||
/// // test 1D casting
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 1, 4, 1, 1, 4]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let y = Tensor::<i32>::new(
|
||||
/// let y = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2]),
|
||||
/// &[1],
|
||||
/// ).unwrap();
|
||||
/// let result = x.div(y).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[2, 0, 2, 0, 0, 2]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 0, 2, 0, 0, 2]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn div(self, rhs: Self) -> Self::Output {
|
||||
@@ -1789,17 +1569,18 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
|
||||
/// # Examples
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// use ezkl::fieldutils::IntegerRep;
|
||||
/// use std::ops::Rem;
|
||||
/// let x = Tensor::<i32>::new(
|
||||
/// let x = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[4, 1, 4, 1, 1, 4]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let y = Tensor::<i32>::new(
|
||||
/// let y = Tensor::<IntegerRep>::new(
|
||||
/// Some(&[2, 1, 2, 1, 1, 1]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = x.rem(y).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[0, 0, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
fn rem(self, rhs: Self) -> Self::Output {
|
||||
@@ -1883,25 +1664,25 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn tensor_clone() {
|
||||
let x = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
let x = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
assert_eq!(x, x.clone());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_eq() {
|
||||
let a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
let mut b = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
|
||||
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
let mut b = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
|
||||
b.reshape(&[3]).unwrap();
|
||||
let c = Tensor::<i32>::new(Some(&[1, 2, 4]), &[3]).unwrap();
|
||||
let d = Tensor::<i32>::new(Some(&[1, 2, 4]), &[3, 1]).unwrap();
|
||||
let c = Tensor::<IntegerRep>::new(Some(&[1, 2, 4]), &[3]).unwrap();
|
||||
let d = Tensor::<IntegerRep>::new(Some(&[1, 2, 4]), &[3, 1]).unwrap();
|
||||
assert_eq!(a, b);
|
||||
assert_ne!(a, c);
|
||||
assert_ne!(a, d);
|
||||
}
|
||||
#[test]
|
||||
fn tensor_slice() {
|
||||
let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
let b = Tensor::<i32>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
|
||||
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
|
||||
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
|
||||
}
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,4 @@
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use crate::{circuit::region::ConstantsMap, fieldutils::felt_to_integer_rep};
|
||||
use maybe_rayon::slice::Iter;
|
||||
|
||||
use super::{
|
||||
@@ -54,6 +54,44 @@ pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::S
|
||||
AssignedConstant(AssignedCell<F, F>, F),
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> From<ValType<F>> for IntegerRep {
|
||||
fn from(val: ValType<F>) -> Self {
|
||||
match val {
|
||||
ValType::Value(v) => {
|
||||
let mut output = 0;
|
||||
let mut i = 0;
|
||||
v.map(|y| {
|
||||
let e = felt_to_integer_rep(y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::AssignedValue(v) => {
|
||||
let mut output = 0;
|
||||
let mut i = 0;
|
||||
v.evaluate().map(|y| {
|
||||
let e = felt_to_integer_rep(y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
let mut output = 0;
|
||||
let mut i = 0;
|
||||
v.value().map(|y| {
|
||||
let e = felt_to_integer_rep(*y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::Constant(v) => felt_to_integer_rep(v),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> ValType<F> {
|
||||
/// Returns the inner cell of the [ValType].
|
||||
pub fn cell(&self) -> Option<Cell> {
|
||||
@@ -121,44 +159,6 @@ impl<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + Partia
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> From<ValType<F>> for i32 {
|
||||
fn from(val: ValType<F>) -> Self {
|
||||
match val {
|
||||
ValType::Value(v) => {
|
||||
let mut output = 0_i32;
|
||||
let mut i = 0;
|
||||
v.map(|y| {
|
||||
let e = felt_to_i32(y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::AssignedValue(v) => {
|
||||
let mut output = 0_i32;
|
||||
let mut i = 0;
|
||||
v.evaluate().map(|y| {
|
||||
let e = felt_to_i32(y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
let mut output = 0_i32;
|
||||
let mut i = 0;
|
||||
v.value().map(|y| {
|
||||
let e = felt_to_i32(*y);
|
||||
output = e;
|
||||
i += 1;
|
||||
});
|
||||
output
|
||||
}
|
||||
ValType::Constant(v) => felt_to_i32(v),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> From<F> for ValType<F> {
|
||||
fn from(t: F) -> ValType<F> {
|
||||
ValType::Constant(t)
|
||||
@@ -317,8 +317,8 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
/// Allocate a new [ValTensor::Value] from the given [Tensor] of [i64].
|
||||
pub fn from_i64_tensor(t: Tensor<i64>) -> ValTensor<F> {
|
||||
let inner = t.map(|x| ValType::Value(Value::known(i64_to_felt(x))));
|
||||
pub fn from_integer_rep_tensor(t: Tensor<IntegerRep>) -> ValTensor<F> {
|
||||
let inner = t.map(|x| ValType::Value(Value::known(integer_rep_to_felt(x))));
|
||||
inner.into()
|
||||
}
|
||||
|
||||
@@ -521,9 +521,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Calls `int_evals` on the inner tensor.
|
||||
pub fn get_int_evals(&self) -> Result<Tensor<i64>, TensorError> {
|
||||
pub fn int_evals(&self) -> Result<Tensor<IntegerRep>, TensorError> {
|
||||
// finally convert to vector of integers
|
||||
let mut integer_evals: Vec<i64> = vec![];
|
||||
let mut integer_evals: Vec<IntegerRep> = vec![];
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: _, ..
|
||||
@@ -531,25 +531,26 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
// we have to push to an externally created vector or else vaf.map() returns an evaluation wrapped in Value<> (which we don't want)
|
||||
let _ = v.map(|vaf| match vaf {
|
||||
ValType::Value(v) => v.map(|f| {
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(f));
|
||||
integer_evals.push(crate::fieldutils::felt_to_integer_rep(f));
|
||||
}),
|
||||
ValType::AssignedValue(v) => v.map(|f| {
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate()));
|
||||
integer_evals.push(crate::fieldutils::felt_to_integer_rep(f.evaluate()));
|
||||
}),
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
v.value_field().map(|f| {
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate()));
|
||||
integer_evals
|
||||
.push(crate::fieldutils::felt_to_integer_rep(f.evaluate()));
|
||||
})
|
||||
}
|
||||
ValType::Constant(v) => {
|
||||
integer_evals.push(crate::fieldutils::felt_to_i64(v));
|
||||
integer_evals.push(crate::fieldutils::felt_to_integer_rep(v));
|
||||
Value::unknown()
|
||||
}
|
||||
});
|
||||
}
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
let mut tensor: Tensor<i64> = integer_evals.into_iter().into();
|
||||
let mut tensor: Tensor<IntegerRep> = integer_evals.into_iter().into();
|
||||
match tensor.reshape(self.dims()) {
|
||||
_ => {}
|
||||
};
|
||||
@@ -1002,25 +1003,22 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
}
|
||||
/// A [String] representation of the [ValTensor] for display, for example in showing intermediate values in a computational graph.
|
||||
pub fn show(&self) -> String {
|
||||
match self.clone() {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: _, ..
|
||||
} => {
|
||||
let r: Tensor<i32> = v.map(|x| x.into());
|
||||
if r.len() > 10 {
|
||||
let start = r[..5].to_vec();
|
||||
let end = r[r.len() - 5..].to_vec();
|
||||
// print the two split by ... in the middle
|
||||
format!(
|
||||
"[{} ... {}]",
|
||||
start.iter().map(|x| format!("{}", x)).join(", "),
|
||||
end.iter().map(|x| format!("{}", x)).join(", ")
|
||||
)
|
||||
} else {
|
||||
format!("{:?}", r)
|
||||
}
|
||||
}
|
||||
_ => "ValTensor not PrevAssigned".into(),
|
||||
let r = match self.int_evals() {
|
||||
Ok(v) => v,
|
||||
Err(_) => return "ValTensor not PrevAssigned".into(),
|
||||
};
|
||||
|
||||
if r.len() > 10 {
|
||||
let start = r[..5].to_vec();
|
||||
let end = r[r.len() - 5..].to_vec();
|
||||
// print the two split by ... in the middle
|
||||
format!(
|
||||
"[{} ... {}]",
|
||||
start.iter().map(|x| format!("{}", x)).join(", "),
|
||||
end.iter().map(|x| format!("{}", x)).join(", ")
|
||||
)
|
||||
} else {
|
||||
format!("{:?}", r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -368,7 +368,7 @@ impl VarTensor {
|
||||
.sum::<usize>();
|
||||
let dims = &dims[*idx];
|
||||
// this should never ever fail
|
||||
let t: Tensor<i32> = Tensor::new(None, dims).unwrap();
|
||||
let t: Tensor<IntegerRep> = Tensor::new(None, dims).unwrap();
|
||||
Ok(t.enum_map(|coord, _| {
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord);
|
||||
region.assign_advice_from_instance(
|
||||
@@ -497,7 +497,7 @@ impl VarTensor {
|
||||
let (x, y, z) = self.cartesian_coord(offset + coord * step);
|
||||
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
|
||||
// assert that duplication occurred correctly
|
||||
assert_eq!(Into::<i32>::into(k.clone()), Into::<i32>::into(v[coord - 1].clone()));
|
||||
assert_eq!(Into::<IntegerRep>::into(k.clone()), Into::<IntegerRep>::into(v[coord - 1].clone()));
|
||||
};
|
||||
|
||||
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
|
||||
@@ -533,13 +533,14 @@ impl VarTensor {
|
||||
if matches!(check_mode, CheckMode::SAFE) {
|
||||
// during key generation this will be 0 so we use this as a flag to check
|
||||
// TODO: this isn't very safe and would be better to get the phase directly
|
||||
let is_assigned = !Into::<Tensor<i32>>::into(res.clone().get_inner().unwrap())
|
||||
let res_evals = res.int_evals().unwrap();
|
||||
let is_assigned = res_evals
|
||||
.iter()
|
||||
.all(|&x| x == 0);
|
||||
if is_assigned {
|
||||
if !is_assigned {
|
||||
assert_eq!(
|
||||
Into::<Tensor<i32>>::into(values.get_inner().unwrap()),
|
||||
Into::<Tensor<i32>>::into(res.get_inner().unwrap())
|
||||
values.int_evals().unwrap(),
|
||||
res_evals
|
||||
)};
|
||||
}
|
||||
|
||||
|
||||
89
src/wasm.rs
89
src/wasm.rs
@@ -1,41 +1,52 @@
|
||||
use crate::circuit::modules::polycommit::PolyCommitChip;
|
||||
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
|
||||
use crate::circuit::modules::poseidon::PoseidonChip;
|
||||
use crate::circuit::modules::Module;
|
||||
use crate::fieldutils::felt_to_i64;
|
||||
use crate::fieldutils::i64_to_felt;
|
||||
use crate::graph::modules::POSEIDON_LEN_GRAPH;
|
||||
use crate::graph::quantize_float;
|
||||
use crate::graph::scale_to_multiplier;
|
||||
use crate::graph::{GraphCircuit, GraphSettings};
|
||||
use crate::pfsys::create_proof_circuit;
|
||||
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
|
||||
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
|
||||
use crate::pfsys::verify_proof_circuit;
|
||||
use crate::pfsys::TranscriptType;
|
||||
use crate::tensor::TensorType;
|
||||
use crate::CheckMode;
|
||||
use crate::Commitments;
|
||||
use crate::{
|
||||
circuit::{
|
||||
modules::{
|
||||
polycommit::PolyCommitChip,
|
||||
poseidon::{
|
||||
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
|
||||
PoseidonChip,
|
||||
},
|
||||
Module,
|
||||
},
|
||||
region::RegionSettings,
|
||||
},
|
||||
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
|
||||
graph::{
|
||||
modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit,
|
||||
GraphSettings,
|
||||
},
|
||||
pfsys::{
|
||||
create_proof_circuit,
|
||||
evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript},
|
||||
verify_proof_circuit, TranscriptType,
|
||||
},
|
||||
tensor::TensorType,
|
||||
CheckMode, Commitments,
|
||||
};
|
||||
use console_error_panic_hook;
|
||||
use halo2_proofs::plonk::*;
|
||||
use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver};
|
||||
use halo2_proofs::poly::ipa::multiopen::{ProverIPA, VerifierIPA};
|
||||
use halo2_proofs::poly::ipa::{
|
||||
commitment::{IPACommitmentScheme, ParamsIPA},
|
||||
strategy::SingleStrategy as IPASingleStrategy,
|
||||
use halo2_proofs::{
|
||||
plonk::*,
|
||||
poly::{
|
||||
commitment::{CommitmentScheme, ParamsProver},
|
||||
ipa::{
|
||||
commitment::{IPACommitmentScheme, ParamsIPA},
|
||||
multiopen::{ProverIPA, VerifierIPA},
|
||||
strategy::SingleStrategy as IPASingleStrategy,
|
||||
},
|
||||
kzg::{
|
||||
commitment::{KZGCommitmentScheme, ParamsKZG},
|
||||
multiopen::{ProverSHPLONK, VerifierSHPLONK},
|
||||
strategy::SingleStrategy as KZGSingleStrategy,
|
||||
},
|
||||
VerificationStrategy,
|
||||
},
|
||||
};
|
||||
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
|
||||
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
|
||||
use halo2_proofs::poly::kzg::{
|
||||
commitment::{KZGCommitmentScheme, ParamsKZG},
|
||||
strategy::SingleStrategy as KZGSingleStrategy,
|
||||
};
|
||||
use halo2_proofs::poly::VerificationStrategy;
|
||||
use halo2_solidity_verifier::encode_calldata;
|
||||
use halo2curves::bn256::{Bn256, Fr, G1Affine};
|
||||
use halo2curves::ff::{FromUniformBytes, PrimeField};
|
||||
use snark_verifier::loader::native::NativeLoader;
|
||||
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
|
||||
use halo2curves::{
|
||||
bn256::{Bn256, Fr, G1Affine},
|
||||
ff::{FromUniformBytes, PrimeField},
|
||||
};
|
||||
use snark_verifier::{loader::native::NativeLoader, system::halo2::transcript::evm::EvmTranscript};
|
||||
use std::str::FromStr;
|
||||
use wasm_bindgen::prelude::*;
|
||||
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
|
||||
@@ -113,7 +124,7 @@ pub fn feltToInt(
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
Ok(wasm_bindgen::Clamped(
|
||||
serde_json::to_vec(&felt_to_i64(felt))
|
||||
serde_json::to_vec(&felt_to_integer_rep(felt))
|
||||
.map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?,
|
||||
))
|
||||
}
|
||||
@@ -127,7 +138,7 @@ pub fn feltToFloat(
|
||||
) -> Result<f64, JsError> {
|
||||
let felt: Fr = serde_json::from_slice(&array[..])
|
||||
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
|
||||
let int_rep = felt_to_i64(felt);
|
||||
let int_rep = felt_to_integer_rep(felt);
|
||||
let multiplier = scale_to_multiplier(scale);
|
||||
Ok(int_rep as f64 / multiplier)
|
||||
}
|
||||
@@ -141,7 +152,7 @@ pub fn floatToFelt(
|
||||
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
|
||||
let int_rep =
|
||||
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
let felt = i64_to_felt(int_rep);
|
||||
let felt = integer_rep_to_felt(int_rep);
|
||||
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
|
||||
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|
||||
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
|
||||
@@ -275,7 +286,7 @@ pub fn genWitness(
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
let witness = circuit
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, false, false)
|
||||
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, RegionSettings::all_false())
|
||||
.map_err(|e| JsError::new(&format!("{}", e)))?;
|
||||
|
||||
serde_json::to_vec(&witness)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
mod native_tests {
|
||||
|
||||
use ezkl::circuit::Tolerance;
|
||||
use ezkl::fieldutils::{felt_to_i64, i64_to_felt};
|
||||
use ezkl::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
|
||||
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
|
||||
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
|
||||
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
|
||||
@@ -646,6 +646,15 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_hashed_params_public_inputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn mock_fixed_inputs_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
@@ -1413,10 +1422,10 @@ mod native_tests {
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::zero()
|
||||
} else {
|
||||
i64_to_felt(
|
||||
(felt_to_i64(*v) as f32
|
||||
integer_rep_to_felt(
|
||||
(felt_to_integer_rep(*v) as f32
|
||||
* (rand::thread_rng().gen_range(-0.01..0.01) * tolerance))
|
||||
as i64,
|
||||
as IntegerRep,
|
||||
)
|
||||
};
|
||||
|
||||
@@ -1436,10 +1445,10 @@ mod native_tests {
|
||||
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
|
||||
halo2curves::bn256::Fr::from(2)
|
||||
} else {
|
||||
i64_to_felt(
|
||||
(felt_to_i64(*v) as f32
|
||||
integer_rep_to_felt(
|
||||
(felt_to_integer_rep(*v) as f32
|
||||
* (rand::thread_rng().gen_range(0.02..0.1) * tolerance))
|
||||
as i64,
|
||||
as IntegerRep,
|
||||
)
|
||||
};
|
||||
*v + perturbation
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user