mirror of
https://github.com/zama-ai/concrete.git
synced 2026-01-13 14:58:01 -05:00
424 lines
12 KiB
Python
424 lines
12 KiB
Python
"""
|
|
Tests of the examples.
|
|
"""
|
|
|
|
import os
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from examples.game_of_life.game_of_life import GameOfLife
|
|
from examples.key_value_database.static_size import StaticKeyValueDatabase
|
|
from examples.levenshtein_distance.levenshtein_distance import (
|
|
Alphabet,
|
|
LevenshteinDistance,
|
|
levenshtein_clear,
|
|
)
|
|
|
|
|
|
def test_static_kvdb(helpers):
|
|
"""
|
|
Test static key-value database example.
|
|
"""
|
|
|
|
configuration = helpers.configuration()
|
|
|
|
def inspect(db: StaticKeyValueDatabase) -> np.ndarray:
|
|
encrypted_state = db.inspect.run(db.state)
|
|
clear_state = db.inspect.decrypt(encrypted_state)
|
|
return clear_state # type: ignore
|
|
|
|
def insert(db: StaticKeyValueDatabase, key: int, value: int):
|
|
encoded_key, encoded_value = db.encode_key(key), db.encode_value(value)
|
|
_, encrypted_key, encoded_value = db.insert.encrypt( # type: ignore
|
|
None,
|
|
encoded_key,
|
|
encoded_value,
|
|
)
|
|
db.state = db.insert.run(db.state, encrypted_key, encoded_value) # type: ignore
|
|
|
|
def replace(db: StaticKeyValueDatabase, key: int, value: int):
|
|
encoded_key, encoded_value = db.encode_key(key), db.encode_value(value)
|
|
_, encrypted_key, encoded_value = db.replace.encrypt( # type: ignore
|
|
None,
|
|
encoded_key,
|
|
encoded_value,
|
|
)
|
|
db.state = db.replace.run(db.state, encrypted_key, encoded_value) # type: ignore
|
|
|
|
def query(db: StaticKeyValueDatabase, key: int) -> Optional[int]:
|
|
encoded_key = db.encode_key(key)
|
|
_, encrypted_key = db.query.encrypt(None, encoded_key) # type: ignore
|
|
encrypted_found, encrypted_value = db.query.run(db.state, encrypted_key) # type: ignore
|
|
|
|
found, value = db.query.decrypt(encrypted_found, encrypted_value) # type: ignore
|
|
if not found:
|
|
return None
|
|
|
|
return db.decode_value(value) # type: ignore
|
|
|
|
db = StaticKeyValueDatabase(
|
|
number_of_entries=4,
|
|
key_size=8,
|
|
value_size=8,
|
|
chunk_size=2,
|
|
configuration=configuration,
|
|
)
|
|
db.keygen()
|
|
|
|
db.initialize()
|
|
assert np.array_equal(
|
|
inspect(db),
|
|
[
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
],
|
|
)
|
|
|
|
assert query(db, 3) is None
|
|
|
|
insert(db, 3, 4)
|
|
assert np.array_equal(
|
|
inspect(db),
|
|
[
|
|
[1] + [0, 0, 0, 3] + [0, 0, 1, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
],
|
|
)
|
|
|
|
assert query(db, 3) == 4
|
|
|
|
replace(db, 3, 1)
|
|
assert np.array_equal(
|
|
inspect(db),
|
|
[
|
|
[1] + [0, 0, 0, 3] + [0, 0, 0, 1],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
],
|
|
)
|
|
|
|
assert query(db, 3) == 1
|
|
|
|
insert(db, 25, 40)
|
|
assert np.array_equal(
|
|
inspect(db),
|
|
[
|
|
[1] + [0, 0, 0, 3] + [0, 0, 0, 1],
|
|
[1] + [0, 1, 2, 1] + [0, 2, 2, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
],
|
|
)
|
|
|
|
assert query(db, 25) == 40
|
|
|
|
minimum_key = 0
|
|
maximum_key = 2**db.key_size - 1
|
|
|
|
minimum_value = 0
|
|
maximum_value = 2**db.value_size - 1
|
|
|
|
db.initialize()
|
|
assert np.array_equal(
|
|
inspect(db),
|
|
[
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
],
|
|
)
|
|
|
|
insert(db, minimum_key, minimum_value)
|
|
assert np.array_equal(
|
|
inspect(db),
|
|
[
|
|
[1] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
],
|
|
)
|
|
|
|
assert query(db, minimum_key) == minimum_value
|
|
|
|
replace(db, minimum_key, maximum_value)
|
|
assert np.array_equal(
|
|
inspect(db),
|
|
[
|
|
[1] + [0, 0, 0, 0] + [3, 3, 3, 3],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
],
|
|
)
|
|
|
|
assert query(db, minimum_key) == maximum_value
|
|
|
|
insert(db, maximum_key, maximum_value)
|
|
assert np.array_equal(
|
|
inspect(db),
|
|
[
|
|
[1] + [0, 0, 0, 0] + [3, 3, 3, 3],
|
|
[1] + [3, 3, 3, 3] + [3, 3, 3, 3],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
],
|
|
)
|
|
|
|
assert query(db, maximum_key) == maximum_value
|
|
|
|
replace(db, maximum_key, minimum_value)
|
|
assert np.array_equal(
|
|
inspect(db),
|
|
[
|
|
[1] + [0, 0, 0, 0] + [3, 3, 3, 3],
|
|
[1] + [3, 3, 3, 3] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
[0] + [0, 0, 0, 0] + [0, 0, 0, 0],
|
|
],
|
|
)
|
|
|
|
assert query(db, maximum_key) == minimum_value
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"mode",
|
|
[
|
|
"simulate",
|
|
"fhe",
|
|
],
|
|
)
|
|
def test_levenshtein_distance(mode, helpers):
|
|
"""
|
|
Test levenshtein distance example.
|
|
"""
|
|
|
|
configuration = helpers.configuration()
|
|
if mode == "simulate":
|
|
configuration = configuration.fork(fhe_execution=False, fhe_simulation=True)
|
|
|
|
alphabet = Alphabet.lowercase()
|
|
max_string_length = 5
|
|
|
|
levenshtein_distance = LevenshteinDistance(alphabet, max_string_length, configuration)
|
|
levenshtein_distance.module.keygen()
|
|
|
|
samples = [
|
|
# same
|
|
("hello", "hello", 0),
|
|
# one character missing from the end
|
|
("hell", "hello", 1),
|
|
("hello", "hell", 1),
|
|
# one character missing from the start
|
|
("ello", "hello", 1),
|
|
("hello", "ello", 1),
|
|
# one character missing from the middle
|
|
("hllo", "hello", 1),
|
|
("hello", "hllo", 1),
|
|
# two characters missing from the start and the end
|
|
("ell", "hello", 2),
|
|
("hello", "ell", 2),
|
|
# three characters missing from the start, the end and the middle
|
|
("el", "hello", 3),
|
|
("hello", "el", 3),
|
|
# shifted one character
|
|
("hello", "elloh", 2),
|
|
("elloh", "hello", 2),
|
|
# shifted two characters
|
|
("hello", "llohe", 4),
|
|
("llohe", "hello", 4),
|
|
# shifted three characters
|
|
("hello", "lohel", 4),
|
|
("lohel", "hello", 4),
|
|
# shifted four characters
|
|
("hello", "ohell", 2),
|
|
("ohell", "hello", 2),
|
|
# completely different
|
|
("hello", "numpy", 5),
|
|
]
|
|
|
|
for str1, str2, expected_distance in samples:
|
|
actual_distance = levenshtein_distance.calculate(str1, str2, mode, show_distance=True)
|
|
assert actual_distance == expected_distance
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"alphabet_name",
|
|
Alphabet.return_available_alphabets(),
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"max_length",
|
|
[2, 3],
|
|
)
|
|
def test_levenshtein_distance_randomly(alphabet_name, max_length, helpers):
|
|
"""
|
|
Test levenshtein distance example with randomly generated strings.
|
|
"""
|
|
|
|
configuration = helpers.configuration().fork(fhe_execution=False, fhe_simulation=True)
|
|
|
|
alphabet = Alphabet.init_by_name(alphabet_name)
|
|
levenshtein_distance = LevenshteinDistance(alphabet, max_length, configuration)
|
|
levenshtein_distance.module.keygen()
|
|
|
|
for str1, str2 in alphabet.prepare_random_patterns(0, max_length, nb_strings=3):
|
|
expected_distance = levenshtein_clear(str1, str2)
|
|
actual_distance = levenshtein_distance.calculate(str1, str2, "simulate", show_distance=True)
|
|
assert actual_distance == expected_distance
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"implementation",
|
|
GameOfLife.implementations(),
|
|
)
|
|
@pytest.mark.parametrize(
|
|
"dimension,sample_input_output_pairs",
|
|
[
|
|
pytest.param(
|
|
4,
|
|
[
|
|
(
|
|
[
|
|
[0, 0, 0, 0],
|
|
[0, 0, 0, 0],
|
|
[0, 0, 0, 0],
|
|
[0, 0, 0, 0],
|
|
],
|
|
# should become
|
|
[
|
|
[0, 0, 0, 0],
|
|
[0, 0, 0, 0],
|
|
[0, 0, 0, 0],
|
|
[0, 0, 0, 0],
|
|
],
|
|
),
|
|
(
|
|
[
|
|
[1, 0, 0, 0],
|
|
[0, 1, 0, 0],
|
|
[0, 0, 1, 0],
|
|
[0, 0, 0, 1],
|
|
],
|
|
# should become
|
|
[
|
|
[0, 0, 0, 0],
|
|
[0, 1, 0, 0],
|
|
[0, 0, 1, 0],
|
|
[0, 0, 0, 0],
|
|
],
|
|
),
|
|
(
|
|
[
|
|
[0, 0, 0, 0],
|
|
[0, 1, 1, 0],
|
|
[0, 1, 1, 0],
|
|
[0, 0, 0, 0],
|
|
],
|
|
# should become
|
|
[
|
|
[0, 0, 0, 0],
|
|
[0, 1, 1, 0],
|
|
[0, 1, 1, 0],
|
|
[0, 0, 0, 0],
|
|
],
|
|
),
|
|
(
|
|
[
|
|
[1, 0, 0, 0],
|
|
[0, 1, 1, 0],
|
|
[0, 1, 1, 0],
|
|
[0, 0, 0, 1],
|
|
],
|
|
# should become
|
|
[
|
|
[0, 0, 0, 0],
|
|
[0, 0, 1, 0],
|
|
[0, 1, 0, 0],
|
|
[0, 0, 0, 0],
|
|
],
|
|
),
|
|
(
|
|
[
|
|
[1, 1, 0, 0],
|
|
[1, 0, 0, 0],
|
|
[0, 0, 0, 1],
|
|
[0, 0, 1, 1],
|
|
],
|
|
# should become
|
|
[
|
|
[0, 0, 0, 0],
|
|
[0, 1, 0, 0],
|
|
[0, 0, 1, 0],
|
|
[0, 0, 0, 0],
|
|
],
|
|
),
|
|
(
|
|
[
|
|
[1, 1, 1, 1],
|
|
[1, 1, 1, 1],
|
|
[1, 1, 1, 1],
|
|
[1, 1, 1, 1],
|
|
],
|
|
# should become
|
|
[
|
|
[0, 0, 0, 0],
|
|
[0, 0, 0, 0],
|
|
[0, 0, 0, 0],
|
|
[0, 0, 0, 0],
|
|
],
|
|
),
|
|
],
|
|
),
|
|
],
|
|
)
|
|
def test_game_of_life(implementation, dimension, sample_input_output_pairs, helpers):
|
|
"""
|
|
Test game of life implementation.
|
|
"""
|
|
|
|
configuration = helpers.configuration()
|
|
game_of_life = GameOfLife.implementation(implementation, dimension, configuration)
|
|
game_of_life.circuit.keygen()
|
|
|
|
for sample_input, expected_output in sample_input_output_pairs:
|
|
sample = np.array(sample_input).reshape((1, 1, dimension, dimension))
|
|
result = game_of_life.circuit.encrypt_run_decrypt(sample)
|
|
actual_output = result.reshape((dimension, dimension))
|
|
assert np.array_equal(
|
|
actual_output, expected_output
|
|
), f"""
|
|
|
|
Expected Output
|
|
===============
|
|
{expected_output}
|
|
|
|
Actual Output
|
|
=============
|
|
{actual_output}
|
|
|
|
"""
|
|
|
|
|
|
def test_tfhers_example():
|
|
path_to_test_script = f"{os.path.dirname(os.path.abspath(__file__))}/../../examples/tfhers/"
|
|
test_script_filename = "test.sh"
|
|
assert (
|
|
os.system(f"cd {path_to_test_script} && sh {test_script_filename}") == 0
|
|
), "test script failed"
|
|
|
|
|
|
def test_tfhers_ml_example():
|
|
path_to_test_script = f"{os.path.dirname(os.path.abspath(__file__))}/../../examples/tfhers-ml/"
|
|
test_script_filename = "test.sh"
|
|
assert (
|
|
os.system(f"cd {path_to_test_script} && sh {test_script_filename}") == 0
|
|
), "test script failed"
|