diff --git a/concrete/numpy/representation/graph.py b/concrete/numpy/representation/graph.py index 8c61cb7c5..1231214c7 100644 --- a/concrete/numpy/representation/graph.py +++ b/concrete/numpy/representation/graph.py @@ -2,17 +2,21 @@ Declaration of `Graph` class. """ +import math import re from copy import deepcopy from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import networkx as nx import numpy as np +import scipy.special from ..dtypes import Float, Integer, UnsignedInteger from .node import Node from .operation import Operation +P_ERROR_PER_ERROR_SIZE_CACHE: Dict[float, Dict[int, float]] = {} + class Graph: """ @@ -81,6 +85,8 @@ class Graph: nodes and their values during computation """ + # pylint: disable=no-member,too-many-nested-blocks + if p_error is None: p_error = 0.0 @@ -106,16 +112,46 @@ class Graph: if pred_node.operation != Operation.Input: dtype = node.inputs[index].dtype if isinstance(dtype, Integer): - # this is not the real behavior of FHE - # it's a simplified model, and it will be replaced at one point + # see https://github.com/zama-ai/concrete-numpy/blob/main/docs/_static/p_error_simulation.pdf # noqa: E501 # pylint: disable=line-too-long + # to learn more about the distribution of error + + if p_error not in P_ERROR_PER_ERROR_SIZE_CACHE: + std_score = math.sqrt(2) * scipy.special.erfcinv(p_error) + p_error_per_error_size = {} + + error_size = 1 + last_p = 1 - p_error + while last_p != 1.0 or error_size == 1: + new_std_score = (2 * error_size + 1) * std_score + new_p = scipy.special.erf(new_std_score / math.sqrt(2)) + + p_error_per_error_size[error_size] = new_p - last_p + + last_p = new_p + error_size += 1 + + # ordering of `p_error_per_error_size` is relied on + # during the introduction of the error below + # thus we explicitly sort it to make sure it's ordered + p_error_per_error_size = dict( + sorted(p_error_per_error_size.items()) + ) + + P_ERROR_PER_ERROR_SIZE_CACHE[p_error] = p_error_per_error_size + else: # pragma: no cover + p_error_per_error_size = P_ERROR_PER_ERROR_SIZE_CACHE[p_error] error = np.random.rand(*pred_results[index].shape) - error = np.where(error < p_error**3, 3, error) - error = np.where(error < p_error**2, 2, error) - error = np.where(error < p_error, 1, np.where(error > 1, error, 0)) + + accumulated_p_error = 0.0 + for error_size, p_error_for_size in p_error_per_error_size.items(): + accumulated_p_error += p_error_for_size + error = np.where(error < accumulated_p_error, error_size, error) + + error = np.where(error < 1, 0, error).astype(np.int64) error_sign = np.random.rand(*pred_results[index].shape) - error_sign = np.where(error < 0.5, 1, -1) + error_sign = np.where(error_sign < 0.5, 1, -1).astype(np.int64) new_results = pred_results[index] + (error * error_sign) diff --git a/docs/_static/p_error_simulation.pdf b/docs/_static/p_error_simulation.pdf new file mode 100644 index 000000000..c9ff8e655 Binary files /dev/null and b/docs/_static/p_error_simulation.pdf differ diff --git a/docs/linux.dependency.licenses.txt b/docs/linux.dependency.licenses.txt index de36ad750..f68e9254c 100644 --- a/docs/linux.dependency.licenses.txt +++ b/docs/linux.dependency.licenses.txt @@ -15,6 +15,7 @@ packaging 23.0 Apache Software License; BSD License pyparsing 3.0.9 MIT License python-dateutil 2.8.2 Apache Software License; BSD License + scipy 1.7.3 BSD License six 1.16.0 MIT License torch 1.13.1 BSD License typing-extensions 3.10.0.2 Python Software Foundation License diff --git a/poetry.lock b/poetry.lock index 540fbdbc6..3921708cf 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3274,6 +3274,48 @@ files = [ {file = "ruff-0.0.191.tar.gz", hash = "sha256:d698c4d5e3b2963cbbb7c2728f404091d5c47cdf8d94db3eb2f335e2a93a6b1b"}, ] +[[package]] +name = "scipy" +version = "1.7.3" +description = "SciPy: Scientific Library for Python" +category = "main" +optional = false +python-versions = ">=3.7,<3.11" +files = [ + {file = "scipy-1.7.3-1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:c9e04d7e9b03a8a6ac2045f7c5ef741be86727d8f49c45db45f244bdd2bcff17"}, + {file = "scipy-1.7.3-1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:b0e0aeb061a1d7dcd2ed59ea57ee56c9b23dd60100825f98238c06ee5cc4467e"}, + {file = "scipy-1.7.3-1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:b78a35c5c74d336f42f44106174b9851c783184a85a3fe3e68857259b37b9ffb"}, + {file = "scipy-1.7.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:173308efba2270dcd61cd45a30dfded6ec0085b4b6eb33b5eb11ab443005e088"}, + {file = "scipy-1.7.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:21b66200cf44b1c3e86495e3a436fc7a26608f92b8d43d344457c54f1c024cbc"}, + {file = "scipy-1.7.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceebc3c4f6a109777c0053dfa0282fddb8893eddfb0d598574acfb734a926168"}, + {file = "scipy-1.7.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f7eaea089345a35130bc9a39b89ec1ff69c208efa97b3f8b25ea5d4c41d88094"}, + {file = "scipy-1.7.3-cp310-cp310-win_amd64.whl", hash = "sha256:304dfaa7146cffdb75fbf6bb7c190fd7688795389ad060b970269c8576d038e9"}, + {file = "scipy-1.7.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:033ce76ed4e9f62923e1f8124f7e2b0800db533828c853b402c7eec6e9465d80"}, + {file = "scipy-1.7.3-cp37-cp37m-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4d242d13206ca4302d83d8a6388c9dfce49fc48fdd3c20efad89ba12f785bf9e"}, + {file = "scipy-1.7.3-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:8499d9dd1459dc0d0fe68db0832c3d5fc1361ae8e13d05e6849b358dc3f2c279"}, + {file = "scipy-1.7.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca36e7d9430f7481fc7d11e015ae16fbd5575615a8e9060538104778be84addf"}, + {file = "scipy-1.7.3-cp37-cp37m-win32.whl", hash = "sha256:e2c036492e673aad1b7b0d0ccdc0cb30a968353d2c4bf92ac8e73509e1bf212c"}, + {file = "scipy-1.7.3-cp37-cp37m-win_amd64.whl", hash = "sha256:866ada14a95b083dd727a845a764cf95dd13ba3dc69a16b99038001b05439709"}, + {file = "scipy-1.7.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:65bd52bf55f9a1071398557394203d881384d27b9c2cad7df9a027170aeaef93"}, + {file = "scipy-1.7.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:f99d206db1f1ae735a8192ab93bd6028f3a42f6fa08467d37a14eb96c9dd34a3"}, + {file = "scipy-1.7.3-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:5f2cfc359379c56b3a41b17ebd024109b2049f878badc1e454f31418c3a18436"}, + {file = "scipy-1.7.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb7ae2c4dbdb3c9247e07acc532f91077ae6dbc40ad5bd5dca0bb5a176ee9bda"}, + {file = "scipy-1.7.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95c2d250074cfa76715d58830579c64dff7354484b284c2b8b87e5a38321672c"}, + {file = "scipy-1.7.3-cp38-cp38-win32.whl", hash = "sha256:87069cf875f0262a6e3187ab0f419f5b4280d3dcf4811ef9613c605f6e4dca95"}, + {file = "scipy-1.7.3-cp38-cp38-win_amd64.whl", hash = "sha256:7edd9a311299a61e9919ea4192dd477395b50c014cdc1a1ac572d7c27e2207fa"}, + {file = "scipy-1.7.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:eef93a446114ac0193a7b714ce67659db80caf940f3232bad63f4c7a81bc18df"}, + {file = "scipy-1.7.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:eb326658f9b73c07081300daba90a8746543b5ea177184daed26528273157294"}, + {file = "scipy-1.7.3-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:93378f3d14fff07572392ce6a6a2ceb3a1f237733bd6dcb9eb6a2b29b0d19085"}, + {file = "scipy-1.7.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edad1cf5b2ce1912c4d8ddad20e11d333165552aba262c882e28c78bbc09dbf6"}, + {file = "scipy-1.7.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5d1cc2c19afe3b5a546ede7e6a44ce1ff52e443d12b231823268019f608b9b12"}, + {file = "scipy-1.7.3-cp39-cp39-win32.whl", hash = "sha256:2c56b820d304dffcadbbb6cbfbc2e2c79ee46ea291db17e288e73cd3c64fefa9"}, + {file = "scipy-1.7.3-cp39-cp39-win_amd64.whl", hash = "sha256:3f78181a153fa21c018d346f595edd648344751d7f03ab94b398be2ad083ed3e"}, + {file = "scipy-1.7.3.tar.gz", hash = "sha256:ab5875facfdef77e0a47d5fd39ea178b58e60e454a4c85aa1e52fcb80db7babf"}, +] + +[package.dependencies] +numpy = ">=1.16.5,<1.23.0" + [[package]] name = "secretstorage" version = "3.3.3" @@ -3846,4 +3888,4 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools" [metadata] lock-version = "2.0" python-versions = ">=3.7,<3.11" -content-hash = "1ba82f6d92e8cfd1636e23c6ad7374408f27cf86bb3d80c77ae9c2b109e8e8b8" +content-hash = "34cb8b9ae8d4245e0d2f98b37cd5c11288e55e4968fa05313b853510019b0ef6" diff --git a/pyproject.toml b/pyproject.toml index c078a1db3..c6cb0d64e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ numpy = "^1.21.0" Pillow = "^9.0.0" concrete-compiler = "^0.23.4" torch = "^1.13.1" +scipy = "1.7.3" [tool.poetry.dev-dependencies] isort = "^5.10.1" diff --git a/tests/compilation/test_circuit.py b/tests/compilation/test_circuit.py index 87ece5235..641e406e3 100644 --- a/tests/compilation/test_circuit.py +++ b/tests/compilation/test_circuit.py @@ -302,9 +302,9 @@ def test_bad_server_save(helpers): assert str(excinfo.value) == "Just-in-Time compilation cannot be saved" -@pytest.mark.parametrize("p_error", [0.5, 0.1, 0.01]) +@pytest.mark.parametrize("p_error", [0.75, 0.5, 0.4, 0.25, 0.2, 0.1, 0.01, 0.001]) @pytest.mark.parametrize("bit_width", [10]) -@pytest.mark.parametrize("sample_size", [100_000]) +@pytest.mark.parametrize("sample_size", [1_000_000]) @pytest.mark.parametrize("tolerance", [0.075]) def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers): """ @@ -333,7 +333,7 @@ def test_virtual_p_error(p_error, bit_width, sample_size, tolerance, helpers): expected_number_of_errors_on_average - (expected_number_of_errors_on_average * tolerance), expected_number_of_errors_on_average + (expected_number_of_errors_on_average * tolerance), ] - assert acceptable_number_of_errors[0] < errors < acceptable_number_of_errors[1] + assert acceptable_number_of_errors[0] <= errors <= acceptable_number_of_errors[1] def test_circuit_run_with_unused_arg(helpers):