diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py index 4ad7afb9e9..f53b4673f3 100644 --- a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py +++ b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py @@ -21,7 +21,7 @@ from backend.data.rabbitmq import ( from . import service as chat_service from . import stream_registry -from .response_model import StreamError, StreamToolOutputAvailable +from .response_model import StreamError, StreamFinish, StreamToolOutputAvailable from .tools.models import ErrorResponse logger = logging.getLogger(__name__) @@ -96,38 +96,52 @@ class ChatCompletionConsumer: logger.info("Chat completion consumer stopped") async def _consume_messages(self) -> None: - """Main message consumption loop.""" - if not self._rabbitmq: - logger.error("RabbitMQ not initialized") - return + """Main message consumption loop with retry logic.""" + max_retries = 10 + retry_delay = 5 # seconds + retry_count = 0 - try: - channel = await self._rabbitmq.get_channel() - queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name) + while self._running and retry_count < max_retries: + if not self._rabbitmq: + logger.error("RabbitMQ not initialized") + return - async with queue.iterator() as queue_iter: - async for message in queue_iter: - if not self._running: - break + try: + channel = await self._rabbitmq.get_channel() + queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name) - try: - async with message.process(): - await self._handle_message(message.body) - except Exception as e: - logger.error( - f"Error processing completion message: {e}", - exc_info=True, - ) - # Message will be requeued due to exception + # Reset retry count on successful connection + retry_count = 0 - except asyncio.CancelledError: - logger.info("Consumer cancelled") - except Exception as e: - logger.error(f"Consumer error: {e}", exc_info=True) - # Attempt to reconnect after a delay - if self._running: - await asyncio.sleep(5) - await self._consume_messages() + async with queue.iterator() as queue_iter: + async for message in queue_iter: + if not self._running: + return + + try: + async with message.process(): + await self._handle_message(message.body) + except Exception as e: + logger.error( + f"Error processing completion message: {e}", + exc_info=True, + ) + # Message will be requeued due to exception + + except asyncio.CancelledError: + logger.info("Consumer cancelled") + return + except Exception as e: + retry_count += 1 + logger.error( + f"Consumer error (retry {retry_count}/{max_retries}): {e}", + exc_info=True, + ) + if self._running and retry_count < max_retries: + await asyncio.sleep(retry_delay) + else: + logger.error("Max retries reached, stopping consumer") + return async def _handle_message(self, body: bytes) -> None: """Handle a single completion message.""" @@ -206,8 +220,9 @@ class ChatCompletionConsumer: task_id=task.task_id, ) - # Mark task as completed + # Mark task as completed and release Redis lock await stream_registry.mark_task_completed(task.task_id, status="completed") + await chat_service._mark_operation_completed(task.tool_call_id) logger.info( f"Successfully processed completion for task {task.task_id} " @@ -222,11 +237,12 @@ class ChatCompletionConsumer: """Handle failed operation completion.""" error_msg = message.error or "Operation failed" - # Publish error to stream registry + # Publish error to stream registry followed by finish event await stream_registry.publish_chunk( task.task_id, StreamError(errorText=error_msg), ) + await stream_registry.publish_chunk(task.task_id, StreamFinish()) # Update pending operation with error error_response = ErrorResponse( @@ -239,8 +255,9 @@ class ChatCompletionConsumer: result=error_response.model_dump_json(), ) - # Mark task as failed + # Mark task as failed and release Redis lock await stream_registry.mark_task_completed(task.task_id, status="failed") + await chat_service._mark_operation_completed(task.tool_call_id) logger.info( f"Processed failure for task {task.task_id} " diff --git a/autogpt_platform/backend/backend/api/features/chat/config.py b/autogpt_platform/backend/backend/api/features/chat/config.py index 7f58a08d6a..2b1692c026 100644 --- a/autogpt_platform/backend/backend/api/features/chat/config.py +++ b/autogpt_platform/backend/backend/api/features/chat/config.py @@ -96,6 +96,14 @@ class ChatConfig(BaseSettings): v = "https://openrouter.ai/api/v1" return v + @field_validator("internal_api_key", mode="before") + @classmethod + def get_internal_api_key(cls, v): + """Get internal API key from environment if not provided.""" + if v is None: + v = os.getenv("CHAT_INTERNAL_API_KEY") + return v + # Prompt paths for different contexts PROMPT_PATHS: dict[str, str] = { "default": "prompts/chat_system.md", diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 3b98d7e542..14d5a41482 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -4,6 +4,7 @@ import logging from collections.abc import AsyncGenerator from typing import Annotated +import orjson from autogpt_libs import auth from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security from fastapi.responses import StreamingResponse @@ -532,16 +533,17 @@ async def complete_operation( Raises: HTTPException: If API key is invalid or operation not found. """ - # Validate internal API key - if config.internal_api_key: - if x_api_key != config.internal_api_key: - raise HTTPException(status_code=401, detail="Invalid API key") - else: - # If no internal API key is configured, log a warning - logger.warning( - "Operation complete webhook called without API key validation " - "(CHAT_INTERNAL_API_KEY not configured)" + # Validate internal API key - reject if not configured or invalid + if not config.internal_api_key: + logger.error( + "Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured" ) + raise HTTPException( + status_code=503, + detail="Webhook not available: internal API key not configured", + ) + if x_api_key != config.internal_api_key: + raise HTTPException(status_code=401, detail="Invalid API key") # Find task by operation_id task = await stream_registry.find_task_by_operation_id(operation_id) @@ -569,7 +571,7 @@ async def complete_operation( output=( result_output if isinstance(result_output, str) - else str(result_output) + else orjson.dumps(result_output).decode("utf-8") ), success=True, ), @@ -581,7 +583,11 @@ async def complete_operation( result_str = ( request.result if isinstance(request.result, str) - else str(request.result) if request.result else '{"status": "completed"}' + else ( + orjson.dumps(request.result).decode("utf-8") + if request.result + else '{"status": "completed"}' + ) ) await svc._update_pending_operation( session_id=task.session_id, @@ -596,8 +602,9 @@ async def complete_operation( task_id=task.task_id, ) - # Mark task as completed + # Mark task as completed and release Redis lock await stream_registry.mark_task_completed(task.task_id, status="completed") + await svc._mark_operation_completed(task.tool_call_id) else: # Publish error to stream registry from .response_model import StreamError @@ -607,6 +614,8 @@ async def complete_operation( task.task_id, StreamError(errorText=error_msg), ) + # Send finish event to end the stream + await stream_registry.publish_chunk(task.task_id, StreamFinish()) # Update pending operation with error from . import service as svc @@ -622,8 +631,9 @@ async def complete_operation( result=error_response.model_dump_json(), ) - # Mark task as failed + # Mark task as failed and release Redis lock await stream_registry.mark_task_completed(task.task_id, status="failed") + await svc._mark_operation_completed(task.tool_call_id) return {"status": "ok", "task_id": task.task_id} diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index 9b5085038a..49f52f7668 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -1881,6 +1881,10 @@ async def _execute_long_running_tool_with_streaming( If the external service returns a 202 Accepted (async), this function exits early and lets the RabbitMQ completion consumer handle the rest. """ + # Track whether we delegated to async processing - if so, the RabbitMQ + # completion consumer will handle cleanup, not us + delegated_to_async = False + try: # Load fresh session (not stale reference) session = await get_chat_session(session_id, user_id) @@ -1915,9 +1919,10 @@ async def _execute_long_running_tool_with_streaming( f"(operation_id={operation_id}, task_id={task_id}). " f"RabbitMQ completion consumer will handle the rest." ) - # Don't publish result, don't continue with LLM + # Don't publish result, don't continue with LLM, and don't cleanup # The RabbitMQ consumer will handle everything when the external # service completes and publishes to the queue + delegated_to_async = True return except (orjson.JSONDecodeError, TypeError): pass # Not JSON or not async - continue normally @@ -1958,11 +1963,12 @@ async def _execute_long_running_tool_with_streaming( message=f"Tool {tool_name} failed: {str(e)}", ) - # Publish error to stream registry + # Publish error to stream registry followed by finish event await stream_registry.publish_chunk( task_id, StreamError(errorText=str(e)), ) + await stream_registry.publish_chunk(task_id, StreamFinish()) await _update_pending_operation( session_id=session_id, @@ -1973,7 +1979,10 @@ async def _execute_long_running_tool_with_streaming( # Mark task as failed in stream registry await stream_registry.mark_task_completed(task_id, status="failed") finally: - await _mark_operation_completed(tool_call_id) + # Only cleanup if we didn't delegate to async processing + # For async path, the RabbitMQ completion consumer handles cleanup + if not delegated_to_async: + await _mark_operation_completed(tool_call_id) async def _update_pending_operation( @@ -2221,8 +2230,9 @@ async def _generate_llm_continuation_with_streaming( logger.error( f"Failed to generate streaming LLM continuation: {e}", exc_info=True ) - # Publish error to stream registry + # Publish error to stream registry followed by finish event await stream_registry.publish_chunk( task_id, StreamError(errorText=f"Failed to generate response: {e}"), ) + await stream_registry.publish_chunk(task_id, StreamFinish()) diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index b936312ce1..824211be5f 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -22,6 +22,10 @@ import backend.api.features.admin.store_admin_routes import backend.api.features.builder import backend.api.features.builder.routes import backend.api.features.chat.routes as chat_routes +from backend.api.features.chat.completion_consumer import ( + start_completion_consumer, + stop_completion_consumer, +) import backend.api.features.executions.review.routes import backend.api.features.library.db import backend.api.features.library.model @@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI): await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL) await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs() + # Start chat completion consumer for RabbitMQ notifications + try: + await start_completion_consumer() + except Exception as e: + logger.warning(f"Could not start chat completion consumer: {e}") + with launch_darkly_context(): yield + # Stop chat completion consumer + try: + await stop_completion_consumer() + except Exception as e: + logger.warning(f"Error stopping chat completion consumer: {e}") + try: await shutdown_cloud_storage_handler() except Exception as e: