From 000ca60062e5856861fb7d6159f803ac95e9faeb Mon Sep 17 00:00:00 2001 From: Umut Date: Thu, 25 Jan 2024 15:07:05 +0300 Subject: [PATCH] feat(frontend-python): if then else extension --- docs/howto/configure.md | 2 + docs/tutorial/extensions.md | 34 +++- .../concrete-python/concrete/fhe/__init__.py | 1 + .../concrete/fhe/compilation/configuration.py | 4 + .../concrete/fhe/compilation/utils.py | 3 + .../concrete/fhe/extensions/__init__.py | 2 + .../concrete/fhe/mlir/context.py | 179 ++++++++++++++++++ .../concrete/fhe/mlir/converter.py | 4 + .../concrete/fhe/representation/node.py | 3 + .../tests/execution/test_if_then_else.py | 132 +++++++++++++ .../tests/mlir/test_converter.py | 76 ++++++++ 11 files changed, 439 insertions(+), 1 deletion(-) create mode 100644 frontends/concrete-python/tests/execution/test_if_then_else.py diff --git a/docs/howto/configure.md b/docs/howto/configure.md index 9921b803d..9d2fc7489 100644 --- a/docs/howto/configure.md +++ b/docs/howto/configure.md @@ -116,3 +116,5 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t * Bit-width to start implementing the ReLU extension with [fhe.bits](../tutorial/bit_extraction.md). * **relu_on_bits_chunk_size**: int = 3, * Chunk size of the ReLU extension when [fhe.bits](../tutorial/bit_extraction.md) implementation is used. +* **if_then_else_chunk_size**: int = 3 + * Chunk size to use when converting `fhe.if_then_else` extension. diff --git a/docs/tutorial/extensions.md b/docs/tutorial/extensions.md index 814be6822..4ba066d49 100644 --- a/docs/tutorial/extensions.md +++ b/docs/tutorial/extensions.md @@ -351,7 +351,6 @@ def is_vectors_same(x, y): return is_same ``` - ## fhe.relu(value) Allows you to perform ReLU operation, with the same semantic as `x if x >= 0 else 0`: @@ -481,3 +480,36 @@ The default values of these options are set based on simple circuits. How they a {% hint style="warning" %} Conversion with the second method (i.e., using chunks) only works in `Native` encoding, which is usually selected when all table lookups in the circuit are below or equal to 8 bits. {% endhint %} + + +## fhe.if_then_else(condition, x, y) + +Allows you to perform ternary if operation, with the same semantic as `x if condition else y`: + +```python +import numpy as np +from concrete import fhe + +@fhe.compiler({"condition": "encrypted", "x": "encrypted", "y": "encrypted"}) +def f(condition, x, y): + return fhe.if_then_else(condition, x, y) + +inputset = [ + ( + np.random.randint(0, 2**1), + np.random.randint(0, 2**5), + np.random.randint(-2**3, 2**3), + ) + for _ in range(10) +] +circuit = f.compile(inputset) + +assert circuit.encrypt_run_decrypt(1, 3, 5) == 3 +assert circuit.encrypt_run_decrypt(0, 3, 5) == 5 +assert circuit.encrypt_run_decrypt(1, 3, -5) == 3 +assert circuit.encrypt_run_decrypt(0, 3, -5) == -5 +``` + +{% hint style="info" %} +`fhe.if_then_else` is just an alias for [np.where](https://numpy.org/doc/stable/reference/generated/numpy.where.html). +{% endhint %} diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index e46f2dec4..879198b1f 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -35,6 +35,7 @@ from .extensions import ( bits, conv, hint, + if_then_else, maxpool, multivariate, one, diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index 52a2095b6..6065b58f0 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -930,6 +930,7 @@ class Configuration: use_gpu: bool relu_on_bits_threshold: int relu_on_bits_chunk_size: int + if_then_else_chunk_size: int def __init__( self, @@ -985,6 +986,7 @@ class Configuration: use_gpu: bool = False, relu_on_bits_threshold: int = 7, relu_on_bits_chunk_size: int = 3, + if_then_else_chunk_size: int = 3, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1066,6 +1068,7 @@ class Configuration: self.use_gpu = use_gpu self.relu_on_bits_threshold = relu_on_bits_threshold self.relu_on_bits_chunk_size = relu_on_bits_chunk_size + self.if_then_else_chunk_size = if_then_else_chunk_size self._validate() @@ -1125,6 +1128,7 @@ class Configuration: use_gpu: Union[Keep, bool] = KEEP, relu_on_bits_threshold: Union[Keep, int] = KEEP, relu_on_bits_chunk_size: Union[Keep, int] = KEEP, + if_then_else_chunk_size: Union[Keep, int] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. diff --git a/frontends/concrete-python/concrete/fhe/compilation/utils.py b/frontends/concrete-python/concrete/fhe/compilation/utils.py index d68105be0..ffe14a9d0 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/utils.py +++ b/frontends/concrete-python/concrete/fhe/compilation/utils.py @@ -666,6 +666,9 @@ def convert_subgraph_to_subgraph_node( variable_input_node.location, ] + if terminal_node.properties["name"] == "where": + return None + raise RuntimeError( "A subgraph within the function you are trying to compile cannot be fused " "because it has multiple input nodes\n\n" diff --git a/frontends/concrete-python/concrete/fhe/extensions/__init__.py b/frontends/concrete-python/concrete/fhe/extensions/__init__.py index 36c006888..cdb0ae328 100644 --- a/frontends/concrete-python/concrete/fhe/extensions/__init__.py +++ b/frontends/concrete-python/concrete/fhe/extensions/__init__.py @@ -2,6 +2,8 @@ Provide additional features that are not present in numpy. """ +from numpy import where as if_then_else + from .array import array from .bits import bits from .convolution import conv diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index 784f9f0cb..3739b0498 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -1408,6 +1408,124 @@ class Context: # add contributions of x and y to compute the result return self.add(resulting_type, x_contribution, y_contribution) + def multiplication_with_boolean( + self, + boolean: Conversion, + value: Conversion, + *, + resulting_bit_width: int, + chunk_size: int, + inverted: bool = False, + ): + """ + Calculate boolean * value using bits. + """ + + assert boolean.is_encrypted and boolean.is_unsigned and boolean.bit_width == 1 + assert value.is_encrypted + + boolean = self.reinterpret(boolean, bit_width=(chunk_size + 1)) + + chunks = [] + intermediate_type = self.tensor( + self.eint(resulting_bit_width), + shape=(np.zeros(boolean.shape) + np.zeros(value.shape)).shape, + ) + + cursor = 0 + while cursor < value.original_bit_width: + start = cursor + end = min(start + chunk_size, value.original_bit_width) + cursor += chunk_size + + chunk = self.extract_bits( + self.tensor(self.eint(chunk_size + 1), shape=value.shape), + value, + slice(start, end), + ) + packed_boolean_and_chunk = self.add( + self.tensor(self.eint(chunk_size + 1), shape=intermediate_type.shape), + boolean, + chunk, + ) + + chunks.append( + self.tlu( + intermediate_type, + packed_boolean_and_chunk, + ( + ( + [0 for _ in range(2**chunk_size)] + + [x << start for x in range(2**chunk_size)] + ) + if not inverted + else ( + [x << start for x in range(2**chunk_size)] + + [0 for _ in range(2**chunk_size)] + ) + ), + ) + ) + + result = self.tree_add(intermediate_type, chunks) + if value.is_signed: + # bit extraction results in unsigned result + # so if you have -2 in 3-bits for example + # you have the following bit pattern + # 110 which is 6 but we want -2 + # it's simple to get it back to -2 + # if the value is signed and negative + # we need to apply -(2**original_bit_width) + result + # for the case above, it's -8 + 6 == -2 + + sign = self.extract_bits( + self.tensor(self.eint(2), shape=value.shape), + value, + bits=(value.original_bit_width - 1), + ) + packed_boolean_and_sign = self.add( + self.tensor(self.eint(2), shape=result.shape), + self.reinterpret(boolean, bit_width=2), + sign, + ) + + result_signed_type = self.tensor(self.esint(result.bit_width), shape=result.shape) + result_base = self.tlu( + result_signed_type, + packed_boolean_and_sign, + ( + ( + [ + # boolean=0, sign=0 + 0, + # boolean=0, sign=1 + 0, + # boolean=1, sign=0 + 0, + # boolean=1, sign=1 + -(2**value.original_bit_width), + ] + ) + if not inverted + else ( + [ + # boolean=0, sign=0 + 0, + # boolean=0, sign=1 + -(2**value.original_bit_width), + # boolean=1, sign=0 + 0, + # boolean=1, sign=1 + 0, + ] + ) + ), + ) + + result = self.add(result_signed_type, result_base, result) + + return result + # operations # each operation is checked for compatibility @@ -3447,6 +3565,9 @@ class Context: def reinterpret(self, x: Conversion, *, bit_width: int) -> Conversion: assert x.is_encrypted + if x.bit_width == bit_width: + return x + resulting_element_type = (self.eint if x.is_unsigned else self.esint)(bit_width) resulting_type = self.tensor(resulting_element_type, shape=x.shape) @@ -3455,6 +3576,64 @@ class Context: ) return self.operation(operation, resulting_type, x.result) + def where( + self, + resulting_type: ConversionType, + condition: Conversion, + when_true: Conversion, + when_false: Conversion, + ) -> Conversion: + if condition.is_clear: + highlights = { + condition.origin: "condition is not encrypted", + self.converting: "but it needs to be for where operation", + } + self.error(highlights) + + if when_true.is_clear: + highlights = { + when_true.origin: "outcome of true condition is not encrypted", + self.converting: "but it needs to be for where operation", + } + self.error(highlights) + + if when_false.is_clear: + highlights = { + when_false.origin: "outcome of false condition is not encrypted", + self.converting: "but it needs to be for where operation", + } + self.error(highlights) + + if condition.original_bit_width != 1 or condition.is_signed: + highlights = { + condition.origin: "condition is not uint1", + self.converting: "but it needs to be for where operation", + } + self.error(highlights) + + if condition.bit_width != 1: + shifter = self.constant(self.i(condition.bit_width + 1), 2 ** (condition.bit_width - 1)) + condition = self.reinterpret(self.mul(condition.type, condition, shifter), bit_width=1) + + chunk_size = self.configuration.if_then_else_chunk_size + + when_true_contribution = self.multiplication_with_boolean( + condition, + when_true, + resulting_bit_width=resulting_type.bit_width, + chunk_size=chunk_size, + inverted=False, + ) + when_false_contribution = self.multiplication_with_boolean( + condition, + when_false, + resulting_bit_width=resulting_type.bit_width, + chunk_size=chunk_size, + inverted=True, + ) + + return self.add(resulting_type, when_true_contribution, when_false_contribution) + def zeros(self, resulting_type: ConversionType) -> Conversion: assert resulting_type.is_encrypted diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index 433a8dc0f..75d91c569 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -671,6 +671,10 @@ class Converter: assert len(preds) == 1 return ctx.truncate_bit_pattern(preds[0], node.properties["kwargs"]["lsbs_to_remove"]) + def where(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: + assert len(preds) == 3 + return ctx.where(ctx.typeof(node), *preds) + def zeros(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert len(preds) == 0 return ctx.zeros(ctx.typeof(node)) diff --git a/frontends/concrete-python/concrete/fhe/representation/node.py b/frontends/concrete-python/concrete/fhe/representation/node.py index bdc9f53d6..380fca1b7 100644 --- a/frontends/concrete-python/concrete/fhe/representation/node.py +++ b/frontends/concrete-python/concrete/fhe/representation/node.py @@ -325,6 +325,9 @@ class Node: elements = [format_indexing_element(element) for element in index] return f"bits({predecessors[0]})[{', '.join(elements)}]" + if name == "where" and len(predecessors) == 3: + return f"{predecessors[1]} if {predecessors[0]} else {predecessors[2]}" + args.extend( format_constant(value, maximum_constant_length) for value in self.properties["args"] ) diff --git a/frontends/concrete-python/tests/execution/test_if_then_else.py b/frontends/concrete-python/tests/execution/test_if_then_else.py new file mode 100644 index 000000000..60859bde0 --- /dev/null +++ b/frontends/concrete-python/tests/execution/test_if_then_else.py @@ -0,0 +1,132 @@ +""" +Tests of execution of `if_then_else` extension. +""" + +import random + +import numpy as np +import pytest + +from concrete import fhe +from concrete.fhe.dtypes import Integer +from concrete.fhe.values import EncryptedScalar, EncryptedTensor + +# pylint: disable=redefined-outer-name + +functions = [ + lambda condition, when_true, when_false: np.where(condition, when_true, when_false), + lambda condition, when_true, when_false: np.where(condition, when_true, when_false) + 100, +] +condition_descriptions = [ + EncryptedTensor(Integer(is_signed=False, bit_width=1), shape=shape) + for shape in [(), (2,), (3, 2)] +] +when_true_descriptions = [ + EncryptedTensor(Integer(is_signed, bit_width), shape=shape) + for is_signed in [False, True] + for bit_width in [3, 4, 5] + for shape in [(), (2,), (3, 2)] +] +when_false_descriptions = [ + EncryptedTensor(Integer(is_signed, bit_width), shape=shape) + for is_signed in [False, True] + for bit_width in [3, 4, 5] + for shape in [(), (2,), (3, 2)] +] +chunk_sizes = [ + 2, + 3, +] + +cases = [] +for function in functions: + for condition_description in condition_descriptions: + for when_true_description in when_true_descriptions: + for when_false_description in when_false_descriptions: + for chunk_size in chunk_sizes: + cases.append( + ( + function, + condition_description, + when_true_description, + when_false_description, + chunk_size, + ) + ) + +cases = random.sample(cases, 100) +cases.append( + ( + # special case of increased bit-width for condition + lambda condition, when_true, when_false: ( + np.where(condition, when_true, when_false) + (condition + 100) + ), + EncryptedScalar(Integer(is_signed=False, bit_width=1)), + EncryptedScalar(Integer(is_signed=False, bit_width=4)), + EncryptedScalar(Integer(is_signed=False, bit_width=4)), + 2, + ) +) + + +@pytest.mark.parametrize( + "function,condition_description,when_true_description,when_false_description,chunk_size", + cases, +) +def test_if_then_else( + function, + condition_description, + when_true_description, + when_false_description, + chunk_size, + helpers, +): + """ + Test encrypted evaluation of `if_then_else` extension. + """ + + print() + print() + print( + f"[{when_true_description}] " + f"if [{condition_description}] " + f"else [{when_false_description}] " + f"{{{chunk_size=}}}" + ) + print() + print() + + inputset = [ + ( + np.random.randint( + condition_description.dtype.min(), + condition_description.dtype.max() + 1, + size=condition_description.shape, + ), + np.random.randint( + when_true_description.dtype.min(), + when_true_description.dtype.max() + 1, + size=when_true_description.shape, + ), + np.random.randint( + when_false_description.dtype.min(), + when_false_description.dtype.max() + 1, + size=when_false_description.shape, + ), + ) + for _ in range(100) + ] + configuration = helpers.configuration().fork(if_then_else_chunk_size=chunk_size) + + compiler = fhe.Compiler( + function, + { + "condition": "encrypted", + "when_true": "encrypted", + "when_false": "encrypted", + }, + ) + circuit = compiler.compile(inputset, configuration) + + for sample in random.sample(inputset, 8): + helpers.check_execution(circuit, function, list(sample), retries=3) diff --git a/frontends/concrete-python/tests/mlir/test_converter.py b/frontends/concrete-python/tests/mlir/test_converter.py index 3af9af07a..d2219a5df 100644 --- a/frontends/concrete-python/tests/mlir/test_converter.py +++ b/frontends/concrete-python/tests/mlir/test_converter.py @@ -1023,6 +1023,82 @@ Function you are trying to compile cannot be compiled ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear bit extraction is not supported return %1 + """, # noqa: E501 + ), + pytest.param( + lambda x, y, z: np.where(x, y, z), + {"x": "encrypted", "y": "encrypted", "z": "encrypted"}, + [(10, 2, 3), (20, 1, 5)], + RuntimeError, + """ + +Function you are trying to compile cannot be compiled + +%0 = x # EncryptedScalar ∈ [10, 20] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ condition is not uint1 +%1 = y # EncryptedScalar ∈ [1, 2] +%2 = z # EncryptedScalar ∈ [3, 5] +%3 = %1 if %0 else %2 # EncryptedScalar ∈ [1, 2] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but it needs to be for where operation +return %3 + + """, # noqa: E501 + ), + pytest.param( + lambda x, y, z: np.where(x, y, z), + {"x": "encrypted", "y": "clear", "z": "encrypted"}, + [(1, 2, 3), (0, 1, 5)], + RuntimeError, + """ + +Function you are trying to compile cannot be compiled + +%0 = x # EncryptedScalar ∈ [0, 1] +%1 = y # ClearScalar ∈ [1, 2] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ outcome of true condition is not encrypted +%2 = z # EncryptedScalar ∈ [3, 5] +%3 = %1 if %0 else %2 # EncryptedScalar ∈ [2, 5] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but it needs to be for where operation +return %3 + + """, # noqa: E501 + ), + pytest.param( + lambda x, y, z: np.where(x, y, z), + {"x": "encrypted", "y": "encrypted", "z": "clear"}, + [(1, 2, 3), (0, 1, 5)], + RuntimeError, + """ + +Function you are trying to compile cannot be compiled + +%0 = x # EncryptedScalar ∈ [0, 1] +%1 = y # EncryptedScalar ∈ [1, 2] +%2 = z # ClearScalar ∈ [3, 5] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ outcome of false condition is not encrypted +%3 = %1 if %0 else %2 # EncryptedScalar ∈ [2, 5] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but it needs to be for where operation +return %3 + + """, # noqa: E501 + ), + pytest.param( + lambda x, y, z: np.where(x, y, z), + {"x": "clear", "y": "encrypted", "z": "encrypted"}, + [(1, 2, 3), (0, 1, 5)], + RuntimeError, + """ + +Function you are trying to compile cannot be compiled + +%0 = x # ClearScalar ∈ [0, 1] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ condition is not encrypted +%1 = y # EncryptedScalar ∈ [1, 2] +%2 = z # EncryptedScalar ∈ [3, 5] +%3 = %1 if %0 else %2 # EncryptedScalar ∈ [2, 5] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but it needs to be for where operation +return %3 + """, # noqa: E501 ), ],