diff --git a/docs/_static/rounding/approximate-off-by-one-error-approx-clipping.png b/docs/_static/rounding/approximate-off-by-one-error-approx-clipping.png new file mode 100644 index 000000000..9364d8f4a Binary files /dev/null and b/docs/_static/rounding/approximate-off-by-one-error-approx-clipping.png differ diff --git a/docs/_static/rounding/approximate-off-by-one-error-logical-clipping.png b/docs/_static/rounding/approximate-off-by-one-error-logical-clipping.png new file mode 100644 index 000000000..ca70ea3aa Binary files /dev/null and b/docs/_static/rounding/approximate-off-by-one-error-logical-clipping.png differ diff --git a/docs/_static/rounding/approximate-off-by-one-error.png b/docs/_static/rounding/approximate-off-by-one-error.png new file mode 100644 index 000000000..3983d3409 Binary files /dev/null and b/docs/_static/rounding/approximate-off-by-one-error.png differ diff --git a/docs/_static/rounding/approximate-off-centering-distribution.png b/docs/_static/rounding/approximate-off-centering-distribution.png new file mode 100644 index 000000000..447fba6f5 Binary files /dev/null and b/docs/_static/rounding/approximate-off-centering-distribution.png differ diff --git a/docs/_static/rounding/approximate-speedup.png b/docs/_static/rounding/approximate-speedup.png new file mode 100644 index 000000000..164ba7e05 Binary files /dev/null and b/docs/_static/rounding/approximate-speedup.png differ diff --git a/docs/howto/configure.md b/docs/howto/configure.md index 9d2fc7489..616ef274b 100644 --- a/docs/howto/configure.md +++ b/docs/howto/configure.md @@ -118,3 +118,12 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t * Chunk size of the ReLU extension when [fhe.bits](../tutorial/bit_extraction.md) implementation is used. * **if_then_else_chunk_size**: int = 3 * Chunk size to use when converting `fhe.if_then_else` extension. +* **rounding_exactness** : Exactness = `fhe.Exactness.EXACT` + * Set default exactness mode for the rounding operation: + * `EXACT`: threshold for rounding up or down is exactly centered between upper and lower value, + * `APPROXIMATE`: faster but threshold for rounding up or down is approximately centered with pseudo-random shift. + * Precise and more complete behavior is described in [fhe.rounding_bit_pattern](../tutorial/rounding.md). +* **approximate_rounding_config** : ApproximateRoundingConfig = `fhe.ApproximateRoundingConfig()`: + * Provide more fine control on [approximate rounding](../tutorial/rounding.md#approximate-rounding-features): + * to enable exact cliping, + * or/and approximate clipping which make overflow protection faster. diff --git a/docs/tutorial/rounding.md b/docs/tutorial/rounding.md index 74967852b..948b9b666 100644 --- a/docs/tutorial/rounding.md +++ b/docs/tutorial/rounding.md @@ -97,7 +97,7 @@ prints: and displays: -![](../\_static/rounding/identity.png) +![](../_static/rounding/identity.png) {% hint style="info" %} If the rounded number is one of the last `2**(lsbs_to_remove - 1)` numbers in the input range `[0, 2**original_bit_width)`, an overflow **will** happen. @@ -194,7 +194,7 @@ The reason why the speed-up is not increasing with `lsbs_to_remove` is because t and displays: -![](../\_static/rounding/lsbs_to_remove.png) +![](../_static/rounding/lsbs_to_remove.png) {% hint style="info" %} Feel free to disable overflow protection and see what happens. @@ -289,8 +289,73 @@ target_msbs=1 => 2.34x speedup and displays: -![](../\_static/rounding/msbs_to_keep.png) +![](../_static/rounding/msbs_to_keep.png) {% hint style="warning" %} `AutoRounder`s should be defined outside the function that is being compiled. They are used to store the result of the adjustment process, so they shouldn't be created each time the function is called. Furthermore, each `AutoRounder` should be used with exactly one `round_bit_pattern` call. {% endhint %} + + +## Exactness + +One use of rounding is doing faster computation by ignoring the lower significant bits. +For this usage, you can even get faster results if you accept the rounding it-self to be slighlty inexact. +The speedup is usually around 2x-3x but can be higher for big precision reduction. +This also enable higher precisions values that are not possible otherwise. + +| ![approximate-speedup.png](../_static/rounding/approximate-speedup.png) | +|:--:| +| *Using the default configuration in approximate mode. For 3, 4, 5 and 6 reduced precision bits and accumulator precision up to 32bits | + + + +You can turn on this mode either globally on the configuration: +```python +configuration = fhe.Configuration( + ... + rounding_exactness=fhe.Exactness.APPROXIMATE +) +``` +or on/off locally: +```python +v = fhe.round_bit_pattern(v, lsbs_to_remove=2, exactness=fhe.Exactness.APPROXIMATE) +v = fhe.round_bit_pattern(v, lsbs_to_remove=2, exactness=fhe.Exactness.EXACT) +``` + +In approximate mode the rounding threshold up or down is not perfectly centered: +The off-centering is: +* is bounded, i.e. at worst an off-by-one on the reduced precision value compared to the exact result, +* is pseudo-random, i.e. it will be different on each call, +* almost symetrically distributed, +* depends on cryptographic properties like the encryption mask, the encryption noise and the crypto-parameters. + +| ![approximate-off-by-one-error.png](../_static/rounding/approximate-off-by-one-error.png) | +|:--:| +| *In blue the exact value, the red dots are approximate values due to off-centered transition in approximate mode.* | + +| ![approximate-off-centering-distribution.png](../_static/rounding/approximate-off-centering-distribution.png) | +|:--:| +| *Histogram of transitions off-centering delta. Each count correspond to a specific random mask and a specific encryption noise.* | + +## Approximate rounding features + +With approximate rounding, you can enable an approximate clipping to get further improve performance in the case of overflow handling. Approximate clipping enable to discard the extra bit of overflow protection bit in the successor TLU. For consistency a logical clipping is available when this optimization is not suitable. + +### Logical clipping + +When fast approximate clipping is not suitable (i.e. slower), it's better to apply logical clipping for consistency and better resilience to code change. +It has no extra cost since it's fuzed with the successor TLU. + +| ![logical-clipping.png](../_static/rounding/approximate-off-by-one-error-logical-clipping.png) | +|:--:| +| *Only the last step is clipped.* | + + +### Approximate clipping + +This set the first precision where approximate clipping is enabled, starting from this precision, an extra small precision TLU is introduced to safely remove the extra precision bit used to contain overflow. This way the successor TLU is faster. +E.g. for a rounding to 7bits, that finishes to a TLU of 8bits due to overflow, forcing to use a TLU of 7bits is 3x faster. + +| ![approximate-clipping.png](../_static/rounding/approximate-off-by-one-error-approx-clipping.png) | +|:--:| +| *The last steps are decreased.* | diff --git a/frontends/concrete-python/concrete/fhe/__init__.py b/frontends/concrete-python/concrete/fhe/__init__.py index db965cbd3..0e8386ca9 100644 --- a/frontends/concrete-python/concrete/fhe/__init__.py +++ b/frontends/concrete-python/concrete/fhe/__init__.py @@ -9,6 +9,7 @@ from concrete.compiler import EvaluationKeys, Parameter, PublicArguments, Public from .compilation import ( DEFAULT_GLOBAL_P_ERROR, DEFAULT_P_ERROR, + ApproximateRoundingConfig, BitwiseStrategy, Circuit, Client, @@ -18,6 +19,7 @@ from .compilation import ( Configuration, DebugArtifacts, EncryptionStatus, + Exactness, Keys, MinMaxStrategy, MultiParameterStrategy, diff --git a/frontends/concrete-python/concrete/fhe/compilation/__init__.py b/frontends/concrete-python/concrete/fhe/compilation/__init__.py index 2043e751e..5f134c2a9 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/__init__.py +++ b/frontends/concrete-python/concrete/fhe/compilation/__init__.py @@ -9,9 +9,11 @@ from .compiler import Compiler, EncryptionStatus from .configuration import ( DEFAULT_GLOBAL_P_ERROR, DEFAULT_P_ERROR, + ApproximateRoundingConfig, BitwiseStrategy, ComparisonStrategy, Configuration, + Exactness, MinMaxStrategy, MultiParameterStrategy, MultivariateStrategy, diff --git a/frontends/concrete-python/concrete/fhe/compilation/configuration.py b/frontends/concrete-python/concrete/fhe/compilation/configuration.py index 016adb487..f01b34bc4 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/configuration.py +++ b/frontends/concrete-python/concrete/fhe/compilation/configuration.py @@ -3,7 +3,8 @@ Declaration of `Configuration` class. """ import platform -from enum import Enum +from dataclasses import dataclass +from enum import Enum, IntEnum from pathlib import Path from typing import List, Optional, Tuple, Union, get_type_hints @@ -73,6 +74,54 @@ class MultiParameterStrategy(str, Enum): raise ValueError(message) +class Exactness(IntEnum): + """ + Exactness, to specify for specific operator the implementation preference (default and local). + """ + + EXACT = 0 + APPROXIMATE = 1 + + +@dataclass +class ApproximateRoundingConfig: + """ + Controls the behavior of approximate rounding. + + In the following `k` is the ideal rounding output precision. + Often the precision used after rounding is `k`+1 to avoid overflow. + `logical_clipping`, `approximate_clipping_start_precision` can be used to stay at precision `k`, + either logically or physically at the successor TLU. + See examples in https://github.com/zama-ai/concrete/blob/main/docs/tutorial/rounding.md. + """ + + logical_clipping: bool = True + """ + Enable logical clipping to simulate a precision `k` in the successor TLU of precision `k`+1. + """ + + approximate_clipping_start_precision: int = 5 + """Actively avoid the overflow using a `k`-1 precision TLU. + This is similar to logical clipping but less accurate and faster. + Effect on: + * accuracy: the upper values of the rounding range are sligtly decreased, + * cost: adds an extra `k`-1 bits TLU to guarantee that the precision after rounding is `k`. + This is usually a win when `k` >= 5 . + This is enabled by default for `k` >= 5. + Due to the extra inaccuracy and cost, it is possible to disable it completely using False.""" + + reduce_precision_after_approximate_clipping: bool = True + """Enable the reduction to `k` bits in the TLU. + Can be disabled for debugging/testing purposes. + When disabled along with logical_clipping, the result of approximate clipping is accessible. + """ + + symetrize_deltas: bool = True + """Enable asymetry of correction of deltas w.r.t. the exact rounding computation. + Can be disabled for debugging/testing purposes. + """ + + class ComparisonStrategy(str, Enum): """ ComparisonStrategy, to specify implementation preference for comparisons. @@ -933,6 +982,8 @@ class Configuration: relu_on_bits_chunk_size: int if_then_else_chunk_size: int additional_processors: List[GraphProcessor] + rounding_exactness: Exactness + approximate_rounding_config: ApproximateRoundingConfig def __init__( self, @@ -990,6 +1041,8 @@ class Configuration: relu_on_bits_chunk_size: int = 3, if_then_else_chunk_size: int = 3, additional_processors: Optional[List[GraphProcessor]] = None, + rounding_exactness: Exactness = Exactness.EXACT, + approximate_rounding_config: Optional[ApproximateRoundingConfig] = None, ): self.verbose = verbose self.compiler_debug_mode = compiler_debug_mode @@ -1073,6 +1126,10 @@ class Configuration: self.relu_on_bits_chunk_size = relu_on_bits_chunk_size self.if_then_else_chunk_size = if_then_else_chunk_size self.additional_processors = [] if additional_processors is None else additional_processors + self.rounding_exactness = rounding_exactness + self.approximate_rounding_config = ( + approximate_rounding_config or ApproximateRoundingConfig() + ) self._validate() @@ -1134,6 +1191,8 @@ class Configuration: relu_on_bits_chunk_size: Union[Keep, int] = KEEP, if_then_else_chunk_size: Union[Keep, int] = KEEP, additional_processors: Union[Keep, Optional[List[GraphProcessor]]] = KEEP, + rounding_exactness: Union[Keep, Exactness] = KEEP, + approximate_rounding_config: Union[Keep, Optional[ApproximateRoundingConfig]] = KEEP, ) -> "Configuration": """ Get a new configuration from another one specified changes. diff --git a/frontends/concrete-python/concrete/fhe/compilation/server.py b/frontends/concrete-python/concrete/fhe/compilation/server.py index f6da3b1e7..8ec8ea007 100644 --- a/frontends/concrete-python/concrete/fhe/compilation/server.py +++ b/frontends/concrete-python/concrete/fhe/compilation/server.py @@ -4,7 +4,6 @@ Declaration of `Server` class. # pylint: disable=import-error,no-member,no-name-in-module -import json import shutil import tempfile from pathlib import Path @@ -12,6 +11,7 @@ from typing import Dict, List, Optional, Tuple, Union # mypy: disable-error-code=attr-defined import concrete.compiler +import jsonpickle from concrete.compiler import ( CompilationContext, CompilationOptions, @@ -253,7 +253,7 @@ class Server: f.write("1" if self.is_simulated else "0") with open(Path(tmp) / "configuration.json", "w", encoding="utf-8") as f: - f.write(json.dumps(self._configuration.__dict__)) + f.write(jsonpickle.dumps(self._configuration.__dict__)) shutil.make_archive(path, "zip", tmp) @@ -300,7 +300,7 @@ class Server: mlir = f.read() with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f: - configuration = Configuration().fork(**json.load(f)) + configuration = Configuration().fork(**jsonpickle.loads(f.read())) return Server.create(mlir, configuration, is_simulated) diff --git a/frontends/concrete-python/concrete/fhe/extensions/round_bit_pattern.py b/frontends/concrete-python/concrete/fhe/extensions/round_bit_pattern.py index 48a3d9be2..c67380831 100644 --- a/frontends/concrete-python/concrete/fhe/extensions/round_bit_pattern.py +++ b/frontends/concrete-python/concrete/fhe/extensions/round_bit_pattern.py @@ -4,10 +4,11 @@ Declaration of `round_bit_pattern` function, to provide an interface for rounded import threading from copy import deepcopy -from typing import Any, Callable, Dict, Iterable, List, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np +from ..compilation.configuration import Exactness from ..dtypes import Integer from ..mlir.utils import MAXIMUM_TLU_BIT_WIDTH from ..representation import Node @@ -158,6 +159,7 @@ def round_bit_pattern( x: Union[int, np.integer, List, np.ndarray, Tracer], lsbs_to_remove: Union[int, AutoRounder], overflow_protection: bool = True, + exactness: Optional[Exactness] = None, ) -> Union[int, np.integer, List, np.ndarray, Tracer]: """ Round the bit pattern of an integer. @@ -212,6 +214,11 @@ def round_bit_pattern( overflow_protection (bool, default = True) whether to adjust bit widths and lsbs to remove to avoid overflows + exactness (Optional[Exactness], default = None) + select the exactness of the operation, None means use the global exactness. + The global exactnessdefault is EXACT. + It can be changed on the Configuration object. + Returns: Union[int, np.integer, np.ndarray, Tracer]: Tracer that respresents the operation during tracing @@ -240,6 +247,8 @@ def round_bit_pattern( def evaluator( x: Union[int, np.integer, np.ndarray], lsbs_to_remove: int, + overflow_protection: bool, # pylint: disable=unused-argument + exactness: Optional[Exactness], # pylint: disable=unused-argument ) -> Union[int, np.integer, np.ndarray]: if lsbs_to_remove == 0: return x @@ -255,8 +264,11 @@ def round_bit_pattern( [deepcopy(x.output)], deepcopy(x.output), evaluator, - kwargs={"lsbs_to_remove": lsbs_to_remove}, - attributes={"overflow_protection": overflow_protection}, + kwargs={ + "lsbs_to_remove": lsbs_to_remove, + "overflow_protection": overflow_protection, + "exactness": exactness, + }, ) return Tracer(computation, [x]) @@ -276,6 +288,6 @@ def round_bit_pattern( message = f"Expected input to be an int or a numpy array but it's {type(x).__name__}" raise TypeError(message) - return evaluator(x, lsbs_to_remove) + return evaluator(x, lsbs_to_remove, overflow_protection, exactness) # pylint: enable=protected-access,too-many-branches diff --git a/frontends/concrete-python/concrete/fhe/mlir/context.py b/frontends/concrete-python/concrete/fhe/mlir/context.py index 2fcb4cae3..eaf242058 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/context.py +++ b/frontends/concrete-python/concrete/fhe/mlir/context.py @@ -28,6 +28,7 @@ from ..compilation.configuration import ( BitwiseStrategy, ComparisonStrategy, Configuration, + Exactness, MinMaxStrategy, ) from ..dtypes import Integer @@ -3143,6 +3144,8 @@ class Context: resulting_type: ConversionType, x: Conversion, lsbs_to_remove: int, + exactness: Exactness, + overflow_detected: bool, ) -> Conversion: if x.is_clear: highlights = { @@ -3153,19 +3156,121 @@ class Context: assert x.bit_width > lsbs_to_remove + if exactness is None: + exactness = self.configuration.rounding_exactness + + intermediate_bit_width = x.bit_width - lsbs_to_remove intermediate_type = self.typeof( ValueDescription( - dtype=Integer(is_signed=x.is_signed, bit_width=(x.bit_width - lsbs_to_remove)), + dtype=Integer(is_signed=x.is_signed, bit_width=intermediate_bit_width), shape=x.shape, is_encrypted=x.is_encrypted, ) ) + if exactness is Exactness.APPROXIMATE: + approx_conf = self.configuration.approximate_rounding_config + # 1. Unskew TLU's futur error distribution on approximated value + # this balances agains all leading zeros in the noise (ignoring symetric noise) + unskewed = x + if approx_conf.symetrize_deltas: + highest_supported_precision = 62 + delta_precision = highest_supported_precision - x.type.bit_width + full_precision = x.type.bit_width + delta_precision + half_in_extra_precision = ( + 1 << (delta_precision - 1) + ) - 1 # slightly smaller then half + half_in_extra_precision = self.constant( + self.i(full_precision + 1), half_in_extra_precision + ) + x_high_precision = self.reinterpret(x, bit_width=full_precision) + unskewed = self.add( + x_high_precision.type, x_high_precision, half_in_extra_precision + ) - rounded = self.operation( - fhe.RoundEintOp if x.is_scalar else fhelinalg.RoundOp, - intermediate_type, - x.result, - ) + # 2. Cancel overflow to have a TLU at exactly target_precision + # starting from 5 bits, the extra overflow bit in the TLU is too costly + # a smaller precision TLU can detect approximately the overflow to cancel it + # this is only possible because of the extra-bit from overflow protection + target_precision = x.bit_width - lsbs_to_remove - overflow_detected + if ( + overflow_detected + and target_precision >= approx_conf.approximate_clipping_start_precision + and approx_conf.approximate_clipping_start_precision is not False + ): + unskew_pre_overflow = self.reinterpret(unskewed, bit_width=x.type.bit_width) + overflow_precision = max(2, target_precision - 1) + # The last half-cell values in overflow_precision will naturally overflow. + # But there can also be an off by minus 1 to the previous cell in the worst case + # and an overflow in the successor TLU. + # We sliglty decrease the value of the rounding output on theses cells. + # `realign_cell_by` defines where the decrease starts to apply. + step_high = 1 << (x.type.bit_width - intermediate_bit_width) + step_wide = step_high + full_decrease_by = step_high + realign_cell_by = step_wide // 2 + realign_cell_by = self.constant(self.i(x.type.bit_width + 1), realign_cell_by) + overflow_candidate = self.sub( + unskew_pre_overflow.type, unskew_pre_overflow, realign_cell_by + ) + overflow_candidate = self.reinterpret( + overflow_candidate, bit_width=overflow_precision + ) + half_tlu_size = 2 ** (overflow_precision - 1) + if x.is_signed: + negative_size = half_tlu_size + positive_size = negative_size + used_positive_size = half_tlu_size // 2 + # this is oriented for precision higher than 3 + # it will work with smaller precision but with more invasive effects + prevent_overflow_positive = ( + # pre-overflow + [0] * used_positive_size + # overflow part + + [3 * full_decrease_by // 4, full_decrease_by] + # unused + + [0] * (positive_size - used_positive_size - 2) + )[:half_tlu_size] + prevent_overflow = prevent_overflow_positive + [0] * negative_size + else: + prevent_overflow = ( + # pre-overflow + [0] * half_tlu_size + # overflow part + + [3 * full_decrease_by // 4, full_decrease_by] + # unused + + [0] * (half_tlu_size - 2) + )[: 2 * half_tlu_size] + signed_type = self.to_signed(x).type + overflow_cancel = self.reinterpret( + self.tlu( + signed_type, + overflow_candidate, + table=prevent_overflow, + ), + bit_width=x.type.bit_width, + signed=x.is_signed, + ) + unskewed = self.sub(unskew_pre_overflow.type, unskew_pre_overflow, overflow_cancel) + if approx_conf.reduce_precision_after_approximate_clipping: + # a minimum bitwith 3 is required to multiply by 2 in signed case + if unskewed.bit_width < 3: + # pragma: no-cover + self.reinterpret(unskewed, bit_width=3) + unskewed = self.mul( + unskewed.type, unskewed, self.constant(self.i(unskewed.bit_width + 1), 2) + ) + rounded = self.reinterpret(unskewed, bit_width=intermediate_type.bit_width - 1) + # The TLU after may be adjusted to the right precision (see `Converter.tlu`) + else: + rounded = self.reinterpret(unskewed, bit_width=intermediate_type.bit_width) + else: + rounded = self.reinterpret(unskewed, bit_width=intermediate_type.bit_width) + else: + rounded = self.operation( + fhe.RoundEintOp if x.is_scalar else fhelinalg.RoundOp, + intermediate_type, + x.result, + ) return self.to_signedness(rounded, of=resulting_type) @@ -3593,13 +3698,16 @@ class Context: return x - def reinterpret(self, x: Conversion, *, bit_width: int) -> Conversion: + def reinterpret( + self, x: Conversion, *, bit_width: int, signed: Optional[bool] = None + ) -> Conversion: assert x.is_encrypted if x.bit_width == bit_width: return x - resulting_element_type = (self.eint if x.is_unsigned else self.esint)(bit_width) + result_signed = x.is_unsigned if signed is None else signed + resulting_element_type = (self.eint if result_signed else self.esint)(bit_width) resulting_type = self.tensor(resulting_element_type, shape=x.shape) operation = ( diff --git a/frontends/concrete-python/concrete/fhe/mlir/converter.py b/frontends/concrete-python/concrete/fhe/mlir/converter.py index 3792c9402..8271cb742 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/converter.py +++ b/frontends/concrete-python/concrete/fhe/mlir/converter.py @@ -4,6 +4,7 @@ Declaration of `Converter` class. # pylint: disable=import-error,no-name-in-module +import math import sys from typing import Dict, List, Tuple, Union @@ -17,8 +18,7 @@ from mlir.ir import InsertionPoint as MlirInsertionPoint from mlir.ir import Location as MlirLocation from mlir.ir import Module as MlirModule -from concrete.fhe.compilation.configuration import Configuration - +from ..compilation.configuration import Configuration, Exactness from ..representation import Graph, Node, Operation from .context import Context from .conversion import Conversion @@ -195,7 +195,9 @@ class Converter: multivariate_strategy_preference=configuration.multivariate_strategy_preference, min_max_strategy_preference=configuration.min_max_strategy_preference, ), - ProcessRounding(), + ProcessRounding( + rounding_exactness=configuration.rounding_exactness, + ), ] + configuration.additional_processors for processor in pipeline: @@ -485,9 +487,9 @@ class Converter: assert len(preds) == 1 pred = preds[0] + overflow_detected = node.properties["overflow_detected"] if pred.is_encrypted and pred.bit_width != pred.original_bit_width: overflow_protection = node.properties["overflow_protection"] - overflow_detected = node.properties["overflow_detected"] shifter = 2 ** (pred.bit_width - pred.original_bit_width) if overflow_protection and overflow_detected: @@ -500,6 +502,8 @@ class Converter: ctx.typeof(node), pred, node.properties["final_lsbs_to_remove"], + node.properties["exactness"], + overflow_detected, ) def subtract(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: @@ -531,6 +535,61 @@ class Converter: # otherwise, a simple reshape would work as we already have the correct shape return ctx.reshape(preds[0], shape=node.output.shape) + @classmethod + def tlu_adjust(cls, table, variable_input, target_bit_width, clipping, reduce_precision): + target_bit_width = min( + variable_input.bit_width, target_bit_width + ) # inconsistency due to more precise bound vs precision + table_bit_width = math.log2(len(table)) + assert table_bit_width.is_integer() + table_bit_width = int(table_bit_width) + table_has_right_size = variable_input.bit_width == table_bit_width + if table_has_right_size and not clipping: + return table + half_rounded_bit_width = target_bit_width - 1 + if variable_input.is_signed: + # upper = positive part, lower = negative part + upper_clipping_index = 2**half_rounded_bit_width - 1 + lower_clipping_index = 2**table_bit_width - 2**half_rounded_bit_width + positive_clipped_card = 2 ** (table_bit_width - 1) - upper_clipping_index - 1 + negative_clipped_card = 2 ** (table_bit_width - 1) - 2**half_rounded_bit_width + else: + upper_clipping_index = 2**target_bit_width - 1 + lower_clipping_index = 0 + positive_clipped_card = 2**table_bit_width - upper_clipping_index - 1 + lower_clipping = table[lower_clipping_index] + upper_clipping = table[upper_clipping_index] + if table_has_right_size: + # value clipping + assert clipping + if variable_input.is_signed: + table = ( + list(table[: upper_clipping_index + 1]) + + [upper_clipping] * positive_clipped_card + + [lower_clipping] * negative_clipped_card + + list(table[lower_clipping_index:]) + ) + else: + table = ( + list(table[lower_clipping_index : upper_clipping_index + 1]) + + [upper_clipping] * positive_clipped_card + ) + assert len(table) == 2**table_bit_width, ( + len(table), + 2**table_bit_width, + table, + upper_clipping, + lower_clipping, + ) + return np.array(table, dtype=np.uint64) # negative value are in unsigned representation + + # adjust tlu size + assert reduce_precision + if variable_input.is_signed: + return np.concatenate((table[: upper_clipping_index + 1], table[lower_clipping_index:])) + + return table[lower_clipping_index : upper_clipping_index + 1] + def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: assert node.converted_to_table_lookup @@ -654,6 +713,33 @@ class Converter: variable_input = ctx.mul(variable_input.type, variable_input, shifter) variable_input = ctx.reinterpret(variable_input, bit_width=truncated_bit_width) + elif variable_input.origin.properties.get("name") == "round_bit_pattern": + exactness = ( + variable_input.origin.properties["exactness"] + or ctx.configuration.rounding_exactness + ) + if exactness == Exactness.APPROXIMATE: + # we clip values to enforce input precision exactly as queried + original_bit_width = variable_input.origin.properties["original_bit_width"] + lsbs_to_remove = variable_input.origin.properties["kwargs"]["lsbs_to_remove"] + overflow = variable_input.origin.properties["overflow_detected"] + rounded_bit_width = original_bit_width - lsbs_to_remove - overflow + approx_config = ctx.configuration.approximate_rounding_config + clipping = approx_config.logical_clipping + reduce_precision = approx_config.reduce_precision_after_approximate_clipping + if len(tables) == 1: + lut_values = self.tlu_adjust( + lut_values, variable_input, rounded_bit_width, clipping, reduce_precision + ) + else: + for sub_i, sub_lut_values in enumerate(lut_values): + lut_values[sub_i] = self.tlu_adjust( + sub_lut_values, + variable_input, + rounded_bit_width, + clipping, + reduce_precision, + ) if len(tables) == 1: return ctx.tlu(ctx.typeof(node), on=variable_input, table=lut_values.tolist()) diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py index 8871a0ae8..6466f0ac5 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/assign_bit_widths.py @@ -230,7 +230,7 @@ class AdditionalConstraints: return any(pred.output.is_clear for pred in preds) def has_overflow_protection(self, node: Node, preds: List[Node]) -> bool: - return node.properties["attributes"]["overflow_protection"] is True + return node.properties["kwargs"]["overflow_protection"] is True # =========== # Constraints diff --git a/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py b/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py index b60216a32..01abaaa8f 100644 --- a/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py +++ b/frontends/concrete-python/concrete/fhe/mlir/processors/process_rounding.py @@ -8,6 +8,7 @@ from typing import Optional import numpy as np +from ...compilation.configuration import Exactness from ...dtypes import Integer from ...extensions.table import LookupTable from ...representation import Graph, GraphProcessor, Node @@ -18,6 +19,14 @@ class ProcessRounding(GraphProcessor): ProcessRounding graph processor, to analyze rounding and support regular operations on it. """ + rounding_exactness: Exactness + + def __init__( + self, + rounding_exactness: Exactness, + ): + self.rounding_exactness = rounding_exactness + def apply(self, graph: Graph): rounding_nodes = graph.query_nodes(operation_filter="round_bit_pattern") for node in rounding_nodes: @@ -26,8 +35,13 @@ class ProcessRounding(GraphProcessor): original_lsbs_to_remove = node.properties["kwargs"]["lsbs_to_remove"] final_lsbs_to_remove = node.properties["final_lsbs_to_remove"] + exactness = node.properties["exactness"] + if exactness is None: + exactness = self.rounding_exactness + if original_lsbs_to_remove != 0 and final_lsbs_to_remove == 0: - self.replace_with_tlu(graph, node) + if exactness != Exactness.APPROXIMATE: + self.replace_with_tlu(graph, node) continue self.process_successors(graph, node) @@ -43,12 +57,14 @@ class ProcessRounding(GraphProcessor): pred = preds[0] assert isinstance(pred.output.dtype, Integer) - overflow_protection = node.properties["attributes"]["overflow_protection"] + exactness = node.properties["kwargs"]["exactness"] + overflow_protection = node.properties["kwargs"]["overflow_protection"] overflow_detected = ( overflow_protection and pred.properties["original_bit_width"] != node.properties["original_bit_width"] ) + node.properties["exactness"] = exactness node.properties["overflow_protection"] = overflow_protection node.properties["overflow_detected"] = overflow_detected diff --git a/frontends/concrete-python/mypy.ini b/frontends/concrete-python/mypy.ini index 35159d2f1..62bf5c450 100644 --- a/frontends/concrete-python/mypy.ini +++ b/frontends/concrete-python/mypy.ini @@ -1,3 +1,4 @@ [mypy] plugins = numpy.typing.mypy_plugin disable_error_code = annotation-unchecked +allow_redefinition = True diff --git a/frontends/concrete-python/requirements.txt b/frontends/concrete-python/requirements.txt index 364f7d192..1d612353f 100644 --- a/frontends/concrete-python/requirements.txt +++ b/frontends/concrete-python/requirements.txt @@ -1,4 +1,5 @@ importlib-resources>=6.1 +jsonpickle>=3.0.3 networkx>=2.6 numpy>=1.23 scipy>=1.10 diff --git a/frontends/concrete-python/tests/execution/test_round_bit_pattern.py b/frontends/concrete-python/tests/execution/test_round_bit_pattern.py index 964472aa0..ef070514c 100644 --- a/frontends/concrete-python/tests/execution/test_round_bit_pattern.py +++ b/frontends/concrete-python/tests/execution/test_round_bit_pattern.py @@ -6,6 +6,7 @@ import numpy as np import pytest from concrete import fhe +from concrete.fhe.compilation.configuration import Exactness from concrete.fhe.representation.utils import format_constant @@ -432,14 +433,14 @@ def test_auto_rounding(helpers): helpers.check_str( f""" -%0 = x # EncryptedScalar -%1 = round_bit_pattern(%0, lsbs_to_remove=3) # EncryptedScalar -%2 = tlu(%1, table={table3_formatted_string}) # EncryptedScalar -%3 = round_bit_pattern(%2, lsbs_to_remove=2) # EncryptedScalar -%4 = tlu(%3, table={table4_formatted_string}) # EncryptedScalar +%0 = x # EncryptedScalar +%1 = round_bit_pattern(%0, lsbs_to_remove=3, overflow_protection=True, exactness=None) # EncryptedScalar +%2 = tlu(%1, table={table3_formatted_string}) # EncryptedScalar +%3 = round_bit_pattern(%2, lsbs_to_remove=2, overflow_protection=True, exactness=None) # EncryptedScalar +%4 = tlu(%3, table={table4_formatted_string}) # EncryptedScalar return %4 - """, + """, # noqa: E501 str(circuit3.graph.format(show_bounds=False)), ) @@ -611,3 +612,253 @@ def test_round_bit_pattern_overflow_to_sign_bit(helpers): for x in inputset: helpers.check_execution(circuit, function, x, retries=3) + + +def test_round_bit_pattern_approximate_enabling(helpers): + """ + Test round bit pattern various activation paths. + """ + + @fhe.compiler({"x": "encrypted"}) + def function_default(x): + return fhe.round_bit_pattern(x, lsbs_to_remove=8) + + @fhe.compiler({"x": "encrypted"}) + def function_exact(x): + return fhe.round_bit_pattern(x, lsbs_to_remove=8, exactness=Exactness.EXACT) + + @fhe.compiler({"x": "encrypted"}) + def function_approx(x): + return fhe.round_bit_pattern(x, lsbs_to_remove=8, exactness=Exactness.APPROXIMATE) + + inputset = [-(2**10), 2**10 - 1] + configuration = helpers.configuration() + + circuit_default_default = function_default.compile(inputset, configuration) + circuit_default_exact = function_default.compile( + inputset, configuration.fork(rounding_exactness=Exactness.EXACT) + ) + circuit_default_approx = function_default.compile( + inputset, configuration.fork(rounding_exactness=Exactness.APPROXIMATE) + ) + circuit_exact = function_exact.compile( + inputset, configuration.fork(rounding_exactness=Exactness.APPROXIMATE) + ) + circuit_approx = function_approx.compile( + inputset, configuration.fork(rounding_exactness=Exactness.EXACT) + ) + + assert circuit_approx.complexity < circuit_exact.complexity + assert circuit_exact.complexity == circuit_default_default.complexity + assert circuit_exact.complexity == circuit_default_exact.complexity + assert circuit_approx.complexity == circuit_default_approx.complexity + + +@pytest.mark.parametrize( + "accumulator_precision,reduced_precision,signed,conf", + [ + (8, 4, True, fhe.ApproximateRoundingConfig(False, 4)), + (7, 4, False, fhe.ApproximateRoundingConfig(False, 4)), + (9, 3, True, fhe.ApproximateRoundingConfig(True, False)), + (8, 3, False, fhe.ApproximateRoundingConfig(True, False)), + (7, 3, False, fhe.ApproximateRoundingConfig(True, 3)), + (7, 2, True, fhe.ApproximateRoundingConfig(False, 2)), + (7, 2, False, fhe.ApproximateRoundingConfig(False, False, False, False)), + (8, 1, True, fhe.ApproximateRoundingConfig(False, 1)), + (8, 1, False, fhe.ApproximateRoundingConfig(True, False)), + (6, 5, False, fhe.ApproximateRoundingConfig(True, 6)), + (6, 5, False, fhe.ApproximateRoundingConfig(True, 5)), + ], +) +def test_round_bit_pattern_approximate_off_by_one_errors( + accumulator_precision, reduced_precision, signed, conf, helpers +): + """ + Test round bit pattern off by 1 errors. + """ + lsbs_to_remove = accumulator_precision - reduced_precision + + @fhe.compiler({"x": "encrypted"}) + def function(x): + x = fhe.univariate(lambda x: x)(x) + x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove) + x = x // 2**lsbs_to_remove + return x + + if signed: + inputset = [-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 1) - 1] + else: + inputset = [0, 2**accumulator_precision - 1] + + configuration = helpers.configuration() + circuit_exact = function.compile(inputset, configuration) + circuit_approx = function.compile( + inputset, + configuration.fork( + approximate_rounding_config=conf, rounding_exactness=Exactness.APPROXIMATE + ), + ) + # check it's better even with bad conf + assert circuit_approx.complexity < circuit_exact.complexity + + testset = range(*inputset) + + nb_error = 0 + for x in testset: + approx = circuit_approx.encrypt_run_decrypt(x) + approx_simu = circuit_approx.simulate(x) + exact = circuit_exact.simulate(x) + assert abs(approx_simu - exact) <= 1 + assert abs(approx_simu - approx) <= 1 + delta = abs(approx - approx_simu) + assert delta <= 1 + nb_error += delta > 0 + + nb_transitions = 2 ** (accumulator_precision - reduced_precision) + assert nb_error <= 3 * nb_transitions # of the same order as transitions but small sample size + + +@pytest.mark.parametrize( + "signed,physical", + [(signed, physical) for signed in (True, False) for physical in (True, False)], +) +def test_round_bit_pattern_approximate_clippping(signed, physical, helpers): + """ + Test round bit pattern clipping. + """ + accumulator_precision = 6 + reduced_precision = 3 + lsbs_to_remove = accumulator_precision - reduced_precision + + @fhe.compiler({"x": "encrypted"}) + def function(x): + x = fhe.univariate(lambda x: x)(x) + x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove) + x = x // 2**lsbs_to_remove + return x + + if signed: + input_domain = range(-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 1)) + else: + input_domain = range(0, 2 ** (accumulator_precision)) + + configuration = helpers.configuration() + approx_conf = fhe.ApproximateRoundingConfig( + logical_clipping=not physical, + approximate_clipping_start_precision=physical and reduced_precision, + reduce_precision_after_approximate_clipping=False, + ) + no_clipping_conf = fhe.ApproximateRoundingConfig( + logical_clipping=False, approximate_clipping_start_precision=False + ) + assert approx_conf.logical_clipping or approx_conf.approximate_clipping_start_precision + circuit_clipping = function.compile( + input_domain, + configuration.fork( + approximate_rounding_config=approx_conf, rounding_exactness=Exactness.APPROXIMATE + ), + ) + circuit_no_clipping = function.compile( + input_domain, + configuration.fork( + approximate_rounding_config=no_clipping_conf, rounding_exactness=Exactness.APPROXIMATE + ), + ) + + if signed: + clipped_output_domain = range(-(2 ** (reduced_precision - 1)), 2 ** (reduced_precision - 1)) + else: + clipped_output_domain = range(0, 2**reduced_precision) + + # With clipping + for x in input_domain: + assert ( + circuit_clipping.encrypt_run_decrypt(x) in clipped_output_domain + ), circuit_clipping.mlir # no overflow + assert circuit_clipping.simulate(x) in clipped_output_domain + + # Without clipping + # overflow + assert circuit_no_clipping.simulate(input_domain[-1]) not in clipped_output_domain + + +@pytest.mark.parametrize( + "signed,accumulator_precision", + [ + (signed, accumulator_precision) + for signed in (True, False) + for accumulator_precision in (13, 24) + ], +) +def test_round_bit_pattern_approximate_acc_to_6_costs(signed, accumulator_precision, helpers): + """ + Test round bit pattern speedup when approximatipn is activated. + """ + reduced_precision = 6 + lsbs_to_remove = accumulator_precision - reduced_precision + + @fhe.compiler({"x": "encrypted"}) + def function(x): + x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove, overflow_protection=True) + x = x // 2**lsbs_to_remove + return x + + # with overflow + if signed: + input_domain = [-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 1) - 1] + else: + input_domain = [0, 2 ** (accumulator_precision) - 1] + + configuration = helpers.configuration().fork( + single_precision=False, + parameter_selection_strategy=fhe.ParameterSelectionStrategy.MULTI, + composable=True, + ) + circuit_exact = function.compile(input_domain, configuration) + approx_conf_fastest = fhe.ApproximateRoundingConfig(approximate_clipping_start_precision=6) + approx_conf_safest = fhe.ApproximateRoundingConfig(approximate_clipping_start_precision=100) + circuit_approx_fastest = function.compile( + input_domain, + configuration.fork( + approximate_rounding_config=approx_conf_fastest, + rounding_exactness=Exactness.APPROXIMATE, + ), + ) + circuit_approx_safest = function.compile( + input_domain, + configuration.fork( + approximate_rounding_config=approx_conf_safest, rounding_exactness=Exactness.APPROXIMATE + ), + ) + assert circuit_approx_safest.complexity < circuit_exact.complexity + assert circuit_approx_fastest.complexity < circuit_approx_safest.complexity + + @fhe.compiler({"x": "encrypted"}) + def function(x): # pylint: disable=function-redefined + x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove, overflow_protection=False) + x = x // 2**lsbs_to_remove + return x + + # without overflow + if signed: + input_domain = [-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 2) - 2] + else: + input_domain = [0, 2 ** (accumulator_precision - 1) - 2] + + circuit_exact_no_ovf = function.compile(input_domain, configuration) + circuit_approx_fastest_no_ovf = function.compile( + input_domain, + configuration.fork( + approximate_rounding_config=approx_conf_fastest, + rounding_exactness=Exactness.APPROXIMATE, + ), + ) + circuit_approx_safest_no_ovf = function.compile( + input_domain, + configuration.fork( + approximate_rounding_config=approx_conf_safest, rounding_exactness=Exactness.APPROXIMATE + ), + ) + assert circuit_approx_fastest_no_ovf.complexity == circuit_approx_safest_no_ovf.complexity + assert circuit_approx_safest_no_ovf.complexity < circuit_exact_no_ovf.complexity + assert circuit_exact_no_ovf.complexity < circuit_exact.complexity diff --git a/frontends/concrete-python/tests/mlir/test_converter.py b/frontends/concrete-python/tests/mlir/test_converter.py index 99369f710..742693413 100644 --- a/frontends/concrete-python/tests/mlir/test_converter.py +++ b/frontends/concrete-python/tests/mlir/test_converter.py @@ -512,11 +512,11 @@ return %1 Function you are trying to compile cannot be compiled -%0 = x # ClearScalar ∈ [10, 30] -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear -%1 = round_bit_pattern(%0, lsbs_to_remove=2) # ClearScalar ∈ [12, 32] -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear round bit pattern is not supported -%2 = reinterpret(%1) # ClearScalar +%0 = x # ClearScalar ∈ [10, 30] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear +%1 = round_bit_pattern(%0, lsbs_to_remove=2, overflow_protection=True, exactness=None) # ClearScalar ∈ [12, 32] +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear round bit pattern is not supported +%2 = reinterpret(%1) # ClearScalar return %2 """, # noqa: E501