mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-10 13:27:57 -05:00
107 lines
2.6 KiB
Python
107 lines
2.6 KiB
Python
"""
|
|
Tests of execution of round bit pattern operation.
|
|
"""
|
|
|
|
import random
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from concrete import fhe
|
|
from concrete.fhe.dtypes import Integer
|
|
|
|
# pylint: disable=redefined-outer-name
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sample,expected_output",
|
|
[
|
|
(0, 0),
|
|
(1, 1),
|
|
(-1, 0),
|
|
(10, 10),
|
|
(-10, 0),
|
|
],
|
|
)
|
|
def test_plain_relu(sample, expected_output):
|
|
"""
|
|
Test plain evaluation of relu.
|
|
"""
|
|
assert fhe.relu(sample) == expected_output
|
|
|
|
|
|
operations = [
|
|
lambda x: fhe.relu(x),
|
|
lambda x: fhe.relu(x) + 100,
|
|
]
|
|
cases = [
|
|
# fhe.relu(int1), should result in fhe.zero()
|
|
[
|
|
operation,
|
|
1,
|
|
True,
|
|
(),
|
|
0,
|
|
2,
|
|
]
|
|
for operation in operations
|
|
] + [
|
|
# fhe.relu should use an optimized TLU when it's assigned bit-width is bigger than the original
|
|
[
|
|
lambda x: fhe.relu(x) + (x + 10),
|
|
3,
|
|
True,
|
|
(),
|
|
10,
|
|
2,
|
|
]
|
|
]
|
|
|
|
with_tlu = set()
|
|
for function in operations:
|
|
for bit_width in [1, 2, 3, 4, 5, 8, 12, 16]:
|
|
for is_signed in [False, True]:
|
|
for shape in [(), (3,), (2, 3)]:
|
|
for threshold in [5, 7]:
|
|
for chunk_size in [2, 3]:
|
|
if bit_width < threshold:
|
|
key = (bit_width, is_signed)
|
|
if key in with_tlu:
|
|
continue
|
|
with_tlu.add(key)
|
|
|
|
cases += [
|
|
[
|
|
function,
|
|
bit_width,
|
|
is_signed,
|
|
shape,
|
|
threshold,
|
|
chunk_size,
|
|
]
|
|
]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"function,bit_width,is_signed,shape,threshold,chunk_size",
|
|
cases,
|
|
)
|
|
def test_relu(function, bit_width, is_signed, shape, threshold, chunk_size, helpers):
|
|
"""
|
|
Test encrypted evaluation of relu.
|
|
"""
|
|
|
|
dtype = Integer(is_signed, bit_width)
|
|
|
|
inputset = [np.random.randint(dtype.min(), dtype.max() + 1, size=shape) for _ in range(100)]
|
|
configuration = helpers.configuration().fork(
|
|
relu_on_bits_threshold=threshold,
|
|
relu_on_bits_chunk_size=chunk_size,
|
|
)
|
|
|
|
compiler = fhe.Compiler(function, {"x": "encrypted"})
|
|
circuit = compiler.compile(inputset, configuration)
|
|
|
|
for value in random.sample(inputset, 8):
|
|
helpers.check_execution(circuit, function, value, retries=3)
|