mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-10 06:48:01 -05:00
chore: bump tract and h2 (#438)
This commit is contained in:
24
.github/workflows/rust.yml
vendored
24
.github/workflows/rust.yml
vendored
@@ -104,26 +104,6 @@ jobs:
|
||||
- name: Run wasm verifier tests
|
||||
run: wasm-pack test --chrome --headless -- -Z build-std="panic_abort,std" --features web
|
||||
|
||||
render-circuit:
|
||||
runs-on: ubuntu-latest-32-cores
|
||||
needs: [build, library-tests, docs]
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions-rs/toolchain@v1
|
||||
with:
|
||||
toolchain: nightly-2023-08-24
|
||||
override: true
|
||||
components: rustfmt, clippy
|
||||
- uses: mwilliamson/setup-wasmtime-action@v2
|
||||
with:
|
||||
wasmtime-version: "3.0.1"
|
||||
- uses: baptiste0928/cargo-install@v1
|
||||
with:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Circuit Render
|
||||
run: cargo nextest run --release --features render --verbose tests::render_circuit_
|
||||
|
||||
tutorial:
|
||||
runs-on: ubuntu-latest
|
||||
needs: [build, library-tests, docs]
|
||||
@@ -448,6 +428,8 @@ jobs:
|
||||
# # now dump the contents of the file into a file called kaggle.json
|
||||
# echo $KAGGLE_API_KEY > /home/ubuntu/.kaggle/kaggle.json
|
||||
# chmod 600 /home/ubuntu/.kaggle/kaggle.json
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: Mean tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_7_expects --no-capture
|
||||
- name: End to end demo tutorial
|
||||
@@ -460,8 +442,6 @@ jobs:
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_9_expects
|
||||
- name: Little transformer tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_8_expects --no-capture
|
||||
- name: Voice tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::voice_
|
||||
- name: Encrypted tutorial
|
||||
run: source .env/bin/activate; cargo nextest run py_tests::tests::run_notebook_::tests_2_expects
|
||||
- name: Hashed tutorial
|
||||
|
||||
55
Cargo.lock
generated
55
Cargo.lock
generated
@@ -1119,7 +1119,7 @@ checksum = "68b0cf012f1230e43cd00ebb729c6bb58707ecfa8ad08b52ef3a4ccd2697fc30"
|
||||
[[package]]
|
||||
name = "ecc"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/send-sync-region#64e3a06b6be822fcfd4a117d331c6478a181e11e"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/send-sync-region#1c49a066b0d0c38e302f0ff3b1f51c05ffe1486d"
|
||||
dependencies = [
|
||||
"integer",
|
||||
"num-bigint",
|
||||
@@ -1597,7 +1597,7 @@ dependencies = [
|
||||
"getrandom",
|
||||
"halo2_gadgets",
|
||||
"halo2_proofs",
|
||||
"halo2curves 0.1.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=2f322219b39b67da8979bf2b014b31145e7872b0)",
|
||||
"halo2curves 0.1.0",
|
||||
"hex",
|
||||
"indicatif",
|
||||
"instant",
|
||||
@@ -2043,14 +2043,14 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_gadgets"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=ac/update-h2curves#b154081f47f70fed41b04d67f7813487814f8dc6"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=ac/fix-gadget-diff#44c2ec6a03c4de060281b63a34c5ed4924f9097f"
|
||||
dependencies = [
|
||||
"arrayvec 0.7.4",
|
||||
"bitvec 1.0.1",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2_proofs",
|
||||
"halo2curves 0.1.0 (registry+https://github.com/rust-lang/crates.io-index)",
|
||||
"halo2curves 0.1.0",
|
||||
"lazy_static",
|
||||
"rand 0.8.5",
|
||||
"subtle",
|
||||
@@ -2060,12 +2060,12 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.2.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=ac/update-h2curves#b154081f47f70fed41b04d67f7813487814f8dc6"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=ac/fix-gadget-diff#44c2ec6a03c4de060281b63a34c5ed4924f9097f"
|
||||
dependencies = [
|
||||
"blake2b_simd",
|
||||
"ff",
|
||||
"group",
|
||||
"halo2curves 0.1.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=2f322219b39b67da8979bf2b014b31145e7872b0)",
|
||||
"halo2curves 0.1.0",
|
||||
"plotters",
|
||||
"rand_chacha",
|
||||
"rand_core 0.6.4",
|
||||
@@ -2080,25 +2080,6 @@ name = "halo2curves"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6b1142bd1059aacde1b477e0c80c142910f1ceae67fc619311d6a17428007ab"
|
||||
dependencies = [
|
||||
"blake2b_simd",
|
||||
"ff",
|
||||
"group",
|
||||
"lazy_static",
|
||||
"num-bigint",
|
||||
"num-traits",
|
||||
"pasta_curves",
|
||||
"paste",
|
||||
"rand 0.8.5",
|
||||
"rand_core 0.6.4",
|
||||
"static_assertions",
|
||||
"subtle",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "halo2curves"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/privacy-scaling-explorations/halo2curves?rev=2f322219b39b67da8979bf2b014b31145e7872b0#2f322219b39b67da8979bf2b014b31145e7872b0"
|
||||
dependencies = [
|
||||
"blake2b_simd",
|
||||
"ff",
|
||||
@@ -2137,7 +2118,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2wrong"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/send-sync-region#64e3a06b6be822fcfd4a117d331c6478a181e11e"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/send-sync-region#1c49a066b0d0c38e302f0ff3b1f51c05ffe1486d"
|
||||
dependencies = [
|
||||
"halo2_proofs",
|
||||
"num-bigint",
|
||||
@@ -2477,7 +2458,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "integer"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/send-sync-region#64e3a06b6be822fcfd4a117d331c6478a181e11e"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/send-sync-region#1c49a066b0d0c38e302f0ff3b1f51c05ffe1486d"
|
||||
dependencies = [
|
||||
"maingate",
|
||||
"num-bigint",
|
||||
@@ -2720,7 +2701,7 @@ checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
|
||||
[[package]]
|
||||
name = "maingate"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/send-sync-region#64e3a06b6be822fcfd4a117d331c6478a181e11e"
|
||||
source = "git+https://github.com/zkonduit/halo2wrong?branch=ac/send-sync-region#1c49a066b0d0c38e302f0ff3b1f51c05ffe1486d"
|
||||
dependencies = [
|
||||
"halo2wrong",
|
||||
"num-bigint",
|
||||
@@ -4500,11 +4481,11 @@ checksum = "a507befe795404456341dfab10cef66ead4c041f62b8b11bbb92bffe5d0953e0"
|
||||
[[package]]
|
||||
name = "snark-verifier"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/send-sync-region#9acce727c7b30225b2d4bd642910558e2ae5fd5c"
|
||||
source = "git+https://github.com/zkonduit/snark-verifier?branch=ac/send-sync-region#52394d8d326648dfa4d3bdc8cdf46a6c578870bb"
|
||||
dependencies = [
|
||||
"ecc",
|
||||
"halo2_proofs",
|
||||
"halo2curves 0.1.0 (git+https://github.com/privacy-scaling-explorations/halo2curves?rev=2f322219b39b67da8979bf2b014b31145e7872b0)",
|
||||
"halo2curves 0.1.0",
|
||||
"hex",
|
||||
"itertools",
|
||||
"lazy_static",
|
||||
@@ -5074,7 +5055,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "tract-core"
|
||||
version = "0.20.7-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=f5901a9634ec2d3e3769146d8799e776f0d79f15#f5901a9634ec2d3e3769146d8799e776f0d79f15"
|
||||
source = "git+https://github.com/sonos/tract/?rev=dd39d4e#dd39d4e31f83200dfd0fa77cfb603517f81b5107"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bit-set",
|
||||
@@ -5097,7 +5078,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "tract-data"
|
||||
version = "0.20.7-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=f5901a9634ec2d3e3769146d8799e776f0d79f15#f5901a9634ec2d3e3769146d8799e776f0d79f15"
|
||||
source = "git+https://github.com/sonos/tract/?rev=dd39d4e#dd39d4e31f83200dfd0fa77cfb603517f81b5107"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"half 2.2.1",
|
||||
@@ -5116,7 +5097,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "tract-hir"
|
||||
version = "0.20.7-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=f5901a9634ec2d3e3769146d8799e776f0d79f15#f5901a9634ec2d3e3769146d8799e776f0d79f15"
|
||||
source = "git+https://github.com/sonos/tract/?rev=dd39d4e#dd39d4e31f83200dfd0fa77cfb603517f81b5107"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"log",
|
||||
@@ -5126,7 +5107,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "tract-linalg"
|
||||
version = "0.20.7-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=f5901a9634ec2d3e3769146d8799e776f0d79f15#f5901a9634ec2d3e3769146d8799e776f0d79f15"
|
||||
source = "git+https://github.com/sonos/tract/?rev=dd39d4e#dd39d4e31f83200dfd0fa77cfb603517f81b5107"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"derive-new",
|
||||
@@ -5150,7 +5131,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "tract-nnef"
|
||||
version = "0.20.7-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=f5901a9634ec2d3e3769146d8799e776f0d79f15#f5901a9634ec2d3e3769146d8799e776f0d79f15"
|
||||
source = "git+https://github.com/sonos/tract/?rev=dd39d4e#dd39d4e31f83200dfd0fa77cfb603517f81b5107"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"flate2",
|
||||
@@ -5164,7 +5145,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "tract-onnx"
|
||||
version = "0.20.7-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=f5901a9634ec2d3e3769146d8799e776f0d79f15#f5901a9634ec2d3e3769146d8799e776f0d79f15"
|
||||
source = "git+https://github.com/sonos/tract/?rev=dd39d4e#dd39d4e31f83200dfd0fa77cfb603517f81b5107"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"derive-new",
|
||||
@@ -5181,7 +5162,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "tract-onnx-opl"
|
||||
version = "0.20.7-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=f5901a9634ec2d3e3769146d8799e776f0d79f15#f5901a9634ec2d3e3769146d8799e776f0d79f15"
|
||||
source = "git+https://github.com/sonos/tract/?rev=dd39d4e#dd39d4e31f83200dfd0fa77cfb603517f81b5107"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"log",
|
||||
|
||||
10
Cargo.toml
10
Cargo.toml
@@ -13,9 +13,9 @@ crate-type = ["cdylib", "rlib"]
|
||||
|
||||
|
||||
[dependencies]
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", default_features = false, branch= "ac/update-h2curves" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "ac/update-h2curves", default_features = false, features = ["thread-safe-region"]}
|
||||
halo2curves = { git = "https://github.com/privacy-scaling-explorations/halo2curves", rev = "2f322219b39b67da8979bf2b014b31145e7872b0", package = "halo2curves", features = ["derive_serde"] }
|
||||
halo2_gadgets = { git = "https://github.com/zkonduit/halo2", branch= "ac/fix-gadget-diff", default_features = false }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2", branch= "ac/fix-gadget-diff", default_features = false, features = ["thread-safe-region"]}
|
||||
halo2curves = { version = "0.1.0" }
|
||||
rand = { version = "0.8", default_features = false }
|
||||
itertools = { version = "0.10.3", default_features = false }
|
||||
clap = { version = "4.3.3", features = ["derive"]}
|
||||
@@ -24,7 +24,7 @@ serde_json = { version = "1.0.97", default_features = false, features = ["float_
|
||||
log = { version = "0.4.17", default_features = false, optional = true }
|
||||
thiserror = { version = "1.0.38", default_features = false }
|
||||
hex = { version = "0.4.3", default_features = false }
|
||||
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", default_features = false, package = "ecc", branch = "ac/send-sync-region"}
|
||||
halo2_wrong_ecc = { git = "https://github.com/zkonduit/halo2wrong", branch = "ac/send-sync-region", default_features=false, package = "ecc" }
|
||||
snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch = "ac/send-sync-region", features=["derive_serde"]}
|
||||
rayon = { version = "1.7.0", default_features = false }
|
||||
bincode = { version = "1.3.3", default_features = false }
|
||||
@@ -50,7 +50,7 @@ tokio = { version = "1.26.0", default_features = false, features = ["macros", "
|
||||
pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
|
||||
pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
|
||||
pyo3-log = { version = "0.8.1", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "f5901a9634ec2d3e3769146d8799e776f0d79f15", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev= "dd39d4e", default_features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true}
|
||||
|
||||
|
||||
|
||||
@@ -395,7 +395,7 @@
|
||||
"with open(cal_path, \"w\") as f:\n",
|
||||
" json.dump(cal_data, f)\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", num_batches=2)"
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
@@ -281,7 +281,7 @@
|
||||
"with open(cal_path, \"w\") as f:\n",
|
||||
" json.dump(cal_data, f)\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", num_batches=10)"
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -474,6 +474,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -481,6 +482,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -511,6 +513,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -569,6 +572,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
|
||||
@@ -264,7 +264,7 @@
|
||||
"with open(cal_path, \"w\") as f:\n",
|
||||
" json.dump(cal_data, f)\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\", num_batches=10)"
|
||||
"res = await ezkl.calibrate_settings(cal_path, model_path, settings_path, \"resources\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -559,6 +559,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -41,6 +42,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -103,6 +105,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -141,6 +144,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -179,6 +183,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -225,6 +230,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -268,6 +274,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -298,6 +305,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -340,6 +348,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -413,6 +422,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -505,6 +515,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": []
|
||||
@@ -562,6 +573,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -597,6 +609,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -615,7 +628,7 @@
|
||||
"res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)\n",
|
||||
"assert res == True\n",
|
||||
"\n",
|
||||
"res = await ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\", num_batches=len(val))\n",
|
||||
"res = await ezkl.calibrate_settings(val_data, model_path, settings_path, \"resources\")\n",
|
||||
"assert res == True\n",
|
||||
"print(\"verified\")\n"
|
||||
]
|
||||
@@ -631,6 +644,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -651,6 +665,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -671,6 +686,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -689,6 +705,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -721,6 +738,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -753,6 +771,7 @@
|
||||
]
|
||||
},
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
@@ -885,7 +904,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.9.13"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -251,9 +251,6 @@ pub enum Commands {
|
||||
#[arg(long = "target", default_value = "resources")]
|
||||
/// Target for calibration.
|
||||
target: CalibrationTarget,
|
||||
/// Number of calibration batches to run.
|
||||
#[arg(long = "num-batches", default_value = "1")]
|
||||
num_batches: usize,
|
||||
},
|
||||
|
||||
/// Generates a dummy SRS
|
||||
|
||||
@@ -160,8 +160,7 @@ pub async fn run(cli: Cli) -> Result<(), Box<dyn Error>> {
|
||||
settings_path,
|
||||
data,
|
||||
target,
|
||||
num_batches,
|
||||
} => calibrate(model, data, settings_path, target, num_batches).await,
|
||||
} => calibrate(model, data, settings_path, target).await,
|
||||
Commands::GenWitness {
|
||||
data,
|
||||
compiled_model,
|
||||
@@ -518,6 +517,9 @@ pub(crate) fn init_bar(len: u64) -> ProgressBar {
|
||||
pb
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
use colored_json::ToColoredJson;
|
||||
|
||||
/// Calibrate the circuit parameters to a given a dataset
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
#[allow(trivial_casts)]
|
||||
@@ -526,26 +528,25 @@ pub(crate) async fn calibrate(
|
||||
data: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
target: CalibrationTarget,
|
||||
num_batches: usize,
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let data = GraphData::from_path(data)?;
|
||||
// load the pre-generated settings
|
||||
let settings = GraphSettings::load(&settings_path)?;
|
||||
// now retrieve the run args
|
||||
|
||||
let pb = init_bar((2..16).len() as u64);
|
||||
|
||||
pb.set_message("calibrating...");
|
||||
// we load the model to get the input and output shapes
|
||||
let _r = Gag::stdout().unwrap();
|
||||
let model = Model::from_run_args(&settings.run_args, &model_path).unwrap();
|
||||
// drop the gag
|
||||
std::mem::drop(_r);
|
||||
|
||||
let chunks = data
|
||||
.split_into_batches(num_batches, model.graph.input_shapes())
|
||||
.unwrap();
|
||||
let chunks = data.split_into_batches(model.graph.input_shapes()).unwrap();
|
||||
|
||||
debug!("num of calibration batches: {}", chunks.len(),);
|
||||
info!("num of calibration batches: {}", chunks.len());
|
||||
|
||||
let pb = init_bar((2..16).len() as u64);
|
||||
|
||||
pb.set_message("calibrating...");
|
||||
|
||||
let mut found_params: Vec<GraphSettings> = vec![];
|
||||
|
||||
@@ -556,8 +557,9 @@ pub(crate) async fn calibrate(
|
||||
// vec of settings copied chunks.len() times
|
||||
let run_args_iterable = vec![settings.run_args.clone(); chunks.len()];
|
||||
|
||||
// let _r = Gag::stdout().unwrap();
|
||||
// Result<Vec<GraphSettings>, &str>
|
||||
let _r = Gag::stdout().unwrap();
|
||||
let _q = Gag::stderr().unwrap();
|
||||
|
||||
let tasks = chunks
|
||||
.iter()
|
||||
.zip(run_args_iterable)
|
||||
@@ -621,15 +623,23 @@ pub(crate) async fn calibrate(
|
||||
res.push(task);
|
||||
}
|
||||
}
|
||||
|
||||
// drop the gag
|
||||
std::mem::drop(_r);
|
||||
std::mem::drop(_q);
|
||||
|
||||
if let Some(best) = res
|
||||
.into_iter()
|
||||
.max_by_key(|p| (p.run_args.bits, p.run_args.scale))
|
||||
{
|
||||
// pick the one with the largest logrows
|
||||
found_params.push(best);
|
||||
found_params.push(best.clone());
|
||||
info!(
|
||||
"found settings: \n {}",
|
||||
best.as_json()?.to_colored_json_auto()?
|
||||
);
|
||||
}
|
||||
|
||||
// std::mem::drop(_r);
|
||||
pb.inc(1);
|
||||
}
|
||||
|
||||
|
||||
@@ -442,7 +442,6 @@ impl GraphData {
|
||||
///
|
||||
pub fn split_into_batches(
|
||||
&self,
|
||||
batch_size: usize,
|
||||
input_shapes: Vec<Vec<usize>>,
|
||||
) -> Result<Vec<Self>, Box<dyn std::error::Error>> {
|
||||
// split input data into batches
|
||||
@@ -466,15 +465,17 @@ impl GraphData {
|
||||
|
||||
for (i, input) in iterable.iter().enumerate() {
|
||||
// ensure the input is evenly divisible by batch_size
|
||||
if input.len() % batch_size != 0 {
|
||||
let input_size = input_shapes[i].clone().iter().product::<usize>();
|
||||
if input.len() % input_size != 0 {
|
||||
return Err(Box::new(GraphError::InvalidDims(
|
||||
0,
|
||||
"input data length must be evenly divisible by batch size".to_string(),
|
||||
"calibration data length must be evenly divisible by the original input_size"
|
||||
.to_string(),
|
||||
)));
|
||||
}
|
||||
let input_size = input_shapes[i].clone().iter().product::<usize>();
|
||||
let mut batches = vec![];
|
||||
for batch in input.chunks(batch_size * input_size) {
|
||||
for batch in input.chunks(input_size) {
|
||||
batches.push(batch.to_vec());
|
||||
}
|
||||
batched_inputs.push(batches);
|
||||
|
||||
@@ -356,6 +356,21 @@ impl GraphSettings {
|
||||
let res = serde_json::from_str(&data)?;
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
/// Parse an ezkl configuration from a json
|
||||
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
|
||||
serde_json::from_str(arg_json)
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for a computational graph / model loaded from a `.onnx` file.
|
||||
|
||||
17
src/lib.rs
17
src/lib.rs
@@ -94,6 +94,23 @@ pub struct RunArgs {
|
||||
pub param_visibility: Visibility,
|
||||
}
|
||||
|
||||
impl RunArgs {
|
||||
/// Export the ezkl configuration as json
|
||||
pub fn as_json(&self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let serialized = match serde_json::to_string(&self) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
return Err(Box::new(e));
|
||||
}
|
||||
};
|
||||
Ok(serialized)
|
||||
}
|
||||
/// Parse an ezkl configuration from a json
|
||||
pub fn from_json(arg_json: &str) -> Result<Self, serde_json::Error> {
|
||||
serde_json::from_str(arg_json)
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a single key-value pair
|
||||
fn parse_key_val<T, U>(
|
||||
s: &str,
|
||||
|
||||
@@ -251,7 +251,6 @@ fn gen_settings(
|
||||
model,
|
||||
settings,
|
||||
target,
|
||||
num_batches=None,
|
||||
))]
|
||||
fn calibrate_settings(
|
||||
py: Python,
|
||||
@@ -259,12 +258,10 @@ fn calibrate_settings(
|
||||
model: PathBuf,
|
||||
settings: PathBuf,
|
||||
target: Option<CalibrationTarget>,
|
||||
num_batches: Option<usize>,
|
||||
) -> PyResult<&pyo3::PyAny> {
|
||||
let target = target.unwrap_or(CalibrationTarget::Resources);
|
||||
let num_batches = num_batches.unwrap_or(1);
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::calibrate(model, data, settings, target, num_batches)
|
||||
crate::execute::calibrate(model, data, settings, target)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to calibrate settings: {}", e);
|
||||
|
||||
@@ -359,6 +359,7 @@ mod native_tests {
|
||||
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
fn render_circuit_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
@@ -2275,14 +2276,7 @@ mod native_tests {
|
||||
|
||||
fn build_ezkl() {
|
||||
let status = Command::new("cargo")
|
||||
.args([
|
||||
"build",
|
||||
"--release",
|
||||
"--features",
|
||||
"render",
|
||||
"--bin",
|
||||
"ezkl",
|
||||
])
|
||||
.args(["build", "--release", "--bin", "ezkl"])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
@@ -1 +1 @@
|
||||
[{"type":"function","name":"verify","inputs":[{"internalType":"uint256[3]","name":"pubInputs","type":"uint256[3]"},{"internalType":"bytes","name":"proof","type":"bytes"}],"outputs":[{"internalType":"bool","name":"","type":"bool"}],"stateMutability":"view"}]
|
||||
[{"type":"function","name":"verify","inputs":[{"internalType":"uint256[3]","name":"instances","type":"uint256[3]"},{"internalType":"bytes","name":"proof","type":"bytes"}],"outputs":[{"internalType":"bool","name":"","type":"bool"}],"stateMutability":"view"}]
|
||||
Reference in New Issue
Block a user