Files
concrete/frontends/concrete-python/tests/execution/test_round_bit_pattern.py
2024-09-24 14:37:37 +02:00

874 lines
28 KiB
Python

"""
Tests of execution of round bit pattern operation.
"""
import numpy as np
import pytest
from concrete import fhe
from concrete.fhe.compilation.configuration import Exactness
from concrete.fhe.representation.utils import format_constant
@pytest.mark.parametrize(
"sample,lsbs_to_remove,expected_output",
[
(0b_0000_0011, 0, 0b_0000_0011),
(0b_0000_0100, 0, 0b_0000_0100),
(0b_0000_0000, 3, 0b_0000_0000),
(0b_0000_0001, 3, 0b_0000_0000),
(0b_0000_0010, 3, 0b_0000_0000),
(0b_0000_0011, 3, 0b_0000_0000),
(0b_0000_0100, 3, 0b_0000_1000),
(0b_0000_0101, 3, 0b_0000_1000),
(0b_0000_0110, 3, 0b_0000_1000),
(0b_0000_0111, 3, 0b_0000_1000),
(0b_0000_1000, 3, 0b_0000_1000),
(0b_0000_1001, 3, 0b_0000_1000),
(0b_0000_1010, 3, 0b_0000_1000),
(0b_0000_1011, 3, 0b_0000_1000),
(0b_0000_1100, 3, 0b_0001_0000),
(0b_0000_1101, 3, 0b_0001_0000),
(0b_0000_1110, 3, 0b_0001_0000),
(0b_0000_1111, 3, 0b_0001_0000),
],
)
def test_plain_round_bit_pattern(sample, lsbs_to_remove, expected_output):
"""
Test round bit pattern in evaluation context.
"""
assert fhe.round_bit_pattern(sample, lsbs_to_remove=lsbs_to_remove) == expected_output
@pytest.mark.parametrize(
"sample,lsbs_to_remove,expected_error,expected_message",
[
(
np.array([3.2, 4.1]),
3,
TypeError,
f"Expected input elements to be integers but they are {type(np.array([3.2, 4.1]).dtype).__name__}", # noqa: E501
),
(
"foo",
3,
TypeError,
"Expected input to be an int or a numpy array but it's str",
),
],
)
def test_bad_plain_round_bit_pattern(
sample,
lsbs_to_remove,
expected_error,
expected_message,
):
"""
Test round bit pattern in evaluation context with bad parameters.
"""
with pytest.raises(expected_error) as excinfo:
fhe.round_bit_pattern(sample, lsbs_to_remove=lsbs_to_remove)
assert str(excinfo.value) == expected_message
@pytest.mark.parametrize(
"input_bits,lsbs_to_remove",
[
(3, 1),
(3, 2),
(4, 1),
(4, 2),
(4, 3),
(5, 1),
(5, 2),
(5, 3),
(5, 4),
],
)
@pytest.mark.parametrize(
"mapper",
[
pytest.param(
lambda x: x,
id="x",
),
pytest.param(
lambda x: x + 10,
id="x + 10",
),
pytest.param(
lambda x: x**2,
id="x ** 2",
),
pytest.param(
lambda x: fhe.univariate(lambda x: x if x >= 0 else 0)(x),
id="relu",
),
],
)
@pytest.mark.parametrize(
"overflow",
[
True,
False,
],
)
def test_round_bit_pattern(input_bits, lsbs_to_remove, mapper, overflow, helpers):
"""
Test round bit pattern.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
x_rounded = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
return mapper(x_rounded)
upper_bound = (2**input_bits) - (2 ** (lsbs_to_remove - 1))
inputset = [0, upper_bound - 1]
if overflow:
upper_bound = 2**input_bits
inputset.append(upper_bound - 1)
circuit = function.compile(inputset, helpers.configuration())
helpers.check_execution(circuit, function, np.random.randint(0, upper_bound), retries=3)
for value in inputset:
helpers.check_execution(circuit, function, value, retries=3)
@pytest.mark.parametrize(
"input_bits,lsbs_to_remove",
[
(3, 1),
(3, 2),
(4, 1),
(4, 2),
(4, 3),
],
)
def test_round_bit_pattern_unsigned_range(input_bits, lsbs_to_remove, helpers):
"""
Test round bit pattern in unsigned range.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
return fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
inputset = range(0, 2**input_bits)
circuit = function.compile(inputset, helpers.configuration())
for value in inputset:
helpers.check_execution(circuit, function, value, retries=3)
@pytest.mark.parametrize(
"input_bits,lsbs_to_remove",
[
(3, 1),
(3, 2),
(4, 1),
(4, 2),
(4, 3),
],
)
def test_round_bit_pattern_signed_range(input_bits, lsbs_to_remove, helpers):
"""
Test round bit pattern in signed range.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
return fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
inputset = range(-(2 ** (input_bits - 1)), 2 ** (input_bits - 1))
circuit = function.compile(inputset, helpers.configuration())
for value in inputset:
helpers.check_execution(circuit, function, value, retries=3)
@pytest.mark.parametrize(
"input_bits,lsbs_to_remove",
[
(3, 1),
(3, 2),
(4, 1),
(4, 2),
(4, 3),
],
)
def test_round_bit_pattern_unsigned_range_assigned(input_bits, lsbs_to_remove, helpers):
"""
Test round bit pattern in unsigned range with a big bit-width assigned.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
rounded = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
return (rounded**2) + (50_000 - x)
inputset = range(0, 2**input_bits)
circuit = function.compile(inputset, helpers.configuration())
for value in inputset:
helpers.check_execution(circuit, function, value, retries=3)
@pytest.mark.parametrize(
"input_bits,lsbs_to_remove",
[
(3, 1),
(3, 2),
(4, 1),
(4, 2),
(4, 3),
],
)
def test_round_bit_pattern_signed_range_assigned(input_bits, lsbs_to_remove, helpers):
"""
Test round bit pattern in signed range with a big bit-width assigned.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
rounded = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
return (rounded**2) + (50_000 - x)
inputset = range(-(2 ** (input_bits - 1)), 2 ** (input_bits - 1))
circuit = function.compile(inputset, helpers.configuration())
for value in inputset:
helpers.check_execution(circuit, function, value, retries=3)
def test_round_bit_pattern_no_overflow_protection(helpers):
"""
Test round bit pattern without overflow protection.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
rounded = fhe.round_bit_pattern(x, lsbs_to_remove=2)
return rounded**2
inputset = range(-32, 32)
configuration = helpers.configuration()
circuit = function.compile(inputset, configuration)
expected_mlir = (
"""
module {
func.func @function(%arg0: !FHE.esint<7>) -> !FHE.eint<11> {
%0 = "FHE.round"(%arg0) : (!FHE.esint<7>) -> !FHE.esint<5>
%c2_i3 = arith.constant 2 : i3
%cst = arith.constant dense<[0, 16, 64, 144, 256, 400, 576, 784, 1024, 1296, 1600, 1936, 2304, 2704, 3136, 3600, 4096, 3600, 3136, 2704, 2304, 1936, 1600, 1296, 1024, 784, 576, 400, 256, 144, 64, 16]> : tensor<32xi64>
%1 = "FHE.apply_lookup_table"(%0, %cst) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.eint<11>
return %1 : !FHE.eint<11>
}
}
""" # noqa: E501
if not configuration.single_precision
else """
module {
func.func @function(%arg0: !FHE.esint<11>) -> !FHE.eint<11> {
%c16_i12 = arith.constant 16 : i12
%0 = "FHE.mul_eint_int"(%arg0, %c16_i12) : (!FHE.esint<11>, i12) -> !FHE.esint<11>
%1 = "FHE.round"(%0) : (!FHE.esint<11>) -> !FHE.esint<5>
%c2_i12 = arith.constant 2 : i12
%cst = arith.constant dense<[0, 16, 64, 144, 256, 400, 576, 784, 1024, 1296, 1600, 1936, 2304, 2704, 3136, 3600, 4096, 3600, 3136, 2704, 2304, 1936, 1600, 1296, 1024, 784, 576, 400, 256, 144, 64, 16]> : tensor<32xi64>
%2 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.esint<5>, tensor<32xi64>) -> !FHE.eint<11>
return %2 : !FHE.eint<11>
}
}
""" # noqa: E501
)
helpers.check_str(expected_mlir, circuit.mlir)
def test_round_bit_pattern_identity(helpers):
"""
Test round bit pattern used multiple times outside TLUs.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
rounded = fhe.round_bit_pattern(x, lsbs_to_remove=2, overflow_protection=False)
return rounded + rounded
inputset = range(-20, 20)
configuration = helpers.configuration()
circuit = function.compile(inputset, configuration)
expected_mlir = (
"""
module {
func.func @function(%arg0: !FHE.esint<6>) -> !FHE.esint<7> {
%0 = "FHE.round"(%arg0) : (!FHE.esint<6>) -> !FHE.esint<4>
%cst = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, -32, -28, -24, -20, -16, -12, -8, -4]> : tensor<16xi64>
%1 = "FHE.apply_lookup_table"(%0, %cst) : (!FHE.esint<4>, tensor<16xi64>) -> !FHE.esint<7>
%2 = "FHE.add_eint"(%1, %1) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
return %2 : !FHE.esint<7>
}
}
""" # noqa: E501
if not configuration.single_precision
else """
module {
func.func @function(%arg0: !FHE.esint<7>) -> !FHE.esint<7> {
%c2_i8 = arith.constant 2 : i8
%0 = "FHE.mul_eint_int"(%arg0, %c2_i8) : (!FHE.esint<7>, i8) -> !FHE.esint<7>
%1 = "FHE.round"(%0) : (!FHE.esint<7>) -> !FHE.esint<4>
%cst = arith.constant dense<[0, 4, 8, 12, 16, 20, 24, 28, -32, -28, -24, -20, -16, -12, -8, -4]> : tensor<16xi64>
%2 = "FHE.apply_lookup_table"(%1, %cst) : (!FHE.esint<4>, tensor<16xi64>) -> !FHE.esint<7>
%3 = "FHE.add_eint"(%2, %2) : (!FHE.esint<7>, !FHE.esint<7>) -> !FHE.esint<7>
return %3 : !FHE.esint<7>
}
}
""" # noqa: E501
)
helpers.check_str(expected_mlir, circuit.mlir)
def test_auto_rounding(helpers):
"""
Test round bit pattern with auto rounding.
"""
# with auto adjust rounders configuration
# ---------------------------------------
# y has the max value of 1999, so it's 11 bits
# our target msb is 5 bits, which means we need to remove 6 of the least significant bits
rounder1 = fhe.AutoRounder(target_msbs=5)
@fhe.compiler({"x": "encrypted"})
def function1(x):
y = x + 1000
z = fhe.round_bit_pattern(y, lsbs_to_remove=rounder1)
return np.sqrt(z).astype(np.int64)
inputset1 = range(1000)
function1.trace(inputset1, helpers.configuration(), auto_adjust_rounders=True)
assert rounder1.lsbs_to_remove == 6
# manual
# ------
# y has the max value of 1999, so it's 11 bits
# our target msb is 3 bits, which means we need to remove 8 of the least significant bits
rounder2 = fhe.AutoRounder(target_msbs=3)
@fhe.compiler({"x": "encrypted"})
def function2(x):
y = x + 1000
z = fhe.round_bit_pattern(y, lsbs_to_remove=rounder2)
return np.sqrt(z).astype(np.int64)
inputset2 = range(1000)
fhe.AutoRounder.adjust(function2, inputset2)
assert rounder2.lsbs_to_remove == 8
# complicated case
# ----------------
# have 2 ** 8 entries during evaluation, it won't matter after compilation
entries3 = list(range(2**8))
# we have 8-bit inputs for this table, and we only want to use first 5-bits
for i in range(0, 2**8, 2**3):
# so we set every 8th entry to a 4-bit value
entries3[i] = np.random.randint(0, (2**4) - (2**2))
# when this tlu is applied to an 8-bit value with 5-bit msb rounding, result will be 4-bits
table3 = fhe.LookupTable(entries3)
# and this is the rounder for table1, which should have lsbs_to_remove of 3
rounder3 = fhe.AutoRounder(target_msbs=5)
# have 2 ** 8 entries during evaluation, it won't matter after compilation
entries4 = list(range(2**8))
# we have 4-bit inputs for this table, and we only want to use first 2-bits
for i in range(0, 2**4, 2**2):
# so we set every 4th entry to an 8-bit value
entries4[i] = np.random.randint(2**7, 2**8)
# when this tlu is applied to a 4-bit value with 2-bit msb rounding, result will be 8-bits
table4 = fhe.LookupTable(entries4)
# and this is the rounder for table2, which should have lsbs_to_remove of 2
rounder4 = fhe.AutoRounder(target_msbs=2)
@fhe.compiler({"x": "encrypted"})
def function3(x):
a = fhe.round_bit_pattern(x, lsbs_to_remove=rounder3)
b = table3[a]
c = fhe.round_bit_pattern(b, lsbs_to_remove=rounder4)
d = table4[c]
return d
inputset3 = range((2**8) - (2**3))
circuit3 = function3.compile(
inputset3,
helpers.configuration(),
auto_adjust_rounders=True,
)
assert rounder3.lsbs_to_remove == 3
assert rounder4.lsbs_to_remove == 2
table3_formatted_string = format_constant(table3.table, 25)
table4_formatted_string = format_constant(table4.table, 25)
helpers.check_str(
f"""
%0 = x # EncryptedScalar<uint8>
%1 = round_bit_pattern(%0, lsbs_to_remove=3, overflow_protection=True, exactness=None) # EncryptedScalar<uint8>
%2 = tlu(%1, table={table3_formatted_string}) # EncryptedScalar<uint4>
%3 = round_bit_pattern(%2, lsbs_to_remove=2, overflow_protection=True, exactness=None) # EncryptedScalar<uint4>
%4 = tlu(%3, table={table4_formatted_string}) # EncryptedScalar<uint8>
return %4
""", # noqa: E501
str(circuit3.graph.format(show_bounds=False)),
)
def test_auto_rounding_without_adjustment():
"""
Test round bit pattern with auto rounding but without adjustment.
"""
rounder = fhe.AutoRounder(target_msbs=5)
def function(x):
y = x + 1000
z = fhe.round_bit_pattern(y, lsbs_to_remove=rounder)
return np.sqrt(z).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
function(100)
assert str(excinfo.value) == (
"AutoRounders cannot be used before adjustment, "
"please call AutoRounder.adjust with the function that will be compiled "
"and provide the exact inputset that will be used for compilation"
)
def test_auto_rounding_with_empty_inputset():
"""
Test round bit pattern with auto rounding but with empty inputset.
"""
rounder = fhe.AutoRounder(target_msbs=5)
def function(x):
y = x + 1000
z = fhe.round_bit_pattern(y, lsbs_to_remove=rounder)
return np.sqrt(z).astype(np.int64)
with pytest.raises(ValueError) as excinfo:
fhe.AutoRounder.adjust(function, [])
assert str(excinfo.value) == "AutoRounders cannot be adjusted with an empty inputset"
def test_auto_rounding_recursive_adjustment():
"""
Test round bit pattern with auto rounding but with recursive adjustment.
"""
rounder = fhe.AutoRounder(target_msbs=5)
def function(x):
fhe.AutoRounder.adjust(function, range(10))
y = x + 1000
z = fhe.round_bit_pattern(y, lsbs_to_remove=rounder)
return np.sqrt(z).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
fhe.AutoRounder.adjust(function, range(10))
assert str(excinfo.value) == "AutoRounders cannot be adjusted recursively"
def test_auto_rounding_construct_in_function():
"""
Test round bit pattern with auto rounding but rounder is constructed within the function.
"""
def function(x):
y = x + 1000
z = fhe.round_bit_pattern(y, lsbs_to_remove=fhe.AutoRounder(target_msbs=5))
return np.sqrt(z).astype(np.int64)
with pytest.raises(RuntimeError) as excinfo:
fhe.AutoRounder.adjust(function, range(10))
assert str(excinfo.value) == (
"AutoRounders cannot be constructed during adjustment, "
"please construct AutoRounders outside the function and reference it"
)
def test_overflowing_round_bit_pattern_with_lsbs_to_remove_of_one(helpers):
"""
Test round bit pattern where overflow is detected when only one bit is to be removed.
"""
# pylint: disable=invalid-name
def subgraph(v0):
v1 = v0.astype(np.float32)
v2 = 0
v3 = np.add(v1, v2)
v4 = np.array([[3, -4, -9, -12, 5]])
v5 = np.subtract(v3, v4)
v6 = 0.8477082011119198
v7 = np.multiply(v6, v5)
v8 = [-0.16787443, -0.4266992, 0.33739513, 0.15412766, -0.39808342]
v9 = np.add(v7, v8)
v10 = 0.16666666666666666
v11 = np.multiply(v10, v9)
v12 = 0.5
v13 = np.add(v11, v12)
v14 = 1
v15 = np.minimum(v14, v13)
v16 = 0
v17 = np.maximum(v16, v15)
v18 = np.multiply(v9, v17)
v19 = 7.667494564113974
v20 = np.divide(v18, v19)
v21 = -8
v22 = np.add(v20, v21)
v23 = np.rint(v22)
v24 = -8
v25 = 7
v26 = np.clip(v23, v24, v25)
v27 = v26.astype(np.int64)
return v27
def function(v0):
v1 = np.array(
[
[2, -5, 1, 6, -4],
[2, 7, 6, -3, 5],
[0, -3, 4, -2, -5],
[-5, 5, -4, 7, 3],
[-2, 0, 2, 4, -4],
]
)
v2 = np.matmul(v0, v1)
v3 = fhe.round_bit_pattern(v2, lsbs_to_remove=1)
v4 = subgraph(v3)
v5 = np.array(
[
[2, -3, 5, 2, -5],
[-6, -1, -2, 7, 4],
[-6, 6, -6, 3, 3],
[-7, -1, -5, 1, 1],
[-3, -1, 5, -1, -5],
]
)
v6 = np.matmul(v4, v5)
return v6
# pylint: enable=invalid-name
sample = np.array([[-7, 5, -8, -8, -8]])
compiler = fhe.Compiler(function, {"v0": "encrypted"})
circuit = compiler.compile(inputset=[sample], configuration=helpers.configuration())
helpers.check_execution(circuit, function, sample, retries=5)
def test_round_bit_pattern_overflow_to_sign_bit(helpers):
"""
Test round bit pattern where it's applies to a signed value and p-1 bits are removed.
"""
@fhe.compiler({"x": "encrypted"})
def function(x):
return fhe.round_bit_pattern(x - 4, lsbs_to_remove=3) // (2**3)
inputset = range(10)
configuration = helpers.configuration()
circuit = function.compile(inputset, configuration)
circuit.keys.generate()
for x in inputset:
helpers.check_execution(circuit, function, x, retries=3)
def test_round_bit_pattern_approximate_enabling(helpers):
"""
Test round bit pattern various activation paths.
"""
@fhe.compiler({"x": "encrypted"})
def function_default(x):
return fhe.round_bit_pattern(x, lsbs_to_remove=8)
@fhe.compiler({"x": "encrypted"})
def function_exact(x):
return fhe.round_bit_pattern(x, lsbs_to_remove=8, exactness=Exactness.EXACT)
@fhe.compiler({"x": "encrypted"})
def function_approx(x):
return fhe.round_bit_pattern(x, lsbs_to_remove=8, exactness=Exactness.APPROXIMATE)
inputset = [-(2**10), 2**10 - 1]
configuration = helpers.configuration()
circuit_default_default = function_default.compile(inputset, configuration)
circuit_default_exact = function_default.compile(
inputset, configuration.fork(rounding_exactness=Exactness.EXACT)
)
circuit_default_approx = function_default.compile(
inputset, configuration.fork(rounding_exactness=Exactness.APPROXIMATE)
)
circuit_exact = function_exact.compile(
inputset, configuration.fork(rounding_exactness=Exactness.APPROXIMATE)
)
circuit_approx = function_approx.compile(
inputset, configuration.fork(rounding_exactness=Exactness.EXACT)
)
assert circuit_approx.complexity < circuit_exact.complexity
assert circuit_exact.complexity == circuit_default_default.complexity
assert circuit_exact.complexity == circuit_default_exact.complexity
assert circuit_approx.complexity == circuit_default_approx.complexity
@pytest.mark.parametrize(
"accumulator_precision,reduced_precision,signed,conf",
[
(8, 4, True, fhe.ApproximateRoundingConfig(False, 4)),
(7, 4, False, fhe.ApproximateRoundingConfig(False, 4)),
(9, 3, True, fhe.ApproximateRoundingConfig(True, False)),
(8, 3, False, fhe.ApproximateRoundingConfig(True, False)),
(7, 3, False, fhe.ApproximateRoundingConfig(True, 3)),
(7, 2, True, fhe.ApproximateRoundingConfig(False, 2)),
(7, 2, False, fhe.ApproximateRoundingConfig(False, False, False, False)),
(8, 1, True, fhe.ApproximateRoundingConfig(False, 1)),
(8, 1, False, fhe.ApproximateRoundingConfig(True, False)),
(6, 5, False, fhe.ApproximateRoundingConfig(True, 6)),
(6, 5, False, fhe.ApproximateRoundingConfig(True, 5)),
],
)
def test_round_bit_pattern_approximate_off_by_one_errors(
accumulator_precision, reduced_precision, signed, conf, helpers
):
"""
Test round bit pattern off by 1 errors.
"""
lsbs_to_remove = accumulator_precision - reduced_precision
@fhe.compiler({"x": "encrypted"})
def function(x):
x = fhe.univariate(lambda x: x)(x)
x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
x = x // 2**lsbs_to_remove
return x
if signed:
inputset = [-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 1) - 1]
else:
inputset = [0, 2**accumulator_precision - 1]
configuration = helpers.configuration()
circuit_exact = function.compile(inputset, configuration)
circuit_approx = function.compile(
inputset,
configuration.fork(
approximate_rounding_config=conf, rounding_exactness=Exactness.APPROXIMATE
),
)
# check it's better even with bad conf
assert circuit_approx.complexity < circuit_exact.complexity
# avoiding overflows
if signed:
testset = [
-(2 ** (accumulator_precision - 1)),
2 ** (accumulator_precision - 1) - 2**lsbs_to_remove,
]
else:
testset = [0, 2**accumulator_precision - 2**lsbs_to_remove]
nb_error = 0
for x in testset:
approx = circuit_approx.encrypt_run_decrypt(x)
approx_simu = circuit_approx.simulate(x)
exact = circuit_exact.simulate(x)
delta_simu = abs(approx_simu - exact)
delta = abs(approx - exact)
assert delta_simu <= 1
assert delta <= 1
nb_error += delta > 0
nb_transitions = 2 ** (accumulator_precision - reduced_precision)
assert nb_error <= 3 * nb_transitions # of the same order as transitions but small sample size
@pytest.mark.parametrize(
"signed,physical",
[(signed, physical) for signed in (True, False) for physical in (True, False)],
)
def test_round_bit_pattern_approximate_clippping(signed, physical, helpers):
"""
Test round bit pattern clipping.
"""
accumulator_precision = 6
reduced_precision = 3
lsbs_to_remove = accumulator_precision - reduced_precision
@fhe.compiler({"x": "encrypted"})
def function(x):
x = fhe.univariate(lambda x: x)(x)
x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove)
x = x // 2**lsbs_to_remove
return x
if signed:
input_domain = range(-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 1))
else:
input_domain = range(0, 2 ** (accumulator_precision))
configuration = helpers.configuration()
approx_conf = fhe.ApproximateRoundingConfig(
logical_clipping=not physical,
approximate_clipping_start_precision=physical and reduced_precision,
reduce_precision_after_approximate_clipping=False,
)
no_clipping_conf = fhe.ApproximateRoundingConfig(
logical_clipping=False, approximate_clipping_start_precision=False
)
assert approx_conf.logical_clipping or approx_conf.approximate_clipping_start_precision
circuit_clipping = function.compile(
input_domain,
configuration.fork(
approximate_rounding_config=approx_conf, rounding_exactness=Exactness.APPROXIMATE
),
)
circuit_no_clipping = function.compile(
input_domain,
configuration.fork(
approximate_rounding_config=no_clipping_conf, rounding_exactness=Exactness.APPROXIMATE
),
)
if signed:
clipped_output_domain = range(-(2 ** (reduced_precision - 1)), 2 ** (reduced_precision - 1))
else:
clipped_output_domain = range(0, 2**reduced_precision)
# With clipping
for x in input_domain:
assert (
circuit_clipping.encrypt_run_decrypt(x) in clipped_output_domain
), circuit_clipping.mlir # no overflow
assert circuit_clipping.simulate(x) in clipped_output_domain
# Without clipping
# overflow
assert circuit_no_clipping.simulate(input_domain[-1]) not in clipped_output_domain
@pytest.mark.parametrize(
"signed,accumulator_precision",
[
(signed, accumulator_precision)
for signed in (True, False)
for accumulator_precision in (13, 24)
],
)
def test_round_bit_pattern_approximate_acc_to_6_costs(signed, accumulator_precision, helpers):
"""
Test round bit pattern speedup when approximatipn is activated.
"""
reduced_precision = 6
lsbs_to_remove = accumulator_precision - reduced_precision
@fhe.compiler({"x": "encrypted"})
def function(x):
x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove, overflow_protection=True)
x = x // 2**lsbs_to_remove
return x
# with overflow
if signed:
input_domain = [-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 1) - 1]
else:
input_domain = [0, 2 ** (accumulator_precision) - 1]
configuration = helpers.configuration().fork(
single_precision=False,
parameter_selection_strategy=fhe.ParameterSelectionStrategy.MULTI,
composable=True,
)
circuit_exact = function.compile(input_domain, configuration)
approx_conf_fastest = fhe.ApproximateRoundingConfig(approximate_clipping_start_precision=6)
approx_conf_safest = fhe.ApproximateRoundingConfig(approximate_clipping_start_precision=100)
circuit_approx_fastest = function.compile(
input_domain,
configuration.fork(
approximate_rounding_config=approx_conf_fastest,
rounding_exactness=Exactness.APPROXIMATE,
),
)
circuit_approx_safest = function.compile(
input_domain,
configuration.fork(
approximate_rounding_config=approx_conf_safest, rounding_exactness=Exactness.APPROXIMATE
),
)
assert circuit_approx_safest.complexity < circuit_exact.complexity
assert circuit_approx_fastest.complexity < circuit_approx_safest.complexity
@fhe.compiler({"x": "encrypted"})
def function(x): # pylint: disable=function-redefined
x = fhe.round_bit_pattern(x, lsbs_to_remove=lsbs_to_remove, overflow_protection=False)
x = x // 2**lsbs_to_remove
return x
# without overflow
if signed:
input_domain = [-(2 ** (accumulator_precision - 1)), 2 ** (accumulator_precision - 2) - 2]
else:
input_domain = [0, 2 ** (accumulator_precision - 1) - 2]
circuit_exact_no_ovf = function.compile(input_domain, configuration)
circuit_approx_fastest_no_ovf = function.compile(
input_domain,
configuration.fork(
approximate_rounding_config=approx_conf_fastest,
rounding_exactness=Exactness.APPROXIMATE,
),
)
circuit_approx_safest_no_ovf = function.compile(
input_domain,
configuration.fork(
approximate_rounding_config=approx_conf_safest, rounding_exactness=Exactness.APPROXIMATE
),
)
assert circuit_approx_fastest_no_ovf.complexity == circuit_approx_safest_no_ovf.complexity
assert circuit_approx_safest_no_ovf.complexity < circuit_exact_no_ovf.complexity
assert circuit_exact_no_ovf.complexity < circuit_exact.complexity