mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-09 03:55:04 -05:00
92 lines
2.5 KiB
Python
92 lines
2.5 KiB
Python
"""
|
|
Tests of execution of round bit pattern operation.
|
|
"""
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
import concrete.numpy as cnp
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sample,lsbs_to_remove,expected_output",
|
|
[
|
|
(0b_0000_0000, 3, 0b_0000_0000),
|
|
(0b_0000_0001, 3, 0b_0000_0000),
|
|
(0b_0000_0010, 3, 0b_0000_0000),
|
|
(0b_0000_0011, 3, 0b_0000_0000),
|
|
(0b_0000_0100, 3, 0b_0000_1000),
|
|
(0b_0000_0101, 3, 0b_0000_1000),
|
|
(0b_0000_0110, 3, 0b_0000_1000),
|
|
(0b_0000_0111, 3, 0b_0000_1000),
|
|
(0b_0000_1000, 3, 0b_0000_1000),
|
|
(0b_0000_1001, 3, 0b_0000_1000),
|
|
(0b_0000_1010, 3, 0b_0000_1000),
|
|
(0b_0000_1011, 3, 0b_0000_1000),
|
|
(0b_0000_1100, 3, 0b_0001_0000),
|
|
(0b_0000_1101, 3, 0b_0001_0000),
|
|
(0b_0000_1110, 3, 0b_0001_0000),
|
|
(0b_0000_1111, 3, 0b_0001_0000),
|
|
],
|
|
)
|
|
def test_plain_round_bit_pattern(sample, lsbs_to_remove, expected_output):
|
|
"""
|
|
Test round bit pattern in evaluation context.
|
|
"""
|
|
assert cnp.round_bit_pattern(sample, lsbs_to_remove) == expected_output
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"sample,lsbs_to_remove,expected_error,expected_message",
|
|
[
|
|
(
|
|
np.array([3.2, 4.1]),
|
|
3,
|
|
TypeError,
|
|
"Expected input elements to be integers but they are dtype[float64]",
|
|
),
|
|
(
|
|
"foo",
|
|
3,
|
|
TypeError,
|
|
"Expected input to be an int or a numpy array but it's str",
|
|
),
|
|
],
|
|
)
|
|
def test_bad_plain_round_bit_pattern(sample, lsbs_to_remove, expected_error, expected_message):
|
|
"""
|
|
Test round bit pattern in evaluation context with bad parameters.
|
|
"""
|
|
|
|
with pytest.raises(expected_error) as excinfo:
|
|
cnp.round_bit_pattern(sample, lsbs_to_remove)
|
|
|
|
assert str(excinfo.value) == expected_message
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"input_bits,lsbs_to_remove",
|
|
[
|
|
(3, 1),
|
|
(3, 2),
|
|
(4, 1),
|
|
(4, 2),
|
|
(4, 3),
|
|
(5, 1),
|
|
(5, 2),
|
|
(5, 3),
|
|
(5, 4),
|
|
],
|
|
)
|
|
def test_round_bit_pattern(input_bits, lsbs_to_remove, helpers):
|
|
"""
|
|
Test round bit pattern in evaluation context.
|
|
"""
|
|
|
|
@cnp.compiler({"x": "encrypted"})
|
|
def function(x):
|
|
return np.abs(50 * np.sin(cnp.round_bit_pattern(x, lsbs_to_remove))).astype(np.int64)
|
|
|
|
circuit = function.compile([(2**input_bits) - 1], helpers.configuration(), virtual=True)
|
|
helpers.check_execution(circuit, function, np.random.randint(0, 2**input_bits))
|