diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py index 83ca543c0e..a336e68f7a 100644 --- a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py +++ b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py @@ -1,27 +1,24 @@ -"""RabbitMQ consumer for operation completion messages. +"""Redis Streams consumer for operation completion messages. This module provides a consumer that listens for completion notifications from external services (like Agent Generator) and triggers the appropriate stream registry and chat service updates. -The consumer initializes its own Prisma client to avoid async context issues. +The consumer uses Redis Streams with consumer groups for reliable message +processing across multiple platform pods. """ import asyncio import logging import os +import uuid import orjson from prisma import Prisma from pydantic import BaseModel +from redis.exceptions import ResponseError -from backend.data.rabbitmq import ( - AsyncRabbitMQ, - Exchange, - ExchangeType, - Queue, - RabbitMQConfig, -) +from backend.data.redis_client import get_redis_async from . import service as chat_service from . import stream_registry @@ -30,24 +27,10 @@ from .tools.models import ErrorResponse logger = logging.getLogger(__name__) -# Queue and exchange configuration -OPERATION_COMPLETE_EXCHANGE = Exchange( - name="chat_operations", - type=ExchangeType.DIRECT, - durable=True, -) - -OPERATION_COMPLETE_QUEUE = Queue( - name="chat_operation_complete", - durable=True, - exchange=OPERATION_COMPLETE_EXCHANGE, - routing_key="operation.complete", -) - -RABBITMQ_CONFIG = RabbitMQConfig( - exchanges=[OPERATION_COMPLETE_EXCHANGE], - queues=[OPERATION_COMPLETE_QUEUE], -) +# Stream configuration +COMPLETION_STREAM = "chat:completions" +CONSUMER_GROUP = "chat_consumers" +STREAM_MAX_LENGTH = 10000 class OperationCompleteMessage(BaseModel): @@ -61,17 +44,20 @@ class OperationCompleteMessage(BaseModel): class ChatCompletionConsumer: - """Consumer for chat operation completion messages from RabbitMQ. + """Consumer for chat operation completion messages from Redis Streams. This consumer initializes its own Prisma client in start() to ensure database operations work correctly within this async context. + + Uses Redis consumer groups to allow multiple platform pods to consume + messages reliably with automatic redelivery on failure. """ def __init__(self): - self._rabbitmq: AsyncRabbitMQ | None = None self._consumer_task: asyncio.Task | None = None self._running = False self._prisma: Prisma | None = None + self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}" async def start(self) -> None: """Start the completion consumer.""" @@ -79,15 +65,29 @@ class ChatCompletionConsumer: logger.warning("Completion consumer already running") return - # Don't initialize Prisma here - do it lazily on first message - # to ensure it's in the same async context as the message handler - - self._rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG) - await self._rabbitmq.connect() + # Create consumer group if it doesn't exist + try: + redis = await get_redis_async() + await redis.xgroup_create( + COMPLETION_STREAM, + CONSUMER_GROUP, + id="0", + mkstream=True, + ) + logger.info( + f"Created consumer group '{CONSUMER_GROUP}' on stream '{COMPLETION_STREAM}'" + ) + except ResponseError as e: + if "BUSYGROUP" in str(e): + logger.debug(f"Consumer group '{CONSUMER_GROUP}' already exists") + else: + raise self._running = True self._consumer_task = asyncio.create_task(self._consume_messages()) - logger.info("Chat completion consumer started") + logger.info( + f"Chat completion consumer started (consumer: {self._consumer_name})" + ) async def _ensure_prisma(self) -> Prisma: """Lazily initialize Prisma client on first use.""" @@ -110,10 +110,6 @@ class ChatCompletionConsumer: pass self._consumer_task = None - if self._rabbitmq: - await self._rabbitmq.disconnect() - self._rabbitmq = None - if self._prisma: await self._prisma.disconnect() self._prisma = None @@ -126,33 +122,54 @@ class ChatCompletionConsumer: max_retries = 10 retry_delay = 5 # seconds retry_count = 0 + block_timeout = 5000 # milliseconds while self._running and retry_count < max_retries: - if not self._rabbitmq: - logger.error("RabbitMQ not initialized") - return - try: - channel = await self._rabbitmq.get_channel() - queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name) + redis = await get_redis_async() # Reset retry count on successful connection retry_count = 0 - async with queue.iterator() as queue_iter: - async for message in queue_iter: - if not self._running: - return + while self._running: + # Read new messages from the stream + messages = await redis.xreadgroup( + groupname=CONSUMER_GROUP, + consumername=self._consumer_name, + streams={COMPLETION_STREAM: ">"}, + block=block_timeout, + count=10, + ) - try: - async with message.process(): - await self._handle_message(message.body) - except Exception as e: - logger.error( - f"Error processing completion message: {e}", - exc_info=True, - ) - # Message will be requeued due to exception + if not messages: + continue + + for stream_name, entries in messages: + for entry_id, data in entries: + if not self._running: + return + + try: + # Handle the message + message_data = data.get("data") + if message_data: + await self._handle_message( + message_data.encode() + if isinstance(message_data, str) + else message_data + ) + + # Acknowledge the message + await redis.xack( + COMPLETION_STREAM, CONSUMER_GROUP, entry_id + ) + except Exception as e: + logger.error( + f"Error processing completion message {entry_id}: {e}", + exc_info=True, + ) + # Message will be redelivered to another consumer + # or can be claimed after timeout except asyncio.CancelledError: logger.info("Consumer cancelled") @@ -363,7 +380,7 @@ async def publish_operation_complete( result: dict | str | None = None, error: str | None = None, ) -> None: - """Publish an operation completion message. + """Publish an operation completion message to Redis Streams. Args: operation_id: The operation ID that completed. @@ -380,14 +397,10 @@ async def publish_operation_complete( error=error, ) - rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG) - try: - await rabbitmq.connect() - await rabbitmq.publish_message( - routing_key="operation.complete", - message=message.model_dump_json(), - exchange=OPERATION_COMPLETE_EXCHANGE, - ) - logger.info(f"Published completion for operation {operation_id}") - finally: - await rabbitmq.disconnect() + redis = await get_redis_async() + await redis.xadd( + COMPLETION_STREAM, + {"data": message.model_dump_json()}, + maxlen=STREAM_MAX_LENGTH, + ) + logger.info(f"Published completion for operation {operation_id}") diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 4ddc1b33ef..b9d7a4aba7 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -209,14 +209,28 @@ async def get_session( session_id, user_id ) if active_task: + # Filter out the in-progress assistant message from the session response. + # The client will receive the complete assistant response through the SSE + # stream replay instead, preventing duplicate content. + if messages and messages[-1].get("role") == "assistant": + original_count = len(messages) + messages = messages[:-1] + logger.info( + f"[SSE-RECONNECT] Filtered out in-progress assistant message " + f"(was {original_count} messages, now {len(messages)})" + ) + + # Use "0-0" as last_message_id to replay the stream from the beginning. + # Since we filtered out the cached assistant message, the client needs + # the full stream to reconstruct the response. active_stream_info = ActiveStreamInfo( task_id=active_task.task_id, - last_message_id=last_message_id, + last_message_id="0-0", ) logger.info( f"[SSE-RECONNECT] Session {session_id} HAS active stream: " f"task_id={active_task.task_id}, status={active_task.status}, " - f"last_message_id={last_message_id}" + f"last_message_id=0-0 (replay from start)" ) else: logger.info(f"[SSE-RECONNECT] Session {session_id} has NO active stream") diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py index 45f3fc233c..8fe12692f0 100644 --- a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -211,9 +211,7 @@ async def subscribe_to_task( task_status = meta.get("status", "") task_user_id = meta.get("user_id", "") or None - logger.info( - f"[SSE-RECONNECT] Subscribing to task {task_id}: status={task_status}" - ) + logger.info(f"[SSE-RECONNECT] Subscribing to task {task_id}: status={task_status}") # Validate ownership if user_id and task_user_id and task_user_id != user_id: @@ -256,9 +254,7 @@ async def subscribe_to_task( logger.info( f"[SSE-RECONNECT] Task {task_id} is running, starting stream listener" ) - asyncio.create_task( - _stream_listener(task_id, subscriber_queue, replay_last_id) - ) + asyncio.create_task(_stream_listener(task_id, subscriber_queue, replay_last_id)) else: # Task is completed/failed - add finish marker logger.info( @@ -470,9 +466,7 @@ async def get_active_task_for_session( tasks_checked = 0 while True: - cursor, keys = await redis.scan( - cursor, match=f"{TASK_META_PREFIX}*", count=100 - ) + cursor, keys = await redis.scan(cursor, match=f"{TASK_META_PREFIX}*", count=100) for key in keys: tasks_checked += 1 diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py index 25dc05f22a..66d02f8f3e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py @@ -842,5 +842,9 @@ async def generate_agent_patch( _check_service_configured() logger.info("Calling external Agent Generator service for generate_agent_patch") return await generate_agent_patch_external( - update_request, current_agent, _to_dict_list(library_agents), operation_id, task_id + update_request, + current_agent, + _to_dict_list(library_agents), + operation_id, + task_id, )