feat(platform): WebSocket-based notifications (#11297)

This enables real time notifications from backend to browser via
WebSocket using Redis bus for moving notifications from REST process to
WebSocket process.
This is needed for (follow-up) backend-completion of onboarding tasks
with instant notifications.

### Changes 🏗️

- Add new `AsyncRedisNotificationEventBus` to enable publishing
notifications to the Redis event bus
- Consume notifications in `ws_api.py` similarly to execution events and
send them via WebSocket
- Store WebSocket user connections in `ConnectionManager`
- Add relevant tests and types

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  - [x] Notifications are sent to the frontend
This commit is contained in:
Krzysztof Czerwinski
2025-11-09 10:42:20 +09:00
committed by GitHub
parent 8058b9487b
commit 18bb78d93e
9 changed files with 159 additions and 18 deletions

View File

@@ -0,0 +1,33 @@
from __future__ import annotations
from typing import AsyncGenerator
from pydantic import BaseModel
from backend.data.event_bus import AsyncRedisEventBus
from backend.server.model import NotificationPayload
from backend.util.settings import Settings
class NotificationEvent(BaseModel):
"""Generic notification event destined for websocket delivery."""
user_id: str
payload: NotificationPayload
class AsyncRedisNotificationEventBus(AsyncRedisEventBus[NotificationEvent]):
Model = NotificationEvent # type: ignore
@property
def event_bus_name(self) -> str:
return Settings().config.notification_event_bus_name
async def publish(self, event: NotificationEvent) -> None:
await self.publish_event(event, event.user_id)
async def listen(
self, user_id: str = "*"
) -> AsyncGenerator[NotificationEvent, None]:
async for event in self.listen_events(user_id):
yield event

View File

@@ -1,3 +1,4 @@
import asyncio
from typing import Dict, Set
from fastapi import WebSocket
@@ -7,7 +8,7 @@ from backend.data.execution import (
GraphExecutionEvent,
NodeExecutionEvent,
)
from backend.server.model import WSMessage, WSMethod
from backend.server.model import NotificationPayload, WSMessage, WSMethod
_EVENT_TYPE_TO_METHOD_MAP: dict[ExecutionEventType, WSMethod] = {
ExecutionEventType.GRAPH_EXEC_UPDATE: WSMethod.GRAPH_EXECUTION_EVENT,
@@ -19,15 +20,24 @@ class ConnectionManager:
def __init__(self):
self.active_connections: Set[WebSocket] = set()
self.subscriptions: Dict[str, Set[WebSocket]] = {}
self.user_connections: Dict[str, Set[WebSocket]] = {}
async def connect_socket(self, websocket: WebSocket):
async def connect_socket(self, websocket: WebSocket, *, user_id: str):
await websocket.accept()
self.active_connections.add(websocket)
if user_id not in self.user_connections:
self.user_connections[user_id] = set()
self.user_connections[user_id].add(websocket)
def disconnect_socket(self, websocket: WebSocket):
self.active_connections.remove(websocket)
def disconnect_socket(self, websocket: WebSocket, *, user_id: str):
self.active_connections.discard(websocket)
for subscribers in self.subscriptions.values():
subscribers.discard(websocket)
user_conns = self.user_connections.get(user_id)
if user_conns is not None:
user_conns.discard(websocket)
if not user_conns:
self.user_connections.pop(user_id, None)
async def subscribe_graph_exec(
self, *, user_id: str, graph_exec_id: str, websocket: WebSocket
@@ -92,6 +102,26 @@ class ConnectionManager:
return n_sent
async def send_notification(
self, *, user_id: str, payload: NotificationPayload
) -> int:
"""Send a notification to all websocket connections belonging to a user."""
message = WSMessage(
method=WSMethod.NOTIFICATION,
data=payload.model_dump(),
).model_dump_json()
connections = tuple(self.user_connections.get(user_id, set()))
if not connections:
return 0
await asyncio.gather(
*(connection.send_text(message) for connection in connections),
return_exceptions=True,
)
return len(connections)
async def _subscribe(self, channel_key: str, websocket: WebSocket) -> str:
if channel_key not in self.subscriptions:
self.subscriptions[channel_key] = set()

View File

@@ -10,7 +10,7 @@ from backend.data.execution import (
NodeExecutionEvent,
)
from backend.server.conn_manager import ConnectionManager
from backend.server.model import WSMessage, WSMethod
from backend.server.model import NotificationPayload, WSMessage, WSMethod
@pytest.fixture
@@ -29,8 +29,9 @@ def mock_websocket() -> AsyncMock:
async def test_connect(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
await connection_manager.connect_socket(mock_websocket)
await connection_manager.connect_socket(mock_websocket, user_id="user-1")
assert mock_websocket in connection_manager.active_connections
assert mock_websocket in connection_manager.user_connections["user-1"]
mock_websocket.accept.assert_called_once()
@@ -39,11 +40,13 @@ def test_disconnect(
) -> None:
connection_manager.active_connections.add(mock_websocket)
connection_manager.subscriptions["test_channel_42"] = {mock_websocket}
connection_manager.user_connections["user-1"] = {mock_websocket}
connection_manager.disconnect_socket(mock_websocket)
connection_manager.disconnect_socket(mock_websocket, user_id="user-1")
assert mock_websocket not in connection_manager.active_connections
assert mock_websocket not in connection_manager.subscriptions["test_channel_42"]
assert "user-1" not in connection_manager.user_connections
@pytest.mark.asyncio
@@ -207,3 +210,22 @@ async def test_send_execution_result_no_subscribers(
await connection_manager.send_execution_update(result)
mock_websocket.send_text.assert_not_called()
@pytest.mark.asyncio
async def test_send_notification(
connection_manager: ConnectionManager, mock_websocket: AsyncMock
) -> None:
connection_manager.user_connections["user-1"] = {mock_websocket}
await connection_manager.send_notification(
user_id="user-1", payload=NotificationPayload(type="info", event="hey")
)
mock_websocket.send_text.assert_called_once()
sent_message = mock_websocket.send_text.call_args[0][0]
expected_message = WSMessage(
method=WSMethod.NOTIFICATION,
data={"type": "info", "event": "hey"},
).model_dump_json()
assert sent_message == expected_message

View File

@@ -14,6 +14,7 @@ class WSMethod(enum.Enum):
UNSUBSCRIBE = "unsubscribe"
GRAPH_EXECUTION_EVENT = "graph_execution_event"
NODE_EXECUTION_EVENT = "node_execution_event"
NOTIFICATION = "notification"
ERROR = "error"
HEARTBEAT = "heartbeat"
@@ -76,3 +77,12 @@ class TimezoneResponse(pydantic.BaseModel):
class UpdateTimezoneRequest(pydantic.BaseModel):
timezone: TimeZoneName
class NotificationPayload(pydantic.BaseModel):
type: str
event: str
class OnboardingNotificationPayload(NotificationPayload):
step: str

View File

@@ -10,6 +10,7 @@ from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
from starlette.middleware.cors import CORSMiddleware
from backend.data.execution import AsyncRedisExecutionEventBus
from backend.data.notification_bus import AsyncRedisNotificationEventBus
from backend.data.user import DEFAULT_USER_ID
from backend.monitoring.instrumentation import (
instrument_fastapi,
@@ -62,9 +63,21 @@ def get_connection_manager():
@continuous_retry()
async def event_broadcaster(manager: ConnectionManager):
event_queue = AsyncRedisExecutionEventBus()
async for event in event_queue.listen("*"):
await manager.send_execution_update(event)
execution_bus = AsyncRedisExecutionEventBus()
notification_bus = AsyncRedisNotificationEventBus()
async def execution_worker():
async for event in execution_bus.listen("*"):
await manager.send_execution_update(event)
async def notification_worker():
async for notification in notification_bus.listen("*"):
await manager.send_notification(
user_id=notification.user_id,
payload=notification.payload,
)
await asyncio.gather(execution_worker(), notification_worker())
async def authenticate_websocket(websocket: WebSocket) -> str:
@@ -229,7 +242,7 @@ async def websocket_router(
user_id = await authenticate_websocket(websocket)
if not user_id:
return
await manager.connect_socket(websocket)
await manager.connect_socket(websocket, user_id=user_id)
# Track WebSocket connection
update_websocket_connections(user_id, 1)
@@ -302,7 +315,7 @@ async def websocket_router(
)
except WebSocketDisconnect:
manager.disconnect_socket(websocket)
manager.disconnect_socket(websocket, user_id=user_id)
logger.debug("WebSocket client disconnected")
finally:
update_websocket_connections(user_id, -1)

View File

@@ -96,7 +96,9 @@ async def test_websocket_router_subscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
)
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
mock_manager.connect_socket.assert_called_once_with(
mock_websocket, user_id=DEFAULT_USER_ID
)
mock_manager.subscribe_graph_exec.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_exec_id="test-graph-exec-1",
@@ -115,7 +117,9 @@ async def test_websocket_router_subscribe(
snapshot.snapshot_dir = "snapshots"
snapshot.assert_match(json.dumps(parsed_message, indent=2, sort_keys=True), "sub")
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
mock_manager.disconnect_socket.assert_called_once_with(
mock_websocket, user_id=DEFAULT_USER_ID
)
@pytest.mark.asyncio
@@ -142,7 +146,9 @@ async def test_websocket_router_unsubscribe(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
)
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
mock_manager.connect_socket.assert_called_once_with(
mock_websocket, user_id=DEFAULT_USER_ID
)
mock_manager.unsubscribe_graph_exec.assert_called_once_with(
user_id=DEFAULT_USER_ID,
graph_exec_id="test-graph-exec-1",
@@ -158,7 +164,9 @@ async def test_websocket_router_unsubscribe(
snapshot.snapshot_dir = "snapshots"
snapshot.assert_match(json.dumps(parsed_message, indent=2, sort_keys=True), "unsub")
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
mock_manager.disconnect_socket.assert_called_once_with(
mock_websocket, user_id=DEFAULT_USER_ID
)
@pytest.mark.asyncio
@@ -179,11 +187,15 @@ async def test_websocket_router_invalid_method(
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager)
)
mock_manager.connect_socket.assert_called_once_with(mock_websocket)
mock_manager.connect_socket.assert_called_once_with(
mock_websocket, user_id=DEFAULT_USER_ID
)
mock_websocket.send_text.assert_called_once()
assert '"method":"error"' in mock_websocket.send_text.call_args[0][0]
assert '"success":false' in mock_websocket.send_text.call_args[0][0]
mock_manager.disconnect_socket.assert_called_once_with(mock_websocket)
mock_manager.disconnect_socket.assert_called_once_with(
mock_websocket, user_id=DEFAULT_USER_ID
)
@pytest.mark.asyncio

View File

@@ -418,6 +418,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
description="Name of the event bus",
)
notification_event_bus_name: str = Field(
default="notification_event",
description="Name of the websocket notification event bus",
)
trust_endpoints_for_requests: List[str] = Field(
default_factory=list,
description="A whitelist of trusted internal endpoints for the backend to make requests to.",

View File

@@ -70,6 +70,7 @@ import type {
UserOnboarding,
UserPasswordCredentials,
UsersBalanceHistoryResponse,
WebSocketNotification,
} from "./types";
import { environment } from "@/services/environment";
@@ -1346,6 +1347,7 @@ type WebsocketMessageTypeMap = {
subscribe_graph_executions: { graph_id: GraphID };
graph_execution_event: GraphExecution;
node_execution_event: NodeExecutionResult;
notification: WebSocketNotification;
heartbeat: "ping" | "pong";
};

View File

@@ -976,6 +976,20 @@ export interface UserOnboarding {
agentRuns: number;
}
export interface OnboardingNotificationPayload {
type: "onboarding";
event: string;
step: OnboardingStep;
}
export type WebSocketNotification =
| OnboardingNotificationPayload
| {
type: string;
event: string;
[key: string]: unknown;
};
/* *** UTILITIES *** */
/** Use branded types for IDs -> deny mixing IDs between different object classes */