mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 07:08:09 -05:00
feat(platform): WebSocket-based notifications (#11297)
This enables real time notifications from backend to browser via WebSocket using Redis bus for moving notifications from REST process to WebSocket process. This is needed for (follow-up) backend-completion of onboarding tasks with instant notifications. ### Changes 🏗️ - Add new `AsyncRedisNotificationEventBus` to enable publishing notifications to the Redis event bus - Consume notifications in `ws_api.py` similarly to execution events and send them via WebSocket - Store WebSocket user connections in `ConnectionManager` - Add relevant tests and types ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Notifications are sent to the frontend
This commit is contained in:
committed by
GitHub
parent
8058b9487b
commit
18bb78d93e
33
autogpt_platform/backend/backend/data/notification_bus.py
Normal file
33
autogpt_platform/backend/backend/data/notification_bus.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.data.event_bus import AsyncRedisEventBus
|
||||
from backend.server.model import NotificationPayload
|
||||
from backend.util.settings import Settings
|
||||
|
||||
|
||||
class NotificationEvent(BaseModel):
|
||||
"""Generic notification event destined for websocket delivery."""
|
||||
|
||||
user_id: str
|
||||
payload: NotificationPayload
|
||||
|
||||
|
||||
class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
|
||||
Model = NotificationEvent # type: ignore
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return Settings().config.notification_event_bus_name
|
||||
|
||||
async def publish(self, event: NotificationEvent) -> None:
|
||||
await self.publish_event(event, event.user_id)
|
||||
|
||||
async def listen(
|
||||
self, user_id: str = "*"
|
||||
) -> AsyncGenerator[NotificationEvent, None]:
|
||||
async for event in self.listen_events(user_id):
|
||||
yield event
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from typing import Dict, Set
|
||||
|
||||
from fastapi import WebSocket
|
||||
@@ -7,7 +8,7 @@ from backend.data.execution import (
|
||||
GraphExecutionEvent,
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.model import WSMessage, WSMethod
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
|
||||
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
|
||||
@@ -19,15 +20,24 @@ class ConnectionManager:
|
||||
def __init__(self):
|
||||
self.active_connections: Set[WebSocket] = set()
|
||||
self.subscriptions: Dict[str, Set[WebSocket]] = {}
|
||||
self.user_connections: Dict[str, Set[WebSocket]] = {}
|
||||
|
||||
async def connect_socket(self, websocket: WebSocket):
|
||||
async def connect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||
await websocket.accept()
|
||||
self.active_connections.add(websocket)
|
||||
if user_id not in self.user_connections:
|
||||
self.user_connections[user_id] = set()
|
||||
self.user_connections[user_id].add(websocket)
|
||||
|
||||
def disconnect_socket(self, websocket: WebSocket):
|
||||
self.active_connections.remove(websocket)
|
||||
def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
|
||||
self.active_connections.discard(websocket)
|
||||
for subscribers in self.subscriptions.values():
|
||||
subscribers.discard(websocket)
|
||||
user_conns = self.user_connections.get(user_id)
|
||||
if user_conns is not None:
|
||||
user_conns.discard(websocket)
|
||||
if not user_conns:
|
||||
self.user_connections.pop(user_id, None)
|
||||
|
||||
async def subscribe_graph_exec(
|
||||
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
|
||||
@@ -92,6 +102,26 @@ class ConnectionManager:
|
||||
|
||||
return n_sent
|
||||
|
||||
async def send_notification(
|
||||
self, *, user_id: str, payload: NotificationPayload
|
||||
) -> int:
|
||||
"""Send a notification to all websocket connections belonging to a user."""
|
||||
message = WSMessage(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data=payload.model_dump(),
|
||||
).model_dump_json()
|
||||
|
||||
connections = tuple(self.user_connections.get(user_id, set()))
|
||||
if not connections:
|
||||
return 0
|
||||
|
||||
await asyncio.gather(
|
||||
*(connection.send_text(message) for connection in connections),
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
return len(connections)
|
||||
|
||||
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
|
||||
if channel_key not in self.subscriptions:
|
||||
self.subscriptions[channel_key] = set()
|
||||
|
||||
@@ -10,7 +10,7 @@ from backend.data.execution import (
|
||||
NodeExecutionEvent,
|
||||
)
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import WSMessage, WSMethod
|
||||
from backend.server.model import NotificationPayload, WSMessage, WSMethod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -29,8 +29,9 @@ def mock_websocket() -> AsyncMock:
|
||||
async def test_connect(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
await connection_manager.connect_socket(mock_websocket)
|
||||
await connection_manager.connect_socket(mock_websocket, user_id="user-1")
|
||||
assert mock_websocket in connection_manager.active_connections
|
||||
assert mock_websocket in connection_manager.user_connections["user-1"]
|
||||
mock_websocket.accept.assert_called_once()
|
||||
|
||||
|
||||
@@ -39,11 +40,13 @@ def test_disconnect(
|
||||
) -> None:
|
||||
connection_manager.active_connections.add(mock_websocket)
|
||||
connection_manager.subscriptions["test_channel_42"] = {mock_websocket}
|
||||
connection_manager.user_connections["user-1"] = {mock_websocket}
|
||||
|
||||
connection_manager.disconnect_socket(mock_websocket)
|
||||
connection_manager.disconnect_socket(mock_websocket, user_id="user-1")
|
||||
|
||||
assert mock_websocket not in connection_manager.active_connections
|
||||
assert mock_websocket not in connection_manager.subscriptions["test_channel_42"]
|
||||
assert "user-1" not in connection_manager.user_connections
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -207,3 +210,22 @@ async def test_send_execution_result_no_subscribers(
|
||||
await connection_manager.send_execution_update(result)
|
||||
|
||||
mock_websocket.send_text.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_notification(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
connection_manager.user_connections["user-1"] = {mock_websocket}
|
||||
|
||||
await connection_manager.send_notification(
|
||||
user_id="user-1", payload=NotificationPayload(type="info", event="hey")
|
||||
)
|
||||
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
sent_message = mock_websocket.send_text.call_args[0][0]
|
||||
expected_message = WSMessage(
|
||||
method=WSMethod.NOTIFICATION,
|
||||
data={"type": "info", "event": "hey"},
|
||||
).model_dump_json()
|
||||
assert sent_message == expected_message
|
||||
|
||||
@@ -14,6 +14,7 @@ class WSMethod(enum.Enum):
|
||||
UNSUBSCRIBE = "unsubscribe"
|
||||
GRAPH_EXECUTION_EVENT = "graph_execution_event"
|
||||
NODE_EXECUTION_EVENT = "node_execution_event"
|
||||
NOTIFICATION = "notification"
|
||||
ERROR = "error"
|
||||
HEARTBEAT = "heartbeat"
|
||||
|
||||
@@ -76,3 +77,12 @@ class TimezoneResponse(pydantic.BaseModel):
|
||||
|
||||
class UpdateTimezoneRequest(pydantic.BaseModel):
|
||||
timezone: TimeZoneName
|
||||
|
||||
|
||||
class NotificationPayload(pydantic.BaseModel):
|
||||
type: str
|
||||
event: str
|
||||
|
||||
|
||||
class OnboardingNotificationPayload(NotificationPayload):
|
||||
step: str
|
||||
|
||||
@@ -10,6 +10,7 @@ from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.notification_bus import AsyncRedisNotificationEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.monitoring.instrumentation import (
|
||||
instrument_fastapi,
|
||||
@@ -62,9 +63,21 @@ def get_connection_manager():
|
||||
|
||||
@continuous_retry()
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
async for event in event_queue.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
execution_bus = AsyncRedisExecutionEventBus()
|
||||
notification_bus = AsyncRedisNotificationEventBus()
|
||||
|
||||
async def execution_worker():
|
||||
async for event in execution_bus.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
|
||||
async def notification_worker():
|
||||
async for notification in notification_bus.listen("*"):
|
||||
await manager.send_notification(
|
||||
user_id=notification.user_id,
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
@@ -229,7 +242,7 @@ async def websocket_router(
|
||||
user_id = await authenticate_websocket(websocket)
|
||||
if not user_id:
|
||||
return
|
||||
await manager.connect_socket(websocket)
|
||||
await manager.connect_socket(websocket, user_id=user_id)
|
||||
|
||||
# Track WebSocket connection
|
||||
update_websocket_connections(user_id, 1)
|
||||
@@ -302,7 +315,7 @@ async def websocket_router(
|
||||
)
|
||||
|
||||
except WebSocketDisconnect:
|
||||
manager.disconnect_socket(websocket)
|
||||
manager.disconnect_socket(websocket, user_id=user_id)
|
||||
logger.debug("WebSocket client disconnected")
|
||||
finally:
|
||||
update_websocket_connections(user_id, -1)
|
||||
|
||||
@@ -96,7 +96,9 @@ async def test_websocket_router_subscribe(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
|
||||
)
|
||||
|
||||
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.connect_socket.assert_called_once_with(
|
||||
mock_websocket, user_id=DEFAULT_USER_ID
|
||||
)
|
||||
mock_manager.subscribe_graph_exec.assert_called_once_with(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
graph_exec_id="test-graph-exec-1",
|
||||
@@ -115,7 +117,9 @@ async def test_websocket_router_subscribe(
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(json.dumps(parsed_message, indent=2, sort_keys=True), "sub")
|
||||
|
||||
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.disconnect_socket.assert_called_once_with(
|
||||
mock_websocket, user_id=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -142,7 +146,9 @@ async def test_websocket_router_unsubscribe(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
|
||||
)
|
||||
|
||||
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.connect_socket.assert_called_once_with(
|
||||
mock_websocket, user_id=DEFAULT_USER_ID
|
||||
)
|
||||
mock_manager.unsubscribe_graph_exec.assert_called_once_with(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
graph_exec_id="test-graph-exec-1",
|
||||
@@ -158,7 +164,9 @@ async def test_websocket_router_unsubscribe(
|
||||
snapshot.snapshot_dir = "snapshots"
|
||||
snapshot.assert_match(json.dumps(parsed_message, indent=2, sort_keys=True), "unsub")
|
||||
|
||||
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.disconnect_socket.assert_called_once_with(
|
||||
mock_websocket, user_id=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -179,11 +187,15 @@ async def test_websocket_router_invalid_method(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
|
||||
)
|
||||
|
||||
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.connect_socket.assert_called_once_with(
|
||||
mock_websocket, user_id=DEFAULT_USER_ID
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
|
||||
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
|
||||
mock_manager.disconnect_socket.assert_called_once_with(
|
||||
mock_websocket, user_id=DEFAULT_USER_ID
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -418,6 +418,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="Name of the event bus",
|
||||
)
|
||||
|
||||
notification_event_bus_name: str = Field(
|
||||
default="notification_event",
|
||||
description="Name of the websocket notification event bus",
|
||||
)
|
||||
|
||||
trust_endpoints_for_requests: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="A whitelist of trusted internal endpoints for the backend to make requests to.",
|
||||
|
||||
@@ -70,6 +70,7 @@ import type {
|
||||
UserOnboarding,
|
||||
UserPasswordCredentials,
|
||||
UsersBalanceHistoryResponse,
|
||||
WebSocketNotification,
|
||||
} from "./types";
|
||||
import { environment } from "@/services/environment";
|
||||
|
||||
@@ -1346,6 +1347,7 @@ type WebsocketMessageTypeMap = {
|
||||
subscribe_graph_executions: { graph_id: GraphID };
|
||||
graph_execution_event: GraphExecution;
|
||||
node_execution_event: NodeExecutionResult;
|
||||
notification: WebSocketNotification;
|
||||
heartbeat: "ping" | "pong";
|
||||
};
|
||||
|
||||
|
||||
@@ -976,6 +976,20 @@ export interface UserOnboarding {
|
||||
agentRuns: number;
|
||||
}
|
||||
|
||||
export interface OnboardingNotificationPayload {
|
||||
type: "onboarding";
|
||||
event: string;
|
||||
step: OnboardingStep;
|
||||
}
|
||||
|
||||
export type WebSocketNotification =
|
||||
| OnboardingNotificationPayload
|
||||
| {
|
||||
type: string;
|
||||
event: string;
|
||||
[key: string]: unknown;
|
||||
};
|
||||
|
||||
/* *** UTILITIES *** */
|
||||
|
||||
/** Use branded types for IDs -> deny mixing IDs between different object classes */
|
||||
|
||||
Reference in New Issue
Block a user