import time from typing import List import numpy as np from concrete import fhe CHUNK_SIZE = 4 KEY_SIZE = 32 VALUE_SIZE = 32 assert KEY_SIZE % CHUNK_SIZE == 0 assert VALUE_SIZE % CHUNK_SIZE == 0 NUMBER_OF_KEY_CHUNKS = KEY_SIZE // CHUNK_SIZE NUMBER_OF_VALUE_CHUNKS = VALUE_SIZE // CHUNK_SIZE def encode(number, width): binary_repr = np.binary_repr(number, width=width) blocks = [binary_repr[i : (i + CHUNK_SIZE)] for i in range(0, len(binary_repr), CHUNK_SIZE)] return np.array([int(block, 2) for block in blocks]) def encode_key(number): return encode(number, width=KEY_SIZE) def encode_value(number): return encode(number, width=VALUE_SIZE) def decode(encoded_number): result = 0 for i in range(len(encoded_number)): result += 2 ** (CHUNK_SIZE * i) * encoded_number[(len(encoded_number) - i) - 1] return result keep_if_match_lut = fhe.LookupTable([0 for _ in range(16)] + [i for i in range(16)]) keep_if_no_match_lut = fhe.LookupTable([i for i in range(16)] + [0 for _ in range(16)]) def _replace_impl(key, value, candidate_key, candidate_value): number_of_matching_chunks = np.sum((candidate_key - key) == 0) fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS) match = number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS packed_match_and_value = (2**CHUNK_SIZE) * match + value value_if_match_else_zeros = keep_if_match_lut[packed_match_and_value] packed_match_and_candidate_value = (2**CHUNK_SIZE) * match + candidate_value zeros_if_match_else_candidate_value = keep_if_no_match_lut[packed_match_and_candidate_value] return value_if_match_else_zeros + zeros_if_match_else_candidate_value def _query_impl(key, candidate_key, candidate_value): number_of_matching_chunks = np.sum((candidate_key - key) == 0) fhe.hint(number_of_matching_chunks, can_store=NUMBER_OF_KEY_CHUNKS) match = number_of_matching_chunks == NUMBER_OF_KEY_CHUNKS packed_match_and_candidate_value = (2**CHUNK_SIZE) * match + candidate_value candidate_value_if_match_else_zeros = keep_if_match_lut[packed_match_and_candidate_value] return fhe.array([match, *candidate_value_if_match_else_zeros]) class KeyValueDatabase: _state: List[np.ndarray] _replace_circuit: fhe.Circuit _query_circuit: fhe.Circuit def __init__(self): self._state = [] configuration = fhe.Configuration( enable_unsafe_features=True, use_insecure_key_cache=True, insecure_key_cache_location=".keys", ) replace_compiler = fhe.Compiler( _replace_impl, { "key": "encrypted", "value": "encrypted", "candidate_key": "encrypted", "candidate_value": "encrypted", }, ) query_compiler = fhe.Compiler( _query_impl, { "key": "encrypted", "candidate_key": "encrypted", "candidate_value": "encrypted", }, ) replace_inputset = [ ( # key np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1), # value np.ones(NUMBER_OF_VALUE_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1), # candidate_key np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1), # candidate_value np.ones(NUMBER_OF_VALUE_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1), ) ] query_inputset = [ ( # key np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1), # candidate_key np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1), # candidate_value np.ones(NUMBER_OF_VALUE_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1), ) ] print() print("Compiling replacement circuit...") start = time.time() self._replace_circuit = replace_compiler.compile(replace_inputset, configuration) end = time.time() print(f"(took {end - start:.3f} seconds)") print() print("Compiling query circuit...") start = time.time() self._query_circuit = query_compiler.compile(query_inputset, configuration) end = time.time() print(f"(took {end - start:.3f} seconds)") print() print("Generating replacement keys...") start = time.time() self._replace_circuit.keygen() end = time.time() print(f"(took {end - start:.3f} seconds)") print() print("Generating query keys...") start = time.time() self._query_circuit.keygen() end = time.time() print(f"(took {end - start:.3f} seconds)") def insert(self, key, value): print() print("Inserting...") start = time.time() self._state.append([encode_key(key), encode_value(value)]) end = time.time() print(f"(took {end - start:.3f} seconds)") def replace(self, key, value): print() print("Replacing...") start = time.time() encoded_key = encode_key(key) encoded_value = encode_value(value) for entry in self._state: entry[1] = self._replace_circuit.encrypt_run_decrypt(encoded_key, encoded_value, *entry) end = time.time() print(f"(took {end - start:.3f} seconds)") def query(self, key): print() print("Querying...") start = time.time() encoded_key = encode_key(key) accumulation = np.zeros(1 + NUMBER_OF_VALUE_CHUNKS, dtype=np.int64) for entry in self._state: contribution = self._query_circuit.encrypt_run_decrypt(encoded_key, *entry) accumulation += contribution match_count = accumulation[0] if match_count > 1: message = "Key inserted multiple times" raise RuntimeError(message) result = decode(accumulation[1:]) if match_count == 1 else None end = time.time() print(f"(took {end - start:.3f} seconds)") return result db = KeyValueDatabase() # Test: Insert/Query db.insert(3, 4) assert db.query(3) == 4 db.replace(3, 1) assert db.query(3) == 1 # Test: Insert/Query db.insert(25, 40) assert db.query(25) == 40 # Test: Query Not Found assert db.query(4) is None # Test: Replace/Query db.replace(3, 5) assert db.query(3) == 5 # Define lower/upper bounds for the key minimum_key = 0 maximum_key = 2**KEY_SIZE - 1 # Define lower/upper bounds for the value minimum_value = 0 maximum_value = 2**VALUE_SIZE - 1 # Test: Insert/Replace/Query Bounds # Insert (key: minimum_key , value: minimum_value) into the database db.insert(minimum_key, minimum_value) # Query the database for the key=minimum_key # The value minimum_value should be returned assert db.query(minimum_key) == minimum_value # Insert (key: maximum_key , value: maximum_value) into the database db.insert(maximum_key, maximum_value) # Query the database for the key=maximum_key # The value maximum_value should be returned assert db.query(maximum_key) == maximum_value # Replace the value of key=minimum_key with maximum_value db.replace(minimum_key, maximum_value) # Query the database for the key=minimum_key # The value maximum_value should be returned assert db.query(minimum_key) == maximum_value # Replace the value of key=maximum_key with minimum_value db.replace(maximum_key, minimum_value) # Query the database for the key=maximum_key # The value minimum_value should be returned assert db.query(maximum_key) == minimum_value