Buffered queue (#14)

* Import enum from local

* Change class name because of previous rebase

* Fix double await because of previous changes in RPCClient

* Change Singal class to create a stream with a bufferedqueue and an asyncgenerator

* Change StatusBackend to use previous signal changes. Removed __aexit__ to have better control over lifetime

* Clean imports

* Add enums class
This commit is contained in:
Alberto Soutullo
2025-09-25 10:40:02 +02:00
committed by GitHub
parent 8b518b75b2
commit 2c6f8a57b8
7 changed files with 182 additions and 127 deletions

View File

@@ -10,9 +10,9 @@ class AccountAsyncService(AsyncService):
super().__init__(rpc, "accounts")
async def get_accounts(self) -> dict:
response_dict = await self.rpc.rpc_valid_request("getAccounts")
return response_dict
json_response = await self.rpc.rpc_valid_request("getAccounts")
return json_response
async def get_account_keypairs(self) -> dict:
response_dict = await self.rpc.rpc_valid_request("getKeypairs")
return response_dict
json_response = await self.rpc.rpc_valid_request("getKeypairs")
return json_response

32
src/enums.py Normal file
View File

@@ -0,0 +1,32 @@
# Python Imports
from enum import Enum
class MessageContentType(Enum):
UNKNOWN_CONTENT_TYPE = 0
TEXT_PLAIN = 1
STICKER = 2
STATUS = 3
EMOJI = 4
TRANSACTION_COMMAND = 5
SYSTEM_MESSAGE_CONTENT_PRIVATE_GROUP = 6
IMAGE = 7
AUDIO = 8
COMMUNITY = 9
SYSTEM_MESSAGE_GAP = 10
CONTACT_REQUEST = 11
DISCORD_MESSAGE = 12
IDENTITY_VERIFICATION = 13
SYSTEM_MESSAGE_PINNED_MESSAGE = 14
SYSTEM_MESSAGE_MUTUAL_EVENT_SENT = 15
SYSTEM_MESSAGE_MUTUAL_EVENT_ACCEPTED = 16
SYSTEM_MESSAGE_MUTUAL_EVENT_REMOVED = 17
BRIDGE_MESSAGE = 18
class SignalType(Enum):
MESSAGES_NEW = "messages.new"
MESSAGE_DELIVERED = "message.delivered"
NODE_READY = "node.ready"
NODE_STARTED = "node.started"
NODE_LOGIN = "node.login"
NODE_LOGOUT = "node.stopped"

View File

@@ -1,5 +1,5 @@
# Python Imports
from typing import Optional, Any
from typing import Optional
# Project Imports
from src.rpc_client import AsyncRpcClient
@@ -10,6 +10,7 @@ class AsyncService:
self.rpc = async_rpc_client
self.name = name
async def rpc_request(self, method: str, params: Optional[list] = None, enable_logging: bool = True) -> Any:
async def rpc_request(self, method: str, params: Optional[list] = None, enable_logging: bool = True) -> dict:
# In order to be validated, the response is already awaited, so this already returns the dict data
full_method_name = f"{self.name}_{method}"
return await self.rpc.rpc_valid_request(full_method_name, params or [], enable_logging=enable_logging)

View File

@@ -1,15 +1,17 @@
# Python Imports
import asyncio
import contextlib
import json
import logging
import os
from enum import Enum
from typing import Optional, Callable
from typing import Optional, AsyncGenerator
from aiohttp import ClientSession, ClientWebSocketResponse, WSMsgType
from pathlib import Path
from datetime import datetime
from collections import deque
# Project Imports
from src.enums import SignalType
logger = logging.getLogger(__name__)
@@ -17,38 +19,33 @@ LOG_SIGNALS_TO_FILE = False
SIGNALS_DIR = os.path.dirname(os.path.abspath(__file__))
class SignalType(Enum):
MESSAGES_NEW = "messages.new"
MESSAGE_DELIVERED = "message.delivered"
NODE_READY = "node.ready"
NODE_STARTED = "node.started"
NODE_LOGIN = "node.login"
NODE_LOGOUT = "node.stopped"
class BufferedQueue:
def __init__(self, max_size: int = 100):
self.queue = asyncio.Queue()
self.buffer = deque(maxlen=max_size)
async def put(self, item):
self.buffer.append(item)
await self.queue.put(item)
async def get(self):
return await self.queue.get()
def recent(self) -> list:
return list(self.buffer)
class AsyncSignalClient:
def __init__(self, ws_url: str, await_signals: list[str]):
def __init__(self, ws_url: str, await_signals: list[str], buffer_size: int = 100):
self.url = f"{ws_url}/signals"
self.await_signals = await_signals
self.ws: Optional[ClientWebSocketResponse] = None
self.session: Optional[ClientSession] = None
self.signal_file_path = None
self.signal_lock = asyncio.Lock()
self.listener_task = None
self.received_signals: dict[str, dict] = {
# For each signal type, store:
# - list of received signals
# - expected received event delta count (resets to 1 after each wait_for_event call)
# - expected received event count
# - a function that takes the received signal as an argument and returns True if the signal is accepted (counted) or discarded
signal: {
"received": [],
"delta_count": 1,
"expected_count": 1,
"accept_fn": None,
}
for signal in self.await_signals
self.signal_queues: dict[str, BufferedQueue] = {
signal: BufferedQueue(max_size=buffer_size) for signal in self.await_signals
}
if LOG_SIGNALS_TO_FILE: # Not being used currently
@@ -61,7 +58,8 @@ class AsyncSignalClient:
async def __aenter__(self):
self.session = ClientSession()
self.ws = await self.session.ws_connect(self.url)
asyncio.create_task(self._listen())
self.listener_task = asyncio.create_task(self._listen())
await asyncio.sleep(0) # Yield control to ensure _listen starts
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
@@ -69,8 +67,13 @@ class AsyncSignalClient:
await self.ws.close()
if self.session:
await self.session.close()
if self.listener_task:
self.listener_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self.listener_task
async def _listen(self):
logger.debug("WebSocket listener started")
async for msg in self.ws:
if msg.type == WSMsgType.TEXT:
await self.on_message(msg.data)
@@ -79,50 +82,44 @@ class AsyncSignalClient:
async def on_message(self, signal: str):
signal_data = json.loads(signal)
logger.debug(f"Received WebSocket message: {signal_data}")
if LOG_SIGNALS_TO_FILE:
pass # TODO
pass # TODO: write to file if needed
signal_type = signal_data.get("type")
if signal_type in self.await_signals:
async with self.signal_lock:
accept_fn = self.received_signals[signal_type]["accept_fn"]
if not accept_fn or accept_fn(signal_data):
self.received_signals[signal_type]["received"].append(signal_data)
if signal_type in self.signal_queues:
await self.signal_queues[signal_type].put(signal_data)
logger.debug(f"Queued signal: {signal_type}")
else:
logger.debug(f"Ignored signal not in await list: {signal_type}")
# Used to set up how many instances of a signal to wait for, before triggering the actions
# that cause them to be emitted.
async def prepare_wait_for_signal(self, signal_type: str, delta_count: int, accept_fn: Optional[Callable] = None):
if signal_type not in self.await_signals:
async def wait_for_signal(self, signal_type: str, timeout: int = 20) -> dict:
if signal_type not in self.signal_queues:
raise ValueError(f"Signal type {signal_type} is not in the list of awaited signals")
async with self.signal_lock:
self.received_signals[signal_type]["delta_count"] = delta_count
self.received_signals[signal_type]["expected_count"] = (
len(self.received_signals[signal_type]["received"]) + delta_count
)
self.received_signals[signal_type]["accept_fn"] = accept_fn
async def wait_for_signal(self, signal_type: str, timeout: int = 20) -> dict | list[dict]:
if signal_type not in self.await_signals:
raise ValueError(f"Signal type {signal_type} is not in the list of awaited signals")
start_time = asyncio.get_event_loop().time()
while True:
async with self.signal_lock:
received = self.received_signals[signal_type]["received"]
expected = self.received_signals[signal_type]["expected_count"]
delta_count = self.received_signals[signal_type]["delta_count"]
if len(received) >= expected:
await self.prepare_wait_for_signal(signal_type, 1)
return received[-1] if delta_count == 1 else received[-delta_count:]
if asyncio.get_event_loop().time() - start_time >= timeout:
try:
signal = await asyncio.wait_for(self.signal_queues[signal_type].get(), timeout)
logger.info(f"Received {signal_type} signal: {signal}")
return signal
except asyncio.TimeoutError:
raise TimeoutError(f"Signal {signal_type} not received in {timeout} seconds")
await asyncio.sleep(0.2)
async def signal_stream(self, signal_type: str) -> AsyncGenerator[dict, None]:
if signal_type not in self.signal_queues:
raise ValueError(f"Signal type {signal_type} is not in the list of awaited signals")
while True:
yield await self.signal_queues[signal_type].get()
def get_recent_signals(self, signal_type: str) -> list:
if signal_type not in self.signal_queues:
raise ValueError(f"Signal type {signal_type} is not in the list of awaited signals")
return self.signal_queues[signal_type].recent()
async def wait_for_login(self) -> dict:
logger.info("Waiting for login signal...")
signal = await self.wait_for_signal(SignalType.NODE_LOGIN.value)
if "error" in signal["event"]:
logger.info(f"Login signal received: {signal}")
if "error" in signal.get("event", {}):
error_details = signal["event"]["error"]
assert not error_details, f"Unexpected error during login: {error_details}"
self.node_login_event = signal
@@ -134,12 +131,12 @@ class AsyncSignalClient:
async def find_signal_containing_string(self, signal_type: str, event_string: str, timeout=20) -> Optional[dict]:
start_time = asyncio.get_event_loop().time()
while True:
async with self.signal_lock:
for event in self.received_signals.get(signal_type, {}).get("received", []):
if event_string in json.dumps(event):
try:
signal = await asyncio.wait_for(self.signal_queues[signal_type].get(), timeout)
if event_string in json.dumps(signal):
logger.info(f"Found {signal_type} containing '{event_string}'")
return event
if asyncio.get_event_loop().time() - start_time >= timeout:
raise TimeoutError(f"Signal {signal_type} containing '{event_string}' not received in {timeout} seconds")
await asyncio.sleep(0.2)
return signal
except asyncio.TimeoutError:
raise TimeoutError(
f"Signal {signal_type} containing '{event_string}' not received in {timeout} seconds"
)

View File

@@ -1,10 +1,12 @@
import json
# Python Imports
import logging
import json
from typing import List, Dict
from aiohttp import ClientSession, ClientTimeout, ClientResponse
from aiohttp import ClientSession, ClientTimeout
# Project Imports
from src.account_service import AccountAsyncService
from src.enums import SignalType
from src.rpc_client import AsyncRpcClient
from src.signal_client import AsyncSignalClient
from src.wakuext_service import WakuextAsyncService
@@ -35,37 +37,45 @@ class StatusBackend:
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.signal.__aexit__(exc_type, exc_val, exc_tb)
await self.rpc.__aexit__(exc_type, exc_val, exc_tb)
# Let the caller handle shutdown
pass
async def shutdown(self):
await self.signal.__aexit__(None, None, None)
await self.rpc.__aexit__(None, None, None)
await self.session.close()
async def call_rpc(self, method: str, params: List = None):
return await self.rpc.rpc_valid_request(method, params or [])
async def api_request(self, method: str, data: Dict, url: str = None) -> ClientResponse:
url = url or self.api_url
url = f"{url}/{method}"
logger.debug(f"Sending async POST request to {url} with data: {json.dumps(data, sort_keys=True)}")
async def api_request(self, method: str, data: Dict) -> dict:
url = f"{self.api_url}/{method}"
logger.debug(f"Sending POST to {url} with data: {data}")
async with self.session.post(url, json=data) as response:
logger.debug(f"Got response: {await response.text()}")
return response
logger.debug(f"Received response from {method}: {response.status}")
async def verify_is_valid_api_response(self, response: ClientResponse):
if response.status != 200:
raise AssertionError(f"Bad HTTP status: {response.status}")
body = await response.text()
raise AssertionError(f"Bad HTTP status: {response.status}, body: {body}")
try:
json_data = await response.json()
if "error" in json_data:
except json.JSONDecodeError:
body = await response.text()
raise AssertionError(f"Invalid JSON in response: {body}")
if json_data.get("error"):
raise AssertionError(f"API error: {json_data['error']}")
except Exception as e:
raise AssertionError(f"Invalid JSON response: {e}")
async def api_valid_request(self, method: str, data: Dict, url: str = None) -> ClientResponse:
response = await self.api_request(method, data, url)
await self.verify_is_valid_api_response(response)
return response
return json_data
async def start_status_backend(self) -> ClientResponse:
async def api_valid_request(self, method: str, data: Dict) -> dict:
json_data = await self.api_request(method, data)
logger.debug(f"Valid response from {method}: {json_data}")
return json_data
async def start_status_backend(self) -> dict:
await self.__aenter__()
try:
await self.logout()
logger.debug("Successfully logged out")
@@ -126,22 +136,33 @@ class StatusBackend:
self._set_networks(data)
return data
async def create_account_and_login(self, **kwargs) -> ClientResponse:
return await self.api_valid_request("CreateAccountAndLogin", self._create_account_request(**kwargs))
async def create_account_and_login(self, **kwargs) -> dict | None:
response = await self.api_valid_request("CreateAccountAndLogin", self._create_account_request(**kwargs))
async def login(self, key_uid: str) -> ClientResponse:
return await self.api_valid_request("LoginAccount", {
signal = await self.signal.wait_for_login()
self.set_public_key(signal)
self.signal.node_login_event = signal
return response
async def login(self, key_uid: str) -> dict:
response = await self.api_valid_request("LoginAccount", {
"password": "Strong12345",
"keyUid": key_uid,
"kdfIterations": 256000,
})
signal = await self.signal.wait_for_login()
self.set_public_key(signal)
return response
async def logout(self) -> ClientResponse:
async def logout(self) -> dict:
return await self.api_valid_request("Logout", {})
def set_public_key(self):
# Only make sense to call this method if the lodes are logged in, otherwise public_key will be set to None.
self.public_key = self.signal.node_login_event.get("event", {}).get("settings", {}).get("public-key")
def set_public_key(self, signal_data: dict):
self.public_key = signal_data.get("event", {}).get("settings", {}).get("public-key")
def find_key_uid(self) -> str:
return self.signal.node_login_event.get("event", {}).get("account", {}).get("key-uid")
recent = self.signal.get_recent_signals(SignalType.NODE_LOGIN.value)
if not recent:
raise RuntimeError("No login signal received to extract key UID")
return recent[-1].get("event", {}).get("account", {}).get("key-uid")

View File

@@ -11,42 +11,46 @@ class WakuextAsyncService(AsyncService):
super().__init__(async_rpc_client, "wakuext")
async def start_messenger(self):
response = await self.rpc_request("startMessenger")
json_response = await response.json()
json_response = await self.rpc_request("startMessenger")
if "error" in json_response:
assert json_response["error"]["code"] == -32000
assert json_response["error"]["message"] == "messenger already started"
return
async def create_community(self, name: str, color="#ffffff", membership: int = 3) -> Dict:
async def create_community(self, name: str, color="#ffffff", membership: int = 3) -> dict:
# TODO check what is membership = 3
params = [{"membership": membership, "name": name, "color": color, "description": name}]
response = await self.rpc_request("createCommunity", params)
return await response.json()
json_response = await self.rpc_request("createCommunity", params)
return json_response
async def fetch_community(self, community_key: str) -> Dict:
async def fetch_community(self, community_key: str) -> dict:
params = [{"communityKey": community_key, "waitForResponse": True, "tryDatabase": True}]
response = await self.rpc_request("fetchCommunity", params)
return await response.json()
json_response = await self.rpc_request("fetchCommunity", params)
return json_response
async def request_to_join_community(self, community_id: str, address: str = "fakeaddress") -> Dict:
async def request_to_join_community(self, community_id: str, address: str = "fakeaddress") -> dict:
params = [{"communityId": community_id, "addressesToReveal": [address], "airdropAddress": address}]
response = await self.rpc_request("requestToJoinCommunity", params)
return await response.json()
json_response = await self.rpc_request("requestToJoinCommunity", params)
return json_response
async def accept_request_to_join_community(self, request_to_join_id: str) -> Dict:
async def accept_request_to_join_community(self, request_to_join_id: str) -> dict:
params = [{"id": request_to_join_id}]
response = await self.rpc_request("acceptRequestToJoinCommunity", params)
return await response.json()
json_response = await self.rpc_request("acceptRequestToJoinCommunity", params)
return json_response
async def send_chat_message(self, chat_id: str, message: str, content_type: int = 1) -> Dict:
async def send_chat_message(self, chat_id: str, message: str, content_type: int = 1) -> dict:
# TODO content type can always be 1? (plain TEXT), does it need to be community type for communities?
params = [{"chatId": chat_id, "text": message, "contentType": content_type}]
response = await self.rpc_request("sendChatMessage", params)
return await response.json()
json_response = await self.rpc_request("sendChatMessage", params)
return json_response
async def send_contact_request(self, contact_id: str, message: str) -> Dict:
params = [{"id": contact_id, "message": message}]
response = await self.rpc_request("sendContactRequest", params)
return await response.json()
json_response = await self.rpc_request("sendContactRequest", params)
return json_response
async def accept_contact_request(self, request_id: str) -> dict:
params = [{"id": request_id}]
json_response = await self.rpc_request("acceptContactRequest", params)
return json_response

View File

@@ -10,6 +10,6 @@ class WalletAsyncService(AsyncService):
def __init__(self, async_rpc_client: AsyncRpcClient):
super().__init__(async_rpc_client, "wallet")
async def start_wallet(self) -> Any:
response = await self.rpc_request("startWallet")
return await response.json()
async def start_wallet(self) -> dict:
json_response = await self.rpc_request("startWallet")
return json_response