mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
feat: improve accuracy of p_error simulation in virtual circuits
This commit is contained in:
@@ -2,17 +2,21 @@
|
||||
Declaration of `Graph` class.
|
||||
"""
|
||||
|
||||
import math
|
||||
import re
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
import scipy.special
|
||||
|
||||
from ..dtypes import Float, Integer, UnsignedInteger
|
||||
from .node import Node
|
||||
from .operation import Operation
|
||||
|
||||
P_ERROR_PER_ERROR_SIZE_CACHE: Dict[float, Dict[int, float]] = {}
|
||||
|
||||
|
||||
class Graph:
|
||||
"""
|
||||
@@ -81,6 +85,8 @@ class Graph:
|
||||
nodes and their values during computation
|
||||
"""
|
||||
|
||||
# pylint: disable=no-member,too-many-nested-blocks
|
||||
|
||||
if p_error is None:
|
||||
p_error = 0.0
|
||||
|
||||
@@ -106,16 +112,46 @@ class Graph:
|
||||
if pred_node.operation != Operation.Input:
|
||||
dtype = node.inputs[index].dtype
|
||||
if isinstance(dtype, Integer):
|
||||
# this is not the real behavior of FHE
|
||||
# it's a simplified model, and it will be replaced at one point
|
||||
# see https://github.com/zama-ai/concrete-numpy/blob/main/docs/_static/p_error_simulation.pdf # noqa: E501 # pylint: disable=line-too-long
|
||||
# to learn more about the distribution of error
|
||||
|
||||
if p_error not in P_ERROR_PER_ERROR_SIZE_CACHE:
|
||||
std_score = math.sqrt(2) * scipy.special.erfcinv(p_error)
|
||||
p_error_per_error_size = {}
|
||||
|
||||
error_size = 1
|
||||
last_p = 1 - p_error
|
||||
while last_p != 1.0 or error_size == 1:
|
||||
new_std_score = (2 * error_size + 1) * std_score
|
||||
new_p = scipy.special.erf(new_std_score / math.sqrt(2))
|
||||
|
||||
p_error_per_error_size[error_size] = new_p - last_p
|
||||
|
||||
last_p = new_p
|
||||
error_size += 1
|
||||
|
||||
# ordering of `p_error_per_error_size` is relied on
|
||||
# during the introduction of the error below
|
||||
# thus we explicitly sort it to make sure it's ordered
|
||||
p_error_per_error_size = dict(
|
||||
sorted(p_error_per_error_size.items())
|
||||
)
|
||||
|
||||
P_ERROR_PER_ERROR_SIZE_CACHE[p_error] = p_error_per_error_size
|
||||
else: # pragma: no cover
|
||||
p_error_per_error_size = P_ERROR_PER_ERROR_SIZE_CACHE[p_error]
|
||||
|
||||
error = np.random.rand(*pred_results[index].shape)
|
||||
error = np.where(error < p_error**3, 3, error)
|
||||
error = np.where(error < p_error**2, 2, error)
|
||||
error = np.where(error < p_error, 1, np.where(error > 1, error, 0))
|
||||
|
||||
accumulated_p_error = 0.0
|
||||
for error_size, p_error_for_size in p_error_per_error_size.items():
|
||||
accumulated_p_error += p_error_for_size
|
||||
error = np.where(error < accumulated_p_error, error_size, error)
|
||||
|
||||
error = np.where(error < 1, 0, error).astype(np.int64)
|
||||
|
||||
error_sign = np.random.rand(*pred_results[index].shape)
|
||||
error_sign = np.where(error < 0.5, 1, -1)
|
||||
error_sign = np.where(error_sign < 0.5, 1, -1).astype(np.int64)
|
||||
|
||||
new_results = pred_results[index] + (error * error_sign)
|
||||
|
||||
|
||||
BIN
docs/_static/p_error_simulation.pdf
vendored
Normal file
BIN
docs/_static/p_error_simulation.pdf
vendored
Normal file
Binary file not shown.
@@ -15,6 +15,7 @@
|
||||
packaging 23.0 Apache Software License; BSD License
|
||||
pyparsing 3.0.9 MIT License
|
||||
python-dateutil 2.8.2 Apache Software License; BSD License
|
||||
scipy 1.7.3 BSD License
|
||||
six 1.16.0 MIT License
|
||||
torch 1.13.1 BSD License
|
||||
typing-extensions 3.10.0.2 Python Software Foundation License
|
||||
|
||||
44
poetry.lock
generated
44
poetry.lock
generated
@@ -3274,6 +3274,48 @@ files = [
|
||||
{file = "ruff-0.0.191.tar.gz", hash = "sha256:d698c4d5e3b2963cbbb7c2728f404091d5c47cdf8d94db3eb2f335e2a93a6b1b"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "scipy"
|
||||
version = "1.7.3"
|
||||
description = "SciPy: Scientific Library for Python"
|
||||
category = "main"
|
||||
optional = false
|
||||
python-versions = ">=3.7,<3.11"
|
||||
files = [
|
||||
{file = "scipy-1.7.3-1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:c9e04d7e9b03a8a6ac2045f7c5ef741be86727d8f49c45db45f244bdd2bcff17"},
|
||||
{file = "scipy-1.7.3-1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b0e0aeb061a1d7dcd2ed59ea57ee56c9b23dd60100825f98238c06ee5cc4467e"},
|
||||
{file = "scipy-1.7.3-1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b78a35c5c74d336f42f44106174b9851c783184a85a3fe3e68857259b37b9ffb"},
|
||||
{file = "scipy-1.7.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:173308efba2270dcd61cd45a30dfded6ec0085b4b6eb33b5eb11ab443005e088"},
|
||||
{file = "scipy-1.7.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:21b66200cf44b1c3e86495e3a436fc7a26608f92b8d43d344457c54f1c024cbc"},
|
||||
{file = "scipy-1.7.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceebc3c4f6a109777c0053dfa0282fddb8893eddfb0d598574acfb734a926168"},
|
||||
{file = "scipy-1.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7eaea089345a35130bc9a39b89ec1ff69c208efa97b3f8b25ea5d4c41d88094"},
|
||||
{file = "scipy-1.7.3-cp310-cp310-win_amd64.whl", hash = "sha256:304dfaa7146cffdb75fbf6bb7c190fd7688795389ad060b970269c8576d038e9"},
|
||||
{file = "scipy-1.7.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:033ce76ed4e9f62923e1f8124f7e2b0800db533828c853b402c7eec6e9465d80"},
|
||||
{file = "scipy-1.7.3-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4d242d13206ca4302d83d8a6388c9dfce49fc48fdd3c20efad89ba12f785bf9e"},
|
||||
{file = "scipy-1.7.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8499d9dd1459dc0d0fe68db0832c3d5fc1361ae8e13d05e6849b358dc3f2c279"},
|
||||
{file = "scipy-1.7.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca36e7d9430f7481fc7d11e015ae16fbd5575615a8e9060538104778be84addf"},
|
||||
{file = "scipy-1.7.3-cp37-cp37m-win32.whl", hash = "sha256:e2c036492e673aad1b7b0d0ccdc0cb30a968353d2c4bf92ac8e73509e1bf212c"},
|
||||
{file = "scipy-1.7.3-cp37-cp37m-win_amd64.whl", hash = "sha256:866ada14a95b083dd727a845a764cf95dd13ba3dc69a16b99038001b05439709"},
|
||||
{file = "scipy-1.7.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:65bd52bf55f9a1071398557394203d881384d27b9c2cad7df9a027170aeaef93"},
|
||||
{file = "scipy-1.7.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f99d206db1f1ae735a8192ab93bd6028f3a42f6fa08467d37a14eb96c9dd34a3"},
|
||||
{file = "scipy-1.7.3-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5f2cfc359379c56b3a41b17ebd024109b2049f878badc1e454f31418c3a18436"},
|
||||
{file = "scipy-1.7.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb7ae2c4dbdb3c9247e07acc532f91077ae6dbc40ad5bd5dca0bb5a176ee9bda"},
|
||||
{file = "scipy-1.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c2d250074cfa76715d58830579c64dff7354484b284c2b8b87e5a38321672c"},
|
||||
{file = "scipy-1.7.3-cp38-cp38-win32.whl", hash = "sha256:87069cf875f0262a6e3187ab0f419f5b4280d3dcf4811ef9613c605f6e4dca95"},
|
||||
{file = "scipy-1.7.3-cp38-cp38-win_amd64.whl", hash = "sha256:7edd9a311299a61e9919ea4192dd477395b50c014cdc1a1ac572d7c27e2207fa"},
|
||||
{file = "scipy-1.7.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eef93a446114ac0193a7b714ce67659db80caf940f3232bad63f4c7a81bc18df"},
|
||||
{file = "scipy-1.7.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:eb326658f9b73c07081300daba90a8746543b5ea177184daed26528273157294"},
|
||||
{file = "scipy-1.7.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:93378f3d14fff07572392ce6a6a2ceb3a1f237733bd6dcb9eb6a2b29b0d19085"},
|
||||
{file = "scipy-1.7.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edad1cf5b2ce1912c4d8ddad20e11d333165552aba262c882e28c78bbc09dbf6"},
|
||||
{file = "scipy-1.7.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d1cc2c19afe3b5a546ede7e6a44ce1ff52e443d12b231823268019f608b9b12"},
|
||||
{file = "scipy-1.7.3-cp39-cp39-win32.whl", hash = "sha256:2c56b820d304dffcadbbb6cbfbc2e2c79ee46ea291db17e288e73cd3c64fefa9"},
|
||||
{file = "scipy-1.7.3-cp39-cp39-win_amd64.whl", hash = "sha256:3f78181a153fa21c018d346f595edd648344751d7f03ab94b398be2ad083ed3e"},
|
||||
{file = "scipy-1.7.3.tar.gz", hash = "sha256:ab5875facfdef77e0a47d5fd39ea178b58e60e454a4c85aa1e52fcb80db7babf"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
numpy = ">=1.16.5,<1.23.0"
|
||||
|
||||
[[package]]
|
||||
name = "secretstorage"
|
||||
version = "3.3.3"
|
||||
@@ -3846,4 +3888,4 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools"
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.7,<3.11"
|
||||
content-hash = "1ba82f6d92e8cfd1636e23c6ad7374408f27cf86bb3d80c77ae9c2b109e8e8b8"
|
||||
content-hash = "34cb8b9ae8d4245e0d2f98b37cd5c11288e55e4968fa05313b853510019b0ef6"
|
||||
|
||||
@@ -44,6 +44,7 @@ numpy = "^1.21.0"
|
||||
Pillow = "^9.0.0"
|
||||
concrete-compiler = "^0.23.4"
|
||||
torch = "^1.13.1"
|
||||
scipy = "1.7.3"
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
isort = "^5.10.1"
|
||||
|
||||
@@ -302,9 +302,9 @@ def test_bad_server_save(helpers):
|
||||
assert str(excinfo.value) == "Just-in-Time compilation cannot be saved"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("p_error", [0.5, 0.1, 0.01])
|
||||
@pytest.mark.parametrize("p_error", [0.75, 0.5, 0.4, 0.25, 0.2, 0.1, 0.01, 0.001])
|
||||
@pytest.mark.parametrize("bit_width", [10])
|
||||
@pytest.mark.parametrize("sample_size", [100_000])
|
||||
@pytest.mark.parametrize("sample_size", [1_000_000])
|
||||
@pytest.mark.parametrize("tolerance", [0.075])
|
||||
def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
|
||||
"""
|
||||
@@ -333,7 +333,7 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
|
||||
expected_number_of_errors_on_average - (expected_number_of_errors_on_average * tolerance),
|
||||
expected_number_of_errors_on_average + (expected_number_of_errors_on_average * tolerance),
|
||||
]
|
||||
assert acceptable_number_of_errors[0] < errors < acceptable_number_of_errors[1]
|
||||
assert acceptable_number_of_errors[0] <= errors <= acceptable_number_of_errors[1]
|
||||
|
||||
|
||||
def test_circuit_run_with_unused_arg(helpers):
|
||||
|
||||
Reference in New Issue
Block a user