mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-09 15:17:59 -05:00
We'll soon be needing a more feature-complete external API. To make way for this, I'm moving some files around so: - We can more easily create new versions of our external API - The file structure of our internal API is more homogeneous These changes are quite opinionated, but IMO in any case they're better than the chaotic structure we have now. ### Changes 🏗️ - Move `backend/server` -> `backend/api` - Move `backend/server/routers` + `backend/server/v2` -> `backend/api/features` - Change absolute sibling imports to relative imports - Move `backend/server/v2/AutoMod` -> `backend/executor/automod` - Combine `backend/server/routers/analytics_*test.py` -> `backend/api/features/analytics_test.py` - Sort OpenAPI spec file ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - CI tests - [x] Clicking around in the app -> no obvious breakage
351 lines
10 KiB
Python
351 lines
10 KiB
Python
import asyncio
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import Protocol
|
|
|
|
import pydantic
|
|
import uvicorn
|
|
from autogpt_libs.auth.jwt_utils import parse_jwt_token
|
|
from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
|
|
from backend.api.conn_manager import ConnectionManager
|
|
from backend.api.model import (
|
|
WSMessage,
|
|
WSMethod,
|
|
WSSubscribeGraphExecutionRequest,
|
|
WSSubscribeGraphExecutionsRequest,
|
|
)
|
|
from backend.api.utils.cors import build_cors_params
|
|
from backend.data.execution import AsyncRedisExecutionEventBus
|
|
from backend.data.notification_bus import AsyncRedisNotificationEventBus
|
|
from backend.data.user import DEFAULT_USER_ID
|
|
from backend.monitoring.instrumentation import (
|
|
instrument_fastapi,
|
|
update_websocket_connections,
|
|
)
|
|
from backend.util.retry import continuous_retry
|
|
from backend.util.service import AppProcess
|
|
from backend.util.settings import AppEnvironment, Config, Settings
|
|
|
|
logger = logging.getLogger(__name__)
|
|
settings = Settings()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
manager = get_connection_manager()
|
|
fut = asyncio.create_task(event_broadcaster(manager))
|
|
fut.add_done_callback(lambda _: logger.info("Event broadcaster stopped"))
|
|
yield
|
|
|
|
|
|
docs_url = "/docs" if settings.config.app_env == AppEnvironment.LOCAL else None
|
|
app = FastAPI(lifespan=lifespan, docs_url=docs_url)
|
|
_connection_manager = None
|
|
|
|
# Add Prometheus instrumentation
|
|
instrument_fastapi(
|
|
app,
|
|
service_name="websocket-server",
|
|
expose_endpoint=True,
|
|
endpoint="/metrics",
|
|
include_in_schema=settings.config.app_env == AppEnvironment.LOCAL,
|
|
)
|
|
|
|
|
|
def get_connection_manager():
|
|
global _connection_manager
|
|
if _connection_manager is None:
|
|
_connection_manager = ConnectionManager()
|
|
return _connection_manager
|
|
|
|
|
|
@continuous_retry()
|
|
async def event_broadcaster(manager: ConnectionManager):
|
|
execution_bus = AsyncRedisExecutionEventBus()
|
|
notification_bus = AsyncRedisNotificationEventBus()
|
|
|
|
async def execution_worker():
|
|
async for event in execution_bus.listen("*"):
|
|
await manager.send_execution_update(event)
|
|
|
|
async def notification_worker():
|
|
async for notification in notification_bus.listen("*"):
|
|
await manager.send_notification(
|
|
user_id=notification.user_id,
|
|
payload=notification.payload,
|
|
)
|
|
|
|
await asyncio.gather(execution_worker(), notification_worker())
|
|
|
|
|
|
async def authenticate_websocket(websocket: WebSocket) -> str:
|
|
if not settings.config.enable_auth:
|
|
return DEFAULT_USER_ID
|
|
|
|
token = websocket.query_params.get("token")
|
|
if not token:
|
|
await websocket.close(code=4001, reason="Missing authentication token")
|
|
return ""
|
|
|
|
try:
|
|
payload = parse_jwt_token(token)
|
|
user_id = payload.get("sub")
|
|
if not user_id:
|
|
await websocket.close(code=4002, reason="Invalid token")
|
|
return ""
|
|
return user_id
|
|
except ValueError:
|
|
await websocket.close(code=4003, reason="Invalid token")
|
|
return ""
|
|
|
|
|
|
# ===================== Message Handlers ===================== #
|
|
|
|
|
|
class WSMessageHandler(Protocol):
|
|
async def __call__(
|
|
self,
|
|
connection_manager: ConnectionManager,
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
message: WSMessage,
|
|
): ...
|
|
|
|
|
|
async def handle_subscribe(
|
|
connection_manager: ConnectionManager,
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
message: WSMessage,
|
|
):
|
|
if not message.data:
|
|
await websocket.send_text(
|
|
WSMessage(
|
|
method=WSMethod.ERROR,
|
|
success=False,
|
|
error="Subscription data missing",
|
|
).model_dump_json()
|
|
)
|
|
return
|
|
|
|
# Verify that user has read access to graph
|
|
# if not get_db_client().get_graph(
|
|
# graph_id=sub_req.graph_id,
|
|
# version=sub_req.graph_version,
|
|
# user_id=user_id,
|
|
# ):
|
|
# await websocket.send_text(
|
|
# WsMessage(
|
|
# method=Methods.ERROR,
|
|
# success=False,
|
|
# error="Access denied",
|
|
# ).model_dump_json()
|
|
# )
|
|
# return
|
|
|
|
if message.method == WSMethod.SUBSCRIBE_GRAPH_EXEC:
|
|
sub_req = WSSubscribeGraphExecutionRequest.model_validate(message.data)
|
|
channel_key = await connection_manager.subscribe_graph_exec(
|
|
user_id=user_id,
|
|
graph_exec_id=sub_req.graph_exec_id,
|
|
websocket=websocket,
|
|
)
|
|
|
|
elif message.method == WSMethod.SUBSCRIBE_GRAPH_EXECS:
|
|
sub_req = WSSubscribeGraphExecutionsRequest.model_validate(message.data)
|
|
channel_key = await connection_manager.subscribe_graph_execs(
|
|
user_id=user_id,
|
|
graph_id=sub_req.graph_id,
|
|
websocket=websocket,
|
|
)
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"{handle_subscribe.__name__} can't handle '{message.method}' messages"
|
|
)
|
|
|
|
logger.debug(f"New subscription on channel {channel_key} for user #{user_id}")
|
|
await websocket.send_text(
|
|
WSMessage(
|
|
method=message.method,
|
|
success=True,
|
|
channel=channel_key,
|
|
).model_dump_json()
|
|
)
|
|
|
|
|
|
async def handle_unsubscribe(
|
|
connection_manager: ConnectionManager,
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
message: WSMessage,
|
|
):
|
|
if not message.data:
|
|
await websocket.send_text(
|
|
WSMessage(
|
|
method=WSMethod.ERROR,
|
|
success=False,
|
|
error="Subscription data missing",
|
|
).model_dump_json()
|
|
)
|
|
return
|
|
|
|
unsub_req = WSSubscribeGraphExecutionRequest.model_validate(message.data)
|
|
channel_key = await connection_manager.unsubscribe_graph_exec(
|
|
user_id=user_id,
|
|
graph_exec_id=unsub_req.graph_exec_id,
|
|
websocket=websocket,
|
|
)
|
|
|
|
logger.debug(f"Removed subscription on channel {channel_key} for user #{user_id}")
|
|
await websocket.send_text(
|
|
WSMessage(
|
|
method=WSMethod.UNSUBSCRIBE,
|
|
success=True,
|
|
channel=channel_key,
|
|
).model_dump_json()
|
|
)
|
|
|
|
|
|
async def handle_heartbeat(
|
|
connection_manager: ConnectionManager,
|
|
websocket: WebSocket,
|
|
user_id: str,
|
|
message: WSMessage,
|
|
):
|
|
await websocket.send_json(
|
|
{
|
|
"method": WSMethod.HEARTBEAT.value,
|
|
"data": "pong",
|
|
"success": True,
|
|
}
|
|
)
|
|
|
|
|
|
_MSG_HANDLERS: dict[WSMethod, WSMessageHandler] = {
|
|
WSMethod.HEARTBEAT: handle_heartbeat,
|
|
WSMethod.SUBSCRIBE_GRAPH_EXEC: handle_subscribe,
|
|
WSMethod.SUBSCRIBE_GRAPH_EXECS: handle_subscribe,
|
|
WSMethod.UNSUBSCRIBE: handle_unsubscribe,
|
|
}
|
|
|
|
|
|
# ===================== WebSocket Server ===================== #
|
|
|
|
|
|
@app.websocket("/ws")
|
|
async def websocket_router(
|
|
websocket: WebSocket, manager: ConnectionManager = Depends(get_connection_manager)
|
|
):
|
|
user_id = await authenticate_websocket(websocket)
|
|
if not user_id:
|
|
return
|
|
await manager.connect_socket(websocket, user_id=user_id)
|
|
|
|
# Track WebSocket connection
|
|
update_websocket_connections(user_id, 1)
|
|
|
|
try:
|
|
while True:
|
|
data = await websocket.receive_text()
|
|
try:
|
|
message = WSMessage.model_validate_json(data)
|
|
except pydantic.ValidationError as e:
|
|
logger.error(
|
|
"Invalid WebSocket message from user #%s: %s",
|
|
user_id,
|
|
e,
|
|
)
|
|
await websocket.send_text(
|
|
WSMessage(
|
|
method=WSMethod.ERROR,
|
|
success=False,
|
|
error=("Invalid message format. Review the schema and retry"),
|
|
).model_dump_json()
|
|
)
|
|
continue
|
|
|
|
try:
|
|
if message.method in _MSG_HANDLERS:
|
|
await _MSG_HANDLERS[message.method](
|
|
connection_manager=manager,
|
|
websocket=websocket,
|
|
user_id=user_id,
|
|
message=message,
|
|
)
|
|
continue
|
|
except pydantic.ValidationError as e:
|
|
logger.error(
|
|
"Validation error while handling '%s' for user #%s: %s",
|
|
message.method.value,
|
|
user_id,
|
|
e,
|
|
)
|
|
await websocket.send_text(
|
|
WSMessage(
|
|
method=WSMethod.ERROR,
|
|
success=False,
|
|
error="Invalid message data. Refer to the API schema",
|
|
).model_dump_json()
|
|
)
|
|
continue
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error while handling '{message.method.value}' message "
|
|
f"for user #{user_id}: {e}"
|
|
)
|
|
continue
|
|
|
|
if message.method == WSMethod.ERROR:
|
|
logger.error(f"WebSocket Error message received: {message.data}")
|
|
|
|
else:
|
|
logger.warning(
|
|
f"Unknown WebSocket message type {message.method} received: "
|
|
f"{message.data}"
|
|
)
|
|
await websocket.send_text(
|
|
WSMessage(
|
|
method=WSMethod.ERROR,
|
|
success=False,
|
|
error="Message type is not processed by the server",
|
|
).model_dump_json()
|
|
)
|
|
|
|
except WebSocketDisconnect:
|
|
manager.disconnect_socket(websocket, user_id=user_id)
|
|
logger.debug("WebSocket client disconnected")
|
|
finally:
|
|
update_websocket_connections(user_id, -1)
|
|
|
|
|
|
@app.get("/")
|
|
async def health():
|
|
return {"status": "healthy"}
|
|
|
|
|
|
class WebsocketServer(AppProcess):
|
|
def run(self):
|
|
logger.info(f"CORS allow origins: {settings.config.backend_cors_allow_origins}")
|
|
cors_params = build_cors_params(
|
|
settings.config.backend_cors_allow_origins,
|
|
settings.config.app_env,
|
|
)
|
|
server_app = CORSMiddleware(
|
|
app=app,
|
|
**cors_params,
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
uvicorn.run(
|
|
server_app,
|
|
host=Config().websocket_server_host,
|
|
port=Config().websocket_server_port,
|
|
ws="websockets-sansio",
|
|
log_config=None,
|
|
)
|