Files
concrete/frontends/concrete-python/tests/execution/test_bit_extraction.py

377 lines
13 KiB
Python

"""
Tests of execution of bit extraction.
"""
import random
import numpy as np
import pytest
from concrete import fhe
@pytest.mark.parametrize(
"sample,operation,expected_output",
[
(0b_0110_0100, lambda x: fhe.bits(x)[0], 0b_0),
(0b_0110_0100, lambda x: fhe.bits(x)[1], 0b_0),
(0b_0110_0100, lambda x: fhe.bits(x)[2], 0b_1),
(0b_0110_0100, lambda x: fhe.bits(x)[3], 0b_0),
(0b_0110_0100, lambda x: fhe.bits(x)[4], 0b_0),
(0b_0110_0100, lambda x: fhe.bits(x)[5], 0b_1),
(0b_0110_0100, lambda x: fhe.bits(x)[6], 0b_1),
(0b_0110_0100, lambda x: fhe.bits(x)[7], 0b_0),
(0b_0110_0100, lambda x: fhe.bits(x)[30], 0b_0),
# --------------------------------------------
(0b_1001_1011, lambda x: fhe.bits(x)[0], 0b_1),
(0b_1001_1011, lambda x: fhe.bits(x)[1], 0b_1),
(0b_1001_1011, lambda x: fhe.bits(x)[2], 0b_0),
(0b_1001_1011, lambda x: fhe.bits(x)[3], 0b_1),
(0b_1001_1011, lambda x: fhe.bits(x)[4], 0b_1),
(0b_1001_1011, lambda x: fhe.bits(x)[5], 0b_0),
(0b_1001_1011, lambda x: fhe.bits(x)[6], 0b_0),
(0b_1001_1011, lambda x: fhe.bits(x)[7], 0b_1),
(0b_1001_1011, lambda x: fhe.bits(x)[30], 0b_0),
# --------------------------------------------
(0b_0110_0100, lambda x: fhe.bits(x)[1:3], 0b_10),
(0b_0110_0100, lambda x: fhe.bits(x)[:3], 0b_100),
(0b_0110_0100, lambda x: fhe.bits(x)[1:], 0b_0110_010),
(0b_0110_0100, lambda x: fhe.bits(x)[1:6:2], 0b_100),
(0b_0110_0100, lambda x: fhe.bits(x)[3:1:-1], 0b_10),
(0b_0110_0100, lambda x: fhe.bits(x)[2::-1], 0b_001),
(0b_0110_0100, lambda x: fhe.bits(x)[30:32], 0b_00),
# --------------------------------------------
(0b_1001_1011, lambda x: fhe.bits(x)[1:3], 0b_01),
(0b_1001_1011, lambda x: fhe.bits(x)[:3], 0b_011),
(0b_1001_1011, lambda x: fhe.bits(x)[1:], 0b_1001_101),
(0b_1001_1011, lambda x: fhe.bits(x)[1:6:2], 0b_011),
(0b_1001_1011, lambda x: fhe.bits(x)[3:1:-1], 0b_01),
(0b_1001_1011, lambda x: fhe.bits(x)[2::-1], 0b_110),
(0b_1001_1011, lambda x: fhe.bits(x)[30:32], 0b_00),
],
)
def test_plain_bit_extraction(sample, operation, expected_output):
"""
Test plain bit extraction.
"""
assert operation(sample) == expected_output
@pytest.mark.parametrize(
"sample,operation,expected_error,expected_message",
[
(
100,
lambda x: fhe.bits(x)[1.1:3], # type: ignore
ValueError,
"Extracting bits using a non integer start (e.g., 1.1) isn't supported",
),
(
100,
lambda x: fhe.bits(x)[-2:3],
ValueError,
"Extracting bits using a negative start (e.g., -2) isn't supported",
),
(
100,
lambda x: fhe.bits(x)[1:3.1], # type: ignore
ValueError,
"Extracting bits using a non integer stop (e.g., 3.1) isn't supported",
),
(
100,
lambda x: fhe.bits(x)[1:-2],
ValueError,
"Extracting bits using a negative stop (e.g., -2) isn't supported",
),
(
100,
lambda x: fhe.bits(x)[1:3:1.1], # type: ignore
ValueError,
"Extracting bits using a non integer step (e.g., 1.1) isn't supported",
),
(
100,
lambda x: fhe.bits(x)[1:3:0],
ValueError,
"Extracting bits using zero step isn't supported",
),
(
0b_1001_1011,
lambda x: fhe.bits(x)[::-1],
ValueError,
"Extracting bits in reverse (step < 0) isn't supported without providing the start bit",
),
(
0b_1001_1011,
lambda x: fhe.bits(x)[-1],
ValueError,
"Extracting bits from the back (index == -1 < 0) isn't supported",
),
(
100,
lambda x: fhe.bits(x)[2.1], # type: ignore
ValueError,
"Bits of 100 cannot be extracted using 2.1 since it's not an integer or a slice",
),
(
3.2,
lambda x: fhe.bits(x)[0],
ValueError,
"Bits of 3.2 cannot be extracted since it's not an integer",
),
(
-100,
lambda x: fhe.bits(x)[1:],
ValueError,
(
"Extracting bits without an upper bound (stop is None) "
"isn't supported on signed values (e.g., -100)"
),
),
],
)
def test_bad_plain_bit_extraction(
sample,
operation,
expected_error,
expected_message,
):
"""
Test plain bit extraction with bad parameters.
"""
with pytest.raises(expected_error) as excinfo:
operation(sample)
assert str(excinfo.value) == expected_message
@pytest.mark.parametrize(
"input_bit_width,input_is_signed,operation",
[
# unsigned
pytest.param(3, False, lambda x: fhe.bits(x)[0:3], id="unsigned-3b[0:3]"),
pytest.param(5, False, lambda x: fhe.bits(x)[0], id="unsigned-5b[0]"),
pytest.param(5, False, lambda x: fhe.bits(x)[1], id="unsigned-5b[1]"),
pytest.param(5, False, lambda x: fhe.bits(x)[2], id="unsigned-5b[2]"),
pytest.param(5, False, lambda x: fhe.bits(x)[3], id="unsigned-5b[3]"),
pytest.param(5, False, lambda x: fhe.bits(x)[4], id="unsigned-5b[4]"),
pytest.param(5, False, lambda x: fhe.bits(x)[30], id="unsigned-5b[30]"),
pytest.param(5, False, lambda x: fhe.bits(x)[1:3], id="unsigned-5b[1:3]"),
pytest.param(5, False, lambda x: fhe.bits(x)[:3], id="unsigned-5b[:3]"),
pytest.param(5, False, lambda x: fhe.bits(x)[1:], id="unsigned-5b[1:]"),
pytest.param(5, False, lambda x: fhe.bits(x)[1:6:2], id="unsigned-5b[1:6:2]"),
pytest.param(5, False, lambda x: fhe.bits(x)[3:1:-1], id="unsigned-5b[3:1:-1]"),
pytest.param(5, False, lambda x: fhe.bits(x)[2::-1], id="unsigned-5b[2::-1]"),
pytest.param(5, False, lambda x: fhe.bits(x)[1:30:10], id="unsigned-5b[1:30:10]"),
# signed
pytest.param(3, True, lambda x: fhe.bits(x)[0:3], id="signed-3b[0:3]"),
pytest.param(5, True, lambda x: fhe.bits(x)[0], id="signed-5b[0]"),
pytest.param(5, True, lambda x: fhe.bits(x)[1], id="signed-5b[1]"),
pytest.param(5, True, lambda x: fhe.bits(x)[2], id="signed-5b[2]"),
pytest.param(5, True, lambda x: fhe.bits(x)[3], id="signed-5b[3]"),
pytest.param(5, True, lambda x: fhe.bits(x)[4], id="signed-5b[4]"),
pytest.param(5, True, lambda x: fhe.bits(x)[30], id="signed-5b[30]"),
pytest.param(5, True, lambda x: fhe.bits(x)[1:3], id="signed-5b[1:3]"),
pytest.param(5, True, lambda x: fhe.bits(x)[:3], id="signed-5b[:3]"),
pytest.param(5, True, lambda x: fhe.bits(x)[1:6:2], id="signed-5b[1:6:2]"),
pytest.param(5, True, lambda x: fhe.bits(x)[3:1:-1], id="signed-5b[3:1:-1]"),
pytest.param(5, True, lambda x: fhe.bits(x)[2::-1], id="signed-5b[2::-1]"),
pytest.param(5, True, lambda x: fhe.bits(x)[1:30:10], id="signed-5b[1:30:10]"),
# unsigned (result bit-width increased)
pytest.param(3, False, lambda x: fhe.bits(x)[0:3] + 100, id="unsigned-3b[0:3] + 100"),
pytest.param(5, False, lambda x: fhe.bits(x)[0] + 100, id="unsigned-5b[0] + 100"),
pytest.param(5, False, lambda x: fhe.bits(x)[1:3] + 100, id="unsigned-5b[1:3] + 100"),
# signed (result bit-width increased)
pytest.param(3, True, lambda x: fhe.bits(x)[0:3], id="signed-3b[0:3] + 100"),
pytest.param(5, True, lambda x: fhe.bits(x)[0] + 100, id="signed-5b[0] + 100"),
pytest.param(5, True, lambda x: fhe.bits(x)[1:3] + 100, id="signed-5b[1:3] + 100"),
# unsigned high bits
pytest.param(10, False, lambda x: fhe.bits(x)[5:15], id="unsigned-10b[5:15]"),
],
)
def test_bit_extraction(input_bit_width, input_is_signed, operation, helpers):
"""
Test bit extraction.
"""
lower_bound = 0 if not input_is_signed else -(2 ** (input_bit_width - 1))
upper_bound = 2 ** (input_bit_width if not input_is_signed else (input_bit_width - 1))
sizes = [(), (2,), (3, 2)]
for size in sizes:
inputset = [
np.random.randint(lower_bound, upper_bound, size=size)
for _ in range(2**input_bit_width)
]
compiler = fhe.Compiler(operation, {"x": "encrypted"})
circuit = compiler.compile(inputset, helpers.configuration())
values = inputset if len(inputset) <= 8 else random.sample(inputset, 8)
for value in values:
helpers.check_execution(circuit, operation, value, retries=3)
def mlir_count_ops(mlir, operation):
"""
Count op in mlir.
"""
return sum(operation in line for line in mlir.splitlines())
def test_highest_bit_extraction_mlir(helpers):
"""
Test bit extraction of the highest bit. Saves one lsb.
"""
precision = 8
inputset = list(range(2**precision))
@fhe.compiler({"x": "encrypted"})
def operation(x):
return fhe.bits(x)[precision - 1]
circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1
assert mlir_count_ops(circuit.mlir, "lookup") == 0
def test_bits_extraction_to_same_bitwidth_mlir(helpers):
"""
Test bit extraction to same.
"""
precision = 8
inputset = list(range(2**precision))
@fhe.compiler({"x": "encrypted"})
def operation(x):
return tuple(fhe.bits(x)[i] for i in range(precision))
circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1
assert mlir_count_ops(circuit.mlir, "lookup") == 0
def test_bits_extraction_to_bigger_bitwidth_mlir(helpers):
"""
Test bit extraction to bigger bitwidth.
"""
precision = 8
inputset = list(range(2**precision))
@fhe.compiler({"x": "encrypted"})
def operation(x):
return tuple(fhe.bits(x)[i] + (2**precision + 1) for i in range(precision))
circuit = operation.compile(inputset, helpers.configuration())
print(circuit.mlir)
assert mlir_count_ops(circuit.mlir, "lsb") == precision
assert mlir_count_ops(circuit.mlir, "lookup") == 0
def test_seq_bits_extraction_to_same_bitwidth_mlir(helpers):
"""
Test sequential bit extraction to smaller bitwidth.
"""
precision = 8
inputset = list(range(2**precision))
@fhe.compiler({"x": "encrypted"})
def operation(x):
return tuple(fhe.bits(x)[i] + (2**precision - 2) for i in range(precision))
circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == precision
assert mlir_count_ops(circuit.mlir, "lookup") == 0
def test_seq_bits_extraction_to_smaller_bitwidth_mlir(helpers):
"""
Test sequential bit extraction to smaller bitwidth.
"""
precision = 8
inputset = list(range(2**precision))
@fhe.compiler({"x": "encrypted"})
def operation(x):
return tuple(fhe.bits(x)[i] for i in range(precision))
circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == precision - 1
assert mlir_count_ops(circuit.mlir, "lookup") == 0
def test_seq_bits_extraction_to_bigger_bitwidth_mlir(helpers):
"""
Test sequential bit extraction to bigger bitwidth.
"""
precision = 8
inputset = list(range(2**precision))
@fhe.compiler({"x": "encrypted"})
def operation(x):
return tuple(fhe.bits(x)[i] + 2 ** (precision + 1) for i in range(precision))
circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == precision
assert mlir_count_ops(circuit.mlir, "lookup") == 0
def test_bit_extract_to_1_tlu(helpers):
"""
Test bit extract as 1 tlu for small precision.
"""
precision = 3
inputset = list(range(2**precision))
@fhe.compiler({"x": "encrypted"})
def operation(x):
return fhe.bits(x)[0:2]
circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == 0
assert mlir_count_ops(circuit.mlir, "lookup") == 1
precision = 4
inputset = list(range(2**precision))
@fhe.compiler({"x": "encrypted"})
def operation(x): # pylint: disable=function-redefined
return fhe.bits(x)[0:2]
circuit = operation.compile(inputset, helpers.configuration())
assert mlir_count_ops(circuit.mlir, "lsb") == 2
assert mlir_count_ops(circuit.mlir, "lookup") == 0
def test_bit_extraction_in_direct_circuit(helpers):
"""
Test single bit extraction in a direct circuit and check assigned bit width.
"""
configuration = helpers.configuration()
@fhe.circuit({"x": "encrypted"}, configuration)
def circuit(x: fhe.uint4):
return fhe.bits(x)[0]
assert (
str(circuit).strip()
== (
"""
%0 = x # EncryptedScalar<uint4>
%1 = bits(%0)[0] # EncryptedScalar<uint1>
return %1
"""
).strip()
)