mirror of
https://github.com/blyssprivacy/sdk.git
synced 2026-01-13 09:07:58 -05:00
359 lines
13 KiB
Python
359 lines
13 KiB
Python
"""Bucket
|
|
|
|
Abstracts functionality on an existing bucket.
|
|
"""
|
|
|
|
from typing import Optional, Any, Union, Iterator
|
|
|
|
from . import api, serializer, seed
|
|
from .blyss_lib import BlyssLib
|
|
|
|
import json
|
|
import base64
|
|
import bz2
|
|
import time
|
|
import asyncio
|
|
|
|
|
|
def _chunk_parser(raw_data: bytes) -> Iterator[bytes]:
|
|
"""
|
|
Parse a bytestream containing an arbitrary number of length-prefixed chunks.
|
|
|
|
"""
|
|
data = memoryview(raw_data)
|
|
i = 0
|
|
num_chunks = int.from_bytes(data[:8], "little", signed=False)
|
|
i += 8
|
|
for _ in range(num_chunks):
|
|
chunk_len = int.from_bytes(data[i : i + 8], "little", signed=False)
|
|
i += 8
|
|
chunk_data = bytes(data[i : i + chunk_len])
|
|
i += chunk_len
|
|
yield chunk_data
|
|
|
|
|
|
class Bucket:
|
|
"""Interface to a single Blyss bucket."""
|
|
|
|
def __init__(self, api: api.API, name: str, secret_seed: Optional[str] = None):
|
|
"""
|
|
@private
|
|
Initialize a client for a single, existing Blyss bucket.
|
|
|
|
Args:
|
|
api: A target API to send all underlying API calls to.
|
|
name: The name of the bucket.
|
|
secret_seed: An optional secret seed to initialize the client with.
|
|
A random one will be generated if not supplied.
|
|
"""
|
|
self.name: str = name
|
|
"""Name of the bucket. See [bucket naming rules](https://docs.blyss.dev/docs/buckets#names)."""
|
|
|
|
# Internal attributes
|
|
self._api = api
|
|
self._metadata = self._api.meta(self.name)
|
|
if secret_seed:
|
|
self._secret_seed = secret_seed
|
|
else:
|
|
self._secret_seed = seed.get_random_seed()
|
|
self._lib = BlyssLib(
|
|
json.dumps(self._metadata["pir_scheme"]), self._secret_seed
|
|
)
|
|
self._public_uuid: Optional[str] = None
|
|
self._exfil: Any = None # used for benchmarking
|
|
|
|
def _check(self, uuid: str) -> bool:
|
|
"""Checks if the server has the given UUID.
|
|
|
|
Args:
|
|
uuid (str): The key to check.
|
|
|
|
Returns:
|
|
bool: Whether the server has the given UUID.
|
|
"""
|
|
try:
|
|
self._api.check(uuid)
|
|
return True
|
|
except api.ApiException as e:
|
|
if e.code == 404:
|
|
return False
|
|
else:
|
|
raise e
|
|
|
|
async def _async_check(self, uuid: str) -> bool:
|
|
try:
|
|
await self._api.async_check(uuid)
|
|
return True
|
|
except api.ApiException as e:
|
|
if e.code == 404:
|
|
return False
|
|
else:
|
|
raise e
|
|
|
|
def _split_into_chunks(
|
|
self, kv_pairs: dict[str, bytes]
|
|
) -> list[list[dict[str, str]]]:
|
|
_MAX_PAYLOAD = 5 * 2**20 # 5 MiB
|
|
|
|
# 1. Bin keys by row index
|
|
keys_by_index: dict[int, list[str]] = {}
|
|
for k in kv_pairs.keys():
|
|
i = self._lib.get_row(k)
|
|
if i in keys_by_index:
|
|
keys_by_index[i].append(k)
|
|
else:
|
|
keys_by_index[i] = [k]
|
|
|
|
# 2. Prepare chunks of items, where each is a JSON-ready structure.
|
|
# Each chunk is less than the maximum payload size, and guarantees
|
|
# zero overlap of rows across chunks.
|
|
kv_chunks: list[list[dict[str, str]]] = []
|
|
current_chunk: list[dict[str, str]] = []
|
|
current_chunk_size = 0
|
|
sorted_indices = sorted(keys_by_index.keys())
|
|
for i in sorted_indices:
|
|
keys = keys_by_index[i]
|
|
# prepare all keys in this row
|
|
row = []
|
|
row_size = 0
|
|
for key in keys:
|
|
value = kv_pairs[key]
|
|
value_str = base64.b64encode(value).decode("utf-8")
|
|
fmt = {
|
|
"key": key,
|
|
"value": value_str,
|
|
"content-type": "application/octet-stream",
|
|
}
|
|
row.append(fmt)
|
|
row_size += int(24 + len(key) + len(value_str) + 48)
|
|
|
|
# if the new row doesn't fit into the current chunk, start a new one
|
|
if current_chunk_size + row_size > _MAX_PAYLOAD:
|
|
kv_chunks.append(current_chunk)
|
|
current_chunk = row
|
|
current_chunk_size = row_size
|
|
else:
|
|
current_chunk.extend(row)
|
|
current_chunk_size += row_size
|
|
|
|
# add the last chunk
|
|
if len(current_chunk) > 0:
|
|
kv_chunks.append(current_chunk)
|
|
|
|
return kv_chunks
|
|
|
|
def _generate_query_stream(self, keys: list[str]) -> bytes:
|
|
assert self._public_uuid
|
|
# generate encrypted queries
|
|
queries: list[bytes] = [
|
|
self._lib.generate_query(self._public_uuid, self._lib.get_row(k))
|
|
for k in keys
|
|
]
|
|
# interleave the queries with their lengths (uint64_t)
|
|
query_lengths = [len(q).to_bytes(8, "little") for q in queries]
|
|
lengths_and_queries = [x for lq in zip(query_lengths, queries) for x in lq]
|
|
# prepend the total number of queries (uint64_t)
|
|
lengths_and_queries.insert(0, len(queries).to_bytes(8, "little"))
|
|
# serialize the queries
|
|
multi_query = b"".join(lengths_and_queries)
|
|
return multi_query
|
|
|
|
def _unpack_query_result(
|
|
self, keys: list[str], raw_result: bytes, ignore_errors=False
|
|
) -> list[Optional[bytes]]:
|
|
retrievals = []
|
|
for key, result in zip(keys, _chunk_parser(raw_result)):
|
|
if len(result) == 0:
|
|
# error in processing this query
|
|
if ignore_errors:
|
|
extracted_result = None
|
|
else:
|
|
raise RuntimeError(f"Failed to process query for key {key}.")
|
|
else:
|
|
decrypted_result = self._lib.decode_response(result)
|
|
decompressed_result = bz2.decompress(decrypted_result)
|
|
extracted_result = self._lib.extract_result(key, decompressed_result)
|
|
retrievals.append(extracted_result)
|
|
return retrievals
|
|
|
|
def _private_read(self, keys: list[str]) -> list[Optional[bytes]]:
|
|
"""Performs the underlying private retrieval.
|
|
|
|
Args:
|
|
keys (str): A list of keys to retrieve.
|
|
|
|
Returns:
|
|
a list of values (bytes) corresponding to keys. None for keys not found.
|
|
"""
|
|
if not self._public_uuid or not self._check(self._public_uuid):
|
|
self.setup()
|
|
assert self._public_uuid
|
|
|
|
multi_query = self._generate_query_stream(keys)
|
|
|
|
start = time.perf_counter()
|
|
multi_result = self._api.private_read(self.name, multi_query)
|
|
self._exfil = time.perf_counter() - start
|
|
|
|
retrievals = self._unpack_query_result(keys, multi_result)
|
|
|
|
return retrievals
|
|
|
|
def setup(self):
|
|
"""Prepares this bucket client for private reads.
|
|
|
|
This method will be called automatically by :method:`read`, but
|
|
clients may call it explicitly prior to make subsequent
|
|
`private_read` calls faster.
|
|
|
|
Can upload significant amounts of data (1-10 MB).
|
|
|
|
"""
|
|
public_params = self._lib.generate_keys_with_public_params()
|
|
setup_resp = self._api.setup(self.name, bytes(public_params))
|
|
self._public_uuid = setup_resp["uuid"]
|
|
|
|
def info(self) -> dict[Any, Any]:
|
|
"""Fetch this bucket's properties from the service, such as access permissions and PIR scheme parameters."""
|
|
return self._api.meta(self.name)
|
|
|
|
def list_keys(self) -> list[str]:
|
|
"""List all key strings in this bucket. Only available if bucket was created with keyStoragePolicy="full"."""
|
|
return self._api.list_keys(self.name)
|
|
|
|
def rename(self, new_name: str):
|
|
"""Rename this bucket to new_name."""
|
|
bucket_create_req = {
|
|
"name": new_name,
|
|
}
|
|
self._api.modify(self.name, json.dumps(bucket_create_req))
|
|
self.name = new_name
|
|
|
|
def write(self, kv_pairs: dict[str, bytes]):
|
|
"""Writes the supplied key-value pair(s) into the bucket.
|
|
|
|
Args:
|
|
kv_pairs: A dictionary of key-value pairs to write into the bucket.
|
|
Keys must be UTF8 strings, and values may be arbitrary bytes.
|
|
"""
|
|
concatenated_kv_items = b""
|
|
for key, value in kv_pairs.items():
|
|
concatenated_kv_items += serializer.wrap_key_val(key.encode("utf-8"), value)
|
|
# single call to API endpoint
|
|
self._api.write(self.name, concatenated_kv_items)
|
|
|
|
def delete_key(self, key: str):
|
|
"""Deletes a single key-value pair from the bucket.
|
|
|
|
Args:
|
|
key: The key to delete.
|
|
"""
|
|
self._api.delete_key(self.name, key)
|
|
|
|
def destroy_entire_bucket(self):
|
|
"""Destroys the entire bucket. This action is permanent and irreversible."""
|
|
self._api.destroy(self.name)
|
|
|
|
def clear_entire_bucket(self):
|
|
"""Deletes all keys in this bucket. This action is permanent and irreversible.
|
|
|
|
Differs from destroy in that the bucket's metadata
|
|
(e.g. permissions, PIR scheme parameters, and clients' setup data) are preserved.
|
|
"""
|
|
self._api.clear(self.name)
|
|
|
|
def private_read(
|
|
self, keys: Union[str, list[str]]
|
|
) -> Union[Optional[bytes], list[Optional[bytes]]]:
|
|
"""Privately reads the supplied key(s) from the bucket,
|
|
and returns the corresponding value(s).
|
|
|
|
Data will be accessed using fully homomorphic encryption, designed to
|
|
make it impossible for any entity (including the Blyss service!) to
|
|
determine which key(s) are being read.
|
|
|
|
Args:
|
|
keys: A key or list of keys to privately retrieve.
|
|
If a list of keys is supplied,
|
|
results will be returned in the same order.
|
|
|
|
Returns:
|
|
For each key, the value found for the key in the bucket,
|
|
or None if the key was not found.
|
|
"""
|
|
single_query = False
|
|
if isinstance(keys, str):
|
|
keys = [keys]
|
|
single_query = True
|
|
|
|
results = [r if r is not None else None for r in self._private_read(keys)]
|
|
if single_query:
|
|
return results[0]
|
|
|
|
return results
|
|
|
|
def private_key_intersect(self, keys: list[str]) -> list[str]:
|
|
"""Privately intersects the given set of keys with the keys in this bucket,
|
|
returning the keys that intersected. This is generally slower than a single
|
|
private read, but much faster than making a private read for each key.
|
|
|
|
Has the same privacy guarantees as private_read - zero information is leaked
|
|
about keys being intersected.
|
|
|
|
Requires that the bucket was created with key_storage_policy of "bloom" or "full".
|
|
If the bucket cannot support private bloom filter lookups, an exception will be raised.
|
|
|
|
Args:
|
|
keys: A list of keys to privately intersect with this bucket.
|
|
"""
|
|
bloom_filter = self._api.bloom(self.name)
|
|
present_keys = list(filter(bloom_filter.lookup, keys))
|
|
return present_keys
|
|
|
|
|
|
class AsyncBucket(Bucket):
|
|
"""Asyncio-compatible version of Bucket."""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
async def write(self, kv_pairs: dict[str, bytes], CONCURRENCY=4):
|
|
"""
|
|
Functionally equivalent to Bucket.write.
|
|
|
|
Handles chunking and parallel submission of writes, up to CONCURRENCY.
|
|
For maximum performance, call this function with as much data as possible.
|
|
Data races are possible with parallel writes, but will never corrupt data.
|
|
|
|
Args:
|
|
CONCURRENCY: The number of concurrent server writes. Maximum is 8.
|
|
"""
|
|
CONCURRENCY = min(CONCURRENCY, 8)
|
|
|
|
# Split the key-value pairs into chunks not exceeding max payload size.
|
|
kv_chunks = self._split_into_chunks(kv_pairs)
|
|
# Make one write call per chunk, while respecting a max concurrency limit.
|
|
sem = asyncio.Semaphore(CONCURRENCY)
|
|
|
|
async def _paced_writer(chunk):
|
|
async with sem:
|
|
await self._api.async_write(self.name, json.dumps(chunk))
|
|
|
|
_tasks = [asyncio.create_task(_paced_writer(c)) for c in kv_chunks]
|
|
await asyncio.gather(*_tasks)
|
|
|
|
async def private_read(self, keys: list[str]) -> list[Optional[bytes]]:
|
|
if not self._public_uuid or not await self._async_check(self._public_uuid):
|
|
self.setup()
|
|
assert self._public_uuid
|
|
|
|
multi_query = self._generate_query_stream(keys)
|
|
|
|
start = time.perf_counter()
|
|
multi_result = await self._api.async_private_read(self.name, multi_query)
|
|
self._exfil = time.perf_counter() - start
|
|
|
|
retrievals = self._unpack_query_result(keys, multi_result)
|
|
|
|
return retrievals
|