mirror of
https://github.com/vacp2p/status-benchmarks.git
synced 2026-01-08 15:13:59 -05:00
Buffered queue (#14)
* Import enum from local * Change class name because of previous rebase * Fix double await because of previous changes in RPCClient * Change Singal class to create a stream with a bufferedqueue and an asyncgenerator * Change StatusBackend to use previous signal changes. Removed __aexit__ to have better control over lifetime * Clean imports * Add enums class
This commit is contained in:
@@ -10,9 +10,9 @@ class AccountAsyncService(AsyncService):
|
||||
super().__init__(rpc, "accounts")
|
||||
|
||||
async def get_accounts(self) -> dict:
|
||||
response_dict = await self.rpc.rpc_valid_request("getAccounts")
|
||||
return response_dict
|
||||
json_response = await self.rpc.rpc_valid_request("getAccounts")
|
||||
return json_response
|
||||
|
||||
async def get_account_keypairs(self) -> dict:
|
||||
response_dict = await self.rpc.rpc_valid_request("getKeypairs")
|
||||
return response_dict
|
||||
json_response = await self.rpc.rpc_valid_request("getKeypairs")
|
||||
return json_response
|
||||
|
||||
32
src/enums.py
Normal file
32
src/enums.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Python Imports
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class MessageContentType(Enum):
|
||||
UNKNOWN_CONTENT_TYPE = 0
|
||||
TEXT_PLAIN = 1
|
||||
STICKER = 2
|
||||
STATUS = 3
|
||||
EMOJI = 4
|
||||
TRANSACTION_COMMAND = 5
|
||||
SYSTEM_MESSAGE_CONTENT_PRIVATE_GROUP = 6
|
||||
IMAGE = 7
|
||||
AUDIO = 8
|
||||
COMMUNITY = 9
|
||||
SYSTEM_MESSAGE_GAP = 10
|
||||
CONTACT_REQUEST = 11
|
||||
DISCORD_MESSAGE = 12
|
||||
IDENTITY_VERIFICATION = 13
|
||||
SYSTEM_MESSAGE_PINNED_MESSAGE = 14
|
||||
SYSTEM_MESSAGE_MUTUAL_EVENT_SENT = 15
|
||||
SYSTEM_MESSAGE_MUTUAL_EVENT_ACCEPTED = 16
|
||||
SYSTEM_MESSAGE_MUTUAL_EVENT_REMOVED = 17
|
||||
BRIDGE_MESSAGE = 18
|
||||
|
||||
class SignalType(Enum):
|
||||
MESSAGES_NEW = "messages.new"
|
||||
MESSAGE_DELIVERED = "message.delivered"
|
||||
NODE_READY = "node.ready"
|
||||
NODE_STARTED = "node.started"
|
||||
NODE_LOGIN = "node.login"
|
||||
NODE_LOGOUT = "node.stopped"
|
||||
@@ -1,5 +1,5 @@
|
||||
# Python Imports
|
||||
from typing import Optional, Any
|
||||
from typing import Optional
|
||||
|
||||
# Project Imports
|
||||
from src.rpc_client import AsyncRpcClient
|
||||
@@ -10,6 +10,7 @@ class AsyncService:
|
||||
self.rpc = async_rpc_client
|
||||
self.name = name
|
||||
|
||||
async def rpc_request(self, method: str, params: Optional[list] = None, enable_logging: bool = True) -> Any:
|
||||
async def rpc_request(self, method: str, params: Optional[list] = None, enable_logging: bool = True) -> dict:
|
||||
# In order to be validated, the response is already awaited, so this already returns the dict data
|
||||
full_method_name = f"{self.name}_{method}"
|
||||
return await self.rpc.rpc_valid_request(full_method_name, params or [], enable_logging=enable_logging)
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
# Python Imports
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Optional, Callable
|
||||
from typing import Optional, AsyncGenerator
|
||||
from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from collections import deque
|
||||
|
||||
# Project Imports
|
||||
from src.enums import SignalType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -17,38 +19,33 @@ LOG_SIGNALS_TO_FILE = False
|
||||
SIGNALS_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
class SignalType(Enum):
|
||||
MESSAGES_NEW = "messages.new"
|
||||
MESSAGE_DELIVERED = "message.delivered"
|
||||
NODE_READY = "node.ready"
|
||||
NODE_STARTED = "node.started"
|
||||
NODE_LOGIN = "node.login"
|
||||
NODE_LOGOUT = "node.stopped"
|
||||
class BufferedQueue:
|
||||
def __init__(self, max_size: int = 100):
|
||||
self.queue = asyncio.Queue()
|
||||
self.buffer = deque(maxlen=max_size)
|
||||
|
||||
async def put(self, item):
|
||||
self.buffer.append(item)
|
||||
await self.queue.put(item)
|
||||
|
||||
async def get(self):
|
||||
return await self.queue.get()
|
||||
|
||||
def recent(self) -> list:
|
||||
return list(self.buffer)
|
||||
|
||||
|
||||
class AsyncSignalClient:
|
||||
def __init__(self, ws_url: str, await_signals: list[str]):
|
||||
def __init__(self, ws_url: str, await_signals: list[str], buffer_size: int = 100):
|
||||
self.url = f"{ws_url}/signals"
|
||||
|
||||
self.await_signals = await_signals
|
||||
self.ws: Optional[ClientWebSocketResponse] = None
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.signal_file_path = None
|
||||
self.signal_lock = asyncio.Lock()
|
||||
self.listener_task = None
|
||||
|
||||
self.received_signals: dict[str, dict] = {
|
||||
# For each signal type, store:
|
||||
# - list of received signals
|
||||
# - expected received event delta count (resets to 1 after each wait_for_event call)
|
||||
# - expected received event count
|
||||
# - a function that takes the received signal as an argument and returns True if the signal is accepted (counted) or discarded
|
||||
signal: {
|
||||
"received": [],
|
||||
"delta_count": 1,
|
||||
"expected_count": 1,
|
||||
"accept_fn": None,
|
||||
}
|
||||
for signal in self.await_signals
|
||||
self.signal_queues: dict[str, BufferedQueue] = {
|
||||
signal: BufferedQueue(max_size=buffer_size) for signal in self.await_signals
|
||||
}
|
||||
|
||||
if LOG_SIGNALS_TO_FILE: # Not being used currently
|
||||
@@ -61,7 +58,8 @@ class AsyncSignalClient:
|
||||
async def __aenter__(self):
|
||||
self.session = ClientSession()
|
||||
self.ws = await self.session.ws_connect(self.url)
|
||||
asyncio.create_task(self._listen())
|
||||
self.listener_task = asyncio.create_task(self._listen())
|
||||
await asyncio.sleep(0) # Yield control to ensure _listen starts
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
@@ -69,8 +67,13 @@ class AsyncSignalClient:
|
||||
await self.ws.close()
|
||||
if self.session:
|
||||
await self.session.close()
|
||||
if self.listener_task:
|
||||
self.listener_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self.listener_task
|
||||
|
||||
async def _listen(self):
|
||||
logger.debug("WebSocket listener started")
|
||||
async for msg in self.ws:
|
||||
if msg.type == WSMsgType.TEXT:
|
||||
await self.on_message(msg.data)
|
||||
@@ -79,50 +82,44 @@ class AsyncSignalClient:
|
||||
|
||||
async def on_message(self, signal: str):
|
||||
signal_data = json.loads(signal)
|
||||
logger.debug(f"Received WebSocket message: {signal_data}")
|
||||
|
||||
if LOG_SIGNALS_TO_FILE:
|
||||
pass # TODO
|
||||
pass # TODO: write to file if needed
|
||||
|
||||
signal_type = signal_data.get("type")
|
||||
if signal_type in self.await_signals:
|
||||
async with self.signal_lock:
|
||||
accept_fn = self.received_signals[signal_type]["accept_fn"]
|
||||
if not accept_fn or accept_fn(signal_data):
|
||||
self.received_signals[signal_type]["received"].append(signal_data)
|
||||
if signal_type in self.signal_queues:
|
||||
await self.signal_queues[signal_type].put(signal_data)
|
||||
logger.debug(f"Queued signal: {signal_type}")
|
||||
else:
|
||||
logger.debug(f"Ignored signal not in await list: {signal_type}")
|
||||
|
||||
# Used to set up how many instances of a signal to wait for, before triggering the actions
|
||||
# that cause them to be emitted.
|
||||
async def prepare_wait_for_signal(self, signal_type: str, delta_count: int, accept_fn: Optional[Callable] = None):
|
||||
if signal_type not in self.await_signals:
|
||||
async def wait_for_signal(self, signal_type: str, timeout: int = 20) -> dict:
|
||||
if signal_type not in self.signal_queues:
|
||||
raise ValueError(f"Signal type {signal_type} is not in the list of awaited signals")
|
||||
async with self.signal_lock:
|
||||
self.received_signals[signal_type]["delta_count"] = delta_count
|
||||
self.received_signals[signal_type]["expected_count"] = (
|
||||
len(self.received_signals[signal_type]["received"]) + delta_count
|
||||
)
|
||||
self.received_signals[signal_type]["accept_fn"] = accept_fn
|
||||
try:
|
||||
signal = await asyncio.wait_for(self.signal_queues[signal_type].get(), timeout)
|
||||
logger.info(f"Received {signal_type} signal: {signal}")
|
||||
return signal
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError(f"Signal {signal_type} not received in {timeout} seconds")
|
||||
|
||||
async def wait_for_signal(self, signal_type: str, timeout: int = 20) -> dict | list[dict]:
|
||||
if signal_type not in self.await_signals:
|
||||
async def signal_stream(self, signal_type: str) -> AsyncGenerator[dict, None]:
|
||||
if signal_type not in self.signal_queues:
|
||||
raise ValueError(f"Signal type {signal_type} is not in the list of awaited signals")
|
||||
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
while True:
|
||||
async with self.signal_lock:
|
||||
received = self.received_signals[signal_type]["received"]
|
||||
expected = self.received_signals[signal_type]["expected_count"]
|
||||
delta_count = self.received_signals[signal_type]["delta_count"]
|
||||
yield await self.signal_queues[signal_type].get()
|
||||
|
||||
if len(received) >= expected:
|
||||
await self.prepare_wait_for_signal(signal_type, 1)
|
||||
return received[-1] if delta_count == 1 else received[-delta_count:]
|
||||
|
||||
if asyncio.get_event_loop().time() - start_time >= timeout:
|
||||
raise TimeoutError(f"Signal {signal_type} not received in {timeout} seconds")
|
||||
await asyncio.sleep(0.2)
|
||||
def get_recent_signals(self, signal_type: str) -> list:
|
||||
if signal_type not in self.signal_queues:
|
||||
raise ValueError(f"Signal type {signal_type} is not in the list of awaited signals")
|
||||
return self.signal_queues[signal_type].recent()
|
||||
|
||||
async def wait_for_login(self) -> dict:
|
||||
logger.info("Waiting for login signal...")
|
||||
signal = await self.wait_for_signal(SignalType.NODE_LOGIN.value)
|
||||
if "error" in signal["event"]:
|
||||
logger.info(f"Login signal received: {signal}")
|
||||
if "error" in signal.get("event", {}):
|
||||
error_details = signal["event"]["error"]
|
||||
assert not error_details, f"Unexpected error during login: {error_details}"
|
||||
self.node_login_event = signal
|
||||
@@ -134,12 +131,12 @@ class AsyncSignalClient:
|
||||
async def find_signal_containing_string(self, signal_type: str, event_string: str, timeout=20) -> Optional[dict]:
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
while True:
|
||||
async with self.signal_lock:
|
||||
for event in self.received_signals.get(signal_type, {}).get("received", []):
|
||||
if event_string in json.dumps(event):
|
||||
logger.info(f"Found {signal_type} containing '{event_string}'")
|
||||
return event
|
||||
|
||||
if asyncio.get_event_loop().time() - start_time >= timeout:
|
||||
raise TimeoutError(f"Signal {signal_type} containing '{event_string}' not received in {timeout} seconds")
|
||||
await asyncio.sleep(0.2)
|
||||
try:
|
||||
signal = await asyncio.wait_for(self.signal_queues[signal_type].get(), timeout)
|
||||
if event_string in json.dumps(signal):
|
||||
logger.info(f"Found {signal_type} containing '{event_string}'")
|
||||
return signal
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError(
|
||||
f"Signal {signal_type} containing '{event_string}' not received in {timeout} seconds"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
# Python Imports
|
||||
import logging
|
||||
import json
|
||||
from typing import List, Dict
|
||||
from aiohttp import ClientSession, ClientTimeout, ClientResponse
|
||||
from aiohttp import ClientSession, ClientTimeout
|
||||
|
||||
# Project Imports
|
||||
from src.account_service import AccountAsyncService
|
||||
from src.enums import SignalType
|
||||
from src.rpc_client import AsyncRpcClient
|
||||
from src.signal_client import AsyncSignalClient
|
||||
from src.wakuext_service import WakuextAsyncService
|
||||
@@ -35,37 +37,45 @@ class StatusBackend:
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
await self.signal.__aexit__(exc_type, exc_val, exc_tb)
|
||||
await self.rpc.__aexit__(exc_type, exc_val, exc_tb)
|
||||
# Let the caller handle shutdown
|
||||
pass
|
||||
|
||||
async def shutdown(self):
|
||||
await self.signal.__aexit__(None, None, None)
|
||||
await self.rpc.__aexit__(None, None, None)
|
||||
await self.session.close()
|
||||
|
||||
async def call_rpc(self, method: str, params: List = None):
|
||||
return await self.rpc.rpc_valid_request(method, params or [])
|
||||
|
||||
async def api_request(self, method: str, data: Dict, url: str = None) -> ClientResponse:
|
||||
url = url or self.api_url
|
||||
url = f"{url}/{method}"
|
||||
logger.debug(f"Sending async POST request to {url} with data: {json.dumps(data, sort_keys=True)}")
|
||||
async def api_request(self, method: str, data: Dict) -> dict:
|
||||
url = f"{self.api_url}/{method}"
|
||||
logger.debug(f"Sending POST to {url} with data: {data}")
|
||||
async with self.session.post(url, json=data) as response:
|
||||
logger.debug(f"Got response: {await response.text()}")
|
||||
return response
|
||||
logger.debug(f"Received response from {method}: {response.status}")
|
||||
|
||||
async def verify_is_valid_api_response(self, response: ClientResponse):
|
||||
if response.status != 200:
|
||||
raise AssertionError(f"Bad HTTP status: {response.status}")
|
||||
try:
|
||||
json_data = await response.json()
|
||||
if "error" in json_data:
|
||||
if response.status != 200:
|
||||
body = await response.text()
|
||||
raise AssertionError(f"Bad HTTP status: {response.status}, body: {body}")
|
||||
|
||||
try:
|
||||
json_data = await response.json()
|
||||
except json.JSONDecodeError:
|
||||
body = await response.text()
|
||||
raise AssertionError(f"Invalid JSON in response: {body}")
|
||||
|
||||
if json_data.get("error"):
|
||||
raise AssertionError(f"API error: {json_data['error']}")
|
||||
except Exception as e:
|
||||
raise AssertionError(f"Invalid JSON response: {e}")
|
||||
|
||||
async def api_valid_request(self, method: str, data: Dict, url: str = None) -> ClientResponse:
|
||||
response = await self.api_request(method, data, url)
|
||||
await self.verify_is_valid_api_response(response)
|
||||
return response
|
||||
return json_data
|
||||
|
||||
async def start_status_backend(self) -> ClientResponse:
|
||||
async def api_valid_request(self, method: str, data: Dict) -> dict:
|
||||
json_data = await self.api_request(method, data)
|
||||
logger.debug(f"Valid response from {method}: {json_data}")
|
||||
return json_data
|
||||
|
||||
async def start_status_backend(self) -> dict:
|
||||
await self.__aenter__()
|
||||
try:
|
||||
await self.logout()
|
||||
logger.debug("Successfully logged out")
|
||||
@@ -126,22 +136,33 @@ class StatusBackend:
|
||||
self._set_networks(data)
|
||||
return data
|
||||
|
||||
async def create_account_and_login(self, **kwargs) -> ClientResponse:
|
||||
return await self.api_valid_request("CreateAccountAndLogin", self._create_account_request(**kwargs))
|
||||
async def create_account_and_login(self, **kwargs) -> dict | None:
|
||||
response = await self.api_valid_request("CreateAccountAndLogin", self._create_account_request(**kwargs))
|
||||
|
||||
async def login(self, key_uid: str) -> ClientResponse:
|
||||
return await self.api_valid_request("LoginAccount", {
|
||||
signal = await self.signal.wait_for_login()
|
||||
|
||||
self.set_public_key(signal)
|
||||
self.signal.node_login_event = signal
|
||||
return response
|
||||
|
||||
async def login(self, key_uid: str) -> dict:
|
||||
response = await self.api_valid_request("LoginAccount", {
|
||||
"password": "Strong12345",
|
||||
"keyUid": key_uid,
|
||||
"kdfIterations": 256000,
|
||||
})
|
||||
signal = await self.signal.wait_for_login()
|
||||
self.set_public_key(signal)
|
||||
return response
|
||||
|
||||
async def logout(self) -> ClientResponse:
|
||||
async def logout(self) -> dict:
|
||||
return await self.api_valid_request("Logout", {})
|
||||
|
||||
def set_public_key(self):
|
||||
# Only make sense to call this method if the lodes are logged in, otherwise public_key will be set to None.
|
||||
self.public_key = self.signal.node_login_event.get("event", {}).get("settings", {}).get("public-key")
|
||||
def set_public_key(self, signal_data: dict):
|
||||
self.public_key = signal_data.get("event", {}).get("settings", {}).get("public-key")
|
||||
|
||||
def find_key_uid(self) -> str:
|
||||
return self.signal.node_login_event.get("event", {}).get("account", {}).get("key-uid")
|
||||
recent = self.signal.get_recent_signals(SignalType.NODE_LOGIN.value)
|
||||
if not recent:
|
||||
raise RuntimeError("No login signal received to extract key UID")
|
||||
return recent[-1].get("event", {}).get("account", {}).get("key-uid")
|
||||
|
||||
@@ -11,42 +11,46 @@ class WakuextAsyncService(AsyncService):
|
||||
super().__init__(async_rpc_client, "wakuext")
|
||||
|
||||
async def start_messenger(self):
|
||||
response = await self.rpc_request("startMessenger")
|
||||
json_response = await response.json()
|
||||
json_response = await self.rpc_request("startMessenger")
|
||||
|
||||
if "error" in json_response:
|
||||
assert json_response["error"]["code"] == -32000
|
||||
assert json_response["error"]["message"] == "messenger already started"
|
||||
return
|
||||
|
||||
async def create_community(self, name: str, color="#ffffff", membership: int = 3) -> Dict:
|
||||
async def create_community(self, name: str, color="#ffffff", membership: int = 3) -> dict:
|
||||
# TODO check what is membership = 3
|
||||
params = [{"membership": membership, "name": name, "color": color, "description": name}]
|
||||
response = await self.rpc_request("createCommunity", params)
|
||||
return await response.json()
|
||||
json_response = await self.rpc_request("createCommunity", params)
|
||||
return json_response
|
||||
|
||||
async def fetch_community(self, community_key: str) -> Dict:
|
||||
async def fetch_community(self, community_key: str) -> dict:
|
||||
params = [{"communityKey": community_key, "waitForResponse": True, "tryDatabase": True}]
|
||||
response = await self.rpc_request("fetchCommunity", params)
|
||||
return await response.json()
|
||||
json_response = await self.rpc_request("fetchCommunity", params)
|
||||
return json_response
|
||||
|
||||
async def request_to_join_community(self, community_id: str, address: str = "fakeaddress") -> Dict:
|
||||
async def request_to_join_community(self, community_id: str, address: str = "fakeaddress") -> dict:
|
||||
params = [{"communityId": community_id, "addressesToReveal": [address], "airdropAddress": address}]
|
||||
response = await self.rpc_request("requestToJoinCommunity", params)
|
||||
return await response.json()
|
||||
json_response = await self.rpc_request("requestToJoinCommunity", params)
|
||||
return json_response
|
||||
|
||||
async def accept_request_to_join_community(self, request_to_join_id: str) -> Dict:
|
||||
async def accept_request_to_join_community(self, request_to_join_id: str) -> dict:
|
||||
params = [{"id": request_to_join_id}]
|
||||
response = await self.rpc_request("acceptRequestToJoinCommunity", params)
|
||||
return await response.json()
|
||||
json_response = await self.rpc_request("acceptRequestToJoinCommunity", params)
|
||||
return json_response
|
||||
|
||||
async def send_chat_message(self, chat_id: str, message: str, content_type: int = 1) -> Dict:
|
||||
async def send_chat_message(self, chat_id: str, message: str, content_type: int = 1) -> dict:
|
||||
# TODO content type can always be 1? (plain TEXT), does it need to be community type for communities?
|
||||
params = [{"chatId": chat_id, "text": message, "contentType": content_type}]
|
||||
response = await self.rpc_request("sendChatMessage", params)
|
||||
return await response.json()
|
||||
json_response = await self.rpc_request("sendChatMessage", params)
|
||||
return json_response
|
||||
|
||||
async def send_contact_request(self, contact_id: str, message: str) -> Dict:
|
||||
params = [{"id": contact_id, "message": message}]
|
||||
response = await self.rpc_request("sendContactRequest", params)
|
||||
return await response.json()
|
||||
json_response = await self.rpc_request("sendContactRequest", params)
|
||||
return json_response
|
||||
|
||||
async def accept_contact_request(self, request_id: str) -> dict:
|
||||
params = [{"id": request_id}]
|
||||
json_response = await self.rpc_request("acceptContactRequest", params)
|
||||
return json_response
|
||||
|
||||
@@ -10,6 +10,6 @@ class WalletAsyncService(AsyncService):
|
||||
def __init__(self, async_rpc_client: AsyncRpcClient):
|
||||
super().__init__(async_rpc_client, "wallet")
|
||||
|
||||
async def start_wallet(self) -> Any:
|
||||
response = await self.rpc_request("startWallet")
|
||||
return await response.json()
|
||||
async def start_wallet(self) -> dict:
|
||||
json_response = await self.rpc_request("startWallet")
|
||||
return json_response
|
||||
|
||||
Reference in New Issue
Block a user