mirror of
https://github.com/blyssprivacy/sdk.git
synced 2026-04-26 03:00:13 -04:00
(python) more consistent json for internal api
all requests and response are JSON. all binary payloads are explicitly encoded as base64 within api.py, and decoded back to bytes before leaving api.py. User-facing code, e.g. bucket.py and bucket_service.py, should not see base64 wrangling.
This commit is contained in:
@@ -9,7 +9,6 @@ from typing import Any, Optional, Union
|
||||
import requests
|
||||
import httpx
|
||||
import gzip
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import base64
|
||||
@@ -116,14 +115,14 @@ def _post_data(api_key: str, url: str, data: Union[bytes, str]) -> bytes:
|
||||
return resp.content
|
||||
|
||||
|
||||
def _post_data_json(api_key: str, url: str, data: Union[bytes, str]) -> dict[Any, Any]:
|
||||
def _post_data_json(api_key: str, url: str, data: Union[bytes, Any]) -> Any:
|
||||
"""Perform an HTTP POST request, returning a JSON-parsed dict.
|
||||
Request data can be any JSON string, or a raw bytestring that will be base64-encoded before send.
|
||||
All requests and responses are compressed JSON."""
|
||||
|
||||
if len(data) > APIGW_MAX_SIZE:
|
||||
raise ValueError(
|
||||
f"Scheme public parameters too large ({len(data)} bytes); maximum size is {APIGW_MAX_SIZE} bytes"
|
||||
f"Request data is too large ({len(data)} bytes); maximum size is {APIGW_MAX_SIZE} bytes"
|
||||
)
|
||||
|
||||
c = httpx.Client(
|
||||
@@ -133,20 +132,19 @@ def _post_data_json(api_key: str, url: str, data: Union[bytes, str]) -> dict[Any
|
||||
}
|
||||
)
|
||||
|
||||
data_textsafe: bytes
|
||||
if type(data) == bytes:
|
||||
data_textsafe = base64.b64encode(data)
|
||||
elif type(data) == str:
|
||||
data_textsafe = data.encode("utf-8")
|
||||
data_jsonable = base64.b64encode(data)
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type {type(data)}")
|
||||
data_jsonable = data
|
||||
data_json = json.dumps(data_jsonable).encode("utf-8")
|
||||
|
||||
# compress requests larger than 1KB
|
||||
extra_headers = {}
|
||||
if len(data_textsafe) > 1000:
|
||||
payload = gzip.compress(data_textsafe)
|
||||
if len(data_json) > 1000:
|
||||
payload = gzip.compress(data_json)
|
||||
extra_headers["Content-Encoding"] = "gzip"
|
||||
else:
|
||||
payload = data_textsafe
|
||||
payload = data_json
|
||||
|
||||
resp = c.post(url, content=payload, headers=extra_headers)
|
||||
|
||||
@@ -163,26 +161,28 @@ def _post_form_data(url: str, fields: dict[Any, Any], data: bytes):
|
||||
async def _async_post_data(
|
||||
api_key: str,
|
||||
url: str,
|
||||
data: Union[str, bytes],
|
||||
data: Union[bytes, Any],
|
||||
compress: bool = True,
|
||||
decode_json: bool = True,
|
||||
) -> Any:
|
||||
"""Perform an async HTTP POST request, returning a JSON-parsed dict response"""
|
||||
headers = {"x-api-key": api_key, "Content-Type": "application/json"}
|
||||
if type(data) == str:
|
||||
data = data.encode("utf-8")
|
||||
elif type(data) == bytes:
|
||||
data = base64.b64encode(data)
|
||||
|
||||
if type(data) == bytes:
|
||||
data_jsonable = base64.b64encode(data).decode("utf-8")
|
||||
else:
|
||||
raise ValueError(f"Unsupported data type {type(data)}")
|
||||
data_jsonable = data
|
||||
data_json = json.dumps(data_jsonable).encode("utf-8")
|
||||
|
||||
if compress:
|
||||
# apply gzip compression to data before sending
|
||||
data = gzip.compress(data)
|
||||
payload = gzip.compress(data_json)
|
||||
headers["Content-Encoding"] = "gzip"
|
||||
else:
|
||||
payload = data_json
|
||||
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(5, read=None)) as client:
|
||||
r = await client.post(url, content=data, headers=headers)
|
||||
r = await client.post(url, content=payload, headers=headers)
|
||||
|
||||
_check_http_error(r) # type: ignore
|
||||
if decode_json:
|
||||
@@ -212,14 +212,14 @@ class API:
|
||||
def _service_url_for(self, path: str) -> str:
|
||||
return self.service_endpoint + path
|
||||
|
||||
def create(self, data_json: str) -> dict[Any, Any]:
|
||||
def create(self, data_jsonable: dict) -> dict[Any, Any]:
|
||||
"""Create a new bucket, given the supplied data.
|
||||
|
||||
Args:
|
||||
data_json (str): A JSON-encoded string of the new bucket request.
|
||||
"""
|
||||
return _post_data_json(
|
||||
self.api_key, self._service_url_for(CREATE_PATH), data_json
|
||||
self.api_key, self._service_url_for(CREATE_PATH), data_jsonable
|
||||
)
|
||||
|
||||
def check(self, uuid: str) -> dict[Any, Any]:
|
||||
@@ -251,14 +251,14 @@ class API:
|
||||
def _url_for(self, bucket_name: str, path: str) -> str:
|
||||
return self.service_endpoint + "/" + bucket_name + path
|
||||
|
||||
def modify(self, bucket_name: str, data_json: str) -> dict[Any, Any]:
|
||||
def modify(self, bucket_name: str, data_jsonable: Any) -> dict[Any, Any]:
|
||||
"""Modify existing bucket.
|
||||
|
||||
Args:
|
||||
data_json (str): same as create.
|
||||
"""
|
||||
return _post_data_json(
|
||||
self.api_key, self._url_for(bucket_name, MODIFY_PATH), data_json
|
||||
self.api_key, self._url_for(bucket_name, MODIFY_PATH), data_jsonable
|
||||
)
|
||||
|
||||
def meta(self, bucket_name: str) -> dict[Any, Any]:
|
||||
@@ -325,14 +325,29 @@ class API:
|
||||
"""Delete all keys in this bucket."""
|
||||
_post_data(self.api_key, self._url_for(bucket_name, CLEAR_PATH), "")
|
||||
|
||||
def write(self, bucket_name: str, data: bytes):
|
||||
def write(self, bucket_name: str, data: dict[str, Optional[bytes]]):
|
||||
"""Write some data to this bucket."""
|
||||
_post_data(self.api_key, self._url_for(bucket_name, WRITE_PATH), data)
|
||||
data_jsonable = {
|
||||
k: None if v is None else base64.b64encode(v).decode("utf-8")
|
||||
for k, v in data.items()
|
||||
}
|
||||
_post_data_json(
|
||||
self.api_key, self._url_for(bucket_name, WRITE_PATH), data_jsonable
|
||||
)
|
||||
|
||||
async def async_write(self, bucket_name: str, data: str):
|
||||
async def async_write(self, bucket_name: str, data: dict[str, Optional[bytes]]):
|
||||
"""Write JSON payload to this bucket."""
|
||||
|
||||
data_jsonable = {
|
||||
k: None if v is None else base64.b64encode(v).decode("utf-8")
|
||||
for k, v in data.items()
|
||||
}
|
||||
|
||||
await _async_post_data(
|
||||
self.api_key, self._url_for(bucket_name, WRITE_PATH), data, decode_json=True
|
||||
self.api_key,
|
||||
self._url_for(bucket_name, WRITE_PATH),
|
||||
data_jsonable,
|
||||
compress=True,
|
||||
)
|
||||
|
||||
def delete_key(self, bucket_name: str, key: str):
|
||||
@@ -341,16 +356,26 @@ class API:
|
||||
self.api_key, self._url_for(bucket_name, DELETE_PATH), key.encode("utf-8")
|
||||
)
|
||||
|
||||
def private_read(self, bucket_name: str, data: bytes) -> bytes:
|
||||
def private_read(
|
||||
self, bucket_name: str, queries: list[bytes]
|
||||
) -> list[Optional[bytes]]:
|
||||
"""Privately read data from this bucket."""
|
||||
val = _post_data(self.api_key, self._url_for(bucket_name, READ_PATH), data)
|
||||
return base64.b64decode(val)
|
||||
|
||||
async def async_private_read(self, bucket_name: str, data: bytes) -> bytes:
|
||||
"""Privately read data from this bucket."""
|
||||
val: bytes = await _async_post_data(
|
||||
self.api_key, self._url_for(bucket_name, READ_PATH), data, decode_json=False
|
||||
data_jsonable = [base64.b64encode(q).decode("utf-8") for q in queries]
|
||||
r = _post_data_json(
|
||||
self.api_key, self._url_for(bucket_name, READ_PATH), data_jsonable
|
||||
)
|
||||
# AWS APIGW encodes its responses as base64
|
||||
return base64.b64decode(val)
|
||||
# return self.private_read(bucket_name, data)
|
||||
return [base64.b64decode(v) if v is not None else None for v in r]
|
||||
|
||||
async def async_private_read(
|
||||
self, bucket_name: str, queries: list[bytes]
|
||||
) -> list[Optional[bytes]]:
|
||||
"""Privately read data from this bucket."""
|
||||
data_jsonable = [base64.b64encode(q).decode("utf-8") for q in queries]
|
||||
r: list[str] = await _async_post_data(
|
||||
self.api_key,
|
||||
self._url_for(bucket_name, READ_PATH),
|
||||
data_jsonable,
|
||||
compress=True,
|
||||
decode_json=True,
|
||||
)
|
||||
return [base64.b64decode(v) if v is not None else None for v in r]
|
||||
|
||||
@@ -5,11 +5,10 @@ Abstracts functionality on an existing bucket.
|
||||
|
||||
from typing import Optional, Any, Union, Iterator
|
||||
|
||||
from . import api, serializer, seed
|
||||
from . import api, seed
|
||||
from .blyss_lib import BlyssLib
|
||||
|
||||
import json
|
||||
import base64
|
||||
import bz2
|
||||
import time
|
||||
import asyncio
|
||||
@@ -82,9 +81,7 @@ class Bucket:
|
||||
else:
|
||||
raise e
|
||||
|
||||
def _split_into_chunks(
|
||||
self, kv_pairs: dict[str, bytes]
|
||||
) -> list[list[dict[str, str]]]:
|
||||
def _split_into_chunks(self, kv_pairs: dict[str, bytes]) -> list[dict[str, bytes]]:
|
||||
_MAX_PAYLOAD = 5 * 2**20 # 5 MiB
|
||||
|
||||
# 1. Bin keys by row index
|
||||
@@ -99,25 +96,19 @@ class Bucket:
|
||||
# 2. Prepare chunks of items, where each is a JSON-ready structure.
|
||||
# Each chunk is less than the maximum payload size, and guarantees
|
||||
# zero overlap of rows across chunks.
|
||||
kv_chunks: list[list[dict[str, str]]] = []
|
||||
current_chunk: list[dict[str, str]] = []
|
||||
kv_chunks: list[dict[str, bytes]] = []
|
||||
current_chunk: dict[str, bytes] = {}
|
||||
current_chunk_size = 0
|
||||
sorted_indices = sorted(keys_by_index.keys())
|
||||
for i in sorted_indices:
|
||||
keys = keys_by_index[i]
|
||||
# prepare all keys in this row
|
||||
row = []
|
||||
row = {}
|
||||
row_size = 0
|
||||
for key in keys:
|
||||
value = kv_pairs[key]
|
||||
value_str = base64.b64encode(value).decode("utf-8")
|
||||
fmt = {
|
||||
"key": key,
|
||||
"value": value_str,
|
||||
"content-type": "application/octet-stream",
|
||||
}
|
||||
row.append(fmt)
|
||||
row_size += int(24 + len(key) + len(value_str) + 48)
|
||||
v = kv_pairs[key]
|
||||
row[key] = v
|
||||
row_size += int(16 + len(key) + len(v) * 4 / 3)
|
||||
|
||||
# if the new row doesn't fit into the current chunk, start a new one
|
||||
if current_chunk_size + row_size > _MAX_PAYLOAD:
|
||||
@@ -125,7 +116,7 @@ class Bucket:
|
||||
current_chunk = row
|
||||
current_chunk_size = row_size
|
||||
else:
|
||||
current_chunk.extend(row)
|
||||
current_chunk.update(row)
|
||||
current_chunk_size += row_size
|
||||
|
||||
# add the last chunk
|
||||
@@ -134,13 +125,14 @@ class Bucket:
|
||||
|
||||
return kv_chunks
|
||||
|
||||
def _generate_query_stream(self, keys: list[str]) -> bytes:
|
||||
def _generate_query_stream(self, keys: list[str]) -> list[bytes]:
|
||||
assert self._public_uuid
|
||||
# generate encrypted queries
|
||||
queries: list[bytes] = [
|
||||
self._lib.generate_query(self._public_uuid, self._lib.get_row(k))
|
||||
for k in keys
|
||||
]
|
||||
return queries
|
||||
# interleave the queries with their lengths (uint64_t)
|
||||
query_lengths = [len(q).to_bytes(8, "little") for q in queries]
|
||||
lengths_and_queries = [x for lq in zip(query_lengths, queries) for x in lq]
|
||||
@@ -150,6 +142,14 @@ class Bucket:
|
||||
multi_query = b"".join(lengths_and_queries)
|
||||
return multi_query
|
||||
|
||||
def _decode_result(self, key: str, result_row: bytes) -> Optional[bytes]:
|
||||
try:
|
||||
decrypted_result = self._lib.decode_response(result_row)
|
||||
decompressed_result = bz2.decompress(decrypted_result)
|
||||
return self._lib.extract_result(key, decompressed_result)
|
||||
except:
|
||||
return None
|
||||
|
||||
def _unpack_query_result(
|
||||
self, keys: list[str], raw_result: bytes, ignore_errors=False
|
||||
) -> list[Optional[bytes]]:
|
||||
@@ -162,9 +162,7 @@ class Bucket:
|
||||
else:
|
||||
raise RuntimeError(f"Failed to process query for key {key}.")
|
||||
else:
|
||||
decrypted_result = self._lib.decode_response(result)
|
||||
decompressed_result = bz2.decompress(decrypted_result)
|
||||
extracted_result = self._lib.extract_result(key, decompressed_result)
|
||||
extracted_result = self._decode_result(key, result)
|
||||
retrievals.append(extracted_result)
|
||||
return retrievals
|
||||
|
||||
@@ -181,15 +179,18 @@ class Bucket:
|
||||
self.setup()
|
||||
assert self._public_uuid
|
||||
|
||||
multi_query = self._generate_query_stream(keys)
|
||||
queries = self._generate_query_stream(keys)
|
||||
|
||||
start = time.perf_counter()
|
||||
multi_result = self._api.private_read(self.name, multi_query)
|
||||
rows_per_result = self._api.private_read(self.name, queries)
|
||||
self._exfil = time.perf_counter() - start
|
||||
|
||||
retrievals = self._unpack_query_result(keys, multi_result)
|
||||
results = [
|
||||
self._decode_result(key, result) if result else None
|
||||
for key, result in zip(keys, rows_per_result)
|
||||
]
|
||||
|
||||
return retrievals
|
||||
return results
|
||||
|
||||
def setup(self):
|
||||
"""Prepares this bucket client for private reads.
|
||||
@@ -218,7 +219,8 @@ class Bucket:
|
||||
bucket_create_req = {
|
||||
"name": new_name,
|
||||
}
|
||||
self._api.modify(self.name, json.dumps(bucket_create_req))
|
||||
r = self._api.modify(self.name, bucket_create_req)
|
||||
print(r)
|
||||
self.name = new_name
|
||||
|
||||
def write(self, kv_pairs: dict[str, bytes]):
|
||||
@@ -228,19 +230,20 @@ class Bucket:
|
||||
kv_pairs: A dictionary of key-value pairs to write into the bucket.
|
||||
Keys must be UTF8 strings, and values may be arbitrary bytes.
|
||||
"""
|
||||
concatenated_kv_items = b""
|
||||
for key, value in kv_pairs.items():
|
||||
concatenated_kv_items += serializer.wrap_key_val(key.encode("utf-8"), value)
|
||||
# single call to API endpoint
|
||||
self._api.write(self.name, concatenated_kv_items)
|
||||
self._api.write(self.name, kv_pairs) # type: ignore
|
||||
# bytes is a valid subset of Optional[bytes], despite mypy's complaints
|
||||
|
||||
def delete_key(self, key: str):
|
||||
"""Deletes a single key-value pair from the bucket.
|
||||
def delete_key(self, keys: str | list[str]):
|
||||
"""Deletes key-value pairs from the bucket.
|
||||
|
||||
Args:
|
||||
key: The key to delete.
|
||||
"""
|
||||
self._api.delete_key(self.name, key)
|
||||
if isinstance(keys, str):
|
||||
keys = [keys]
|
||||
|
||||
delete_payload = {k: None for k in keys}
|
||||
self._api.write(self.name, delete_payload) # type: ignore
|
||||
|
||||
def destroy_entire_bucket(self):
|
||||
"""Destroys the entire bucket. This action is permanent and irreversible."""
|
||||
@@ -278,7 +281,7 @@ class Bucket:
|
||||
keys = [keys]
|
||||
single_query = True
|
||||
|
||||
results = [r if r is not None else None for r in self._private_read(keys)]
|
||||
results = self._private_read(keys)
|
||||
if single_query:
|
||||
return results[0]
|
||||
|
||||
@@ -344,9 +347,10 @@ class AsyncBucket(Bucket):
|
||||
# Make one write call per chunk, while respecting a max concurrency limit.
|
||||
sem = asyncio.Semaphore(CONCURRENCY)
|
||||
|
||||
async def _paced_writer(chunk):
|
||||
async def _paced_writer(chunk: dict[str, bytes]):
|
||||
async with sem:
|
||||
await self._api.async_write(self.name, json.dumps(chunk))
|
||||
await self._api.async_write(self.name, chunk) # type: ignore
|
||||
# bytes is a valid subset of Optional[bytes], despite mypy's complaints
|
||||
|
||||
_tasks = [asyncio.create_task(_paced_writer(c)) for c in kv_chunks]
|
||||
await asyncio.gather(*_tasks)
|
||||
@@ -359,9 +363,12 @@ class AsyncBucket(Bucket):
|
||||
multi_query = self._generate_query_stream(keys)
|
||||
|
||||
start = time.perf_counter()
|
||||
multi_result = await self._api.async_private_read(self.name, multi_query)
|
||||
rows_per_result = await self._api.async_private_read(self.name, multi_query)
|
||||
self._exfil = time.perf_counter() - start
|
||||
|
||||
retrievals = self._unpack_query_result(keys, multi_result)
|
||||
results = [
|
||||
self._decode_result(key, result) if result else None
|
||||
for key, result in zip(keys, rows_per_result)
|
||||
]
|
||||
|
||||
return retrievals
|
||||
return results
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any, Optional, Union
|
||||
from . import bucket, api, seed
|
||||
import json
|
||||
|
||||
BLYSS_BUCKET_URL = "https://beta.api.blyss.dev"
|
||||
DEFAULT_BUCKET_PARAMETERS = {
|
||||
@@ -83,10 +82,10 @@ class BucketService:
|
||||
parameters.update(usage_hints)
|
||||
bucket_create_req = {
|
||||
"name": bucket_name,
|
||||
"parameters": json.dumps(parameters),
|
||||
"parameters": parameters,
|
||||
"open_access": open_access,
|
||||
}
|
||||
self._api.create(json.dumps(bucket_create_req))
|
||||
r = self._api.create(bucket_create_req)
|
||||
|
||||
def exists(self, name: str) -> bool:
|
||||
"""Check if a bucket exists.
|
||||
|
||||
@@ -19,9 +19,14 @@ def key_to_gold_value(key: str, length: int = 512) -> bytes:
|
||||
|
||||
|
||||
def verify_read(key: str, value: bytes):
|
||||
expected = key_to_gold_value(key, len(value))
|
||||
try:
|
||||
assert value == key_to_gold_value(key, len(value))
|
||||
assert value == expected
|
||||
except:
|
||||
print(f"read mismatch for key {key}")
|
||||
print(f"received {value.hex()[:16]}")
|
||||
print(f"expected {expected.hex()[:16]}")
|
||||
|
||||
print(traceback.format_exc())
|
||||
raise
|
||||
|
||||
@@ -119,6 +124,8 @@ if __name__ == "__main__":
|
||||
if len(sys.argv) > 2:
|
||||
print("Using api_key from command line")
|
||||
api_key = sys.argv[2]
|
||||
if api_key == "none":
|
||||
api_key = None
|
||||
print("DEBUG", api_key, endpoint)
|
||||
assert endpoint is not None
|
||||
assert api_key is not None
|
||||
|
||||
Reference in New Issue
Block a user