mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 08:17:57 -05:00
Compare commits
3 Commits
v8.1.7
...
example-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5cb303b149 | ||
|
|
9fb78c36e0 | ||
|
|
074db5d229 |
2
.github/workflows/large-tests.yml
vendored
2
.github/workflows/large-tests.yml
vendored
@@ -11,7 +11,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: nanoGPT Mock
|
||||
|
||||
2
.github/workflows/release.yml
vendored
2
.github/workflows/release.yml
vendored
@@ -45,7 +45,7 @@ jobs:
|
||||
steps:
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Checkout repo
|
||||
|
||||
65
.github/workflows/rust.yml
vendored
65
.github/workflows/rust.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Build
|
||||
@@ -38,7 +38,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Docs
|
||||
@@ -50,7 +50,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -73,7 +73,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -106,7 +106,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -139,7 +139,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -172,7 +172,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -189,7 +189,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
@@ -198,8 +198,10 @@ jobs:
|
||||
# chromedriver-version: "115.0.5790.102"
|
||||
- name: Install wasm32-unknown-unknown
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
- name: Install wasm runner
|
||||
run: cargo install wasm-server-runner
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
- name: Run wasm verifier tests
|
||||
# on mac:
|
||||
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
|
||||
@@ -212,7 +214,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -229,7 +231,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -284,7 +286,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -310,7 +312,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
- name: KZG prove and verify tests (EVM + VK rendered seperately)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
|
||||
- name: KZG prove and verify tests (EVM + kzg all)
|
||||
@@ -343,15 +345,18 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Install wasm-server-runner
|
||||
run: cargo install wasm-server-runner
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- name: Use pnpm 8
|
||||
uses: pnpm/action-setup@v2
|
||||
@@ -411,11 +416,11 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
- uses: actions/checkout@v3
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
@@ -445,7 +450,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -455,7 +460,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
- name: fuzz tests (EVM)
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_fuzz_ --test-threads 2
|
||||
# - name: fuzz tests
|
||||
@@ -468,7 +473,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -486,7 +491,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -503,7 +508,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -520,7 +525,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -530,7 +535,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
- name: KZG prove and verify aggr tests
|
||||
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
|
||||
|
||||
@@ -541,7 +546,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -563,7 +568,7 @@ jobs:
|
||||
python-version: "3.7"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- name: Install solc
|
||||
@@ -571,7 +576,7 @@ jobs:
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
- name: Build python ezkl
|
||||
run: source .env/bin/activate; maturin develop --features python-bindings --release
|
||||
- name: Run pytest
|
||||
@@ -587,7 +592,7 @@ jobs:
|
||||
python-version: "3.7"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -618,7 +623,7 @@ jobs:
|
||||
python-version: "3.9"
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
@@ -628,7 +633,7 @@ jobs:
|
||||
- name: Install solc
|
||||
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
|
||||
- name: Install Anvil
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
|
||||
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
|
||||
- name: Setup Virtual Env and Install python dependencies
|
||||
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
|
||||
- name: Build python ezkl
|
||||
|
||||
7
.github/workflows/wasm.yml
vendored
7
.github/workflows/wasm.yml
vendored
@@ -22,15 +22,18 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2024-01-04
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: jetli/wasm-pack-action@v0.4.0
|
||||
- name: Add wasm32-unknown-unknown target
|
||||
run: rustup target add wasm32-unknown-unknown
|
||||
|
||||
- name: Install wasm-server-runner
|
||||
run: cargo install wasm-server-runner
|
||||
|
||||
- name: Add rust-src
|
||||
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
|
||||
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
|
||||
- name: Install binaryen
|
||||
run: |
|
||||
set -e
|
||||
|
||||
539
Cargo.lock
generated
539
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -17,7 +17,7 @@ crate-type = ["cdylib", "rlib"]
|
||||
[dependencies]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "main" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "main" }
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev="9fff22c", features=["derive_serde"] }
|
||||
halo2curves = { version = "0.6.0", features = ["derive_serde"] }
|
||||
rand = { version = "0.8", default_features = false }
|
||||
itertools = { version = "0.10.3", default_features = false }
|
||||
clap = { version = "4.3.3", features = ["derive"]}
|
||||
@@ -34,11 +34,10 @@ bincode = { version = "1.3.3", default_features = false }
|
||||
ark-std = { version = "^0.3.0", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = "1.6.0"
|
||||
|
||||
# evm related deps
|
||||
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
|
||||
ethers = { version = "2.0.11", default_features = false, features = ["ethers-solc"] }
|
||||
ethers = { version = "2.0.7", default_features = false, features = ["ethers-solc"] }
|
||||
indicatif = {version = "0.17.5", features = ["rayon"]}
|
||||
gag = { version = "1.0.0", default_features = false}
|
||||
instant = { version = "0.1" }
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
|
Before Width: | Height: | Size: 109 KiB After Width: | Height: | Size: 109 KiB |
@@ -125,8 +125,8 @@ impl BaseOp {
|
||||
BaseOp::Sum => 1,
|
||||
BaseOp::SumInit => 1,
|
||||
BaseOp::Range { .. } => 1,
|
||||
BaseOp::IsZero => 0,
|
||||
BaseOp::IsBoolean => 0,
|
||||
BaseOp::IsZero => 1,
|
||||
BaseOp::IsBoolean => 1,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -276,20 +276,9 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
|
||||
let constraints = match base_op {
|
||||
BaseOp::IsBoolean => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
|
||||
let output = expected_output[base_op.constraint_idx()].clone();
|
||||
|
||||
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
|
||||
}
|
||||
BaseOp::IsZero => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
|
||||
.expect("non accum: output query failed");
|
||||
vec![expected_output[base_op.constraint_idx()].clone()]
|
||||
vec![(qis[1].clone()) * (qis[1].clone() - Expression::Constant(F::from(1)))]
|
||||
}
|
||||
BaseOp::IsZero => vec![qis[1].clone()],
|
||||
_ => {
|
||||
let expected_output: Tensor<Expression<F>> = output
|
||||
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
|
||||
@@ -523,10 +512,10 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
|
||||
|
||||
// we borrow mutably twice so we need to do this dance
|
||||
|
||||
let range_check = if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
|
||||
let range_check = if !self.range_checks.contains_key(&range) {
|
||||
// as all tables have the same input we see if there's another table who's input we can reuse
|
||||
let range_check = RangeCheck::<F>::configure(cs, range);
|
||||
e.insert(range_check.clone());
|
||||
self.range_checks.insert(range, range_check.clone());
|
||||
range_check
|
||||
} else {
|
||||
return Ok(());
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::{
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
// import run args from model
|
||||
|
||||
@@ -68,6 +69,14 @@ pub enum HybridOp {
|
||||
dim: usize,
|
||||
num_classes: usize,
|
||||
},
|
||||
GatherElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
ScatterElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
@@ -75,6 +84,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
match self {
|
||||
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
|
||||
HybridOp::ScatterElements { .. } => vec![0, 2],
|
||||
HybridOp::GreaterEqual | HybridOp::LessEqual => vec![0, 1],
|
||||
_ => vec![],
|
||||
}
|
||||
@@ -88,42 +98,176 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let x = inputs[0].clone().map(|x| felt_to_i128(x));
|
||||
|
||||
let res = match &self {
|
||||
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
|
||||
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
|
||||
HybridOp::Div { denom, .. } => {
|
||||
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
|
||||
let (res, intermediate_lookups) = match &self {
|
||||
HybridOp::ReduceMax { axes, .. } => {
|
||||
let res = tensor::ops::max_axes(&x, axes)?;
|
||||
let max_minus_one =
|
||||
Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter());
|
||||
let unit = Tensor::from(vec![1].into_iter());
|
||||
// relu(x - max(x - 1)
|
||||
let inter_1 = (x.clone() - max_minus_one)?;
|
||||
// relu(1 - sum(relu(inter_1)))
|
||||
let inter_2 = (unit
|
||||
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
|
||||
|
||||
(res.clone(), vec![inter_1, inter_2])
|
||||
}
|
||||
HybridOp::ReduceMin { axes, .. } => {
|
||||
let res = tensor::ops::min_axes(&x, axes)?;
|
||||
let min_plus_one =
|
||||
Tensor::from(vec![x.clone().into_iter().min().unwrap() + 1].into_iter());
|
||||
let unit = Tensor::from(vec![1].into_iter());
|
||||
// relu(min(x + 1) - x)
|
||||
let inter_1 = (min_plus_one - x.clone())?;
|
||||
// relu(1 - sum(relu(inter_1)))
|
||||
let inter_2 = (unit
|
||||
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
|
||||
(res.clone(), vec![inter_1, inter_2])
|
||||
}
|
||||
HybridOp::Div {
|
||||
denom,
|
||||
use_range_check_for_int,
|
||||
..
|
||||
} => {
|
||||
let res = crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64);
|
||||
// if denom is a round number and use_range_check_for_int is true, use range check check
|
||||
if denom.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
let divisor = Tensor::from(vec![denom.0 as i128 / 2].into_iter());
|
||||
(res, vec![-divisor.clone(), divisor])
|
||||
} else {
|
||||
(res, vec![x])
|
||||
}
|
||||
}
|
||||
HybridOp::Recip {
|
||||
input_scale,
|
||||
output_scale,
|
||||
..
|
||||
} => crate::tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.0 as f64,
|
||||
output_scale.0 as f64,
|
||||
),
|
||||
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
|
||||
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
|
||||
use_range_check_for_int,
|
||||
} => {
|
||||
let res = crate::tensor::ops::nonlinearities::recip(
|
||||
&x,
|
||||
input_scale.0 as f64,
|
||||
output_scale.0 as f64,
|
||||
);
|
||||
// if scale is a round number and use_range_check_for_int is true, use range check check
|
||||
if input_scale.0.fract() == 0.0 && *use_range_check_for_int {
|
||||
let err_tol = Tensor::from(
|
||||
vec![(output_scale.0 * input_scale.0) as i128 / 2].into_iter(),
|
||||
);
|
||||
(res, vec![-err_tol.clone(), err_tol])
|
||||
} else {
|
||||
(res, vec![x])
|
||||
}
|
||||
}
|
||||
HybridOp::ReduceArgMax { dim } => {
|
||||
let res = tensor::ops::argmax_axes(&x, *dim)?;
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let mut inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let inter =
|
||||
Op::f(&HybridOp::ReduceMax { axes: vec![*dim] }, inputs)?.intermediate_lookups;
|
||||
inter_equals.extend(inter);
|
||||
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
HybridOp::ReduceArgMin { dim } => {
|
||||
let res = tensor::ops::argmin_axes(&x, *dim)?;
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let mut inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let inter =
|
||||
Op::f(&HybridOp::ReduceMin { axes: vec![*dim] }, inputs)?.intermediate_lookups;
|
||||
inter_equals.extend(inter);
|
||||
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
HybridOp::Gather { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather(&x, idx, *dim)?
|
||||
log::debug!("idx: {}", idx.show());
|
||||
let res = tensor::ops::gather(&x, idx, *dim)?;
|
||||
(res.clone(), vec![])
|
||||
} else {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
}
|
||||
HybridOp::OneHot { dim, num_classes } => {
|
||||
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::one_hot(&x, *num_classes, *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
HybridOp::TopK { dim, k, largest } => {
|
||||
let res = tensor::ops::topk_axes(&x, *k, *dim, *largest)?;
|
||||
|
||||
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
|
||||
let mut inter_equals = x
|
||||
.clone()
|
||||
.into_iter()
|
||||
.flat_map(|elem| {
|
||||
tensor::ops::equals(&res, &vec![elem].into_iter().into())
|
||||
.unwrap()
|
||||
.1
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// sort in descending order and take pairwise differences
|
||||
inter_equals.push(
|
||||
x.into_iter()
|
||||
.sorted()
|
||||
.tuple_windows()
|
||||
.map(|(a, b)| b - a)
|
||||
.into(),
|
||||
);
|
||||
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
HybridOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
log::debug!("idx: {}", idx.show());
|
||||
let res = tensor::ops::gather_elements(&x, idx, *dim)?;
|
||||
(res.clone(), vec![])
|
||||
} else {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::gather_elements(&x, &y.map(|x| x as usize), *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
}
|
||||
HybridOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
log::debug!("idx: {}", idx.show());
|
||||
let src = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
let res = tensor::ops::scatter(&x, idx, &src, *dim)?;
|
||||
(res.clone(), vec![])
|
||||
} else {
|
||||
let idx = inputs[1].clone().map(|x| felt_to_i128(x) as usize);
|
||||
let src = inputs[2].clone().map(|x| felt_to_i128(x));
|
||||
let indices = Tensor::from(0..x.dims()[*dim] as i128);
|
||||
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
|
||||
let res = tensor::ops::scatter(&x, &idx, &src, *dim)?;
|
||||
(res.clone(), inter_equals)
|
||||
}
|
||||
}
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
pool_dims,
|
||||
..
|
||||
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
|
||||
} => {
|
||||
let max_minus_one =
|
||||
Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter());
|
||||
let unit = Tensor::from(vec![1].into_iter());
|
||||
// relu(x - max(x - 1)
|
||||
let inter_1 = (x.clone() - max_minus_one)?;
|
||||
// relu(1 - sum(relu(inter_1)))
|
||||
let inter_2 = (unit
|
||||
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
|
||||
(
|
||||
tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
|
||||
vec![inter_1, inter_2],
|
||||
)
|
||||
}
|
||||
HybridOp::SumPool {
|
||||
padding,
|
||||
stride,
|
||||
@@ -135,7 +279,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
}
|
||||
HybridOp::RangeCheck(tol) => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
|
||||
(
|
||||
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val),
|
||||
vec![],
|
||||
)
|
||||
}
|
||||
HybridOp::Greater => {
|
||||
let y = inputs[1].clone().map(|x| felt_to_i128(x));
|
||||
@@ -162,7 +309,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
// convert back to felt
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
Ok(ForwardResult {
|
||||
output,
|
||||
intermediate_lookups,
|
||||
})
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
@@ -216,6 +366,8 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
HybridOp::TopK { k, dim, largest } => {
|
||||
format!("TOPK (k={}, dim={}, largest={})", k, dim, largest)
|
||||
}
|
||||
HybridOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
|
||||
HybridOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
|
||||
HybridOp::OneHot { dim, num_classes } => {
|
||||
format!("ONEHOT (dim={}, num_classes={})", dim, num_classes)
|
||||
}
|
||||
@@ -288,7 +440,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
config,
|
||||
region,
|
||||
values.try_into()?,
|
||||
&LookupOp::Div { denom: *denom },
|
||||
&LookupOp::Div {
|
||||
denom: denom.clone(),
|
||||
},
|
||||
)?
|
||||
}
|
||||
}
|
||||
@@ -299,7 +453,26 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
|
||||
layouts::gather(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
|
||||
HybridOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
HybridOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::scatter(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
values[1].get_inner_tensor()?,
|
||||
*dim,
|
||||
)?
|
||||
.into()
|
||||
} else {
|
||||
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
HybridOp::MaxPool2d {
|
||||
padding,
|
||||
stride,
|
||||
|
||||
@@ -128,7 +128,13 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
let range_check_bracket = felt_to_i128(output_scale * input_scale) / 2;
|
||||
|
||||
let is_assigned = !input.any_unknowns()?;
|
||||
let mut scaled_unit =
|
||||
Tensor::from(vec![ValType::Constant(output_scale * input_scale)].into_iter());
|
||||
scaled_unit.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let scaled_unit = region.assign(&config.inputs[1], &scaled_unit.into())?;
|
||||
region.increment(scaled_unit.len());
|
||||
|
||||
let is_assigned = !input.any_unknowns()? && !scaled_unit.any_unknowns()?;
|
||||
|
||||
let mut claimed_output: ValTensor<F> = if is_assigned {
|
||||
let input_evals = input.get_int_evals()?;
|
||||
@@ -160,14 +166,27 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
log::debug!("product: {:?}", product.get_int_evals()?);
|
||||
|
||||
// this is now of scale 2 * scale hence why we rescaled the unit scale
|
||||
let diff_with_input = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[product.clone(), scaled_unit.clone()],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
log::debug!("scaled_unit: {:?}", scaled_unit.get_int_evals()?);
|
||||
|
||||
// debug print the diff
|
||||
log::debug!("diff_with_input: {:?}", diff_with_input.get_int_evals()?);
|
||||
|
||||
log::debug!("range_check_bracket: {:?}", range_check_bracket);
|
||||
|
||||
// at most the error should be in the original unit scale's range
|
||||
range_check(
|
||||
config,
|
||||
region,
|
||||
&[product],
|
||||
&(range_check_bracket, 3 * range_check_bracket),
|
||||
&[diff_with_input],
|
||||
&(-range_check_bracket, range_check_bracket),
|
||||
)?;
|
||||
|
||||
Ok(claimed_output)
|
||||
@@ -214,7 +233,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
let mut assigned_len = 0;
|
||||
for (i, input) in values.iter_mut().enumerate() {
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
input.pad_to_zero_rem(block_width)?;
|
||||
let inp = {
|
||||
let (res, len) = region.assign_with_duplication(
|
||||
&config.inputs[i],
|
||||
@@ -770,7 +789,14 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
|
||||
let assigned_input = region.assign(&config.inputs[0], &input)?;
|
||||
|
||||
// now assert all elems are 0 or 1
|
||||
let assigned_output = boolean_identity(config, region, &[output.clone()], true)?;
|
||||
let assigned_output = region.assign(&config.inputs[1], &output)?;
|
||||
if !region.is_dummy() {
|
||||
for i in 0..assigned_output.len() {
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
}
|
||||
}
|
||||
region.increment(std::cmp::max(assigned_output.len(), assigned_input.len()));
|
||||
|
||||
let sum = sum(config, region, &[assigned_output.clone()])?;
|
||||
@@ -1142,7 +1168,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd>(
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
let mut input = values[0].clone();
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
|
||||
input.pad_to_zero_rem(block_width)?;
|
||||
let (res, len) =
|
||||
region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?;
|
||||
assigned_len = len;
|
||||
@@ -1211,7 +1237,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd>(
|
||||
let assigned_len: usize;
|
||||
let input = {
|
||||
let mut input = values[0].clone();
|
||||
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
|
||||
input.pad_to_zero_rem(block_width)?;
|
||||
let (res, len) =
|
||||
region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?;
|
||||
assigned_len = len;
|
||||
@@ -1675,28 +1701,10 @@ pub fn equals<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 2],
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let diff = pairwise(config, region, values, BaseOp::Sub)?;
|
||||
let diff_inverse = diff.inverse()?;
|
||||
let product_diff_and_invert =
|
||||
pairwise(config, region, &[diff.clone(), diff_inverse], BaseOp::Mult)?;
|
||||
|
||||
// constant of 1
|
||||
let mut ones = Tensor::from(vec![ValType::Constant(F::from(1))].into_iter());
|
||||
ones.set_visibility(&crate::graph::Visibility::Fixed);
|
||||
let res = nonlinearity(config, region, &[diff], &LookupOp::KroneckerDelta)?;
|
||||
|
||||
// subtract
|
||||
let output = pairwise(
|
||||
config,
|
||||
region,
|
||||
&[ones.into(), product_diff_and_invert],
|
||||
BaseOp::Sub,
|
||||
)?;
|
||||
|
||||
// take the product of diff and output
|
||||
let prod_check = pairwise(config, region, &[diff, output.clone()], BaseOp::Mult)?;
|
||||
|
||||
is_zero_identity(config, region, &[prod_check], false)?;
|
||||
|
||||
Ok(output)
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Xor boolean operation
|
||||
@@ -1760,7 +1768,21 @@ pub fn iff<F: PrimeField + TensorType + PartialOrd>(
|
||||
.into();
|
||||
|
||||
// make sure mask is boolean
|
||||
let assigned_mask = boolean_identity(config, region, &[mask.clone()], true)?;
|
||||
let assigned_mask = region.assign(&config.inputs[1], mask)?;
|
||||
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..assigned_mask.len())
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(assigned_mask.len());
|
||||
|
||||
let one_minus_mask = pairwise(config, region, &[unit, assigned_mask.clone()], BaseOp::Sub)?;
|
||||
|
||||
@@ -1858,11 +1880,13 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
|
||||
last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?;
|
||||
|
||||
if normalized {
|
||||
last_elem = div(
|
||||
last_elem = nonlinearity(
|
||||
config,
|
||||
region,
|
||||
&[last_elem],
|
||||
F::from((kernel_shape.0 * kernel_shape.1) as u64),
|
||||
&LookupOp::Div {
|
||||
denom: utils::F32((kernel_shape.0 * kernel_shape.1) as f32),
|
||||
},
|
||||
)?;
|
||||
}
|
||||
Ok(last_elem)
|
||||
@@ -2359,60 +2383,18 @@ pub fn identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// is zero identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
|
||||
pub fn is_zero_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
assign: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let output = if assign || !values[0].get_const_indices()?.is_empty() {
|
||||
let output = region.assign(&config.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
output
|
||||
} else {
|
||||
values[0].clone()
|
||||
};
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output.len())
|
||||
.map(|j| {
|
||||
let index = region.linear_coord() - j - 1;
|
||||
|
||||
let (x, y, z) = config.output.cartesian_coord(index);
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
|
||||
pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
assign: bool,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let output = if assign || !values[0].get_const_indices()?.is_empty() {
|
||||
// get zero constants indices
|
||||
let output = region.assign(&config.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
output
|
||||
} else {
|
||||
values[0].clone()
|
||||
};
|
||||
let output = region.assign(&config.inputs[1], &values[0])?;
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
(0..output.len())
|
||||
.map(|j| {
|
||||
let index = region.linear_coord() - j - 1;
|
||||
|
||||
let (x, y, z) = config.output.cartesian_coord(index);
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + j);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
|
||||
region.enable(selector, z)?;
|
||||
@@ -2420,6 +2402,7 @@ pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
region.increment(output.len());
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -2469,7 +2452,7 @@ pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
range: &crate::circuit::table::Range,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
region.add_used_range_check(*range)?;
|
||||
region.add_used_range_check(*range);
|
||||
|
||||
// time the entire operation
|
||||
let timer = instant::Instant::now();
|
||||
@@ -2488,7 +2471,7 @@ pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
|
||||
let (x, y, z) = config
|
||||
.lookup_input
|
||||
.cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.range_check_selectors.get(&(*range, x, y));
|
||||
let selector = config.range_check_selectors.get(&(range.clone(), x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
@@ -2515,7 +2498,7 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
values: &[ValTensor<F>; 1],
|
||||
nl: &LookupOp,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
region.add_used_lookup(nl.clone(), values)?;
|
||||
region.add_used_lookup(nl.clone());
|
||||
|
||||
// time the entire operation
|
||||
let timer = instant::Instant::now();
|
||||
@@ -2598,6 +2581,22 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
/// mean function layout
|
||||
pub fn mean<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 1],
|
||||
scale: usize,
|
||||
) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let x = &values[0];
|
||||
|
||||
let sum_x = sum(config, region, &[x.clone()])?;
|
||||
let nl = LookupOp::Div {
|
||||
denom: utils::F32((scale * x.len()) as f32),
|
||||
};
|
||||
nonlinearity(config, region, &[sum_x], &nl)
|
||||
}
|
||||
|
||||
/// Argmax
|
||||
pub fn argmax<F: PrimeField + TensorType + PartialOrd>(
|
||||
config: &BaseConfig<F>,
|
||||
@@ -2710,8 +2709,24 @@ pub fn max<F: PrimeField + TensorType + PartialOrd>(
|
||||
)?;
|
||||
// relu(x - max(x - 1))
|
||||
let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?;
|
||||
// constraining relu(x - max(x - 1)) = 0/1
|
||||
boolean_identity(config, region, &[relu.clone()], false)?;
|
||||
|
||||
let len = relu.dims().iter().product();
|
||||
|
||||
// y_i*(1 - y_i) =0 // assert the values are either 0 or 1
|
||||
region.assign(&config.inputs[1], &relu)?;
|
||||
|
||||
if !region.is_dummy() {
|
||||
(0..len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(len);
|
||||
|
||||
// sum(relu(x - max(x - 1)))
|
||||
let sum_relu = sum(config, region, &[relu])?;
|
||||
@@ -2722,7 +2737,13 @@ pub fn max<F: PrimeField + TensorType + PartialOrd>(
|
||||
nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?;
|
||||
|
||||
// constraining 1 - sum(relu(x - max(x - 1))) = 0
|
||||
is_zero_identity(config, region, &[relu_one_minus_sum_relu], false)?;
|
||||
region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?;
|
||||
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
region.enable(selector, z)?;
|
||||
|
||||
region.increment(relu_one_minus_sum_relu.len());
|
||||
|
||||
Ok(assigned_max_val)
|
||||
}
|
||||
@@ -2767,8 +2788,23 @@ pub fn min<F: PrimeField + TensorType + PartialOrd>(
|
||||
|
||||
// relu(min(x + 1) - x)
|
||||
let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?;
|
||||
// constraining relu(min(x + 1) - x) = 0/1
|
||||
boolean_identity(config, region, &[relu.clone()], false)?;
|
||||
|
||||
let len = relu.dims().iter().product();
|
||||
|
||||
region.assign(&config.inputs[1], &relu)?;
|
||||
// y_i*(1 - y_i) =0 // assert the values are either 0 or 1
|
||||
if !region.is_dummy() {
|
||||
(0..len)
|
||||
.map(|i| {
|
||||
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
|
||||
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
|
||||
region.enable(selector, z)?;
|
||||
Ok(())
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
}
|
||||
|
||||
region.increment(len);
|
||||
|
||||
// sum(relu(min(x + 1) - x))
|
||||
let sum_relu = sum(config, region, &[relu])?;
|
||||
@@ -2779,8 +2815,14 @@ pub fn min<F: PrimeField + TensorType + PartialOrd>(
|
||||
let relu_one_minus_sum_relu =
|
||||
nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?;
|
||||
|
||||
region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?;
|
||||
|
||||
// constraining product to 0
|
||||
is_zero_identity(config, region, &[relu_one_minus_sum_relu], false)?;
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
region.enable(selector, z)?;
|
||||
|
||||
region.increment(relu_one_minus_sum_relu.len());
|
||||
|
||||
Ok(assigned_min_val)
|
||||
}
|
||||
@@ -2999,8 +3041,15 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
|
||||
// Add the lower_bound and upper_bound
|
||||
let sum = pairwise(config, region, &[lower_bound, upper_bound], BaseOp::Add)?;
|
||||
|
||||
// Assign the sum tensor to the inputs
|
||||
region.assign(&config.inputs[1], &sum)?;
|
||||
|
||||
// Constrain the sum to be all zeros
|
||||
is_zero_identity(config, region, &[sum.clone()], false)?;
|
||||
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
|
||||
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
|
||||
region.enable(selector, z)?;
|
||||
|
||||
region.increment(sum.len());
|
||||
|
||||
Ok(sum)
|
||||
}
|
||||
|
||||
@@ -227,7 +227,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
|
||||
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
Ok(ForwardResult {
|
||||
output,
|
||||
intermediate_lookups: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the name of the operation
|
||||
|
||||
@@ -29,6 +29,7 @@ pub mod region;
|
||||
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
|
||||
pub(crate) output: Tensor<F>,
|
||||
pub(crate) intermediate_lookups: Vec<Tensor<i128>>,
|
||||
}
|
||||
|
||||
/// A trait representing operations that can be represented as constraints in a circuit.
|
||||
@@ -177,6 +178,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
Ok(ForwardResult {
|
||||
output: x[0].clone(),
|
||||
intermediate_lookups: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
@@ -199,7 +201,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
config,
|
||||
region,
|
||||
values[..].try_into()?,
|
||||
true,
|
||||
)?))
|
||||
}
|
||||
_ => Ok(Some(super::layouts::identity(
|
||||
@@ -302,7 +303,10 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
let output = self.quantized_values.clone();
|
||||
|
||||
Ok(ForwardResult { output })
|
||||
Ok(ForwardResult {
|
||||
output,
|
||||
intermediate_lookups: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use crate::{
|
||||
circuit::layouts,
|
||||
fieldutils::felt_to_i128,
|
||||
tensor::{self, Tensor, TensorError},
|
||||
};
|
||||
|
||||
@@ -10,14 +9,6 @@ use super::{base::BaseOp, *};
|
||||
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub enum PolyOp {
|
||||
GatherElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
ScatterElements {
|
||||
dim: usize,
|
||||
constant_idx: Option<Tensor<usize>>,
|
||||
},
|
||||
MultiBroadcastTo {
|
||||
shape: Vec<usize>,
|
||||
},
|
||||
@@ -90,8 +81,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match &self {
|
||||
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
|
||||
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
|
||||
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
|
||||
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
|
||||
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
|
||||
@@ -214,36 +203,14 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
if 1 != inputs.len() {
|
||||
return Err(TensorError::DimMismatch("slice inputs".to_string()));
|
||||
}
|
||||
tensor::ops::slice(&inputs[0], axis, start, end)
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
let y = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
tensor::ops::gather_elements(&x, &y, *dim)
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
let x = inputs[0].clone();
|
||||
|
||||
let idx = if let Some(idx) = constant_idx {
|
||||
idx.clone()
|
||||
} else {
|
||||
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
|
||||
};
|
||||
|
||||
let src = if let Some(_) = constant_idx {
|
||||
inputs[1].clone()
|
||||
} else {
|
||||
inputs[2].clone()
|
||||
};
|
||||
tensor::ops::scatter(&x, &idx, &src, *dim)
|
||||
Ok(tensor::ops::slice(&inputs[0], axis, start, end)?)
|
||||
}
|
||||
}?;
|
||||
|
||||
Ok(ForwardResult { output: res })
|
||||
Ok(ForwardResult {
|
||||
output: res,
|
||||
intermediate_lookups: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
fn layout(
|
||||
@@ -284,26 +251,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
PolyOp::Conv { padding, stride } => {
|
||||
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
|
||||
}
|
||||
PolyOp::GatherElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
|
||||
} else {
|
||||
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
PolyOp::ScatterElements { dim, constant_idx } => {
|
||||
if let Some(idx) = constant_idx {
|
||||
tensor::ops::scatter(
|
||||
values[0].get_inner_tensor()?,
|
||||
idx,
|
||||
values[1].get_inner_tensor()?,
|
||||
*dim,
|
||||
)?
|
||||
.into()
|
||||
} else {
|
||||
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
|
||||
}
|
||||
}
|
||||
PolyOp::DeConv {
|
||||
padding,
|
||||
output_padding,
|
||||
@@ -405,8 +352,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
vec![1, 2]
|
||||
} else if matches!(self, PolyOp::Concat { .. }) {
|
||||
(0..100).collect()
|
||||
} else if matches!(self, PolyOp::ScatterElements { .. }) {
|
||||
vec![0, 2]
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
|
||||
@@ -16,8 +16,6 @@ use std::{
|
||||
},
|
||||
};
|
||||
|
||||
use portable_atomic::AtomicI128 as AtomicInt;
|
||||
|
||||
use super::lookup::LookupOp;
|
||||
|
||||
/// Region error
|
||||
@@ -68,8 +66,6 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
|
||||
total_constants: usize,
|
||||
used_lookups: HashSet<LookupOp>,
|
||||
used_range_checks: HashSet<Range>,
|
||||
max_lookup_inputs: i128,
|
||||
min_lookup_inputs: i128,
|
||||
}
|
||||
|
||||
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
@@ -91,8 +87,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
total_constants: 0,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
}
|
||||
}
|
||||
/// Create a new region context from a wrapped region
|
||||
@@ -110,8 +104,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
total_constants: 0,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,8 +120,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
total_constants: 0,
|
||||
used_lookups: HashSet::new(),
|
||||
used_range_checks: HashSet::new(),
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -151,8 +141,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
total_constants,
|
||||
used_lookups,
|
||||
used_range_checks,
|
||||
max_lookup_inputs: 0,
|
||||
min_lookup_inputs: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,7 +180,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
|
||||
/// Create a new region context per loop iteration
|
||||
/// hacky but it works
|
||||
|
||||
pub fn dummy_loop<T: TensorType + Send + Sync>(
|
||||
&mut self,
|
||||
output: &mut Tensor<T>,
|
||||
@@ -203,10 +190,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
let row = AtomicUsize::new(self.row());
|
||||
let linear_coord = AtomicUsize::new(self.linear_coord());
|
||||
let constants = AtomicUsize::new(self.total_constants());
|
||||
let max_lookup_inputs =
|
||||
AtomicInt::new(self.max_lookup_inputs().try_into().unwrap_or_default());
|
||||
let min_lookup_inputs =
|
||||
AtomicInt::new(self.min_lookup_inputs().try_into().unwrap_or_default());
|
||||
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
|
||||
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
|
||||
|
||||
@@ -238,14 +221,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
local_reg.total_constants() - starting_constants,
|
||||
Ordering::SeqCst,
|
||||
);
|
||||
|
||||
let local_max_lookup_inputs =
|
||||
local_reg.max_lookup_inputs().try_into().unwrap_or_default();
|
||||
let local_min_lookup_inputs =
|
||||
local_reg.min_lookup_inputs().try_into().unwrap_or_default();
|
||||
|
||||
max_lookup_inputs.fetch_max(local_max_lookup_inputs, Ordering::SeqCst);
|
||||
min_lookup_inputs.fetch_min(local_min_lookup_inputs, Ordering::SeqCst);
|
||||
// update the lookups
|
||||
let mut lookups = lookups.lock().unwrap();
|
||||
lookups.extend(local_reg.used_lookups());
|
||||
@@ -259,11 +234,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
})?;
|
||||
self.total_constants = constants.into_inner();
|
||||
self.linear_coord = linear_coord.into_inner();
|
||||
#[allow(trivial_numeric_casts)]
|
||||
{
|
||||
self.max_lookup_inputs = max_lookup_inputs.into_inner() as i128;
|
||||
self.min_lookup_inputs = min_lookup_inputs.into_inner() as i128;
|
||||
}
|
||||
self.row = row.into_inner();
|
||||
self.used_lookups = Arc::try_unwrap(lookups)
|
||||
.map_err(|e| RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e)))?
|
||||
@@ -283,54 +253,19 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_inputs(
|
||||
&mut self,
|
||||
inputs: &[ValTensor<F>],
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in inputs {
|
||||
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
|
||||
min = min.min(i.get_int_evals()?.into_iter().min().unwrap_or_default());
|
||||
}
|
||||
self.max_lookup_inputs = self.max_lookup_inputs.max(max);
|
||||
self.min_lookup_inputs = self.min_lookup_inputs.min(min);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the max and min from inputs
|
||||
pub fn update_max_min_lookup_range(
|
||||
&mut self,
|
||||
range: Range,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
if range.0 > range.1 {
|
||||
return Err("update_max_min_lookup_range: invalid range".into());
|
||||
}
|
||||
|
||||
self.max_lookup_inputs = self.max_lookup_inputs.max(range.1);
|
||||
self.min_lookup_inputs = self.min_lookup_inputs.min(range.0);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if the region is dummy
|
||||
pub fn is_dummy(&self) -> bool {
|
||||
self.region.is_none()
|
||||
}
|
||||
|
||||
/// add used lookup
|
||||
pub fn add_used_lookup(
|
||||
&mut self,
|
||||
lookup: LookupOp,
|
||||
inputs: &[ValTensor<F>],
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn add_used_lookup(&mut self, lookup: LookupOp) {
|
||||
self.used_lookups.insert(lookup);
|
||||
self.update_max_min_lookup_inputs(inputs)
|
||||
}
|
||||
|
||||
/// add used range check
|
||||
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), Box<dyn std::error::Error>> {
|
||||
pub fn add_used_range_check(&mut self, range: Range) {
|
||||
self.used_range_checks.insert(range);
|
||||
self.update_max_min_lookup_range(range)
|
||||
}
|
||||
|
||||
/// Get the offset
|
||||
@@ -358,16 +293,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
|
||||
self.used_range_checks.clone()
|
||||
}
|
||||
|
||||
/// max lookup inputs
|
||||
pub fn max_lookup_inputs(&self) -> i128 {
|
||||
self.max_lookup_inputs
|
||||
}
|
||||
|
||||
/// min lookup inputs
|
||||
pub fn min_lookup_inputs(&self) -> i128 {
|
||||
self.min_lookup_inputs
|
||||
}
|
||||
|
||||
/// Assign a constant value
|
||||
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
|
||||
self.total_constants += 1;
|
||||
|
||||
@@ -726,11 +726,11 @@ impl AccuracyResults {
|
||||
let percentage_error = error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
|
||||
let abs_percentage_error = percentage_error.map(|x| x.abs());
|
||||
|
||||
errors.extend(error);
|
||||
abs_errors.extend(abs_error);
|
||||
squared_errors.extend(squared_error);
|
||||
percentage_errors.extend(percentage_error);
|
||||
abs_percentage_errors.extend(abs_percentage_error);
|
||||
errors.extend(error.into_iter());
|
||||
abs_errors.extend(abs_error.into_iter());
|
||||
squared_errors.extend(squared_error.into_iter());
|
||||
percentage_errors.extend(percentage_error.into_iter());
|
||||
abs_percentage_errors.extend(abs_percentage_error.into_iter());
|
||||
}
|
||||
|
||||
let mean_percent_error =
|
||||
@@ -826,7 +826,10 @@ pub(crate) fn calibrate(
|
||||
let range = if let Some(scales) = scales {
|
||||
scales
|
||||
} else {
|
||||
(10..14).collect::<Vec<crate::Scale>>()
|
||||
match target {
|
||||
CalibrationTarget::Resources { .. } => (8..10).collect::<Vec<crate::Scale>>(),
|
||||
CalibrationTarget::Accuracy => (10..14).collect::<Vec<crate::Scale>>(),
|
||||
}
|
||||
};
|
||||
|
||||
let div_rebasing = if let Some(div_rebasing) = div_rebasing {
|
||||
|
||||
@@ -10,6 +10,7 @@ use crate::circuit::table::Range;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::InputType;
|
||||
use crate::circuit::Unknown;
|
||||
use crate::fieldutils::felt_to_i128;
|
||||
use crate::tensor::ValType;
|
||||
use crate::{
|
||||
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
|
||||
@@ -67,16 +68,6 @@ pub struct ForwardResult {
|
||||
pub min_lookup_inputs: i128,
|
||||
}
|
||||
|
||||
impl From<DummyPassRes> for ForwardResult {
|
||||
fn from(res: DummyPassRes) -> Self {
|
||||
Self {
|
||||
outputs: res.outputs,
|
||||
max_lookup_inputs: res.max_lookup_inputs,
|
||||
min_lookup_inputs: res.min_lookup_inputs,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A circuit configuration for the entirety of a model loaded from an Onnx file.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ModelConfig {
|
||||
@@ -102,12 +93,6 @@ pub struct DummyPassRes {
|
||||
pub lookup_ops: HashSet<LookupOp>,
|
||||
/// range checks
|
||||
pub range_checks: HashSet<Range>,
|
||||
/// max lookup inputs
|
||||
pub max_lookup_inputs: i128,
|
||||
/// min lookup inputs
|
||||
pub min_lookup_inputs: i128,
|
||||
/// outputs
|
||||
pub outputs: Vec<Tensor<Fp>>,
|
||||
}
|
||||
|
||||
/// A struct for loading from an Onnx file and converting a computational graph to a circuit.
|
||||
@@ -501,25 +486,7 @@ impl Model {
|
||||
);
|
||||
// this is the total number of variables we will need to allocate
|
||||
// for the circuit
|
||||
let default_value = if !self.visibility.input.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::ONE)
|
||||
};
|
||||
|
||||
let inputs: Vec<ValTensor<Fp>> = self
|
||||
.graph
|
||||
.input_shapes()?
|
||||
.iter()
|
||||
.map(|shape| {
|
||||
let mut t: ValTensor<Fp> =
|
||||
vec![default_value.clone(); shape.iter().product()].into();
|
||||
t.reshape(shape)?;
|
||||
Ok(t)
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
|
||||
let res = self.dummy_layout(run_args, &inputs)?;
|
||||
let res = self.dummy_layout(run_args, &self.graph.input_shapes()?)?;
|
||||
|
||||
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
|
||||
|
||||
@@ -555,12 +522,206 @@ impl Model {
|
||||
/// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model.
|
||||
/// * `run_args` - [RunArgs]
|
||||
pub fn forward(&self, model_inputs: &[Tensor<Fp>]) -> Result<ForwardResult, Box<dyn Error>> {
|
||||
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
|
||||
.iter()
|
||||
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
|
||||
.collect();
|
||||
let res = self.dummy_layout(&RunArgs::default(), &valtensor_inputs)?;
|
||||
Ok(res.into())
|
||||
let mut results: BTreeMap<&usize, Vec<Tensor<Fp>>> = BTreeMap::new();
|
||||
let mut max_lookup_inputs = 0;
|
||||
let mut min_lookup_inputs = 0;
|
||||
|
||||
let input_shapes = self.graph.input_shapes()?;
|
||||
|
||||
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
|
||||
let mut input = model_inputs[i].clone();
|
||||
input.reshape(&input_shapes[i])?;
|
||||
results.insert(input_idx, vec![input]);
|
||||
}
|
||||
|
||||
for (idx, n) in self.graph.nodes.iter() {
|
||||
let mut inputs = vec![];
|
||||
if n.is_input() {
|
||||
let t = results.get(idx).ok_or(GraphError::MissingResults)?[0].clone();
|
||||
inputs.push(t);
|
||||
} else {
|
||||
for (idx, outlet) in n.inputs().iter() {
|
||||
match results.get(&idx) {
|
||||
Some(value) => inputs.push(value[*outlet].clone()),
|
||||
None => return Err(Box::new(GraphError::MissingNode(*idx))),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
debug!("executing {}: {}", idx, n.as_str());
|
||||
debug!("dims: {:?}", n.out_dims());
|
||||
debug!(
|
||||
"input_dims: {:?}",
|
||||
inputs.iter().map(|x| x.dims()).collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
debug!("input nodes: {:?}", n.inputs());
|
||||
|
||||
if n.is_lookup() {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in &inputs {
|
||||
max = max.max(
|
||||
i.iter()
|
||||
.map(|x| felt_to_i128(*x))
|
||||
.max()
|
||||
.ok_or("missing max")?,
|
||||
);
|
||||
min = min.min(
|
||||
i.iter()
|
||||
.map(|x| felt_to_i128(*x))
|
||||
.min()
|
||||
.ok_or("missing min")?,
|
||||
);
|
||||
}
|
||||
max_lookup_inputs = max_lookup_inputs.max(max);
|
||||
min_lookup_inputs = min_lookup_inputs.min(min);
|
||||
debug!("max lookup inputs: {}", max);
|
||||
debug!("min lookup inputs: {}", min);
|
||||
}
|
||||
|
||||
match n {
|
||||
NodeType::Node(n) => {
|
||||
// execute the op
|
||||
let start = instant::Instant::now();
|
||||
let mut res = Op::<Fp>::f(&n.opkind, &inputs)?;
|
||||
res.output.reshape(&n.out_dims)?;
|
||||
let elapsed = start.elapsed();
|
||||
trace!("op took: {:?}", elapsed);
|
||||
// see if any of the intermediate lookup calcs are the max
|
||||
if !res.intermediate_lookups.is_empty() {
|
||||
let (mut min, mut max) = (0, 0);
|
||||
for i in &res.intermediate_lookups {
|
||||
max = max.max(i.clone().into_iter().max().ok_or("missing max")?);
|
||||
min = min.min(i.clone().into_iter().min().ok_or("missing min")?);
|
||||
}
|
||||
max_lookup_inputs = max_lookup_inputs.max(max);
|
||||
min_lookup_inputs = min_lookup_inputs.min(min);
|
||||
debug!("intermediate max lookup inputs: {}", max);
|
||||
debug!("intermediate min lookup inputs: {}", min);
|
||||
}
|
||||
debug!(
|
||||
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {} \n ------------ scale: {}",
|
||||
idx,
|
||||
res.output.map(crate::fieldutils::felt_to_i32).show(),
|
||||
res.output
|
||||
.map(|x| crate::fieldutils::felt_to_f64(x)
|
||||
/ scale_to_multiplier(n.out_scale))
|
||||
.show(),
|
||||
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).max().unwrap_or(0),
|
||||
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).min().unwrap_or(0),
|
||||
n.out_scale
|
||||
);
|
||||
results.insert(idx, vec![res.output]);
|
||||
}
|
||||
NodeType::SubGraph {
|
||||
model,
|
||||
output_mappings,
|
||||
input_mappings,
|
||||
inputs: input_tuple,
|
||||
..
|
||||
} => {
|
||||
let orig_inputs = inputs.clone();
|
||||
let input_mappings = input_mappings.clone();
|
||||
|
||||
let input_dims = inputs.iter().map(|inp| inp.dims());
|
||||
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());
|
||||
|
||||
debug!(
|
||||
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
|
||||
num_iter, input_tuple, model.graph.inputs
|
||||
);
|
||||
|
||||
debug!("input_mappings: {:?}", input_mappings);
|
||||
|
||||
let mut full_results: Vec<Tensor<Fp>> = vec![];
|
||||
|
||||
for i in 0..num_iter {
|
||||
// replace the Stacked input with the current chunk iter
|
||||
for ((mapping, inp), og_input) in
|
||||
input_mappings.iter().zip(&mut inputs).zip(&orig_inputs)
|
||||
{
|
||||
if let InputMapping::Stacked { axis, chunk } = mapping {
|
||||
let start = i * chunk;
|
||||
let end = (i + 1) * chunk;
|
||||
let t = crate::tensor::ops::slice(og_input, axis, &start, &end)?;
|
||||
*inp = t;
|
||||
}
|
||||
}
|
||||
|
||||
let res = model.forward(&inputs)?;
|
||||
// recursively get the max lookup inputs for subgraphs
|
||||
max_lookup_inputs = max_lookup_inputs.max(res.max_lookup_inputs);
|
||||
min_lookup_inputs = min_lookup_inputs.min(res.min_lookup_inputs);
|
||||
|
||||
let mut outlets = BTreeMap::new();
|
||||
for (mappings, outlet_res) in output_mappings.iter().zip(res.outputs) {
|
||||
for mapping in mappings {
|
||||
match mapping {
|
||||
OutputMapping::Single { outlet, .. } => {
|
||||
outlets.insert(outlet, outlet_res.clone());
|
||||
}
|
||||
OutputMapping::Stacked { outlet, axis, .. } => {
|
||||
if !full_results.is_empty() {
|
||||
let stacked_res = crate::tensor::ops::concat(
|
||||
&[&full_results[*outlet], &outlet_res],
|
||||
*axis,
|
||||
)?;
|
||||
|
||||
outlets.insert(outlet, stacked_res);
|
||||
} else {
|
||||
outlets.insert(outlet, outlet_res.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
full_results = outlets.into_values().collect_vec();
|
||||
|
||||
let output_states = output_state_idx(output_mappings);
|
||||
let input_states = input_state_idx(&input_mappings);
|
||||
|
||||
assert_eq!(input_states.len(), output_states.len());
|
||||
|
||||
for (input_idx, output_idx) in input_states.iter().zip(output_states) {
|
||||
inputs[*input_idx] = full_results[output_idx].clone();
|
||||
}
|
||||
}
|
||||
|
||||
trace!(
|
||||
"------------ output subgraph node {}: {:?}",
|
||||
idx,
|
||||
full_results
|
||||
.iter()
|
||||
.map(|x|
|
||||
// convert to tensor i32
|
||||
x.map(crate::fieldutils::felt_to_i32).show())
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
results.insert(idx, full_results);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let output_nodes = self.graph.outputs.iter();
|
||||
debug!(
|
||||
"model outputs are nodes: {:?}",
|
||||
output_nodes.clone().collect_vec()
|
||||
);
|
||||
let outputs = output_nodes
|
||||
.map(|(idx, outlet)| {
|
||||
Ok(results.get(&idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
|
||||
})
|
||||
.collect::<Result<Vec<_>, GraphError>>()?;
|
||||
|
||||
let res = ForwardResult {
|
||||
outputs,
|
||||
max_lookup_inputs,
|
||||
min_lookup_inputs,
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Loads an Onnx model from a specified path.
|
||||
@@ -1012,8 +1173,8 @@ impl Model {
|
||||
);
|
||||
// set scale for HybridOp::RangeCheck and call self.conf_lookup on that op for percentage tolerance case
|
||||
let input = &vars.advices[0];
|
||||
let output = &vars.advices[2];
|
||||
let index = &vars.advices[1];
|
||||
let output = &vars.advices[1];
|
||||
let index = &vars.advices[2];
|
||||
for op in required_lookups {
|
||||
base_gate.configure_lookup(meta, input, output, index, lookup_range, logrows, &op)?;
|
||||
}
|
||||
@@ -1174,29 +1335,18 @@ impl Model {
|
||||
};
|
||||
|
||||
debug!(
|
||||
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
|
||||
"laying out {}: {}, row:{}, coord:{}, total_constants: {}",
|
||||
idx,
|
||||
node.as_str(),
|
||||
region.row(),
|
||||
region.linear_coord(),
|
||||
region.total_constants(),
|
||||
region.max_lookup_inputs(),
|
||||
region.min_lookup_inputs()
|
||||
region.total_constants()
|
||||
);
|
||||
debug!("dims: {:?}", node.out_dims());
|
||||
debug!(
|
||||
"input_dims {:?}",
|
||||
values.iter().map(|v| v.dims()).collect_vec()
|
||||
);
|
||||
debug!("output scales: {:?}", node.out_scales());
|
||||
debug!("input indices: {:?}", node.inputs());
|
||||
debug!(
|
||||
"input scales: {:?}",
|
||||
node.inputs()
|
||||
.iter()
|
||||
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
|
||||
.collect_vec()
|
||||
);
|
||||
|
||||
match &node {
|
||||
NodeType::Node(n) => {
|
||||
@@ -1339,13 +1489,28 @@ impl Model {
|
||||
pub fn dummy_layout(
|
||||
&self,
|
||||
run_args: &RunArgs,
|
||||
inputs: &[ValTensor<Fp>],
|
||||
input_shapes: &[Vec<usize>],
|
||||
) -> Result<DummyPassRes, Box<dyn Error>> {
|
||||
info!("calculating num of constraints using dummy model layout...");
|
||||
|
||||
let start_time = instant::Instant::now();
|
||||
|
||||
let mut results = BTreeMap::<usize, Vec<ValTensor<Fp>>>::new();
|
||||
let default_value = if !self.visibility.input.is_fixed() {
|
||||
ValType::Value(Value::<Fp>::unknown())
|
||||
} else {
|
||||
ValType::Constant(Fp::ONE)
|
||||
};
|
||||
|
||||
let inputs: Vec<ValTensor<Fp>> = input_shapes
|
||||
.iter()
|
||||
.map(|shape| {
|
||||
let mut t: ValTensor<Fp> =
|
||||
vec![default_value.clone(); shape.iter().product()].into();
|
||||
t.reshape(shape)?;
|
||||
Ok(t)
|
||||
})
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
|
||||
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
|
||||
results.insert(*input_idx, vec![inputs[i].clone()]);
|
||||
@@ -1380,12 +1545,12 @@ impl Model {
|
||||
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
|
||||
|
||||
let _ = outputs
|
||||
.iter()
|
||||
.into_iter()
|
||||
.zip(comparator)
|
||||
.map(|(o, c)| {
|
||||
dummy_config.layout(
|
||||
&mut region,
|
||||
&[o.clone(), c],
|
||||
&[o, c],
|
||||
Box::new(HybridOp::RangeCheck(run_args.tolerance)),
|
||||
)
|
||||
})
|
||||
@@ -1410,23 +1575,12 @@ impl Model {
|
||||
region.total_constants().to_string().red()
|
||||
);
|
||||
|
||||
let outputs = outputs
|
||||
.iter()
|
||||
.map(|x| {
|
||||
x.get_felt_evals()
|
||||
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
|
||||
})
|
||||
.collect();
|
||||
|
||||
let res = DummyPassRes {
|
||||
num_rows: region.row(),
|
||||
linear_coord: region.linear_coord(),
|
||||
total_const_size: region.total_constants(),
|
||||
lookup_ops: region.used_lookups(),
|
||||
range_checks: region.used_range_checks(),
|
||||
max_lookup_inputs: region.max_lookup_inputs(),
|
||||
min_lookup_inputs: region.min_lookup_inputs(),
|
||||
outputs,
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
|
||||
@@ -148,7 +148,7 @@ impl RebaseScale {
|
||||
SupportedOp::RebaseScale(RebaseScale {
|
||||
inner: op.inner.clone(),
|
||||
target_scale: op.target_scale,
|
||||
multiplier,
|
||||
multiplier: multiplier,
|
||||
rebase_op: HybridOp::Div {
|
||||
denom: crate::circuit::utils::F32((multiplier) as f32),
|
||||
use_range_check_for_int: !div_rebasing,
|
||||
@@ -219,6 +219,8 @@ impl Op<Fp> for RebaseScale {
|
||||
let mut res = Op::<Fp>::f(&*self.inner, x)?;
|
||||
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
|
||||
res.output = rebase_res.output;
|
||||
res.intermediate_lookups
|
||||
.extend(rebase_res.intermediate_lookups);
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
@@ -439,16 +439,17 @@ pub fn new_op_from_onnx(
|
||||
let op = load_op::<ScatterElements>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
|
||||
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: None,
|
||||
});
|
||||
let mut op =
|
||||
SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: None,
|
||||
});
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
|
||||
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
@@ -477,16 +478,17 @@ pub fn new_op_from_onnx(
|
||||
let op = load_op::<GatherElements>(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
|
||||
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: None,
|
||||
});
|
||||
let mut op =
|
||||
SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: None,
|
||||
});
|
||||
|
||||
// if param_visibility.is_public() {
|
||||
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
|
||||
inputs[1].decrement_use();
|
||||
deleted_indices.push(inputs.len() - 1);
|
||||
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
|
||||
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements {
|
||||
dim: axis,
|
||||
constant_idx: Some(c.raw_values.map(|x| x as usize)),
|
||||
})
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
overflowing_literals,
|
||||
path_statements,
|
||||
patterns_in_fns_without_body,
|
||||
private_in_public,
|
||||
unconditional_recursion,
|
||||
unused,
|
||||
unused_allocation,
|
||||
|
||||
@@ -580,16 +580,16 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
|
||||
/// assert_eq!(a.pad_to_zero_rem(4).unwrap(), expected);
|
||||
///
|
||||
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
|
||||
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
|
||||
/// assert_eq!(a.pad_to_zero_rem(9).unwrap(), expected);
|
||||
/// ```
|
||||
pub fn pad_to_zero_rem(&self, n: usize, pad: T) -> Result<Tensor<T>, TensorError> {
|
||||
pub fn pad_to_zero_rem(&self, n: usize) -> Result<Tensor<T>, TensorError> {
|
||||
let mut inner = self.inner.clone();
|
||||
let remainder = self.len() % n;
|
||||
if remainder != 0 {
|
||||
inner.resize(self.len() + n - remainder, pad);
|
||||
inner.resize(self.len() + n - remainder, T::zero().unwrap());
|
||||
}
|
||||
Tensor::new(Some(&inner), &[inner.len()])
|
||||
}
|
||||
|
||||
@@ -243,7 +243,7 @@ pub fn and<
|
||||
/// Some(&[1, 0, 1, 0, 1, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = equals(&a, &b).unwrap();
|
||||
/// let result = equals(&a, &b).unwrap().0;
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 0, 1, 0, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
@@ -260,7 +260,7 @@ pub fn equals<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
let a = a.clone();
|
||||
let b = b.clone();
|
||||
|
||||
@@ -268,7 +268,7 @@ pub fn equals<
|
||||
|
||||
let result = nonlinearities::kronecker_delta(&diff);
|
||||
|
||||
Ok(result)
|
||||
Ok((result, vec![diff]))
|
||||
}
|
||||
|
||||
/// Greater than operation.
|
||||
@@ -289,7 +289,7 @@ pub fn equals<
|
||||
/// ).unwrap();
|
||||
/// let result = greater(&a, &b).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 1, 0, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result.0, expected);
|
||||
/// ```
|
||||
pub fn greater<
|
||||
T: TensorType
|
||||
@@ -302,7 +302,7 @@ pub fn greater<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
let mask_inter = (a.clone() - b.clone())?;
|
||||
let mask = mask_inter.map(|x| {
|
||||
if x > T::zero().ok_or(TensorError::Unsupported).unwrap() {
|
||||
@@ -311,7 +311,7 @@ pub fn greater<
|
||||
T::zero().ok_or(TensorError::Unsupported).unwrap()
|
||||
}
|
||||
});
|
||||
Ok(mask)
|
||||
Ok((mask, vec![mask_inter]))
|
||||
}
|
||||
|
||||
/// Greater equals than operation.
|
||||
@@ -332,7 +332,7 @@ pub fn greater<
|
||||
/// ).unwrap();
|
||||
/// let result = greater_equal(&a, &b).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 1, 1, 1, 0, 0]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result.0, expected);
|
||||
/// ```
|
||||
pub fn greater_equal<
|
||||
T: TensorType
|
||||
@@ -345,7 +345,7 @@ pub fn greater_equal<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
let mask_inter = (a.clone() - b.clone())?;
|
||||
let mask = mask_inter.map(|x| {
|
||||
if x >= T::zero().ok_or(TensorError::Unsupported).unwrap() {
|
||||
@@ -354,7 +354,7 @@ pub fn greater_equal<
|
||||
T::zero().ok_or(TensorError::Unsupported).unwrap()
|
||||
}
|
||||
});
|
||||
Ok(mask)
|
||||
Ok((mask, vec![mask_inter]))
|
||||
}
|
||||
|
||||
/// Less than to operation.
|
||||
@@ -375,7 +375,7 @@ pub fn greater_equal<
|
||||
/// ).unwrap();
|
||||
/// let result = less(&a, &b).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 0, 0, 0, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result.0, expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn less<
|
||||
@@ -389,7 +389,7 @@ pub fn less<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
// a < b <=> b > a
|
||||
greater(b, a)
|
||||
}
|
||||
@@ -412,7 +412,7 @@ pub fn less<
|
||||
/// ).unwrap();
|
||||
/// let result = less_equal(&a, &b).unwrap();
|
||||
/// let expected = Tensor::<i128>::new(Some(&[1, 1, 0, 1, 1, 1]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// assert_eq!(result.0, expected);
|
||||
/// ```
|
||||
///
|
||||
pub fn less_equal<
|
||||
@@ -426,7 +426,7 @@ pub fn less_equal<
|
||||
>(
|
||||
a: &Tensor<T>,
|
||||
b: &Tensor<T>,
|
||||
) -> Result<Tensor<T>, TensorError> {
|
||||
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
|
||||
// a < b <=> b > a
|
||||
greater_equal(b, a)
|
||||
}
|
||||
@@ -2300,12 +2300,12 @@ pub fn deconv<
|
||||
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
|
||||
/// &[1, 1, 3, 3],
|
||||
/// ).unwrap();
|
||||
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), false).unwrap();
|
||||
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), false).unwrap().0;
|
||||
/// let expected: Tensor<i128> = Tensor::<i128>::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled, expected);
|
||||
///
|
||||
/// // This time with normalization
|
||||
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), true).unwrap();
|
||||
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), true).unwrap().0;
|
||||
/// let expected: Tensor<i128> = Tensor::<i128>::new(Some(&[3, 2, 2, 3]), &[1, 1, 2, 2]).unwrap();
|
||||
/// assert_eq!(pooled, expected);
|
||||
/// ```
|
||||
@@ -2315,7 +2315,7 @@ pub fn sumpool(
|
||||
stride: (usize, usize),
|
||||
kernel_shape: (usize, usize),
|
||||
normalize: bool,
|
||||
) -> Result<Tensor<i128>, TensorError> {
|
||||
) -> Result<(Tensor<i128>, Vec<Tensor<i128>>), TensorError> {
|
||||
let image_dims = image.dims();
|
||||
let batch_size = image_dims[0];
|
||||
let image_channels = image_dims[1];
|
||||
@@ -2345,12 +2345,15 @@ pub fn sumpool(
|
||||
let mut combined = res.combine()?;
|
||||
combined.reshape(&[&[batch_size, image_channels], shape].concat())?;
|
||||
|
||||
let mut inter = vec![];
|
||||
|
||||
if normalize {
|
||||
inter.push(combined.clone());
|
||||
let norm = kernel.len();
|
||||
combined = nonlinearities::const_div(&combined, norm as f64);
|
||||
}
|
||||
|
||||
Ok(combined)
|
||||
Ok((combined, inter))
|
||||
}
|
||||
|
||||
/// Applies 2D max pooling over a 4D tensor of shape B x C x H x W.
|
||||
@@ -3045,7 +3048,11 @@ pub mod nonlinearities {
|
||||
}
|
||||
|
||||
/// softmax layout
|
||||
pub fn softmax_axes(a: &Tensor<i128>, scale: f64, axes: &[usize]) -> Tensor<i128> {
|
||||
pub fn softmax_axes(
|
||||
a: &Tensor<i128>,
|
||||
scale: f64,
|
||||
axes: &[usize],
|
||||
) -> (Tensor<i128>, Vec<Tensor<i128>>) {
|
||||
// we want this to be as small as possible so we set the output scale to 1
|
||||
let dims = a.dims();
|
||||
|
||||
@@ -3053,6 +3060,8 @@ pub mod nonlinearities {
|
||||
return softmax(a, scale);
|
||||
}
|
||||
|
||||
let mut intermediate_values = vec![];
|
||||
|
||||
let cartesian_coord = dims[..dims.len() - 1]
|
||||
.iter()
|
||||
.map(|x| 0..*x)
|
||||
@@ -3075,7 +3084,8 @@ pub mod nonlinearities {
|
||||
|
||||
let res = softmax(&softmax_input, scale);
|
||||
|
||||
outputs.push(res);
|
||||
outputs.push(res.0);
|
||||
intermediate_values.extend(res.1);
|
||||
}
|
||||
|
||||
let mut res = Tensor::new(Some(&outputs), &[outputs.len()])
|
||||
@@ -3083,7 +3093,7 @@ pub mod nonlinearities {
|
||||
.combine()
|
||||
.unwrap();
|
||||
res.reshape(dims).unwrap();
|
||||
res
|
||||
(res, intermediate_values)
|
||||
}
|
||||
|
||||
/// Applies softmax
|
||||
@@ -3100,20 +3110,24 @@ pub mod nonlinearities {
|
||||
/// Some(&[2, 2, 3, 2, 2, 0]),
|
||||
/// &[2, 3],
|
||||
/// ).unwrap();
|
||||
/// let result = softmax(&x, 128.0);
|
||||
/// let result = softmax(&x, 128.0).0;
|
||||
/// // doubles the scale of the input
|
||||
/// let expected = Tensor::<i128>::new(Some(&[2730, 2730, 2751, 2730, 2730, 2688]), &[2, 3]).unwrap();
|
||||
/// assert_eq!(result, expected);
|
||||
/// ```
|
||||
pub fn softmax(a: &Tensor<i128>, scale: f64) -> Tensor<i128> {
|
||||
pub fn softmax(a: &Tensor<i128>, scale: f64) -> (Tensor<i128>, Vec<Tensor<i128>>) {
|
||||
// the more accurate calculation is commented out and we implement as below so it matches the steps in layout
|
||||
let mut intermediate_values = vec![];
|
||||
|
||||
intermediate_values.push(a.clone());
|
||||
|
||||
let exp = exp(a, scale);
|
||||
|
||||
let sum = sum(&exp).unwrap();
|
||||
intermediate_values.push(sum.clone());
|
||||
let inv_denom = recip(&sum, scale, scale);
|
||||
|
||||
(exp * inv_denom).unwrap()
|
||||
((exp * inv_denom).unwrap(), intermediate_values)
|
||||
}
|
||||
|
||||
/// Applies range_check_percent
|
||||
|
||||
@@ -454,12 +454,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
|
||||
/// Calls `pad_to_zero_rem` on the inner tensor.
|
||||
pub fn pad_to_zero_rem(&mut self, n: usize, pad: ValType<F>) -> Result<(), Box<dyn Error>> {
|
||||
pub fn pad_to_zero_rem(&mut self, n: usize) -> Result<(), Box<dyn Error>> {
|
||||
match self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = v.pad_to_zero_rem(n, pad)?;
|
||||
*v = v.pad_to_zero_rem(n)?;
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
@@ -871,30 +871,3 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
|
||||
/// inverts the inner values
|
||||
pub fn inverse(&self) -> Result<ValTensor<F>, Box<dyn Error>> {
|
||||
let mut cloned_self = self.clone();
|
||||
|
||||
match &mut cloned_self {
|
||||
ValTensor::Value {
|
||||
inner: v, dims: d, ..
|
||||
} => {
|
||||
*v = v.map(|x| match x {
|
||||
ValType::AssignedValue(v) => ValType::AssignedValue(v.invert()),
|
||||
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
|
||||
ValType::AssignedValue(v.value_field().invert())
|
||||
}
|
||||
ValType::Value(v) => ValType::Value(v.map(|x| x.invert().unwrap_or(F::ZERO))),
|
||||
ValType::Constant(v) => ValType::Constant(v.invert().unwrap_or(F::ZERO)),
|
||||
});
|
||||
*d = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => {
|
||||
return Err(Box::new(TensorError::WrongMethod));
|
||||
}
|
||||
};
|
||||
Ok(cloned_self)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,86 +192,86 @@ mod native_tests {
|
||||
];
|
||||
|
||||
const TESTS: [&str; 77] = [
|
||||
"1l_mlp", //0
|
||||
"1l_mlp",
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
"1l_flatten",
|
||||
// "1l_average",
|
||||
"1l_div",
|
||||
"1l_pad", // 5
|
||||
"1l_pad",
|
||||
"1l_reshape",
|
||||
"1l_eltwise_div",
|
||||
"1l_sigmoid",
|
||||
"1l_sqrt",
|
||||
"1l_softmax", //10
|
||||
"1l_softmax",
|
||||
// "1l_instance_norm",
|
||||
"1l_batch_norm",
|
||||
"1l_prelu",
|
||||
"1l_leakyrelu",
|
||||
"1l_gelu_noappx",
|
||||
// "1l_gelu_tanh_appx",
|
||||
"1l_relu", //15
|
||||
"1l_relu",
|
||||
"1l_downsample",
|
||||
"1l_tanh",
|
||||
"2l_relu_sigmoid_small",
|
||||
"2l_relu_fc",
|
||||
"2l_relu_small", //20
|
||||
"2l_relu_small",
|
||||
"2l_relu_sigmoid",
|
||||
"1l_conv",
|
||||
"2l_sigmoid_small",
|
||||
"2l_relu_sigmoid_conv",
|
||||
"3l_relu_conv_fc", //25
|
||||
"3l_relu_conv_fc",
|
||||
"4l_relu_conv_fc",
|
||||
"1l_erf",
|
||||
"1l_var",
|
||||
"1l_elu",
|
||||
"min", //30
|
||||
"1l_elu", //30
|
||||
"min",
|
||||
"max",
|
||||
"1l_max_pool",
|
||||
"1l_conv_transpose",
|
||||
"1l_upsample",
|
||||
"1l_identity", //35
|
||||
"1l_upsample", //35
|
||||
"1l_identity",
|
||||
"idolmodel",
|
||||
"trig",
|
||||
"prelu_gmm",
|
||||
"lstm",
|
||||
"rnn", //40
|
||||
"lstm", //40
|
||||
"rnn",
|
||||
"quantize_dequantize",
|
||||
"1l_where",
|
||||
"boolean",
|
||||
"boolean_identity",
|
||||
"decision_tree", // 45
|
||||
"decision_tree", // "variable_cnn",
|
||||
"random_forest",
|
||||
"gradient_boosted_trees",
|
||||
"1l_topk",
|
||||
"xgboost",
|
||||
"lightgbm", //50
|
||||
"xgboost", //50
|
||||
"lightgbm",
|
||||
"hummingbird_decision_tree",
|
||||
"oh_decision_tree",
|
||||
"linear_svc",
|
||||
"gather_elements",
|
||||
"less", //55
|
||||
"less",
|
||||
"xgboost_reg",
|
||||
"1l_powf",
|
||||
"scatter_elements",
|
||||
"1l_linear",
|
||||
"linear_regression", //60
|
||||
"1l_linear", //60
|
||||
"linear_regression",
|
||||
"sklearn_mlp",
|
||||
"1l_mean",
|
||||
"rounding_ops",
|
||||
// "mean_as_constrain",
|
||||
"arange",
|
||||
"layernorm", //65
|
||||
"layernorm",
|
||||
"bitwise_ops",
|
||||
"blackman_window",
|
||||
"softsign", //68
|
||||
"softsign", //70
|
||||
"softplus",
|
||||
"selu", //70
|
||||
"selu",
|
||||
"hard_sigmoid",
|
||||
"log_softmax",
|
||||
"eye",
|
||||
"ltsf",
|
||||
"remainder", //75
|
||||
"remainder",
|
||||
"bitshift",
|
||||
];
|
||||
|
||||
@@ -560,7 +560,7 @@ mod native_tests {
|
||||
crate::native_tests::setup_py_env();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1, false);
|
||||
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 18.0, false);
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
@@ -131,7 +131,7 @@ mod py_tests {
|
||||
"simple_demo_aggregated_proofs.ipynb",
|
||||
"ezkl_demo.ipynb", // 10
|
||||
"lstm.ipynb",
|
||||
"set_membership.ipynb", // 12
|
||||
"set_membership.ipynb",
|
||||
"decision_tree.ipynb",
|
||||
"random_forest.ipynb",
|
||||
"gradient_boosted_trees.ipynb", // 15
|
||||
|
||||
@@ -113,13 +113,13 @@ def test_calibrate_over_user_range():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_relu',
|
||||
'1l_average',
|
||||
'input.json'
|
||||
)
|
||||
model_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_relu',
|
||||
'1l_average',
|
||||
'network.onnx'
|
||||
)
|
||||
output_path = os.path.join(
|
||||
@@ -147,13 +147,13 @@ def test_calibrate():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_relu',
|
||||
'1l_average',
|
||||
'input.json'
|
||||
)
|
||||
model_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_relu',
|
||||
'1l_average',
|
||||
'network.onnx'
|
||||
)
|
||||
output_path = os.path.join(
|
||||
@@ -183,7 +183,7 @@ def test_model_compile():
|
||||
model_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_relu',
|
||||
'1l_average',
|
||||
'network.onnx'
|
||||
)
|
||||
compiled_model_path = os.path.join(
|
||||
@@ -205,7 +205,7 @@ def test_forward():
|
||||
data_path = os.path.join(
|
||||
examples_path,
|
||||
'onnx',
|
||||
'1l_relu',
|
||||
'1l_average',
|
||||
'input.json'
|
||||
)
|
||||
model_path = os.path.join(
|
||||
|
||||
Binary file not shown.
Binary file not shown.
File diff suppressed because one or more lines are too long
Binary file not shown.
Reference in New Issue
Block a user