feat(frontend-python): add truncate bit pattern extension

This commit is contained in:
Umut
2023-11-22 09:11:45 +01:00
parent 07d6293ca8
commit 54d792c7bf
18 changed files with 1149 additions and 2 deletions

View File

@@ -23,6 +23,7 @@
* [Min/Max Operations](tutorial/minmax.md) * [Min/Max Operations](tutorial/minmax.md)
* [Bitwise Operations](tutorial/bitwise.md) * [Bitwise Operations](tutorial/bitwise.md)
* [Table Lookups](tutorial/table\_lookups.md) * [Table Lookups](tutorial/table\_lookups.md)
* [Truncating](tutorial/truncating.md)
* [Rounding](tutorial/rounding.md) * [Rounding](tutorial/rounding.md)
* [Floating Points](tutorial/floating\_points.md) * [Floating Points](tutorial/floating\_points.md)
* [Multi Precision](tutorial/multi\_precision.md) * [Multi Precision](tutorial/multi\_precision.md)

BIN
docs/_static/truncating/identity.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 86 KiB

BIN
docs/_static/truncating/msbs_to_keep.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 85 KiB

281
docs/tutorial/truncating.md Normal file
View File

@@ -0,0 +1,281 @@
# Truncating
Table lookups have a strict constraint on the number of bits they support. This can be limiting, especially if you don't need exact precision. As well as this, using larger bit-widths leads to slower table lookups.
To overcome these issues, truncated table lookups are introduced. This operation provides a way to zero the least significant bits of a large integer and then apply the table lookup on the resulting (smaller) value.
Imagine you have a 5-bit value, you can use `fhe.truncate_bit_pattern(value, lsbs_to_remove=2)` to truncate it (here the last 2 bits are discarded). Once truncated, value will remain in 5-bits (e.g., 22 = 0b10110 would be truncated to 20 = 0b10100), and the last 2 bits of it would be zero. Concrete uses this to optimize table lookups on the truncated value, the 5-bit table lookup gets optimized to a 3-bit table lookup, which is much faster!
Let's see how truncation works in practice:
```python
import matplotlib.pyplot as plt
import numpy as np
from concrete import fhe
original_bit_width = 5
lsbs_to_remove = 2
assert 0 < lsbs_to_remove < original_bit_width
original_values = list(range(2**original_bit_width))
truncated_values = [
fhe.truncate_bit_pattern(value, lsbs_to_remove)
for value in original_values
]
previous_truncated = truncated_values[0]
for original, truncated in zip(original_values, truncated_values):
if truncated != previous_truncated:
previous_truncated = truncated
print()
original_binary = np.binary_repr(original, width=(original_bit_width + 1))
truncated_binary = np.binary_repr(truncated, width=(original_bit_width + 1))
print(
f"{original:2} = 0b_{original_binary[:-lsbs_to_remove]}[{original_binary[-lsbs_to_remove:]}] "
f"=> "
f"0b_{truncated_binary[:-lsbs_to_remove]}[{truncated_binary[-lsbs_to_remove:]}] = {truncated}"
)
fig = plt.figure()
ax = fig.add_subplot()
plt.plot(original_values, original_values, label="original", color="black")
plt.plot(original_values, truncated_values, label="truncated", color="green")
plt.legend()
ax.set_aspect("equal", adjustable="box")
plt.show()
```
prints:
```
0 = 0b_0000[00] => 0b_0000[00] = 0
1 = 0b_0000[01] => 0b_0000[00] = 0
2 = 0b_0000[10] => 0b_0000[00] = 0
3 = 0b_0000[11] => 0b_0000[00] = 0
4 = 0b_0001[00] => 0b_0001[00] = 4
5 = 0b_0001[01] => 0b_0001[00] = 4
6 = 0b_0001[10] => 0b_0001[00] = 4
7 = 0b_0001[11] => 0b_0001[00] = 4
8 = 0b_0010[00] => 0b_0010[00] = 8
9 = 0b_0010[01] => 0b_0010[00] = 8
10 = 0b_0010[10] => 0b_0010[00] = 8
11 = 0b_0010[11] => 0b_0010[00] = 8
12 = 0b_0011[00] => 0b_0011[00] = 12
13 = 0b_0011[01] => 0b_0011[00] = 12
14 = 0b_0011[10] => 0b_0011[00] = 12
15 = 0b_0011[11] => 0b_0011[00] = 12
16 = 0b_0100[00] => 0b_0100[00] = 16
17 = 0b_0100[01] => 0b_0100[00] = 16
18 = 0b_0100[10] => 0b_0100[00] = 16
19 = 0b_0100[11] => 0b_0100[00] = 16
20 = 0b_0101[00] => 0b_0101[00] = 20
21 = 0b_0101[01] => 0b_0101[00] = 20
22 = 0b_0101[10] => 0b_0101[00] = 20
23 = 0b_0101[11] => 0b_0101[00] = 20
24 = 0b_0110[00] => 0b_0110[00] = 24
25 = 0b_0110[01] => 0b_0110[00] = 24
26 = 0b_0110[10] => 0b_0110[00] = 24
27 = 0b_0110[11] => 0b_0110[00] = 24
28 = 0b_0111[00] => 0b_0111[00] = 28
29 = 0b_0111[01] => 0b_0111[00] = 28
30 = 0b_0111[10] => 0b_0111[00] = 28
31 = 0b_0111[11] => 0b_0111[00] = 28
```
and displays:
![](../\_static/truncating/identity.png)
Now, let's see how truncating can be used in FHE.
```python
import itertools
import time
import matplotlib.pyplot as plt
import numpy as np
from concrete import fhe
configuration = fhe.Configuration(
enable_unsafe_features=True,
use_insecure_key_cache=True,
insecure_key_cache_location=".keys",
)
input_bit_width = 6
input_range = np.array(range(2**input_bit_width))
timings = {}
results = {}
for lsbs_to_remove in range(input_bit_width):
@fhe.compiler({"x": "encrypted"})
def f(x):
return fhe.truncate_bit_pattern(x, lsbs_to_remove) ** 2
circuit = f.compile(inputset=[input_range], configuration=configuration)
circuit.keygen()
encrypted_sample = circuit.encrypt(input_range)
start = time.time()
encrypted_result = circuit.run(encrypted_sample)
end = time.time()
result = circuit.decrypt(encrypted_result)
took = end - start
timings[lsbs_to_remove] = took
results[lsbs_to_remove] = result
number_of_figures = len(results)
columns = 1
for i in range(2, number_of_figures):
if number_of_figures % i == 0:
columns = i
rows = number_of_figures // columns
fig, axs = plt.subplots(rows, columns)
axs = axs.flatten()
baseline = timings[0]
for lsbs_to_remove in range(input_bit_width):
timing = timings[lsbs_to_remove]
speedup = baseline / timing
print(f"lsbs_to_remove={lsbs_to_remove} => {speedup:.2f}x speedup")
axs[lsbs_to_remove].set_title(f"lsbs_to_remove={lsbs_to_remove}")
axs[lsbs_to_remove].plot(input_range, results[lsbs_to_remove])
plt.show()
```
prints:
```
lsbs_to_remove=0 => 1.00x speedup
lsbs_to_remove=1 => 1.69x speedup
lsbs_to_remove=2 => 3.48x speedup
lsbs_to_remove=3 => 3.06x speedup
lsbs_to_remove=4 => 3.46x speedup
lsbs_to_remove=5 => 3.14x speedup
```
{% hint style="info" %}
These speed-ups can vary from system to system.
{% endhint %}
{% hint style="info" %}
The reason why the speed-up is not increasing with `lsbs_to_remove` is because the truncating operation itself has a cost: each bit removal is a PBS. Therefore, if a lot of bits are removed, truncation itself could take longer than the bigger TLU which is evaluated afterwards.
{% endhint %}
and displays:
![](../\_static/truncating/lsbs_to_remove.png)
## Auto Truncators
Truncating is very useful but, in some cases, you don't know how many bits your input contains, so it's not reliable to specify `lsbs_to_remove` manually. For this reason, the `AutoTruncator` class is introduced.
`AutoTruncator` allows you to set how many of the most significant bits to keep, but they need to be adjusted using an inputset to determine how many of the least significant bits to remove. This can be done manually using `fhe.AutoTruncator.adjust(function, inputset)`, or by setting `auto_adjust_truncators` configuration to `True` during compilation.
Here is how auto truncators can be used in FHE:
```python
import itertools
import time
import matplotlib.pyplot as plt
import numpy as np
from concrete import fhe
configuration = fhe.Configuration(
enable_unsafe_features=True,
use_insecure_key_cache=True,
insecure_key_cache_location=".keys",
single_precision=False,
parameter_selection_strategy=fhe.ParameterSelectionStrategy.MULTI,
)
input_bit_width = 6
input_range = np.array(range(2**input_bit_width))
timings = {}
results = {}
for target_msbs in reversed(range(1, input_bit_width + 1)):
truncator = fhe.AutoTruncator(target_msbs)
@fhe.compiler({"x": "encrypted"})
def f(x):
return fhe.truncate_bit_pattern(x, lsbs_to_remove=truncator) ** 2
fhe.AutoTruncator.adjust(f, inputset=[input_range])
circuit = f.compile(inputset=[input_range], configuration=configuration)
circuit.keygen()
encrypted_sample = circuit.encrypt(input_range)
start = time.time()
encrypted_result = circuit.run(encrypted_sample)
end = time.time()
result = circuit.decrypt(encrypted_result)
took = end - start
timings[target_msbs] = took
results[target_msbs] = result
number_of_figures = len(results)
columns = 1
for i in range(2, number_of_figures):
if number_of_figures % i == 0:
columns = i
rows = number_of_figures // columns
fig, axs = plt.subplots(rows, columns)
axs = axs.flatten()
baseline = timings[input_bit_width]
for i, target_msbs in enumerate(reversed(range(1, input_bit_width + 1))):
timing = timings[target_msbs]
speedup = baseline / timing
print(f"target_msbs={target_msbs} => {speedup:.2f}x speedup")
axs[i].set_title(f"target_msbs={target_msbs}")
axs[i].plot(input_range, results[target_msbs])
plt.show()
```
prints:
```
target_msbs=6 => 1.00x speedup
target_msbs=5 => 1.80x speedup
target_msbs=4 => 3.47x speedup
target_msbs=3 => 3.02x speedup
target_msbs=2 => 3.38x speedup
target_msbs=1 => 3.37x speedup
```
and displays:
![](../\_static/truncating/msbs_to_keep.png)
{% hint style="warning" %}
`AutoTruncator`s should be defined outside the function that is being compiled. They are used to store the result of the adjustment process, so they shouldn't be created each time the function is called. Furthermore, each `AutoTruncator` should be used with exactly one `truncate_bit_pattern` call.
{% endhint %}

View File

@@ -28,6 +28,7 @@ from .compilation import (
from .compilation.decorators import circuit, compiler from .compilation.decorators import circuit, compiler
from .extensions import ( from .extensions import (
AutoRounder, AutoRounder,
AutoTruncator,
LookupTable, LookupTable,
array, array,
conv, conv,
@@ -39,6 +40,7 @@ from .extensions import (
ones_like, ones_like,
round_bit_pattern, round_bit_pattern,
tag, tag,
truncate_bit_pattern,
univariate, univariate,
zero, zero,
zeros, zeros,

View File

@@ -14,7 +14,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
from concrete.compiler import CompilationContext from concrete.compiler import CompilationContext
from ..extensions import AutoRounder from ..extensions import AutoRounder, AutoTruncator
from ..mlir import GraphConverter from ..mlir import GraphConverter
from ..representation import Graph from ..representation import Graph
from ..tracing import Tracer from ..tracing import Tracer
@@ -279,6 +279,9 @@ class Compiler:
if self.configuration.auto_adjust_rounders: if self.configuration.auto_adjust_rounders:
AutoRounder.adjust(self.function, self.inputset) AutoRounder.adjust(self.function, self.inputset)
if self.configuration.auto_adjust_truncators:
AutoTruncator.adjust(self.function, self.inputset)
if self.graph is None: if self.graph is None:
try: try:
first_sample = next(iter(self.inputset)) first_sample = next(iter(self.inputset))

View File

@@ -880,6 +880,7 @@ class Configuration:
global_p_error: Optional[float] global_p_error: Optional[float]
insecure_key_cache_location: Optional[str] insecure_key_cache_location: Optional[str]
auto_adjust_rounders: bool auto_adjust_rounders: bool
auto_adjust_truncators: bool
single_precision: bool single_precision: bool
parameter_selection_strategy: ParameterSelectionStrategy parameter_selection_strategy: ParameterSelectionStrategy
show_progress: bool show_progress: bool
@@ -913,6 +914,7 @@ class Configuration:
p_error: Optional[float] = None, p_error: Optional[float] = None,
global_p_error: Optional[float] = None, global_p_error: Optional[float] = None,
auto_adjust_rounders: bool = False, auto_adjust_rounders: bool = False,
auto_adjust_truncators: bool = False,
single_precision: bool = False, single_precision: bool = False,
parameter_selection_strategy: Union[ parameter_selection_strategy: Union[
ParameterSelectionStrategy, str ParameterSelectionStrategy, str
@@ -959,6 +961,7 @@ class Configuration:
self.p_error = p_error self.p_error = p_error
self.global_p_error = global_p_error self.global_p_error = global_p_error
self.auto_adjust_rounders = auto_adjust_rounders self.auto_adjust_rounders = auto_adjust_rounders
self.auto_adjust_truncators = auto_adjust_truncators
self.single_precision = single_precision self.single_precision = single_precision
self.parameter_selection_strategy = ParameterSelectionStrategy.parse( self.parameter_selection_strategy = ParameterSelectionStrategy.parse(
parameter_selection_strategy parameter_selection_strategy
@@ -1035,6 +1038,7 @@ class Configuration:
p_error: Union[Keep, Optional[float]] = KEEP, p_error: Union[Keep, Optional[float]] = KEEP,
global_p_error: Union[Keep, Optional[float]] = KEEP, global_p_error: Union[Keep, Optional[float]] = KEEP,
auto_adjust_rounders: Union[Keep, bool] = KEEP, auto_adjust_rounders: Union[Keep, bool] = KEEP,
auto_adjust_truncators: Union[Keep, bool] = KEEP,
single_precision: Union[Keep, bool] = KEEP, single_precision: Union[Keep, bool] = KEEP,
parameter_selection_strategy: Union[Keep, Union[ParameterSelectionStrategy, str]] = KEEP, parameter_selection_strategy: Union[Keep, Union[ParameterSelectionStrategy, str]] = KEEP,
show_progress: Union[Keep, bool] = KEEP, show_progress: Union[Keep, bool] = KEEP,

View File

@@ -11,5 +11,6 @@ from .ones import one, ones, ones_like
from .round_bit_pattern import AutoRounder, round_bit_pattern from .round_bit_pattern import AutoRounder, round_bit_pattern
from .table import LookupTable from .table import LookupTable
from .tag import tag from .tag import tag
from .truncate_bit_pattern import AutoTruncator, truncate_bit_pattern
from .univariate import univariate from .univariate import univariate
from .zeros import zero, zeros, zeros_like from .zeros import zero, zeros, zeros_like

View File

@@ -0,0 +1,257 @@
"""
Declaration of `truncate_bit_pattern` extension.
"""
import threading
from copy import deepcopy
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
import numpy as np
from ..dtypes import Integer
from ..mlir.utils import MAXIMUM_TLU_BIT_WIDTH
from ..representation import Node
from ..tracing import Tracer
from ..values import ValueDescription
local = threading.local()
# pylint: disable=protected-access
local._is_adjusting = False
# pylint: enable=protected-access
class Adjusting(BaseException):
"""
Adjusting class, to be used as early stop signal during adjustment.
"""
truncator: "AutoTruncator"
input_min: int
input_max: int
def __init__(self, truncator: "AutoTruncator", input_min: int, input_max: int):
super().__init__()
self.truncator = truncator
self.input_min = input_min
self.input_max = input_max
class AutoTruncator:
"""
AutoTruncator class, to optimize for the number of msbs to keep during truncate operation.
"""
target_msbs: int
is_adjusted: bool
input_min: int
input_max: int
input_bit_width: int
lsbs_to_remove: int
def __init__(self, target_msbs: int = MAXIMUM_TLU_BIT_WIDTH):
# pylint: disable=protected-access
if local._is_adjusting:
message = (
"AutoTruncators cannot be constructed during adjustment, "
"please construct AutoTruncators outside the function and reference it"
)
raise RuntimeError(message)
# pylint: enable=protected-access
self.target_msbs = target_msbs
self.is_adjusted = False
self.input_min = 0
self.input_max = 0
self.input_bit_width = 0
self.lsbs_to_remove = 0
@staticmethod
def adjust(function: Callable, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
"""
Adjust AutoTruncators in a function using an inputset.
"""
# pylint: disable=protected-access,too-many-branches
try: # extract underlying function for decorators
function = function.function # type: ignore
assert callable(function)
except AttributeError:
pass
if local._is_adjusting:
message = "AutoTruncators cannot be adjusted recursively"
raise RuntimeError(message)
try:
local._is_adjusting = True
# adjust the truncator using the inputset
# this loop continues until the return is reached in the loop body
# which only happens when ALL truncators are adjusted
# this condition is met if the function can be executed fully
# without `Adjusting` exception is raised
while True:
truncator = None
for sample in inputset:
if not isinstance(sample, tuple):
sample = (sample,)
try:
function(*sample)
except Adjusting as adjuster:
truncator = adjuster.truncator
truncator.input_min = min(truncator.input_min, adjuster.input_min)
truncator.input_max = max(truncator.input_max, adjuster.input_max)
input_value = ValueDescription.of(
[truncator.input_min, truncator.input_max]
)
assert isinstance(input_value.dtype, Integer)
truncator.input_bit_width = input_value.dtype.bit_width
if (
truncator.input_bit_width - truncator.lsbs_to_remove
> truncator.target_msbs
):
truncator.lsbs_to_remove = (
truncator.input_bit_width - truncator.target_msbs
)
else:
# this branch will be executed if there were no exceptions in the try block
return
if truncator is None:
message = "AutoTruncators cannot be adjusted with an empty inputset"
raise ValueError(message)
truncator.is_adjusted = True
finally:
local._is_adjusting = False
# pylint: enable=protected-access,too-many-branches
def dump_dict(self) -> Dict:
"""
Dump properties of the truncator to a dict.
"""
return {
"target_msbs": self.target_msbs,
"is_adjusted": self.is_adjusted,
"input_min": self.input_min,
"input_max": self.input_max,
"input_bit_width": self.input_bit_width,
"lsbs_to_remove": self.lsbs_to_remove,
}
@classmethod
def load_dict(cls, properties: Dict) -> "AutoTruncator":
"""
Load previously dumped truncator.
"""
result = AutoTruncator(target_msbs=properties["target_msbs"])
result.is_adjusted = properties["is_adjusted"]
result.input_min = properties["input_min"]
result.input_max = properties["input_max"]
result.lsbs_to_remove = properties["lsbs_to_remove"]
result.input_bit_width = properties["input_bit_width"]
return result
def truncate_bit_pattern(
x: Union[int, np.integer, List, np.ndarray, Tracer],
lsbs_to_remove: Union[int, AutoTruncator],
) -> Union[int, np.integer, List, np.ndarray, Tracer]:
"""
Round the bit pattern of an integer.
If `lsbs_to_remove` is an `AutoTruncator`:
corresponding integer value will be determined by adjustment process.
x = 0b_0000 , lsbs_to_remove = 2 => 0b_0000
x = 0b_0001 , lsbs_to_remove = 2 => 0b_0000
x = 0b_0010 , lsbs_to_remove = 2 => 0b_0000
x = 0b_0100 , lsbs_to_remove = 2 => 0b_0100
x = 0b_0110 , lsbs_to_remove = 2 => 0b_0100
x = 0b_1100 , lsbs_to_remove = 2 => 0b_1100
x = 0b_abcd , lsbs_to_remove = 2 => 0b_ab00
Args:
x (Union[int, np.integer, np.ndarray, Tracer]):
input to truncate
lsbs_to_remove (Union[int, AutoTruncator]):
number of the least significant bits to clear
or an auto truncator object which will be used to determine the integer value
Returns:
Union[int, np.integer, np.ndarray, Tracer]:
Tracer that represents the operation during tracing
truncated value(s) otherwise
"""
# pylint: disable=protected-access,too-many-branches
if isinstance(lsbs_to_remove, AutoTruncator):
if local._is_adjusting:
if not lsbs_to_remove.is_adjusted:
raise Adjusting(lsbs_to_remove, int(np.min(x)), int(np.max(x))) # type: ignore
elif not lsbs_to_remove.is_adjusted:
message = (
"AutoTruncators cannot be used before adjustment, "
"please call AutoTruncator.adjust with the function that will be compiled "
"and provide the exact inputset that will be used for compilation"
)
raise RuntimeError(message)
lsbs_to_remove = lsbs_to_remove.lsbs_to_remove
assert isinstance(lsbs_to_remove, int)
def evaluator(
x: Union[int, np.integer, np.ndarray],
lsbs_to_remove: int,
) -> Union[int, np.integer, np.ndarray]:
return (x >> lsbs_to_remove) << lsbs_to_remove
if isinstance(x, Tracer):
computation = Node.generic(
"truncate_bit_pattern",
[deepcopy(x.output)],
deepcopy(x.output),
evaluator,
kwargs={"lsbs_to_remove": lsbs_to_remove},
)
return Tracer(computation, [x])
if isinstance(x, list): # pragma: no cover
try:
x = np.array(x)
except Exception: # pylint: disable=broad-except
pass
if isinstance(x, np.ndarray):
if not np.issubdtype(x.dtype, np.integer):
message = (
f"Expected input elements to be integers but they are {type(x.dtype).__name__}"
)
raise TypeError(message)
elif not isinstance(x, (int, np.integer)):
message = f"Expected input to be an int or a numpy array but it's {type(x).__name__}"
raise TypeError(message)
return evaluator(x, lsbs_to_remove)
# pylint: enable=protected-access,too-many-branches

View File

@@ -3218,6 +3218,45 @@ class Context:
return xs[0] return xs[0]
def truncate_bit_pattern(self, x: Conversion, lsbs_to_remove: int) -> Conversion:
if x.is_clear:
highlights = {
x.origin: "operand is clear",
self.converting: "but clear truncate bit pattern is not supported",
}
self.error(highlights)
assert x.bit_width > lsbs_to_remove
resulting_bit_width = x.bit_width
for i in range(lsbs_to_remove):
lsb = self.lsb(x.type, x)
cleared = self.sub(x.type, x, lsb)
new_bit_width = (x.bit_width - 1) if i != (lsbs_to_remove - 1) else resulting_bit_width
x = self.reinterpret(cleared, bit_width=new_bit_width)
return x
def lsb(self, resulting_type: ConversionType, x: Conversion) -> Conversion:
assert resulting_type.shape == x.shape
assert resulting_type.is_signed == x.is_signed
assert resulting_type.is_encrypted and x.is_encrypted
operation = fhe.LsbEintOp if x.is_scalar else fhelinalg.LsbEintOp
return self.operation(operation, resulting_type, x.result)
def reinterpret(self, x: Conversion, *, bit_width: int) -> Conversion:
assert x.is_encrypted
resulting_element_type = (self.eint if x.is_unsigned else self.esint)(bit_width)
resulting_type = self.tensor(resulting_element_type, shape=x.shape)
operation = (
fhe.ReinterpretPrecisionEintOp if x.is_scalar else fhelinalg.ReinterpretPrecisionEintOp
)
return self.operation(operation, resulting_type, x.result)
def zeros(self, resulting_type: ConversionType) -> Conversion: def zeros(self, resulting_type: ConversionType) -> Conversion:
assert resulting_type.is_encrypted assert resulting_type.is_encrypted

View File

@@ -618,6 +618,20 @@ class Converter:
variable_input_index = variable_input_indices[0] variable_input_index = variable_input_indices[0]
variable_input = preds[variable_input_index] variable_input = preds[variable_input_index]
if variable_input.origin.properties.get("name") == "truncate_bit_pattern":
original_bit_width = variable_input.origin.properties["original_bit_width"]
lsbs_to_remove = variable_input.origin.properties["kwargs"]["lsbs_to_remove"]
truncated_bit_width = original_bit_width - lsbs_to_remove
if variable_input.bit_width > original_bit_width:
bit_width_difference = variable_input.bit_width - original_bit_width
shifter = ctx.constant(
ctx.i(variable_input.bit_width + 1), 2**bit_width_difference
)
variable_input = ctx.mul(variable_input.type, variable_input, shifter)
variable_input = ctx.reinterpret(variable_input, bit_width=truncated_bit_width)
if len(tables) == 1: if len(tables) == 1:
return ctx.tlu(ctx.typeof(node), on=variable_input, table=lut_values.tolist()) return ctx.tlu(ctx.typeof(node), on=variable_input, table=lut_values.tolist())
@@ -637,6 +651,10 @@ class Converter:
axes=node.properties["kwargs"].get("axes", []), axes=node.properties["kwargs"].get("axes", []),
) )
def truncate_bit_pattern(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
return ctx.truncate_bit_pattern(preds[0], node.properties["kwargs"]["lsbs_to_remove"])
def zeros(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion: def zeros(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 0 assert len(preds) == 0
return ctx.zeros(ctx.typeof(node)) return ctx.zeros(ctx.typeof(node))

View File

@@ -562,3 +562,7 @@ class AdditionalConstraints:
transpose = { transpose = {
inputs_and_output_share_precision, inputs_and_output_share_precision,
} }
truncate_bit_pattern = {
inputs_and_output_share_precision,
}

View File

@@ -162,7 +162,9 @@ def construct_table(node: Node, preds: List[Node]) -> List[Any]:
if pred.operation != Operation.Constant: if pred.operation != Operation.Constant:
variable_input_index = index variable_input_index = index
break break
assert_that(variable_input_index != -1) assert_that(variable_input_index != -1)
variable_input = preds[variable_input_index]
variable_input_dtype = node.inputs[variable_input_index].dtype variable_input_dtype = node.inputs[variable_input_index].dtype
variable_input_shape = node.inputs[variable_input_index].shape variable_input_shape = node.inputs[variable_input_index].shape
@@ -170,7 +172,6 @@ def construct_table(node: Node, preds: List[Node]) -> List[Any]:
assert_that(isinstance(variable_input_dtype, Integer)) assert_that(isinstance(variable_input_dtype, Integer))
variable_input_dtype = deepcopy(cast(Integer, variable_input_dtype)) variable_input_dtype = deepcopy(cast(Integer, variable_input_dtype))
variable_input = preds[variable_input_index]
if ( if (
variable_input.operation == Operation.Generic variable_input.operation == Operation.Generic
and variable_input.properties["name"] == "round_bit_pattern" and variable_input.properties["name"] == "round_bit_pattern"
@@ -186,6 +187,20 @@ def construct_table(node: Node, preds: List[Node]) -> List[Any]:
variable_input_dtype.bit_width += 1 variable_input_dtype.bit_width += 1
step = (2**variable_input_dtype.bit_width) // expected_number_of_elements step = (2**variable_input_dtype.bit_width) // expected_number_of_elements
elif (
variable_input.operation == Operation.Generic
and variable_input.properties["name"] == "truncate_bit_pattern"
):
original_bit_width = variable_input.properties["original_bit_width"]
lsbs_to_remove = variable_input.properties["kwargs"]["lsbs_to_remove"]
resulting_bit_width = original_bit_width - lsbs_to_remove
expected_number_of_elements = 2**resulting_bit_width
variable_input_dtype.bit_width = original_bit_width
step = (2**original_bit_width) // expected_number_of_elements
else: else:
step = 1 step = 1

View File

@@ -1,3 +1,4 @@
importlib-resources>=6.1
networkx>=2.6 networkx>=2.6
numpy>=1.23 numpy>=1.23
scipy>=1.10 scipy>=1.10

View File

@@ -0,0 +1,467 @@
"""
Tests of execution of truncate bit pattern operation.
"""
import numpy as np
import pytest
from concrete import fhe
from concrete.fhe.representation.utils import format_constant
@pytest.mark.parametrize(
"sample,lsbs_to_remove,expected_output",
[
(0b_0000_0011, 0, 0b_0000_0011),
(0b_0000_0100, 0, 0b_0000_0100),
(0b_0000_0000, 3, 0b_0000_0000),
(0b_0000_0001, 3, 0b_0000_0000),
(0b_0000_0010, 3, 0b_0000_0000),
(0b_0000_0011, 3, 0b_0000_0000),
(0b_0000_0100, 3, 0b_0000_0000),
(0b_0000_0101, 3, 0b_0000_0000),
(0b_0000_0110, 3, 0b_0000_0000),
(0b_0000_0111, 3, 0b_0000_0000),
(0b_0000_1000, 3, 0b_0000_1000),
(0b_0000_1001, 3, 0b_0000_1000),
(0b_0000_1010, 3, 0b_0000_1000),
(0b_0000_1011, 3, 0b_0000_1000),
(0b_0000_1100, 3, 0b_0000_1000),
(0b_0000_1101, 3, 0b_0000_1000),
(0b_0000_1110, 3, 0b_0000_1000),
(0b_0000_1111, 3, 0b_0000_1000),
],
)
def test_plain_truncate_bit_pattern(sample, lsbs_to_remove, expected_output):
"""
Test truncate bit pattern in evaluation context.
"""
assert fhe.truncate_bit_pattern(sample, lsbs_to_remove=lsbs_to_remove) == expected_output
@pytest.mark.parametrize(
"sample,lsbs_to_remove,expected_error,expected_message",
[
(
np.array([3.2, 4.1]),
3,
TypeError,
f"Expected input elements to be integers but they are {type(np.array([3.2, 4.1]).dtype).__name__}", # noqa: E501
),
(
"foo",
3,
TypeError,
"Expected input to be an int or a numpy array but it's str",
),
],
)
def test_bad_plain_truncate_bit_pattern(
sample,
lsbs_to_remove,
expected_error,
expected_message,
):
"""
Test truncate bit pattern in evaluation context with bad parameters.
"""
with pytest.raises(expected_error) as excinfo:
fhe.truncate_bit_pattern(sample, lsbs_to_remove=lsbs_to_remove)
assert str(excinfo.value) == expected_message
@pytest.mark.parametrize(
"input_bits,lsbs_to_remove",
[
(3, 1),
(3, 2),
(4, 1),
(4, 2),
(4, 3),
(5, 1),
(5, 2),
(5, 3),
(5, 4),
],
)
@pytest.mark.parametrize(
"mapper",
[
pytest.param(
lambda x: x,
id="x",
),
pytest.param(
lambda x: x + 10,
id="x + 10",
),
pytest.param(
lambda x: x**2,
id="x ** 2",
),
pytest.param(
lambda x: fhe.univariate(lambda x: x if x >= 0 else 0)(x),
id="relu",
),
],
)
def test_truncate_bit_pattern(input_bits, lsbs_to_remove, mapper, helpers):
"""
Test truncate bit pattern.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
x_truncated = fhe.truncate_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
return mapper(x_truncated)
upper_bound = 2**input_bits
inputset = [0, upper_bound - 1]
circuit = function.compile(inputset, helpers.configuration())
helpers.check_execution(circuit, function, np.random.randint(0, upper_bound), retries=3)
for value in inputset:
helpers.check_execution(circuit, function, value, retries=3)
@pytest.mark.parametrize(
"input_bits,lsbs_to_remove",
[
(3, 1),
(3, 2),
(4, 1),
(4, 2),
(4, 3),
],
)
def test_truncate_bit_pattern_unsigned_range(input_bits, lsbs_to_remove, helpers):
"""
Test truncate bit pattern in unsigned range.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
return fhe.truncate_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
inputset = range(0, 2**input_bits)
circuit = function.compile(inputset, helpers.configuration())
for value in inputset:
helpers.check_execution(circuit, function, value, retries=3)
@pytest.mark.parametrize(
"input_bits,lsbs_to_remove",
[
(3, 1),
(3, 2),
(4, 1),
(4, 2),
(4, 3),
],
)
def test_truncate_bit_pattern_signed_range(input_bits, lsbs_to_remove, helpers):
"""
Test truncate bit pattern in signed range.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
return fhe.truncate_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
inputset = range(-(2 ** (input_bits - 1)), 2 ** (input_bits - 1))
circuit = function.compile(inputset, helpers.configuration())
for value in inputset:
helpers.check_execution(circuit, function, value, retries=3)
@pytest.mark.parametrize(
"input_bits,lsbs_to_remove",
[
(3, 1),
(3, 2),
(4, 1),
(4, 2),
(4, 3),
],
)
def test_truncate_bit_pattern_unsigned_range_assigned(input_bits, lsbs_to_remove, helpers):
"""
Test truncate bit pattern in unsigned range with a big bit-width assigned.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
truncated = fhe.truncate_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
return (truncated**2) + (63 - x)
inputset = range(0, 2**input_bits)
circuit = function.compile(inputset, helpers.configuration())
for value in inputset:
helpers.check_execution(circuit, function, value, retries=3)
@pytest.mark.parametrize(
"input_bits,lsbs_to_remove",
[
(3, 1),
(3, 2),
(4, 1),
(4, 2),
(4, 3),
],
)
def test_truncate_bit_pattern_signed_range_assigned(input_bits, lsbs_to_remove, helpers):
"""
Test truncate bit pattern in signed range with a big bit-width assigned.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
truncated = fhe.truncate_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
return (truncated**2) + (63 - x)
inputset = range(-(2 ** (input_bits - 1)), 2 ** (input_bits - 1))
circuit = function.compile(inputset, helpers.configuration())
for value in inputset:
helpers.check_execution(circuit, function, value, retries=3)
def test_truncate_bit_pattern_identity(helpers, pytestconfig):
"""
Test truncate bit pattern used multiple times outside TLUs.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
truncated = fhe.truncate_bit_pattern(x, lsbs_to_remove=2)
return truncated + truncated
inputset = range(-20, 20)
circuit = function.compile(inputset, helpers.configuration())
expected_mlir = (
"""
module {
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
%0 = "FHE.lsb"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<7>
%1 = "FHE.sub_eint"(%arg0, %0) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
%2 = "FHE.reinterpret_precision"(%1) : (!FHE.esint<7>) -> !FHE.esint<6>
%3 = "FHE.lsb"(%2) : (!FHE.esint<6>) -> !FHE.esint<6>
%4 = "FHE.sub_eint"(%2, %3) : (!FHE.esint<6>, !FHE.esint<6>) -> !FHE.esint<6>
%5 = "FHE.reinterpret_precision"(%4) : (!FHE.esint<6>) -> !FHE.esint<7>
%6 = "FHE.add_eint"(%5, %5) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
return %6 : !FHE.esint<7>
}
}
""" # noqa: E501
if pytestconfig.getoption("precision") == "multi"
else """
module {
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
%0 = "FHE.lsb"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<7>
%1 = "FHE.sub_eint"(%arg0, %0) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
%2 = "FHE.reinterpret_precision"(%1) : (!FHE.esint<7>) -> !FHE.esint<6>
%3 = "FHE.lsb"(%2) : (!FHE.esint<6>) -> !FHE.esint<6>
%4 = "FHE.sub_eint"(%2, %3) : (!FHE.esint<6>, !FHE.esint<6>) -> !FHE.esint<6>
%5 = "FHE.reinterpret_precision"(%4) : (!FHE.esint<6>) -> !FHE.esint<7>
%6 = "FHE.add_eint"(%5, %5) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
return %6 : !FHE.esint<7>
}
}
""" # noqa: E501
)
helpers.check_str(expected_mlir, circuit.mlir)
def test_auto_truncating(helpers):
"""
Test truncate bit pattern with auto truncating.
"""
# with auto adjust truncators configuration
# ---------------------------------------
# y has the max value of 1999, so it's 11 bits
# our target msb is 5 bits, which means we need to remove 6 of the least significant bits
truncator1 = fhe.AutoTruncator(target_msbs=5)
@fhe.compiler({"x": "encrypted"})
def function1(x):
y = x + 1000
z = fhe.truncate_bit_pattern(y, lsbs_to_remove=truncator1)
return np.sqrt(z).astype(np.int64)
inputset1 = range(1000)
function1.trace(inputset1, helpers.configuration(), auto_adjust_truncators=True)
assert truncator1.lsbs_to_remove == 6
# manual
# ------
# y has the max value of 1999, so it's 11 bits
# our target msb is 3 bits, which means we need to remove 8 of the least significant bits
truncator2 = fhe.AutoTruncator(target_msbs=3)
@fhe.compiler({"x": "encrypted"})
def function2(x):
y = x + 1000
z = fhe.truncate_bit_pattern(y, lsbs_to_remove=truncator2)
return np.sqrt(z).astype(np.int64)
inputset2 = range(1000)
fhe.AutoTruncator.adjust(function2, inputset2)
assert truncator2.lsbs_to_remove == 8
# complicated case
# ----------------
# have 2 ** 8 entries during evaluation, it won't matter after compilation
entries3 = list(range(2**8))
# we have 8-bit inputs for this table, and we only want to use first 5-bits
for i in range(0, 2**8, 2**3):
# so we set every 8th entry to a 4-bit value
entries3[i] = np.random.randint(0, (2**4) - (2**2))
# when this tlu is applied to an 8-bit value with 5-bit msb truncating, result will be 4-bits
table3 = fhe.LookupTable(entries3)
# and this is the truncator for table1, which should have lsbs_to_remove of 3
truncator3 = fhe.AutoTruncator(target_msbs=5)
# have 2 ** 8 entries during evaluation, it won't matter after compilation
entries4 = list(range(2**8))
# we have 4-bit inputs for this table, and we only want to use first 2-bits
for i in range(0, 2**4, 2**2):
# so we set every 4th entry to an 8-bit value
entries4[i] = np.random.randint(2**7, 2**8)
# when this tlu is applied to a 4-bit value with 2-bit msb truncating, result will be 8-bits
table4 = fhe.LookupTable(entries4)
# and this is the truncator for table2, which should have lsbs_to_remove of 2
truncator4 = fhe.AutoTruncator(target_msbs=2)
@fhe.compiler({"x": "encrypted"})
def function3(x):
a = fhe.truncate_bit_pattern(x, lsbs_to_remove=truncator3)
b = table3[a]
c = fhe.truncate_bit_pattern(b, lsbs_to_remove=truncator4)
d = table4[c]
return d
inputset3 = range((2**8) - (2**3))
circuit3 = function3.compile(
inputset3,
helpers.configuration(),
auto_adjust_truncators=True,
)
assert truncator3.lsbs_to_remove == 3
assert truncator4.lsbs_to_remove == 2
table3_formatted_string = format_constant(table3.table, 25)
table4_formatted_string = format_constant(table4.table, 25)
helpers.check_str(
f"""
%0 = x # EncryptedScalar<uint8>
%1 = truncate_bit_pattern(%0, lsbs_to_remove=3) # EncryptedScalar<uint8>
%2 = tlu(%1, table={table3_formatted_string}) # EncryptedScalar<uint4>
%3 = truncate_bit_pattern(%2, lsbs_to_remove=2) # EncryptedScalar<uint4>
%4 = tlu(%3, table={table4_formatted_string}) # EncryptedScalar<uint8>
return %4
""",
str(circuit3.graph.format(show_bounds=False)),
)
def test_auto_truncating_without_adjustment():
"""
Test truncate bit pattern with auto truncating but without adjustment.
"""
truncator = fhe.AutoTruncator(target_msbs=5)
def function(x):
y = x + 1000
z = fhe.truncate_bit_pattern(y, lsbs_to_remove=truncator)
return np.sqrt(z).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
function(100)
assert str(excinfo.value) == (
"AutoTruncators cannot be used before adjustment, "
"please call AutoTruncator.adjust with the function that will be compiled "
"and provide the exact inputset that will be used for compilation"
)
def test_auto_truncating_with_empty_inputset():
"""
Test truncate bit pattern with auto truncating but with empty inputset.
"""
truncator = fhe.AutoTruncator(target_msbs=5)
def function(x):
y = x + 1000
z = fhe.truncate_bit_pattern(y, lsbs_to_remove=truncator)
return np.sqrt(z).astype(np.int64)
with pytest.raises(ValueError) as excinfo:
fhe.AutoTruncator.adjust(function, [])
assert str(excinfo.value) == "AutoTruncators cannot be adjusted with an empty inputset"
def test_auto_truncating_recursive_adjustment():
"""
Test truncate bit pattern with auto truncating but with recursive adjustment.
"""
truncator = fhe.AutoTruncator(target_msbs=5)
def function(x):
fhe.AutoTruncator.adjust(function, range(10))
y = x + 1000
z = fhe.truncate_bit_pattern(y, lsbs_to_remove=truncator)
return np.sqrt(z).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
fhe.AutoTruncator.adjust(function, range(10))
assert str(excinfo.value) == "AutoTruncators cannot be adjusted recursively"
def test_auto_truncating_construct_in_function():
"""
Test truncate bit pattern with auto truncating but truncator is constructed within the function.
"""
def function(x):
y = x + 1000
z = fhe.truncate_bit_pattern(y, lsbs_to_remove=fhe.AutoTruncator(target_msbs=5))
return np.sqrt(z).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
fhe.AutoTruncator.adjust(function, range(10))
assert str(excinfo.value) == (
"AutoTruncators cannot be constructed during adjustment, "
"please construct AutoTruncators outside the function and reference it"
)

View File

@@ -0,0 +1,37 @@
"""
Tests of 'truncate_bit_pattern' extension.
"""
from concrete import fhe
def test_dump_load_auto_truncator():
"""
Test 'dump_dict' and 'load_dict' methods of AutoTruncator.
"""
truncator = fhe.AutoTruncator(target_msbs=3)
truncator.is_adjusted = True
truncator.input_min = 10
truncator.input_max = 20
truncator.input_bit_width = 5
truncator.lsbs_to_remove = 2
dumped = truncator.dump_dict()
assert dumped == {
"target_msbs": 3,
"is_adjusted": True,
"input_min": 10,
"input_max": 20,
"input_bit_width": 5,
"lsbs_to_remove": 2,
}
loaded = fhe.AutoTruncator.load_dict(dumped)
assert loaded.target_msbs == 3
assert loaded.is_adjusted
assert loaded.input_min == 10
assert loaded.input_max == 20
assert loaded.input_bit_width == 5
assert loaded.lsbs_to_remove == 2

View File

@@ -988,6 +988,23 @@ Function you are trying to compile cannot be compiled
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit maximum operation is supported ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit maximum operation is supported
return %2 return %2
""", # noqa: E501
),
pytest.param(
lambda x: fhe.truncate_bit_pattern(x, lsbs_to_remove=2),
{"x": "clear"},
[10, 20, 30],
RuntimeError,
"""
Function you are trying to compile cannot be compiled
%0 = x # ClearScalar<uint5> ∈ [10, 30]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear
%1 = truncate_bit_pattern(%0, lsbs_to_remove=2) # ClearScalar<uint5> ∈ [8, 28]
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear truncate bit pattern is not supported
return %1
""", # noqa: E501 """, # noqa: E501
), ),
], ],