chore: serialize/deserialize processed graph (#371)

This commit is contained in:
dante
2023-07-24 22:48:49 +01:00
committed by GitHub
parent b08a4341f1
commit cd39e5564e
31 changed files with 767 additions and 636 deletions

View File

@@ -26,31 +26,12 @@ jobs:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- name: Build
run: cargo build --verbose
build-wasm:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
components: rustfmt, clippy
- uses: mwilliamson/setup-wasmtime-action@v2
with:
wasmtime-version: "3.0.1"
- name: Install wasm32-wasi
run: rustup target add wasm32-wasi
- name: Build wasm
run: cargo build --release --bin ezkl --target=wasm32-wasi
- name: Run help
run: wasmtime run './target/wasm32-wasi/release/ezkl.wasm' -- --help
docs:
runs-on: ubuntu-latest
@@ -58,7 +39,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- name: Docs
@@ -70,7 +51,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -83,14 +64,30 @@ jobs:
- name: Library tests
run: cargo nextest run --lib --verbose -- --include-ignored
wasm32-tests:
runs-on: ubuntu-latest
needs: [build, build-wasm, library-tests, docs]
model-serialization:
runs-on: ubuntu-latest-32-cores
steps:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Model serialization
run: cargo nextest run native_tests::tests::model_serialization_
wasm32-tests:
runs-on: ubuntu-latest
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: jetli/wasm-pack-action@v0.4.0
@@ -104,18 +101,18 @@ jobs:
- name: Install wasm runner
run: cargo install wasm-server-runner
- name: Add rust-src
run: rustup component add rust-src --toolchain nightly-2023-06-27-x86_64-unknown-linux-gnu
run: rustup component add rust-src --toolchain nightly-2023-04-17-x86_64-unknown-linux-gnu
- name: Run wasm verifier tests
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std"
render-circuit:
runs-on: ubuntu-latest-32-cores
needs: [build, build-wasm, library-tests, docs]
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: mwilliamson/setup-wasmtime-action@v2
@@ -130,12 +127,12 @@ jobs:
tutorial:
runs-on: ubuntu-latest
needs: [build, build-wasm, library-tests, docs]
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -147,12 +144,12 @@ jobs:
mock-proving-tests:
runs-on: ubuntu-latest-32-cores
needs: [build, build-wasm, library-tests, docs]
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -169,6 +166,8 @@ jobs:
run: cargo nextest run --release --verbose tests::mock_public_params_ --test-threads 32
- name: Mock proving tests (hashed inputs)
run: cargo nextest run --release --verbose tests::mock_hashed_input_::t --test-threads 32
- name: Mock proving tests (hashed params)
run: cargo nextest run --release --verbose tests::mock_hashed_params_::t --test-threads 32
- name: Mock proving tests (hashed outputs)
run: cargo nextest run --release --verbose tests::mock_hashed_output_::t --test-threads 32
- name: Mock proving tests (encrypted inputs)
@@ -184,41 +183,13 @@ jobs:
- name: Mock proving tests (encrypted inputs + hashed params)
run: cargo nextest run --release --verbose tests::mock_encrypted_input_hashed_params_::t --test-threads 32
mock-proving-tests-wasi:
runs-on: ubuntu-latest-32-cores
needs: [build, build-wasm, library-tests, docs]
steps:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
override: true
components: rustfmt, clippy
- uses: mwilliamson/setup-wasmtime-action@v2
with:
wasmtime-version: "3.0.1"
- uses: baptiste0928/cargo-install@v1
with:
crate: cargo-nextest
locked: true
- name: Install wasm32-wasi
run: rustup target add wasm32-wasi
- name: Mock proving tests (WASI) (public outputs)
run: cargo nextest run --release --verbose tests_wasi::mock_public_outputs_ --test-threads 32
- name: Mock proving tests (WASI) (public inputs)
run: cargo nextest run --release --verbose tests_wasi::mock_public_inputs_ --test-threads 32
- name: Mock proving tests (WASI) (public params)
run: cargo nextest run --release --verbose tests_wasi::mock_public_params_ --test-threads 32
prove-and-verify-evm-tests:
runs-on: ubuntu-latest-16-cores
needs:
[
build,
build-wasm,
library-tests,
mock-proving-tests,
mock-proving-tests-wasi,
python-tests,
python-integration-tests,
]
@@ -226,7 +197,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -257,10 +228,8 @@ jobs:
needs:
[
build,
build-wasm,
library-tests,
mock-proving-tests,
mock-proving-tests-wasi,
python-tests,
python-integration-tests,
]
@@ -268,7 +237,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -291,10 +260,8 @@ jobs:
needs:
[
build,
build-wasm,
library-tests,
mock-proving-tests,
mock-proving-tests-wasi,
python-tests,
python-integration-tests,
]
@@ -302,7 +269,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -323,14 +290,13 @@ jobs:
needs:
[
build,
build-wasm,
library-tests,
]
steps:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -345,10 +311,8 @@ jobs:
needs:
[
build,
build-wasm,
library-tests,
mock-proving-tests,
mock-proving-tests-wasi,
python-tests,
python-integration-tests,
]
@@ -356,7 +320,7 @@ jobs:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -371,17 +335,15 @@ jobs:
needs:
[
build,
build-wasm,
library-tests,
mock-proving-tests,
mock-proving-tests-wasi,
python-tests,
]
steps:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -400,12 +362,12 @@ jobs:
examples:
runs-on: ubuntu-latest-32-cores
needs: [build, build-wasm, library-tests, docs]
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -419,12 +381,12 @@ jobs:
neg-tests:
runs-on: ubuntu-latest
needs: [build, build-wasm, library-tests, docs]
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1
@@ -436,7 +398,7 @@ jobs:
python-tests:
runs-on: ubuntu-latest-32-cores
needs: [build, build-wasm, library-tests, docs]
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
@@ -444,7 +406,7 @@ jobs:
python-version: "3.7"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- name: Install solc
@@ -460,7 +422,7 @@ jobs:
python-integration-tests:
runs-on: 512gb
needs: [build, build-wasm, library-tests, docs]
needs: [build, library-tests, docs]
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
@@ -468,7 +430,7 @@ jobs:
python-version: "3.9"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2023-06-27
toolchain: nightly-2023-04-17
override: true
components: rustfmt, clippy
- uses: baptiste0928/cargo-install@v1

319
Cargo.lock generated
View File

@@ -24,7 +24,7 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac1f845298e95f983ff1944b728ae08b8cebab80d684f0a832ed0fc74dfa27e2"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"cipher",
"cpufeatures",
]
@@ -46,7 +46,7 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"once_cell",
"version_check",
]
@@ -90,7 +90,7 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "220044e6a1bb31ddee4e3db724d29767f352de47445a6cd75e1a173142136c83"
dependencies = [
"nom",
"nom 7.1.3",
"vte",
]
@@ -189,6 +189,18 @@ version = "0.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711"
[[package]]
name = "as-slice"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "45403b49e3954a4b8428a0ac21a4b7afadccf92bfd96273f1a58cd4812496ae0"
dependencies = [
"generic-array 0.12.4",
"generic-array 0.13.3",
"generic-array 0.14.7",
"stable_deref_trait",
]
[[package]]
name = "ascii-canvas"
version = "3.0.0"
@@ -349,7 +361,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4"
dependencies = [
"block-padding",
"generic-array",
"generic-array 0.14.7",
]
[[package]]
@@ -358,7 +370,7 @@ version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
dependencies = [
"generic-array",
"generic-array 0.14.7",
]
[[package]]
@@ -376,6 +388,19 @@ dependencies = [
"sha2 0.9.9",
]
[[package]]
name = "build_id"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6deb6795d8b4d2269c3fcf87a87bff9f4cd45a99e259806603ee8007077daf3"
dependencies = [
"byteorder",
"once_cell",
"palaver",
"twox-hash",
"uuid",
]
[[package]]
name = "bumpalo"
version = "3.13.0"
@@ -459,6 +484,12 @@ version = "1.0.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
[[package]]
name = "cfg-if"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
[[package]]
name = "cfg-if"
version = "1.0.0"
@@ -598,7 +629,7 @@ dependencies = [
"bech32",
"bs58",
"digest 0.10.7",
"generic-array",
"generic-array 0.14.7",
"hex",
"ripemd",
"serde",
@@ -662,7 +693,7 @@ version = "0.1.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"wasm-bindgen",
]
@@ -757,7 +788,7 @@ version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
]
[[package]]
@@ -802,7 +833,7 @@ version = "0.5.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"crossbeam-utils",
]
@@ -812,7 +843,7 @@ version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"crossbeam-epoch",
"crossbeam-utils",
]
@@ -824,7 +855,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7"
dependencies = [
"autocfg",
"cfg-if",
"cfg-if 1.0.0",
"crossbeam-utils",
"memoffset 0.9.0",
"scopeguard",
@@ -836,7 +867,7 @@ version = "0.8.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
]
[[package]]
@@ -851,7 +882,7 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf4c2f4e1afd912bc40bfd6fed5d9dc1f288e0ba01bfcc835cc5bc3eb13efe15"
dependencies = [
"generic-array",
"generic-array 0.14.7",
"rand_core 0.6.4",
"subtle",
"zeroize",
@@ -863,7 +894,7 @@ version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
dependencies = [
"generic-array",
"generic-array 0.14.7",
"typenum",
]
@@ -1003,7 +1034,7 @@ version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066"
dependencies = [
"generic-array",
"generic-array 0.14.7",
]
[[package]]
@@ -1033,7 +1064,7 @@ version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"dirs-sys-next",
]
@@ -1148,7 +1179,7 @@ dependencies = [
"crypto-bigint",
"digest 0.10.7",
"ff",
"generic-array",
"generic-array 0.14.7",
"group",
"pkcs8",
"rand_core 0.6.4",
@@ -1178,7 +1209,7 @@ version = "0.8.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
]
[[package]]
@@ -1225,6 +1256,15 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "88bffebc5d80432c9b140ee17875ff173a8ab62faad5b257da912bd2f6c1c0a1"
[[package]]
name = "erased-serde"
version = "0.3.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c6984864d65d092d9e9ada107007a846a09f75d2e24046bcce9a38d14aa52052"
dependencies = [
"serde",
]
[[package]]
name = "errno"
version = "0.3.1"
@@ -1413,7 +1453,7 @@ dependencies = [
"chrono",
"elliptic-curve",
"ethabi",
"generic-array",
"generic-array 0.14.7",
"hex",
"k256",
"num_enum",
@@ -1534,7 +1574,7 @@ version = "2.0.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a81c89f121595cf8959e746045bb8b25a6a38d72588561e1a3b7992fc213f674"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"dunce",
"ethers-core",
"glob",
@@ -1607,6 +1647,7 @@ dependencies = [
"serde",
"serde-wasm-bindgen",
"serde_json",
"serde_traitobject",
"shellexpand",
"snark-verifier",
"tabled",
@@ -1667,7 +1708,7 @@ version = "0.2.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5cbc844cecaee9d4443931972e1289c8ff485cb4cc2767cb03ca139ed6885153"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"libc",
"redox_syscall 0.2.16",
"windows-sys 0.48.0",
@@ -1923,6 +1964,24 @@ dependencies = [
"tempfile",
]
[[package]]
name = "generic-array"
version = "0.12.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ffdf9f34f1447443d37393cc6c2b8313aebddcd96906caf34e54c68d8e57d7bd"
dependencies = [
"typenum",
]
[[package]]
name = "generic-array"
version = "0.13.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f797e67af32588215eaaab8327027ee8e71b9dd0b2b26996aedf20c030fce309"
dependencies = [
"typenum",
]
[[package]]
name = "generic-array"
version = "0.14.7"
@@ -1940,7 +1999,7 @@ version = "0.2.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
@@ -2086,6 +2145,15 @@ dependencies = [
"num-traits",
]
[[package]]
name = "hash32"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d4041af86e63ac4298ce40e5cca669066e75b6f1aa3390fe2561ffa5e1d9f4cc"
dependencies = [
"byteorder",
]
[[package]]
name = "hashbrown"
version = "0.11.2"
@@ -2125,6 +2193,18 @@ dependencies = [
"fxhash",
]
[[package]]
name = "heapless"
version = "0.5.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "74911a68a1658cfcfb61bc0ccfbd536e3b6e906f8c2f7883ee50157e3e2184f1"
dependencies = [
"as-slice",
"generic-array 0.13.3",
"hash32",
"stable_deref_trait",
]
[[package]]
name = "heck"
version = "0.4.1"
@@ -2400,7 +2480,7 @@ version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5"
dependencies = [
"generic-array",
"generic-array 0.14.7",
]
[[package]]
@@ -2409,7 +2489,7 @@ version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"js-sys",
"wasm-bindgen",
"web-sys",
@@ -2493,7 +2573,7 @@ version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cadb76004ed8e97623117f3df85b17aaa6626ab0b0831e6573f104df16cd1bcc"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"ecdsa",
"elliptic-curve",
"once_cell",
@@ -2569,7 +2649,7 @@ version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d580318f95776505201b28cf98eb1fa5e4be3b689633ba6a3e6cd880ff22d8cb"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"windows-sys 0.48.0",
]
@@ -2658,6 +2738,15 @@ version = "0.4.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
[[package]]
name = "mach"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa"
dependencies = [
"libc",
]
[[package]]
name = "maingate"
version = "0.1.0"
@@ -2729,6 +2818,12 @@ dependencies = [
"autocfg",
]
[[package]]
name = "metatype"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23decce7c32638bcefbd5a5a5d79a5bb5b720c47b82ad5cb670a7eb912705946"
[[package]]
name = "mime"
version = "0.3.17"
@@ -2808,6 +2903,25 @@ version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e4a24736216ec316047a1fc4252e27dabb04218aa4a3f37c6e7ddbf1f9782b54"
[[package]]
name = "nix"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b2e0b4f3320ed72aaedb9a5ac838690a8047c7b275da22711fddff4f8a14229"
dependencies = [
"bitflags",
"cc",
"cfg-if 0.1.10",
"libc",
"void",
]
[[package]]
name = "nom"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf51a729ecf40266a2368ad335a5fdde43471f545a967109cd62146ecf8b66ff"
[[package]]
name = "nom"
version = "7.1.3"
@@ -2983,7 +3097,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d"
dependencies = [
"bitflags",
"cfg-if",
"cfg-if 1.0.0",
"foreign-types",
"libc",
"once_cell",
@@ -3036,6 +3150,23 @@ version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
[[package]]
name = "palaver"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49dfc200733ac34dcd9a1e4a7e454b521723936010bef3710e2d8024a32d685f"
dependencies = [
"bitflags",
"heapless",
"lazy_static",
"libc",
"mach",
"nix",
"procinfo",
"typenum",
"winapi",
]
[[package]]
name = "papergrid"
version = "0.9.1"
@@ -3091,7 +3222,7 @@ version = "0.9.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"libc",
"redox_syscall 0.3.5",
"smallvec",
@@ -3494,6 +3625,18 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "procinfo"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ab1427f3d2635891f842892dda177883dca0639e05fe66796a62c9d2f23b49c"
dependencies = [
"byteorder",
"libc",
"nom 2.2.1",
"rustc_version 0.2.3",
]
[[package]]
name = "prost"
version = "0.11.9"
@@ -3523,7 +3666,7 @@ version = "0.18.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3b1ac5b3731ba34fdaa9785f8d74d17448cd18f30cf19e0c7e7b1fdb5272109"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"indoc",
"libc",
"memoffset 0.8.0",
@@ -3791,6 +3934,17 @@ version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78"
[[package]]
name = "relative"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3401c189ee92c7028ba4863f3fdb92af815789993221af2fa186eed8115da304"
dependencies = [
"build_id",
"serde",
"uuid",
]
[[package]]
name = "remove_dir_all"
version = "0.5.3"
@@ -3965,6 +4119,15 @@ version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3e75f6a532d0fd9f7f13144f392b6ad56a32696bfcd9c78f797f16bbb6f072d6"
[[package]]
name = "rustc_version"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
dependencies = [
"semver 0.9.0",
]
[[package]]
name = "rustc_version"
version = "0.3.3"
@@ -4048,7 +4211,7 @@ version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad560913365790f17cbf12479491169f01b9d46d29cfc7422bf8c64bdc61b731"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"derive_more",
"parity-scale-codec",
"scale-info-derive",
@@ -4116,7 +4279,7 @@ checksum = "f0aec48e813d6b90b15f0b8948af3c63483992dee44c03e9930b3eebdabe046e"
dependencies = [
"base16ct",
"der",
"generic-array",
"generic-array 0.14.7",
"pkcs8",
"subtle",
"zeroize",
@@ -4163,13 +4326,22 @@ dependencies = [
"libc",
]
[[package]]
name = "semver"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403"
dependencies = [
"semver-parser 0.7.0",
]
[[package]]
name = "semver"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6"
dependencies = [
"semver-parser",
"semver-parser 0.10.2",
]
[[package]]
@@ -4181,6 +4353,12 @@ dependencies = [
"serde",
]
[[package]]
name = "semver-parser"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
[[package]]
name = "semver-parser"
version = "0.10.2"
@@ -4238,6 +4416,28 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_closure"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9659437bcfbf4dd061a5e1f7994312990ac5b24d781f7ce577eefc3a27792da0"
dependencies = [
"rustversion",
"serde",
"serde_closure_derive",
]
[[package]]
name = "serde_closure_derive"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a6bb4d612b5caad466a9a09ee550445e34123a74075607cc0d882ff1ca28f46"
dependencies = [
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "serde_derive"
version = "1.0.164"
@@ -4269,6 +4469,19 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_traitobject"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c5ae15a5d31f7c57875a480ddd7be02314d264617d0294d961314a6d502e6b1"
dependencies = [
"erased-serde",
"metatype",
"relative",
"serde",
"serde_closure",
]
[[package]]
name = "serde_urlencoded"
version = "0.7.1"
@@ -4288,7 +4501,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800"
dependencies = [
"block-buffer 0.9.0",
"cfg-if",
"cfg-if 1.0.0",
"cpufeatures",
"digest 0.9.0",
"opaque-debug",
@@ -4300,7 +4513,7 @@ version = "0.10.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "479fb9d862239e610720565ca91403019f2f00410f1864c5aa7479b950a76ed8"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"cpufeatures",
"digest 0.10.7",
]
@@ -4440,6 +4653,12 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02a8428da277a8e3a15271d79943e80ccc2ef254e78813a166a08d65e4c3ece5"
[[package]]
name = "stable_deref_trait"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]]
name = "static_assertions"
version = "1.1.0"
@@ -4458,7 +4677,7 @@ version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e2531d8525b29b514d25e275a43581320d587b86db302b9a7e464bac579648"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"hashbrown 0.11.2",
"serde",
]
@@ -4628,7 +4847,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6"
dependencies = [
"autocfg",
"cfg-if",
"cfg-if 1.0.0",
"fastrand",
"redox_syscall 0.3.5",
"rustix",
@@ -4661,7 +4880,7 @@ version = "2.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e45b7bf6e19353ddd832745c8fcf77a17a93171df7151187f26623f2b75b5b26"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"proc-macro-error",
"proc-macro2",
"quote",
@@ -4867,7 +5086,7 @@ version = "0.1.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"pin-project-lite",
"tracing-attributes",
"tracing-core",
@@ -4937,7 +5156,7 @@ dependencies = [
"lazy_static",
"maplit",
"ndarray",
"nom",
"nom 7.1.3",
"num-integer",
"num-traits",
"scan_fmt",
@@ -4986,7 +5205,7 @@ dependencies = [
"byteorder",
"flate2",
"log",
"nom",
"nom 7.1.3",
"tar",
"tract-core",
"walkdir",
@@ -5044,6 +5263,16 @@ version = "0.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "375812fa44dab6df41c195cd2f7fecb488f6c09fbaafb62807488cefab642bff"
[[package]]
name = "twox-hash"
version = "1.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675"
dependencies = [
"cfg-if 1.0.0",
"static_assertions",
]
[[package]]
name = "typenum"
version = "1.16.0"
@@ -5152,6 +5381,12 @@ version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "void"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d"
[[package]]
name = "vte"
version = "0.10.1"
@@ -5210,7 +5445,7 @@ version = "0.2.87"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"serde",
"serde_json",
"wasm-bindgen-macro",
@@ -5237,7 +5472,7 @@ version = "0.4.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03"
dependencies = [
"cfg-if",
"cfg-if 1.0.0",
"js-sys",
"wasm-bindgen",
"web-sys",

View File

@@ -37,6 +37,7 @@ tokio = { version = "1.26.0", default_features = false, features = ["macros", "
rayon = { version = "1.7.0", default_features = false }
bincode = { version = "1.3.3", default_features = false }
ark-std = { version = "^0.3.0", default-features = false }
serde_traitobject = { version = "0.2.8", features = ["serde_closure"] }
# python binding related deps
pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }

View File

@@ -68,7 +68,8 @@ https://user-images.githubusercontent.com/45801863/236771676-5bbbbfd1-ba6f-418a-
Note that the library requires a nightly version of the rust toolchain. You can change the default toolchain by running:
```bash
rustup override set nightly
# we set it to this version because of https://github.com/rust-lang/rust/issues/110829
rustup override set nightly-2023-04-17
```
After which you may build the library

View File

@@ -28,8 +28,8 @@ const K: usize = 17;
#[derive(Clone, Debug)]
struct MyCircuit {
image: ValTensor<Fr>,
kernel: ValTensor<Fr>,
bias: ValTensor<Fr>,
kernel: Tensor<Fr>,
bias: Tensor<Fr>,
}
impl Circuit<Fr> for MyCircuit {
@@ -100,18 +100,20 @@ fn runcnvrl(c: &mut Criterion) {
.map(|_| Value::known(Fr::random(OsRng))),
);
image.reshape(&[1, IN_CHANNELS, IMAGE_HEIGHT, IMAGE_WIDTH]);
let mut kernels = Tensor::from(
let mut kernel = Tensor::from(
(0..{ OUT_CHANNELS * IN_CHANNELS * KERNEL_HEIGHT * KERNEL_WIDTH })
.map(|_| Value::known(Fr::random(OsRng))),
.map(|_| Fr::random(OsRng)),
);
kernels.reshape(&[OUT_CHANNELS, IN_CHANNELS, KERNEL_HEIGHT, KERNEL_WIDTH]);
kernel.reshape(&[OUT_CHANNELS, IN_CHANNELS, KERNEL_HEIGHT, KERNEL_WIDTH]);
kernel.set_visibility(ezkl::graph::Visibility::Private);
let bias = Tensor::from((0..{ OUT_CHANNELS }).map(|_| Value::known(Fr::random(OsRng))));
let mut bias = Tensor::from((0..{ OUT_CHANNELS }).map(|_| Fr::random(OsRng)));
bias.set_visibility(ezkl::graph::Visibility::Private);
let circuit = MyCircuit {
image: ValTensor::from(image),
kernel: ValTensor::from(kernels),
bias: ValTensor::from(bias),
kernel,
bias,
};
group.throughput(Throughput::Elements(*size as u64));

View File

@@ -6,6 +6,7 @@ use ezkl::fieldutils;
use ezkl::fieldutils::i32_to_felt;
use ezkl::tensor::*;
use halo2_proofs::dev::MockProver;
use halo2_proofs::poly::kzg::multiopen::VerifierGWC;
use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{
@@ -14,20 +15,19 @@ use halo2_proofs::{
},
poly::{
commitment::ParamsProver,
ipa::{
commitment::{IPACommitmentScheme, ParamsIPA},
multiopen::ProverIPA,
kzg::{
commitment::{KZGCommitmentScheme, ParamsKZG},
multiopen::ProverGWC,
strategy::SingleStrategy,
},
VerificationStrategy,
},
transcript::{
Blake2bRead, Blake2bWrite, Challenge255, TranscriptReadBuffer, TranscriptWriterBuffer,
},
};
use halo2curves::ff::PrimeField;
use halo2curves::pasta::vesta;
use halo2curves::pasta::Fp as F;
use halo2curves::bn256::Bn256;
use halo2curves::bn256::Fr as F;
use instant::Instant;
use mnist::*;
use rand::rngs::OsRng;
@@ -39,7 +39,6 @@ const K: usize = 20;
#[derive(Clone)]
struct Config<
F: PrimeField + TensorType + PartialOrd,
const LEN: usize, //LEN = CHOUT x OH x OW flattened //not supported yet in rust stable
const CLASSES: usize,
const BITS: usize,
@@ -63,7 +62,6 @@ struct Config<
#[derive(Clone)]
struct MyCircuit<
F: PrimeField + TensorType + PartialOrd,
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const CLASSES: usize,
const BITS: usize,
@@ -82,12 +80,11 @@ struct MyCircuit<
// Given the stateless ConvConfig type information, a DNN trace is determined by its input and the parameters of its layers.
// Computing the trace still requires a forward pass. The intermediate activations are stored only by the layouter.
input: ValTensor<F>,
l0_params: [ValTensor<F>; 2],
l2_params: [ValTensor<F>; 2],
l0_params: [Tensor<F>; 2],
l2_params: [Tensor<F>; 2],
}
impl<
F: PrimeField + TensorType + PartialOrd,
const LEN: usize,
const CLASSES: usize,
const BITS: usize,
@@ -102,7 +99,6 @@ impl<
const PADDING: usize,
> Circuit<F>
for MyCircuit<
F,
LEN,
CLASSES,
BITS,
@@ -119,7 +115,6 @@ where
Value<F>: TensorType,
{
type Config = Config<
F,
LEN,
CLASSES,
BITS,
@@ -203,7 +198,7 @@ where
.layer_config
.layout(
&mut region,
&[self.l2_params[0].clone(), x],
&[self.l2_params[0].clone().into(), x],
Box::new(PolyOp::Einsum {
equation: "ij,j->ik".to_string(),
}),
@@ -295,7 +290,7 @@ pub fn runconv() {
input.reshape(&[1, 1, 28, 28]).unwrap();
let myparams = params::Params::new();
let mut l0_kernels: ValTensor<F> = Tensor::<Value<F>>::from(
let mut l0_kernels = Tensor::<F>::from(
myparams
.kernels
.clone()
@@ -307,47 +302,35 @@ pub fn runconv() {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
let felt = fieldutils::i32_to_felt(integral);
Value::known(felt)
fieldutils::i32_to_felt(integral)
}),
)
.into();
);
l0_kernels
.reshape(&[OUT_CHANNELS, IN_CHANNELS, KERNEL_HEIGHT, KERNEL_WIDTH])
.unwrap();
l0_kernels.reshape(&[OUT_CHANNELS, IN_CHANNELS, KERNEL_HEIGHT, KERNEL_WIDTH]);
l0_kernels.set_visibility(ezkl::graph::Visibility::Private);
let l0_bias: ValTensor<F> = Tensor::<Value<F>>::from(
(0..OUT_CHANNELS).map(|_| Value::known(fieldutils::i32_to_felt(0))),
)
.into();
let mut l0_bias = Tensor::<F>::from((0..OUT_CHANNELS).map(|_| fieldutils::i32_to_felt(0)));
l0_bias.set_visibility(ezkl::graph::Visibility::Private);
let mut l2_biases: ValTensor<F> =
Tensor::<Value<F>>::from(myparams.biases.into_iter().map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
let felt = fieldutils::i32_to_felt(integral);
Value::known(felt)
}))
.into();
let mut l2_biases = Tensor::<F>::from(myparams.biases.into_iter().map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
}));
l2_biases.set_visibility(ezkl::graph::Visibility::Private);
l2_biases.reshape(&[l2_biases.len(), 1]);
l2_biases.reshape(&[l2_biases.len(), 1]).unwrap();
let mut l2_weights: ValTensor<F> =
Tensor::<Value<F>>::from(myparams.weights.into_iter().flatten().map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
let felt = fieldutils::i32_to_felt(integral);
Value::known(felt)
}))
.into();
l2_weights.reshape(&[CLASSES, LEN]).unwrap();
let mut l2_weights = Tensor::<F>::from(myparams.weights.into_iter().flatten().map(|fl| {
let dx = fl * 32_f32;
let rounded = dx.round();
let integral: i32 = unsafe { rounded.to_int_unchecked() };
fieldutils::i32_to_felt(integral)
}));
l2_weights.set_visibility(ezkl::graph::Visibility::Private);
l2_weights.reshape(&[CLASSES, LEN]);
let circuit = MyCircuit::<
F,
LEN,
10,
16,
@@ -411,7 +394,7 @@ pub fn runconv() {
// Real proof
println!("SRS GENERATION");
let now = Instant::now();
let params: ParamsIPA<vesta::Affine> = ParamsIPA::new(K as u32);
let params: ParamsKZG<Bn256> = ParamsKZG::new(K as u32);
let elapsed = now.elapsed();
println!(
"SRS GENERATION took {}.{}",
@@ -446,7 +429,7 @@ pub fn runconv() {
let now = Instant::now();
let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]);
let mut rng = OsRng;
create_proof::<IPACommitmentScheme<_>, ProverIPA<_>, _, _, _, _>(
create_proof::<KZGCommitmentScheme<_>, ProverGWC<_>, _, _, _, _>(
&params,
&pk,
&[circuit],
@@ -466,7 +449,7 @@ pub fn runconv() {
let now = Instant::now();
let strategy = SingleStrategy::new(&params);
let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]);
let verify = verify_proof(
let verify = verify_proof::<_, VerifierGWC<_>, _, _, _>(
&params,
pk.get_vk(),
strategy,

View File

@@ -9,36 +9,32 @@ use halo2_proofs::{
circuit::{Layouter, SimpleFloorPlanner, Value},
plonk::{Circuit, Column, ConstraintSystem, Error, Instance},
};
use halo2curves::ff::PrimeField;
use halo2curves::pasta::Fp as F;
use halo2curves::bn256::Fr as F;
use std::marker::PhantomData;
const K: usize = 15;
// A columnar ReLu MLP
#[derive(Clone)]
struct MyConfig<F: PrimeField + TensorType + PartialOrd> {
struct MyConfig {
layer_config: PolyConfig<F>,
public_output: Column<Instance>,
}
#[derive(Clone)]
struct MyCircuit<
F: PrimeField + TensorType + PartialOrd,
const LEN: usize, //LEN = CHOUT x OH x OW flattened
const BITS: usize,
> {
// Given the stateless MyConfig type information, a DNN trace is determined by its input and the parameters of its layers.
// Computing the trace still requires a forward pass. The intermediate activations are stored only by the layouter.
input: ValTensor<F>,
l0_params: [ValTensor<F>; 2],
l2_params: [ValTensor<F>; 2],
l0_params: [Tensor<F>; 2],
l2_params: [Tensor<F>; 2],
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd, const LEN: usize, const BITS: usize> Circuit<F>
for MyCircuit<F, LEN, BITS>
{
type Config = MyConfig<F>;
impl<const LEN: usize, const BITS: usize> Circuit<F> for MyCircuit<LEN, BITS> {
type Config = MyConfig;
type FloorPlanner = SimpleFloorPlanner;
type Params = PhantomData<F>;
@@ -100,7 +96,7 @@ impl<F: PrimeField + TensorType + PartialOrd, const LEN: usize, const BITS: usiz
.layer_config
.layout(
&mut region,
&[self.l0_params[0].clone(), self.input.clone()],
&[self.l0_params[0].clone().into(), self.input.clone()],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -138,7 +134,7 @@ impl<F: PrimeField + TensorType + PartialOrd, const LEN: usize, const BITS: usiz
.layer_config
.layout(
&mut region,
&[self.l2_params[0].clone(), x],
&[self.l2_params[0].clone().into(), x],
Box::new(PolyOp::Einsum {
equation: "ab,bc->ac".to_string(),
}),
@@ -202,34 +198,39 @@ impl<F: PrimeField + TensorType + PartialOrd, const LEN: usize, const BITS: usiz
pub fn runmlp() {
env_logger::init();
// parameters
let l0_kernel: Tensor<Value<F>> = Tensor::<i32>::new(
let mut l0_kernel: Tensor<F> = Tensor::<i32>::new(
Some(&[10, 0, 0, -1, 0, 10, 1, 0, 0, 1, 10, 0, 1, 0, 0, 10]),
&[4, 4],
)
.unwrap()
.into();
let l0_bias: Tensor<Value<F>> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
.unwrap()
.into();
.map(i32_to_felt);
l0_kernel.set_visibility(ezkl::graph::Visibility::Private);
let l2_kernel: Tensor<Value<F>> = Tensor::<i32>::new(
let mut l0_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
.unwrap()
.map(i32_to_felt);
l0_bias.set_visibility(ezkl::graph::Visibility::Private);
let mut l2_kernel: Tensor<F> = Tensor::<i32>::new(
Some(&[0, 3, 10, -1, 0, 10, 1, 0, 0, 1, 0, 12, 1, -2, 32, 0]),
&[4, 4],
)
.unwrap()
.into();
.map(i32_to_felt);
l2_kernel.set_visibility(ezkl::graph::Visibility::Private);
// input data, with 1 padding to allow for bias
let input: Tensor<Value<F>> = Tensor::<i32>::new(Some(&[-30, -21, 11, 40]), &[4, 1])
.unwrap()
.into();
let l2_bias: Tensor<Value<F>> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
let mut l2_bias: Tensor<F> = Tensor::<i32>::new(Some(&[0, 0, 0, 1]), &[4, 1])
.unwrap()
.into();
.map(i32_to_felt);
l2_bias.set_visibility(ezkl::graph::Visibility::Private);
let circuit = MyCircuit::<F, 4, 14> {
let circuit = MyCircuit::<4, 14> {
input: input.into(),
l0_params: [l0_kernel.into(), l0_bias.into()],
l2_params: [l2_kernel.into(), l2_bias.into()],
l0_params: [l0_kernel, l0_bias],
l2_params: [l2_kernel, l2_bias],
_marker: PhantomData,
};

View File

@@ -1,3 +1,3 @@
[toolchain]
channel = "nightly-2023-06-27"
channel = "nightly-2023-04-16"
components = [ "rustfmt", "clippy" ]

View File

@@ -69,8 +69,8 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
meta,
hash_inputs.clone().try_into().unwrap(),
partial_sbox,
rc_a.try_into().unwrap(),
rc_b.try_into().unwrap(),
rc_a,
rc_b,
);
PoseidonConfig {
@@ -290,9 +290,9 @@ impl<S: Spec<Fp, WIDTH, RATE> + Sync, const WIDTH: usize, const RATE: usize, con
},
)?;
return Ok(assigned_input.into());
Ok(assigned_input.into())
} else {
return Ok(result.into());
Ok(result.into())
}
}

View File

@@ -1,5 +1,4 @@
use std::any::Any;
use super::*;
use crate::{
circuit::{self, layouts, Tolerance},
fieldutils::{felt_to_i128, i128_to_felt},
@@ -7,8 +6,6 @@ use crate::{
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
};
use serde::{Deserialize, Serialize};
use super::{lookup::LookupOp, region::RegionCtx, ForwardResult, Op};
use halo2curves::ff::PrimeField;
// import run args from model
@@ -34,10 +31,10 @@ pub enum HybridOp {
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for HybridOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
}
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, inputs: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let x = inputs[0].clone().map(|x| felt_to_i128(x));

View File

@@ -41,6 +41,7 @@ impl LookupOp {
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for LookupOp {
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
}

View File

@@ -34,7 +34,9 @@ pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
}
/// An enum representing operations that can be represented as constraints in a circuit.
pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send + Sync + Any {
pub trait Op<F: PrimeField + TensorType + PartialOrd>:
std::fmt::Debug + Send + Sync + Any + serde_traitobject::Serialize + serde_traitobject::Deserialize
{
/// Matches a [Op] to an operation in the `tensor::ops` module.
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
/// Returns a string representation of the operation.
@@ -134,15 +136,16 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
}
/// A wrapper for an operation that has been rescaled.
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Rescaled<F: PrimeField + TensorType + PartialOrd> {
/// The operation to be rescaled.
#[serde(with = "serde_traitobject")]
pub inner: Box<dyn Op<F>>,
/// The scale of the operation's inputs.
pub scale: Vec<(usize, u128)>,
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Rescaled<F> {
impl<F: PrimeField + TensorType + PartialOrd + Serialize> Op<F> for Rescaled<F> {
fn as_any(&self) -> &dyn Any {
self
}
@@ -214,7 +217,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Rescaled<F> {
}
/// An unknown operation.
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize)]
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Unknown;
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
@@ -246,32 +249,41 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Unknown {
}
///
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Constant<F: PrimeField + TensorType + PartialOrd> {
///
pub quantized_values: ValTensor<F>,
pub quantized_values: Tensor<F>,
///
pub raw_values: Tensor<f32>,
///
#[serde(skip)]
pub pre_assigned_val: Option<ValTensor<F>>,
}
impl<F: PrimeField + TensorType + PartialOrd> Constant<F> {
///
pub fn new(quantized_values: ValTensor<F>, raw_values: Tensor<f32>) -> Self {
pub fn new(quantized_values: Tensor<F>, raw_values: Tensor<f32>) -> Self {
Self {
quantized_values,
raw_values,
pre_assigned_val: None,
}
}
///
pub fn pre_assign(&mut self, val: ValTensor<F>) {
self.pre_assigned_val = Some(val)
}
}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Constant<F> {
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F>
for Constant<F>
{
fn as_any(&self) -> &dyn Any {
self
}
fn f(&self, _: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
let mut output = self.quantized_values.get_felt_evals().unwrap();
// make sure its the right shape
output.reshape(self.quantized_values.dims());
let output = self.quantized_values.clone();
Ok(ForwardResult {
output,
@@ -284,11 +296,20 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Constant<F> {
}
fn layout(
&self,
_: &mut crate::circuit::BaseConfig<F>,
_: &mut RegionCtx<F>,
config: &mut crate::circuit::BaseConfig<F>,
region: &mut RegionCtx<F>,
_: &[ValTensor<F>],
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
Ok(Some(self.quantized_values.clone()))
if let Some(value) = &self.pre_assigned_val {
Ok(Some(value.clone()))
} else {
// we gotta constrain it once
Ok(Some(layouts::identity(
config,
region,
&[self.quantized_values.clone().into()],
)?))
}
}
fn rescale(&self, _: Vec<u32>, _: u32) -> Box<dyn Op<F>> {
Box::new(self.clone())
@@ -299,7 +320,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Constant<F> {
}
}
fn homogenize_input_scales<F: PrimeField + TensorType + PartialOrd>(
fn homogenize_input_scales<F: PrimeField + TensorType + PartialOrd + Serialize>(
op: impl Op<F> + Clone,
input_scales: Vec<u32>,
inputs_to_scale: Vec<usize>,

View File

@@ -7,14 +7,14 @@ use super::{base::BaseOp, *};
#[allow(missing_docs)]
/// An enum representing the operations that can be expressed as arithmetic (non lookup) operations.
#[derive(Clone, Debug)]
pub enum PolyOp<F: PrimeField + TensorType + PartialOrd> {
#[derive(Clone, Debug, Serialize, Deserialize)]
pub enum PolyOp<F: PrimeField + TensorType + PartialOrd>{
Einsum {
equation: String,
},
Conv {
kernel: ValTensor<F>,
bias: Option<ValTensor<F>>,
kernel: Tensor<F>,
bias: Option<Tensor<F>>,
padding: (usize, usize),
stride: (usize, usize),
},
@@ -24,8 +24,8 @@ pub enum PolyOp<F: PrimeField + TensorType + PartialOrd> {
modulo: usize,
},
DeConv {
kernel: ValTensor<F>,
bias: Option<ValTensor<F>>,
kernel: Tensor<F>,
bias: Option<Tensor<F>>,
padding: (usize, usize),
output_padding: (usize, usize),
stride: (usize, usize),
@@ -36,11 +36,11 @@ pub enum PolyOp<F: PrimeField + TensorType + PartialOrd> {
kernel_shape: (usize, usize),
},
Add {
a: Option<ValTensor<F>>,
a: Option<Tensor<F>>,
},
Sub,
Mult {
a: Option<ValTensor<F>>,
a: Option<Tensor<F>>,
},
Identity,
Reshape(Vec<usize>),
@@ -76,10 +76,14 @@ pub enum PolyOp<F: PrimeField + TensorType + PartialOrd> {
impl<F: PrimeField + TensorType + PartialOrd> PolyOp<F> {}
impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<'de>> Op<F> for PolyOp<F>
{
/// Returns a reference to the Any trait.
fn as_any(&self) -> &dyn Any {
self
}
fn as_string(&self) -> String {
let name = match &self {
PolyOp::MoveAxis { .. } => "MOVEAXIS",
@@ -144,14 +148,14 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
}
PolyOp::Add { a } => {
if let Some(a) = a {
inputs.push(Tensor::new(Some(&a.get_felt_evals().unwrap()), a.dims())?);
inputs.push(a.clone());
}
tensor::ops::add(&inputs)
}
PolyOp::Sub => tensor::ops::sub(&inputs),
PolyOp::Mult { a } => {
if let Some(a) = a {
inputs.push(Tensor::new(Some(&a.get_felt_evals().unwrap()), a.dims())?);
inputs.push(a.clone());
}
tensor::ops::mult(&inputs)
}
@@ -161,9 +165,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
padding,
stride,
} => {
inputs.push(Tensor::new(Some(&a.get_felt_evals().unwrap()), a.dims())?);
inputs.push(a.clone());
if let Some(b) = bias {
inputs.push(Tensor::new(Some(&b.get_felt_evals().unwrap()), b.dims())?);
inputs.push(b.clone());
}
tensor::ops::conv(&inputs, *padding, *stride)
}
@@ -174,9 +178,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
output_padding,
stride,
} => {
inputs.push(Tensor::new(Some(&a.get_felt_evals().unwrap()), a.dims())?);
inputs.push(a.clone());
if let Some(b) = bias {
inputs.push(Tensor::new(Some(&b.get_felt_evals().unwrap()), b.dims())?);
inputs.push(b.clone());
}
tensor::ops::deconv(&inputs, *padding, *output_padding, *stride)
}
@@ -260,9 +264,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
padding,
stride,
} => {
values.push(kernel.clone());
values.push(kernel.clone().into());
if let Some(bias) = bias {
values.push(bias.clone());
values.push(bias.clone().into());
}
layouts::conv(config, region, values[..].try_into()?, *padding, *stride)?
}
@@ -273,9 +277,9 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
output_padding,
stride,
} => {
values.push(kernel.clone());
values.push(kernel.clone().into());
if let Some(bias) = bias {
values.push(bias.clone());
values.push(bias.clone().into());
}
layouts::deconv(
config,
@@ -300,7 +304,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
)?,
PolyOp::Add { a } => {
if let Some(a) = a {
values.push(a.clone());
values.push(a.clone().into());
}
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Add)?
@@ -308,7 +312,7 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
PolyOp::Sub => layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Sub)?,
PolyOp::Mult { a } => {
if let Some(a) = a {
values.push(a.clone());
values.push(a.clone().into());
}
layouts::pairwise(config, region, values[..].try_into()?, BaseOp::Mult)?
}
@@ -356,16 +360,32 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
PolyOp::Sum { .. } => in_scales[0],
PolyOp::Conv { kernel, bias, .. } => {
let output_scale = in_scales[0] + kernel.scale();
let kernel_scale = match kernel.scale() {
Some(s) => s,
None => panic!("scale must be set for conv kernel"),
};
let output_scale = in_scales[0] + kernel_scale;
if let Some(b) = bias {
assert_eq!(output_scale, b.scale());
let bias_scale = match b.scale() {
Some(s) => s,
None => panic!("scale must be set for conv bias"),
};
assert_eq!(output_scale, bias_scale);
}
output_scale
}
PolyOp::DeConv { kernel, bias, .. } => {
let output_scale = in_scales[0] + kernel.scale();
let kernel_scale = match kernel.scale() {
Some(s) => s,
None => panic!("scale must be set for deconv kernel"),
};
let output_scale = in_scales[0] + kernel_scale;
if let Some(b) = bias {
assert_eq!(output_scale, b.scale());
let bias_scale = match b.scale() {
Some(s) => s,
None => panic!("scale must be set for deconv bias"),
};
assert_eq!(output_scale, bias_scale);
}
output_scale
}
@@ -374,7 +394,11 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
let mut scale_a = 0;
let scale_b = in_scales[0];
if let Some(a) = a {
scale_a += a.scale();
let a_scale = match a.scale() {
Some(s) => s,
None => panic!("scale must be set for add constant"),
};
scale_a += a_scale;
} else {
scale_a += in_scales[1];
}
@@ -385,7 +409,11 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for PolyOp<F> {
PolyOp::Mult { a } => {
let mut scale = in_scales[0];
if let Some(a) = a {
scale += a.scale();
let a_scale = match a.scale() {
Some(s) => s,
None => panic!("scale must be set for add constant"),
};
scale += a_scale;
} else {
scale += in_scales[1];
}

View File

@@ -7,13 +7,13 @@ use halo2_proofs::{
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::bn256::Fr as F;
use halo2curves::ff::{Field, PrimeField};
use halo2curves::pasta::pallas;
use halo2curves::pasta::Fp as F;
use ops::lookup::LookupOp;
use ops::region::RegionCtx;
use rand::rngs::OsRng;
use std::marker::PhantomData;
#[derive(Default)]
struct TestParams;
@@ -31,7 +31,7 @@ mod matmul {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MatmulCircuit<F> {
impl Circuit<F> for MatmulCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -107,7 +107,7 @@ mod matmul_col_overflow {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MatmulCircuit<F> {
impl Circuit<F> for MatmulCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -183,7 +183,7 @@ mod dot {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -256,7 +256,7 @@ mod dot_col_overflow {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -329,7 +329,7 @@ mod sum {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -398,7 +398,7 @@ mod sum_col_overflow {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -468,7 +468,7 @@ mod composition {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -557,11 +557,11 @@ mod conv {
#[derive(Clone)]
struct ConvCircuit<F: PrimeField + TensorType + PartialOrd> {
inputs: Vec<ValTensor<F>>,
inputs: Vec<Tensor<F>>,
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for ConvCircuit<F> {
impl Circuit<F> for ConvCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -590,7 +590,7 @@ mod conv {
config
.layout(
&mut region,
&[self.inputs[0].clone()],
&[self.inputs[0].clone().into()],
Box::new(PolyOp::Conv {
kernel: self.inputs[1].clone(),
bias: None,
@@ -616,27 +616,23 @@ mod conv {
let in_channels = 3;
let out_channels = 2;
let mut image = Tensor::from(
(0..in_channels * image_height * image_width)
.map(|_| Value::known(pallas::Base::random(OsRng))),
);
let mut image =
Tensor::from((0..in_channels * image_height * image_width).map(|_| F::random(OsRng)));
image.reshape(&[1, in_channels, image_height, image_width]);
image.set_visibility(crate::graph::Visibility::Private);
let mut kernels = Tensor::from(
(0..{ out_channels * in_channels * kernel_height * kernel_width })
.map(|_| Value::known(pallas::Base::random(OsRng))),
.map(|_| F::random(OsRng)),
);
kernels.reshape(&[out_channels, in_channels, kernel_height, kernel_width]);
kernels.set_visibility(crate::graph::Visibility::Private);
let bias =
Tensor::from((0..{ out_channels }).map(|_| Value::known(pallas::Base::random(OsRng))));
let mut bias = Tensor::from((0..{ out_channels }).map(|_| F::random(OsRng)));
bias.set_visibility(crate::graph::Visibility::Private);
let circuit = ConvCircuit::<F> {
inputs: [
ValTensor::from(image),
ValTensor::from(kernels),
ValTensor::from(bias),
]
.to_vec(),
inputs: [image, kernels, bias].to_vec(),
_marker: PhantomData,
};
@@ -654,18 +650,20 @@ mod conv {
let in_channels = 3;
let out_channels = 2;
let mut image = Tensor::from(
(0..in_channels * image_height * image_width).map(|i| Value::known(F::from(i as u64))),
);
let mut image =
Tensor::from((0..in_channels * image_height * image_width).map(|i| F::from(i as u64)));
image.reshape(&[1, in_channels, image_height, image_width]);
image.set_visibility(crate::graph::Visibility::Private);
let mut kernels = Tensor::from(
(0..{ out_channels * in_channels * kernel_height * kernel_width })
.map(|i| Value::known(F::from(i as u64))),
.map(|i| F::from(i as u64)),
);
kernels.reshape(&[out_channels, in_channels, kernel_height, kernel_width]);
kernels.set_visibility(crate::graph::Visibility::Private);
let circuit = ConvCircuit::<F> {
inputs: [ValTensor::from(image), ValTensor::from(kernels)].to_vec(),
inputs: [image, kernels].to_vec(),
_marker: PhantomData,
};
@@ -688,7 +686,7 @@ mod sumpool {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for ConvCircuit<F> {
impl Circuit<F> for ConvCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -739,8 +737,7 @@ mod sumpool {
let in_channels = 1;
let mut image = Tensor::from(
(0..in_channels * image_height * image_width)
.map(|_| Value::known(pallas::Base::random(OsRng))),
(0..in_channels * image_height * image_width).map(|_| Value::known(F::random(OsRng))),
);
image.reshape(&[1, in_channels, image_height, image_width]);
@@ -767,7 +764,7 @@ mod add_w_shape_casting {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -838,7 +835,7 @@ mod add {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -909,7 +906,7 @@ mod add_with_overflow {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -1118,7 +1115,7 @@ mod sub {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -1185,7 +1182,7 @@ mod mult {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -1256,7 +1253,7 @@ mod pow {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -1321,7 +1318,7 @@ mod pack {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -1390,7 +1387,7 @@ mod rescaled {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -1486,7 +1483,7 @@ mod matmul_relu {
base_config: BaseConfig<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = MyConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -1572,7 +1569,6 @@ mod rangecheckpercent {
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::pasta::Fp;
const RANGE: f32 = 1.0; // 1 percent error tolerance
const K: usize = 18;
@@ -1588,7 +1584,7 @@ mod rangecheckpercent {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for MyCircuit<F> {
impl Circuit<F> for MyCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -1647,9 +1643,9 @@ mod rangecheckpercent {
fn test_range_check_percent() {
// Successful cases
{
let inp = Tensor::new(Some(&[Value::<Fp>::known(Fp::from(100_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<Fp>::known(Fp::from(101_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<Fp> {
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(100_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<F>::known(F::from(101_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<F> {
input: ValTensor::from(inp),
output: ValTensor::from(out),
_marker: PhantomData,
@@ -1658,9 +1654,9 @@ mod rangecheckpercent {
prover.assert_satisfied();
}
{
let inp = Tensor::new(Some(&[Value::<Fp>::known(Fp::from(200_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<Fp>::known(Fp::from(199_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<Fp> {
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(200_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<F>::known(F::from(199_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<F> {
input: ValTensor::from(inp),
output: ValTensor::from(out),
_marker: PhantomData,
@@ -1671,9 +1667,9 @@ mod rangecheckpercent {
// Unsuccessful case
{
let inp = Tensor::new(Some(&[Value::<Fp>::known(Fp::from(100_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<Fp>::known(Fp::from(102_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<Fp> {
let inp = Tensor::new(Some(&[Value::<F>::known(F::from(100_u64))]), &[1]).unwrap();
let out = Tensor::new(Some(&[Value::<F>::known(F::from(102_u64))]), &[1]).unwrap();
let circuit = MyCircuit::<F> {
input: ValTensor::from(inp),
output: ValTensor::from(out),
_marker: PhantomData,
@@ -1699,14 +1695,13 @@ mod relu {
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::pasta::Fp as F;
#[derive(Clone)]
struct ReLUCircuit<F: PrimeField + TensorType + PartialOrd> {
pub input: ValTensor<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for ReLUCircuit<F> {
impl Circuit<F> for ReLUCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;
@@ -1779,7 +1774,6 @@ mod softmax {
dev::MockProver,
plonk::{Circuit, ConstraintSystem, Error},
};
use halo2curves::pasta::Fp as F;
const K: usize = 18;
const LEN: usize = 3;
@@ -1791,7 +1785,7 @@ mod softmax {
_marker: PhantomData<F>,
}
impl<F: PrimeField + TensorType + PartialOrd> Circuit<F> for SoftmaxCircuit<F> {
impl Circuit<F> for SoftmaxCircuit<F> {
type Config = BaseConfig<F>;
type FloorPlanner = SimpleFloorPlanner;
type Params = TestParams;

View File

@@ -247,7 +247,7 @@ pub async fn verify_proof_via_solidity(
let encoded = func.encode_input(&[
Token::FixedArray(public_inputs.into_iter().map(Token::Uint).collect()),
Token::Bytes(proof.proof.into()),
Token::Bytes(proof.proof),
])?;
info!("encoded: {:#?}", hex::encode(&encoded));
@@ -371,7 +371,7 @@ pub async fn verify_proof_with_data_attestation(
let encoded = func.encode_input(&[
Token::FixedArray(public_inputs.into_iter().map(Token::Uint).collect()),
Token::Bytes(proof.proof.into()),
Token::Bytes(proof.proof),
])?;
info!("encoded: {:#?}", hex::encode(&encoded));

View File

@@ -860,7 +860,7 @@ pub(crate) fn create_evm_verifier(
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path.clone(), "Verifier", None)?;
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Verifier", None)?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
@@ -933,7 +933,7 @@ pub(crate) fn create_evm_data_attestation_verifier(
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) =
get_contract_artifacts(sol_code_path.clone(), "DataAttestationVerifier", None)?;
get_contract_artifacts(sol_code_path, "DataAttestationVerifier", None)?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
} else {
@@ -1018,7 +1018,7 @@ pub(crate) fn create_evm_aggregate_verifier(
let settings: Vec<GraphSettings> = circuit_settings
.iter()
.map(|path| GraphSettings::load(&path).unwrap())
.map(|path| GraphSettings::load(path).unwrap())
.collect::<Vec<_>>();
let num_public_inputs: usize = settings
@@ -1053,7 +1053,7 @@ pub(crate) fn create_evm_aggregate_verifier(
let _ = f.write(output.as_bytes());
// fetch abi of the contract
let (abi, _, _) = get_contract_artifacts(sol_code_path.clone(), "Verifier", None)?;
let (abi, _, _) = get_contract_artifacts(sol_code_path, "Verifier", None)?;
// save abi to file
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;

View File

@@ -405,7 +405,7 @@ pub struct GraphConfig {
}
/// Defines the circuit for a computational graph / model loaded from a `.onnx` file.
#[derive(Clone, Debug, Default)]
#[derive(Clone, Debug, Default, Serialize)]
pub struct GraphCircuit {
/// The model / graph of computations.
pub model: Model,
@@ -796,13 +796,13 @@ impl GraphCircuit {
if visibility.params.requires_processing() {
let params = self.model.get_all_consts();
let flattened_params = flatten_valtensors(params)?;
let flattened_params = if !flattened_params.is_empty() {
vec![flattened_params[0].get_felt_evals()?.into_iter().into()]
} else {
vec![]
};
processed_params = Some(GraphModules::forward(&flattened_params, visibility.params)?);
if !params.is_empty() {
let flattened_params = Tensor::new(Some(&params), &[params.len()])?.combine()?;
processed_params = Some(GraphModules::forward(
&[flattened_params],
visibility.params,
)?);
}
}
let model_results = self.model.forward(inputs)?;
@@ -1003,40 +1003,53 @@ impl Circuit<Fp> for GraphCircuit {
&self.module_settings.input,
)?;
// now we need to flatten the params
let mut flattened_params =
flatten_valtensors(self.model.get_all_consts()).map_err(|_| {
log::error!("failed to flatten params");
// now we need to assign the flattened params to the model
let mut model = self.model.clone();
let param_visibility = self.settings.run_args.param_visibility;
trace!("running params module layout");
if !self.model.get_all_consts().is_empty() && param_visibility.requires_processing() {
// now we need to flatten the params
let consts = self.model.get_all_consts();
let mut flattened_params = {
let mut t = Tensor::new(Some(&consts), &[consts.len()])
.map_err(|_| {
log::error!("failed to flatten params");
PlonkError::Synthesis
})?
.combine()
.map_err(|_| {
log::error!("failed to combine params");
PlonkError::Synthesis
})?;
t.set_visibility(param_visibility);
vec![t.into()]
};
// now do stuff to the model params
GraphModules::layout(
&mut layouter,
&config.module_configs,
&mut flattened_params,
param_visibility,
&mut instance_offset,
&self.module_settings.params,
)?;
let shapes = self.model.const_shapes();
trace!("replacing processed consts");
let split_params = split_valtensor(&flattened_params[0], shapes).map_err(|_| {
log::error!("failed to split params");
PlonkError::Synthesis
})?;
// now do stuff to the model params
GraphModules::layout(
&mut layouter,
&config.module_configs,
&mut flattened_params,
self.settings.run_args.param_visibility,
&mut instance_offset,
&self.module_settings.params,
)?;
// now we need to assign the flattened params to the model
let mut model = self.model.clone();
if !self.model.get_all_consts().is_empty() {
// now the flattened_params have been assigned to and we-assign them to the model consts such that they are constrained to be equal
model.replace_consts(
split_valtensor(flattened_params[0].clone(), self.model.const_shapes()).map_err(
|_| {
log::error!("failed to replace params");
PlonkError::Synthesis
},
)?,
);
model.replace_consts(split_params);
}
// create a new module for the model (space 2)
layouter.assign_region(|| "_new_module", |_| Ok(()))?;
trace!("Laying out model");
trace!("laying out model");
let mut outputs = model
.layout(
config.model_config.clone(),

View File

@@ -17,6 +17,8 @@ use crate::{
use halo2curves::bn256::Fr as Fp;
use colored::Colorize;
use serde::Deserialize;
use serde::Serialize;
use tract_onnx::prelude::{
DatumExt, Graph, InferenceFact, InferenceModelExt, SymbolValues, TypedFact, TypedOp,
};
@@ -33,6 +35,9 @@ use log::{debug, info, trace};
use std::collections::BTreeMap;
use std::collections::HashSet;
use std::error::Error;
use std::fs;
use std::io::Read;
use std::path::PathBuf;
use tabled::Table;
use tract_onnx;
use tract_onnx::prelude::Framework;
@@ -59,7 +64,7 @@ pub struct ModelConfig {
pub type NodeGraph = BTreeMap<usize, NodeType>;
/// A struct for loading from an Onnx file and converting a computational graph to a circuit.
#[derive(Clone, Debug, Default)]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
pub struct Model {
/// input indices
pub graph: ParsedNodes,
@@ -68,7 +73,7 @@ pub struct Model {
}
/// Enables model as subnode of other models
#[derive(Clone, Debug)]
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum NodeType {
/// A node in the model
Node(Node),
@@ -146,7 +151,7 @@ impl NodeType {
}
}
#[derive(Clone, Debug, Default)]
#[derive(Clone, Debug, Default, Serialize, Deserialize, PartialEq)]
/// A set of EZKL nodes that represent a computational graph.
pub struct ParsedNodes {
nodes: BTreeMap<usize, NodeType>,
@@ -225,6 +230,25 @@ impl Model {
Ok(om)
}
///
pub fn save(&self, path: PathBuf) -> Result<(), Box<dyn Error>> {
let f = std::fs::File::create(path)?;
let writer = std::io::BufWriter::new(f);
bincode::serialize_into(writer, &self)?;
Ok(())
}
///
pub fn load(path: PathBuf) -> Result<Self, Box<dyn Error>> {
// read bytes from file
let mut f = std::fs::File::open(&path).expect("no file found");
let metadata = fs::metadata(&path).expect("unable to read metadata");
let mut buffer = vec![0; metadata.len() as usize];
f.read_exact(&mut buffer).expect("buffer overflow");
let result = bincode::deserialize(&buffer)?;
Ok(result)
}
/// Generate model parameters for the circuit
pub fn gen_params(
&self,
@@ -793,7 +817,7 @@ impl Model {
}
/// Retrieves all constants from the model.
pub fn get_all_consts(&self) -> Vec<ValTensor<Fp>> {
pub fn get_all_consts(&self) -> Vec<Tensor<Fp>> {
let mut consts = vec![];
for node in self.graph.nodes.values() {
match node {
@@ -841,10 +865,13 @@ impl Model {
.as_any()
.downcast_ref::<crate::circuit::ops::Constant<Fp>>()
{
n.opkind = Box::new(crate::circuit::Constant::new(
consts[const_idx].clone(),
let mut op = crate::circuit::Constant::new(
constant.quantized_values.clone(),
constant.raw_values.clone(),
));
);
op.pre_assign(consts[const_idx].clone());
n.opkind = Box::new(op);
const_idx += 1;
};
}

View File

@@ -240,21 +240,21 @@ impl GraphModules {
module_res: &Option<ModuleForwardResult>,
instances: &mut ModuleInstances,
) {
if visibility.is_hashed() {
instances
.poseidon
.extend(module_res.clone().unwrap().poseidon_hash.unwrap());
} else if visibility.is_encrypted() {
instances.elgamal.extend(
module_res
.clone()
.unwrap()
.elgamal
.unwrap()
.ciphertexts
.into_iter()
.flatten(),
);
if let Some(res) = module_res {
if visibility.is_hashed() {
instances
.poseidon
.extend(res.poseidon_hash.clone().unwrap());
} else if visibility.is_encrypted() {
instances.elgamal.extend(
res.elgamal
.clone()
.unwrap()
.ciphertexts
.into_iter()
.flatten(),
);
}
}
}

View File

@@ -5,6 +5,8 @@ use crate::graph::new_op_from_onnx;
use crate::graph::GraphError;
use halo2curves::bn256::Fr as Fp;
use log::trace;
use serde::Deserialize;
use serde::Serialize;
use std::collections::BTreeMap;
use std::error::Error;
use std::fmt;
@@ -28,10 +30,11 @@ fn display_opkind(v: &Box<dyn Op<Fp>>) -> String {
}
/// A single operation in a [crate::graph::Model].
#[derive(Clone, Debug, Tabled)]
#[derive(Clone, Debug, Tabled, Serialize, Deserialize)]
pub struct Node {
/// [Op] i.e what operation this node represents.
#[tabled(display_with = "display_opkind")]
#[serde(with = "serde_traitobject")]
pub opkind: Box<dyn Op<Fp>>,
/// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined.
pub out_scale: u32,
@@ -47,6 +50,16 @@ pub struct Node {
pub idx: usize,
}
impl PartialEq for Node {
fn eq(&self, other: &Node) -> bool {
(self.out_scale == other.out_scale)
&& (self.inputs == other.inputs)
&& (self.out_dims == other.out_dims)
&& (self.idx == other.idx)
&& (self.opkind.as_string() == other.opkind.as_string())
}
}
impl Node {
/// Converts a tract [OnnxNode] into an ezkl [Node].
/// # Arguments:

View File

@@ -4,7 +4,7 @@ use super::{GraphError, Visibility};
use crate::circuit::hybrid::HybridOp;
use crate::circuit::lookup::LookupOp;
use crate::circuit::poly::PolyOp;
use crate::tensor::{Tensor, TensorError, TensorType, ValTensor, ValType};
use crate::tensor::{Tensor, TensorError, TensorType};
use halo2curves::bn256::Fr as Fp;
use halo2curves::ff::PrimeField;
use log::{debug, warn};
@@ -309,7 +309,7 @@ pub fn new_op_from_onnx(
let constant_scale = if dt == DatumType::Bool { 0 } else { scale };
// Quantize the raw value
let quantized_value =
tensor_to_valtensor(raw_value.clone(), constant_scale, param_visibility)?;
quantize_tensor(raw_value.clone(), constant_scale, param_visibility)?;
// Create a constant op
Box::new(crate::circuit::ops::Constant::new(
quantized_value,
@@ -412,7 +412,7 @@ pub fn new_op_from_onnx(
let boxed_op = inp.opkind();
if let Some(c) = extract_const_raw_values(boxed_op) {
inputs.remove(idx);
params = Some(tensor_to_valtensor(c, max_scale, param_visibility)?);
params = Some(quantize_tensor(c, max_scale, param_visibility)?);
}
}
@@ -557,13 +557,13 @@ pub fn new_op_from_onnx(
(padding[0], padding[1], stride[0], stride[1]);
let kernel = extract_tensor_value(conv_node.kernel.clone())?;
let kernel = tensor_to_valtensor(kernel, scale, param_visibility)?;
let kernel = quantize_tensor(kernel, scale, param_visibility)?;
let bias = match conv_node.bias.clone() {
Some(b) => {
let const_value = extract_tensor_value(b)?;
let val = tensor_to_valtensor(
let val = quantize_tensor(
const_value,
scale + inputs[0].out_scales()[0],
param_visibility,
@@ -621,13 +621,13 @@ pub fn new_op_from_onnx(
(padding[0], padding[1], stride[0], stride[1]);
let kernel = extract_tensor_value(deconv_node.kernel.clone())?;
let kernel = tensor_to_valtensor(kernel, scale, param_visibility)?;
let kernel = quantize_tensor(kernel, scale, param_visibility)?;
let bias = match deconv_node.bias.clone() {
Some(b) => {
let const_value = extract_tensor_value(b)?;
let val = tensor_to_valtensor(
let val = quantize_tensor(
const_value,
scale + inputs[0].out_scales()[0],
param_visibility,
@@ -805,7 +805,7 @@ pub fn extract_const_raw_values(boxed_op: Box<dyn crate::circuit::Op<Fp>>) -> Op
/// Extracts the quantized values from a [crate::circuit::ops::Constant] op.
pub fn extract_const_quantized_values(
boxed_op: Box<dyn crate::circuit::Op<Fp>>,
) -> Option<ValTensor<Fp>> {
) -> Option<Tensor<Fp>> {
boxed_op
.as_any()
.downcast_ref::<crate::circuit::ops::Constant<Fp>>()
@@ -813,58 +813,23 @@ pub fn extract_const_quantized_values(
}
/// Converts a tensor to a [ValTensor] with a given scale.
pub fn tensor_to_valtensor<F: PrimeField + TensorType + PartialOrd>(
pub fn quantize_tensor<F: PrimeField + TensorType + PartialOrd>(
const_value: Tensor<f32>,
scale: u32,
visibility: Visibility,
) -> Result<ValTensor<F>, Box<dyn std::error::Error>> {
let mut value: ValTensor<F> = match visibility {
Visibility::Public => const_value
.map(|x| {
crate::tensor::ValType::Constant(crate::fieldutils::i128_to_felt::<F>(
quantize_float(&x.into(), 0.0, scale).unwrap(),
))
})
.into(),
Visibility::Private | Visibility::Hashed | Visibility::Encrypted => const_value
.map(|x| {
crate::tensor::ValType::Value(halo2_proofs::circuit::Value::known(
crate::fieldutils::i128_to_felt::<F>(
quantize_float(&x.into(), 0.0, scale).unwrap(),
),
))
})
.into(),
};
) -> Result<Tensor<F>, Box<dyn std::error::Error>> {
let mut value: Tensor<F> = const_value.map(|x| {
crate::fieldutils::i128_to_felt::<F>(quantize_float(&x.into(), 0.0, scale).unwrap())
});
value.set_scale(scale);
value.set_visibility(visibility);
Ok(value)
}
/// Flatten a vector of [ValTensor]s into a single [ValTensor].
pub(crate) fn flatten_valtensors(
tensors: Vec<ValTensor<Fp>>,
) -> Result<Vec<ValTensor<Fp>>, Box<dyn std::error::Error>> {
if tensors.is_empty() {
return Ok(vec![]);
}
let mut merged: Vec<ValType<Fp>> = tensors[0]
.get_inner_tensor()?
.into_iter()
.collect::<Vec<_>>();
for tensor in tensors.iter().skip(1) {
let vals = tensor.get_inner_tensor()?.into_iter();
merged.extend(vals);
}
let tensor = Tensor::new(Some(&merged), &[merged.len()])?;
Ok(vec![tensor.into()])
}
use crate::tensor::ValTensor;
/// Split a [ValTensor] into a vector of [ValTensor]s.
pub(crate) fn split_valtensor(
values: ValTensor<Fp>,
values: &ValTensor<Fp>,
shapes: Vec<Vec<usize>>,
) -> Result<Vec<ValTensor<Fp>>, Box<dyn std::error::Error>> {
let mut tensors: Vec<ValTensor<Fp>> = Vec::new();
@@ -890,13 +855,18 @@ pub mod tests {
let tensor2: Tensor<Fp> = (10..20).map(|x| x.into()).into();
let tensor3: Tensor<Fp> = (20..30).map(|x| x.into()).into();
let flattened =
flatten_valtensors(vec![tensor1.into(), tensor2.into(), tensor3.into()]).unwrap();
let mut tensor = Tensor::new(Some(&[tensor1, tensor2, tensor3]), &[3])
.unwrap()
.combine()
.unwrap();
assert_eq!(flattened[0].len(), 30);
tensor.set_visibility(Visibility::Public);
let split =
split_valtensor(flattened[0].clone(), vec![vec![2, 5], vec![10], vec![5, 2]]).unwrap();
let flattened: ValTensor<Fp> = tensor.into();
assert_eq!(flattened.len(), 30);
let split = split_valtensor(&flattened, vec![vec![2, 5], vec![10], vec![5, 2]]).unwrap();
assert_eq!(split.len(), 3);
assert_eq!(split[0].len(), 10);

View File

@@ -101,7 +101,7 @@ impl std::fmt::Display for Visibility {
}
/// Represents whether the model input, model parameters, and model output are Public or Private to the prover.
#[derive(Clone, Debug, Default, Deserialize, Serialize)]
#[derive(Clone, Debug, Default, Deserialize, Serialize, PartialEq, PartialOrd)]
pub struct VarVisibility {
/// Input to the model or computational graph
pub input: Visibility,

View File

@@ -199,14 +199,14 @@ impl AggregationCircuit {
trace!("Aggregating with snark instances {:?}", snark.instances);
let mut transcript = PoseidonTranscript::<NativeLoader, _>::new(snark.proof.as_slice());
let proof = PlonkSuccinctVerifier::read_proof(
&svk,
svk,
snark.protocol.as_ref().unwrap(),
&snark.instances,
&mut transcript,
)
.map_err(|_| AggregationError::ProofRead)?;
let mut accum = PlonkSuccinctVerifier::verify(
&svk,
svk,
snark.protocol.as_ref().unwrap(),
&snark.instances,
&proof,
@@ -231,7 +231,7 @@ impl AggregationCircuit {
.concat();
Ok(Self {
svk: svk.clone(),
svk: *svk,
snarks: snarks.into_iter().map_into().collect(),
instances,
as_proof: Value::known(as_proof),
@@ -264,7 +264,7 @@ impl AggregationCircuit {
for instance in snark_instance.iter_mut() {
let mut felt_evals = vec![];
for value in instance.iter_mut() {
value.map(|v| felt_evals.push(v.clone()));
value.map(|v| felt_evals.push(v));
}
instances[0].extend(felt_evals);
}

View File

@@ -14,6 +14,7 @@ pub use var::*;
use crate::{
circuit::utils,
fieldutils::{felt_to_i32, i128_to_felt, i32_to_felt},
graph::Visibility,
};
use halo2_proofs::{
@@ -265,6 +266,8 @@ impl TensorType for halo2curves::bn256::Fr {
pub struct Tensor<T: TensorType> {
inner: Vec<T>,
dims: Vec<usize>,
scale: Option<u32>,
visibility: Option<Visibility>,
}
impl<T: TensorType> IntoIterator for Tensor<T> {
@@ -446,15 +449,39 @@ impl<T: Clone + TensorType> Tensor<T> {
Ok(Tensor {
inner: Vec::from(v),
dims: Vec::from(dims),
scale: None,
visibility: None,
})
}
None => Ok(Tensor {
inner: vec![T::zero().unwrap(); total_dims],
dims: Vec::from(dims),
scale: None,
visibility: None,
}),
}
}
/// set the tensor's (optional) scale parameter
pub fn set_scale(&mut self, scale: u32) {
self.scale = Some(scale)
}
/// set the tensor's (optional) visibility parameter
pub fn set_visibility(&mut self, visibility: Visibility) {
self.visibility = Some(visibility)
}
/// getter for scale
pub fn scale(&self) -> Option<u32> {
self.scale
}
/// getter for visibility
pub fn visibility(&self) -> Option<Visibility> {
self.visibility
}
/// Returns the number of elements in the tensor.
pub fn len(&self) -> usize {
self.dims().iter().product::<usize>()

View File

@@ -130,7 +130,19 @@ impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<ValType<F>>> for ValTe
impl<F: PrimeField + TensorType + PartialOrd> From<Tensor<F>> for ValTensor<F> {
fn from(t: Tensor<F>) -> ValTensor<F> {
ValTensor::Value {
inner: t.map(|x| x.into()),
inner: t.map(|x|
if let Some(vis) = t.visibility {
match vis {
Visibility::Public => x.into(),
Visibility::Private | Visibility::Hashed | Visibility::Encrypted => {
Value::known(x).into()
}
}
}
else {
panic!("visibility should be set to convert a tensor of field elements to a ValTensor.")
}
),
dims: t.dims().to_vec(),
scale: 1,
}

View File

@@ -4,7 +4,7 @@ mod native_tests {
use core::panic;
use ezkl::graph::input::{FileSource, GraphData};
use ezkl::graph::{DataSource, GraphSettings};
use ezkl::graph::{DataSource, GraphSettings, Visibility};
use lazy_static::lazy_static;
use std::env::var;
use std::process::Command;
@@ -302,6 +302,7 @@ mod native_tests {
use crate::native_tests::kzg_prove_and_verify;
use crate::native_tests::kzg_fuzz;
use crate::native_tests::render_circuit;
use crate::native_tests::model_serialization;
use crate::native_tests::tutorial as run_tutorial;
#[test]
@@ -312,8 +313,16 @@ mod native_tests {
}
seq!(N in 0..=36 {
#(#[test_case(TESTS[N])])*
fn model_serialization_(test: &str) {
crate::native_tests::mv_test_(test);
// percent tolerance test
model_serialization(test.to_string());
}
#(#[test_case(TESTS[N])])*
fn render_circuit_(test: &str) {
crate::native_tests::init_binary();
@@ -696,6 +705,26 @@ mod native_tests {
test_func_examples!();
test_neg_examples!();
fn model_serialization(example_name: String) {
let test_dir = TEST_DIR.path().to_str().unwrap();
let model_path = format!("{}/{}/network.onnx", test_dir, example_name);
let serialization_path = format!("{}/{}/network.ezkl", test_dir, example_name);
let run_args = ezkl::commands::RunArgs {
param_visibility: Visibility::Public,
batch_size: 1,
..Default::default()
};
let model =
ezkl::graph::Model::new(&mut std::fs::File::open(model_path).unwrap(), run_args)
.unwrap();
model.save(serialization_path.clone().into()).unwrap();
let loaded_model = ezkl::graph::Model::load(serialization_path.into()).unwrap();
assert_eq!(model, loaded_model)
}
// Mock prove (fast, but does not cover some potential issues)
fn neg_mock(example_name: String, counter_example: String) {
let test_dir = TEST_DIR.path().to_str().unwrap();

View File

@@ -75,7 +75,7 @@ mod py_tests {
])
.status()
.expect("failed to execute process");
assert!(status.success());
let status = Command::new("pip")
.args(["install", "numpy==1.23"])
.status()
@@ -168,7 +168,7 @@ mod py_tests {
"--to",
"notebook",
"--execute",
&path.to_str().unwrap(),
(path.to_str().unwrap()),
])
.status()
.expect("failed to execute process");

View File

@@ -1,186 +0,0 @@
#[cfg(test)]
mod wasi_tests {
use lazy_static::lazy_static;
use std::env::var;
use std::process::Command;
use std::sync::Once;
use tempdir::TempDir;
static COMPILE: Once = Once::new();
lazy_static! {
static ref CARGO_TARGET_DIR: String =
var("CARGO_TARGET_DIR").unwrap_or_else(|_| "./target".to_string());
static ref TEST_DIR: TempDir = TempDir::new("example").unwrap();
}
fn mv_test_(test: &str) {
let test_dir = TEST_DIR.path().to_str().unwrap();
let path: std::path::PathBuf = format!("{}/{}", test_dir, test).into();
if !path.exists() {
let status = Command::new("cp")
.args([
"-R",
&format!("./examples/onnx/{}", test),
&format!("{}/{}", test_dir, test),
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
}
fn init() {
COMPILE.call_once(|| {
println!("using cargo target dir: {}", *CARGO_TARGET_DIR);
build_ezkl_wasm();
});
}
const TESTS: [&str; 19] = [
"1l_mlp",
"1l_flatten",
"1l_average",
"1l_div",
"1l_pad",
"1l_reshape",
"1l_sigmoid",
"1l_sqrt",
"1l_leakyrelu",
"1l_relu",
"2l_relu_sigmoid_small",
"2l_relu_fc",
"2l_relu_small",
"2l_relu_sigmoid",
"1l_conv",
"2l_sigmoid_small",
"2l_relu_sigmoid_conv",
"3l_relu_conv_fc",
"4l_relu_conv_fc",
];
macro_rules! wasi_test_func {
() => {
#[cfg(test)]
mod tests_wasi {
use seq_macro::seq;
use crate::wasi_tests::TESTS;
use test_case::test_case;
use crate::wasi_tests::mock;
seq!(N in 0..=18 {
#(#[test_case(TESTS[N])])*
fn mock_public_outputs_(test: &str) {
crate::wasi_tests::init();
crate::wasi_tests::mv_test_(test);
mock(test.to_string(), 7, 16, 17, "private", "private", "public", 1);
}
#(#[test_case(TESTS[N])])*
fn mock_public_inputs_(test: &str) {
crate::wasi_tests::init();
crate::wasi_tests::mv_test_(test);
mock(test.to_string(), 7, 16, 17, "public", "private", "private", 1);
}
#(#[test_case(TESTS[N])])*
fn mock_public_params_(test: &str) {
crate::wasi_tests::init();
crate::wasi_tests::mv_test_(test);
mock(test.to_string(), 7, 16, 17, "private", "public", "private", 1);
}
});
}
};
}
wasi_test_func!();
#[allow(clippy::too_many_arguments)]
// Mock prove (fast, but does not cover some potential issues)
fn mock(
example_name: String,
scale: usize,
bits: usize,
logrows: usize,
input_visibility: &str,
param_visibility: &str,
output_visibility: &str,
batch_size: usize,
) {
let test_dir = TEST_DIR.path().to_str().unwrap();
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"gen-settings",
"-M",
format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
&format!(
"--settings-path={}/{}/settings.json",
test_dir, example_name
),
&format!("--bits={}", bits),
&format!("--logrows={}", logrows),
&format!("--scale={}", scale),
&format!("--batch-size={}", batch_size),
&format!("--input-visibility={}", input_visibility),
&format!("--param-visibility={}", param_visibility),
&format!("--output-visibility={}", output_visibility),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"gen-witness",
"-D",
&format!("{}/{}/input.json", test_dir, example_name),
"-M",
&format!("{}/{}/network.onnx", test_dir, example_name),
"-O",
&format!("{}/{}/witness.json", test_dir, example_name),
&format!(
"--settings-path={}/{}/settings.json",
test_dir, example_name
),
])
.status()
.expect("failed to execute process");
assert!(status.success());
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
.args([
"mock",
"-W",
format!("{}/{}/witness.json", test_dir, example_name).as_str(),
"-M",
format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
&format!(
"--settings-path={}/{}/settings.json",
test_dir, example_name
),
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
fn build_ezkl_wasm() {
let status = Command::new("cargo")
.args([
"build",
"--release",
"--bin",
"ezkl",
"--target",
"wasm32-wasi",
])
.status()
.expect("failed to execute process");
assert!(status.success());
}
}

Binary file not shown.

File diff suppressed because one or more lines are too long

Binary file not shown.