diff --git a/docs/tutorial/extensions.md b/docs/tutorial/extensions.md index 4ba066d49..db65da2aa 100644 --- a/docs/tutorial/extensions.md +++ b/docs/tutorial/extensions.md @@ -513,3 +513,37 @@ assert circuit.encrypt_run_decrypt(0, 3, -5) == -5 {% hint style="info" %} `fhe.if_then_else` is just an alias for [np.where](https://numpy.org/doc/stable/reference/generated/numpy.where.html). {% endhint %} + +## fhe.identity(value) + +Allows you to copy the value: + +```python +import numpy as np +from concrete import fhe + +@fhe.compiler({"x": "encrypted"}) +def f(x): + return fhe.identity(x) + +inputset = [np.random.randint(-10, 10) for _ in range(10)] +circuit = f.compile(inputset) + +assert circuit.encrypt_run_decrypt(0) == 0 +assert circuit.encrypt_run_decrypt(1) == 1 +assert circuit.encrypt_run_decrypt(-1) == -1 +assert circuit.encrypt_run_decrypt(-3) == -3 +assert circuit.encrypt_run_decrypt(5) == 5 +``` + +{% hint style="info" %} +Identity extension can be used to clone an input while changing its bit-width. Imagine you +have `return x**2, x+100` where `x` is 2-bits. Because of `x+100`, `x` will be assigned 7-bits +and `x**2` would be more expensive than it needs to be. If `return x**2, fhe.identity(x)+100` +is used instead, `x` will be assigned 2-bits as it should and `fhe.identity(x)` will be assigned +7-bits as necessary. +{% endhint %} + +{% hint style="warning" %} +Identity extension only works in `Native` encoding, which is usually selected when all table lookups in the circuit are below or equal to 8 bits. +{% endhint %} diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index 879198b1f..b38103742 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -35,6 +35,7 @@ from .extensions import ( bits, conv, hint, + identity, if_then_else, maxpool, multivariate, diff --git a/frontends/concrete-python/concrete/fhe/extensions/__init__.py b/frontends/concrete-python/concrete/fhe/extensions/__init__.py index cdb0ae328..825074896 100644 --- a/frontends/concrete-python/concrete/fhe/extensions/__init__.py +++ b/frontends/concrete-python/concrete/fhe/extensions/__init__.py @@ -8,6 +8,7 @@ from .array import array from .bits import bits from .convolution import conv from .hint import hint +from .identity import identity from .maxpool import maxpool from .multivariate import multivariate from .ones import one, ones, ones_like diff --git a/frontends/concrete-python/concrete/fhe/extensions/identity.py b/frontends/concrete-python/concrete/fhe/extensions/identity.py new file mode 100644 index 000000000..ab13fc7db --- /dev/null +++ b/frontends/concrete-python/concrete/fhe/extensions/identity.py @@ -0,0 +1,37 @@ +""" +Declaration of `identity` extension. +""" + +from copy import deepcopy +from typing import Any, Union + +from ..representation import Node +from ..tracing import Tracer + + +def identity(x: Union[Tracer, Any]) -> Union[Tracer, Any]: + """ + Apply identity function to x. + + Bit-width of the input and the output can be different. + + Args: + x (Union[Tracer, Any]): + input to identity + + Returns: + Union[Tracer, Any]: + identity tracer if called with a tracer + deepcopy of the input otherwise + """ + + if not isinstance(x, Tracer): + return deepcopy(x) + + computation = Node.generic( + "identity", + [deepcopy(x.output)], + x.output, + lambda x: deepcopy(x), # pylint: disable=unnecessary-lambda + ) + return Tracer(computation, [x]) diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index 3739b0498..2d8f33938 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -2269,6 +2269,38 @@ class Context: ) -> Conversion: return self.comparison(resulting_type, x, y, accept={Comparison.GREATER, Comparison.EQUAL}) + def identity(self, resulting_type: ConversionType, x: Conversion) -> Conversion: + assert ( + x.is_encrypted + and resulting_type.is_encrypted + and x.shape == resulting_type.shape + and x.is_signed == resulting_type.is_signed + ) + + if resulting_type.bit_width == x.bit_width: + return x + + result = self.extract_bits( + self.tensor(self.eint(resulting_type.bit_width), shape=x.shape), + x, + bits=slice(0, x.original_bit_width), + ) + + if x.is_signed: + sign = self.extract_bits( + self.tensor(self.eint(resulting_type.bit_width), shape=x.shape), + x, + bits=(x.original_bit_width - 1), + ) + base = self.mul( + resulting_type, + sign, + self.constant(self.i(sign.bit_width + 1), -(2**x.original_bit_width)), + ) + result = self.add(resulting_type, base, result) + + return result + def index_static( self, resulting_type: ConversionType, diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index 3204399fc..a8bb13fad 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -359,6 +359,10 @@ class Converter: return self.tlu(ctx, node, preds) + def identity(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: + assert len(preds) == 1 + return ctx.identity(ctx.typeof(node), preds[0]) + def index_static(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 1 return ctx.index_static( diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py b/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py index b84d80ea1..6aace8bda 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py @@ -151,7 +151,7 @@ class ProcessRounding(GraphProcessor): return identity identity = Node.generic( - "identity", + "reinterpret", [deepcopy(node.output)], deepcopy(node.output), lambda x: x, diff --git a/frontends/concrete-python/tests/execution/test_identity.py b/frontends/concrete-python/tests/execution/test_identity.py new file mode 100644 index 000000000..b58b65012 --- /dev/null +++ b/frontends/concrete-python/tests/execution/test_identity.py @@ -0,0 +1,71 @@ +""" +Tests of execution of identity extension. +""" + +import random + +import numpy as np +import pytest + +from concrete import fhe +from concrete.fhe.dtypes import Integer + +# pylint: disable=redefined-outer-name + + +@pytest.mark.parametrize( + "sample,expected_output", + [ + (0, 0), + (1, 1), + (-1, -1), + (10, 10), + (-10, -10), + ], +) +def test_plain_identity(sample, expected_output): + """ + Test plain evaluation of identity extension. + """ + assert fhe.identity(sample) == expected_output + + +operations = [ + lambda x: fhe.identity(x), + lambda x: fhe.identity(x) + 100, +] + +cases = [] +for function in operations: + for bit_width in [1, 2, 3, 4, 5, 8, 12]: + for is_signed in [False, True]: + for shape in [(), (3,), (2, 3)]: + cases += [ + [ + function, + bit_width, + is_signed, + shape, + ] + ] + + +@pytest.mark.parametrize( + "function,bit_width,is_signed,shape", + cases, +) +def test_identity(function, bit_width, is_signed, shape, helpers): + """ + Test encrypted evaluation of identity extension. + """ + + dtype = Integer(is_signed, bit_width) + + inputset = [np.random.randint(dtype.min(), dtype.max() + 1, size=shape) for _ in range(100)] + configuration = helpers.configuration() + + compiler = fhe.Compiler(function, {"x": "encrypted"}) + circuit = compiler.compile(inputset, configuration) + + for value in random.sample(inputset, 8): + helpers.check_execution(circuit, function, value, retries=3) diff --git a/frontends/concrete-python/tests/mlir/test_converter.py b/frontends/concrete-python/tests/mlir/test_converter.py index d2219a5df..99369f710 100644 --- a/frontends/concrete-python/tests/mlir/test_converter.py +++ b/frontends/concrete-python/tests/mlir/test_converter.py @@ -516,7 +516,7 @@ Function you are trying to compile cannot be compiled ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear %1 = round_bit_pattern(%0, lsbs_to_remove=2) # ClearScalar ∈ [12, 32] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear round bit pattern is not supported -%2 = identity(%1) # ClearScalar +%2 = reinterpret(%1) # ClearScalar return %2 """, # noqa: E501