feat(frontend-python): add relu extension

This commit is contained in:
Umut
2024-01-02 17:04:34 +03:00
committed by rudy-6-4
parent cc14e7a4f9
commit 8ef84bed42
10 changed files with 375 additions and 4 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 38 KiB

View File

@@ -112,3 +112,7 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t
* Enable promotions in encrypted shifts instead of casting in runtime. See [Bitwise#Shifts](../tutorial/bitwise.md#Shifts) to learn more.
* **composable**: bool = False,
* Specify that the function must be composable with itself.
* **relu_on_bits_threshold**: int = 7,
* Bit-width to start implementing the ReLU extension with [fhe.bits](../tutorial/bit_extraction.md).
* **relu_on_bits_chunk_size**: int = 3,
* Chunk size of the ReLU extension when [fhe.bits](../tutorial/bit_extraction.md) implementation is used.

View File

@@ -350,3 +350,134 @@ def is_vectors_same(x, y):
return is_same
```
## fhe.relu(value)
Allows you to perform ReLU operation, with the same semantic as `x if x >= 0 else 0`:
```python
import numpy as np
from concrete import fhe
@fhe.compiler({"x": "encrypted"})
def f(x):
return fhe.relu(x)
inputset = [np.random.randint(-10, 10) for _ in range(10)]
circuit = f.compile(inputset)
assert circuit.encrypt_run_decrypt(0) == 0
assert circuit.encrypt_run_decrypt(1) == 1
assert circuit.encrypt_run_decrypt(-1) == 0
assert circuit.encrypt_run_decrypt(-3) == 0
assert circuit.encrypt_run_decrypt(5) == 5
```
ReLU extension can be converted in two different ways:
- With a single TLU on the original bit-width.
- With multiple TLUs on smaller bit-widths.
For small bit-widths, the first one is better as it'll have a single TLU on a small bit-width.
For big bit-widths, the second one is better as it won't have a TLU on a big bit-width.
The decision between the two can be controlled with `relu_on_bits_threshold: int = 7` configuration option:
- `relu_on_bits_threshold=5` means:
- 1-bit to 4-bits would be converted using the first way (i.e., using TLU)
- 5-bits and more would be converted using the second way (i.e., using bits)
There is another option to customize the implementation `relu_on_bits_chunk_size: int = 2`:
- `relu_on_bits_chunk_size=4` means:
- When using the second implementation:
- The input would be split to 4-bit chunks using [fhe.bits](../tutorial/bit_extraction.md), and then the ReLU would be applied to those chunks, which are then combined back.
Here is a script showing how execution cost is impacted when changing these values:
```python
from concrete import fhe
import numpy as np
import matplotlib.pyplot as plt
chunk_sizes = np.array(range(1, 6), dtype=int)
bit_widths = np.array(range(5, 17), dtype=int)
data = []
for bit_width in bit_widths:
title = f"{bit_width=}:"
print(title)
print("-" * len(title))
inputset = range(-2**(bit_width-1), 2**(bit_width-1))
configuration = fhe.Configuration(relu_on_bits_threshold=17)
compiler = fhe.Compiler(lambda x: fhe.relu((fhe.relu(x) - (2**(bit_width-2))) * 2), {"x": "encrypted"})
circuit = compiler.compile(inputset, configuration)
print(f" Complexity: {circuit.complexity} # tlu")
data.append((bit_width, 0, circuit.complexity))
for chunk_size in chunk_sizes:
configuration = fhe.Configuration(
relu_on_bits_threshold=1,
relu_on_bits_chunk_size=int(chunk_size),
)
circuit = compiler.compile(inputset, configuration)
print(f" Complexity: {circuit.complexity} # {chunk_size=}")
data.append((bit_width, chunk_size, circuit.complexity))
print()
data = np.array(data)
plt.title(f"ReLU using TLU vs using bits")
plt.xlabel("Input/Output precision")
plt.ylabel("Cost")
for i, chunk_size in enumerate([0] + list(chunk_sizes)):
costs = [
cost
for _, candidate_chunk_size, cost in data
if candidate_chunk_size == chunk_size
]
assert len(costs) == len(bit_widths)
label = "Single TLU" if i == 0 else f"Bits extract + multiples {chunk_size + 1} bits TLUs"
width_bar = 0.8 / (len(chunk_sizes) + 1)
if i == 0:
plt.hlines(
costs,
bit_widths - 0.45,
bit_widths + 0.45,
label=label,
linestyle="--",
)
else:
plt.bar(
np.array(bit_widths) + width_bar * (i - (len(chunk_sizes) + 1) / 2),
height=costs,
width=width_bar,
label=label,
)
plt.xticks(bit_widths)
plt.legend(loc="upper left")
plt.show()
```
{% hint style="info" %}
You might need to run the script twice to avoid crashing when plotting.
{% endhint %}
The script will show the following figure:
![](../_static/tutorials/relu/configuration_and_cost.png)
{% hint style="info" %}
The default values of these options are set based on simple circuits. How they affect performance will depend on the circuit, so play around with them to get the most out of this extension.
{% endhint %}
{% hint style="warning" %}
Conversion with the second method (i.e., using chunks) only works in `Native` encoding, which is usually selected when all table lookups in the circuit are below or equal to 8 bits.
{% endhint %}

View File

@@ -40,6 +40,7 @@ from .extensions import (
one,
ones,
ones_like,
relu,
round_bit_pattern,
tag,
truncate_bit_pattern,

View File

@@ -928,6 +928,8 @@ class Configuration:
min_max_strategy_preference: List[MinMaxStrategy]
composable: bool
use_gpu: bool
relu_on_bits_threshold: int
relu_on_bits_chunk_size: int
def __init__(
self,
@@ -981,6 +983,8 @@ class Configuration:
] = None,
composable: bool = False,
use_gpu: bool = False,
relu_on_bits_threshold: int = 7,
relu_on_bits_chunk_size: int = 3,
):
self.verbose = verbose
self.compiler_debug_mode = compiler_debug_mode
@@ -1060,6 +1064,8 @@ class Configuration:
)
self.composable = composable
self.use_gpu = use_gpu
self.relu_on_bits_threshold = relu_on_bits_threshold
self.relu_on_bits_chunk_size = relu_on_bits_chunk_size
self._validate()
@@ -1117,6 +1123,8 @@ class Configuration:
] = KEEP,
composable: Union[Keep, bool] = KEEP,
use_gpu: Union[Keep, bool] = KEEP,
relu_on_bits_threshold: Union[Keep, int] = KEEP,
relu_on_bits_chunk_size: Union[Keep, int] = KEEP,
) -> "Configuration":
"""
Get a new configuration from another one specified changes.

View File

@@ -9,6 +9,7 @@ from .hint import hint
from .maxpool import maxpool
from .multivariate import multivariate
from .ones import one, ones, ones_like
from .relu import relu
from .round_bit_pattern import AutoRounder, round_bit_pattern
from .table import LookupTable
from .tag import tag

View File

@@ -0,0 +1,48 @@
"""
Declaration of `relu` extension.
"""
from copy import deepcopy
from typing import Any, Union
import numpy as np
from ..dtypes import Integer
from ..representation import Node
from ..tracing import Tracer
def relu(x: Union[Tracer, Any]) -> Union[Tracer, Any]:
"""
Rectified linear unit extension.
Computes:
x if x >= 0 else 0
Args:
x (Union[Tracer, Any]):
input to apply ReLU
Returns:
Union[Tracer, Any]:
Tracer that represent the operation during tracing
result of ReLU on `x` otherwise
"""
def evaluator(x):
return np.where(x >= 0, x, 0)
if not isinstance(x, Tracer):
return evaluator(x)
resulting_value = deepcopy(x.output)
if isinstance(resulting_value.dtype, Integer) and resulting_value.dtype.is_signed:
resulting_value.dtype.is_signed = False
computation = Node.generic(
"relu",
[deepcopy(x.output)],
resulting_value,
evaluator,
)
return Tracer(computation, [x])

View File

@@ -24,7 +24,12 @@ from mlir.ir import OpResult as MlirOperation
from mlir.ir import RankedTensorType
from mlir.ir import Type as MlirType
from ..compilation.configuration import BitwiseStrategy, ComparisonStrategy, MinMaxStrategy
from ..compilation.configuration import (
BitwiseStrategy,
ComparisonStrategy,
Configuration,
MinMaxStrategy,
)
from ..dtypes import Integer
from ..extensions.bits import MAX_EXTRACTABLE_BIT, MIN_EXTRACTABLE_BIT
from ..representation import Graph, Node
@@ -51,7 +56,9 @@ class Context:
conversion_cache: Dict[Tuple, Conversion]
constant_cache: Dict[MlirAttribute, MlirOperation]
def __init__(self, context: MlirContext, graph: Graph):
configuration: Configuration
def __init__(self, context: MlirContext, graph: Graph, configuration: Configuration):
self.context = context
self.graph = graph
@@ -61,6 +68,8 @@ class Context:
self.conversion_cache = {}
self.constant_cache = {}
self.configuration = configuration
# types
def i(self, width: int) -> ConversionType:
@@ -2787,6 +2796,62 @@ class Context:
return self.add(resulting_type, one, self.zeros(resulting_type))
def relu(self, resulting_type: ConversionType, x: Conversion) -> Conversion:
if x.bit_width < self.configuration.relu_on_bits_threshold:
if x.bit_width > x.original_bit_width:
shifter = self.constant(
self.i(x.bit_width + 1),
2 ** (x.bit_width - x.original_bit_width),
)
x = self.reinterpret(
self.mul(x.type, x, shifter),
bit_width=x.original_bit_width,
)
x_dtype = Integer(is_signed=x.is_signed, bit_width=x.bit_width)
table = [x for x in range(x_dtype.max() + 1)] + [0 for x in range(abs(x_dtype.min()))]
return self.tlu(resulting_type, x, table)
if x.is_unsigned:
if resulting_type.bit_width == x.bit_width:
return x
assert resulting_type.bit_width > x.bit_width
return self.extract_bits(resulting_type, x, bits=slice(0, x.original_bit_width))
if x.original_bit_width == 1:
return self.zeros(resulting_type)
chunk_size = self.configuration.relu_on_bits_chunk_size
intermediate_type = self.tensor(self.eint(chunk_size + 1), shape=x.shape)
sign = self.reinterpret(
self.extract_bits(
self.tensor(self.eint(1), shape=x.shape),
x,
bits=(x.original_bit_width - 1),
),
bit_width=intermediate_type.bit_width,
)
filtered_chunks = []
for chunk_start in range(0, x.original_bit_width - 1, chunk_size):
chunk_end = min(chunk_start + chunk_size, x.original_bit_width - 1)
chunk = self.extract_bits(intermediate_type, x, bits=slice(chunk_start, chunk_end))
packed_chunk_and_sign = self.add(intermediate_type, chunk, sign)
filtered_chunk = self.tlu(
resulting_type,
packed_chunk_and_sign,
[
(x << chunk_start) if x >> chunk_size == 0 else 0
for x in range(2**intermediate_type.bit_width)
],
)
filtered_chunks.append(filtered_chunk)
return self.tree_add(resulting_type, filtered_chunks)
def reshape(self, x: Conversion, shape: Tuple[int, ...]) -> Conversion:
if x.is_scalar:
x = self.tensorize(x)

View File

@@ -34,7 +34,10 @@ class Converter:
"""
def convert(
self, graph: Graph, configuration: Configuration, mlir_context: MlirContext
self,
graph: Graph,
configuration: Configuration,
mlir_context: MlirContext,
) -> MlirModule:
"""
Convert a computation graph to MLIR.
@@ -61,7 +64,7 @@ class Converter:
module = MlirModule.create()
with MlirInsertionPoint(module.body):
ctx = Context(context, graph)
ctx = Context(context, graph, configuration)
input_types = [ctx.typeof(node).mlir for node in graph.ordered_inputs()]
@@ -451,6 +454,10 @@ class Converter:
assert len(preds) == 0
return ctx.ones(ctx.typeof(node))
def relu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
return ctx.relu(ctx.typeof(node), preds[0])
def reshape(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
assert len(preds) == 1
return ctx.reshape(preds[0], shape=node.output.shape)

View File

@@ -0,0 +1,106 @@
"""
Tests of execution of round bit pattern operation.
"""
import random
import numpy as np
import pytest
from concrete import fhe
from concrete.fhe.dtypes import Integer
# pylint: disable=redefined-outer-name
@pytest.mark.parametrize(
"sample,expected_output",
[
(0, 0),
(1, 1),
(-1, 0),
(10, 10),
(-10, 0),
],
)
def test_plain_relu(sample, expected_output):
"""
Test plain evaluation of relu.
"""
assert fhe.relu(sample) == expected_output
operations = [
lambda x: fhe.relu(x),
lambda x: fhe.relu(x) + 100,
]
cases = [
# fhe.relu(int1), should result in fhe.zero()
[
operation,
1,
True,
(),
0,
2,
]
for operation in operations
] + [
# fhe.relu should use an optimized TLU when it's assigned bit-width is bigger than the original
[
lambda x: fhe.relu(x) + (x + 10),
3,
True,
(),
10,
2,
]
]
with_tlu = set()
for function in operations:
for bit_width in [1, 2, 3, 4, 5, 8, 12, 16]:
for is_signed in [False, True]:
for shape in [(), (3,), (2, 3)]:
for threshold in [5, 7]:
for chunk_size in [2, 3]:
if bit_width < threshold:
key = (bit_width, is_signed)
if key in with_tlu:
continue
with_tlu.add(key)
cases += [
[
function,
bit_width,
is_signed,
shape,
threshold,
chunk_size,
]
]
@pytest.mark.parametrize(
"function,bit_width,is_signed,shape,threshold,chunk_size",
cases,
)
def test_relu(function, bit_width, is_signed, shape, threshold, chunk_size, helpers):
"""
Test encrypted evaluation of relu.
"""
dtype = Integer(is_signed, bit_width)
inputset = [np.random.randint(dtype.min(), dtype.max() + 1, size=shape) for _ in range(100)]
configuration = helpers.configuration().fork(
relu_on_bits_threshold=threshold,
relu_on_bits_chunk_size=chunk_size,
)
compiler = fhe.Compiler(function, {"x": "encrypted"})
circuit = compiler.compile(inputset, configuration)
for value in random.sample(inputset, 8):
helpers.check_execution(circuit, function, value, retries=3)