mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-08 20:38:06 -05:00
feat(frontend-python): add relu extension
This commit is contained in:
BIN
docs/_static/tutorials/relu/configuration_and_cost.png
vendored
Normal file
BIN
docs/_static/tutorials/relu/configuration_and_cost.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 38 KiB |
@@ -112,3 +112,7 @@ Additional kwargs to `compile` functions take higher precedence. So if you set t
|
||||
* Enable promotions in encrypted shifts instead of casting in runtime. See [Bitwise#Shifts](../tutorial/bitwise.md#Shifts) to learn more.
|
||||
* **composable**: bool = False,
|
||||
* Specify that the function must be composable with itself.
|
||||
* **relu_on_bits_threshold**: int = 7,
|
||||
* Bit-width to start implementing the ReLU extension with [fhe.bits](../tutorial/bit_extraction.md).
|
||||
* **relu_on_bits_chunk_size**: int = 3,
|
||||
* Chunk size of the ReLU extension when [fhe.bits](../tutorial/bit_extraction.md) implementation is used.
|
||||
|
||||
@@ -350,3 +350,134 @@ def is_vectors_same(x, y):
|
||||
|
||||
return is_same
|
||||
```
|
||||
|
||||
|
||||
## fhe.relu(value)
|
||||
|
||||
Allows you to perform ReLU operation, with the same semantic as `x if x >= 0 else 0`:
|
||||
|
||||
```python
|
||||
import numpy as np
|
||||
from concrete import fhe
|
||||
|
||||
@fhe.compiler({"x": "encrypted"})
|
||||
def f(x):
|
||||
return fhe.relu(x)
|
||||
|
||||
inputset = [np.random.randint(-10, 10) for _ in range(10)]
|
||||
circuit = f.compile(inputset)
|
||||
|
||||
assert circuit.encrypt_run_decrypt(0) == 0
|
||||
assert circuit.encrypt_run_decrypt(1) == 1
|
||||
assert circuit.encrypt_run_decrypt(-1) == 0
|
||||
assert circuit.encrypt_run_decrypt(-3) == 0
|
||||
assert circuit.encrypt_run_decrypt(5) == 5
|
||||
```
|
||||
|
||||
ReLU extension can be converted in two different ways:
|
||||
- With a single TLU on the original bit-width.
|
||||
- With multiple TLUs on smaller bit-widths.
|
||||
|
||||
For small bit-widths, the first one is better as it'll have a single TLU on a small bit-width.
|
||||
For big bit-widths, the second one is better as it won't have a TLU on a big bit-width.
|
||||
|
||||
The decision between the two can be controlled with `relu_on_bits_threshold: int = 7` configuration option:
|
||||
- `relu_on_bits_threshold=5` means:
|
||||
- 1-bit to 4-bits would be converted using the first way (i.e., using TLU)
|
||||
- 5-bits and more would be converted using the second way (i.e., using bits)
|
||||
|
||||
There is another option to customize the implementation `relu_on_bits_chunk_size: int = 2`:
|
||||
- `relu_on_bits_chunk_size=4` means:
|
||||
- When using the second implementation:
|
||||
- The input would be split to 4-bit chunks using [fhe.bits](../tutorial/bit_extraction.md), and then the ReLU would be applied to those chunks, which are then combined back.
|
||||
|
||||
Here is a script showing how execution cost is impacted when changing these values:
|
||||
```python
|
||||
from concrete import fhe
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
chunk_sizes = np.array(range(1, 6), dtype=int)
|
||||
bit_widths = np.array(range(5, 17), dtype=int)
|
||||
|
||||
data = []
|
||||
for bit_width in bit_widths:
|
||||
title = f"{bit_width=}:"
|
||||
print(title)
|
||||
print("-" * len(title))
|
||||
|
||||
inputset = range(-2**(bit_width-1), 2**(bit_width-1))
|
||||
configuration = fhe.Configuration(relu_on_bits_threshold=17)
|
||||
|
||||
compiler = fhe.Compiler(lambda x: fhe.relu((fhe.relu(x) - (2**(bit_width-2))) * 2), {"x": "encrypted"})
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
print(f" Complexity: {circuit.complexity} # tlu")
|
||||
data.append((bit_width, 0, circuit.complexity))
|
||||
|
||||
for chunk_size in chunk_sizes:
|
||||
configuration = fhe.Configuration(
|
||||
relu_on_bits_threshold=1,
|
||||
relu_on_bits_chunk_size=int(chunk_size),
|
||||
)
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
print(f" Complexity: {circuit.complexity} # {chunk_size=}")
|
||||
data.append((bit_width, chunk_size, circuit.complexity))
|
||||
|
||||
print()
|
||||
|
||||
data = np.array(data)
|
||||
|
||||
plt.title(f"ReLU using TLU vs using bits")
|
||||
plt.xlabel("Input/Output precision")
|
||||
plt.ylabel("Cost")
|
||||
|
||||
for i, chunk_size in enumerate([0] + list(chunk_sizes)):
|
||||
costs = [
|
||||
cost
|
||||
for _, candidate_chunk_size, cost in data
|
||||
if candidate_chunk_size == chunk_size
|
||||
]
|
||||
assert len(costs) == len(bit_widths)
|
||||
|
||||
label = "Single TLU" if i == 0 else f"Bits extract + multiples {chunk_size + 1} bits TLUs"
|
||||
width_bar = 0.8 / (len(chunk_sizes) + 1)
|
||||
|
||||
if i == 0:
|
||||
plt.hlines(
|
||||
costs,
|
||||
bit_widths - 0.45,
|
||||
bit_widths + 0.45,
|
||||
label=label,
|
||||
linestyle="--",
|
||||
)
|
||||
else:
|
||||
plt.bar(
|
||||
np.array(bit_widths) + width_bar * (i - (len(chunk_sizes) + 1) / 2),
|
||||
height=costs,
|
||||
width=width_bar,
|
||||
label=label,
|
||||
)
|
||||
|
||||
plt.xticks(bit_widths)
|
||||
plt.legend(loc="upper left")
|
||||
|
||||
plt.show()
|
||||
```
|
||||
|
||||
{% hint style="info" %}
|
||||
You might need to run the script twice to avoid crashing when plotting.
|
||||
{% endhint %}
|
||||
|
||||
The script will show the following figure:
|
||||
|
||||

|
||||
|
||||
{% hint style="info" %}
|
||||
The default values of these options are set based on simple circuits. How they affect performance will depend on the circuit, so play around with them to get the most out of this extension.
|
||||
{% endhint %}
|
||||
|
||||
{% hint style="warning" %}
|
||||
Conversion with the second method (i.e., using chunks) only works in `Native` encoding, which is usually selected when all table lookups in the circuit are below or equal to 8 bits.
|
||||
{% endhint %}
|
||||
|
||||
@@ -40,6 +40,7 @@ from .extensions import (
|
||||
one,
|
||||
ones,
|
||||
ones_like,
|
||||
relu,
|
||||
round_bit_pattern,
|
||||
tag,
|
||||
truncate_bit_pattern,
|
||||
|
||||
@@ -928,6 +928,8 @@ class Configuration:
|
||||
min_max_strategy_preference: List[MinMaxStrategy]
|
||||
composable: bool
|
||||
use_gpu: bool
|
||||
relu_on_bits_threshold: int
|
||||
relu_on_bits_chunk_size: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -981,6 +983,8 @@ class Configuration:
|
||||
] = None,
|
||||
composable: bool = False,
|
||||
use_gpu: bool = False,
|
||||
relu_on_bits_threshold: int = 7,
|
||||
relu_on_bits_chunk_size: int = 3,
|
||||
):
|
||||
self.verbose = verbose
|
||||
self.compiler_debug_mode = compiler_debug_mode
|
||||
@@ -1060,6 +1064,8 @@ class Configuration:
|
||||
)
|
||||
self.composable = composable
|
||||
self.use_gpu = use_gpu
|
||||
self.relu_on_bits_threshold = relu_on_bits_threshold
|
||||
self.relu_on_bits_chunk_size = relu_on_bits_chunk_size
|
||||
|
||||
self._validate()
|
||||
|
||||
@@ -1117,6 +1123,8 @@ class Configuration:
|
||||
] = KEEP,
|
||||
composable: Union[Keep, bool] = KEEP,
|
||||
use_gpu: Union[Keep, bool] = KEEP,
|
||||
relu_on_bits_threshold: Union[Keep, int] = KEEP,
|
||||
relu_on_bits_chunk_size: Union[Keep, int] = KEEP,
|
||||
) -> "Configuration":
|
||||
"""
|
||||
Get a new configuration from another one specified changes.
|
||||
|
||||
@@ -9,6 +9,7 @@ from .hint import hint
|
||||
from .maxpool import maxpool
|
||||
from .multivariate import multivariate
|
||||
from .ones import one, ones, ones_like
|
||||
from .relu import relu
|
||||
from .round_bit_pattern import AutoRounder, round_bit_pattern
|
||||
from .table import LookupTable
|
||||
from .tag import tag
|
||||
|
||||
48
frontends/concrete-python/concrete/fhe/extensions/relu.py
Normal file
48
frontends/concrete-python/concrete/fhe/extensions/relu.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
Declaration of `relu` extension.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import Any, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from ..dtypes import Integer
|
||||
from ..representation import Node
|
||||
from ..tracing import Tracer
|
||||
|
||||
|
||||
def relu(x: Union[Tracer, Any]) -> Union[Tracer, Any]:
|
||||
"""
|
||||
Rectified linear unit extension.
|
||||
|
||||
Computes:
|
||||
x if x >= 0 else 0
|
||||
|
||||
Args:
|
||||
x (Union[Tracer, Any]):
|
||||
input to apply ReLU
|
||||
|
||||
Returns:
|
||||
Union[Tracer, Any]:
|
||||
Tracer that represent the operation during tracing
|
||||
result of ReLU on `x` otherwise
|
||||
"""
|
||||
|
||||
def evaluator(x):
|
||||
return np.where(x >= 0, x, 0)
|
||||
|
||||
if not isinstance(x, Tracer):
|
||||
return evaluator(x)
|
||||
|
||||
resulting_value = deepcopy(x.output)
|
||||
if isinstance(resulting_value.dtype, Integer) and resulting_value.dtype.is_signed:
|
||||
resulting_value.dtype.is_signed = False
|
||||
|
||||
computation = Node.generic(
|
||||
"relu",
|
||||
[deepcopy(x.output)],
|
||||
resulting_value,
|
||||
evaluator,
|
||||
)
|
||||
return Tracer(computation, [x])
|
||||
@@ -24,7 +24,12 @@ from mlir.ir import OpResult as MlirOperation
|
||||
from mlir.ir import RankedTensorType
|
||||
from mlir.ir import Type as MlirType
|
||||
|
||||
from ..compilation.configuration import BitwiseStrategy, ComparisonStrategy, MinMaxStrategy
|
||||
from ..compilation.configuration import (
|
||||
BitwiseStrategy,
|
||||
ComparisonStrategy,
|
||||
Configuration,
|
||||
MinMaxStrategy,
|
||||
)
|
||||
from ..dtypes import Integer
|
||||
from ..extensions.bits import MAX_EXTRACTABLE_BIT, MIN_EXTRACTABLE_BIT
|
||||
from ..representation import Graph, Node
|
||||
@@ -51,7 +56,9 @@ class Context:
|
||||
conversion_cache: Dict[Tuple, Conversion]
|
||||
constant_cache: Dict[MlirAttribute, MlirOperation]
|
||||
|
||||
def __init__(self, context: MlirContext, graph: Graph):
|
||||
configuration: Configuration
|
||||
|
||||
def __init__(self, context: MlirContext, graph: Graph, configuration: Configuration):
|
||||
self.context = context
|
||||
|
||||
self.graph = graph
|
||||
@@ -61,6 +68,8 @@ class Context:
|
||||
self.conversion_cache = {}
|
||||
self.constant_cache = {}
|
||||
|
||||
self.configuration = configuration
|
||||
|
||||
# types
|
||||
|
||||
def i(self, width: int) -> ConversionType:
|
||||
@@ -2787,6 +2796,62 @@ class Context:
|
||||
|
||||
return self.add(resulting_type, one, self.zeros(resulting_type))
|
||||
|
||||
def relu(self, resulting_type: ConversionType, x: Conversion) -> Conversion:
|
||||
if x.bit_width < self.configuration.relu_on_bits_threshold:
|
||||
if x.bit_width > x.original_bit_width:
|
||||
shifter = self.constant(
|
||||
self.i(x.bit_width + 1),
|
||||
2 ** (x.bit_width - x.original_bit_width),
|
||||
)
|
||||
x = self.reinterpret(
|
||||
self.mul(x.type, x, shifter),
|
||||
bit_width=x.original_bit_width,
|
||||
)
|
||||
|
||||
x_dtype = Integer(is_signed=x.is_signed, bit_width=x.bit_width)
|
||||
table = [x for x in range(x_dtype.max() + 1)] + [0 for x in range(abs(x_dtype.min()))]
|
||||
return self.tlu(resulting_type, x, table)
|
||||
|
||||
if x.is_unsigned:
|
||||
if resulting_type.bit_width == x.bit_width:
|
||||
return x
|
||||
|
||||
assert resulting_type.bit_width > x.bit_width
|
||||
return self.extract_bits(resulting_type, x, bits=slice(0, x.original_bit_width))
|
||||
|
||||
if x.original_bit_width == 1:
|
||||
return self.zeros(resulting_type)
|
||||
|
||||
chunk_size = self.configuration.relu_on_bits_chunk_size
|
||||
intermediate_type = self.tensor(self.eint(chunk_size + 1), shape=x.shape)
|
||||
sign = self.reinterpret(
|
||||
self.extract_bits(
|
||||
self.tensor(self.eint(1), shape=x.shape),
|
||||
x,
|
||||
bits=(x.original_bit_width - 1),
|
||||
),
|
||||
bit_width=intermediate_type.bit_width,
|
||||
)
|
||||
|
||||
filtered_chunks = []
|
||||
for chunk_start in range(0, x.original_bit_width - 1, chunk_size):
|
||||
chunk_end = min(chunk_start + chunk_size, x.original_bit_width - 1)
|
||||
|
||||
chunk = self.extract_bits(intermediate_type, x, bits=slice(chunk_start, chunk_end))
|
||||
packed_chunk_and_sign = self.add(intermediate_type, chunk, sign)
|
||||
filtered_chunk = self.tlu(
|
||||
resulting_type,
|
||||
packed_chunk_and_sign,
|
||||
[
|
||||
(x << chunk_start) if x >> chunk_size == 0 else 0
|
||||
for x in range(2**intermediate_type.bit_width)
|
||||
],
|
||||
)
|
||||
|
||||
filtered_chunks.append(filtered_chunk)
|
||||
|
||||
return self.tree_add(resulting_type, filtered_chunks)
|
||||
|
||||
def reshape(self, x: Conversion, shape: Tuple[int, ...]) -> Conversion:
|
||||
if x.is_scalar:
|
||||
x = self.tensorize(x)
|
||||
|
||||
@@ -34,7 +34,10 @@ class Converter:
|
||||
"""
|
||||
|
||||
def convert(
|
||||
self, graph: Graph, configuration: Configuration, mlir_context: MlirContext
|
||||
self,
|
||||
graph: Graph,
|
||||
configuration: Configuration,
|
||||
mlir_context: MlirContext,
|
||||
) -> MlirModule:
|
||||
"""
|
||||
Convert a computation graph to MLIR.
|
||||
@@ -61,7 +64,7 @@ class Converter:
|
||||
|
||||
module = MlirModule.create()
|
||||
with MlirInsertionPoint(module.body):
|
||||
ctx = Context(context, graph)
|
||||
ctx = Context(context, graph, configuration)
|
||||
|
||||
input_types = [ctx.typeof(node).mlir for node in graph.ordered_inputs()]
|
||||
|
||||
@@ -451,6 +454,10 @@ class Converter:
|
||||
assert len(preds) == 0
|
||||
return ctx.ones(ctx.typeof(node))
|
||||
|
||||
def relu(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 1
|
||||
return ctx.relu(ctx.typeof(node), preds[0])
|
||||
|
||||
def reshape(self, ctx: Context, node: Node, preds: List[Conversion]) -> Conversion:
|
||||
assert len(preds) == 1
|
||||
return ctx.reshape(preds[0], shape=node.output.shape)
|
||||
|
||||
106
frontends/concrete-python/tests/execution/test_relu.py
Normal file
106
frontends/concrete-python/tests/execution/test_relu.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""
|
||||
Tests of execution of round bit pattern operation.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from concrete import fhe
|
||||
from concrete.fhe.dtypes import Integer
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sample,expected_output",
|
||||
[
|
||||
(0, 0),
|
||||
(1, 1),
|
||||
(-1, 0),
|
||||
(10, 10),
|
||||
(-10, 0),
|
||||
],
|
||||
)
|
||||
def test_plain_relu(sample, expected_output):
|
||||
"""
|
||||
Test plain evaluation of relu.
|
||||
"""
|
||||
assert fhe.relu(sample) == expected_output
|
||||
|
||||
|
||||
operations = [
|
||||
lambda x: fhe.relu(x),
|
||||
lambda x: fhe.relu(x) + 100,
|
||||
]
|
||||
cases = [
|
||||
# fhe.relu(int1), should result in fhe.zero()
|
||||
[
|
||||
operation,
|
||||
1,
|
||||
True,
|
||||
(),
|
||||
0,
|
||||
2,
|
||||
]
|
||||
for operation in operations
|
||||
] + [
|
||||
# fhe.relu should use an optimized TLU when it's assigned bit-width is bigger than the original
|
||||
[
|
||||
lambda x: fhe.relu(x) + (x + 10),
|
||||
3,
|
||||
True,
|
||||
(),
|
||||
10,
|
||||
2,
|
||||
]
|
||||
]
|
||||
|
||||
with_tlu = set()
|
||||
for function in operations:
|
||||
for bit_width in [1, 2, 3, 4, 5, 8, 12, 16]:
|
||||
for is_signed in [False, True]:
|
||||
for shape in [(), (3,), (2, 3)]:
|
||||
for threshold in [5, 7]:
|
||||
for chunk_size in [2, 3]:
|
||||
if bit_width < threshold:
|
||||
key = (bit_width, is_signed)
|
||||
if key in with_tlu:
|
||||
continue
|
||||
with_tlu.add(key)
|
||||
|
||||
cases += [
|
||||
[
|
||||
function,
|
||||
bit_width,
|
||||
is_signed,
|
||||
shape,
|
||||
threshold,
|
||||
chunk_size,
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"function,bit_width,is_signed,shape,threshold,chunk_size",
|
||||
cases,
|
||||
)
|
||||
def test_relu(function, bit_width, is_signed, shape, threshold, chunk_size, helpers):
|
||||
"""
|
||||
Test encrypted evaluation of relu.
|
||||
"""
|
||||
|
||||
dtype = Integer(is_signed, bit_width)
|
||||
|
||||
inputset = [np.random.randint(dtype.min(), dtype.max() + 1, size=shape) for _ in range(100)]
|
||||
configuration = helpers.configuration().fork(
|
||||
relu_on_bits_threshold=threshold,
|
||||
relu_on_bits_chunk_size=chunk_size,
|
||||
)
|
||||
|
||||
compiler = fhe.Compiler(function, {"x": "encrypted"})
|
||||
circuit = compiler.compile(inputset, configuration)
|
||||
|
||||
for value in random.sample(inputset, 8):
|
||||
helpers.check_execution(circuit, function, value, retries=3)
|
||||
Reference in New Issue
Block a user