feat(frontend-python): approximate mode for round_bit_pattern

This commit is contained in:
rudy
2024-01-08 15:51:26 +01:00
committed by rudy-6-4
parent 9b5a2e46da
commit 05bd8cc5f2
20 changed files with 649 additions and 37 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 43 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 47 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 30 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 89 KiB

View File

@@ -118,3 +118,12 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t
* Chunk size of the ReLU extension when [fhe.bits](../tutorial/bit_extraction.md) implementation is used.
* **if_then_else_chunk_size**: int = 3
* Chunk size to use when converting `fhe.if_then_else` extension.
* **rounding_exactness** : Exactness = `fhe.Exactness.EXACT`
* Set default exactness mode for the rounding operation:
* `EXACT`: threshold for rounding up or down is exactly centered between upper and lower value,
* `APPROXIMATE`: faster but threshold for rounding up or down is approximately centered with pseudo-random shift.
* Precise and more complete behavior is described in [fhe.rounding_bit_pattern](../tutorial/rounding.md).
* **approximate_rounding_config** : ApproximateRoundingConfig = `fhe.ApproximateRoundingConfig()`:
* Provide more fine control on [approximate rounding](../tutorial/rounding.md#approximate-rounding-features):
* to enable exact cliping,
* or/and approximate clipping which make overflow protection faster.

View File

@@ -97,7 +97,7 @@ prints:
and displays:
![](../\_static/rounding/identity.png)
![](../_static/rounding/identity.png)
{% hint style="info" %}
If the rounded number is one of the last `2**(lsbs_to_remove - 1)` numbers in the input range `[0, 2**original_bit_width)`, an overflow **will** happen.
@@ -194,7 +194,7 @@ The reason why the speed-up is not increasing with `lsbs_to_remove` is because t
and displays:
![](../\_static/rounding/lsbs_to_remove.png)
![](../_static/rounding/lsbs_to_remove.png)
{% hint style="info" %}
Feel free to disable overflow protection and see what happens.
@@ -289,8 +289,73 @@ target_msbs=1 => 2.34x speedup
and displays:
![](../\_static/rounding/msbs_to_keep.png)
![](../_static/rounding/msbs_to_keep.png)
{% hint style="warning" %}
`AutoRounder`s should be defined outside the function that is being compiled. They are used to store the result of the adjustment process, so they shouldn't be created each time the function is called. Furthermore, each `AutoRounder` should be used with exactly one `round_bit_pattern` call.
{% endhint %}
## Exactness
One use of rounding is doing faster computation by ignoring the lower significant bits.
For this usage, you can even get faster results if you accept the rounding it-self to be slighlty inexact.
The speedup is usually around 2x-3x but can be higher for big precision reduction.
This also enable higher precisions values that are not possible otherwise.
| ![approximate-speedup.png](../_static/rounding/approximate-speedup.png) |
|:--:|
| *Using the default configuration in approximate mode. For 3, 4, 5 and 6 reduced precision bits and accumulator precision up to 32bits |
You can turn on this mode either globally on the configuration:
```python
configuration = fhe.Configuration(
...
rounding_exactness=fhe.Exactness.APPROXIMATE
)
```
or on/off locally:
```python
v = fhe.round_bit_pattern(v, lsbs_to_remove=2, exactness=fhe.Exactness.APPROXIMATE)
v = fhe.round_bit_pattern(v, lsbs_to_remove=2, exactness=fhe.Exactness.EXACT)
```
In approximate mode the rounding threshold up or down is not perfectly centered:
The off-centering is:
* is bounded, i.e. at worst an off-by-one on the reduced precision value compared to the exact result,
* is pseudo-random, i.e. it will be different on each call,
* almost symetrically distributed,
* depends on cryptographic properties like the encryption mask, the encryption noise and the crypto-parameters.
| ![approximate-off-by-one-error.png](../_static/rounding/approximate-off-by-one-error.png) |
|:--:|
| *In blue the exact value, the red dots are approximate values due to off-centered transition in approximate mode.* |
| ![approximate-off-centering-distribution.png](../_static/rounding/approximate-off-centering-distribution.png) |
|:--:|
| *Histogram of transitions off-centering delta. Each count correspond to a specific random mask and a specific encryption noise.* |
## Approximate rounding features
With approximate rounding, you can enable an approximate clipping to get further improve performance in the case of overflow handling. Approximate clipping enable to discard the extra bit of overflow protection bit in the successor TLU. For consistency a logical clipping is available when this optimization is not suitable.
### Logical clipping
When fast approximate clipping is not suitable (i.e. slower), it's better to apply logical clipping for consistency and better resilience to code change.
It has no extra cost since it's fuzed with the successor TLU.
| ![logical-clipping.png](../_static/rounding/approximate-off-by-one-error-logical-clipping.png) |
|:--:|
| *Only the last step is clipped.* |
### Approximate clipping
This set the first precision where approximate clipping is enabled, starting from this precision, an extra small precision TLU is introduced to safely remove the extra precision bit used to contain overflow. This way the successor TLU is faster.
E.g. for a rounding to 7bits, that finishes to a TLU of 8bits due to overflow, forcing to use a TLU of 7bits is 3x faster.
| ![approximate-clipping.png](../_static/rounding/approximate-off-by-one-error-approx-clipping.png) |
|:--:|
| *The last steps are decreased.* |

View File

@@ -9,6 +9,7 @@ from concrete.compiler import EvaluationKeys, Parameter, PublicArguments, Public
from .compilation import (
DEFAULT_GLOBAL_P_ERROR,
DEFAULT_P_ERROR,
ApproximateRoundingConfig,
BitwiseStrategy,
Circuit,
Client,
@@ -18,6 +19,7 @@ from .compilation import (
Configuration,
DebugArtifacts,
EncryptionStatus,
Exactness,
Keys,
MinMaxStrategy,
MultiParameterStrategy,

View File

@@ -9,9 +9,11 @@ from .compiler import Compiler, EncryptionStatus
from .configuration import (
DEFAULT_GLOBAL_P_ERROR,
DEFAULT_P_ERROR,
ApproximateRoundingConfig,
BitwiseStrategy,
ComparisonStrategy,
Configuration,
Exactness,
MinMaxStrategy,
MultiParameterStrategy,
MultivariateStrategy,

View File

@@ -3,7 +3,8 @@ Declaration of `Configuration` class.
"""
import platform
from enum import Enum
from dataclasses import dataclass
from enum import Enum, IntEnum
from pathlib import Path
from typing import List, Optional, Tuple, Union, get_type_hints
@@ -73,6 +74,54 @@ class MultiParameterStrategy(str, Enum):
raise ValueError(message)
class Exactness(IntEnum):
"""
Exactness, to specify for specific operator the implementation preference (default and local).
"""
EXACT = 0
APPROXIMATE = 1
@dataclass
class ApproximateRoundingConfig:
"""
Controls the behavior of approximate rounding.
In the following `k` is the ideal rounding output precision.
Often the precision used after rounding is `k`+1 to avoid overflow.
`logical_clipping`, `approximate_clipping_start_precision` can be used to stay at precision `k`,
either logically or physically at the successor TLU.
See examples in https://github.com/zama-ai/concrete/blob/main/docs/tutorial/rounding.md.
"""
logical_clipping: bool = True
"""
Enable logical clipping to simulate a precision `k` in the successor TLU of precision `k`+1.
"""
approximate_clipping_start_precision: int = 5
"""Actively avoid the overflow using a `k`-1 precision TLU.
This is similar to logical clipping but less accurate and faster.
Effect on:
* accuracy: the upper values of the rounding range are sligtly decreased,
* cost: adds an extra `k`-1 bits TLU to guarantee that the precision after rounding is `k`.
This is usually a win when `k` >= 5 .
This is enabled by default for `k` >= 5.
Due to the extra inaccuracy and cost, it is possible to disable it completely using False."""
reduce_precision_after_approximate_clipping: bool = True
"""Enable the reduction to `k` bits in the TLU.
Can be disabled for debugging/testing purposes.
When disabled along with logical_clipping, the result of approximate clipping is accessible.
"""
symetrize_deltas: bool = True
"""Enable asymetry of correction of deltas w.r.t. the exact rounding computation.
Can be disabled for debugging/testing purposes.
"""
class ComparisonStrategy(str, Enum):
"""
ComparisonStrategy, to specify implementation preference for comparisons.
@@ -933,6 +982,8 @@ class Configuration:
relu_on_bits_chunk_size: int
if_then_else_chunk_size: int
additional_processors: List[GraphProcessor]
rounding_exactness: Exactness
approximate_rounding_config: ApproximateRoundingConfig
def __init__(
self,
@@ -990,6 +1041,8 @@ class Configuration:
relu_on_bits_chunk_size: int = 3,
if_then_else_chunk_size: int = 3,
additional_processors: Optional[List[GraphProcessor]] = None,
rounding_exactness: Exactness = Exactness.EXACT,
approximate_rounding_config: Optional[ApproximateRoundingConfig] = None,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
@@ -1073,6 +1126,10 @@ class Configuration:
self.relu_on_bits_chunk_size = relu_on_bits_chunk_size
self.if_then_else_chunk_size = if_then_else_chunk_size
self.additional_processors = [] if additional_processors is None else additional_processors
self.rounding_exactness = rounding_exactness
self.approximate_rounding_config = (
approximate_rounding_config or ApproximateRoundingConfig()
)
self._validate()
@@ -1134,6 +1191,8 @@ class Configuration:
relu_on_bits_chunk_size: Union[Keep, int] = KEEP,
if_then_else_chunk_size: Union[Keep, int] = KEEP,
additional_processors: Union[Keep, Optional[List[GraphProcessor]]] = KEEP,
rounding_exactness: Union[Keep, Exactness] = KEEP,
approximate_rounding_config: Union[Keep, Optional[ApproximateRoundingConfig]] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.

View File

@@ -4,7 +4,6 @@ Declaration of `Server` class.
# pylint: disable=import-error,no-member,no-name-in-module
import json
import shutil
import tempfile
from pathlib import Path
@@ -12,6 +11,7 @@ from typing import Dict, List, Optional, Tuple, Union
# mypy: disable-error-code=attr-defined
import concrete.compiler
import jsonpickle
from concrete.compiler import (
CompilationContext,
CompilationOptions,
@@ -253,7 +253,7 @@ class Server:
f.write("1" if self.is_simulated else "0")
with open(Path(tmp) / "configuration.json", "w", encoding="utf-8") as f:
f.write(json.dumps(self._configuration.__dict__))
f.write(jsonpickle.dumps(self._configuration.__dict__))
shutil.make_archive(path, "zip", tmp)
@@ -300,7 +300,7 @@ class Server:
mlir = f.read()
with open(output_dir_path / "configuration.json", "r", encoding="utf-8") as f:
configuration = Configuration().fork(**json.load(f))
configuration = Configuration().fork(**jsonpickle.loads(f.read()))
return Server.create(mlir, configuration, is_simulated)

View File

@@ -4,10 +4,11 @@ Declaration of `round_bit_pattern` function, to provide an interface for rounded
import threading
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
from ..compilation.configuration import Exactness
from ..dtypes import Integer
from ..mlir.utils import MAXIMUM_TLU_BIT_WIDTH
from ..representation import Node
@@ -158,6 +159,7 @@ def round_bit_pattern(
x: Union[int, np.integer, List, np.ndarray, Tracer],
lsbs_to_remove: Union[int, AutoRounder],
overflow_protection: bool = True,
exactness: Optional[Exactness] = None,
) -> Union[int, np.integer, List, np.ndarray, Tracer]:
"""
Round the bit pattern of an integer.
@@ -212,6 +214,11 @@ def round_bit_pattern(
overflow_protection (bool, default = True)
whether to adjust bit widths and lsbs to remove to avoid overflows
exactness (Optional[Exactness], default = None)
select the exactness of the operation, None means use the global exactness.
The global exactnessdefault is EXACT.
It can be changed on the Configuration object.
Returns:
Union[int, np.integer, np.ndarray, Tracer]:
Tracer that respresents the operation during tracing
@@ -240,6 +247,8 @@ def round_bit_pattern(
def evaluator(
x: Union[int, np.integer, np.ndarray],
lsbs_to_remove: int,
overflow_protection: bool, # pylint: disable=unused-argument
exactness: Optional[Exactness], # pylint: disable=unused-argument
) -> Union[int, np.integer, np.ndarray]:
if lsbs_to_remove == 0:
return x
@@ -255,8 +264,11 @@ def round_bit_pattern(
[deepcopy(x.output)],
deepcopy(x.output),
evaluator,
kwargs={"lsbs_to_remove": lsbs_to_remove},
attributes={"overflow_protection": overflow_protection},
kwargs={
"lsbs_to_remove": lsbs_to_remove,
"overflow_protection": overflow_protection,
"exactness": exactness,
},
)
return Tracer(computation, [x])
@@ -276,6 +288,6 @@ def round_bit_pattern(
message = f"Expected input to be an int or a numpy array but it's {type(x).__name__}"
raise TypeError(message)
return evaluator(x, lsbs_to_remove)
return evaluator(x, lsbs_to_remove, overflow_protection, exactness)
# pylint: enable=protected-access,too-many-branches

View File

@@ -28,6 +28,7 @@ from ..compilation.configuration import (
BitwiseStrategy,
ComparisonStrategy,
Configuration,
Exactness,
MinMaxStrategy,
)
from ..dtypes import Integer
@@ -3143,6 +3144,8 @@ class Context:
resulting_type: ConversionType,
x: Conversion,
lsbs_to_remove: int,
exactness: Exactness,
overflow_detected: bool,
) -> Conversion:
if x.is_clear:
highlights = {
@@ -3153,19 +3156,121 @@ class Context:
assert x.bit_width > lsbs_to_remove
if exactness is None:
exactness = self.configuration.rounding_exactness
intermediate_bit_width = x.bit_width - lsbs_to_remove
intermediate_type = self.typeof(
ValueDescription(
dtype=Integer(is_signed=x.is_signed, bit_width=(x.bit_width - lsbs_to_remove)),
dtype=Integer(is_signed=x.is_signed, bit_width=intermediate_bit_width),
shape=x.shape,
is_encrypted=x.is_encrypted,
)
)
if exactness is Exactness.APPROXIMATE:
approx_conf = self.configuration.approximate_rounding_config
# 1. Unskew TLU's futur error distribution on approximated value
# this balances agains all leading zeros in the noise (ignoring symetric noise)
unskewed = x
if approx_conf.symetrize_deltas:
highest_supported_precision = 62
delta_precision = highest_supported_precision - x.type.bit_width
full_precision = x.type.bit_width + delta_precision
half_in_extra_precision = (
1 << (delta_precision - 1)
) - 1 # slightly smaller then half
half_in_extra_precision = self.constant(
self.i(full_precision + 1), half_in_extra_precision
)
x_high_precision = self.reinterpret(x, bit_width=full_precision)
unskewed = self.add(
x_high_precision.type, x_high_precision, half_in_extra_precision
)
rounded = self.operation(
fhe.RoundEintOp if x.is_scalar else fhelinalg.RoundOp,
intermediate_type,
x.result,
)
# 2. Cancel overflow to have a TLU at exactly target_precision
# starting from 5 bits, the extra overflow bit in the TLU is too costly
# a smaller precision TLU can detect approximately the overflow to cancel it
# this is only possible because of the extra-bit from overflow protection
target_precision = x.bit_width - lsbs_to_remove - overflow_detected
if (
overflow_detected
and target_precision >= approx_conf.approximate_clipping_start_precision
and approx_conf.approximate_clipping_start_precision is not False
):
unskew_pre_overflow = self.reinterpret(unskewed, bit_width=x.type.bit_width)
overflow_precision = max(2, target_precision - 1)
# The last half-cell values in overflow_precision will naturally overflow.
# But there can also be an off by minus 1 to the previous cell in the worst case
# and an overflow in the successor TLU.
# We sliglty decrease the value of the rounding output on theses cells.
# `realign_cell_by` defines where the decrease starts to apply.
step_high = 1 << (x.type.bit_width - intermediate_bit_width)
step_wide = step_high
full_decrease_by = step_high
realign_cell_by = step_wide // 2
realign_cell_by = self.constant(self.i(x.type.bit_width + 1), realign_cell_by)
overflow_candidate = self.sub(
unskew_pre_overflow.type, unskew_pre_overflow, realign_cell_by
)
overflow_candidate = self.reinterpret(
overflow_candidate, bit_width=overflow_precision
)
half_tlu_size = 2 ** (overflow_precision - 1)
if x.is_signed:
negative_size = half_tlu_size
positive_size = negative_size
used_positive_size = half_tlu_size // 2
# this is oriented for precision higher than 3
# it will work with smaller precision but with more invasive effects
prevent_overflow_positive = (
# pre-overflow
[0] * used_positive_size
# overflow part
+ [3 * full_decrease_by // 4, full_decrease_by]
# unused
+ [0] * (positive_size - used_positive_size - 2)
)[:half_tlu_size]
prevent_overflow = prevent_overflow_positive + [0] * negative_size
else:
prevent_overflow = (
# pre-overflow
[0] * half_tlu_size
# overflow part
+ [3 * full_decrease_by // 4, full_decrease_by]
# unused
+ [0] * (half_tlu_size - 2)
)[: 2 * half_tlu_size]
signed_type = self.to_signed(x).type
overflow_cancel = self.reinterpret(
self.tlu(
signed_type,
overflow_candidate,
table=prevent_overflow,
),
bit_width=x.type.bit_width,
signed=x.is_signed,
)
unskewed = self.sub(unskew_pre_overflow.type, unskew_pre_overflow, overflow_cancel)
if approx_conf.reduce_precision_after_approximate_clipping:
# a minimum bitwith 3 is required to multiply by 2 in signed case
if unskewed.bit_width < 3:
# pragma: no-cover
self.reinterpret(unskewed, bit_width=3)
unskewed = self.mul(
unskewed.type, unskewed, self.constant(self.i(unskewed.bit_width + 1), 2)
)
rounded = self.reinterpret(unskewed, bit_width=intermediate_type.bit_width - 1)
# The TLU after may be adjusted to the right precision (see `Converter.tlu`)
else:
rounded = self.reinterpret(unskewed, bit_width=intermediate_type.bit_width)
else:
rounded = self.reinterpret(unskewed, bit_width=intermediate_type.bit_width)
else:
rounded = self.operation(
fhe.RoundEintOp if x.is_scalar else fhelinalg.RoundOp,
intermediate_type,
x.result,
)
return self.to_signedness(rounded, of=resulting_type)
@@ -3593,13 +3698,16 @@ class Context:
return x
def reinterpret(self, x: Conversion, *, bit_width: int) -> Conversion:
def reinterpret(
self, x: Conversion, *, bit_width: int, signed: Optional[bool] = None
) -> Conversion:
assert x.is_encrypted
if x.bit_width == bit_width:
return x
resulting_element_type = (self.eint if x.is_unsigned else self.esint)(bit_width)
result_signed = x.is_unsigned if signed is None else signed
resulting_element_type = (self.eint if result_signed else self.esint)(bit_width)
resulting_type = self.tensor(resulting_element_type, shape=x.shape)
operation = (

View File

@@ -4,6 +4,7 @@ Declaration of `Converter` class.
# pylint: disable=import-error,no-name-in-module
import math
import sys
from typing import Dict, List, Tuple, Union
@@ -17,8 +18,7 @@ from mlir.ir import InsertionPoint as MlirInsertionPoint
from mlir.ir import Location as MlirLocation
from mlir.ir import Module as MlirModule
from concrete.fhe.compilation.configuration import Configuration
from ..compilation.configuration import Configuration, Exactness
from ..representation import Graph, Node, Operation
from .context import Context
from .conversion import Conversion
@@ -195,7 +195,9 @@ class Converter:
multivariate_strategy_preference=configuration.multivariate_strategy_preference,
min_max_strategy_preference=configuration.min_max_strategy_preference,
),
ProcessRounding(),
ProcessRounding(
rounding_exactness=configuration.rounding_exactness,
),
] + configuration.additional_processors
for processor in pipeline:
@@ -485,9 +487,9 @@ class Converter:
assert len(preds) == 1
pred = preds[0]
overflow_detected = node.properties["overflow_detected"]
if pred.is_encrypted and pred.bit_width != pred.original_bit_width:
overflow_protection = node.properties["overflow_protection"]
overflow_detected = node.properties["overflow_detected"]
shifter = 2 ** (pred.bit_width - pred.original_bit_width)
if overflow_protection and overflow_detected:
@@ -500,6 +502,8 @@ class Converter:
ctx.typeof(node),
pred,
node.properties["final_lsbs_to_remove"],
node.properties["exactness"],
overflow_detected,
)
def subtract(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
@@ -531,6 +535,61 @@ class Converter:
# otherwise, a simple reshape would work as we already have the correct shape
return ctx.reshape(preds[0], shape=node.output.shape)
@classmethod
def tlu_adjust(cls, table, variable_input, target_bit_width, clipping, reduce_precision):
target_bit_width = min(
variable_input.bit_width, target_bit_width
) # inconsistency due to more precise bound vs precision
table_bit_width = math.log2(len(table))
assert table_bit_width.is_integer()
table_bit_width = int(table_bit_width)
table_has_right_size = variable_input.bit_width == table_bit_width
if table_has_right_size and not clipping:
return table
half_rounded_bit_width = target_bit_width - 1
if variable_input.is_signed:
# upper = positive part, lower = negative part
upper_clipping_index = 2**half_rounded_bit_width - 1
lower_clipping_index = 2**table_bit_width - 2**half_rounded_bit_width
positive_clipped_card = 2 ** (table_bit_width - 1) - upper_clipping_index - 1
negative_clipped_card = 2 ** (table_bit_width - 1) - 2**half_rounded_bit_width
else:
upper_clipping_index = 2**target_bit_width - 1
lower_clipping_index = 0
positive_clipped_card = 2**table_bit_width - upper_clipping_index - 1
lower_clipping = table[lower_clipping_index]
upper_clipping = table[upper_clipping_index]
if table_has_right_size:
# value clipping
assert clipping
if variable_input.is_signed:
table = (
list(table[: upper_clipping_index + 1])
+ [upper_clipping] * positive_clipped_card
+ [lower_clipping] * negative_clipped_card
+ list(table[lower_clipping_index:])
)
else:
table = (
list(table[lower_clipping_index : upper_clipping_index + 1])
+ [upper_clipping] * positive_clipped_card
)
assert len(table) == 2**table_bit_width, (
len(table),
2**table_bit_width,
table,
upper_clipping,
lower_clipping,
)
return np.array(table, dtype=np.uint64) # negative value are in unsigned representation
# adjust tlu size
assert reduce_precision
if variable_input.is_signed:
return np.concatenate((table[: upper_clipping_index + 1], table[lower_clipping_index:]))
return table[lower_clipping_index : upper_clipping_index + 1]
def tlu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert node.converted_to_table_lookup
@@ -654,6 +713,33 @@ class Converter:
variable_input = ctx.mul(variable_input.type, variable_input, shifter)
variable_input = ctx.reinterpret(variable_input, bit_width=truncated_bit_width)
elif variable_input.origin.properties.get("name") == "round_bit_pattern":
exactness = (
variable_input.origin.properties["exactness"]
or ctx.configuration.rounding_exactness
)
if exactness == Exactness.APPROXIMATE:
# we clip values to enforce input precision exactly as queried
original_bit_width = variable_input.origin.properties["original_bit_width"]
lsbs_to_remove = variable_input.origin.properties["kwargs"]["lsbs_to_remove"]
overflow = variable_input.origin.properties["overflow_detected"]
rounded_bit_width = original_bit_width - lsbs_to_remove - overflow
approx_config = ctx.configuration.approximate_rounding_config
clipping = approx_config.logical_clipping
reduce_precision = approx_config.reduce_precision_after_approximate_clipping
if len(tables) == 1:
lut_values = self.tlu_adjust(
lut_values, variable_input, rounded_bit_width, clipping, reduce_precision
)
else:
for sub_i, sub_lut_values in enumerate(lut_values):
lut_values[sub_i] = self.tlu_adjust(
sub_lut_values,
variable_input,
rounded_bit_width,
clipping,
reduce_precision,
)
if len(tables) == 1:
return ctx.tlu(ctx.typeof(node), on=variable_input, table=lut_values.tolist())

View File

@@ -230,7 +230,7 @@ class AdditionalConstraints:
return any(pred.output.is_clear for pred in preds)
def has_overflow_protection(self, node: Node, preds: List[Node]) -> bool:
return node.properties["attributes"]["overflow_protection"] is True
return node.properties["kwargs"]["overflow_protection"] is True
# ===========
# Constraints

View File

@@ -8,6 +8,7 @@ from typing import Optional
import numpy as np
from ...compilation.configuration import Exactness
from ...dtypes import Integer
from ...extensions.table import LookupTable
from ...representation import Graph, GraphProcessor, Node
@@ -18,6 +19,14 @@ class ProcessRounding(GraphProcessor):
ProcessRounding graph processor, to analyze rounding and support regular operations on it.
"""
rounding_exactness: Exactness
def __init__(
self,
rounding_exactness: Exactness,
):
self.rounding_exactness = rounding_exactness
def apply(self, graph: Graph):
rounding_nodes = graph.query_nodes(operation_filter="round_bit_pattern")
for node in rounding_nodes:
@@ -26,8 +35,13 @@ class ProcessRounding(GraphProcessor):
original_lsbs_to_remove = node.properties["kwargs"]["lsbs_to_remove"]
final_lsbs_to_remove = node.properties["final_lsbs_to_remove"]
exactness = node.properties["exactness"]
if exactness is None:
exactness = self.rounding_exactness
if original_lsbs_to_remove != 0 and final_lsbs_to_remove == 0:
self.replace_with_tlu(graph, node)
if exactness != Exactness.APPROXIMATE:
self.replace_with_tlu(graph, node)
continue
self.process_successors(graph, node)
@@ -43,12 +57,14 @@ class ProcessRounding(GraphProcessor):
pred = preds[0]
assert isinstance(pred.output.dtype, Integer)
overflow_protection = node.properties["attributes"]["overflow_protection"]
exactness = node.properties["kwargs"]["exactness"]
overflow_protection = node.properties["kwargs"]["overflow_protection"]
overflow_detected = (
overflow_protection
and pred.properties["original_bit_width"] != node.properties["original_bit_width"]
)
node.properties["exactness"] = exactness
node.properties["overflow_protection"] = overflow_protection
node.properties["overflow_detected"] = overflow_detected

View File

@@ -1,3 +1,4 @@
[mypy]
plugins = numpy.typing.mypy_plugin
disable_error_code = annotation-unchecked
allow_redefinition = True

View File

@@ -1,4 +1,5 @@
importlib-resources>=6.1
jsonpickle>=3.0.3
networkx>=2.6
numpy>=1.23
scipy>=1.10

View File

@@ -6,6 +6,7 @@ import numpy as np
import pytest
from concrete import fhe
from concrete.fhe.compilation.configuration import Exactness
from concrete.fhe.representation.utils import format_constant
@@ -432,14 +433,14 @@ def test_auto_rounding(helpers):
helpers.check_str(
f"""
%0 = x # EncryptedScalar<uint8>
%1 = round_bit_pattern(%0, lsbs_to_remove=3) # EncryptedScalar<uint8>
%2 = tlu(%1, table={table3_formatted_string}) # EncryptedScalar<uint4>
%3 = round_bit_pattern(%2, lsbs_to_remove=2) # EncryptedScalar<uint4>
%4 = tlu(%3, table={table4_formatted_string}) # EncryptedScalar<uint8>
%0 = x # EncryptedScalar<uint8>
%1 = round_bit_pattern(%0, lsbs_to_remove=3, overflow_protection=True, exactness=None) # EncryptedScalar<uint8>
%2 = tlu(%1, table={table3_formatted_string}) # EncryptedScalar<uint4>
%3 = round_bit_pattern(%2, lsbs_to_remove=2, overflow_protection=True, exactness=None) # EncryptedScalar<uint4>
%4 = tlu(%3, table={table4_formatted_string}) # EncryptedScalar<uint8>
return %4
""",
""", # noqa: E501
str(circuit3.graph.format(show_bounds=False)),
)
@@ -611,3 +612,253 @@ def test_round_bit_pattern_overflow_to_sign_bit(helpers):
for x in inputset:
helpers.check_execution(circuit, function, x, retries=3)
def test_round_bit_pattern_approximate_enabling(helpers):
"""
Test round bit pattern various activation paths.
"""
@fhe.compiler({"x": "encrypted"})
def function_default(x):
return fhe.round_bit_pattern(x, lsbs_to_remove=8)
@fhe.compiler({"x": "encrypted"})
def function_exact(x):
return fhe.round_bit_pattern(x, lsbs_to_remove=8, exactness=Exactness.EXACT)
@fhe.compiler({"x": "encrypted"})
def function_approx(x):
return fhe.round_bit_pattern(x, lsbs_to_remove=8, exactness=Exactness.APPROXIMATE)
inputset = [-(2**10), 2**10 - 1]
configuration = helpers.configuration()
circuit_default_default = function_default.compile(inputset, configuration)
circuit_default_exact = function_default.compile(
inputset, configuration.fork(rounding_exactness=Exactness.EXACT)
)
circuit_default_approx = function_default.compile(
inputset, configuration.fork(rounding_exactness=Exactness.APPROXIMATE)
)
circuit_exact = function_exact.compile(
inputset, configuration.fork(rounding_exactness=Exactness.APPROXIMATE)
)
circuit_approx = function_approx.compile(
inputset, configuration.fork(rounding_exactness=Exactness.EXACT)
)
assert circuit_approx.complexity < circuit_exact.complexity
assert circuit_exact.complexity == circuit_default_default.complexity
assert circuit_exact.complexity == circuit_default_exact.complexity
assert circuit_approx.complexity == circuit_default_approx.complexity
@pytest.mark.parametrize(
"accumulator_precision,reduced_precision,signed,conf",
[
(8, 4, True, fhe.ApproximateRoundingConfig(False, 4)),
(7, 4, False, fhe.ApproximateRoundingConfig(False, 4)),
(9, 3, True, fhe.ApproximateRoundingConfig(True, False)),
(8, 3, False, fhe.ApproximateRoundingConfig(True, False)),
(7, 3, False, fhe.ApproximateRoundingConfig(True, 3)),
(7, 2, True, fhe.ApproximateRoundingConfig(False, 2)),
(7, 2, False, fhe.ApproximateRoundingConfig(False, False, False, False)),
(8, 1, True, fhe.ApproximateRoundingConfig(False, 1)),
(8, 1, False, fhe.ApproximateRoundingConfig(True, False)),
(6, 5, False, fhe.ApproximateRoundingConfig(True, 6)),
(6, 5, False, fhe.ApproximateRoundingConfig(True, 5)),
],
)
def test_round_bit_pattern_approximate_off_by_one_errors(
accumulator_precision, reduced_precision, signed, conf, helpers
):
"""
Test round bit pattern off by 1 errors.
"""
lsbs_to_remove = accumulator_precision - reduced_precision
@fhe.compiler({"x": "encrypted"})
def function(x):
x = fhe.univariate(lambda x: x)(x)
x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
x = x // 2**lsbs_to_remove
return x
if signed:
inputset = [-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 1) - 1]
else:
inputset = [0, 2**accumulator_precision - 1]
configuration = helpers.configuration()
circuit_exact = function.compile(inputset, configuration)
circuit_approx = function.compile(
inputset,
configuration.fork(
approximate_rounding_config=conf, rounding_exactness=Exactness.APPROXIMATE
),
)
# check it's better even with bad conf
assert circuit_approx.complexity < circuit_exact.complexity
testset = range(*inputset)
nb_error = 0
for x in testset:
approx = circuit_approx.encrypt_run_decrypt(x)
approx_simu = circuit_approx.simulate(x)
exact = circuit_exact.simulate(x)
assert abs(approx_simu - exact) <= 1
assert abs(approx_simu - approx) <= 1
delta = abs(approx - approx_simu)
assert delta <= 1
nb_error += delta > 0
nb_transitions = 2 ** (accumulator_precision - reduced_precision)
assert nb_error <= 3 * nb_transitions # of the same order as transitions but small sample size
@pytest.mark.parametrize(
"signed,physical",
[(signed, physical) for signed in (True, False) for physical in (True, False)],
)
def test_round_bit_pattern_approximate_clippping(signed, physical, helpers):
"""
Test round bit pattern clipping.
"""
accumulator_precision = 6
reduced_precision = 3
lsbs_to_remove = accumulator_precision - reduced_precision
@fhe.compiler({"x": "encrypted"})
def function(x):
x = fhe.univariate(lambda x: x)(x)
x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
x = x // 2**lsbs_to_remove
return x
if signed:
input_domain = range(-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 1))
else:
input_domain = range(0, 2 ** (accumulator_precision))
configuration = helpers.configuration()
approx_conf = fhe.ApproximateRoundingConfig(
logical_clipping=not physical,
approximate_clipping_start_precision=physical and reduced_precision,
reduce_precision_after_approximate_clipping=False,
)
no_clipping_conf = fhe.ApproximateRoundingConfig(
logical_clipping=False, approximate_clipping_start_precision=False
)
assert approx_conf.logical_clipping or approx_conf.approximate_clipping_start_precision
circuit_clipping = function.compile(
input_domain,
configuration.fork(
approximate_rounding_config=approx_conf, rounding_exactness=Exactness.APPROXIMATE
),
)
circuit_no_clipping = function.compile(
input_domain,
configuration.fork(
approximate_rounding_config=no_clipping_conf, rounding_exactness=Exactness.APPROXIMATE
),
)
if signed:
clipped_output_domain = range(-(2 ** (reduced_precision - 1)), 2 ** (reduced_precision - 1))
else:
clipped_output_domain = range(0, 2**reduced_precision)
# With clipping
for x in input_domain:
assert (
circuit_clipping.encrypt_run_decrypt(x) in clipped_output_domain
), circuit_clipping.mlir # no overflow
assert circuit_clipping.simulate(x) in clipped_output_domain
# Without clipping
# overflow
assert circuit_no_clipping.simulate(input_domain[-1]) not in clipped_output_domain
@pytest.mark.parametrize(
"signed,accumulator_precision",
[
(signed, accumulator_precision)
for signed in (True, False)
for accumulator_precision in (13, 24)
],
)
def test_round_bit_pattern_approximate_acc_to_6_costs(signed, accumulator_precision, helpers):
"""
Test round bit pattern speedup when approximatipn is activated.
"""
reduced_precision = 6
lsbs_to_remove = accumulator_precision - reduced_precision
@fhe.compiler({"x": "encrypted"})
def function(x):
x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove, overflow_protection=True)
x = x // 2**lsbs_to_remove
return x
# with overflow
if signed:
input_domain = [-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 1) - 1]
else:
input_domain = [0, 2 ** (accumulator_precision) - 1]
configuration = helpers.configuration().fork(
single_precision=False,
parameter_selection_strategy=fhe.ParameterSelectionStrategy.MULTI,
composable=True,
)
circuit_exact = function.compile(input_domain, configuration)
approx_conf_fastest = fhe.ApproximateRoundingConfig(approximate_clipping_start_precision=6)
approx_conf_safest = fhe.ApproximateRoundingConfig(approximate_clipping_start_precision=100)
circuit_approx_fastest = function.compile(
input_domain,
configuration.fork(
approximate_rounding_config=approx_conf_fastest,
rounding_exactness=Exactness.APPROXIMATE,
),
)
circuit_approx_safest = function.compile(
input_domain,
configuration.fork(
approximate_rounding_config=approx_conf_safest, rounding_exactness=Exactness.APPROXIMATE
),
)
assert circuit_approx_safest.complexity < circuit_exact.complexity
assert circuit_approx_fastest.complexity < circuit_approx_safest.complexity
@fhe.compiler({"x": "encrypted"})
def function(x): # pylint: disable=function-redefined
x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove, overflow_protection=False)
x = x // 2**lsbs_to_remove
return x
# without overflow
if signed:
input_domain = [-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 2) - 2]
else:
input_domain = [0, 2 ** (accumulator_precision - 1) - 2]
circuit_exact_no_ovf = function.compile(input_domain, configuration)
circuit_approx_fastest_no_ovf = function.compile(
input_domain,
configuration.fork(
approximate_rounding_config=approx_conf_fastest,
rounding_exactness=Exactness.APPROXIMATE,
),
)
circuit_approx_safest_no_ovf = function.compile(
input_domain,
configuration.fork(
approximate_rounding_config=approx_conf_safest, rounding_exactness=Exactness.APPROXIMATE
),
)
assert circuit_approx_fastest_no_ovf.complexity == circuit_approx_safest_no_ovf.complexity
assert circuit_approx_safest_no_ovf.complexity < circuit_exact_no_ovf.complexity
assert circuit_exact_no_ovf.complexity < circuit_exact.complexity

View File

@@ -512,11 +512,11 @@ return %1
Function you are trying to compile cannot be compiled
%0 = x # ClearScalar<uint5> ∈ [10, 30]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear
%1 = round_bit_pattern(%0, lsbs_to_remove=2) # ClearScalar<uint6> ∈ [12, 32]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear round bit pattern is not supported
%2 = reinterpret(%1) # ClearScalar<uint6>
%0 = x # ClearScalar<uint5> ∈ [10, 30]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear
%1 = round_bit_pattern(%0, lsbs_to_remove=2, overflow_protection=True, exactness=None) # ClearScalar<uint6> ∈ [12, 32]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear round bit pattern is not supported
%2 = reinterpret(%1) # ClearScalar<uint6>
return %2
""", # noqa: E501