diff --git a/autogpt_platform/backend/backend/api/ws_api.py b/autogpt_platform/backend/backend/api/ws_api.py index b71fdb3526..e254d4b4db 100644 --- a/autogpt_platform/backend/backend/api/ws_api.py +++ b/autogpt_platform/backend/backend/api/ws_api.py @@ -66,18 +66,24 @@ async def event_broadcaster(manager: ConnectionManager): execution_bus = AsyncRedisExecutionEventBus() notification_bus = AsyncRedisNotificationEventBus() - async def execution_worker(): - async for event in execution_bus.listen("*"): - await manager.send_execution_update(event) + try: - async def notification_worker(): - async for notification in notification_bus.listen("*"): - await manager.send_notification( - user_id=notification.user_id, - payload=notification.payload, - ) + async def execution_worker(): + async for event in execution_bus.listen("*"): + await manager.send_execution_update(event) - await asyncio.gather(execution_worker(), notification_worker()) + async def notification_worker(): + async for notification in notification_bus.listen("*"): + await manager.send_notification( + user_id=notification.user_id, + payload=notification.payload, + ) + + await asyncio.gather(execution_worker(), notification_worker()) + finally: + # Ensure PubSub connections are closed on any exit to prevent leaks + await execution_bus.close() + await notification_bus.close() async def authenticate_websocket(websocket: WebSocket) -> str: diff --git a/autogpt_platform/backend/backend/data/event_bus.py b/autogpt_platform/backend/backend/data/event_bus.py index d8a1c5b729..614fb158b2 100644 --- a/autogpt_platform/backend/backend/data/event_bus.py +++ b/autogpt_platform/backend/backend/data/event_bus.py @@ -133,10 +133,23 @@ class RedisEventBus(BaseRedisEventBus[M], ABC): class AsyncRedisEventBus(BaseRedisEventBus[M], ABC): + def __init__(self): + self._pubsub: AsyncPubSub | None = None + @property async def connection(self) -> redis.AsyncRedis: return await redis.get_redis_async() + async def close(self) -> None: + """Close the PubSub connection if it exists.""" + if self._pubsub is not None: + try: + await self._pubsub.close() + except Exception: + logger.warning("Failed to close PubSub connection", exc_info=True) + finally: + self._pubsub = None + async def publish_event(self, event: M, channel_key: str): """ Publish an event to Redis. Gracefully handles connection failures @@ -157,6 +170,7 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC): await self.connection, channel_key ) assert isinstance(pubsub, AsyncPubSub) + self._pubsub = pubsub if "*" in channel_key: await pubsub.psubscribe(full_channel_name)