Files
concrete/frontends/concrete-python/tests/compilation/test_keys.py
youben11 d056758c3c perf(frontend/compiler): support ser keyset using path
reduce memory usage by avoiding unecessary copy
2024-12-10 11:51:53 +01:00

211 lines
6.0 KiB
Python

"""
Tests of `Keys` class.
"""
import tempfile
from pathlib import Path
import pytest
from concrete import fhe
def test_keys_save_load(helpers):
"""
Test saving and loading keys.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x**2
inputset = range(10)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = Path(tmp_dir)
keys_path = tmp_dir_path / "keys"
circuit1 = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
circuit1.keygen()
sample = circuit1.encrypt(5)
evaluation = circuit1.run(sample)
circuit1.keys.save(str(keys_path))
circuit2 = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
circuit2.keys.load(str(keys_path))
assert circuit2.decrypt(evaluation) == 25
def test_keys_bad_save_load(helpers):
"""
Test saving/loading keys where location is (not) empty.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x**2
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = Path(tmp_dir)
keys_path = tmp_dir_path / "keys"
inputset = range(10)
circuit = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
with pytest.raises(ValueError) as excinfo:
circuit.keys.load(keys_path)
expected_message = f"Unable to load keys from {keys_path} because it doesn't exist"
helpers.check_str(expected_message, str(excinfo.value))
with open(keys_path, "w", encoding="utf-8") as f:
f.write("foo")
circuit.keys.generate()
with pytest.raises(ValueError) as excinfo:
circuit.keys.save(keys_path)
expected_message = f"Unable to save keys to {keys_path} because it already exists"
helpers.check_str(expected_message, str(excinfo.value))
def test_keys_load_if_exists_generate_and_save_otherwise(helpers):
"""
Test saving and loading keys using `load_if_exists_generate_and_save_otherwise`.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x**2
inputset = range(10)
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_dir_path = Path(tmp_dir)
keys_path = tmp_dir_path / "keys"
circuit1 = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
circuit1.keys.load_if_exists_generate_and_save_otherwise(str(keys_path))
sample = circuit1.encrypt(5)
evaluation = circuit1.run(sample)
circuit2 = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
circuit2.keys.load_if_exists_generate_and_save_otherwise(str(keys_path))
assert circuit2.decrypt(evaluation) == 25
def test_keys_bad_seed(helpers):
"""
Test serializing and deserializing keys.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x**2
inputset = range(10)
circuit = f.compile(inputset, helpers.configuration())
server = circuit.server
client1 = fhe.Client(server.client_specs)
with pytest.raises(ValueError) as excinfo:
client1.keys.generate(secret_seed=-1)
assert str(excinfo.value) == "secret_seed must be a positive 128 bits integer"
with pytest.raises(ValueError) as excinfo:
client1.keys.generate(secret_seed=2**128)
assert str(excinfo.value) == "secret_seed must be a positive 128 bits integer"
with pytest.raises(ValueError) as excinfo:
client1.keys.generate(encryption_seed=2**128)
assert str(excinfo.value) == "encryption_seed must be a positive 128 bits integer"
with pytest.raises(ValueError) as excinfo:
client1.keys.generate(encryption_seed=-1)
assert str(excinfo.value) == "encryption_seed must be a positive 128 bits integer"
def test_keys_serialize_deserialize(helpers):
"""
Test serializing and deserializing keys.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x**2
inputset = range(10)
circuit = f.compile(inputset, helpers.configuration())
server = circuit.server
client1 = fhe.Client(server.client_specs)
client1.keys.generate()
sample = client1.encrypt(5)
evaluation = server.run(sample, evaluation_keys=client1.evaluation_keys)
client2 = fhe.Client(server.client_specs)
client2.keys = fhe.Keys.deserialize(client1.keys.serialize())
assert client2.decrypt(evaluation) == 25
def test_keys_serialize_before_generation(helpers):
"""
Test serialization of keys before their generation.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x + 42
inputset = range(10)
circuit = f.compile(inputset, configuration=helpers.configuration())
with pytest.raises(RuntimeError) as excinfo:
circuit.keys.serialize()
expected_message = "Keys cannot be serialized before they are generated"
helpers.check_str(expected_message, str(excinfo.value))
with pytest.raises(RuntimeError) as excinfo:
# path doesn't matter as it will fail
circuit.keys.serialize_to_file(Path("_keys_file"))
expected_message = "Keys cannot be serialized before they are generated"
helpers.check_str(expected_message, str(excinfo.value))
def test_keys_generate_manual_seed(helpers):
"""
Test key generation with custom seed.
"""
@fhe.compiler({"x": "encrypted"})
def f(x):
return x**2
inputset = range(10)
circuit = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
assert not circuit.keys.are_generated
circuit.keygen(seed=42, encryption_seed=24)
assert circuit.keys.are_generated
sample = circuit.encrypt(5)
evaluation = circuit.run(sample)
same_circuit = f.compile(inputset, helpers.configuration().fork(use_insecure_key_cache=False))
assert not same_circuit.keys.are_generated
same_circuit.keygen(seed=42)
assert same_circuit.keys.are_generated
assert same_circuit.decrypt(evaluation) == 25