(python) more consistent json for internal api

all requests and response are JSON.
all binary payloads are explicitly encoded as base64
within api.py, and decoded back to bytes before leaving api.py.
User-facing code, e.g. bucket.py and bucket_service.py,
should not see base64 wrangling.
This commit is contained in:
Neil Movva
2023-09-05 07:38:20 +00:00
parent afd6fad6f6
commit e27fb7d5ac
4 changed files with 121 additions and 83 deletions

View File

@@ -9,7 +9,6 @@ from typing import Any, Optional, Union
import requests
import httpx
import gzip
import asyncio
import json
import logging
import base64
@@ -116,14 +115,14 @@ def _post_data(api_key: str, url: str, data: Union[bytes, str]) -> bytes:
return resp.content
def _post_data_json(api_key: str, url: str, data: Union[bytes, str]) -> dict[Any, Any]:
def _post_data_json(api_key: str, url: str, data: Union[bytes, Any]) -> Any:
"""Perform an HTTP POST request, returning a JSON-parsed dict.
Request data can be any JSON string, or a raw bytestring that will be base64-encoded before send.
All requests and responses are compressed JSON."""
if len(data) > APIGW_MAX_SIZE:
raise ValueError(
f"Scheme public parameters too large ({len(data)} bytes); maximum size is {APIGW_MAX_SIZE} bytes"
f"Request data is too large ({len(data)} bytes); maximum size is {APIGW_MAX_SIZE} bytes"
)
c = httpx.Client(
@@ -133,20 +132,19 @@ def _post_data_json(api_key: str, url: str, data: Union[bytes, str]) -> dict[Any
}
)
data_textsafe: bytes
if type(data) == bytes:
data_textsafe = base64.b64encode(data)
elif type(data) == str:
data_textsafe = data.encode("utf-8")
data_jsonable = base64.b64encode(data)
else:
raise ValueError(f"Unsupported data type {type(data)}")
data_jsonable = data
data_json = json.dumps(data_jsonable).encode("utf-8")
# compress requests larger than 1KB
extra_headers = {}
if len(data_textsafe) > 1000:
payload = gzip.compress(data_textsafe)
if len(data_json) > 1000:
payload = gzip.compress(data_json)
extra_headers["Content-Encoding"] = "gzip"
else:
payload = data_textsafe
payload = data_json
resp = c.post(url, content=payload, headers=extra_headers)
@@ -163,26 +161,28 @@ def _post_form_data(url: str, fields: dict[Any, Any], data: bytes):
async def _async_post_data(
api_key: str,
url: str,
data: Union[str, bytes],
data: Union[bytes, Any],
compress: bool = True,
decode_json: bool = True,
) -> Any:
"""Perform an async HTTP POST request, returning a JSON-parsed dict response"""
headers = {"x-api-key": api_key, "Content-Type": "application/json"}
if type(data) == str:
data = data.encode("utf-8")
elif type(data) == bytes:
data = base64.b64encode(data)
if type(data) == bytes:
data_jsonable = base64.b64encode(data).decode("utf-8")
else:
raise ValueError(f"Unsupported data type {type(data)}")
data_jsonable = data
data_json = json.dumps(data_jsonable).encode("utf-8")
if compress:
# apply gzip compression to data before sending
data = gzip.compress(data)
payload = gzip.compress(data_json)
headers["Content-Encoding"] = "gzip"
else:
payload = data_json
async with httpx.AsyncClient(timeout=httpx.Timeout(5, read=None)) as client:
r = await client.post(url, content=data, headers=headers)
r = await client.post(url, content=payload, headers=headers)
_check_http_error(r) # type: ignore
if decode_json:
@@ -212,14 +212,14 @@ class API:
def _service_url_for(self, path: str) -> str:
return self.service_endpoint + path
def create(self, data_json: str) -> dict[Any, Any]:
def create(self, data_jsonable: dict) -> dict[Any, Any]:
"""Create a new bucket, given the supplied data.
Args:
data_json (str): A JSON-encoded string of the new bucket request.
"""
return _post_data_json(
self.api_key, self._service_url_for(CREATE_PATH), data_json
self.api_key, self._service_url_for(CREATE_PATH), data_jsonable
)
def check(self, uuid: str) -> dict[Any, Any]:
@@ -251,14 +251,14 @@ class API:
def _url_for(self, bucket_name: str, path: str) -> str:
return self.service_endpoint + "/" + bucket_name + path
def modify(self, bucket_name: str, data_json: str) -> dict[Any, Any]:
def modify(self, bucket_name: str, data_jsonable: Any) -> dict[Any, Any]:
"""Modify existing bucket.
Args:
data_json (str): same as create.
"""
return _post_data_json(
self.api_key, self._url_for(bucket_name, MODIFY_PATH), data_json
self.api_key, self._url_for(bucket_name, MODIFY_PATH), data_jsonable
)
def meta(self, bucket_name: str) -> dict[Any, Any]:
@@ -325,14 +325,29 @@ class API:
"""Delete all keys in this bucket."""
_post_data(self.api_key, self._url_for(bucket_name, CLEAR_PATH), "")
def write(self, bucket_name: str, data: bytes):
def write(self, bucket_name: str, data: dict[str, Optional[bytes]]):
"""Write some data to this bucket."""
_post_data(self.api_key, self._url_for(bucket_name, WRITE_PATH), data)
data_jsonable = {
k: None if v is None else base64.b64encode(v).decode("utf-8")
for k, v in data.items()
}
_post_data_json(
self.api_key, self._url_for(bucket_name, WRITE_PATH), data_jsonable
)
async def async_write(self, bucket_name: str, data: str):
async def async_write(self, bucket_name: str, data: dict[str, Optional[bytes]]):
"""Write JSON payload to this bucket."""
data_jsonable = {
k: None if v is None else base64.b64encode(v).decode("utf-8")
for k, v in data.items()
}
await _async_post_data(
self.api_key, self._url_for(bucket_name, WRITE_PATH), data, decode_json=True
self.api_key,
self._url_for(bucket_name, WRITE_PATH),
data_jsonable,
compress=True,
)
def delete_key(self, bucket_name: str, key: str):
@@ -341,16 +356,26 @@ class API:
self.api_key, self._url_for(bucket_name, DELETE_PATH), key.encode("utf-8")
)
def private_read(self, bucket_name: str, data: bytes) -> bytes:
def private_read(
self, bucket_name: str, queries: list[bytes]
) -> list[Optional[bytes]]:
"""Privately read data from this bucket."""
val = _post_data(self.api_key, self._url_for(bucket_name, READ_PATH), data)
return base64.b64decode(val)
async def async_private_read(self, bucket_name: str, data: bytes) -> bytes:
"""Privately read data from this bucket."""
val: bytes = await _async_post_data(
self.api_key, self._url_for(bucket_name, READ_PATH), data, decode_json=False
data_jsonable = [base64.b64encode(q).decode("utf-8") for q in queries]
r = _post_data_json(
self.api_key, self._url_for(bucket_name, READ_PATH), data_jsonable
)
# AWS APIGW encodes its responses as base64
return base64.b64decode(val)
# return self.private_read(bucket_name, data)
return [base64.b64decode(v) if v is not None else None for v in r]
async def async_private_read(
self, bucket_name: str, queries: list[bytes]
) -> list[Optional[bytes]]:
"""Privately read data from this bucket."""
data_jsonable = [base64.b64encode(q).decode("utf-8") for q in queries]
r: list[str] = await _async_post_data(
self.api_key,
self._url_for(bucket_name, READ_PATH),
data_jsonable,
compress=True,
decode_json=True,
)
return [base64.b64decode(v) if v is not None else None for v in r]

View File

@@ -5,11 +5,10 @@ Abstracts functionality on an existing bucket.
from typing import Optional, Any, Union, Iterator
from . import api, serializer, seed
from . import api, seed
from .blyss_lib import BlyssLib
import json
import base64
import bz2
import time
import asyncio
@@ -82,9 +81,7 @@ class Bucket:
else:
raise e
def _split_into_chunks(
self, kv_pairs: dict[str, bytes]
) -> list[list[dict[str, str]]]:
def _split_into_chunks(self, kv_pairs: dict[str, bytes]) -> list[dict[str, bytes]]:
_MAX_PAYLOAD = 5 * 2**20 # 5 MiB
# 1. Bin keys by row index
@@ -99,25 +96,19 @@ class Bucket:
# 2. Prepare chunks of items, where each is a JSON-ready structure.
# Each chunk is less than the maximum payload size, and guarantees
# zero overlap of rows across chunks.
kv_chunks: list[list[dict[str, str]]] = []
current_chunk: list[dict[str, str]] = []
kv_chunks: list[dict[str, bytes]] = []
current_chunk: dict[str, bytes] = {}
current_chunk_size = 0
sorted_indices = sorted(keys_by_index.keys())
for i in sorted_indices:
keys = keys_by_index[i]
# prepare all keys in this row
row = []
row = {}
row_size = 0
for key in keys:
value = kv_pairs[key]
value_str = base64.b64encode(value).decode("utf-8")
fmt = {
"key": key,
"value": value_str,
"content-type": "application/octet-stream",
}
row.append(fmt)
row_size += int(24 + len(key) + len(value_str) + 48)
v = kv_pairs[key]
row[key] = v
row_size += int(16 + len(key) + len(v) * 4 / 3)
# if the new row doesn't fit into the current chunk, start a new one
if current_chunk_size + row_size > _MAX_PAYLOAD:
@@ -125,7 +116,7 @@ class Bucket:
current_chunk = row
current_chunk_size = row_size
else:
current_chunk.extend(row)
current_chunk.update(row)
current_chunk_size += row_size
# add the last chunk
@@ -134,13 +125,14 @@ class Bucket:
return kv_chunks
def _generate_query_stream(self, keys: list[str]) -> bytes:
def _generate_query_stream(self, keys: list[str]) -> list[bytes]:
assert self._public_uuid
# generate encrypted queries
queries: list[bytes] = [
self._lib.generate_query(self._public_uuid, self._lib.get_row(k))
for k in keys
]
return queries
# interleave the queries with their lengths (uint64_t)
query_lengths = [len(q).to_bytes(8, "little") for q in queries]
lengths_and_queries = [x for lq in zip(query_lengths, queries) for x in lq]
@@ -150,6 +142,14 @@ class Bucket:
multi_query = b"".join(lengths_and_queries)
return multi_query
def _decode_result(self, key: str, result_row: bytes) -> Optional[bytes]:
try:
decrypted_result = self._lib.decode_response(result_row)
decompressed_result = bz2.decompress(decrypted_result)
return self._lib.extract_result(key, decompressed_result)
except:
return None
def _unpack_query_result(
self, keys: list[str], raw_result: bytes, ignore_errors=False
) -> list[Optional[bytes]]:
@@ -162,9 +162,7 @@ class Bucket:
else:
raise RuntimeError(f"Failed to process query for key {key}.")
else:
decrypted_result = self._lib.decode_response(result)
decompressed_result = bz2.decompress(decrypted_result)
extracted_result = self._lib.extract_result(key, decompressed_result)
extracted_result = self._decode_result(key, result)
retrievals.append(extracted_result)
return retrievals
@@ -181,15 +179,18 @@ class Bucket:
self.setup()
assert self._public_uuid
multi_query = self._generate_query_stream(keys)
queries = self._generate_query_stream(keys)
start = time.perf_counter()
multi_result = self._api.private_read(self.name, multi_query)
rows_per_result = self._api.private_read(self.name, queries)
self._exfil = time.perf_counter() - start
retrievals = self._unpack_query_result(keys, multi_result)
results = [
self._decode_result(key, result) if result else None
for key, result in zip(keys, rows_per_result)
]
return retrievals
return results
def setup(self):
"""Prepares this bucket client for private reads.
@@ -218,7 +219,8 @@ class Bucket:
bucket_create_req = {
"name": new_name,
}
self._api.modify(self.name, json.dumps(bucket_create_req))
r = self._api.modify(self.name, bucket_create_req)
print(r)
self.name = new_name
def write(self, kv_pairs: dict[str, bytes]):
@@ -228,19 +230,20 @@ class Bucket:
kv_pairs: A dictionary of key-value pairs to write into the bucket.
Keys must be UTF8 strings, and values may be arbitrary bytes.
"""
concatenated_kv_items = b""
for key, value in kv_pairs.items():
concatenated_kv_items += serializer.wrap_key_val(key.encode("utf-8"), value)
# single call to API endpoint
self._api.write(self.name, concatenated_kv_items)
self._api.write(self.name, kv_pairs) # type: ignore
# bytes is a valid subset of Optional[bytes], despite mypy's complaints
def delete_key(self, key: str):
"""Deletes a single key-value pair from the bucket.
def delete_key(self, keys: str | list[str]):
"""Deletes key-value pairs from the bucket.
Args:
key: The key to delete.
"""
self._api.delete_key(self.name, key)
if isinstance(keys, str):
keys = [keys]
delete_payload = {k: None for k in keys}
self._api.write(self.name, delete_payload) # type: ignore
def destroy_entire_bucket(self):
"""Destroys the entire bucket. This action is permanent and irreversible."""
@@ -278,7 +281,7 @@ class Bucket:
keys = [keys]
single_query = True
results = [r if r is not None else None for r in self._private_read(keys)]
results = self._private_read(keys)
if single_query:
return results[0]
@@ -344,9 +347,10 @@ class AsyncBucket(Bucket):
# Make one write call per chunk, while respecting a max concurrency limit.
sem = asyncio.Semaphore(CONCURRENCY)
async def _paced_writer(chunk):
async def _paced_writer(chunk: dict[str, bytes]):
async with sem:
await self._api.async_write(self.name, json.dumps(chunk))
await self._api.async_write(self.name, chunk) # type: ignore
# bytes is a valid subset of Optional[bytes], despite mypy's complaints
_tasks = [asyncio.create_task(_paced_writer(c)) for c in kv_chunks]
await asyncio.gather(*_tasks)
@@ -359,9 +363,12 @@ class AsyncBucket(Bucket):
multi_query = self._generate_query_stream(keys)
start = time.perf_counter()
multi_result = await self._api.async_private_read(self.name, multi_query)
rows_per_result = await self._api.async_private_read(self.name, multi_query)
self._exfil = time.perf_counter() - start
retrievals = self._unpack_query_result(keys, multi_result)
results = [
self._decode_result(key, result) if result else None
for key, result in zip(keys, rows_per_result)
]
return retrievals
return results

View File

@@ -1,6 +1,5 @@
from typing import Any, Optional, Union
from . import bucket, api, seed
import json
BLYSS_BUCKET_URL = "https://beta.api.blyss.dev"
DEFAULT_BUCKET_PARAMETERS = {
@@ -83,10 +82,10 @@ class BucketService:
parameters.update(usage_hints)
bucket_create_req = {
"name": bucket_name,
"parameters": json.dumps(parameters),
"parameters": parameters,
"open_access": open_access,
}
self._api.create(json.dumps(bucket_create_req))
r = self._api.create(bucket_create_req)
def exists(self, name: str) -> bool:
"""Check if a bucket exists.

View File

@@ -19,9 +19,14 @@ def key_to_gold_value(key: str, length: int = 512) -> bytes:
def verify_read(key: str, value: bytes):
expected = key_to_gold_value(key, len(value))
try:
assert value == key_to_gold_value(key, len(value))
assert value == expected
except:
print(f"read mismatch for key {key}")
print(f"received {value.hex()[:16]}")
print(f"expected {expected.hex()[:16]}")
print(traceback.format_exc())
raise
@@ -119,6 +124,8 @@ if __name__ == "__main__":
if len(sys.argv) > 2:
print("Using api_key from command line")
api_key = sys.argv[2]
if api_key == "none":
api_key = None
print("DEBUG", api_key, endpoint)
assert endpoint is not None
assert api_key is not None