diff --git a/docs/core-features/extensions.md b/docs/core-features/extensions.md index 967c3ae2e..ef69c2e1c 100644 --- a/docs/core-features/extensions.md +++ b/docs/core-features/extensions.md @@ -258,6 +258,28 @@ for x in range(10): assert np.array_equal(circuit.encrypt_run_decrypt(x), np.array([[x, x, x], [x, x, x]]) + 1) ``` +## fhe.constant(value) + +Allows you to create an encrypted constant of a given value. + +```python +from concrete import fhe +import numpy as np + +@fhe.compiler({"x": "encrypted", "a":"clear"}) +def f(x, a): + z = fhe.constant(a) + return x + z + +inputset = range(10) +circuit = f.compile(inputset) + +for x in range(10): + assert circuit.encrypt_run_decrypt(x, 5) == x + 5 +``` + +This extension is also compatible with constant arrays. + ## fhe.hint(value, \*\*kwargs) Hint properties of a value. Imagine you have this circuit: diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index f82430727..70c13ed30 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -52,6 +52,7 @@ from .extensions import ( LookupTable, array, bits, + constant, conv, hint, identity, diff --git a/frontends/concrete-python/concrete/fhe/extensions/__init__.py b/frontends/concrete-python/concrete/fhe/extensions/__init__.py index e6437e1c4..95809224b 100644 --- a/frontends/concrete-python/concrete/fhe/extensions/__init__.py +++ b/frontends/concrete-python/concrete/fhe/extensions/__init__.py @@ -6,6 +6,7 @@ from numpy import where as if_then_else from .array import array from .bits import bits +from .constant import constant from .convolution import conv from .hint import hint from .identity import identity, refresh diff --git a/frontends/concrete-python/concrete/fhe/extensions/constant.py b/frontends/concrete-python/concrete/fhe/extensions/constant.py new file mode 100644 index 000000000..24931c894 --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/constant.py @@ -0,0 +1,16 @@ +""" +Declaration of `constant` functions, to allow server side trivial encryption. +""" + +from typing import Any, Union + +from ..tracing import Tracer +from .zeros import zeros + + +def constant(x: Union[Tracer, Any]) -> Union[Tracer, Any]: + """ + Trivial encryption of a cleartext value. + """ + + return zeros(() if isinstance(x, int) else x.shape) + x diff --git a/frontends/concrete-python/tests/extensions/test_constant.py b/frontends/concrete-python/tests/extensions/test_constant.py new file mode 100644 index 000000000..48399ee99 --- /dev/null +++ b/frontends/concrete-python/tests/extensions/test_constant.py @@ -0,0 +1,44 @@ +""" +Tests of 'constant' extension. +""" + +import numpy as np + +from concrete import fhe + + +def test_constant_scalar(helpers): + """ + Test that fhe.constant works with scalars. + """ + configuration = helpers.configuration() + + @fhe.compiler({"x": "clear", "y": "encrypted"}) + def f(x, y): + return fhe.constant(x) + y + + inputset = [(np.random.randint(0, 2**5), np.random.randint(0, 2**5)) for _ in range(100)] + circuit = f.compile(inputset, configuration.fork()) + x = np.random.randint(0, 2**5) + y = np.random.randint(0, 2**5) + assert circuit.encrypt_run_decrypt(x, y) == x + y + + +def test_constant_tensor(helpers): + """ + Test that fhe.constant works with arrays. + """ + configuration = helpers.configuration() + + @fhe.compiler({"x": "clear", "y": "encrypted"}) + def f(x, y): + return fhe.constant(x) + y + + inputset = [ + (np.random.randint(0, 2**5, size=10), np.random.randint(0, 2**5, size=10)) + for _ in range(100) + ] + circuit = f.compile(inputset, configuration.fork()) + x = np.random.randint(0, 2**5, size=10) + y = np.random.randint(0, 2**5, size=10) + assert np.all(circuit.encrypt_run_decrypt(x, y) == x + y)