Compare commits

..

26 Commits

Author SHA1 Message Date
dante
d21b43502d Merge branch 'main' into ac/small-worm 2024-07-14 22:41:16 -04:00
dante
9f1495bf7f Update rust.yml 2024-07-14 17:25:28 -04:00
dante
355aed3b1f Update binding_tests.py 2024-07-12 23:50:21 +01:00
dante
ca7fe53687 Update Readme.md 2024-07-12 13:04:29 +01:00
dante
1b49f8774f Create Readme.md 2024-07-12 13:02:32 +01:00
dante
ec6be275cf lock 2024-07-12 12:48:34 +01:00
dante
51ace92b73 Update Cargo.toml 2024-07-12 12:47:51 +01:00
dante
e4332ebb14 Update model.rs 2024-07-12 12:45:16 +01:00
dante
186515448c Merge branch 'ac/small-worm' of https://github.com/zkonduit/ezkl into ac/small-worm 2024-07-12 12:44:45 +01:00
dante
c70fd153e2 revert 2024-07-12 12:43:57 +01:00
dante
66d733d0ce Delete examples/onnx/1l_conv_transpose/witness.json 2024-07-12 12:41:46 +01:00
dante
89c5238130 Revert "rm for now"
This reverts commit c6cba69bc3.
2024-07-12 12:40:10 +01:00
dante
e88f362596 Merge branch 'main' into ac/small-worm 2024-07-12 12:39:26 +01:00
Alexander Camuto
069ac6ee6e Update ops.rs 2024-05-02 17:09:20 +01:00
Alexander Camuto
c6cba69bc3 rm for now 2024-05-02 17:02:01 +01:00
Alexander Camuto
eec19d6058 accelerate dot calc 2024-05-02 16:57:09 +01:00
Alexander Camuto
685f462766 cleanup 2024-05-02 15:12:12 +01:00
Alexander Camuto
9b027fc908 smoller 2024-05-02 15:07:01 +01:00
Alexander Camuto
2e211d1314 smol 2024-05-02 15:07:01 +01:00
Alexander Camuto
feae28042e Update rust.yml 2024-05-02 15:07:01 +01:00
Alexander Camuto
885cd880d2 Update rust.yml 2024-05-02 15:07:01 +01:00
Alexander Camuto
b88bb6ccda Update rust.yml 2024-05-02 15:07:01 +01:00
Alexander Camuto
7b32a99856 Update rust.yml 2024-05-02 15:07:00 +01:00
Alexander Camuto
e80eae41f0 worm town 2024-05-02 15:07:00 +01:00
Alexander Camuto
f4d2fc0ccb chore: smallworm example 2024-05-02 15:06:41 +01:00
Alexander Camuto
e7f76c5ae6 Create kmeans.ipynb 2024-05-02 15:04:43 +01:00
41 changed files with 1502 additions and 1482 deletions

View File

@@ -22,7 +22,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -30,7 +30,7 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- name: Install binaryen
run: |
set -e

View File

@@ -11,7 +11,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: nanoGPT Mock

View File

@@ -40,7 +40,7 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
@@ -86,7 +86,7 @@ jobs:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy

View File

@@ -45,7 +45,7 @@ jobs:
steps:
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Checkout repo
@@ -106,27 +106,27 @@ jobs:
include:
- build: windows-msvc
os: windows-latest
rust: nightly-2024-07-18
rust: nightly-2024-02-06
target: x86_64-pc-windows-msvc
- build: macos
os: macos-13
rust: nightly-2024-07-18
rust: nightly-2024-02-06
target: x86_64-apple-darwin
- build: macos-aarch64
os: macos-13
rust: nightly-2024-07-18
rust: nightly-2024-02-06
target: aarch64-apple-darwin
- build: linux-musl
os: ubuntu-22.04
rust: nightly-2024-07-18
rust: nightly-2024-02-06
target: x86_64-unknown-linux-musl
- build: linux-gnu
os: ubuntu-22.04
rust: nightly-2024-07-18
rust: nightly-2024-02-06
target: x86_64-unknown-linux-gnu
- build: linux-aarch64
os: ubuntu-22.04
rust: nightly-2024-07-18
rust: nightly-2024-02-06
target: aarch64-unknown-linux-gnu
steps:

View File

@@ -26,7 +26,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Build
@@ -38,7 +38,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Docs
@@ -50,7 +50,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -73,7 +73,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -106,7 +106,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -139,7 +139,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -172,7 +172,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -189,7 +189,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -199,7 +199,7 @@ jobs:
- name: Install wasm32-unknown-unknown
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- name: Run wasm verifier tests
# on mac:
# AR=/opt/homebrew/opt/llvm/bin/llvm-ar CC=/opt/homebrew/opt/llvm/bin/clang wasm-pack test --firefox --headless -- -Z build-std="panic_abort,std" --features web
@@ -212,7 +212,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -229,7 +229,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -260,8 +260,6 @@ jobs:
run: cargo nextest run --release --verbose tests::mock_hashed_input_::t --test-threads 32
- name: hashed params
run: cargo nextest run --release --verbose tests::mock_hashed_params_::t --test-threads 32
- name: hashed params public inputs
run: cargo nextest run --release --verbose tests::mock_hashed_params_public_inputs_::t --test-threads 32
- name: hashed outputs
run: cargo nextest run --release --verbose tests::mock_hashed_output_::t --test-threads 32
- name: hashed inputs + params + outputs
@@ -290,7 +288,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -308,7 +306,7 @@ jobs:
node-version: "18.12.1"
cache: "pnpm"
- name: "Add rust-src"
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- name: Install dependencies for js tests and in-browser-evm-verifier package
run: |
pnpm install --frozen-lockfile
@@ -330,7 +328,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: KZG prove and verify tests (EVM + VK rendered seperately)
run: cargo nextest run --release --verbose tests_evm::kzg_evm_prove_and_verify_render_seperately_ --test-threads 1
- name: KZG prove and verify tests (EVM + kzg all)
@@ -369,7 +367,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -377,7 +375,7 @@ jobs:
run: rustup target add wasm32-unknown-unknown
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- name: Use pnpm 8
uses: pnpm/action-setup@v2
@@ -443,11 +441,11 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2024-07-18-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2024-02-06-x86_64-unknown-linux-gnu
- uses: actions/checkout@v3
- uses: baptiste0928/cargo-install@v1
with:
@@ -477,7 +475,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -495,7 +493,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -512,7 +510,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -529,7 +527,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -539,7 +537,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: KZG prove and verify aggr tests
run: cargo nextest run --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_::t --test-threads 4 -- --include-ignored
@@ -550,7 +548,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -570,7 +568,7 @@ jobs:
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- name: Install cmake
@@ -580,7 +578,7 @@ jobs:
- name: Setup Virtual Env and Install python dependencies
run: python -m venv .env --clear; source .env/bin/activate; pip install -r requirements.txt;
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Build python ezkl
run: source .env/bin/activate; unset CONDA_PREFIX; maturin develop --features python-bindings --release
- name: Run pytest
@@ -596,7 +594,7 @@ jobs:
python-version: "3.12"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -645,7 +643,7 @@ jobs:
python-version: "3.11"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2024-07-18
toolchain: nightly-2024-02-06
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -655,7 +653,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.20 && solc --version
- name: Install Anvil
run: cargo install --git https://github.com/foundry-rs/foundry --rev 62cdea8ff9e6efef011f77e295823b5f2dbeb3a1 --locked anvil --force
run: cargo install --git https://github.com/foundry-rs/foundry --rev c2233ec9fe61e0920c61c6d779bc707252852037 --profile local --locked anvil --force
- name: Install pip
run: python -m ensurepip --upgrade
- name: Setup Virtual Env and Install python dependencies

50
Cargo.lock generated
View File

@@ -1536,15 +1536,6 @@ dependencies = [
"zeroize",
]
[[package]]
name = "deranged"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4"
dependencies = [
"powerfmt",
]
[[package]]
name = "derivative"
version = "2.2.0"
@@ -2306,7 +2297,7 @@ dependencies = [
[[package]]
name = "halo2_proofs"
version = "0.3.0"
source = "git+https://github.com/zkonduit/halo2#098ac0ef3b29255e0e2524ecbb4e7e325ae6e7fd#098ac0ef3b29255e0e2524ecbb4e7e325ae6e7fd"
source = "git+https://github.com/zkonduit/halo2#8cfca221f53069a0374687654882b99e729041d7#8cfca221f53069a0374687654882b99e729041d7"
dependencies = [
"blake2b_simd",
"env_logger",
@@ -2319,7 +2310,6 @@ dependencies = [
"rand_chacha",
"rand_core 0.6.4",
"rustacuda",
"rustc-hash 2.0.0",
"sha3 0.9.1",
"tracing",
]
@@ -3272,12 +3262,6 @@ dependencies = [
"num-traits",
]
[[package]]
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
version = "0.1.46"
@@ -3791,12 +3775,6 @@ dependencies = [
"postgres-protocol",
]
[[package]]
name = "powerfmt"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
@@ -4029,7 +4007,7 @@ dependencies = [
"pin-project-lite",
"quinn-proto",
"quinn-udp",
"rustc-hash 1.1.0",
"rustc-hash",
"rustls",
"thiserror",
"tokio",
@@ -4045,7 +4023,7 @@ dependencies = [
"bytes",
"rand 0.8.5",
"ring",
"rustc-hash 1.1.0",
"rustc-hash",
"rustls",
"slab",
"thiserror",
@@ -4469,12 +4447,6 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustc-hash"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "583034fd73374156e66797ed8e5b0d5690409c9226b22d87cb7f19821c05d152"
[[package]]
name = "rustc-hex"
version = "2.1.0"
@@ -5245,14 +5217,11 @@ dependencies = [
[[package]]
name = "time"
version = "0.3.36"
version = "0.3.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885"
checksum = "59e399c068f43a5d116fedaf73b203fa4f9c519f17e2b34f63221d3792f81446"
dependencies = [
"deranged",
"itoa",
"num-conv",
"powerfmt",
"serde",
"time-core",
"time-macros",
@@ -5260,17 +5229,16 @@ dependencies = [
[[package]]
name = "time-core"
version = "0.1.2"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
checksum = "7300fbefb4dadc1af235a9cef3737cea692a9d97e1b9cbcd4ebdae6f8868e6fb"
[[package]]
name = "time-macros"
version = "0.2.18"
version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf"
checksum = "96ba15a897f3c86766b757e5ac7221554c6750054d74d5b28844fce5fb36a6c4"
dependencies = [
"num-conv",
"time-core",
]

View File

@@ -209,7 +209,7 @@ metal = ["dep:metal", "dep:objc"]
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
[patch.'https://github.com/zkonduit/halo2']
halo2_proofs = { git = "https://github.com/zkonduit/halo2#098ac0ef3b29255e0e2524ecbb4e7e325ae6e7fd", package = "halo2_proofs" }
halo2_proofs = { git = "https://github.com/zkonduit/halo2#8cfca221f53069a0374687654882b99e729041d7", package = "halo2_proofs" }
[profile.release]

View File

@@ -72,7 +72,6 @@ impl Circuit<Fr> for MyCircuit {
Box::new(PolyOp::Conv {
padding: vec![(0, 0)],
stride: vec![1; 2],
group: 1,
}),
)
.unwrap();

View File

@@ -2,7 +2,6 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion, Through
use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::table::Range;
use ezkl::circuit::{ops::lookup::LookupOp, BaseConfig as Config, CheckMode};
use ezkl::fieldutils::IntegerRep;
use ezkl::pfsys::create_proof_circuit;
use ezkl::pfsys::TranscriptType;
use ezkl::pfsys::{create_keys, srs::gen_srs};
@@ -85,7 +84,7 @@ fn runrelu(c: &mut Criterion) {
};
let input: Tensor<Value<Fr>> =
Tensor::<IntegerRep>::from((0..len).map(|_| rng.gen_range(0..10))).into();
Tensor::<i32>::from((0..len).map(|_| rng.gen_range(0..10))).into();
let circuit = NLCircuit {
input: ValTensor::from(input.clone()),

View File

@@ -1,4 +1,4 @@
ezkl==12.0.2
ezkl==0.0.0
sphinx
sphinx-rtd-theme
sphinxcontrib-napoleon

View File

@@ -1,7 +1,7 @@
import ezkl
project = 'ezkl'
release = '12.0.2'
release = '0.0.0'
version = release

View File

@@ -2,7 +2,8 @@ use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
};
use ezkl::fieldutils::{self, integer_rep_to_felt, IntegerRep};
use ezkl::fieldutils;
use ezkl::fieldutils::i32_to_felt;
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::poly::commitment::Params;
@@ -41,8 +42,8 @@ const NUM_INNER_COLS: usize = 1;
struct Config<
const LEN: usize, //LEN = CHOUT x OH x OW flattened //not supported yet in rust stable
const CLASSES: usize,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -65,8 +66,8 @@ struct Config<
struct MyCircuit<
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const CLASSES: usize,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -89,8 +90,8 @@ struct MyCircuit<
impl<
const LEN: usize,
const CLASSES: usize,
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
// Convolution
const KERNEL_HEIGHT: usize,
const KERNEL_WIDTH: usize,
@@ -204,7 +205,6 @@ where
let op = PolyOp::Conv {
padding: vec![(PADDING, PADDING); 2],
stride: vec![STRIDE; 2],
group: 1,
};
let x = config
.layer_config
@@ -315,11 +315,7 @@ pub fn runconv() {
.test_set_length(10_000)
.finalize();
let mut train_data = Tensor::from(
trn_img
.iter()
.map(|x| integer_rep_to_felt::<F>(*x as IntegerRep / 16)),
);
let mut train_data = Tensor::from(trn_img.iter().map(|x| i32_to_felt::<F>(*x as i32 / 16)));
train_data.reshape(&[50_000, 28, 28]).unwrap();
let mut train_labels = Tensor::from(trn_lbl.iter().map(|x| *x as f32));
@@ -347,8 +343,8 @@ pub fn runconv() {
.map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
fieldutils::integer_rep_to_felt(integral)
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
}),
);
@@ -359,8 +355,7 @@ pub fn runconv() {
let l0_kernels = l0_kernels.try_into().unwrap();
let mut l0_bias =
Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::integer_rep_to_felt(0)));
let mut l0_bias = Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::i32_to_felt(0)));
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
let l0_bias = l0_bias.try_into().unwrap();
@@ -368,8 +363,8 @@ pub fn runconv() {
let mut l2_biases = Tensor::<F>::from(myparams.biases.into_iter().map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
fieldutils::integer_rep_to_felt(integral)
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
}));
l2_biases.set_visibility(&ezkl::graph::Visibility::Private);
l2_biases.reshape(&[l2_biases.len(), 1]).unwrap();
@@ -379,8 +374,8 @@ pub fn runconv() {
let mut l2_weights = Tensor::<F>::from(myparams.weights.into_iter().flatten().map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: IntegerRep = unsafe { rounded.to_int_unchecked() };
fieldutils::integer_rep_to_felt(integral)
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
}));
l2_weights.set_visibility(&ezkl::graph::Visibility::Private);
l2_weights.reshape(&[CLASSES, LEN]).unwrap();
@@ -406,13 +401,13 @@ pub fn runconv() {
l2_params: [l2_weights, l2_biases],
};
let public_input: Tensor<IntegerRep> = vec![
-25124, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
let public_input: Tensor<i32> = vec![
-25124i32, -19304, -16668, -4399, -6209, -4548, -2317, -8349, -6117, -23461,
]
.into_iter()
.into();
let pi_inner: Tensor<F> = public_input.map(integer_rep_to_felt::<F>);
let pi_inner: Tensor<F> = public_input.map(i32_to_felt::<F>);
println!("MOCK PROVING");
let now = Instant::now();

View File

@@ -2,7 +2,7 @@ use ezkl::circuit::region::RegionCtx;
use ezkl::circuit::{
ops::lookup::LookupOp, ops::poly::PolyOp, BaseConfig as PolyConfig, CheckMode,
};
use ezkl::fieldutils::{integer_rep_to_felt, IntegerRep};
use ezkl::fieldutils::i32_to_felt;
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::{
@@ -23,8 +23,8 @@ struct MyConfig {
#[derive(Clone)]
struct MyCircuit<
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const LOOKUP_MIN: IntegerRep,
const LOOKUP_MAX: IntegerRep,
const LOOKUP_MIN: i64,
const LOOKUP_MAX: i64,
> {
// Given the stateless MyConfig type information, a DNN trace is determined by its input and the parameters of its layers.
// Computing the trace still requires a forward pass. The intermediate activations are stored only by the layouter.
@@ -34,7 +34,7 @@ struct MyCircuit<
_marker: PhantomData<F>,
}
impl<const LEN: usize, const LOOKUP_MIN: IntegerRep, const LOOKUP_MAX: IntegerRep> Circuit<F>
impl<const LEN: usize, const LOOKUP_MIN: i64, const LOOKUP_MAX: i64> Circuit<F>
for MyCircuit<LEN, LOOKUP_MIN, LOOKUP_MAX>
{
type Config = MyConfig;
@@ -215,33 +215,33 @@ pub fn runmlp() {
#[cfg(not(target_arch = "wasm32"))]
env_logger::init();
// parameters
let mut l0_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
let mut l0_kernel: Tensor<F> = Tensor::<i32>::new(
Some(&[10, 0, 0, -1, 0, 10, 1, 0, 0, 1, 10, 0, 1, 0, 0, 10]),
&[4, 4],
)
.unwrap()
.map(integer_rep_to_felt);
.map(i32_to_felt);
l0_kernel.set_visibility(&ezkl::graph::Visibility::Private);
let mut l0_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
let mut l0_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
.unwrap()
.map(integer_rep_to_felt);
.map(i32_to_felt);
l0_bias.set_visibility(&ezkl::graph::Visibility::Private);
let mut l2_kernel: Tensor<F> = Tensor::<IntegerRep>::new(
let mut l2_kernel: Tensor<F> = Tensor::<i32>::new(
Some(&[0, 3, 10, -1, 0, 10, 1, 0, 0, 1, 0, 12, 1, -2, 32, 0]),
&[4, 4],
)
.unwrap()
.map(integer_rep_to_felt);
.map(i32_to_felt);
l2_kernel.set_visibility(&ezkl::graph::Visibility::Private);
// input data, with 1 padding to allow for bias
let input: Tensor<Value<F>> = Tensor::<IntegerRep>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
let input: Tensor<Value<F>> = Tensor::<i32>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
.unwrap()
.into();
let mut l2_bias: Tensor<F> = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 1]), &[4, 1])
let mut l2_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
.unwrap()
.map(integer_rep_to_felt);
.map(i32_to_felt);
l2_bias.set_visibility(&ezkl::graph::Visibility::Private);
let circuit = MyCircuit::<4, -8192, 8192> {
@@ -251,12 +251,12 @@ pub fn runmlp() {
_marker: PhantomData,
};
let public_input: Vec<IntegerRep> = unsafe {
let public_input: Vec<i32> = unsafe {
vec![
(531f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
(103f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
(4469f32 / 128f32).round().to_int_unchecked::<IntegerRep>(),
(2849f32 / 128f32).to_int_unchecked::<IntegerRep>(),
(531f32 / 128f32).round().to_int_unchecked::<i32>(),
(103f32 / 128f32).round().to_int_unchecked::<i32>(),
(4469f32 / 128f32).round().to_int_unchecked::<i32>(),
(2849f32 / 128f32).to_int_unchecked::<i32>(),
]
};
@@ -265,10 +265,7 @@ pub fn runmlp() {
let prover = MockProver::run(
K as u32,
&circuit,
vec![public_input
.iter()
.map(|x| integer_rep_to_felt::<F>(*x))
.collect()],
vec![public_input.iter().map(|x| i32_to_felt::<F>(*x)).collect()],
)
.unwrap();
prover.assert_satisfied();

View File

@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2024-07-18"
channel = "nightly-2024-02-06"
components = ["rustfmt", "clippy"]

View File

@@ -22,7 +22,7 @@ use crate::{
table::{Range, RangeCheck, Table},
utils,
},
tensor::{Tensor, TensorType, ValTensor, VarTensor},
tensor::{IntoI64, Tensor, TensorType, ValTensor, VarTensor},
};
use std::{collections::BTreeMap, marker::PhantomData};
@@ -327,7 +327,7 @@ pub struct BaseConfig<F: PrimeField + TensorType + PartialOrd> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> BaseConfig<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> BaseConfig<F> {
/// Returns a new [BaseConfig] with no inputs, no selectors, and no tables.
pub fn dummy(col_size: usize, num_inner_cols: usize) -> Self {
Self {

View File

@@ -1,6 +1,6 @@
use std::convert::Infallible;
use crate::{fieldutils::IntegerRep, tensor::TensorError};
use crate::tensor::TensorError;
use halo2_proofs::plonk::Error as PlonkError;
use thiserror::Error;
@@ -57,7 +57,7 @@ pub enum CircuitError {
InvalidConversion(#[from] Infallible),
/// Invalid min/max lookup range
#[error("invalid min/max lookup range: min: {0}, max: {1}")]
InvalidMinMaxRange(IntegerRep, IntegerRep),
InvalidMinMaxRange(i64, i64),
/// Missing product in einsum
#[error("missing product in einsum")]
MissingEinsumProduct,
@@ -81,7 +81,7 @@ pub enum CircuitError {
MissingSelectors(String),
/// Table lookup error
#[error("value ({0}) out of range: ({1}, {2})")]
TableOOR(IntegerRep, IntegerRep, IntegerRep),
TableOOR(i64, i64, i64),
/// Loookup not configured
#[error("lookup not configured: {0}")]
LookupNotConfigured(String),

View File

@@ -1,7 +1,7 @@
use super::*;
use crate::{
circuit::{layouts, utils, Tolerance},
fieldutils::integer_rep_to_felt,
fieldutils::i64_to_felt,
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorType, ValTensor},
};
@@ -71,7 +71,7 @@ pub enum HybridOp {
},
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for HybridOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for HybridOp {
///
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
match self {
@@ -184,8 +184,8 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for Hybri
config,
region,
values[..].try_into()?,
integer_rep_to_felt(input_scale.0 as i128),
integer_rep_to_felt(output_scale.0 as i128),
i64_to_felt(input_scale.0 as i64),
i64_to_felt(output_scale.0 as i64),
)?
} else {
layouts::nonlinearity(
@@ -209,7 +209,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash > Op<F> for Hybri
config,
region,
values[..].try_into()?,
integer_rep_to_felt(denom.0 as i128),
i64_to_felt(denom.0 as i64),
)?
} else {
layouts::nonlinearity(

File diff suppressed because it is too large Load Diff

View File

@@ -3,9 +3,9 @@ use serde::{Deserialize, Serialize};
use crate::{
circuit::{layouts, table::Range, utils},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep},
fieldutils::{felt_to_i64, i64_to_felt},
graph::multiplier_to_scale,
tensor::{self, Tensor, TensorError, TensorType},
tensor::{self, IntoI64, Tensor, TensorError, TensorType},
};
use super::Op;
@@ -131,16 +131,16 @@ impl LookupOp {
/// Returns the range of values that can be represented by the table
pub fn bit_range(max_len: usize) -> Range {
let range = (max_len - 1) as f64 / 2_f64;
let range = range as IntegerRep;
let range = range as i64;
(-range, range)
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>(
pub(crate) fn f<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
&self,
x: &[Tensor<F>],
) -> Result<ForwardResult<F>, TensorError> {
let x = x[0].clone().map(|x| felt_to_integer_rep(x));
let x = x[0].clone().map(|x| felt_to_i64(x));
let res = match &self {
LookupOp::Abs => Ok(tensor::ops::abs(&x)?),
LookupOp::Ceil { scale } => Ok(tensor::ops::nonlinearities::ceil(&x, scale.into())),
@@ -227,13 +227,13 @@ impl LookupOp {
}
}?;
let output = res.map(|x| integer_rep_to_felt(x));
let output = res.map(|x| i64_to_felt(x));
Ok(ForwardResult { output })
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for LookupOp {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self

View File

@@ -4,7 +4,7 @@ use serde::{Deserialize, Serialize};
use crate::{
graph::quantize_tensor,
tensor::{self, Tensor, TensorType, ValTensor},
tensor::{self, IntoI64, Tensor, TensorType, ValTensor},
};
use halo2curves::ff::PrimeField;
@@ -31,12 +31,12 @@ pub use errors::CircuitError;
/// A struct representing the result of a forward pass.
#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
pub(crate) output: Tensor<F>,
}
/// A trait representing operations that can be represented as constraints in a circuit.
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>:
std::fmt::Debug + Send + Sync + Any
{
/// Returns a string representation of the operation.
@@ -75,7 +75,7 @@ pub trait Op<F: PrimeField + TensorType + PartialOrd + std::hash::Hash>:
fn as_any(&self) -> &dyn Any;
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Clone for Box<dyn Op<F>> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Clone for Box<dyn Op<F>> {
fn clone(&self) -> Self {
self.clone_dyn()
}
@@ -142,7 +142,7 @@ pub struct Input {
pub datum_type: InputType,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Input {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(self.scale)
}
@@ -197,7 +197,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Input
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Unknown;
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknown {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Op<F> for Unknown {
fn out_scale(&self, _: Vec<crate::Scale>) -> Result<crate::Scale, CircuitError> {
Ok(0)
}
@@ -224,7 +224,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Op<F> for Unknow
///
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> {
///
pub quantized_values: Tensor<F>,
///
@@ -234,7 +234,7 @@ pub struct Constant<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
pub pre_assigned_val: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Constant<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Constant<F> {
///
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
Self {
@@ -267,7 +267,8 @@ impl<
+ PartialOrd
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>,
+ for<'de> Deserialize<'de>
+ IntoI64,
> Op<F> for Constant<F>
{
fn as_any(&self) -> &dyn Any {

View File

@@ -33,7 +33,6 @@ pub enum PolyOp {
Conv {
padding: Vec<(usize, usize)>,
stride: Vec<usize>,
group: usize,
},
Downsample {
axis: usize,
@@ -44,7 +43,6 @@ pub enum PolyOp {
padding: Vec<(usize, usize)>,
output_padding: Vec<usize>,
stride: Vec<usize>,
group: usize,
},
Add,
Sub,
@@ -100,7 +98,7 @@ impl<
+ std::hash::Hash
+ Serialize
+ for<'de> Deserialize<'de>
,
+ IntoI64,
> Op<F> for PolyOp
{
/// Returns a reference to the Any trait.
@@ -150,25 +148,17 @@ impl<
PolyOp::Sum { axes } => format!("SUM (axes={:?})", axes),
PolyOp::Prod { .. } => "PROD".into(),
PolyOp::Pow(_) => "POW".into(),
PolyOp::Conv {
stride,
padding,
group,
} => {
format!(
"CONV (stride={:?}, padding={:?}, group={})",
stride, padding, group
)
PolyOp::Conv { stride, padding } => {
format!("CONV (stride={:?}, padding={:?})", stride, padding)
}
PolyOp::DeConv {
stride,
padding,
output_padding,
group,
} => {
format!(
"DECONV (stride={:?}, padding={:?}, output_padding={:?}, group={})",
stride, padding, output_padding, group
"DECONV (stride={:?}, padding={:?}, output_padding={:?})",
stride, padding, output_padding
)
}
PolyOp::Concat { axis } => format!("CONCAT (axis={})", axis),
@@ -222,18 +212,9 @@ impl<
PolyOp::Prod { axes, .. } => {
layouts::prod_axes(config, region, values[..].try_into()?, axes)?
}
PolyOp::Conv {
padding,
stride,
group,
} => layouts::conv(
config,
region,
values[..].try_into()?,
padding,
stride,
*group,
)?,
PolyOp::Conv { padding, stride } => {
layouts::conv(config, region, values[..].try_into()?, padding, stride)?
}
PolyOp::GatherElements { dim, constant_idx } => {
if let Some(idx) = constant_idx {
tensor::ops::gather_elements(values[0].get_inner_tensor()?, idx, *dim)?.into()
@@ -280,7 +261,6 @@ impl<
padding,
output_padding,
stride,
group,
} => layouts::deconv(
config,
region,
@@ -288,7 +268,6 @@ impl<
padding,
output_padding,
stride,
*group,
)?,
PolyOp::Add => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?,
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,

View File

@@ -1,6 +1,5 @@
use crate::{
circuit::table::Range,
fieldutils::IntegerRep,
tensor::{Tensor, TensorType, ValTensor, ValType, VarTensor},
};
#[cfg(not(target_arch = "wasm32"))]
@@ -12,6 +11,7 @@ use halo2_proofs::{
use halo2curves::ff::PrimeField;
use itertools::Itertools;
use maybe_rayon::iter::ParallelExtend;
use portable_atomic::AtomicI64 as AtomicInt;
use std::{
cell::RefCell,
collections::{HashMap, HashSet},
@@ -86,78 +86,6 @@ impl ShuffleIndex {
}
}
#[derive(Debug, Clone)]
/// Some settings for a region to differentiate it across the different phases of proof generation
pub struct RegionSettings {
/// whether we are in witness generation mode
pub witness_gen: bool,
/// whether we should check range checks for validity
pub check_range: bool,
}
#[allow(unsafe_code)]
unsafe impl Sync for RegionSettings {}
#[allow(unsafe_code)]
unsafe impl Send for RegionSettings {}
impl RegionSettings {
/// Create a new region settings
pub fn new(witness_gen: bool, check_range: bool) -> RegionSettings {
RegionSettings {
witness_gen,
check_range,
}
}
/// Create a new region settings with all true
pub fn all_true() -> RegionSettings {
RegionSettings {
witness_gen: true,
check_range: true,
}
}
/// Create a new region settings with all false
pub fn all_false() -> RegionSettings {
RegionSettings {
witness_gen: false,
check_range: false,
}
}
}
#[derive(Debug, Default, Clone)]
/// Region statistics
pub struct RegionStatistics {
/// the current maximum value of the lookup inputs
pub max_lookup_inputs: IntegerRep,
/// the current minimum value of the lookup inputs
pub min_lookup_inputs: IntegerRep,
/// the current maximum value of the range size
pub max_range_size: IntegerRep,
/// the current set of used lookups
pub used_lookups: HashSet<LookupOp>,
/// the current set of used range checks
pub used_range_checks: HashSet<Range>,
}
impl RegionStatistics {
/// update the statistics with another set of statistics
pub fn update(&mut self, other: &RegionStatistics) {
self.max_lookup_inputs = self.max_lookup_inputs.max(other.max_lookup_inputs);
self.min_lookup_inputs = self.min_lookup_inputs.min(other.min_lookup_inputs);
self.max_range_size = self.max_range_size.max(other.max_range_size);
self.used_lookups.extend(other.used_lookups.clone());
self.used_range_checks
.extend(other.used_range_checks.clone());
}
}
#[allow(unsafe_code)]
unsafe impl Sync for RegionStatistics {}
#[allow(unsafe_code)]
unsafe impl Send for RegionStatistics {}
#[derive(Debug)]
/// A context for a region
pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> {
@@ -167,8 +95,13 @@ pub struct RegionCtx<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Ha
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
statistics: RegionStatistics,
settings: RegionSettings,
used_lookups: HashSet<LookupOp>,
used_range_checks: HashSet<Range>,
max_lookup_inputs: i64,
min_lookup_inputs: i64,
max_range_size: i64,
witness_gen: bool,
check_lookup_range: bool,
assigned_constants: ConstantsMap<F>,
}
@@ -220,17 +153,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
///
pub fn witness_gen(&self) -> bool {
self.settings.witness_gen
self.witness_gen
}
///
pub fn check_range(&self) -> bool {
self.settings.check_range
}
///
pub fn statistics(&self) -> &RegionStatistics {
&self.statistics
pub fn check_lookup_range(&self) -> bool {
self.check_lookup_range
}
/// Create a new region context
@@ -245,8 +173,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
linear_coord,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
statistics: RegionStatistics::default(),
settings: RegionSettings::all_true(),
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen: true,
check_lookup_range: true,
assigned_constants: HashMap::new(),
}
}
@@ -262,12 +195,39 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
new_self.assigned_constants = constants;
new_self
}
/// Create a new region context from a wrapped region
pub fn from_wrapped_region(
region: Option<RefCell<Region<'a, F>>>,
row: usize,
num_inner_cols: usize,
dynamic_lookup_index: DynamicLookupIndex,
shuffle_index: ShuffleIndex,
) -> RegionCtx<'a, F> {
let linear_coord = row * num_inner_cols;
RegionCtx {
region,
num_inner_cols,
linear_coord,
row,
dynamic_lookup_index,
shuffle_index,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen: false,
check_lookup_range: false,
assigned_constants: HashMap::new(),
}
}
/// Create a new region context
pub fn new_dummy(
row: usize,
num_inner_cols: usize,
settings: RegionSettings,
witness_gen: bool,
check_lookup_range: bool,
) -> RegionCtx<'a, F> {
let region = None;
let linear_coord = row * num_inner_cols;
@@ -279,8 +239,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
row,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
statistics: RegionStatistics::default(),
settings,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen,
check_lookup_range,
assigned_constants: HashMap::new(),
}
}
@@ -290,7 +255,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
row: usize,
linear_coord: usize,
num_inner_cols: usize,
settings: RegionSettings,
witness_gen: bool,
check_lookup_range: bool,
) -> RegionCtx<'a, F> {
let region = None;
RegionCtx {
@@ -300,8 +266,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
row,
dynamic_lookup_index: DynamicLookupIndex::default(),
shuffle_index: ShuffleIndex::default(),
statistics: RegionStatistics::default(),
settings,
used_lookups: HashSet::new(),
used_range_checks: HashSet::new(),
max_lookup_inputs: 0,
min_lookup_inputs: 0,
max_range_size: 0,
witness_gen,
check_lookup_range,
assigned_constants: HashMap::new(),
}
}
@@ -352,9 +323,12 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
) -> Result<(), CircuitError> {
let row = AtomicUsize::new(self.row());
let linear_coord = AtomicUsize::new(self.linear_coord());
let statistics = Arc::new(Mutex::new(self.statistics.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
let max_lookup_inputs = AtomicInt::new(self.max_lookup_inputs());
let min_lookup_inputs = AtomicInt::new(self.min_lookup_inputs());
let lookups = Arc::new(Mutex::new(self.used_lookups.clone()));
let range_checks = Arc::new(Mutex::new(self.used_range_checks.clone()));
let dynamic_lookup_index = Arc::new(Mutex::new(self.dynamic_lookup_index.clone()));
let shuffle_index = Arc::new(Mutex::new(self.shuffle_index.clone()));
let constants = Arc::new(Mutex::new(self.assigned_constants.clone()));
*output = output.par_enum_map(|idx, _| {
@@ -368,7 +342,8 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
starting_offset,
starting_linear_coord,
self.num_inner_cols,
self.settings.clone(),
self.witness_gen,
self.check_lookup_range,
);
let res = inner_loop_function(idx, &mut local_reg);
// we update the offset and constants
@@ -378,9 +353,14 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
Ordering::SeqCst,
);
max_lookup_inputs.fetch_max(local_reg.max_lookup_inputs(), Ordering::SeqCst);
min_lookup_inputs.fetch_min(local_reg.min_lookup_inputs(), Ordering::SeqCst);
// update the lookups
let mut statistics = statistics.lock().unwrap();
statistics.update(local_reg.statistics());
let mut lookups = lookups.lock().unwrap();
lookups.extend(local_reg.used_lookups());
// update the range checks
let mut range_checks = range_checks.lock().unwrap();
range_checks.extend(local_reg.used_range_checks());
// update the dynamic lookup index
let mut dynamic_lookup_index = dynamic_lookup_index.lock().unwrap();
dynamic_lookup_index.update(&local_reg.dynamic_lookup_index);
@@ -394,11 +374,20 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
res
})?;
self.linear_coord = linear_coord.into_inner();
#[allow(trivial_numeric_casts)]
{
self.max_lookup_inputs = max_lookup_inputs.into_inner();
self.min_lookup_inputs = min_lookup_inputs.into_inner();
}
self.row = row.into_inner();
self.statistics = Arc::try_unwrap(statistics)
self.used_lookups = Arc::try_unwrap(lookups)
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| CircuitError::GetLookupsError(format!("{:?}", e)))?;
self.used_range_checks = Arc::try_unwrap(range_checks)
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?
.into_inner()
.map_err(|e| CircuitError::GetRangeChecksError(format!("{:?}", e)))?;
self.dynamic_lookup_index = Arc::try_unwrap(dynamic_lookup_index)
.map_err(|e| CircuitError::GetDynamicLookupError(format!("{:?}", e)))?
.into_inner()
@@ -422,11 +411,11 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
) -> Result<(), CircuitError> {
let (mut min, mut max) = (0, 0);
for i in inputs {
max = max.max(i.int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.int_evals()?.into_iter().min().unwrap_or_default());
max = max.max(i.get_int_evals()?.into_iter().max().unwrap_or_default());
min = min.min(i.get_int_evals()?.into_iter().min().unwrap_or_default());
}
self.statistics.max_lookup_inputs = self.statistics.max_lookup_inputs.max(max);
self.statistics.min_lookup_inputs = self.statistics.min_lookup_inputs.min(min);
self.max_lookup_inputs = self.max_lookup_inputs.max(max);
self.min_lookup_inputs = self.min_lookup_inputs.min(min);
Ok(())
}
@@ -438,7 +427,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
let range_size = (range.1 - range.0).abs();
self.statistics.max_range_size = self.statistics.max_range_size.max(range_size);
self.max_range_size = self.max_range_size.max(range_size);
Ok(())
}
@@ -453,13 +442,13 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
lookup: LookupOp,
inputs: &[ValTensor<F>],
) -> Result<(), CircuitError> {
self.statistics.used_lookups.insert(lookup);
self.used_lookups.insert(lookup);
self.update_max_min_lookup_inputs(inputs)
}
/// add used range check
pub fn add_used_range_check(&mut self, range: Range) -> Result<(), CircuitError> {
self.statistics.used_range_checks.insert(range);
self.used_range_checks.insert(range);
self.update_max_min_lookup_range(range)
}
@@ -500,27 +489,27 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
/// get used lookups
pub fn used_lookups(&self) -> HashSet<LookupOp> {
self.statistics.used_lookups.clone()
self.used_lookups.clone()
}
/// get used range checks
pub fn used_range_checks(&self) -> HashSet<Range> {
self.statistics.used_range_checks.clone()
self.used_range_checks.clone()
}
/// max lookup inputs
pub fn max_lookup_inputs(&self) -> IntegerRep {
self.statistics.max_lookup_inputs
pub fn max_lookup_inputs(&self) -> i64 {
self.max_lookup_inputs
}
/// min lookup inputs
pub fn min_lookup_inputs(&self) -> IntegerRep {
self.statistics.min_lookup_inputs
pub fn min_lookup_inputs(&self) -> i64 {
self.min_lookup_inputs
}
/// max range check
pub fn max_range_size(&self) -> IntegerRep {
self.statistics.max_range_size
pub fn max_range_size(&self) -> i64 {
self.max_range_size
}
/// Assign a valtensor to a vartensor

View File

@@ -11,17 +11,17 @@ use maybe_rayon::prelude::{IntoParallelIterator, ParallelIterator};
use crate::{
circuit::CircuitError,
fieldutils::{integer_rep_to_felt, IntegerRep},
tensor::{Tensor, TensorType},
fieldutils::i64_to_felt,
tensor::{IntoI64, Tensor, TensorType},
};
use crate::circuit::lookup::LookupOp;
/// The range of the lookup table.
pub type Range = (IntegerRep, IntegerRep);
pub type Range = (i64, i64);
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: IntegerRep = 2;
pub const RANGE_MULTIPLIER: i64 = 2;
/// The safety factor offset for the number of rows in the lookup table.
pub const RESERVED_BLINDING_ROWS_PAD: usize = 3;
@@ -96,22 +96,21 @@ pub struct Table<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
let chunk = (crate::fieldutils::felt_to_integer_rep(input) - self.range.0).abs()
/ (self.col_size as IntegerRep);
let chunk =
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
integer_rep_to_felt(chunk)
i64_to_felt(chunk)
}
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> (F, F) {
let chunk = chunk as IntegerRep;
let chunk = chunk as i64;
// we index from 1 to prevent soundness issues
let first_element =
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0);
let first_element = i64_to_felt(chunk * (self.col_size as i64) + self.range.0);
let op_f = self
.nonlinearity
.f(&[Tensor::from(vec![first_element].into_iter())])
@@ -131,12 +130,12 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
}
///
pub fn num_cols_required(range_len: IntegerRep, col_size: usize) -> usize {
pub fn num_cols_required(range_len: i64, col_size: usize) -> usize {
// number of cols needed to store the range
(range_len / (col_size as IntegerRep)) as usize + 1
(range_len / (col_size as i64)) as usize + 1
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> Table<F> {
/// Configures the table.
pub fn configure(
cs: &mut ConstraintSystem<F>,
@@ -203,7 +202,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> Table<F> {
let smallest = self.range.0;
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| integer_rep_to_felt(x));
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
let evals = self.nonlinearity.f(&[inputs.clone()])?;
let chunked_inputs = inputs.chunks(self.col_size);
@@ -273,12 +272,12 @@ pub struct RangeCheck<F: PrimeField> {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
/// get first_element of column
pub fn get_first_element(&self, chunk: usize) -> F {
let chunk = chunk as IntegerRep;
let chunk = chunk as i64;
// we index from 1 to prevent soundness issues
integer_rep_to_felt(chunk * (self.col_size as IntegerRep) + self.range.0)
i64_to_felt(chunk * (self.col_size as i64) + self.range.0)
}
///
@@ -294,14 +293,14 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
/// get column index given input
pub fn get_col_index(&self, input: F) -> F {
// range is split up into chunks of size col_size, find the chunk that input is in
let chunk = (crate::fieldutils::felt_to_integer_rep(input) - self.range.0).abs()
/ (self.col_size as IntegerRep);
let chunk =
(crate::fieldutils::felt_to_i64(input) - self.range.0).abs() / (self.col_size as i64);
integer_rep_to_felt(chunk)
i64_to_felt(chunk)
}
}
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64> RangeCheck<F> {
/// Configures the table.
pub fn configure(cs: &mut ConstraintSystem<F>, range: Range, logrows: usize) -> RangeCheck<F> {
log::debug!("range check range: {:?}", range);
@@ -351,7 +350,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RangeCheck<F> {
let smallest = self.range.0;
let largest = self.range.1;
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| integer_rep_to_felt(x));
let inputs: Tensor<F> = Tensor::from(smallest..=largest).map(|x| i64_to_felt(x));
let chunked_inputs = inputs.chunks(self.col_size);
self.is_assigned = true;

View File

@@ -1050,7 +1050,6 @@ mod conv {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis)
@@ -1201,7 +1200,6 @@ mod conv_col_ultra_overflow {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis)
@@ -1347,7 +1345,6 @@ mod conv_relu_col_ultra_overflow {
Box::new(PolyOp::Conv {
padding: vec![(1, 1); 2],
stride: vec![2; 2],
group: 1,
}),
)
.map_err(|_| Error::Synthesis);

View File

@@ -1,4 +1,3 @@
use crate::circuit::region::RegionSettings;
use crate::circuit::CheckMode;
#[cfg(not(target_arch = "wasm32"))]
use crate::commands::CalibrationTarget;
@@ -786,8 +785,6 @@ pub(crate) async fn gen_witness(
let commitment: Commitments = settings.run_args.commitment.into();
let region_settings = RegionSettings::all_true();
let start_time = Instant::now();
let witness = if settings.module_requires_polycommit() {
if get_srs_path(settings.run_args.logrows, srs_path.clone(), commitment).exists() {
@@ -802,7 +799,8 @@ pub(crate) async fn gen_witness(
&mut input,
vk.as_ref(),
Some(&srs),
region_settings,
true,
true,
)?
}
Commitments::IPA => {
@@ -816,7 +814,8 @@ pub(crate) async fn gen_witness(
&mut input,
vk.as_ref(),
Some(&srs),
region_settings,
true,
true,
)?
}
}
@@ -826,16 +825,12 @@ pub(crate) async fn gen_witness(
&mut input,
vk.as_ref(),
None,
region_settings,
true,
true,
)?
}
} else {
circuit.forward::<KZGCommitmentScheme<Bn256>>(
&mut input,
vk.as_ref(),
None,
region_settings,
)?
circuit.forward::<KZGCommitmentScheme<Bn256>>(&mut input, vk.as_ref(), None, true, true)?
};
// print each variable tuple (symbol, value) as symbol=value
@@ -1028,8 +1023,6 @@ pub(crate) async fn calibrate(
use std::collections::HashMap;
use tabled::Table;
use crate::fieldutils::IntegerRep;
let data = GraphData::from_path(data)?;
// load the pre-generated settings
let settings = GraphSettings::load(&settings_path)?;
@@ -1138,7 +1131,7 @@ pub(crate) async fn calibrate(
param_scale,
scale_rebase_multiplier,
div_rebasing,
lookup_range: (IntegerRep::MIN, IntegerRep::MAX),
lookup_range: (i64::MIN, i64::MAX),
..settings.run_args.clone()
};
@@ -1178,7 +1171,8 @@ pub(crate) async fn calibrate(
&mut data.clone(),
None,
None,
RegionSettings::all_true(),
true,
false,
)
.map_err(|e| format!("failed to forward: {}", e))?;

View File

@@ -2,11 +2,17 @@ use halo2_proofs::arithmetic::Field;
/// Utilities for converting from Halo2 PrimeField types to integers (and vice-versa).
use halo2curves::ff::PrimeField;
/// Integer representation of a PrimeField element.
pub type IntegerRep = i128;
/// Converts an i32 to a PrimeField element.
pub fn i32_to_felt<F: PrimeField>(x: i32) -> F {
if x >= 0 {
F::from(x as u64)
} else {
-F::from(x.unsigned_abs() as u64)
}
}
/// Converts an i64 to a PrimeField element.
pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
pub fn i64_to_felt<F: PrimeField>(x: i64) -> F {
if x >= 0 {
F::from_u128(x as u128)
} else {
@@ -14,9 +20,24 @@ pub fn integer_rep_to_felt<F: PrimeField>(x: IntegerRep) -> F {
}
}
/// Converts a PrimeField element to an i32.
pub fn felt_to_i32<F: PrimeField + PartialOrd + Field>(x: F) -> i32 {
if x > F::from(i32::MAX as u64) {
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_32 = u32::from_le_bytes(negtmp[..4].try_into().unwrap());
-(lower_32 as i32)
} else {
let rep = (x).to_repr();
let tmp: &[u8] = rep.as_ref();
let lower_32 = u32::from_le_bytes(tmp[..4].try_into().unwrap());
lower_32 as i32
}
}
/// Converts a PrimeField element to an f64.
pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
if x > F::from_u128(IntegerRep::MAX as u128) {
if x > F::from_u128(i64::MAX as u128) {
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
@@ -30,17 +51,17 @@ pub fn felt_to_f64<F: PrimeField + PartialOrd + Field>(x: F) -> f64 {
}
/// Converts a PrimeField element to an i64.
pub fn felt_to_integer_rep<F: PrimeField + PartialOrd + Field>(x: F) -> IntegerRep {
if x > F::from_u128(IntegerRep::MAX as u128) {
pub fn felt_to_i64<F: PrimeField + PartialOrd + Field>(x: F) -> i64 {
if x > F::from_u128(i64::MAX as u128) {
let rep = (-x).to_repr();
let negtmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(negtmp[..16].try_into().unwrap());
-(lower_128 as IntegerRep)
-(lower_128 as i64)
} else {
let rep = (x).to_repr();
let tmp: &[u8] = rep.as_ref();
let lower_128: u128 = u128::from_le_bytes(tmp[..16].try_into().unwrap());
lower_128 as IntegerRep
lower_128 as i64
}
}
@@ -52,24 +73,33 @@ mod test {
#[test]
fn test_conv() {
let res: F = integer_rep_to_felt(-15);
let res: F = i32_to_felt(-15i32);
assert_eq!(res, -F::from(15));
let res: F = integer_rep_to_felt(2_i128.pow(17));
let res: F = i32_to_felt(2_i32.pow(17));
assert_eq!(res, F::from(131072));
let res: F = integer_rep_to_felt(-15);
let res: F = i64_to_felt(-15i64);
assert_eq!(res, -F::from(15));
let res: F = integer_rep_to_felt(2_i128.pow(17));
let res: F = i64_to_felt(2_i64.pow(17));
assert_eq!(res, F::from(131072));
}
#[test]
fn felttointegerrep() {
for x in -(2_i128.pow(16))..(2_i128.pow(16)) {
let fieldx: F = integer_rep_to_felt::<F>(x);
let xf: i128 = felt_to_integer_rep::<F>(fieldx);
fn felttoi32() {
for x in -(2i32.pow(16))..(2i32.pow(16)) {
let fieldx: F = i32_to_felt::<F>(x);
let xf: i32 = felt_to_i32::<F>(fieldx);
assert_eq!(x, xf);
}
}
#[test]
fn felttoi64() {
for x in -(2i64.pow(20))..(2i64.pow(20)) {
let fieldx: F = i64_to_felt::<F>(x);
let xf: i64 = felt_to_i64::<F>(fieldx);
assert_eq!(x, xf);
}
}

View File

@@ -1,7 +1,7 @@
use super::errors::GraphError;
use super::quantize_float;
use crate::circuit::InputType;
use crate::fieldutils::integer_rep_to_felt;
use crate::fieldutils::i64_to_felt;
#[cfg(not(target_arch = "wasm32"))]
use crate::graph::postgres::Client;
#[cfg(not(target_arch = "wasm32"))]
@@ -128,7 +128,7 @@ impl FileSourceInner {
/// Convert to a field element
pub fn to_field(&self, scale: crate::Scale) -> Fp {
match self {
FileSourceInner::Float(f) => integer_rep_to_felt(quantize_float(f, 0.0, scale).unwrap()),
FileSourceInner::Float(f) => i64_to_felt(quantize_float(f, 0.0, scale).unwrap()),
FileSourceInner::Bool(f) => {
if *f {
Fp::one()
@@ -150,7 +150,7 @@ impl FileSourceInner {
0.0
}
}
FileSourceInner::Field(f) => crate::fieldutils::felt_to_integer_rep(*f) as f64,
FileSourceInner::Field(f) => crate::fieldutils::felt_to_i64(*f) as f64,
}
}
}

View File

@@ -34,10 +34,10 @@ use self::input::{FileSource, GraphData};
use self::modules::{GraphModules, ModuleConfigs, ModuleForwardResult, ModuleSizes};
use crate::circuit::lookup::LookupOp;
use crate::circuit::modules::ModulePlanner;
use crate::circuit::region::{ConstantsMap, RegionSettings};
use crate::circuit::region::ConstantsMap;
use crate::circuit::table::{num_cols_required, Range, Table, RESERVED_BLINDING_ROWS_PAD};
use crate::circuit::{CheckMode, InputType};
use crate::fieldutils::{felt_to_f64, IntegerRep};
use crate::fieldutils::felt_to_f64;
use crate::pfsys::PrettyElements;
use crate::tensor::{Tensor, ValTensor};
use crate::{RunArgs, EZKL_BUF_CAPACITY};
@@ -69,14 +69,13 @@ pub use vars::*;
use crate::pfsys::field_to_string;
/// The safety factor for the range of the lookup table.
pub const RANGE_MULTIPLIER: IntegerRep = 2;
pub const RANGE_MULTIPLIER: i64 = 2;
/// The maximum number of columns in a lookup table.
pub const MAX_NUM_LOOKUP_COLS: usize = 12;
/// Max representation of a lookup table input
pub const MAX_LOOKUP_ABS: IntegerRep =
(MAX_NUM_LOOKUP_COLS as IntegerRep) * 2_i128.pow(MAX_PUBLIC_SRS);
pub const MAX_LOOKUP_ABS: i64 = (MAX_NUM_LOOKUP_COLS as i64) * 2_i64.pow(MAX_PUBLIC_SRS);
#[cfg(not(target_arch = "wasm32"))]
lazy_static! {
@@ -127,11 +126,11 @@ pub struct GraphWitness {
/// Any hashes of outputs generated during the forward pass
pub processed_outputs: Option<ModuleForwardResult>,
/// max lookup input
pub max_lookup_inputs: IntegerRep,
pub max_lookup_inputs: i64,
/// max lookup input
pub min_lookup_inputs: IntegerRep,
pub min_lookup_inputs: i64,
/// max range check size
pub max_range_size: IntegerRep,
pub max_range_size: i64,
}
impl GraphWitness {
@@ -805,26 +804,18 @@ impl GraphCircuit {
// the ordering here is important, we want the inputs to come before the outputs
// as they are configured in that order as Column<Instances>
let mut public_inputs: Vec<Fp> = vec![];
// we first process the inputs
if let Some(processed_inputs) = &data.processed_inputs {
if self.settings().run_args.input_visibility.is_public() {
public_inputs.extend(self.graph_witness.inputs.clone().into_iter().flatten())
} else if let Some(processed_inputs) = &data.processed_inputs {
public_inputs.extend(processed_inputs.get_instances().into_iter().flatten());
}
// we then process the params
if let Some(processed_params) = &data.processed_params {
public_inputs.extend(processed_params.get_instances().into_iter().flatten());
}
// if the inputs are public, we add them to the public inputs AFTER the processed params as they are configured in that order as Column<Instances>
if self.settings().run_args.input_visibility.is_public() {
public_inputs.extend(self.graph_witness.inputs.clone().into_iter().flatten())
}
// if the outputs are public, we add them to the public inputs
if self.settings().run_args.output_visibility.is_public() {
public_inputs.extend(self.graph_witness.outputs.clone().into_iter().flatten());
// if the outputs are processed, we add the processed outputs to the public inputs
} else if let Some(processed_outputs) = &data.processed_outputs {
public_inputs.extend(processed_outputs.get_instances().into_iter().flatten());
}
@@ -1045,12 +1036,12 @@ impl GraphCircuit {
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: f64) -> Range {
(
(lookup_safety_margin * min_max_lookup.0 as f64).floor() as IntegerRep,
(lookup_safety_margin * min_max_lookup.1 as f64).ceil() as IntegerRep,
(lookup_safety_margin * min_max_lookup.0 as f64).floor() as i64,
(lookup_safety_margin * min_max_lookup.1 as f64).ceil() as i64,
)
}
fn calc_num_cols(range_len: IntegerRep, max_logrows: u32) -> usize {
fn calc_num_cols(range_len: i64, max_logrows: u32) -> usize {
let max_col_size = Table::<Fp>::cal_col_size(max_logrows as usize, RESERVED_BLINDING_ROWS);
num_cols_required(range_len, max_col_size)
}
@@ -1058,7 +1049,7 @@ impl GraphCircuit {
fn table_size_logrows(
&self,
safe_lookup_range: Range,
max_range_size: IntegerRep,
max_range_size: i64,
) -> Result<u32, GraphError> {
// pick the range with the largest absolute size safe_lookup_range or max_range_size
let safe_range = std::cmp::max(
@@ -1077,7 +1068,7 @@ impl GraphCircuit {
pub fn calc_min_logrows(
&mut self,
min_max_lookup: Range,
max_range_size: IntegerRep,
max_range_size: i64,
max_logrows: Option<u32>,
lookup_safety_margin: f64,
) -> Result<(), GraphError> {
@@ -1095,7 +1086,7 @@ impl GraphCircuit {
(safe_lookup_range.1.saturating_sub(safe_lookup_range.0)).saturating_abs();
// check if has overflowed max lookup input
if lookup_size > (MAX_LOOKUP_ABS as f64 / lookup_safety_margin).floor() as IntegerRep {
if lookup_size > (MAX_LOOKUP_ABS as f64 / lookup_safety_margin).floor() as i64 {
return Err(GraphError::LookupRangeTooLarge(
lookup_size.unsigned_abs() as usize
));
@@ -1175,7 +1166,7 @@ impl GraphCircuit {
&self,
k: u32,
safe_lookup_range: Range,
max_range_size: IntegerRep,
max_range_size: i64,
) -> bool {
// if num cols is too large then the extended k is too large
if Self::calc_num_cols(safe_lookup_range.1 - safe_lookup_range.0, k) > MAX_NUM_LOOKUP_COLS
@@ -1233,7 +1224,8 @@ impl GraphCircuit {
inputs: &mut [Tensor<Fp>],
vk: Option<&VerifyingKey<G1Affine>>,
srs: Option<&Scheme::ParamsProver>,
region_settings: RegionSettings,
witness_gen: bool,
check_lookup: bool,
) -> Result<GraphWitness, GraphError> {
let original_inputs = inputs.to_vec();
@@ -1282,7 +1274,7 @@ impl GraphCircuit {
let mut model_results =
self.model()
.forward(inputs, &self.settings().run_args, region_settings)?;
.forward(inputs, &self.settings().run_args, witness_gen, check_lookup)?;
if visibility.output.requires_processing() {
let module_outlets = visibility.output.overwrites_inputs();

View File

@@ -7,12 +7,10 @@ use super::GraphSettings;
use crate::circuit::hybrid::HybridOp;
use crate::circuit::region::ConstantsMap;
use crate::circuit::region::RegionCtx;
use crate::circuit::region::RegionSettings;
use crate::circuit::table::Range;
use crate::circuit::Input;
use crate::circuit::InputType;
use crate::circuit::Unknown;
use crate::fieldutils::IntegerRep;
use crate::tensor::ValType;
use crate::{
circuit::{lookup::LookupOp, BaseConfig as PolyConfig, CheckMode, Op},
@@ -66,11 +64,11 @@ pub struct ForwardResult {
/// The outputs of the forward pass.
pub outputs: Vec<Tensor<Fp>>,
/// The maximum value of any input to a lookup operation.
pub max_lookup_inputs: IntegerRep,
pub max_lookup_inputs: i64,
/// The minimum value of any input to a lookup operation.
pub min_lookup_inputs: IntegerRep,
pub min_lookup_inputs: i64,
/// The max range check size
pub max_range_size: IntegerRep,
pub max_range_size: i64,
}
impl From<DummyPassRes> for ForwardResult {
@@ -118,11 +116,11 @@ pub struct DummyPassRes {
/// range checks
pub range_checks: HashSet<Range>,
/// max lookup inputs
pub max_lookup_inputs: IntegerRep,
pub max_lookup_inputs: i64,
/// min lookup inputs
pub min_lookup_inputs: IntegerRep,
pub min_lookup_inputs: i64,
/// min range check
pub max_range_size: IntegerRep,
pub max_range_size: i64,
/// outputs
pub outputs: Vec<Tensor<Fp>>,
}
@@ -547,7 +545,7 @@ impl Model {
})
.collect::<Result<Vec<_>, GraphError>>()?;
let res = self.dummy_layout(run_args, &inputs, RegionSettings::all_false())?;
let res = self.dummy_layout(run_args, &inputs, false, false)?;
// if we're using percentage tolerance, we need to add the necessary range check ops for it.
@@ -590,13 +588,14 @@ impl Model {
&self,
model_inputs: &[Tensor<Fp>],
run_args: &RunArgs,
region_settings: RegionSettings,
witness_gen: bool,
check_lookup: bool,
) -> Result<ForwardResult, GraphError> {
let valtensor_inputs: Vec<ValTensor<Fp>> = model_inputs
.iter()
.map(|x| x.map(|elem| ValType::Value(Value::known(elem))).into())
.collect();
let res = self.dummy_layout(run_args, &valtensor_inputs, region_settings)?;
let res = self.dummy_layout(run_args, &valtensor_inputs, witness_gen, check_lookup)?;
Ok(res.into())
}
@@ -1391,7 +1390,8 @@ impl Model {
&self,
run_args: &RunArgs,
inputs: &[ValTensor<Fp>],
region_settings: RegionSettings,
witness_gen: bool,
check_lookup: bool,
) -> Result<DummyPassRes, GraphError> {
debug!("calculating num of constraints using dummy model layout...");
@@ -1410,7 +1410,8 @@ impl Model {
vars: ModelVars::new_dummy(),
};
let mut region = RegionCtx::new_dummy(0, run_args.num_inner_cols, region_settings);
let mut region =
RegionCtx::new_dummy(0, run_args.num_inner_cols, witness_gen, check_lookup);
let outputs = self.layout_nodes(&mut model_config, &mut region, &mut results)?;

View File

@@ -9,7 +9,6 @@ use crate::circuit::lookup::LookupOp;
#[cfg(not(target_arch = "wasm32"))]
use crate::circuit::poly::PolyOp;
use crate::circuit::Op;
use crate::fieldutils::IntegerRep;
use crate::tensor::{Tensor, TensorError, TensorType};
use halo2curves::bn256::Fr as Fp;
use halo2curves::ff::PrimeField;
@@ -51,20 +50,16 @@ use tract_onnx::tract_hir::{
/// * `dims` - the dimensionality of the resulting [Tensor].
/// * `shift` - offset used in the fixed point representation.
/// * `scale` - `2^scale` used in the fixed point representation.
pub fn quantize_float(
elem: &f64,
shift: f64,
scale: crate::Scale,
) -> Result<IntegerRep, TensorError> {
pub fn quantize_float(elem: &f64, shift: f64, scale: crate::Scale) -> Result<i64, TensorError> {
let mult = scale_to_multiplier(scale);
let max_value = ((IntegerRep::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
let max_value = ((i64::MAX as f64 - shift) / mult).round(); // the maximum value that can be represented w/o sig bit truncation
if *elem > max_value {
return Err(TensorError::SigBitTruncationError);
}
// we parallelize the quantization process as it seems to be quite slow at times
let scaled = (mult * *elem + shift).round() as IntegerRep;
let scaled = (mult * *elem + shift).round() as i64;
Ok(scaled)
}
@@ -75,7 +70,7 @@ pub fn quantize_float(
/// * `scale` - `2^scale` used in the fixed point representation.
/// * `shift` - offset used in the fixed point representation.
pub fn dequantize(felt: Fp, scale: crate::Scale, shift: f64) -> f64 {
let int_rep = crate::fieldutils::felt_to_integer_rep(felt);
let int_rep = crate::fieldutils::felt_to_i64(felt);
let multiplier = scale_to_multiplier(scale);
int_rep as f64 / multiplier - shift
}
@@ -288,7 +283,10 @@ pub fn new_op_from_onnx(
.flat_map(|x| x.out_scales())
.collect::<Vec<_>>();
let input_dims = inputs.iter().flat_map(|x| x.out_dims()).collect::<Vec<_>>();
let input_dims = inputs
.iter()
.flat_map(|x| x.out_dims())
.collect::<Vec<_>>();
let mut replace_const = |scale: crate::Scale,
index: usize,
@@ -1194,13 +1192,7 @@ pub fn new_op_from_onnx(
}
}
let group = conv_node.group;
SupportedOp::Linear(PolyOp::Conv {
padding,
stride,
group,
})
SupportedOp::Linear(PolyOp::Conv { padding, stride })
}
"Not" => SupportedOp::Linear(PolyOp::Not),
"And" => SupportedOp::Linear(PolyOp::And),
@@ -1255,7 +1247,6 @@ pub fn new_op_from_onnx(
padding,
output_padding: deconv_node.adjustments.to_vec(),
stride,
group: deconv_node.group,
})
}
"Downsample" => {
@@ -1429,7 +1420,7 @@ pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
visibility: &Visibility,
) -> Result<Tensor<F>, TensorError> {
let mut value: Tensor<F> = const_value.par_enum_map(|_, x| {
Ok::<_, TensorError>(crate::fieldutils::integer_rep_to_felt::<F>(quantize_float(
Ok::<_, TensorError>(crate::fieldutils::i64_to_felt::<F>(quantize_float(
&(x).into(),
0.0,
scale,

View File

@@ -85,7 +85,6 @@ use std::str::FromStr;
use circuit::{table::Range, CheckMode, Tolerance};
use clap::Args;
use fieldutils::IntegerRep;
use graph::Visibility;
use halo2_proofs::poly::{
ipa::commitment::IPACommitmentScheme, kzg::commitment::KZGCommitmentScheme,
@@ -244,7 +243,7 @@ pub struct RunArgs {
#[arg(long, default_value = "1", value_hint = clap::ValueHint::Other)]
pub scale_rebase_multiplier: u32,
/// The min and max elements in the lookup table input column
#[arg(short = 'B', long, value_parser = parse_key_val::<IntegerRep, IntegerRep>, default_value = "-32768->32768")]
#[arg(short = 'B', long, value_parser = parse_key_val::<i64, i64>, default_value = "-32768->32768")]
pub lookup_range: Range,
/// The log_2 number of rows
#[arg(short = 'K', long, default_value = "17", value_hint = clap::ValueHint::Other)]

View File

@@ -6,7 +6,7 @@ use crate::circuit::modules::poseidon::{
use crate::circuit::modules::Module;
use crate::circuit::{CheckMode, Tolerance};
use crate::commands::*;
use crate::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use crate::fieldutils::{felt_to_i64, i64_to_felt};
use crate::graph::modules::POSEIDON_LEN_GRAPH;
use crate::graph::TestDataSource;
use crate::graph::{
@@ -331,9 +331,9 @@ fn felt_to_big_endian(felt: PyFelt) -> PyResult<String> {
#[pyfunction(signature = (
felt,
))]
fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
fn felt_to_int(felt: PyFelt) -> PyResult<i64> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_integer_rep(felt);
let int_rep = felt_to_i64(felt);
Ok(int_rep)
}
@@ -357,7 +357,7 @@ fn felt_to_int(felt: PyFelt) -> PyResult<IntegerRep> {
))]
fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
let felt = crate::pfsys::string_to_field::<Fr>(&felt);
let int_rep = felt_to_integer_rep(felt);
let int_rep = felt_to_i64(felt);
let multiplier = scale_to_multiplier(scale);
let float_rep = int_rep as f64 / multiplier;
Ok(float_rep)
@@ -385,7 +385,7 @@ fn felt_to_float(felt: PyFelt, scale: crate::Scale) -> PyResult<f64> {
fn float_to_felt(input: f64, scale: crate::Scale) -> PyResult<PyFelt> {
let int_rep = quantize_float(&input, 0.0, scale)
.map_err(|_| PyIOError::new_err("Failed to quantize input"))?;
let felt = integer_rep_to_felt(int_rep);
let felt = i64_to_felt(int_rep);
Ok(crate::pfsys::field_to_string::<Fr>(&felt))
}

View File

@@ -9,7 +9,7 @@ pub mod var;
pub use errors::TensorError;
use halo2curves::ff::PrimeField;
use halo2curves::{bn256::Fr, ff::PrimeField};
use maybe_rayon::{
prelude::{
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator,
@@ -26,7 +26,7 @@ use instant::Instant;
use crate::{
circuit::utils,
fieldutils::{integer_rep_to_felt, IntegerRep},
fieldutils::{felt_to_i32, felt_to_i64, i32_to_felt, i64_to_felt},
graph::Visibility,
};
@@ -108,7 +108,7 @@ impl TensorType for f32 {
Some(0.0)
}
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
// f32 doesnt impl Ord so we cant just use max like we can for i32, usize.
// A comparison between f32s needs to handle NAN values.
fn tmax(&self, other: &Self) -> Option<Self> {
match (self.is_nan(), other.is_nan()) {
@@ -131,7 +131,7 @@ impl TensorType for f64 {
Some(0.0)
}
// f32 doesnt impl Ord so we cant just use max like we can for IntegerRep, usize.
// f32 doesnt impl Ord so we cant just use max like we can for i32, usize.
// A comparison between f32s needs to handle NAN values.
fn tmax(&self, other: &Self) -> Option<Self> {
match (self.is_nan(), other.is_nan()) {
@@ -150,7 +150,8 @@ impl TensorType for f64 {
}
tensor_type!(bool, Bool, false, true);
tensor_type!(IntegerRep, IntegerRep, 0, 1);
tensor_type!(i64, Int64, 0, 1);
tensor_type!(i32, Int32, 0, 1);
tensor_type!(usize, USize, 0, 1);
tensor_type!((), Empty, (), ());
tensor_type!(utils::F32, F32, utils::F32(0.0), utils::F32(1.0));
@@ -315,6 +316,92 @@ impl<T: TensorType> DerefMut for Tensor<T> {
self.inner.deref_mut()
}
}
/// Convert to i64 trait
pub trait IntoI64 {
/// Convert to i64
fn into_i64(self) -> i64;
/// From i64
fn from_i64(i: i64) -> Self;
}
impl IntoI64 for i64 {
fn into_i64(self) -> i64 {
self
}
fn from_i64(i: i64) -> i64 {
i
}
}
impl IntoI64 for i32 {
fn into_i64(self) -> i64 {
self as i64
}
fn from_i64(i: i64) -> Self {
i as i32
}
}
impl IntoI64 for usize {
fn into_i64(self) -> i64 {
self as i64
}
fn from_i64(i: i64) -> Self {
i as usize
}
}
impl IntoI64 for f32 {
fn into_i64(self) -> i64 {
self as i64
}
fn from_i64(i: i64) -> Self {
i as f32
}
}
impl IntoI64 for f64 {
fn into_i64(self) -> i64 {
self as i64
}
fn from_i64(i: i64) -> Self {
i as f64
}
}
impl IntoI64 for () {
fn into_i64(self) -> i64 {
0
}
fn from_i64(_: i64) -> Self {}
}
impl IntoI64 for Fr {
fn into_i64(self) -> i64 {
felt_to_i64(self)
}
fn from_i64(i: i64) -> Self {
i64_to_felt::<Fr>(i)
}
}
impl<F: PrimeField + IntoI64> IntoI64 for Value<F> {
fn into_i64(self) -> i64 {
let mut res = vec![];
self.map(|x| res.push(x.into_i64()));
if res.is_empty() {
0
} else {
res[0]
}
}
fn from_i64(i: i64) -> Self {
Value::known(F::from_i64(i))
}
}
impl<T: PartialEq + TensorType> PartialEq for Tensor<T> {
fn eq(&self, other: &Tensor<T>) -> bool {
@@ -344,6 +431,42 @@ where
}
}
impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<Assigned<F>, F>>>
for Tensor<i32>
{
fn from(value: Tensor<AssignedCell<Assigned<F>, F>>) -> Tensor<i32> {
let mut output = Vec::new();
value.map(|x| {
x.evaluate().value().map(|y| {
let e = felt_to_i32(*y);
output.push(e);
e
})
});
Tensor::new(Some(&output), value.dims()).unwrap()
}
}
impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>>
for Tensor<i32>
{
fn from(value: Tensor<AssignedCell<F, F>>) -> Tensor<i32> {
let mut output = Vec::new();
value.map(|x| {
let mut i = 0;
x.value().map(|y| {
let e = felt_to_i32(*y);
output.push(e);
i += 1;
});
if i == 0 {
output.push(0);
}
});
Tensor::new(Some(&output), value.dims()).unwrap()
}
}
impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<Assigned<F>, F>>>
for Tensor<Value<F>>
{
@@ -356,6 +479,24 @@ impl<F: PrimeField + Clone + TensorType + PartialOrd> From<Tensor<AssignedCell<A
}
}
impl<F: PrimeField + TensorType + Clone + PartialOrd> From<Tensor<Value<F>>> for Tensor<i32> {
fn from(t: Tensor<Value<F>>) -> Tensor<i32> {
let mut output = Vec::new();
t.map(|x| {
let mut i = 0;
x.map(|y| {
let e = felt_to_i32(y);
output.push(e);
i += 1;
});
if i == 0 {
output.push(0);
}
});
Tensor::new(Some(&output), t.dims()).unwrap()
}
}
impl<F: PrimeField + TensorType + Clone + PartialOrd> From<Tensor<Value<F>>>
for Tensor<Value<Assigned<F>>>
{
@@ -367,10 +508,20 @@ impl<F: PrimeField + TensorType + Clone + PartialOrd> From<Tensor<Value<F>>>
}
}
impl<F: PrimeField + TensorType + Clone> From<Tensor<IntegerRep>> for Tensor<Value<F>> {
fn from(t: Tensor<IntegerRep>) -> Tensor<Value<F>> {
impl<F: PrimeField + TensorType + Clone> From<Tensor<i32>> for Tensor<Value<F>> {
fn from(t: Tensor<i32>) -> Tensor<Value<F>> {
let mut ta: Tensor<Value<F>> =
Tensor::from((0..t.len()).map(|i| Value::known(integer_rep_to_felt::<F>(t[i]))));
Tensor::from((0..t.len()).map(|i| Value::known(i32_to_felt::<F>(t[i]))));
// safe to unwrap as we know the dims are correct
ta.reshape(t.dims()).unwrap();
ta
}
}
impl<F: PrimeField + TensorType + Clone> From<Tensor<i64>> for Tensor<Value<F>> {
fn from(t: Tensor<i64>) -> Tensor<Value<F>> {
let mut ta: Tensor<Value<F>> =
Tensor::from((0..t.len()).map(|i| Value::known(i64_to_felt::<F>(t[i]))));
// safe to unwrap as we know the dims are correct
ta.reshape(t.dims()).unwrap();
ta
@@ -482,8 +633,7 @@ impl<T: Clone + TensorType> Tensor<T> {
///
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(None, &[3, 3, 3]).unwrap();
/// let mut a = Tensor::<i32>::new(None, &[3, 3, 3]).unwrap();
///
/// a.set(&[0, 0, 1], 10);
/// assert_eq!(a[0 + 0 + 1], 10);
@@ -500,8 +650,7 @@ impl<T: Clone + TensorType> Tensor<T> {
///
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(None, &[2, 3, 5]).unwrap();
/// let mut a = Tensor::<i32>::new(None, &[2, 3, 5]).unwrap();
///
/// a[1*15 + 1*5 + 1] = 5;
/// assert_eq!(a.get(&[1, 1, 1]), 5);
@@ -515,8 +664,7 @@ impl<T: Clone + TensorType> Tensor<T> {
///
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(None, &[2, 3, 5]).unwrap();
/// let mut a = Tensor::<i32>::new(None, &[2, 3, 5]).unwrap();
///
/// a[1*15 + 1*5 + 1] = 5;
/// assert_eq!(a.get(&[1, 1, 1]), 5);
@@ -536,12 +684,11 @@ impl<T: Clone + TensorType> Tensor<T> {
/// Pad to a length that is divisible by n
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1,2,3,4,5,6]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0]), &[8]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(4, 0).unwrap(), expected);
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 0, 0, 0]), &[9]).unwrap();
/// assert_eq!(a.pad_to_zero_rem(9, 0).unwrap(), expected);
/// ```
pub fn pad_to_zero_rem(&self, n: usize, pad: T) -> Result<Tensor<T>, TensorError> {
@@ -557,8 +704,7 @@ impl<T: Clone + TensorType> Tensor<T> {
///
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(None, &[2, 3, 5]).unwrap();
/// let mut a = Tensor::<i32>::new(None, &[2, 3, 5]).unwrap();
///
/// let flat_index = 1*15 + 1*5 + 1;
/// a[1*15 + 1*5 + 1] = 5;
@@ -585,9 +731,8 @@ impl<T: Clone + TensorType> Tensor<T> {
/// Get a slice from the Tensor.
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
/// let mut b = Tensor::<IntegerRep>::new(Some(&[1, 2]), &[2]).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
/// let mut b = Tensor::<i32>::new(Some(&[1, 2]), &[2]).unwrap();
///
/// assert_eq!(a.get_slice(&[0..2]).unwrap(), b);
/// ```
@@ -637,10 +782,9 @@ impl<T: Clone + TensorType> Tensor<T> {
/// Set a slice of the Tensor.
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
/// let b = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
/// a.set_slice(&[1..2], &Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
/// let b = Tensor::<i32>::new(Some(&[1, 2, 3, 1, 2, 3]), &[2, 3]).unwrap();
/// a.set_slice(&[1..2], &Tensor::<i32>::new(Some(&[1, 2, 3]), &[1, 3]).unwrap()).unwrap();
/// assert_eq!(a, b);
/// ```
pub fn set_slice(
@@ -701,7 +845,6 @@ impl<T: Clone + TensorType> Tensor<T> {
///
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
///
/// assert_eq!(a.get_index(&[2, 2, 2]), 26);
@@ -725,13 +868,12 @@ impl<T: Clone + TensorType> Tensor<T> {
///
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 3, 4, 5, 5, 6]), &[8]).unwrap();
/// let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 3, 4, 5, 5, 6]), &[8]).unwrap();
/// assert_eq!(a.duplicate_every_n(3, 1, 0).unwrap(), expected);
/// assert_eq!(a.duplicate_every_n(7, 1, 0).unwrap(), a);
///
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 2, 3, 3, 4, 5, 5, 6]), &[9]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[1, 1, 2, 3, 3, 4, 5, 5, 6]), &[9]).unwrap();
/// assert_eq!(a.duplicate_every_n(3, 1, 2).unwrap(), expected);
///
/// ```
@@ -758,9 +900,8 @@ impl<T: Clone + TensorType> Tensor<T> {
///
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 3, 4, 5, 6, 6]), &[8]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 3, 5, 6, 6]), &[7]).unwrap();
/// let a = Tensor::<i32>::new(Some(&[1, 2, 3, 3, 4, 5, 6, 6]), &[8]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 3, 5, 6, 6]), &[7]).unwrap();
/// assert_eq!(a.remove_every_n(4, 1, 0).unwrap(), expected);
///
///
@@ -794,15 +935,14 @@ impl<T: Clone + TensorType> Tensor<T> {
/// WARN: assumes indices are in ascending order for speed
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
/// let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[6]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[1, 2, 3, 6]), &[4]).unwrap();
/// let mut indices = vec![3, 4];
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), expected);
///
///
/// let a = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
/// let b = Tensor::<IntegerRep>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
/// let a = Tensor::<i32>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, 1, -335, 244, 188, -4, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[82]).unwrap();
/// let b = Tensor::<i32>::new(Some(&[52, -245, 153, 13, -4, -56, -163, 249, -128, -172, 396, 143, 2, -96, 504, -44, -158, -393, 61, 95, 191, 74, 64, -219, 553, 104, 235, 222, 44, -216, 63, -251, 40, -140, 112, -355, 60, 123, 26, -116, -89, -200, -109, 168, 135, -34, -99, -54, 5, -81, 322, 87, 4, -139, 420, 92, -295, -12, 262, -1, 26, -48, 231, -335, 244, 188, 5, -362, 57, -198, -184, -117, 40, 305, 49, 30, -59, -26, -37, 96]), &[80]).unwrap();
/// let mut indices = vec![63, 67];
/// assert_eq!(a.remove_indices(&mut indices, true).unwrap(), b);
/// ```
@@ -832,7 +972,6 @@ impl<T: Clone + TensorType> Tensor<T> {
///Reshape the tensor
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
/// a.reshape(&[9, 3]);
/// assert_eq!(a.dims(), &[9, 3]);
@@ -867,23 +1006,22 @@ impl<T: Clone + TensorType> Tensor<T> {
/// Move axis of the tensor
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
/// let b = a.move_axis(0, 2).unwrap();
/// assert_eq!(b.dims(), &[3, 3, 3]);
///
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6]), &[1, 2, 3]).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6]), &[1, 2, 3]).unwrap();
/// let b = a.move_axis(0, 2).unwrap();
/// assert_eq!(b, expected);
///
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
/// let b = a.move_axis(1, 2).unwrap();
/// assert_eq!(b, expected);
///
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
/// let b = a.move_axis(2, 1).unwrap();
/// assert_eq!(b, expected);
/// ```
@@ -948,23 +1086,22 @@ impl<T: Clone + TensorType> Tensor<T> {
/// Swap axes of the tensor
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
/// let b = a.swap_axes(0, 2).unwrap();
/// assert_eq!(b.dims(), &[3, 3, 3]);
///
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6]), &[2, 1, 3]).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[3, 1, 2]).unwrap();
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6]), &[2, 1, 3]).unwrap();
/// let b = a.swap_axes(0, 2).unwrap();
/// assert_eq!(b, expected);
///
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
/// let b = a.swap_axes(1, 2).unwrap();
/// assert_eq!(b, expected);
///
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]), &[2, 3, 2]).unwrap();
/// let mut expected = Tensor::<i32>::new(Some(&[1, 3, 5, 2, 4, 6, 7, 9, 11, 8, 10, 12]), &[2, 2, 3]).unwrap();
/// let b = a.swap_axes(2, 1).unwrap();
/// assert_eq!(b, expected);
/// ```
@@ -1011,10 +1148,9 @@ impl<T: Clone + TensorType> Tensor<T> {
/// Broadcasts the tensor to a given shape
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
///
/// let mut expected = Tensor::<IntegerRep>::new(Some(&[1, 1, 1, 2, 2, 2, 3, 3, 3]), &[3, 3]).unwrap();
/// let mut expected = Tensor::<i32>::new(Some(&[1, 1, 1, 2, 2, 2, 3, 3, 3]), &[3, 3]).unwrap();
/// assert_eq!(a.expand(&[3, 3]).unwrap(), expected);
///
/// ```
@@ -1068,7 +1204,6 @@ impl<T: Clone + TensorType> Tensor<T> {
///Flatten the tensor shape
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<f32>::new(None, &[3, 3, 3]).unwrap();
/// a.flatten();
/// assert_eq!(a.dims(), &[27]);
@@ -1082,9 +1217,8 @@ impl<T: Clone + TensorType> Tensor<T> {
/// Maps a function to tensors
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.map(|x| IntegerRep::pow(x,2));
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.map(|x| i32::pow(x,2));
/// assert_eq!(c, Tensor::from([1, 16].into_iter()))
/// ```
pub fn map<F: FnMut(T) -> G, G: TensorType>(&self, mut f: F) -> Tensor<G> {
@@ -1097,9 +1231,8 @@ impl<T: Clone + TensorType> Tensor<T> {
/// Maps a function to tensors and enumerates
/// ```
/// use ezkl::tensor::{Tensor, TensorError};
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
/// ```
pub fn enum_map<F: FnMut(usize, T) -> Result<G, E>, G: TensorType, E: Error>(
@@ -1121,9 +1254,8 @@ impl<T: Clone + TensorType> Tensor<T> {
/// Maps a function to tensors and enumerates in parallel
/// ```
/// use ezkl::tensor::{Tensor, TensorError};
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
/// ```
pub fn par_enum_map<
@@ -1152,9 +1284,8 @@ impl<T: Clone + TensorType> Tensor<T> {
/// Get last elem from Tensor
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
/// let mut b = Tensor::<IntegerRep>::new(Some(&[3]), &[1]).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
/// let mut b = Tensor::<i32>::new(Some(&[3]), &[1]).unwrap();
///
/// assert_eq!(a.last().unwrap(), b);
/// ```
@@ -1177,9 +1308,8 @@ impl<T: Clone + TensorType> Tensor<T> {
/// Maps a function to tensors and enumerates in parallel
/// ```
/// use ezkl::tensor::{Tensor, TensorError};
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(IntegerRep::pow(x + i as IntegerRep, 2))).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 4]), &[2]).unwrap();
/// let mut c = a.par_enum_map::<_,_,TensorError>(|i, x| Ok(i32::pow(x + i as i32, 2))).unwrap();
/// assert_eq!(c, Tensor::from([1, 25].into_iter()));
/// ```
pub fn par_enum_map_mut_filtered<
@@ -1202,13 +1332,103 @@ impl<T: Clone + TensorType> Tensor<T> {
}
}
#[cfg(feature = "metal")]
#[allow(unsafe_code)]
/// Perform a tensor operation on the GPU using Metal.
pub fn metal_tensor_op<T: Clone + TensorType + IntoI64 + Send + Sync>(
v: &Tensor<T>,
w: &Tensor<T>,
op: &str,
) -> Tensor<T> {
assert_eq!(v.dims(), w.dims());
log::trace!("------------------------------------------------");
let start = Instant::now();
let v = v
.par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64()))
.unwrap();
let w = w
.par_enum_map(|_, x| Ok::<_, TensorError>(x.into_i64()))
.unwrap();
log::trace!("Time to map tensors: {:?}", start.elapsed());
objc::rc::autoreleasepool(|| {
// create function pipeline.
// this compiles the function, so a pipline can't be created in performance sensitive code.
let pipeline = &PIPELINES[op];
let length = v.len() as u64;
let size = length * core::mem::size_of::<i64>() as u64;
assert_eq!(v.len(), w.len());
let start = Instant::now();
let buffer_a = DEVICE.new_buffer_with_data(
unsafe { std::mem::transmute(v.as_ptr()) },
size,
MTLResourceOptions::StorageModeShared,
);
let buffer_b = DEVICE.new_buffer_with_data(
unsafe { std::mem::transmute(w.as_ptr()) },
size,
MTLResourceOptions::StorageModeShared,
);
let buffer_result = DEVICE.new_buffer(
size, // the operation will return an array with the same size.
MTLResourceOptions::StorageModeShared,
);
log::trace!("Time to load buffers: {:?}", start.elapsed());
// for sending commands, a command buffer is needed.
let start = Instant::now();
let command_buffer = QUEUE.new_command_buffer();
log::trace!("Time to load command buffer: {:?}", start.elapsed());
// to write commands into a buffer an encoder is needed, in our case a compute encoder.
let start = Instant::now();
let compute_encoder = command_buffer.new_compute_command_encoder();
compute_encoder.set_compute_pipeline_state(&pipeline);
compute_encoder.set_buffers(
0,
&[Some(&buffer_a), Some(&buffer_b), Some(&buffer_result)],
&[0; 3],
);
log::trace!("Time to load compute encoder: {:?}", start.elapsed());
// specify thread count and organization
let start = Instant::now();
let grid_size = MTLSize::new(length, 1, 1);
let threadgroup_size = MTLSize::new(length, 1, 1);
compute_encoder.dispatch_threads(grid_size, threadgroup_size);
log::trace!("Time to dispatch threads: {:?}", start.elapsed());
// end encoding and execute commands
let start = Instant::now();
compute_encoder.end_encoding();
command_buffer.commit();
command_buffer.wait_until_completed();
log::trace!("Time to commit: {:?}", start.elapsed());
let start = Instant::now();
let ptr = buffer_result.contents() as *const i64;
let len = buffer_result.length() as usize / std::mem::size_of::<i64>();
let slice = unsafe { core::slice::from_raw_parts(ptr, len) };
let res = Tensor::new(Some(&slice.to_vec()), &v.dims()).unwrap();
log::trace!("Time to get result: {:?}", start.elapsed());
res.map(|x| T::from_i64(x))
})
}
impl<T: Clone + TensorType> Tensor<Tensor<T>> {
/// Flattens a tensor of tensors
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// let mut a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
/// let mut b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
/// let mut b = Tensor::<i32>::new(Some(&[1, 4]), &[2, 1]).unwrap();
/// let mut c = Tensor::new(Some(&[a,b]), &[2]).unwrap();
/// let mut d = c.combine().unwrap();
/// assert_eq!(d.dims(), &[8]);
@@ -1224,7 +1444,9 @@ impl<T: Clone + TensorType> Tensor<Tensor<T>> {
}
}
impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Add for Tensor<T> {
impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Add
for Tensor<T>
{
type Output = Result<Tensor<T>, TensorError>;
/// Adds tensors.
/// # Arguments
@@ -1234,43 +1456,42 @@ impl<T: TensorType + Add<Output = T> + std::marker::Send + std::marker::Sync> Ad
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use std::ops::Add;
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// let k = Tensor::<i32>::new(
/// Some(&[2, 3, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = x.add(k).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[4, 4, 4, 2, 2, 2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// // Now test 1D casting
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// let k = Tensor::<i32>::new(
/// Some(&[2]),
/// &[1]).unwrap();
/// let result = x.add(k).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 3, 3, 3]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
///
/// // Now test 2D casting
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// let k = Tensor::<i32>::new(
/// Some(&[2, 3]),
/// &[2]).unwrap();
/// let result = x.add(k).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 3, 4, 4, 4, 4]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 4, 4, 4]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
fn add(self, rhs: Self) -> Self::Output {
@@ -1304,14 +1525,13 @@ impl<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sync> Ne
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use std::ops::Neg;
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = x.neg();
/// let expected = Tensor::<IntegerRep>::new(Some(&[-2, -1, -2, -1, -1, -1]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[-2, -1, -2, -1, -1, -1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
fn neg(self) -> Self {
@@ -1324,7 +1544,9 @@ impl<T: TensorType + Neg<Output = T> + std::marker::Send + std::marker::Sync> Ne
}
}
impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Sub for Tensor<T> {
impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Sub
for Tensor<T>
{
type Output = Result<Tensor<T>, TensorError>;
/// Subtracts tensors.
/// # Arguments
@@ -1334,44 +1556,43 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Su
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use std::ops::Sub;
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// let k = Tensor::<i32>::new(
/// Some(&[2, 3, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = x.sub(k).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[0, -2, 0, 0, 0, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// // Now test 1D sub
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// let k = Tensor::<i32>::new(
/// Some(&[2]),
/// &[1],
/// ).unwrap();
/// let result = x.sub(k).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[0, -1, 0, -1, -1, -1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// // Now test 2D sub
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// let k = Tensor::<i32>::new(
/// Some(&[2, 3]),
/// &[2],
/// ).unwrap();
/// let result = x.sub(k).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, -1, 0, -2, -2, -2]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[0, -1, 0, -2, -2, -2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
fn sub(self, rhs: Self) -> Self::Output {
@@ -1397,7 +1618,9 @@ impl<T: TensorType + Sub<Output = T> + std::marker::Send + std::marker::Sync> Su
}
}
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mul for Tensor<T> {
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Mul
for Tensor<T>
{
type Output = Result<Tensor<T>, TensorError>;
/// Elementwise multiplies tensors.
/// # Arguments
@@ -1407,42 +1630,41 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mu
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use std::ops::Mul;
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// let k = Tensor::<i32>::new(
/// Some(&[2, 3, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = x.mul(k).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[4, 3, 4, 1, 1, 1]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// // Now test 1D mult
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// let k = Tensor::<i32>::new(
/// Some(&[2]),
/// &[1]).unwrap();
/// let result = x.mul(k).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// // Now test 2D mult
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let k = Tensor::<IntegerRep>::new(
/// let k = Tensor::<i32>::new(
/// Some(&[2, 2]),
/// &[2]).unwrap();
/// let result = x.mul(k).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[4, 2, 4, 2, 2, 2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
fn mul(self, rhs: Self) -> Self::Output {
@@ -1468,7 +1690,7 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Mu
}
}
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Tensor<T> {
impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync + IntoI64> Tensor<T> {
/// Elementwise raise a tensor to the nth power.
/// # Arguments
///
@@ -1477,14 +1699,13 @@ impl<T: TensorType + Mul<Output = T> + std::marker::Send + std::marker::Sync> Te
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use std::ops::Mul;
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[2, 15, 2, 1, 1, 0]),
/// &[2, 3],
/// ).unwrap();
/// let result = x.pow(3).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[8, 3375, 8, 1, 1, 0]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[8, 3375, 8, 1, 1, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
pub fn pow(&self, mut exp: u32) -> Result<Self, TensorError> {
@@ -1518,31 +1739,30 @@ impl<T: TensorType + Div<Output = T> + std::marker::Send + std::marker::Sync> Di
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use std::ops::Div;
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[4, 1, 4, 1, 1, 4]),
/// &[2, 3],
/// ).unwrap();
/// let y = Tensor::<IntegerRep>::new(
/// let y = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = x.div(y).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[2, 1, 2, 1, 1, 4]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
///
/// // test 1D casting
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[4, 1, 4, 1, 1, 4]),
/// &[2, 3],
/// ).unwrap();
/// let y = Tensor::<IntegerRep>::new(
/// let y = Tensor::<i32>::new(
/// Some(&[2]),
/// &[1],
/// ).unwrap();
/// let result = x.div(y).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[2, 0, 2, 0, 0, 2]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[2, 0, 2, 0, 0, 2]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
fn div(self, rhs: Self) -> Self::Output {
@@ -1569,18 +1789,17 @@ impl<T: TensorType + Rem<Output = T> + std::marker::Send + std::marker::Sync> Re
/// # Examples
/// ```
/// use ezkl::tensor::Tensor;
/// use ezkl::fieldutils::IntegerRep;
/// use std::ops::Rem;
/// let x = Tensor::<IntegerRep>::new(
/// let x = Tensor::<i32>::new(
/// Some(&[4, 1, 4, 1, 1, 4]),
/// &[2, 3],
/// ).unwrap();
/// let y = Tensor::<IntegerRep>::new(
/// let y = Tensor::<i32>::new(
/// Some(&[2, 1, 2, 1, 1, 1]),
/// &[2, 3],
/// ).unwrap();
/// let result = x.rem(y).unwrap();
/// let expected = Tensor::<IntegerRep>::new(Some(&[0, 0, 0, 0, 0, 0]), &[2, 3]).unwrap();
/// let expected = Tensor::<i32>::new(Some(&[0, 0, 0, 0, 0, 0]), &[2, 3]).unwrap();
/// assert_eq!(result, expected);
/// ```
fn rem(self, rhs: Self) -> Self::Output {
@@ -1664,25 +1883,25 @@ mod tests {
#[test]
fn tensor_clone() {
let x = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
let x = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
assert_eq!(x, x.clone());
}
#[test]
fn tensor_eq() {
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3]).unwrap();
let mut b = Tensor::<IntegerRep>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
let a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
let mut b = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3, 1]).unwrap();
b.reshape(&[3]).unwrap();
let c = Tensor::<IntegerRep>::new(Some(&[1, 2, 4]), &[3]).unwrap();
let d = Tensor::<IntegerRep>::new(Some(&[1, 2, 4]), &[3, 1]).unwrap();
let c = Tensor::<i32>::new(Some(&[1, 2, 4]), &[3]).unwrap();
let d = Tensor::<i32>::new(Some(&[1, 2, 4]), &[3, 1]).unwrap();
assert_eq!(a, b);
assert_ne!(a, c);
assert_ne!(a, d);
}
#[test]
fn tensor_slice() {
let a = Tensor::<IntegerRep>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
let b = Tensor::<IntegerRep>::new(Some(&[1, 4]), &[2, 1]).unwrap();
let a = Tensor::<i32>::new(Some(&[1, 2, 3, 4, 5, 6]), &[2, 3]).unwrap();
let b = Tensor::<i32>::new(Some(&[1, 4]), &[2, 1]).unwrap();
assert_eq!(a.get_slice(&[0..2, 0..1]).unwrap(), b);
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
use crate::{circuit::region::ConstantsMap, fieldutils::felt_to_integer_rep};
use crate::circuit::region::ConstantsMap;
use maybe_rayon::slice::Iter;
use super::{
@@ -54,44 +54,6 @@ pub enum ValType<F: PrimeField + TensorType + std::marker::Send + std::marker::S
AssignedConstant(AssignedCell<F, F>, F),
}
impl<F: PrimeField + TensorType + PartialOrd> From<ValType<F>> for IntegerRep {
fn from(val: ValType<F>) -> Self {
match val {
ValType::Value(v) => {
let mut output = 0;
let mut i = 0;
v.map(|y| {
let e = felt_to_integer_rep(y);
output = e;
i += 1;
});
output
}
ValType::AssignedValue(v) => {
let mut output = 0;
let mut i = 0;
v.evaluate().map(|y| {
let e = felt_to_integer_rep(y);
output = e;
i += 1;
});
output
}
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
let mut output = 0;
let mut i = 0;
v.value().map(|y| {
let e = felt_to_integer_rep(*y);
output = e;
i += 1;
});
output
}
ValType::Constant(v) => felt_to_integer_rep(v),
}
}
}
impl<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd> ValType<F> {
/// Returns the inner cell of the [ValType].
pub fn cell(&self) -> Option<Cell> {
@@ -159,6 +121,44 @@ impl<F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + Partia
}
}
impl<F: PrimeField + TensorType + PartialOrd> From<ValType<F>> for i32 {
fn from(val: ValType<F>) -> Self {
match val {
ValType::Value(v) => {
let mut output = 0_i32;
let mut i = 0;
v.map(|y| {
let e = felt_to_i32(y);
output = e;
i += 1;
});
output
}
ValType::AssignedValue(v) => {
let mut output = 0_i32;
let mut i = 0;
v.evaluate().map(|y| {
let e = felt_to_i32(y);
output = e;
i += 1;
});
output
}
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
let mut output = 0_i32;
let mut i = 0;
v.value().map(|y| {
let e = felt_to_i32(*y);
output = e;
i += 1;
});
output
}
ValType::Constant(v) => felt_to_i32(v),
}
}
}
impl<F: PrimeField + TensorType + PartialOrd> From<F> for ValType<F> {
fn from(t: F) -> ValType<F> {
ValType::Constant(t)
@@ -317,8 +317,8 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<AssignedCell<F, F>>> f
impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
/// Allocate a new [ValTensor::Value] from the given [Tensor] of [i64].
pub fn from_integer_rep_tensor(t: Tensor<IntegerRep>) -> ValTensor<F> {
let inner = t.map(|x| ValType::Value(Value::known(integer_rep_to_felt(x))));
pub fn from_i64_tensor(t: Tensor<i64>) -> ValTensor<F> {
let inner = t.map(|x| ValType::Value(Value::known(i64_to_felt(x))));
inner.into()
}
@@ -521,9 +521,9 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// Calls `int_evals` on the inner tensor.
pub fn int_evals(&self) -> Result<Tensor<IntegerRep>, TensorError> {
pub fn get_int_evals(&self) -> Result<Tensor<i64>, TensorError> {
// finally convert to vector of integers
let mut integer_evals: Vec<IntegerRep> = vec![];
let mut integer_evals: Vec<i64> = vec![];
match self {
ValTensor::Value {
inner: v, dims: _, ..
@@ -531,26 +531,25 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
// we have to push to an externally created vector or else vaf.map() returns an evaluation wrapped in Value<> (which we don't want)
let _ = v.map(|vaf| match vaf {
ValType::Value(v) => v.map(|f| {
integer_evals.push(crate::fieldutils::felt_to_integer_rep(f));
integer_evals.push(crate::fieldutils::felt_to_i64(f));
}),
ValType::AssignedValue(v) => v.map(|f| {
integer_evals.push(crate::fieldutils::felt_to_integer_rep(f.evaluate()));
integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate()));
}),
ValType::PrevAssigned(v) | ValType::AssignedConstant(v, ..) => {
v.value_field().map(|f| {
integer_evals
.push(crate::fieldutils::felt_to_integer_rep(f.evaluate()));
integer_evals.push(crate::fieldutils::felt_to_i64(f.evaluate()));
})
}
ValType::Constant(v) => {
integer_evals.push(crate::fieldutils::felt_to_integer_rep(v));
integer_evals.push(crate::fieldutils::felt_to_i64(v));
Value::unknown()
}
});
}
_ => return Err(TensorError::WrongMethod),
};
let mut tensor: Tensor<IntegerRep> = integer_evals.into_iter().into();
let mut tensor: Tensor<i64> = integer_evals.into_iter().into();
match tensor.reshape(self.dims()) {
_ => {}
};
@@ -1003,22 +1002,25 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
}
/// A [String] representation of the [ValTensor] for display, for example in showing intermediate values in a computational graph.
pub fn show(&self) -> String {
let r = match self.int_evals() {
Ok(v) => v,
Err(_) => return "ValTensor not PrevAssigned".into(),
};
if r.len() > 10 {
let start = r[..5].to_vec();
let end = r[r.len() - 5..].to_vec();
// print the two split by ... in the middle
format!(
"[{} ... {}]",
start.iter().map(|x| format!("{}", x)).join(", "),
end.iter().map(|x| format!("{}", x)).join(", ")
)
} else {
format!("{:?}", r)
match self.clone() {
ValTensor::Value {
inner: v, dims: _, ..
} => {
let r: Tensor<i32> = v.map(|x| x.into());
if r.len() > 10 {
let start = r[..5].to_vec();
let end = r[r.len() - 5..].to_vec();
// print the two split by ... in the middle
format!(
"[{} ... {}]",
start.iter().map(|x| format!("{}", x)).join(", "),
end.iter().map(|x| format!("{}", x)).join(", ")
)
} else {
format!("{:?}", r)
}
}
_ => "ValTensor not PrevAssigned".into(),
}
}
}

View File

@@ -368,7 +368,7 @@ impl VarTensor {
.sum::<usize>();
let dims = &dims[*idx];
// this should never ever fail
let t: Tensor<IntegerRep> = Tensor::new(None, dims).unwrap();
let t: Tensor<i32> = Tensor::new(None, dims).unwrap();
Ok(t.enum_map(|coord, _| {
let (x, y, z) = self.cartesian_coord(offset + coord);
region.assign_advice_from_instance(
@@ -497,7 +497,7 @@ impl VarTensor {
let (x, y, z) = self.cartesian_coord(offset + coord * step);
if matches!(check_mode, CheckMode::SAFE) && coord > 0 && z == 0 && y == 0 {
// assert that duplication occurred correctly
assert_eq!(Into::<IntegerRep>::into(k.clone()), Into::<IntegerRep>::into(v[coord - 1].clone()));
assert_eq!(Into::<i32>::into(k.clone()), Into::<i32>::into(v[coord - 1].clone()));
};
let cell = self.assign_value(region, offset, k.clone(), coord * step, constants)?;
@@ -533,14 +533,13 @@ impl VarTensor {
if matches!(check_mode, CheckMode::SAFE) {
// during key generation this will be 0 so we use this as a flag to check
// TODO: this isn't very safe and would be better to get the phase directly
let res_evals = res.int_evals().unwrap();
let is_assigned = res_evals
let is_assigned = !Into::<Tensor<i32>>::into(res.clone().get_inner().unwrap())
.iter()
.all(|&x| x == 0);
if !is_assigned {
if is_assigned {
assert_eq!(
values.int_evals().unwrap(),
res_evals
Into::<Tensor<i32>>::into(values.get_inner().unwrap()),
Into::<Tensor<i32>>::into(res.get_inner().unwrap())
)};
}

View File

@@ -1,52 +1,41 @@
use crate::{
circuit::{
modules::{
polycommit::PolyCommitChip,
poseidon::{
spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH},
PoseidonChip,
},
Module,
},
region::RegionSettings,
},
fieldutils::{felt_to_integer_rep, integer_rep_to_felt},
graph::{
modules::POSEIDON_LEN_GRAPH, quantize_float, scale_to_multiplier, GraphCircuit,
GraphSettings,
},
pfsys::{
create_proof_circuit,
evm::aggregation_kzg::{AggregationCircuit, PoseidonTranscript},
verify_proof_circuit, TranscriptType,
},
tensor::TensorType,
CheckMode, Commitments,
};
use crate::circuit::modules::polycommit::PolyCommitChip;
use crate::circuit::modules::poseidon::spec::{PoseidonSpec, POSEIDON_RATE, POSEIDON_WIDTH};
use crate::circuit::modules::poseidon::PoseidonChip;
use crate::circuit::modules::Module;
use crate::fieldutils::felt_to_i64;
use crate::fieldutils::i64_to_felt;
use crate::graph::modules::POSEIDON_LEN_GRAPH;
use crate::graph::quantize_float;
use crate::graph::scale_to_multiplier;
use crate::graph::{GraphCircuit, GraphSettings};
use crate::pfsys::create_proof_circuit;
use crate::pfsys::evm::aggregation_kzg::AggregationCircuit;
use crate::pfsys::evm::aggregation_kzg::PoseidonTranscript;
use crate::pfsys::verify_proof_circuit;
use crate::pfsys::TranscriptType;
use crate::tensor::TensorType;
use crate::CheckMode;
use crate::Commitments;
use console_error_panic_hook;
use halo2_proofs::{
plonk::*,
poly::{
commitment::{CommitmentScheme, ParamsProver},
ipa::{
commitment::{IPACommitmentScheme, ParamsIPA},
multiopen::{ProverIPA, VerifierIPA},
strategy::SingleStrategy as IPASingleStrategy,
},
kzg::{
commitment::{KZGCommitmentScheme, ParamsKZG},
multiopen::{ProverSHPLONK, VerifierSHPLONK},
strategy::SingleStrategy as KZGSingleStrategy,
},
VerificationStrategy,
},
use halo2_proofs::plonk::*;
use halo2_proofs::poly::commitment::{CommitmentScheme, ParamsProver};
use halo2_proofs::poly::ipa::multiopen::{ProverIPA, VerifierIPA};
use halo2_proofs::poly::ipa::{
commitment::{IPACommitmentScheme, ParamsIPA},
strategy::SingleStrategy as IPASingleStrategy,
};
use halo2_proofs::poly::kzg::multiopen::ProverSHPLONK;
use halo2_proofs::poly::kzg::multiopen::VerifierSHPLONK;
use halo2_proofs::poly::kzg::{
commitment::{KZGCommitmentScheme, ParamsKZG},
strategy::SingleStrategy as KZGSingleStrategy,
};
use halo2_proofs::poly::VerificationStrategy;
use halo2_solidity_verifier::encode_calldata;
use halo2curves::{
bn256::{Bn256, Fr, G1Affine},
ff::{FromUniformBytes, PrimeField},
};
use snark_verifier::{loader::native::NativeLoader, system::halo2::transcript::evm::EvmTranscript};
use halo2curves::bn256::{Bn256, Fr, G1Affine};
use halo2curves::ff::{FromUniformBytes, PrimeField};
use snark_verifier::loader::native::NativeLoader;
use snark_verifier::system::halo2::transcript::evm::EvmTranscript;
use std::str::FromStr;
use wasm_bindgen::prelude::*;
use wasm_bindgen_console_logger::DEFAULT_LOGGER;
@@ -124,7 +113,7 @@ pub fn feltToInt(
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
Ok(wasm_bindgen::Clamped(
serde_json::to_vec(&felt_to_integer_rep(felt))
serde_json::to_vec(&felt_to_i64(felt))
.map_err(|e| JsError::new(&format!("Failed to serialize integer: {}", e)))?,
))
}
@@ -138,7 +127,7 @@ pub fn feltToFloat(
) -> Result<f64, JsError> {
let felt: Fr = serde_json::from_slice(&array[..])
.map_err(|e| JsError::new(&format!("Failed to deserialize field element: {}", e)))?;
let int_rep = felt_to_integer_rep(felt);
let int_rep = felt_to_i64(felt);
let multiplier = scale_to_multiplier(scale);
Ok(int_rep as f64 / multiplier)
}
@@ -152,7 +141,7 @@ pub fn floatToFelt(
) -> Result<wasm_bindgen::Clamped<Vec<u8>>, JsError> {
let int_rep =
quantize_float(&input, 0.0, scale).map_err(|e| JsError::new(&format!("{}", e)))?;
let felt = integer_rep_to_felt(int_rep);
let felt = i64_to_felt(int_rep);
let vec = crate::pfsys::field_to_string::<halo2curves::bn256::Fr>(&felt);
Ok(wasm_bindgen::Clamped(serde_json::to_vec(&vec).map_err(
|e| JsError::new(&format!("Failed to serialize a float to felt{}", e)),
@@ -286,7 +275,7 @@ pub fn genWitness(
.map_err(|e| JsError::new(&format!("{}", e)))?;
let witness = circuit
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, RegionSettings::all_false())
.forward::<KZGCommitmentScheme<Bn256>>(&mut input, None, None, false, false)
.map_err(|e| JsError::new(&format!("{}", e)))?;
serde_json::to_vec(&witness)

View File

@@ -3,7 +3,7 @@
mod native_tests {
use ezkl::circuit::Tolerance;
use ezkl::fieldutils::{felt_to_integer_rep, integer_rep_to_felt, IntegerRep};
use ezkl::fieldutils::{felt_to_i64, i64_to_felt};
// use ezkl::circuit::table::RESERVED_BLINDING_ROWS_PAD;
use ezkl::graph::input::{FileSource, FileSourceInner, GraphData};
use ezkl::graph::{DataSource, GraphSettings, GraphWitness};
@@ -646,15 +646,6 @@ mod native_tests {
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn mock_hashed_params_public_inputs_(test: &str) {
crate::native_tests::init_binary();
let test_dir = TempDir::new(test).unwrap();
let path = test_dir.path().to_str().unwrap(); crate::native_tests::mv_test_(path, test);
mock(path, test.to_string(), "public", "hashed", "private", 1, "resources", None, 0.0);
test_dir.close().unwrap();
}
#(#[test_case(TESTS[N])])*
fn mock_fixed_inputs_(test: &str) {
crate::native_tests::init_binary();
@@ -1422,10 +1413,10 @@ mod native_tests {
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
halo2curves::bn256::Fr::zero()
} else {
integer_rep_to_felt(
(felt_to_integer_rep(*v) as f32
i64_to_felt(
(felt_to_i64(*v) as f32
* (rand::thread_rng().gen_range(-0.01..0.01) * tolerance))
as IntegerRep,
as i64,
)
};
@@ -1445,10 +1436,10 @@ mod native_tests {
let perturbation = if v == &halo2curves::bn256::Fr::zero() {
halo2curves::bn256::Fr::from(2)
} else {
integer_rep_to_felt(
(felt_to_integer_rep(*v) as f32
i64_to_felt(
(felt_to_i64(*v) as f32
* (rand::thread_rng().gen_range(0.02..0.1) * tolerance))
as IntegerRep,
as i64,
)
};
*v + perturbation

Binary file not shown.

Binary file not shown.

Binary file not shown.