# Key Value Database

This is an interactive tutorial of an Encrypted Key Value Database. The database allows for three operations, **Insert, Replace, and Query**. All the operations are implemented as fully-homomorphic encrypted circuits.

In `examples/key-value-database/`, you will find the following files:

- `static-size.py`: This file contains a static size database implementation, meaning that the number of entries is given as a parameter at the beginning.
- `dynamic-size.py`: This file contains a dynamic size database implementation, meaning that the database starts as a zero entry database, and is grown as needed.

This tutorial goes over the statically-sized database implementation. The dynamic database implementation is very similar, and the reader is encouraged to look at the code to see the differences.


Below are the import statements.

**time:** Used for measuring the time to create keys, encrypt and run circuits.

**concrete.numpy:** Used for implementing homomorphic circuits.

**numpy:** Used for mathematical operations. Concrete library compiles numpy operations into FHE encrypted operations.


In [1]:
import time

from concrete import fhe
import numpy as np

Below are the database configuration parameters. 

**Number of Entries:** Defines the maximum number of insertable (key, value) pairs. 

**Chunk Size:** Defines the size of each chunk. Chunks are used as the smallest substructure of key and values.

**Key Size:** Defines the size of each key.

**Value Size:** Defines the size of each value.

In [2]:

# The number of entries in the database
NUMBER_OF_ENTRIES = 5
# The number of bits in each chunk
CHUNK_SIZE = 4

# The number of bits in the key and value
KEY_SIZE = 32
VALUE_SIZE = 32

Below are the definition of the state, and the accessors/indexers to the state.

The shape of the state is defined with respect to the size of the key/value with the table given below.

| Flag Size | Key Size | Number of Key Chunks | Value Size | Number of Value Chunks |
| --- | --- | --- | --- | --- |
| 1         | 32       | 32/4 = 8                   | 32         | 32/4 = 8                      |
| 1         | 8        | 8/4 = 2                    | 16          | 16/4 = 4                       |
| 1         | 4        | 4/4 = 1                    | 4          | 4/4 = 1                       |

In [3]:

# Key and Value size must be a multiple of chunk size
assert KEY_SIZE % CHUNK_SIZE == 0
assert VALUE_SIZE % CHUNK_SIZE == 0

# Required number of chunks to store keys and values
NUMBER_OF_KEY_CHUNKS = KEY_SIZE // CHUNK_SIZE
NUMBER_OF_VALUE_CHUNKS = VALUE_SIZE // CHUNK_SIZE

# The shape of the state as a tensor
# Shape:
# | Flag Size | Key Size | Value Size |
# | 1         | 32/4 = 8 | 32/4 = 8   |
STATE_SHAPE = (NUMBER_OF_ENTRIES, 1 + NUMBER_OF_KEY_CHUNKS + NUMBER_OF_VALUE_CHUNKS)

Slices below are used to index certain parts of the the state. 

In [4]:

# Indexers for each part of the state
FLAG = 0
KEY = slice(1, 1 + NUMBER_OF_KEY_CHUNKS)
VALUE = slice(1 + NUMBER_OF_KEY_CHUNKS, None)

### Encode/Decode functions.

Encode/Decode functions are used to convert between integers and numpy arrays. The interface exposes integers, but the state is stored and processed as a numpy array.

#### Encode

Encodes a number into a numpy array.

- The number is encoded in binary and then split into chunks.
- Each chunk is then converted to an integer
- The integers are then stored in a numpy array

| Function Call | Input(Integer) | Array-Width | Result(Numpy Array) |
| --- | --- | --- | --- |
| encode(25, 4) | 25 | 4 | [0, 0, 1, 9] |
| encode(40, 4) | 40 | 4 | [0, 0, 2, 8] |
| encode(11, 3) | 11 | 3 | [0, 0, 11] |


In [5]:
def encode(number: int, width: int) -> np.array:
    binary_repr = np.binary_repr(number, width=width)
    blocks = [binary_repr[i:i+CHUNK_SIZE] for i in range(0, len(binary_repr), CHUNK_SIZE)]
    return np.array([int(block, 2) for block in blocks])

# Encode a number with the key size
def encode_key(number: int) -> np.array:
    return encode(number, width=KEY_SIZE)

# Encode a number with the value size
def encode_value(number: int) -> np.array:
    return encode(number, width=VALUE_SIZE)

#### Decode

Decodes a numpy array into a number.

| Function Call | Input(Numpy Array) | Result(Integer) |
| --- | --- | --- |
| decode([0, 0, 1, 9]) | [0, 0, 1, 9] | 25 |
| decode([0, 0, 2, 8]) | [0, 0, 2, 8] | 40 |
| decode([0, 0, 11]) | [0, 0, 11] | 11 |

In [6]:
def decode(encoded_number: np.array) -> int:
    result = 0
    for i in range(len(encoded_number)):
        result += 2**(CHUNK_SIZE*i) * encoded_number[(len(encoded_number) - i) - 1]
    return result

### Row Selection with Table Lookups

Keep selected function is used to select the correct row of the database for each operation.

Below is the python definition of the function.

In [7]:
def keep_selected(value, selected):
  if selected:
      return value
  else:
      return 0

This function takes any value, and a boolean flag that indicates if value is selected or not. Within homomorphic encryption circuits, we cannot compile this function as encrypted values cannot affect control flow. Instead, we turn this function into a lookup table.

Selected is preprended to the value, and function is modified to act as below.

`keep_selected(i=0..15, 1) -> 0` 
`keep_selected(i=16..31, 0) -> i-16`

Below is the python code for the lookup table.

In [8]:
keep_selected_lut = fhe.LookupTable([0 for _ in range(16)] + [i for i in range(16)])

Most significant bit of the input to the lookup table represents the select bit, hence if `select=0 <=> i=0..15` then the output is `0`. If `select=1 <=> i=16..31` then the output is `i-16`, the value itself.

To summarize, we could implement the keep_selected function as below.

In [9]:
def keep_selected_using_lut(value, selected):
  packed = (2 ** CHUNK_SIZE) * selected + value
  return keep_selected_lut[packed]


### Circuit Implementation Functions

The following functions are used to implement the key-value database circuits. 
Three circuits are implemented: 
- insert: Inserts a key value pair into the database
- replace: Replaces the value of a key in the database
- query: Queries the database for a key and returns the value


#### Insert

Algorithm of the insert function is as follows:
- Create a selection array to select a certain row of the database
- Fill this array by setting the first non-empty row of the database to 1
- Create a state update array, where the first non-empty row of the database is set to the new key and value
- Add the state update array to the state

Implementation is below. 

In [10]:

# Insert a key value pair into the database
# - state: The state of the database
# - key: The key to insert
# - value: The value to insert
# Returns the updated state
def _insert_impl(state, key, value):
    # Get the used bit from the state
    # This bit is used to determine if an entry is used or not
    flags = state[:, FLAG]

    # Create a selection array
    # This array is used to select the first unused entry
    selection = fhe.zeros(NUMBER_OF_ENTRIES)

    # The found bit is used to determine if an unused entry has been found
    found = fhe.zero()
    for i in range(NUMBER_OF_ENTRIES):
        # The packed flag and found bit are used to determine if the entry is unused
        # | Flag | Found |
        # | 0    | 0     | -> Unused, select
        # | 0    | 1     | -> Unused, skip
        # | 1    | 0     | -> Used, skip
        # | 1    | 1     | -> Used, skip
        packed_flag_and_found = (found * 2) + flags[i]
        # Use the packed flag and found bit to determine if the entry is unused
        is_selected = (packed_flag_and_found == 0)

        # Update the selection array
        selection[i] = is_selected
        # Update the found bit, so all entries will be 
        # skipped after the first unused entry is found
        found += is_selected

    # Create a state update array
    state_update = fhe.zeros(STATE_SHAPE)
    # Update the state update array with the selection array
    state_update[:, FLAG] = selection

    # Reshape the selection array to be able to use it as an index
    selection = selection.reshape((-1, 1))

    # Create a packed selection and key array
    # This array is used to update the key of the selected entry
    packed_selection_and_key = (selection * (2 ** CHUNK_SIZE)) + key
    key_update = keep_selected_lut[packed_selection_and_key]

    # Create a packed selection and value array
    # This array is used to update the value of the selected entry
    packed_selection_and_value = selection * (2 ** CHUNK_SIZE) + value
    value_update = keep_selected_lut[packed_selection_and_value]

    # Update the state update array with the key and value update arrays
    state_update[:, KEY] = key_update
    state_update[:, VALUE] = value_update

    # Update the state with the state update array
    new_state = state + state_update
    return new_state

#### Replace

Algorithm of the replace function is as follows:
- Create a equal-rows array to select rows that match the given key in the database
- Create a selection array to select the row that is currently used in the database
- Set the selection array to 1 for the row that contains the key, and 0 for all other rows
- Create an inverse selection array by inverting the selection array
- Row set to 1 in the selection array will be updated, whereas all other values will stay the same
- To do this, we multiply the selection array with the new key and value, and the inverse selection array with the old key and value
- We then add the two arrays to get the new state

In [11]:

# Replace the value of a key in the database
#   If the key is not in the database, nothing happens
#   If the key is in the database, the value is replaced
# - state: The state of the database
# - key: The key to replace
# - value: The value to replace
# Returns the updated state
def _replace_impl(state, key, value):
    # Get the flags, keys and values from the state
    flags = state[:, FLAG]
    keys = state[:, KEY]
    values = state[:, VALUE]

    

    # Create an equal_rows array
    # This array is used to select all entries with the given key
    # The equal_rows array is created by comparing the keys in the state
    # with the given key, and only setting the entry to 1 if the keys are equal
    # Example:
    #   keys = [[1, 0, 1, 0], [0, 1, 0, 1, 1]]
    #   key = [1, 0, 1, 0]
    #   equal_rows = [1, 0]
    equal_rows = (np.sum((keys - key) == 0, axis=1) == NUMBER_OF_KEY_CHUNKS)

    # Create a selection array
    # This array is used to select the entry to change the value of
    # The selection array is created by combining the equal_rows array
    # with the flags array, which is used to determine if an entry is used or not
    # The reason for combining the equal_rows array with the flags array
    # is to make sure that only used entries are selected
    selection = (flags * 2 + equal_rows == 3).reshape((-1, 1))
    
    # Create a packed selection and value array
    # This array is used to update the value of the selected entry
    packed_selection_and_value = selection * (2 ** CHUNK_SIZE) + value
    set_value = keep_selected_lut[packed_selection_and_value]

    # Create an inverse selection array
    # This array is used to pick entries that are not selected
    # Example:
    #   selection = [1, 0, 0]
    #   inverse_selection = [0, 1, 1]
    inverse_selection = 1 - selection

    # Create a packed inverse selection and value array
    # This array is used to keep the value of the entries that are not selected
    packed_inverse_selection_and_values = inverse_selection * (2 ** CHUNK_SIZE) + values
    kept_values = keep_selected_lut[packed_inverse_selection_and_values]

    # Update the values of the state with the new values
    new_values = kept_values + set_value
    state[:, VALUE] = new_values

    return state

#### Query

Algorithm of the query function is as follows:
- Create a selection array to select a certain row of the database
- Set the selection array to 1 for the row that contains the key
- Multiply the selection array with the state to zero all rows that do not contain the key
- Sum the rows of the state to get the remaining non-zero row, basically doing a dimension reduction
- Prepend the found flag to the value, return the resulting array.
- The resulting array will be destructured in the non-encrypted query function

In [12]:

# Query the database for a key and return the value
# - state: The state of the database
# - key: The key to query
# Returns an array with the following format:
#   [found, value]
#   found: 1 if the key was found, 0 otherwise
#   value: The value of the key if the key was found, 0 otherwise
def _query_impl(state, key):
    # Get the keys and values from the state
    keys = state[:, KEY]
    values = state[:, VALUE]

    # Create a selection array
    # This array is used to select the entry with the given key
    # The selection array is created by comparing the keys in the state
    # with the given key, and only setting the entry to 1 if the keys are equal
    # Example:
    #   keys = [[1, 0, 1, 0], [0, 1, 0, 1, 1]]
    #   key = [1, 0, 1, 0]
    #   selection = [1, 0]
    selection = (np.sum((keys - key) == 0, axis=1) == NUMBER_OF_KEY_CHUNKS).reshape((-1, 1))

    # Create a found bit
    # This bit is used to determine if the key was found
    # The found bit is set to 1 if the key was found, and 0 otherwise
    found = np.sum(selection)

    # Create a packed selection and value array
    # This array is used to get the value of the selected entry
    packed_selection_and_values = selection * (2 ** CHUNK_SIZE) + values
    value_selection = keep_selected_lut[packed_selection_and_values]

    # Sum the value selection array to get the value
    value = np.sum(value_selection, axis=0)

    # Return the found bit and the value
    return fhe.array([found, *value])

### Key-Value Database Interface

KeyValueDatabase class is the interface that exposes the functionality of the key-value database.

In [13]:
class KeyValueDatabase:
    """
    A key-value database that uses fully homomorphic encryption circuits to store the data.
    """

    # The state of the database, it holds all the 
    # keys and values as a table of entries
    _state: np.ndarray

    # The circuits used to implement the database
    _insert_circuit: fhe.Circuit
    _replace_circuit: fhe.Circuit
    _query_circuit: fhe.Circuit

    # Below is the initialization of the database.

    # First, we initialize the state, and provide the necessary input sets. 
    # In versions later than concrete-numpy.0.9.0, we can use the `direct circuit` 
    # functionality to define the bit-widths of encrypted values rather than using 
    # `input sets`. Input sets are used to determine the required bit-width of the 
    # encrypted values. Hence, we add the largest possible value in the database 
    # to the input sets.

    # Within the initialization phase, we create the required configuration, 
    # compilers, circuits, and keys. Circuit and key generation phase is 
    # timed and printed in the output.

    def __init__(self):
        # Initialize the state to all zeros
        self._state = np.zeros(STATE_SHAPE, dtype=np.int64)

        ## Input sets for initialization of the circuits
        # The input sets are used to initialize the circuits with the correct parameters

        # The input set for the query circuit
        inputset_binary = [
            (
                np.zeros(STATE_SHAPE, dtype=np.int64), # state
                np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1), # key
            )
        ]
        # The input set for the insert and replace circuits
        inputset_ternary = [
            (
                np.zeros(STATE_SHAPE, dtype=np.int64), # state
                np.ones(NUMBER_OF_KEY_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1), # key
                np.ones(NUMBER_OF_VALUE_CHUNKS, dtype=np.int64) * (2**CHUNK_SIZE - 1), # value
            )
        ]

        ## Circuit compilation

        # Create a configuration for the compiler
        configuration = fhe.Configuration(
            enable_unsafe_features=True,
            use_insecure_key_cache=True,
            insecure_key_cache_location=".keys",
        )

        # Create the compilers for the circuits
        # Each compiler is provided with
        # - The implementation of the circuit
        # - The inputs and their corresponding types of the circuit
        #  - "encrypted": The input is encrypted
        #  - "plain": The input is not encrypted
        insert_compiler = fhe.Compiler(
            _insert_impl,
            {"state": "encrypted", "key": "encrypted", "value": "encrypted"}
        )
        replace_compiler = fhe.Compiler(
            _replace_impl,
            {"state": "encrypted", "key": "encrypted", "value": "encrypted"}
        )
        query_compiler = fhe.Compiler(
            _query_impl,
            {"state": "encrypted", "key": "encrypted"}
        )


        ## Compile the circuits
        # The circuits are compiled with the input set and the configuration

        print()

        print("Compiling insertion circuit...")
        start = time.time()
        self._insert_circuit = insert_compiler.compile(inputset_ternary, configuration)
        end = time.time()
        print(f"(took {end - start:.3f} seconds)")

        print()

        print("Compiling replacement circuit...")
        start = time.time()
        self._replace_circuit = replace_compiler.compile(inputset_ternary, configuration)
        end = time.time()
        print(f"(took {end - start:.3f} seconds)")

        print()

        print("Compiling query circuit...")
        start = time.time()
        self._query_circuit = query_compiler.compile(inputset_binary, configuration)
        end = time.time()
        print(f"(took {end - start:.3f} seconds)")

        print()

        ## Generate the keys for the circuits
        # The keys are seaparately generated for each circuit

        print("Generating insertion keys...")
        start = time.time()
        self._insert_circuit.keygen()
        end = time.time()
        print(f"(took {end - start:.3f} seconds)")

        print()

        print("Generating replacement keys...")
        start = time.time()
        self._replace_circuit.keygen()
        end = time.time()
        print(f"(took {end - start:.3f} seconds)")

        print()

        print("Generating query keys...")
        start = time.time()
        self._query_circuit.keygen()
        end = time.time()
        print(f"(took {end - start:.3f} seconds)")

    ### The Interface Functions
        
    # The following methods are used to interact with the database. 
    # They are used to insert, replace and query the database. 
    # The methods are implemented by encrypting the inputs, 
    # running the circuit and decrypting the output.

    # Insert a key-value pair into the database
    # - key: The key to insert
    # - value: The value to insert
    # The key and value are encoded before they are inserted
    # The state of the database is updated with the new key-value pair
    def insert(self, key, value):
        print()
        print(f"Inserting...")
        start = time.time()
        self._state = self._insert_circuit.encrypt_run_decrypt(
            self._state, encode_key(key), encode_value(value)
        )
        end = time.time()
        print(f"(took {end - start:.3f} seconds)")

    # Replace a key-value pair in the database
    # - key: The key to replace
    # - value: The new value to insert with the key
    # The key and value are encoded before they are inserted
    # The state of the database is updated with the new key-value pair
    def replace(self, key, value):
        print()
        print(f"Replacing...")
        start = time.time()
        self._state = self._replace_circuit.encrypt_run_decrypt(
            self._state, encode_key(key), encode_value(value)
        )
        end = time.time()
        print(f"(took {end - start:.3f} seconds)")

    # Query the database for a key
    # - key: The key to query
    # The key is encoded before it is queried
    # Returns the value associated with the key or None if the key is not found
    def query(self, key):
        print()
        print(f"Querying...")
        start = time.time()
        result = self._query_circuit.encrypt_run_decrypt(
            self._state, encode_key(key)
        )
        end = time.time()
        print(f"(took {end - start:.3f} seconds)")

        if result[0] == 0:
            return None

        return decode(result[1:])


The implementation provided above is the statically-sized implementation. We will shortly discuss the dynamic implementation below.

Whereas static implementation works with circuits over the whole database, dynamic implementation works with circuits over a single row of the database.

In the dynamic implementation, we iterate over the rows of the database in a simple Python loop, and run the circuits over each row. This difference in implementation is reflected in the `insert`, `replace` and `query` functions.

In terms of comparison of the implementations, the static implementation is more efficient with dense databases as it works with parallelized tensors, but it takes the same amount of time to query an empty database and a database with 1 million entries. The dynamic implementation is more efficient with sparse databases as it grows with the number of entries, but it doesn't use circuit level parallelization. Also, insertion is free in the dynamic implementation as it only appends a new item to a Python list.

We have now finished the definition of the database. We can now use the database to insert, replace and query values.

## Usage

Below is the initialization of the database. As we provide parameters globally, we can simply initialize the database with the following command.

In [14]:
## Test: Initialization
# Initialize the database
db = KeyValueDatabase()


Compiling insertion circuit...
(took 1.178 seconds)

Compiling replacement circuit...
(took 0.626 seconds)

Compiling query circuit...
(took 0.603 seconds)

Generating insertion keys...
(took 0.188 seconds)

Generating replacement keys...
(took 0.280 seconds)

Generating query keys...
(took 0.227 seconds)


We can use the interface functions as provided below.

In [15]:
# Test: Insert/Query
# Insert (key: 3, value: 4) into the database
db.insert(3, 4)


Inserting...
(took 0.768 seconds)


In [16]:
# Query the database for the key 3
# The value 4 should be returned
assert db.query(3) == 4


Querying...
(took 0.460 seconds)


In [17]:
# Test: Replace/Query
# Replace the value of the key 3 with 1
db.replace(3, 1)


Replacing...
(took 0.806 seconds)


In [18]:
# Query the database for the key 3
# The value 1 should be returned
assert db.query(3) == 1


Querying...
(took 0.483 seconds)


In [19]:
# Test: Insert/Query
# Insert (key: 25, value: 40) into the database
db.insert(25, 40)


Inserting...
(took 0.618 seconds)


In [20]:
# Query the database for the key 25
# The value 40 should be returned
assert db.query(25) == 40


Querying...
(took 0.957 seconds)


In [21]:
# Test: Query Not Found
# Query the database for the key 4
# None should be returned
assert db.query(4) == None


Querying...
(took 1.133 seconds)


In [22]:
# Test: Replace/Query
# Replace the value of the key 3 with 5
db.replace(3, 5)


Replacing...
(took 2.325 seconds)


In [23]:
# Query the database for the key 3
# The value 5 should be returned
assert db.query(3) == 5


Querying...
(took 1.172 seconds)


We can now test the limits, we'll use the hyper-parameters `KEY_SIZE` and `VALUE_SIZE` in order to ensure that the examples work robustly against changes to the parameters.

In [24]:
# Define lower/upper bounds for the key
minimum_key = 1
maximum_key = 2 ** KEY_SIZE - 1
# Define lower/upper bounds for the value
minimum_value = 1
maximum_value = 2 ** VALUE_SIZE - 1

Below are the usage examples with these bounds.

In [25]:
# Test: Insert/Replace/Query Bounds
# Insert (key: minimum_key , value: minimum_value) into the database
db.insert(minimum_key, minimum_value)

# Query the database for the key=minimum_key
# The value minimum_value should be returned
assert db.query(minimum_key) == minimum_value

# Insert (key: maximum_key , value: maximum_value) into the database
db.insert(maximum_key, maximum_value)

# Query the database for the key=maximum_key
# The value maximum_value should be returned
assert db.query(maximum_key) == maximum_value

# Replace the value of key=minimum_key with maximum_value
db.replace(minimum_key, maximum_value)

# Query the database for the key=minimum_key
# The value maximum_value should be returned
assert db.query(minimum_key) == maximum_value

# Replace the value of key=maximum_key with minimum_value
db.replace(maximum_key, minimum_value)

# Query the database for the key=maximum_key
# The value minimum_value should be returned
assert db.query(maximum_key) == minimum_value


Inserting...
(took 1.358 seconds)

Querying...
(took 1.137 seconds)

Inserting...
(took 1.383 seconds)

Querying...
(took 1.207 seconds)

Replacing...
(took 2.404 seconds)

Querying...
(took 1.241 seconds)

Replacing...
(took 2.345 seconds)

Querying...
(took 1.213 seconds)
