Files
sdk/python/blyss/bucket.py
2023-08-28 16:20:22 -07:00

359 lines
13 KiB
Python

"""Bucket
Abstracts functionality on an existing bucket.
"""
from typing import Optional, Any, Union, Iterator
from . import api, serializer, seed
from .blyss_lib import BlyssLib
import json
import base64
import bz2
import time
import asyncio
def _chunk_parser(raw_data: bytes) -> Iterator[bytes]:
"""
Parse a bytestream containing an arbitrary number of length-prefixed chunks.
"""
data = memoryview(raw_data)
i = 0
num_chunks = int.from_bytes(data[:8], "little", signed=False)
i += 8
for _ in range(num_chunks):
chunk_len = int.from_bytes(data[i : i + 8], "little", signed=False)
i += 8
chunk_data = bytes(data[i : i + chunk_len])
i += chunk_len
yield chunk_data
class Bucket:
"""Interface to a single Blyss bucket."""
def __init__(self, api: api.API, name: str, secret_seed: Optional[str] = None):
"""
@private
Initialize a client for a single, existing Blyss bucket.
Args:
api: A target API to send all underlying API calls to.
name: The name of the bucket.
secret_seed: An optional secret seed to initialize the client with.
A random one will be generated if not supplied.
"""
self.name: str = name
"""Name of the bucket. See [bucket naming rules](https://docs.blyss.dev/docs/buckets#names)."""
# Internal attributes
self._api = api
self._metadata = self._api.meta(self.name)
if secret_seed:
self._secret_seed = secret_seed
else:
self._secret_seed = seed.get_random_seed()
self._lib = BlyssLib(
json.dumps(self._metadata["pir_scheme"]), self._secret_seed
)
self._public_uuid: Optional[str] = None
self._exfil: Any = None # used for benchmarking
def _check(self, uuid: str) -> bool:
"""Checks if the server has the given UUID.
Args:
uuid (str): The key to check.
Returns:
bool: Whether the server has the given UUID.
"""
try:
self._api.check(uuid)
return True
except api.ApiException as e:
if e.code == 404:
return False
else:
raise e
async def _async_check(self, uuid: str) -> bool:
try:
await self._api.async_check(uuid)
return True
except api.ApiException as e:
if e.code == 404:
return False
else:
raise e
def _split_into_chunks(
self, kv_pairs: dict[str, bytes]
) -> list[list[dict[str, str]]]:
_MAX_PAYLOAD = 5 * 2**20 # 5 MiB
# 1. Bin keys by row index
keys_by_index: dict[int, list[str]] = {}
for k in kv_pairs.keys():
i = self._lib.get_row(k)
if i in keys_by_index:
keys_by_index[i].append(k)
else:
keys_by_index[i] = [k]
# 2. Prepare chunks of items, where each is a JSON-ready structure.
# Each chunk is less than the maximum payload size, and guarantees
# zero overlap of rows across chunks.
kv_chunks: list[list[dict[str, str]]] = []
current_chunk: list[dict[str, str]] = []
current_chunk_size = 0
sorted_indices = sorted(keys_by_index.keys())
for i in sorted_indices:
keys = keys_by_index[i]
# prepare all keys in this row
row = []
row_size = 0
for key in keys:
value = kv_pairs[key]
value_str = base64.b64encode(value).decode("utf-8")
fmt = {
"key": key,
"value": value_str,
"content-type": "application/octet-stream",
}
row.append(fmt)
row_size += int(24 + len(key) + len(value_str) + 48)
# if the new row doesn't fit into the current chunk, start a new one
if current_chunk_size + row_size > _MAX_PAYLOAD:
kv_chunks.append(current_chunk)
current_chunk = row
current_chunk_size = row_size
else:
current_chunk.extend(row)
current_chunk_size += row_size
# add the last chunk
if len(current_chunk) > 0:
kv_chunks.append(current_chunk)
return kv_chunks
def _generate_query_stream(self, keys: list[str]) -> bytes:
assert self._public_uuid
# generate encrypted queries
queries: list[bytes] = [
self._lib.generate_query(self._public_uuid, self._lib.get_row(k))
for k in keys
]
# interleave the queries with their lengths (uint64_t)
query_lengths = [len(q).to_bytes(8, "little") for q in queries]
lengths_and_queries = [x for lq in zip(query_lengths, queries) for x in lq]
# prepend the total number of queries (uint64_t)
lengths_and_queries.insert(0, len(queries).to_bytes(8, "little"))
# serialize the queries
multi_query = b"".join(lengths_and_queries)
return multi_query
def _unpack_query_result(
self, keys: list[str], raw_result: bytes, ignore_errors=False
) -> list[Optional[bytes]]:
retrievals = []
for key, result in zip(keys, _chunk_parser(raw_result)):
if len(result) == 0:
# error in processing this query
if ignore_errors:
extracted_result = None
else:
raise RuntimeError(f"Failed to process query for key {key}.")
else:
decrypted_result = self._lib.decode_response(result)
decompressed_result = bz2.decompress(decrypted_result)
extracted_result = self._lib.extract_result(key, decompressed_result)
retrievals.append(extracted_result)
return retrievals
def _private_read(self, keys: list[str]) -> list[Optional[bytes]]:
"""Performs the underlying private retrieval.
Args:
keys (str): A list of keys to retrieve.
Returns:
a list of values (bytes) corresponding to keys. None for keys not found.
"""
if not self._public_uuid or not self._check(self._public_uuid):
self.setup()
assert self._public_uuid
multi_query = self._generate_query_stream(keys)
start = time.perf_counter()
multi_result = self._api.private_read(self.name, multi_query)
self._exfil = time.perf_counter() - start
retrievals = self._unpack_query_result(keys, multi_result)
return retrievals
def setup(self):
"""Prepares this bucket client for private reads.
This method will be called automatically by :method:`read`, but
clients may call it explicitly prior to make subsequent
`private_read` calls faster.
Can upload significant amounts of data (1-10 MB).
"""
public_params = self._lib.generate_keys_with_public_params()
setup_resp = self._api.setup(self.name, bytes(public_params))
self._public_uuid = setup_resp["uuid"]
def info(self) -> dict[Any, Any]:
"""Fetch this bucket's properties from the service, such as access permissions and PIR scheme parameters."""
return self._api.meta(self.name)
def list_keys(self) -> list[str]:
"""List all key strings in this bucket. Only available if bucket was created with keyStoragePolicy="full"."""
return self._api.list_keys(self.name)
def rename(self, new_name: str):
"""Rename this bucket to new_name."""
bucket_create_req = {
"name": new_name,
}
self._api.modify(self.name, json.dumps(bucket_create_req))
self.name = new_name
def write(self, kv_pairs: dict[str, bytes]):
"""Writes the supplied key-value pair(s) into the bucket.
Args:
kv_pairs: A dictionary of key-value pairs to write into the bucket.
Keys must be UTF8 strings, and values may be arbitrary bytes.
"""
concatenated_kv_items = b""
for key, value in kv_pairs.items():
concatenated_kv_items += serializer.wrap_key_val(key.encode("utf-8"), value)
# single call to API endpoint
self._api.write(self.name, concatenated_kv_items)
def delete_key(self, key: str):
"""Deletes a single key-value pair from the bucket.
Args:
key: The key to delete.
"""
self._api.delete_key(self.name, key)
def destroy_entire_bucket(self):
"""Destroys the entire bucket. This action is permanent and irreversible."""
self._api.destroy(self.name)
def clear_entire_bucket(self):
"""Deletes all keys in this bucket. This action is permanent and irreversible.
Differs from destroy in that the bucket's metadata
(e.g. permissions, PIR scheme parameters, and clients' setup data) are preserved.
"""
self._api.clear(self.name)
def private_read(
self, keys: Union[str, list[str]]
) -> Union[Optional[bytes], list[Optional[bytes]]]:
"""Privately reads the supplied key(s) from the bucket,
and returns the corresponding value(s).
Data will be accessed using fully homomorphic encryption, designed to
make it impossible for any entity (including the Blyss service!) to
determine which key(s) are being read.
Args:
keys: A key or list of keys to privately retrieve.
If a list of keys is supplied,
results will be returned in the same order.
Returns:
For each key, the value found for the key in the bucket,
or None if the key was not found.
"""
single_query = False
if isinstance(keys, str):
keys = [keys]
single_query = True
results = [r if r is not None else None for r in self._private_read(keys)]
if single_query:
return results[0]
return results
def private_key_intersect(self, keys: list[str]) -> list[str]:
"""Privately intersects the given set of keys with the keys in this bucket,
returning the keys that intersected. This is generally slower than a single
private read, but much faster than making a private read for each key.
Has the same privacy guarantees as private_read - zero information is leaked
about keys being intersected.
Requires that the bucket was created with key_storage_policy of "bloom" or "full".
If the bucket cannot support private bloom filter lookups, an exception will be raised.
Args:
keys: A list of keys to privately intersect with this bucket.
"""
bloom_filter = self._api.bloom(self.name)
present_keys = list(filter(bloom_filter.lookup, keys))
return present_keys
class AsyncBucket(Bucket):
"""Asyncio-compatible version of Bucket."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
async def write(self, kv_pairs: dict[str, bytes], CONCURRENCY=4):
"""
Functionally equivalent to Bucket.write.
Handles chunking and parallel submission of writes, up to CONCURRENCY.
For maximum performance, call this function with as much data as possible.
Data races are possible with parallel writes, but will never corrupt data.
Args:
CONCURRENCY: The number of concurrent server writes. Maximum is 8.
"""
CONCURRENCY = min(CONCURRENCY, 8)
# Split the key-value pairs into chunks not exceeding max payload size.
kv_chunks = self._split_into_chunks(kv_pairs)
# Make one write call per chunk, while respecting a max concurrency limit.
sem = asyncio.Semaphore(CONCURRENCY)
async def _paced_writer(chunk):
async with sem:
await self._api.async_write(self.name, json.dumps(chunk))
_tasks = [asyncio.create_task(_paced_writer(c)) for c in kv_chunks]
await asyncio.gather(*_tasks)
async def private_read(self, keys: list[str]) -> list[Optional[bytes]]:
if not self._public_uuid or not await self._async_check(self._public_uuid):
self.setup()
assert self._public_uuid
multi_query = self._generate_query_stream(keys)
start = time.perf_counter()
multi_result = await self._api.async_private_read(self.name, multi_query)
self._exfil = time.perf_counter() - start
retrievals = self._unpack_query_result(keys, multi_result)
return retrievals