Files
concrete/frontends/concrete-python/tests/execution/test_examples.py
2024-12-18 10:44:53 +01:00

424 lines
12 KiB
Python

"""
Tests of the examples.
"""
import os
from typing import Optional
import numpy as np
import pytest
from examples.game_of_life.game_of_life import GameOfLife
from examples.key_value_database.static_size import StaticKeyValueDatabase
from examples.levenshtein_distance.levenshtein_distance import (
Alphabet,
LevenshteinDistance,
levenshtein_clear,
)
def test_static_kvdb(helpers):
"""
Test static key-value database example.
"""
configuration = helpers.configuration()
def inspect(db: StaticKeyValueDatabase) -> np.ndarray:
encrypted_state = db.inspect.run(db.state)
clear_state = db.inspect.decrypt(encrypted_state)
return clear_state # type: ignore
def insert(db: StaticKeyValueDatabase, key: int, value: int):
encoded_key, encoded_value = db.encode_key(key), db.encode_value(value)
_, encrypted_key, encoded_value = db.insert.encrypt( # type: ignore
None,
encoded_key,
encoded_value,
)
db.state = db.insert.run(db.state, encrypted_key, encoded_value) # type: ignore
def replace(db: StaticKeyValueDatabase, key: int, value: int):
encoded_key, encoded_value = db.encode_key(key), db.encode_value(value)
_, encrypted_key, encoded_value = db.replace.encrypt( # type: ignore
None,
encoded_key,
encoded_value,
)
db.state = db.replace.run(db.state, encrypted_key, encoded_value) # type: ignore
def query(db: StaticKeyValueDatabase, key: int) -> Optional[int]:
encoded_key = db.encode_key(key)
_, encrypted_key = db.query.encrypt(None, encoded_key) # type: ignore
encrypted_found, encrypted_value = db.query.run(db.state, encrypted_key) # type: ignore
found, value = db.query.decrypt(encrypted_found, encrypted_value) # type: ignore
if not found:
return None
return db.decode_value(value) # type: ignore
db = StaticKeyValueDatabase(
number_of_entries=4,
key_size=8,
value_size=8,
chunk_size=2,
configuration=configuration,
)
db.keygen()
db.initialize()
assert np.array_equal(
inspect(db),
[
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
],
)
assert query(db, 3) is None
insert(db, 3, 4)
assert np.array_equal(
inspect(db),
[
[1] + [0, 0, 0, 3] + [0, 0, 1, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
],
)
assert query(db, 3) == 4
replace(db, 3, 1)
assert np.array_equal(
inspect(db),
[
[1] + [0, 0, 0, 3] + [0, 0, 0, 1],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
],
)
assert query(db, 3) == 1
insert(db, 25, 40)
assert np.array_equal(
inspect(db),
[
[1] + [0, 0, 0, 3] + [0, 0, 0, 1],
[1] + [0, 1, 2, 1] + [0, 2, 2, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
],
)
assert query(db, 25) == 40
minimum_key = 0
maximum_key = 2**db.key_size - 1
minimum_value = 0
maximum_value = 2**db.value_size - 1
db.initialize()
assert np.array_equal(
inspect(db),
[
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
],
)
insert(db, minimum_key, minimum_value)
assert np.array_equal(
inspect(db),
[
[1] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
],
)
assert query(db, minimum_key) == minimum_value
replace(db, minimum_key, maximum_value)
assert np.array_equal(
inspect(db),
[
[1] + [0, 0, 0, 0] + [3, 3, 3, 3],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
],
)
assert query(db, minimum_key) == maximum_value
insert(db, maximum_key, maximum_value)
assert np.array_equal(
inspect(db),
[
[1] + [0, 0, 0, 0] + [3, 3, 3, 3],
[1] + [3, 3, 3, 3] + [3, 3, 3, 3],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
],
)
assert query(db, maximum_key) == maximum_value
replace(db, maximum_key, minimum_value)
assert np.array_equal(
inspect(db),
[
[1] + [0, 0, 0, 0] + [3, 3, 3, 3],
[1] + [3, 3, 3, 3] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
],
)
assert query(db, maximum_key) == minimum_value
@pytest.mark.parametrize(
"mode",
[
"simulate",
"fhe",
],
)
def test_levenshtein_distance(mode, helpers):
"""
Test levenshtein distance example.
"""
configuration = helpers.configuration()
if mode == "simulate":
configuration = configuration.fork(fhe_execution=False, fhe_simulation=True)
alphabet = Alphabet.lowercase()
max_string_length = 5
levenshtein_distance = LevenshteinDistance(alphabet, max_string_length, configuration)
levenshtein_distance.module.keygen()
samples = [
# same
("hello", "hello", 0),
# one character missing from the end
("hell", "hello", 1),
("hello", "hell", 1),
# one character missing from the start
("ello", "hello", 1),
("hello", "ello", 1),
# one character missing from the middle
("hllo", "hello", 1),
("hello", "hllo", 1),
# two characters missing from the start and the end
("ell", "hello", 2),
("hello", "ell", 2),
# three characters missing from the start, the end and the middle
("el", "hello", 3),
("hello", "el", 3),
# shifted one character
("hello", "elloh", 2),
("elloh", "hello", 2),
# shifted two characters
("hello", "llohe", 4),
("llohe", "hello", 4),
# shifted three characters
("hello", "lohel", 4),
("lohel", "hello", 4),
# shifted four characters
("hello", "ohell", 2),
("ohell", "hello", 2),
# completely different
("hello", "numpy", 5),
]
for str1, str2, expected_distance in samples:
actual_distance = levenshtein_distance.calculate(str1, str2, mode, show_distance=True)
assert actual_distance == expected_distance
@pytest.mark.parametrize(
"alphabet_name",
Alphabet.return_available_alphabets(),
)
@pytest.mark.parametrize(
"max_length",
[2, 3],
)
def test_levenshtein_distance_randomly(alphabet_name, max_length, helpers):
"""
Test levenshtein distance example with randomly generated strings.
"""
configuration = helpers.configuration().fork(fhe_execution=False, fhe_simulation=True)
alphabet = Alphabet.init_by_name(alphabet_name)
levenshtein_distance = LevenshteinDistance(alphabet, max_length, configuration)
levenshtein_distance.module.keygen()
for str1, str2 in alphabet.prepare_random_patterns(0, max_length, nb_strings=3):
expected_distance = levenshtein_clear(str1, str2)
actual_distance = levenshtein_distance.calculate(str1, str2, "simulate", show_distance=True)
assert actual_distance == expected_distance
@pytest.mark.parametrize(
"implementation",
GameOfLife.implementations(),
)
@pytest.mark.parametrize(
"dimension,sample_input_output_pairs",
[
pytest.param(
4,
[
(
[
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
],
# should become
[
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
],
),
(
[
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
],
# should become
[
[0, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 0],
],
),
(
[
[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 0],
[0, 0, 0, 0],
],
# should become
[
[0, 0, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 0],
[0, 0, 0, 0],
],
),
(
[
[1, 0, 0, 0],
[0, 1, 1, 0],
[0, 1, 1, 0],
[0, 0, 0, 1],
],
# should become
[
[0, 0, 0, 0],
[0, 0, 1, 0],
[0, 1, 0, 0],
[0, 0, 0, 0],
],
),
(
[
[1, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 1],
[0, 0, 1, 1],
],
# should become
[
[0, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 0],
],
),
(
[
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
[1, 1, 1, 1],
],
# should become
[
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
],
),
],
),
],
)
def test_game_of_life(implementation, dimension, sample_input_output_pairs, helpers):
"""
Test game of life implementation.
"""
configuration = helpers.configuration()
game_of_life = GameOfLife.implementation(implementation, dimension, configuration)
game_of_life.circuit.keygen()
for sample_input, expected_output in sample_input_output_pairs:
sample = np.array(sample_input).reshape((1, 1, dimension, dimension))
result = game_of_life.circuit.encrypt_run_decrypt(sample)
actual_output = result.reshape((dimension, dimension))
assert np.array_equal(
actual_output, expected_output
), f"""
Expected Output
===============
{expected_output}
Actual Output
=============
{actual_output}
"""
def test_tfhers_example():
path_to_test_script = f"{os.path.dirname(os.path.abspath(__file__))}/../../examples/tfhers/"
test_script_filename = "test.sh"
assert (
os.system(f"cd {path_to_test_script} && sh {test_script_filename}") == 0
), "test script failed"
def test_tfhers_ml_example():
path_to_test_script = f"{os.path.dirname(os.path.abspath(__file__))}/../../examples/tfhers-ml/"
test_script_filename = "test.sh"
assert (
os.system(f"cd {path_to_test_script} && sh {test_script_filename}") == 0
), "test script failed"