diff --git a/python/blyss/api.py b/python/blyss/api.py index f7590fc..1a6de22 100644 --- a/python/blyss/api.py +++ b/python/blyss/api.py @@ -9,7 +9,6 @@ from typing import Any, Optional, Union import requests import httpx import gzip -import asyncio import json import logging import base64 @@ -116,14 +115,14 @@ def _post_data(api_key: str, url: str, data: Union[bytes, str]) -> bytes: return resp.content -def _post_data_json(api_key: str, url: str, data: Union[bytes, str]) -> dict[Any, Any]: +def _post_data_json(api_key: str, url: str, data: Union[bytes, Any]) -> Any: """Perform an HTTP POST request, returning a JSON-parsed dict. Request data can be any JSON string, or a raw bytestring that will be base64-encoded before send. All requests and responses are compressed JSON.""" if len(data) > APIGW_MAX_SIZE: raise ValueError( - f"Scheme public parameters too large ({len(data)} bytes); maximum size is {APIGW_MAX_SIZE} bytes" + f"Request data is too large ({len(data)} bytes); maximum size is {APIGW_MAX_SIZE} bytes" ) c = httpx.Client( @@ -133,20 +132,19 @@ def _post_data_json(api_key: str, url: str, data: Union[bytes, str]) -> dict[Any } ) - data_textsafe: bytes if type(data) == bytes: - data_textsafe = base64.b64encode(data) - elif type(data) == str: - data_textsafe = data.encode("utf-8") + data_jsonable = base64.b64encode(data) else: - raise ValueError(f"Unsupported data type {type(data)}") + data_jsonable = data + data_json = json.dumps(data_jsonable).encode("utf-8") + # compress requests larger than 1KB extra_headers = {} - if len(data_textsafe) > 1000: - payload = gzip.compress(data_textsafe) + if len(data_json) > 1000: + payload = gzip.compress(data_json) extra_headers["Content-Encoding"] = "gzip" else: - payload = data_textsafe + payload = data_json resp = c.post(url, content=payload, headers=extra_headers) @@ -163,26 +161,28 @@ def _post_form_data(url: str, fields: dict[Any, Any], data: bytes): async def _async_post_data( api_key: str, url: str, - data: Union[str, bytes], + data: Union[bytes, Any], compress: bool = True, decode_json: bool = True, ) -> Any: """Perform an async HTTP POST request, returning a JSON-parsed dict response""" headers = {"x-api-key": api_key, "Content-Type": "application/json"} - if type(data) == str: - data = data.encode("utf-8") - elif type(data) == bytes: - data = base64.b64encode(data) + + if type(data) == bytes: + data_jsonable = base64.b64encode(data).decode("utf-8") else: - raise ValueError(f"Unsupported data type {type(data)}") + data_jsonable = data + data_json = json.dumps(data_jsonable).encode("utf-8") if compress: # apply gzip compression to data before sending - data = gzip.compress(data) + payload = gzip.compress(data_json) headers["Content-Encoding"] = "gzip" + else: + payload = data_json async with httpx.AsyncClient(timeout=httpx.Timeout(5, read=None)) as client: - r = await client.post(url, content=data, headers=headers) + r = await client.post(url, content=payload, headers=headers) _check_http_error(r) # type: ignore if decode_json: @@ -212,14 +212,14 @@ class API: def _service_url_for(self, path: str) -> str: return self.service_endpoint + path - def create(self, data_json: str) -> dict[Any, Any]: + def create(self, data_jsonable: dict) -> dict[Any, Any]: """Create a new bucket, given the supplied data. Args: data_json (str): A JSON-encoded string of the new bucket request. """ return _post_data_json( - self.api_key, self._service_url_for(CREATE_PATH), data_json + self.api_key, self._service_url_for(CREATE_PATH), data_jsonable ) def check(self, uuid: str) -> dict[Any, Any]: @@ -251,14 +251,14 @@ class API: def _url_for(self, bucket_name: str, path: str) -> str: return self.service_endpoint + "/" + bucket_name + path - def modify(self, bucket_name: str, data_json: str) -> dict[Any, Any]: + def modify(self, bucket_name: str, data_jsonable: Any) -> dict[Any, Any]: """Modify existing bucket. Args: data_json (str): same as create. """ return _post_data_json( - self.api_key, self._url_for(bucket_name, MODIFY_PATH), data_json + self.api_key, self._url_for(bucket_name, MODIFY_PATH), data_jsonable ) def meta(self, bucket_name: str) -> dict[Any, Any]: @@ -325,14 +325,29 @@ class API: """Delete all keys in this bucket.""" _post_data(self.api_key, self._url_for(bucket_name, CLEAR_PATH), "") - def write(self, bucket_name: str, data: bytes): + def write(self, bucket_name: str, data: dict[str, Optional[bytes]]): """Write some data to this bucket.""" - _post_data(self.api_key, self._url_for(bucket_name, WRITE_PATH), data) + data_jsonable = { + k: None if v is None else base64.b64encode(v).decode("utf-8") + for k, v in data.items() + } + _post_data_json( + self.api_key, self._url_for(bucket_name, WRITE_PATH), data_jsonable + ) - async def async_write(self, bucket_name: str, data: str): + async def async_write(self, bucket_name: str, data: dict[str, Optional[bytes]]): """Write JSON payload to this bucket.""" + + data_jsonable = { + k: None if v is None else base64.b64encode(v).decode("utf-8") + for k, v in data.items() + } + await _async_post_data( - self.api_key, self._url_for(bucket_name, WRITE_PATH), data, decode_json=True + self.api_key, + self._url_for(bucket_name, WRITE_PATH), + data_jsonable, + compress=True, ) def delete_key(self, bucket_name: str, key: str): @@ -341,16 +356,26 @@ class API: self.api_key, self._url_for(bucket_name, DELETE_PATH), key.encode("utf-8") ) - def private_read(self, bucket_name: str, data: bytes) -> bytes: + def private_read( + self, bucket_name: str, queries: list[bytes] + ) -> list[Optional[bytes]]: """Privately read data from this bucket.""" - val = _post_data(self.api_key, self._url_for(bucket_name, READ_PATH), data) - return base64.b64decode(val) - - async def async_private_read(self, bucket_name: str, data: bytes) -> bytes: - """Privately read data from this bucket.""" - val: bytes = await _async_post_data( - self.api_key, self._url_for(bucket_name, READ_PATH), data, decode_json=False + data_jsonable = [base64.b64encode(q).decode("utf-8") for q in queries] + r = _post_data_json( + self.api_key, self._url_for(bucket_name, READ_PATH), data_jsonable ) - # AWS APIGW encodes its responses as base64 - return base64.b64decode(val) - # return self.private_read(bucket_name, data) + return [base64.b64decode(v) if v is not None else None for v in r] + + async def async_private_read( + self, bucket_name: str, queries: list[bytes] + ) -> list[Optional[bytes]]: + """Privately read data from this bucket.""" + data_jsonable = [base64.b64encode(q).decode("utf-8") for q in queries] + r: list[str] = await _async_post_data( + self.api_key, + self._url_for(bucket_name, READ_PATH), + data_jsonable, + compress=True, + decode_json=True, + ) + return [base64.b64decode(v) if v is not None else None for v in r] diff --git a/python/blyss/bucket.py b/python/blyss/bucket.py index e87dd43..54c5a58 100644 --- a/python/blyss/bucket.py +++ b/python/blyss/bucket.py @@ -5,11 +5,10 @@ Abstracts functionality on an existing bucket. from typing import Optional, Any, Union, Iterator -from . import api, serializer, seed +from . import api, seed from .blyss_lib import BlyssLib import json -import base64 import bz2 import time import asyncio @@ -82,9 +81,7 @@ class Bucket: else: raise e - def _split_into_chunks( - self, kv_pairs: dict[str, bytes] - ) -> list[list[dict[str, str]]]: + def _split_into_chunks(self, kv_pairs: dict[str, bytes]) -> list[dict[str, bytes]]: _MAX_PAYLOAD = 5 * 2**20 # 5 MiB # 1. Bin keys by row index @@ -99,25 +96,19 @@ class Bucket: # 2. Prepare chunks of items, where each is a JSON-ready structure. # Each chunk is less than the maximum payload size, and guarantees # zero overlap of rows across chunks. - kv_chunks: list[list[dict[str, str]]] = [] - current_chunk: list[dict[str, str]] = [] + kv_chunks: list[dict[str, bytes]] = [] + current_chunk: dict[str, bytes] = {} current_chunk_size = 0 sorted_indices = sorted(keys_by_index.keys()) for i in sorted_indices: keys = keys_by_index[i] # prepare all keys in this row - row = [] + row = {} row_size = 0 for key in keys: - value = kv_pairs[key] - value_str = base64.b64encode(value).decode("utf-8") - fmt = { - "key": key, - "value": value_str, - "content-type": "application/octet-stream", - } - row.append(fmt) - row_size += int(24 + len(key) + len(value_str) + 48) + v = kv_pairs[key] + row[key] = v + row_size += int(16 + len(key) + len(v) * 4 / 3) # if the new row doesn't fit into the current chunk, start a new one if current_chunk_size + row_size > _MAX_PAYLOAD: @@ -125,7 +116,7 @@ class Bucket: current_chunk = row current_chunk_size = row_size else: - current_chunk.extend(row) + current_chunk.update(row) current_chunk_size += row_size # add the last chunk @@ -134,13 +125,14 @@ class Bucket: return kv_chunks - def _generate_query_stream(self, keys: list[str]) -> bytes: + def _generate_query_stream(self, keys: list[str]) -> list[bytes]: assert self._public_uuid # generate encrypted queries queries: list[bytes] = [ self._lib.generate_query(self._public_uuid, self._lib.get_row(k)) for k in keys ] + return queries # interleave the queries with their lengths (uint64_t) query_lengths = [len(q).to_bytes(8, "little") for q in queries] lengths_and_queries = [x for lq in zip(query_lengths, queries) for x in lq] @@ -150,6 +142,14 @@ class Bucket: multi_query = b"".join(lengths_and_queries) return multi_query + def _decode_result(self, key: str, result_row: bytes) -> Optional[bytes]: + try: + decrypted_result = self._lib.decode_response(result_row) + decompressed_result = bz2.decompress(decrypted_result) + return self._lib.extract_result(key, decompressed_result) + except: + return None + def _unpack_query_result( self, keys: list[str], raw_result: bytes, ignore_errors=False ) -> list[Optional[bytes]]: @@ -162,9 +162,7 @@ class Bucket: else: raise RuntimeError(f"Failed to process query for key {key}.") else: - decrypted_result = self._lib.decode_response(result) - decompressed_result = bz2.decompress(decrypted_result) - extracted_result = self._lib.extract_result(key, decompressed_result) + extracted_result = self._decode_result(key, result) retrievals.append(extracted_result) return retrievals @@ -181,15 +179,18 @@ class Bucket: self.setup() assert self._public_uuid - multi_query = self._generate_query_stream(keys) + queries = self._generate_query_stream(keys) start = time.perf_counter() - multi_result = self._api.private_read(self.name, multi_query) + rows_per_result = self._api.private_read(self.name, queries) self._exfil = time.perf_counter() - start - retrievals = self._unpack_query_result(keys, multi_result) + results = [ + self._decode_result(key, result) if result else None + for key, result in zip(keys, rows_per_result) + ] - return retrievals + return results def setup(self): """Prepares this bucket client for private reads. @@ -218,7 +219,8 @@ class Bucket: bucket_create_req = { "name": new_name, } - self._api.modify(self.name, json.dumps(bucket_create_req)) + r = self._api.modify(self.name, bucket_create_req) + print(r) self.name = new_name def write(self, kv_pairs: dict[str, bytes]): @@ -228,19 +230,20 @@ class Bucket: kv_pairs: A dictionary of key-value pairs to write into the bucket. Keys must be UTF8 strings, and values may be arbitrary bytes. """ - concatenated_kv_items = b"" - for key, value in kv_pairs.items(): - concatenated_kv_items += serializer.wrap_key_val(key.encode("utf-8"), value) - # single call to API endpoint - self._api.write(self.name, concatenated_kv_items) + self._api.write(self.name, kv_pairs) # type: ignore + # bytes is a valid subset of Optional[bytes], despite mypy's complaints - def delete_key(self, key: str): - """Deletes a single key-value pair from the bucket. + def delete_key(self, keys: str | list[str]): + """Deletes key-value pairs from the bucket. Args: key: The key to delete. """ - self._api.delete_key(self.name, key) + if isinstance(keys, str): + keys = [keys] + + delete_payload = {k: None for k in keys} + self._api.write(self.name, delete_payload) # type: ignore def destroy_entire_bucket(self): """Destroys the entire bucket. This action is permanent and irreversible.""" @@ -278,7 +281,7 @@ class Bucket: keys = [keys] single_query = True - results = [r if r is not None else None for r in self._private_read(keys)] + results = self._private_read(keys) if single_query: return results[0] @@ -344,9 +347,10 @@ class AsyncBucket(Bucket): # Make one write call per chunk, while respecting a max concurrency limit. sem = asyncio.Semaphore(CONCURRENCY) - async def _paced_writer(chunk): + async def _paced_writer(chunk: dict[str, bytes]): async with sem: - await self._api.async_write(self.name, json.dumps(chunk)) + await self._api.async_write(self.name, chunk) # type: ignore + # bytes is a valid subset of Optional[bytes], despite mypy's complaints _tasks = [asyncio.create_task(_paced_writer(c)) for c in kv_chunks] await asyncio.gather(*_tasks) @@ -359,9 +363,12 @@ class AsyncBucket(Bucket): multi_query = self._generate_query_stream(keys) start = time.perf_counter() - multi_result = await self._api.async_private_read(self.name, multi_query) + rows_per_result = await self._api.async_private_read(self.name, multi_query) self._exfil = time.perf_counter() - start - retrievals = self._unpack_query_result(keys, multi_result) + results = [ + self._decode_result(key, result) if result else None + for key, result in zip(keys, rows_per_result) + ] - return retrievals + return results diff --git a/python/blyss/bucket_service.py b/python/blyss/bucket_service.py index e92431f..7dc442c 100644 --- a/python/blyss/bucket_service.py +++ b/python/blyss/bucket_service.py @@ -1,6 +1,5 @@ from typing import Any, Optional, Union from . import bucket, api, seed -import json BLYSS_BUCKET_URL = "https://beta.api.blyss.dev" DEFAULT_BUCKET_PARAMETERS = { @@ -83,10 +82,10 @@ class BucketService: parameters.update(usage_hints) bucket_create_req = { "name": bucket_name, - "parameters": json.dumps(parameters), + "parameters": parameters, "open_access": open_access, } - self._api.create(json.dumps(bucket_create_req)) + r = self._api.create(bucket_create_req) def exists(self, name: str) -> bool: """Check if a bucket exists. diff --git a/python/tests/test_service.py b/python/tests/test_service.py index 5688ca9..9098127 100644 --- a/python/tests/test_service.py +++ b/python/tests/test_service.py @@ -19,9 +19,14 @@ def key_to_gold_value(key: str, length: int = 512) -> bytes: def verify_read(key: str, value: bytes): + expected = key_to_gold_value(key, len(value)) try: - assert value == key_to_gold_value(key, len(value)) + assert value == expected except: + print(f"read mismatch for key {key}") + print(f"received {value.hex()[:16]}") + print(f"expected {expected.hex()[:16]}") + print(traceback.format_exc()) raise @@ -119,6 +124,8 @@ if __name__ == "__main__": if len(sys.argv) > 2: print("Using api_key from command line") api_key = sys.argv[2] + if api_key == "none": + api_key = None print("DEBUG", api_key, endpoint) assert endpoint is not None assert api_key is not None