Files
concrete/tests/execution/test_round_bit_pattern.py

92 lines
2.5 KiB
Python

"""
Tests of execution of round bit pattern operation.
"""
import numpy as np
import pytest
import concrete.numpy as cnp
@pytest.mark.parametrize(
"sample,lsbs_to_remove,expected_output",
[
(0b_0000_0000, 3, 0b_0000_0000),
(0b_0000_0001, 3, 0b_0000_0000),
(0b_0000_0010, 3, 0b_0000_0000),
(0b_0000_0011, 3, 0b_0000_0000),
(0b_0000_0100, 3, 0b_0000_1000),
(0b_0000_0101, 3, 0b_0000_1000),
(0b_0000_0110, 3, 0b_0000_1000),
(0b_0000_0111, 3, 0b_0000_1000),
(0b_0000_1000, 3, 0b_0000_1000),
(0b_0000_1001, 3, 0b_0000_1000),
(0b_0000_1010, 3, 0b_0000_1000),
(0b_0000_1011, 3, 0b_0000_1000),
(0b_0000_1100, 3, 0b_0001_0000),
(0b_0000_1101, 3, 0b_0001_0000),
(0b_0000_1110, 3, 0b_0001_0000),
(0b_0000_1111, 3, 0b_0001_0000),
],
)
def test_plain_round_bit_pattern(sample, lsbs_to_remove, expected_output):
"""
Test round bit pattern in evaluation context.
"""
assert cnp.round_bit_pattern(sample, lsbs_to_remove) == expected_output
@pytest.mark.parametrize(
"sample,lsbs_to_remove,expected_error,expected_message",
[
(
np.array([3.2, 4.1]),
3,
TypeError,
"Expected input elements to be integers but they are dtype[float64]",
),
(
"foo",
3,
TypeError,
"Expected input to be an int or a numpy array but it's str",
),
],
)
def test_bad_plain_round_bit_pattern(sample, lsbs_to_remove, expected_error, expected_message):
"""
Test round bit pattern in evaluation context with bad parameters.
"""
with pytest.raises(expected_error) as excinfo:
cnp.round_bit_pattern(sample, lsbs_to_remove)
assert str(excinfo.value) == expected_message
@pytest.mark.parametrize(
"input_bits,lsbs_to_remove",
[
(3, 1),
(3, 2),
(4, 1),
(4, 2),
(4, 3),
(5, 1),
(5, 2),
(5, 3),
(5, 4),
],
)
def test_round_bit_pattern(input_bits, lsbs_to_remove, helpers):
"""
Test round bit pattern in evaluation context.
"""
@cnp.compiler({"x": "encrypted"})
def function(x):
return np.abs(50 * np.sin(cnp.round_bit_pattern(x, lsbs_to_remove))).astype(np.int64)
circuit = function.compile([(2**input_bits) - 1], helpers.configuration(), virtual=True)
helpers.check_execution(circuit, function, np.random.randint(0, 2**input_bits))