Files
concrete/frontends/concrete-python/tests/compilation/test_circuit.py

841 lines
25 KiB
Python

"""
Tests of `Circuit` class.
"""
import tempfile
from pathlib import Path
import numpy as np
import pytest
from concrete import fhe
from concrete.fhe import Client, ClientSpecs, EvaluationKeys, LookupTable, Server, Value
def test_circuit_str(helpers):
"""
Test `__str__` method of `Circuit` class.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(100)]
circuit = f.compile(inputset, configuration.fork(p_error=6e-5))
assert str(circuit) == circuit.graph.format()
@pytest.mark.graphviz
def test_circuit_draw(helpers):
"""
Test `draw` method of `Circuit` class.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return (x**2) * y + 2
inputset = [
(
np.random.randint(0, 2**4, size=(2,)),
np.random.randint(0, 2**5, size=()),
)
for _ in range(100)
]
circuit = f.compile(inputset, configuration)
drawing = circuit.draw()
assert drawing.suffix == ".png"
assert drawing.exists()
with tempfile.TemporaryDirectory() as path:
tmpdir = Path(path)
png = tmpdir / "drawing.png"
circuit.draw(save_to=png)
assert png.exists()
def test_circuit_feedback(helpers):
"""
Test feedback properties of `Circuit` class.
"""
configuration = helpers.configuration()
p_error = 0.1
global_p_error = 0.05
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return np.sqrt(((x + y) ** 2) + 10).astype(np.int64)
inputset = [(np.random.randint(0, 2**2), np.random.randint(0, 2**2)) for _ in range(100)]
circuit = f.compile(inputset, configuration, p_error=p_error, global_p_error=global_p_error)
assert isinstance(circuit.complexity, float)
assert isinstance(circuit.size_of_secret_keys, int)
assert isinstance(circuit.size_of_bootstrap_keys, int)
assert isinstance(circuit.size_of_keyswitch_keys, int)
assert isinstance(circuit.size_of_inputs, int)
assert isinstance(circuit.size_of_outputs, int)
assert isinstance(circuit.p_error, float)
assert isinstance(circuit.global_p_error, float)
assert isinstance(circuit.memory_usage_per_location, dict)
assert all(
isinstance(key, str) and isinstance(value, int)
for key, value in circuit.memory_usage_per_location.items()
)
assert circuit.p_error <= p_error
assert circuit.global_p_error <= global_p_error
def test_circuit_bad_run(helpers):
"""
Test `run` method of `Circuit` class with bad parameters.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(100)]
circuit = f.compile(inputset, configuration)
# with 1 argument
# ---------------
with pytest.raises(ValueError) as excinfo:
circuit.encrypt_run_decrypt(1)
assert str(excinfo.value) == "Expected 2 inputs but got 1"
# with 3 arguments
# ----------------
with pytest.raises(ValueError) as excinfo:
circuit.encrypt_run_decrypt(1, 2, 3)
assert str(excinfo.value) == "Expected 2 inputs but got 3"
# with negative argument 0
# ------------------------
with pytest.raises(ValueError) as excinfo:
circuit.encrypt_run_decrypt(-1, 11)
assert str(excinfo.value) == (
"Expected argument 0 to be EncryptedScalar<uint6> but it's EncryptedScalar<int1>"
)
# with negative argument 1
# ------------------------
with pytest.raises(ValueError) as excinfo:
circuit.encrypt_run_decrypt(1, -11)
assert str(excinfo.value) == (
"Expected argument 1 to be EncryptedScalar<uint6> but it's EncryptedScalar<int5>"
)
# with large argument 0
# ---------------------
with pytest.raises(ValueError) as excinfo:
circuit.encrypt_run_decrypt(100, 10)
assert str(excinfo.value) == (
"Expected argument 0 to be EncryptedScalar<uint6> but it's EncryptedScalar<uint7>"
)
# with large argument 1
# ---------------------
with pytest.raises(ValueError) as excinfo:
circuit.encrypt_run_decrypt(1, 100)
assert str(excinfo.value) == (
"Expected argument 1 to be EncryptedScalar<uint6> but it's EncryptedScalar<uint7>"
)
# with None
# ---------
with pytest.raises(ValueError) as excinfo:
circuit.encrypt_run_decrypt(None, 10)
assert str(excinfo.value) == "Expected argument 0 to be an fhe.Value but it's None"
# with non Value
# --------------
with pytest.raises(ValueError) as excinfo:
_, b = circuit.encrypt(None, 10)
circuit.run({"yes": "no"}, b)
assert str(excinfo.value) == "Expected argument 0 to be an fhe.Value but it's dict"
# with invalid argument
# ---------------------
with pytest.raises(ValueError) as excinfo:
circuit.encrypt_run_decrypt({"yes": "no"}, 10)
assert str(excinfo.value) == "Expected argument 0 to be EncryptedScalar<uint6> but it's dict"
def test_circuit_separate_args(helpers):
"""
Test running circuit with separately encrypted args.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def function(x, y):
return x + y
inputset = [
(
np.random.randint(0, 10, size=()),
np.random.randint(0, 10, size=(3,)),
)
for _ in range(10)
]
circuit = function.compile(inputset, configuration)
x = 4
y = [1, 2, 3]
x_encrypted, _ = circuit.encrypt(x, None)
_, y_encrypted = circuit.encrypt(None, y)
x_plus_y_encrypted = circuit.run(x_encrypted, y_encrypted)
x_plus_y = circuit.decrypt(x_plus_y_encrypted)
assert np.array_equal(x_plus_y, x + np.array(y))
def test_client_server_api(helpers):
"""
Test client/server API.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = [np.random.randint(0, 10, size=(3,)) for _ in range(10)]
circuit = function.compile(inputset, configuration.fork())
# for coverage
circuit.keygen()
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = Path(tmp_dir)
server_path = tmp_dir_path / "server.zip"
circuit.server.save(server_path)
client_path = tmp_dir_path / "client.zip"
circuit.client.save(client_path)
circuit.cleanup()
server = Server.load(server_path)
serialized_client_specs = server.client_specs.serialize()
client_specs = ClientSpecs.deserialize(serialized_client_specs)
clients = [
Client(client_specs, configuration.insecure_key_cache_location),
Client.load(client_path, configuration.insecure_key_cache_location),
]
for client in clients:
arg = client.encrypt([3, 8, 1])
serialized_arg = arg.serialize()
serialized_evaluation_keys = client.evaluation_keys.serialize()
deserialized_arg = Value.deserialize(serialized_arg)
deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys)
result = server.run(deserialized_arg, evaluation_keys=deserialized_evaluation_keys)
serialized_result = result.serialize()
deserialized_result = Value.deserialize(serialized_result)
output = client.decrypt(deserialized_result)
assert np.array_equal(output, [45, 50, 43])
with pytest.raises(RuntimeError) as excinfo:
server.save("UNUSED", via_mlir=True)
assert str(excinfo.value) == "Loaded server objects cannot be saved again via MLIR"
with pytest.raises(ValueError) as excinfo:
client.encrypt([1, 2, 3], function_name="foo")
assert str(excinfo.value) == "Function `foo` is not in the module"
server.cleanup()
def test_client_server_api_run_with_clear(helpers):
"""
Test running server run API with a clear input.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted", "y": "clear", "z": "clear"})
def function(x, y, z):
return x + y + z
inputset = fhe.inputset(fhe.uint3, fhe.uint3, fhe.tensor[fhe.uint3, 2, 2]) # type: ignore
circuit = function.compile(inputset, configuration.fork())
client = circuit.client
server = circuit.server
x, y, z = 3, 2, [[1, 2], [3, 4]]
encrypted_x, _, _ = client.encrypt(x, None, None)
encrypted_result = server.run(encrypted_x, y, z, evaluation_keys=client.evaluation_keys)
result = client.decrypt(encrypted_result)
assert np.array_equal(result, x + y + np.array(z))
with pytest.raises(ValueError) as excinfo:
server.run(1, 2, 3, evaluation_keys=client.evaluation_keys)
assert str(excinfo.value) == "Expected argument 0 to be an fhe.Value but it's int"
with pytest.raises(RuntimeError) as excinfo:
server.run(encrypted_x, [2, 2], 3, evaluation_keys=client.evaluation_keys)
assert str(excinfo.value) == "Tried to transform plaintext value with incompatible shape."
def test_client_server_api_crt(helpers):
"""
Test client/server API on a CRT circuit.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted"})
def function(x):
return x**2
inputset = [np.random.randint(0, 200, size=(3,)) for _ in range(10)]
circuit = function.compile(inputset, configuration.fork())
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = Path(tmp_dir)
server_path = tmp_dir_path / "server.zip"
circuit.server.save(server_path)
client_path = tmp_dir_path / "client.zip"
circuit.client.save(client_path)
server = Server.load(server_path)
serialized_client_specs = server.client_specs.serialize()
client_specs = ClientSpecs.deserialize(serialized_client_specs)
clients = [
Client(client_specs, configuration.insecure_key_cache_location),
Client.load(client_path, configuration.insecure_key_cache_location),
]
for client in clients:
arg = client.encrypt([100, 150, 10])
serialized_arg = arg.serialize()
serialized_evaluation_keys = client.evaluation_keys.serialize()
deserialized_arg = Value.deserialize(serialized_arg)
deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys)
result = server.run(deserialized_arg, evaluation_keys=deserialized_evaluation_keys)
serialized_result = result.serialize()
deserialized_result = Value.deserialize(serialized_result)
output = client.decrypt(deserialized_result)
assert np.array_equal(output, [100**2, 150**2, 10**2])
def test_client_server_api_via_mlir(helpers):
"""
Test client/server API.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted"})
def function(x):
return x + 42
inputset = [np.random.randint(0, 10, size=(3,)) for _ in range(10)]
circuit = function.compile(inputset, configuration.fork())
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = Path(tmp_dir)
server_path = tmp_dir_path / "server.zip"
circuit.server.save(server_path, via_mlir=True)
client_path = tmp_dir_path / "client.zip"
circuit.client.save(client_path)
circuit.cleanup()
server = Server.load(server_path)
serialized_client_specs = server.client_specs.serialize()
client_specs = ClientSpecs.deserialize(serialized_client_specs)
clients = [
Client(client_specs, configuration.insecure_key_cache_location),
Client.load(client_path, configuration.insecure_key_cache_location),
]
for client in clients:
arg = client.encrypt([3, 8, 1])
serialized_arg = arg.serialize()
serialized_evaluation_keys = client.evaluation_keys.serialize()
deserialized_arg = Value.deserialize(serialized_arg)
deserialized_evaluation_keys = EvaluationKeys.deserialize(serialized_evaluation_keys)
result = server.run(deserialized_arg, evaluation_keys=deserialized_evaluation_keys)
serialized_result = result.serialize()
deserialized_result = Value.deserialize(serialized_result)
output = client.decrypt(deserialized_result)
assert np.array_equal(output, [45, 50, 43])
server.cleanup()
def test_server_loading_via_mlir_kwargs(helpers):
"""
Test server loading via MLIR with kwarg overrides.
"""
configuration = helpers.configuration().fork(global_p_error=None, p_error=0.001)
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def function(x, y):
return x == y
inputset = fhe.inputset(fhe.uint4, fhe.uint4)
circuit = function.compile(inputset, configuration)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = Path(tmp_dir)
server_path = tmp_dir_path / "server.zip"
circuit.server.save(server_path, via_mlir=True)
server = Server.load(server_path, p_error=0.05)
assert server.complexity < circuit.complexity
def test_circuit_run_with_unused_arg(helpers):
"""
Test `encrypt_run_decrypt` method of `Circuit` class with unused arguments.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y): # pylint: disable=unused-argument
return x + 10
inputset = [(np.random.randint(2**3, 2**4), np.random.randint(2**4, 2**5)) for _ in range(100)]
circuit = f.compile(inputset, configuration)
with pytest.raises(ValueError, match="Expected 2 inputs but got 1"):
circuit.encrypt_run_decrypt(10)
assert circuit.encrypt_run_decrypt(10, 0) == 20
assert circuit.encrypt_run_decrypt(10, 10) == 20
assert circuit.encrypt_run_decrypt(10, 20) == 20
@pytest.mark.dataflow
def test_dataflow_circuit(helpers):
"""
Test execution with dataflow_parallelize=True.
"""
configuration = helpers.configuration().fork(dataflow_parallelize=True)
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return (x**2) + (y // 2)
inputset = [(np.random.randint(0, 2**3), np.random.randint(0, 2**3)) for _ in range(100)]
circuit = f.compile(inputset, configuration)
assert circuit.encrypt_run_decrypt(5, 6) == 28
def test_circuit_sim_disabled(helpers):
"""
Test attempt to simulate without enabling fhe simulation.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
circuit = f.compile(inputset, configuration)
assert circuit.simulate(*inputset[0]) == f(*inputset[0])
def test_circuit_fhe_exec_disabled(helpers):
"""
Test attempt to run fhe execution without it being enabled.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
circuit = f.compile(inputset, configuration.fork(fhe_execution=False))
assert circuit.encrypt_run_decrypt(*inputset[0]) == f(*inputset[0])
def test_circuit_fhe_exec_no_eval_keys(helpers):
"""
Test attempt to run fhe execution without eval keys.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return x + y
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
circuit = f.compile(inputset, configuration)
with pytest.raises(RuntimeError) as excinfo:
# as we can't encrypt, we just pass plain inputs, and it should lead to the expected error
encrypted_args = inputset[0]
circuit.server.run(*encrypted_args)
assert (
str(excinfo.value) == "Expected evaluation keys to be provided when not in simulation mode"
)
def test_circuit_eval_graph_scalar(helpers):
"""
Test evaluation of the graph.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
lut = LookupTable(list(range(128)))
return lut[x + y]
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
circuit = f.compile(inputset, configuration.fork(fhe_simulation=False, fhe_execution=False))
assert f(*inputset[0]) == circuit.graph(*inputset[0])
def test_circuit_eval_graph_tensor(helpers):
"""
Test evaluation of the graph.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
lut = LookupTable(list(range(128)))
return lut[x + y]
inputset = [
(
np.random.randint(0, 2**4, size=[2, 2]),
np.random.randint(0, 2**5, size=[2, 2]),
)
for _ in range(2)
]
circuit = f.compile(inputset, configuration.fork(fhe_simulation=False, fhe_execution=False))
assert np.all(f(*inputset[0]) == circuit.graph(*inputset[0]))
def test_circuit_compile_sim_only(helpers):
"""
Test compiling with simulation only.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
lut = LookupTable(list(range(128)))
return lut[x + y]
inputset = [(np.random.randint(0, 2**4), np.random.randint(0, 2**5)) for _ in range(2)]
circuit = f.compile(inputset, configuration.fork(fhe_simulation=True, fhe_execution=False))
assert f(*inputset[0]) == circuit.simulate(*inputset[0])
def tagged_function(x, y, z):
"""
A tagged function to test statistics.
"""
with fhe.tag("a"):
x = fhe.univariate(lambda v: v)(x)
with fhe.tag("b"):
y = fhe.univariate(lambda v: v)(y)
with fhe.tag("c"):
z = fhe.univariate(lambda v: v)(z)
return x + y + z
@pytest.mark.parametrize(
"function,parameters,expected_statistics",
[
pytest.param(
lambda x: x**2,
{
"x": {"status": "encrypted", "range": [0, 10], "shape": ()},
},
{
"programmable_bootstrap_count": 1,
"clear_addition_count": 0,
"encrypted_addition_count": 0,
"clear_multiplication_count": 0,
"encrypted_negation_count": 0,
},
id="x**2 | x.is_encrypted | x.shape == ()",
),
pytest.param(
lambda x: x**2,
{
"x": {"status": "encrypted", "range": [0, 10], "shape": (3,)},
},
{
"programmable_bootstrap_count": 3,
"clear_addition_count": 0,
"encrypted_addition_count": 0,
"clear_multiplication_count": 0,
"encrypted_negation_count": 0,
},
id="x**2 | x.is_encrypted | x.shape == (3,)",
),
pytest.param(
lambda x: x**2,
{
"x": {"status": "encrypted", "range": [0, 10], "shape": (3, 2)},
},
{
"programmable_bootstrap_count": 3 * 2,
"clear_addition_count": 0,
"encrypted_addition_count": 0,
"clear_multiplication_count": 0,
"encrypted_negation_count": 0,
},
id="x**2 | x.is_encrypted | x.shape == (3, 2)",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"status": "encrypted", "range": [0, 10], "shape": ()},
"y": {"status": "encrypted", "range": [0, 10], "shape": ()},
},
{
"programmable_bootstrap_count": 2,
"clear_addition_count": 1,
"encrypted_addition_count": 3,
"clear_multiplication_count": 0,
"encrypted_negation_count": 2,
},
id="x * y | x.is_encrypted | x.shape == () | y.is_encrypted | y.shape == ()",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"status": "encrypted", "range": [0, 10], "shape": (3,)},
"y": {"status": "encrypted", "range": [0, 10], "shape": (3,)},
},
{
"programmable_bootstrap_count": 3 * 2,
"clear_addition_count": 3 * 1,
"encrypted_addition_count": 3 * 3,
"clear_multiplication_count": 0,
"encrypted_negation_count": 3 * 2,
},
id="x * y | x.is_encrypted | x.shape == (3,) | y.is_encrypted | y.shape == (3,)",
),
pytest.param(
lambda x, y: x * y,
{
"x": {"status": "encrypted", "range": [0, 10], "shape": (3, 2)},
"y": {"status": "encrypted", "range": [0, 10], "shape": (3, 2)},
},
{
"programmable_bootstrap_count": 3 * 2 * 2,
"clear_addition_count": 3 * 2 * 1,
"encrypted_addition_count": 3 * 2 * 3,
"clear_multiplication_count": 0,
"encrypted_negation_count": 3 * 2 * 2,
},
id="x * y | x.is_encrypted | x.shape == (3, 2) | y.is_encrypted | y.shape == (3, 2)",
),
pytest.param(
tagged_function,
{
"x": {"status": "encrypted", "range": [0, 2**3 - 1], "shape": ()},
"y": {"status": "encrypted", "range": [0, 2**4 - 1], "shape": ()},
"z": {"status": "encrypted", "range": [0, 2**5 - 1], "shape": ()},
},
{
"programmable_bootstrap_count_per_tag": {
"a": 3,
"a.b": 2,
"a.b.c": 1,
},
},
id="tagged_function",
),
],
)
def test_statistics(function, parameters, expected_statistics, helpers):
"""
Test statistics of the circuit provided by the compiler.
"""
parameter_encryption_statuses = helpers.generate_encryption_statuses(parameters)
configuration = helpers.configuration()
compiler = fhe.Compiler(function, parameter_encryption_statuses)
inputset = helpers.generate_inputset(parameters)
circuit = compiler.compile(inputset, configuration)
for name, expected_value in expected_statistics.items():
assert hasattr(circuit, name)
attr = getattr(circuit, name)
if callable(attr):
attr = attr()
assert (
attr == expected_value
), f"""
Expected {name} to be {expected_value} but it's {getattr(circuit, name)}
""".strip()
def test_setting_keys(helpers):
"""
Test setting circuit.keys explicitly.
"""
configuration = helpers.configuration()
@fhe.compiler({"x": "encrypted", "y": "encrypted"})
def f(x, y):
return (x + y) ** 2
inputset = [
(
np.random.randint(0, 2**3, size=(10,)),
np.random.randint(0, 2**5, size=(10,)),
)
for _ in range(100)
]
circuit = f.compile(inputset, configuration.fork(use_insecure_key_cache=False))
circuit.keygen(force=True, seed=100)
keys1 = circuit.keys.serialize()
circuit.keygen(force=True, seed=200)
keys2 = circuit.keys.serialize()
assert keys1 != keys2
sample_x = np.random.randint(0, 2**3, size=(10,))
sample_y = np.random.randint(0, 2**5, size=(10,))
sample = circuit.encrypt(sample_x, sample_y)
output = circuit.run(*sample)
circuit.keys = fhe.Keys.deserialize(keys1)
result = circuit.decrypt(output)
assert not np.array_equal(result, (sample_x + sample_y) ** 2)
circuit.keys = fhe.Keys.deserialize(keys2)
result = circuit.decrypt(output)
assert np.array_equal(result, (sample_x + sample_y) ** 2)
def test_simulate_encrypt_run_decrypt(helpers):
"""
Test `simulate_encrypt_run_decrypt` configuration option.
"""
def f(x, y):
return x + y
inputset = fhe.inputset(fhe.uint3, fhe.uint3)
configuration = helpers.configuration().fork(
fhe_execution=False,
fhe_simulation=True,
simulate_encrypt_run_decrypt=True,
)
compiler = fhe.Compiler(f, {"x": "encrypted", "y": "encrypted"})
circuit = compiler.compile(inputset, configuration)
sample_x, sample_y = 3, 4
encrypted_x, encrypted_y = circuit.encrypt(sample_x, sample_y)
encrypted_result = circuit.run(encrypted_x, encrypted_y)
result = circuit.decrypt(encrypted_result)
assert result == sample_x + sample_y
# Make sure computation happened in simulation.
assert isinstance(encrypted_x, int)
assert isinstance(encrypted_y, int)
assert hasattr(circuit, "simulator")
assert not hasattr(circuit, "server")
assert isinstance(encrypted_result, int)