mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-13 00:08:12 -05:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0bbab9a020 | ||
|
|
8197340985 | ||
|
|
6855ea1947 | ||
|
|
2ca57bde2c | ||
|
|
390de88194 | ||
|
|
cd91f0af26 |
175
Cargo.lock
generated
175
Cargo.lock
generated
@@ -504,6 +504,12 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anes"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
|
||||
|
||||
[[package]]
|
||||
name = "ansi-str"
|
||||
version = "0.8.0"
|
||||
@@ -897,17 +903,6 @@ dependencies = [
|
||||
"syn 2.0.53",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||
dependencies = [
|
||||
"hermit-abi 0.1.19",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "auto_impl"
|
||||
version = "1.2.0"
|
||||
@@ -1176,14 +1171,30 @@ dependencies = [
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "2.34.0"
|
||||
name = "ciborium"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0610544180c38b88101fecf2dd634b174a62eef6946f84dfc6a7127512b381c"
|
||||
checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"textwrap",
|
||||
"unicode-width",
|
||||
"ciborium-io",
|
||||
"ciborium-ll",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ciborium-io"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757"
|
||||
|
||||
[[package]]
|
||||
name = "ciborium-ll"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9"
|
||||
dependencies = [
|
||||
"ciborium-io",
|
||||
"half 2.4.1",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1214,7 +1225,7 @@ version = "4.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd79504325bf38b10165b02e89b4347300f855f273c4cb30c4a3209e6583275e"
|
||||
dependencies = [
|
||||
"clap 4.5.3",
|
||||
"clap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1373,24 +1384,24 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "criterion"
|
||||
version = "0.3.6"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b01d6de93b2b6c65e17c634a26653a29d107b3c98c607c765bf38d041531cd8f"
|
||||
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
|
||||
dependencies = [
|
||||
"atty",
|
||||
"anes",
|
||||
"cast",
|
||||
"clap 2.34.0",
|
||||
"ciborium",
|
||||
"clap",
|
||||
"criterion-plot",
|
||||
"csv",
|
||||
"is-terminal",
|
||||
"itertools 0.10.5",
|
||||
"lazy_static",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"oorandom",
|
||||
"plotters",
|
||||
"rayon",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_cbor",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"tinytemplate",
|
||||
@@ -1399,9 +1410,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "criterion-plot"
|
||||
version = "0.4.5"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2673cc8207403546f45f5fd319a974b1e6983ad1a3ee7e6041650013be041876"
|
||||
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
|
||||
dependencies = [
|
||||
"cast",
|
||||
"itertools 0.10.5",
|
||||
@@ -1469,27 +1480,6 @@ dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "csv"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe"
|
||||
dependencies = [
|
||||
"csv-core",
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "csv-core"
|
||||
version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cuda-config"
|
||||
version = "0.1.0"
|
||||
@@ -1837,10 +1827,9 @@ name = "ezkl"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"alloy",
|
||||
"ark-std 0.3.0",
|
||||
"bincode",
|
||||
"chrono",
|
||||
"clap 4.5.3",
|
||||
"clap",
|
||||
"clap_complete",
|
||||
"colored",
|
||||
"colored_json",
|
||||
@@ -1850,7 +1839,6 @@ dependencies = [
|
||||
"env_logger",
|
||||
"ethabi",
|
||||
"foundry-compilers",
|
||||
"futures-util",
|
||||
"gag",
|
||||
"getrandom",
|
||||
"halo2_gadgets",
|
||||
@@ -1870,7 +1858,6 @@ dependencies = [
|
||||
"objc",
|
||||
"openssl",
|
||||
"pg_bigdecimal",
|
||||
"plotters",
|
||||
"portable-atomic",
|
||||
"pyo3",
|
||||
"pyo3-asyncio",
|
||||
@@ -1893,7 +1880,6 @@ dependencies = [
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"tokio-postgres",
|
||||
"tokio-util",
|
||||
"tosubcommand",
|
||||
"tract-onnx",
|
||||
"unzip-n",
|
||||
@@ -2267,10 +2253,11 @@ checksum = "1b43ede17f21864e81be2fa654110bf1e793774238d86ef8555c37e6519c0403"
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "2.2.1"
|
||||
version = "2.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02b4af3693f1b705df946e9fe5631932443781d0aabb423b62fcd4d73f6d2fd0"
|
||||
checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crunchy",
|
||||
"num-traits",
|
||||
]
|
||||
@@ -2295,7 +2282,7 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "halo2_proofs"
|
||||
version = "0.3.0"
|
||||
source = "git+https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a"
|
||||
source = "git+https://github.com/zkonduit/halo2#8cfca221f53069a0374687654882b99e729041d7#8cfca221f53069a0374687654882b99e729041d7"
|
||||
dependencies = [
|
||||
"blake2b_simd",
|
||||
"env_logger",
|
||||
@@ -2456,15 +2443,6 @@ version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.1.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.3.9"
|
||||
@@ -2784,7 +2762,7 @@ version = "0.4.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b"
|
||||
dependencies = [
|
||||
"hermit-abi 0.3.9",
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
"windows-sys 0.52.0",
|
||||
]
|
||||
@@ -3309,7 +3287,7 @@ version = "1.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
|
||||
dependencies = [
|
||||
"hermit-abi 0.3.9",
|
||||
"hermit-abi",
|
||||
"libc",
|
||||
]
|
||||
|
||||
@@ -3410,9 +3388,9 @@ checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
|
||||
|
||||
[[package]]
|
||||
name = "openssl-src"
|
||||
version = "300.2.3+3.2.1"
|
||||
version = "300.3.1+3.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5cff92b6f71555b61bb9315f7c64da3ca43d87531622120fea0195fc761b4843"
|
||||
checksum = "7259953d42a81bf137fbbd73bd30a8e1914d6dce43c2b90ed575783a22608b91"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
@@ -3711,9 +3689,9 @@ checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec"
|
||||
|
||||
[[package]]
|
||||
name = "plotters"
|
||||
version = "0.3.5"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45"
|
||||
checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"plotters-backend",
|
||||
@@ -3724,15 +3702,15 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "plotters-backend"
|
||||
version = "0.3.5"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609"
|
||||
checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7"
|
||||
|
||||
[[package]]
|
||||
name = "plotters-svg"
|
||||
version = "0.3.5"
|
||||
version = "0.3.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab"
|
||||
checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705"
|
||||
dependencies = [
|
||||
"plotters-backend",
|
||||
]
|
||||
@@ -4126,9 +4104,9 @@ checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3"
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.9.0"
|
||||
version = "1.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e4963ed1bc86e4f3ee217022bd855b297cef07fb9eac5dfa1f788b220b49b3bd"
|
||||
checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa"
|
||||
dependencies = [
|
||||
"either",
|
||||
"rayon-core",
|
||||
@@ -5155,15 +5133,6 @@ dependencies = [
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "textwrap"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d326610f408c7a4eb6f51c37c330e496b08506c9457c9d34287ecc38809fb060"
|
||||
dependencies = [
|
||||
"unicode-width",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.58"
|
||||
@@ -5463,8 +5432,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-core"
|
||||
version = "0.21.5-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50#7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bit-set",
|
||||
@@ -5487,13 +5456,14 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-data"
|
||||
version = "0.21.5-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50#7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"downcast-rs",
|
||||
"dyn-clone",
|
||||
"dyn-hash",
|
||||
"half 2.2.1",
|
||||
"half 2.4.1",
|
||||
"itertools 0.12.1",
|
||||
"lazy_static",
|
||||
"maplit",
|
||||
@@ -5508,8 +5478,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-hir"
|
||||
version = "0.21.5-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50#7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50"
|
||||
dependencies = [
|
||||
"derive-new",
|
||||
"log",
|
||||
@@ -5518,21 +5488,22 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-linalg"
|
||||
version = "0.21.5-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50#7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50"
|
||||
dependencies = [
|
||||
"cc",
|
||||
"derive-new",
|
||||
"downcast-rs",
|
||||
"dyn-clone",
|
||||
"dyn-hash",
|
||||
"half 2.2.1",
|
||||
"half 2.4.1",
|
||||
"lazy_static",
|
||||
"liquid",
|
||||
"liquid-core",
|
||||
"log",
|
||||
"num-traits",
|
||||
"paste",
|
||||
"rayon",
|
||||
"scan_fmt",
|
||||
"smallvec",
|
||||
"time",
|
||||
@@ -5543,8 +5514,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-nnef"
|
||||
version = "0.21.5-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50#7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"flate2",
|
||||
@@ -5557,8 +5528,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-onnx"
|
||||
version = "0.21.5-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50#7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"derive-new",
|
||||
@@ -5574,8 +5545,8 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "tract-onnx-opl"
|
||||
version = "0.21.5-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=05ebf550aa9922b221af4635c21a67a8d2af12a9#05ebf550aa9922b221af4635c21a67a8d2af12a9"
|
||||
version = "0.21.6-pre"
|
||||
source = "git+https://github.com/sonos/tract/?rev=7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50#7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
"log",
|
||||
|
||||
16
Cargo.toml
16
Cargo.toml
@@ -39,7 +39,6 @@ snark-verifier = { git = "https://github.com/zkonduit/snark-verifier", branch =
|
||||
halo2_solidity_verifier = { git = "https://github.com/alexander-camuto/halo2-solidity-verifier", branch = "main" }
|
||||
maybe-rayon = { version = "0.1.1", default_features = false }
|
||||
bincode = { version = "1.3.3", default_features = false }
|
||||
ark-std = { version = "^0.3.0", default-features = false }
|
||||
unzip-n = "0.1.2"
|
||||
num = "0.4.1"
|
||||
portable-atomic = "1.6.0"
|
||||
@@ -63,16 +62,13 @@ reqwest = { version = "0.12.4", default-features = false, features = [
|
||||
openssl = { version = "0.10.55", features = ["vendored"] }
|
||||
tokio-postgres = "0.7.10"
|
||||
pg_bigdecimal = "0.1.5"
|
||||
futures-util = "0.3.30"
|
||||
lazy_static = "1.4.0"
|
||||
colored_json = { version = "3.0.1", default_features = false, optional = true }
|
||||
plotters = { version = "0.3.0", default_features = false, optional = true }
|
||||
regex = { version = "1", default_features = false }
|
||||
tokio = { version = "1.35", default_features = false, features = [
|
||||
"macros",
|
||||
"rt-multi-thread"
|
||||
] }
|
||||
tokio-util = { version = "0.7.9", features = ["codec"] }
|
||||
pyo3 = { version = "0.21.2", features = [
|
||||
"extension-module",
|
||||
"abi3-py37",
|
||||
@@ -83,9 +79,8 @@ pyo3-asyncio = { git = "https://github.com/jopemachine/pyo3-asyncio/", branch="m
|
||||
"tokio-runtime",
|
||||
], default_features = false, optional = true }
|
||||
pyo3-log = { version = "0.10.0", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "05ebf550aa9922b221af4635c21a67a8d2af12a9", default_features = false, optional = true }
|
||||
tract-onnx = { git = "https://github.com/sonos/tract/", rev = "7bf303b2ae9bddd5fa6951ae95848c0d52fb7f50", default_features = false, optional = true }
|
||||
tabled = { version = "0.12.0", optional = true }
|
||||
|
||||
objc = { version = "0.2.4", optional = true }
|
||||
|
||||
|
||||
@@ -108,8 +103,10 @@ console_error_panic_hook = "0.1.7"
|
||||
wasm-bindgen-console-logger = "0.1.1"
|
||||
|
||||
|
||||
[target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dev-dependencies]
|
||||
criterion = { version = "0.5.1", features = ["html_reports"] }
|
||||
|
||||
[dev-dependencies]
|
||||
criterion = { version = "0.3", features = ["html_reports"] }
|
||||
tempfile = "3.3.0"
|
||||
lazy_static = "1.4.0"
|
||||
mnist = "0.5"
|
||||
@@ -180,7 +177,7 @@ required-features = ["ezkl"]
|
||||
|
||||
[features]
|
||||
web = ["wasm-bindgen-rayon"]
|
||||
default = ["ezkl", "mv-lookup", "no-banner"]
|
||||
default = ["ezkl", "mv-lookup", "no-banner", "parallel-poly-read"]
|
||||
onnx = ["dep:tract-onnx"]
|
||||
python-bindings = ["pyo3", "pyo3-log", "pyo3-asyncio"]
|
||||
ezkl = [
|
||||
@@ -194,6 +191,7 @@ ezkl = [
|
||||
"colored_json",
|
||||
"halo2_proofs/circuit-params",
|
||||
]
|
||||
parallel-poly-read = ["halo2_proofs/parallel-poly-read"]
|
||||
mv-lookup = [
|
||||
"halo2_proofs/mv-lookup",
|
||||
"snark-verifier/mv-lookup",
|
||||
@@ -211,7 +209,7 @@ metal = ["dep:metal", "dep:objc"]
|
||||
icicle = { git = "https://github.com/ingonyama-zk/icicle?rev=45b00fb", package = "icicle", branch = "fix/vhnat/ezkl-build-fix" }
|
||||
|
||||
[patch.'https://github.com/zkonduit/halo2']
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2?branch=ac/optional-selector-poly#54f54453cf186aa5d89579c4e7663f9a27cfb89a", package = "halo2_proofs", branch = "ac/optional-selector-poly" }
|
||||
halo2_proofs = { git = "https://github.com/zkonduit/halo2#8cfca221f53069a0374687654882b99e729041d7", package = "halo2_proofs" }
|
||||
|
||||
|
||||
[profile.release]
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
ezkl==0.0.0
|
||||
ezkl==11.6.1
|
||||
sphinx
|
||||
sphinx-rtd-theme
|
||||
sphinxcontrib-napoleon
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import ezkl
|
||||
|
||||
project = 'ezkl'
|
||||
release = '0.0.0'
|
||||
release = '11.6.1'
|
||||
version = release
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@
|
||||
"import json\n",
|
||||
"import numpy as np\n",
|
||||
"from sklearn.svm import SVC\n",
|
||||
"import sk2torch\n",
|
||||
"from hummingbird.ml import convert\n",
|
||||
"import torch\n",
|
||||
"import ezkl\n",
|
||||
"import os\n",
|
||||
@@ -59,11 +59,11 @@
|
||||
"# Train an SVM on the data and wrap it in PyTorch.\n",
|
||||
"sk_model = SVC(probability=True)\n",
|
||||
"sk_model.fit(xs, ys)\n",
|
||||
"model = sk2torch.wrap(sk_model)\n",
|
||||
"\n",
|
||||
"model = convert(sk_model, \"torch\").model\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"model\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
@@ -84,33 +84,6 @@
|
||||
"data_path = os.path.join('input.json')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "7f0ca328",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
"# Create a coordinate grid to compute a vector field on.\n",
|
||||
"spaced = np.linspace(-2, 2, num=25)\n",
|
||||
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Compute the gradients of the SVM output.\n",
|
||||
"outputs = model.predict_proba(grid_xs)[:, 1]\n",
|
||||
"(input_grads,) = torch.autograd.grad(outputs.sum(), (grid_xs,))\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# Create a quiver plot of the vector field.\n",
|
||||
"plt.quiver(\n",
|
||||
" grid_xs[:, 0].detach().numpy(),\n",
|
||||
" grid_xs[:, 1].detach().numpy(),\n",
|
||||
" input_grads[:, 0].detach().numpy(),\n",
|
||||
" input_grads[:, 1].detach().numpy(),\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -119,14 +92,14 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"\n",
|
||||
"\n",
|
||||
"spaced = np.linspace(-2, 2, num=25)\n",
|
||||
"grid_xs = torch.tensor([[x, y] for x in spaced for y in spaced], requires_grad=True)\n",
|
||||
"# export to onnx format\n",
|
||||
"# !!!!!!!!!!!!!!!!! This will flash a warning but it is fine !!!!!!!!!!!!!!!!!!!!!\n",
|
||||
"\n",
|
||||
"# Input to the model\n",
|
||||
"shape = xs.shape[1:]\n",
|
||||
"x = grid_xs[0:1]\n",
|
||||
"torch_out = model.predict(x)\n",
|
||||
"# Export the model\n",
|
||||
"torch.onnx.export(model, # model being run\n",
|
||||
" # model input (or a tuple for multiple inputs)\n",
|
||||
@@ -143,9 +116,7 @@
|
||||
"\n",
|
||||
"d = ((x).detach().numpy()).reshape([-1]).tolist()\n",
|
||||
"\n",
|
||||
"data = dict(input_shapes=[shape],\n",
|
||||
" input_data=[d],\n",
|
||||
" output_data=[o.reshape([-1]).tolist() for o in torch_out])\n",
|
||||
"data = dict(input_data=[d])\n",
|
||||
"\n",
|
||||
"# Serialize data into file:\n",
|
||||
"json.dump(data, open(\"input.json\", 'w'))\n"
|
||||
@@ -167,6 +138,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "0bee4d7f",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -220,7 +192,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 10,
|
||||
"execution_count": null,
|
||||
"id": "b1c561a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
@@ -441,9 +413,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.15"
|
||||
"version": "3.12.3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
}
|
||||
|
||||
1
examples/onnx/lenet_5/input.json
Normal file
1
examples/onnx/lenet_5/input.json
Normal file
File diff suppressed because one or more lines are too long
BIN
examples/onnx/lenet_5/network.onnx
Normal file
BIN
examples/onnx/lenet_5/network.onnx
Normal file
Binary file not shown.
@@ -250,6 +250,10 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>; 2],
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if values[0].len() != values[1].len() {
|
||||
return Err(TensorError::DimMismatch("dot".to_string()).into());
|
||||
}
|
||||
|
||||
region.flush()?;
|
||||
// time this entire function run
|
||||
let global_start = instant::Instant::now();
|
||||
@@ -257,12 +261,17 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
|
||||
let mut values = values.clone();
|
||||
|
||||
// this section has been optimized to death, don't mess with it
|
||||
let mut removal_indices = values[0].get_const_zero_indices()?;
|
||||
let second_zero_indices = values[1].get_const_zero_indices()?;
|
||||
let mut removal_indices = values[0].get_const_zero_indices();
|
||||
let second_zero_indices = values[1].get_const_zero_indices();
|
||||
removal_indices.extend(second_zero_indices);
|
||||
removal_indices.par_sort_unstable();
|
||||
removal_indices.dedup();
|
||||
|
||||
// if empty return a const
|
||||
if removal_indices.len() == values[0].len() {
|
||||
return Ok(create_zero_tensor(1));
|
||||
}
|
||||
|
||||
// is already sorted
|
||||
values[0].remove_indices(&mut removal_indices, true)?;
|
||||
values[1].remove_indices(&mut removal_indices, true)?;
|
||||
@@ -270,15 +279,6 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
|
||||
let elapsed = global_start.elapsed();
|
||||
trace!("filtering const zero indices took: {:?}", elapsed);
|
||||
|
||||
if values[0].len() != values[1].len() {
|
||||
return Err(TensorError::DimMismatch("dot".to_string()).into());
|
||||
}
|
||||
|
||||
// if empty return a const
|
||||
if values[0].is_empty() && values[1].is_empty() {
|
||||
return Ok(create_zero_tensor(1));
|
||||
}
|
||||
|
||||
let start = instant::Instant::now();
|
||||
let mut inputs = vec![];
|
||||
let block_width = config.custom_gates.output.num_inner_cols();
|
||||
@@ -343,7 +343,7 @@ pub fn dot<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
|
||||
let last_elem = output.get_slice(&[output.len() - 1..output.len()])?;
|
||||
let last_elem = output.last()?;
|
||||
|
||||
region.increment(assigned_len);
|
||||
|
||||
@@ -1779,12 +1779,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
|
||||
let mut values = values.clone();
|
||||
|
||||
// this section has been optimized to death, don't mess with it
|
||||
let mut removal_indices = values[0].get_const_zero_indices()?;
|
||||
removal_indices.par_sort_unstable();
|
||||
removal_indices.dedup();
|
||||
|
||||
// is already sorted
|
||||
values[0].remove_indices(&mut removal_indices, true)?;
|
||||
values[0].remove_const_zero_values();
|
||||
|
||||
let elapsed = global_start.elapsed();
|
||||
trace!("filtering const zero indices took: {:?}", elapsed);
|
||||
@@ -1841,7 +1836,7 @@ pub fn sum<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>(
|
||||
}
|
||||
}
|
||||
|
||||
let last_elem = output.get_slice(&[output.len() - 1..output.len()])?;
|
||||
let last_elem = output.last()?;
|
||||
|
||||
region.increment(assigned_len);
|
||||
|
||||
@@ -1884,7 +1879,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>
|
||||
let global_start = instant::Instant::now();
|
||||
|
||||
// this section has been optimized to death, don't mess with it
|
||||
let removal_indices = values[0].get_const_zero_indices()?;
|
||||
let removal_indices = values[0].get_const_zero_indices();
|
||||
|
||||
let elapsed = global_start.elapsed();
|
||||
trace!("finding const zero indices took: {:?}", elapsed);
|
||||
@@ -1945,7 +1940,7 @@ pub fn prod<F: PrimeField + TensorType + PartialOrd + std::hash::Hash + IntoI64>
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
}
|
||||
|
||||
let last_elem = output.get_slice(&[output.len() - 1..output.len()])?;
|
||||
let last_elem = output.last()?;
|
||||
|
||||
region.increment(assigned_len);
|
||||
|
||||
@@ -2256,22 +2251,22 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
let orig_lhs = lhs.clone();
|
||||
let orig_rhs = rhs.clone();
|
||||
|
||||
// get indices of zeros
|
||||
let first_zero_indices = lhs.get_const_zero_indices()?;
|
||||
let second_zero_indices = rhs.get_const_zero_indices()?;
|
||||
let mut removal_indices = match op {
|
||||
let start = instant::Instant::now();
|
||||
let first_zero_indices = HashSet::from_iter(lhs.get_const_zero_indices());
|
||||
let second_zero_indices = HashSet::from_iter(rhs.get_const_zero_indices());
|
||||
|
||||
let removal_indices = match op {
|
||||
BaseOp::Add | BaseOp::Mult => {
|
||||
let mut removal_indices = first_zero_indices.clone();
|
||||
removal_indices.extend(second_zero_indices.clone());
|
||||
removal_indices
|
||||
// join the zero indices
|
||||
first_zero_indices
|
||||
.union(&second_zero_indices)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
BaseOp::Sub => second_zero_indices.clone(),
|
||||
_ => return Err(CircuitError::UnsupportedOp),
|
||||
};
|
||||
removal_indices.dedup();
|
||||
|
||||
let removal_indices: HashSet<&usize> = HashSet::from_iter(removal_indices.iter());
|
||||
let removal_indices_ptr = &removal_indices;
|
||||
trace!("setting up indices took {:?}", start.elapsed());
|
||||
|
||||
if lhs.len() != rhs.len() {
|
||||
return Err(CircuitError::DimMismatch(format!(
|
||||
@@ -2280,20 +2275,19 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
)));
|
||||
}
|
||||
|
||||
let mut inputs = vec![];
|
||||
for (i, input) in [lhs.clone(), rhs.clone()].iter().enumerate() {
|
||||
let inp = {
|
||||
let inputs = [lhs.clone(), rhs.clone()]
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, input)| {
|
||||
let res = region.assign_with_omissions(
|
||||
&config.custom_gates.inputs[i],
|
||||
input,
|
||||
removal_indices_ptr,
|
||||
&removal_indices,
|
||||
)?;
|
||||
|
||||
res.get_inner()?
|
||||
};
|
||||
|
||||
inputs.push(inp);
|
||||
}
|
||||
Ok(res.get_inner()?)
|
||||
})
|
||||
.collect::<Result<Vec<_>, CircuitError>>()?;
|
||||
|
||||
// Now we can assign the dot product
|
||||
// time the calc
|
||||
@@ -2308,15 +2302,20 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
error!("{}", e);
|
||||
halo2_proofs::plonk::Error::Synthesis
|
||||
})?;
|
||||
let elapsed = start.elapsed();
|
||||
trace!("pairwise {} calc took {:?}", op.as_str(), start.elapsed());
|
||||
|
||||
let assigned_len = inputs[0].len() - removal_indices.len();
|
||||
let start = instant::Instant::now();
|
||||
let assigned_len = op_result.len() - removal_indices.len();
|
||||
let mut output = region.assign_with_omissions(
|
||||
&config.custom_gates.output,
|
||||
&op_result.into(),
|
||||
removal_indices_ptr,
|
||||
&removal_indices,
|
||||
)?;
|
||||
trace!("pairwise {} calc took {:?}", op.as_str(), elapsed);
|
||||
trace!(
|
||||
"pairwise {} input assign took {:?}",
|
||||
op.as_str(),
|
||||
start.elapsed()
|
||||
);
|
||||
|
||||
// Enable the selectors
|
||||
if !region.is_dummy() {
|
||||
@@ -2337,16 +2336,11 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
let a_tensor = orig_lhs.get_inner_tensor()?;
|
||||
let b_tensor = orig_rhs.get_inner_tensor()?;
|
||||
|
||||
let first_zero_indices: HashSet<&usize> = HashSet::from_iter(first_zero_indices.iter());
|
||||
let second_zero_indices: HashSet<&usize> = HashSet::from_iter(second_zero_indices.iter());
|
||||
|
||||
trace!("setting up indices took {:?}", start.elapsed());
|
||||
|
||||
// infill the zero indices with the correct values from values[0] or values[1]
|
||||
if !removal_indices_ptr.is_empty() {
|
||||
if !removal_indices.is_empty() {
|
||||
output
|
||||
.get_inner_tensor_mut()?
|
||||
.par_enum_map_mut_filtered(removal_indices_ptr, |i| {
|
||||
.par_enum_map_mut_filtered(&removal_indices, |i| {
|
||||
let val = match op {
|
||||
BaseOp::Add => {
|
||||
let a_is_null = first_zero_indices.contains(&i);
|
||||
@@ -2386,6 +2380,7 @@ pub(crate) fn pairwise<F: PrimeField + TensorType + PartialOrd + std::hash::Hash
|
||||
end,
|
||||
region.row()
|
||||
);
|
||||
trace!("----------------------------");
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
@@ -3777,7 +3772,7 @@ pub(crate) fn boolean_identity<
|
||||
values: &[ValTensor<F>; 1],
|
||||
assign: bool,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
let output = if assign || !values[0].get_const_indices()?.is_empty() {
|
||||
let output = if assign || !values[0].get_const_indices().is_empty() {
|
||||
// get zero constants indices
|
||||
let output = region.assign(&config.custom_gates.output, &values[0])?;
|
||||
region.increment(output.len());
|
||||
@@ -3942,11 +3937,10 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
|
||||
let x = values[0].clone();
|
||||
|
||||
let removal_indices = values[0].get_const_indices()?;
|
||||
let removal_indices: HashSet<&usize> = HashSet::from_iter(removal_indices.iter());
|
||||
let removal_indices_ptr = &removal_indices;
|
||||
let removal_indices = values[0].get_const_indices();
|
||||
let removal_indices: HashSet<usize> = HashSet::from_iter(removal_indices);
|
||||
|
||||
let w = region.assign_with_omissions(&config.static_lookups.input, &x, removal_indices_ptr)?;
|
||||
let w = region.assign_with_omissions(&config.static_lookups.input, &x, &removal_indices)?;
|
||||
|
||||
let output = w.get_inner_tensor()?.par_enum_map(|i, e| {
|
||||
Ok::<_, TensorError>(if let Some(f) = e.get_felt_eval() {
|
||||
@@ -3964,7 +3958,7 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
let mut output = region.assign_with_omissions(
|
||||
&config.static_lookups.output,
|
||||
&output.into(),
|
||||
removal_indices_ptr,
|
||||
&removal_indices,
|
||||
)?;
|
||||
|
||||
let is_dummy = region.is_dummy();
|
||||
@@ -3994,11 +3988,7 @@ pub(crate) fn nonlinearity<F: PrimeField + TensorType + PartialOrd + std::hash::
|
||||
})?
|
||||
.into();
|
||||
|
||||
region.assign_with_omissions(
|
||||
&config.static_lookups.index,
|
||||
&table_index,
|
||||
removal_indices_ptr,
|
||||
)?;
|
||||
region.assign_with_omissions(&config.static_lookups.index, &table_index, &removal_indices)?;
|
||||
|
||||
if !is_dummy {
|
||||
(0..assigned_len)
|
||||
|
||||
@@ -9,6 +9,8 @@ use halo2_proofs::{
|
||||
plonk::{Error, Selector},
|
||||
};
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use maybe_rayon::iter::ParallelExtend;
|
||||
use portable_atomic::AtomicI64 as AtomicInt;
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
@@ -515,18 +517,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if let Some(region) = &self.region {
|
||||
var.assign(
|
||||
Ok(var.assign(
|
||||
&mut region.borrow_mut(),
|
||||
self.linear_coord,
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
)?)
|
||||
} else {
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.extend(values_map);
|
||||
self.assigned_constants.par_extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
@@ -542,18 +544,18 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if let Some(region) = &self.region {
|
||||
var.assign(
|
||||
Ok(var.assign(
|
||||
&mut region.borrow_mut(),
|
||||
self.combined_dynamic_shuffle_coord(),
|
||||
values,
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
)?)
|
||||
} else {
|
||||
if !values.is_instance() {
|
||||
let values_map = values.create_constants_map_iterator();
|
||||
self.assigned_constants.extend(values_map);
|
||||
self.assigned_constants.par_extend(values_map);
|
||||
}
|
||||
Ok(values.clone())
|
||||
}
|
||||
@@ -564,7 +566,7 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
self.assign_dynamic_lookup(var, values)
|
||||
}
|
||||
|
||||
@@ -573,27 +575,24 @@ impl<'a, F: PrimeField + TensorType + PartialOrd + std::hash::Hash> RegionCtx<'a
|
||||
&mut self,
|
||||
var: &VarTensor,
|
||||
values: &ValTensor<F>,
|
||||
ommissions: &HashSet<&usize>,
|
||||
) -> Result<ValTensor<F>, Error> {
|
||||
ommissions: &HashSet<usize>,
|
||||
) -> Result<ValTensor<F>, CircuitError> {
|
||||
if let Some(region) = &self.region {
|
||||
var.assign_with_omissions(
|
||||
Ok(var.assign_with_omissions(
|
||||
&mut region.borrow_mut(),
|
||||
self.linear_coord,
|
||||
values,
|
||||
ommissions,
|
||||
&mut self.assigned_constants,
|
||||
)
|
||||
)?)
|
||||
} else {
|
||||
let inner_tensor = values.get_inner_tensor().unwrap();
|
||||
let mut values_map = values.create_constants_map();
|
||||
let mut values_clone = values.clone();
|
||||
let mut indices = ommissions.clone().into_iter().collect_vec();
|
||||
values_clone.remove_indices(&mut indices, false)?;
|
||||
|
||||
for o in ommissions {
|
||||
if let ValType::Constant(value) = inner_tensor.get_flat_index(**o) {
|
||||
values_map.remove(&value);
|
||||
}
|
||||
}
|
||||
let values_map = values.create_constants_map();
|
||||
|
||||
self.assigned_constants.extend(values_map);
|
||||
self.assigned_constants.par_extend(values_map);
|
||||
|
||||
Ok(values.clone())
|
||||
}
|
||||
|
||||
@@ -379,9 +379,9 @@ pub enum Commands {
|
||||
#[arg(long = "target", default_value = DEFAULT_CALIBRATION_TARGET, value_hint = clap::ValueHint::Other)]
|
||||
/// Target for calibration. Set to "resources" to optimize for computational resource. Otherwise, set to "accuracy" to optimize for accuracy.
|
||||
target: CalibrationTarget,
|
||||
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be 2^k * lookup_safety_margin. larger = safer but slower
|
||||
/// the lookup safety margin to use for calibration. if the max lookup is 2^k, then the max lookup will be ceil(2^k * lookup_safety_margin). larger = safer but slower
|
||||
#[arg(long, default_value = DEFAULT_LOOKUP_SAFETY_MARGIN, value_hint = clap::ValueHint::Other)]
|
||||
lookup_safety_margin: i64,
|
||||
lookup_safety_margin: f64,
|
||||
/// Optional scales to specifically try for calibration. Example, --scales 0,4
|
||||
#[arg(long, value_delimiter = ',', allow_hyphen_values = true, value_hint = clap::ValueHint::Other)]
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
|
||||
@@ -1013,7 +1013,7 @@ pub(crate) async fn calibrate(
|
||||
data: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
target: CalibrationTarget,
|
||||
lookup_safety_margin: i64,
|
||||
lookup_safety_margin: f64,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
only_range_check_rebase: bool,
|
||||
@@ -1502,10 +1502,10 @@ pub(crate) async fn create_evm_vk(
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub(crate) async fn create_evm_data_attestation(
|
||||
settings_path: PathBuf,
|
||||
_sol_code_path: PathBuf,
|
||||
_abi_path: PathBuf,
|
||||
_input: PathBuf,
|
||||
_witness: Option<PathBuf>,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
input: PathBuf,
|
||||
witness: Option<PathBuf>,
|
||||
) -> Result<String, EZKLError> {
|
||||
#[allow(unused_imports)]
|
||||
use crate::graph::{DataSource, VarVisibility};
|
||||
@@ -1517,7 +1517,7 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
trace!("params computed");
|
||||
|
||||
// if input is not provided, we just instantiate dummy input data
|
||||
let data = GraphData::from_path(_input).unwrap_or(GraphData::new(DataSource::File(vec![])));
|
||||
let data = GraphData::from_path(input).unwrap_or(GraphData::new(DataSource::File(vec![])));
|
||||
|
||||
let output_data = if let Some(DataSource::OnChain(source)) = data.output_data {
|
||||
if visibility.output.is_private() {
|
||||
@@ -1552,7 +1552,7 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
|| settings.run_args.output_visibility == Visibility::KZGCommit
|
||||
|| settings.run_args.param_visibility == Visibility::KZGCommit
|
||||
{
|
||||
let witness = GraphWitness::from_path(_witness.unwrap_or(DEFAULT_WITNESS.into()))?;
|
||||
let witness = GraphWitness::from_path(witness.unwrap_or(DEFAULT_WITNESS.into()))?;
|
||||
let commitments = witness.get_polycommitments();
|
||||
let proof_first_bytes = get_proof_commitments::<
|
||||
KZGCommitmentScheme<Bn256>,
|
||||
@@ -1566,12 +1566,12 @@ pub(crate) async fn create_evm_data_attestation(
|
||||
};
|
||||
|
||||
let output = fix_da_sol(input_data, output_data, commitment_bytes)?;
|
||||
let mut f = File::create(_sol_code_path.clone())?;
|
||||
let mut f = File::create(sol_code_path.clone())?;
|
||||
let _ = f.write(output.as_bytes());
|
||||
// fetch abi of the contract
|
||||
let (abi, _, _) = get_contract_artifacts(_sol_code_path, "DataAttestation", 0).await?;
|
||||
let (abi, _, _) = get_contract_artifacts(sol_code_path, "DataAttestation", 0).await?;
|
||||
// save abi to file
|
||||
serde_json::to_writer(std::fs::File::create(_abi_path)?, &abi)?;
|
||||
serde_json::to_writer(std::fs::File::create(abi_path)?, &abi)?;
|
||||
|
||||
Ok(String::new())
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ pub fn i64_to_felt<F: PrimeField>(x: i64) -> F {
|
||||
if x >= 0 {
|
||||
F::from_u128(x as u128)
|
||||
} else {
|
||||
-F::from_u128((-x) as u128)
|
||||
-F::from_u128(x.saturating_neg() as u128)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -50,7 +50,7 @@ pub enum GraphError {
|
||||
/// Tract error
|
||||
#[cfg(not(all(target_arch = "wasm32", target_os = "unknown")))]
|
||||
#[error("[tract] {0}")]
|
||||
TractError(#[from] tract_onnx::tract_core::anyhow::Error),
|
||||
TractError(#[from] tract_onnx::prelude::TractError),
|
||||
/// Packing exponent is too large
|
||||
#[error("largest packing exponent exceeds max. try reducing the scale")]
|
||||
PackingExponent,
|
||||
|
||||
@@ -1034,10 +1034,10 @@ impl GraphCircuit {
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: i64) -> Range {
|
||||
fn calc_safe_lookup_range(min_max_lookup: Range, lookup_safety_margin: f64) -> Range {
|
||||
(
|
||||
lookup_safety_margin * min_max_lookup.0,
|
||||
lookup_safety_margin * min_max_lookup.1,
|
||||
(lookup_safety_margin * min_max_lookup.0 as f64).floor() as i64,
|
||||
(lookup_safety_margin * min_max_lookup.1 as f64).ceil() as i64,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1070,7 +1070,7 @@ impl GraphCircuit {
|
||||
min_max_lookup: Range,
|
||||
max_range_size: i64,
|
||||
max_logrows: Option<u32>,
|
||||
lookup_safety_margin: i64,
|
||||
lookup_safety_margin: f64,
|
||||
) -> Result<(), GraphError> {
|
||||
// load the max logrows
|
||||
let max_logrows = max_logrows.unwrap_or(MAX_PUBLIC_SRS);
|
||||
@@ -1080,9 +1080,13 @@ impl GraphCircuit {
|
||||
|
||||
let safe_lookup_range = Self::calc_safe_lookup_range(min_max_lookup, lookup_safety_margin);
|
||||
|
||||
let lookup_size = (safe_lookup_range.1 - safe_lookup_range.0).abs();
|
||||
// check if subtraction overflows
|
||||
|
||||
let lookup_size =
|
||||
(safe_lookup_range.1.saturating_sub(safe_lookup_range.0)).saturating_abs();
|
||||
// check if has overflowed max lookup input
|
||||
if lookup_size > MAX_LOOKUP_ABS / lookup_safety_margin {
|
||||
|
||||
if lookup_size > (MAX_LOOKUP_ABS as f64 / lookup_safety_margin).floor() as i64 {
|
||||
return Err(GraphError::LookupRangeTooLarge(
|
||||
lookup_size.unsigned_abs() as usize
|
||||
));
|
||||
|
||||
@@ -85,6 +85,34 @@ pub fn multiplier_to_scale(mult: f64) -> crate::Scale {
|
||||
mult.log2().round() as crate::Scale
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// extract padding from a onnx node.
|
||||
pub fn extract_padding(
|
||||
pool_spec: &PoolSpec,
|
||||
num_dims: usize,
|
||||
) -> Result<Vec<(usize, usize)>, GraphError> {
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
PaddingSpec::Valid => vec![(0, 0); num_dims],
|
||||
_ => {
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
Ok(padding)
|
||||
}
|
||||
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
/// Extracts the strides from a onnx node.
|
||||
pub fn extract_strides(pool_spec: &PoolSpec) -> Result<Vec<usize>, GraphError> {
|
||||
Ok(pool_spec
|
||||
.strides
|
||||
.clone()
|
||||
.ok_or(GraphError::MissingParams("stride".to_string()))?
|
||||
.to_vec())
|
||||
}
|
||||
|
||||
/// Gets the shape of a onnx node's outlets.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
pub fn node_output_shapes(
|
||||
@@ -255,6 +283,11 @@ pub fn new_op_from_onnx(
|
||||
.flat_map(|x| x.out_scales())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let input_dims = inputs
|
||||
.iter()
|
||||
.flat_map(|x| x.out_dims())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut replace_const = |scale: crate::Scale,
|
||||
index: usize,
|
||||
default_op: SupportedOp|
|
||||
@@ -1073,18 +1106,8 @@ pub fn new_op_from_onnx(
|
||||
));
|
||||
}
|
||||
|
||||
let stride = pool_spec
|
||||
.strides
|
||||
.clone()
|
||||
.ok_or(GraphError::MissingParams("stride".to_string()))?;
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, input_dims[0].len())?;
|
||||
let kernel_shape = &pool_spec.kernel_shape;
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::MaxPool {
|
||||
@@ -1151,21 +1174,10 @@ pub fn new_op_from_onnx(
|
||||
));
|
||||
}
|
||||
|
||||
let stride = match conv_node.pool_spec.strides.clone() {
|
||||
Some(s) => s.to_vec(),
|
||||
None => {
|
||||
return Err(GraphError::MissingParams("strides".to_string()));
|
||||
}
|
||||
};
|
||||
let pool_spec = &conv_node.pool_spec;
|
||||
|
||||
let padding = match &conv_node.pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, input_dims[0].len())?;
|
||||
|
||||
// if bias exists then rescale it to the input + kernel scale
|
||||
if input_scales.len() == 3 {
|
||||
@@ -1214,21 +1226,10 @@ pub fn new_op_from_onnx(
|
||||
));
|
||||
}
|
||||
|
||||
let stride = match deconv_node.pool_spec.strides.clone() {
|
||||
Some(s) => s.to_vec(),
|
||||
None => {
|
||||
return Err(GraphError::MissingParams("strides".to_string()));
|
||||
}
|
||||
};
|
||||
let padding = match &deconv_node.pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
let pool_spec = &deconv_node.pool_spec;
|
||||
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, input_dims[0].len())?;
|
||||
// if bias exists then rescale it to the input + kernel scale
|
||||
if input_scales.len() == 3 {
|
||||
let bias_scale = input_scales[2];
|
||||
@@ -1339,18 +1340,8 @@ pub fn new_op_from_onnx(
|
||||
));
|
||||
}
|
||||
|
||||
let stride = pool_spec
|
||||
.strides
|
||||
.clone()
|
||||
.ok_or(GraphError::MissingParams("stride".to_string()))?;
|
||||
let padding = match &pool_spec.padding {
|
||||
PaddingSpec::Explicit(b, a) | PaddingSpec::ExplicitOnnxPool(b, a, _) => {
|
||||
b.iter().zip(a.iter()).map(|(b, a)| (*b, *a)).collect()
|
||||
}
|
||||
_ => {
|
||||
return Err(GraphError::MissingParams("padding".to_string()));
|
||||
}
|
||||
};
|
||||
let stride = extract_strides(pool_spec)?;
|
||||
let padding = extract_padding(pool_spec, input_dims[0].len())?;
|
||||
|
||||
SupportedOp::Hybrid(HybridOp::SumPool {
|
||||
padding,
|
||||
|
||||
@@ -887,7 +887,7 @@ fn calibrate_settings(
|
||||
model: PathBuf,
|
||||
settings: PathBuf,
|
||||
target: CalibrationTarget,
|
||||
lookup_safety_margin: i64,
|
||||
lookup_safety_margin: f64,
|
||||
scales: Option<Vec<crate::Scale>>,
|
||||
scale_rebase_multiplier: Vec<u32>,
|
||||
max_logrows: Option<u32>,
|
||||
@@ -1491,7 +1491,7 @@ fn encode_evm_calldata<'a>(
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// render_vk_separately: bool
|
||||
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create-evm-vk command
|
||||
/// Whether the verifier key should be rendered as a separate contract. We recommend disabling selector compression if this is enabled. To save the verifier key as a separate contract, set this to true and then call the create_evm_vk command
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
@@ -1533,6 +1533,56 @@ fn create_evm_verifier(
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates an Evm verifer key. This command should be called after create_evm_verifier with the render_vk_separately arg set to true. By rendering a verification key separately you can reuse the same verifier for similar circuit setups with different verifying keys, helping to reduce the amount of state our verifiers store on the blockchain.
|
||||
///
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// vk_path: str
|
||||
/// The path to the verification key file
|
||||
///
|
||||
/// settings_path: str
|
||||
/// The path to the settings file
|
||||
///
|
||||
/// sol_code_path: str
|
||||
/// The path to the create the solidity verifying key.
|
||||
///
|
||||
/// abi_path: str
|
||||
/// The path to create the ABI for the solidity verifier
|
||||
///
|
||||
/// srs_path: str
|
||||
/// The path to the SRS file
|
||||
///
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
///
|
||||
#[pyfunction(signature = (
|
||||
vk_path=PathBuf::from(DEFAULT_VK),
|
||||
settings_path=PathBuf::from(DEFAULT_SETTINGS),
|
||||
sol_code_path=PathBuf::from(DEFAULT_VK_SOL),
|
||||
abi_path=PathBuf::from(DEFAULT_VERIFIER_ABI),
|
||||
srs_path=None
|
||||
))]
|
||||
fn create_evm_vk(
|
||||
py: Python,
|
||||
vk_path: PathBuf,
|
||||
settings_path: PathBuf,
|
||||
sol_code_path: PathBuf,
|
||||
abi_path: PathBuf,
|
||||
srs_path: Option<PathBuf>,
|
||||
) -> PyResult<Bound<'_, PyAny>> {
|
||||
pyo3_asyncio::tokio::future_into_py(py, async move {
|
||||
crate::execute::create_evm_vk(vk_path, srs_path, settings_path, sol_code_path, abi_path)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let err_str = format!("Failed to run create_evm_verifier: {}", e);
|
||||
PyRuntimeError::new_err(err_str)
|
||||
})?;
|
||||
|
||||
Ok(true)
|
||||
})
|
||||
}
|
||||
|
||||
/// Creates an EVM compatible data attestation verifier, you will need solc installed in your environment to run this
|
||||
///
|
||||
/// Arguments
|
||||
@@ -1762,7 +1812,7 @@ fn deploy_da_evm(
|
||||
/// Arguments
|
||||
/// ---------
|
||||
/// addr_verifier: str
|
||||
/// The path to verifier contract's address
|
||||
/// The verifier contract's address as a hex string
|
||||
///
|
||||
/// proof_path: str
|
||||
/// The path to the proof file (generated using the prove command)
|
||||
@@ -1774,7 +1824,7 @@ fn deploy_da_evm(
|
||||
/// does the verifier use data attestation ?
|
||||
///
|
||||
/// addr_vk: str
|
||||
///
|
||||
/// The addess of the separate VK contract (if the verifier key is rendered as a separate contract)
|
||||
/// Returns
|
||||
/// -------
|
||||
/// bool
|
||||
@@ -1925,6 +1975,7 @@ fn ezkl(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(compile_circuit, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(verify_aggr, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_verifier, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(create_evm_vk, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_vk_evm, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(deploy_da_evm, m)?)?;
|
||||
|
||||
@@ -1281,6 +1281,30 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
Ok(t)
|
||||
}
|
||||
|
||||
/// Get last elem from Tensor
|
||||
/// ```
|
||||
/// use ezkl::tensor::Tensor;
|
||||
/// let mut a = Tensor::<i32>::new(Some(&[1, 2, 3]), &[3]).unwrap();
|
||||
/// let mut b = Tensor::<i32>::new(Some(&[3]), &[1]).unwrap();
|
||||
///
|
||||
/// assert_eq!(a.last().unwrap(), b);
|
||||
/// ```
|
||||
pub fn last(&self) -> Result<Tensor<T>, TensorError>
|
||||
where
|
||||
T: Send + Sync,
|
||||
{
|
||||
let res = match self.inner.last() {
|
||||
Some(e) => e.clone(),
|
||||
None => {
|
||||
return Err(TensorError::DimError(
|
||||
"Cannot get last element of empty tensor".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Tensor::new(Some(&[res]), &[1])
|
||||
}
|
||||
|
||||
/// Maps a function to tensors and enumerates in parallel
|
||||
/// ```
|
||||
/// use ezkl::tensor::{Tensor, TensorError};
|
||||
@@ -1293,7 +1317,7 @@ impl<T: Clone + TensorType> Tensor<T> {
|
||||
E: Error + std::marker::Send + std::marker::Sync,
|
||||
>(
|
||||
&mut self,
|
||||
filter_indices: &std::collections::HashSet<&usize>,
|
||||
filter_indices: &std::collections::HashSet<usize>,
|
||||
f: F,
|
||||
) -> Result<(), E>
|
||||
where
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
use core::{iter::FilterMap, slice::Iter};
|
||||
|
||||
use crate::circuit::region::ConstantsMap;
|
||||
use maybe_rayon::slice::Iter;
|
||||
|
||||
use super::{
|
||||
ops::{intercalate_values, pad, resize},
|
||||
*,
|
||||
};
|
||||
use halo2_proofs::{arithmetic::Field, circuit::Cell, plonk::Instance};
|
||||
use maybe_rayon::iter::{FilterMap, IntoParallelIterator, ParallelIterator};
|
||||
|
||||
pub(crate) fn create_constant_tensor<
|
||||
F: PrimeField + TensorType + std::marker::Send + std::marker::Sync + PartialOrd,
|
||||
@@ -460,7 +460,7 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
&self,
|
||||
) -> FilterMap<Iter<'_, ValType<F>>, fn(&ValType<F>) -> Option<(F, ValType<F>)>> {
|
||||
match self {
|
||||
ValTensor::Value { inner, .. } => inner.iter().filter_map(|x| {
|
||||
ValTensor::Value { inner, .. } => inner.par_iter().filter_map(|x| {
|
||||
if let ValType::Constant(v) = x {
|
||||
Some((*v, x.clone()))
|
||||
} else {
|
||||
@@ -573,6 +573,27 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Calls `get_slice` on the inner tensor.
|
||||
pub fn last(&self) -> Result<ValTensor<F>, TensorError> {
|
||||
let slice = match self {
|
||||
ValTensor::Value {
|
||||
inner: v,
|
||||
dims: _,
|
||||
scale,
|
||||
} => {
|
||||
let inner = v.last()?;
|
||||
let dims = inner.dims().to_vec();
|
||||
ValTensor::Value {
|
||||
inner,
|
||||
dims,
|
||||
scale: *scale,
|
||||
}
|
||||
}
|
||||
_ => return Err(TensorError::WrongMethod),
|
||||
};
|
||||
Ok(slice)
|
||||
}
|
||||
|
||||
/// Calls `get_slice` on the inner tensor.
|
||||
pub fn get_slice(&self, indices: &[Range<usize>]) -> Result<ValTensor<F>, TensorError> {
|
||||
if indices.iter().map(|x| x.end - x.start).collect::<Vec<_>>() == self.dims() {
|
||||
@@ -753,43 +774,72 @@ impl<F: PrimeField + TensorType + PartialOrd + std::hash::Hash> ValTensor<F> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// gets constants
|
||||
pub fn get_const_zero_indices(&self) -> Result<Vec<usize>, TensorError> {
|
||||
/// remove constant zero values constants
|
||||
pub fn remove_const_zero_values(&mut self) {
|
||||
match self {
|
||||
ValTensor::Value { inner: v, .. } => {
|
||||
let mut indices = vec![];
|
||||
for (i, e) in v.iter().enumerate() {
|
||||
if let ValType::Constant(r) = e {
|
||||
if *r == F::ZERO {
|
||||
indices.push(i);
|
||||
ValTensor::Value { inner: v, dims, .. } => {
|
||||
*v = v
|
||||
.clone()
|
||||
.into_par_iter()
|
||||
.filter_map(|e| {
|
||||
if let ValType::Constant(r) = e {
|
||||
if r == F::ZERO {
|
||||
return None;
|
||||
}
|
||||
} else if let ValType::AssignedConstant(_, r) = e {
|
||||
if r == F::ZERO {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
} else if let ValType::AssignedConstant(_, r) = e {
|
||||
if *r == F::ZERO {
|
||||
indices.push(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(indices)
|
||||
Some(e)
|
||||
})
|
||||
.collect();
|
||||
*dims = v.dims().to_vec();
|
||||
}
|
||||
ValTensor::Instance { .. } => Ok(vec![]),
|
||||
ValTensor::Instance { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
/// gets constants
|
||||
pub fn get_const_indices(&self) -> Result<Vec<usize>, TensorError> {
|
||||
pub fn get_const_zero_indices(&self) -> Vec<usize> {
|
||||
match self {
|
||||
ValTensor::Value { inner: v, .. } => {
|
||||
let mut indices = vec![];
|
||||
for (i, e) in v.iter().enumerate() {
|
||||
if let ValType::Constant(_) = e {
|
||||
indices.push(i);
|
||||
} else if let ValType::AssignedConstant(_, _) = e {
|
||||
indices.push(i);
|
||||
ValTensor::Value { inner: v, .. } => v
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, e)| {
|
||||
if let ValType::Constant(r) = e {
|
||||
if *r == F::ZERO {
|
||||
return Some(i);
|
||||
}
|
||||
} else if let ValType::AssignedConstant(_, r) = e {
|
||||
if *r == F::ZERO {
|
||||
return Some(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(indices)
|
||||
}
|
||||
ValTensor::Instance { .. } => Ok(vec![]),
|
||||
None
|
||||
})
|
||||
.collect(),
|
||||
ValTensor::Instance { .. } => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// gets constants
|
||||
pub fn get_const_indices(&self) -> Vec<usize> {
|
||||
match self {
|
||||
ValTensor::Value { inner: v, .. } => v
|
||||
.par_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(i, e)| {
|
||||
if let ValType::Constant(_) = e {
|
||||
Some(i)
|
||||
} else if let ValType::AssignedConstant(_, _) = e {
|
||||
Some(i)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
ValTensor::Instance { .. } => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -319,7 +319,7 @@ impl VarTensor {
|
||||
region: &mut Region<F>,
|
||||
offset: usize,
|
||||
values: &ValTensor<F>,
|
||||
omissions: &HashSet<&usize>,
|
||||
omissions: &HashSet<usize>,
|
||||
constants: &mut ConstantsMap<F>,
|
||||
) -> Result<ValTensor<F>, halo2_proofs::plonk::Error> {
|
||||
let mut assigned_coord = 0;
|
||||
|
||||
@@ -200,7 +200,7 @@ mod native_tests {
|
||||
"1l_tiny_div",
|
||||
];
|
||||
|
||||
const TESTS: [&str; 93] = [
|
||||
const TESTS: [&str; 94] = [
|
||||
"1l_mlp", //0
|
||||
"1l_slice",
|
||||
"1l_concat",
|
||||
@@ -298,6 +298,7 @@ mod native_tests {
|
||||
"1l_lppool",
|
||||
"lstm_large", // 91
|
||||
"lstm_medium", // 92
|
||||
"lenet_5", // 93
|
||||
];
|
||||
|
||||
const WASM_TESTS: [&str; 46] = [
|
||||
@@ -536,7 +537,7 @@ mod native_tests {
|
||||
}
|
||||
});
|
||||
|
||||
seq!(N in 0..=92 {
|
||||
seq!(N in 0..=93 {
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
#[ignore]
|
||||
|
||||
@@ -423,6 +423,74 @@ async def test_create_evm_verifier():
|
||||
assert res == True
|
||||
assert os.path.isfile(sol_code_path)
|
||||
|
||||
async def test_create_evm_verifier_separate_vk():
|
||||
"""
|
||||
Create EVM a verifier with solidity code and separate vk
|
||||
In order to run this test you will need to install solc in your environment
|
||||
"""
|
||||
vk_path = os.path.join(folder_path, 'test_evm.vk')
|
||||
settings_path = os.path.join(folder_path, 'settings.json')
|
||||
sol_code_path = os.path.join(folder_path, 'test_separate.sol')
|
||||
vk_code_path = os.path.join(folder_path, 'test_vk.sol')
|
||||
abi_path = os.path.join(folder_path, 'test_separate.abi')
|
||||
abi_vk_path = os.path.join(folder_path, 'test_vk_separate.abi')
|
||||
proof_path = os.path.join(folder_path, 'test_evm.pf')
|
||||
calldata_path = os.path.join(folder_path, 'calldata.bytes')
|
||||
|
||||
# # res is now a vector of bytes
|
||||
# res = ezkl.encode_evm_calldata(proof_path, calldata_path)
|
||||
|
||||
# assert os.path.isfile(calldata_path)
|
||||
# assert len(res) > 0
|
||||
|
||||
|
||||
res = await ezkl.create_evm_verifier(
|
||||
vk_path,
|
||||
settings_path,
|
||||
sol_code_path,
|
||||
abi_path,
|
||||
srs_path=srs_path,
|
||||
render_vk_seperately=True
|
||||
)
|
||||
|
||||
res = await ezkl.create_evm_vk(
|
||||
vk_path,
|
||||
settings_path,
|
||||
vk_code_path,
|
||||
abi_vk_path,
|
||||
srs_path=srs_path,
|
||||
)
|
||||
|
||||
assert res == True
|
||||
assert os.path.isfile(sol_code_path)
|
||||
|
||||
|
||||
async def test_deploy_evm_separate_vk():
|
||||
"""
|
||||
Test deployment of the separate verifier smart contract + vk
|
||||
In order to run this you will need to install solc in your environment
|
||||
"""
|
||||
addr_path_verifier = os.path.join(folder_path, 'address_separate.json')
|
||||
addr_path_vk = os.path.join(folder_path, 'address_vk.json')
|
||||
sol_code_path = os.path.join(folder_path, 'test_separate.sol')
|
||||
vk_code_path = os.path.join(folder_path, 'test_vk.sol')
|
||||
|
||||
# TODO: without optimization there will be out of gas errors
|
||||
# sol_code_path = os.path.join(folder_path, 'test.sol')
|
||||
|
||||
res = await ezkl.deploy_evm(
|
||||
addr_path_verifier,
|
||||
sol_code_path,
|
||||
rpc_url=anvil_url,
|
||||
)
|
||||
|
||||
res = await ezkl.deploy_vk_evm(
|
||||
addr_path_vk,
|
||||
vk_code_path,
|
||||
rpc_url=anvil_url,
|
||||
)
|
||||
|
||||
assert res == True
|
||||
|
||||
async def test_deploy_evm():
|
||||
"""
|
||||
@@ -503,6 +571,47 @@ async def test_verify_evm():
|
||||
|
||||
assert res == True
|
||||
|
||||
async def test_verify_evm_separate_vk():
|
||||
"""
|
||||
Verifies an evm proof
|
||||
In order to run this you will need to install solc in your environment
|
||||
"""
|
||||
proof_path = os.path.join(folder_path, 'test_evm.pf')
|
||||
addr_path_verifier = os.path.join(folder_path, 'address_separate.json')
|
||||
addr_path_vk = os.path.join(folder_path, 'address_vk.json')
|
||||
proof_path = os.path.join(folder_path, 'test_evm.pf')
|
||||
calldata_path = os.path.join(folder_path, 'calldata_separate.bytes')
|
||||
|
||||
with open(addr_path_verifier, 'r') as file:
|
||||
addr_verifier = file.read().rstrip()
|
||||
|
||||
print(addr_verifier)
|
||||
|
||||
with open(addr_path_vk, 'r') as file:
|
||||
addr_vk = file.read().rstrip()
|
||||
|
||||
print(addr_vk)
|
||||
|
||||
# res is now a vector of bytes
|
||||
res = ezkl.encode_evm_calldata(proof_path, calldata_path, addr_vk=addr_vk)
|
||||
|
||||
assert os.path.isfile(calldata_path)
|
||||
assert len(res) > 0
|
||||
|
||||
# TODO: without optimization there will be out of gas errors
|
||||
# sol_code_path = os.path.join(folder_path, 'test.sol')
|
||||
|
||||
res = await ezkl.verify_evm(
|
||||
addr_verifier,
|
||||
proof_path,
|
||||
rpc_url=anvil_url,
|
||||
addr_vk=addr_vk,
|
||||
# sol_code_path
|
||||
# optimizer_runs
|
||||
)
|
||||
|
||||
assert res == True
|
||||
|
||||
|
||||
async def test_aggregate_and_verify_aggr():
|
||||
data_path = os.path.join(
|
||||
|
||||
Reference in New Issue
Block a user