mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-18 18:44:42 -05:00
Uncouple Copilot task execution from the REST API server. This should help performance and scalability, and allows task execution to continue regardless of the state of the user's connection. - Resolves #12023 ### Changes 🏗️ - Add `backend.copilot.executor`->`CoPilotExecutor` (setup similar to `backend.executor`->`ExecutionManager`). This executor service uses RabbitMQ-based task distribution, and sticks with the existing Redis Streams setup for task output. It uses a cluster lock mechanism to ensure a task is only executed by one pod, and the `DatabaseManager` for pooled DB access. - Add `backend.data.db_accessors` for automatic choice of direct/proxied DB access Chat requests now flow: API → RabbitMQ → CoPilot Executor → Redis Streams → SSE Client. This enables horizontal scaling of chat processing and isolates long-running LLM operations from the API service. - Move non-API Copilot stuff into `backend.copilot` (from `backend.api.features.chat`) - Updated import paths for all usages - Move `backend.executor.database` to `backend.data.db_manager` and add methods for copilot executor - Updated import paths for all usages - Make `backend.copilot.db` RPC-compatible (-> DB ops return ~~Prisma~~ Pydantic models) - Make `backend.data.workspace` RPC-compatible - Make `backend.data.graphs.get_store_listed_graphs` RPC-compatible DX: - Add `copilot_executor` service to Docker setup Config: - Add `Config.num_copilot_workers` (default 5) and `Config.copilot_executor_port` (default 8008) - Remove unused `Config.agent_server_port` > [!WARNING] > **This change adds a new microservice to the system, with entrypoint `backend.copilot.executor`.** > The `docker compose` setup has been updated, but if you run the Platform on something else, you'll have to update your deployment config to include this new service. > > When running locally, the `CoPilotExecutor` uses port 8008 by default. ### Checklist 📋 #### For code changes: - [x] I have clearly listed my changes in the PR description - [x] I have made a test plan - [x] I have tested my changes according to the test plan: - [x] Copilot works - [x] Processes messages when triggered - [x] Can use its tools #### For configuration changes: - [x] `.env.default` is updated or already compatible with my changes - [x] `docker-compose.yml` is updated or already compatible with my changes - [x] I have included a list of my configuration changes in the PR description (under **Changes**) --------- Co-authored-by: Zamil Majdy <zamil.majdy@agpt.co>
350 lines
13 KiB
Python
350 lines
13 KiB
Python
"""Redis Streams consumer for operation completion messages.
|
|
|
|
This module provides a consumer (ChatCompletionConsumer) that listens for
|
|
completion notifications (OperationCompleteMessage) from external services
|
|
(like Agent Generator) and triggers the appropriate stream registry and
|
|
chat service updates via process_operation_success/process_operation_failure.
|
|
|
|
Why Redis Streams instead of RabbitMQ?
|
|
--------------------------------------
|
|
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
|
queue), Redis Streams was chosen for chat completion notifications because:
|
|
|
|
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
|
Streams (via stream_registry) for message persistence and replay. Using Redis
|
|
Streams for completion notifications keeps all chat streaming infrastructure
|
|
in one system, simplifying operations and reducing cross-system coordination.
|
|
|
|
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
|
allowing consumers to replay missed messages after reconnection. This aligns
|
|
with the SSE reconnection pattern where clients can resume from last_message_id.
|
|
|
|
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
|
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
|
recovering from dead consumers - ideal for the completion callback pattern.
|
|
|
|
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
|
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
|
|
|
5. **Atomicity with Task State**: Completion processing often needs to update
|
|
task metadata stored in Redis. Keeping both in Redis enables simpler
|
|
transactional semantics without distributed coordination.
|
|
|
|
The consumer uses Redis Streams with consumer groups for reliable message
|
|
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
|
stale pending messages from dead consumers.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import uuid
|
|
from typing import Any
|
|
|
|
import orjson
|
|
from pydantic import BaseModel
|
|
from redis.exceptions import ResponseError
|
|
|
|
from backend.data.redis_client import get_redis_async
|
|
|
|
from . import stream_registry
|
|
from .completion_handler import process_operation_failure, process_operation_success
|
|
from .config import ChatConfig
|
|
|
|
logger = logging.getLogger(__name__)
|
|
config = ChatConfig()
|
|
|
|
|
|
class OperationCompleteMessage(BaseModel):
|
|
"""Message format for operation completion notifications."""
|
|
|
|
operation_id: str
|
|
task_id: str
|
|
success: bool
|
|
result: dict | str | None = None
|
|
error: str | None = None
|
|
|
|
|
|
class ChatCompletionConsumer:
|
|
"""Consumer for chat operation completion messages from Redis Streams.
|
|
|
|
Database operations are handled through the chat_db() accessor, which
|
|
routes through DatabaseManager RPC when Prisma is not directly connected.
|
|
|
|
Uses Redis consumer groups to allow multiple platform pods to consume
|
|
messages reliably with automatic redelivery on failure.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._consumer_task: asyncio.Task | None = None
|
|
self._running = False
|
|
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
|
|
|
async def start(self) -> None:
|
|
"""Start the completion consumer."""
|
|
if self._running:
|
|
logger.warning("Completion consumer already running")
|
|
return
|
|
|
|
# Create consumer group if it doesn't exist
|
|
try:
|
|
redis = await get_redis_async()
|
|
await redis.xgroup_create(
|
|
config.stream_completion_name,
|
|
config.stream_consumer_group,
|
|
id="0",
|
|
mkstream=True,
|
|
)
|
|
logger.info(
|
|
f"Created consumer group '{config.stream_consumer_group}' "
|
|
f"on stream '{config.stream_completion_name}'"
|
|
)
|
|
except ResponseError as e:
|
|
if "BUSYGROUP" in str(e):
|
|
logger.debug(
|
|
f"Consumer group '{config.stream_consumer_group}' already exists"
|
|
)
|
|
else:
|
|
raise
|
|
|
|
self._running = True
|
|
self._consumer_task = asyncio.create_task(self._consume_messages())
|
|
logger.info(
|
|
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
|
)
|
|
|
|
async def stop(self) -> None:
|
|
"""Stop the completion consumer."""
|
|
self._running = False
|
|
|
|
if self._consumer_task:
|
|
self._consumer_task.cancel()
|
|
try:
|
|
await self._consumer_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
self._consumer_task = None
|
|
|
|
logger.info("Chat completion consumer stopped")
|
|
|
|
async def _consume_messages(self) -> None:
|
|
"""Main message consumption loop with retry logic."""
|
|
max_retries = 10
|
|
retry_delay = 5 # seconds
|
|
retry_count = 0
|
|
block_timeout = 5000 # milliseconds
|
|
|
|
while self._running and retry_count < max_retries:
|
|
try:
|
|
redis = await get_redis_async()
|
|
|
|
# Reset retry count on successful connection
|
|
retry_count = 0
|
|
|
|
while self._running:
|
|
# First, claim any stale pending messages from dead consumers
|
|
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
|
# claim them using XAUTOCLAIM
|
|
try:
|
|
claimed_result = await redis.xautoclaim(
|
|
name=config.stream_completion_name,
|
|
groupname=config.stream_consumer_group,
|
|
consumername=self._consumer_name,
|
|
min_idle_time=config.stream_claim_min_idle_ms,
|
|
start_id="0-0",
|
|
count=10,
|
|
)
|
|
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
|
if claimed_result and len(claimed_result) >= 2:
|
|
claimed_entries = claimed_result[1]
|
|
if claimed_entries:
|
|
logger.info(
|
|
f"Claimed {len(claimed_entries)} stale pending messages"
|
|
)
|
|
for entry_id, data in claimed_entries:
|
|
if not self._running:
|
|
return
|
|
await self._process_entry(redis, entry_id, data)
|
|
except Exception as e:
|
|
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
|
|
|
# Read new messages from the stream
|
|
messages = await redis.xreadgroup(
|
|
groupname=config.stream_consumer_group,
|
|
consumername=self._consumer_name,
|
|
streams={config.stream_completion_name: ">"},
|
|
block=block_timeout,
|
|
count=10,
|
|
)
|
|
|
|
if not messages:
|
|
continue
|
|
|
|
for stream_name, entries in messages:
|
|
for entry_id, data in entries:
|
|
if not self._running:
|
|
return
|
|
await self._process_entry(redis, entry_id, data)
|
|
|
|
except asyncio.CancelledError:
|
|
logger.info("Consumer cancelled")
|
|
return
|
|
except Exception as e:
|
|
retry_count += 1
|
|
logger.error(
|
|
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
|
exc_info=True,
|
|
)
|
|
if self._running and retry_count < max_retries:
|
|
await asyncio.sleep(retry_delay)
|
|
else:
|
|
logger.error("Max retries reached, stopping consumer")
|
|
return
|
|
|
|
async def _process_entry(
|
|
self, redis: Any, entry_id: str, data: dict[str, Any]
|
|
) -> None:
|
|
"""Process a single stream entry and acknowledge it on success.
|
|
|
|
Args:
|
|
redis: Redis client connection
|
|
entry_id: The stream entry ID
|
|
data: The entry data dict
|
|
"""
|
|
try:
|
|
# Handle the message
|
|
message_data = data.get("data")
|
|
if message_data:
|
|
await self._handle_message(
|
|
message_data.encode()
|
|
if isinstance(message_data, str)
|
|
else message_data
|
|
)
|
|
|
|
# Acknowledge the message after successful processing
|
|
await redis.xack(
|
|
config.stream_completion_name,
|
|
config.stream_consumer_group,
|
|
entry_id,
|
|
)
|
|
except Exception as e:
|
|
logger.error(
|
|
f"Error processing completion message {entry_id}: {e}",
|
|
exc_info=True,
|
|
)
|
|
# Message remains in pending state and will be claimed by
|
|
# XAUTOCLAIM after min_idle_time expires
|
|
|
|
async def _handle_message(self, body: bytes) -> None:
|
|
"""Handle a completion message."""
|
|
try:
|
|
data = orjson.loads(body)
|
|
message = OperationCompleteMessage(**data)
|
|
except Exception as e:
|
|
logger.error(f"Failed to parse completion message: {e}")
|
|
return
|
|
|
|
logger.info(
|
|
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
|
f"(task_id={message.task_id}, success={message.success})"
|
|
)
|
|
|
|
# Find task in registry
|
|
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
|
if task is None:
|
|
task = await stream_registry.get_task(message.task_id)
|
|
|
|
if task is None:
|
|
logger.warning(
|
|
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
|
f"(task_id={message.task_id})"
|
|
)
|
|
return
|
|
|
|
logger.info(
|
|
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
|
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
|
)
|
|
|
|
# Guard against empty task fields
|
|
if not task.task_id or not task.session_id or not task.tool_call_id:
|
|
logger.error(
|
|
f"[COMPLETION] Task has empty critical fields! "
|
|
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
|
f"tool_call_id={task.tool_call_id!r}"
|
|
)
|
|
return
|
|
|
|
if message.success:
|
|
await self._handle_success(task, message)
|
|
else:
|
|
await self._handle_failure(task, message)
|
|
|
|
async def _handle_success(
|
|
self,
|
|
task: stream_registry.ActiveTask,
|
|
message: OperationCompleteMessage,
|
|
) -> None:
|
|
"""Handle successful operation completion."""
|
|
await process_operation_success(task, message.result)
|
|
|
|
async def _handle_failure(
|
|
self,
|
|
task: stream_registry.ActiveTask,
|
|
message: OperationCompleteMessage,
|
|
) -> None:
|
|
"""Handle failed operation completion."""
|
|
await process_operation_failure(task, message.error)
|
|
|
|
|
|
# Module-level consumer instance
|
|
_consumer: ChatCompletionConsumer | None = None
|
|
|
|
|
|
async def start_completion_consumer() -> None:
|
|
"""Start the global completion consumer."""
|
|
global _consumer
|
|
if _consumer is None:
|
|
_consumer = ChatCompletionConsumer()
|
|
await _consumer.start()
|
|
|
|
|
|
async def stop_completion_consumer() -> None:
|
|
"""Stop the global completion consumer."""
|
|
global _consumer
|
|
if _consumer:
|
|
await _consumer.stop()
|
|
_consumer = None
|
|
|
|
|
|
async def publish_operation_complete(
|
|
operation_id: str,
|
|
task_id: str,
|
|
success: bool,
|
|
result: dict | str | None = None,
|
|
error: str | None = None,
|
|
) -> None:
|
|
"""Publish an operation completion message to Redis Streams.
|
|
|
|
Args:
|
|
operation_id: The operation ID that completed.
|
|
task_id: The task ID associated with the operation.
|
|
success: Whether the operation succeeded.
|
|
result: The result data (for success).
|
|
error: The error message (for failure).
|
|
"""
|
|
message = OperationCompleteMessage(
|
|
operation_id=operation_id,
|
|
task_id=task_id,
|
|
success=success,
|
|
result=result,
|
|
error=error,
|
|
)
|
|
|
|
redis = await get_redis_async()
|
|
await redis.xadd(
|
|
config.stream_completion_name,
|
|
{"data": message.model_dump_json()},
|
|
maxlen=config.stream_max_length,
|
|
)
|
|
logger.info(f"Published completion for operation {operation_id}")
|