mirror of
https://github.com/zkonduit/ezkl.git
synced 2026-01-09 14:28:00 -05:00
fix: compiled model serialization with neq binaries (#395)
This commit is contained in:
4
.github/workflows/rust.yml
vendored
4
.github/workflows/rust.yml
vendored
@@ -78,7 +78,9 @@ jobs:
|
||||
crate: cargo-nextest
|
||||
locked: true
|
||||
- name: Model serialization
|
||||
run: cargo nextest run native_tests::tests::model_serialization_
|
||||
run: cargo nextest run native_tests::tests::model_serialization_::t
|
||||
- name: Model serialization different binary ID
|
||||
run: cargo nextest run native_tests::tests::model_serialization_different_binaries_::t --test-threads 1
|
||||
|
||||
wasm32-tests:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -39,4 +39,5 @@ var/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
.vscode/
|
||||
*.whl
|
||||
*.whl
|
||||
*.bak
|
||||
319
Cargo.lock
generated
319
Cargo.lock
generated
@@ -24,7 +24,7 @@ version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac1f845298e95f983ff1944b728ae08b8cebab80d684f0a832ed0fc74dfa27e2"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"cipher",
|
||||
"cpufeatures",
|
||||
]
|
||||
@@ -46,7 +46,7 @@ version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2c99f64d1e06488f620f932677e24bc6e2897582980441ae90a671415bd7ec2f"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"once_cell",
|
||||
"version_check",
|
||||
]
|
||||
@@ -90,7 +90,7 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "220044e6a1bb31ddee4e3db724d29767f352de47445a6cd75e1a173142136c83"
|
||||
dependencies = [
|
||||
"nom 7.1.3",
|
||||
"nom",
|
||||
"vte",
|
||||
]
|
||||
|
||||
@@ -189,18 +189,6 @@ version = "0.7.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711"
|
||||
|
||||
[[package]]
|
||||
name = "as-slice"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "45403b49e3954a4b8428a0ac21a4b7afadccf92bfd96273f1a58cd4812496ae0"
|
||||
dependencies = [
|
||||
"generic-array 0.12.4",
|
||||
"generic-array 0.13.3",
|
||||
"generic-array 0.14.7",
|
||||
"stable_deref_trait",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ascii-canvas"
|
||||
version = "3.0.0"
|
||||
@@ -361,7 +349,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4152116fd6e9dadb291ae18fc1ec3575ed6d84c29642d97890f4b4a3417297e4"
|
||||
dependencies = [
|
||||
"block-padding",
|
||||
"generic-array 0.14.7",
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -370,7 +358,7 @@ version = "0.10.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71"
|
||||
dependencies = [
|
||||
"generic-array 0.14.7",
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -388,19 +376,6 @@ dependencies = [
|
||||
"sha2 0.9.9",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "build_id"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c6deb6795d8b4d2269c3fcf87a87bff9f4cd45a99e259806603ee8007077daf3"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"once_cell",
|
||||
"palaver",
|
||||
"twox-hash",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.13.0"
|
||||
@@ -484,12 +459,6 @@ version = "1.0.79"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "0.1.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4785bdd1c96b2a846b2bd7cc02e86b6b3dbf14e7e53446c4f54c92a361040822"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
@@ -629,7 +598,7 @@ dependencies = [
|
||||
"bech32",
|
||||
"bs58",
|
||||
"digest 0.10.7",
|
||||
"generic-array 0.14.7",
|
||||
"generic-array",
|
||||
"hex",
|
||||
"ripemd",
|
||||
"serde",
|
||||
@@ -693,7 +662,7 @@ version = "0.1.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
@@ -788,7 +757,7 @@ version = "1.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b540bd8bc810d3885c6ea91e2018302f68baba2129ab3e88f32389ee9370880d"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -833,7 +802,7 @@ version = "0.5.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
@@ -843,7 +812,7 @@ version = "0.8.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
@@ -855,7 +824,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
"memoffset 0.9.0",
|
||||
"scopeguard",
|
||||
@@ -867,7 +836,7 @@ version = "0.8.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -882,7 +851,7 @@ version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf4c2f4e1afd912bc40bfd6fed5d9dc1f288e0ba01bfcc835cc5bc3eb13efe15"
|
||||
dependencies = [
|
||||
"generic-array 0.14.7",
|
||||
"generic-array",
|
||||
"rand_core 0.6.4",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
@@ -894,7 +863,7 @@ version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3"
|
||||
dependencies = [
|
||||
"generic-array 0.14.7",
|
||||
"generic-array",
|
||||
"typenum",
|
||||
]
|
||||
|
||||
@@ -1034,7 +1003,7 @@ version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d3dd60d1080a57a05ab032377049e0591415d2b31afd7028356dbf3cc6dcb066"
|
||||
dependencies = [
|
||||
"generic-array 0.14.7",
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1064,7 +1033,7 @@ version = "2.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b98cf8ebf19c3d1b223e151f99a4f9f0690dca41414773390fc824184ac833e1"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"dirs-sys-next",
|
||||
]
|
||||
|
||||
@@ -1179,7 +1148,7 @@ dependencies = [
|
||||
"crypto-bigint",
|
||||
"digest 0.10.7",
|
||||
"ff",
|
||||
"generic-array 0.14.7",
|
||||
"generic-array",
|
||||
"group",
|
||||
"pkcs8",
|
||||
"rand_core 0.6.4",
|
||||
@@ -1209,7 +1178,7 @@ version = "0.8.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "071a31f4ee85403370b58aca746f01041ede6f0da2730960ad001edc2b71b394"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1256,15 +1225,6 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "88bffebc5d80432c9b140ee17875ff173a8ab62faad5b257da912bd2f6c1c0a1"
|
||||
|
||||
[[package]]
|
||||
name = "erased-serde"
|
||||
version = "0.3.26"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c6984864d65d092d9e9ada107007a846a09f75d2e24046bcce9a38d14aa52052"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.1"
|
||||
@@ -1453,7 +1413,7 @@ dependencies = [
|
||||
"chrono",
|
||||
"elliptic-curve",
|
||||
"ethabi",
|
||||
"generic-array 0.14.7",
|
||||
"generic-array",
|
||||
"hex",
|
||||
"k256",
|
||||
"num_enum",
|
||||
@@ -1574,7 +1534,7 @@ version = "2.0.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a81c89f121595cf8959e746045bb8b25a6a38d72588561e1a3b7992fc213f674"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"dunce",
|
||||
"ethers-core",
|
||||
"glob",
|
||||
@@ -1647,7 +1607,6 @@ dependencies = [
|
||||
"serde",
|
||||
"serde-wasm-bindgen",
|
||||
"serde_json",
|
||||
"serde_traitobject",
|
||||
"shellexpand",
|
||||
"snark-verifier",
|
||||
"tabled",
|
||||
@@ -1708,7 +1667,7 @@ version = "0.2.21"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5cbc844cecaee9d4443931972e1289c8ff485cb4cc2767cb03ca139ed6885153"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall 0.2.16",
|
||||
"windows-sys 0.48.0",
|
||||
@@ -1964,24 +1923,6 @@ dependencies = [
|
||||
"tempfile",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.12.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ffdf9f34f1447443d37393cc6c2b8313aebddcd96906caf34e54c68d8e57d7bd"
|
||||
dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.13.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f797e67af32588215eaaab8327027ee8e71b9dd0b2b26996aedf20c030fce309"
|
||||
dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "generic-array"
|
||||
version = "0.14.7"
|
||||
@@ -1999,7 +1940,7 @@ version = "0.2.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"libc",
|
||||
"wasi 0.11.0+wasi-snapshot-preview1",
|
||||
@@ -2185,15 +2126,6 @@ dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hash32"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d4041af86e63ac4298ce40e5cca669066e75b6f1aa3390fe2561ffa5e1d9f4cc"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.11.2"
|
||||
@@ -2233,18 +2165,6 @@ dependencies = [
|
||||
"fxhash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heapless"
|
||||
version = "0.5.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "74911a68a1658cfcfb61bc0ccfbd536e3b6e906f8c2f7883ee50157e3e2184f1"
|
||||
dependencies = [
|
||||
"as-slice",
|
||||
"generic-array 0.13.3",
|
||||
"hash32",
|
||||
"stable_deref_trait",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.4.1"
|
||||
@@ -2520,7 +2440,7 @@ version = "0.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5"
|
||||
dependencies = [
|
||||
"generic-array 0.14.7",
|
||||
"generic-array",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -2529,7 +2449,7 @@ version = "0.1.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
"web-sys",
|
||||
@@ -2613,7 +2533,7 @@ version = "0.13.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cadb76004ed8e97623117f3df85b17aaa6626ab0b0831e6573f104df16cd1bcc"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"ecdsa",
|
||||
"elliptic-curve",
|
||||
"once_cell",
|
||||
@@ -2689,7 +2609,7 @@ version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d580318f95776505201b28cf98eb1fa5e4be3b689633ba6a3e6cd880ff22d8cb"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"windows-sys 0.48.0",
|
||||
]
|
||||
|
||||
@@ -2778,15 +2698,6 @@ version = "0.4.19"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b06a4cde4c0f271a446782e3eff8de789548ce57dbc8eca9292c27f4a42004b4"
|
||||
|
||||
[[package]]
|
||||
name = "mach"
|
||||
version = "0.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b823e83b2affd8f40a9ee8c29dbc56404c1e34cd2710921f2801e2cf29527afa"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "maingate"
|
||||
version = "0.1.0"
|
||||
@@ -2858,12 +2769,6 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "metatype"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "23decce7c32638bcefbd5a5a5d79a5bb5b720c47b82ad5cb670a7eb912705946"
|
||||
|
||||
[[package]]
|
||||
name = "mime"
|
||||
version = "0.3.17"
|
||||
@@ -2943,25 +2848,6 @@ version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e4a24736216ec316047a1fc4252e27dabb04218aa4a3f37c6e7ddbf1f9782b54"
|
||||
|
||||
[[package]]
|
||||
name = "nix"
|
||||
version = "0.15.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b2e0b4f3320ed72aaedb9a5ac838690a8047c7b275da22711fddff4f8a14229"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"cc",
|
||||
"cfg-if 0.1.10",
|
||||
"libc",
|
||||
"void",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "2.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cf51a729ecf40266a2368ad335a5fdde43471f545a967109cd62146ecf8b66ff"
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "7.1.3"
|
||||
@@ -3137,7 +3023,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "345df152bc43501c5eb9e4654ff05f794effb78d4efe3d53abc158baddc0703d"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"foreign-types",
|
||||
"libc",
|
||||
"once_cell",
|
||||
@@ -3190,23 +3076,6 @@ version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d"
|
||||
|
||||
[[package]]
|
||||
name = "palaver"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49dfc200733ac34dcd9a1e4a7e454b521723936010bef3710e2d8024a32d685f"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"heapless",
|
||||
"lazy_static",
|
||||
"libc",
|
||||
"mach",
|
||||
"nix",
|
||||
"procinfo",
|
||||
"typenum",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "papergrid"
|
||||
version = "0.9.1"
|
||||
@@ -3262,7 +3131,7 @@ version = "0.9.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93f00c865fe7cabf650081affecd3871070f26767e7b2070a3ffae14c654b447"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"redox_syscall 0.3.5",
|
||||
"smallvec",
|
||||
@@ -3665,18 +3534,6 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "procinfo"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ab1427f3d2635891f842892dda177883dca0639e05fe66796a62c9d2f23b49c"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"libc",
|
||||
"nom 2.2.1",
|
||||
"rustc_version 0.2.3",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prost"
|
||||
version = "0.11.9"
|
||||
@@ -3706,7 +3563,7 @@ version = "0.18.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e3b1ac5b3731ba34fdaa9785f8d74d17448cd18f30cf19e0c7e7b1fdb5272109"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"indoc",
|
||||
"libc",
|
||||
"memoffset 0.8.0",
|
||||
@@ -3974,17 +3831,6 @@ version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "436b050e76ed2903236f032a59761c1eb99e1b0aead2c257922771dab1fc8c78"
|
||||
|
||||
[[package]]
|
||||
name = "relative"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3401c189ee92c7028ba4863f3fdb92af815789993221af2fa186eed8115da304"
|
||||
dependencies = [
|
||||
"build_id",
|
||||
"serde",
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "remove_dir_all"
|
||||
version = "0.5.3"
|
||||
@@ -4159,15 +4005,6 @@ version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3e75f6a532d0fd9f7f13144f392b6ad56a32696bfcd9c78f797f16bbb6f072d6"
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "138e3e0acb6c9fb258b19b67cb8abd63c00679d2851805ea151465464fe9030a"
|
||||
dependencies = [
|
||||
"semver 0.9.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rustc_version"
|
||||
version = "0.3.3"
|
||||
@@ -4251,7 +4088,7 @@ version = "2.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad560913365790f17cbf12479491169f01b9d46d29cfc7422bf8c64bdc61b731"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"derive_more",
|
||||
"parity-scale-codec",
|
||||
"scale-info-derive",
|
||||
@@ -4319,7 +4156,7 @@ checksum = "f0aec48e813d6b90b15f0b8948af3c63483992dee44c03e9930b3eebdabe046e"
|
||||
dependencies = [
|
||||
"base16ct",
|
||||
"der",
|
||||
"generic-array 0.14.7",
|
||||
"generic-array",
|
||||
"pkcs8",
|
||||
"subtle",
|
||||
"zeroize",
|
||||
@@ -4366,22 +4203,13 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "0.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1d7eb9ef2c18661902cc47e535f9bc51b78acd254da71d375c2f6720d9a40403"
|
||||
dependencies = [
|
||||
"semver-parser 0.7.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "0.11.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f301af10236f6df4160f7c3f04eec6dbc70ace82d23326abad5edee88801c6b6"
|
||||
dependencies = [
|
||||
"semver-parser 0.10.2",
|
||||
"semver-parser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4393,12 +4221,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "semver-parser"
|
||||
version = "0.7.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "388a1df253eca08550bef6c72392cfe7c30914bf41df5269b68cbd6ff8f570a3"
|
||||
|
||||
[[package]]
|
||||
name = "semver-parser"
|
||||
version = "0.10.2"
|
||||
@@ -4465,28 +4287,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_closure"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9659437bcfbf4dd061a5e1f7994312990ac5b24d781f7ce577eefc3a27792da0"
|
||||
dependencies = [
|
||||
"rustversion",
|
||||
"serde",
|
||||
"serde_closure_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_closure_derive"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a6bb4d612b5caad466a9a09ee550445e34123a74075607cc0d882ff1ca28f46"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.164"
|
||||
@@ -4518,19 +4318,6 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_traitobject"
|
||||
version = "0.2.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c5ae15a5d31f7c57875a480ddd7be02314d264617d0294d961314a6d502e6b1"
|
||||
dependencies = [
|
||||
"erased-serde",
|
||||
"metatype",
|
||||
"relative",
|
||||
"serde",
|
||||
"serde_closure",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_urlencoded"
|
||||
version = "0.7.1"
|
||||
@@ -4550,7 +4337,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4d58a1e1bf39749807d89cf2d98ac2dfa0ff1cb3faa38fbb64dd88ac8013d800"
|
||||
dependencies = [
|
||||
"block-buffer 0.9.0",
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest 0.9.0",
|
||||
"opaque-debug",
|
||||
@@ -4562,7 +4349,7 @@ version = "0.10.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "479fb9d862239e610720565ca91403019f2f00410f1864c5aa7479b950a76ed8"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"cpufeatures",
|
||||
"digest 0.10.7",
|
||||
]
|
||||
@@ -4702,12 +4489,6 @@ version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02a8428da277a8e3a15271d79943e80ccc2ef254e78813a166a08d65e4c3ece5"
|
||||
|
||||
[[package]]
|
||||
name = "stable_deref_trait"
|
||||
version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
|
||||
|
||||
[[package]]
|
||||
name = "static_assertions"
|
||||
version = "1.1.0"
|
||||
@@ -4726,7 +4507,7 @@ version = "0.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "91e2531d8525b29b514d25e275a43581320d587b86db302b9a7e464bac579648"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"hashbrown 0.11.2",
|
||||
"serde",
|
||||
]
|
||||
@@ -4896,7 +4677,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "31c0432476357e58790aaa47a8efb0c5138f137343f3b5f23bd36a27e3b0a6d6"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"fastrand",
|
||||
"redox_syscall 0.3.5",
|
||||
"rustix",
|
||||
@@ -4929,7 +4710,7 @@ version = "2.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e45b7bf6e19353ddd832745c8fcf77a17a93171df7151187f26623f2b75b5b26"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
@@ -5135,7 +4916,7 @@ version = "0.1.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8ce8c33a8d48bd45d624a6e523445fd21ec13d3653cd51f681abf67418f54eb8"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"pin-project-lite",
|
||||
"tracing-attributes",
|
||||
"tracing-core",
|
||||
@@ -5205,7 +4986,7 @@ dependencies = [
|
||||
"lazy_static",
|
||||
"maplit",
|
||||
"ndarray",
|
||||
"nom 7.1.3",
|
||||
"nom",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
"scan_fmt",
|
||||
@@ -5254,7 +5035,7 @@ dependencies = [
|
||||
"byteorder",
|
||||
"flate2",
|
||||
"log",
|
||||
"nom 7.1.3",
|
||||
"nom",
|
||||
"tar",
|
||||
"tract-core",
|
||||
"walkdir",
|
||||
@@ -5312,16 +5093,6 @@ version = "0.17.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "375812fa44dab6df41c195cd2f7fecb488f6c09fbaafb62807488cefab642bff"
|
||||
|
||||
[[package]]
|
||||
name = "twox-hash"
|
||||
version = "1.6.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.16.0"
|
||||
@@ -5430,12 +5201,6 @@ version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "void"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d"
|
||||
|
||||
[[package]]
|
||||
name = "vte"
|
||||
version = "0.10.1"
|
||||
@@ -5494,7 +5259,7 @@ version = "0.2.87"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7706a72ab36d8cb1f80ffbf0e071533974a60d0a308d01a5d0375bf60499a342"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"wasm-bindgen-macro",
|
||||
@@ -5521,7 +5286,7 @@ version = "0.4.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c02dbc21516f9f1f04f187958890d7e6026df8d16540b7ad9492bc34a67cea03"
|
||||
dependencies = [
|
||||
"cfg-if 1.0.0",
|
||||
"cfg-if",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
"web-sys",
|
||||
|
||||
@@ -38,8 +38,8 @@ tokio = { version = "1.26.0", default_features = false, features = ["macros", "
|
||||
rayon = { version = "1.7.0", default_features = false }
|
||||
bincode = { version = "1.3.3", default_features = false }
|
||||
ark-std = { version = "^0.3.0", default-features = false }
|
||||
serde_traitobject = { version = "0.2.8", features = ["serde_closure"] }
|
||||
lazy_static = "1.4.0"
|
||||
|
||||
# python binding related deps
|
||||
pyo3 = { version = "0.18.3", features = ["extension-module", "abi3-py37", "macros"], default_features = false, optional = true }
|
||||
pyo3-asyncio = { version = "0.18.0", features = ["attributes", "tokio-runtime"], default_features = false, optional = true }
|
||||
@@ -73,7 +73,7 @@ console_error_panic_hook = "0.1.7"
|
||||
[dev-dependencies]
|
||||
criterion = {version = "0.3", features = ["html_reports"]}
|
||||
tempfile = "3.3.0"
|
||||
|
||||
lazy_static = "1.4.0"
|
||||
mnist = "0.5"
|
||||
seq-macro = "0.3.1"
|
||||
test-case = "2.2.2"
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
use std::{any::Any, error::Error};
|
||||
|
||||
use itertools::Itertools;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
fieldutils::{felt_to_i128, i128_to_felt},
|
||||
tensor::{self, Tensor, TensorError, TensorType, ValTensor},
|
||||
};
|
||||
use crate::tensor::{self, Tensor, TensorError, TensorType, ValTensor};
|
||||
use halo2curves::ff::PrimeField;
|
||||
|
||||
use self::{lookup::LookupOp, region::RegionCtx};
|
||||
@@ -34,9 +30,7 @@ pub struct ForwardResult<F: PrimeField + TensorType + PartialOrd> {
|
||||
}
|
||||
|
||||
/// An enum representing operations that can be represented as constraints in a circuit.
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd>:
|
||||
std::fmt::Debug + Send + Sync + Any + serde_traitobject::Serialize + serde_traitobject::Deserialize
|
||||
{
|
||||
pub trait Op<F: PrimeField + TensorType + PartialOrd>: std::fmt::Debug + Send + Sync + Any {
|
||||
/// Matches a [Op] to an operation in the `tensor::ops` module.
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError>;
|
||||
/// Returns a string representation of the operation.
|
||||
@@ -144,86 +138,6 @@ impl<F: PrimeField + TensorType + PartialOrd> Op<F> for Input {
|
||||
}
|
||||
}
|
||||
|
||||
/// A wrapper for an operation that has been rescaled.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Rescaled<F: PrimeField + TensorType + PartialOrd> {
|
||||
/// The operation to be rescaled.
|
||||
#[serde(with = "serde_traitobject")]
|
||||
pub inner: Box<dyn Op<F>>,
|
||||
/// The scale of the operation's inputs.
|
||||
pub scale: Vec<(usize, u128)>,
|
||||
}
|
||||
|
||||
impl<F: PrimeField + TensorType + PartialOrd + Serialize> Op<F> for Rescaled<F> {
|
||||
fn as_any(&self) -> &dyn Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, x: &[Tensor<F>]) -> Result<ForwardResult<F>, TensorError> {
|
||||
if self.scale.len() != x.len() {
|
||||
return Err(TensorError::DimMismatch("rescaled inputs".to_string()));
|
||||
}
|
||||
|
||||
let mut rescaled_inputs = vec![];
|
||||
let inputs = &mut x.to_vec();
|
||||
for (i, ri) in inputs.iter_mut().enumerate() {
|
||||
let ri = ri.map(|x| felt_to_i128(x));
|
||||
let res = tensor::ops::nonlinearities::const_div(&ri, self.scale[i].1 as f64);
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
rescaled_inputs.push(output);
|
||||
}
|
||||
Op::<F>::f(&*self.inner, &rescaled_inputs)
|
||||
}
|
||||
|
||||
fn rescale(&self, _: Vec<u32>, _: u32) -> Box<dyn Op<F>> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!("RESCALED {}", self.inner.as_string())
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<u32>, _g: u32) -> u32 {
|
||||
let in_scales = in_scales
|
||||
.into_iter()
|
||||
.zip(self.scale.iter())
|
||||
.map(|(a, b)| a - crate::graph::mult_to_scale(b.1 as f64))
|
||||
.collect();
|
||||
|
||||
Op::<F>::out_scale(&*self.inner, in_scales, _g)
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
let mut required_lookups = vec![];
|
||||
for scale in &self.scale {
|
||||
if scale.1 != 0 {
|
||||
required_lookups.push(LookupOp::Div {
|
||||
denom: (scale.1 as f32).into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
required_lookups
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<F>,
|
||||
region: &mut RegionCtx<F>,
|
||||
values: &[ValTensor<F>],
|
||||
) -> Result<Option<ValTensor<F>>, Box<dyn Error>> {
|
||||
if self.scale.len() != values.len() {
|
||||
return Err(Box::new(TensorError::DimMismatch(
|
||||
"rescaled inputs".to_string(),
|
||||
)));
|
||||
}
|
||||
let res = &layouts::rescale(config, region, values[..].try_into()?, &self.scale)?[..];
|
||||
self.inner.layout(config, region, res)
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<F>> {
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
}
|
||||
|
||||
/// An unknown operation.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, Serialize, Deserialize)]
|
||||
pub struct Unknown;
|
||||
@@ -347,42 +261,3 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
}
|
||||
|
||||
fn homogenize_input_scales<F: PrimeField + TensorType + PartialOrd + Serialize>(
|
||||
op: impl Op<F> + Clone,
|
||||
input_scales: Vec<u32>,
|
||||
inputs_to_scale: Vec<usize>,
|
||||
) -> Result<Box<dyn Op<F>>, Box<dyn Error>> {
|
||||
if inputs_to_scale.is_empty() {
|
||||
return Ok(Box::new(op));
|
||||
}
|
||||
|
||||
let mut dividers: Vec<u128> = vec![1; input_scales.len()];
|
||||
if !input_scales.windows(2).all(|w| w[0] == w[1]) {
|
||||
let min_scale = input_scales.iter().min().unwrap();
|
||||
let _ = input_scales
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, input_scale)| {
|
||||
if !inputs_to_scale.contains(&idx) {
|
||||
return;
|
||||
}
|
||||
let scale_diff = input_scale - min_scale;
|
||||
if scale_diff > 0 {
|
||||
let mult = crate::graph::scale_to_multiplier(scale_diff);
|
||||
dividers[idx] = mult as u128;
|
||||
}
|
||||
})
|
||||
.collect_vec();
|
||||
}
|
||||
|
||||
// only rescale if need to
|
||||
if dividers.iter().any(|&x| x > 1) {
|
||||
Ok(Box::new(crate::circuit::Rescaled {
|
||||
inner: Box::new(op),
|
||||
scale: (0..input_scales.len()).zip(dividers).collect_vec(),
|
||||
}))
|
||||
} else {
|
||||
Ok(Box::new(op))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -396,10 +396,8 @@ impl<F: PrimeField + TensorType + PartialOrd + Serialize + for<'de> Deserialize<
|
||||
}
|
||||
}
|
||||
|
||||
fn rescale(&self, input_scales: Vec<u32>, _: u32) -> Box<dyn Op<F>> {
|
||||
let inputs_to_scale = self.requires_homogenous_input_scales();
|
||||
// creates a rescaled op if the inputs are not homogenous
|
||||
homogenize_input_scales::<F>(self.clone(), input_scales, inputs_to_scale).unwrap()
|
||||
fn rescale(&self, _: Vec<u32>, _: u32) -> Box<dyn Op<F>> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
|
||||
@@ -1358,95 +1358,6 @@ mod pack {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod rescaled {
|
||||
use super::*;
|
||||
|
||||
const K: usize = 8;
|
||||
const LEN: usize = 4;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct MyCircuit<F: PrimeField + TensorType + PartialOrd> {
|
||||
inputs: [ValTensor<F>; 1],
|
||||
_marker: PhantomData<F>,
|
||||
}
|
||||
|
||||
impl Circuit<F> for MyCircuit<F> {
|
||||
type Config = BaseConfig<F>;
|
||||
type FloorPlanner = SimpleFloorPlanner;
|
||||
type Params = TestParams;
|
||||
|
||||
fn without_witnesses(&self) -> Self {
|
||||
self.clone()
|
||||
}
|
||||
|
||||
fn configure(cs: &mut ConstraintSystem<F>) -> Self::Config {
|
||||
let a = VarTensor::new_advice(cs, K, LEN);
|
||||
let b = VarTensor::new_advice(cs, K, LEN);
|
||||
let output = VarTensor::new_advice(cs, K, LEN);
|
||||
|
||||
let mut config = Self::Config::configure(cs, &[a, b.clone()], &output, CheckMode::SAFE);
|
||||
|
||||
config
|
||||
.configure_lookup(
|
||||
cs,
|
||||
&b,
|
||||
&output,
|
||||
7,
|
||||
&LookupOp::Div {
|
||||
denom: (5.0).into(),
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
fn synthesize(
|
||||
&self,
|
||||
mut config: Self::Config,
|
||||
mut layouter: impl Layouter<F>,
|
||||
) -> Result<(), Error> {
|
||||
config.layout_tables(&mut layouter).unwrap();
|
||||
|
||||
layouter
|
||||
.assign_region(
|
||||
|| "",
|
||||
|region| {
|
||||
let mut region = RegionCtx::new(region, 0);
|
||||
config
|
||||
.layout(
|
||||
&mut region,
|
||||
&self.inputs.clone(),
|
||||
Box::new(Rescaled {
|
||||
inner: Box::new(PolyOp::Sum { axes: vec![0] }),
|
||||
scale: vec![(0, 5)],
|
||||
}),
|
||||
)
|
||||
.map_err(|_| Error::Synthesis)
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rescaledcircuit() {
|
||||
// parameters
|
||||
let mut a = Tensor::from((0..LEN).map(|i| Value::known(F::from(i as u64))));
|
||||
a.reshape(&[LEN, 1]);
|
||||
|
||||
let circuit = MyCircuit::<F> {
|
||||
inputs: [ValTensor::from(a)],
|
||||
_marker: PhantomData,
|
||||
};
|
||||
|
||||
let prover = MockProver::run(K as u32, &circuit, vec![]).unwrap();
|
||||
prover.assert_satisfied();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod matmul_relu {
|
||||
use super::*;
|
||||
|
||||
@@ -152,7 +152,7 @@ impl NodeType {
|
||||
.downcast_ref::<crate::circuit::Constant<Fp>>()
|
||||
{
|
||||
if c.num_uses > 0 {
|
||||
n.opkind = Box::new(crate::circuit::Constant {
|
||||
n.opkind = SupportedOp::Constant(crate::circuit::Constant {
|
||||
num_uses: c.num_uses - 1,
|
||||
..c.clone()
|
||||
});
|
||||
@@ -172,7 +172,7 @@ impl NodeType {
|
||||
}
|
||||
|
||||
/// Replace the operation kind of the node.
|
||||
pub fn replace_opkind(&mut self, opkind: Box<dyn Op<Fp>>) {
|
||||
pub fn replace_opkind(&mut self, opkind: SupportedOp) {
|
||||
match self {
|
||||
NodeType::Node(n) => n.opkind = opkind,
|
||||
NodeType::SubGraph { .. } => log::warn!("Cannot replace opkind of subgraph"),
|
||||
@@ -383,7 +383,7 @@ impl Model {
|
||||
|
||||
match n {
|
||||
NodeType::Node(n) => {
|
||||
let res = Op::<Fp>::f(&*n.opkind, &inputs)?;
|
||||
let res = Op::<Fp>::f(&n.opkind, &inputs)?;
|
||||
// see if any of the intermediate lookup calcs are the max
|
||||
if !res.intermediate_lookups.is_empty() {
|
||||
let mut max = 0;
|
||||
@@ -594,7 +594,7 @@ impl Model {
|
||||
i,
|
||||
)?;
|
||||
if n.opkind.is_input() {
|
||||
n.opkind = Box::new(Input {
|
||||
n.opkind = SupportedOp::Input(Input {
|
||||
scale: input_scales[input_idx],
|
||||
});
|
||||
n.out_scale = n.opkind.out_scale(vec![], 0);
|
||||
@@ -954,7 +954,7 @@ impl Model {
|
||||
constant.raw_values.clone(),
|
||||
);
|
||||
op.pre_assign(consts[const_idx].clone());
|
||||
n.opkind = Box::new(op);
|
||||
n.opkind = SupportedOp::Constant(op);
|
||||
|
||||
const_idx += 1;
|
||||
}
|
||||
|
||||
@@ -1,7 +1,17 @@
|
||||
use super::utilities::node_output_shapes;
|
||||
use super::Visibility;
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::Constant;
|
||||
use crate::circuit::Input;
|
||||
use crate::circuit::Op;
|
||||
use crate::circuit::Unknown;
|
||||
use crate::fieldutils::felt_to_i128;
|
||||
use crate::fieldutils::i128_to_felt;
|
||||
use crate::graph::new_op_from_onnx;
|
||||
use crate::tensor::Tensor;
|
||||
use crate::tensor::TensorError;
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
use log::trace;
|
||||
use serde::Deserialize;
|
||||
@@ -24,17 +34,278 @@ fn display_vector<T: fmt::Debug>(v: &Vec<T>) -> String {
|
||||
}
|
||||
|
||||
#[allow(clippy::borrowed_box)]
|
||||
fn display_opkind(v: &Box<dyn Op<Fp>>) -> String {
|
||||
fn display_opkind(v: &SupportedOp) -> String {
|
||||
v.as_string()
|
||||
}
|
||||
|
||||
/// A wrapper for an operation that has been rescaled.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Rescaled {
|
||||
/// The operation that has to be rescaled.
|
||||
pub inner: Box<SupportedOp>,
|
||||
/// The scale of the operation's inputs.
|
||||
pub scale: Vec<(usize, u128)>,
|
||||
}
|
||||
|
||||
impl Op<Fp> for Rescaled {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
fn f(&self, x: &[Tensor<Fp>]) -> Result<crate::circuit::ForwardResult<Fp>, TensorError> {
|
||||
if self.scale.len() != x.len() {
|
||||
return Err(TensorError::DimMismatch("rescaled inputs".to_string()));
|
||||
}
|
||||
|
||||
let mut rescaled_inputs = vec![];
|
||||
let inputs = &mut x.to_vec();
|
||||
for (i, ri) in inputs.iter_mut().enumerate() {
|
||||
let ri = ri.map(|x| felt_to_i128(x));
|
||||
let res = crate::tensor::ops::nonlinearities::const_div(&ri, self.scale[i].1 as f64);
|
||||
let output = res.map(|x| i128_to_felt(x));
|
||||
rescaled_inputs.push(output);
|
||||
}
|
||||
Op::<Fp>::f(&*self.inner, &rescaled_inputs)
|
||||
}
|
||||
|
||||
fn rescale(&self, _: Vec<u32>, _: u32) -> Box<dyn Op<Fp>> {
|
||||
Box::new(self.clone())
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
format!("RESCALED {}", self.inner.as_string())
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<u32>, _g: u32) -> u32 {
|
||||
let in_scales = in_scales
|
||||
.into_iter()
|
||||
.zip(self.scale.iter())
|
||||
.map(|(a, b)| a - crate::graph::mult_to_scale(b.1 as f64))
|
||||
.collect();
|
||||
|
||||
Op::<Fp>::out_scale(&*self.inner, in_scales, _g)
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
let mut required_lookups = vec![];
|
||||
for scale in &self.scale {
|
||||
if scale.1 != 0 {
|
||||
required_lookups.push(LookupOp::Div {
|
||||
denom: (scale.1 as f32).into(),
|
||||
});
|
||||
}
|
||||
}
|
||||
required_lookups
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
region: &mut crate::circuit::region::RegionCtx<Fp>,
|
||||
values: &[crate::tensor::ValTensor<Fp>],
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
|
||||
if self.scale.len() != values.len() {
|
||||
return Err(Box::new(TensorError::DimMismatch(
|
||||
"rescaled inputs".to_string(),
|
||||
)));
|
||||
}
|
||||
|
||||
let res =
|
||||
&crate::circuit::layouts::rescale(config, region, values[..].try_into()?, &self.scale)?
|
||||
[..];
|
||||
self.inner.layout(config, region, res)
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
Box::new(self.clone()) // Forward to the derive(Clone) impl
|
||||
}
|
||||
}
|
||||
|
||||
/// A single operation in a [crate::graph::Model].
|
||||
#[derive(Clone, Debug, Tabled, Serialize, Deserialize)]
|
||||
pub enum SupportedOp {
|
||||
/// A linear operation.
|
||||
Linear(PolyOp<Fp>),
|
||||
/// A nonlinear operation.
|
||||
Nonlinear(LookupOp),
|
||||
/// A hybrid operation.
|
||||
Hybrid(HybridOp),
|
||||
///
|
||||
Input(Input),
|
||||
///
|
||||
Constant(Constant<Fp>),
|
||||
///
|
||||
Unknown(Unknown),
|
||||
///
|
||||
Rescaled(Rescaled),
|
||||
}
|
||||
|
||||
impl From<Box<dyn Op<Fp>>> for SupportedOp {
|
||||
fn from(value: Box<dyn Op<Fp>>) -> Self {
|
||||
match value.as_any().downcast_ref::<PolyOp<Fp>>() {
|
||||
Some(op) => return SupportedOp::Linear(op.clone()),
|
||||
None => {}
|
||||
};
|
||||
match value.as_any().downcast_ref::<LookupOp>() {
|
||||
Some(op) => return SupportedOp::Nonlinear(op.clone()),
|
||||
None => {}
|
||||
};
|
||||
match value.as_any().downcast_ref::<HybridOp>() {
|
||||
Some(op) => return SupportedOp::Hybrid(op.clone()),
|
||||
None => {}
|
||||
};
|
||||
match value.as_any().downcast_ref::<Input>() {
|
||||
Some(op) => return SupportedOp::Input(op.clone()),
|
||||
None => {}
|
||||
};
|
||||
match value.as_any().downcast_ref::<Constant<Fp>>() {
|
||||
Some(op) => return SupportedOp::Constant(op.clone()),
|
||||
None => {}
|
||||
};
|
||||
match value.as_any().downcast_ref::<Unknown>() {
|
||||
Some(op) => return SupportedOp::Unknown(op.clone()),
|
||||
None => {}
|
||||
};
|
||||
match value.as_any().downcast_ref::<Rescaled>() {
|
||||
Some(op) => return SupportedOp::Rescaled(op.clone()),
|
||||
None => {}
|
||||
};
|
||||
panic!("Unsupported op type")
|
||||
}
|
||||
}
|
||||
|
||||
impl Op<Fp> for SupportedOp {
|
||||
fn f(
|
||||
&self,
|
||||
inputs: &[Tensor<Fp>],
|
||||
) -> Result<crate::circuit::ForwardResult<Fp>, crate::tensor::TensorError> {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => op.f(inputs),
|
||||
SupportedOp::Nonlinear(op) => op.f(inputs),
|
||||
SupportedOp::Hybrid(op) => op.f(inputs),
|
||||
SupportedOp::Input(op) => op.f(inputs),
|
||||
SupportedOp::Constant(op) => op.f(inputs),
|
||||
SupportedOp::Unknown(op) => op.f(inputs),
|
||||
SupportedOp::Rescaled(op) => op.f(inputs),
|
||||
}
|
||||
}
|
||||
|
||||
fn layout(
|
||||
&self,
|
||||
config: &mut crate::circuit::BaseConfig<Fp>,
|
||||
region: &mut crate::circuit::region::RegionCtx<Fp>,
|
||||
values: &[crate::tensor::ValTensor<Fp>],
|
||||
) -> Result<Option<crate::tensor::ValTensor<Fp>>, Box<dyn Error>> {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => op.layout(config, region, values),
|
||||
SupportedOp::Nonlinear(op) => op.layout(config, region, values),
|
||||
SupportedOp::Hybrid(op) => op.layout(config, region, values),
|
||||
SupportedOp::Input(op) => op.layout(config, region, values),
|
||||
SupportedOp::Constant(op) => op.layout(config, region, values),
|
||||
SupportedOp::Unknown(op) => op.layout(config, region, values),
|
||||
SupportedOp::Rescaled(op) => op.layout(config, region, values),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_input(&self) -> bool {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => Op::<Fp>::is_input(op),
|
||||
SupportedOp::Nonlinear(op) => Op::<Fp>::is_input(op),
|
||||
SupportedOp::Hybrid(op) => Op::<Fp>::is_input(op),
|
||||
SupportedOp::Input(op) => Op::<Fp>::is_input(op),
|
||||
SupportedOp::Constant(op) => Op::<Fp>::is_input(op),
|
||||
SupportedOp::Unknown(op) => Op::<Fp>::is_input(op),
|
||||
SupportedOp::Rescaled(op) => Op::<Fp>::is_input(op),
|
||||
}
|
||||
}
|
||||
|
||||
fn requires_homogenous_input_scales(&self) -> Vec<usize> {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => Op::<Fp>::requires_homogenous_input_scales(op),
|
||||
SupportedOp::Nonlinear(op) => Op::<Fp>::requires_homogenous_input_scales(op),
|
||||
SupportedOp::Hybrid(op) => Op::<Fp>::requires_homogenous_input_scales(op),
|
||||
SupportedOp::Input(op) => Op::<Fp>::requires_homogenous_input_scales(op),
|
||||
SupportedOp::Constant(op) => Op::<Fp>::requires_homogenous_input_scales(op),
|
||||
SupportedOp::Unknown(op) => Op::<Fp>::requires_homogenous_input_scales(op),
|
||||
SupportedOp::Rescaled(op) => Op::<Fp>::requires_homogenous_input_scales(op),
|
||||
}
|
||||
}
|
||||
|
||||
fn clone_dyn(&self) -> Box<dyn Op<Fp>> {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => Box::new(op.clone()),
|
||||
SupportedOp::Nonlinear(op) => Box::new(op.clone()),
|
||||
SupportedOp::Hybrid(op) => Box::new(op.clone()),
|
||||
SupportedOp::Input(op) => Box::new(op.clone()),
|
||||
SupportedOp::Constant(op) => Box::new(op.clone()),
|
||||
SupportedOp::Unknown(op) => Box::new(op.clone()),
|
||||
SupportedOp::Rescaled(op) => Box::new(op.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn as_string(&self) -> String {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => Op::<Fp>::as_string(op),
|
||||
SupportedOp::Nonlinear(op) => Op::<Fp>::as_string(op),
|
||||
SupportedOp::Hybrid(op) => Op::<Fp>::as_string(op),
|
||||
SupportedOp::Input(op) => Op::<Fp>::as_string(op),
|
||||
SupportedOp::Constant(op) => Op::<Fp>::as_string(op),
|
||||
SupportedOp::Unknown(op) => Op::<Fp>::as_string(op),
|
||||
SupportedOp::Rescaled(op) => Op::<Fp>::as_string(op),
|
||||
}
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
fn required_lookups(&self) -> Vec<LookupOp> {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => Op::<Fp>::required_lookups(op),
|
||||
SupportedOp::Nonlinear(op) => Op::<Fp>::required_lookups(op),
|
||||
SupportedOp::Hybrid(op) => Op::<Fp>::required_lookups(op),
|
||||
SupportedOp::Input(op) => Op::<Fp>::required_lookups(op),
|
||||
SupportedOp::Constant(op) => Op::<Fp>::required_lookups(op),
|
||||
SupportedOp::Unknown(op) => Op::<Fp>::required_lookups(op),
|
||||
SupportedOp::Rescaled(op) => Op::<Fp>::required_lookups(op),
|
||||
}
|
||||
}
|
||||
|
||||
fn rescale(&self, in_scales: Vec<u32>, out_scale: u32) -> Box<dyn Op<Fp>> {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => {
|
||||
let inputs_to_scale = self.requires_homogenous_input_scales();
|
||||
// creates a rescaled op if the inputs are not homogenous
|
||||
super::homogenize_input_scales(Box::new(op.clone()), in_scales, inputs_to_scale)
|
||||
.unwrap()
|
||||
}
|
||||
SupportedOp::Nonlinear(op) => op.rescale(in_scales, out_scale),
|
||||
SupportedOp::Hybrid(op) => op.rescale(in_scales, out_scale),
|
||||
SupportedOp::Input(op) => op.rescale(in_scales, out_scale),
|
||||
SupportedOp::Constant(op) => op.rescale(in_scales, out_scale),
|
||||
SupportedOp::Unknown(op) => op.rescale(in_scales, out_scale),
|
||||
SupportedOp::Rescaled(op) => op.rescale(in_scales, out_scale),
|
||||
}
|
||||
}
|
||||
|
||||
fn out_scale(&self, in_scales: Vec<u32>, global: u32) -> u32 {
|
||||
match self {
|
||||
SupportedOp::Linear(op) => Op::<Fp>::out_scale(op, in_scales, global),
|
||||
SupportedOp::Nonlinear(op) => Op::<Fp>::out_scale(op, in_scales, global),
|
||||
SupportedOp::Hybrid(op) => Op::<Fp>::out_scale(op, in_scales, global),
|
||||
SupportedOp::Input(op) => Op::<Fp>::out_scale(op, in_scales, global),
|
||||
SupportedOp::Constant(op) => Op::<Fp>::out_scale(op, in_scales, global),
|
||||
SupportedOp::Unknown(op) => Op::<Fp>::out_scale(op, in_scales, global),
|
||||
SupportedOp::Rescaled(op) => Op::<Fp>::out_scale(op, in_scales, global),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A single operation in a [crate::graph::Model].
|
||||
#[derive(Clone, Debug, Tabled, Serialize, Deserialize)]
|
||||
pub struct Node {
|
||||
/// [Op] i.e what operation this node represents.
|
||||
#[tabled(display_with = "display_opkind")]
|
||||
#[serde(with = "serde_traitobject")]
|
||||
pub opkind: Box<dyn Op<Fp>>,
|
||||
pub opkind: SupportedOp,
|
||||
/// The denominator in the fixed point representation for the node's output. Tensors of differing scales should not be combined.
|
||||
pub out_scale: u32,
|
||||
// Usually there is a simple in and out shape of the node as an operator. For example, an Affine node has three input_shapes (one for the input, weight, and bias),
|
||||
@@ -108,7 +379,7 @@ impl Node {
|
||||
inputs[idx].out_scales()[0]
|
||||
})
|
||||
.collect();
|
||||
opkind = opkind.rescale(in_scales.clone(), scale);
|
||||
opkind = opkind.rescale(in_scales.clone(), scale).into();
|
||||
let out_scale = match in_scales.len() {
|
||||
0 => scale,
|
||||
_ => opkind.out_scale(in_scales, scale),
|
||||
|
||||
@@ -1,12 +1,15 @@
|
||||
use std::error::Error;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::{GraphError, Visibility};
|
||||
use super::{GraphError, Rescaled, SupportedOp, Visibility};
|
||||
use crate::circuit::hybrid::HybridOp;
|
||||
use crate::circuit::lookup::LookupOp;
|
||||
use crate::circuit::poly::PolyOp;
|
||||
use crate::circuit::Op;
|
||||
use crate::tensor::{Tensor, TensorError, TensorType};
|
||||
use halo2curves::bn256::Fr as Fp;
|
||||
use halo2curves::ff::PrimeField;
|
||||
use itertools::Itertools;
|
||||
use log::{debug, warn};
|
||||
use tract_onnx::prelude::{DatumType, Node as OnnxNode, TypedFact, TypedOp};
|
||||
use tract_onnx::tract_core::ops::array::Gather;
|
||||
@@ -244,7 +247,7 @@ pub fn new_op_from_onnx(
|
||||
param_visibility: Visibility,
|
||||
node: OnnxNode<TypedFact, Box<dyn TypedOp>>,
|
||||
inputs: &mut Vec<&mut super::NodeType>,
|
||||
) -> Result<Box<dyn crate::circuit::Op<Fp>>, Box<dyn std::error::Error>> {
|
||||
) -> Result<SupportedOp, Box<dyn std::error::Error>> {
|
||||
debug!("Loading node: {:?}", node);
|
||||
Ok(match node.op().name().as_ref() {
|
||||
"Gather" => {
|
||||
@@ -273,7 +276,7 @@ pub fn new_op_from_onnx(
|
||||
inputs.pop();
|
||||
}
|
||||
|
||||
Box::new(crate::circuit::ops::poly::PolyOp::Gather { dim: axis, index })
|
||||
SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::Gather { dim: axis, index })
|
||||
}
|
||||
"MoveAxis" => {
|
||||
let op = load_axis_op(node.op(), idx, node.op().name().to_string())?;
|
||||
@@ -281,7 +284,7 @@ pub fn new_op_from_onnx(
|
||||
AxisOp::Move(from, to) => {
|
||||
let source = from.to_usize()?;
|
||||
let destination = to.to_usize()?;
|
||||
Box::new(crate::circuit::ops::poly::PolyOp::MoveAxis {
|
||||
SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::MoveAxis {
|
||||
source,
|
||||
destination,
|
||||
})
|
||||
@@ -292,7 +295,7 @@ pub fn new_op_from_onnx(
|
||||
"Concat" | "InferenceConcat" => {
|
||||
let op = load_concat_op(node.op(), idx, node.op().name().to_string())?;
|
||||
let axis = op.axis;
|
||||
Box::new(crate::circuit::ops::poly::PolyOp::Concat { axis })
|
||||
SupportedOp::Linear(crate::circuit::ops::poly::PolyOp::Concat { axis })
|
||||
}
|
||||
"Slice" => {
|
||||
let slice = load_slice_op(node.op(), node.op().name().to_string())?;
|
||||
@@ -301,7 +304,7 @@ pub fn new_op_from_onnx(
|
||||
let start = slice.start.to_usize()?;
|
||||
let end = slice.end.to_usize()?;
|
||||
|
||||
Box::new(PolyOp::Slice { axis, start, end })
|
||||
SupportedOp::Linear(PolyOp::Slice { axis, start, end })
|
||||
}
|
||||
"Const" => {
|
||||
let op: Const = load_const(node.op(), idx, node.op().name().to_string())?;
|
||||
@@ -317,7 +320,7 @@ pub fn new_op_from_onnx(
|
||||
let mut c = crate::circuit::ops::Constant::new(quantized_value, raw_value);
|
||||
c.num_uses += node.outputs.len();
|
||||
// Create a constant op
|
||||
Box::new(c)
|
||||
SupportedOp::Constant(c)
|
||||
}
|
||||
"Reduce<Min>" => {
|
||||
if inputs.len() != 1 {
|
||||
@@ -326,7 +329,7 @@ pub fn new_op_from_onnx(
|
||||
let op = load_reduce_op(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
|
||||
Box::new(HybridOp::Min { axes })
|
||||
SupportedOp::Hybrid(HybridOp::Min { axes })
|
||||
}
|
||||
"Reduce<Max>" => {
|
||||
if inputs.len() != 1 {
|
||||
@@ -335,7 +338,7 @@ pub fn new_op_from_onnx(
|
||||
let op = load_reduce_op(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
|
||||
Box::new(HybridOp::Max { axes })
|
||||
SupportedOp::Hybrid(HybridOp::Max { axes })
|
||||
}
|
||||
"Reduce<Sum>" => {
|
||||
if inputs.len() != 1 {
|
||||
@@ -344,7 +347,7 @@ pub fn new_op_from_onnx(
|
||||
let op = load_reduce_op(node.op(), idx, node.op().name().to_string())?;
|
||||
let axes = op.axes.into_iter().collect();
|
||||
|
||||
Box::new(PolyOp::Sum { axes })
|
||||
SupportedOp::Linear(PolyOp::Sum { axes })
|
||||
}
|
||||
"Max" => {
|
||||
// Extract the slope layer hyperparams
|
||||
@@ -364,7 +367,7 @@ pub fn new_op_from_onnx(
|
||||
node.decrement_const();
|
||||
inputs.pop();
|
||||
}
|
||||
Box::new(LookupOp::ReLU {
|
||||
SupportedOp::Nonlinear(LookupOp::ReLU {
|
||||
scale: inputs[0].out_scales()[0] as usize,
|
||||
})
|
||||
} else {
|
||||
@@ -373,7 +376,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Recip" => {
|
||||
// Extract the slope layer hyperparams
|
||||
Box::new(LookupOp::Recip { scale: 1 })
|
||||
SupportedOp::Nonlinear(LookupOp::Recip { scale: 1 })
|
||||
}
|
||||
|
||||
"LeakyRelu" => {
|
||||
@@ -390,7 +393,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
};
|
||||
|
||||
Box::new(LookupOp::LeakyReLU {
|
||||
SupportedOp::Nonlinear(LookupOp::LeakyReLU {
|
||||
scale: 1,
|
||||
slope: crate::circuit::utils::F32(leaky_op.alpha),
|
||||
})
|
||||
@@ -398,27 +401,27 @@ pub fn new_op_from_onnx(
|
||||
"Scan" => {
|
||||
panic!("should never reach here")
|
||||
}
|
||||
"Abs" => Box::new(HybridOp::Abs),
|
||||
"Neg" => Box::new(PolyOp::Neg),
|
||||
"Sigmoid" => Box::new(LookupOp::Sigmoid { scales: (1, 1) }),
|
||||
"Sqrt" => Box::new(LookupOp::Sqrt { scales: (1, 1) }),
|
||||
"Rsqrt" => Box::new(LookupOp::Rsqrt { scales: (1, 1) }),
|
||||
"Exp" => Box::new(LookupOp::Exp { scales: (1, 1) }),
|
||||
"Ln" => Box::new(LookupOp::Ln { scales: (1, 1) }),
|
||||
"Sin" => Box::new(LookupOp::Sin { scales: (1, 1) }),
|
||||
"Cos" => Box::new(LookupOp::Cos { scales: (1, 1) }),
|
||||
"Tan" => Box::new(LookupOp::Tan { scales: (1, 1) }),
|
||||
"Asin" => Box::new(LookupOp::ASin { scales: (1, 1) }),
|
||||
"Acos" => Box::new(LookupOp::ACos { scales: (1, 1) }),
|
||||
"Atan" => Box::new(LookupOp::ATan { scales: (1, 1) }),
|
||||
"Sinh" => Box::new(LookupOp::Sinh { scales: (1, 1) }),
|
||||
"Cosh" => Box::new(LookupOp::Cosh { scales: (1, 1) }),
|
||||
"Tanh" => Box::new(LookupOp::Tanh { scales: (1, 1) }),
|
||||
"Asinh" => Box::new(LookupOp::ASinh { scales: (1, 1) }),
|
||||
"Acosh" => Box::new(LookupOp::ACosh { scales: (1, 1) }),
|
||||
"Atanh" => Box::new(LookupOp::ATanh { scales: (1, 1) }),
|
||||
"Erf" => Box::new(LookupOp::Erf { scales: (1, 1) }),
|
||||
"Source" => Box::new(crate::circuit::ops::Input { scale }),
|
||||
"Abs" => SupportedOp::Hybrid(HybridOp::Abs),
|
||||
"Neg" => SupportedOp::Linear(PolyOp::Neg),
|
||||
"Sigmoid" => SupportedOp::Nonlinear(LookupOp::Sigmoid { scales: (1, 1) }),
|
||||
"Sqrt" => SupportedOp::Nonlinear(LookupOp::Sqrt { scales: (1, 1) }),
|
||||
"Rsqrt" => SupportedOp::Nonlinear(LookupOp::Rsqrt { scales: (1, 1) }),
|
||||
"Exp" => SupportedOp::Nonlinear(LookupOp::Exp { scales: (1, 1) }),
|
||||
"Ln" => SupportedOp::Nonlinear(LookupOp::Ln { scales: (1, 1) }),
|
||||
"Sin" => SupportedOp::Nonlinear(LookupOp::Sin { scales: (1, 1) }),
|
||||
"Cos" => SupportedOp::Nonlinear(LookupOp::Cos { scales: (1, 1) }),
|
||||
"Tan" => SupportedOp::Nonlinear(LookupOp::Tan { scales: (1, 1) }),
|
||||
"Asin" => SupportedOp::Nonlinear(LookupOp::ASin { scales: (1, 1) }),
|
||||
"Acos" => SupportedOp::Nonlinear(LookupOp::ACos { scales: (1, 1) }),
|
||||
"Atan" => SupportedOp::Nonlinear(LookupOp::ATan { scales: (1, 1) }),
|
||||
"Sinh" => SupportedOp::Nonlinear(LookupOp::Sinh { scales: (1, 1) }),
|
||||
"Cosh" => SupportedOp::Nonlinear(LookupOp::Cosh { scales: (1, 1) }),
|
||||
"Tanh" => SupportedOp::Nonlinear(LookupOp::Tanh { scales: (1, 1) }),
|
||||
"Asinh" => SupportedOp::Nonlinear(LookupOp::ASinh { scales: (1, 1) }),
|
||||
"Acosh" => SupportedOp::Nonlinear(LookupOp::ACosh { scales: (1, 1) }),
|
||||
"Atanh" => SupportedOp::Nonlinear(LookupOp::ATanh { scales: (1, 1) }),
|
||||
"Erf" => SupportedOp::Nonlinear(LookupOp::Erf { scales: (1, 1) }),
|
||||
"Source" => SupportedOp::Input(crate::circuit::ops::Input { scale }),
|
||||
"Add" => {
|
||||
// get the max scale of inputs
|
||||
let max_scale = inputs
|
||||
@@ -434,12 +437,12 @@ pub fn new_op_from_onnx(
|
||||
log::debug!("requantizing #{} to {} for add", inp.idx(), max_scale);
|
||||
n.requantize(max_scale)?;
|
||||
inp.bump_scale(max_scale);
|
||||
inp.replace_opkind(Box::new(n));
|
||||
inp.replace_opkind(SupportedOp::Constant(n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Box::new(PolyOp::Add)
|
||||
SupportedOp::Linear(PolyOp::Add)
|
||||
}
|
||||
"Sub" => {
|
||||
// get the max scale of inputs
|
||||
@@ -456,15 +459,15 @@ pub fn new_op_from_onnx(
|
||||
log::debug!("requantizing #{} to {} fro sub", inp.idx(), max_scale);
|
||||
n.requantize(max_scale)?;
|
||||
inp.bump_scale(max_scale);
|
||||
inp.replace_opkind(Box::new(n));
|
||||
inp.replace_opkind(SupportedOp::Constant(n));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Box::new(PolyOp::Sub)
|
||||
SupportedOp::Linear(PolyOp::Sub)
|
||||
}
|
||||
"Mul" => Box::new(PolyOp::Mult),
|
||||
"Iff" => Box::new(PolyOp::Iff),
|
||||
"Mul" => SupportedOp::Linear(PolyOp::Mult),
|
||||
"Iff" => SupportedOp::Linear(PolyOp::Iff),
|
||||
"Less" => {
|
||||
// Extract the slope layer hyperparams
|
||||
let boxed_op = inputs[0].clone().opkind();
|
||||
@@ -480,7 +483,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
if inputs.len() == 2 {
|
||||
inputs.remove(0);
|
||||
Box::new(LookupOp::LessThan {
|
||||
SupportedOp::Nonlinear(LookupOp::LessThan {
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
} else {
|
||||
@@ -502,7 +505,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
if inputs.len() == 2 {
|
||||
inputs.remove(0);
|
||||
Box::new(LookupOp::GreaterThan {
|
||||
SupportedOp::Nonlinear(LookupOp::GreaterThan {
|
||||
a: crate::circuit::utils::F32(unit),
|
||||
})
|
||||
} else {
|
||||
@@ -519,7 +522,7 @@ pub fn new_op_from_onnx(
|
||||
};
|
||||
|
||||
let axes = &op.axes;
|
||||
Box::new(PolyOp::Einsum {
|
||||
SupportedOp::Linear(PolyOp::Einsum {
|
||||
equation: axes.to_string(),
|
||||
})
|
||||
}
|
||||
@@ -540,7 +543,7 @@ pub fn new_op_from_onnx(
|
||||
)));
|
||||
}
|
||||
|
||||
Box::new(HybridOp::Softmax { scales: (1, 1) })
|
||||
SupportedOp::Hybrid(HybridOp::Softmax { scales: (1, 1) })
|
||||
}
|
||||
"MaxPool" => {
|
||||
// Extract the padding and stride layer hyperparams
|
||||
@@ -574,7 +577,7 @@ pub fn new_op_from_onnx(
|
||||
(padding[0], padding[1], stride[0], stride[1]);
|
||||
let (kernel_height, kernel_width) = (kernel_shape[0], kernel_shape[1]);
|
||||
|
||||
Box::new(HybridOp::MaxPool2d {
|
||||
SupportedOp::Hybrid(HybridOp::MaxPool2d {
|
||||
padding: (padding_h, padding_w),
|
||||
stride: (stride_h, stride_w),
|
||||
pool_dims: (kernel_height, kernel_width),
|
||||
@@ -582,11 +585,11 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
"Ceil" | "Floor" | "Round" | "RoundHalfToEven" => {
|
||||
warn!("using a round op in the circuit which does not make sense in Field arithmetic");
|
||||
Box::new(PolyOp::Identity)
|
||||
SupportedOp::Linear(PolyOp::Identity)
|
||||
}
|
||||
"Sign" => Box::new(LookupOp::Sign),
|
||||
"Cube" => Box::new(PolyOp::Pow(3)),
|
||||
"Square" => Box::new(PolyOp::Pow(2)),
|
||||
"Sign" => SupportedOp::Nonlinear(LookupOp::Sign),
|
||||
"Cube" => SupportedOp::Linear(PolyOp::Pow(3)),
|
||||
"Square" => SupportedOp::Linear(PolyOp::Pow(2)),
|
||||
"ConvUnary" => {
|
||||
let conv_node: &ConvUnary = match node.op().downcast_ref::<ConvUnary>() {
|
||||
Some(b) => b,
|
||||
@@ -646,7 +649,7 @@ pub fn new_op_from_onnx(
|
||||
None => None,
|
||||
};
|
||||
|
||||
Box::new(PolyOp::Conv {
|
||||
SupportedOp::Linear(PolyOp::Conv {
|
||||
kernel,
|
||||
bias,
|
||||
padding: (padding_h, padding_w),
|
||||
@@ -712,7 +715,7 @@ pub fn new_op_from_onnx(
|
||||
|
||||
let output_padding = (deconv_node.adjustments[0], deconv_node.adjustments[1]);
|
||||
|
||||
Box::new(PolyOp::DeConv {
|
||||
SupportedOp::Linear(PolyOp::DeConv {
|
||||
kernel,
|
||||
bias,
|
||||
padding: (padding_h, padding_w),
|
||||
@@ -731,7 +734,7 @@ pub fn new_op_from_onnx(
|
||||
}
|
||||
};
|
||||
|
||||
Box::new(PolyOp::Downsample {
|
||||
SupportedOp::Linear(PolyOp::Downsample {
|
||||
axis: downsample_node.axis,
|
||||
stride: downsample_node.stride as usize,
|
||||
modulo: downsample_node.modulo,
|
||||
@@ -768,7 +771,7 @@ pub fn new_op_from_onnx(
|
||||
inputs.pop();
|
||||
}
|
||||
|
||||
Box::new(PolyOp::Resize { scale_factor })
|
||||
SupportedOp::Linear(PolyOp::Resize { scale_factor })
|
||||
}
|
||||
|
||||
"SumPool" => {
|
||||
@@ -803,13 +806,13 @@ pub fn new_op_from_onnx(
|
||||
(padding[0], padding[1], stride[0], stride[1]);
|
||||
let (kernel_height, kernel_width) = (kernel_shape[0], kernel_shape[1]);
|
||||
|
||||
Box::new(PolyOp::SumPool {
|
||||
SupportedOp::Linear(PolyOp::SumPool {
|
||||
padding: (padding_h, padding_w),
|
||||
stride: (stride_h, stride_w),
|
||||
kernel_shape: (kernel_height, kernel_width),
|
||||
})
|
||||
}
|
||||
"GlobalAvgPool" => Box::new(PolyOp::SumPool {
|
||||
"GlobalAvgPool" => SupportedOp::Linear(PolyOp::SumPool {
|
||||
padding: (0, 0),
|
||||
stride: (1, 1),
|
||||
kernel_shape: (inputs[0].out_dims()[0][1], inputs[0].out_dims()[0][2]),
|
||||
@@ -853,22 +856,22 @@ pub fn new_op_from_onnx(
|
||||
pad_node.pads[padding_len - 2].0,
|
||||
pad_node.pads[padding_len - 1].0,
|
||||
);
|
||||
Box::new(PolyOp::Pad(padding_h, padding_w))
|
||||
SupportedOp::Linear(PolyOp::Pad(padding_h, padding_w))
|
||||
}
|
||||
"RmAxis" | "Reshape" | "AddAxis" => {
|
||||
// Extract the slope layer hyperparams
|
||||
let shapes = node_output_shapes(&node)?;
|
||||
let output_shape = shapes[0].as_ref().unwrap().clone();
|
||||
|
||||
Box::new(PolyOp::Reshape(output_shape))
|
||||
SupportedOp::Linear(PolyOp::Reshape(output_shape))
|
||||
}
|
||||
"Flatten" => {
|
||||
let new_dims: Vec<usize> = vec![inputs[0].out_dims()[0].iter().product::<usize>()];
|
||||
Box::new(PolyOp::Flatten(new_dims))
|
||||
SupportedOp::Linear(PolyOp::Flatten(new_dims))
|
||||
}
|
||||
c => {
|
||||
warn!("Unknown op: {}", c);
|
||||
Box::new(crate::circuit::ops::Unknown)
|
||||
SupportedOp::Unknown(crate::circuit::ops::Unknown)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -879,7 +882,8 @@ pub fn downcast_const_op(
|
||||
) -> Option<crate::circuit::ops::Constant<Fp>> {
|
||||
boxed_op
|
||||
.as_any()
|
||||
.downcast_ref::<crate::circuit::ops::Constant<Fp>>().cloned()
|
||||
.downcast_ref::<crate::circuit::ops::Constant<Fp>>()
|
||||
.cloned()
|
||||
}
|
||||
|
||||
/// Extracts the raw values from a [crate::circuit::ops::Constant] op.
|
||||
@@ -947,6 +951,46 @@ pub(crate) fn split_valtensor(
|
||||
Ok(tensors)
|
||||
}
|
||||
|
||||
///
|
||||
pub fn homogenize_input_scales(
|
||||
op: Box<dyn Op<Fp>>,
|
||||
input_scales: Vec<u32>,
|
||||
inputs_to_scale: Vec<usize>,
|
||||
) -> Result<Box<dyn Op<Fp>>, Box<dyn Error>> {
|
||||
if inputs_to_scale.is_empty() {
|
||||
return Ok(op);
|
||||
}
|
||||
|
||||
let mut dividers: Vec<u128> = vec![1; input_scales.len()];
|
||||
if !input_scales.windows(2).all(|w| w[0] == w[1]) {
|
||||
let min_scale = input_scales.iter().min().unwrap();
|
||||
let _ = input_scales
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, input_scale)| {
|
||||
if !inputs_to_scale.contains(&idx) {
|
||||
return;
|
||||
}
|
||||
let scale_diff = input_scale - min_scale;
|
||||
if scale_diff > 0 {
|
||||
let mult = crate::graph::scale_to_multiplier(scale_diff);
|
||||
dividers[idx] = mult as u128;
|
||||
}
|
||||
})
|
||||
.collect_vec();
|
||||
}
|
||||
|
||||
// only rescale if need to
|
||||
if dividers.iter().any(|&x| x > 1) {
|
||||
Ok(Box::new(Rescaled {
|
||||
inner: Box::new(op.into()),
|
||||
scale: (0..input_scales.len()).zip(dividers).collect_vec(),
|
||||
}))
|
||||
} else {
|
||||
Ok(op)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod tests {
|
||||
|
||||
|
||||
@@ -6,7 +6,9 @@ mod native_tests {
|
||||
use ezkl::graph::input::{FileSource, GraphData};
|
||||
use ezkl::graph::{DataSource, GraphSettings, Visibility};
|
||||
use lazy_static::lazy_static;
|
||||
use rand::Rng;
|
||||
use std::env::var;
|
||||
use std::io::{Read, Write};
|
||||
use std::process::{Child, Command};
|
||||
use std::sync::Once;
|
||||
static COMPILE: Once = Once::new();
|
||||
@@ -304,6 +306,7 @@ mod native_tests {
|
||||
use crate::native_tests::kzg_fuzz;
|
||||
use crate::native_tests::render_circuit;
|
||||
use crate::native_tests::model_serialization;
|
||||
use crate::native_tests::model_serialization_different_binaries;
|
||||
use crate::native_tests::tutorial as run_tutorial;
|
||||
use tempdir::TempDir;
|
||||
|
||||
@@ -331,6 +334,18 @@ mod native_tests {
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn model_serialization_different_binaries_(test: &str) {
|
||||
let test_dir = TempDir::new(test).unwrap();
|
||||
let path = test_dir.path().to_str().unwrap();
|
||||
crate::native_tests::mv_test_(path, test);
|
||||
// percent tolerance test
|
||||
model_serialization_different_binaries(path, test.to_string());
|
||||
test_dir.close().unwrap();
|
||||
}
|
||||
|
||||
|
||||
|
||||
#(#[test_case(TESTS[N])])*
|
||||
fn render_circuit_(test: &str) {
|
||||
crate::native_tests::init_binary();
|
||||
@@ -817,6 +832,116 @@ mod native_tests {
|
||||
assert_eq!(model, loaded_model)
|
||||
}
|
||||
|
||||
fn model_serialization_different_binaries(test_dir: &str, example_name: String) {
|
||||
let status = Command::new("cargo")
|
||||
.args([
|
||||
"run",
|
||||
"--bin",
|
||||
"ezkl",
|
||||
"--",
|
||||
"gen-settings",
|
||||
"-M",
|
||||
format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
|
||||
&format!(
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new("cargo")
|
||||
.args([
|
||||
"run",
|
||||
"--bin",
|
||||
"ezkl",
|
||||
"--",
|
||||
"compile-model",
|
||||
"-M",
|
||||
format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
|
||||
"--compiled-model",
|
||||
format!("{}/{}/network.onnx", test_dir, example_name).as_str(),
|
||||
&format!(
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// now alter binary slightly
|
||||
// create new temp cargo.toml with a different version
|
||||
// cpy old cargo.toml to cargo.toml.bak
|
||||
let status = Command::new("cp")
|
||||
.args(["Cargo.toml", "Cargo.toml.bak"])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let mut cargo_toml = std::fs::File::open("Cargo.toml").unwrap();
|
||||
let mut cargo_toml_contents = String::new();
|
||||
cargo_toml.read_to_string(&mut cargo_toml_contents).unwrap();
|
||||
let mut cargo_toml_contents = cargo_toml_contents.split("\n").collect::<Vec<_>>();
|
||||
|
||||
// draw a random version number from 0.0.0 to 0.100.100
|
||||
let mut rng = rand::thread_rng();
|
||||
let version = &format!(
|
||||
"version = \"0.{}.{}-test\"",
|
||||
rng.gen_range(0..100),
|
||||
rng.gen_range(0..100)
|
||||
);
|
||||
let cargo_toml_contents = cargo_toml_contents
|
||||
.iter_mut()
|
||||
.map(|line| {
|
||||
if line.starts_with("version") {
|
||||
*line = version;
|
||||
}
|
||||
*line
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let mut cargo_toml = std::fs::File::create("Cargo.toml").unwrap();
|
||||
cargo_toml
|
||||
.write_all(cargo_toml_contents.join("\n").as_bytes())
|
||||
.unwrap();
|
||||
|
||||
let status = Command::new("cargo")
|
||||
.args([
|
||||
"run",
|
||||
"--bin",
|
||||
"ezkl",
|
||||
"--",
|
||||
"gen-witness",
|
||||
"-D",
|
||||
&format!("{}/{}/input.json", test_dir, example_name),
|
||||
"-M",
|
||||
&format!("{}/{}/network.onnx", test_dir, example_name),
|
||||
"-O",
|
||||
&format!("{}/{}/witness.json", test_dir, example_name),
|
||||
&format!(
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
// now delete cargo.toml and move cargo.toml.bak to cargo.toml
|
||||
let status = Command::new("rm")
|
||||
.args(["Cargo.toml"])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new("mv")
|
||||
.args(["Cargo.toml.bak", "Cargo.toml"])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
}
|
||||
|
||||
// Mock prove (fast, but does not cover some potential issues)
|
||||
fn neg_mock(test_dir: &str, example_name: String, counter_example: String) {
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
@@ -849,11 +974,29 @@ mod native_tests {
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"gen-witness",
|
||||
"-D",
|
||||
&format!("{}/{}/input.json", test_dir, example_name),
|
||||
"-M",
|
||||
&format!("{}/{}/network.onnx", test_dir, example_name),
|
||||
"-O",
|
||||
&format!("{}/{}/witness.json", test_dir, example_name),
|
||||
&format!(
|
||||
"--settings-path={}/{}/settings.json",
|
||||
test_dir, example_name
|
||||
),
|
||||
])
|
||||
.status()
|
||||
.expect("failed to execute process");
|
||||
assert!(status.success());
|
||||
|
||||
let status = Command::new(format!("{}/release/ezkl", *CARGO_TARGET_DIR))
|
||||
.args([
|
||||
"mock",
|
||||
"-W",
|
||||
format!("{}/{}/input.json", test_dir, counter_example).as_str(),
|
||||
format!("{}/{}/witness.json", test_dir, counter_example).as_str(),
|
||||
"-M",
|
||||
format!("{}/{}/network.compiled", test_dir, example_name).as_str(),
|
||||
&format!(
|
||||
|
||||
Reference in New Issue
Block a user