mirror of
https://github.com/zama-ai/concrete.git
synced 2026-02-08 19:44:57 -05:00
docs: add key-value database tutorial
This commit is contained in:
@@ -21,6 +21,7 @@
|
||||
* [Floating Points](tutorial/floating\_points.md)
|
||||
* [Virtual Circuits](tutorial/virtual\_circuits.md)
|
||||
* [Direct Circuits](tutorial/direct\_circuits.md)
|
||||
* [Key Value Database](tutorial/key\_value\_database.md)
|
||||
|
||||
## How To
|
||||
|
||||
|
||||
1084
docs/tutorial/key_value_database.ipynb
Normal file
1084
docs/tutorial/key_value_database.ipynb
Normal file
File diff suppressed because it is too large
Load Diff
3
docs/tutorial/key_value_database.md
Normal file
3
docs/tutorial/key_value_database.md
Normal file
@@ -0,0 +1,3 @@
|
||||
# Key Value Database
|
||||
|
||||
This is an interactive tutorial written as a Jupyter Notebook, which you can find [here](https://github.com/zama-ai/concrete-numpy/blob/main/docs/tutorial/key_value_database.ipynb).
|
||||
247
examples/key-value-database/dynamic-size.py
Normal file
247
examples/key-value-database/dynamic-size.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import time
|
||||
|
||||
import concrete.numpy as cnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
CHUNK_SIZE = 4
|
||||
|
||||
KEY_SIZE = 32
|
||||
VALUE_SIZE = 32
|
||||
|
||||
assert KEY_SIZE % CHUNK_SIZE == 0
|
||||
assert VALUE_SIZE % CHUNK_SIZE == 0
|
||||
|
||||
NUMBER_OF_KEY_CHUNKS = KEY_SIZE // CHUNK_SIZE
|
||||
NUMBER_OF_VALUE_CHUNKS = VALUE_SIZE // CHUNK_SIZE
|
||||
|
||||
|
||||
def encode(number, width):
|
||||
binary_repr = np.binary_repr(number, width=width)
|
||||
blocks = [binary_repr[i:i+CHUNK_SIZE] for i in range(0, len(binary_repr), CHUNK_SIZE)]
|
||||
return np.array([int(block, 2) for block in blocks])
|
||||
|
||||
def encode_key(number):
|
||||
return encode(number, width=KEY_SIZE)
|
||||
|
||||
def encode_value(number):
|
||||
return encode(number, width=VALUE_SIZE)
|
||||
|
||||
def decode(encoded_number):
|
||||
result = 0
|
||||
for i in range(len(encoded_number)):
|
||||
result += 2**(CHUNK_SIZE*i) * encoded_number[(len(encoded_number) - i) - 1]
|
||||
return result
|
||||
|
||||
|
||||
keep_if_match_lut = cnp.LookupTable([0 for _ in range(16)] + [i for i in range(16)])
|
||||
keep_if_no_match_lut = cnp.LookupTable([i for i in range(16)] + [0 for _ in range(16)])
|
||||
|
||||
def _replace_impl(key, value, candidate_key, candidate_value):
|
||||
match = np.sum((candidate_key - key) == 0) == NUMBER_OF_KEY_CHUNKS
|
||||
|
||||
packed_match_and_value = (2 ** CHUNK_SIZE) * match + value
|
||||
value_if_match_else_zeros = keep_if_match_lut[packed_match_and_value]
|
||||
|
||||
packed_match_and_candidate_value = (2 ** CHUNK_SIZE) * match + candidate_value
|
||||
zeros_if_match_else_candidate_value = keep_if_no_match_lut[packed_match_and_candidate_value]
|
||||
|
||||
return value_if_match_else_zeros + zeros_if_match_else_candidate_value
|
||||
|
||||
def _query_impl(key, candidate_key, candidate_value):
|
||||
match = np.sum((candidate_key - key) == 0) == NUMBER_OF_KEY_CHUNKS
|
||||
|
||||
packed_match_and_candidate_value = (2 ** CHUNK_SIZE) * match + candidate_value
|
||||
candidate_value_if_match_else_zeros = keep_if_match_lut[packed_match_and_candidate_value]
|
||||
|
||||
return cnp.array([match, *candidate_value_if_match_else_zeros])
|
||||
|
||||
|
||||
class KeyValueDatabase:
|
||||
|
||||
_state: list[np.ndarray]
|
||||
|
||||
_replace_circuit: cnp.Circuit
|
||||
_query_circuit: cnp.Circuit
|
||||
|
||||
def __init__(self):
|
||||
self._state = []
|
||||
|
||||
configuration = cnp.Configuration(
|
||||
enable_unsafe_features=True,
|
||||
use_insecure_key_cache=True,
|
||||
insecure_key_cache_location=".keys",
|
||||
)
|
||||
|
||||
replace_compiler = cnp.Compiler(
|
||||
_replace_impl,
|
||||
{
|
||||
"key": "encrypted", "value": "encrypted",
|
||||
"candidate_key": "encrypted", "candidate_value": "encrypted",
|
||||
}
|
||||
)
|
||||
query_compiler = cnp.Compiler(
|
||||
_query_impl,
|
||||
{
|
||||
"key": "encrypted",
|
||||
"candidate_key": "encrypted", "candidate_value": "encrypted",
|
||||
}
|
||||
)
|
||||
|
||||
replace_inputset = [
|
||||
(
|
||||
# key
|
||||
np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1),
|
||||
# value
|
||||
np.ones(NUMBER_OF_VALUE_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1),
|
||||
# candidate_key
|
||||
np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1),
|
||||
# candidate_value
|
||||
np.ones(NUMBER_OF_VALUE_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1),
|
||||
)
|
||||
]
|
||||
query_inputset = [
|
||||
(
|
||||
# key
|
||||
np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1),
|
||||
# candidate_key
|
||||
np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1),
|
||||
# candidate_value
|
||||
np.ones(NUMBER_OF_VALUE_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1),
|
||||
)
|
||||
]
|
||||
|
||||
print("Compiling replacement circuit...")
|
||||
start = time.time()
|
||||
self._replace_circuit = replace_compiler.compile(replace_inputset, configuration)
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
print()
|
||||
|
||||
print("Compiling query circuit...")
|
||||
start = time.time()
|
||||
self._query_circuit = query_compiler.compile(query_inputset, configuration)
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
print()
|
||||
|
||||
print("Generating replacement keys...")
|
||||
start = time.time()
|
||||
self._replace_circuit.keygen()
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
print()
|
||||
|
||||
print("Generating query keys...")
|
||||
start = time.time()
|
||||
self._query_circuit.keygen()
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
def insert(self, key, value):
|
||||
print()
|
||||
print(f"Inserting...")
|
||||
start = time.time()
|
||||
|
||||
self._state.append([encode_key(key), encode_value(value)])
|
||||
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
def replace(self, key, value):
|
||||
print()
|
||||
print(f"Replacing...")
|
||||
start = time.time()
|
||||
|
||||
encoded_key = encode_key(key)
|
||||
encoded_value = encode_value(value)
|
||||
|
||||
for entry in self._state:
|
||||
entry[1] = self._replace_circuit.encrypt_run_decrypt(encoded_key, encoded_value, *entry)
|
||||
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
def query(self, key):
|
||||
print()
|
||||
print(f"Querying...")
|
||||
start = time.time()
|
||||
|
||||
encoded_key = encode_key(key)
|
||||
|
||||
accumulation = np.zeros(1 + NUMBER_OF_VALUE_CHUNKS, dtype=np.uint64)
|
||||
for entry in self._state:
|
||||
accumulation += self._query_circuit.encrypt_run_decrypt(encoded_key, *entry)
|
||||
|
||||
match_count = accumulation[0]
|
||||
if match_count > 1:
|
||||
raise RuntimeError("Key inserted multiple times")
|
||||
|
||||
result = decode(accumulation[1:]) if match_count == 1 else None
|
||||
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
db = KeyValueDatabase()
|
||||
|
||||
# Test: Insert/Query
|
||||
db.insert(3, 4)
|
||||
assert db.query(3) == 4
|
||||
|
||||
db.replace(3, 1)
|
||||
assert db.query(3) == 1
|
||||
|
||||
# Test: Insert/Query
|
||||
db.insert(25, 40)
|
||||
assert db.query(25) == 40
|
||||
|
||||
# Test: Query Not Found
|
||||
assert db.query(4) is None
|
||||
|
||||
# Test: Replace/Query
|
||||
db.replace(3, 5)
|
||||
assert db.query(3) == 5
|
||||
|
||||
|
||||
# Define lower/upper bounds for the key
|
||||
minimum_key = 0
|
||||
maximum_key = 2 ** KEY_SIZE - 1
|
||||
# Define lower/upper bounds for the value
|
||||
minimum_value = 0
|
||||
maximum_value = 2 ** VALUE_SIZE - 1
|
||||
|
||||
|
||||
# Test: Insert/Replace/Query Bounds
|
||||
# Insert (key: minimum_key , value: minimum_value) into the database
|
||||
db.insert(minimum_key, minimum_value)
|
||||
|
||||
# Query the database for the key=minimum_key
|
||||
# The value minimum_value should be returned
|
||||
assert db.query(minimum_key) == minimum_value
|
||||
|
||||
# Insert (key: maximum_key , value: maximum_value) into the database
|
||||
db.insert(maximum_key, maximum_value)
|
||||
|
||||
# Query the database for the key=maximum_key
|
||||
# The value maximum_value should be returned
|
||||
assert db.query(maximum_key) == maximum_value
|
||||
|
||||
# Replace the value of key=minimum_key with maximum_value
|
||||
db.replace(minimum_key, maximum_value)
|
||||
|
||||
# Query the database for the key=minimum_key
|
||||
# The value maximum_value should be returned
|
||||
assert db.query(minimum_key) == maximum_value
|
||||
|
||||
# Replace the value of key=maximum_key with minimum_value
|
||||
db.replace(maximum_key, minimum_value)
|
||||
|
||||
# Query the database for the key=maximum_key
|
||||
# The value minimum_value should be returned
|
||||
assert db.query(maximum_key) == minimum_value
|
||||
299
examples/key-value-database/static-size.py
Normal file
299
examples/key-value-database/static-size.py
Normal file
@@ -0,0 +1,299 @@
|
||||
import time
|
||||
|
||||
import concrete.numpy as cnp
|
||||
import numpy as np
|
||||
|
||||
|
||||
NUMBER_OF_ENTRIES = 5
|
||||
CHUNK_SIZE = 4
|
||||
|
||||
KEY_SIZE = 32
|
||||
VALUE_SIZE = 32
|
||||
|
||||
assert KEY_SIZE % CHUNK_SIZE == 0
|
||||
assert VALUE_SIZE % CHUNK_SIZE == 0
|
||||
|
||||
NUMBER_OF_KEY_CHUNKS = KEY_SIZE // CHUNK_SIZE
|
||||
NUMBER_OF_VALUE_CHUNKS = VALUE_SIZE // CHUNK_SIZE
|
||||
|
||||
STATE_SHAPE = (NUMBER_OF_ENTRIES, 1 + NUMBER_OF_KEY_CHUNKS + NUMBER_OF_VALUE_CHUNKS)
|
||||
|
||||
FLAG = 0
|
||||
KEY = slice(1, 1 + NUMBER_OF_KEY_CHUNKS)
|
||||
VALUE = slice(1 + NUMBER_OF_KEY_CHUNKS, None)
|
||||
|
||||
|
||||
def encode(number: int, width: int) -> np.array:
|
||||
binary_repr = np.binary_repr(number, width=width)
|
||||
blocks = [binary_repr[i:i+CHUNK_SIZE] for i in range(0, len(binary_repr), CHUNK_SIZE)]
|
||||
return np.array([int(block, 2) for block in blocks])
|
||||
|
||||
def encode_key(number: int) -> np.array:
|
||||
return encode(number, width=KEY_SIZE)
|
||||
|
||||
def encode_value(number: int) -> np.array:
|
||||
return encode(number, width=VALUE_SIZE)
|
||||
|
||||
def decode(encoded_number: np.array) -> int:
|
||||
result = 0
|
||||
for i in range(len(encoded_number)):
|
||||
result += 2**(CHUNK_SIZE*i) * encoded_number[(len(encoded_number) - i) - 1]
|
||||
return result
|
||||
|
||||
|
||||
keep_selected_lut = cnp.LookupTable([0 for _ in range(16)] + [i for i in range(16)])
|
||||
|
||||
def _insert_impl(state, key, value):
|
||||
flags = state[:, FLAG]
|
||||
|
||||
selection = cnp.zeros(NUMBER_OF_ENTRIES)
|
||||
|
||||
found = cnp.zero()
|
||||
for i in range(NUMBER_OF_ENTRIES):
|
||||
packed_flag_and_already_found = (found * 2) + flags[i]
|
||||
is_selected = (packed_flag_and_already_found == 0)
|
||||
|
||||
selection[i] = is_selected
|
||||
found += is_selected
|
||||
|
||||
state_update = cnp.zeros(STATE_SHAPE)
|
||||
state_update[:, FLAG] = selection
|
||||
|
||||
selection = selection.reshape((-1, 1))
|
||||
|
||||
packed_selection_and_key = (selection * (2 ** CHUNK_SIZE)) + key
|
||||
key_update = keep_selected_lut[packed_selection_and_key]
|
||||
|
||||
packed_selection_and_value = selection * (2 ** CHUNK_SIZE) + value
|
||||
value_update = keep_selected_lut[packed_selection_and_value]
|
||||
|
||||
state_update[:, KEY] = key_update
|
||||
state_update[:, VALUE] = value_update
|
||||
|
||||
new_state = state + state_update
|
||||
return new_state
|
||||
|
||||
def _replace_impl(state, key, value):
|
||||
flags = state[:, FLAG]
|
||||
keys = state[:, KEY]
|
||||
values = state[:, VALUE]
|
||||
|
||||
equal_rows = (np.sum((keys - key) == 0, axis=1) == NUMBER_OF_KEY_CHUNKS)
|
||||
selection = (flags * 2 + equal_rows == 3).reshape((-1, 1))
|
||||
|
||||
packed_selection_and_value = selection * (2 ** CHUNK_SIZE) + value
|
||||
set_value = keep_selected_lut[packed_selection_and_value]
|
||||
|
||||
inverse_selection = 1 - selection
|
||||
packed_inverse_selection_and_values = inverse_selection * (2 ** CHUNK_SIZE) + values
|
||||
kept_values = keep_selected_lut[packed_inverse_selection_and_values]
|
||||
|
||||
new_values = kept_values + set_value
|
||||
state[:, VALUE] = new_values
|
||||
|
||||
return state
|
||||
|
||||
def _query_impl(state, key):
|
||||
keys = state[:, KEY]
|
||||
values = state[:, VALUE]
|
||||
|
||||
selection = (np.sum((keys - key) == 0, axis=1) == NUMBER_OF_KEY_CHUNKS).reshape((-1, 1))
|
||||
found = np.sum(selection)
|
||||
|
||||
packed_selection_and_values = selection * (2 ** CHUNK_SIZE) + values
|
||||
value_selection = keep_selected_lut[packed_selection_and_values]
|
||||
value = np.sum(value_selection, axis=0)
|
||||
|
||||
return cnp.array([found, *value])
|
||||
|
||||
|
||||
class KeyValueDatabase:
|
||||
|
||||
_state: np.ndarray
|
||||
|
||||
_insert_circuit: cnp.Circuit
|
||||
_replace_circuit: cnp.Circuit
|
||||
_query_circuit: cnp.Circuit
|
||||
|
||||
def __init__(self):
|
||||
self._state = np.zeros(STATE_SHAPE, dtype=np.int64)
|
||||
|
||||
inputset_binary = [
|
||||
(
|
||||
# state
|
||||
np.zeros(STATE_SHAPE, dtype=np.int64),
|
||||
# key
|
||||
np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1),
|
||||
)
|
||||
]
|
||||
inputset_ternary = [
|
||||
(
|
||||
# state
|
||||
np.zeros(STATE_SHAPE, dtype=np.int64),
|
||||
# key
|
||||
np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1),
|
||||
# value
|
||||
np.ones(NUMBER_OF_VALUE_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1),
|
||||
)
|
||||
]
|
||||
|
||||
configuration = cnp.Configuration(
|
||||
enable_unsafe_features=True,
|
||||
use_insecure_key_cache=True,
|
||||
insecure_key_cache_location=".keys",
|
||||
)
|
||||
|
||||
insert_compiler = cnp.Compiler(
|
||||
_insert_impl,
|
||||
{"state": "encrypted", "key": "encrypted", "value": "encrypted"}
|
||||
)
|
||||
replace_compiler = cnp.Compiler(
|
||||
_replace_impl,
|
||||
{"state": "encrypted", "key": "encrypted", "value": "encrypted"}
|
||||
)
|
||||
query_compiler = cnp.Compiler(
|
||||
_query_impl,
|
||||
{"state": "encrypted", "key": "encrypted"}
|
||||
)
|
||||
|
||||
print()
|
||||
|
||||
print("Compiling insertion circuit...")
|
||||
start = time.time()
|
||||
self._insert_circuit = insert_compiler.compile(inputset_ternary, configuration)
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
print()
|
||||
|
||||
print("Compiling replacement circuit...")
|
||||
start = time.time()
|
||||
self._replace_circuit = replace_compiler.compile(inputset_ternary, configuration)
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
print()
|
||||
|
||||
print("Compiling query circuit...")
|
||||
start = time.time()
|
||||
self._query_circuit = query_compiler.compile(inputset_binary, configuration)
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
print()
|
||||
|
||||
print("Generating insertion keys...")
|
||||
start = time.time()
|
||||
self._insert_circuit.keygen()
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
print()
|
||||
|
||||
print("Generating replacement keys...")
|
||||
start = time.time()
|
||||
self._replace_circuit.keygen()
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
print()
|
||||
|
||||
print("Generating query keys...")
|
||||
start = time.time()
|
||||
self._query_circuit.keygen()
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
def insert(self, key, value):
|
||||
print()
|
||||
print(f"Inserting...")
|
||||
start = time.time()
|
||||
self._state = self._insert_circuit.encrypt_run_decrypt(
|
||||
self._state, encode_key(key), encode_value(value)
|
||||
)
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
def replace(self, key, value):
|
||||
print()
|
||||
print(f"Replacing...")
|
||||
start = time.time()
|
||||
self._state = self._replace_circuit.encrypt_run_decrypt(
|
||||
self._state, encode_key(key), encode_value(value)
|
||||
)
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
def query(self, key):
|
||||
print()
|
||||
print(f"Querying...")
|
||||
start = time.time()
|
||||
result = self._query_circuit.encrypt_run_decrypt(
|
||||
self._state, encode_key(key)
|
||||
)
|
||||
end = time.time()
|
||||
print(f"(took {end - start:.3f} seconds)")
|
||||
|
||||
if result[0] == 0:
|
||||
return None
|
||||
|
||||
return decode(result[1:])
|
||||
|
||||
|
||||
db = KeyValueDatabase()
|
||||
|
||||
# Test: Insert/Query
|
||||
db.insert(3, 4)
|
||||
assert db.query(3) == 4
|
||||
|
||||
db.replace(3, 1)
|
||||
assert db.query(3) == 1
|
||||
|
||||
# Test: Insert/Query
|
||||
db.insert(25, 40)
|
||||
assert db.query(25) == 40
|
||||
|
||||
# Test: Query Not Found
|
||||
assert db.query(4) is None
|
||||
|
||||
# Test: Replace/Query
|
||||
db.replace(3, 5)
|
||||
assert db.query(3) == 5
|
||||
|
||||
|
||||
# Define lower/upper bounds for the key
|
||||
minimum_key = 1
|
||||
maximum_key = 2 ** KEY_SIZE - 1
|
||||
# Define lower/upper bounds for the value
|
||||
minimum_value = 1
|
||||
maximum_value = 2 ** VALUE_SIZE - 1
|
||||
|
||||
|
||||
# Test: Insert/Replace/Query Bounds
|
||||
# Insert (key: minimum_key , value: minimum_value) into the database
|
||||
db.insert(minimum_key, minimum_value)
|
||||
|
||||
# Query the database for the key=minimum_key
|
||||
# The value minimum_value should be returned
|
||||
assert db.query(minimum_key) == minimum_value
|
||||
|
||||
# Insert (key: maximum_key , value: maximum_value) into the database
|
||||
db.insert(maximum_key, maximum_value)
|
||||
|
||||
# Query the database for the key=maximum_key
|
||||
# The value maximum_value should be returned
|
||||
assert db.query(maximum_key) == maximum_value
|
||||
|
||||
# Replace the value of key=minimum_key with maximum_value
|
||||
db.replace(minimum_key, maximum_value)
|
||||
|
||||
# Query the database for the key=minimum_key
|
||||
# The value maximum_value should be returned
|
||||
assert db.query(minimum_key) == maximum_value
|
||||
|
||||
# Replace the value of key=maximum_key with minimum_value
|
||||
db.replace(maximum_key, minimum_value)
|
||||
|
||||
# Query the database for the key=maximum_key
|
||||
# The value minimum_value should be returned
|
||||
assert db.query(maximum_key) == minimum_value
|
||||
Reference in New Issue
Block a user