diff --git a/autogpt_platform/backend/backend/data/rabbitmq.py b/autogpt_platform/backend/backend/data/rabbitmq.py index bdf2090083..524e21748a 100644 --- a/autogpt_platform/backend/backend/data/rabbitmq.py +++ b/autogpt_platform/backend/backend/data/rabbitmq.py @@ -1,3 +1,4 @@ +import asyncio import logging from abc import ABC, abstractmethod from enum import Enum @@ -225,6 +226,10 @@ class SyncRabbitMQ(RabbitMQBase): class AsyncRabbitMQ(RabbitMQBase): """Asynchronous RabbitMQ client""" + def __init__(self, config: RabbitMQConfig): + super().__init__(config) + self._reconnect_lock: asyncio.Lock | None = None + @property def is_connected(self) -> bool: return bool(self._connection and not self._connection.is_closed) @@ -235,7 +240,17 @@ class AsyncRabbitMQ(RabbitMQBase): @conn_retry("AsyncRabbitMQ", "Acquiring async connection") async def connect(self): - if self.is_connected: + if self.is_connected and self._channel and not self._channel.is_closed: + return + + if ( + self.is_connected + and self._connection + and (self._channel is None or self._channel.is_closed) + ): + self._channel = await self._connection.channel() + await self._channel.set_qos(prefetch_count=1) + await self.declare_infrastructure() return self._connection = await aio_pika.connect_robust( @@ -291,24 +306,46 @@ class AsyncRabbitMQ(RabbitMQBase): exchange, routing_key=queue.routing_key or queue.name ) - @func_retry - async def publish_message( + @property + def _lock(self) -> asyncio.Lock: + if self._reconnect_lock is None: + self._reconnect_lock = asyncio.Lock() + return self._reconnect_lock + + async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel: + """Get a valid channel, reconnecting if the current one is stale. + + Uses a lock to prevent concurrent reconnection attempts from racing. + """ + if self.is_ready: + return self._channel # type: ignore # is_ready guarantees non-None + + async with self._lock: + # Double-check after acquiring lock + if self.is_ready: + return self._channel # type: ignore + + self._channel = None + await self.connect() + + if self._channel is None: + raise RuntimeError("Channel should be established after connect") + + return self._channel + + async def _publish_once( self, routing_key: str, message: str, exchange: Optional[Exchange] = None, persistent: bool = True, ) -> None: - if not self.is_ready: - await self.connect() - - if self._channel is None: - raise RuntimeError("Channel should be established after connect") + channel = await self._ensure_channel() if exchange: - exchange_obj = await self._channel.get_exchange(exchange.name) + exchange_obj = await channel.get_exchange(exchange.name) else: - exchange_obj = self._channel.default_exchange + exchange_obj = channel.default_exchange await exchange_obj.publish( aio_pika.Message( @@ -322,9 +359,23 @@ class AsyncRabbitMQ(RabbitMQBase): routing_key=routing_key, ) + @func_retry + async def publish_message( + self, + routing_key: str, + message: str, + exchange: Optional[Exchange] = None, + persistent: bool = True, + ) -> None: + try: + await self._publish_once(routing_key, message, exchange, persistent) + except aio_pika.exceptions.ChannelInvalidStateError: + logger.warning( + "RabbitMQ channel invalid, forcing reconnect and retrying publish" + ) + async with self._lock: + self._channel = None + await self._publish_once(routing_key, message, exchange, persistent) + async def get_channel(self) -> aio_pika.abc.AbstractChannel: - if not self.is_ready: - await self.connect() - if self._channel is None: - raise RuntimeError("Channel should be established after connect") - return self._channel + return await self._ensure_channel()