mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-09 12:57:55 -05:00
feat(frontend-python): add truncate bit pattern extension
This commit is contained in:
@@ -23,6 +23,7 @@
|
|||||||
* [Min/Max Operations](tutorial/minmax.md)
|
* [Min/Max Operations](tutorial/minmax.md)
|
||||||
* [Bitwise Operations](tutorial/bitwise.md)
|
* [Bitwise Operations](tutorial/bitwise.md)
|
||||||
* [Table Lookups](tutorial/table\_lookups.md)
|
* [Table Lookups](tutorial/table\_lookups.md)
|
||||||
|
* [Truncating](tutorial/truncating.md)
|
||||||
* [Rounding](tutorial/rounding.md)
|
* [Rounding](tutorial/rounding.md)
|
||||||
* [Floating Points](tutorial/floating\_points.md)
|
* [Floating Points](tutorial/floating\_points.md)
|
||||||
* [Multi Precision](tutorial/multi\_precision.md)
|
* [Multi Precision](tutorial/multi\_precision.md)
|
||||||
|
|||||||
BIN
docs/_static/truncating/identity.png
vendored
Normal file
BIN
docs/_static/truncating/identity.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 14 KiB |
BIN
docs/_static/truncating/lsbs_to_remove.png
vendored
Normal file
BIN
docs/_static/truncating/lsbs_to_remove.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 86 KiB |
BIN
docs/_static/truncating/msbs_to_keep.png
vendored
Normal file
BIN
docs/_static/truncating/msbs_to_keep.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 85 KiB |
281
docs/tutorial/truncating.md
Normal file
281
docs/tutorial/truncating.md
Normal file
@@ -0,0 +1,281 @@
|
|||||||
|
# Truncating
|
||||||
|
|
||||||
|
Table lookups have a strict constraint on the number of bits they support. This can be limiting, especially if you don't need exact precision. As well as this, using larger bit-widths leads to slower table lookups.
|
||||||
|
|
||||||
|
To overcome these issues, truncated table lookups are introduced. This operation provides a way to zero the least significant bits of a large integer and then apply the table lookup on the resulting (smaller) value.
|
||||||
|
|
||||||
|
Imagine you have a 5-bit value, you can use `fhe.truncate_bit_pattern(value, lsbs_to_remove=2)` to truncate it (here the last 2 bits are discarded). Once truncated, value will remain in 5-bits (e.g., 22 = 0b10110 would be truncated to 20 = 0b10100), and the last 2 bits of it would be zero. Concrete uses this to optimize table lookups on the truncated value, the 5-bit table lookup gets optimized to a 3-bit table lookup, which is much faster!
|
||||||
|
|
||||||
|
Let's see how truncation works in practice:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from concrete import fhe
|
||||||
|
|
||||||
|
original_bit_width = 5
|
||||||
|
lsbs_to_remove = 2
|
||||||
|
|
||||||
|
assert 0 < lsbs_to_remove < original_bit_width
|
||||||
|
|
||||||
|
original_values = list(range(2**original_bit_width))
|
||||||
|
truncated_values = [
|
||||||
|
fhe.truncate_bit_pattern(value, lsbs_to_remove)
|
||||||
|
for value in original_values
|
||||||
|
]
|
||||||
|
|
||||||
|
previous_truncated = truncated_values[0]
|
||||||
|
for original, truncated in zip(original_values, truncated_values):
|
||||||
|
if truncated != previous_truncated:
|
||||||
|
previous_truncated = truncated
|
||||||
|
print()
|
||||||
|
|
||||||
|
original_binary = np.binary_repr(original, width=(original_bit_width + 1))
|
||||||
|
truncated_binary = np.binary_repr(truncated, width=(original_bit_width + 1))
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"{original:2} = 0b_{original_binary[:-lsbs_to_remove]}[{original_binary[-lsbs_to_remove:]}] "
|
||||||
|
f"=> "
|
||||||
|
f"0b_{truncated_binary[:-lsbs_to_remove]}[{truncated_binary[-lsbs_to_remove:]}] = {truncated}"
|
||||||
|
)
|
||||||
|
|
||||||
|
fig = plt.figure()
|
||||||
|
ax = fig.add_subplot()
|
||||||
|
|
||||||
|
plt.plot(original_values, original_values, label="original", color="black")
|
||||||
|
plt.plot(original_values, truncated_values, label="truncated", color="green")
|
||||||
|
plt.legend()
|
||||||
|
|
||||||
|
ax.set_aspect("equal", adjustable="box")
|
||||||
|
plt.show()
|
||||||
|
```
|
||||||
|
|
||||||
|
prints:
|
||||||
|
|
||||||
|
```
|
||||||
|
0 = 0b_0000[00] => 0b_0000[00] = 0
|
||||||
|
1 = 0b_0000[01] => 0b_0000[00] = 0
|
||||||
|
2 = 0b_0000[10] => 0b_0000[00] = 0
|
||||||
|
3 = 0b_0000[11] => 0b_0000[00] = 0
|
||||||
|
|
||||||
|
4 = 0b_0001[00] => 0b_0001[00] = 4
|
||||||
|
5 = 0b_0001[01] => 0b_0001[00] = 4
|
||||||
|
6 = 0b_0001[10] => 0b_0001[00] = 4
|
||||||
|
7 = 0b_0001[11] => 0b_0001[00] = 4
|
||||||
|
|
||||||
|
8 = 0b_0010[00] => 0b_0010[00] = 8
|
||||||
|
9 = 0b_0010[01] => 0b_0010[00] = 8
|
||||||
|
10 = 0b_0010[10] => 0b_0010[00] = 8
|
||||||
|
11 = 0b_0010[11] => 0b_0010[00] = 8
|
||||||
|
|
||||||
|
12 = 0b_0011[00] => 0b_0011[00] = 12
|
||||||
|
13 = 0b_0011[01] => 0b_0011[00] = 12
|
||||||
|
14 = 0b_0011[10] => 0b_0011[00] = 12
|
||||||
|
15 = 0b_0011[11] => 0b_0011[00] = 12
|
||||||
|
|
||||||
|
16 = 0b_0100[00] => 0b_0100[00] = 16
|
||||||
|
17 = 0b_0100[01] => 0b_0100[00] = 16
|
||||||
|
18 = 0b_0100[10] => 0b_0100[00] = 16
|
||||||
|
19 = 0b_0100[11] => 0b_0100[00] = 16
|
||||||
|
|
||||||
|
20 = 0b_0101[00] => 0b_0101[00] = 20
|
||||||
|
21 = 0b_0101[01] => 0b_0101[00] = 20
|
||||||
|
22 = 0b_0101[10] => 0b_0101[00] = 20
|
||||||
|
23 = 0b_0101[11] => 0b_0101[00] = 20
|
||||||
|
|
||||||
|
24 = 0b_0110[00] => 0b_0110[00] = 24
|
||||||
|
25 = 0b_0110[01] => 0b_0110[00] = 24
|
||||||
|
26 = 0b_0110[10] => 0b_0110[00] = 24
|
||||||
|
27 = 0b_0110[11] => 0b_0110[00] = 24
|
||||||
|
|
||||||
|
28 = 0b_0111[00] => 0b_0111[00] = 28
|
||||||
|
29 = 0b_0111[01] => 0b_0111[00] = 28
|
||||||
|
30 = 0b_0111[10] => 0b_0111[00] = 28
|
||||||
|
31 = 0b_0111[11] => 0b_0111[00] = 28
|
||||||
|
```
|
||||||
|
|
||||||
|
and displays:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
Now, let's see how truncating can be used in FHE.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import itertools
|
||||||
|
import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from concrete import fhe
|
||||||
|
|
||||||
|
configuration = fhe.Configuration(
|
||||||
|
enable_unsafe_features=True,
|
||||||
|
use_insecure_key_cache=True,
|
||||||
|
insecure_key_cache_location=".keys",
|
||||||
|
)
|
||||||
|
|
||||||
|
input_bit_width = 6
|
||||||
|
input_range = np.array(range(2**input_bit_width))
|
||||||
|
|
||||||
|
timings = {}
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for lsbs_to_remove in range(input_bit_width):
|
||||||
|
@fhe.compiler({"x": "encrypted"})
|
||||||
|
def f(x):
|
||||||
|
return fhe.truncate_bit_pattern(x, lsbs_to_remove) ** 2
|
||||||
|
|
||||||
|
circuit = f.compile(inputset=[input_range], configuration=configuration)
|
||||||
|
circuit.keygen()
|
||||||
|
|
||||||
|
encrypted_sample = circuit.encrypt(input_range)
|
||||||
|
start = time.time()
|
||||||
|
encrypted_result = circuit.run(encrypted_sample)
|
||||||
|
end = time.time()
|
||||||
|
result = circuit.decrypt(encrypted_result)
|
||||||
|
|
||||||
|
took = end - start
|
||||||
|
|
||||||
|
timings[lsbs_to_remove] = took
|
||||||
|
results[lsbs_to_remove] = result
|
||||||
|
|
||||||
|
number_of_figures = len(results)
|
||||||
|
|
||||||
|
columns = 1
|
||||||
|
for i in range(2, number_of_figures):
|
||||||
|
if number_of_figures % i == 0:
|
||||||
|
columns = i
|
||||||
|
rows = number_of_figures // columns
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(rows, columns)
|
||||||
|
axs = axs.flatten()
|
||||||
|
|
||||||
|
baseline = timings[0]
|
||||||
|
for lsbs_to_remove in range(input_bit_width):
|
||||||
|
timing = timings[lsbs_to_remove]
|
||||||
|
speedup = baseline / timing
|
||||||
|
print(f"lsbs_to_remove={lsbs_to_remove} => {speedup:.2f}x speedup")
|
||||||
|
|
||||||
|
axs[lsbs_to_remove].set_title(f"lsbs_to_remove={lsbs_to_remove}")
|
||||||
|
axs[lsbs_to_remove].plot(input_range, results[lsbs_to_remove])
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
```
|
||||||
|
|
||||||
|
prints:
|
||||||
|
|
||||||
|
```
|
||||||
|
lsbs_to_remove=0 => 1.00x speedup
|
||||||
|
lsbs_to_remove=1 => 1.69x speedup
|
||||||
|
lsbs_to_remove=2 => 3.48x speedup
|
||||||
|
lsbs_to_remove=3 => 3.06x speedup
|
||||||
|
lsbs_to_remove=4 => 3.46x speedup
|
||||||
|
lsbs_to_remove=5 => 3.14x speedup
|
||||||
|
```
|
||||||
|
|
||||||
|
{% hint style="info" %}
|
||||||
|
These speed-ups can vary from system to system.
|
||||||
|
{% endhint %}
|
||||||
|
|
||||||
|
{% hint style="info" %}
|
||||||
|
The reason why the speed-up is not increasing with `lsbs_to_remove` is because the truncating operation itself has a cost: each bit removal is a PBS. Therefore, if a lot of bits are removed, truncation itself could take longer than the bigger TLU which is evaluated afterwards.
|
||||||
|
{% endhint %}
|
||||||
|
|
||||||
|
and displays:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Auto Truncators
|
||||||
|
|
||||||
|
Truncating is very useful but, in some cases, you don't know how many bits your input contains, so it's not reliable to specify `lsbs_to_remove` manually. For this reason, the `AutoTruncator` class is introduced.
|
||||||
|
|
||||||
|
`AutoTruncator` allows you to set how many of the most significant bits to keep, but they need to be adjusted using an inputset to determine how many of the least significant bits to remove. This can be done manually using `fhe.AutoTruncator.adjust(function, inputset)`, or by setting `auto_adjust_truncators` configuration to `True` during compilation.
|
||||||
|
|
||||||
|
Here is how auto truncators can be used in FHE:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import itertools
|
||||||
|
import time
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
from concrete import fhe
|
||||||
|
|
||||||
|
configuration = fhe.Configuration(
|
||||||
|
enable_unsafe_features=True,
|
||||||
|
use_insecure_key_cache=True,
|
||||||
|
insecure_key_cache_location=".keys",
|
||||||
|
single_precision=False,
|
||||||
|
parameter_selection_strategy=fhe.ParameterSelectionStrategy.MULTI,
|
||||||
|
)
|
||||||
|
|
||||||
|
input_bit_width = 6
|
||||||
|
input_range = np.array(range(2**input_bit_width))
|
||||||
|
|
||||||
|
timings = {}
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
for target_msbs in reversed(range(1, input_bit_width + 1)):
|
||||||
|
truncator = fhe.AutoTruncator(target_msbs)
|
||||||
|
|
||||||
|
@fhe.compiler({"x": "encrypted"})
|
||||||
|
def f(x):
|
||||||
|
return fhe.truncate_bit_pattern(x, lsbs_to_remove=truncator) ** 2
|
||||||
|
|
||||||
|
fhe.AutoTruncator.adjust(f, inputset=[input_range])
|
||||||
|
|
||||||
|
circuit = f.compile(inputset=[input_range], configuration=configuration)
|
||||||
|
circuit.keygen()
|
||||||
|
|
||||||
|
encrypted_sample = circuit.encrypt(input_range)
|
||||||
|
start = time.time()
|
||||||
|
encrypted_result = circuit.run(encrypted_sample)
|
||||||
|
end = time.time()
|
||||||
|
result = circuit.decrypt(encrypted_result)
|
||||||
|
|
||||||
|
took = end - start
|
||||||
|
|
||||||
|
timings[target_msbs] = took
|
||||||
|
results[target_msbs] = result
|
||||||
|
|
||||||
|
number_of_figures = len(results)
|
||||||
|
|
||||||
|
columns = 1
|
||||||
|
for i in range(2, number_of_figures):
|
||||||
|
if number_of_figures % i == 0:
|
||||||
|
columns = i
|
||||||
|
rows = number_of_figures // columns
|
||||||
|
|
||||||
|
fig, axs = plt.subplots(rows, columns)
|
||||||
|
axs = axs.flatten()
|
||||||
|
|
||||||
|
baseline = timings[input_bit_width]
|
||||||
|
for i, target_msbs in enumerate(reversed(range(1, input_bit_width + 1))):
|
||||||
|
timing = timings[target_msbs]
|
||||||
|
speedup = baseline / timing
|
||||||
|
print(f"target_msbs={target_msbs} => {speedup:.2f}x speedup")
|
||||||
|
|
||||||
|
axs[i].set_title(f"target_msbs={target_msbs}")
|
||||||
|
axs[i].plot(input_range, results[target_msbs])
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
```
|
||||||
|
|
||||||
|
prints:
|
||||||
|
|
||||||
|
```
|
||||||
|
target_msbs=6 => 1.00x speedup
|
||||||
|
target_msbs=5 => 1.80x speedup
|
||||||
|
target_msbs=4 => 3.47x speedup
|
||||||
|
target_msbs=3 => 3.02x speedup
|
||||||
|
target_msbs=2 => 3.38x speedup
|
||||||
|
target_msbs=1 => 3.37x speedup
|
||||||
|
```
|
||||||
|
|
||||||
|
and displays:
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
{% hint style="warning" %}
|
||||||
|
`AutoTruncator`s should be defined outside the function that is being compiled. They are used to store the result of the adjustment process, so they shouldn't be created each time the function is called. Furthermore, each `AutoTruncator` should be used with exactly one `truncate_bit_pattern` call.
|
||||||
|
{% endhint %}
|
||||||
@@ -28,6 +28,7 @@ from .compilation import (
|
|||||||
from .compilation.decorators import circuit, compiler
|
from .compilation.decorators import circuit, compiler
|
||||||
from .extensions import (
|
from .extensions import (
|
||||||
AutoRounder,
|
AutoRounder,
|
||||||
|
AutoTruncator,
|
||||||
LookupTable,
|
LookupTable,
|
||||||
array,
|
array,
|
||||||
conv,
|
conv,
|
||||||
@@ -39,6 +40,7 @@ from .extensions import (
|
|||||||
ones_like,
|
ones_like,
|
||||||
round_bit_pattern,
|
round_bit_pattern,
|
||||||
tag,
|
tag,
|
||||||
|
truncate_bit_pattern,
|
||||||
univariate,
|
univariate,
|
||||||
zero,
|
zero,
|
||||||
zeros,
|
zeros,
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from concrete.compiler import CompilationContext
|
from concrete.compiler import CompilationContext
|
||||||
|
|
||||||
from ..extensions import AutoRounder
|
from ..extensions import AutoRounder, AutoTruncator
|
||||||
from ..mlir import GraphConverter
|
from ..mlir import GraphConverter
|
||||||
from ..representation import Graph
|
from ..representation import Graph
|
||||||
from ..tracing import Tracer
|
from ..tracing import Tracer
|
||||||
@@ -279,6 +279,9 @@ class Compiler:
|
|||||||
if self.configuration.auto_adjust_rounders:
|
if self.configuration.auto_adjust_rounders:
|
||||||
AutoRounder.adjust(self.function, self.inputset)
|
AutoRounder.adjust(self.function, self.inputset)
|
||||||
|
|
||||||
|
if self.configuration.auto_adjust_truncators:
|
||||||
|
AutoTruncator.adjust(self.function, self.inputset)
|
||||||
|
|
||||||
if self.graph is None:
|
if self.graph is None:
|
||||||
try:
|
try:
|
||||||
first_sample = next(iter(self.inputset))
|
first_sample = next(iter(self.inputset))
|
||||||
|
|||||||
@@ -880,6 +880,7 @@ class Configuration:
|
|||||||
global_p_error: Optional[float]
|
global_p_error: Optional[float]
|
||||||
insecure_key_cache_location: Optional[str]
|
insecure_key_cache_location: Optional[str]
|
||||||
auto_adjust_rounders: bool
|
auto_adjust_rounders: bool
|
||||||
|
auto_adjust_truncators: bool
|
||||||
single_precision: bool
|
single_precision: bool
|
||||||
parameter_selection_strategy: ParameterSelectionStrategy
|
parameter_selection_strategy: ParameterSelectionStrategy
|
||||||
show_progress: bool
|
show_progress: bool
|
||||||
@@ -913,6 +914,7 @@ class Configuration:
|
|||||||
p_error: Optional[float] = None,
|
p_error: Optional[float] = None,
|
||||||
global_p_error: Optional[float] = None,
|
global_p_error: Optional[float] = None,
|
||||||
auto_adjust_rounders: bool = False,
|
auto_adjust_rounders: bool = False,
|
||||||
|
auto_adjust_truncators: bool = False,
|
||||||
single_precision: bool = False,
|
single_precision: bool = False,
|
||||||
parameter_selection_strategy: Union[
|
parameter_selection_strategy: Union[
|
||||||
ParameterSelectionStrategy, str
|
ParameterSelectionStrategy, str
|
||||||
@@ -959,6 +961,7 @@ class Configuration:
|
|||||||
self.p_error = p_error
|
self.p_error = p_error
|
||||||
self.global_p_error = global_p_error
|
self.global_p_error = global_p_error
|
||||||
self.auto_adjust_rounders = auto_adjust_rounders
|
self.auto_adjust_rounders = auto_adjust_rounders
|
||||||
|
self.auto_adjust_truncators = auto_adjust_truncators
|
||||||
self.single_precision = single_precision
|
self.single_precision = single_precision
|
||||||
self.parameter_selection_strategy = ParameterSelectionStrategy.parse(
|
self.parameter_selection_strategy = ParameterSelectionStrategy.parse(
|
||||||
parameter_selection_strategy
|
parameter_selection_strategy
|
||||||
@@ -1035,6 +1038,7 @@ class Configuration:
|
|||||||
p_error: Union[Keep, Optional[float]] = KEEP,
|
p_error: Union[Keep, Optional[float]] = KEEP,
|
||||||
global_p_error: Union[Keep, Optional[float]] = KEEP,
|
global_p_error: Union[Keep, Optional[float]] = KEEP,
|
||||||
auto_adjust_rounders: Union[Keep, bool] = KEEP,
|
auto_adjust_rounders: Union[Keep, bool] = KEEP,
|
||||||
|
auto_adjust_truncators: Union[Keep, bool] = KEEP,
|
||||||
single_precision: Union[Keep, bool] = KEEP,
|
single_precision: Union[Keep, bool] = KEEP,
|
||||||
parameter_selection_strategy: Union[Keep, Union[ParameterSelectionStrategy, str]] = KEEP,
|
parameter_selection_strategy: Union[Keep, Union[ParameterSelectionStrategy, str]] = KEEP,
|
||||||
show_progress: Union[Keep, bool] = KEEP,
|
show_progress: Union[Keep, bool] = KEEP,
|
||||||
|
|||||||
@@ -11,5 +11,6 @@ from .ones import one, ones, ones_like
|
|||||||
from .round_bit_pattern import AutoRounder, round_bit_pattern
|
from .round_bit_pattern import AutoRounder, round_bit_pattern
|
||||||
from .table import LookupTable
|
from .table import LookupTable
|
||||||
from .tag import tag
|
from .tag import tag
|
||||||
|
from .truncate_bit_pattern import AutoTruncator, truncate_bit_pattern
|
||||||
from .univariate import univariate
|
from .univariate import univariate
|
||||||
from .zeros import zero, zeros, zeros_like
|
from .zeros import zero, zeros, zeros_like
|
||||||
|
|||||||
@@ -0,0 +1,257 @@
|
|||||||
|
"""
|
||||||
|
Declaration of `truncate_bit_pattern` extension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Any, Callable, Dict, Iterable, List, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from ..dtypes import Integer
|
||||||
|
from ..mlir.utils import MAXIMUM_TLU_BIT_WIDTH
|
||||||
|
from ..representation import Node
|
||||||
|
from ..tracing import Tracer
|
||||||
|
from ..values import ValueDescription
|
||||||
|
|
||||||
|
local = threading.local()
|
||||||
|
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
local._is_adjusting = False
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
|
||||||
|
class Adjusting(BaseException):
|
||||||
|
"""
|
||||||
|
Adjusting class, to be used as early stop signal during adjustment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
truncator: "AutoTruncator"
|
||||||
|
input_min: int
|
||||||
|
input_max: int
|
||||||
|
|
||||||
|
def __init__(self, truncator: "AutoTruncator", input_min: int, input_max: int):
|
||||||
|
super().__init__()
|
||||||
|
self.truncator = truncator
|
||||||
|
self.input_min = input_min
|
||||||
|
self.input_max = input_max
|
||||||
|
|
||||||
|
|
||||||
|
class AutoTruncator:
|
||||||
|
"""
|
||||||
|
AutoTruncator class, to optimize for the number of msbs to keep during truncate operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
target_msbs: int
|
||||||
|
|
||||||
|
is_adjusted: bool
|
||||||
|
input_min: int
|
||||||
|
input_max: int
|
||||||
|
input_bit_width: int
|
||||||
|
lsbs_to_remove: int
|
||||||
|
|
||||||
|
def __init__(self, target_msbs: int = MAXIMUM_TLU_BIT_WIDTH):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
if local._is_adjusting:
|
||||||
|
message = (
|
||||||
|
"AutoTruncators cannot be constructed during adjustment, "
|
||||||
|
"please construct AutoTruncators outside the function and reference it"
|
||||||
|
)
|
||||||
|
raise RuntimeError(message)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
self.target_msbs = target_msbs
|
||||||
|
|
||||||
|
self.is_adjusted = False
|
||||||
|
self.input_min = 0
|
||||||
|
self.input_max = 0
|
||||||
|
self.input_bit_width = 0
|
||||||
|
self.lsbs_to_remove = 0
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def adjust(function: Callable, inputset: Union[Iterable[Any], Iterable[Tuple[Any, ...]]]):
|
||||||
|
"""
|
||||||
|
Adjust AutoTruncators in a function using an inputset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=protected-access,too-many-branches
|
||||||
|
|
||||||
|
try: # extract underlying function for decorators
|
||||||
|
function = function.function # type: ignore
|
||||||
|
assert callable(function)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if local._is_adjusting:
|
||||||
|
message = "AutoTruncators cannot be adjusted recursively"
|
||||||
|
raise RuntimeError(message)
|
||||||
|
|
||||||
|
try:
|
||||||
|
local._is_adjusting = True
|
||||||
|
|
||||||
|
# adjust the truncator using the inputset
|
||||||
|
# this loop continues until the return is reached in the loop body
|
||||||
|
# which only happens when ALL truncators are adjusted
|
||||||
|
# this condition is met if the function can be executed fully
|
||||||
|
# without `Adjusting` exception is raised
|
||||||
|
|
||||||
|
while True:
|
||||||
|
truncator = None
|
||||||
|
|
||||||
|
for sample in inputset:
|
||||||
|
if not isinstance(sample, tuple):
|
||||||
|
sample = (sample,)
|
||||||
|
|
||||||
|
try:
|
||||||
|
function(*sample)
|
||||||
|
except Adjusting as adjuster:
|
||||||
|
truncator = adjuster.truncator
|
||||||
|
|
||||||
|
truncator.input_min = min(truncator.input_min, adjuster.input_min)
|
||||||
|
truncator.input_max = max(truncator.input_max, adjuster.input_max)
|
||||||
|
|
||||||
|
input_value = ValueDescription.of(
|
||||||
|
[truncator.input_min, truncator.input_max]
|
||||||
|
)
|
||||||
|
assert isinstance(input_value.dtype, Integer)
|
||||||
|
truncator.input_bit_width = input_value.dtype.bit_width
|
||||||
|
|
||||||
|
if (
|
||||||
|
truncator.input_bit_width - truncator.lsbs_to_remove
|
||||||
|
> truncator.target_msbs
|
||||||
|
):
|
||||||
|
truncator.lsbs_to_remove = (
|
||||||
|
truncator.input_bit_width - truncator.target_msbs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# this branch will be executed if there were no exceptions in the try block
|
||||||
|
return
|
||||||
|
|
||||||
|
if truncator is None:
|
||||||
|
message = "AutoTruncators cannot be adjusted with an empty inputset"
|
||||||
|
raise ValueError(message)
|
||||||
|
|
||||||
|
truncator.is_adjusted = True
|
||||||
|
|
||||||
|
finally:
|
||||||
|
local._is_adjusting = False
|
||||||
|
|
||||||
|
# pylint: enable=protected-access,too-many-branches
|
||||||
|
|
||||||
|
def dump_dict(self) -> Dict:
|
||||||
|
"""
|
||||||
|
Dump properties of the truncator to a dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
return {
|
||||||
|
"target_msbs": self.target_msbs,
|
||||||
|
"is_adjusted": self.is_adjusted,
|
||||||
|
"input_min": self.input_min,
|
||||||
|
"input_max": self.input_max,
|
||||||
|
"input_bit_width": self.input_bit_width,
|
||||||
|
"lsbs_to_remove": self.lsbs_to_remove,
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load_dict(cls, properties: Dict) -> "AutoTruncator":
|
||||||
|
"""
|
||||||
|
Load previously dumped truncator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
result = AutoTruncator(target_msbs=properties["target_msbs"])
|
||||||
|
|
||||||
|
result.is_adjusted = properties["is_adjusted"]
|
||||||
|
result.input_min = properties["input_min"]
|
||||||
|
result.input_max = properties["input_max"]
|
||||||
|
result.lsbs_to_remove = properties["lsbs_to_remove"]
|
||||||
|
result.input_bit_width = properties["input_bit_width"]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def truncate_bit_pattern(
|
||||||
|
x: Union[int, np.integer, List, np.ndarray, Tracer],
|
||||||
|
lsbs_to_remove: Union[int, AutoTruncator],
|
||||||
|
) -> Union[int, np.integer, List, np.ndarray, Tracer]:
|
||||||
|
"""
|
||||||
|
Round the bit pattern of an integer.
|
||||||
|
|
||||||
|
If `lsbs_to_remove` is an `AutoTruncator`:
|
||||||
|
corresponding integer value will be determined by adjustment process.
|
||||||
|
|
||||||
|
x = 0b_0000 , lsbs_to_remove = 2 => 0b_0000
|
||||||
|
x = 0b_0001 , lsbs_to_remove = 2 => 0b_0000
|
||||||
|
x = 0b_0010 , lsbs_to_remove = 2 => 0b_0000
|
||||||
|
x = 0b_0100 , lsbs_to_remove = 2 => 0b_0100
|
||||||
|
x = 0b_0110 , lsbs_to_remove = 2 => 0b_0100
|
||||||
|
x = 0b_1100 , lsbs_to_remove = 2 => 0b_1100
|
||||||
|
x = 0b_abcd , lsbs_to_remove = 2 => 0b_ab00
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (Union[int, np.integer, np.ndarray, Tracer]):
|
||||||
|
input to truncate
|
||||||
|
|
||||||
|
lsbs_to_remove (Union[int, AutoTruncator]):
|
||||||
|
number of the least significant bits to clear
|
||||||
|
or an auto truncator object which will be used to determine the integer value
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[int, np.integer, np.ndarray, Tracer]:
|
||||||
|
Tracer that represents the operation during tracing
|
||||||
|
truncated value(s) otherwise
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=protected-access,too-many-branches
|
||||||
|
|
||||||
|
if isinstance(lsbs_to_remove, AutoTruncator):
|
||||||
|
if local._is_adjusting:
|
||||||
|
if not lsbs_to_remove.is_adjusted:
|
||||||
|
raise Adjusting(lsbs_to_remove, int(np.min(x)), int(np.max(x))) # type: ignore
|
||||||
|
|
||||||
|
elif not lsbs_to_remove.is_adjusted:
|
||||||
|
message = (
|
||||||
|
"AutoTruncators cannot be used before adjustment, "
|
||||||
|
"please call AutoTruncator.adjust with the function that will be compiled "
|
||||||
|
"and provide the exact inputset that will be used for compilation"
|
||||||
|
)
|
||||||
|
raise RuntimeError(message)
|
||||||
|
|
||||||
|
lsbs_to_remove = lsbs_to_remove.lsbs_to_remove
|
||||||
|
|
||||||
|
assert isinstance(lsbs_to_remove, int)
|
||||||
|
|
||||||
|
def evaluator(
|
||||||
|
x: Union[int, np.integer, np.ndarray],
|
||||||
|
lsbs_to_remove: int,
|
||||||
|
) -> Union[int, np.integer, np.ndarray]:
|
||||||
|
return (x >> lsbs_to_remove) << lsbs_to_remove
|
||||||
|
|
||||||
|
if isinstance(x, Tracer):
|
||||||
|
computation = Node.generic(
|
||||||
|
"truncate_bit_pattern",
|
||||||
|
[deepcopy(x.output)],
|
||||||
|
deepcopy(x.output),
|
||||||
|
evaluator,
|
||||||
|
kwargs={"lsbs_to_remove": lsbs_to_remove},
|
||||||
|
)
|
||||||
|
return Tracer(computation, [x])
|
||||||
|
|
||||||
|
if isinstance(x, list): # pragma: no cover
|
||||||
|
try:
|
||||||
|
x = np.array(x)
|
||||||
|
except Exception: # pylint: disable=broad-except
|
||||||
|
pass
|
||||||
|
|
||||||
|
if isinstance(x, np.ndarray):
|
||||||
|
if not np.issubdtype(x.dtype, np.integer):
|
||||||
|
message = (
|
||||||
|
f"Expected input elements to be integers but they are {type(x.dtype).__name__}"
|
||||||
|
)
|
||||||
|
raise TypeError(message)
|
||||||
|
elif not isinstance(x, (int, np.integer)):
|
||||||
|
message = f"Expected input to be an int or a numpy array but it's {type(x).__name__}"
|
||||||
|
raise TypeError(message)
|
||||||
|
|
||||||
|
return evaluator(x, lsbs_to_remove)
|
||||||
|
|
||||||
|
# pylint: enable=protected-access,too-many-branches
|
||||||
@@ -3218,6 +3218,45 @@ class Context:
|
|||||||
|
|
||||||
return xs[0]
|
return xs[0]
|
||||||
|
|
||||||
|
def truncate_bit_pattern(self, x: Conversion, lsbs_to_remove: int) -> Conversion:
|
||||||
|
if x.is_clear:
|
||||||
|
highlights = {
|
||||||
|
x.origin: "operand is clear",
|
||||||
|
self.converting: "but clear truncate bit pattern is not supported",
|
||||||
|
}
|
||||||
|
self.error(highlights)
|
||||||
|
|
||||||
|
assert x.bit_width > lsbs_to_remove
|
||||||
|
|
||||||
|
resulting_bit_width = x.bit_width
|
||||||
|
for i in range(lsbs_to_remove):
|
||||||
|
lsb = self.lsb(x.type, x)
|
||||||
|
cleared = self.sub(x.type, x, lsb)
|
||||||
|
|
||||||
|
new_bit_width = (x.bit_width - 1) if i != (lsbs_to_remove - 1) else resulting_bit_width
|
||||||
|
x = self.reinterpret(cleared, bit_width=new_bit_width)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
def lsb(self, resulting_type: ConversionType, x: Conversion) -> Conversion:
|
||||||
|
assert resulting_type.shape == x.shape
|
||||||
|
assert resulting_type.is_signed == x.is_signed
|
||||||
|
assert resulting_type.is_encrypted and x.is_encrypted
|
||||||
|
|
||||||
|
operation = fhe.LsbEintOp if x.is_scalar else fhelinalg.LsbEintOp
|
||||||
|
return self.operation(operation, resulting_type, x.result)
|
||||||
|
|
||||||
|
def reinterpret(self, x: Conversion, *, bit_width: int) -> Conversion:
|
||||||
|
assert x.is_encrypted
|
||||||
|
|
||||||
|
resulting_element_type = (self.eint if x.is_unsigned else self.esint)(bit_width)
|
||||||
|
resulting_type = self.tensor(resulting_element_type, shape=x.shape)
|
||||||
|
|
||||||
|
operation = (
|
||||||
|
fhe.ReinterpretPrecisionEintOp if x.is_scalar else fhelinalg.ReinterpretPrecisionEintOp
|
||||||
|
)
|
||||||
|
return self.operation(operation, resulting_type, x.result)
|
||||||
|
|
||||||
def zeros(self, resulting_type: ConversionType) -> Conversion:
|
def zeros(self, resulting_type: ConversionType) -> Conversion:
|
||||||
assert resulting_type.is_encrypted
|
assert resulting_type.is_encrypted
|
||||||
|
|
||||||
|
|||||||
@@ -618,6 +618,20 @@ class Converter:
|
|||||||
variable_input_index = variable_input_indices[0]
|
variable_input_index = variable_input_indices[0]
|
||||||
variable_input = preds[variable_input_index]
|
variable_input = preds[variable_input_index]
|
||||||
|
|
||||||
|
if variable_input.origin.properties.get("name") == "truncate_bit_pattern":
|
||||||
|
original_bit_width = variable_input.origin.properties["original_bit_width"]
|
||||||
|
lsbs_to_remove = variable_input.origin.properties["kwargs"]["lsbs_to_remove"]
|
||||||
|
truncated_bit_width = original_bit_width - lsbs_to_remove
|
||||||
|
|
||||||
|
if variable_input.bit_width > original_bit_width:
|
||||||
|
bit_width_difference = variable_input.bit_width - original_bit_width
|
||||||
|
shifter = ctx.constant(
|
||||||
|
ctx.i(variable_input.bit_width + 1), 2**bit_width_difference
|
||||||
|
)
|
||||||
|
variable_input = ctx.mul(variable_input.type, variable_input, shifter)
|
||||||
|
|
||||||
|
variable_input = ctx.reinterpret(variable_input, bit_width=truncated_bit_width)
|
||||||
|
|
||||||
if len(tables) == 1:
|
if len(tables) == 1:
|
||||||
return ctx.tlu(ctx.typeof(node), on=variable_input, table=lut_values.tolist())
|
return ctx.tlu(ctx.typeof(node), on=variable_input, table=lut_values.tolist())
|
||||||
|
|
||||||
@@ -637,6 +651,10 @@ class Converter:
|
|||||||
axes=node.properties["kwargs"].get("axes", []),
|
axes=node.properties["kwargs"].get("axes", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def truncate_bit_pattern(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||||
|
assert len(preds) == 1
|
||||||
|
return ctx.truncate_bit_pattern(preds[0], node.properties["kwargs"]["lsbs_to_remove"])
|
||||||
|
|
||||||
def zeros(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
def zeros(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||||
assert len(preds) == 0
|
assert len(preds) == 0
|
||||||
return ctx.zeros(ctx.typeof(node))
|
return ctx.zeros(ctx.typeof(node))
|
||||||
|
|||||||
@@ -562,3 +562,7 @@ class AdditionalConstraints:
|
|||||||
transpose = {
|
transpose = {
|
||||||
inputs_and_output_share_precision,
|
inputs_and_output_share_precision,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
truncate_bit_pattern = {
|
||||||
|
inputs_and_output_share_precision,
|
||||||
|
}
|
||||||
|
|||||||
@@ -162,7 +162,9 @@ def construct_table(node: Node, preds: List[Node]) -> List[Any]:
|
|||||||
if pred.operation != Operation.Constant:
|
if pred.operation != Operation.Constant:
|
||||||
variable_input_index = index
|
variable_input_index = index
|
||||||
break
|
break
|
||||||
|
|
||||||
assert_that(variable_input_index != -1)
|
assert_that(variable_input_index != -1)
|
||||||
|
variable_input = preds[variable_input_index]
|
||||||
|
|
||||||
variable_input_dtype = node.inputs[variable_input_index].dtype
|
variable_input_dtype = node.inputs[variable_input_index].dtype
|
||||||
variable_input_shape = node.inputs[variable_input_index].shape
|
variable_input_shape = node.inputs[variable_input_index].shape
|
||||||
@@ -170,7 +172,6 @@ def construct_table(node: Node, preds: List[Node]) -> List[Any]:
|
|||||||
assert_that(isinstance(variable_input_dtype, Integer))
|
assert_that(isinstance(variable_input_dtype, Integer))
|
||||||
variable_input_dtype = deepcopy(cast(Integer, variable_input_dtype))
|
variable_input_dtype = deepcopy(cast(Integer, variable_input_dtype))
|
||||||
|
|
||||||
variable_input = preds[variable_input_index]
|
|
||||||
if (
|
if (
|
||||||
variable_input.operation == Operation.Generic
|
variable_input.operation == Operation.Generic
|
||||||
and variable_input.properties["name"] == "round_bit_pattern"
|
and variable_input.properties["name"] == "round_bit_pattern"
|
||||||
@@ -186,6 +187,20 @@ def construct_table(node: Node, preds: List[Node]) -> List[Any]:
|
|||||||
variable_input_dtype.bit_width += 1
|
variable_input_dtype.bit_width += 1
|
||||||
|
|
||||||
step = (2**variable_input_dtype.bit_width) // expected_number_of_elements
|
step = (2**variable_input_dtype.bit_width) // expected_number_of_elements
|
||||||
|
|
||||||
|
elif (
|
||||||
|
variable_input.operation == Operation.Generic
|
||||||
|
and variable_input.properties["name"] == "truncate_bit_pattern"
|
||||||
|
):
|
||||||
|
original_bit_width = variable_input.properties["original_bit_width"]
|
||||||
|
lsbs_to_remove = variable_input.properties["kwargs"]["lsbs_to_remove"]
|
||||||
|
|
||||||
|
resulting_bit_width = original_bit_width - lsbs_to_remove
|
||||||
|
expected_number_of_elements = 2**resulting_bit_width
|
||||||
|
|
||||||
|
variable_input_dtype.bit_width = original_bit_width
|
||||||
|
step = (2**original_bit_width) // expected_number_of_elements
|
||||||
|
|
||||||
else:
|
else:
|
||||||
step = 1
|
step = 1
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
importlib-resources>=6.1
|
||||||
networkx>=2.6
|
networkx>=2.6
|
||||||
numpy>=1.23
|
numpy>=1.23
|
||||||
scipy>=1.10
|
scipy>=1.10
|
||||||
|
|||||||
@@ -0,0 +1,467 @@
|
|||||||
|
"""
|
||||||
|
Tests of execution of truncate bit pattern operation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from concrete import fhe
|
||||||
|
from concrete.fhe.representation.utils import format_constant
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sample,lsbs_to_remove,expected_output",
|
||||||
|
[
|
||||||
|
(0b_0000_0011, 0, 0b_0000_0011),
|
||||||
|
(0b_0000_0100, 0, 0b_0000_0100),
|
||||||
|
(0b_0000_0000, 3, 0b_0000_0000),
|
||||||
|
(0b_0000_0001, 3, 0b_0000_0000),
|
||||||
|
(0b_0000_0010, 3, 0b_0000_0000),
|
||||||
|
(0b_0000_0011, 3, 0b_0000_0000),
|
||||||
|
(0b_0000_0100, 3, 0b_0000_0000),
|
||||||
|
(0b_0000_0101, 3, 0b_0000_0000),
|
||||||
|
(0b_0000_0110, 3, 0b_0000_0000),
|
||||||
|
(0b_0000_0111, 3, 0b_0000_0000),
|
||||||
|
(0b_0000_1000, 3, 0b_0000_1000),
|
||||||
|
(0b_0000_1001, 3, 0b_0000_1000),
|
||||||
|
(0b_0000_1010, 3, 0b_0000_1000),
|
||||||
|
(0b_0000_1011, 3, 0b_0000_1000),
|
||||||
|
(0b_0000_1100, 3, 0b_0000_1000),
|
||||||
|
(0b_0000_1101, 3, 0b_0000_1000),
|
||||||
|
(0b_0000_1110, 3, 0b_0000_1000),
|
||||||
|
(0b_0000_1111, 3, 0b_0000_1000),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_plain_truncate_bit_pattern(sample, lsbs_to_remove, expected_output):
|
||||||
|
"""
|
||||||
|
Test truncate bit pattern in evaluation context.
|
||||||
|
"""
|
||||||
|
assert fhe.truncate_bit_pattern(sample, lsbs_to_remove=lsbs_to_remove) == expected_output
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"sample,lsbs_to_remove,expected_error,expected_message",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
np.array([3.2, 4.1]),
|
||||||
|
3,
|
||||||
|
TypeError,
|
||||||
|
f"Expected input elements to be integers but they are {type(np.array([3.2, 4.1]).dtype).__name__}", # noqa: E501
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"foo",
|
||||||
|
3,
|
||||||
|
TypeError,
|
||||||
|
"Expected input to be an int or a numpy array but it's str",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_bad_plain_truncate_bit_pattern(
|
||||||
|
sample,
|
||||||
|
lsbs_to_remove,
|
||||||
|
expected_error,
|
||||||
|
expected_message,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Test truncate bit pattern in evaluation context with bad parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
with pytest.raises(expected_error) as excinfo:
|
||||||
|
fhe.truncate_bit_pattern(sample, lsbs_to_remove=lsbs_to_remove)
|
||||||
|
|
||||||
|
assert str(excinfo.value) == expected_message
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_bits,lsbs_to_remove",
|
||||||
|
[
|
||||||
|
(3, 1),
|
||||||
|
(3, 2),
|
||||||
|
(4, 1),
|
||||||
|
(4, 2),
|
||||||
|
(4, 3),
|
||||||
|
(5, 1),
|
||||||
|
(5, 2),
|
||||||
|
(5, 3),
|
||||||
|
(5, 4),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"mapper",
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
lambda x: x,
|
||||||
|
id="x",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
lambda x: x + 10,
|
||||||
|
id="x + 10",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
lambda x: x**2,
|
||||||
|
id="x ** 2",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
lambda x: fhe.univariate(lambda x: x if x >= 0 else 0)(x),
|
||||||
|
id="relu",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_truncate_bit_pattern(input_bits, lsbs_to_remove, mapper, helpers):
|
||||||
|
"""
|
||||||
|
Test truncate bit pattern.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@fhe.compiler({"x": "encrypted"})
|
||||||
|
def function(x):
|
||||||
|
x_truncated = fhe.truncate_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
|
||||||
|
return mapper(x_truncated)
|
||||||
|
|
||||||
|
upper_bound = 2**input_bits
|
||||||
|
inputset = [0, upper_bound - 1]
|
||||||
|
|
||||||
|
circuit = function.compile(inputset, helpers.configuration())
|
||||||
|
helpers.check_execution(circuit, function, np.random.randint(0, upper_bound), retries=3)
|
||||||
|
|
||||||
|
for value in inputset:
|
||||||
|
helpers.check_execution(circuit, function, value, retries=3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_bits,lsbs_to_remove",
|
||||||
|
[
|
||||||
|
(3, 1),
|
||||||
|
(3, 2),
|
||||||
|
(4, 1),
|
||||||
|
(4, 2),
|
||||||
|
(4, 3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_truncate_bit_pattern_unsigned_range(input_bits, lsbs_to_remove, helpers):
|
||||||
|
"""
|
||||||
|
Test truncate bit pattern in unsigned range.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@fhe.compiler({"x": "encrypted"})
|
||||||
|
def function(x):
|
||||||
|
return fhe.truncate_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
|
||||||
|
|
||||||
|
inputset = range(0, 2**input_bits)
|
||||||
|
circuit = function.compile(inputset, helpers.configuration())
|
||||||
|
|
||||||
|
for value in inputset:
|
||||||
|
helpers.check_execution(circuit, function, value, retries=3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_bits,lsbs_to_remove",
|
||||||
|
[
|
||||||
|
(3, 1),
|
||||||
|
(3, 2),
|
||||||
|
(4, 1),
|
||||||
|
(4, 2),
|
||||||
|
(4, 3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_truncate_bit_pattern_signed_range(input_bits, lsbs_to_remove, helpers):
|
||||||
|
"""
|
||||||
|
Test truncate bit pattern in signed range.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@fhe.compiler({"x": "encrypted"})
|
||||||
|
def function(x):
|
||||||
|
return fhe.truncate_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
|
||||||
|
|
||||||
|
inputset = range(-(2 ** (input_bits - 1)), 2 ** (input_bits - 1))
|
||||||
|
circuit = function.compile(inputset, helpers.configuration())
|
||||||
|
|
||||||
|
for value in inputset:
|
||||||
|
helpers.check_execution(circuit, function, value, retries=3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_bits,lsbs_to_remove",
|
||||||
|
[
|
||||||
|
(3, 1),
|
||||||
|
(3, 2),
|
||||||
|
(4, 1),
|
||||||
|
(4, 2),
|
||||||
|
(4, 3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_truncate_bit_pattern_unsigned_range_assigned(input_bits, lsbs_to_remove, helpers):
|
||||||
|
"""
|
||||||
|
Test truncate bit pattern in unsigned range with a big bit-width assigned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@fhe.compiler({"x": "encrypted"})
|
||||||
|
def function(x):
|
||||||
|
truncated = fhe.truncate_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
|
||||||
|
return (truncated**2) + (63 - x)
|
||||||
|
|
||||||
|
inputset = range(0, 2**input_bits)
|
||||||
|
circuit = function.compile(inputset, helpers.configuration())
|
||||||
|
|
||||||
|
for value in inputset:
|
||||||
|
helpers.check_execution(circuit, function, value, retries=3)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_bits,lsbs_to_remove",
|
||||||
|
[
|
||||||
|
(3, 1),
|
||||||
|
(3, 2),
|
||||||
|
(4, 1),
|
||||||
|
(4, 2),
|
||||||
|
(4, 3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_truncate_bit_pattern_signed_range_assigned(input_bits, lsbs_to_remove, helpers):
|
||||||
|
"""
|
||||||
|
Test truncate bit pattern in signed range with a big bit-width assigned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@fhe.compiler({"x": "encrypted"})
|
||||||
|
def function(x):
|
||||||
|
truncated = fhe.truncate_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
|
||||||
|
return (truncated**2) + (63 - x)
|
||||||
|
|
||||||
|
inputset = range(-(2 ** (input_bits - 1)), 2 ** (input_bits - 1))
|
||||||
|
circuit = function.compile(inputset, helpers.configuration())
|
||||||
|
|
||||||
|
for value in inputset:
|
||||||
|
helpers.check_execution(circuit, function, value, retries=3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_truncate_bit_pattern_identity(helpers, pytestconfig):
|
||||||
|
"""
|
||||||
|
Test truncate bit pattern used multiple times outside TLUs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@fhe.compiler({"x": "encrypted"})
|
||||||
|
def function(x):
|
||||||
|
truncated = fhe.truncate_bit_pattern(x, lsbs_to_remove=2)
|
||||||
|
return truncated + truncated
|
||||||
|
|
||||||
|
inputset = range(-20, 20)
|
||||||
|
circuit = function.compile(inputset, helpers.configuration())
|
||||||
|
|
||||||
|
expected_mlir = (
|
||||||
|
"""
|
||||||
|
|
||||||
|
module {
|
||||||
|
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
|
||||||
|
%0 = "FHE.lsb"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<7>
|
||||||
|
%1 = "FHE.sub_eint"(%arg0, %0) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
|
||||||
|
%2 = "FHE.reinterpret_precision"(%1) : (!FHE.esint<7>) -> !FHE.esint<6>
|
||||||
|
%3 = "FHE.lsb"(%2) : (!FHE.esint<6>) -> !FHE.esint<6>
|
||||||
|
%4 = "FHE.sub_eint"(%2, %3) : (!FHE.esint<6>, !FHE.esint<6>) -> !FHE.esint<6>
|
||||||
|
%5 = "FHE.reinterpret_precision"(%4) : (!FHE.esint<6>) -> !FHE.esint<7>
|
||||||
|
%6 = "FHE.add_eint"(%5, %5) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
|
||||||
|
return %6 : !FHE.esint<7>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
if pytestconfig.getoption("precision") == "multi"
|
||||||
|
else """
|
||||||
|
|
||||||
|
module {
|
||||||
|
func.func @main(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
|
||||||
|
%0 = "FHE.lsb"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<7>
|
||||||
|
%1 = "FHE.sub_eint"(%arg0, %0) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
|
||||||
|
%2 = "FHE.reinterpret_precision"(%1) : (!FHE.esint<7>) -> !FHE.esint<6>
|
||||||
|
%3 = "FHE.lsb"(%2) : (!FHE.esint<6>) -> !FHE.esint<6>
|
||||||
|
%4 = "FHE.sub_eint"(%2, %3) : (!FHE.esint<6>, !FHE.esint<6>) -> !FHE.esint<6>
|
||||||
|
%5 = "FHE.reinterpret_precision"(%4) : (!FHE.esint<6>) -> !FHE.esint<7>
|
||||||
|
%6 = "FHE.add_eint"(%5, %5) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
|
||||||
|
return %6 : !FHE.esint<7>
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
""" # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
|
helpers.check_str(expected_mlir, circuit.mlir)
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_truncating(helpers):
|
||||||
|
"""
|
||||||
|
Test truncate bit pattern with auto truncating.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# with auto adjust truncators configuration
|
||||||
|
# ---------------------------------------
|
||||||
|
|
||||||
|
# y has the max value of 1999, so it's 11 bits
|
||||||
|
# our target msb is 5 bits, which means we need to remove 6 of the least significant bits
|
||||||
|
|
||||||
|
truncator1 = fhe.AutoTruncator(target_msbs=5)
|
||||||
|
|
||||||
|
@fhe.compiler({"x": "encrypted"})
|
||||||
|
def function1(x):
|
||||||
|
y = x + 1000
|
||||||
|
z = fhe.truncate_bit_pattern(y, lsbs_to_remove=truncator1)
|
||||||
|
return np.sqrt(z).astype(np.int64)
|
||||||
|
|
||||||
|
inputset1 = range(1000)
|
||||||
|
function1.trace(inputset1, helpers.configuration(), auto_adjust_truncators=True)
|
||||||
|
|
||||||
|
assert truncator1.lsbs_to_remove == 6
|
||||||
|
|
||||||
|
# manual
|
||||||
|
# ------
|
||||||
|
|
||||||
|
# y has the max value of 1999, so it's 11 bits
|
||||||
|
# our target msb is 3 bits, which means we need to remove 8 of the least significant bits
|
||||||
|
|
||||||
|
truncator2 = fhe.AutoTruncator(target_msbs=3)
|
||||||
|
|
||||||
|
@fhe.compiler({"x": "encrypted"})
|
||||||
|
def function2(x):
|
||||||
|
y = x + 1000
|
||||||
|
z = fhe.truncate_bit_pattern(y, lsbs_to_remove=truncator2)
|
||||||
|
return np.sqrt(z).astype(np.int64)
|
||||||
|
|
||||||
|
inputset2 = range(1000)
|
||||||
|
fhe.AutoTruncator.adjust(function2, inputset2)
|
||||||
|
|
||||||
|
assert truncator2.lsbs_to_remove == 8
|
||||||
|
|
||||||
|
# complicated case
|
||||||
|
# ----------------
|
||||||
|
|
||||||
|
# have 2 ** 8 entries during evaluation, it won't matter after compilation
|
||||||
|
entries3 = list(range(2**8))
|
||||||
|
# we have 8-bit inputs for this table, and we only want to use first 5-bits
|
||||||
|
for i in range(0, 2**8, 2**3):
|
||||||
|
# so we set every 8th entry to a 4-bit value
|
||||||
|
entries3[i] = np.random.randint(0, (2**4) - (2**2))
|
||||||
|
# when this tlu is applied to an 8-bit value with 5-bit msb truncating, result will be 4-bits
|
||||||
|
table3 = fhe.LookupTable(entries3)
|
||||||
|
# and this is the truncator for table1, which should have lsbs_to_remove of 3
|
||||||
|
truncator3 = fhe.AutoTruncator(target_msbs=5)
|
||||||
|
|
||||||
|
# have 2 ** 8 entries during evaluation, it won't matter after compilation
|
||||||
|
entries4 = list(range(2**8))
|
||||||
|
# we have 4-bit inputs for this table, and we only want to use first 2-bits
|
||||||
|
for i in range(0, 2**4, 2**2):
|
||||||
|
# so we set every 4th entry to an 8-bit value
|
||||||
|
entries4[i] = np.random.randint(2**7, 2**8)
|
||||||
|
# when this tlu is applied to a 4-bit value with 2-bit msb truncating, result will be 8-bits
|
||||||
|
table4 = fhe.LookupTable(entries4)
|
||||||
|
# and this is the truncator for table2, which should have lsbs_to_remove of 2
|
||||||
|
truncator4 = fhe.AutoTruncator(target_msbs=2)
|
||||||
|
|
||||||
|
@fhe.compiler({"x": "encrypted"})
|
||||||
|
def function3(x):
|
||||||
|
a = fhe.truncate_bit_pattern(x, lsbs_to_remove=truncator3)
|
||||||
|
b = table3[a]
|
||||||
|
c = fhe.truncate_bit_pattern(b, lsbs_to_remove=truncator4)
|
||||||
|
d = table4[c]
|
||||||
|
return d
|
||||||
|
|
||||||
|
inputset3 = range((2**8) - (2**3))
|
||||||
|
circuit3 = function3.compile(
|
||||||
|
inputset3,
|
||||||
|
helpers.configuration(),
|
||||||
|
auto_adjust_truncators=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert truncator3.lsbs_to_remove == 3
|
||||||
|
assert truncator4.lsbs_to_remove == 2
|
||||||
|
|
||||||
|
table3_formatted_string = format_constant(table3.table, 25)
|
||||||
|
table4_formatted_string = format_constant(table4.table, 25)
|
||||||
|
|
||||||
|
helpers.check_str(
|
||||||
|
f"""
|
||||||
|
|
||||||
|
%0 = x # EncryptedScalar<uint8>
|
||||||
|
%1 = truncate_bit_pattern(%0, lsbs_to_remove=3) # EncryptedScalar<uint8>
|
||||||
|
%2 = tlu(%1, table={table3_formatted_string}) # EncryptedScalar<uint4>
|
||||||
|
%3 = truncate_bit_pattern(%2, lsbs_to_remove=2) # EncryptedScalar<uint4>
|
||||||
|
%4 = tlu(%3, table={table4_formatted_string}) # EncryptedScalar<uint8>
|
||||||
|
return %4
|
||||||
|
|
||||||
|
""",
|
||||||
|
str(circuit3.graph.format(show_bounds=False)),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_truncating_without_adjustment():
|
||||||
|
"""
|
||||||
|
Test truncate bit pattern with auto truncating but without adjustment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
truncator = fhe.AutoTruncator(target_msbs=5)
|
||||||
|
|
||||||
|
def function(x):
|
||||||
|
y = x + 1000
|
||||||
|
z = fhe.truncate_bit_pattern(y, lsbs_to_remove=truncator)
|
||||||
|
return np.sqrt(z).astype(np.int64)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
function(100)
|
||||||
|
|
||||||
|
assert str(excinfo.value) == (
|
||||||
|
"AutoTruncators cannot be used before adjustment, "
|
||||||
|
"please call AutoTruncator.adjust with the function that will be compiled "
|
||||||
|
"and provide the exact inputset that will be used for compilation"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_truncating_with_empty_inputset():
|
||||||
|
"""
|
||||||
|
Test truncate bit pattern with auto truncating but with empty inputset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
truncator = fhe.AutoTruncator(target_msbs=5)
|
||||||
|
|
||||||
|
def function(x):
|
||||||
|
y = x + 1000
|
||||||
|
z = fhe.truncate_bit_pattern(y, lsbs_to_remove=truncator)
|
||||||
|
return np.sqrt(z).astype(np.int64)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
fhe.AutoTruncator.adjust(function, [])
|
||||||
|
|
||||||
|
assert str(excinfo.value) == "AutoTruncators cannot be adjusted with an empty inputset"
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_truncating_recursive_adjustment():
|
||||||
|
"""
|
||||||
|
Test truncate bit pattern with auto truncating but with recursive adjustment.
|
||||||
|
"""
|
||||||
|
|
||||||
|
truncator = fhe.AutoTruncator(target_msbs=5)
|
||||||
|
|
||||||
|
def function(x):
|
||||||
|
fhe.AutoTruncator.adjust(function, range(10))
|
||||||
|
y = x + 1000
|
||||||
|
z = fhe.truncate_bit_pattern(y, lsbs_to_remove=truncator)
|
||||||
|
return np.sqrt(z).astype(np.int64)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
fhe.AutoTruncator.adjust(function, range(10))
|
||||||
|
|
||||||
|
assert str(excinfo.value) == "AutoTruncators cannot be adjusted recursively"
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_truncating_construct_in_function():
|
||||||
|
"""
|
||||||
|
Test truncate bit pattern with auto truncating but truncator is constructed within the function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def function(x):
|
||||||
|
y = x + 1000
|
||||||
|
z = fhe.truncate_bit_pattern(y, lsbs_to_remove=fhe.AutoTruncator(target_msbs=5))
|
||||||
|
return np.sqrt(z).astype(np.int64)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError) as excinfo:
|
||||||
|
fhe.AutoTruncator.adjust(function, range(10))
|
||||||
|
|
||||||
|
assert str(excinfo.value) == (
|
||||||
|
"AutoTruncators cannot be constructed during adjustment, "
|
||||||
|
"please construct AutoTruncators outside the function and reference it"
|
||||||
|
)
|
||||||
@@ -0,0 +1,37 @@
|
|||||||
|
"""
|
||||||
|
Tests of 'truncate_bit_pattern' extension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from concrete import fhe
|
||||||
|
|
||||||
|
|
||||||
|
def test_dump_load_auto_truncator():
|
||||||
|
"""
|
||||||
|
Test 'dump_dict' and 'load_dict' methods of AutoTruncator.
|
||||||
|
"""
|
||||||
|
|
||||||
|
truncator = fhe.AutoTruncator(target_msbs=3)
|
||||||
|
truncator.is_adjusted = True
|
||||||
|
truncator.input_min = 10
|
||||||
|
truncator.input_max = 20
|
||||||
|
truncator.input_bit_width = 5
|
||||||
|
truncator.lsbs_to_remove = 2
|
||||||
|
|
||||||
|
dumped = truncator.dump_dict()
|
||||||
|
assert dumped == {
|
||||||
|
"target_msbs": 3,
|
||||||
|
"is_adjusted": True,
|
||||||
|
"input_min": 10,
|
||||||
|
"input_max": 20,
|
||||||
|
"input_bit_width": 5,
|
||||||
|
"lsbs_to_remove": 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
loaded = fhe.AutoTruncator.load_dict(dumped)
|
||||||
|
|
||||||
|
assert loaded.target_msbs == 3
|
||||||
|
assert loaded.is_adjusted
|
||||||
|
assert loaded.input_min == 10
|
||||||
|
assert loaded.input_max == 20
|
||||||
|
assert loaded.input_bit_width == 5
|
||||||
|
assert loaded.lsbs_to_remove == 2
|
||||||
@@ -988,6 +988,23 @@ Function you are trying to compile cannot be compiled
|
|||||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit maximum operation is supported
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but only up to 16-bit maximum operation is supported
|
||||||
return %2
|
return %2
|
||||||
|
|
||||||
|
""", # noqa: E501
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
lambda x: fhe.truncate_bit_pattern(x, lsbs_to_remove=2),
|
||||||
|
{"x": "clear"},
|
||||||
|
[10, 20, 30],
|
||||||
|
RuntimeError,
|
||||||
|
"""
|
||||||
|
|
||||||
|
Function you are trying to compile cannot be compiled
|
||||||
|
|
||||||
|
%0 = x # ClearScalar<uint5> ∈ [10, 30]
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ operand is clear
|
||||||
|
%1 = truncate_bit_pattern(%0, lsbs_to_remove=2) # ClearScalar<uint5> ∈ [8, 28]
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ but clear truncate bit pattern is not supported
|
||||||
|
return %1
|
||||||
|
|
||||||
""", # noqa: E501
|
""", # noqa: E501
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
|
|||||||
Reference in New Issue
Block a user