mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-06 22:03:59 -05:00
fix(backend/ws): Add user_id to websocket event subscription key (#9660)
- Add `user_id` to WS subscription key - Add error catching to WS message handler
This commit is contained in:
committed by
GitHub
parent
90b147ff51
commit
9a661b5101
@@ -1,7 +1,7 @@
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from multiprocessing import Manager
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, Type, TypeVar
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, Optional, Type, TypeVar
|
||||
|
||||
from prisma import Json
|
||||
from prisma.enums import AgentExecutionStatus
|
||||
@@ -65,6 +65,7 @@ class ExecutionQueue(Generic[T]):
|
||||
|
||||
|
||||
class ExecutionResult(BaseModel):
|
||||
user_id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
graph_exec_id: str
|
||||
@@ -80,27 +81,28 @@ class ExecutionResult(BaseModel):
|
||||
end_time: datetime | None
|
||||
|
||||
@staticmethod
|
||||
def from_graph(graph: AgentGraphExecution):
|
||||
def from_graph(graph_exec: AgentGraphExecution):
|
||||
return ExecutionResult(
|
||||
graph_id=graph.agentGraphId,
|
||||
graph_version=graph.agentGraphVersion,
|
||||
graph_exec_id=graph.id,
|
||||
user_id=graph_exec.userId,
|
||||
graph_id=graph_exec.agentGraphId,
|
||||
graph_version=graph_exec.agentGraphVersion,
|
||||
graph_exec_id=graph_exec.id,
|
||||
node_exec_id="",
|
||||
node_id="",
|
||||
block_id="",
|
||||
status=graph.executionStatus,
|
||||
status=graph_exec.executionStatus,
|
||||
# TODO: Populate input_data & output_data from AgentNodeExecutions
|
||||
# Input & Output comes AgentInputBlock & AgentOutputBlock.
|
||||
input_data={},
|
||||
output_data={},
|
||||
add_time=graph.createdAt,
|
||||
queue_time=graph.createdAt,
|
||||
start_time=graph.startedAt,
|
||||
end_time=graph.updatedAt,
|
||||
add_time=graph_exec.createdAt,
|
||||
queue_time=graph_exec.createdAt,
|
||||
start_time=graph_exec.startedAt,
|
||||
end_time=graph_exec.updatedAt,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_db(execution: AgentNodeExecution):
|
||||
def from_db(execution: AgentNodeExecution, user_id: Optional[str] = None):
|
||||
if execution.executionData:
|
||||
# Execution that has been queued for execution will persist its data.
|
||||
input_data = type.convert(execution.executionData, dict[str, Any])
|
||||
@@ -115,8 +117,15 @@ class ExecutionResult(BaseModel):
|
||||
output_data[data.name].append(type.convert(data.data, Type[Any]))
|
||||
|
||||
graph_execution: AgentGraphExecution | None = execution.AgentGraphExecution
|
||||
if graph_execution:
|
||||
user_id = graph_execution.userId
|
||||
elif not user_id:
|
||||
raise ValueError(
|
||||
"AgentGraphExecution must be included or user_id passed in"
|
||||
)
|
||||
|
||||
return ExecutionResult(
|
||||
user_id=user_id,
|
||||
graph_id=graph_execution.agentGraphId if graph_execution else "",
|
||||
graph_version=graph_execution.agentGraphVersion if graph_execution else 0,
|
||||
graph_exec_id=execution.agentGraphExecutionId,
|
||||
@@ -175,7 +184,7 @@ async def create_graph_execution(
|
||||
)
|
||||
|
||||
return result.id, [
|
||||
ExecutionResult.from_db(execution)
|
||||
ExecutionResult.from_db(execution, result.userId)
|
||||
for execution in result.AgentNodeExecutions or []
|
||||
]
|
||||
|
||||
|
||||
@@ -217,7 +217,8 @@ class GraphExecution(GraphExecutionMeta):
|
||||
graph_exec = GraphExecutionMeta.from_db(_graph_exec)
|
||||
|
||||
node_executions = [
|
||||
ExecutionResult.from_db(ne) for ne in _graph_exec.AgentNodeExecutions
|
||||
ExecutionResult.from_db(ne, _graph_exec.userId)
|
||||
for ne in _graph_exec.AgentNodeExecutions
|
||||
]
|
||||
|
||||
inputs = {
|
||||
|
||||
@@ -20,23 +20,25 @@ class ConnectionManager:
|
||||
for subscribers in self.subscriptions.values():
|
||||
subscribers.discard(websocket)
|
||||
|
||||
async def subscribe(self, graph_id: str, graph_version: int, websocket: WebSocket):
|
||||
key = f"{graph_id}_{graph_version}"
|
||||
async def subscribe(
|
||||
self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
|
||||
):
|
||||
key = f"{user_id}_{graph_id}_{graph_version}"
|
||||
if key not in self.subscriptions:
|
||||
self.subscriptions[key] = set()
|
||||
self.subscriptions[key].add(websocket)
|
||||
|
||||
async def unsubscribe(
|
||||
self, graph_id: str, graph_version: int, websocket: WebSocket
|
||||
self, *, user_id: str, graph_id: str, graph_version: int, websocket: WebSocket
|
||||
):
|
||||
key = f"{graph_id}_{graph_version}"
|
||||
key = f"{user_id}_{graph_id}_{graph_version}"
|
||||
if key in self.subscriptions:
|
||||
self.subscriptions[key].discard(websocket)
|
||||
if not self.subscriptions[key]:
|
||||
del self.subscriptions[key]
|
||||
|
||||
async def send_execution_result(self, result: execution.ExecutionResult):
|
||||
key = f"{result.graph_id}_{result.graph_version}"
|
||||
key = f"{result.user_id}_{result.graph_id}_{result.graph_version}"
|
||||
if key in self.subscriptions:
|
||||
message = WsMessage(
|
||||
method=Methods.EXECUTION_EVENT,
|
||||
|
||||
@@ -4,6 +4,7 @@ from contextlib import asynccontextmanager
|
||||
|
||||
import uvicorn
|
||||
from autogpt_libs.auth import parse_jwt_token
|
||||
from autogpt_libs.utils.cache import thread_cached
|
||||
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
@@ -12,7 +13,7 @@ from backend.data.execution import AsyncRedisExecutionEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import ExecutionSubscription, Methods, WsMessage
|
||||
from backend.util.service import AppProcess
|
||||
from backend.util.service import AppProcess, get_service_client
|
||||
from backend.util.settings import AppEnvironment, Config, Settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -39,6 +40,13 @@ def get_connection_manager():
|
||||
return _connection_manager
|
||||
|
||||
|
||||
@thread_cached
|
||||
def get_db_client():
|
||||
from backend.executor import DatabaseManager
|
||||
|
||||
return get_service_client(DatabaseManager)
|
||||
|
||||
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
try:
|
||||
redis.connect()
|
||||
@@ -74,7 +82,10 @@ async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
|
||||
async def handle_subscribe(
|
||||
websocket: WebSocket, manager: ConnectionManager, message: WsMessage
|
||||
connection_manager: ConnectionManager,
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
message: WsMessage,
|
||||
):
|
||||
if not message.data:
|
||||
await websocket.send_text(
|
||||
@@ -85,20 +96,47 @@ async def handle_subscribe(
|
||||
).model_dump_json()
|
||||
)
|
||||
else:
|
||||
ex_sub = ExecutionSubscription.model_validate(message.data)
|
||||
await manager.subscribe(ex_sub.graph_id, ex_sub.graph_version, websocket)
|
||||
logger.debug(f"New execution subscription for graph {ex_sub.graph_id}")
|
||||
sub_req = ExecutionSubscription.model_validate(message.data)
|
||||
|
||||
# Verify that user has read access to graph
|
||||
# if not get_db_client().get_graph(
|
||||
# graph_id=sub_req.graph_id,
|
||||
# version=sub_req.graph_version,
|
||||
# user_id=user_id,
|
||||
# ):
|
||||
# await websocket.send_text(
|
||||
# WsMessage(
|
||||
# method=Methods.ERROR,
|
||||
# success=False,
|
||||
# error="Access denied",
|
||||
# ).model_dump_json()
|
||||
# )
|
||||
# return
|
||||
|
||||
await connection_manager.subscribe(
|
||||
user_id=user_id,
|
||||
graph_id=sub_req.graph_id,
|
||||
graph_version=sub_req.graph_version,
|
||||
websocket=websocket,
|
||||
)
|
||||
logger.debug(
|
||||
f"New execution subscription for user #{user_id} "
|
||||
f"graph #{sub_req.graph_id}v{sub_req.graph_version}"
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.SUBSCRIBE,
|
||||
success=True,
|
||||
channel=f"{ex_sub.graph_id}_{ex_sub.graph_version}",
|
||||
channel=f"{user_id}_{sub_req.graph_id}_{sub_req.graph_version}",
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
|
||||
async def handle_unsubscribe(
|
||||
websocket: WebSocket, manager: ConnectionManager, message: WsMessage
|
||||
connection_manager: ConnectionManager,
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
message: WsMessage,
|
||||
):
|
||||
if not message.data:
|
||||
await websocket.send_text(
|
||||
@@ -109,14 +147,22 @@ async def handle_unsubscribe(
|
||||
).model_dump_json()
|
||||
)
|
||||
else:
|
||||
ex_sub = ExecutionSubscription.model_validate(message.data)
|
||||
await manager.unsubscribe(ex_sub.graph_id, ex_sub.graph_version, websocket)
|
||||
logger.debug(f"Removed execution subscription for graph {ex_sub.graph_id}")
|
||||
unsub_req = ExecutionSubscription.model_validate(message.data)
|
||||
await connection_manager.unsubscribe(
|
||||
user_id=user_id,
|
||||
graph_id=unsub_req.graph_id,
|
||||
graph_version=unsub_req.graph_version,
|
||||
websocket=websocket,
|
||||
)
|
||||
logger.debug(
|
||||
f"Removed execution subscription for user #{user_id} "
|
||||
f"graph #{unsub_req.graph_id}v{unsub_req.graph_version}"
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsMessage(
|
||||
method=Methods.UNSUBSCRIBE,
|
||||
success=True,
|
||||
channel=f"{ex_sub.graph_id}_{ex_sub.graph_version}",
|
||||
channel=f"{unsub_req.graph_id}_{unsub_req.graph_version}",
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
@@ -145,13 +191,32 @@ async def websocket_router(
|
||||
)
|
||||
continue
|
||||
|
||||
if message.method == Methods.SUBSCRIBE:
|
||||
await handle_subscribe(websocket, manager, message)
|
||||
try:
|
||||
if message.method == Methods.SUBSCRIBE:
|
||||
await handle_subscribe(
|
||||
connection_manager=manager,
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
|
||||
elif message.method == Methods.UNSUBSCRIBE:
|
||||
await handle_unsubscribe(websocket, manager, message)
|
||||
elif message.method == Methods.UNSUBSCRIBE:
|
||||
await handle_unsubscribe(
|
||||
connection_manager=manager,
|
||||
websocket=websocket,
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error while handling '{message.method}' message "
|
||||
f"for user #{user_id}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
elif message.method == Methods.ERROR:
|
||||
if message.method == Methods.ERROR:
|
||||
logger.error(f"WebSocket Error message received: {message.data}")
|
||||
|
||||
else:
|
||||
|
||||
@@ -46,17 +46,27 @@ def test_disconnect(
|
||||
async def test_subscribe(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
await connection_manager.subscribe("test_graph", 1, mock_websocket)
|
||||
assert mock_websocket in connection_manager.subscriptions["test_graph_1"]
|
||||
await connection_manager.subscribe(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
assert mock_websocket in connection_manager.subscriptions["user-1_test_graph_1"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsubscribe(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
connection_manager.subscriptions["test_graph_1"] = {mock_websocket}
|
||||
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
|
||||
|
||||
await connection_manager.unsubscribe("test_graph", 1, mock_websocket)
|
||||
await connection_manager.unsubscribe(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
|
||||
assert "test_graph" not in connection_manager.subscriptions
|
||||
|
||||
@@ -65,8 +75,9 @@ async def test_unsubscribe(
|
||||
async def test_send_execution_result(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
connection_manager.subscriptions["test_graph_1"] = {mock_websocket}
|
||||
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
|
||||
result: ExecutionResult = ExecutionResult(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test_exec_id",
|
||||
@@ -87,17 +98,45 @@ async def test_send_execution_result(
|
||||
mock_websocket.send_text.assert_called_once_with(
|
||||
WsMessage(
|
||||
method=Methods.EXECUTION_EVENT,
|
||||
channel="test_graph_1",
|
||||
channel="user-1_test_graph_1",
|
||||
data=result.model_dump(),
|
||||
).model_dump_json()
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_execution_result_user_mismatch(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
connection_manager.subscriptions["user-1_test_graph_1"] = {mock_websocket}
|
||||
result: ExecutionResult = ExecutionResult(
|
||||
user_id="user-2",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test_exec_id",
|
||||
node_exec_id="test_node_exec_id",
|
||||
node_id="test_node_id",
|
||||
block_id="test_block_id",
|
||||
status=ExecutionStatus.COMPLETED,
|
||||
input_data={"input1": "value1"},
|
||||
output_data={"output1": ["result1"]},
|
||||
add_time=datetime.now(tz=timezone.utc),
|
||||
queue_time=None,
|
||||
start_time=datetime.now(tz=timezone.utc),
|
||||
end_time=datetime.now(tz=timezone.utc),
|
||||
)
|
||||
|
||||
await connection_manager.send_execution_result(result)
|
||||
|
||||
mock_websocket.send_text.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_execution_result_no_subscribers(
|
||||
connection_manager: ConnectionManager, mock_websocket: AsyncMock
|
||||
) -> None:
|
||||
result: ExecutionResult = ExecutionResult(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
graph_exec_id="test_exec_id",
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import AsyncMock
|
||||
import pytest
|
||||
from fastapi import WebSocket, WebSocketDisconnect
|
||||
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.ws_api import (
|
||||
Methods,
|
||||
@@ -41,7 +42,12 @@ async def test_websocket_router_subscribe(
|
||||
)
|
||||
|
||||
mock_manager.connect.assert_called_once_with(mock_websocket)
|
||||
mock_manager.subscribe.assert_called_once_with("test_graph", 1, mock_websocket)
|
||||
mock_manager.subscribe.assert_called_once_with(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
@@ -65,7 +71,12 @@ async def test_websocket_router_unsubscribe(
|
||||
)
|
||||
|
||||
mock_manager.connect.assert_called_once_with(mock_websocket)
|
||||
mock_manager.unsubscribe.assert_called_once_with("test_graph", 1, mock_websocket)
|
||||
mock_manager.unsubscribe.assert_called_once_with(
|
||||
user_id=DEFAULT_USER_ID,
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
@@ -101,10 +112,18 @@ async def test_handle_subscribe_success(
|
||||
)
|
||||
|
||||
await handle_subscribe(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.subscribe.assert_called_once_with("test_graph", 1, mock_websocket)
|
||||
mock_manager.subscribe.assert_called_once_with(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"subscribe"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
@@ -117,7 +136,10 @@ async def test_handle_subscribe_missing_data(
|
||||
message = WsMessage(method=Methods.SUBSCRIBE)
|
||||
|
||||
await handle_subscribe(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.subscribe.assert_not_called()
|
||||
@@ -135,10 +157,18 @@ async def test_handle_unsubscribe_success(
|
||||
)
|
||||
|
||||
await handle_unsubscribe(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.unsubscribe.assert_called_once_with("test_graph", 1, mock_websocket)
|
||||
mock_manager.unsubscribe.assert_called_once_with(
|
||||
user_id="user-1",
|
||||
graph_id="test_graph",
|
||||
graph_version=1,
|
||||
websocket=mock_websocket,
|
||||
)
|
||||
mock_websocket.send_text.assert_called_once()
|
||||
assert '"method":"unsubscribe"' in mock_websocket.send_text.call_args[0][0]
|
||||
assert '"success":true' in mock_websocket.send_text.call_args[0][0]
|
||||
@@ -151,7 +181,10 @@ async def test_handle_unsubscribe_missing_data(
|
||||
message = WsMessage(method=Methods.UNSUBSCRIBE)
|
||||
|
||||
await handle_unsubscribe(
|
||||
cast(WebSocket, mock_websocket), cast(ConnectionManager, mock_manager), message
|
||||
connection_manager=cast(ConnectionManager, mock_manager),
|
||||
websocket=cast(WebSocket, mock_websocket),
|
||||
user_id="user-1",
|
||||
message=message,
|
||||
)
|
||||
|
||||
mock_manager.unsubscribe.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user