Files

304 lines
8.6 KiB
Python

import time
import numpy as np
from concrete import fhe
NUMBER_OF_ENTRIES = 5
CHUNK_SIZE = 4
KEY_SIZE = 32
VALUE_SIZE = 32
assert KEY_SIZE % CHUNK_SIZE == 0
assert VALUE_SIZE % CHUNK_SIZE == 0
NUMBER_OF_KEY_CHUNKS = KEY_SIZE // CHUNK_SIZE
NUMBER_OF_VALUE_CHUNKS = VALUE_SIZE // CHUNK_SIZE
STATE_SHAPE = (NUMBER_OF_ENTRIES, 1 + NUMBER_OF_KEY_CHUNKS + NUMBER_OF_VALUE_CHUNKS)
FLAG = 0
KEY = slice(1, 1 + NUMBER_OF_KEY_CHUNKS)
VALUE = slice(1 + NUMBER_OF_KEY_CHUNKS, None)
def encode(number: int, width: int) -> np.ndarray:
binary_repr = np.binary_repr(number, width=width)
blocks = [binary_repr[i : (i + CHUNK_SIZE)] for i in range(0, len(binary_repr), CHUNK_SIZE)]
return np.array([int(block, 2) for block in blocks])
def encode_key(number: int) -> np.ndarray:
return encode(number, width=KEY_SIZE)
def encode_value(number: int) -> np.ndarray:
return encode(number, width=VALUE_SIZE)
def decode(encoded_number: np.ndarray) -> int:
result = 0
for i in range(len(encoded_number)):
result += 2 ** (CHUNK_SIZE * i) * encoded_number[(len(encoded_number) - i) - 1]
return result
keep_selected_lut = fhe.LookupTable([0 for _ in range(16)] + [i for i in range(16)])
def _insert_impl(state, key, value):
flags = state[:, FLAG]
selection = fhe.zeros(NUMBER_OF_ENTRIES)
found = fhe.zero()
for i in range(NUMBER_OF_ENTRIES):
packed_flag_and_already_found = (found * 2) + flags[i]
is_selected = packed_flag_and_already_found == 0
selection[i] = is_selected
found += is_selected
state_update = fhe.zeros(STATE_SHAPE)
state_update[:, FLAG] = selection
selection = selection.reshape((-1, 1))
packed_selection_and_key = (selection * (2**CHUNK_SIZE)) + key
key_update = keep_selected_lut[packed_selection_and_key]
packed_selection_and_value = selection * (2**CHUNK_SIZE) + value
value_update = keep_selected_lut[packed_selection_and_value]
state_update[:, KEY] = key_update
state_update[:, VALUE] = value_update
new_state = state + state_update
return new_state
def _replace_impl(state, key, value):
flags = state[:, FLAG]
keys = state[:, KEY]
values = state[:, VALUE]
number_of_matching_chunks = np.sum((keys - key) == 0, axis=1)
fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS)
equal_rows = number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS
selection = (flags * 2 + equal_rows == 3).reshape((-1, 1))
packed_selection_and_value = selection * (2**CHUNK_SIZE) + value
set_value = keep_selected_lut[packed_selection_and_value]
inverse_selection = 1 - selection
packed_inverse_selection_and_values = inverse_selection * (2**CHUNK_SIZE) + values
kept_values = keep_selected_lut[packed_inverse_selection_and_values]
new_values = kept_values + set_value
state[:, VALUE] = new_values
return state
def _query_impl(state, key):
keys = state[:, KEY]
values = state[:, VALUE]
number_of_matching_chunks = np.sum((keys - key) == 0, axis=1)
fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS)
selection = (number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS).reshape((-1, 1))
found = np.sum(selection)
packed_selection_and_values = selection * (2**CHUNK_SIZE) + values
value_selection = keep_selected_lut[packed_selection_and_values]
value = np.sum(value_selection, axis=0)
return fhe.array([found, *value])
class KeyValueDatabase:
_state: np.ndarray
_insert_circuit: fhe.Circuit
_replace_circuit: fhe.Circuit
_query_circuit: fhe.Circuit
def __init__(self):
self._state = np.zeros(STATE_SHAPE, dtype=np.int64)
inputset_binary = [
(
# state
np.zeros(STATE_SHAPE, dtype=np.int64),
# key
np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1),
)
]
inputset_ternary = [
(
# state
np.zeros(STATE_SHAPE, dtype=np.int64),
# key
np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1),
# value
np.ones(NUMBER_OF_VALUE_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1),
)
]
configuration = fhe.Configuration(
enable_unsafe_features=True,
use_insecure_key_cache=True,
insecure_key_cache_location=".keys",
)
insert_compiler = fhe.Compiler(
_insert_impl, {"state": "encrypted", "key": "encrypted", "value": "encrypted"}
)
replace_compiler = fhe.Compiler(
_replace_impl, {"state": "encrypted", "key": "encrypted", "value": "encrypted"}
)
query_compiler = fhe.Compiler(_query_impl, {"state": "encrypted", "key": "encrypted"})
print()
print("Compiling insertion circuit...")
start = time.time()
self._insert_circuit = insert_compiler.compile(inputset_ternary, configuration)
end = time.time()
print(f"(took {end - start:.3f} seconds)")
print()
print("Compiling replacement circuit...")
start = time.time()
self._replace_circuit = replace_compiler.compile(inputset_ternary, configuration)
end = time.time()
print(f"(took {end - start:.3f} seconds)")
print()
print("Compiling query circuit...")
start = time.time()
self._query_circuit = query_compiler.compile(inputset_binary, configuration)
end = time.time()
print(f"(took {end - start:.3f} seconds)")
print()
print("Generating insertion keys...")
start = time.time()
self._insert_circuit.keygen()
end = time.time()
print(f"(took {end - start:.3f} seconds)")
print()
print("Generating replacement keys...")
start = time.time()
self._replace_circuit.keygen()
end = time.time()
print(f"(took {end - start:.3f} seconds)")
print()
print("Generating query keys...")
start = time.time()
self._query_circuit.keygen()
end = time.time()
print(f"(took {end - start:.3f} seconds)")
def insert(self, key, value):
print()
print("Inserting...")
start = time.time()
self._state = self._insert_circuit.encrypt_run_decrypt(
self._state, encode_key(key), encode_value(value)
)
end = time.time()
print(f"(took {end - start:.3f} seconds)")
def replace(self, key, value):
print()
print("Replacing...")
start = time.time()
self._state = self._replace_circuit.encrypt_run_decrypt(
self._state, encode_key(key), encode_value(value)
)
end = time.time()
print(f"(took {end - start:.3f} seconds)")
def query(self, key):
print()
print("Querying...")
start = time.time()
result = self._query_circuit.encrypt_run_decrypt(self._state, encode_key(key))
end = time.time()
print(f"(took {end - start:.3f} seconds)")
if result[0] == 0:
return None
return decode(result[1:])
db = KeyValueDatabase()
# Test: Insert/Query
db.insert(3, 4)
assert db.query(3) == 4
db.replace(3, 1)
assert db.query(3) == 1
# Test: Insert/Query
db.insert(25, 40)
assert db.query(25) == 40
# Test: Query Not Found
assert db.query(4) is None
# Test: Replace/Query
db.replace(3, 5)
assert db.query(3) == 5
# Define lower/upper bounds for the key
minimum_key = 1
maximum_key = 2**KEY_SIZE - 1
# Define lower/upper bounds for the value
minimum_value = 1
maximum_value = 2**VALUE_SIZE - 1
# Test: Insert/Replace/Query Bounds
# Insert (key: minimum_key , value: minimum_value) into the database
db.insert(minimum_key, minimum_value)
# Query the database for the key=minimum_key
# The value minimum_value should be returned
assert db.query(minimum_key) == minimum_value
# Insert (key: maximum_key , value: maximum_value) into the database
db.insert(maximum_key, maximum_value)
# Query the database for the key=maximum_key
# The value maximum_value should be returned
assert db.query(maximum_key) == maximum_value
# Replace the value of key=minimum_key with maximum_value
db.replace(minimum_key, maximum_value)
# Query the database for the key=minimum_key
# The value maximum_value should be returned
assert db.query(minimum_key) == maximum_value
# Replace the value of key=maximum_key with minimum_value
db.replace(maximum_key, minimum_value)
# Query the database for the key=maximum_key
# The value minimum_value should be returned
assert db.query(maximum_key) == minimum_value