mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-01-29 17:08:01 -05:00
pr comments
This commit is contained in:
@@ -21,7 +21,7 @@ from backend.data.rabbitmq import (
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .response_model import StreamError, StreamToolOutputAvailable
|
||||
from .response_model import StreamError, StreamFinish, StreamToolOutputAvailable
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -96,38 +96,52 @@ class ChatCompletionConsumer:
|
||||
logger.info("Chat completion consumer stopped")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
"""Main message consumption loop."""
|
||||
if not self._rabbitmq:
|
||||
logger.error("RabbitMQ not initialized")
|
||||
return
|
||||
"""Main message consumption loop with retry logic."""
|
||||
max_retries = 10
|
||||
retry_delay = 5 # seconds
|
||||
retry_count = 0
|
||||
|
||||
try:
|
||||
channel = await self._rabbitmq.get_channel()
|
||||
queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name)
|
||||
while self._running and retry_count < max_retries:
|
||||
if not self._rabbitmq:
|
||||
logger.error("RabbitMQ not initialized")
|
||||
return
|
||||
|
||||
async with queue.iterator() as queue_iter:
|
||||
async for message in queue_iter:
|
||||
if not self._running:
|
||||
break
|
||||
try:
|
||||
channel = await self._rabbitmq.get_channel()
|
||||
queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name)
|
||||
|
||||
try:
|
||||
async with message.process():
|
||||
await self._handle_message(message.body)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message will be requeued due to exception
|
||||
# Reset retry count on successful connection
|
||||
retry_count = 0
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Consumer cancelled")
|
||||
except Exception as e:
|
||||
logger.error(f"Consumer error: {e}", exc_info=True)
|
||||
# Attempt to reconnect after a delay
|
||||
if self._running:
|
||||
await asyncio.sleep(5)
|
||||
await self._consume_messages()
|
||||
async with queue.iterator() as queue_iter:
|
||||
async for message in queue_iter:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
async with message.process():
|
||||
await self._handle_message(message.body)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message will be requeued due to exception
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Consumer cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
logger.error(
|
||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if self._running and retry_count < max_retries:
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("Max retries reached, stopping consumer")
|
||||
return
|
||||
|
||||
async def _handle_message(self, body: bytes) -> None:
|
||||
"""Handle a single completion message."""
|
||||
@@ -206,8 +220,9 @@ class ChatCompletionConsumer:
|
||||
task_id=task.task_id,
|
||||
)
|
||||
|
||||
# Mark task as completed
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
|
||||
logger.info(
|
||||
f"Successfully processed completion for task {task.task_id} "
|
||||
@@ -222,11 +237,12 @@ class ChatCompletionConsumer:
|
||||
"""Handle failed operation completion."""
|
||||
error_msg = message.error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry
|
||||
# Publish error to stream registry followed by finish event
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
await stream_registry.publish_chunk(task.task_id, StreamFinish())
|
||||
|
||||
# Update pending operation with error
|
||||
error_response = ErrorResponse(
|
||||
@@ -239,8 +255,9 @@ class ChatCompletionConsumer:
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
|
||||
# Mark task as failed
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
|
||||
logger.info(
|
||||
f"Processed failure for task {task.task_id} "
|
||||
|
||||
@@ -96,6 +96,14 @@ class ChatConfig(BaseSettings):
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
@field_validator("internal_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_internal_api_key(cls, v):
|
||||
"""Get internal API key from environment if not provided."""
|
||||
if v is None:
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
import orjson
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
@@ -532,16 +533,17 @@ async def complete_operation(
|
||||
Raises:
|
||||
HTTPException: If API key is invalid or operation not found.
|
||||
"""
|
||||
# Validate internal API key
|
||||
if config.internal_api_key:
|
||||
if x_api_key != config.internal_api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
else:
|
||||
# If no internal API key is configured, log a warning
|
||||
logger.warning(
|
||||
"Operation complete webhook called without API key validation "
|
||||
"(CHAT_INTERNAL_API_KEY not configured)"
|
||||
# Validate internal API key - reject if not configured or invalid
|
||||
if not config.internal_api_key:
|
||||
logger.error(
|
||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Webhook not available: internal API key not configured",
|
||||
)
|
||||
if x_api_key != config.internal_api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# Find task by operation_id
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
@@ -569,7 +571,7 @@ async def complete_operation(
|
||||
output=(
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else str(result_output)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
),
|
||||
success=True,
|
||||
),
|
||||
@@ -581,7 +583,11 @@ async def complete_operation(
|
||||
result_str = (
|
||||
request.result
|
||||
if isinstance(request.result, str)
|
||||
else str(request.result) if request.result else '{"status": "completed"}'
|
||||
else (
|
||||
orjson.dumps(request.result).decode("utf-8")
|
||||
if request.result
|
||||
else '{"status": "completed"}'
|
||||
)
|
||||
)
|
||||
await svc._update_pending_operation(
|
||||
session_id=task.session_id,
|
||||
@@ -596,8 +602,9 @@ async def complete_operation(
|
||||
task_id=task.task_id,
|
||||
)
|
||||
|
||||
# Mark task as completed
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
await svc._mark_operation_completed(task.tool_call_id)
|
||||
else:
|
||||
# Publish error to stream registry
|
||||
from .response_model import StreamError
|
||||
@@ -607,6 +614,8 @@ async def complete_operation(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
# Send finish event to end the stream
|
||||
await stream_registry.publish_chunk(task.task_id, StreamFinish())
|
||||
|
||||
# Update pending operation with error
|
||||
from . import service as svc
|
||||
@@ -622,8 +631,9 @@ async def complete_operation(
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
|
||||
# Mark task as failed
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
await svc._mark_operation_completed(task.tool_call_id)
|
||||
|
||||
return {"status": "ok", "task_id": task.task_id}
|
||||
|
||||
|
||||
@@ -1881,6 +1881,10 @@ async def _execute_long_running_tool_with_streaming(
|
||||
If the external service returns a 202 Accepted (async), this function exits
|
||||
early and lets the RabbitMQ completion consumer handle the rest.
|
||||
"""
|
||||
# Track whether we delegated to async processing - if so, the RabbitMQ
|
||||
# completion consumer will handle cleanup, not us
|
||||
delegated_to_async = False
|
||||
|
||||
try:
|
||||
# Load fresh session (not stale reference)
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
@@ -1915,9 +1919,10 @@ async def _execute_long_running_tool_with_streaming(
|
||||
f"(operation_id={operation_id}, task_id={task_id}). "
|
||||
f"RabbitMQ completion consumer will handle the rest."
|
||||
)
|
||||
# Don't publish result, don't continue with LLM
|
||||
# Don't publish result, don't continue with LLM, and don't cleanup
|
||||
# The RabbitMQ consumer will handle everything when the external
|
||||
# service completes and publishes to the queue
|
||||
delegated_to_async = True
|
||||
return
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
pass # Not JSON or not async - continue normally
|
||||
@@ -1958,11 +1963,12 @@ async def _execute_long_running_tool_with_streaming(
|
||||
message=f"Tool {tool_name} failed: {str(e)}",
|
||||
)
|
||||
|
||||
# Publish error to stream registry
|
||||
# Publish error to stream registry followed by finish event
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamError(errorText=str(e)),
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
|
||||
await _update_pending_operation(
|
||||
session_id=session_id,
|
||||
@@ -1973,7 +1979,10 @@ async def _execute_long_running_tool_with_streaming(
|
||||
# Mark task as failed in stream registry
|
||||
await stream_registry.mark_task_completed(task_id, status="failed")
|
||||
finally:
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
# Only cleanup if we didn't delegate to async processing
|
||||
# For async path, the RabbitMQ completion consumer handles cleanup
|
||||
if not delegated_to_async:
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
|
||||
|
||||
async def _update_pending_operation(
|
||||
@@ -2221,8 +2230,9 @@ async def _generate_llm_continuation_with_streaming(
|
||||
logger.error(
|
||||
f"Failed to generate streaming LLM continuation: {e}", exc_info=True
|
||||
)
|
||||
# Publish error to stream registry
|
||||
# Publish error to stream registry followed by finish event
|
||||
await stream_registry.publish_chunk(
|
||||
task_id,
|
||||
StreamError(errorText=f"Failed to generate response: {e}"),
|
||||
)
|
||||
await stream_registry.publish_chunk(task_id, StreamFinish())
|
||||
|
||||
@@ -22,6 +22,10 @@ import backend.api.features.admin.store_admin_routes
|
||||
import backend.api.features.builder
|
||||
import backend.api.features.builder.routes
|
||||
import backend.api.features.chat.routes as chat_routes
|
||||
from backend.api.features.chat.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
import backend.api.features.executions.review.routes
|
||||
import backend.api.features.library.db
|
||||
import backend.api.features.library.model
|
||||
@@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
# Start chat completion consumer for RabbitMQ notifications
|
||||
try:
|
||||
await start_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
# Stop chat completion consumer
|
||||
try:
|
||||
await stop_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
|
||||
Reference in New Issue
Block a user