Compare commits

..

3 Commits

Author SHA1 Message Date
dante
5cb303b149 Merge branch 'main' into example-reorg 2024-02-05 14:43:01 +00:00
Sofia Wawrzyniak
9fb78c36e0 readding examples 2024-02-05 09:41:01 -05:00
Sofia Wawrzyniak
074db5d229 preliminary bucketing of examples 2024-02-05 09:09:41 -05:00
87 changed files with 871 additions and 2862 deletions

View File

@@ -11,7 +11,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- name: nanoGPT Mock

View File

@@ -45,7 +45,7 @@ jobs:
steps:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- name: Checkout repo

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- name: Build
@@ -38,7 +38,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- name: Docs
@@ -50,7 +50,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -73,7 +73,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -106,7 +106,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -139,7 +139,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -172,7 +172,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -189,7 +189,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -198,8 +198,10 @@ jobs:
# chromedriver-version: "115.0.5790.102"
- name: Install wasm32-unknown-unknown
run: rustup target add wasm32-unknown-unknown
- name: Install wasm runner
run: cargo install wasm-server-runner
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
- name: Run wasm verifier tests
# on mac:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
@@ -212,7 +214,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -229,7 +231,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -284,7 +286,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -310,7 +312,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
- name: KZG prove and verify tests (EVM + VK rendered seperately)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
@@ -343,15 +345,18 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
- name: Install wasm-server-runner
run: cargo install wasm-server-runner
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- name: Use pnpm 8
uses: pnpm/action-setup@v2
@@ -411,11 +416,11 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- uses: baptiste0928/cargo-install@v1
with:
@@ -445,7 +450,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -455,7 +460,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
- name: fuzz tests (EVM)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_fuzz_ --test-threads 2
# - name: fuzz tests
@@ -468,7 +473,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -486,7 +491,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -503,7 +508,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -520,7 +525,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -530,7 +535,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
- name: KZG prove and verify aggr tests
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
@@ -541,7 +546,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -563,7 +568,7 @@ jobs:
python-version: "3.7"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- name: Install solc
@@ -571,7 +576,7 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; maturin develop --features python-bindings --release
- name: Run pytest
@@ -587,7 +592,7 @@ jobs:
python-version: "3.7"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -618,7 +623,7 @@ jobs:
python-version: "3.9"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -628,7 +633,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev b320f350156a0fb15c2eb13dc380deb2367c4474 --profile local --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev 95a93cd397f25f3f8d49d2851eb52bc2d52dd983 --profile local --locked anvil --force
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env; source .env/bin/activate; pip install -r requirements.txt;
- name: Build python ezkl

View File

@@ -22,15 +22,18 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-01-04
toolchain: nightly-2023-08-24
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
- name: Add wasm32-unknown-unknown target
run: rustup target add wasm32-unknown-unknown
- name: Install wasm-server-runner
run: cargo install wasm-server-runner
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-01-04-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2023-08-24-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e

539
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -17,7 +17,7 @@ crate-type = ["cdylib", "rlib"]
[dependencies]
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "main" }
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev="9fff22c", features=["derive_serde"] }
halo2curves = { version = "0.6.0", features = ["derive_serde"] }
rand = { version = "0.8", default_features = false }
itertools = { version = "0.10.3", default_features = false }
clap = { version = "4.3.3", features = ["derive"]}
@@ -34,11 +34,10 @@ bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
unzip-n = "0.1.2"
num = "0.4.1"
portable-atomic = "1.6.0"
# evm related deps
[target.'cfg(not(target_arch = "wasm32"))'.dependencies]
ethers = { version = "2.0.11", default_features = false, features = ["ethers-solc"] }
ethers = { version = "2.0.7", default_features = false, features = ["ethers-solc"] }
indicatif = {version = "0.17.5", features = ["rayon"]}
gag = { version = "1.0.0", default_features = false}
instant = { version = "0.1" }

File diff suppressed because it is too large Load Diff

View File

Before

Width:  |  Height:  |  Size: 109 KiB

After

Width:  |  Height:  |  Size: 109 KiB

View File

@@ -125,8 +125,8 @@ impl BaseOp {
BaseOp::Sum => 1,
BaseOp::SumInit => 1,
BaseOp::Range { .. } => 1,
BaseOp::IsZero => 0,
BaseOp::IsBoolean => 0,
BaseOp::IsZero => 1,
BaseOp::IsBoolean => 1,
}
}

View File

@@ -276,20 +276,9 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
let constraints = match base_op {
BaseOp::IsBoolean => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("non accum: output query failed");
let output = expected_output[base_op.constraint_idx()].clone();
vec![(output.clone()) * (output.clone() - Expression::Constant(F::from(1)))]
}
BaseOp::IsZero => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, 0, 1)
.expect("non accum: output query failed");
vec![expected_output[base_op.constraint_idx()].clone()]
vec![(qis[1].clone()) * (qis[1].clone() - Expression::Constant(F::from(1)))]
}
BaseOp::IsZero => vec![qis[1].clone()],
_ => {
let expected_output: Tensor<Expression<F>> = output
.query_rng(meta, *block_idx, *inner_col_idx, rotation_offset, rng)
@@ -523,10 +512,10 @@ impl<F: PrimeField + TensorType + PartialOrd> BaseConfig<F> {
// we borrow mutably twice so we need to do this dance
let range_check = if let std::collections::btree_map::Entry::Vacant(e) = self.range_checks.entry(range) {
let range_check = if !self.range_checks.contains_key(&range) {
// as all tables have the same input we see if there's another table who's input we can reuse
let range_check = RangeCheck::<F>::configure(cs, range);
e.insert(range_check.clone());
self.range_checks.insert(range, range_check.clone());
range_check
} else {
return Ok(());

View File

@@ -6,6 +6,7 @@ use crate::{
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use serde::{Deserialize, Serialize};
// import run args from model
@@ -68,6 +69,14 @@ pub enum HybridOp {
dim: usize,
num_classes: usize,
},
GatherElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
@@ -75,6 +84,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
HybridOp::Greater | HybridOp::Less | HybridOp::Equals => vec![0, 1],
HybridOp::ScatterElements { .. } => vec![0, 2],
HybridOp::GreaterEqual | HybridOp::LessEqual => vec![0, 1],
_ => vec![],
}
@@ -88,42 +98,176 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let x = inputs[0].clone().map(|x| felt_to_i128(x));
let res = match &self {
HybridOp::ReduceMax { axes, .. } => tensor::ops::max_axes(&x, axes)?,
HybridOp::ReduceMin { axes, .. } => tensor::ops::min_axes(&x, axes)?,
HybridOp::Div { denom, .. } => {
crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64)
let (res, intermediate_lookups) = match &self {
HybridOp::ReduceMax { axes, .. } => {
let res = tensor::ops::max_axes(&x, axes)?;
let max_minus_one =
Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter());
let unit = Tensor::from(vec![1].into_iter());
// relu(x - max(x - 1)
let inter_1 = (x.clone() - max_minus_one)?;
// relu(1 - sum(relu(inter_1)))
let inter_2 = (unit
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
(res.clone(), vec![inter_1, inter_2])
}
HybridOp::ReduceMin { axes, .. } => {
let res = tensor::ops::min_axes(&x, axes)?;
let min_plus_one =
Tensor::from(vec![x.clone().into_iter().min().unwrap() + 1].into_iter());
let unit = Tensor::from(vec![1].into_iter());
// relu(min(x + 1) - x)
let inter_1 = (min_plus_one - x.clone())?;
// relu(1 - sum(relu(inter_1)))
let inter_2 = (unit
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
(res.clone(), vec![inter_1, inter_2])
}
HybridOp::Div {
denom,
use_range_check_for_int,
..
} => {
let res = crate::tensor::ops::nonlinearities::const_div(&x, denom.0 as f64);
// if denom is a round number and use_range_check_for_int is true, use range check check
if denom.0.fract() == 0.0 && *use_range_check_for_int {
let divisor = Tensor::from(vec![denom.0 as i128 / 2].into_iter());
(res, vec![-divisor.clone(), divisor])
} else {
(res, vec![x])
}
}
HybridOp::Recip {
input_scale,
output_scale,
..
} => crate::tensor::ops::nonlinearities::recip(
&x,
input_scale.0 as f64,
output_scale.0 as f64,
),
HybridOp::ReduceArgMax { dim } => tensor::ops::argmax_axes(&x, *dim)?,
HybridOp::ReduceArgMin { dim } => tensor::ops::argmin_axes(&x, *dim)?,
use_range_check_for_int,
} => {
let res = crate::tensor::ops::nonlinearities::recip(
&x,
input_scale.0 as f64,
output_scale.0 as f64,
);
// if scale is a round number and use_range_check_for_int is true, use range check check
if input_scale.0.fract() == 0.0 && *use_range_check_for_int {
let err_tol = Tensor::from(
vec![(output_scale.0 * input_scale.0) as i128 / 2].into_iter(),
);
(res, vec![-err_tol.clone(), err_tol])
} else {
(res, vec![x])
}
}
HybridOp::ReduceArgMax { dim } => {
let res = tensor::ops::argmax_axes(&x, *dim)?;
let indices = Tensor::from(0..x.dims()[*dim] as i128);
let mut inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
let inter =
Op::f(&HybridOp::ReduceMax { axes: vec![*dim] }, inputs)?.intermediate_lookups;
inter_equals.extend(inter);
(res.clone(), inter_equals)
}
HybridOp::ReduceArgMin { dim } => {
let res = tensor::ops::argmin_axes(&x, *dim)?;
let indices = Tensor::from(0..x.dims()[*dim] as i128);
let mut inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
let inter =
Op::f(&HybridOp::ReduceMin { axes: vec![*dim] }, inputs)?.intermediate_lookups;
inter_equals.extend(inter);
(res.clone(), inter_equals)
}
HybridOp::Gather { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather(&x, idx, *dim)?
log::debug!("idx: {}", idx.show());
let res = tensor::ops::gather(&x, idx, *dim)?;
(res.clone(), vec![])
} else {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?
let indices = Tensor::from(0..x.dims()[*dim] as i128);
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
let res = tensor::ops::gather(&x, &y.map(|x| x as usize), *dim)?;
(res.clone(), inter_equals)
}
}
HybridOp::OneHot { dim, num_classes } => {
tensor::ops::one_hot(&x, *num_classes, *dim)?.clone()
let indices = Tensor::from(0..x.dims()[*dim] as i128);
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
let res = tensor::ops::one_hot(&x, *num_classes, *dim)?;
(res.clone(), inter_equals)
}
HybridOp::TopK { dim, k, largest } => {
let res = tensor::ops::topk_axes(&x, *k, *dim, *largest)?;
HybridOp::TopK { dim, k, largest } => tensor::ops::topk_axes(&x, *k, *dim, *largest)?,
let mut inter_equals = x
.clone()
.into_iter()
.flat_map(|elem| {
tensor::ops::equals(&res, &vec![elem].into_iter().into())
.unwrap()
.1
})
.collect::<Vec<_>>();
// sort in descending order and take pairwise differences
inter_equals.push(
x.into_iter()
.sorted()
.tuple_windows()
.map(|(a, b)| b - a)
.into(),
);
(res.clone(), inter_equals)
}
HybridOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
log::debug!("idx: {}", idx.show());
let res = tensor::ops::gather_elements(&x, idx, *dim)?;
(res.clone(), vec![])
} else {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
let indices = Tensor::from(0..x.dims()[*dim] as i128);
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
let res = tensor::ops::gather_elements(&x, &y.map(|x| x as usize), *dim)?;
(res.clone(), inter_equals)
}
}
HybridOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
log::debug!("idx: {}", idx.show());
let src = inputs[1].clone().map(|x| felt_to_i128(x));
let res = tensor::ops::scatter(&x, idx, &src, *dim)?;
(res.clone(), vec![])
} else {
let idx = inputs[1].clone().map(|x| felt_to_i128(x) as usize);
let src = inputs[2].clone().map(|x| felt_to_i128(x));
let indices = Tensor::from(0..x.dims()[*dim] as i128);
let inter_equals: Vec<Tensor<i128>> = vec![indices.clone(), -indices];
let res = tensor::ops::scatter(&x, &idx, &src, *dim)?;
(res.clone(), inter_equals)
}
}
HybridOp::MaxPool2d {
padding,
stride,
pool_dims,
..
} => tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
} => {
let max_minus_one =
Tensor::from(vec![x.clone().into_iter().max().unwrap() - 1].into_iter());
let unit = Tensor::from(vec![1].into_iter());
// relu(x - max(x - 1)
let inter_1 = (x.clone() - max_minus_one)?;
// relu(1 - sum(relu(inter_1)))
let inter_2 = (unit
- tensor::ops::sum(&tensor::ops::nonlinearities::leakyrelu(&inter_1, 0.0))?)?;
(
tensor::ops::max_pool2d(&x, padding, stride, pool_dims)?,
vec![inter_1, inter_2],
)
}
HybridOp::SumPool {
padding,
stride,
@@ -135,7 +279,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
}
HybridOp::RangeCheck(tol) => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val)
(
tensor::ops::nonlinearities::range_check_percent(&[x, y], 128, 128, tol.val),
vec![],
)
}
HybridOp::Greater => {
let y = inputs[1].clone().map(|x| felt_to_i128(x));
@@ -162,7 +309,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
// convert back to felt
let output = res.map(|x| i128_to_felt(x));
Ok(ForwardResult { output })
Ok(ForwardResult {
output,
intermediate_lookups,
})
}
fn as_string(&self) -> String {
@@ -216,6 +366,8 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
HybridOp::TopK { k, dim, largest } => {
format!("TOPK (k={}, dim={}, largest={})", k, dim, largest)
}
HybridOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
HybridOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
HybridOp::OneHot { dim, num_classes } => {
format!("ONEHOT (dim={}, num_classes={})", dim, num_classes)
}
@@ -288,7 +440,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
config,
region,
values.try_into()?,
&LookupOp::Div { denom: *denom },
&LookupOp::Div {
denom: denom.clone(),
},
)?
}
}
@@ -299,7 +453,26 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
layouts::gather(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
*dim,
)?
.into()
} else {
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
HybridOp::MaxPool2d {
padding,
stride,

View File

@@ -128,7 +128,13 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
let range_check_bracket = felt_to_i128(output_scale * input_scale) / 2;
let is_assigned = !input.any_unknowns()?;
let mut scaled_unit =
Tensor::from(vec![ValType::Constant(output_scale * input_scale)].into_iter());
scaled_unit.set_visibility(&crate::graph::Visibility::Fixed);
let scaled_unit = region.assign(&config.inputs[1], &scaled_unit.into())?;
region.increment(scaled_unit.len());
let is_assigned = !input.any_unknowns()? && !scaled_unit.any_unknowns()?;
let mut claimed_output: ValTensor<F> = if is_assigned {
let input_evals = input.get_int_evals()?;
@@ -160,14 +166,27 @@ pub fn recip<F: PrimeField + TensorType + PartialOrd>(
log::debug!("product: {:?}", product.get_int_evals()?);
// this is now of scale 2 * scale hence why we rescaled the unit scale
let diff_with_input = pairwise(
config,
region,
&[product.clone(), scaled_unit.clone()],
BaseOp::Sub,
)?;
log::debug!("scaled_unit: {:?}", scaled_unit.get_int_evals()?);
// debug print the diff
log::debug!("diff_with_input: {:?}", diff_with_input.get_int_evals()?);
log::debug!("range_check_bracket: {:?}", range_check_bracket);
// at most the error should be in the original unit scale's range
range_check(
config,
region,
&[product],
&(range_check_bracket, 3 * range_check_bracket),
&[diff_with_input],
&(-range_check_bracket, range_check_bracket),
)?;
Ok(claimed_output)
@@ -214,7 +233,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd>(
let mut assigned_len = 0;
for (i, input) in values.iter_mut().enumerate() {
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
input.pad_to_zero_rem(block_width)?;
let inp = {
let (res, len) = region.assign_with_duplication(
&config.inputs[i],
@@ -770,7 +789,14 @@ fn one_hot<F: PrimeField + TensorType + PartialOrd>(
let assigned_input = region.assign(&config.inputs[0], &input)?;
// now assert all elems are 0 or 1
let assigned_output = boolean_identity(config, region, &[output.clone()], true)?;
let assigned_output = region.assign(&config.inputs[1], &output)?;
if !region.is_dummy() {
for i in 0..assigned_output.len() {
let (x, y, z) = config.output.cartesian_coord(region.linear_coord() + i);
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
region.enable(selector, z)?;
}
}
region.increment(std::cmp::max(assigned_output.len(), assigned_input.len()));
let sum = sum(config, region, &[assigned_output.clone()])?;
@@ -1142,7 +1168,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd>(
let assigned_len: usize;
let input = {
let mut input = values[0].clone();
input.pad_to_zero_rem(block_width, ValType::Constant(F::ZERO))?;
input.pad_to_zero_rem(block_width)?;
let (res, len) =
region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?;
assigned_len = len;
@@ -1211,7 +1237,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd>(
let assigned_len: usize;
let input = {
let mut input = values[0].clone();
input.pad_to_zero_rem(block_width, ValType::Constant(F::ONE))?;
input.pad_to_zero_rem(block_width)?;
let (res, len) =
region.assign_with_duplication(&config.inputs[1], &input, &config.check_mode, false)?;
assigned_len = len;
@@ -1675,28 +1701,10 @@ pub fn equals<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 2],
) -> Result<ValTensor<F>, Box<dyn Error>> {
let diff = pairwise(config, region, values, BaseOp::Sub)?;
let diff_inverse = diff.inverse()?;
let product_diff_and_invert =
pairwise(config, region, &[diff.clone(), diff_inverse], BaseOp::Mult)?;
// constant of 1
let mut ones = Tensor::from(vec![ValType::Constant(F::from(1))].into_iter());
ones.set_visibility(&crate::graph::Visibility::Fixed);
let res = nonlinearity(config, region, &[diff], &LookupOp::KroneckerDelta)?;
// subtract
let output = pairwise(
config,
region,
&[ones.into(), product_diff_and_invert],
BaseOp::Sub,
)?;
// take the product of diff and output
let prod_check = pairwise(config, region, &[diff, output.clone()], BaseOp::Mult)?;
is_zero_identity(config, region, &[prod_check], false)?;
Ok(output)
Ok(res)
}
/// Xor boolean operation
@@ -1760,7 +1768,21 @@ pub fn iff<F: PrimeField + TensorType + PartialOrd>(
.into();
// make sure mask is boolean
let assigned_mask = boolean_identity(config, region, &[mask.clone()], true)?;
let assigned_mask = region.assign(&config.inputs[1], mask)?;
// Enable the selectors
if !region.is_dummy() {
(0..assigned_mask.len())
.map(|i| {
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
region.increment(assigned_mask.len());
let one_minus_mask = pairwise(config, region, &[unit, assigned_mask.clone()], BaseOp::Sub)?;
@@ -1858,11 +1880,13 @@ pub fn sumpool<F: PrimeField + TensorType + PartialOrd>(
last_elem.reshape(&[&[batch_size, image_channels], shape].concat())?;
if normalized {
last_elem = div(
last_elem = nonlinearity(
config,
region,
&[last_elem],
F::from((kernel_shape.0 * kernel_shape.1) as u64),
&LookupOp::Div {
denom: utils::F32((kernel_shape.0 * kernel_shape.1) as f32),
},
)?;
}
Ok(last_elem)
@@ -2359,60 +2383,18 @@ pub fn identity<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
/// is zero identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
pub fn is_zero_identity<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
assign: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let output = if assign || !values[0].get_const_indices()?.is_empty() {
let output = region.assign(&config.output, &values[0])?;
region.increment(output.len());
output
} else {
values[0].clone()
};
// Enable the selectors
if !region.is_dummy() {
(0..output.len())
.map(|j| {
let index = region.linear_coord() - j - 1;
let (x, y, z) = config.output.cartesian_coord(index);
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
Ok(output)
}
/// Boolean identity constraint. Usually used to constrain an instance column to an advice so the returned cells / values can be operated upon.
pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
assign: bool,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let output = if assign || !values[0].get_const_indices()?.is_empty() {
// get zero constants indices
let output = region.assign(&config.output, &values[0])?;
region.increment(output.len());
output
} else {
values[0].clone()
};
let output = region.assign(&config.inputs[1], &values[0])?;
// Enable the selectors
if !region.is_dummy() {
(0..output.len())
.map(|j| {
let index = region.linear_coord() - j - 1;
let (x, y, z) = config.output.cartesian_coord(index);
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + j);
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
region.enable(selector, z)?;
@@ -2420,6 +2402,7 @@ pub fn boolean_identity<F: PrimeField + TensorType + PartialOrd>(
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
region.increment(output.len());
Ok(output)
}
@@ -2469,7 +2452,7 @@ pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 1],
range: &crate::circuit::table::Range,
) -> Result<ValTensor<F>, Box<dyn Error>> {
region.add_used_range_check(*range)?;
region.add_used_range_check(*range);
// time the entire operation
let timer = instant::Instant::now();
@@ -2488,7 +2471,7 @@ pub fn range_check<F: PrimeField + TensorType + PartialOrd>(
let (x, y, z) = config
.lookup_input
.cartesian_coord(region.linear_coord() + i);
let selector = config.range_check_selectors.get(&(*range, x, y));
let selector = config.range_check_selectors.get(&(range.clone(), x, y));
region.enable(selector, z)?;
Ok(())
})
@@ -2515,7 +2498,7 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
values: &[ValTensor<F>; 1],
nl: &LookupOp,
) -> Result<ValTensor<F>, Box<dyn Error>> {
region.add_used_lookup(nl.clone(), values)?;
region.add_used_lookup(nl.clone());
// time the entire operation
let timer = instant::Instant::now();
@@ -2598,6 +2581,22 @@ pub fn nonlinearity<F: PrimeField + TensorType + PartialOrd>(
Ok(output)
}
/// mean function layout
pub fn mean<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
region: &mut RegionCtx<F>,
values: &[ValTensor<F>; 1],
scale: usize,
) -> Result<ValTensor<F>, Box<dyn Error>> {
let x = &values[0];
let sum_x = sum(config, region, &[x.clone()])?;
let nl = LookupOp::Div {
denom: utils::F32((scale * x.len()) as f32),
};
nonlinearity(config, region, &[sum_x], &nl)
}
/// Argmax
pub fn argmax<F: PrimeField + TensorType + PartialOrd>(
config: &BaseConfig<F>,
@@ -2710,8 +2709,24 @@ pub fn max<F: PrimeField + TensorType + PartialOrd>(
)?;
// relu(x - max(x - 1))
let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?;
// constraining relu(x - max(x - 1)) = 0/1
boolean_identity(config, region, &[relu.clone()], false)?;
let len = relu.dims().iter().product();
// y_i*(1 - y_i) =0 // assert the values are either 0 or 1
region.assign(&config.inputs[1], &relu)?;
if !region.is_dummy() {
(0..len)
.map(|i| {
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
region.increment(len);
// sum(relu(x - max(x - 1)))
let sum_relu = sum(config, region, &[relu])?;
@@ -2722,7 +2737,13 @@ pub fn max<F: PrimeField + TensorType + PartialOrd>(
nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?;
// constraining 1 - sum(relu(x - max(x - 1))) = 0
is_zero_identity(config, region, &[relu_one_minus_sum_relu], false)?;
region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?;
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
region.enable(selector, z)?;
region.increment(relu_one_minus_sum_relu.len());
Ok(assigned_max_val)
}
@@ -2767,8 +2788,23 @@ pub fn min<F: PrimeField + TensorType + PartialOrd>(
// relu(min(x + 1) - x)
let relu = nonlinearity(config, region, &[diff], &LookupOp::ReLU)?;
// constraining relu(min(x + 1) - x) = 0/1
boolean_identity(config, region, &[relu.clone()], false)?;
let len = relu.dims().iter().product();
region.assign(&config.inputs[1], &relu)?;
// y_i*(1 - y_i) =0 // assert the values are either 0 or 1
if !region.is_dummy() {
(0..len)
.map(|i| {
let (x, y, z) = config.inputs[1].cartesian_coord(region.linear_coord() + i);
let selector = config.selectors.get(&(BaseOp::IsBoolean, x, y));
region.enable(selector, z)?;
Ok(())
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
}
region.increment(len);
// sum(relu(min(x + 1) - x))
let sum_relu = sum(config, region, &[relu])?;
@@ -2779,8 +2815,14 @@ pub fn min<F: PrimeField + TensorType + PartialOrd>(
let relu_one_minus_sum_relu =
nonlinearity(config, region, &[one_minus_sum_relu], &LookupOp::ReLU)?;
region.assign(&config.inputs[1], &relu_one_minus_sum_relu)?;
// constraining product to 0
is_zero_identity(config, region, &[relu_one_minus_sum_relu], false)?;
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
region.enable(selector, z)?;
region.increment(relu_one_minus_sum_relu.len());
Ok(assigned_min_val)
}
@@ -2999,8 +3041,15 @@ pub fn range_check_percent<F: PrimeField + TensorType + PartialOrd>(
// Add the lower_bound and upper_bound
let sum = pairwise(config, region, &[lower_bound, upper_bound], BaseOp::Add)?;
// Assign the sum tensor to the inputs
region.assign(&config.inputs[1], &sum)?;
// Constrain the sum to be all zeros
is_zero_identity(config, region, &[sum.clone()], false)?;
let (x, y, z) = config.output.cartesian_coord(region.linear_coord());
let selector = config.selectors.get(&(BaseOp::IsZero, x, y));
region.enable(selector, z)?;
region.increment(sum.len());
Ok(sum)
}

View File

@@ -227,7 +227,10 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
let output = res.map(|x| i128_to_felt(x));
Ok(ForwardResult { output })
Ok(ForwardResult {
output,
intermediate_lookups: vec![],
})
}
/// Returns the name of the operation

View File

@@ -29,6 +29,7 @@ pub mod region;
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
pub(crate) output: Tensor<F>,
pub(crate) intermediate_lookups: Vec<Tensor<i128>>,
}
/// A trait representing operations that can be represented as constraints in a circuit.
@@ -177,6 +178,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
Ok(ForwardResult {
output: x[0].clone(),
intermediate_lookups: vec![],
})
}
@@ -199,7 +201,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
config,
region,
values[..].try_into()?,
true,
)?))
}
_ => Ok(Some(super::layouts::identity(
@@ -302,7 +303,10 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let output = self.quantized_values.clone();
Ok(ForwardResult { output })
Ok(ForwardResult {
output,
intermediate_lookups: vec![],
})
}
fn as_string(&self) -> String {

View File

@@ -1,6 +1,5 @@
use crate::{
circuit::layouts,
fieldutils::felt_to_i128,
tensor::{self, Tensor, TensorError},
};
@@ -10,14 +9,6 @@ use super::{base::BaseOp, *};
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum PolyOp {
GatherElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
ScatterElements {
dim: usize,
constant_idx: Option<Tensor<usize>>,
},
MultiBroadcastTo {
shape: Vec<usize>,
},
@@ -90,8 +81,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
fn as_string(&self) -> String {
match &self {
PolyOp::GatherElements { dim, .. } => format!("GATHERELEMENTS (dim={})", dim),
PolyOp::ScatterElements { dim, .. } => format!("SCATTERELEMENTS (dim={})", dim),
PolyOp::MultiBroadcastTo { shape } => format!("MULTIBROADCASTTO (shape={:?})", shape),
PolyOp::MoveAxis { .. } => "MOVEAXIS".into(),
PolyOp::Downsample { .. } => "DOWNSAMPLE".into(),
@@ -214,36 +203,14 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
if 1 != inputs.len() {
return Err(TensorError::DimMismatch("slice inputs".to_string()));
}
tensor::ops::slice(&inputs[0], axis, start, end)
}
PolyOp::GatherElements { dim, constant_idx } => {
let x = inputs[0].clone();
let y = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
tensor::ops::gather_elements(&x, &y, *dim)
}
PolyOp::ScatterElements { dim, constant_idx } => {
let x = inputs[0].clone();
let idx = if let Some(idx) = constant_idx {
idx.clone()
} else {
inputs[1].clone().map(|x| felt_to_i128(x) as usize)
};
let src = if let Some(_) = constant_idx {
inputs[1].clone()
} else {
inputs[2].clone()
};
tensor::ops::scatter(&x, &idx, &src, *dim)
Ok(tensor::ops::slice(&inputs[0], axis, start, end)?)
}
}?;
Ok(ForwardResult { output: res })
Ok(ForwardResult {
output: res,
intermediate_lookups: vec![],
})
}
fn layout(
@@ -284,26 +251,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
PolyOp::Conv { padding, stride } => {
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
}
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
} else {
layouts::gather_elements(config, region, values[..].try_into()?, *dim)?
}
}
PolyOp::ScatterElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::scatter(
values[0].get_inner_tensor()?,
idx,
values[1].get_inner_tensor()?,
*dim,
)?
.into()
} else {
layouts::scatter_elements(config, region, values[..].try_into()?, *dim)?
}
}
PolyOp::DeConv {
padding,
output_padding,
@@ -405,8 +352,6 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
vec![1, 2]
} else if matches!(self, PolyOp::Concat { .. }) {
(0..100).collect()
} else if matches!(self, PolyOp::ScatterElements { .. }) {
vec![0, 2]
} else {
vec![]
}

View File

@@ -16,8 +16,6 @@ use std::{
},
};
use portable_atomic::AtomicI128 as AtomicInt;
use super::lookup::LookupOp;
/// Region error
@@ -68,8 +66,6 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd> {
total_constants: usize,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
max_lookup_inputs: i128,
min_lookup_inputs: i128,
}
impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
@@ -91,8 +87,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
total_constants: 0,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
}
}
/// Create a new region context from a wrapped region
@@ -110,8 +104,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
total_constants: 0,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
}
}
@@ -128,8 +120,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
total_constants: 0,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
}
}
@@ -151,8 +141,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
total_constants,
used_lookups,
used_range_checks,
max_lookup_inputs: 0,
min_lookup_inputs: 0,
}
}
@@ -192,7 +180,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
/// Create a new region context per loop iteration
/// hacky but it works
pub fn dummy_loop<T: TensorType + Send + Sync>(
&mut self,
output: &mut Tensor<T>,
@@ -203,10 +190,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let constants = AtomicUsize::new(self.total_constants());
let max_lookup_inputs =
AtomicInt::new(self.max_lookup_inputs().try_into().unwrap_or_default());
let min_lookup_inputs =
AtomicInt::new(self.min_lookup_inputs().try_into().unwrap_or_default());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
@@ -238,14 +221,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
local_reg.total_constants() - starting_constants,
Ordering::SeqCst,
);
let local_max_lookup_inputs =
local_reg.max_lookup_inputs().try_into().unwrap_or_default();
let local_min_lookup_inputs =
local_reg.min_lookup_inputs().try_into().unwrap_or_default();
max_lookup_inputs.fetch_max(local_max_lookup_inputs, Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_min_lookup_inputs, Ordering::SeqCst);
// update the lookups
let mut lookups = lookups.lock().unwrap();
lookups.extend(local_reg.used_lookups());
@@ -259,11 +234,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
})?;
self.total_constants = constants.into_inner();
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
{
self.max_lookup_inputs = max_lookup_inputs.into_inner() as i128;
self.min_lookup_inputs = min_lookup_inputs.into_inner() as i128;
}
self.row = row.into_inner();
self.used_lookups = Arc::try_unwrap(lookups)
.map_err(|e| RegionError::from(format!("dummy_loop: failed to get lookups: {:?}", e)))?
@@ -283,54 +253,19 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
Ok(())
}
/// Update the max and min from inputs
pub fn update_max_min_lookup_inputs(
&mut self,
inputs: &[ValTensor<F>],
) -> Result<(), Box<dyn std::error::Error>> {
let (mut min, mut max) = (0, 0);
for i in inputs {
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.get_int_evals()?.into_iter().min().unwrap_or_default());
}
self.max_lookup_inputs = self.max_lookup_inputs.max(max);
self.min_lookup_inputs = self.min_lookup_inputs.min(min);
Ok(())
}
/// Update the max and min from inputs
pub fn update_max_min_lookup_range(
&mut self,
range: Range,
) -> Result<(), Box<dyn std::error::Error>> {
if range.0 > range.1 {
return Err("update_max_min_lookup_range: invalid range".into());
}
self.max_lookup_inputs = self.max_lookup_inputs.max(range.1);
self.min_lookup_inputs = self.min_lookup_inputs.min(range.0);
Ok(())
}
/// Check if the region is dummy
pub fn is_dummy(&self) -> bool {
self.region.is_none()
}
/// add used lookup
pub fn add_used_lookup(
&mut self,
lookup: LookupOp,
inputs: &[ValTensor<F>],
) -> Result<(), Box<dyn std::error::Error>> {
pub fn add_used_lookup(&mut self, lookup: LookupOp) {
self.used_lookups.insert(lookup);
self.update_max_min_lookup_inputs(inputs)
}
/// add used range check
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), Box<dyn std::error::Error>> {
pub fn add_used_range_check(&mut self, range: Range) {
self.used_range_checks.insert(range);
self.update_max_min_lookup_range(range)
}
/// Get the offset
@@ -358,16 +293,6 @@ impl<'a, F: PrimeField + TensorType + PartialOrd> RegionCtx<'a, F> {
self.used_range_checks.clone()
}
/// max lookup inputs
pub fn max_lookup_inputs(&self) -> i128 {
self.max_lookup_inputs
}
/// min lookup inputs
pub fn min_lookup_inputs(&self) -> i128 {
self.min_lookup_inputs
}
/// Assign a constant value
pub fn assign_constant(&mut self, var: &VarTensor, value: F) -> Result<ValType<F>, Error> {
self.total_constants += 1;

View File

@@ -726,11 +726,11 @@ impl AccuracyResults {
let percentage_error = error.enum_map(|i, x| Ok::<_, TensorError>(x / original[i]))?;
let abs_percentage_error = percentage_error.map(|x| x.abs());
errors.extend(error);
abs_errors.extend(abs_error);
squared_errors.extend(squared_error);
percentage_errors.extend(percentage_error);
abs_percentage_errors.extend(abs_percentage_error);
errors.extend(error.into_iter());
abs_errors.extend(abs_error.into_iter());
squared_errors.extend(squared_error.into_iter());
percentage_errors.extend(percentage_error.into_iter());
abs_percentage_errors.extend(abs_percentage_error.into_iter());
}
let mean_percent_error =
@@ -826,7 +826,10 @@ pub(crate) fn calibrate(
let range = if let Some(scales) = scales {
scales
} else {
(10..14).collect::<Vec<crate::Scale>>()
match target {
CalibrationTarget::Resources { .. } => (8..10).collect::<Vec<crate::Scale>>(),
CalibrationTarget::Accuracy => (10..14).collect::<Vec<crate::Scale>>(),
}
};
let div_rebasing = if let Some(div_rebasing) = div_rebasing {

View File

@@ -10,6 +10,7 @@ use crate::circuit::table::Range;
use crate::circuit::Input;
use crate::circuit::InputType;
use crate::circuit::Unknown;
use crate::fieldutils::felt_to_i128;
use crate::tensor::ValType;
use crate::{
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
@@ -67,16 +68,6 @@ pub struct ForwardResult {
pub min_lookup_inputs: i128,
}
impl From<DummyPassRes> for ForwardResult {
fn from(res: DummyPassRes) -> Self {
Self {
outputs: res.outputs,
max_lookup_inputs: res.max_lookup_inputs,
min_lookup_inputs: res.min_lookup_inputs,
}
}
}
/// A circuit configuration for the entirety of a model loaded from an Onnx file.
#[derive(Clone, Debug)]
pub struct ModelConfig {
@@ -102,12 +93,6 @@ pub struct DummyPassRes {
pub lookup_ops: HashSet<LookupOp>,
/// range checks
pub range_checks: HashSet<Range>,
/// max lookup inputs
pub max_lookup_inputs: i128,
/// min lookup inputs
pub min_lookup_inputs: i128,
/// outputs
pub outputs: Vec<Tensor<Fp>>,
}
/// A struct for loading from an Onnx file and converting a computational graph to a circuit.
@@ -501,25 +486,7 @@ impl Model {
);
// this is the total number of variables we will need to allocate
// for the circuit
let default_value = if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let inputs: Vec<ValTensor<Fp>> = self
.graph
.input_shapes()?
.iter()
.map(|shape| {
let mut t: ValTensor<Fp> =
vec![default_value.clone(); shape.iter().product()].into();
t.reshape(shape)?;
Ok(t)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
let res = self.dummy_layout(run_args, &inputs)?;
let res = self.dummy_layout(run_args, &self.graph.input_shapes()?)?;
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
@@ -555,12 +522,206 @@ impl Model {
/// * `model_inputs` - A vector of [Tensor]s to use as inputs to the model.
/// * `run_args` - [RunArgs]
pub fn forward(&self, model_inputs: &[Tensor<Fp>]) -> Result<ForwardResult, Box<dyn Error>> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(&RunArgs::default(), &valtensor_inputs)?;
Ok(res.into())
let mut results: BTreeMap<&usize, Vec<Tensor<Fp>>> = BTreeMap::new();
let mut max_lookup_inputs = 0;
let mut min_lookup_inputs = 0;
let input_shapes = self.graph.input_shapes()?;
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
let mut input = model_inputs[i].clone();
input.reshape(&input_shapes[i])?;
results.insert(input_idx, vec![input]);
}
for (idx, n) in self.graph.nodes.iter() {
let mut inputs = vec![];
if n.is_input() {
let t = results.get(idx).ok_or(GraphError::MissingResults)?[0].clone();
inputs.push(t);
} else {
for (idx, outlet) in n.inputs().iter() {
match results.get(&idx) {
Some(value) => inputs.push(value[*outlet].clone()),
None => return Err(Box::new(GraphError::MissingNode(*idx))),
}
}
};
debug!("executing {}: {}", idx, n.as_str());
debug!("dims: {:?}", n.out_dims());
debug!(
"input_dims: {:?}",
inputs.iter().map(|x| x.dims()).collect::<Vec<_>>()
);
debug!("input nodes: {:?}", n.inputs());
if n.is_lookup() {
let (mut min, mut max) = (0, 0);
for i in &inputs {
max = max.max(
i.iter()
.map(|x| felt_to_i128(*x))
.max()
.ok_or("missing max")?,
);
min = min.min(
i.iter()
.map(|x| felt_to_i128(*x))
.min()
.ok_or("missing min")?,
);
}
max_lookup_inputs = max_lookup_inputs.max(max);
min_lookup_inputs = min_lookup_inputs.min(min);
debug!("max lookup inputs: {}", max);
debug!("min lookup inputs: {}", min);
}
match n {
NodeType::Node(n) => {
// execute the op
let start = instant::Instant::now();
let mut res = Op::<Fp>::f(&n.opkind, &inputs)?;
res.output.reshape(&n.out_dims)?;
let elapsed = start.elapsed();
trace!("op took: {:?}", elapsed);
// see if any of the intermediate lookup calcs are the max
if !res.intermediate_lookups.is_empty() {
let (mut min, mut max) = (0, 0);
for i in &res.intermediate_lookups {
max = max.max(i.clone().into_iter().max().ok_or("missing max")?);
min = min.min(i.clone().into_iter().min().ok_or("missing min")?);
}
max_lookup_inputs = max_lookup_inputs.max(max);
min_lookup_inputs = min_lookup_inputs.min(min);
debug!("intermediate max lookup inputs: {}", max);
debug!("intermediate min lookup inputs: {}", min);
}
debug!(
"------------ output node int {}: {} \n ------------ float: {} \n ------------ max: {} \n ------------ min: {} \n ------------ scale: {}",
idx,
res.output.map(crate::fieldutils::felt_to_i32).show(),
res.output
.map(|x| crate::fieldutils::felt_to_f64(x)
/ scale_to_multiplier(n.out_scale))
.show(),
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).max().unwrap_or(0),
res.output.clone().into_iter().map(crate::fieldutils::felt_to_i128).min().unwrap_or(0),
n.out_scale
);
results.insert(idx, vec![res.output]);
}
NodeType::SubGraph {
model,
output_mappings,
input_mappings,
inputs: input_tuple,
..
} => {
let orig_inputs = inputs.clone();
let input_mappings = input_mappings.clone();
let input_dims = inputs.iter().map(|inp| inp.dims());
let num_iter = number_of_iterations(&input_mappings, input_dims.collect());
debug!(
"{} iteration(s) in a subgraph with inputs {:?} and sources {:?}",
num_iter, input_tuple, model.graph.inputs
);
debug!("input_mappings: {:?}", input_mappings);
let mut full_results: Vec<Tensor<Fp>> = vec![];
for i in 0..num_iter {
// replace the Stacked input with the current chunk iter
for ((mapping, inp), og_input) in
input_mappings.iter().zip(&mut inputs).zip(&orig_inputs)
{
if let InputMapping::Stacked { axis, chunk } = mapping {
let start = i * chunk;
let end = (i + 1) * chunk;
let t = crate::tensor::ops::slice(og_input, axis, &start, &end)?;
*inp = t;
}
}
let res = model.forward(&inputs)?;
// recursively get the max lookup inputs for subgraphs
max_lookup_inputs = max_lookup_inputs.max(res.max_lookup_inputs);
min_lookup_inputs = min_lookup_inputs.min(res.min_lookup_inputs);
let mut outlets = BTreeMap::new();
for (mappings, outlet_res) in output_mappings.iter().zip(res.outputs) {
for mapping in mappings {
match mapping {
OutputMapping::Single { outlet, .. } => {
outlets.insert(outlet, outlet_res.clone());
}
OutputMapping::Stacked { outlet, axis, .. } => {
if !full_results.is_empty() {
let stacked_res = crate::tensor::ops::concat(
&[&full_results[*outlet], &outlet_res],
*axis,
)?;
outlets.insert(outlet, stacked_res);
} else {
outlets.insert(outlet, outlet_res.clone());
}
}
}
}
}
full_results = outlets.into_values().collect_vec();
let output_states = output_state_idx(output_mappings);
let input_states = input_state_idx(&input_mappings);
assert_eq!(input_states.len(), output_states.len());
for (input_idx, output_idx) in input_states.iter().zip(output_states) {
inputs[*input_idx] = full_results[output_idx].clone();
}
}
trace!(
"------------ output subgraph node {}: {:?}",
idx,
full_results
.iter()
.map(|x|
// convert to tensor i32
x.map(crate::fieldutils::felt_to_i32).show())
.collect_vec()
);
results.insert(idx, full_results);
}
}
}
let output_nodes = self.graph.outputs.iter();
debug!(
"model outputs are nodes: {:?}",
output_nodes.clone().collect_vec()
);
let outputs = output_nodes
.map(|(idx, outlet)| {
Ok(results.get(&idx).ok_or(GraphError::MissingResults)?[*outlet].clone())
})
.collect::<Result<Vec<_>, GraphError>>()?;
let res = ForwardResult {
outputs,
max_lookup_inputs,
min_lookup_inputs,
};
Ok(res)
}
/// Loads an Onnx model from a specified path.
@@ -1012,8 +1173,8 @@ impl Model {
);
// set scale for HybridOp::RangeCheck and call self.conf_lookup on that op for percentage tolerance case
let input = &vars.advices[0];
let output = &vars.advices[2];
let index = &vars.advices[1];
let output = &vars.advices[1];
let index = &vars.advices[2];
for op in required_lookups {
base_gate.configure_lookup(meta, input, output, index, lookup_range, logrows, &op)?;
}
@@ -1174,29 +1335,18 @@ impl Model {
};
debug!(
"laying out {}: {}, row:{}, coord:{}, total_constants: {}, max_lookup_inputs: {}, min_lookup_inputs: {}",
"laying out {}: {}, row:{}, coord:{}, total_constants: {}",
idx,
node.as_str(),
region.row(),
region.linear_coord(),
region.total_constants(),
region.max_lookup_inputs(),
region.min_lookup_inputs()
region.total_constants()
);
debug!("dims: {:?}", node.out_dims());
debug!(
"input_dims {:?}",
values.iter().map(|v| v.dims()).collect_vec()
);
debug!("output scales: {:?}", node.out_scales());
debug!("input indices: {:?}", node.inputs());
debug!(
"input scales: {:?}",
node.inputs()
.iter()
.map(|(idx, outlet)| self.graph.nodes[idx].out_scales()[*outlet])
.collect_vec()
);
match &node {
NodeType::Node(n) => {
@@ -1339,13 +1489,28 @@ impl Model {
pub fn dummy_layout(
&self,
run_args: &RunArgs,
inputs: &[ValTensor<Fp>],
input_shapes: &[Vec<usize>],
) -> Result<DummyPassRes, Box<dyn Error>> {
info!("calculating num of constraints using dummy model layout...");
let start_time = instant::Instant::now();
let mut results = BTreeMap::<usize, Vec<ValTensor<Fp>>>::new();
let default_value = if !self.visibility.input.is_fixed() {
ValType::Value(Value::<Fp>::unknown())
} else {
ValType::Constant(Fp::ONE)
};
let inputs: Vec<ValTensor<Fp>> = input_shapes
.iter()
.map(|shape| {
let mut t: ValTensor<Fp> =
vec![default_value.clone(); shape.iter().product()].into();
t.reshape(shape)?;
Ok(t)
})
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
for (i, input_idx) in self.graph.inputs.iter().enumerate() {
results.insert(*input_idx, vec![inputs[i].clone()]);
@@ -1380,12 +1545,12 @@ impl Model {
.collect::<Result<Vec<_>, Box<dyn Error>>>()?;
let _ = outputs
.iter()
.into_iter()
.zip(comparator)
.map(|(o, c)| {
dummy_config.layout(
&mut region,
&[o.clone(), c],
&[o, c],
Box::new(HybridOp::RangeCheck(run_args.tolerance)),
)
})
@@ -1410,23 +1575,12 @@ impl Model {
region.total_constants().to_string().red()
);
let outputs = outputs
.iter()
.map(|x| {
x.get_felt_evals()
.unwrap_or(Tensor::new(Some(&[Fp::ZERO]), &[1]).unwrap())
})
.collect();
let res = DummyPassRes {
num_rows: region.row(),
linear_coord: region.linear_coord(),
total_const_size: region.total_constants(),
lookup_ops: region.used_lookups(),
range_checks: region.used_range_checks(),
max_lookup_inputs: region.max_lookup_inputs(),
min_lookup_inputs: region.min_lookup_inputs(),
outputs,
};
Ok(res)

View File

@@ -148,7 +148,7 @@ impl RebaseScale {
SupportedOp::RebaseScale(RebaseScale {
inner: op.inner.clone(),
target_scale: op.target_scale,
multiplier,
multiplier: multiplier,
rebase_op: HybridOp::Div {
denom: crate::circuit::utils::F32((multiplier) as f32),
use_range_check_for_int: !div_rebasing,
@@ -219,6 +219,8 @@ impl Op<Fp> for RebaseScale {
let mut res = Op::<Fp>::f(&*self.inner, x)?;
let rebase_res = Op::<Fp>::f(&self.rebase_op, &[res.output])?;
res.output = rebase_res.output;
res.intermediate_lookups
.extend(rebase_res.intermediate_lookups);
Ok(res)
}

View File

@@ -439,16 +439,17 @@ pub fn new_op_from_onnx(
let op = load_op::<ScatterElements>(node.op(), idx, node.op().name().to_string())?;
let axis = op.axis;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
dim: axis,
constant_idx: None,
});
let mut op =
SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements {
dim: axis,
constant_idx: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::ScatterElements {
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::ScatterElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})
@@ -477,16 +478,17 @@ pub fn new_op_from_onnx(
let op = load_op::<GatherElements>(node.op(), idx, node.op().name().to_string())?;
let axis = op.axis;
let mut op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
dim: axis,
constant_idx: None,
});
let mut op =
SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements {
dim: axis,
constant_idx: None,
});
// if param_visibility.is_public() {
if let Some(c) = inputs[1].opkind().get_mutable_constant() {
inputs[1].decrement_use();
deleted_indices.push(inputs.len() - 1);
op = SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::GatherElements {
op = SupportedOp::Hybrid(crate::circuit::ops::hybrid::HybridOp::GatherElements {
dim: axis,
constant_idx: Some(c.raw_values.map(|x| x as usize)),
})

View File

@@ -7,6 +7,7 @@
overflowing_literals,
path_statements,
patterns_in_fns_without_body,
private_in_public,
unconditional_recursion,
unused,
unused_allocation,

View File

@@ -580,16 +580,16 @@ impl<T: Clone + TensorType> Tensor<T> {
/// use ezkl::tensor::Tensor;
/// let mut a = Tensor::<i32>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
/// assert_eq!(a.pad_to_zero_rem(4).unwrap(), expected);
///
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
/// assert_eq!(a.pad_to_zero_rem(9).unwrap(), expected);
/// ```
pub fn pad_to_zero_rem(&self, n: usize, pad: T) -> Result<Tensor<T>, TensorError> {
pub fn pad_to_zero_rem(&self, n: usize) -> Result<Tensor<T>, TensorError> {
let mut inner = self.inner.clone();
let remainder = self.len() % n;
if remainder != 0 {
inner.resize(self.len() + n - remainder, pad);
inner.resize(self.len() + n - remainder, T::zero().unwrap());
}
Tensor::new(Some(&inner), &[inner.len()])
}

View File

@@ -243,7 +243,7 @@ pub fn and<
/// Some(&[1, 0, 1, 0, 1, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = equals(&a, &b).unwrap();
/// let result = equals(&a, &b).unwrap().0;
/// let expected = Tensor::<i128>::new(Some(&[1, 0, 1, 0, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
@@ -260,7 +260,7 @@ pub fn equals<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
let a = a.clone();
let b = b.clone();
@@ -268,7 +268,7 @@ pub fn equals<
let result = nonlinearities::kronecker_delta(&diff);
Ok(result)
Ok((result, vec![diff]))
}
/// Greater than operation.
@@ -289,7 +289,7 @@ pub fn equals<
/// ).unwrap();
/// let result = greater(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 1, 0, 0, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result.0, expected);
/// ```
pub fn greater<
T: TensorType
@@ -302,7 +302,7 @@ pub fn greater<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
let mask_inter = (a.clone() - b.clone())?;
let mask = mask_inter.map(|x| {
if x > T::zero().ok_or(TensorError::Unsupported).unwrap() {
@@ -311,7 +311,7 @@ pub fn greater<
T::zero().ok_or(TensorError::Unsupported).unwrap()
}
});
Ok(mask)
Ok((mask, vec![mask_inter]))
}
/// Greater equals than operation.
@@ -332,7 +332,7 @@ pub fn greater<
/// ).unwrap();
/// let result = greater_equal(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 1, 1, 1, 0, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result.0, expected);
/// ```
pub fn greater_equal<
T: TensorType
@@ -345,7 +345,7 @@ pub fn greater_equal<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
let mask_inter = (a.clone() - b.clone())?;
let mask = mask_inter.map(|x| {
if x >= T::zero().ok_or(TensorError::Unsupported).unwrap() {
@@ -354,7 +354,7 @@ pub fn greater_equal<
T::zero().ok_or(TensorError::Unsupported).unwrap()
}
});
Ok(mask)
Ok((mask, vec![mask_inter]))
}
/// Less than to operation.
@@ -375,7 +375,7 @@ pub fn greater_equal<
/// ).unwrap();
/// let result = less(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[0, 1, 0, 0, 0, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result.0, expected);
/// ```
///
pub fn less<
@@ -389,7 +389,7 @@ pub fn less<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
// a < b <=> b > a
greater(b, a)
}
@@ -412,7 +412,7 @@ pub fn less<
/// ).unwrap();
/// let result = less_equal(&a, &b).unwrap();
/// let expected = Tensor::<i128>::new(Some(&[1, 1, 0, 1, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// assert_eq!(result.0, expected);
/// ```
///
pub fn less_equal<
@@ -426,7 +426,7 @@ pub fn less_equal<
>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> Result<Tensor<T>, TensorError> {
) -> Result<(Tensor<T>, Vec<Tensor<T>>), TensorError> {
// a < b <=> b > a
greater_equal(b, a)
}
@@ -2300,12 +2300,12 @@ pub fn deconv<
/// Some(&[5, 2, 3, 0, 4, -1, 3, 1, 6]),
/// &[1, 1, 3, 3],
/// ).unwrap();
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), false).unwrap();
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), false).unwrap().0;
/// let expected: Tensor<i128> = Tensor::<i128>::new(Some(&[11, 8, 8, 10]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(pooled, expected);
///
/// // This time with normalization
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), true).unwrap();
/// let pooled = sumpool(&x, [(0, 0); 2], (1, 1), (2, 2), true).unwrap().0;
/// let expected: Tensor<i128> = Tensor::<i128>::new(Some(&[3, 2, 2, 3]), &[1, 1, 2, 2]).unwrap();
/// assert_eq!(pooled, expected);
/// ```
@@ -2315,7 +2315,7 @@ pub fn sumpool(
stride: (usize, usize),
kernel_shape: (usize, usize),
normalize: bool,
) -> Result<Tensor<i128>, TensorError> {
) -> Result<(Tensor<i128>, Vec<Tensor<i128>>), TensorError> {
let image_dims = image.dims();
let batch_size = image_dims[0];
let image_channels = image_dims[1];
@@ -2345,12 +2345,15 @@ pub fn sumpool(
let mut combined = res.combine()?;
combined.reshape(&[&[batch_size, image_channels], shape].concat())?;
let mut inter = vec![];
if normalize {
inter.push(combined.clone());
let norm = kernel.len();
combined = nonlinearities::const_div(&combined, norm as f64);
}
Ok(combined)
Ok((combined, inter))
}
/// Applies 2D max pooling over a 4D tensor of shape B x C x H x W.
@@ -3045,7 +3048,11 @@ pub mod nonlinearities {
}
/// softmax layout
pub fn softmax_axes(a: &Tensor<i128>, scale: f64, axes: &[usize]) -> Tensor<i128> {
pub fn softmax_axes(
a: &Tensor<i128>,
scale: f64,
axes: &[usize],
) -> (Tensor<i128>, Vec<Tensor<i128>>) {
// we want this to be as small as possible so we set the output scale to 1
let dims = a.dims();
@@ -3053,6 +3060,8 @@ pub mod nonlinearities {
return softmax(a, scale);
}
let mut intermediate_values = vec![];
let cartesian_coord = dims[..dims.len() - 1]
.iter()
.map(|x| 0..*x)
@@ -3075,7 +3084,8 @@ pub mod nonlinearities {
let res = softmax(&softmax_input, scale);
outputs.push(res);
outputs.push(res.0);
intermediate_values.extend(res.1);
}
let mut res = Tensor::new(Some(&outputs), &[outputs.len()])
@@ -3083,7 +3093,7 @@ pub mod nonlinearities {
.combine()
.unwrap();
res.reshape(dims).unwrap();
res
(res, intermediate_values)
}
/// Applies softmax
@@ -3100,20 +3110,24 @@ pub mod nonlinearities {
/// Some(&[2, 2, 3, 2, 2, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = softmax(&x, 128.0);
/// let result = softmax(&x, 128.0).0;
/// // doubles the scale of the input
/// let expected = Tensor::<i128>::new(Some(&[2730, 2730, 2751, 2730, 2730, 2688]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn softmax(a: &Tensor<i128>, scale: f64) -> Tensor<i128> {
pub fn softmax(a: &Tensor<i128>, scale: f64) -> (Tensor<i128>, Vec<Tensor<i128>>) {
// the more accurate calculation is commented out and we implement as below so it matches the steps in layout
let mut intermediate_values = vec![];
intermediate_values.push(a.clone());
let exp = exp(a, scale);
let sum = sum(&exp).unwrap();
intermediate_values.push(sum.clone());
let inv_denom = recip(&sum, scale, scale);
(exp * inv_denom).unwrap()
((exp * inv_denom).unwrap(), intermediate_values)
}
/// Applies range_check_percent

View File

@@ -454,12 +454,12 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
/// Calls `pad_to_zero_rem` on the inner tensor.
pub fn pad_to_zero_rem(&mut self, n: usize, pad: ValType<F>) -> Result<(), Box<dyn Error>> {
pub fn pad_to_zero_rem(&mut self, n: usize) -> Result<(), Box<dyn Error>> {
match self {
ValTensor::Value {
inner: v, dims: d, ..
} => {
*v = v.pad_to_zero_rem(n, pad)?;
*v = v.pad_to_zero_rem(n)?;
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
@@ -871,30 +871,3 @@ impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
}
}
}
impl<F: PrimeField + TensorType + PartialOrd> ValTensor<F> {
/// inverts the inner values
pub fn inverse(&self) -> Result<ValTensor<F>, Box<dyn Error>> {
let mut cloned_self = self.clone();
match &mut cloned_self {
ValTensor::Value {
inner: v, dims: d, ..
} => {
*v = v.map(|x| match x {
ValType::AssignedValue(v) => ValType::AssignedValue(v.invert()),
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
ValType::AssignedValue(v.value_field().invert())
}
ValType::Value(v) => ValType::Value(v.map(|x| x.invert().unwrap_or(F::ZERO))),
ValType::Constant(v) => ValType::Constant(v.invert().unwrap_or(F::ZERO)),
});
*d = v.dims().to_vec();
}
ValTensor::Instance { .. } => {
return Err(Box::new(TensorError::WrongMethod));
}
};
Ok(cloned_self)
}
}

View File

@@ -192,86 +192,86 @@ mod native_tests {
];
const TESTS: [&str; 77] = [
"1l_mlp", //0
"1l_mlp",
"1l_slice",
"1l_concat",
"1l_flatten",
// "1l_average",
"1l_div",
"1l_pad", // 5
"1l_pad",
"1l_reshape",
"1l_eltwise_div",
"1l_sigmoid",
"1l_sqrt",
"1l_softmax", //10
"1l_softmax",
// "1l_instance_norm",
"1l_batch_norm",
"1l_prelu",
"1l_leakyrelu",
"1l_gelu_noappx",
// "1l_gelu_tanh_appx",
"1l_relu", //15
"1l_relu",
"1l_downsample",
"1l_tanh",
"2l_relu_sigmoid_small",
"2l_relu_fc",
"2l_relu_small", //20
"2l_relu_small",
"2l_relu_sigmoid",
"1l_conv",
"2l_sigmoid_small",
"2l_relu_sigmoid_conv",
"3l_relu_conv_fc", //25
"3l_relu_conv_fc",
"4l_relu_conv_fc",
"1l_erf",
"1l_var",
"1l_elu",
"min", //30
"1l_elu", //30
"min",
"max",
"1l_max_pool",
"1l_conv_transpose",
"1l_upsample",
"1l_identity", //35
"1l_upsample", //35
"1l_identity",
"idolmodel",
"trig",
"prelu_gmm",
"lstm",
"rnn", //40
"lstm", //40
"rnn",
"quantize_dequantize",
"1l_where",
"boolean",
"boolean_identity",
"decision_tree", // 45
"decision_tree", // "variable_cnn",
"random_forest",
"gradient_boosted_trees",
"1l_topk",
"xgboost",
"lightgbm", //50
"xgboost", //50
"lightgbm",
"hummingbird_decision_tree",
"oh_decision_tree",
"linear_svc",
"gather_elements",
"less", //55
"less",
"xgboost_reg",
"1l_powf",
"scatter_elements",
"1l_linear",
"linear_regression", //60
"1l_linear", //60
"linear_regression",
"sklearn_mlp",
"1l_mean",
"rounding_ops",
// "mean_as_constrain",
"arange",
"layernorm", //65
"layernorm",
"bitwise_ops",
"blackman_window",
"softsign", //68
"softsign", //70
"softplus",
"selu", //70
"selu",
"hard_sigmoid",
"log_softmax",
"eye",
"ltsf",
"remainder", //75
"remainder",
"bitshift",
];
@@ -560,7 +560,7 @@ mod native_tests {
crate::native_tests::setup_py_env();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 3.1, false);
accuracy_measurement(path, test.to_string(), "private", "private", "public", 1, "resources", 18.0, false);
test_dir.close().unwrap();
}

View File

@@ -131,7 +131,7 @@ mod py_tests {
"simple_demo_aggregated_proofs.ipynb",
"ezkl_demo.ipynb", // 10
"lstm.ipynb",
"set_membership.ipynb", // 12
"set_membership.ipynb",
"decision_tree.ipynb",
"random_forest.ipynb",
"gradient_boosted_trees.ipynb", // 15

View File

@@ -113,13 +113,13 @@ def test_calibrate_over_user_range():
data_path = os.path.join(
examples_path,
'onnx',
'1l_relu',
'1l_average',
'input.json'
)
model_path = os.path.join(
examples_path,
'onnx',
'1l_relu',
'1l_average',
'network.onnx'
)
output_path = os.path.join(
@@ -147,13 +147,13 @@ def test_calibrate():
data_path = os.path.join(
examples_path,
'onnx',
'1l_relu',
'1l_average',
'input.json'
)
model_path = os.path.join(
examples_path,
'onnx',
'1l_relu',
'1l_average',
'network.onnx'
)
output_path = os.path.join(
@@ -183,7 +183,7 @@ def test_model_compile():
model_path = os.path.join(
examples_path,
'onnx',
'1l_relu',
'1l_average',
'network.onnx'
)
compiled_model_path = os.path.join(
@@ -205,7 +205,7 @@ def test_forward():
data_path = os.path.join(
examples_path,
'onnx',
'1l_relu',
'1l_average',
'input.json'
)
model_path = os.path.join(

Binary file not shown.

Binary file not shown.

File diff suppressed because one or more lines are too long

Binary file not shown.