chore: Move concrete-cpu backend to the concrete mono repository

The main contributors of concrete-cpu are

Co-authored-by: Mayeul@Zama <mayeul.debellabre@zama.ai>
Co-authored-by: sarah <sarah.elkazdadi@zama.ai>
This commit is contained in:
Quentin Bourgerie
2023-03-03 16:01:38 +01:00
parent f65d6e0b44
commit 06d3c316e7
53 changed files with 9326 additions and 11 deletions

6
.gitmodules vendored
View File

@@ -13,8 +13,4 @@
[submodule "compiler/parameter-curves"]
path = compilers/concrete-compiler/compiler/parameter-curves
url = git@github.com:zama-ai/parameter-curves.git
shallow = true
[submodule "compiler/concrete-cpu"]
path = compilers/concrete-compiler/compiler/concrete-cpu
url = git@github.com:zama-ai/concrete-cpu.git
shallow = true
shallow = true

962
backends/concrete-cpu/Cargo.lock generated Normal file
View File

@@ -0,0 +1,962 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "aes-soft"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be14c7498ea50828a38d0e24a765ed2effe92a705885b57d029cd67d45744072"
dependencies = [
"cipher",
"opaque-debug",
]
[[package]]
name = "aligned-vec"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4aa90d7ce82d4be67b64039a3d588d38dbcc6736577de4a847025ce5b0c468d1"
[[package]]
name = "anes"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
[[package]]
name = "atty"
version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [
"hermit-abi 0.1.19",
"libc",
"winapi",
]
[[package]]
name = "autocfg"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bumpalo"
version = "3.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0d261e256854913907f67ed06efbc3338dfe6179796deefc1ff763fc1aee5535"
[[package]]
name = "bytemuck"
version = "1.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17febce684fd15d89027105661fec94afb475cb995fbc59d2865198446ba2eea"
[[package]]
name = "cast"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
[[package]]
name = "cbindgen"
version = "0.24.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a6358dedf60f4d9b8db43ad187391afe959746101346fe51bb978126bec61dfb"
dependencies = [
"clap",
"heck",
"indexmap",
"log",
"proc-macro2",
"quote",
"serde",
"serde_json",
"syn",
"tempfile",
"toml",
]
[[package]]
name = "cc"
version = "1.0.79"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "ciborium"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0c137568cc60b904a7724001b35ce2630fd00d5d84805fbb608ab89509d788f"
dependencies = [
"ciborium-io",
"ciborium-ll",
"serde",
]
[[package]]
name = "ciborium-io"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "346de753af073cc87b52b2083a506b38ac176a44cfb05497b622e27be899b369"
[[package]]
name = "ciborium-ll"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "213030a2b5a4e0c0892b6652260cf6ccac84827b83a85a534e178e3906c4cf1b"
dependencies = [
"ciborium-io",
"half",
]
[[package]]
name = "cipher"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "12f8e7987cbd042a63249497f41aed09f8e65add917ea6566effbc56578d6801"
dependencies = [
"generic-array",
]
[[package]]
name = "clap"
version = "3.2.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "71655c45cb9845d3270c9d6df84ebe72b4dad3c2ba3f7023ad47c144e4e473a5"
dependencies = [
"atty",
"bitflags",
"clap_lex",
"indexmap",
"strsim",
"termcolor",
"textwrap",
]
[[package]]
name = "clap_lex"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2850f2f5a82cbf437dd5af4d49848fbdfc27c157c3d010345776f952765261c5"
dependencies = [
"os_str_bytes",
]
[[package]]
name = "concrete-cpu"
version = "0.1.0"
dependencies = [
"aligned-vec",
"bytemuck",
"cbindgen",
"concrete-csprng",
"concrete-fft",
"criterion",
"dyn-stack",
"libc",
"num-complex",
"once_cell",
"pulp",
"rayon",
"readonly",
]
[[package]]
name = "concrete-csprng"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36a67ae02f50fdf3b3b8a43feb8c83f0cd6554b6449086d3df13bbf330641f78"
dependencies = [
"aes-soft",
"libc",
]
[[package]]
name = "concrete-fft"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55f4c28b64f48194903ef4415b1cbcb0ba3dc3ea06c77a2a89110363d642033b"
dependencies = [
"aligned-vec",
"dyn-stack",
"num-complex",
]
[[package]]
name = "criterion"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7c76e09c1aae2bc52b3d2f29e13c6572553b30c4aa1b8a49fd70de6412654cb"
dependencies = [
"anes",
"atty",
"cast",
"ciborium",
"clap",
"criterion-plot",
"itertools",
"lazy_static",
"num-traits",
"oorandom",
"plotters",
"rayon",
"regex",
"serde",
"serde_derive",
"serde_json",
"tinytemplate",
"walkdir",
]
[[package]]
name = "criterion-plot"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
dependencies = [
"cast",
"itertools",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf2b3e8478797446514c91ef04bafcb59faba183e621ad488df88983cc14128c"
dependencies = [
"cfg-if",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
dependencies = [
"cfg-if",
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46bd5f3f85273295a9d14aedfb86f6aadbff6d8f5295c4a9edb08e819dcf5695"
dependencies = [
"autocfg",
"cfg-if",
"crossbeam-utils",
"memoffset",
"scopeguard",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c063cd8cc95f5c377ed0d4b49a4b21f632396ff690e8470c29b3359b346984b"
dependencies = [
"cfg-if",
]
[[package]]
name = "dyn-stack"
version = "0.8.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2c63c3199ff7026af3c8149907c5677a6d3431e1d3cd16b7fffbfcf158b788ca"
dependencies = [
"reborrow",
]
[[package]]
name = "either"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91"
[[package]]
name = "errno"
version = "0.2.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f639046355ee4f37944e44f60642c6f3a7efa3cf6b78c78a0d989a8ce6c396a1"
dependencies = [
"errno-dragonfly",
"libc",
"winapi",
]
[[package]]
name = "errno-dragonfly"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "aa68f1b12764fab894d2755d2518754e71b4fd80ecfb822714a1206c2aab39bf"
dependencies = [
"cc",
"libc",
]
[[package]]
name = "fastrand"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e51093e27b0797c359783294ca4f0a911c270184cb10f85783b118614a1501be"
dependencies = [
"instant",
]
[[package]]
name = "generic-array"
version = "0.14.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9"
dependencies = [
"typenum",
"version_check",
]
[[package]]
name = "half"
version = "1.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7"
[[package]]
name = "hashbrown"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "heck"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "hermit-abi"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62b467343b94ba476dcb2500d242dadbb39557df889310ac77c5d99100aaac33"
dependencies = [
"libc",
]
[[package]]
name = "hermit-abi"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee512640fe35acbfb4bb779db6f0d80704c2cacfa2e39b601ef3e3f47d1ae4c7"
dependencies = [
"libc",
]
[[package]]
name = "indexmap"
version = "1.9.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1885e79c1fc4b10f0e172c475f458b7f7b93061064d98c3293e98c5ba0c8b399"
dependencies = [
"autocfg",
"hashbrown",
]
[[package]]
name = "instant"
version = "0.1.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a5bbe824c507c5da5956355e86a746d82e0e1464f65d862cc5e71da70e94b2c"
dependencies = [
"cfg-if",
]
[[package]]
name = "io-lifetimes"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1abeb7a0dd0f8181267ff8adc397075586500b81b28a73e8a0208b00fc170fb3"
dependencies = [
"libc",
"windows-sys 0.45.0",
]
[[package]]
name = "itertools"
version = "0.10.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fad582f4b9e86b6caa621cabeb0963332d92eea04729ab12892c2533951e6440"
[[package]]
name = "js-sys"
version = "0.3.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "445dde2150c55e483f3d8416706b97ec8e8237c307e5b7b4b8dd15e6af2a0730"
dependencies = [
"wasm-bindgen",
]
[[package]]
name = "lazy_static"
version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646"
[[package]]
name = "libc"
version = "0.2.139"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "201de327520df007757c1f0adce6e827fe8562fbc28bfd9c15571c66ca1f5f79"
[[package]]
name = "linux-raw-sys"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f051f77a7c8e6957c0696eac88f26b0117e54f52d3fc682ab19397a8812846a4"
[[package]]
name = "log"
version = "0.4.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "abb12e687cfb44aa40f41fc3978ef76448f9b6038cad6aef4259d3c095a2382e"
dependencies = [
"cfg-if",
]
[[package]]
name = "memoffset"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d61c719bcfbcf5d62b3a09efa6088de8c54bc0bfcd3ea7ae39fcc186108b8de1"
dependencies = [
"autocfg",
]
[[package]]
name = "num-complex"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02e0d21255c828d6f128a1e41534206671e8c3ea0c62f32291e808dc82cff17d"
dependencies = [
"bytemuck",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "578ede34cf02f8924ab9447f50c28075b4d3e5b269972345e7e0372b38c6cdcd"
dependencies = [
"autocfg",
]
[[package]]
name = "num_cpus"
version = "1.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fac9e2da13b5eb447a6ce3d392f23a29d8694bff781bf03a16cd9ac8697593b"
dependencies = [
"hermit-abi 0.2.6",
"libc",
]
[[package]]
name = "once_cell"
version = "1.17.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7e5500299e16ebb147ae15a00a942af264cf3688f47923b8fc2cd5858f23ad3"
[[package]]
name = "oorandom"
version = "11.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
[[package]]
name = "opaque-debug"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]]
name = "os_str_bytes"
version = "6.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b7820b9daea5457c9f21c69448905d723fbd21136ccf521748f23fd49e723ee"
[[package]]
name = "plotters"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2538b639e642295546c50fcd545198c9d64ee2a38620a628724a3b266d5fbf97"
dependencies = [
"num-traits",
"plotters-backend",
"plotters-svg",
"wasm-bindgen",
"web-sys",
]
[[package]]
name = "plotters-backend"
version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "193228616381fecdc1224c62e96946dfbc73ff4384fba576e052ff8c1bea8142"
[[package]]
name = "plotters-svg"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f9a81d2759aae1dae668f783c308bc5c8ebd191ff4184aaa1b37f65a6ae5a56f"
dependencies = [
"plotters-backend",
]
[[package]]
name = "proc-macro2"
version = "1.0.51"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d727cae5b39d21da60fa540906919ad737832fe0b1c165da3a34d6548c849d6"
dependencies = [
"unicode-ident",
]
[[package]]
name = "pulp"
version = "0.10.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "284c392c810680912400c6f70879a8cde404344db6b68ff52cc3990c020324d1"
dependencies = [
"bytemuck",
]
[[package]]
name = "quote"
version = "1.0.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8856d8364d252a14d474036ea1358d63c9e6965c8e5c1885c18f73d70bff9c7b"
dependencies = [
"proc-macro2",
]
[[package]]
name = "rayon"
version = "1.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6db3a213adf02b3bcfd2d3846bb41cb22857d131789e01df434fb7e7bc0759b7"
dependencies = [
"either",
"rayon-core",
]
[[package]]
name = "rayon-core"
version = "1.10.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "356a0625f1954f730c0201cdab48611198dc6ce21f4acff55089b5a78e6e835b"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-utils",
"num_cpus",
]
[[package]]
name = "readonly"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d78725e4e53781014168628ef49b2dc2fc6ae8d01a08769a5064685d34ee116c"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "reborrow"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd0aeb369494c9cde5aaaccb26bb493d83339a9a3e0e22e15ed307715a2e68b4"
[[package]]
name = "redox_syscall"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb5a58c1855b4b6819d59012155603f0b22ad30cad752600aadfcb695265519a"
dependencies = [
"bitflags",
]
[[package]]
name = "regex"
version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "48aaa5748ba571fb95cd2c85c09f629215d3a6ece942baa100950af03a34f733"
dependencies = [
"regex-syntax",
]
[[package]]
name = "regex-syntax"
version = "0.6.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "456c603be3e8d448b072f410900c09faf164fbce2d480456f50eea6e25f9c848"
[[package]]
name = "rustix"
version = "0.36.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f43abb88211988493c1abb44a70efa56ff0ce98f233b7b276146f1f3f7ba9644"
dependencies = [
"bitflags",
"errno",
"io-lifetimes",
"libc",
"linux-raw-sys",
"windows-sys 0.45.0",
]
[[package]]
name = "ryu"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde"
[[package]]
name = "same-file"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
dependencies = [
"winapi-util",
]
[[package]]
name = "scopeguard"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "serde"
version = "1.0.152"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.152"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af487d118eecd09402d70a5d72551860e788df87b464af30e5ea6a38c75c541e"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cad406b69c91885b5107daf2c29572f6c8cdb3c66826821e286c533490c0bc76"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "strsim"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "tempfile"
version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "af18f7ae1acd354b992402e9ec5864359d693cd8a79dcbef59f76891701c1e95"
dependencies = [
"cfg-if",
"fastrand",
"redox_syscall",
"rustix",
"windows-sys 0.42.0",
]
[[package]]
name = "termcolor"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6"
dependencies = [
"winapi-util",
]
[[package]]
name = "textwrap"
version = "0.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "222a222a5bfe1bba4a77b45ec488a741b3cb8872e5e499451fd7d0129c9c7c3d"
[[package]]
name = "tinytemplate"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
dependencies = [
"serde",
"serde_json",
]
[[package]]
name = "toml"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f4f7f0dd8d50a853a531c426359045b1998f04219d88799810762cd4ad314234"
dependencies = [
"serde",
]
[[package]]
name = "typenum"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba"
[[package]]
name = "unicode-ident"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "84a22b9f218b40614adcb3f4ff08b703773ad44fa9423e4e0d346d5db86e4ebc"
[[package]]
name = "version_check"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
[[package]]
name = "walkdir"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "808cf2735cd4b6866113f648b791c6adc5714537bc222d9347bb203386ffda56"
dependencies = [
"same-file",
"winapi",
"winapi-util",
]
[[package]]
name = "wasm-bindgen"
version = "0.2.84"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31f8dcbc21f30d9b8f2ea926ecb58f6b91192c17e9d33594b3df58b2007ca53b"
dependencies = [
"cfg-if",
"wasm-bindgen-macro",
]
[[package]]
name = "wasm-bindgen-backend"
version = "0.2.84"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95ce90fd5bcc06af55a641a86428ee4229e44e07033963a2290a8e241607ccb9"
dependencies = [
"bumpalo",
"log",
"once_cell",
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-macro"
version = "0.2.84"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c21f77c0bedc37fd5dc21f897894a5ca01e7bb159884559461862ae90c0b4c5"
dependencies = [
"quote",
"wasm-bindgen-macro-support",
]
[[package]]
name = "wasm-bindgen-macro-support"
version = "0.2.84"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2aff81306fcac3c7515ad4e177f521b5c9a15f2b08f4e32d823066102f35a5f6"
dependencies = [
"proc-macro2",
"quote",
"syn",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
[[package]]
name = "wasm-bindgen-shared"
version = "0.2.84"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0046fef7e28c3804e5e38bfa31ea2a0f73905319b677e57ebe37e49358989b5d"
[[package]]
name = "web-sys"
version = "0.3.61"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e33b99f4b23ba3eec1a53ac264e35a755f00e966e0065077d6027c0f575b0b97"
dependencies = [
"js-sys",
"wasm-bindgen",
]
[[package]]
name = "winapi"
version = "0.3.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
dependencies = [
"winapi-i686-pc-windows-gnu",
"winapi-x86_64-pc-windows-gnu",
]
[[package]]
name = "winapi-i686-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
dependencies = [
"winapi",
]
[[package]]
name = "winapi-x86_64-pc-windows-gnu"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
[[package]]
name = "windows-sys"
version = "0.42.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a3e1820f08b8513f676f7ab6c1f99ff312fb97b553d30ff4dd86f9f15728aa7"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows-sys"
version = "0.45.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75283be5efb2831d37ea142365f009c02ec203cd29a3ebecbc093d52315b66d0"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-targets"
version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e2522491fbfcd58cc84d47aeb2958948c4b8982e9a2d8a2a35bbaed431390e7"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c9864e83243fdec7fc9c5444389dcbbfd258f745e7853198f365e3c4968a608"
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c8b1b673ffc16c47a9ff48570a9d85e25d265735c503681332589af6253c6c7"
[[package]]
name = "windows_i686_gnu"
version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de3887528ad530ba7bdbb1faa8275ec7a1155a45ffa57c37993960277145d640"
[[package]]
name = "windows_i686_msvc"
version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf4d1122317eddd6ff351aa852118a2418ad4214e6613a50e0191f7004372605"
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1040f221285e17ebccbc2591ffdc2d44ee1f9186324dd3e84e99ac68d699c45"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "628bfdf232daa22b0d64fdb62b09fcc36bb01f05a3939e20ab73aaf9470d0463"
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "447660ad36a13288b1db4d4248e857b510e8c3a225c822ba4fb748c0aafecffd"

View File

@@ -0,0 +1,57 @@
[package]
name = "concrete-cpu"
version = "0.1.0"
edition = "2021"
license = "BSD-3-Clause-Clear"
[lib]
name = "concrete_cpu"
crate-type = ["lib", "staticlib"]
[dependencies]
concrete-csprng = { version = "0.2", optional = true, features = [
"generator_soft",
] }
libc = { version = "0.2", default-features = false }
pulp = { version = "0.10", default-features = false }
dyn-stack = { version = "0.8", default-features = false }
readonly = "0.2"
aligned-vec = { version = "0.5", default-features = false }
concrete-fft = { version = "0.1", default-features = false }
bytemuck = "1.12"
num-complex = { version = "0.4", default-features = false, features = [
"bytemuck",
] }
rayon = { version = "1.6", optional = true }
once_cell = { version = "1.16", optional = true }
[features]
default = ["parallel", "std", "csprng"]
std = [
"concrete-fft/std",
"aligned-vec/std",
"dyn-stack/std",
"pulp/std",
"once_cell",
]
csprng = ["concrete-csprng"]
parallel = ["rayon"]
nightly = ["pulp/nightly"]
[build-dependencies]
cbindgen = "0.24"
[dev-dependencies]
criterion = "0.4"
[[bench]]
name = "bench"
harness = false
[profile.test]
overflow-checks = true
[profile.dev]
opt-level = 3
overflow-checks = true

View File

@@ -0,0 +1,67 @@
use concrete_cpu::c_api::linear_op::{
concrete_cpu_add_lwe_ciphertext_u64, concrete_cpu_add_plaintext_lwe_ciphertext_u64,
concrete_cpu_mul_cleartext_lwe_ciphertext_u64, concrete_cpu_negate_lwe_ciphertext_u64,
};
use criterion::{criterion_group, criterion_main, Criterion};
pub fn criterion_benchmark(c: &mut Criterion) {
for lwe_dimension in [128, 256, 512] {
let lwe_size = lwe_dimension + 1;
c.bench_function(&format!("add-lwe-ciphertext-u64-{lwe_dimension}"), |b| {
let mut out = vec![0_u64; lwe_size];
let ct0 = vec![0_u64; lwe_size];
let ct1 = vec![0_u64; lwe_size];
b.iter(|| unsafe {
concrete_cpu_add_lwe_ciphertext_u64(
out.as_mut_ptr(),
ct0.as_ptr(),
ct1.as_ptr(),
lwe_dimension,
);
});
});
c.bench_function(&format!("add-lwe-plaintext-u64-{lwe_dimension}"), |b| {
let mut out = vec![0_u64; lwe_size];
let ct0 = vec![0_u64; lwe_size];
let plaintext = 0_u64;
b.iter(|| unsafe {
concrete_cpu_add_plaintext_lwe_ciphertext_u64(
out.as_mut_ptr(),
ct0.as_ptr(),
plaintext,
lwe_dimension,
);
});
});
c.bench_function(&format!("mul-lwe-cleartext-u64-{lwe_dimension}"), |b| {
let mut out = vec![0_u64; lwe_size];
let ct0 = vec![0_u64; lwe_size];
let cleartext = 0_u64;
b.iter(|| unsafe {
concrete_cpu_mul_cleartext_lwe_ciphertext_u64(
out.as_mut_ptr(),
ct0.as_ptr(),
cleartext,
lwe_dimension,
);
});
});
c.bench_function(&format!("negate-lwe-ciphertext-u64-{lwe_dimension}"), |b| {
let mut out = vec![0_u64; lwe_size];
let ct0 = vec![0_u64; lwe_size];
b.iter(|| unsafe {
concrete_cpu_negate_lwe_ciphertext_u64(
out.as_mut_ptr(),
ct0.as_ptr(),
lwe_dimension,
);
});
});
}
}
criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);

View File

@@ -0,0 +1,14 @@
extern crate cbindgen;
use std::env;
fn main() {
let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
let package_name = env::var("CARGO_PKG_NAME").unwrap();
let output_file = format!("include/{package_name}.h");
println!("cargo:rerun-if-changed={output_file}");
cbindgen::generate(crate_dir)
.unwrap()
.write_to_file(output_file);
}

View File

@@ -0,0 +1,129 @@
# This is a template cbindgen.toml file with all of the default values.
# Some values are commented out because their absence is the real default.
#
# See https://github.com/eqrion/cbindgen/blob/master/docs.md#cbindgentoml
# for detailed documentation of every option here.
language = "C"
############## Options for Wrapping the Contents of the Header #################
header = "// Copyright © 2022 ZAMA.\n// All rights reserved."
# trailer = "/* Text to put at the end of the generated file */"
include_guard = "CONCRETE_CPU_FFI_H"
# pragma_once = true
autogen_warning = "// Warning, this file is autogenerated by cbindgen. Do not modify this manually."
include_version = false
#namespace = "concrete_cpu_ffi"
namespaces = []
using_namespaces = []
sys_includes = []
includes = []
no_includes = false
cpp_compat = true
after_includes = ""
############################ Code Style Options ################################
braces = "SameLine"
line_length = 100
tab_width = 2
documentation = false
documentation_style = "auto"
line_endings = "LF" # also "CR", "CRLF", "Native"
############################# Codegen Options ##################################
style = "both"
sort_by = "Name" # default for `fn.sort_by` and `const.sort_by`
usize_is_size_t = true
[defines]
# "target_os = freebsd" = "DEFINE_FREEBSD"
# "feature = serde" = "DEFINE_SERDE"
[export]
include = []
exclude = []
#prefix = "CAPI_"
item_types = []
renaming_overrides_prefixing = false
[export.rename]
[export.body]
[export.mangle]
[fn]
rename_args = "None"
# must_use = "MUST_USE_FUNC"
# no_return = "NO_RETURN"
# prefix = "START_FUNC"
# postfix = "END_FUNC"
args = "auto"
sort_by = "Name"
[struct]
rename_fields = "None"
# must_use = "MUST_USE_STRUCT"
derive_constructor = false
derive_eq = false
derive_neq = false
derive_lt = false
derive_lte = false
derive_gt = false
derive_gte = false
[enum]
rename_variants = "None"
# must_use = "MUST_USE_ENUM"
add_sentinel = false
prefix_with_name = false
derive_helper_methods = false
derive_const_casts = false
derive_mut_casts = false
# cast_assert_name = "ASSERT"
derive_tagged_enum_destructor = false
derive_tagged_enum_copy_constructor = false
enum_class = true
private_default_tagged_enum_constructor = false
[const]
allow_static_const = true
allow_constexpr = false
sort_by = "Name"
[macro_expansion]
bitflags = false
############## Options for How Your Rust library Should Be Parsed ##############
[parse]
parse_deps = true
include = ["concrete-cpu"]
exclude = []
clean = false
extra_bindings = []
[parse.expand]
crates = []
all_features = false
default_features = true
features = []

View File

@@ -0,0 +1,285 @@
// Copyright © 2022 ZAMA.
// All rights reserved.
#ifndef CONCRETE_CPU_FFI_H
#define CONCRETE_CPU_FFI_H
// Warning, this file is autogenerated by cbindgen. Do not modify this manually.
#include <stdarg.h>
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>
#include <stdlib.h>
enum Parallelism
#ifdef __cplusplus
: uint32_t
#endif // __cplusplus
{
No = 0,
Rayon = 1,
};
#ifndef __cplusplus
typedef uint32_t Parallelism;
#endif // __cplusplus
enum ScratchStatus
#ifdef __cplusplus
: uint32_t
#endif // __cplusplus
{
Valid = 0,
SizeOverflow = 1,
};
#ifndef __cplusplus
typedef uint32_t ScratchStatus;
#endif // __cplusplus
typedef struct Csprng Csprng;
typedef struct Fft Fft;
typedef struct Uint128 {
uint8_t little_endian_bytes[16];
} Uint128;
typedef struct CsprngVtable {
struct Uint128 (*remaining_bytes)(const struct Csprng *csprng);
size_t (*next_bytes)(struct Csprng *csprng, uint8_t *byte_array, size_t byte_count);
} CsprngVtable;
#ifdef __cplusplus
extern "C" {
#endif // __cplusplus
extern const size_t CONCRETE_CSPRNG_ALIGN;
extern const size_t CONCRETE_CSPRNG_SIZE;
extern const struct CsprngVtable CONCRETE_CSPRNG_VTABLE;
extern const size_t CONCRETE_FFT_ALIGN;
extern const size_t CONCRETE_FFT_SIZE;
void concrete_cpu_add_lwe_ciphertext_u64(uint64_t *ct_out,
const uint64_t *ct_in0,
const uint64_t *ct_in1,
size_t lwe_dimension);
void concrete_cpu_add_plaintext_lwe_ciphertext_u64(uint64_t *ct_out,
const uint64_t *ct_in,
uint64_t plaintext,
size_t lwe_dimension);
void concrete_cpu_bootstrap_key_convert_u64_to_fourier(const uint64_t *standard_bsk,
double *fourier_bsk,
size_t decomposition_level_count,
size_t decomposition_base_log,
size_t glwe_dimension,
size_t polynomial_size,
size_t input_lwe_dimension,
const struct Fft *fft,
uint8_t *stack,
size_t stack_size);
ScratchStatus concrete_cpu_bootstrap_key_convert_u64_to_fourier_scratch(size_t *stack_size,
size_t *stack_align,
const struct Fft *fft);
size_t concrete_cpu_bootstrap_key_size_u64(size_t decomposition_level_count,
size_t glwe_dimension,
size_t polynomial_size,
size_t input_lwe_dimension);
void concrete_cpu_bootstrap_lwe_ciphertext_u64(uint64_t *ct_out,
const uint64_t *ct_in,
const uint64_t *accumulator,
const double *fourier_bsk,
size_t decomposition_level_count,
size_t decomposition_base_log,
size_t glwe_dimension,
size_t polynomial_size,
size_t input_lwe_dimension,
const struct Fft *fft,
uint8_t *stack,
size_t stack_size);
ScratchStatus concrete_cpu_bootstrap_lwe_ciphertext_u64_scratch(size_t *stack_size,
size_t *stack_align,
size_t glwe_dimension,
size_t polynomial_size,
const struct Fft *fft);
void concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64(uint64_t *ct_out_vec,
const uint64_t *ct_in_vec,
const uint64_t *lut,
const double *fourier_bsk,
const uint64_t *fpksk,
size_t ct_out_dimension,
size_t ct_out_count,
size_t ct_in_dimension,
size_t ct_in_count,
size_t lut_size,
size_t lut_count,
size_t bsk_decomposition_level_count,
size_t bsk_decomposition_base_log,
size_t bsk_glwe_dimension,
size_t bsk_polynomial_size,
size_t bsk_input_lwe_dimension,
size_t fpksk_decomposition_level_count,
size_t fpksk_decomposition_base_log,
size_t fpksk_input_dimension,
size_t fpksk_output_glwe_dimension,
size_t fpksk_output_polynomial_size,
size_t fpksk_count,
size_t cbs_decomposition_level_count,
size_t cbs_decomposition_base_log,
const struct Fft *fft,
uint8_t *stack,
size_t stack_size);
ScratchStatus concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64_scratch(size_t *stack_size,
size_t *stack_align,
size_t ct_out_count,
size_t ct_in_dimension,
size_t ct_in_count,
size_t lut_size,
size_t lut_count,
size_t bsk_glwe_dimension,
size_t bsk_polynomial_size,
size_t fpksk_output_polynomial_size,
size_t cbs_decomposition_level_count,
const struct Fft *fft);
void concrete_cpu_construct_concrete_csprng(struct Csprng *mem, struct Uint128 seed);
void concrete_cpu_construct_concrete_fft(struct Fft *mem, size_t polynomial_size);
int concrete_cpu_crypto_secure_random_128(struct Uint128 *u128);
void concrete_cpu_decrypt_lwe_ciphertext_u64(const uint64_t *lwe_sk,
const uint64_t *lwe_ct_in,
size_t lwe_dimension,
uint64_t *plaintext);
void concrete_cpu_destroy_concrete_csprng(struct Csprng *mem);
void concrete_cpu_destroy_concrete_fft(struct Fft *mem);
void concrete_cpu_encrypt_lwe_ciphertext_u64(const uint64_t *lwe_sk,
uint64_t *lwe_out,
uint64_t input,
size_t lwe_dimension,
double variance,
struct Csprng *csprng,
const struct CsprngVtable *csprng_vtable);
void concrete_cpu_extract_bit_lwe_ciphertext_u64(uint64_t *ct_vec_out,
const uint64_t *ct_in,
const double *fourier_bsk,
const uint64_t *ksk,
size_t ct_out_dimension,
size_t ct_out_count,
size_t ct_in_dimension,
size_t number_of_bits,
size_t delta_log,
size_t bsk_decomposition_level_count,
size_t bsk_decomposition_base_log,
size_t bsk_glwe_dimension,
size_t bsk_polynomial_size,
size_t bsk_input_lwe_dimension,
size_t ksk_decomposition_level_count,
size_t ksk_decomposition_base_log,
size_t ksk_input_dimension,
size_t ksk_output_dimension,
const struct Fft *fft,
uint8_t *stack,
size_t stack_size);
ScratchStatus concrete_cpu_extract_bit_lwe_ciphertext_u64_scratch(size_t *stack_size,
size_t *stack_align,
size_t ct_out_dimension,
size_t ct_in_dimension,
size_t bsk_glwe_dimension,
size_t bsk_polynomial_size,
const struct Fft *fft);
void concrete_cpu_init_lwe_bootstrap_key_u64(uint64_t *lwe_bsk,
const uint64_t *input_lwe_sk,
const uint64_t *output_glwe_sk,
size_t input_lwe_dimension,
size_t output_polynomial_size,
size_t output_glwe_dimension,
size_t decomposition_level_count,
size_t decomposition_base_log,
double variance,
Parallelism parallelism,
struct Csprng *csprng,
const struct CsprngVtable *csprng_vtable);
void concrete_cpu_init_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(uint64_t *lwe_pksk,
const uint64_t *input_lwe_sk,
const uint64_t *output_glwe_sk,
size_t input_lwe_dimension,
size_t output_polynomial_size,
size_t output_glwe_dimension,
size_t decomposition_level_count,
size_t decomposition_base_log,
double variance,
Parallelism parallelism,
struct Csprng *csprng,
const struct CsprngVtable *csprng_vtable);
void concrete_cpu_init_lwe_keyswitch_key_u64(uint64_t *lwe_ksk,
const uint64_t *input_lwe_sk,
const uint64_t *output_lwe_sk,
size_t input_lwe_dimension,
size_t output_lwe_dimension,
size_t decomposition_level_count,
size_t decomposition_base_log,
double variance,
struct Csprng *csprng,
const struct CsprngVtable *csprng_vtable);
void concrete_cpu_init_lwe_secret_key_u64(uint64_t *lwe_sk,
size_t lwe_dimension,
struct Csprng *csprng,
const struct CsprngVtable *csprng_vtable);
size_t concrete_cpu_keyswitch_key_size_u64(size_t decomposition_level_count,
size_t _decomposition_base_log,
size_t input_dimension,
size_t output_dimension);
void concrete_cpu_keyswitch_lwe_ciphertext_u64(uint64_t *ct_out,
const uint64_t *ct_in,
const uint64_t *keyswitch_key,
size_t decomposition_level_count,
size_t decomposition_base_log,
size_t input_dimension,
size_t output_dimension);
size_t concrete_cpu_lwe_packing_keyswitch_key_size(size_t glwe_dimension,
size_t polynomial_size,
size_t decomposition_level_count,
size_t input_dimension);
void concrete_cpu_mul_cleartext_lwe_ciphertext_u64(uint64_t *ct_out,
const uint64_t *ct_in,
uint64_t cleartext,
size_t lwe_dimension);
void concrete_cpu_negate_lwe_ciphertext_u64(uint64_t *ct_out,
const uint64_t *ct_in,
size_t lwe_dimension);
size_t concrete_cpu_secret_key_size_u64(size_t lwe_dimension);
#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
#endif /* CONCRETE_CPU_FFI_H */

View File

@@ -0,0 +1,44 @@
pub mod bootstrap;
#[cfg(feature = "csprng")]
pub mod csprng;
pub mod fft;
pub mod keyswitch;
pub mod linear_op;
pub mod secret_key;
pub mod types;
pub mod wop_pbs;
mod utils {
#[inline]
pub fn nounwind<R>(f: impl FnOnce() -> R) -> R {
struct AbortOnDrop;
impl Drop for AbortOnDrop {
#[inline]
fn drop(&mut self) {
panic!();
}
}
let abort = AbortOnDrop;
let val = f();
core::mem::forget(abort);
val
}
const __ASSERT_USIZE_SAME_AS_SIZE_T: () = {
let _: libc::size_t = 0_usize;
};
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unwind() {
// can't test caught panics
// so we just test the successful path
assert_eq!(nounwind(|| 1), 1);
}
}
}

View File

@@ -0,0 +1,236 @@
use crate::{
c_api::types::{Parallelism, ScratchStatus},
implementation::{fft::Fft, types::*},
};
use core::slice;
use dyn_stack::DynStack;
use super::{
types::{Csprng, CsprngVtable},
utils::nounwind,
};
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_init_lwe_bootstrap_key_u64(
// bootstrap key
lwe_bsk: *mut u64,
// secret keys
input_lwe_sk: *const u64,
output_glwe_sk: *const u64,
// secret key dimensions
input_lwe_dimension: usize,
output_polynomial_size: usize,
output_glwe_dimension: usize,
// bootstrap key parameters
decomposition_level_count: usize,
decomposition_base_log: usize,
// noise parameters
variance: f64,
// parallelism
parallelism: Parallelism,
// csprng
csprng: *mut Csprng,
csprng_vtable: *const CsprngVtable,
) {
nounwind(|| {
let glwe_params = GlweParams {
dimension: output_glwe_dimension,
polynomial_size: output_polynomial_size,
};
let decomp_params = DecompParams {
level: decomposition_level_count,
base_log: decomposition_base_log,
};
let bsk =
BootstrapKey::from_raw_parts(lwe_bsk, glwe_params, input_lwe_dimension, decomp_params);
let lwe_sk = LweSecretKey::from_raw_parts(input_lwe_sk, input_lwe_dimension);
let glwe_sk = GlweSecretKey::from_raw_parts(output_glwe_sk, glwe_params);
match parallelism {
Parallelism::No => bsk.fill_with_new_key(
lwe_sk,
glwe_sk,
variance,
CsprngMut::new(csprng, csprng_vtable),
),
Parallelism::Rayon => bsk.fill_with_new_key_par(
lwe_sk,
glwe_sk,
variance,
CsprngMut::new(csprng, csprng_vtable),
),
}
});
}
#[no_mangle]
#[must_use]
pub unsafe extern "C" fn concrete_cpu_bootstrap_key_convert_u64_to_fourier_scratch(
stack_size: *mut usize,
stack_align: *mut usize,
// side resources
fft: *const Fft,
) -> ScratchStatus {
nounwind(|| {
if let Ok(scratch) = (*fft).as_view().forward_scratch() {
*stack_size = scratch.size_bytes();
*stack_align = scratch.align_bytes();
ScratchStatus::Valid
} else {
ScratchStatus::SizeOverflow
}
})
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_bootstrap_key_convert_u64_to_fourier(
// bootstrap key
standard_bsk: *const u64,
fourier_bsk: *mut f64,
// bootstrap parameters
decomposition_level_count: usize,
decomposition_base_log: usize,
glwe_dimension: usize,
polynomial_size: usize,
input_lwe_dimension: usize,
// side resources
fft: *const Fft,
stack: *mut u8,
stack_size: usize,
) {
nounwind(|| {
let glwe_params = GlweParams {
dimension: glwe_dimension,
polynomial_size,
};
let decomp_params = DecompParams {
level: decomposition_level_count,
base_log: decomposition_base_log,
};
let standard = BootstrapKey::from_raw_parts(
standard_bsk,
glwe_params,
input_lwe_dimension,
decomp_params,
);
let mut fourier = BootstrapKey::from_raw_parts(
fourier_bsk,
glwe_params,
input_lwe_dimension,
decomp_params,
);
fourier.fill_with_forward_fourier(
standard,
(*fft).as_view(),
DynStack::new(slice::from_raw_parts_mut(stack as _, stack_size)),
);
})
}
#[no_mangle]
#[must_use]
pub unsafe extern "C" fn concrete_cpu_bootstrap_lwe_ciphertext_u64_scratch(
stack_size: *mut usize,
stack_align: *mut usize,
// bootstrap parameters
glwe_dimension: usize,
polynomial_size: usize,
// side resources
fft: *const crate::implementation::fft::Fft,
) -> ScratchStatus {
nounwind(|| {
let fft = (*fft).as_view();
if let Ok(scratch) = BootstrapKey::bootstrap_scratch(
GlweParams {
dimension: glwe_dimension,
polynomial_size,
},
fft,
) {
*stack_size = scratch.size_bytes();
*stack_align = scratch.align_bytes();
ScratchStatus::Valid
} else {
ScratchStatus::SizeOverflow
}
})
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_bootstrap_lwe_ciphertext_u64(
// ciphertexts
ct_out: *mut u64,
ct_in: *const u64,
// accumulator
accumulator: *const u64,
// bootstrap key
fourier_bsk: *const f64,
// bootstrap parameters
decomposition_level_count: usize,
decomposition_base_log: usize,
glwe_dimension: usize,
polynomial_size: usize,
input_lwe_dimension: usize,
// side resources
fft: *const crate::implementation::fft::Fft,
stack: *mut u8,
stack_size: usize,
) {
nounwind(|| {
let glwe_params = GlweParams {
dimension: glwe_dimension,
polynomial_size,
};
let decomp_params = DecompParams {
level: decomposition_level_count,
base_log: decomposition_base_log,
};
let output_lwe_dimension = glwe_dimension * polynomial_size;
let fourier = BootstrapKey::from_raw_parts(
fourier_bsk,
glwe_params,
input_lwe_dimension,
decomp_params,
);
let lwe_in = LweCiphertext::from_raw_parts(ct_in, input_lwe_dimension);
let lwe_out = LweCiphertext::from_raw_parts(ct_out, output_lwe_dimension);
let accumulator = GlweCiphertext::from_raw_parts(accumulator, glwe_params);
fourier.bootstrap(
lwe_out,
lwe_in,
accumulator,
(*fft).as_view(),
DynStack::new(slice::from_raw_parts_mut(stack as _, stack_size)),
);
})
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_bootstrap_key_size_u64(
decomposition_level_count: usize,
glwe_dimension: usize,
polynomial_size: usize,
input_lwe_dimension: usize,
) -> usize {
BootstrapKey::<&[u64]>::data_len(
GlweParams {
dimension: glwe_dimension,
polynomial_size,
},
decomposition_level_count,
input_lwe_dimension,
)
}

View File

@@ -0,0 +1,95 @@
use std::io::Read;
use super::types::{Csprng, CsprngVtable, Uint128};
use concrete_csprng::{
generators::{RandomGenerator, SoftwareRandomGenerator},
seeders::Seed,
};
use libc::c_int;
type Generator = SoftwareRandomGenerator;
#[no_mangle]
pub static CONCRETE_CSPRNG_VTABLE: CsprngVtable = CsprngVtable {
remaining_bytes: {
unsafe extern "C" fn remaining_bytes(csprng: *const Csprng) -> Uint128 {
let csprng = &*(csprng as *const Generator);
Uint128 {
little_endian_bytes: csprng.remaining_bytes().0.to_le_bytes(),
}
}
remaining_bytes
},
next_bytes: {
unsafe extern "C" fn next_bytes(
csprng: *mut Csprng,
byte_array: *mut u8,
byte_count: usize,
) -> usize {
let csprng = &mut *(csprng as *mut Generator);
let mut count = 0;
while count < byte_count {
if let Some(byte) = csprng.next() {
*byte_array.add(count) = byte;
count += 1;
} else {
break;
};
}
count
}
next_bytes
},
};
#[no_mangle]
pub static CONCRETE_CSPRNG_SIZE: usize = core::mem::size_of::<Generator>();
#[no_mangle]
pub static CONCRETE_CSPRNG_ALIGN: usize = core::mem::align_of::<Generator>();
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_construct_concrete_csprng(mem: *mut Csprng, seed: Uint128) {
let mem = mem as *mut Generator;
let seed = Seed(u128::from_le_bytes(seed.little_endian_bytes));
mem.write(Generator::new(seed));
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_destroy_concrete_csprng(mem: *mut Csprng) {
core::ptr::drop_in_place(mem as *mut Generator);
}
// Randomly fill a uint128.
// Returns 1 if the random is crypto secure, -1 if it not secure, 0 if fail.
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_crypto_secure_random_128(u128: *mut Uint128) -> c_int {
if is_x86_feature_detected!("rdseed") {
let mut rand: u64 = 0;
loop {
if core::arch::x86_64::_rdseed64_step(&mut rand) == 1 {
(*u128).little_endian_bytes[0..8].copy_from_slice(&rand.to_ne_bytes());
break;
}
}
loop {
if core::arch::x86_64::_rdseed64_step(&mut rand) == 1 {
(*u128).little_endian_bytes[8..16].copy_from_slice(&rand.to_ne_bytes());
break;
}
}
return 1;
}
let buf = &mut (*u128).little_endian_bytes[0..16];
if let Ok(mut random) = std::fs::File::open("/dev/random") {
if let Ok(16) = random.read(buf) {
return -1;
}
}
0
}

View File

@@ -0,0 +1,20 @@
use crate::implementation::fft::Fft;
#[no_mangle]
pub static CONCRETE_FFT_SIZE: usize = core::mem::size_of::<Fft>();
#[no_mangle]
pub static CONCRETE_FFT_ALIGN: usize = core::mem::align_of::<Fft>();
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_construct_concrete_fft(
mem: *mut Fft,
polynomial_size: usize,
) {
mem.write(Fft::new(polynomial_size));
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_destroy_concrete_fft(mem: *mut Fft) {
core::ptr::drop_in_place(mem);
}

View File

@@ -0,0 +1,91 @@
use super::{
types::{Csprng, CsprngVtable},
utils::nounwind,
};
use crate::implementation::types::*;
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_init_lwe_keyswitch_key_u64(
// keyswitch key
lwe_ksk: *mut u64,
// secret keys
input_lwe_sk: *const u64,
output_lwe_sk: *const u64,
// secret key dimensions
input_lwe_dimension: usize,
output_lwe_dimension: usize,
// keyswitch key parameters
decomposition_level_count: usize,
decomposition_base_log: usize,
// noise parameters
variance: f64,
// csprng
csprng: *mut Csprng,
csprng_vtable: *const CsprngVtable,
) {
nounwind(|| {
let input_key = LweSecretKey::from_raw_parts(input_lwe_sk, input_lwe_dimension);
let output_key = LweSecretKey::from_raw_parts(output_lwe_sk, output_lwe_dimension);
let ksk = LweKeyswitchKey::from_raw_parts(
lwe_ksk,
output_lwe_dimension,
input_lwe_dimension,
DecompParams {
level: decomposition_level_count,
base_log: decomposition_base_log,
},
);
ksk.fill_with_keyswitch_key(
input_key,
output_key,
variance,
CsprngMut::new(csprng, csprng_vtable),
);
});
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_keyswitch_lwe_ciphertext_u64(
// ciphertexts
ct_out: *mut u64,
ct_in: *const u64,
// keyswitch key
keyswitch_key: *const u64,
// keyswitch parameters
decomposition_level_count: usize,
decomposition_base_log: usize,
input_dimension: usize,
output_dimension: usize,
) {
nounwind(|| {
let ct_out = LweCiphertext::from_raw_parts(ct_out, output_dimension);
let ct_in = LweCiphertext::from_raw_parts(ct_in, input_dimension);
let keyswitch_key = LweKeyswitchKey::from_raw_parts(
keyswitch_key,
output_dimension,
input_dimension,
DecompParams {
level: decomposition_level_count,
base_log: decomposition_base_log,
},
);
keyswitch_key.keyswitch_ciphertext(ct_out, ct_in);
})
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_keyswitch_key_size_u64(
decomposition_level_count: usize,
_decomposition_base_log: usize,
input_dimension: usize,
output_dimension: usize,
) -> usize {
LweKeyswitchKey::<&[u64]>::data_len(
output_dimension,
decomposition_level_count,
input_dimension,
)
}

View File

@@ -0,0 +1,124 @@
use super::utils::nounwind;
use core::slice;
/// # Safety
///
/// `[ct_out, ct_out + lwe_dimension + 1[` must be a valid mutable range, and must not alias
/// `[ct_in0, ct_in0 + lwe_dimension + 1[` or `[ct_in1, ct_in1 + lwe_dimension + 1[`, both of which
/// must be valid ranges for reads.
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_add_lwe_ciphertext_u64(
ct_out: *mut u64,
ct_in0: *const u64,
ct_in1: *const u64,
lwe_dimension: usize,
) {
nounwind(|| {
#[inline]
fn implementation(ct_out: &mut [u64], ct_in0: &[u64], ct_in1: &[u64]) {
for ((out, &c0), &c1) in ct_out.iter_mut().zip(ct_in0).zip(ct_in1) {
*out = c0.wrapping_add(c1)
}
}
let lwe_size = lwe_dimension + 1;
pulp::Arch::new().dispatch(|| {
implementation(
slice::from_raw_parts_mut(ct_out, lwe_size),
slice::from_raw_parts(ct_in0, lwe_size),
slice::from_raw_parts(ct_in1, lwe_size),
)
});
})
}
/// # Safety
///
/// `[ct_out, ct_out + lwe_dimension + 1[` must be a valid mutable range, and must not alias
/// `[ct_in, ct_in + lwe_dimension + 1[`, which must be a valid range for reads.
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_add_plaintext_lwe_ciphertext_u64(
ct_out: *mut u64,
ct_in: *const u64,
plaintext: u64,
lwe_dimension: usize,
) {
nounwind(|| {
#[inline]
fn implementation(ct_out: &mut [u64], ct_in: &[u64], plaintext: u64) {
ct_out.copy_from_slice(ct_in);
let last = ct_out.last_mut().unwrap();
*last = last.wrapping_add(plaintext);
}
let lwe_size = lwe_dimension + 1;
pulp::Arch::new().dispatch(|| {
implementation(
slice::from_raw_parts_mut(ct_out, lwe_size),
slice::from_raw_parts(ct_in, lwe_size),
plaintext,
)
});
})
}
/// # Safety
///
/// `[ct_out, ct_out + lwe_dimension + 1[` must be a valid mutable range, and must not alias
/// `[ct_in, ct_in + lwe_dimension + 1[`, which must be a valid range for reads.
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_mul_cleartext_lwe_ciphertext_u64(
ct_out: *mut u64,
ct_in: *const u64,
cleartext: u64,
lwe_dimension: usize,
) {
nounwind(|| {
#[inline]
fn implementation(ct_out: &mut [u64], ct_in: &[u64], cleartext: u64) {
for (out, &c) in ct_out.iter_mut().zip(ct_in) {
*out = c.wrapping_mul(cleartext)
}
}
let lwe_size = lwe_dimension + 1;
pulp::Arch::new().dispatch(|| {
implementation(
slice::from_raw_parts_mut(ct_out, lwe_size),
slice::from_raw_parts(ct_in, lwe_size),
cleartext,
)
});
})
}
/// # Safety
///
/// `[ct_out, ct_out + lwe_dimension + 1[` must be a valid mutable range, and must not alias
/// `[ct_in, ct_in + lwe_dimension + 1[`, which must be a valid range for reads.
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_negate_lwe_ciphertext_u64(
ct_out: *mut u64,
ct_in: *const u64,
lwe_dimension: usize,
) {
nounwind(|| {
#[inline]
fn implementation(ct_out: &mut [u64], ct_in: &[u64]) {
for (out, &c) in ct_out.iter_mut().zip(ct_in) {
*out = c.wrapping_neg();
}
}
let lwe_size = lwe_dimension + 1;
pulp::Arch::new().dispatch(|| {
implementation(
slice::from_raw_parts_mut(ct_out, lwe_size),
slice::from_raw_parts(ct_in, lwe_size),
)
});
})
}

View File

@@ -0,0 +1,72 @@
use super::{
types::{Csprng, CsprngVtable},
utils::nounwind,
};
use crate::implementation::types::{CsprngMut, LweCiphertext, LweSecretKey};
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_secret_key_size_u64(lwe_dimension: usize) -> usize {
LweSecretKey::<&[u64]>::data_len(lwe_dimension)
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_init_lwe_secret_key_u64(
lwe_sk: *mut u64,
lwe_dimension: usize,
csprng: *mut Csprng,
csprng_vtable: *const CsprngVtable,
) {
nounwind(|| {
let csprng = CsprngMut::new(csprng, csprng_vtable);
let sk = LweSecretKey::<&mut [u64]>::from_raw_parts(lwe_sk, lwe_dimension);
sk.fill_with_new_key(csprng);
})
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_encrypt_lwe_ciphertext_u64(
// secret key
lwe_sk: *const u64,
// ciphertext
lwe_out: *mut u64,
// plaintext
input: u64,
// lwe size
lwe_dimension: usize,
// encryption parameters
variance: f64,
// csprng
csprng: *mut Csprng,
csprng_vtable: *const CsprngVtable,
) {
nounwind(|| {
let lwe_sk = LweSecretKey::from_raw_parts(lwe_sk, lwe_dimension);
let lwe_out = LweCiphertext::from_raw_parts(lwe_out, lwe_dimension);
lwe_sk.encrypt_lwe(
lwe_out,
input,
variance,
CsprngMut::new(csprng, csprng_vtable),
);
});
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_decrypt_lwe_ciphertext_u64(
// secret key
lwe_sk: *const u64,
// ciphertext
lwe_ct_in: *const u64,
// lwe size
lwe_dimension: usize,
// plaintext
plaintext: *mut u64,
) {
nounwind(|| {
let lwe_sk = LweSecretKey::from_raw_parts(lwe_sk, lwe_dimension);
let lwe_ct_in = LweCiphertext::from_raw_parts(lwe_ct_in, lwe_dimension);
*plaintext = lwe_sk.decrypt_lwe(lwe_ct_in);
});
}

View File

@@ -0,0 +1,58 @@
#[repr(C)]
#[derive(Copy, Clone)]
pub struct Uint128 {
pub little_endian_bytes: [u8; 16],
}
impl core::fmt::Debug for Uint128 {
#[inline]
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
u128::from_le_bytes(self.little_endian_bytes).fmt(f)
}
}
pub struct Csprng {
__private: (),
}
#[repr(C)]
#[derive(Copy, Clone)]
pub struct CsprngVtable {
/// Returns the number of remaining bytes that this Csprng can generate.
pub remaining_bytes: unsafe extern "C" fn(csprng: *const Csprng) -> Uint128,
/// Fills the byte array with random bytes, up to the given count, and returns the count of
/// successfully generated bytes.
pub next_bytes:
unsafe extern "C" fn(csprng: *mut Csprng, byte_array: *mut u8, byte_count: usize) -> usize,
}
#[repr(u32)]
#[derive(Copy, Clone, Debug)]
pub enum ScratchStatus {
Valid = 0,
SizeOverflow = 1,
}
#[repr(u32)]
#[derive(Copy, Clone, Debug)]
pub enum Parallelism {
No = 0,
Rayon = 1,
}
#[cfg(test)]
pub mod tests {
use super::*;
use crate::{c_api::csprng::CONCRETE_CSPRNG_VTABLE, implementation::types::CsprngMut};
use concrete_csprng::generators::SoftwareRandomGenerator;
pub fn to_generic(a: &mut SoftwareRandomGenerator) -> CsprngMut<'_, '_> {
unsafe {
CsprngMut::new(
a as *mut SoftwareRandomGenerator as *mut Csprng,
&CONCRETE_CSPRNG_VTABLE,
)
}
}
}

View File

@@ -0,0 +1,375 @@
use crate::{
c_api::{types::*, utils::nounwind},
implementation::{
fft::Fft,
types::{
ciphertext_list::LweCiphertextList,
packing_keyswitch_key_list::PackingKeyswitchKeyList, polynomial_list::PolynomialList,
*,
},
wop::{
circuit_bootstrap_boolean_vertical_packing,
circuit_bootstrap_boolean_vertical_packing_scratch, extract_bits, extract_bits_scratch,
},
},
};
use core::slice;
use dyn_stack::DynStack;
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_init_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
// packing keyswitch key
lwe_pksk: *mut u64,
// secret keys
input_lwe_sk: *const u64,
output_glwe_sk: *const u64,
// secret key dimensions
input_lwe_dimension: usize,
output_polynomial_size: usize,
output_glwe_dimension: usize,
// circuit bootstrap parameters
decomposition_level_count: usize,
decomposition_base_log: usize,
// noise parameters
variance: f64,
parallelism: Parallelism,
// csprng
csprng: *mut Csprng,
csprng_vtable: *const CsprngVtable,
) {
nounwind(|| {
let glwe_params = GlweParams {
dimension: output_glwe_dimension,
polynomial_size: output_polynomial_size,
};
let decomp_params = DecompParams {
level: decomposition_level_count,
base_log: decomposition_base_log,
};
let input_key = LweSecretKey::<&[u64]>::from_raw_parts(input_lwe_sk, input_lwe_dimension);
let output_key = GlweSecretKey::<&[u64]>::from_raw_parts(output_glwe_sk, glwe_params);
let mut fpksk_list = PackingKeyswitchKeyList::<&mut [u64]>::from_raw_parts(
lwe_pksk,
glwe_params,
input_lwe_dimension,
decomp_params,
glwe_params.dimension + 1,
);
match parallelism {
Parallelism::No => fpksk_list.fill_with_fpksk_for_circuit_bootstrap(
&input_key,
&output_key,
variance,
CsprngMut::new(csprng, csprng_vtable),
),
Parallelism::Rayon => fpksk_list.fill_with_fpksk_for_circuit_bootstrap_par(
&input_key,
&output_key,
variance,
CsprngMut::new(csprng, csprng_vtable),
),
}
})
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_extract_bit_lwe_ciphertext_u64_scratch(
stack_size: *mut usize,
stack_align: *mut usize,
// ciphertexts dimensions
ct_out_dimension: usize,
ct_in_dimension: usize,
// bootstrap parameters
bsk_glwe_dimension: usize,
bsk_polynomial_size: usize,
// side resources
fft: *const Fft,
) -> ScratchStatus {
nounwind(|| {
if let Ok(scratch) = extract_bits_scratch(
ct_in_dimension,
ct_out_dimension + 1,
GlweParams {
dimension: bsk_glwe_dimension,
polynomial_size: bsk_polynomial_size,
},
(*fft).as_view(),
) {
*stack_size = scratch.size_bytes();
*stack_align = scratch.align_bytes();
ScratchStatus::Valid
} else {
ScratchStatus::SizeOverflow
}
})
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_extract_bit_lwe_ciphertext_u64(
// ciphertexts
ct_vec_out: *mut u64,
ct_in: *const u64,
// bootstrap key
fourier_bsk: *const f64,
// keyswitch key
ksk: *const u64,
// ciphertexts dimensions
ct_out_dimension: usize,
ct_out_count: usize,
ct_in_dimension: usize,
// extract bit parameters
number_of_bits: usize,
delta_log: usize,
// bootstrap parameters
bsk_decomposition_level_count: usize,
bsk_decomposition_base_log: usize,
bsk_glwe_dimension: usize,
bsk_polynomial_size: usize,
bsk_input_lwe_dimension: usize,
// keyswitch_parameters
ksk_decomposition_level_count: usize,
ksk_decomposition_base_log: usize,
ksk_input_dimension: usize,
ksk_output_dimension: usize,
// side resources
fft: *const Fft,
stack: *mut u8,
stack_size: usize,
) {
nounwind(|| {
assert_eq!(ct_in_dimension, bsk_glwe_dimension * bsk_polynomial_size);
assert_eq!(ct_in_dimension, ksk_input_dimension);
assert_eq!(ct_out_dimension, ksk_output_dimension);
assert_eq!(ct_out_count, number_of_bits);
assert_eq!(ksk_output_dimension, bsk_input_lwe_dimension);
assert!(64 <= number_of_bits + delta_log);
let lwe_list_out =
LweCiphertextList::from_raw_parts(ct_vec_out, ct_out_dimension, ct_out_count);
let lwe_in = LweCiphertext::from_raw_parts(ct_in, ct_in_dimension);
let ksk = LweKeyswitchKey::from_raw_parts(
ksk,
ksk_output_dimension,
ksk_input_dimension,
DecompParams {
level: ksk_decomposition_level_count,
base_log: ksk_decomposition_base_log,
},
);
let fourier_bsk = BootstrapKey::from_raw_parts(
fourier_bsk,
GlweParams {
dimension: bsk_glwe_dimension,
polynomial_size: bsk_polynomial_size,
},
bsk_input_lwe_dimension,
DecompParams {
level: bsk_decomposition_level_count,
base_log: bsk_decomposition_base_log,
},
);
extract_bits(
lwe_list_out,
lwe_in,
ksk,
fourier_bsk,
delta_log,
number_of_bits,
(*fft).as_view(),
DynStack::new(slice::from_raw_parts_mut(stack as _, stack_size)),
);
})
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64_scratch(
stack_size: *mut usize,
stack_align: *mut usize,
// ciphertext dimensions
ct_out_count: usize,
ct_in_dimension: usize,
ct_in_count: usize,
lut_size: usize,
lut_count: usize,
// bootstrap parameters
bsk_glwe_dimension: usize,
bsk_polynomial_size: usize,
// keyswitch_parameters
fpksk_output_polynomial_size: usize,
// circuit bootstrap parameters
cbs_decomposition_level_count: usize,
// side resources
fft: *const Fft,
) -> ScratchStatus {
nounwind(|| {
assert_eq!(ct_out_count, lut_count);
let bsk_output_lwe_dimension = bsk_glwe_dimension * bsk_polynomial_size;
assert_eq!(lut_size, 1 << ct_in_count);
assert_ne!(cbs_decomposition_level_count, 0);
if let Ok(scratch) = circuit_bootstrap_boolean_vertical_packing_scratch(
ct_in_count,
ct_out_count,
ct_in_dimension + 1,
lut_count,
bsk_output_lwe_dimension + 1,
fpksk_output_polynomial_size,
bsk_glwe_dimension + 1,
cbs_decomposition_level_count,
(*fft).as_view(),
) {
*stack_size = scratch.size_bytes();
*stack_align = scratch.align_bytes();
ScratchStatus::Valid
} else {
ScratchStatus::SizeOverflow
}
})
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64(
// ciphertexts
ct_out_vec: *mut u64,
ct_in_vec: *const u64,
// lookup table
lut: *const u64,
// bootstrap key
fourier_bsk: *const f64,
// packing keyswitch key
fpksk: *const u64,
// ciphertext dimensions
ct_out_dimension: usize,
ct_out_count: usize,
ct_in_dimension: usize,
ct_in_count: usize,
lut_size: usize,
lut_count: usize,
// bootstrap parameters
bsk_decomposition_level_count: usize,
bsk_decomposition_base_log: usize,
bsk_glwe_dimension: usize,
bsk_polynomial_size: usize,
bsk_input_lwe_dimension: usize,
// keyswitch_parameters
fpksk_decomposition_level_count: usize,
fpksk_decomposition_base_log: usize,
fpksk_input_dimension: usize,
fpksk_output_glwe_dimension: usize,
fpksk_output_polynomial_size: usize,
fpksk_count: usize,
// circuit bootstrap parameters
cbs_decomposition_level_count: usize,
cbs_decomposition_base_log: usize,
// side resources
fft: *const Fft,
stack: *mut u8,
stack_size: usize,
) {
nounwind(|| {
assert_eq!(ct_out_count, lut_count);
let bsk_output_lwe_dimension = bsk_glwe_dimension * bsk_polynomial_size;
assert_eq!(bsk_output_lwe_dimension, fpksk_input_dimension);
assert_eq!(ct_in_dimension, bsk_input_lwe_dimension);
assert_eq!(
ct_out_dimension,
fpksk_output_glwe_dimension * fpksk_output_polynomial_size
);
assert_eq!(lut_size, 1 << ct_in_count);
assert_ne!(cbs_decomposition_base_log, 0);
assert_ne!(cbs_decomposition_level_count, 0);
assert!(cbs_decomposition_level_count * cbs_decomposition_base_log <= 64);
let bsk_glwe_params = GlweParams {
dimension: bsk_glwe_dimension,
polynomial_size: bsk_polynomial_size,
};
let luts = PolynomialList::new(
slice::from_raw_parts(lut, lut_size * lut_count),
lut_size,
lut_count,
);
let fourier_bsk = BootstrapKey::<&[f64]>::from_raw_parts(
fourier_bsk,
bsk_glwe_params,
bsk_input_lwe_dimension,
DecompParams {
level: bsk_decomposition_level_count,
base_log: bsk_decomposition_base_log,
},
);
let lwe_list_out = LweCiphertextList::<&mut [u64]>::from_raw_parts(
ct_out_vec,
ct_out_dimension,
ct_out_count,
);
let lwe_list_in =
LweCiphertextList::<&[u64]>::from_raw_parts(ct_in_vec, ct_in_dimension, ct_in_count);
let fpksk_list = PackingKeyswitchKeyList::new(
slice::from_raw_parts(
fpksk,
fpksk_decomposition_level_count
* (fpksk_output_glwe_dimension + 1)
* fpksk_output_polynomial_size
* (fpksk_input_dimension + 1)
* fpksk_count,
),
GlweParams {
dimension: fpksk_output_glwe_dimension,
polynomial_size: fpksk_output_polynomial_size,
},
fpksk_input_dimension,
DecompParams {
level: fpksk_decomposition_level_count,
base_log: fpksk_decomposition_base_log,
},
fpksk_count,
);
circuit_bootstrap_boolean_vertical_packing(
luts,
fourier_bsk,
lwe_list_out,
lwe_list_in,
fpksk_list,
DecompParams {
level: cbs_decomposition_level_count,
base_log: cbs_decomposition_base_log,
},
(*fft).as_view(),
DynStack::new(slice::from_raw_parts_mut(stack as _, stack_size)),
);
})
}
#[no_mangle]
pub unsafe extern "C" fn concrete_cpu_lwe_packing_keyswitch_key_size(
glwe_dimension: usize,
polynomial_size: usize,
decomposition_level_count: usize,
input_dimension: usize,
) -> usize {
PackingKeyswitchKey::<&[u64]>::data_len(
GlweParams {
dimension: glwe_dimension,
polynomial_size,
},
decomposition_level_count,
input_dimension,
)
}

View File

@@ -0,0 +1,338 @@
use crate::implementation::cmux::cmux_scratch;
use aligned_vec::CACHELINE_ALIGN;
use dyn_stack::*;
use super::{
cmux::cmux,
fft::FftView,
polynomial::{update_with_wrapping_monic_monomial_mul, update_with_wrapping_unit_monomial_div},
types::*,
zip_eq, Split,
};
impl<'a> BootstrapKey<&'a [f64]> {
pub fn blind_rotate_scratch(
bsk_glwe_params: GlweParams,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
StackReq::try_all_of([
StackReq::try_new_aligned::<u64>(
(bsk_glwe_params.dimension + 1) * bsk_glwe_params.polynomial_size,
CACHELINE_ALIGN,
)?,
cmux_scratch(bsk_glwe_params, fft)?,
])
}
pub fn bootstrap_scratch(
bsk_glwe_params: GlweParams,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
StackReq::try_all_of([
StackReq::try_new_aligned::<u64>(
(bsk_glwe_params.dimension + 1) * bsk_glwe_params.polynomial_size,
CACHELINE_ALIGN,
)?,
Self::blind_rotate_scratch(bsk_glwe_params, fft)?,
])
}
pub fn blind_rotate(
self,
mut lut: GlweCiphertext<&mut [u64]>,
lwe: LweCiphertext<&[u64]>,
fft: FftView<'_>,
mut stack: DynStack<'_>,
) {
let (lwe_body, lwe_mask) = lwe.into_data().split_last().unwrap();
let lut_poly_size = lut.glwe_params.polynomial_size;
let modulus_switched_body = pbs_modulus_switch(*lwe_body, lut_poly_size, 0, 0);
for polynomial in lut.as_mut_view().into_data().into_chunks(lut_poly_size) {
update_with_wrapping_unit_monomial_div(polynomial, modulus_switched_body);
}
// We initialize the ct_0 used for the successive cmuxes
let mut ct0 = lut;
for (lwe_mask_element, bootstrap_key_ggsw) in zip_eq(lwe_mask.iter(), self.into_ggsw_iter())
{
if *lwe_mask_element != 0 {
let stack = stack.rb_mut();
// We copy ct_0 to ct_1
let (mut ct1, stack) = stack
.collect_aligned(CACHELINE_ALIGN, ct0.as_view().into_data().iter().copied());
let mut ct1 = GlweCiphertext::new(&mut *ct1, ct0.glwe_params);
// We rotate ct_1 by performing ct_1 <- ct_1 * X^{modulus_switched_mask_element}
let polynomial_size = ct1.glwe_params.polynomial_size;
let modulus_switched_mask_element =
pbs_modulus_switch(*lwe_mask_element, lut_poly_size, 0, 0);
for polynomial in ct1.as_mut_view().into_data().into_chunks(polynomial_size) {
update_with_wrapping_monic_monomial_mul(
polynomial,
modulus_switched_mask_element,
);
}
cmux(
ct0.as_mut_view(),
ct1.as_mut_view(),
bootstrap_key_ggsw,
fft,
stack,
);
}
}
}
pub fn bootstrap(
self,
lwe_out: LweCiphertext<&mut [u64]>,
lwe_in: LweCiphertext<&[u64]>,
accumulator: GlweCiphertext<&[u64]>,
fft: FftView<'_>,
stack: DynStack<'_>,
) {
let (mut local_accumulator_data, stack) = stack.collect_aligned(
CACHELINE_ALIGN,
accumulator.as_view().into_data().iter().copied(),
);
let mut local_accumulator =
GlweCiphertext::new(&mut *local_accumulator_data, accumulator.glwe_params);
self.blind_rotate(local_accumulator.as_mut_view(), lwe_in, fft, stack);
local_accumulator
.as_view()
.fill_lwe_with_sample_extraction(lwe_out, 0);
}
}
/// This function switches modulus for a single coefficient of a ciphertext,
/// only in the context of a PBS
///
/// offset: the number of msb discarded
/// lut_count_log: the right padding
pub fn pbs_modulus_switch(
input: u64,
poly_size: usize,
offset: usize,
lut_count_log: usize,
) -> usize {
// First, do the left shift (we discard the offset msb)
let mut output = input << offset;
// Start doing the right shift
output >>= u64::BITS as usize - int_log2(poly_size) - 2 + lut_count_log;
// Do the rounding
output += output & 1_u64;
// Finish the right shift
output >>= 1;
// Apply the lsb padding
output <<= lut_count_log;
output as usize
}
#[cfg(test)]
mod tests {
use std::mem::MaybeUninit;
use crate::{
c_api::types::tests::to_generic,
implementation::{
fft::{Fft, FftView},
types::*,
},
};
use concrete_csprng::{
generators::{RandomGenerator, SoftwareRandomGenerator},
seeders::Seed,
};
use dyn_stack::DynStack;
struct KeySet {
in_dim: usize,
glwe_params: GlweParams,
decomp_params: DecompParams,
in_sk: LweSecretKey<Vec<u64>>,
out_sk: LweSecretKey<Vec<u64>>,
fourier_bsk: BootstrapKey<Vec<f64>>,
fft: Fft,
stack: Vec<MaybeUninit<u8>>,
}
#[allow(clippy::too_many_arguments)]
fn new_bsk(
csprng: CsprngMut,
in_dim: usize,
glwe_params: GlweParams,
decomp_params: DecompParams,
key_variance: f64,
in_sk: LweSecretKey<&[u64]>,
out_sk: GlweSecretKey<&[u64]>,
fft: FftView,
stack: DynStack,
) -> Vec<f64> {
let bsk_len = glwe_params.polynomial_size
* (glwe_params.dimension + 1)
* (glwe_params.dimension + 1)
* in_dim
* decomp_params.level;
let mut bsk = vec![0_u64; bsk_len];
BootstrapKey::new(bsk.as_mut_slice(), glwe_params, in_dim, decomp_params)
.fill_with_new_key_par(in_sk, out_sk, key_variance, csprng);
let standard = BootstrapKey::new(bsk.as_slice(), glwe_params, in_dim, decomp_params);
let mut bsk_f = vec![0.; bsk_len];
let mut fourier =
BootstrapKey::new(bsk_f.as_mut_slice(), glwe_params, in_dim, decomp_params);
fourier.fill_with_forward_fourier(standard, fft, stack);
bsk_f
}
impl KeySet {
fn new(
mut csprng: CsprngMut,
in_dim: usize,
glwe_params: GlweParams,
decomp_params: DecompParams,
key_variance: f64,
) -> Self {
let in_sk = LweSecretKey::new_random(csprng.as_mut(), in_dim);
let out_sk = LweSecretKey::new_random(csprng.as_mut(), glwe_params.lwe_dimension());
let fft = Fft::new(glwe_params.polynomial_size);
let mut stack = vec![MaybeUninit::new(0_u8); 100000];
let bsk_f = new_bsk(
csprng,
in_dim,
glwe_params,
decomp_params,
key_variance,
in_sk.as_view(),
GlweSecretKey::new(out_sk.data.as_slice(), glwe_params),
fft.as_view(),
DynStack::new(&mut stack),
);
let fft = Fft::new(glwe_params.polynomial_size);
let fourier_bsk = BootstrapKey::new(bsk_f, glwe_params, in_dim, decomp_params);
Self {
in_dim,
glwe_params,
decomp_params,
in_sk,
out_sk,
fourier_bsk,
fft,
stack,
}
}
fn bootstrap(
&mut self,
csprng: CsprngMut,
pt: u64,
encryption_variance: f64,
lut: &[u64],
) -> u64 {
let mut input = LweCiphertext::zero(self.in_dim);
let mut output = LweCiphertext::zero(self.glwe_params.lwe_dimension());
self.in_sk
.as_view()
.encrypt_lwe(input.as_mut_view(), pt, encryption_variance, csprng);
assert_eq!(
lut.len(),
(self.glwe_params.dimension + 1) * self.glwe_params.polynomial_size
);
let accumulator = GlweCiphertext::new(lut, self.glwe_params);
self.fourier_bsk.as_view().bootstrap(
output.as_mut_view(),
input.as_view(),
accumulator,
self.fft.as_view(),
DynStack::new(&mut self.stack),
);
self.out_sk.as_view().decrypt_lwe(output.as_view())
}
}
#[test]
fn bootstrap_correctness() {
let mut csprng = SoftwareRandomGenerator::new(Seed(0));
let glwe_dim = 1;
let log2_poly_size = 10;
let polynomial_size = 1 << log2_poly_size;
let mut keyset = KeySet::new(
to_generic(&mut csprng),
600,
GlweParams {
dimension: glwe_dim,
polynomial_size,
},
DecompParams {
level: 3,
base_log: 10,
},
0.0000000000000000000001,
);
let log2_precision = 4;
let precision = 1 << log2_precision;
let lut_case_number: u64 = precision;
assert_eq!(polynomial_size as u64 % lut_case_number, 0);
let lut_case_size = polynomial_size as u64 / lut_case_number;
for _ in 0..100 {
let lut_index: u64 =
u64::from_le_bytes(std::array::from_fn(|_| csprng.next().unwrap()))
% (2 * precision);
let lut: Vec<u64> = (0..lut_case_number)
.map(|_| u64::from_le_bytes(std::array::from_fn(|_| csprng.next().unwrap())))
.collect();
let raw_lut: Vec<u64> = (0..glwe_dim)
.flat_map(|_| (0..polynomial_size).map(|_| 0))
.chain(
lut.iter()
.flat_map(|&lut_value| (0..lut_case_size).map(move |_| lut_value)),
)
.collect();
let expected_image = if lut_index < precision {
lut[lut_index as usize]
} else {
lut[(lut_index - precision) as usize].wrapping_neg()
};
let pt = (lut_index as f64 + 0.5) / (2. * lut_case_number as f64) * 2.0_f64.powi(64);
let image =
keyset.bootstrap(to_generic(&mut csprng), pt as u64, 0.0000000001, &raw_lut);
let diff = image.wrapping_sub(expected_image) as i64;
assert!((diff as f64).abs() / 2.0_f64.powi(64) < 0.01);
}
}
}

View File

@@ -0,0 +1,35 @@
use super::{types::*, Split};
impl<'a> GlweCiphertext<&'a [u64]> {
pub fn fill_lwe_with_sample_extraction(self, lwe: LweCiphertext<&mut [u64]>, n_th: usize) {
let polynomial_size = self.glwe_params.polynomial_size;
// We retrieve the bodies and masks of the two ciphertexts.
let (lwe_body, lwe_mask) = lwe.into_data().split_last_mut().unwrap();
let glwe_index = self.glwe_params.dimension * polynomial_size;
let (glwe_mask, glwe_body) = self.into_data().split_at(glwe_index);
// We copy the body
*lwe_body = glwe_body[n_th];
// We copy the mask (each polynomial is in the wrong order)
lwe_mask.copy_from_slice(glwe_mask);
// We compute the number of elements which must be
// turned into their opposite
let opposite_count = polynomial_size - n_th - 1;
// We loop through the polynomials (as mut tensors)
for lwe_mask_poly in lwe_mask.into_chunks(polynomial_size) {
// We reverse the polynomial
lwe_mask_poly.reverse();
// We compute the opposite of the proper coefficients
for x in lwe_mask_poly[0..opposite_count].iter_mut() {
*x = x.wrapping_neg()
}
// We rotate the polynomial properly
lwe_mask_poly.rotate_left(opposite_count);
}
}
}

View File

@@ -0,0 +1,25 @@
use super::{external_product::external_product, fft::FftView, types::*, zip_eq};
use crate::implementation::external_product::external_product_scratch;
use dyn_stack::{DynStack, SizeOverflow, StackReq};
/// Returns the required memory for [`cmux`].
pub fn cmux_scratch(
ggsw_glwe_params: GlweParams,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
external_product_scratch(ggsw_glwe_params, fft)
}
/// This cmux mutates both ct1 and ct0. The result is in ct0 after the method was called.
pub fn cmux(
ct0: GlweCiphertext<&mut [u64]>,
mut ct1: GlweCiphertext<&mut [u64]>,
fourier_ggsw: GgswCiphertext<&[f64]>,
fft: FftView<'_>,
stack: DynStack<'_>,
) {
for (c1, c0) in zip_eq(ct1.as_mut_view().into_data(), ct0.as_view().into_data()) {
*c1 = c1.wrapping_sub(*c0);
}
external_product(ct0, fourier_ggsw, ct1.as_view(), fft, stack);
}

View File

@@ -0,0 +1,267 @@
use crate::implementation::{assume_init_mut, from_torus};
use super::{
as_mut_uninit,
fft::{FftView, Twisties},
};
use bytemuck::cast_slice_mut;
use concrete_fft::c64;
use core::mem::MaybeUninit;
use dyn_stack::{DynStack, SizeOverflow, StackReq};
use pulp::{as_arrays, as_arrays_mut};
pub mod x86;
fn convert_forward_integer_u64_scalar(
out: &mut [MaybeUninit<f64>],
in_re: &[u64],
in_im: &[u64],
twisties: Twisties<&[f64]>,
) {
debug_assert_eq!(out.len(), in_re.len() * 2);
let (out, _) = as_arrays_mut::<2, _>(out);
for (out, in_re, in_im, w_re, w_im) in izip!(out, in_re, in_im, twisties.re, twisties.im) {
// Don't remove the cast to i64. It can reduce the noise by up to 10 bits.
let in_re: f64 = *in_re as i64 as f64;
let in_im: f64 = *in_im as i64 as f64;
out[0].write(in_re * w_re - in_im * w_im);
out[1].write(in_re * w_im + in_im * w_re);
}
}
fn convert_add_backward_torus_u64_scalar(
out_re: &mut [u64],
out_im: &mut [u64],
inp: &[f64],
twisties: Twisties<&[f64]>,
) {
debug_assert_eq!(inp.len(), out_re.len() * 2);
let (inp, _) = as_arrays::<2, _>(inp);
let normalization = 1.0 / inp.len() as f64;
for (out_re, out_im, &inp, &w_re, &w_im) in izip!(out_re, out_im, inp, twisties.re, twisties.im)
{
let w_re = w_re * normalization;
let w_im = w_im * normalization;
let tmp_re = inp[0] * w_re + inp[1] * w_im;
let tmp_im = inp[1] * w_re - inp[0] * w_im;
*out_re = out_re.wrapping_add(from_torus(tmp_re));
*out_im = out_im.wrapping_add(from_torus(tmp_im));
}
}
fn convert_forward_torus_u64(
out: &mut [MaybeUninit<f64>],
in_re: &[u64],
in_im: &[u64],
twisties: Twisties<&[f64]>,
) {
debug_assert_eq!(out.len(), in_re.len() * 2);
let normalization = 2.0_f64.powi(-(u64::BITS as i32));
let (out, _) = as_arrays_mut::<2, _>(out);
for (out, in_re, in_im, w_re, w_im) in izip!(out, in_re, in_im, twisties.re, twisties.im) {
// Don't remove the cast to i64. It can reduce the noise by up to 10 bits.
let in_re: f64 = *in_re as i64 as f64 * normalization;
let in_im: f64 = *in_im as i64 as f64 * normalization;
out[0].write(in_re * w_re - in_im * w_im);
out[1].write(in_re * w_im + in_im * w_re);
}
}
fn convert_backward_torus_u64(
out_re: &mut [MaybeUninit<u64>],
out_im: &mut [MaybeUninit<u64>],
inp: &[f64],
twisties: Twisties<&[f64]>,
) {
debug_assert_eq!(inp.len(), out_re.len() * 2);
let (inp, _) = as_arrays::<2, _>(inp);
let normalization = 1.0 / inp.len() as f64;
for (out_re, out_im, inp, w_re, w_im) in izip!(out_re, out_im, inp, twisties.re, twisties.im) {
let w_re = w_re * normalization;
let w_im = w_im * normalization;
let tmp_re = inp[0] * w_re + inp[1] * w_im;
let tmp_im = inp[1] * w_re - inp[0] * w_im;
out_re.write(from_torus(tmp_re));
out_im.write(from_torus(tmp_im));
}
}
fn convert_forward_integer_u64(
out: &mut [MaybeUninit<f64>],
in_re: &[u64],
in_im: &[u64],
twisties: Twisties<&[f64]>,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
x86::convert_forward_integer_u64(out, in_re, in_im, twisties);
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
convert_forward_integer_u64_scalar(out, in_re, in_im, twisties);
}
fn convert_add_backward_torus_u64(
out_re: &mut [u64],
out_im: &mut [u64],
inp: &[f64],
twisties: Twisties<&[f64]>,
) {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
x86::convert_add_backward_torus_u64(out_re, out_im, inp, twisties);
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
convert_add_backward_torus_u64_scalar(out_re, out_im, inp, twisties);
}
impl FftView<'_> {
/// Returns the polynomial size that this FFT was made for.
pub fn polynomial_size(self) -> usize {
2 * self.plan.fft_size()
}
/// Returns the memory required for a forward negacyclic FFT.
pub fn forward_scratch(self) -> Result<StackReq, SizeOverflow> {
self.plan.fft_scratch()
}
/// Returns the memory required for a backward negacyclic FFT.
pub fn backward_scratch(self) -> Result<StackReq, SizeOverflow> {
self.plan
.fft_scratch()?
.try_and(StackReq::try_new_aligned::<c64>(
self.polynomial_size() / 2,
aligned_vec::CACHELINE_ALIGN,
)?)
}
/// Performs a negacyclic real FFT of `standard`, viewed as torus elements, and stores the
/// result in `fourier`.
///
/// # Postconditions
///
/// this function leaves all the elements of `fourier` in an initialized state.
///
/// # Panics
///
/// Panics if `standard`, `fourier` and `self` have different polynomial sizes.
pub fn forward_as_torus(
self,
fourier: &mut [MaybeUninit<f64>],
standard: &[u64],
stack: DynStack<'_>,
) {
// SAFETY: `convert_forward_torus` initializes the output slice that is passed to it
unsafe { self.forward_with_conv(fourier, standard, convert_forward_torus_u64, stack) }
}
/// Performs a negacyclic real FFT of `standard`, viewed as integers, and stores the
/// result in `fourier`.
///
/// # Postconditions
///
/// this function leaves all the elements of `fourier` in an initialized state.
///
/// # Panics
///
/// Panics if `standard`, `fourier` and `self` have different polynomial sizes.
pub fn forward_as_integer(
self,
fourier: &mut [MaybeUninit<f64>],
standard: &[u64],
stack: DynStack<'_>,
) {
// SAFETY: `convert_forward_torus` initializes the output slice that is passed to it
unsafe { self.forward_with_conv(fourier, standard, convert_forward_integer_u64, stack) }
}
/// Performs an inverse negacyclic real FFT of `fourier` and stores the result in `standard`,
/// viewed as torus elements.
///
/// # Postconditions
///
/// this function leaves all the elements of `standard` in an initialized state.
///
/// # Panics
///
/// Panics if `standard`, `fourier` and `self` have different polynomial sizes.
pub fn backward_as_torus(
self,
standard: &mut [MaybeUninit<u64>],
fourier: &[f64],
stack: DynStack<'_>,
) {
// SAFETY: `convert_backward_torus` initializes the output slices that are passed to it
unsafe { self.backward_with_conv(standard, fourier, convert_backward_torus_u64, stack) }
}
/// Performs an inverse negacyclic real FFT of `fourier` and adds the result to `standard`,
/// viewed as torus elements.
///
/// # Panics
///
/// Panics if `standard`, `fourier` and `self` have different polynomial sizes.
pub fn add_backward_as_torus(self, standard: &mut [u64], fourier: &[f64], stack: DynStack<'_>) {
// SAFETY: `convert_add_backward_torus` initializes the output slices that are passed to it
unsafe {
self.backward_with_conv(
as_mut_uninit(standard),
fourier,
|out_re, out_im, inp, twisties| {
convert_add_backward_torus_u64(
assume_init_mut(out_re),
assume_init_mut(out_im),
inp,
twisties,
)
},
stack,
)
}
}
/// # Safety
///
/// `conv_fn` must initialize the entirety of the mutable slice that it receives.
unsafe fn forward_with_conv(
self,
fourier: &mut [MaybeUninit<f64>],
standard: &[u64],
conv_fn: impl Fn(&mut [MaybeUninit<f64>], &[u64], &[u64], Twisties<&[f64]>),
stack: DynStack<'_>,
) {
let n = standard.len();
debug_assert_eq!(n, fourier.len());
let (standard_re, standard_im) = standard.split_at(n / 2);
conv_fn(fourier, standard_re, standard_im, self.twisties);
let fourier = cast_slice_mut(unsafe { assume_init_mut(fourier) });
self.plan.fwd(fourier, stack);
}
/// # Safety
///
/// `conv_fn` must initialize the entirety of the mutable slices that it receives.
unsafe fn backward_with_conv(
self,
standard: &mut [MaybeUninit<u64>],
fourier: &[f64],
conv_fn: impl Fn(&mut [MaybeUninit<u64>], &mut [MaybeUninit<u64>], &[f64], Twisties<&[f64]>),
stack: DynStack<'_>,
) {
let n = standard.len();
debug_assert_eq!(n, fourier.len());
let (mut tmp, stack) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, fourier.iter().copied());
self.plan.inv(cast_slice_mut(&mut tmp), stack);
let (standard_re, standard_im) = standard.split_at_mut(n / 2);
conv_fn(standard_re, standard_im, &tmp, self.twisties);
}
}

View File

@@ -0,0 +1,706 @@
#[cfg(target_arch = "x86")]
use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
use core::mem::MaybeUninit;
use pulp::{as_arrays, as_arrays_mut, cast, simd_type};
use crate::implementation::fft::Twisties;
simd_type! {
pub struct FusedMulAdd {
pub sse2: "sse2",
pub avx: "avx",
pub avx2: "avx2",
pub fma: "fma"
}
#[cfg(feature = "nightly")]
pub struct Avx512 {
pub avx512f: "avx512f",
pub avx512dq: "avx512dq",
}
}
pub unsafe trait CastInto<T: Copy>: Copy {
#[inline(always)]
fn transmute(self) -> T {
debug_assert_eq!(core::mem::size_of::<T>(), core::mem::size_of::<Self>());
unsafe { core::mem::transmute_copy(&self) }
}
}
// anything can be cast into MaybeUninit
unsafe impl<const N: usize, T: Copy, U: Copy> CastInto<[MaybeUninit<T>; N]> for U {}
unsafe impl CastInto<[f64; 4]> for __m256d {}
unsafe impl CastInto<__m256d> for [f64; 4] {}
unsafe impl CastInto<[u64; 4]> for __m256i {}
unsafe impl CastInto<__m256i> for [u64; 4] {}
unsafe impl CastInto<[i64; 4]> for __m256i {}
unsafe impl CastInto<__m256i> for [i64; 4] {}
#[cfg(feature = "nightly")]
mod nightly_impls {
use super::*;
unsafe impl CastInto<[f64; 8]> for __m512d {}
unsafe impl CastInto<__m512d> for [f64; 8] {}
unsafe impl CastInto<[u64; 8]> for __m512i {}
unsafe impl CastInto<__m512i> for [u64; 8] {}
unsafe impl CastInto<[i64; 8]> for __m512i {}
unsafe impl CastInto<__m512i> for [i64; 8] {}
}
#[inline(always)]
pub fn simd_cast<T: CastInto<U>, U: Copy>(t: T) -> U {
t.transmute()
}
/// Converts a vector of f64 values to a vector of i64 values.
/// See `f64_to_i64_bit_twiddles` in `fft/tests.rs` for the scalar version.
#[inline(always)]
fn mm256_cvtpd_epi64(simd: FusedMulAdd, x: __m256d) -> __m256i {
let FusedMulAdd { avx, avx2, .. } = simd;
// reinterpret the bits as u64 values
let bits = avx._mm256_castpd_si256(x);
// mask that covers the first 52 bits
let mantissa_mask = avx._mm256_set1_epi64x(0xFFFFFFFFFFFFF_u64 as i64);
// mask that covers the 52nd bit
let explicit_mantissa_bit = avx._mm256_set1_epi64x(0x10000000000000_u64 as i64);
// mask that covers the first 11 bits
let exp_mask = avx._mm256_set1_epi64x(0x7FF_u64 as i64);
// extract the first 52 bits and add the implicit bit
let mantissa = avx2._mm256_or_si256(
avx2._mm256_and_si256(bits, mantissa_mask),
explicit_mantissa_bit,
);
// extract the 52nd to 63rd (excluded) bits for the biased exponent
let biased_exp = avx2._mm256_and_si256(avx2._mm256_srli_epi64::<52>(bits), exp_mask);
// extract the 63rd sign bit
let sign_is_negative_mask = avx2._mm256_sub_epi64(
avx._mm256_setzero_si256(),
avx2._mm256_srli_epi64::<63>(bits),
);
// we need to shift the mantissa by some value that may be negative, so we first shift
// it to the left by the maximum amount, then shift it to the right by our
// value plus the offset we just shifted by
//
// the 52nd bit is set to 1, so we shift to the left by 11 so the 63rd (last) bit is
// set.
let mantissa_lshift = avx2._mm256_slli_epi64::<11>(mantissa);
// shift to the right and apply the exponent bias
let mantissa_shift = avx2._mm256_srlv_epi64(
mantissa_lshift,
avx2._mm256_sub_epi64(avx._mm256_set1_epi64x(1086), biased_exp),
);
// if the sign bit is unset, we keep our result
let value_if_positive = mantissa_shift;
// otherwise, we negate it
let value_if_negative = avx2._mm256_sub_epi64(avx._mm256_setzero_si256(), value_if_positive);
// if the biased exponent is all zeros, we have a subnormal value (or zero)
// if it is not subnormal, we keep our results
let value_if_non_subnormal =
avx2._mm256_blendv_epi8(value_if_positive, value_if_negative, sign_is_negative_mask);
// if it is subnormal, the conversion to i64 (rounding towards zero) returns zero
let value_if_subnormal = avx._mm256_setzero_si256();
// compare the biased exponent to a zero value
let is_subnormal = avx2._mm256_cmpeq_epi64(biased_exp, avx._mm256_setzero_si256());
// choose the result depending on subnormalness
avx2._mm256_blendv_epi8(value_if_non_subnormal, value_if_subnormal, is_subnormal)
}
/// Converts a vector of f64 values to a vector of i64 values.
/// See `f64_to_i64_bit_twiddles` in `fft/tests.rs` for the scalar version.
#[cfg(feature = "nightly")]
#[inline(always)]
fn mm512_cvtpd_epi64(simd: Avx512, x: __m512d) -> __m512i {
let simd = simd.avx512f;
// reinterpret the bits as u64 values
let bits = simd._mm512_castpd_si512(x);
// mask that covers the first 52 bits
let mantissa_mask = simd._mm512_set1_epi64(0xFFFFFFFFFFFFF_u64 as i64);
// mask that covers the 53rd bit
let explicit_mantissa_bit = simd._mm512_set1_epi64(0x10000000000000_u64 as i64);
// mask that covers the first 11 bits
let exp_mask = simd._mm512_set1_epi64(0x7FF_u64 as i64);
// extract the first 52 bits and add the implicit bit
let mantissa = simd._mm512_or_si512(
simd._mm512_and_si512(bits, mantissa_mask),
explicit_mantissa_bit,
);
// extract the 52nd to 63rd (excluded) bits for the biased exponent
let biased_exp = simd._mm512_and_si512(simd._mm512_srli_epi64::<52>(bits), exp_mask);
// extract the 63rd sign bit
let sign_is_negative_mask = simd._mm512_cmpneq_epi64_mask(
simd._mm512_srli_epi64::<63>(bits),
simd._mm512_set1_epi64(1),
);
// we need to shift the mantissa by some value that may be negative, so we first shift it to
// the left by the maximum amount, then shift it to the right by our value plus the offset we
// just shifted by
//
// the 53rd bit is set to 1, so we shift to the left by 10 so the 63rd (last) bit is set.
let mantissa_lshift = simd._mm512_slli_epi64::<11>(mantissa);
// shift to the right and apply the exponent bias
let mantissa_shift = simd._mm512_srlv_epi64(
mantissa_lshift,
simd._mm512_sub_epi64(simd._mm512_set1_epi64(1086), biased_exp),
);
// if the sign bit is unset, we keep our result
let value_if_positive = mantissa_shift;
// otherwise, we negate it
let value_if_negative = simd._mm512_sub_epi64(simd._mm512_setzero_si512(), value_if_positive);
// if the biased exponent is all zeros, we have a subnormal value (or zero)
// if it is not subnormal, we keep our results
let value_if_non_subnormal =
simd._mm512_mask_blend_epi64(sign_is_negative_mask, value_if_positive, value_if_negative);
// if it is subnormal, the conversion to i64 (rounding towards zero) returns zero
let value_if_subnormal = simd._mm512_setzero_si512();
// compare the biased exponent to a zero value
let is_subnormal = simd._mm512_cmpeq_epi64_mask(biased_exp, simd._mm512_setzero_si512());
// choose the result depending on subnormalness
simd._mm512_mask_blend_epi64(is_subnormal, value_if_non_subnormal, value_if_subnormal)
}
/// Converts a vector of i64 values to a vector of f64 values. Not sure how it works.
/// Ported from <https://stackoverflow.com/a/41148578>.
#[inline(always)]
fn mm256_cvtepi64_pd(simd: FusedMulAdd, x: __m256i) -> __m256d {
let FusedMulAdd { avx, avx2, .. } = simd;
let mut x_hi = avx2._mm256_srai_epi32::<16>(x);
x_hi = avx2._mm256_blend_epi16::<0x33>(x_hi, avx._mm256_setzero_si256());
x_hi = avx2._mm256_add_epi64(
x_hi,
avx._mm256_castpd_si256(avx._mm256_set1_pd(442721857769029238784.0)), // 3*2^67
);
let x_lo = avx2._mm256_blend_epi16::<0x88>(
x,
avx._mm256_castpd_si256(avx._mm256_set1_pd(4503599627370496.0)),
); // 2^52
let f = avx._mm256_sub_pd(
avx._mm256_castsi256_pd(x_hi),
avx._mm256_set1_pd(442726361368656609280.0), // 3*2^67 + 2^52
);
avx._mm256_add_pd(f, avx._mm256_castsi256_pd(x_lo))
}
/// Converts a vector of i64 values to a vector of f64 values.
#[cfg(feature = "nightly")]
#[inline(always)]
fn mm512_cvtepi64_pd(simd: Avx512, x: __m512i) -> __m512d {
// hopefully this compiles to vcvtqq2pd
simd.vectorize(
#[inline(always)]
|| {
let i64x8: [i64; 8] = simd_cast(x);
let as_f64x8 = [
i64x8[0] as f64,
i64x8[1] as f64,
i64x8[2] as f64,
i64x8[3] as f64,
i64x8[4] as f64,
i64x8[5] as f64,
i64x8[6] as f64,
i64x8[7] as f64,
];
simd_cast(as_f64x8)
},
)
}
#[cfg(feature = "nightly")]
fn convert_forward_integer_u64_avx512(
simd: Avx512,
out: &mut [MaybeUninit<f64>],
in_re: &[u64],
in_im: &[u64],
twisties: Twisties<&[f64]>,
) {
let n = in_re.len();
debug_assert_eq!(n % 8, 0);
debug_assert_eq!(2 * n, out.len());
debug_assert_eq!(n, in_re.len());
debug_assert_eq!(n, in_im.len());
debug_assert_eq!(n, twisties.re.len());
debug_assert_eq!(n, twisties.im.len());
let (out, _) = as_arrays_mut::<16, _>(out);
let (in_re, _) = as_arrays::<8, _>(in_re);
let (in_im, _) = as_arrays::<8, _>(in_im);
let (w_re, _) = as_arrays::<8, _>(twisties.re);
let (w_im, _) = as_arrays::<8, _>(twisties.im);
simd.vectorize(
#[inline(always)]
|| {
for (out, in_re, in_im, w_re, w_im) in izip!(
out,
in_re.iter().copied(),
in_im.iter().copied(),
w_re.iter().copied(),
w_im.iter().copied(),
) {
// convert to i64, then to f64
// the intermediate conversion to i64 can reduce noise by up to 10 bits.
let in_re = mm512_cvtepi64_pd(simd, simd_cast(in_re));
let in_im = mm512_cvtepi64_pd(simd, simd_cast(in_im));
let w_re = simd_cast(w_re);
let w_im = simd_cast(w_im);
let simd = simd.avx512f;
// perform complex multiplication
let out_re = simd._mm512_fmsub_pd(in_re, w_re, simd._mm512_mul_pd(in_im, w_im));
let out_im = simd._mm512_fmadd_pd(in_re, w_im, simd._mm512_mul_pd(in_im, w_re));
// we have
// x0 x1 x2 x3 x4 x5 x6 x7
// y0 y1 y2 y3 y4 y5 y6 y7
//
// we want
// x0 y0 x1 y1 x2 y2 x3 y3
// x4 y4 x5 y5 x6 y6 x7 y7
// interleave real part and imaginary part
{
let idx0 = simd._mm512_setr_epi64(
0b0000, 0b1000, 0b0001, 0b1001, 0b0010, 0b1010, 0b0011, 0b1011,
);
let idx1 = simd._mm512_setr_epi64(
0b0100, 0b1100, 0b0101, 0b1101, 0b0110, 0b1110, 0b0111, 0b1111,
);
let out0 = simd._mm512_permutex2var_pd(out_re, idx0, out_im);
let out1 = simd._mm512_permutex2var_pd(out_re, idx1, out_im);
// store c64 values
*out = simd_cast([out0, out1]);
}
}
},
);
}
fn convert_forward_integer_u64_fma(
simd: FusedMulAdd,
out: &mut [MaybeUninit<f64>],
in_re: &[u64],
in_im: &[u64],
twisties: Twisties<&[f64]>,
) {
let n = in_re.len();
debug_assert_eq!(n % 4, 0);
debug_assert_eq!(2 * n, out.len());
debug_assert_eq!(n, in_re.len());
debug_assert_eq!(n, in_im.len());
debug_assert_eq!(n, twisties.re.len());
debug_assert_eq!(n, twisties.im.len());
let (out, _) = as_arrays_mut::<8, _>(out);
let (in_re, _) = as_arrays::<4, _>(in_re);
let (in_im, _) = as_arrays::<4, _>(in_im);
let (w_re, _) = as_arrays::<4, _>(twisties.re);
let (w_im, _) = as_arrays::<4, _>(twisties.im);
simd.vectorize(
#[inline(always)]
move || {
for (out, in_re, in_im, w_re, w_im) in izip!(
out,
in_re.iter().copied(),
in_im.iter().copied(),
w_re.iter().copied(),
w_im.iter().copied(),
) {
// convert to i64, then to f64
// the intermediate conversion to i64 can reduce noise by up to 10 bits.
let in_re = mm256_cvtepi64_pd(simd, cast(in_re));
let in_im = mm256_cvtepi64_pd(simd, cast(in_im));
let w_re = cast(w_re);
let w_im = cast(w_im);
let FusedMulAdd { avx, fma, .. } = simd;
// perform complex multiplication
let out_re = fma._mm256_fmsub_pd(in_re, w_re, avx._mm256_mul_pd(in_im, w_im));
let out_im = fma._mm256_fmadd_pd(in_re, w_im, avx._mm256_mul_pd(in_im, w_re));
// we have
// x0 x1 x2 x3
// y0 y1 y2 y3
//
// we want
// x0 y0 x1 y1
// x2 y2 x3 y3
// interleave real part and imaginary part
// unpacklo/unpackhi
// x0 y0 x2 y2
// x1 y1 x3 y3
let lo = avx._mm256_unpacklo_pd(out_re, out_im);
let hi = avx._mm256_unpackhi_pd(out_re, out_im);
let out0 = avx._mm256_permute2f128_pd::<0b00100000>(lo, hi);
let out1 = avx._mm256_permute2f128_pd::<0b00110001>(lo, hi);
// store c64 values
*out = simd_cast([out0, out1]);
}
},
);
}
/// Performs common work for `u32` and `u64`, used by the backward torus transformation.
///
/// This deinterleaves two vectors of c64 values into two vectors of real part and imaginary part,
/// then rounds to the nearest integer.
#[cfg(feature = "nightly")]
#[inline(always)]
fn convert_torus_prologue_avx512f(
simd: Avx512,
normalization: __m512d,
w_re: __m512d,
w_im: __m512d,
input0: __m512d,
input1: __m512d,
scaling: __m512d,
) -> (__m512d, __m512d) {
let simd = simd.avx512f;
let w_re = simd._mm512_mul_pd(normalization, w_re);
let w_im = simd._mm512_mul_pd(normalization, w_im);
// real indices
let idx0 = simd._mm512_setr_epi64(
0b0000, 0b0010, 0b0100, 0b0110, 0b1000, 0b1010, 0b1100, 0b1110,
);
// imaginary indices
let idx1 = simd._mm512_setr_epi64(
0b0001, 0b0011, 0b0101, 0b0111, 0b1001, 0b1011, 0b1101, 0b1111,
);
// re0 re1 re2 re3 re4 re5 re6 re7
let inp_re = simd._mm512_permutex2var_pd(input0, idx0, input1);
// im0 im1 im2 im3 im4 im5 im6 im7
let inp_im = simd._mm512_permutex2var_pd(input0, idx1, input1);
// perform complex multiplication with conj(w)
let mul_re = simd._mm512_fmadd_pd(inp_re, w_re, simd._mm512_mul_pd(inp_im, w_im));
let mul_im = simd._mm512_fnmadd_pd(inp_re, w_im, simd._mm512_mul_pd(inp_im, w_re));
// round to nearest integer and suppress exceptions
const ROUNDING: i32 = _MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC;
// get the fractional part (centered around zero) by subtracting rounded value
let fract_re = simd._mm512_sub_pd(mul_re, simd._mm512_roundscale_pd::<ROUNDING>(mul_re));
let fract_im = simd._mm512_sub_pd(mul_im, simd._mm512_roundscale_pd::<ROUNDING>(mul_im));
// scale fractional part and round
let fract_re = simd._mm512_roundscale_pd::<ROUNDING>(simd._mm512_mul_pd(scaling, fract_re));
let fract_im = simd._mm512_roundscale_pd::<ROUNDING>(simd._mm512_mul_pd(scaling, fract_im));
(fract_re, fract_im)
}
/// Performs common work for `u32` and `u64`, used by the backward torus transformation.
///
/// This deinterleaves two vectors of c64 values into two vectors of real part and imaginary part,
/// then rounds to the nearest integer.
#[inline(always)]
fn convert_torus_prologue_fma(
simd: FusedMulAdd,
normalization: __m256d,
w_re: __m256d,
w_im: __m256d,
input0: __m256d,
input1: __m256d,
scaling: __m256d,
) -> (__m256d, __m256d) {
let FusedMulAdd { avx, fma, sse2, .. } = simd;
let w_re = avx._mm256_mul_pd(normalization, w_re);
let w_im = avx._mm256_mul_pd(normalization, w_im);
// re0 im0
// re1 im1
let [inp0, inp1] = cast::<__m256d, [__m128d; 2]>(input0);
// re2 im2
// re3 im3
let [inp2, inp3] = cast::<__m256d, [__m128d; 2]>(input1);
// re0 re1
let inp_re01 = sse2._mm_unpacklo_pd(inp0, inp1);
// im0 im1
let inp_im01 = sse2._mm_unpackhi_pd(inp0, inp1);
// re2 re3
let inp_re23 = sse2._mm_unpacklo_pd(inp2, inp3);
// im2 im3
let inp_im23 = sse2._mm_unpackhi_pd(inp2, inp3);
// re0 re1 re2 re3
let inp_re = avx._mm256_insertf128_pd::<0b1>(avx._mm256_castpd128_pd256(inp_re01), inp_re23);
// im0 im1 im2 im3
let inp_im = avx._mm256_insertf128_pd::<0b1>(avx._mm256_castpd128_pd256(inp_im01), inp_im23);
// perform complex multiplication with conj(w)
let mul_re = fma._mm256_fmadd_pd(inp_re, w_re, avx._mm256_mul_pd(inp_im, w_im));
let mul_im = fma._mm256_fnmadd_pd(inp_re, w_im, avx._mm256_mul_pd(inp_im, w_re));
// round to nearest integer and suppress exceptions
const ROUNDING: i32 = _MM_FROUND_NINT | _MM_FROUND_NO_EXC;
// get the fractional part (centered around zero) by subtracting rounded value
let fract_re = avx._mm256_sub_pd(mul_re, avx._mm256_round_pd::<ROUNDING>(mul_re));
let fract_im = avx._mm256_sub_pd(mul_im, avx._mm256_round_pd::<ROUNDING>(mul_im));
// scale fractional part and round
let fract_re = avx._mm256_round_pd::<ROUNDING>(avx._mm256_mul_pd(scaling, fract_re));
let fract_im = avx._mm256_round_pd::<ROUNDING>(avx._mm256_mul_pd(scaling, fract_im));
(fract_re, fract_im)
}
#[cfg(feature = "nightly")]
fn convert_add_backward_torus_u64_avx512f(
simd: Avx512,
out_re: &mut [u64],
out_im: &mut [u64],
input: &[f64],
twisties: Twisties<&[f64]>,
) {
let n = out_re.len();
debug_assert_eq!(n % 8, 0);
debug_assert_eq!(n, out_re.len());
debug_assert_eq!(n, out_im.len());
debug_assert_eq!(2 * n, input.len());
debug_assert_eq!(n, twisties.re.len());
debug_assert_eq!(n, twisties.im.len());
let (out_re, _) = as_arrays_mut::<8, _>(out_re);
let (out_im, _) = as_arrays_mut::<8, _>(out_im);
let (inp, _) = as_arrays::<16, _>(input);
let (w_re, _) = as_arrays::<8, _>(twisties.re);
let (w_im, _) = as_arrays::<8, _>(twisties.im);
simd.vectorize(
#[inline(always)]
|| {
let normalization = simd.avx512f._mm512_set1_pd(1.0 / n as f64);
let scaling = simd.avx512f._mm512_set1_pd(2.0_f64.powi(u64::BITS as i32));
for (out_re, out_im, inp, w_re, w_im) in izip!(
out_re,
out_im,
inp.iter().copied(),
w_re.iter().copied(),
w_im.iter().copied(),
) {
let [input0, input1]: [[f64; 8]; 2] = cast(inp);
let (fract_re, fract_im) = convert_torus_prologue_avx512f(
simd,
normalization,
simd_cast(w_re),
simd_cast(w_im),
simd_cast(input0),
simd_cast(input1),
scaling,
);
// convert f64 to i64
let fract_re = mm512_cvtpd_epi64(simd, fract_re);
let fract_im = mm512_cvtpd_epi64(simd, fract_im);
// add to input and store
*out_re = simd_cast(simd.avx512f._mm512_add_epi64(fract_re, simd_cast(*out_re)));
*out_im = simd_cast(simd.avx512f._mm512_add_epi64(fract_im, simd_cast(*out_im)));
}
},
);
}
fn convert_add_backward_torus_u64_fma(
simd: FusedMulAdd,
out_re: &mut [u64],
out_im: &mut [u64],
input: &[f64],
twisties: Twisties<&[f64]>,
) {
let n = out_re.len();
debug_assert_eq!(n % 8, 0);
debug_assert_eq!(n, out_re.len());
debug_assert_eq!(n, out_im.len());
debug_assert_eq!(2 * n, input.len());
debug_assert_eq!(n, twisties.re.len());
debug_assert_eq!(n, twisties.im.len());
let (out_re, _) = as_arrays_mut::<4, _>(out_re);
let (out_im, _) = as_arrays_mut::<4, _>(out_im);
let (inp, _) = as_arrays::<8, _>(input);
let (w_re, _) = as_arrays::<4, _>(twisties.re);
let (w_im, _) = as_arrays::<4, _>(twisties.im);
simd.vectorize(
#[inline(always)]
|| {
let normalization = simd.avx._mm256_set1_pd(1.0 / n as f64);
let scaling = simd.avx._mm256_set1_pd(2.0_f64.powi(u64::BITS as i32));
for (out_re, out_im, inp, w_re, w_im) in izip!(
out_re,
out_im,
inp.iter().copied(),
w_re.iter().copied(),
w_im.iter().copied(),
) {
let [input0, input1]: [[f64; 4]; 2] = cast(inp);
let (fract_re, fract_im) = convert_torus_prologue_fma(
simd,
normalization,
simd_cast(w_re),
simd_cast(w_im),
simd_cast(input0),
simd_cast(input1),
scaling,
);
// convert f64 to i64
let fract_re = mm256_cvtpd_epi64(simd, fract_re);
let fract_im = mm256_cvtpd_epi64(simd, fract_im);
// add to input and store
*out_re = simd_cast(simd.avx2._mm256_add_epi64(fract_re, simd_cast(*out_re)));
*out_im = simd_cast(simd.avx2._mm256_add_epi64(fract_im, simd_cast(*out_im)));
}
},
);
}
pub fn convert_forward_integer_u64(
out: &mut [MaybeUninit<f64>],
in_re: &[u64],
in_im: &[u64],
twisties: Twisties<&[f64]>,
) {
#[cfg(feature = "nightly")]
if let Some(simd) = Avx512::try_new() {
return convert_forward_integer_u64_avx512(simd, out, in_re, in_im, twisties);
}
if let Some(simd) = FusedMulAdd::try_new() {
return convert_forward_integer_u64_fma(simd, out, in_re, in_im, twisties);
}
super::convert_forward_integer_u64_scalar(out, in_re, in_im, twisties);
}
pub fn convert_add_backward_torus_u64(
out_re: &mut [u64],
out_im: &mut [u64],
inp: &[f64],
twisties: Twisties<&[f64]>,
) {
#[cfg(feature = "nightly")]
if let Some(simd) = Avx512::try_new() {
return convert_add_backward_torus_u64_avx512f(simd, out_re, out_im, inp, twisties);
}
if let Some(simd) = FusedMulAdd::try_new() {
return convert_add_backward_torus_u64_fma(simd, out_re, out_im, inp, twisties);
}
super::convert_add_backward_torus_u64_scalar(out_re, out_im, inp, twisties);
}
#[cfg(test)]
mod tests {
#[test]
fn f64_to_i64_bit_twiddles() {
for x in [
0.0,
-0.0,
37.1242161_f64,
-37.1242161_f64,
0.1,
-0.1,
1.0,
-1.0,
0.9,
-0.9,
2.0,
-2.0,
1e-310,
-1e-310,
2.0_f64.powi(62),
-(2.0_f64.powi(62)),
1.1 * 2.0_f64.powi(62),
1.1 * -(2.0_f64.powi(62)),
-(2.0_f64.powi(63)),
] {
// this test checks the correctness of converting from f64 to i64 by manipulating the
// bits of the ieee754 representation of the floating point values.
//
// if the value is not representable as an i64, the result is unspecified.
//
// https://en.wikipedia.org/wiki/Double-precision_floating-point_format
let bits = x.to_bits();
let implicit_mantissa = bits & 0xFFFFFFFFFFFFF;
let explicit_mantissa = implicit_mantissa | 0x10000000000000;
let biased_exp = ((bits >> 52) & 0x7FF) as i64;
let sign = bits >> 63;
let explicit_mantissa_lshift = explicit_mantissa << 11;
// equivalent to:
//
// let exp = biased_exp - 1023;
// let explicit_mantissa_shift = explicit_mantissa_lshift >> (63 - exp.max(0));
let right_shift_amount = (1086 - biased_exp) as u64;
let explicit_mantissa_shift = if right_shift_amount < 64 {
explicit_mantissa_lshift >> right_shift_amount
} else {
0
};
let value = if sign == 0 {
explicit_mantissa_shift as i64
} else {
(explicit_mantissa_shift as i64).wrapping_neg()
};
let value = if biased_exp == 0 { 0 } else { value };
debug_assert_eq!(value, x as i64);
}
}
}

View File

@@ -0,0 +1,46 @@
use super::{decomposition::SignedDecompositionIter, types::DecompParams};
#[derive(Copy, Clone, Debug)]
#[readonly::make]
pub struct SignedDecomposer {
pub decomp_params: DecompParams,
}
impl SignedDecomposer {
/// Creates a new decomposer.
pub fn new(decomp_params: DecompParams) -> SignedDecomposer {
debug_assert!(
u64::BITS as usize > decomp_params.base_log * decomp_params.level,
"Decomposed bits exceeds the size of the integer to be decomposed"
);
SignedDecomposer { decomp_params }
}
/// Returns the closet value representable by the decomposition.
#[inline]
pub fn closest_representable(&self, input: u64) -> u64 {
// The closest number representable by the decomposition can be computed by performing
// the rounding at the appropriate bit.
// We compute the number of least significant bits which can not be represented by the
// decomposition
let non_rep_bit_count: usize =
u64::BITS as usize - self.decomp_params.level * self.decomp_params.base_log;
// We generate a mask which captures the non representable bits
let non_rep_mask = 1_u64 << (non_rep_bit_count - 1);
// We retrieve the non representable bits
let non_rep_bits = input & non_rep_mask;
// We extract the msb of the non representable bits to perform the rounding
let non_rep_msb = non_rep_bits >> (non_rep_bit_count - 1);
// We remove the non-representable bits and perform the rounding
let res = input >> non_rep_bit_count;
let res = res + non_rep_msb;
res << non_rep_bit_count
}
pub fn decompose(&self, input: u64) -> SignedDecompositionIter {
// Note that there would be no sense of making the decomposition on an input which was
// not rounded to the closest representable first. We then perform it before decomposing.
SignedDecompositionIter::new(self.closest_representable(input), self.decomp_params)
}
}

View File

@@ -0,0 +1,191 @@
use core::{iter::Map, slice::IterMut};
use dyn_stack::{DynArray, DynStack};
use super::types::DecompParams;
/// An iterator that yields the terms of the signed decomposition of an integer.
///
/// # Warning
///
/// This iterator yields the decomposition in reverse order. That means that the highest level
/// will be yielded first.
pub struct SignedDecompositionIter {
// The value being decomposed
input: u64,
decomp_params: DecompParams,
// The internal state of the decomposition
state: u64,
// The current level
current_level: usize,
// A mask which allows to compute the mod B of a value. For B=2^4, this guy is of the form:
// ...0001111
mod_b_mask: u64,
// A flag which store whether the iterator is a fresh one (for the recompose method)
fresh: bool,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
pub struct DecompositionTerm {
level: usize,
base_log: usize,
value: u64,
}
impl DecompositionTerm {
// Creates a new decomposition term.
pub(crate) fn new(level: usize, base_log: usize, value: u64) -> Self {
Self {
level,
base_log,
value,
}
}
/// Turns this term into a summand.
///
/// If our member represents one $\tilde{\theta}\_i$ of the decomposition, this method returns
/// $\tilde{\theta}\_i\frac{q}{B^i}$.
pub fn as_recomposition_summand(&self) -> u64 {
let shift: usize = u64::BITS as usize - self.base_log * self.level;
self.value << shift
}
#[allow(clippy::wrong_self_convention)]
pub fn to_recomposition_summand(&self) -> u64 {
let shift: usize = u64::BITS as usize - self.base_log * self.level;
self.value << shift
}
/// Returns the value of the term.
///
/// If our member represents one $\tilde{\theta}\_i$, this returns its actual value.
pub fn value(&self) -> u64 {
self.value
}
/// Returns the level of the term.
///
/// If our member represents one $\tilde{\theta}\_i$, this returns the value of $i$.
pub fn level(&self) -> usize {
self.level
}
}
impl SignedDecompositionIter {
pub(crate) fn new(input: u64, decomp_params: DecompParams) -> Self {
Self {
input,
decomp_params,
state: input >> (u64::BITS as usize - decomp_params.base_log * decomp_params.level),
current_level: decomp_params.level,
mod_b_mask: (1 << decomp_params.base_log) - 1,
fresh: true,
}
}
}
impl Iterator for SignedDecompositionIter {
type Item = DecompositionTerm;
fn next(&mut self) -> Option<Self::Item> {
// The iterator is not fresh anymore
self.fresh = false;
// We check if the decomposition is over
if self.current_level == 0 {
return None;
}
// We decompose the current level
let output = decompose_one_level(
self.decomp_params.base_log,
&mut self.state,
self.mod_b_mask,
);
self.current_level -= 1;
// We return the output for this level
Some(DecompositionTerm::new(
self.current_level + 1,
self.decomp_params.base_log,
output,
))
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.current_level, Some(self.current_level))
}
}
pub struct TensorSignedDecompositionLendingIter<'buffers> {
// The base log of the decomposition
base_log: usize,
// The current level
current_level: usize,
// A mask which allows to compute the mod B of a value. For B=2^4, this guy is of the form:
// ...0001111
mod_b_mask: u64,
// The internal states of each decomposition
states: DynArray<'buffers, u64>,
// A flag which stores whether the iterator is a fresh one (for the recompose method).
fresh: bool,
}
impl<'buffers> TensorSignedDecompositionLendingIter<'buffers> {
#[inline]
pub(crate) fn new(
input: impl Iterator<Item = u64>,
base_log: usize,
level: usize,
stack: DynStack<'buffers>,
) -> (Self, DynStack<'buffers>) {
let shift = u64::BITS as usize - base_log * level;
let (states, stack) =
stack.collect_aligned(aligned_vec::CACHELINE_ALIGN, input.map(|i| i >> shift));
(
TensorSignedDecompositionLendingIter {
base_log,
current_level: level,
mod_b_mask: (1_u64 << base_log) - 1_u64,
states,
fresh: true,
},
stack,
)
}
// inlining this improves perf of external product by about 25%, even in LTO builds
#[inline]
pub fn next_term<'short>(
&'short mut self,
) -> Option<(
usize,
usize,
Map<IterMut<'short, u64>, impl FnMut(&'short mut u64) -> u64>,
)> {
// The iterator is not fresh anymore.
self.fresh = false;
// We check if the decomposition is over
if self.current_level == 0 {
return None;
}
let current_level = self.current_level;
let base_log = self.base_log;
let mod_b_mask = self.mod_b_mask;
self.current_level -= 1;
Some((
current_level,
self.base_log,
self.states
.iter_mut()
.map(move |state| decompose_one_level(base_log, state, mod_b_mask)),
))
}
}
#[inline]
fn decompose_one_level(base_log: usize, state: &mut u64, mod_b_mask: u64) -> u64 {
let res = *state & mod_b_mask;
*state >>= base_log;
let mut carry = (res.wrapping_sub(1_u64) | *state) & res;
carry >>= base_log - 1;
*state += carry;
res.wrapping_sub(carry << base_log)
}

View File

@@ -0,0 +1,691 @@
use super::{
decomposition::DecompositionTerm,
fpks::LweKeyBitDecomposition,
from_torus,
polynomial::{update_with_wrapping_add_mul, update_with_wrapping_sub_mul},
types::*,
zip_eq,
};
use core::slice;
use rayon::{
prelude::{IndexedParallelIterator, ParallelIterator},
slice::ParallelSliceMut,
};
use std::cmp::Ordering;
pub fn mask_bytes_per_coef() -> usize {
u64::BITS as usize / 8
}
pub fn mask_bytes_per_polynomial(polynomial_size: usize) -> usize {
polynomial_size * mask_bytes_per_coef()
}
pub fn mask_bytes_per_glwe(glwe_params: GlweParams) -> usize {
glwe_params.dimension * mask_bytes_per_polynomial(glwe_params.polynomial_size)
}
pub fn mask_bytes_per_ggsw_level(glwe_params: GlweParams) -> usize {
(glwe_params.dimension + 1) * mask_bytes_per_glwe(glwe_params)
}
pub fn mask_bytes_per_lwe(lwe_dimension: usize) -> usize {
lwe_dimension * mask_bytes_per_coef()
}
pub fn mask_bytes_per_gsw_level(lwe_dimension: usize) -> usize {
(lwe_dimension + 1) * mask_bytes_per_lwe(lwe_dimension)
}
pub fn mask_bytes_per_ggsw(decomposition_level_count: usize, glwe_params: GlweParams) -> usize {
decomposition_level_count * mask_bytes_per_ggsw_level(glwe_params)
}
pub fn mask_bytes_per_pfpksk_chunk(
decomposition_level_count: usize,
glwe_params: GlweParams,
) -> usize {
decomposition_level_count * mask_bytes_per_glwe(glwe_params)
}
pub fn mask_bytes_per_pfpksk(
decomposition_level_count: usize,
glwe_params: GlweParams,
lwe_dimension: usize,
) -> usize {
(lwe_dimension + 1) * mask_bytes_per_pfpksk_chunk(decomposition_level_count, glwe_params)
}
pub fn noise_bytes_per_coef() -> usize {
// We use f64 to sample the noise from a normal distribution with the polar form of the
// Box-Muller algorithm. With this algorithm, the input pair of uniform values will be rejected
// with a probability of pi/4 which means that in average, we need ~4/pi pair of uniform
// values for one pair of normal values. To have a safety margin, we require 32 uniform inputs
// (>> 4/pi) for one pair of normal values
8 * 32
}
pub fn noise_bytes_per_polynomial(polynomial_size: usize) -> usize {
polynomial_size * noise_bytes_per_coef()
}
pub fn noise_bytes_per_glwe(polynomial_size: usize) -> usize {
noise_bytes_per_polynomial(polynomial_size)
}
pub fn noise_bytes_per_ggsw_level(glwe_params: GlweParams) -> usize {
(glwe_params.dimension + 1) * noise_bytes_per_glwe(glwe_params.polynomial_size)
}
pub fn noise_bytes_per_lwe() -> usize {
// Here we take 3 to keep a safety margin
noise_bytes_per_coef() * 3
}
pub fn noise_bytes_per_gsw_level(lwe_dimension: usize) -> usize {
(lwe_dimension + 1) * noise_bytes_per_lwe()
}
pub fn noise_bytes_per_ggsw(decomposition_level_count: usize, glwe_params: GlweParams) -> usize {
decomposition_level_count * noise_bytes_per_ggsw_level(glwe_params)
}
pub fn noise_bytes_per_pfpksk_chunk(
decomposition_level_count: usize,
polynomial_size: usize,
) -> usize {
decomposition_level_count * noise_bytes_per_glwe(polynomial_size)
}
pub fn noise_bytes_per_pfpksk(
decomposition_level_count: usize,
polynomial_size: usize,
lwe_dimension: usize,
) -> usize {
(lwe_dimension + 1) * noise_bytes_per_pfpksk_chunk(decomposition_level_count, polynomial_size)
}
pub fn fill_with_random_uniform(buffer: &mut [u64], mut csprng: CsprngMut<'_, '_>) {
#[cfg(target_endian = "little")]
{
let len = buffer.len() * core::mem::size_of::<u64>();
let random_bytes = csprng
.as_mut()
.next_bytes(unsafe { slice::from_raw_parts_mut(buffer.as_mut_ptr() as _, len) });
assert_eq!(len, random_bytes);
}
#[cfg(target_endian = "big")]
{
let mut little_endian = [0u8; core::mem::size_of::<u64>()];
for e in buffer {
let random_bytes = csprng.as_mut().next_bytes(&mut little_endian);
assert_eq!(little_endian.len(), random_bytes);
*e = u64::from_le_bytes(little_endian);
}
}
}
fn random_gaussian_pair(variance: f64, mut csprng: CsprngMut<'_, '_>) -> (f64, f64) {
loop {
let mut uniform_rand = [0_u64, 0_u64];
fill_with_random_uniform(&mut uniform_rand, csprng.as_mut());
let uniform_rand =
uniform_rand.map(|x| (x as i64 as f64) * 2.0_f64.powi(1 - u64::BITS as i32));
let u = uniform_rand[0];
let v = uniform_rand[1];
let s = u * u + v * v;
if s > 0.0 && s < 1.0 {
let cst = (-2.0 * variance * s.ln() / s).sqrt();
return (u * cst, v * cst);
}
}
}
pub fn fill_with_random_gaussian(buffer: &mut [u64], variance: f64, mut csprng: CsprngMut<'_, '_>) {
for chunk in buffer.chunks_exact_mut(2) {
let (g0, g1) = random_gaussian_pair(variance, csprng.as_mut());
if let Some(first) = chunk.get_mut(0) {
*first = from_torus(g0);
}
if let Some(second) = chunk.get_mut(1) {
*second = from_torus(g1);
}
}
}
impl BootstrapKey<&mut [u64]> {
pub fn fill_with_new_key(
self,
lwe_sk: LweSecretKey<&[u64]>,
glwe_sk: GlweSecretKey<&[u64]>,
variance: f64,
mut csprng: CsprngMut<'_, '_>,
) {
for (mut ggsw, sk_scalar) in
zip_eq(self.into_ggsw_iter(), lwe_sk.into_data().iter().copied())
{
let encoded = sk_scalar;
glwe_sk.gen_noise_ggsw(ggsw.as_mut_view(), variance, csprng.as_mut());
glwe_sk.encrypt_constant_ggsw_noise_full(ggsw, encoded);
}
}
}
#[cfg(feature = "parallel")]
impl BootstrapKey<&mut [u64]> {
pub fn fill_with_new_key_par(
mut self,
lwe_sk: LweSecretKey<&[u64]>,
glwe_sk: GlweSecretKey<&[u64]>,
variance: f64,
mut csprng: CsprngMut<'_, '_>,
) {
for ggsw in self.as_mut_view().into_ggsw_iter() {
glwe_sk.gen_noise_ggsw(ggsw, variance, csprng.as_mut());
}
self.into_ggsw_iter_par()
.zip_eq(lwe_sk.data)
.for_each(|(ggsw, sk_scalar)| {
let encoded = *sk_scalar;
glwe_sk.encrypt_constant_ggsw_noise_full(ggsw, encoded);
});
}
}
impl LweKeyswitchKey<&mut [u64]> {
pub fn fill_with_keyswitch_key(
self,
input_key: LweSecretKey<&[u64]>,
output_key: LweSecretKey<&[u64]>,
variance: f64,
mut csprng: CsprngMut<'_, '_>,
) {
let decomposition_level_count = self.decomp_params.level;
let decomposition_base_log = self.decomp_params.base_log;
// loop over the before key blocks
for (input_key_bit, keyswitch_key_block) in zip_eq(
input_key.into_data().iter().copied(),
self.into_lev_ciphertexts(),
) {
// we encrypt the buffer
for (lwe, message) in zip_eq(
keyswitch_key_block.into_ciphertext_iter(),
(1..(decomposition_level_count + 1)).map(|level| {
let shift = u64::BITS as usize - decomposition_base_log * level;
input_key_bit << shift
}),
) {
output_key.encrypt_lwe(lwe, message, variance, csprng.as_mut());
}
}
}
}
impl PackingKeyswitchKey<&mut [u64]> {
pub fn fill_with_packing_keyswitch_key(
self,
input_lwe_key: LweSecretKey<&[u64]>,
output_glwe_key: GlweSecretKey<&[u64]>,
variance: f64,
mut csprng: CsprngMut<'_, '_>,
) {
let decomposition_level_count = self.decomp_params.level;
let decomposition_base_log = self.decomp_params.base_log;
for (input_key_bit, keyswitch_key_block) in zip_eq(
input_lwe_key.into_data().iter(),
self.into_glev_ciphertexts(),
) {
for (mut glwe, message) in zip_eq(
keyswitch_key_block.into_ciphertext_iter(),
(1..(decomposition_level_count + 1)).map(|level| {
let shift = u64::BITS as usize - decomposition_base_log * level;
input_key_bit << shift
}),
) {
output_glwe_key.encrypt_zero_glwe(glwe.as_mut_view(), variance, csprng.as_mut());
let (_, body) = glwe.into_mask_and_body();
let body = body.into_data();
let first = body.first_mut().unwrap();
*first = first.wrapping_add(message);
}
}
}
pub fn fill_with_private_functional_packing_keyswitch_key(
&mut self,
input_lwe_key: &LweSecretKey<&[u64]>,
output_glwe_key: &GlweSecretKey<&[u64]>,
variance: f64,
mut csprng: CsprngMut,
f: impl Fn(u64) -> u64,
polynomial: &[u64],
) {
// We instantiate a buffer
let mut messages = vec![0_u64; self.decomp_params.level * self.glwe_params.polynomial_size];
// We retrieve decomposition arguments
let decomp_level_count = self.decomp_params.level;
let decomp_base_log = self.decomp_params.base_log;
let polynomial_size = self.glwe_params.polynomial_size;
// add minus one for the function which will be applied to the decomposed body
// ( Scalar::MAX = -Scalar::ONE )
let input_key_bit_iter = input_lwe_key.data.iter().chain(std::iter::once(&u64::MAX));
// loop over the before key blocks
for (&input_key_bit, keyswitch_key_block) in
zip_eq(input_key_bit_iter, self.bit_decomp_iter_mut())
{
// We reset the buffer
messages.fill(0);
// We fill the buffer with the powers of the key bits
for (level, message) in zip_eq(
1..=decomp_level_count,
messages.chunks_exact_mut(polynomial_size),
) {
let multiplier = DecompositionTerm::new(
level,
decomp_base_log,
f(1).wrapping_mul(input_key_bit),
)
.to_recomposition_summand();
for (self_i, other_i) in zip_eq(message, polynomial) {
*self_i = (*self_i).wrapping_add(other_i.wrapping_mul(multiplier));
}
}
// We encrypt the buffer
for (mut glwe, message) in zip_eq(
keyswitch_key_block.into_glwe_list().into_glwe_iter(),
messages.chunks_exact(polynomial_size),
) {
output_glwe_key.encrypt_zero_glwe(glwe.as_mut_view(), variance, csprng.as_mut());
let (_, body) = glwe.into_mask_and_body();
for (r, e) in zip_eq(body.into_data().iter_mut(), message) {
*r = r.wrapping_add(*e)
}
}
}
}
}
impl<'a> PackingKeyswitchKey<&'a mut [u64]> {
pub fn fill_with_private_functional_packing_keyswitch_key_par(
&'a mut self,
input_lwe_key: &LweSecretKey<&[u64]>,
output_glwe_key: &GlweSecretKey<&[u64]>,
variance: f64,
mut csprng: CsprngMut,
f: impl Sync + Fn(u64) -> u64,
polynomial: &[u64],
) {
// We retrieve decomposition arguments
let decomp_level_count = self.decomp_params.level;
let decomp_base_log = self.decomp_params.base_log;
let polynomial_size = self.glwe_params.polynomial_size;
// loop over the before key blocks
for keyswitch_key_block in self.bit_decomp_iter_mut() {
// We encrypt the buffer
for mut glwe in keyswitch_key_block.into_glwe_list().into_glwe_iter() {
output_glwe_key.gen_noise_glwe(glwe.as_mut_view(), variance, csprng.as_mut());
}
}
let input_dimension = self.input_dimension;
// loop over the before key blocks
self.bit_decomp_iter_mut_par()
.enumerate()
.for_each(|(i, keyswitch_key_block)| {
// add minus one for the function which will be applied to the decomposed body
// ( Scalar::MAX = -Scalar::ONE )
let input_key_bit = match i.cmp(&input_dimension) {
Ordering::Less => input_lwe_key.data[i],
Ordering::Equal => u64::MAX,
Ordering::Greater => unreachable!(),
};
// We instantiate a buffer
let mut messages = vec![0_u64; decomp_level_count * polynomial_size];
// We reset the buffer
messages.fill(0);
// We fill the buffer with the powers of the key bits
for (level, message) in zip_eq(
1..=decomp_level_count,
messages.chunks_exact_mut(polynomial_size),
) {
let multiplier = DecompositionTerm::new(
level,
decomp_base_log,
f(1).wrapping_mul(input_key_bit),
)
.to_recomposition_summand();
for (self_i, other_i) in zip_eq(message, polynomial) {
*self_i = (*self_i).wrapping_add(other_i.wrapping_mul(multiplier));
}
}
// We encrypt the buffer
for (mut glwe, message) in zip_eq(
keyswitch_key_block.into_glwe_list().into_glwe_iter(),
messages.chunks_exact(polynomial_size),
) {
output_glwe_key.encrypt_zero_glwe_noise_full(glwe.as_mut_view());
let (_, body) = glwe.into_mask_and_body();
for (r, e) in zip_eq(body.into_data().iter_mut(), message) {
*r = r.wrapping_add(*e)
}
}
});
}
pub fn bit_decomp_iter_mut(
&mut self,
) -> impl Iterator<Item = LweKeyBitDecomposition<&mut [u64]>> {
let glwe_params = self.glwe_params;
let level = self.decomp_params.level;
let chunks_size = level * (glwe_params.dimension + 1) * glwe_params.polynomial_size;
self.as_mut_view()
.into_data()
.chunks_exact_mut(chunks_size)
.map(move |sub| LweKeyBitDecomposition::new(sub, glwe_params, level))
}
}
#[cfg(feature = "parallel")]
impl<'a> PackingKeyswitchKey<&'a mut [u64]> {
pub fn bit_decomp_iter_mut_par(
&'a mut self,
) -> impl 'a + IndexedParallelIterator<Item = LweKeyBitDecomposition<&'a mut [u64]>> {
let glwe_params = self.glwe_params;
let level = self.decomp_params.level;
let chunks_size = level * (glwe_params.dimension + 1) * glwe_params.polynomial_size;
self.as_mut_view()
.into_data()
.par_chunks_exact_mut(chunks_size)
.map(move |sub| LweKeyBitDecomposition::new(sub, glwe_params, level))
}
}
impl PackingKeyswitchKey<&[u64]> {
pub fn bit_decomp_iter(&self) -> impl Iterator<Item = LweKeyBitDecomposition<&[u64]>> {
let glwe_params = self.glwe_params;
let level = self.decomp_params.level;
let size =
self.decomp_params.level * (glwe_params.dimension + 1) * glwe_params.polynomial_size;
self.data
.chunks_exact(size)
.map(move |sub| LweKeyBitDecomposition::new(sub, glwe_params, level))
}
}
impl GlweSecretKey<&[u64]> {
pub fn encrypt_constant_ggsw(
self,
ggsw: GgswCiphertext<&mut [u64]>,
encoded: u64,
variance: f64,
mut csprng: CsprngMut<'_, '_>,
) {
let base_log = ggsw.decomp_params.base_log;
let glwe_params = ggsw.glwe_params;
for matrix in ggsw.into_level_matrices_iter() {
let factor = encoded.wrapping_neg()
<< (u64::BITS as usize - (base_log * matrix.decomposition_level));
let last_row_index = matrix.glwe_params.dimension;
for (row_index, row) in matrix.into_rows_iter().enumerate() {
self.encrypt_constant_ggsw_row(
(row_index, last_row_index),
factor,
GlweCiphertext::new(row.into_data(), glwe_params),
variance,
csprng.as_mut(),
);
}
}
}
pub fn encrypt_constant_ggsw_row(
self,
(row_index, last_row_index): (usize, usize),
factor: u64,
mut row: GlweCiphertext<&mut [u64]>,
variance: f64,
csprng: CsprngMut<'_, '_>,
) {
if row_index < last_row_index {
// Not the last row
let sk_poly = self.get_polynomial(row_index);
let encoded = sk_poly.iter().map(|&e| e.wrapping_mul(factor));
self.encrypt_zero_glwe(row.as_mut_view(), variance, csprng);
let (_, body) = row.into_mask_and_body();
for (r, e) in zip_eq(body.into_data().iter_mut(), encoded) {
*r = r.wrapping_add(e)
}
} else {
// The last row needs a slightly different treatment
self.encrypt_zero_glwe(row.as_mut_view(), variance, csprng);
let (_, body) = row.into_mask_and_body();
let first = body.into_data().first_mut().unwrap();
*first = first.wrapping_add(factor.wrapping_neg());
}
}
pub fn encrypt_zero_glwe(
self,
encrypted: GlweCiphertext<&mut [u64]>,
variance: f64,
mut csprng: CsprngMut<'_, '_>,
) {
let (mut mask, mut body) = encrypted.into_mask_and_body();
fill_with_random_uniform(mask.as_mut_view().into_data(), csprng.as_mut());
fill_with_random_gaussian(body.as_mut_view().into_data(), variance, csprng);
let mask = mask.as_view();
let body = body.into_data();
for idx in 0..mask.glwe_params.dimension {
let poly = mask.get_polynomial(idx);
let bin_poly = self.get_polynomial(idx);
update_with_wrapping_add_mul(body, poly, bin_poly)
}
}
pub fn decrypt_glwe(self, encrypted: GlweCiphertext<&[u64]>) -> Vec<u64> {
let (mask, body) = encrypted.into_mask_and_body();
let mask = mask.as_view();
let mut out = body.into_data().to_owned();
for idx in 0..mask.glwe_params.dimension {
let poly = mask.get_polynomial(idx);
let bin_poly = self.get_polynomial(idx);
update_with_wrapping_sub_mul(&mut out, poly, bin_poly)
}
out
}
pub fn gen_noise_ggsw(
self,
ggsw: GgswCiphertext<&mut [u64]>,
variance: f64,
mut csprng: CsprngMut<'_, '_>,
) {
let glwe_params = ggsw.glwe_params;
for matrix in ggsw.into_level_matrices_iter() {
for row in matrix.into_rows_iter() {
self.gen_noise_glwe(
GlweCiphertext::new(row.into_data(), glwe_params),
variance,
csprng.as_mut(),
);
}
}
}
pub fn gen_noise_glwe(
self,
encrypted: GlweCiphertext<&mut [u64]>,
variance: f64,
mut csprng: CsprngMut<'_, '_>,
) {
let (mut mask, mut body) = encrypted.into_mask_and_body();
fill_with_random_uniform(mask.as_mut_view().into_data(), csprng.as_mut());
fill_with_random_gaussian(body.as_mut_view().into_data(), variance, csprng);
}
pub fn encrypt_constant_ggsw_noise_full(self, ggsw: GgswCiphertext<&mut [u64]>, encoded: u64) {
let base_log = ggsw.decomp_params.base_log;
let glwe_params = ggsw.glwe_params;
for matrix in ggsw.into_level_matrices_iter() {
let factor = encoded.wrapping_neg()
<< (u64::BITS as usize - (base_log * matrix.decomposition_level));
let last_row_index = matrix.glwe_params.dimension;
for (row_index, row) in matrix.into_rows_iter().enumerate() {
self.encrypt_constant_ggsw_row_noise_full(
(row_index, last_row_index),
factor,
GlweCiphertext::new(row.into_data(), glwe_params),
);
}
}
}
pub fn encrypt_constant_ggsw_row_noise_full(
self,
(row_index, last_row_index): (usize, usize),
factor: u64,
mut row: GlweCiphertext<&mut [u64]>,
) {
if row_index < last_row_index {
// Not the last row
let sk_poly = self.get_polynomial(row_index);
let encoded = sk_poly.iter().map(|&e| e.wrapping_mul(factor));
self.encrypt_zero_glwe_noise_full(row.as_mut_view());
let (_, body) = row.into_mask_and_body();
for (r, e) in zip_eq(body.into_data().iter_mut(), encoded) {
*r = r.wrapping_add(e)
}
} else {
// The last row needs a slightly different treatment
self.encrypt_zero_glwe_noise_full(row.as_mut_view());
let (_, body) = row.into_mask_and_body();
let first = body.into_data().first_mut().unwrap();
*first = first.wrapping_add(factor.wrapping_neg());
}
}
pub fn encrypt_zero_glwe_noise_full(self, encrypted: GlweCiphertext<&mut [u64]>) {
let (mask, body) = encrypted.into_mask_and_body();
let mask = mask.as_view();
let body = body.into_data();
for idx in 0..mask.glwe_params.dimension {
let poly = mask.get_polynomial(idx);
let bin_poly = self.get_polynomial(idx);
update_with_wrapping_add_mul(body, poly, bin_poly)
}
}
}
impl LweSecretKey<&[u64]> {
pub fn encrypt_lwe(
self,
encrypted: LweCiphertext<&mut [u64]>,
plaintext: u64,
variance: f64,
mut csprng: CsprngMut<'_, '_>,
) {
let (body, mask) = encrypted.into_data().split_last_mut().unwrap();
fill_with_random_uniform(mask, csprng.as_mut());
*body = from_torus(random_gaussian_pair(variance, csprng.as_mut()).0);
*body = body.wrapping_add(
zip_eq(mask.iter().copied(), self.into_data().iter().copied())
.fold(0_u64, |acc, (lhs, rhs)| acc.wrapping_add(lhs * rhs)),
);
*body = body.wrapping_add(plaintext);
}
pub fn decrypt_lwe(self, encrypted: LweCiphertext<&[u64]>) -> u64 {
let (body, mask) = encrypted.into_data().split_last().unwrap();
body.wrapping_sub(
zip_eq(mask.iter().copied(), self.into_data().iter().copied())
.fold(0_u64, |acc, (lhs, rhs)| acc.wrapping_add(lhs * rhs)),
)
}
}
#[cfg(test)]
mod tests {
use crate::{
c_api::types::tests::to_generic,
implementation::types::{CsprngMut, LweCiphertext, LweSecretKey},
};
use concrete_csprng::{
generators::{RandomGenerator, SoftwareRandomGenerator},
seeders::Seed,
};
fn encrypt_decrypt(
mut csprng: CsprngMut,
pt: u64,
dim: usize,
encryption_variance: f64,
) -> u64 {
let mut ct = LweCiphertext::zero(dim);
let sk = LweSecretKey::new_random(csprng.as_mut(), dim);
sk.as_view()
.encrypt_lwe(ct.as_mut_view(), pt, encryption_variance, csprng);
sk.as_view().decrypt_lwe(ct.as_view())
}
#[test]
fn encryption_decryption_correctness() {
let mut csprng = SoftwareRandomGenerator::new(Seed(0));
for _ in 0..100 {
let a: u64 = u64::from_le_bytes(std::array::from_fn(|_| csprng.next().unwrap()));
let b = encrypt_decrypt(to_generic(&mut csprng), a, 1024, 0.0000000001);
let diff = b.wrapping_sub(a) as i64;
assert!((diff as f64).abs() / 2.0_f64.powi(64) < 0.0001);
}
}
}

View File

@@ -0,0 +1,412 @@
use core::mem::MaybeUninit;
use aligned_vec::CACHELINE_ALIGN;
use concrete_fft::c64;
use dyn_stack::{DynArray, DynStack, ReborrowMut, SizeOverflow, StackReq};
use pulp::{as_arrays, as_arrays_mut};
use crate::implementation::{
assume_init_mut, decomposer::SignedDecomposer,
decomposition::TensorSignedDecompositionLendingIter, Split,
};
use super::{as_mut_uninit, fft::FftView, types::*, zip_eq};
impl GgswCiphertext<&mut [f64]> {
pub fn fill_with_forward_fourier(
self,
standard: GgswCiphertext<&[u64]>,
fft: FftView<'_>,
stack: DynStack<'_>,
) {
let polynomial_size = standard.glwe_params.polynomial_size;
let mut stack = stack;
for (fourier_polynomial, standard_polynomial) in zip_eq(
self.into_data().into_chunks(polynomial_size),
standard.into_data().into_chunks(polynomial_size),
) {
fft.forward_as_torus(
unsafe { as_mut_uninit(fourier_polynomial) },
standard_polynomial,
stack.rb_mut(),
);
}
}
pub fn fill_with_forward_fourier_scratch(fft: FftView<'_>) -> Result<StackReq, SizeOverflow> {
fft.forward_scratch()
}
}
/// Returns the required memory for [`external_product`].
pub fn external_product_scratch(
ggsw_glwe_params: GlweParams,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
let glwe_dimension = ggsw_glwe_params.dimension;
let polynomial_size = ggsw_glwe_params.polynomial_size;
let align = CACHELINE_ALIGN;
let standard_scratch =
StackReq::try_new_aligned::<u64>((glwe_dimension + 1) * polynomial_size, align)?;
let fourier_scratch =
StackReq::try_new_aligned::<f64>((glwe_dimension + 1) * polynomial_size, align)?;
let fourier_scratch_single = StackReq::try_new_aligned::<f64>(polynomial_size, align)?;
let substack3 = fft.forward_scratch()?;
let substack2 = substack3.try_and(fourier_scratch_single)?;
let substack1 = substack2.try_and(standard_scratch)?;
let substack0 = StackReq::try_any_of([
substack1.try_and(standard_scratch)?,
fft.backward_scratch()?,
])?;
substack0.try_and(fourier_scratch)
}
/// Performs the external product of `ggsw` and `glwe`, and stores the result in `out`.
pub fn external_product(
mut out: GlweCiphertext<&mut [u64]>,
ggsw: GgswCiphertext<&[f64]>,
glwe: GlweCiphertext<&[u64]>,
fft: FftView<'_>,
stack: DynStack<'_>,
) {
debug_assert_eq!(ggsw.glwe_params, glwe.glwe_params);
debug_assert_eq!(ggsw.glwe_params, out.glwe_params);
let align = CACHELINE_ALIGN;
let polynomial_size = ggsw.glwe_params.polynomial_size;
let decomposer = SignedDecomposer::new(ggsw.decomp_params);
let (mut output_fft_buffer, mut substack0) =
stack.make_aligned_uninit::<f64>(polynomial_size * (ggsw.glwe_params.dimension + 1), align);
// output_fft_buffer is initially uninitialized, considered to be implicitly zero, to avoid
// the cost of filling it up with zeros. `is_output_uninit` is set to `false` once
// it has been fully initialized for the first time.
let output_fft_buffer = &mut *output_fft_buffer;
let mut is_output_uninit = true;
{
// ------------------------------------------------------ EXTERNAL PRODUCT IN FOURIER DOMAIN
// In this section, we perform the external product in the fourier domain, and accumulate
// the result in the output_fft_buffer variable.
let (mut decomposition, mut substack1) = TensorSignedDecompositionLendingIter::new(
glwe.into_data()
.iter()
.map(|s| decomposer.closest_representable(*s)),
decomposer.decomp_params.base_log,
decomposer.decomp_params.level,
substack0.rb_mut(),
);
// We loop through the levels (we reverse to match the order of the decomposition iterator.)
for ggsw_decomposition_matrix in ggsw.into_level_matrices_iter().rev() {
// We retrieve the decomposition of this level.
let (glwe_level, glwe_decomposition_term, mut substack2) =
collect_next_term(&mut decomposition, &mut substack1, align);
let glwe_decomposition_term =
GlweCiphertext::new(&*glwe_decomposition_term, ggsw.glwe_params);
debug_assert_eq!(ggsw_decomposition_matrix.decomposition_level, glwe_level);
// For each level we have to add the result of the vector-matrix product between the
// decomposition of the glwe, and the ggsw level matrix to the output. To do so, we
// iteratively add to the output, the product between every line of the matrix, and
// the corresponding (scalar) polynomial in the glwe decomposition:
//
// ggsw_mat ggsw_mat
// glwe_dec | - - - - | < glwe_dec | - - - - |
// | - - - | x | - - - - | | - - - | x | - - - - | <
// ^ | - - - - | ^ | - - - - |
//
// t = 1 t = 2 ...
for (ggsw_row, glwe_poly) in zip_eq(
ggsw_decomposition_matrix.into_rows_iter(),
glwe_decomposition_term
.into_data()
.into_chunks(polynomial_size),
) {
let (mut fourier, substack3) = substack2
.rb_mut()
.make_aligned_uninit::<f64>(polynomial_size, align);
// We perform the forward fft transform for the glwe polynomial
fft.forward_as_integer(&mut fourier, glwe_poly, substack3);
let fourier = unsafe { assume_init_mut(&mut fourier) };
// Now we loop through the polynomials of the output, and add the
// corresponding product of polynomials.
// SAFETY: see comment above definition of `output_fft_buffer`
unsafe {
update_with_fmadd(
output_fft_buffer,
ggsw_row,
fourier,
is_output_uninit,
polynomial_size,
)
};
// we initialized `output_fft_buffer, so we can set this to false
is_output_uninit = false;
}
}
}
// -------------------------------------------- TRANSFORMATION OF RESULT TO STANDARD DOMAIN
// In this section, we bring the result from the fourier domain, back to the standard
// domain, and add it to the output.
//
// We iterate over the polynomials in the output.
if !is_output_uninit {
// SAFETY: output_fft_buffer is initialized, since `is_output_uninit` is false
let output_fft_buffer = &*unsafe { assume_init_mut(output_fft_buffer) };
for (out, fourier) in zip_eq(
out.as_mut_view().into_data().into_chunks(polynomial_size),
output_fft_buffer.into_chunks(polynomial_size),
) {
fft.add_backward_as_torus(out, fourier, substack0.rb_mut());
}
}
}
#[cfg_attr(__profiling, inline(never))]
fn collect_next_term<'a>(
decomposition: &mut TensorSignedDecompositionLendingIter<'_>,
substack1: &'a mut DynStack,
align: usize,
) -> (usize, DynArray<'a, u64>, DynStack<'a>) {
let (glwe_level, _, glwe_decomposition_term) = decomposition.next_term().unwrap();
let (glwe_decomposition_term, substack2) = substack1
.rb_mut()
.collect_aligned(align, glwe_decomposition_term);
(glwe_level, glwe_decomposition_term, substack2)
}
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
mod x86 {
use crate::implementation::convert::x86::*;
use core::mem::MaybeUninit;
use pulp::{as_arrays, as_arrays_mut};
/// # Postconditions
///
/// this function leaves all the elements of `output_fourier` in an initialized state.
///
/// # Safety
///
/// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values.
#[cfg(feature = "nightly")]
unsafe fn update_with_fmadd_avx512(
simd: Avx512,
output_fourier: &mut [MaybeUninit<f64>],
ggsw_polynomial: &[f64],
fourier: &[f64],
is_output_uninit: bool,
) {
use crate::implementation::assume_init_mut;
let n = output_fourier.len();
debug_assert_eq!(n, ggsw_polynomial.len());
debug_assert_eq!(n, fourier.len());
debug_assert_eq!(n % 8, 0);
// 8×f64 => 4×c64
let (ggsw_polynomial, _) = as_arrays::<8, _>(ggsw_polynomial);
let (fourier, _) = as_arrays::<8, _>(fourier);
simd.vectorize(|| {
let simd = simd.avx512f;
if is_output_uninit {
let (output_fourier, _) = as_arrays_mut::<8, _>(output_fourier);
for (out, lhs, rhs) in izip!(output_fourier, ggsw_polynomial, fourier) {
let ab = simd_cast(*lhs);
let xy = simd_cast(*rhs);
let aa = simd._mm512_unpacklo_pd(ab, ab);
let bb = simd._mm512_unpackhi_pd(ab, ab);
let yx = simd._mm512_permute_pd::<0b01010101>(xy);
*out = simd_cast(simd._mm512_fmaddsub_pd(aa, xy, simd._mm512_mul_pd(bb, yx)));
}
} else {
let (output_fourier, _) =
as_arrays_mut::<8, _>(unsafe { assume_init_mut(output_fourier) });
for (out, lhs, rhs) in izip!(output_fourier, ggsw_polynomial, fourier) {
let ab = simd_cast(*lhs);
let xy = simd_cast(*rhs);
let aa = simd._mm512_unpacklo_pd(ab, ab);
let bb = simd._mm512_unpackhi_pd(ab, ab);
let yx = simd._mm512_permute_pd::<0b01010101>(xy);
*out = simd_cast(simd._mm512_fmaddsub_pd(
aa,
xy,
simd._mm512_fmaddsub_pd(bb, yx, simd_cast(*out)),
));
}
}
});
}
/// # Postconditions
///
/// this function leaves all the elements of `output_fourier` in an initialized state.
///
/// # Safety
///
/// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values.
unsafe fn update_with_fmadd_fma(
simd: FusedMulAdd,
output_fourier: &mut [MaybeUninit<f64>],
ggsw_polynomial: &[f64],
fourier: &[f64],
is_output_uninit: bool,
) {
use crate::implementation::assume_init_mut;
let n = output_fourier.len();
debug_assert_eq!(n, ggsw_polynomial.len());
debug_assert_eq!(n, fourier.len());
debug_assert_eq!(n % 4, 0);
// 8×f64 => 4×c64
let (ggsw_polynomial, _) = as_arrays::<4, _>(ggsw_polynomial);
let (fourier, _) = as_arrays::<4, _>(fourier);
simd.vectorize(|| {
let FusedMulAdd { avx, fma, .. } = simd;
if is_output_uninit {
let (output_fourier, _) = as_arrays_mut::<4, _>(output_fourier);
for (out, lhs, rhs) in izip!(output_fourier, ggsw_polynomial, fourier) {
let ab = simd_cast(*lhs);
let xy = simd_cast(*rhs);
let aa = avx._mm256_unpacklo_pd(ab, ab);
let bb = avx._mm256_unpackhi_pd(ab, ab);
let yx = avx._mm256_permute_pd::<0b0101>(xy);
*out = simd_cast(fma._mm256_fmaddsub_pd(aa, xy, avx._mm256_mul_pd(bb, yx)));
}
} else {
let (output_fourier, _) =
as_arrays_mut::<4, _>(unsafe { assume_init_mut(output_fourier) });
for (out, lhs, rhs) in izip!(output_fourier, ggsw_polynomial, fourier) {
let ab = simd_cast(*lhs);
let xy = simd_cast(*rhs);
let aa = avx._mm256_unpacklo_pd(ab, ab);
let bb = avx._mm256_unpackhi_pd(ab, ab);
let yx = avx._mm256_permute_pd::<0b0101>(xy);
*out = simd_cast(fma._mm256_fmaddsub_pd(
aa,
xy,
fma._mm256_fmaddsub_pd(bb, yx, simd_cast(*out)),
));
}
}
});
}
pub unsafe fn update_with_fmadd(
output_fourier: &mut [MaybeUninit<f64>],
ggsw_polynomial: &[f64],
fourier: &[f64],
is_output_uninit: bool,
) {
#[cfg(feature = "nightly")]
if let Some(simd) = Avx512::try_new() {
return unsafe {
update_with_fmadd_avx512(
simd,
output_fourier,
ggsw_polynomial,
fourier,
is_output_uninit,
)
};
}
if let Some(simd) = FusedMulAdd::try_new() {
return unsafe {
update_with_fmadd_fma(
simd,
output_fourier,
ggsw_polynomial,
fourier,
is_output_uninit,
)
};
}
unsafe {
super::update_with_fmadd_scalar(
output_fourier,
ggsw_polynomial,
fourier,
is_output_uninit,
)
}
}
}
/// # Postconditions
///
/// this function leaves all the elements of `output_fourier` in an initialized state.
///
/// # Safety
///
/// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values.
unsafe fn update_with_fmadd_scalar(
output_fourier: &mut [MaybeUninit<f64>],
ggsw_polynomial: &[f64],
fourier: &[f64],
is_output_uninit: bool,
) {
let (output_fourier, _) = as_arrays_mut::<2, _>(output_fourier);
let (ggsw_polynomial, _) = as_arrays::<2, _>(ggsw_polynomial);
let (fourier, _) = as_arrays::<2, _>(fourier);
if is_output_uninit {
// we're writing to output_fft_buffer for the first time
// so its contents are uninitialized
for (out_fourier, lhs, rhs) in izip!(output_fourier, ggsw_polynomial, fourier) {
let lhs = c64::new(lhs[0], lhs[1]);
let rhs = c64::new(rhs[0], rhs[1]);
let result = lhs * rhs;
out_fourier[0].write(result.re);
out_fourier[1].write(result.im);
}
} else {
// we already wrote to output_fft_buffer, so we can assume its contents are
// initialized.
for (out_fourier, lhs, rhs) in izip!(output_fourier, ggsw_polynomial, fourier) {
let lhs = c64::new(lhs[0], lhs[1]);
let rhs = c64::new(rhs[0], rhs[1]);
let result = lhs * rhs;
*unsafe { out_fourier[0].assume_init_mut() } += result.re;
*unsafe { out_fourier[1].assume_init_mut() } += result.im;
}
}
}
/// # Postconditions
///
/// this function leaves all the elements of `output_fourier` in an initialized state.
///
/// # Safety
///
/// - if `is_output_uninit` is false, `output_fourier` must not hold any uninitialized values.
#[cfg_attr(__profiling, inline(never))]
unsafe fn update_with_fmadd(
output_fft_buffer: &mut [MaybeUninit<f64>],
ggsw_row: GgswLevelRow<&[f64]>,
fourier: &[f64],
is_output_uninit: bool,
polynomial_size: usize,
) {
for (output_fourier, ggsw_poly) in zip_eq(
output_fft_buffer.into_chunks(polynomial_size),
ggsw_row.data.into_chunks(polynomial_size),
) {
unsafe {
#[cfg(any(target_arch = "x86_64", target_arch = "x86"))]
x86::update_with_fmadd(output_fourier, ggsw_poly, fourier, is_output_uninit);
#[cfg(not(any(target_arch = "x86_64", target_arch = "x86")))]
update_with_fmadd_scalar(output_fourier, ggsw_poly, fourier, is_output_uninit);
}
}
}

View File

@@ -0,0 +1,157 @@
use alloc::sync::Arc;
use aligned_vec::{avec, ABox};
use concrete_fft::unordered::Plan;
use crate::implementation::zip_eq;
use super::Container;
/// Twisting factors from the paper:
/// [Fast and Error-Free Negacyclic Integer Convolution using Extended Fourier Transform][paper]
///
/// The real and imaginary parts form (the first `N/2`) `2N`-th roots of unity.
///
/// [paper]: https://eprint.iacr.org/2021/480
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct Twisties<C: Container<Item = f64>> {
pub re: C,
pub im: C,
}
impl<C: Container<Item = f64>> Twisties<C> {
pub fn as_view(&self) -> Twisties<&[f64]> {
Twisties {
re: self.re.as_ref(),
im: self.im.as_ref(),
}
}
}
impl Twisties<ABox<[f64]>> {
/// Creates a new [`Twisties`] containing the `2N`-th roots of unity with `n = N/2`.
///
/// # Panics
///
/// Panics if `n` is not a power of two.
pub fn new(n: usize) -> Self {
debug_assert!(n.is_power_of_two());
let mut re = avec![0.0; n].into_boxed_slice();
let mut im = avec![0.0; n].into_boxed_slice();
let unit = core::f64::consts::PI / (2.0 * n as f64);
for (i, (re, im)) in zip_eq(&mut *re, &mut *im).enumerate() {
(*im, *re) = (i as f64 * unit).sin_cos();
}
Twisties { re, im }
}
}
/// Negacyclic Fast Fourier Transform. See [`FftView`] for transform functions.
///
/// This structure contains the twisting factors as well as the
/// FFT plan needed for the negacyclic convolution over the reals.
#[derive(Clone, Debug)]
pub struct Fft {
plan: Arc<(Twisties<ABox<[f64]>>, Plan)>,
}
/// View type for [`Fft`].
#[derive(Clone, Copy, Debug)]
#[readonly::make]
pub struct FftView<'a> {
pub plan: &'a Plan,
pub twisties: Twisties<&'a [f64]>,
}
impl Fft {
#[inline]
pub fn as_view(&self) -> FftView<'_> {
FftView {
plan: &self.plan.1,
twisties: self.plan.0.as_view(),
}
}
}
#[cfg(feature = "std")]
mod std_only {
use super::*;
use concrete_fft::unordered::Method;
use core::time::Duration;
use once_cell::sync::OnceCell;
use std::{
collections::{hash_map::Entry, HashMap},
sync::RwLock,
};
type PlanMap = RwLock<HashMap<usize, Arc<OnceCell<Arc<(Twisties<ABox<[f64]>>, Plan)>>>>>;
static PLANS: OnceCell<PlanMap> = OnceCell::new();
fn plans() -> &'static PlanMap {
PLANS.get_or_init(|| RwLock::new(HashMap::new()))
}
impl Fft {
/// Real polynomial of size `size`.
pub fn new(size: usize) -> Self {
let global_plans = plans();
let n = size;
let get_plan = || {
let plans = global_plans.read().unwrap();
let plan = plans.get(&n).cloned();
drop(plans);
plan.map(|p| {
p.get_or_init(|| {
Arc::new((
Twisties::new(n / 2),
Plan::new(n / 2, Method::Measure(Duration::from_millis(10))),
))
})
.clone()
})
};
// could not find a plan of the given size, we lock the map again and try to insert it
let mut plans = global_plans.write().unwrap();
if let Entry::Vacant(v) = plans.entry(n) {
v.insert(Arc::new(OnceCell::new()));
}
drop(plans);
Self {
plan: get_plan().unwrap(),
}
}
}
}
#[cfg(not(feature = "std"))]
mod no_std {
use concrete_fft::{ordered::FftAlgo, unordered::Method};
use super::*;
impl Fft {
/// Real polynomial of size `size`.
pub fn new(size: usize) -> Self {
let n = size.0;
Self {
plan: Arc::new((
Twisties::new(n / 2),
Plan::new(
n / 2,
Method::UserProvided {
base_algo: FftAlgo::Dif4,
base_n: 512,
},
),
)),
}
}
}
}

View File

@@ -0,0 +1,91 @@
use super::{
decomposer::SignedDecomposer,
types::{GlweCiphertext, GlweParams, LweCiphertext, PackingKeyswitchKey},
wop::GlweCiphertextList,
zip_eq, Container,
};
impl PackingKeyswitchKey<&[u64]> {
pub fn private_functional_keyswitch_ciphertext(
&self,
mut after: GlweCiphertext<&mut [u64]>,
before: LweCiphertext<&[u64]>,
) {
debug_assert_eq!(self.glwe_params, after.glwe_params);
debug_assert_eq!(self.input_dimension, before.lwe_dimension);
// We reset the output
after.as_mut_view().into_data().fill_with(|| 0);
// We instantiate a decomposer
let decomposer = SignedDecomposer::new(self.decomp_params);
for (block, input_lwe) in
zip_eq(self.bit_decomp_iter(), before.as_view().into_data().iter())
{
// We decompose
let rounded = decomposer.closest_representable(*input_lwe);
let decomp = decomposer.decompose(rounded);
// Loop over the number of levels:
// We compute the multiplication of a ciphertext from the private functional
// keyswitching key with a piece of the decomposition and subtract it to the buffer
for (level_key_cipher, decomposed) in zip_eq(
block
.data
.chunks_exact(
(self.glwe_params.dimension + 1) * self.glwe_params.polynomial_size,
)
.rev(),
decomp,
) {
after
.as_mut_view()
.update_with_wrapping_sub_element_mul(level_key_cipher, decomposed.value());
}
}
}
}
pub struct LweKeyBitDecomposition<C: Container> {
pub data: C,
pub glwe_params: GlweParams,
pub count: usize,
}
impl<C: Container> LweKeyBitDecomposition<C> {
pub fn new(data: C, glwe_params: GlweParams, count: usize) -> Self {
debug_assert_eq!(
data.len(),
(glwe_params.dimension + 1) * glwe_params.polynomial_size * count
);
LweKeyBitDecomposition {
data,
glwe_params,
count,
}
}
pub fn into_glwe_list(self) -> GlweCiphertextList<C> {
GlweCiphertextList {
data: self.data,
glwe_params: self.glwe_params,
count: self.count,
}
}
}
impl LweKeyBitDecomposition<&[u64]> {
pub fn ciphertext_iter(&self) -> impl Iterator<Item = GlweCiphertext<&[u64]>> {
self.data
.chunks_exact((self.glwe_params.dimension + 1) * self.glwe_params.polynomial_size)
.map(move |sub| GlweCiphertext::new(sub, self.glwe_params))
}
}
impl LweKeyBitDecomposition<&mut [u64]> {
pub fn ciphertext_iter_mut(&mut self) -> impl Iterator<Item = GlweCiphertext<&mut [u64]>> {
let glwe_params = self.glwe_params;
self.data
.chunks_exact_mut((glwe_params.dimension + 1) * glwe_params.polynomial_size)
.map(move |sub| GlweCiphertext::new(sub, glwe_params))
}
}

View File

@@ -0,0 +1,132 @@
use super::{decomposer::SignedDecomposer, types::*, zip_eq};
impl LweKeyswitchKey<&[u64]> {
pub fn keyswitch_ciphertext(
self,
after: LweCiphertext<&mut [u64]>,
before: LweCiphertext<&[u64]>,
) {
let after = after.into_data();
let before = before.into_data();
// We reset the output
after.fill(0);
// We copy the body
*after.last_mut().unwrap() = *before.last().unwrap();
// We instantiate a decomposer
let decomposer = SignedDecomposer::new(self.decomp_params);
let mask_len = before.len() - 1;
for (block, before_mask) in zip_eq(self.into_lev_ciphertexts(), &before[..mask_len]) {
let mask_rounded = decomposer.closest_representable(*before_mask);
let decomp = decomposer.decompose(mask_rounded);
// loop over the number of levels
for (level_key_cipher, decomposed) in zip_eq(
block.into_data().chunks(self.output_dimension + 1).rev(),
decomp,
) {
let val = decomposed.value();
for (a, &b) in zip_eq(after.iter_mut(), level_key_cipher) {
*a = a.wrapping_sub(b.wrapping_mul(val))
}
}
}
}
}
#[cfg(test)]
mod tests {
use crate::{c_api::types::tests::to_generic, implementation::types::*};
use concrete_csprng::{
generators::{RandomGenerator, SoftwareRandomGenerator},
seeders::Seed,
};
struct KeySet {
in_dim: usize,
out_dim: usize,
in_sk: LweSecretKey<Vec<u64>>,
out_sk: LweSecretKey<Vec<u64>>,
ksk: LweKeyswitchKey<Vec<u64>>,
}
impl KeySet {
fn new(
mut csprng: CsprngMut,
in_dim: usize,
out_dim: usize,
decomp_params: DecompParams,
key_variance: f64,
) -> Self {
let in_sk = LweSecretKey::new_random(csprng.as_mut(), in_dim);
let out_sk = LweSecretKey::new_random(csprng.as_mut(), out_dim);
let ksk_len = (out_dim + 1) * in_dim * decomp_params.level;
let mut ksk =
LweKeyswitchKey::new(vec![0_u64; ksk_len], out_dim, in_dim, decomp_params);
ksk.as_mut_view().fill_with_keyswitch_key(
in_sk.as_view(),
out_sk.as_view(),
key_variance,
csprng,
);
Self {
in_dim,
out_dim,
in_sk,
out_sk,
ksk,
}
}
fn keyswitch(&self, csprng: CsprngMut, pt: u64, encryption_variance: f64) -> u64 {
let mut input = LweCiphertext::zero(self.in_dim);
let mut output = LweCiphertext::zero(self.out_dim);
self.in_sk
.as_view()
.encrypt_lwe(input.as_mut_view(), pt, encryption_variance, csprng);
self.ksk
.as_view()
.keyswitch_ciphertext(output.as_mut_view(), input.as_view());
self.out_sk.as_view().decrypt_lwe(output.as_view())
}
}
#[test]
fn keyswitch_correctness() {
let mut csprng = SoftwareRandomGenerator::new(Seed(0));
let keyset = KeySet::new(
to_generic(&mut csprng),
1024,
600,
DecompParams {
level: 3,
base_log: 10,
},
0.0000000000000001,
);
for _ in 0..100 {
let input: u64 = u64::from_le_bytes(std::array::from_fn(|_| csprng.next().unwrap()));
let output = keyset.keyswitch(to_generic(&mut csprng), input, 0.00000000001);
let diff = output.wrapping_sub(input) as i64;
assert!((diff as f64).abs() / 2.0_f64.powi(64) < 0.01);
}
}
}

View File

@@ -0,0 +1,272 @@
use core::mem::MaybeUninit;
use aligned_vec::{ABox, AVec, CACHELINE_ALIGN};
#[allow(unused_macros)]
macro_rules! izip {
// no one should need to zip more than 16 iterators, right?
(@ __closure @ ($a:expr)) => { |a| (a,) };
(@ __closure @ ($a:expr, $b:expr)) => { |(a, b)| (a, b) };
(@ __closure @ ($a:expr, $b:expr, $c:expr)) => { |((a, b), c)| (a, b, c) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr)) => { |(((a, b), c), d)| (a, b, c, d) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr)) => { |((((a, b), c), d), e)| (a, b, c, d, e) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr)) => { |(((((a, b), c), d), e), f)| (a, b, c, d, e, f) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr)) => { |((((((a, b), c), d), e), f), g)| (a, b, c, d, e, f, e) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr)) => { |(((((((a, b), c), d), e), f), g), h)| (a, b, c, d, e, f, g, h) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr)) => { |((((((((a, b), c), d), e), f), g), h), i)| (a, b, c, d, e, f, g, h, i) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr)) => { |(((((((((a, b), c), d), e), f), g), h), i), j)| (a, b, c, d, e, f, g, h, i, j) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr)) => { |((((((((((a, b), c), d), e), f), g), h), i), j), k)| (a, b, c, d, e, f, g, h, i, j, k) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr)) => { |(((((((((((a, b), c), d), e), f), g), h), i), j), k), l)| (a, b, c, d, e, f, g, h, i, j, k, l) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr)) => { |((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m)| (a, b, c, d, e, f, g, h, i, j, k, l, m) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr)) => { |(((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n) };
(@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr, $o:expr)) => { |((((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n), o)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) };
( $first:expr $(,)?) => {
{
#[allow(unused_imports)]
use $crate::implementation::ZipEq;
::core::iter::IntoIterator::into_iter($first)
}
};
( $first:expr, $($rest:expr),+ $(,)?) => {
{
#[allow(unused_imports)]
use $crate::implementation::ZipEq;
::core::iter::IntoIterator::into_iter($first)
$(.zip_eq($rest))*
.map(izip!(@ __closure @ ($first, $($rest),*)))
}
};
}
mod ciphertext;
mod convert;
mod decomposer;
mod decomposition;
pub mod fft;
pub mod fpks;
mod polynomial;
pub mod types;
pub mod bootstrap;
pub mod cmux;
pub mod encrypt;
pub mod external_product;
pub mod keyswitch;
pub mod wop;
/// Convert a mutable slice reference to an uninitialized mutable slice reference.
///
/// # Safety
///
/// No uninitialized values must be written into the output slice by the time the borrow ends
#[inline]
pub unsafe fn as_mut_uninit<T>(slice: &mut [T]) -> &mut [MaybeUninit<T>] {
let len = slice.len();
let ptr = slice.as_mut_ptr();
// SAFETY: T and MaybeUninit<T> have the same layout
unsafe { core::slice::from_raw_parts_mut(ptr as *mut _, len) }
}
/// Convert an uninitialized mutable slice reference to an initialized mutable slice reference.
///
/// # Safety
///
/// All the elements of the input slice must be initialized and in a valid state.
#[inline]
pub unsafe fn assume_init_mut<T>(slice: &mut [MaybeUninit<T>]) -> &mut [T] {
let len = slice.len();
let ptr = slice.as_mut_ptr();
// SAFETY: T and MaybeUninit<T> have the same layout
unsafe { core::slice::from_raw_parts_mut(ptr as *mut _, len) }
}
#[inline]
fn debug_assert_same_len(a: (usize, Option<usize>), b: (usize, Option<usize>)) {
debug_assert_eq!(a.1, Some(a.0));
debug_assert_eq!(b.1, Some(b.0));
debug_assert_eq!(a.0, b.0);
}
/// Returns a Zip iterator, but checks that the two components have the same length.
trait ZipEq: IntoIterator + Sized {
#[inline]
fn zip_eq<B: IntoIterator>(
self,
b: B,
) -> core::iter::Zip<<Self as IntoIterator>::IntoIter, <B as IntoIterator>::IntoIter> {
let a = self.into_iter();
let b = b.into_iter();
debug_assert_same_len(a.size_hint(), b.size_hint());
core::iter::zip(a, b)
}
}
pub fn zip_eq<T, U>(
a: impl IntoIterator<Item = T>,
b: impl IntoIterator<Item = U>,
) -> impl Iterator<Item = (T, U)> {
let a = a.into_iter();
let b = b.into_iter();
debug_assert_same_len(a.size_hint(), b.size_hint());
core::iter::zip(a, b)
}
impl<A: IntoIterator> ZipEq for A {}
pub trait Container: Sized + AsRef<[Self::Item]> {
type Item;
fn len(&self) -> usize {
self.as_ref().len()
}
}
pub trait ContainerMut: Container + AsMut<[Self::Item]> {}
pub trait ContainerOwned: Container + AsMut<[Self::Item]> {
fn collect(iter: impl Iterator<Item = Self::Item>) -> Self;
}
impl<'a, T> Container for &'a [T] {
type Item = T;
}
impl<'a, T> Container for &'a mut [T] {
type Item = T;
}
impl<'a, T> ContainerMut for &'a mut [T] {}
impl<T> Container for ABox<[T]> {
type Item = T;
}
impl<T> ContainerMut for ABox<[T]> {}
impl<T> ContainerOwned for ABox<[T]> {
fn collect(iter: impl Iterator<Item = Self::Item>) -> Self {
AVec::from_iter(CACHELINE_ALIGN, iter).into_boxed_slice()
}
}
impl<T> Container for AVec<T> {
type Item = T;
}
impl<T> ContainerMut for AVec<T> {}
impl<T> ContainerOwned for AVec<T> {
fn collect(iter: impl Iterator<Item = Self::Item>) -> Self {
AVec::from_iter(CACHELINE_ALIGN, iter)
}
}
impl<T> Container for Vec<T> {
type Item = T;
}
impl<T> ContainerMut for Vec<T> {}
impl<T> ContainerOwned for Vec<T> {
fn collect(iter: impl Iterator<Item = Self::Item>) -> Self {
iter.collect()
}
}
pub trait Split: Container {
type Pointer: Copy;
type Chunks: DoubleEndedIterator<Item = Self> + ExactSizeIterator<Item = Self>;
unsafe fn from_raw_parts(data: Self::Pointer, len: usize) -> Self;
fn split_at(self, mid: usize) -> (Self, Self);
fn chunk(self, start: usize, end: usize) -> Self {
self.split_at(end).0.split_at(start).1
}
fn into_chunks(self, chunk_size: usize) -> Self::Chunks;
fn split_into(self, chunk_count: usize) -> Self::Chunks {
debug_assert_ne!(chunk_count, 0);
let len = self.len();
debug_assert_eq!(len % chunk_count, 0);
self.into_chunks(len / chunk_count)
}
}
impl<'a, T> Split for &'a [T] {
type Pointer = *const T;
type Chunks = core::slice::ChunksExact<'a, T>;
unsafe fn from_raw_parts(data: Self::Pointer, len: usize) -> Self {
unsafe { core::slice::from_raw_parts(data, len) }
}
fn split_at(self, mid: usize) -> (Self, Self) {
(*self).split_at(mid)
}
fn into_chunks(self, chunk_size: usize) -> Self::Chunks {
debug_assert_ne!(chunk_size, 0);
debug_assert_eq!(self.len() % chunk_size, 0);
self.chunks_exact(chunk_size)
}
}
impl<'a, T> Split for &'a mut [T] {
type Pointer = *mut T;
type Chunks = core::slice::ChunksExactMut<'a, T>;
unsafe fn from_raw_parts(data: Self::Pointer, len: usize) -> Self {
unsafe { core::slice::from_raw_parts_mut(data, len) }
}
fn split_at(self, mid: usize) -> (Self, Self) {
(*self).split_at_mut(mid)
}
fn into_chunks(self, chunk_size: usize) -> Self::Chunks {
debug_assert_ne!(chunk_size, 0);
debug_assert_eq!(self.len() % chunk_size, 0);
self.chunks_exact_mut(chunk_size)
}
}
#[cfg(feature = "parallel")]
pub mod parallel {
use super::*;
use rayon::prelude::*;
pub trait ParSplit: Split + Send {
type ParChunks: IndexedParallelIterator<Item = Self>;
fn into_par_chunks(self, chunk_size: usize) -> Self::ParChunks;
fn par_split_into(self, chunk_count: usize) -> Self::ParChunks {
if chunk_count == 0 {
self.split_at(0).0.into_par_chunks(1)
} else {
let len = self.len();
debug_assert_eq!(len % chunk_count, 0);
self.into_par_chunks(len / chunk_count)
}
}
}
impl<'a, T: Sync> ParSplit for &'a [T] {
type ParChunks = rayon::slice::ChunksExact<'a, T>;
fn into_par_chunks(self, chunk_size: usize) -> Self::ParChunks {
self.par_chunks_exact(chunk_size)
}
}
impl<'a, T: Send> ParSplit for &'a mut [T] {
type ParChunks = rayon::slice::ChunksExactMut<'a, T>;
fn into_par_chunks(self, chunk_size: usize) -> Self::ParChunks {
self.par_chunks_exact_mut(chunk_size)
}
}
}
pub fn from_torus(input: f64) -> u64 {
let mut fract = input - f64::round(input);
fract *= 2.0_f64.powi(u64::BITS as i32);
fract = f64::round(fract);
fract as i64 as u64
}

View File

@@ -0,0 +1,83 @@
pub fn update_with_wrapping_unit_monomial_div(polynomial: &mut [u64], monomial_degree: usize) {
let full_cycles_count = monomial_degree / polynomial.len();
if full_cycles_count % 2 != 0 {
for a in polynomial.iter_mut() {
*a = a.wrapping_neg()
}
}
let remaining_degree = monomial_degree % polynomial.len();
polynomial.rotate_left(remaining_degree);
for a in polynomial.iter_mut().rev().take(remaining_degree) {
*a = a.wrapping_neg()
}
}
pub fn update_with_wrapping_monic_monomial_mul(polynomial: &mut [u64], monomial_degree: usize) {
let full_cycles_count = monomial_degree / polynomial.len();
if full_cycles_count % 2 != 0 {
for a in polynomial.iter_mut() {
*a = a.wrapping_neg()
}
}
let remaining_degree = monomial_degree % polynomial.len();
polynomial.rotate_right(remaining_degree);
for a in polynomial.iter_mut().take(remaining_degree) {
*a = a.wrapping_neg()
}
}
pub fn update_with_wrapping_add_mul(
polynomial: &mut [u64],
lhs_polynomial: &[u64],
rhs_bin_polynomial: &[u64],
) {
debug_assert_eq!(polynomial.len(), lhs_polynomial.len());
debug_assert_eq!(polynomial.len(), rhs_bin_polynomial.len());
// TODO: optimize performance, while keeping constant time, so as not to leak information about
// the secret key.
let dim = polynomial.len();
for (i, lhs) in lhs_polynomial.iter().enumerate() {
let lhs = *lhs;
for (j, rhs) in rhs_bin_polynomial.iter().enumerate() {
let target_degree = i + j;
if target_degree < dim {
let update = polynomial[target_degree].wrapping_add(lhs * *rhs);
polynomial[target_degree] = update;
} else {
let update = polynomial[target_degree - dim].wrapping_sub(lhs * *rhs);
polynomial[target_degree - dim] = update;
}
}
}
}
pub fn update_with_wrapping_sub_mul(
polynomial: &mut [u64],
lhs_polynomial: &[u64],
rhs_bin_polynomial: &[u64],
) {
debug_assert_eq!(polynomial.len(), lhs_polynomial.len());
debug_assert_eq!(polynomial.len(), rhs_bin_polynomial.len());
// TODO: optimize performance, while keeping constant time, so as not to leak information about
// the secret key.
let dim = polynomial.len();
for (i, lhs) in lhs_polynomial.iter().enumerate() {
let lhs = *lhs;
for (j, rhs) in rhs_bin_polynomial.iter().enumerate() {
let target_degree = i + j;
if target_degree < dim {
let update = polynomial[target_degree].wrapping_sub(lhs * *rhs);
polynomial[target_degree] = update;
} else {
let update = polynomial[target_degree - dim].wrapping_add(lhs * *rhs);
polynomial[target_degree - dim] = update;
}
}
}
}

View File

@@ -0,0 +1,139 @@
use super::{DecompParams, GgswCiphertext, GlweParams};
use crate::implementation::{fft::FftView, zip_eq, Container, ContainerMut, Split};
use dyn_stack::{DynStack, ReborrowMut};
#[cfg(feature = "parallel")]
use rayon::{
prelude::{IndexedParallelIterator, ParallelIterator},
slice::ParallelSliceMut,
};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct BootstrapKey<C: Container> {
pub data: C,
pub glwe_params: GlweParams,
pub input_lwe_dimension: usize,
pub decomp_params: DecompParams,
}
impl<C: Container> BootstrapKey<C> {
pub fn data_len(
glwe_params: GlweParams,
decomposition_level_count: usize,
input_lwe_dimension: usize,
) -> usize {
glwe_params.polynomial_size
* (glwe_params.dimension + 1)
* (glwe_params.dimension + 1)
* decomposition_level_count
* input_lwe_dimension
}
pub fn new(
data: C,
glwe_params: GlweParams,
input_lwe_dimension: usize,
decomp_params: DecompParams,
) -> Self {
debug_assert_eq!(
data.len(),
Self::data_len(glwe_params, decomp_params.level, input_lwe_dimension),
);
Self {
data,
glwe_params,
input_lwe_dimension,
decomp_params,
}
}
pub unsafe fn from_raw_parts(
data: C::Pointer,
glwe_params: GlweParams,
input_lwe_dimension: usize,
decomp_params: DecompParams,
) -> Self
where
C: Split,
{
let data = C::from_raw_parts(
data,
Self::data_len(glwe_params, decomp_params.level, input_lwe_dimension),
);
Self {
data,
glwe_params,
input_lwe_dimension,
decomp_params,
}
}
pub fn as_view(&self) -> BootstrapKey<&[C::Item]> {
BootstrapKey {
data: self.data.as_ref(),
glwe_params: self.glwe_params,
input_lwe_dimension: self.input_lwe_dimension,
decomp_params: self.decomp_params,
}
}
pub fn as_mut_view(&mut self) -> BootstrapKey<&mut [C::Item]>
where
C: ContainerMut,
{
BootstrapKey {
data: self.data.as_mut(),
glwe_params: self.glwe_params,
input_lwe_dimension: self.input_lwe_dimension,
decomp_params: self.decomp_params,
}
}
pub fn into_ggsw_iter(self) -> impl DoubleEndedIterator<Item = GgswCiphertext<C>>
where
C: Split,
{
self.data
.split_into(self.input_lwe_dimension)
.map(move |slice| GgswCiphertext::new(slice, self.glwe_params, self.decomp_params))
}
pub fn output_lwe_dimension(&self) -> usize {
self.glwe_params.lwe_dimension()
}
}
#[cfg(feature = "parallel")]
impl<'a> BootstrapKey<&'a mut [u64]> {
pub fn into_ggsw_iter_par(
self,
) -> impl 'a + IndexedParallelIterator<Item = GgswCiphertext<&'a mut [u64]>> {
debug_assert_eq!(self.data.len() % self.input_lwe_dimension, 0);
let chunk_size = self.data.len() / self.input_lwe_dimension;
self.data
.par_chunks_exact_mut(chunk_size)
.map(move |slice| GgswCiphertext::new(slice, self.glwe_params, self.decomp_params))
}
}
impl BootstrapKey<&mut [f64]> {
pub fn fill_with_forward_fourier(
&mut self,
coef_bsk: BootstrapKey<&[u64]>,
fft: FftView<'_>,
mut stack: DynStack<'_>,
) {
debug_assert_eq!(self.decomp_params, coef_bsk.decomp_params);
debug_assert_eq!(self.glwe_params, coef_bsk.glwe_params);
debug_assert_eq!(self.input_lwe_dimension, coef_bsk.input_lwe_dimension);
for (a, b) in zip_eq(
self.as_mut_view().into_ggsw_iter(),
coef_bsk.into_ggsw_iter(),
) {
a.fill_with_forward_fourier(b, fft, stack.rb_mut());
}
}
}

View File

@@ -0,0 +1,63 @@
use crate::implementation::{Container, ContainerMut, Split};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct LweCiphertext<C: Container> {
pub data: C,
pub lwe_dimension: usize,
}
impl<C: Container> LweCiphertext<C> {
pub fn data_len(lwe_dimension: usize) -> usize {
lwe_dimension + 1
}
pub fn new(data: C, lwe_dimension: usize) -> Self {
debug_assert_eq!(data.len(), Self::data_len(lwe_dimension));
Self {
data,
lwe_dimension,
}
}
pub unsafe fn from_raw_parts(data: C::Pointer, lwe_dimension: usize) -> Self
where
C: Split,
{
Self {
data: C::from_raw_parts(data, Self::data_len(lwe_dimension)),
lwe_dimension,
}
}
pub fn as_view(&self) -> LweCiphertext<&[C::Item]> {
LweCiphertext {
data: self.data.as_ref(),
lwe_dimension: self.lwe_dimension,
}
}
pub fn as_mut_view(&mut self) -> LweCiphertext<&mut [C::Item]>
where
C: ContainerMut,
{
LweCiphertext {
data: self.data.as_mut(),
lwe_dimension: self.lwe_dimension,
}
}
pub fn into_data(self) -> C {
self.data
}
}
pub mod test {
use super::*;
impl LweCiphertext<Vec<u64>> {
pub fn zero(dim: usize) -> Self {
LweCiphertext::new(vec![0; dim + 1], dim)
}
}
}

View File

@@ -0,0 +1,78 @@
use crate::implementation::{Container, ContainerMut, Split};
use super::LweCiphertext;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct LweCiphertextList<C: Container> {
pub data: C,
pub lwe_dimension: usize,
pub count: usize,
}
impl<C: Container> LweCiphertextList<C> {
pub fn data_len(lwe_dimension: usize) -> usize {
lwe_dimension + 1
}
pub fn new(data: C, lwe_dimension: usize, count: usize) -> Self {
debug_assert_eq!(data.len(), (lwe_dimension + 1) * count);
Self {
data,
lwe_dimension,
count,
}
}
pub unsafe fn from_raw_parts(data: C::Pointer, lwe_dimension: usize, count: usize) -> Self
where
C: Split,
{
Self {
data: C::from_raw_parts(data, (lwe_dimension + 1) * count),
lwe_dimension,
count,
}
}
pub fn as_view(&self) -> LweCiphertextList<&[C::Item]> {
LweCiphertextList {
data: self.data.as_ref(),
lwe_dimension: self.lwe_dimension,
count: self.count,
}
}
pub fn as_mut_view(&mut self) -> LweCiphertextList<&mut [C::Item]>
where
C: ContainerMut,
{
LweCiphertextList {
data: self.data.as_mut(),
lwe_dimension: self.lwe_dimension,
count: self.count,
}
}
pub fn into_data(self) -> C {
self.data
}
}
impl LweCiphertextList<&mut [u64]> {
pub fn ciphertext_iter_mut(
&mut self,
) -> impl DoubleEndedIterator<Item = LweCiphertext<&mut [u64]>> {
self.data
.chunks_exact_mut(self.lwe_dimension + 1)
.map(|data| LweCiphertext::new(data, self.lwe_dimension))
}
}
impl LweCiphertextList<&[u64]> {
pub fn ciphertext_iter(&self) -> impl DoubleEndedIterator<Item = LweCiphertext<&[u64]>> {
self.data
.chunks_exact(self.lwe_dimension + 1)
.map(|data| LweCiphertext::new(data, self.lwe_dimension))
}
}

View File

@@ -0,0 +1,34 @@
use crate::c_api::types::{Csprng, CsprngVtable};
use core::marker::PhantomData;
pub struct CsprngMut<'value, 'vtable: 'value> {
ptr: *mut Csprng,
vtable: &'vtable CsprngVtable,
__marker: PhantomData<&'value mut ()>,
}
impl<'value, 'vtable: 'value> CsprngMut<'value, 'vtable> {
#[inline]
pub unsafe fn new(ptr: *mut Csprng, vtable: *const CsprngVtable) -> Self {
Self {
ptr,
vtable: &*vtable,
__marker: PhantomData,
}
}
#[inline]
pub fn as_mut<'this>(&'this mut self) -> CsprngMut<'this, 'vtable> {
Self {
ptr: self.ptr,
vtable: self.vtable,
__marker: PhantomData,
}
}
#[inline]
pub fn next_bytes(&self, slice: &mut [u8]) -> usize {
let byte_count = slice.len();
unsafe { (self.vtable.next_bytes)(self.ptr, slice.as_mut_ptr(), byte_count) }
}
}

View File

@@ -0,0 +1,213 @@
use crate::implementation::{Container, ContainerMut, Split};
use super::{DecompParams, GlweParams};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct GgswLevelRow<C: Container> {
pub data: C,
pub glwe_params: GlweParams,
pub decomposition_level: usize,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct GgswLevelMatrix<C: Container> {
pub data: C,
pub glwe_params: GlweParams,
pub decomposition_level: usize,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct GgswCiphertext<C: Container> {
pub data: C,
pub glwe_params: GlweParams,
pub decomp_params: DecompParams,
}
impl<C: Container> GgswLevelRow<C> {
pub fn data_len(glwe_params: GlweParams) -> usize {
glwe_params.polynomial_size * (glwe_params.dimension + 1)
}
pub fn new(data: C, glwe_params: GlweParams, decomposition_level: usize) -> Self {
debug_assert_eq!(data.len(), Self::data_len(glwe_params));
Self {
data,
glwe_params,
decomposition_level,
}
}
pub unsafe fn from_raw_parts(
data: C::Pointer,
glwe_params: GlweParams,
decomposition_level: usize,
) -> Self
where
C: Split,
{
let data = C::from_raw_parts(data, Self::data_len(glwe_params));
Self {
data,
glwe_params,
decomposition_level,
}
}
pub fn as_view(&self) -> GgswLevelRow<&[C::Item]> {
GgswLevelRow {
data: self.data.as_ref(),
glwe_params: self.glwe_params,
decomposition_level: self.decomposition_level,
}
}
pub fn as_mut_view(&mut self) -> GgswLevelRow<&mut [C::Item]>
where
C: ContainerMut,
{
GgswLevelRow {
data: self.data.as_mut(),
glwe_params: self.glwe_params,
decomposition_level: self.decomposition_level,
}
}
pub fn into_data(self) -> C {
self.data
}
}
impl<C: Container> GgswLevelMatrix<C> {
pub fn data_len(glwe_params: GlweParams) -> usize {
glwe_params.polynomial_size * (glwe_params.dimension + 1) * (glwe_params.dimension + 1)
}
pub fn new(data: C, glwe_params: GlweParams, decomposition_level: usize) -> Self {
debug_assert_eq!(data.len(), Self::data_len(glwe_params));
Self {
data,
glwe_params,
decomposition_level,
}
}
pub unsafe fn from_raw_parts(
data: C::Pointer,
glwe_params: GlweParams,
decomposition_level: usize,
) -> Self
where
C: Split,
{
let data = C::from_raw_parts(data, Self::data_len(glwe_params));
Self {
data,
glwe_params,
decomposition_level,
}
}
pub fn as_view(&self) -> GgswLevelMatrix<&[C::Item]> {
GgswLevelMatrix {
data: self.data.as_ref(),
glwe_params: self.glwe_params,
decomposition_level: self.decomposition_level,
}
}
pub fn as_mut_view(&mut self) -> GgswLevelMatrix<&mut [C::Item]>
where
C: ContainerMut,
{
GgswLevelMatrix {
data: self.data.as_mut(),
glwe_params: self.glwe_params,
decomposition_level: self.decomposition_level,
}
}
pub fn into_data(self) -> C {
self.data
}
pub fn into_rows_iter(self) -> impl DoubleEndedIterator<Item = GgswLevelRow<C>>
where
C: Split,
{
self.data
.split_into(self.glwe_params.dimension + 1)
.map(move |slice| GgswLevelRow::new(slice, self.glwe_params, self.decomposition_level))
}
}
impl<C: Container> GgswCiphertext<C> {
pub fn data_len(glwe_params: GlweParams, decomposition_level_count: usize) -> usize {
glwe_params.polynomial_size
* (glwe_params.dimension + 1)
* (glwe_params.dimension + 1)
* decomposition_level_count
}
pub fn new(data: C, glwe_params: GlweParams, decomp_params: DecompParams) -> Self {
debug_assert_eq!(data.len(), Self::data_len(glwe_params, decomp_params.level));
Self {
data,
glwe_params,
decomp_params,
}
}
pub unsafe fn from_raw_parts(
data: C::Pointer,
glwe_params: GlweParams,
decomposition_level_count: usize,
decomp_params: DecompParams,
) -> Self
where
C: Split,
{
let data = C::from_raw_parts(data, Self::data_len(glwe_params, decomposition_level_count));
Self {
data,
glwe_params,
decomp_params,
}
}
pub fn as_view(&self) -> GgswCiphertext<&[C::Item]> {
GgswCiphertext {
data: self.data.as_ref(),
glwe_params: self.glwe_params,
decomp_params: self.decomp_params,
}
}
pub fn as_mut_view(&mut self) -> GgswCiphertext<&mut [C::Item]>
where
C: ContainerMut,
{
GgswCiphertext {
data: self.data.as_mut(),
glwe_params: self.glwe_params,
decomp_params: self.decomp_params,
}
}
pub fn into_data(self) -> C {
self.data
}
pub fn into_level_matrices_iter(self) -> impl DoubleEndedIterator<Item = GgswLevelMatrix<C>>
where
C: Split,
{
self.data
.split_into(self.decomp_params.level)
.enumerate()
.map(move |(i, slice)| GgswLevelMatrix::new(slice, self.glwe_params, i + 1))
}
}

View File

@@ -0,0 +1,195 @@
use crate::implementation::{zip_eq, Container, ContainerMut, Split};
use super::GlweParams;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct GlweCiphertext<C: Container> {
pub data: C,
pub glwe_params: GlweParams,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct GlweMask<C: Container> {
pub data: C,
pub glwe_params: GlweParams,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct GlweBody<C: Container> {
pub data: C,
pub polynomial_size: usize,
}
impl<C: Container> GlweCiphertext<C> {
pub fn data_len(glwe_params: GlweParams) -> usize {
glwe_params.polynomial_size * (glwe_params.dimension + 1)
}
pub fn new(data: C, glwe_params: GlweParams) -> Self {
debug_assert_eq!(data.len(), Self::data_len(glwe_params));
Self { data, glwe_params }
}
pub unsafe fn from_raw_parts(data: C::Pointer, glwe_params: GlweParams) -> Self
where
C: Split,
{
Self {
data: C::from_raw_parts(data, Self::data_len(glwe_params)),
glwe_params,
}
}
pub fn as_view(&self) -> GlweCiphertext<&[C::Item]> {
GlweCiphertext {
data: self.data.as_ref(),
glwe_params: self.glwe_params,
}
}
pub fn as_mut_view(&mut self) -> GlweCiphertext<&mut [C::Item]>
where
C: ContainerMut,
{
GlweCiphertext {
data: self.data.as_mut(),
glwe_params: self.glwe_params,
}
}
pub fn into_data(self) -> C {
self.data
}
pub fn into_mask_and_body(self) -> (GlweMask<C>, GlweBody<C>)
where
C: Split,
{
let (mask, body) = self
.data
.split_at(self.glwe_params.polynomial_size * self.glwe_params.dimension);
(
GlweMask {
data: mask,
glwe_params: self.glwe_params,
},
GlweBody {
data: body,
polynomial_size: self.glwe_params.polynomial_size,
},
)
}
pub fn into_body(self) -> GlweBody<C>
where
C: Split,
{
self.into_mask_and_body().1
}
}
impl GlweCiphertext<&mut [u64]> {
pub fn update_with_wrapping_sub_element_mul(self, other: &[u64], multiplier: u64) {
for (a, b) in zip_eq(self.data, other) {
*a = a.wrapping_sub(b.wrapping_mul(multiplier));
}
}
}
impl<C: Container> GlweMask<C> {
pub fn data_len(glwe_params: GlweParams) -> usize {
glwe_params.polynomial_size * glwe_params.dimension
}
pub fn new(data: C, glwe_params: GlweParams) -> Self {
debug_assert_eq!(data.len(), Self::data_len(glwe_params));
Self { data, glwe_params }
}
pub unsafe fn from_raw_parts(data: C::Pointer, glwe_params: GlweParams) -> Self
where
C: Split,
{
let data = C::from_raw_parts(data, Self::data_len(glwe_params));
Self { data, glwe_params }
}
pub fn as_view(&self) -> GlweMask<&[C::Item]> {
GlweMask {
data: self.data.as_ref(),
glwe_params: self.glwe_params,
}
}
pub fn as_mut_view(&mut self) -> GlweMask<&mut [C::Item]>
where
C: ContainerMut,
{
GlweMask {
data: self.data.as_mut(),
glwe_params: self.glwe_params,
}
}
pub fn into_data(self) -> C {
self.data
}
pub fn get_polynomial(self, idx: usize) -> C
where
C: Split,
{
self.data.chunk(
idx * self.glwe_params.polynomial_size,
(idx + 1) * self.glwe_params.polynomial_size,
)
}
}
impl<C: Container> GlweBody<C> {
pub fn data_len(polynomial_size: usize) -> usize {
polynomial_size
}
pub fn new(data: C, polynomial_size: usize) -> Self {
debug_assert_eq!(data.len(), Self::data_len(polynomial_size));
Self {
data,
polynomial_size,
}
}
pub unsafe fn from_raw_parts(data: C::Pointer, polynomial_size: usize) -> Self
where
C: Split,
{
Self {
data: C::from_raw_parts(data, Self::data_len(polynomial_size)),
polynomial_size,
}
}
pub fn as_view(&self) -> GlweBody<&[C::Item]> {
GlweBody {
data: self.data.as_ref(),
polynomial_size: self.polynomial_size,
}
}
pub fn as_mut_view(&mut self) -> GlweBody<&mut [C::Item]>
where
C: ContainerMut,
{
GlweBody {
data: self.data.as_mut(),
polynomial_size: self.polynomial_size,
}
}
pub fn into_data(self) -> C {
self.data
}
}

View File

@@ -0,0 +1,66 @@
use super::{GlweParams, LweSecretKey};
use crate::implementation::{Container, ContainerMut, Split};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct GlweSecretKey<C: Container> {
pub data: C,
pub glwe_params: GlweParams,
}
impl<C: Container> GlweSecretKey<C> {
pub fn data_len(glwe_params: GlweParams) -> usize {
glwe_params.lwe_dimension()
}
pub fn new(data: C, glwe_params: GlweParams) -> Self {
debug_assert_eq!(data.len(), Self::data_len(glwe_params));
Self { data, glwe_params }
}
pub unsafe fn from_raw_parts(data: C::Pointer, glwe_params: GlweParams) -> Self
where
C: Split,
{
Self {
data: C::from_raw_parts(data, Self::data_len(glwe_params)),
glwe_params,
}
}
pub fn as_view(&self) -> GlweSecretKey<&[C::Item]> {
GlweSecretKey {
data: self.data.as_ref(),
glwe_params: self.glwe_params,
}
}
pub fn as_mut_view(&mut self) -> GlweSecretKey<&mut [C::Item]>
where
C: ContainerMut,
{
GlweSecretKey {
data: self.data.as_mut(),
glwe_params: self.glwe_params,
}
}
pub fn into_data(self) -> C {
self.data
}
pub fn get_polynomial(self, idx: usize) -> C
where
C: Split,
{
self.data.chunk(
idx * self.glwe_params.polynomial_size,
(idx + 1) * self.glwe_params.polynomial_size,
)
}
#[allow(clippy::wrong_self_convention)]
pub fn as_lwe(self) -> LweSecretKey<C> {
LweSecretKey::new(self.data, self.glwe_params.lwe_dimension())
}
}

View File

@@ -0,0 +1,167 @@
use super::*;
use crate::implementation::{Container, ContainerMut, Split};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct LweKeyswitchKey<C: Container> {
pub data: C,
pub output_dimension: usize,
pub input_dimension: usize,
pub decomp_params: DecompParams,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct LevCiphertext<C: Container> {
pub data: C,
pub lwe_dimension: usize,
pub ciphertext_count: usize,
}
impl<C: Container> LweKeyswitchKey<C> {
pub fn data_len(
output_dimension: usize,
decomposition_level_count: usize,
input_dimension: usize,
) -> usize {
input_dimension * decomposition_level_count * (output_dimension + 1)
}
pub fn new(
data: C,
output_dimension: usize,
input_dimension: usize,
decomp_params: DecompParams,
) -> Self {
debug_assert_eq!(
data.len(),
Self::data_len(output_dimension, decomp_params.level, input_dimension),
);
Self {
data,
output_dimension,
input_dimension,
decomp_params,
}
}
pub unsafe fn from_raw_parts(
data: C::Pointer,
output_dimension: usize,
input_dimension: usize,
decomp_params: DecompParams,
) -> Self
where
C: Split,
{
Self {
data: C::from_raw_parts(
data,
Self::data_len(output_dimension, decomp_params.level, input_dimension),
),
output_dimension,
input_dimension,
decomp_params,
}
}
pub fn as_view(&self) -> LweKeyswitchKey<&[C::Item]> {
LweKeyswitchKey {
data: self.data.as_ref(),
output_dimension: self.output_dimension,
input_dimension: self.input_dimension,
decomp_params: self.decomp_params,
}
}
pub fn as_mut_view(&mut self) -> LweKeyswitchKey<&mut [C::Item]>
where
C: ContainerMut,
{
LweKeyswitchKey {
data: self.data.as_mut(),
output_dimension: self.output_dimension,
input_dimension: self.input_dimension,
decomp_params: self.decomp_params,
}
}
pub fn into_data(self) -> C {
self.data
}
pub fn into_lev_ciphertexts(self) -> impl DoubleEndedIterator<Item = LevCiphertext<C>>
where
C: Split,
{
self.data
.split_into(self.input_dimension)
.map(move |slice| {
LevCiphertext::new(slice, self.output_dimension, self.decomp_params.level)
})
}
}
impl<C: Container> LevCiphertext<C> {
pub fn data_len(lwe_dimension: usize, ciphertext_count: usize) -> usize {
ciphertext_count * (lwe_dimension + 1)
}
pub fn new(data: C, lwe_dimension: usize, ciphertext_count: usize) -> Self {
debug_assert_eq!(data.len(), Self::data_len(lwe_dimension, ciphertext_count));
Self {
data,
lwe_dimension,
ciphertext_count,
}
}
pub unsafe fn from_raw_parts(
data: C::Pointer,
lwe_dimension: usize,
ciphertext_count: usize,
) -> Self
where
C: Split,
{
Self {
data: C::from_raw_parts(data, Self::data_len(lwe_dimension, ciphertext_count)),
lwe_dimension,
ciphertext_count,
}
}
pub fn as_view(&self) -> LevCiphertext<&[C::Item]> {
LevCiphertext {
data: self.data.as_ref(),
lwe_dimension: self.lwe_dimension,
ciphertext_count: self.ciphertext_count,
}
}
pub fn as_mut_view(&mut self) -> LevCiphertext<&mut [C::Item]>
where
C: ContainerMut,
{
LevCiphertext {
data: self.data.as_mut(),
lwe_dimension: self.lwe_dimension,
ciphertext_count: self.ciphertext_count,
}
}
pub fn into_data(self) -> C {
self.data
}
pub fn into_ciphertext_iter(self) -> impl DoubleEndedIterator<Item = LweCiphertext<C>>
where
C: Split,
{
self.data
.split_into(self.ciphertext_count)
.map(move |slice| LweCiphertext::new(slice, self.lwe_dimension))
}
}

View File

@@ -0,0 +1,84 @@
use super::CsprngMut;
use crate::implementation::{Container, ContainerMut, Split};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct LweSecretKey<C: Container> {
pub data: C,
pub lwe_dimension: usize,
}
impl<C: Container> LweSecretKey<C> {
pub fn data_len(lwe_dimension: usize) -> usize {
lwe_dimension
}
pub fn new(data: C, lwe_dimension: usize) -> Self {
debug_assert_eq!(data.len(), Self::data_len(lwe_dimension));
Self {
data,
lwe_dimension,
}
}
pub unsafe fn from_raw_parts(data: C::Pointer, lwe_dimension: usize) -> Self
where
C: Split,
{
Self {
data: C::from_raw_parts(data, Self::data_len(lwe_dimension)),
lwe_dimension,
}
}
pub fn as_view(&self) -> LweSecretKey<&[C::Item]> {
LweSecretKey {
data: self.data.as_ref(),
lwe_dimension: self.lwe_dimension,
}
}
pub fn as_mut_view(&mut self) -> LweSecretKey<&mut [C::Item]>
where
C: ContainerMut,
{
LweSecretKey {
data: self.data.as_mut(),
lwe_dimension: self.lwe_dimension,
}
}
pub fn into_data(self) -> C {
self.data
}
}
impl LweSecretKey<&mut [u64]> {
pub fn fill_with_new_key(self, mut csprng: CsprngMut<'_, '_>) {
for sk_bit in self.data {
let mut bytes = [0_u8; 1];
let success_count = csprng.as_mut().next_bytes(&mut bytes);
if success_count == 0 {
panic!("Csprng failed to generate random bytes");
}
*sk_bit = (bytes[0] & 1) as u64;
}
}
}
pub mod test {
use super::*;
use crate::implementation::types::CsprngMut;
impl LweSecretKey<Vec<u64>> {
pub fn new_random(csprng: CsprngMut, dim: usize) -> Self {
let mut sk = LweSecretKey::new(vec![0; dim], dim);
sk.as_mut_view().fill_with_new_key(csprng);
sk
}
}
}

View File

@@ -0,0 +1,54 @@
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct DecompParams {
pub level: usize,
pub base_log: usize,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct GlweParams {
pub dimension: usize,
pub polynomial_size: usize,
}
impl GlweParams {
pub fn lwe_dimension(self) -> usize {
self.dimension * self.polynomial_size
}
}
pub fn int_log2(a: usize) -> usize {
debug_assert!(a.is_power_of_two());
(a as f64).log2().ceil() as usize
}
mod ciphertext;
pub use ciphertext::*;
mod glwe_ciphertext;
pub use glwe_ciphertext::*;
mod lwe_secret_key;
pub use lwe_secret_key::*;
mod glwe_secret_key;
pub use glwe_secret_key::*;
mod ggsw_ciphertext;
pub use ggsw_ciphertext::*;
mod bootstrap_key;
pub use bootstrap_key::*;
mod keyswitch_key;
pub use keyswitch_key::*;
mod packing_keyswitch_key;
pub use packing_keyswitch_key::*;
pub mod packing_keyswitch_key_list;
mod csprng;
pub use csprng::*;
pub mod ciphertext_list;
pub mod polynomial_list;

View File

@@ -0,0 +1,171 @@
use super::*;
use crate::implementation::{Container, ContainerMut, Split};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct PackingKeyswitchKey<C: Container> {
pub data: C,
pub glwe_params: GlweParams,
pub input_dimension: usize,
pub decomp_params: DecompParams,
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct GlevCiphertext<C: Container> {
pub data: C,
pub glwe_params: GlweParams,
pub ciphertext_count: usize,
}
impl<C: Container> PackingKeyswitchKey<C> {
pub fn data_len(
glwe_params: GlweParams,
decomposition_level_count: usize,
input_dimension: usize,
) -> usize {
(input_dimension + 1)
* decomposition_level_count
* (glwe_params.dimension + 1)
* glwe_params.polynomial_size
}
pub fn new(
data: C,
glwe_params: GlweParams,
input_dimension: usize,
decomp_params: DecompParams,
) -> Self {
debug_assert_eq!(
data.len(),
Self::data_len(glwe_params, decomp_params.level, input_dimension),
);
Self {
data,
glwe_params,
input_dimension,
decomp_params,
}
}
pub unsafe fn from_raw_parts(
data: C::Pointer,
glwe_params: GlweParams,
input_dimension: usize,
decomp_params: DecompParams,
) -> Self
where
C: Split,
{
Self {
data: C::from_raw_parts(
data,
Self::data_len(glwe_params, decomp_params.level, input_dimension),
),
glwe_params,
input_dimension,
decomp_params,
}
}
pub fn as_view(&self) -> PackingKeyswitchKey<&[C::Item]> {
PackingKeyswitchKey {
data: self.data.as_ref(),
glwe_params: self.glwe_params,
input_dimension: self.input_dimension,
decomp_params: self.decomp_params,
}
}
pub fn as_mut_view(&mut self) -> PackingKeyswitchKey<&mut [C::Item]>
where
C: ContainerMut,
{
PackingKeyswitchKey {
data: self.data.as_mut(),
glwe_params: self.glwe_params,
input_dimension: self.input_dimension,
decomp_params: self.decomp_params,
}
}
pub fn into_data(self) -> C {
self.data
}
pub fn into_glev_ciphertexts(self) -> impl DoubleEndedIterator<Item = GlevCiphertext<C>>
where
C: Split,
{
self.data
.split_into(self.input_dimension)
.map(move |slice| {
GlevCiphertext::new(slice, self.glwe_params, self.decomp_params.level)
})
}
}
impl<C: Container> GlevCiphertext<C> {
pub fn data_len(glwe_params: GlweParams, ciphertext_count: usize) -> usize {
glwe_params.polynomial_size * (glwe_params.dimension + 1) * ciphertext_count
}
pub fn new(data: C, glwe_params: GlweParams, ciphertext_count: usize) -> Self {
debug_assert_eq!(data.len(), Self::data_len(glwe_params, ciphertext_count));
Self {
data,
glwe_params,
ciphertext_count,
}
}
pub unsafe fn from_raw_parts(
data: C::Pointer,
glwe_params: GlweParams,
ciphertext_count: usize,
) -> Self
where
C: Split,
{
let data = C::from_raw_parts(data, Self::data_len(glwe_params, ciphertext_count));
Self {
data,
glwe_params,
ciphertext_count,
}
}
pub fn as_view(&self) -> GlevCiphertext<&[C::Item]> {
GlevCiphertext {
data: self.data.as_ref(),
glwe_params: self.glwe_params,
ciphertext_count: self.ciphertext_count,
}
}
pub fn as_mut_view(&mut self) -> GlevCiphertext<&mut [C::Item]>
where
C: ContainerMut,
{
GlevCiphertext {
data: self.data.as_mut(),
glwe_params: self.glwe_params,
ciphertext_count: self.ciphertext_count,
}
}
pub fn into_data(self) -> C {
self.data
}
pub fn into_ciphertext_iter(self) -> impl DoubleEndedIterator<Item = GlweCiphertext<C>>
where
C: Split,
{
self.data
.split_into(self.ciphertext_count)
.map(move |slice| GlweCiphertext::new(slice, self.glwe_params))
}
}

View File

@@ -0,0 +1,171 @@
use super::*;
use crate::implementation::{zip_eq, Container, ContainerMut, Split};
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[readonly::make]
pub struct PackingKeyswitchKeyList<C: Container> {
pub data: C,
pub glwe_params: GlweParams,
pub input_dimension: usize,
pub decomp_params: DecompParams,
pub count: usize,
}
impl<C: Container> PackingKeyswitchKeyList<C> {
pub fn data_len(
glwe_params: GlweParams,
decomposition_level_count: usize,
input_dimension: usize,
count: usize,
) -> usize {
(input_dimension + 1)
* decomposition_level_count
* (glwe_params.dimension + 1)
* glwe_params.polynomial_size
* count
}
pub fn new(
data: C,
glwe_params: GlweParams,
input_dimension: usize,
decomp_params: DecompParams,
count: usize,
) -> Self {
debug_assert_eq!(
data.len(),
Self::data_len(glwe_params, decomp_params.level, input_dimension, count),
);
Self {
data,
glwe_params,
input_dimension,
decomp_params,
count,
}
}
pub unsafe fn from_raw_parts(
data: C::Pointer,
glwe_params: GlweParams,
input_dimension: usize,
decomp_params: DecompParams,
count: usize,
) -> Self
where
C: Split,
{
Self {
data: C::from_raw_parts(
data,
Self::data_len(glwe_params, decomp_params.level, input_dimension, count),
),
glwe_params,
input_dimension,
decomp_params,
count,
}
}
pub fn into_data(self) -> C {
self.data
}
pub fn into_ppksk_key(self) -> impl DoubleEndedIterator<Item = PackingKeyswitchKey<C>>
where
C: Split,
{
let glwe_params = self.glwe_params;
let input_dimension = self.input_dimension;
let decomp_params = self.decomp_params;
let count = self.count;
self.into_data().split_into(count).map(move |slice| {
PackingKeyswitchKey::new(slice, glwe_params, input_dimension, decomp_params)
})
}
pub fn as_mut_view(&mut self) -> PackingKeyswitchKeyList<&mut [C::Item]>
where
C: ContainerMut,
{
PackingKeyswitchKeyList {
data: self.data.as_mut(),
glwe_params: self.glwe_params,
input_dimension: self.input_dimension,
decomp_params: self.decomp_params,
count: self.count,
}
}
}
impl PackingKeyswitchKeyList<&mut [u64]> {
pub fn fill_with_fpksk_for_circuit_bootstrap(
&mut self,
input_lwe_key: &LweSecretKey<&[u64]>,
output_glwe_key: &GlweSecretKey<&[u64]>,
variance: f64,
mut csprng: CsprngMut,
) {
let glwe_params = output_glwe_key.glwe_params;
let polynomial_size = glwe_params.polynomial_size;
debug_assert_eq!(self.count, output_glwe_key.glwe_params.dimension + 1);
let mut last_polynomial = vec![0; polynomial_size];
// We apply the x -> -x function so instead of putting one in the first coeff of the
// polynomial, we put Scalar::MAX == - Sclar::One so that we can use a single function in
// the loop avoiding branching
last_polynomial[0] = u64::MAX;
for (mut fpksk, polynomial_to_encrypt) in zip_eq(
self.as_mut_view().into_ppksk_key(),
output_glwe_key
.data
.chunks_exact(polynomial_size)
.chain(std::iter::once(last_polynomial.as_slice())),
) {
fpksk.fill_with_private_functional_packing_keyswitch_key(
input_lwe_key,
output_glwe_key,
variance,
csprng.as_mut(),
|x: u64| x.wrapping_neg(),
polynomial_to_encrypt,
);
}
}
pub fn fill_with_fpksk_for_circuit_bootstrap_par(
&mut self,
input_lwe_key: &LweSecretKey<&[u64]>,
output_glwe_key: &GlweSecretKey<&[u64]>,
variance: f64,
mut csprng: CsprngMut,
) {
let glwe_params = output_glwe_key.glwe_params;
let polynomial_size = glwe_params.polynomial_size;
debug_assert_eq!(self.count, output_glwe_key.glwe_params.dimension + 1);
let mut last_polynomial = vec![0; polynomial_size];
// We apply the x -> -x function so instead of putting one in the first coeff of the
// polynomial, we put Scalar::MAX == - Sclar::One so that we can use a single function in
// the loop avoiding branching
last_polynomial[0] = u64::MAX;
for (mut fpksk, polynomial_to_encrypt) in zip_eq(
self.as_mut_view().into_ppksk_key(),
output_glwe_key
.data
.chunks_exact(polynomial_size)
.chain(std::iter::once(last_polynomial.as_slice())),
) {
fpksk.fill_with_private_functional_packing_keyswitch_key_par(
input_lwe_key,
output_glwe_key,
variance,
csprng.as_mut(),
|x: u64| x.wrapping_neg(),
polynomial_to_encrypt,
);
}
}
}

View File

@@ -0,0 +1,47 @@
use crate::implementation::Container;
#[derive(Debug, Clone)]
pub struct PolynomialList<C: Container> {
pub data: C,
pub count: usize,
pub polynomial_size: usize,
}
impl<C: Container> PolynomialList<C> {
pub fn new(data: C, polynomial_size: usize, count: usize) -> Self {
debug_assert_eq!(data.len(), polynomial_size * count);
Self {
data,
count,
polynomial_size,
}
}
fn container_len(&self) -> usize {
self.data.len()
}
}
impl PolynomialList<&[u64]> {
pub fn iter_polynomial(&self) -> impl DoubleEndedIterator<Item = &'_ [u64]> {
self.data.chunks_exact(self.polynomial_size)
}
// Creates an iterator over borrowed sub-lists.
pub fn sublist_iter(
&self,
count: usize,
) -> impl DoubleEndedIterator<Item = PolynomialList<&[u64]>> {
let polynomial_size = self.polynomial_size;
debug_assert_eq!(self.count % count, 0);
self.data
.chunks_exact(count * polynomial_size)
.map(move |sub| PolynomialList {
data: sub,
polynomial_size,
count,
})
}
}

View File

@@ -0,0 +1,963 @@
#![allow(clippy::too_many_arguments)]
use std::cmp::Ordering;
use aligned_vec::CACHELINE_ALIGN;
use dyn_stack::{DynStack, ReborrowMut, SizeOverflow, StackReq};
use crate::implementation::{external_product::external_product, types::GlweCiphertext, zip_eq};
use super::{
cmux::{cmux, cmux_scratch},
external_product::external_product_scratch,
fft::FftView,
polynomial::update_with_wrapping_unit_monomial_div,
types::{
ciphertext_list::LweCiphertextList, packing_keyswitch_key_list::PackingKeyswitchKeyList,
polynomial_list::PolynomialList, BootstrapKey, DecompParams, GgswCiphertext, GlweParams,
LweCiphertext, LweKeyswitchKey,
},
Container, Split,
};
pub fn extract_bits_scratch(
lwe_dimension: usize,
ksk_after_key_size: usize,
glwe_params: GlweParams,
fft: FftView,
) -> Result<StackReq, SizeOverflow> {
let align = CACHELINE_ALIGN;
let GlweParams {
dimension,
polynomial_size,
} = glwe_params;
let lwe_in_buffer = StackReq::try_new_aligned::<u64>(lwe_dimension + 1, align)?;
let lwe_out_ks_buffer = StackReq::try_new_aligned::<u64>(ksk_after_key_size + 1, align)?;
let pbs_accumulator =
StackReq::try_new_aligned::<u64>((dimension + 1) * polynomial_size, align)?;
let lwe_out_pbs_buffer =
StackReq::try_new_aligned::<u64>(dimension * polynomial_size + 1, align)?;
let lwe_bit_left_shift_buffer = lwe_in_buffer;
let bootstrap_scratch = BootstrapKey::bootstrap_scratch(glwe_params, fft)?;
lwe_in_buffer
.try_and(lwe_out_ks_buffer)?
.try_and(pbs_accumulator)?
.try_and(lwe_out_pbs_buffer)?
.try_and(StackReq::try_any_of([
lwe_bit_left_shift_buffer,
bootstrap_scratch,
])?)
}
/// Function to extract `number_of_bits_to_extract` from an [`LweCiphertext`] starting at the bit
/// number `delta_log` (0-indexed) included.
///
/// Output bits are ordered from the MSB to the LSB. Each one of them is output in a distinct LWE
/// ciphertext, containing the encryption of the bit scaled by q/2 (i.e., the most significant bit
/// in the plaintext representation).
pub fn extract_bits(
mut lwe_list_out: LweCiphertextList<&mut [u64]>,
lwe_in: LweCiphertext<&[u64]>,
ksk: LweKeyswitchKey<&[u64]>,
fourier_bsk: BootstrapKey<&[f64]>,
delta_log: usize,
number_of_bits_to_extract: usize,
fft: FftView<'_>,
stack: DynStack<'_>,
) {
let ciphertext_n_bits = u64::BITS as usize;
let number_of_bits_to_extract = number_of_bits_to_extract;
debug_assert!(
ciphertext_n_bits >= number_of_bits_to_extract + delta_log,
"Tried to extract {} bits, while the maximum number of extractable bits for {} bits
ciphertexts and a scaling factor of 2^{} is {}",
number_of_bits_to_extract,
ciphertext_n_bits,
delta_log,
ciphertext_n_bits - delta_log,
);
debug_assert_eq!(lwe_list_out.lwe_dimension, ksk.output_dimension,);
debug_assert_eq!(lwe_list_out.count, number_of_bits_to_extract,);
debug_assert_eq!(lwe_in.lwe_dimension, fourier_bsk.output_lwe_dimension(),);
let polynomial_size = fourier_bsk.glwe_params.polynomial_size;
let glwe_dimension = fourier_bsk.glwe_params.dimension;
let align = CACHELINE_ALIGN;
let (mut lwe_in_buffer_data, stack) =
stack.collect_aligned(align, lwe_in.into_data().iter().copied());
let mut lwe_in_buffer = LweCiphertext::new(&mut *lwe_in_buffer_data, lwe_in.lwe_dimension);
let (mut lwe_out_ks_buffer_data, stack) =
stack.make_aligned_with(ksk.output_dimension + 1, align, |_| 0_u64);
let mut lwe_out_ks_buffer =
LweCiphertext::new(&mut *lwe_out_ks_buffer_data, lwe_list_out.lwe_dimension);
let (mut pbs_accumulator_data, stack) =
stack.make_aligned_with((glwe_dimension + 1) * polynomial_size, align, |_| 0_u64);
let mut pbs_accumulator =
GlweCiphertext::new(&mut *pbs_accumulator_data, fourier_bsk.glwe_params);
let lwe_size = glwe_dimension * polynomial_size + 1;
let (mut lwe_out_pbs_buffer_data, mut stack) =
stack.make_aligned_with(lwe_size, align, |_| 0_u64);
let mut lwe_out_pbs_buffer = LweCiphertext::new(
&mut *lwe_out_pbs_buffer_data,
glwe_dimension * polynomial_size,
);
// We iterate on the list in reverse as we want to store the extracted MSB at index 0
for (bit_idx, mut output_ct) in lwe_list_out
.as_mut_view()
.ciphertext_iter_mut()
.rev()
.enumerate()
{
// Shift on padding bit
let (lwe_bit_left_shift_buffer_data, _) = stack.rb_mut().collect_aligned(
align,
lwe_in_buffer
.as_view()
.data
.iter()
.map(|s| *s << (ciphertext_n_bits - delta_log - bit_idx - 1)),
);
// Key switch to input PBS key
ksk.keyswitch_ciphertext(
lwe_out_ks_buffer.as_mut_view(),
LweCiphertext::new(&*lwe_bit_left_shift_buffer_data, ksk.input_dimension),
);
drop(lwe_bit_left_shift_buffer_data);
// Store the keyswitch output unmodified to the output list (as we need to to do other
// computations on the output of the keyswitch)
output_ct
.as_mut_view()
.into_data()
.copy_from_slice(lwe_out_ks_buffer.as_view().into_data());
// If this was the last extracted bit, break
// we subtract 1 because if the number_of_bits_to_extract is 1 we want to stop right away
if bit_idx == number_of_bits_to_extract - 1 {
break;
}
// Add q/4 to center the error while computing a negacyclic LUT
let out_ks_body = lwe_out_ks_buffer
.as_mut_view()
.into_data()
.last_mut()
.unwrap();
*out_ks_body = out_ks_body.wrapping_add(1_u64 << (ciphertext_n_bits - 2));
// Fill lut for the current bit (equivalent to trivial encryption as mask is 0s)
// The LUT is filled with -alpha in each coefficient where alpha = delta*2^{bit_idx-1}
for poly_coeff in pbs_accumulator.as_mut_view().into_body().into_data() {
*poly_coeff = (1_u64 << (delta_log - 1 + bit_idx)).wrapping_neg();
}
fourier_bsk.bootstrap(
lwe_out_pbs_buffer.as_mut_view(),
lwe_out_ks_buffer.as_view(),
pbs_accumulator.as_view(),
fft,
stack.rb_mut(),
);
// Add alpha where alpha = delta*2^{bit_idx-1} to end up with an encryption of 0 if the
// extracted bit was 0 and 1 in the other case
let out_pbs_body = lwe_out_pbs_buffer
.as_mut_view()
.into_data()
.last_mut()
.unwrap();
*out_pbs_body = out_pbs_body.wrapping_add(1_u64 << (delta_log + bit_idx - 1));
// Remove the extracted bit from the initial LWE to get a 0 at the extracted bit location.
for (out, inp) in zip_eq(
lwe_in_buffer.as_mut_view().into_data(),
lwe_out_pbs_buffer.as_view().into_data(),
) {
*out = out.wrapping_sub(*inp);
}
}
}
pub fn circuit_bootstrap_boolean_scratch(
lwe_in_size: usize,
bsk_output_lwe_size: usize,
glwe_params: GlweParams,
fft: FftView,
) -> Result<StackReq, SizeOverflow> {
StackReq::try_new_aligned::<u64>(bsk_output_lwe_size, CACHELINE_ALIGN)?.try_and(
homomorphic_shift_boolean_scratch(lwe_in_size, glwe_params, fft)?,
)
}
/// Circuit bootstrapping for boolean messages, i.e. containing only one bit of message
///
/// The output GGSW ciphertext `ggsw_out` decomposition base log and level count are used as the
/// circuit_bootstrap_boolean decomposition base log and level count.
pub fn circuit_bootstrap_boolean(
fourier_bsk: BootstrapKey<&[f64]>,
lwe_in: LweCiphertext<&[u64]>,
ggsw_out: GgswCiphertext<&mut [u64]>,
delta_log: usize,
fpksk_list: PackingKeyswitchKeyList<&[u64]>,
fft: FftView<'_>,
stack: DynStack<'_>,
) {
let level_cbs = ggsw_out.decomp_params.level;
let base_log_cbs = ggsw_out.decomp_params.base_log;
debug_assert_ne!(level_cbs, 0);
debug_assert_ne!(base_log_cbs, 0);
let fpksk_input_lwe_key_dimension = fpksk_list.input_dimension;
let fourier_bsk_output_lwe_dimension = fourier_bsk.output_lwe_dimension();
let glwe_params = fpksk_list.glwe_params;
debug_assert_eq!(glwe_params, ggsw_out.glwe_params);
debug_assert_eq!(
fpksk_input_lwe_key_dimension,
fourier_bsk_output_lwe_dimension,
);
debug_assert_eq!(glwe_params.dimension + 1, fpksk_list.count);
// Output for every bootstrapping
let (mut lwe_out_bs_buffer_data, mut stack) = stack.make_aligned_with(
fourier_bsk_output_lwe_dimension + 1,
CACHELINE_ALIGN,
|_| 0_u64,
);
let mut lwe_out_bs_buffer = LweCiphertext::new(
&mut *lwe_out_bs_buffer_data,
fourier_bsk_output_lwe_dimension,
);
// Output for every pfksk that that come from the output GGSW
let mut out_pfksk_buffer_iter = ggsw_out
.into_data()
.chunks_exact_mut((glwe_params.dimension + 1) * glwe_params.polynomial_size)
.map(|data| GlweCiphertext::new(data, glwe_params));
for decomposition_level in 1..=level_cbs {
homomorphic_shift_boolean(
fourier_bsk,
lwe_out_bs_buffer.as_mut_view(),
lwe_in,
decomposition_level,
base_log_cbs,
delta_log,
fft,
stack.rb_mut(),
);
for pfksk in fpksk_list.into_ppksk_key() {
let glwe_out = out_pfksk_buffer_iter.next().unwrap();
pfksk.private_functional_keyswitch_ciphertext(glwe_out, lwe_out_bs_buffer.as_view());
}
}
}
pub fn homomorphic_shift_boolean_scratch(
lwe_in_size: usize,
glwe_params: GlweParams,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
let align = CACHELINE_ALIGN;
StackReq::try_new_aligned::<u64>(lwe_in_size, align)?
.try_and(StackReq::try_new_aligned::<u64>(
(glwe_params.dimension + 1) * glwe_params.polynomial_size,
align,
)?)?
.try_and(BootstrapKey::bootstrap_scratch(glwe_params, fft)?)
}
/// Homomorphic shift for LWE without padding bit
///
/// Starts by shifting the message bit at bit #delta_log to the padding bit and then shifts it to
/// the right by base_log * level.
pub fn homomorphic_shift_boolean(
fourier_bsk: BootstrapKey<&[f64]>,
mut lwe_out: LweCiphertext<&mut [u64]>,
lwe_in: LweCiphertext<&[u64]>,
level_cbs: usize,
base_log_cbs: usize,
delta_log: usize,
fft: FftView<'_>,
stack: DynStack<'_>,
) {
let ciphertext_n_bits = u64::BITS;
let lwe_in_size = lwe_in.lwe_dimension + 1;
let polynomial_size = fourier_bsk.glwe_params.polynomial_size;
let (mut lwe_left_shift_buffer_data, stack) =
stack.make_aligned_with(lwe_in_size, CACHELINE_ALIGN, |_| 0_u64);
let mut lwe_left_shift_buffer =
LweCiphertext::new(&mut *lwe_left_shift_buffer_data, lwe_in.lwe_dimension);
// Shift message LSB on padding bit, at this point we expect to have messages with only 1 bit
// of information
let shift = 1 << (ciphertext_n_bits - delta_log as u32 - 1);
debug_assert_eq!(shift, 1);
for (a, b) in zip_eq(
lwe_left_shift_buffer.as_mut_view().into_data(),
lwe_in.into_data(),
) {
*a = b.wrapping_mul(shift);
}
// Add q/4 to center the error while computing a negacyclic LUT
let shift_buffer_body = lwe_left_shift_buffer
.as_mut_view()
.into_data()
.last_mut()
.unwrap();
*shift_buffer_body = shift_buffer_body.wrapping_add(1_u64 << (ciphertext_n_bits - 2));
let (mut pbs_accumulator_data, stack) = stack.make_aligned_with(
polynomial_size * (fourier_bsk.glwe_params.dimension + 1),
CACHELINE_ALIGN,
|_| 0_u64,
);
let mut pbs_accumulator =
GlweCiphertext::new(&mut *pbs_accumulator_data, fourier_bsk.glwe_params);
// Fill lut (equivalent to trivial encryption as mask is 0s)
// The LUT is filled with -alpha in each coefficient where
// alpha = 2^{log(q) - 1 - base_log * level}
let alpha = 1_u64 << (ciphertext_n_bits - 1 - base_log_cbs as u32 * level_cbs as u32);
for body in pbs_accumulator.as_mut_view().into_body().into_data() {
*body = alpha.wrapping_neg();
}
// Applying a negacyclic LUT on a ciphertext with one bit of message in the MSB and no bit
// of padding
fourier_bsk.bootstrap(
lwe_out.as_mut_view(),
lwe_left_shift_buffer.as_view(),
pbs_accumulator.as_view(),
fft,
stack,
);
// Add alpha where alpha = 2^{log(q) - 1 - base_log * level}
// To end up with an encryption of 0 if the message bit was 0 and 1 in the other case
let out_body = lwe_out.as_mut_view().into_data().last_mut().unwrap();
*out_body = out_body
.wrapping_add(1_u64 << (ciphertext_n_bits - 1 - base_log_cbs as u32 * level_cbs as u32));
}
pub type FourierGgswCiphertextListView<'a> = FourierGgswCiphertextList<&'a [f64]>;
pub type FourierGgswCiphertextListMutView<'a> = FourierGgswCiphertextList<&'a mut [f64]>;
pub type GlweCiphertextListView<'a, Scalar> = GlweCiphertextList<&'a [Scalar]>;
pub type GlweCiphertextListMutView<'a, Scalar> = GlweCiphertextList<&'a mut [Scalar]>;
#[derive(PartialEq, Eq, Debug, Clone, Copy)]
pub struct GlweCiphertextList<C: Container> {
pub data: C,
pub count: usize,
pub glwe_params: GlweParams,
}
#[derive(Debug, Clone)]
pub struct FourierGgswCiphertextList<C: Container<Item = f64>> {
pub fourier: PolynomialList<C>,
pub count: usize,
pub glwe_params: GlweParams,
pub decomp_params: DecompParams,
}
impl<C: Container> GlweCiphertextList<C> {
pub fn new(data: C, count: usize, glwe_params: GlweParams) -> Self {
debug_assert_eq!(
data.len(),
count * glwe_params.polynomial_size * (glwe_params.dimension + 1),
);
Self {
data,
count,
glwe_params,
}
}
pub fn as_view(&self) -> GlweCiphertextListView<'_, C::Item> {
GlweCiphertextListView {
data: self.data.as_ref(),
count: self.count,
glwe_params: self.glwe_params,
}
}
pub fn as_mut_view(&mut self) -> GlweCiphertextListMutView<'_, C::Item>
where
C: AsMut<[C::Item]>,
{
GlweCiphertextListMutView {
data: self.data.as_mut(),
count: self.count,
glwe_params: self.glwe_params,
}
}
pub fn into_glwe_iter(self) -> impl DoubleEndedIterator<Item = GlweCiphertext<C>>
where
C: Split,
{
self.data
.split_into(self.count)
.map(move |slice| GlweCiphertext::new(slice, self.glwe_params))
}
}
impl<C: Container<Item = f64>> FourierGgswCiphertextList<C> {
pub fn new(
data: C,
count: usize,
glwe_params: GlweParams,
decomp_params: DecompParams,
) -> Self {
debug_assert_eq!(
data.len(),
count
* glwe_params.polynomial_size
* (glwe_params.dimension + 1)
* (glwe_params.dimension + 1)
* decomp_params.level
);
Self {
fourier: PolynomialList {
data,
polynomial_size: glwe_params.polynomial_size,
count,
},
count,
glwe_params,
decomp_params,
}
}
pub fn as_view(&self) -> FourierGgswCiphertextListView<'_> {
let fourier = PolynomialList {
data: self.fourier.data.as_ref(),
polynomial_size: self.fourier.polynomial_size,
count: self.count,
};
FourierGgswCiphertextListView {
fourier,
count: self.count,
decomp_params: self.decomp_params,
glwe_params: self.glwe_params,
}
}
pub fn as_mut_view(&mut self) -> FourierGgswCiphertextListMutView<'_>
where
C: AsMut<[f64]>,
{
let fourier = PolynomialList {
data: self.fourier.data.as_mut(),
polynomial_size: self.fourier.polynomial_size,
count: self.count,
};
FourierGgswCiphertextListMutView {
fourier,
count: self.count,
decomp_params: self.decomp_params,
glwe_params: self.glwe_params,
}
}
pub fn into_ggsw_iter(self) -> impl DoubleEndedIterator<Item = GgswCiphertext<C>>
where
C: Split,
{
self.fourier
.data
.split_into(self.count)
.map(move |slice| GgswCiphertext::new(slice, self.glwe_params, self.decomp_params))
}
pub fn split_at(self, mid: usize) -> (Self, Self)
where
C: Split,
{
let glwe_dim = self.glwe_params.dimension;
let polynomial_size = self.fourier.polynomial_size;
let (left, right) = self.fourier.data.split_at(
mid * polynomial_size * (glwe_dim + 1) * (glwe_dim + 1) * self.decomp_params.level,
);
(
Self::new(left, mid, self.glwe_params, self.decomp_params),
Self::new(
right,
self.count - mid,
self.glwe_params,
self.decomp_params,
),
)
}
}
pub fn cmux_tree_memory_optimized_scratch(
glwe_params: GlweParams,
nb_layer: usize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
let t_scratch = StackReq::try_new_aligned::<u64>(
(glwe_params.dimension + 1) * glwe_params.polynomial_size * nb_layer,
CACHELINE_ALIGN,
)?;
StackReq::try_all_of([
t_scratch, // t_0
t_scratch, // t_1
StackReq::try_new::<usize>(nb_layer)?, // t_fill
t_scratch, // diff
external_product_scratch(glwe_params, fft)?,
])
}
/// Performs a tree of cmux in a way that limits the total allocated memory to avoid issues for
/// bigger trees.
pub fn cmux_tree_memory_optimized(
mut output_glwe: GlweCiphertext<&mut [u64]>,
lut_per_layer: PolynomialList<&[u64]>,
ggsw_list: FourierGgswCiphertextListView<'_>,
fft: FftView<'_>,
stack: DynStack<'_>,
) {
debug_assert_eq!(lut_per_layer.count, 1 << ggsw_list.count);
debug_assert!(ggsw_list.count > 0);
let glwe_dim = ggsw_list.glwe_params.dimension;
let polynomial_size = ggsw_list.glwe_params.polynomial_size;
let nb_layer = ggsw_list.count;
debug_assert!(stack.can_hold(
cmux_tree_memory_optimized_scratch(output_glwe.glwe_params, nb_layer, fft).unwrap()
));
// These are accumulator that will be used to propagate the result from layer to layer
// At index 0 you have the lut that will be loaded, and then the result for each layer gets
// computed at the next index, last layer result gets stored in `result`.
// This allow to use memory space in C * nb_layer instead of C' * 2 ^ nb_layer
let (mut t_0_data, stack) = stack.make_aligned_with(
polynomial_size * (glwe_dim + 1) * nb_layer,
CACHELINE_ALIGN,
|_| 0_u64,
);
let (mut t_1_data, stack) = stack.make_aligned_with(
polynomial_size * (glwe_dim + 1) * nb_layer,
CACHELINE_ALIGN,
|_| 0_u64,
);
let mut t_0 = GlweCiphertextList::new(t_0_data.as_mut(), nb_layer, ggsw_list.glwe_params);
let mut t_1 = GlweCiphertextList::new(t_1_data.as_mut(), nb_layer, ggsw_list.glwe_params);
let (mut t_fill, mut stack) = stack.make_with(nb_layer, |_| 0_usize);
let mut lut_polynomial_iter = lut_per_layer.iter_polynomial();
loop {
let even = lut_polynomial_iter.next();
let odd = lut_polynomial_iter.next();
let (lut_2i, lut_2i_plus_1) = match (even, odd) {
(Some(even), Some(odd)) => (even, odd),
_ => break,
};
let mut t_iter = zip_eq(
t_0.as_mut_view().into_glwe_iter(),
t_1.as_mut_view().into_glwe_iter(),
)
.enumerate();
let (mut j_counter, (mut t0_j, mut t1_j)) = t_iter.next().unwrap();
t0_j.as_mut_view()
.into_body()
.into_data()
.copy_from_slice(lut_2i);
t1_j.as_mut_view()
.into_body()
.into_data()
.copy_from_slice(lut_2i_plus_1);
t_fill[0] = 2;
for (j, ggsw) in ggsw_list.as_view().into_ggsw_iter().rev().enumerate() {
if t_fill[j] == 2 {
let (diff_data, stack) = stack.rb_mut().collect_aligned(
CACHELINE_ALIGN,
zip_eq(t1_j.as_view().into_data(), t0_j.as_view().data)
.map(|(a, b)| a.wrapping_sub(*b)),
);
let diff = GlweCiphertext::new(&*diff_data, ggsw_list.glwe_params);
if j != nb_layer - 1 {
let (j_counter_plus_1, (mut t_0_j_plus_1, mut t_1_j_plus_1)) =
t_iter.next().unwrap();
debug_assert_eq!(j_counter, j);
debug_assert_eq!(j_counter_plus_1, j + 1);
let mut output = if t_fill[j + 1] == 0 {
t_0_j_plus_1.as_mut_view()
} else {
t_1_j_plus_1.as_mut_view()
};
output
.as_mut_view()
.into_data()
.copy_from_slice(t0_j.as_view().data);
external_product(output, ggsw, diff, fft, stack);
t_fill[j + 1] += 1;
t_fill[j] = 0;
drop(diff_data);
(j_counter, t0_j, t1_j) = (j_counter_plus_1, t_0_j_plus_1, t_1_j_plus_1);
} else {
let mut output = output_glwe.as_mut_view();
output
.as_mut_view()
.into_data()
.copy_from_slice(t0_j.as_view().data);
external_product(output, ggsw, diff, fft, stack);
}
} else {
break;
}
}
}
}
pub fn circuit_bootstrap_boolean_vertical_packing_scratch(
lwe_list_in_count: usize,
lwe_list_out_count: usize,
lwe_in_size: usize,
big_lut_polynomial_count: usize,
bsk_output_lwe_size: usize,
fpksk_output_polynomial_size: usize,
glwe_dimension: usize,
level_cbs: usize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
// We deduce the number of luts in the vec_lut from the number of cipherxtexts in lwe_list_out
let number_of_luts = lwe_list_out_count;
let small_lut_size = big_lut_polynomial_count / number_of_luts;
StackReq::try_all_of([
StackReq::try_new_aligned::<f64>(
lwe_list_in_count
* fpksk_output_polynomial_size
* (glwe_dimension + 1)
* (glwe_dimension + 1)
* level_cbs,
CACHELINE_ALIGN,
)?,
StackReq::try_new_aligned::<u64>(
fpksk_output_polynomial_size * (glwe_dimension + 1) * (glwe_dimension + 1) * level_cbs,
CACHELINE_ALIGN,
)?,
StackReq::try_any_of([
circuit_bootstrap_boolean_scratch(
lwe_in_size,
bsk_output_lwe_size,
GlweParams {
dimension: glwe_dimension,
polynomial_size: fpksk_output_polynomial_size,
},
fft,
)?,
fft.forward_scratch()?,
vertical_packing_scratch(
GlweParams {
dimension: glwe_dimension,
polynomial_size: fpksk_output_polynomial_size,
},
small_lut_size,
lwe_list_in_count,
fft,
)?,
])?,
])
}
/// Perform a circuit bootstrap followed by a vertical packing on ciphertexts encrypting boolean
/// messages.
///
/// The circuit bootstrapping uses the private functional packing key switch.
///
/// This is supposed to be used only with boolean (1 bit of message) LWE ciphertexts.
pub fn circuit_bootstrap_boolean_vertical_packing(
luts: PolynomialList<&[u64]>,
fourier_bsk: BootstrapKey<&[f64]>,
mut lwe_list_out: LweCiphertextList<&mut [u64]>,
lwe_list_in: LweCiphertextList<&[u64]>,
fpksk_list: PackingKeyswitchKeyList<&[u64]>,
cbs_dp: DecompParams,
fft: FftView<'_>,
stack: DynStack<'_>,
) {
debug_assert!(stack.can_hold(
circuit_bootstrap_boolean_vertical_packing_scratch(
lwe_list_in.count,
lwe_list_out.count,
lwe_list_in.lwe_dimension + 1,
luts.count,
fourier_bsk.output_lwe_dimension() + 1,
fpksk_list.glwe_params.polynomial_size,
fourier_bsk.glwe_params.dimension + 1,
cbs_dp.level,
fft
)
.unwrap()
));
debug_assert_ne!(lwe_list_in.count, 0);
debug_assert_eq!(
lwe_list_out.lwe_dimension,
fourier_bsk.output_lwe_dimension(),
);
let glwe_dim = fpksk_list.glwe_params.dimension;
let (mut ggsw_list_data, stack) = stack.make_aligned_with(
lwe_list_in.count
* fpksk_list.glwe_params.polynomial_size
* (glwe_dim + 1)
* (glwe_dim + 1)
* cbs_dp.level,
CACHELINE_ALIGN,
|_| f64::default(),
);
let (mut ggsw_res_data, mut stack) = stack.make_aligned_with(
fpksk_list.glwe_params.polynomial_size * (glwe_dim + 1) * (glwe_dim + 1) * cbs_dp.level,
CACHELINE_ALIGN,
|_| 0_u64,
);
let mut ggsw_list = FourierGgswCiphertextList::new(
&mut *ggsw_list_data,
lwe_list_in.count,
fpksk_list.glwe_params,
cbs_dp,
);
let mut ggsw_res = GgswCiphertext::new(&mut *ggsw_res_data, fpksk_list.glwe_params, cbs_dp);
for (lwe_in, ggsw) in zip_eq(
lwe_list_in.ciphertext_iter(),
ggsw_list.as_mut_view().into_ggsw_iter(),
) {
circuit_bootstrap_boolean(
fourier_bsk,
lwe_in,
ggsw_res.as_mut_view(),
u64::BITS as usize - 1,
fpksk_list,
fft,
stack.rb_mut(),
);
ggsw.fill_with_forward_fourier(ggsw_res.as_view(), fft, stack.rb_mut());
}
// We deduce the number of luts in the vec_lut from the number of cipherxtexts in lwe_list_out
// debug_assert_eq!(lwe_list_out.count, small_lut_count);
debug_assert_eq!(lwe_list_out.count, luts.count);
for (lut, lwe_out) in zip_eq(luts.iter_polynomial(), lwe_list_out.ciphertext_iter_mut()) {
vertical_packing(lut, lwe_out, ggsw_list.as_view(), fft, stack.rb_mut());
}
}
fn print_ct(ct: u64) {
print!("{}", (((ct >> 53) + 1) >> 1) % (1 << 10));
}
pub fn vertical_packing_scratch(
glwe_params: GlweParams,
lut_polynomial_count: usize,
ggsw_list_count: usize,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
let bits = core::mem::size_of::<u64>() * 8;
// Get the base 2 logarithm (rounded down) of the number of polynomials in the list i.e. if
// there is one polynomial, the number will be 0
let log_lut_number: usize = bits - 1 - lut_polynomial_count.leading_zeros() as usize;
let log_number_of_luts_for_cmux_tree = if log_lut_number > ggsw_list_count {
// this means that we dont have enough GGSW to perform the CMux tree, we can only do the
// Blind rotation
0
} else {
log_lut_number
};
StackReq::try_all_of([
// cmux_tree_lut_res
StackReq::try_new_aligned::<u64>(
(glwe_params.dimension + 1) * glwe_params.polynomial_size,
CACHELINE_ALIGN,
)?,
StackReq::try_any_of([
blind_rotate_scratch(glwe_params, fft)?,
cmux_tree_memory_optimized_scratch(glwe_params, log_number_of_luts_for_cmux_tree, fft)?,
])?,
])
}
fn log2(a: usize) -> usize {
let result = u64::BITS as usize - 1 - a.leading_zeros() as usize;
debug_assert_eq!(a, 1 << result);
result
}
// GGSW ciphertexts are stored from the msb (vec_ggsw[0]) to the lsb (vec_ggsw[last])
pub fn vertical_packing(
lut: &[u64],
lwe_out: LweCiphertext<&mut [u64]>,
ggsw_list: FourierGgswCiphertextListView,
fft: FftView,
stack: DynStack<'_>,
) {
let glwe_params = ggsw_list.glwe_params;
let polynomial_size = glwe_params.polynomial_size;
let glwe_dimension = glwe_params.dimension;
debug_assert_eq!(lwe_out.lwe_dimension, polynomial_size * glwe_dimension);
let log_lut_number = log2(lut.len());
debug_assert_eq!(ggsw_list.count, log_lut_number);
let log_poly_size = log2(polynomial_size);
let (mut cmux_tree_lut_res_data, mut stack) = stack.make_aligned_with(
polynomial_size * (glwe_dimension + 1),
CACHELINE_ALIGN,
|_| 0_u64,
);
let mut cmux_tree_lut_res = GlweCiphertext::new(&mut *cmux_tree_lut_res_data, glwe_params);
let br_ggsw = match log_lut_number.cmp(&log_poly_size) {
Ordering::Less => {
cmux_tree_lut_res
.as_mut_view()
.into_data()
.fill_with(|| 0_u64);
cmux_tree_lut_res.as_mut_view().into_body().into_data()[0..lut.len()]
.copy_from_slice(lut);
ggsw_list
}
Ordering::Equal => {
cmux_tree_lut_res
.as_mut_view()
.into_data()
.fill_with(|| 0_u64);
cmux_tree_lut_res
.as_mut_view()
.into_body()
.into_data()
.copy_from_slice(lut);
ggsw_list
}
Ordering::Greater => {
let log_number_of_luts_for_cmux_tree = log_lut_number - log_poly_size;
// split the vec of GGSW in two, the msb GGSW is for the CMux tree and the lsb GGSW is
// for the last blind rotation.
let (cmux_ggsw, br_ggsw) = ggsw_list.split_at(log_number_of_luts_for_cmux_tree);
debug_assert_eq!(br_ggsw.count, log_poly_size);
let small_luts =
PolynomialList::new(lut, polynomial_size, 1 << (log_lut_number - log_poly_size));
cmux_tree_memory_optimized(
cmux_tree_lut_res.as_mut_view(),
small_luts,
cmux_ggsw,
fft,
stack.rb_mut(),
);
br_ggsw
}
};
blind_rotate(
cmux_tree_lut_res.as_mut_view(),
br_ggsw,
fft,
stack.rb_mut(),
);
// sample extract of the RLWE of the Vertical packing
cmux_tree_lut_res
.as_view()
.fill_lwe_with_sample_extraction(lwe_out, 0);
}
pub fn blind_rotate_scratch(
glwe_params: GlweParams,
fft: FftView<'_>,
) -> Result<StackReq, SizeOverflow> {
StackReq::try_all_of([
StackReq::try_new_aligned::<u64>(
(glwe_params.dimension + 1) * glwe_params.polynomial_size,
CACHELINE_ALIGN,
)?,
cmux_scratch(glwe_params, fft)?,
])
}
pub fn blind_rotate(
mut lut: GlweCiphertext<&mut [u64]>,
ggsw_list: FourierGgswCiphertextListView<'_>,
fft: FftView<'_>,
mut stack: DynStack<'_>,
) {
let mut monomial_degree = 1;
for ggsw in ggsw_list.into_ggsw_iter().rev() {
let ct_0 = lut.as_mut_view();
let (mut ct1_data, stack) = stack
.rb_mut()
.collect_aligned(CACHELINE_ALIGN, ct_0.as_view().into_data().iter().copied());
let mut ct_1 = GlweCiphertext::new(&mut *ct1_data, ct_0.glwe_params);
for a in ct_1
.as_mut_view()
.into_data()
.chunks_exact_mut(ct_0.glwe_params.polynomial_size)
{
update_with_wrapping_unit_monomial_div(a, monomial_degree);
}
monomial_degree <<= 1;
cmux(ct_0, ct_1, ggsw, fft, stack);
}
}

View File

@@ -0,0 +1,8 @@
#![allow(clippy::missing_safety_doc, dead_code)]
#![cfg_attr(feature = "nightly", feature(stdsimd))]
#![cfg_attr(feature = "nightly", feature(avx512_target_feature))]
extern crate alloc;
pub mod c_api;
mod implementation;

View File

@@ -0,0 +1,19 @@
regenererate_lib:
rm -rf ../include/concrete-cpu.h
cargo build
test_encryption: regenererate_lib
zig test test_encryption.zig -I.. -lc -lconcrete_cpu -lunwind -L../../target/debug/
test_bootstrap: regenererate_lib
zig test test_bootstrap.zig -I.. -lc -lconcrete_cpu -lunwind -L../../target/debug/
test_vertical_packing: regenererate_lib
zig test test_vertical_packing.zig -I.. -lc -lconcrete_cpu -lunwind -L../../target/debug/
test_bit_extract: regenererate_lib
zig test test_bit_extract.zig -I.. -lc -lconcrete_cpu -lunwind -L../../target/debug/
test: test_encryption test_bootstrap test_bit_extract test_vertical_packing

View File

@@ -0,0 +1,107 @@
const c = @cImport({
@cInclude("stdlib.h");
});
const std = @import("std");
const allocator = std.heap.page_allocator;
const cpu = @cImport({
@cInclude("include/concrete-cpu.h");
});
pub fn new_bsk(
csprng: *cpu.Csprng,
in_dim: usize,
glwe_dim: u64,
polynomial_size: u64,
level: u64,
base_log: u64,
key_variance: f64,
in_sk: []u64,
out_sk: []u64,
fft: *cpu.Fft,
) ![]f64 {
const bsk_size = cpu.concrete_cpu_bootstrap_key_size_u64(level, glwe_dim, polynomial_size, in_dim);
const bsk = try allocator.alloc(u64, bsk_size);
defer allocator.free(bsk);
cpu.concrete_cpu_init_lwe_bootstrap_key_u64(
bsk.ptr,
in_sk.ptr,
out_sk.ptr,
in_dim,
polynomial_size,
glwe_dim,
level,
base_log,
key_variance,
1,
csprng,
&cpu.CONCRETE_CSPRNG_VTABLE,
);
const bsk_f = try allocator.alloc(f64, bsk_size);
{
var stack_size: u64 = 0;
var stack_align: u64 = 0;
try std.testing.expect(
cpu.concrete_cpu_bootstrap_key_convert_u64_to_fourier_scratch(
&stack_size,
&stack_align,
fft,
) == 0,
);
const stack = @ptrCast([*]u8, c.aligned_alloc(stack_align, stack_size) orelse unreachable)[0..stack_size];
defer c.free(stack.ptr);
cpu.concrete_cpu_bootstrap_key_convert_u64_to_fourier(
bsk.ptr,
bsk_f.ptr,
level,
base_log,
glwe_dim,
polynomial_size,
in_dim,
fft,
stack.ptr,
stack.len,
);
}
return bsk_f;
}
pub fn closest_representable(input: u64, level_count: u64, base_log: u64) u64 {
// The closest number representable by the decomposition can be computed by performing
// the rounding at the appropriate bit.
// We compute the number of least significant bits which can not be represented by the
// decomposition
const non_rep_bit_count: u64 = 64 - (level_count * base_log);
// We generate a mask which captures the non representable bits
const one: u64 = 1;
const non_rep_mask = one << @intCast(u6, non_rep_bit_count - 1);
// We retrieve the non representable bits
const non_rep_bits = input & non_rep_mask;
// We extract the msb of the non representable bits to perform the rounding
const non_rep_msb = non_rep_bits >> @intCast(u6, non_rep_bit_count - 1);
// We remove the non-representable bits and perform the rounding
var res = input >> @intCast(u6, non_rep_bit_count);
res += non_rep_msb;
return res << @intCast(u6, non_rep_bit_count);
}
pub fn highest_bits(encoded: u64) ![]u8 {
const precision = 11;
var buffer = try allocator.alloc(u8, precision + 2);
const one: u64 = 1;
const high_bits = (encoded +% (one << @intCast(u6, 64 - precision))) >> @intCast(u6, 64 - precision);
return std.fmt.bufPrint(buffer, "0.{b:0>11}", .{high_bits});
}

View File

@@ -0,0 +1,164 @@
const c = @cImport({
@cInclude("stdlib.h");
});
const std = @import("std");
const allocator = std.heap.page_allocator;
const common = @import("common.zig");
const cpu = @cImport({
@cInclude("include/concrete-cpu.h");
});
fn test3(csprng: *cpu.Csprng) !void {
const polynomial_size: usize = 1024;
const glwe_dim: usize = 1;
const small_dim: usize = 585;
const level_bsk: usize = 2;
const base_log_bsk: usize = 10;
const level_ksk: usize = 7;
const base_log_ksk: usize = 4;
const variance = std.math.pow(f64, 2, -2 * 60);
const number_of_bits_of_message = 5;
var raw_fft = c.aligned_alloc(cpu.CONCRETE_FFT_ALIGN, cpu.CONCRETE_FFT_SIZE);
const fft = @ptrCast(*cpu.Fft, raw_fft);
cpu.concrete_cpu_construct_concrete_fft(fft, polynomial_size);
const big_dim = glwe_dim * polynomial_size;
const small_sk = try allocator.alloc(u64, small_dim);
cpu.concrete_cpu_init_lwe_secret_key_u64(small_sk.ptr, small_dim, csprng, &cpu.CONCRETE_CSPRNG_VTABLE);
const big_sk = try allocator.alloc(u64, big_dim);
cpu.concrete_cpu_init_lwe_secret_key_u64(big_sk.ptr, big_dim, csprng, &cpu.CONCRETE_CSPRNG_VTABLE);
const bsk_f = try common.new_bsk(
csprng,
small_dim,
glwe_dim,
polynomial_size,
level_bsk,
base_log_bsk,
variance,
small_sk,
big_sk,
fft,
);
defer allocator.free(bsk_f);
const ksk_size = cpu.concrete_cpu_keyswitch_key_size_u64(level_ksk, base_log_ksk, big_dim, small_dim);
const ksk = try allocator.alloc(u64, ksk_size);
defer allocator.free(ksk);
cpu.concrete_cpu_init_lwe_keyswitch_key_u64(
ksk.ptr,
big_sk.ptr,
small_sk.ptr,
big_dim,
small_dim,
level_ksk,
base_log_ksk,
variance,
csprng,
&cpu.CONCRETE_CSPRNG_VTABLE,
);
const delta_log = 64 - number_of_bits_of_message;
// 19 in binary is 10011, so has the high bit, low bit set and is not symetrical
const val: u64 = 19;
std.debug.assert(1 << number_of_bits_of_message > val);
const message = val << delta_log;
// We will extract all bits
const number_of_bits_to_extract = number_of_bits_of_message;
const in_ct = try allocator.alloc(u64, big_dim + 1);
defer allocator.free(in_ct);
cpu.concrete_cpu_encrypt_lwe_ciphertext_u64(
big_sk.ptr,
in_ct.ptr,
message,
big_dim,
variance,
csprng,
&cpu.CONCRETE_CSPRNG_VTABLE,
);
const out_cts = try allocator.alloc(u64, (small_dim + 1) * number_of_bits_to_extract);
defer allocator.free(out_cts);
var stack_align: usize = 0;
var stack_size: usize = 0;
try std.testing.expect(cpu.concrete_cpu_extract_bit_lwe_ciphertext_u64_scratch(
&stack_size,
&stack_align,
small_dim,
big_dim,
glwe_dim,
polynomial_size,
fft,
) == 0);
const stack = @ptrCast([*]u8, c.aligned_alloc(stack_align, stack_size) orelse unreachable)[0..stack_size];
defer c.free(stack.ptr);
cpu.concrete_cpu_extract_bit_lwe_ciphertext_u64(
out_cts.ptr,
in_ct.ptr,
bsk_f.ptr,
ksk.ptr,
small_dim,
number_of_bits_to_extract,
big_dim,
number_of_bits_to_extract,
delta_log,
level_bsk,
base_log_bsk,
glwe_dim,
polynomial_size,
small_dim,
level_ksk,
base_log_ksk,
big_dim,
small_dim,
fft,
stack.ptr,
stack.len,
);
var i: u64 = 0;
while (i < number_of_bits_to_extract) {
const expected = (val >> @intCast(u6, number_of_bits_of_message - 1 - i)) & 1;
var decrypted: u64 = 0;
cpu.concrete_cpu_decrypt_lwe_ciphertext_u64(small_sk.ptr, out_cts[(small_dim + 1) * i ..].ptr, small_dim, &decrypted);
const rounded = common.closest_representable(decrypted, 1, 1);
const decoded = rounded >> 63;
std.debug.assert(decoded == expected);
i += 1;
}
}
test "encryption" {
var raw_csprng = c.aligned_alloc(cpu.CONCRETE_CSPRNG_ALIGN, cpu.CONCRETE_CSPRNG_SIZE);
defer c.free(raw_csprng);
const csprng = @ptrCast(*cpu.Csprng, raw_csprng);
cpu.concrete_cpu_construct_concrete_csprng(
csprng,
cpu.Uint128{ .little_endian_bytes = [_]u8{1} ** 16 },
);
defer cpu.concrete_cpu_destroy_concrete_csprng(csprng);
try test3(csprng);
}

View File

@@ -0,0 +1,243 @@
const c = @cImport({
@cInclude("stdlib.h");
});
const std = @import("std");
const allocator = std.heap.page_allocator;
const random = std.rand.Random;
const common = @import("common.zig");
const cpu = @cImport({
@cInclude("include/concrete-cpu.h");
});
const KeySet = struct {
in_dim: u64,
glwe_dim: u64,
polynomial_size: u64,
level: u64,
base_log: u64,
in_sk: []u64,
out_sk: []u64,
bsk_f: []f64,
fft: *cpu.Fft,
stack: []u8,
pub fn init(
csprng: *cpu.Csprng,
in_dim: u64,
glwe_dim: u64,
polynomial_size: u64,
level: u64,
base_log: u64,
key_variance: f64,
) !KeySet {
var raw_fft = c.aligned_alloc(cpu.CONCRETE_FFT_ALIGN, cpu.CONCRETE_FFT_SIZE);
const fft = @ptrCast(*cpu.Fft, raw_fft);
cpu.concrete_cpu_construct_concrete_fft(fft, polynomial_size);
const out_dim = glwe_dim * polynomial_size;
const in_sk = try allocator.alloc(u64, in_dim);
cpu.concrete_cpu_init_lwe_secret_key_u64(in_sk.ptr, in_dim, csprng, &cpu.CONCRETE_CSPRNG_VTABLE);
const out_sk = try allocator.alloc(u64, out_dim);
cpu.concrete_cpu_init_lwe_secret_key_u64(out_sk.ptr, out_dim, csprng, &cpu.CONCRETE_CSPRNG_VTABLE);
const bsk_f = try common.new_bsk(
csprng,
in_dim,
glwe_dim,
polynomial_size,
level,
base_log,
key_variance,
in_sk,
out_sk,
fft,
);
var stack_size: usize = 0;
var stack_align: usize = 0;
try std.testing.expect(
cpu.concrete_cpu_bootstrap_lwe_ciphertext_u64_scratch(
&stack_size,
&stack_align,
glwe_dim,
polynomial_size,
fft,
) == 0,
);
const stack = @ptrCast([*]u8, c.aligned_alloc(stack_align, stack_size))[0..stack_size];
return KeySet{
.in_dim = in_dim,
.glwe_dim = glwe_dim,
.polynomial_size = polynomial_size,
.level = level,
.base_log = base_log,
.in_sk = in_sk,
.out_sk = out_sk,
.bsk_f = bsk_f,
.fft = fft,
.stack = stack,
};
}
pub fn bootstrap(
self: *KeySet,
pt: u64,
encryption_variance: f64,
lut: []u64,
csprng: *cpu.Csprng,
) !u64 {
const out_dim = self.glwe_dim * self.polynomial_size;
const in_ct = try allocator.alloc(u64, self.in_dim + 1);
defer allocator.free(in_ct);
const out_ct = try allocator.alloc(u64, out_dim + 1);
defer allocator.free(out_ct);
cpu.concrete_cpu_encrypt_lwe_ciphertext_u64(
self.in_sk.ptr,
in_ct.ptr,
pt,
self.in_dim,
encryption_variance,
csprng,
&cpu.CONCRETE_CSPRNG_VTABLE,
);
cpu.concrete_cpu_bootstrap_lwe_ciphertext_u64(
out_ct.ptr,
in_ct.ptr,
lut.ptr,
self.bsk_f.ptr,
self.level,
self.base_log,
self.glwe_dim,
self.polynomial_size,
self.in_dim,
self.fft,
self.stack.ptr,
self.stack.len,
);
var image: u64 = 0;
cpu.concrete_cpu_decrypt_lwe_ciphertext_u64(self.out_sk.ptr, out_ct.ptr, out_dim, &image);
return image;
}
pub fn deinit(
self: *KeySet,
) void {
allocator.free(self.in_sk);
allocator.free(self.out_sk);
allocator.free(self.bsk_f);
cpu.concrete_cpu_destroy_concrete_fft(self.fft);
c.free(self.fft);
c.free(self.stack.ptr);
}
};
fn expand_lut(lut: []u64, glwe_dim: u64, polynomial_size: u64) ![]u64 {
const raw_lut = try allocator.alloc(u64, (glwe_dim + 1) * polynomial_size);
std.debug.assert(polynomial_size % lut.len == 0);
const lut_case_size = polynomial_size / lut.len;
for (raw_lut[0..(glwe_dim * polynomial_size)]) |*i| {
i.* = 0;
}
var i: usize = 0;
while (i < lut.len) {
var j: usize = 0;
while (j < lut_case_size) {
raw_lut[glwe_dim * polynomial_size + i * lut_case_size + j] = lut[i];
j += 1;
}
i += 1;
}
return raw_lut;
}
fn encrypt_bootstrap_decrypt(
csprng: *cpu.Csprng,
lut: []u64,
lut_index: u64,
in_dim: u64,
glwe_dim: u64,
polynomial_size: u64,
level: u64,
base_log: u64,
key_variance: f64,
encryption_variance: f64,
) !u64 {
const precision = lut.len;
var key_set = try KeySet.init(
csprng,
in_dim,
glwe_dim,
polynomial_size,
level,
base_log,
key_variance,
);
defer key_set.deinit();
const raw_lut = try expand_lut(lut, glwe_dim, polynomial_size);
defer allocator.free(raw_lut);
const pt = (@intToFloat(f64, lut_index) + 0.5) / (2.0 * @intToFloat(f64, precision)) * std.math.pow(f64, 2.0, 64);
const image = try key_set.bootstrap(@floatToInt(u64, pt), encryption_variance, raw_lut, csprng);
return image;
}
test "bootstrap" {
var raw_csprng = c.aligned_alloc(cpu.CONCRETE_CSPRNG_ALIGN, cpu.CONCRETE_CSPRNG_SIZE);
defer c.free(raw_csprng);
const csprng = @ptrCast(*cpu.Csprng, raw_csprng);
cpu.concrete_cpu_construct_concrete_csprng(
csprng,
cpu.Uint128{ .little_endian_bytes = [_]u8{1} ** 16 },
);
defer cpu.concrete_cpu_destroy_concrete_csprng(csprng);
const log2_precision = 4;
const precision = 1 << log2_precision;
const lut = try allocator.alloc(u64, precision);
defer allocator.free(lut);
try std.os.getrandom(std.mem.sliceAsBytes(lut));
var lut_index: u64 = 0;
try std.os.getrandom(std.mem.asBytes(&lut_index));
lut_index %= 2 * precision;
const in_dim = 3;
const glwe_dim = 1;
const log2_poly_size = 10;
const polynomial_size = 1 << log2_poly_size;
const level = 3;
const base_log = 10;
const key_variance = 0.0000000000000000000001;
const encryption_variance = 0.0000000000000000000001;
const image = try encrypt_bootstrap_decrypt(csprng, lut, lut_index, in_dim, glwe_dim, polynomial_size, level, base_log, key_variance, encryption_variance);
const expected_image = if (lut_index < precision) lut[lut_index] else -%lut[(lut_index - precision)];
const diff = @intToFloat(f64, @bitCast(i64, image -% expected_image)) / std.math.pow(f64, 2.0, 64);
try std.testing.expect(@fabs(diff) < 0.001);
}

View File

@@ -0,0 +1,48 @@
const c = @cImport({
@cInclude("stdlib.h");
});
const std = @import("std");
const cpu = @cImport({
@cInclude("include/concrete-cpu.h");
});
const allocator = std.heap.page_allocator;
fn test_encrypt_decrypt(csprng: *cpu.Csprng, pt: u64, dim: u64) !u64 {
const sk = try allocator.alloc(u64, dim);
defer allocator.free(sk);
cpu.concrete_cpu_init_lwe_secret_key_u64(sk.ptr, dim, csprng, &cpu.CONCRETE_CSPRNG_VTABLE);
const ct = try allocator.alloc(u64, dim + 1);
defer allocator.free(ct);
cpu.concrete_cpu_encrypt_lwe_ciphertext_u64(sk.ptr, ct.ptr, pt, dim, 0.000000000000001, csprng, &cpu.CONCRETE_CSPRNG_VTABLE);
var result: u64 = 0;
cpu.concrete_cpu_decrypt_lwe_ciphertext_u64(sk.ptr, ct.ptr, dim, &result);
return result;
}
test "encryption" {
var raw_csprng = c.aligned_alloc(cpu.CONCRETE_CSPRNG_ALIGN, cpu.CONCRETE_CSPRNG_SIZE);
defer c.free(raw_csprng);
const csprng = @ptrCast(*cpu.Csprng, raw_csprng);
cpu.concrete_cpu_construct_concrete_csprng(
csprng,
cpu.Uint128{ .little_endian_bytes = [_]u8{1} ** 16 },
);
defer cpu.concrete_cpu_destroy_concrete_csprng(csprng);
const pt = 1 << 63;
const result = try test_encrypt_decrypt(csprng, pt, 1024);
const diff = @intToFloat(f64, @bitCast(i64, result -% pt)) / std.math.pow(f64, 2.0, 64);
try std.testing.expect(@fabs(diff) < 0.001);
}

View File

@@ -0,0 +1,200 @@
const c = @cImport({
@cInclude("stdlib.h");
});
const std = @import("std");
const allocator = std.heap.page_allocator;
const common = @import("common.zig");
const cpu = @cImport({
@cInclude("include/concrete-cpu.h");
});
fn test3(csprng: *cpu.Csprng, polynomial_size: usize) !void {
const glwe_dim: usize = 1;
const small_dim: usize = 4;
const level_bsk: usize = 4;
const base_log_bsk: usize = 9;
const level_pksk: usize = 2;
const base_log_pksk: usize = 15;
const level_cbs: usize = 4;
const base_log_cbs: usize = 6;
const variance: f64 = std.math.pow(f64, 2.0, -100);
const big_dim = glwe_dim * polynomial_size;
const small_sk = try allocator.alloc(u64, small_dim);
cpu.concrete_cpu_init_lwe_secret_key_u64(small_sk.ptr, small_dim, csprng, &cpu.CONCRETE_CSPRNG_VTABLE);
const big_sk = try allocator.alloc(u64, big_dim);
cpu.concrete_cpu_init_lwe_secret_key_u64(big_sk.ptr, big_dim, csprng, &cpu.CONCRETE_CSPRNG_VTABLE);
var raw_fft = c.aligned_alloc(cpu.CONCRETE_FFT_ALIGN, cpu.CONCRETE_FFT_SIZE);
const fft = @ptrCast(*cpu.Fft, raw_fft);
cpu.concrete_cpu_construct_concrete_fft(fft, polynomial_size);
const bsk_f = try common.new_bsk(
csprng,
small_dim,
glwe_dim,
polynomial_size,
level_bsk,
base_log_bsk,
variance,
small_sk,
big_sk,
fft,
);
defer allocator.free(bsk_f);
const cbs_pfpksk_size = cpu.concrete_cpu_lwe_packing_keyswitch_key_size(glwe_dim, polynomial_size, level_pksk, big_dim);
const cbs_pfpksk = try allocator.alloc(u64, cbs_pfpksk_size * (glwe_dim + 1));
defer allocator.free(cbs_pfpksk);
cpu.concrete_cpu_init_lwe_circuit_bootstrap_private_functional_packing_keyswitch_keys_u64(
cbs_pfpksk.ptr,
big_sk.ptr,
big_sk.ptr,
big_dim,
polynomial_size,
glwe_dim,
level_pksk,
base_log_pksk,
variance,
1,
csprng,
&cpu.CONCRETE_CSPRNG_VTABLE,
);
// We are going to encrypt two ciphertexts with 5 bits each
const number_of_input_bits: usize = 10;
// Test on 610, binary representation 10011 00010
const val: u64 = 610;
const one: u64 = 1;
const extract_bits_output_buffer = try allocator.alloc(u64, number_of_input_bits * (small_dim + 1));
defer allocator.free(extract_bits_output_buffer);
var i: u64 = 0;
// Decryption of extracted bits for sanity check
while (i < number_of_input_bits) {
const bit: u64 =
(val >> @intCast(u6, number_of_input_bits - i - 1)) % 2;
cpu.concrete_cpu_encrypt_lwe_ciphertext_u64(small_sk.ptr, extract_bits_output_buffer[(small_dim + 1) * i ..].ptr, bit << 63, small_dim, variance, csprng, &cpu.CONCRETE_CSPRNG_VTABLE);
i += 1;
}
// We'll apply a single table look-up computing x + 1 to our 10 bits input integer that was
// represented over two 5 bits ciphertexts
const number_of_luts_and_output_cts: usize = 1;
var cbs_vp_output_buffer = try allocator.alloc(u64, (big_dim + 1) * number_of_luts_and_output_cts);
defer allocator.free(cbs_vp_output_buffer);
// Here we will create a single lut containing a single polynomial, which will result in a single
// Output ciphertecct
const luts_length = number_of_luts_and_output_cts * (1 << number_of_input_bits);
var luts = try allocator.alloc(u64, luts_length);
defer allocator.free(luts);
const delta_log_lut = 64 - number_of_input_bits;
i = 0;
while (i < luts_length) {
luts[i] = ((i + 1) % (one << number_of_input_bits)) << delta_log_lut;
i += 1;
}
{
var stack_align: usize = 0;
var stack_size: usize = 0;
try std.testing.expect(cpu.concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64_scratch(
&stack_size,
&stack_align,
number_of_luts_and_output_cts,
small_dim,
number_of_input_bits,
1 << number_of_input_bits,
number_of_luts_and_output_cts,
glwe_dim,
polynomial_size,
polynomial_size,
level_cbs,
fft,
) == 0);
const stack = @ptrCast([*]u8, c.aligned_alloc(stack_align, stack_size) orelse unreachable)[0..stack_size];
defer c.free(stack.ptr);
cpu.concrete_cpu_circuit_bootstrap_boolean_vertical_packing_lwe_ciphertext_u64(
cbs_vp_output_buffer.ptr,
extract_bits_output_buffer.ptr,
luts.ptr,
bsk_f.ptr,
cbs_pfpksk.ptr,
big_dim,
number_of_luts_and_output_cts,
small_dim,
number_of_input_bits,
1 << number_of_input_bits,
number_of_luts_and_output_cts,
level_bsk,
base_log_bsk,
glwe_dim,
polynomial_size,
small_dim,
level_pksk,
base_log_pksk,
big_dim,
glwe_dim,
polynomial_size,
glwe_dim + 1,
level_cbs,
base_log_cbs,
fft,
stack.ptr,
stack_size,
);
}
const expected = val + 1;
var decrypted: u64 = 0;
cpu.concrete_cpu_decrypt_lwe_ciphertext_u64(big_sk.ptr, cbs_vp_output_buffer.ptr, big_dim, &decrypted);
const rounded =
common.closest_representable(decrypted, 1, number_of_input_bits);
const decoded = rounded >> delta_log_lut;
std.debug.assert(decoded == expected);
}
test "encryption" {
var raw_csprng = c.aligned_alloc(cpu.CONCRETE_CSPRNG_ALIGN, cpu.CONCRETE_CSPRNG_SIZE);
defer c.free(raw_csprng);
const csprng = @ptrCast(*cpu.Csprng, raw_csprng);
cpu.concrete_cpu_construct_concrete_csprng(
csprng,
cpu.Uint128{ .little_endian_bytes = [_]u8{1} ** 16 },
);
defer cpu.concrete_cpu_destroy_concrete_csprng(csprng);
//CMUX tree
try test3(csprng, 512);
//No CMUX tree
try test3(csprng, 1024);
//Expanded lut
try test3(csprng, 2048);
}

View File

@@ -60,22 +60,33 @@ set(CONCRETELANG_BINARY_DIR ${CMAKE_CURRENT_BINARY_DIR})
# -------------------------------------------------------------------------------
include_directories(${PROJECT_SOURCE_DIR}/parameter-curves/concrete-security-curves-cpp/include)
# -------------------------------------------------------------------------------
# Concrete Backends
# -------------------------------------------------------------------------------
set(CONCRETE_BACKENDS_DIR "${PROJECT_SOURCE_DIR}/../../../backends")
# -------------------------------------------------------------------------------
# Concrete CPU Configuration
# -------------------------------------------------------------------------------
set(CONCRETE_CPU_STATIC_LIB "${PROJECT_SOURCE_DIR}/concrete-cpu/target/release/libconcrete_cpu.a")
set(CONCRETE_CPU_DIR "${CONCRETE_BACKENDS_DIR}/concrete-cpu")
set(CONCRETE_CPU_RELEASE_DIR "${CONCRETE_CPU_DIR}/target/release")
set(CONCRETE_CPU_INCLUDE_DIR "${CONCRETE_CPU_DIR}/include")
set(CONCRETE_CPU_STATIC_LIB "${CONCRETE_CPU_RELEASE_DIR}/libconcrete_cpu.a")
ExternalProject_Add(
concrete_cpu_rust
DOWNLOAD_COMMAND ""
CONFIGURE_COMMAND "" OUTPUT "${CONCRETE_CPU_STATIC_LIB}"
BUILD_COMMAND cargo build
COMMAND cargo build --release
BINARY_DIR "${PROJECT_SOURCE_DIR}/concrete-cpu"
BINARY_DIR "${CONCRETE_CPU_DIR}"
INSTALL_COMMAND ""
LOG_BUILD ON)
LOG_BUILD ON
LOG_OUTPUT_ON_FAILURE ON)
add_library(concrete_cpu STATIC IMPORTED)
# TODO - Change that to a location in the release dir
set(CONCRETE_CPU_INCLUDE_DIR "${PROJECT_SOURCE_DIR}/concrete-cpu/concrete-cpu")
set_target_properties(concrete_cpu PROPERTIES IMPORTED_LOCATION "${CONCRETE_CPU_STATIC_LIB}")
add_dependencies(concrete_cpu concrete_cpu_rust)