feat: improve accuracy of p_error simulation in virtual circuits

This commit is contained in:
Umut
2023-02-07 12:10:13 +01:00
parent 840c0eba8c
commit 656761346a
6 changed files with 90 additions and 10 deletions

View File

@@ -2,17 +2,21 @@
Declaration of `Graph` class.
"""
import math
import re
from copy import deepcopy
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import networkx as nx
import numpy as np
import scipy.special
from ..dtypes import Float, Integer, UnsignedInteger
from .node import Node
from .operation import Operation
P_ERROR_PER_ERROR_SIZE_CACHE: Dict[float, Dict[int, float]] = {}
class Graph:
"""
@@ -81,6 +85,8 @@ class Graph:
nodes and their values during computation
"""
# pylint: disable=no-member,too-many-nested-blocks
if p_error is None:
p_error = 0.0
@@ -106,16 +112,46 @@ class Graph:
if pred_node.operation != Operation.Input:
dtype = node.inputs[index].dtype
if isinstance(dtype, Integer):
# this is not the real behavior of FHE
# it's a simplified model, and it will be replaced at one point
# see https://github.com/zama-ai/concrete-numpy/blob/main/docs/_static/p_error_simulation.pdf # noqa: E501 # pylint: disable=line-too-long
# to learn more about the distribution of error
if p_error not in P_ERROR_PER_ERROR_SIZE_CACHE:
std_score = math.sqrt(2) * scipy.special.erfcinv(p_error)
p_error_per_error_size = {}
error_size = 1
last_p = 1 - p_error
while last_p != 1.0 or error_size == 1:
new_std_score = (2 * error_size + 1) * std_score
new_p = scipy.special.erf(new_std_score / math.sqrt(2))
p_error_per_error_size[error_size] = new_p - last_p
last_p = new_p
error_size += 1
# ordering of `p_error_per_error_size` is relied on
# during the introduction of the error below
# thus we explicitly sort it to make sure it's ordered
p_error_per_error_size = dict(
sorted(p_error_per_error_size.items())
)
P_ERROR_PER_ERROR_SIZE_CACHE[p_error] = p_error_per_error_size
else: # pragma: no cover
p_error_per_error_size = P_ERROR_PER_ERROR_SIZE_CACHE[p_error]
error = np.random.rand(*pred_results[index].shape)
error = np.where(error < p_error**3, 3, error)
error = np.where(error < p_error**2, 2, error)
error = np.where(error < p_error, 1, np.where(error > 1, error, 0))
accumulated_p_error = 0.0
for error_size, p_error_for_size in p_error_per_error_size.items():
accumulated_p_error += p_error_for_size
error = np.where(error < accumulated_p_error, error_size, error)
error = np.where(error < 1, 0, error).astype(np.int64)
error_sign = np.random.rand(*pred_results[index].shape)
error_sign = np.where(error < 0.5, 1, -1)
error_sign = np.where(error_sign < 0.5, 1, -1).astype(np.int64)
new_results = pred_results[index] + (error * error_sign)

BIN
docs/_static/p_error_simulation.pdf vendored Normal file

Binary file not shown.

View File

@@ -15,6 +15,7 @@
packaging 23.0 Apache Software License; BSD License
pyparsing 3.0.9 MIT License
python-dateutil 2.8.2 Apache Software License; BSD License
scipy 1.7.3 BSD License
six 1.16.0 MIT License
torch 1.13.1 BSD License
typing-extensions 3.10.0.2 Python Software Foundation License

44
poetry.lock generated
View File

@@ -3274,6 +3274,48 @@ files = [
{file = "ruff-0.0.191.tar.gz", hash = "sha256:d698c4d5e3b2963cbbb7c2728f404091d5c47cdf8d94db3eb2f335e2a93a6b1b"},
]
[[package]]
name = "scipy"
version = "1.7.3"
description = "SciPy: Scientific Library for Python"
category = "main"
optional = false
python-versions = ">=3.7,<3.11"
files = [
{file = "scipy-1.7.3-1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:c9e04d7e9b03a8a6ac2045f7c5ef741be86727d8f49c45db45f244bdd2bcff17"},
{file = "scipy-1.7.3-1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b0e0aeb061a1d7dcd2ed59ea57ee56c9b23dd60100825f98238c06ee5cc4467e"},
{file = "scipy-1.7.3-1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b78a35c5c74d336f42f44106174b9851c783184a85a3fe3e68857259b37b9ffb"},
{file = "scipy-1.7.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:173308efba2270dcd61cd45a30dfded6ec0085b4b6eb33b5eb11ab443005e088"},
{file = "scipy-1.7.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:21b66200cf44b1c3e86495e3a436fc7a26608f92b8d43d344457c54f1c024cbc"},
{file = "scipy-1.7.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceebc3c4f6a109777c0053dfa0282fddb8893eddfb0d598574acfb734a926168"},
{file = "scipy-1.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7eaea089345a35130bc9a39b89ec1ff69c208efa97b3f8b25ea5d4c41d88094"},
{file = "scipy-1.7.3-cp310-cp310-win_amd64.whl", hash = "sha256:304dfaa7146cffdb75fbf6bb7c190fd7688795389ad060b970269c8576d038e9"},
{file = "scipy-1.7.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:033ce76ed4e9f62923e1f8124f7e2b0800db533828c853b402c7eec6e9465d80"},
{file = "scipy-1.7.3-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4d242d13206ca4302d83d8a6388c9dfce49fc48fdd3c20efad89ba12f785bf9e"},
{file = "scipy-1.7.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8499d9dd1459dc0d0fe68db0832c3d5fc1361ae8e13d05e6849b358dc3f2c279"},
{file = "scipy-1.7.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca36e7d9430f7481fc7d11e015ae16fbd5575615a8e9060538104778be84addf"},
{file = "scipy-1.7.3-cp37-cp37m-win32.whl", hash = "sha256:e2c036492e673aad1b7b0d0ccdc0cb30a968353d2c4bf92ac8e73509e1bf212c"},
{file = "scipy-1.7.3-cp37-cp37m-win_amd64.whl", hash = "sha256:866ada14a95b083dd727a845a764cf95dd13ba3dc69a16b99038001b05439709"},
{file = "scipy-1.7.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:65bd52bf55f9a1071398557394203d881384d27b9c2cad7df9a027170aeaef93"},
{file = "scipy-1.7.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f99d206db1f1ae735a8192ab93bd6028f3a42f6fa08467d37a14eb96c9dd34a3"},
{file = "scipy-1.7.3-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5f2cfc359379c56b3a41b17ebd024109b2049f878badc1e454f31418c3a18436"},
{file = "scipy-1.7.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb7ae2c4dbdb3c9247e07acc532f91077ae6dbc40ad5bd5dca0bb5a176ee9bda"},
{file = "scipy-1.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c2d250074cfa76715d58830579c64dff7354484b284c2b8b87e5a38321672c"},
{file = "scipy-1.7.3-cp38-cp38-win32.whl", hash = "sha256:87069cf875f0262a6e3187ab0f419f5b4280d3dcf4811ef9613c605f6e4dca95"},
{file = "scipy-1.7.3-cp38-cp38-win_amd64.whl", hash = "sha256:7edd9a311299a61e9919ea4192dd477395b50c014cdc1a1ac572d7c27e2207fa"},
{file = "scipy-1.7.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eef93a446114ac0193a7b714ce67659db80caf940f3232bad63f4c7a81bc18df"},
{file = "scipy-1.7.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:eb326658f9b73c07081300daba90a8746543b5ea177184daed26528273157294"},
{file = "scipy-1.7.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:93378f3d14fff07572392ce6a6a2ceb3a1f237733bd6dcb9eb6a2b29b0d19085"},
{file = "scipy-1.7.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edad1cf5b2ce1912c4d8ddad20e11d333165552aba262c882e28c78bbc09dbf6"},
{file = "scipy-1.7.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d1cc2c19afe3b5a546ede7e6a44ce1ff52e443d12b231823268019f608b9b12"},
{file = "scipy-1.7.3-cp39-cp39-win32.whl", hash = "sha256:2c56b820d304dffcadbbb6cbfbc2e2c79ee46ea291db17e288e73cd3c64fefa9"},
{file = "scipy-1.7.3-cp39-cp39-win_amd64.whl", hash = "sha256:3f78181a153fa21c018d346f595edd648344751d7f03ab94b398be2ad083ed3e"},
{file = "scipy-1.7.3.tar.gz", hash = "sha256:ab5875facfdef77e0a47d5fd39ea178b58e60e454a4c85aa1e52fcb80db7babf"},
]
[package.dependencies]
numpy = ">=1.16.5,<1.23.0"
[[package]]
name = "secretstorage"
version = "3.3.3"
@@ -3846,4 +3888,4 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools"
[metadata]
lock-version = "2.0"
python-versions = ">=3.7,<3.11"
content-hash = "1ba82f6d92e8cfd1636e23c6ad7374408f27cf86bb3d80c77ae9c2b109e8e8b8"
content-hash = "34cb8b9ae8d4245e0d2f98b37cd5c11288e55e4968fa05313b853510019b0ef6"

View File

@@ -44,6 +44,7 @@ numpy = "^1.21.0"
Pillow = "^9.0.0"
concrete-compiler = "^0.23.4"
torch = "^1.13.1"
scipy = "1.7.3"
[tool.poetry.dev-dependencies]
isort = "^5.10.1"

View File

@@ -302,9 +302,9 @@ def test_bad_server_save(helpers):
assert str(excinfo.value) == "Just-in-Time compilation cannot be saved"
@pytest.mark.parametrize("p_error", [0.5, 0.1, 0.01])
@pytest.mark.parametrize("p_error", [0.75, 0.5, 0.4, 0.25, 0.2, 0.1, 0.01, 0.001])
@pytest.mark.parametrize("bit_width", [10])
@pytest.mark.parametrize("sample_size", [100_000])
@pytest.mark.parametrize("sample_size", [1_000_000])
@pytest.mark.parametrize("tolerance", [0.075])
def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
"""
@@ -333,7 +333,7 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers):
expected_number_of_errors_on_average - (expected_number_of_errors_on_average * tolerance),
expected_number_of_errors_on_average + (expected_number_of_errors_on_average * tolerance),
]
assert acceptable_number_of_errors[0] < errors < acceptable_number_of_errors[1]
assert acceptable_number_of_errors[0] <= errors <= acceptable_number_of_errors[1]
def test_circuit_run_with_unused_arg(helpers):