pr comments

This commit is contained in:
Swifty
2026-01-29 22:29:17 +01:00
parent 46af3b94f2
commit bb608ea60d
5 changed files with 110 additions and 49 deletions

View File

@@ -21,7 +21,7 @@ from backend.data.rabbitmq import (
from . import service as chat_service
from . import stream_registry
from .response_model import StreamError, StreamToolOutputAvailable
from .response_model import StreamError, StreamFinish, StreamToolOutputAvailable
from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
@@ -96,38 +96,52 @@ class ChatCompletionConsumer:
logger.info("Chat completion consumer stopped")
async def _consume_messages(self) -> None:
"""Main message consumption loop."""
if not self._rabbitmq:
logger.error("RabbitMQ not initialized")
return
"""Main message consumption loop with retry logic."""
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
try:
channel = await self._rabbitmq.get_channel()
queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name)
while self._running and retry_count < max_retries:
if not self._rabbitmq:
logger.error("RabbitMQ not initialized")
return
async with queue.iterator() as queue_iter:
async for message in queue_iter:
if not self._running:
break
try:
channel = await self._rabbitmq.get_channel()
queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name)
try:
async with message.process():
await self._handle_message(message.body)
except Exception as e:
logger.error(
f"Error processing completion message: {e}",
exc_info=True,
)
# Message will be requeued due to exception
# Reset retry count on successful connection
retry_count = 0
except asyncio.CancelledError:
logger.info("Consumer cancelled")
except Exception as e:
logger.error(f"Consumer error: {e}", exc_info=True)
# Attempt to reconnect after a delay
if self._running:
await asyncio.sleep(5)
await self._consume_messages()
async with queue.iterator() as queue_iter:
async for message in queue_iter:
if not self._running:
return
try:
async with message.process():
await self._handle_message(message.body)
except Exception as e:
logger.error(
f"Error processing completion message: {e}",
exc_info=True,
)
# Message will be requeued due to exception
except asyncio.CancelledError:
logger.info("Consumer cancelled")
return
except Exception as e:
retry_count += 1
logger.error(
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
exc_info=True,
)
if self._running and retry_count < max_retries:
await asyncio.sleep(retry_delay)
else:
logger.error("Max retries reached, stopping consumer")
return
async def _handle_message(self, body: bytes) -> None:
"""Handle a single completion message."""
@@ -206,8 +220,9 @@ class ChatCompletionConsumer:
task_id=task.task_id,
)
# Mark task as completed
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
await chat_service._mark_operation_completed(task.tool_call_id)
logger.info(
f"Successfully processed completion for task {task.task_id} "
@@ -222,11 +237,12 @@ class ChatCompletionConsumer:
"""Handle failed operation completion."""
error_msg = message.error or "Operation failed"
# Publish error to stream registry
# Publish error to stream registry followed by finish event
await stream_registry.publish_chunk(
task.task_id,
StreamError(errorText=error_msg),
)
await stream_registry.publish_chunk(task.task_id, StreamFinish())
# Update pending operation with error
error_response = ErrorResponse(
@@ -239,8 +255,9 @@ class ChatCompletionConsumer:
result=error_response.model_dump_json(),
)
# Mark task as failed
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
await chat_service._mark_operation_completed(task.tool_call_id)
logger.info(
f"Processed failure for task {task.task_id} "

View File

@@ -96,6 +96,14 @@ class ChatConfig(BaseSettings):
v = "https://openrouter.ai/api/v1"
return v
@field_validator("internal_api_key", mode="before")
@classmethod
def get_internal_api_key(cls, v):
"""Get internal API key from environment if not provided."""
if v is None:
v = os.getenv("CHAT_INTERNAL_API_KEY")
return v
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -4,6 +4,7 @@ import logging
from collections.abc import AsyncGenerator
from typing import Annotated
import orjson
from autogpt_libs import auth
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
from fastapi.responses import StreamingResponse
@@ -532,16 +533,17 @@ async def complete_operation(
Raises:
HTTPException: If API key is invalid or operation not found.
"""
# Validate internal API key
if config.internal_api_key:
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
else:
# If no internal API key is configured, log a warning
logger.warning(
"Operation complete webhook called without API key validation "
"(CHAT_INTERNAL_API_KEY not configured)"
# Validate internal API key - reject if not configured or invalid
if not config.internal_api_key:
logger.error(
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
)
raise HTTPException(
status_code=503,
detail="Webhook not available: internal API key not configured",
)
if x_api_key != config.internal_api_key:
raise HTTPException(status_code=401, detail="Invalid API key")
# Find task by operation_id
task = await stream_registry.find_task_by_operation_id(operation_id)
@@ -569,7 +571,7 @@ async def complete_operation(
output=(
result_output
if isinstance(result_output, str)
else str(result_output)
else orjson.dumps(result_output).decode("utf-8")
),
success=True,
),
@@ -581,7 +583,11 @@ async def complete_operation(
result_str = (
request.result
if isinstance(request.result, str)
else str(request.result) if request.result else '{"status": "completed"}'
else (
orjson.dumps(request.result).decode("utf-8")
if request.result
else '{"status": "completed"}'
)
)
await svc._update_pending_operation(
session_id=task.session_id,
@@ -596,8 +602,9 @@ async def complete_operation(
task_id=task.task_id,
)
# Mark task as completed
# Mark task as completed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="completed")
await svc._mark_operation_completed(task.tool_call_id)
else:
# Publish error to stream registry
from .response_model import StreamError
@@ -607,6 +614,8 @@ async def complete_operation(
task.task_id,
StreamError(errorText=error_msg),
)
# Send finish event to end the stream
await stream_registry.publish_chunk(task.task_id, StreamFinish())
# Update pending operation with error
from . import service as svc
@@ -622,8 +631,9 @@ async def complete_operation(
result=error_response.model_dump_json(),
)
# Mark task as failed
# Mark task as failed and release Redis lock
await stream_registry.mark_task_completed(task.task_id, status="failed")
await svc._mark_operation_completed(task.tool_call_id)
return {"status": "ok", "task_id": task.task_id}

View File

@@ -1881,6 +1881,10 @@ async def _execute_long_running_tool_with_streaming(
If the external service returns a 202 Accepted (async), this function exits
early and lets the RabbitMQ completion consumer handle the rest.
"""
# Track whether we delegated to async processing - if so, the RabbitMQ
# completion consumer will handle cleanup, not us
delegated_to_async = False
try:
# Load fresh session (not stale reference)
session = await get_chat_session(session_id, user_id)
@@ -1915,9 +1919,10 @@ async def _execute_long_running_tool_with_streaming(
f"(operation_id={operation_id}, task_id={task_id}). "
f"RabbitMQ completion consumer will handle the rest."
)
# Don't publish result, don't continue with LLM
# Don't publish result, don't continue with LLM, and don't cleanup
# The RabbitMQ consumer will handle everything when the external
# service completes and publishes to the queue
delegated_to_async = True
return
except (orjson.JSONDecodeError, TypeError):
pass # Not JSON or not async - continue normally
@@ -1958,11 +1963,12 @@ async def _execute_long_running_tool_with_streaming(
message=f"Tool {tool_name} failed: {str(e)}",
)
# Publish error to stream registry
# Publish error to stream registry followed by finish event
await stream_registry.publish_chunk(
task_id,
StreamError(errorText=str(e)),
)
await stream_registry.publish_chunk(task_id, StreamFinish())
await _update_pending_operation(
session_id=session_id,
@@ -1973,7 +1979,10 @@ async def _execute_long_running_tool_with_streaming(
# Mark task as failed in stream registry
await stream_registry.mark_task_completed(task_id, status="failed")
finally:
await _mark_operation_completed(tool_call_id)
# Only cleanup if we didn't delegate to async processing
# For async path, the RabbitMQ completion consumer handles cleanup
if not delegated_to_async:
await _mark_operation_completed(tool_call_id)
async def _update_pending_operation(
@@ -2221,8 +2230,9 @@ async def _generate_llm_continuation_with_streaming(
logger.error(
f"Failed to generate streaming LLM continuation: {e}", exc_info=True
)
# Publish error to stream registry
# Publish error to stream registry followed by finish event
await stream_registry.publish_chunk(
task_id,
StreamError(errorText=f"Failed to generate response: {e}"),
)
await stream_registry.publish_chunk(task_id, StreamFinish())

View File

@@ -22,6 +22,10 @@ import backend.api.features.admin.store_admin_routes
import backend.api.features.builder
import backend.api.features.builder.routes
import backend.api.features.chat.routes as chat_routes
from backend.api.features.chat.completion_consumer import (
start_completion_consumer,
stop_completion_consumer,
)
import backend.api.features.executions.review.routes
import backend.api.features.library.db
import backend.api.features.library.model
@@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI):
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
# Start chat completion consumer for RabbitMQ notifications
try:
await start_completion_consumer()
except Exception as e:
logger.warning(f"Could not start chat completion consumer: {e}")
with launch_darkly_context():
yield
# Stop chat completion consumer
try:
await stop_completion_consumer()
except Exception as e:
logger.warning(f"Error stopping chat completion consumer: {e}")
try:
await shutdown_cloud_storage_handler()
except Exception as e: