From e27fb7d5ac2fb014c4790ab731cbe052236ce149 Mon Sep 17 00:00:00 2001 From: Neil Movva Date: Tue, 5 Sep 2023 07:38:20 +0000 Subject: [PATCH] (python) more consistent json for internal api all requests and response are JSON. all binary payloads are explicitly encoded as base64 within api.py, and decoded back to bytes before leaving api.py. User-facing code, e.g. bucket.py and bucket_service.py, should not see base64 wrangling. --- python/blyss/api.py | 101 ++++++++++++++++++++------------- python/blyss/bucket.py | 89 ++++++++++++++++------------- python/blyss/bucket_service.py | 5 +- python/tests/test_service.py | 9 ++- 4 files changed, 121 insertions(+), 83 deletions(-) diff --git a/python/blyss/api.py b/python/blyss/api.py index f7590fc..1a6de22 100644 --- a/python/blyss/api.py +++ b/python/blyss/api.py @@ -9,7 +9,6 @@ from typing import Any, Optional, Union import requests import httpx import gzip -import asyncio import json import logging import base64 @@ -116,14 +115,14 @@ def _post_data(api_key: str, url: str, data: Union[bytes, str]) -> bytes: return resp.content -def _post_data_json(api_key: str, url: str, data: Union[bytes, str]) -> dict[Any, Any]: +def _post_data_json(api_key: str, url: str, data: Union[bytes, Any]) -> Any: """Perform an HTTP POST request, returning a JSON-parsed dict. Request data can be any JSON string, or a raw bytestring that will be base64-encoded before send. All requests and responses are compressed JSON.""" if len(data) > APIGW_MAX_SIZE: raise ValueError( - f"Scheme public parameters too large ({len(data)} bytes); maximum size is {APIGW_MAX_SIZE} bytes" + f"Request data is too large ({len(data)} bytes); maximum size is {APIGW_MAX_SIZE} bytes" ) c = httpx.Client( @@ -133,20 +132,19 @@ def _post_data_json(api_key: str, url: str, data: Union[bytes, str]) -> dict[Any } ) - data_textsafe: bytes if type(data) == bytes: - data_textsafe = base64.b64encode(data) - elif type(data) == str: - data_textsafe = data.encode("utf-8") + data_jsonable = base64.b64encode(data) else: - raise ValueError(f"Unsupported data type {type(data)}") + data_jsonable = data + data_json = json.dumps(data_jsonable).encode("utf-8") + # compress requests larger than 1KB extra_headers = {} - if len(data_textsafe) > 1000: - payload = gzip.compress(data_textsafe) + if len(data_json) > 1000: + payload = gzip.compress(data_json) extra_headers["Content-Encoding"] = "gzip" else: - payload = data_textsafe + payload = data_json resp = c.post(url, content=payload, headers=extra_headers) @@ -163,26 +161,28 @@ def _post_form_data(url: str, fields: dict[Any, Any], data: bytes): async def _async_post_data( api_key: str, url: str, - data: Union[str, bytes], + data: Union[bytes, Any], compress: bool = True, decode_json: bool = True, ) -> Any: """Perform an async HTTP POST request, returning a JSON-parsed dict response""" headers = {"x-api-key": api_key, "Content-Type": "application/json"} - if type(data) == str: - data = data.encode("utf-8") - elif type(data) == bytes: - data = base64.b64encode(data) + + if type(data) == bytes: + data_jsonable = base64.b64encode(data).decode("utf-8") else: - raise ValueError(f"Unsupported data type {type(data)}") + data_jsonable = data + data_json = json.dumps(data_jsonable).encode("utf-8") if compress: # apply gzip compression to data before sending - data = gzip.compress(data) + payload = gzip.compress(data_json) headers["Content-Encoding"] = "gzip" + else: + payload = data_json async with httpx.AsyncClient(timeout=httpx.Timeout(5, read=None)) as client: - r = await client.post(url, content=data, headers=headers) + r = await client.post(url, content=payload, headers=headers) _check_http_error(r) # type: ignore if decode_json: @@ -212,14 +212,14 @@ class API: def _service_url_for(self, path: str) -> str: return self.service_endpoint + path - def create(self, data_json: str) -> dict[Any, Any]: + def create(self, data_jsonable: dict) -> dict[Any, Any]: """Create a new bucket, given the supplied data. Args: data_json (str): A JSON-encoded string of the new bucket request. """ return _post_data_json( - self.api_key, self._service_url_for(CREATE_PATH), data_json + self.api_key, self._service_url_for(CREATE_PATH), data_jsonable ) def check(self, uuid: str) -> dict[Any, Any]: @@ -251,14 +251,14 @@ class API: def _url_for(self, bucket_name: str, path: str) -> str: return self.service_endpoint + "/" + bucket_name + path - def modify(self, bucket_name: str, data_json: str) -> dict[Any, Any]: + def modify(self, bucket_name: str, data_jsonable: Any) -> dict[Any, Any]: """Modify existing bucket. Args: data_json (str): same as create. """ return _post_data_json( - self.api_key, self._url_for(bucket_name, MODIFY_PATH), data_json + self.api_key, self._url_for(bucket_name, MODIFY_PATH), data_jsonable ) def meta(self, bucket_name: str) -> dict[Any, Any]: @@ -325,14 +325,29 @@ class API: """Delete all keys in this bucket.""" _post_data(self.api_key, self._url_for(bucket_name, CLEAR_PATH), "") - def write(self, bucket_name: str, data: bytes): + def write(self, bucket_name: str, data: dict[str, Optional[bytes]]): """Write some data to this bucket.""" - _post_data(self.api_key, self._url_for(bucket_name, WRITE_PATH), data) + data_jsonable = { + k: None if v is None else base64.b64encode(v).decode("utf-8") + for k, v in data.items() + } + _post_data_json( + self.api_key, self._url_for(bucket_name, WRITE_PATH), data_jsonable + ) - async def async_write(self, bucket_name: str, data: str): + async def async_write(self, bucket_name: str, data: dict[str, Optional[bytes]]): """Write JSON payload to this bucket.""" + + data_jsonable = { + k: None if v is None else base64.b64encode(v).decode("utf-8") + for k, v in data.items() + } + await _async_post_data( - self.api_key, self._url_for(bucket_name, WRITE_PATH), data, decode_json=True + self.api_key, + self._url_for(bucket_name, WRITE_PATH), + data_jsonable, + compress=True, ) def delete_key(self, bucket_name: str, key: str): @@ -341,16 +356,26 @@ class API: self.api_key, self._url_for(bucket_name, DELETE_PATH), key.encode("utf-8") ) - def private_read(self, bucket_name: str, data: bytes) -> bytes: + def private_read( + self, bucket_name: str, queries: list[bytes] + ) -> list[Optional[bytes]]: """Privately read data from this bucket.""" - val = _post_data(self.api_key, self._url_for(bucket_name, READ_PATH), data) - return base64.b64decode(val) - - async def async_private_read(self, bucket_name: str, data: bytes) -> bytes: - """Privately read data from this bucket.""" - val: bytes = await _async_post_data( - self.api_key, self._url_for(bucket_name, READ_PATH), data, decode_json=False + data_jsonable = [base64.b64encode(q).decode("utf-8") for q in queries] + r = _post_data_json( + self.api_key, self._url_for(bucket_name, READ_PATH), data_jsonable ) - # AWS APIGW encodes its responses as base64 - return base64.b64decode(val) - # return self.private_read(bucket_name, data) + return [base64.b64decode(v) if v is not None else None for v in r] + + async def async_private_read( + self, bucket_name: str, queries: list[bytes] + ) -> list[Optional[bytes]]: + """Privately read data from this bucket.""" + data_jsonable = [base64.b64encode(q).decode("utf-8") for q in queries] + r: list[str] = await _async_post_data( + self.api_key, + self._url_for(bucket_name, READ_PATH), + data_jsonable, + compress=True, + decode_json=True, + ) + return [base64.b64decode(v) if v is not None else None for v in r] diff --git a/python/blyss/bucket.py b/python/blyss/bucket.py index e87dd43..54c5a58 100644 --- a/python/blyss/bucket.py +++ b/python/blyss/bucket.py @@ -5,11 +5,10 @@ Abstracts functionality on an existing bucket. from typing import Optional, Any, Union, Iterator -from . import api, serializer, seed +from . import api, seed from .blyss_lib import BlyssLib import json -import base64 import bz2 import time import asyncio @@ -82,9 +81,7 @@ class Bucket: else: raise e - def _split_into_chunks( - self, kv_pairs: dict[str, bytes] - ) -> list[list[dict[str, str]]]: + def _split_into_chunks(self, kv_pairs: dict[str, bytes]) -> list[dict[str, bytes]]: _MAX_PAYLOAD = 5 * 2**20 # 5 MiB # 1. Bin keys by row index @@ -99,25 +96,19 @@ class Bucket: # 2. Prepare chunks of items, where each is a JSON-ready structure. # Each chunk is less than the maximum payload size, and guarantees # zero overlap of rows across chunks. - kv_chunks: list[list[dict[str, str]]] = [] - current_chunk: list[dict[str, str]] = [] + kv_chunks: list[dict[str, bytes]] = [] + current_chunk: dict[str, bytes] = {} current_chunk_size = 0 sorted_indices = sorted(keys_by_index.keys()) for i in sorted_indices: keys = keys_by_index[i] # prepare all keys in this row - row = [] + row = {} row_size = 0 for key in keys: - value = kv_pairs[key] - value_str = base64.b64encode(value).decode("utf-8") - fmt = { - "key": key, - "value": value_str, - "content-type": "application/octet-stream", - } - row.append(fmt) - row_size += int(24 + len(key) + len(value_str) + 48) + v = kv_pairs[key] + row[key] = v + row_size += int(16 + len(key) + len(v) * 4 / 3) # if the new row doesn't fit into the current chunk, start a new one if current_chunk_size + row_size > _MAX_PAYLOAD: @@ -125,7 +116,7 @@ class Bucket: current_chunk = row current_chunk_size = row_size else: - current_chunk.extend(row) + current_chunk.update(row) current_chunk_size += row_size # add the last chunk @@ -134,13 +125,14 @@ class Bucket: return kv_chunks - def _generate_query_stream(self, keys: list[str]) -> bytes: + def _generate_query_stream(self, keys: list[str]) -> list[bytes]: assert self._public_uuid # generate encrypted queries queries: list[bytes] = [ self._lib.generate_query(self._public_uuid, self._lib.get_row(k)) for k in keys ] + return queries # interleave the queries with their lengths (uint64_t) query_lengths = [len(q).to_bytes(8, "little") for q in queries] lengths_and_queries = [x for lq in zip(query_lengths, queries) for x in lq] @@ -150,6 +142,14 @@ class Bucket: multi_query = b"".join(lengths_and_queries) return multi_query + def _decode_result(self, key: str, result_row: bytes) -> Optional[bytes]: + try: + decrypted_result = self._lib.decode_response(result_row) + decompressed_result = bz2.decompress(decrypted_result) + return self._lib.extract_result(key, decompressed_result) + except: + return None + def _unpack_query_result( self, keys: list[str], raw_result: bytes, ignore_errors=False ) -> list[Optional[bytes]]: @@ -162,9 +162,7 @@ class Bucket: else: raise RuntimeError(f"Failed to process query for key {key}.") else: - decrypted_result = self._lib.decode_response(result) - decompressed_result = bz2.decompress(decrypted_result) - extracted_result = self._lib.extract_result(key, decompressed_result) + extracted_result = self._decode_result(key, result) retrievals.append(extracted_result) return retrievals @@ -181,15 +179,18 @@ class Bucket: self.setup() assert self._public_uuid - multi_query = self._generate_query_stream(keys) + queries = self._generate_query_stream(keys) start = time.perf_counter() - multi_result = self._api.private_read(self.name, multi_query) + rows_per_result = self._api.private_read(self.name, queries) self._exfil = time.perf_counter() - start - retrievals = self._unpack_query_result(keys, multi_result) + results = [ + self._decode_result(key, result) if result else None + for key, result in zip(keys, rows_per_result) + ] - return retrievals + return results def setup(self): """Prepares this bucket client for private reads. @@ -218,7 +219,8 @@ class Bucket: bucket_create_req = { "name": new_name, } - self._api.modify(self.name, json.dumps(bucket_create_req)) + r = self._api.modify(self.name, bucket_create_req) + print(r) self.name = new_name def write(self, kv_pairs: dict[str, bytes]): @@ -228,19 +230,20 @@ class Bucket: kv_pairs: A dictionary of key-value pairs to write into the bucket. Keys must be UTF8 strings, and values may be arbitrary bytes. """ - concatenated_kv_items = b"" - for key, value in kv_pairs.items(): - concatenated_kv_items += serializer.wrap_key_val(key.encode("utf-8"), value) - # single call to API endpoint - self._api.write(self.name, concatenated_kv_items) + self._api.write(self.name, kv_pairs) # type: ignore + # bytes is a valid subset of Optional[bytes], despite mypy's complaints - def delete_key(self, key: str): - """Deletes a single key-value pair from the bucket. + def delete_key(self, keys: str | list[str]): + """Deletes key-value pairs from the bucket. Args: key: The key to delete. """ - self._api.delete_key(self.name, key) + if isinstance(keys, str): + keys = [keys] + + delete_payload = {k: None for k in keys} + self._api.write(self.name, delete_payload) # type: ignore def destroy_entire_bucket(self): """Destroys the entire bucket. This action is permanent and irreversible.""" @@ -278,7 +281,7 @@ class Bucket: keys = [keys] single_query = True - results = [r if r is not None else None for r in self._private_read(keys)] + results = self._private_read(keys) if single_query: return results[0] @@ -344,9 +347,10 @@ class AsyncBucket(Bucket): # Make one write call per chunk, while respecting a max concurrency limit. sem = asyncio.Semaphore(CONCURRENCY) - async def _paced_writer(chunk): + async def _paced_writer(chunk: dict[str, bytes]): async with sem: - await self._api.async_write(self.name, json.dumps(chunk)) + await self._api.async_write(self.name, chunk) # type: ignore + # bytes is a valid subset of Optional[bytes], despite mypy's complaints _tasks = [asyncio.create_task(_paced_writer(c)) for c in kv_chunks] await asyncio.gather(*_tasks) @@ -359,9 +363,12 @@ class AsyncBucket(Bucket): multi_query = self._generate_query_stream(keys) start = time.perf_counter() - multi_result = await self._api.async_private_read(self.name, multi_query) + rows_per_result = await self._api.async_private_read(self.name, multi_query) self._exfil = time.perf_counter() - start - retrievals = self._unpack_query_result(keys, multi_result) + results = [ + self._decode_result(key, result) if result else None + for key, result in zip(keys, rows_per_result) + ] - return retrievals + return results diff --git a/python/blyss/bucket_service.py b/python/blyss/bucket_service.py index e92431f..7dc442c 100644 --- a/python/blyss/bucket_service.py +++ b/python/blyss/bucket_service.py @@ -1,6 +1,5 @@ from typing import Any, Optional, Union from . import bucket, api, seed -import json BLYSS_BUCKET_URL = "https://beta.api.blyss.dev" DEFAULT_BUCKET_PARAMETERS = { @@ -83,10 +82,10 @@ class BucketService: parameters.update(usage_hints) bucket_create_req = { "name": bucket_name, - "parameters": json.dumps(parameters), + "parameters": parameters, "open_access": open_access, } - self._api.create(json.dumps(bucket_create_req)) + r = self._api.create(bucket_create_req) def exists(self, name: str) -> bool: """Check if a bucket exists. diff --git a/python/tests/test_service.py b/python/tests/test_service.py index 5688ca9..9098127 100644 --- a/python/tests/test_service.py +++ b/python/tests/test_service.py @@ -19,9 +19,14 @@ def key_to_gold_value(key: str, length: int = 512) -> bytes: def verify_read(key: str, value: bytes): + expected = key_to_gold_value(key, len(value)) try: - assert value == key_to_gold_value(key, len(value)) + assert value == expected except: + print(f"read mismatch for key {key}") + print(f"received {value.hex()[:16]}") + print(f"expected {expected.hex()[:16]}") + print(traceback.format_exc()) raise @@ -119,6 +124,8 @@ if __name__ == "__main__": if len(sys.argv) > 2: print("Using api_key from command line") api_key = sys.argv[2] + if api_key == "none": + api_key = None print("DEBUG", api_key, endpoint) assert endpoint is not None assert api_key is not None