Files
concrete/frontends/concrete-python/tests/execution/test_relu.py
2024-01-11 16:09:42 +01:00

107 lines
2.6 KiB
Python

"""
Tests of execution of round bit pattern operation.
"""
import random
import numpy as np
import pytest
from concrete import fhe
from concrete.fhe.dtypes import Integer
# pylint: disable=redefined-outer-name
@pytest.mark.parametrize(
"sample,expected_output",
[
(0, 0),
(1, 1),
(-1, 0),
(10, 10),
(-10, 0),
],
)
def test_plain_relu(sample, expected_output):
"""
Test plain evaluation of relu.
"""
assert fhe.relu(sample) == expected_output
operations = [
lambda x: fhe.relu(x),
lambda x: fhe.relu(x) + 100,
]
cases = [
# fhe.relu(int1), should result in fhe.zero()
[
operation,
1,
True,
(),
0,
2,
]
for operation in operations
] + [
# fhe.relu should use an optimized TLU when it's assigned bit-width is bigger than the original
[
lambda x: fhe.relu(x) + (x + 10),
3,
True,
(),
10,
2,
]
]
with_tlu = set()
for function in operations:
for bit_width in [1, 2, 3, 4, 5, 8, 12, 16]:
for is_signed in [False, True]:
for shape in [(), (3,), (2, 3)]:
for threshold in [5, 7]:
for chunk_size in [2, 3]:
if bit_width < threshold:
key = (bit_width, is_signed)
if key in with_tlu:
continue
with_tlu.add(key)
cases += [
[
function,
bit_width,
is_signed,
shape,
threshold,
chunk_size,
]
]
@pytest.mark.parametrize(
"function,bit_width,is_signed,shape,threshold,chunk_size",
cases,
)
def test_relu(function, bit_width, is_signed, shape, threshold, chunk_size, helpers):
"""
Test encrypted evaluation of relu.
"""
dtype = Integer(is_signed, bit_width)
inputset = [np.random.randint(dtype.min(), dtype.max() + 1, size=shape) for _ in range(100)]
configuration = helpers.configuration().fork(
relu_on_bits_threshold=threshold,
relu_on_bits_chunk_size=chunk_size,
)
compiler = fhe.Compiler(function, {"x": "encrypted"})
circuit = compiler.compile(inputset, configuration)
for value in random.sample(inputset, 8):
helpers.check_execution(circuit, function, value, retries=3)