mirror of
https://github.com/blyssprivacy/sdk.git
synced 2026-04-26 03:00:13 -04:00
367 lines
13 KiB
Python
367 lines
13 KiB
Python
"""Bucket
|
|
|
|
Abstracts functionality on an existing bucket.
|
|
"""
|
|
|
|
from typing import Optional, Any, Union, Iterator
|
|
|
|
from . import api, seed
|
|
from .blyss_lib import BlyssLib
|
|
|
|
import json
|
|
import bz2
|
|
import asyncio
|
|
import base64
|
|
|
|
|
|
class Bucket:
|
|
"""Interface to a single Blyss bucket."""
|
|
|
|
name: str
|
|
"""Name of the bucket. See [bucket naming rules](https://docs.blyss.dev/docs/buckets#names)."""
|
|
_public_uuid: Optional[str] = None
|
|
|
|
def __init__(self, api: api.API, name: str, secret_seed: Optional[str] = None):
|
|
"""
|
|
@private
|
|
Initialize a client for a single, existing Blyss bucket.
|
|
|
|
Args:
|
|
api: A target API to send all underlying API calls to.
|
|
name: The name of the bucket.
|
|
secret_seed: An optional secret seed to initialize the client with.
|
|
A random one will be generated if not supplied.
|
|
"""
|
|
self._basic_init(api, name, secret_seed)
|
|
self._metadata = self._api._blocking_meta(self.name)
|
|
self._lib = BlyssLib(
|
|
json.dumps(self._metadata["pir_scheme"]), self._secret_seed
|
|
)
|
|
|
|
def _basic_init(self, api: api.API, name: str, secret_seed: Optional[str]):
|
|
self.name: str = name
|
|
# Internal attributes
|
|
self._api = api
|
|
if secret_seed:
|
|
self._secret_seed = secret_seed
|
|
else:
|
|
self._secret_seed = seed.get_random_seed()
|
|
|
|
def _check(self) -> bool:
|
|
"""Checks if the server has this client's public params.
|
|
|
|
Args:
|
|
uuid (str): The key to check.
|
|
|
|
Returns:
|
|
bool: Whether the server has the given UUID.
|
|
"""
|
|
if self._public_uuid is None:
|
|
raise RuntimeError("Bucket not initialized. Call setup() first.")
|
|
return self._api._blocking_check(self._public_uuid)
|
|
|
|
def _split_into_json_chunks(
|
|
self, kv_pairs: dict[str, Optional[bytes]]
|
|
) -> list[dict[str, Optional[str]]]:
|
|
_MAX_PAYLOAD = 5 * 2**20 # 5 MiB
|
|
|
|
# 1. Bin keys by row index
|
|
keys_by_index: dict[int, list[str]] = {}
|
|
for k in kv_pairs.keys():
|
|
i = self._lib.get_row(k)
|
|
if i in keys_by_index:
|
|
keys_by_index[i].append(k)
|
|
else:
|
|
keys_by_index[i] = [k]
|
|
|
|
# 2. Prepare chunks of items, where each is a JSON-ready structure.
|
|
# Each chunk is less than the maximum payload size, and guarantees
|
|
# zero overlap of rows across chunks.
|
|
kv_chunks: list[dict[str, Optional[str]]] = []
|
|
current_chunk: dict[str, Optional[str]] = {}
|
|
current_chunk_size = 0
|
|
sorted_indices = sorted(keys_by_index.keys())
|
|
for i in sorted_indices:
|
|
keys = keys_by_index[i]
|
|
# prepare all keys in this row
|
|
row = {}
|
|
row_size = 0
|
|
for key in keys:
|
|
vi = kv_pairs[key]
|
|
if vi is None:
|
|
v = None
|
|
else:
|
|
v = base64.b64encode(vi).decode("utf-8")
|
|
row[key] = v
|
|
row_size += int(
|
|
16 + len(key) + (len(v) if v is not None else 4)
|
|
) # 4 bytes for 'null'
|
|
|
|
# if the new row doesn't fit into the current chunk, start a new one
|
|
if current_chunk_size + row_size > _MAX_PAYLOAD:
|
|
kv_chunks.append(current_chunk)
|
|
current_chunk = row
|
|
current_chunk_size = row_size
|
|
else:
|
|
current_chunk.update(row)
|
|
current_chunk_size += row_size
|
|
|
|
# add the last chunk
|
|
if len(current_chunk) > 0:
|
|
kv_chunks.append(current_chunk)
|
|
|
|
return kv_chunks
|
|
|
|
def _generate_query_stream(self, keys: list[str]) -> list[bytes]:
|
|
assert self._public_uuid
|
|
# generate encrypted queries
|
|
queries: list[bytes] = [
|
|
self._lib.generate_query(self._public_uuid, self._lib.get_row(k))
|
|
for k in keys
|
|
]
|
|
return queries
|
|
|
|
def _decode_result_row(
|
|
self, result_row: bytes, silence_errors: bool = True
|
|
) -> Optional[bytes]:
|
|
try:
|
|
decrypted_result = self._lib.decode_response(result_row)
|
|
decompressed_result = bz2.decompress(decrypted_result)
|
|
return decompressed_result
|
|
except:
|
|
if not silence_errors:
|
|
raise
|
|
return None
|
|
|
|
def setup(self):
|
|
"""Prepares this bucket client for private reads.
|
|
|
|
This method will be called automatically by :method:`read`, but
|
|
clients may call it explicitly prior to make subsequent
|
|
`private_read` calls faster.
|
|
|
|
Can upload significant amounts of data (1-10 MB).
|
|
|
|
"""
|
|
public_params = self._lib.generate_keys_with_public_params()
|
|
self._public_uuid = self._api._blocking_setup(self.name, public_params)
|
|
assert self._check()
|
|
|
|
def info(self) -> dict[Any, Any]:
|
|
"""Fetch this bucket's properties from the service, such as access permissions and PIR scheme parameters."""
|
|
return self._api._blocking_meta(self.name)
|
|
|
|
def rename(self, new_name: str):
|
|
"""Rename this bucket to new_name."""
|
|
bucket_create_req = {
|
|
"name": new_name,
|
|
}
|
|
self._api._blocking_modify(self.name, bucket_create_req)
|
|
self.name = new_name
|
|
|
|
def write(self, kv_pairs: dict[str, Optional[bytes]]):
|
|
"""Writes the supplied key-value pair(s) into the bucket.
|
|
|
|
Args:
|
|
kv_pairs: A dictionary of key-value pairs to write into the bucket.
|
|
Keys must be UTF8 strings, and values may be arbitrary bytes.
|
|
"""
|
|
kv_json = {
|
|
k: base64.b64encode(v).decode("utf-8") if v else None
|
|
for k, v in kv_pairs.items()
|
|
}
|
|
self._api._blocking_write(self.name, kv_json)
|
|
|
|
def delete_key(self, keys: str | list[str]):
|
|
"""Deletes key-value pairs from the bucket.
|
|
|
|
Args:
|
|
key: The key to delete.
|
|
"""
|
|
if isinstance(keys, str):
|
|
keys = [keys]
|
|
|
|
# Writing None to a key is interpreted as a delete.
|
|
delete_payload = {k: None for k in keys}
|
|
self._api._blocking_write(self.name, delete_payload)
|
|
|
|
def destroy_entire_bucket(self):
|
|
"""Destroys the entire bucket. This action is permanent and irreversible."""
|
|
self._api._blocking_destroy(self.name)
|
|
|
|
def clear_entire_bucket(self):
|
|
"""Deletes all keys in this bucket. This action is permanent and irreversible.
|
|
|
|
Differs from destroy in that the bucket's metadata
|
|
(e.g. permissions, PIR scheme parameters, and clients' setup data) are preserved.
|
|
"""
|
|
self._api._blocking_clear(self.name)
|
|
|
|
def private_read(self, keys: list[str]) -> list[Optional[bytes]]:
|
|
"""Privately reads the supplied keys from the bucket,
|
|
and returns the corresponding values.
|
|
|
|
Data will be accessed using fully homomorphic encryption, designed to
|
|
make it impossible for any entity (including the Blyss service!) to
|
|
determine which keys are being read.
|
|
|
|
Args:
|
|
keys: A list of keys to privately retrieve.
|
|
Results will be returned in the same order.
|
|
|
|
Returns:
|
|
For each key, the value found for the key in the bucket,
|
|
or None if the key was not found.
|
|
"""
|
|
row_indices_per_key = [self._lib.get_row(k) for k in keys]
|
|
rows_per_result = self.private_read_row(row_indices_per_key)
|
|
results = [
|
|
self._lib.extract_result(key, row) if row else None
|
|
for key, row in zip(keys, rows_per_result)
|
|
]
|
|
|
|
return results
|
|
|
|
def private_read_row(self, row_indices: list[int]) -> list[Optional[bytes]]:
|
|
"""Direct API for private reads; fetches full bucket rows rather than individual keys.
|
|
|
|
Args:
|
|
row_indices: A list of row indices to privately retrieve.
|
|
Results will be returned in the same order.
|
|
|
|
Returns:
|
|
For each row index, the value found for the row in the bucket,
|
|
or None if the row was not found.
|
|
"""
|
|
if not self._public_uuid or not self._check():
|
|
self.setup()
|
|
assert self._public_uuid
|
|
|
|
queries = [self._lib.generate_query(self._public_uuid, i) for i in row_indices]
|
|
raw_rows_per_result = self._api._blocking_private_read(self.name, queries)
|
|
rows_per_result = [
|
|
self._decode_result_row(rr) if rr else None for rr in raw_rows_per_result
|
|
]
|
|
|
|
return rows_per_result
|
|
|
|
def private_key_intersect(self, keys: list[str]) -> list[str]:
|
|
"""Privately intersects the given set of keys with the keys in this bucket,
|
|
returning the keys that intersected. This is generally slower than a single
|
|
private read, but much faster than making a private read for each key.
|
|
|
|
Has the same privacy guarantees as private_read - zero information is leaked
|
|
about keys being intersected.
|
|
|
|
Requires that the bucket was created with key_storage_policy of "bloom" or "full".
|
|
If the bucket cannot support private bloom filter lookups, an exception will be raised.
|
|
|
|
Args:
|
|
keys: A list of keys to privately intersect with this bucket.
|
|
"""
|
|
bloom_filter = self._api._blocking_bloom(self.name)
|
|
present_keys = list(filter(bloom_filter.lookup, keys))
|
|
return present_keys
|
|
|
|
|
|
class AsyncBucket(Bucket):
|
|
"""Asyncio-compatible version of Bucket."""
|
|
|
|
def __init__(self, api: api.API, name: str, secret_seed: Optional[str] = None):
|
|
self._basic_init(api, name, secret_seed)
|
|
|
|
async def async_init(self):
|
|
"""Python constructors can't be async, so instances of `AsyncBucket` must call this method after construction."""
|
|
self._metadata = await self._api.meta(self.name)
|
|
self._lib = BlyssLib(
|
|
json.dumps(self._metadata["pir_scheme"]), self._secret_seed
|
|
)
|
|
|
|
async def _check(self) -> bool:
|
|
if self._public_uuid is None:
|
|
raise RuntimeError("Bucket not initialized. Call setup() first.")
|
|
try:
|
|
await self._api.check(self._public_uuid)
|
|
return True
|
|
except api.ApiException as e:
|
|
if e.code == 404:
|
|
return False
|
|
else:
|
|
raise e
|
|
|
|
async def setup(self):
|
|
public_params = self._lib.generate_keys_with_public_params()
|
|
self._public_uuid = await self._api.setup(self.name, public_params)
|
|
assert await self._check()
|
|
|
|
async def info(self) -> dict[str, Any]:
|
|
return await self._api.meta(self.name)
|
|
|
|
async def rename(self, new_name: str):
|
|
bucket_create_req = {
|
|
"name": new_name,
|
|
}
|
|
await self._api.modify(self.name, bucket_create_req)
|
|
self.name = new_name
|
|
|
|
async def delete_key(self, keys: str | list[str]):
|
|
keys = [keys] if isinstance(keys, str) else keys
|
|
delete_payload: dict[str, Optional[str]] = {k: None for k in keys}
|
|
await self._api.write(self.name, delete_payload)
|
|
|
|
async def destroy_entire_bucket(self):
|
|
await self._api.destroy(self.name)
|
|
|
|
async def clear_entire_bucket(self):
|
|
await self._api.clear(self.name)
|
|
|
|
async def write(self, kv_pairs: dict[str, Optional[bytes]], CONCURRENCY=4):
|
|
"""
|
|
Functionally equivalent to Bucket.write.
|
|
|
|
Handles chunking and parallel submission of writes, up to CONCURRENCY.
|
|
For maximum performance, call this function with as much data as possible.
|
|
Data races are possible with parallel writes, but will never corrupt data.
|
|
|
|
Args:
|
|
CONCURRENCY: The number of concurrent server writes. Maximum is 8.
|
|
"""
|
|
CONCURRENCY = min(CONCURRENCY, 8)
|
|
|
|
# Split the key-value pairs into chunks not exceeding max payload size.
|
|
# kv_chunks are JSON-ready, i.e. values are base64-encoded strings.
|
|
kv_chunks = self._split_into_json_chunks(kv_pairs)
|
|
# Make one write call per chunk, while respecting a max concurrency limit.
|
|
sem = asyncio.Semaphore(CONCURRENCY)
|
|
|
|
async def _paced_writer(chunk: dict[str, Optional[str]]):
|
|
async with sem:
|
|
await self._api.write(self.name, chunk)
|
|
|
|
_tasks = [asyncio.create_task(_paced_writer(c)) for c in kv_chunks]
|
|
await asyncio.gather(*_tasks)
|
|
|
|
async def private_read(self, keys: list[str]) -> list[Optional[bytes]]:
|
|
row_indices_per_key = [self._lib.get_row(k) for k in keys]
|
|
rows_per_result = await self.private_read_row(row_indices_per_key)
|
|
|
|
results = [
|
|
self._lib.extract_result(key, row) if row else None
|
|
for key, row in zip(keys, rows_per_result)
|
|
]
|
|
|
|
return results
|
|
|
|
async def private_read_row(self, row_indices: list[int]) -> list[Optional[bytes]]:
|
|
if not self._public_uuid or not await self._check():
|
|
await self.setup()
|
|
assert self._public_uuid
|
|
|
|
queries = [self._lib.generate_query(self._public_uuid, i) for i in row_indices]
|
|
raw_rows_per_result = await self._api.private_read(self.name, queries)
|
|
rows_per_result = [
|
|
self._decode_result_row(rr) if rr else None for rr in raw_rows_per_result
|
|
]
|
|
|
|
return rows_per_result
|