mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
feat(backend): RedisEventQueue into Pub/Sub (#8387)
This commit is contained in:
@@ -2,11 +2,18 @@ import json
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime
|
||||
from typing import Any, AsyncGenerator, Generator, Generic, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis.asyncio.client import PubSub as AsyncPubSub
|
||||
from redis.client import PubSub
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.execution import ExecutionResult
|
||||
from backend.util.settings import Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = Config()
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
@@ -16,35 +23,122 @@ class DateTimeEncoder(json.JSONEncoder):
|
||||
return super().default(o)
|
||||
|
||||
|
||||
class AbstractEventQueue(ABC):
|
||||
@abstractmethod
|
||||
def put(self, execution_result: ExecutionResult):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get(self) -> ExecutionResult | None:
|
||||
pass
|
||||
M = TypeVar("M", bound=BaseModel)
|
||||
|
||||
|
||||
class RedisEventQueue(AbstractEventQueue):
|
||||
def __init__(self):
|
||||
self.queue_name = redis.QUEUE_NAME
|
||||
class BaseRedisEventBus(Generic[M], ABC):
|
||||
Model: type[M]
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
@abstractmethod
|
||||
def event_bus_name(self) -> str:
|
||||
pass
|
||||
|
||||
def _serialize_message(self, item: M, channel_key: str) -> tuple[str, str]:
|
||||
message = json.dumps(item.model_dump(), cls=DateTimeEncoder)
|
||||
channel_name = f"{self.event_bus_name}-{channel_key}"
|
||||
logger.info(f"[{channel_name}] Publishing an event to Redis {message}")
|
||||
return message, channel_name
|
||||
|
||||
def _deserialize_message(self, msg: Any, channel_key: str) -> M | None:
|
||||
message_type = "pmessage" if "*" in channel_key else "message"
|
||||
if msg["type"] != message_type:
|
||||
return None
|
||||
try:
|
||||
data = json.loads(msg["data"])
|
||||
logger.info(f"Consuming an event from Redis {data}")
|
||||
return self.Model(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse event result from Redis {msg} {e}")
|
||||
|
||||
def _subscribe(
|
||||
self, connection: redis.Redis | redis.AsyncRedis, channel_key: str
|
||||
) -> tuple[PubSub | AsyncPubSub, str]:
|
||||
channel_name = f"{self.event_bus_name}-{channel_key}"
|
||||
pubsub = connection.pubsub()
|
||||
return pubsub, channel_name
|
||||
|
||||
|
||||
class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
Model: type[M]
|
||||
|
||||
@property
|
||||
def connection(self) -> redis.Redis:
|
||||
return redis.get_redis()
|
||||
|
||||
def put(self, execution_result: ExecutionResult):
|
||||
message = json.dumps(execution_result.model_dump(), cls=DateTimeEncoder)
|
||||
logger.info(f"Putting execution result to Redis {message}")
|
||||
self.connection.lpush(self.queue_name, message)
|
||||
def publish_event(self, event: M, channel_key: str):
|
||||
message, channel_name = self._serialize_message(event, channel_key)
|
||||
self.connection.publish(channel_name, message)
|
||||
|
||||
def get(self) -> ExecutionResult | None:
|
||||
message = self.connection.rpop(self.queue_name)
|
||||
if message is not None and isinstance(message, (str, bytes, bytearray)):
|
||||
data = json.loads(message)
|
||||
logger.info(f"Getting execution result from Redis {data}")
|
||||
return ExecutionResult(**data)
|
||||
elif message is not None:
|
||||
logger.error(f"Failed to get execution result from Redis {message}")
|
||||
return None
|
||||
def listen_events(self, channel_key: str) -> Generator[M, None, None]:
|
||||
pubsub, channel_name = self._subscribe(self.connection, channel_key)
|
||||
assert isinstance(pubsub, PubSub)
|
||||
|
||||
if "*" in channel_key:
|
||||
pubsub.psubscribe(channel_name)
|
||||
else:
|
||||
pubsub.subscribe(channel_name)
|
||||
|
||||
for message in pubsub.listen():
|
||||
if event := self._deserialize_message(message, channel_key):
|
||||
yield event
|
||||
|
||||
|
||||
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
Model: type[M]
|
||||
|
||||
@property
|
||||
async def connection(self) -> redis.AsyncRedis:
|
||||
return await redis.get_redis_async()
|
||||
|
||||
async def publish_event(self, event: M, channel_key: str):
|
||||
message, channel_name = self._serialize_message(event, channel_key)
|
||||
connection = await self.connection
|
||||
await connection.publish(channel_name, message)
|
||||
|
||||
async def listen_events(self, channel_key: str) -> AsyncGenerator[M, None]:
|
||||
pubsub, channel_name = self._subscribe(await self.connection, channel_key)
|
||||
assert isinstance(pubsub, AsyncPubSub)
|
||||
|
||||
if "*" in channel_key:
|
||||
await pubsub.psubscribe(channel_name)
|
||||
else:
|
||||
await pubsub.subscribe(channel_name)
|
||||
|
||||
async for message in pubsub.listen():
|
||||
if event := self._deserialize_message(message, channel_key):
|
||||
yield event
|
||||
|
||||
|
||||
class RedisExecutionEventBus(RedisEventBus[ExecutionResult]):
|
||||
Model = ExecutionResult
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return config.execution_event_bus_name
|
||||
|
||||
def publish(self, res: ExecutionResult):
|
||||
self.publish_event(res, f"{res.graph_id}-{res.graph_exec_id}")
|
||||
|
||||
def listen(
|
||||
self, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
) -> Generator[ExecutionResult, None, None]:
|
||||
for execution_result in self.listen_events(f"{graph_id}-{graph_exec_id}"):
|
||||
yield execution_result
|
||||
|
||||
|
||||
class AsyncRedisExecutionEventBus(AsyncRedisEventBus[ExecutionResult]):
|
||||
Model = ExecutionResult
|
||||
|
||||
@property
|
||||
def event_bus_name(self) -> str:
|
||||
return config.execution_event_bus_name
|
||||
|
||||
async def publish(self, res: ExecutionResult):
|
||||
await self.publish_event(res, f"{res.graph_id}-{res.graph_exec_id}")
|
||||
|
||||
async def listen(
|
||||
self, graph_id: str = "*", graph_exec_id: str = "*"
|
||||
) -> AsyncGenerator[ExecutionResult, None]:
|
||||
async for execution_result in self.listen_events(f"{graph_id}-{graph_exec_id}"):
|
||||
yield execution_result
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from redis import Redis
|
||||
from redis.asyncio import Redis as AsyncRedis
|
||||
|
||||
from backend.util.retry import conn_retry
|
||||
|
||||
@@ -11,10 +12,10 @@ load_dotenv()
|
||||
HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", "password")
|
||||
QUEUE_NAME = os.getenv("REDIS_QUEUE", "execution_events")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
connection: Redis | None = None
|
||||
connection_async: AsyncRedis | None = None
|
||||
|
||||
|
||||
@conn_retry("Redis", "Acquiring connection")
|
||||
@@ -42,7 +43,42 @@ def disconnect():
|
||||
connection = None
|
||||
|
||||
|
||||
def get_redis() -> Redis:
|
||||
if not connection:
|
||||
raise RuntimeError("Redis connection is not established")
|
||||
return connection
|
||||
def get_redis(auto_connect: bool = True) -> Redis:
|
||||
if connection:
|
||||
return connection
|
||||
if auto_connect:
|
||||
return connect()
|
||||
raise RuntimeError("Redis connection is not established")
|
||||
|
||||
|
||||
@conn_retry("AsyncRedis", "Acquiring connection")
|
||||
async def connect_async() -> AsyncRedis:
|
||||
global connection_async
|
||||
if connection_async:
|
||||
return connection_async
|
||||
|
||||
c = AsyncRedis(
|
||||
host=HOST,
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=True,
|
||||
)
|
||||
await c.ping()
|
||||
connection_async = c
|
||||
return connection_async
|
||||
|
||||
|
||||
@conn_retry("AsyncRedis", "Releasing connection")
|
||||
async def disconnect_async():
|
||||
global connection_async
|
||||
if connection_async:
|
||||
await connection_async.close()
|
||||
connection_async = None
|
||||
|
||||
|
||||
async def get_redis_async(auto_connect: bool = True) -> AsyncRedis:
|
||||
if connection_async:
|
||||
return connection_async
|
||||
if auto_connect:
|
||||
return await connect_async()
|
||||
raise RuntimeError("AsyncRedis connection is not established")
|
||||
|
||||
@@ -15,7 +15,7 @@ from backend.data.execution import (
|
||||
upsert_execution_output,
|
||||
)
|
||||
from backend.data.graph import get_graph, get_node
|
||||
from backend.data.queue import RedisEventQueue
|
||||
from backend.data.queue import RedisExecutionEventBus
|
||||
from backend.data.user import get_user_metadata, update_user_metadata
|
||||
from backend.util.service import AppService, expose
|
||||
from backend.util.settings import Config
|
||||
@@ -30,7 +30,7 @@ class DatabaseManager(AppService):
|
||||
super().__init__()
|
||||
self.use_db = True
|
||||
self.use_redis = True
|
||||
self.event_queue = RedisEventQueue()
|
||||
self.event_queue = RedisExecutionEventBus()
|
||||
|
||||
@classmethod
|
||||
def get_port(cls) -> int:
|
||||
@@ -38,7 +38,7 @@ class DatabaseManager(AppService):
|
||||
|
||||
@expose
|
||||
def send_execution_update(self, execution_result_dict: dict[Any, Any]):
|
||||
self.event_queue.put(ExecutionResult(**execution_result_dict))
|
||||
self.event_queue.publish(ExecutionResult(**execution_result_dict))
|
||||
|
||||
@staticmethod
|
||||
def exposed_run_and_wait(
|
||||
|
||||
@@ -8,7 +8,7 @@ from fastapi import Depends, FastAPI, WebSocket, WebSocketDisconnect
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from backend.data import redis
|
||||
from backend.data.queue import RedisEventQueue
|
||||
from backend.data.queue import AsyncRedisExecutionEventBus
|
||||
from backend.data.user import DEFAULT_USER_ID
|
||||
from backend.server.conn_manager import ConnectionManager
|
||||
from backend.server.model import ExecutionSubscription, Methods, WsMessage
|
||||
@@ -51,13 +51,9 @@ def get_connection_manager():
|
||||
async def event_broadcaster(manager: ConnectionManager):
|
||||
try:
|
||||
redis.connect()
|
||||
event_queue = RedisEventQueue()
|
||||
while True:
|
||||
event = event_queue.get()
|
||||
if event:
|
||||
await manager.send_execution_result(event)
|
||||
else:
|
||||
await asyncio.sleep(0.1)
|
||||
event_queue = AsyncRedisExecutionEventBus()
|
||||
async for event in event_queue.listen():
|
||||
await manager.send_execution_result(event)
|
||||
except Exception as e:
|
||||
logger.exception(f"Event broadcaster error: {e}")
|
||||
raise
|
||||
|
||||
@@ -148,6 +148,11 @@ class Config(UpdateTrackingModel["Config"], BaseSettings):
|
||||
description="What environment to behave as: local or cloud",
|
||||
)
|
||||
|
||||
execution_event_bus_name: str = Field(
|
||||
default="execution_event",
|
||||
description="Name of the event bus",
|
||||
)
|
||||
|
||||
backend_cors_allow_origins: List[str] = Field(default_factory=list)
|
||||
|
||||
@field_validator("backend_cors_allow_origins")
|
||||
|
||||
Reference in New Issue
Block a user