diff --git a/autogpt_platform/backend/backend/data/rabbitmq.py b/autogpt_platform/backend/backend/data/rabbitmq.py index bdf2090083..96070b53df 100644 --- a/autogpt_platform/backend/backend/data/rabbitmq.py +++ b/autogpt_platform/backend/backend/data/rabbitmq.py @@ -291,6 +291,18 @@ class AsyncRabbitMQ(RabbitMQBase): exchange, routing_key=queue.routing_key or queue.name ) + async def _ensure_channel(self) -> aio_pika.abc.AbstractChannel: + """Get a valid channel, reconnecting if the current one is stale.""" + if not self.is_ready: + # Reset stale channel so connect() creates a fresh one + self._channel = None + await self.connect() + + if self._channel is None: + raise RuntimeError("Channel should be established after connect") + + return self._channel + @func_retry async def publish_message( self, @@ -299,32 +311,58 @@ class AsyncRabbitMQ(RabbitMQBase): exchange: Optional[Exchange] = None, persistent: bool = True, ) -> None: - if not self.is_ready: + try: + channel = await self._ensure_channel() + except Exception: + # Force full reconnect on channel acquisition failure + self._channel = None + self._connection = None + channel = await self._ensure_channel() + + try: + if exchange: + exchange_obj = await channel.get_exchange(exchange.name) + else: + exchange_obj = channel.default_exchange + + await exchange_obj.publish( + aio_pika.Message( + body=message.encode(), + delivery_mode=( + aio_pika.DeliveryMode.PERSISTENT + if persistent + else aio_pika.DeliveryMode.NOT_PERSISTENT + ), + ), + routing_key=routing_key, + ) + except aio_pika.exceptions.ChannelInvalidStateError: + logger.warning( + "RabbitMQ channel invalid, reconnecting and retrying publish" + ) + self._channel = None + self._connection = None await self.connect() - if self._channel is None: - raise RuntimeError("Channel should be established after connect") + if self._channel is None: + raise RuntimeError("Channel should be established after reconnect") - if exchange: - exchange_obj = await self._channel.get_exchange(exchange.name) - else: - exchange_obj = self._channel.default_exchange + if exchange: + exchange_obj = await self._channel.get_exchange(exchange.name) + else: + exchange_obj = self._channel.default_exchange - await exchange_obj.publish( - aio_pika.Message( - body=message.encode(), - delivery_mode=( - aio_pika.DeliveryMode.PERSISTENT - if persistent - else aio_pika.DeliveryMode.NOT_PERSISTENT + await exchange_obj.publish( + aio_pika.Message( + body=message.encode(), + delivery_mode=( + aio_pika.DeliveryMode.PERSISTENT + if persistent + else aio_pika.DeliveryMode.NOT_PERSISTENT + ), ), - ), - routing_key=routing_key, - ) + routing_key=routing_key, + ) async def get_channel(self) -> aio_pika.abc.AbstractChannel: - if not self.is_ready: - await self.connect() - if self._channel is None: - raise RuntimeError("Channel should be established after connect") - return self._channel + return await self._ensure_channel()