mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-02 19:05:10 -05:00
update to only use redis for integration
This commit is contained in:
@@ -1,27 +1,24 @@
|
||||
"""RabbitMQ consumer for operation completion messages.
|
||||
"""Redis Streams consumer for operation completion messages.
|
||||
|
||||
This module provides a consumer that listens for completion notifications
|
||||
from external services (like Agent Generator) and triggers the appropriate
|
||||
stream registry and chat service updates.
|
||||
|
||||
The consumer initializes its own Prisma client to avoid async context issues.
|
||||
The consumer uses Redis Streams with consumer groups for reliable message
|
||||
processing across multiple platform pods.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
from pydantic import BaseModel
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
from backend.data.rabbitmq import (
|
||||
AsyncRabbitMQ,
|
||||
Exchange,
|
||||
ExchangeType,
|
||||
Queue,
|
||||
RabbitMQConfig,
|
||||
)
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
@@ -30,24 +27,10 @@ from .tools.models import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Queue and exchange configuration
|
||||
OPERATION_COMPLETE_EXCHANGE = Exchange(
|
||||
name="chat_operations",
|
||||
type=ExchangeType.DIRECT,
|
||||
durable=True,
|
||||
)
|
||||
|
||||
OPERATION_COMPLETE_QUEUE = Queue(
|
||||
name="chat_operation_complete",
|
||||
durable=True,
|
||||
exchange=OPERATION_COMPLETE_EXCHANGE,
|
||||
routing_key="operation.complete",
|
||||
)
|
||||
|
||||
RABBITMQ_CONFIG = RabbitMQConfig(
|
||||
exchanges=[OPERATION_COMPLETE_EXCHANGE],
|
||||
queues=[OPERATION_COMPLETE_QUEUE],
|
||||
)
|
||||
# Stream configuration
|
||||
COMPLETION_STREAM = "chat:completions"
|
||||
CONSUMER_GROUP = "chat_consumers"
|
||||
STREAM_MAX_LENGTH = 10000
|
||||
|
||||
|
||||
class OperationCompleteMessage(BaseModel):
|
||||
@@ -61,17 +44,20 @@ class OperationCompleteMessage(BaseModel):
|
||||
|
||||
|
||||
class ChatCompletionConsumer:
|
||||
"""Consumer for chat operation completion messages from RabbitMQ.
|
||||
"""Consumer for chat operation completion messages from Redis Streams.
|
||||
|
||||
This consumer initializes its own Prisma client in start() to ensure
|
||||
database operations work correctly within this async context.
|
||||
|
||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||
messages reliably with automatic redelivery on failure.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._rabbitmq: AsyncRabbitMQ | None = None
|
||||
self._consumer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self._prisma: Prisma | None = None
|
||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the completion consumer."""
|
||||
@@ -79,15 +65,29 @@ class ChatCompletionConsumer:
|
||||
logger.warning("Completion consumer already running")
|
||||
return
|
||||
|
||||
# Don't initialize Prisma here - do it lazily on first message
|
||||
# to ensure it's in the same async context as the message handler
|
||||
|
||||
self._rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
|
||||
await self._rabbitmq.connect()
|
||||
# Create consumer group if it doesn't exist
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.xgroup_create(
|
||||
COMPLETION_STREAM,
|
||||
CONSUMER_GROUP,
|
||||
id="0",
|
||||
mkstream=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Created consumer group '{CONSUMER_GROUP}' on stream '{COMPLETION_STREAM}'"
|
||||
)
|
||||
except ResponseError as e:
|
||||
if "BUSYGROUP" in str(e):
|
||||
logger.debug(f"Consumer group '{CONSUMER_GROUP}' already exists")
|
||||
else:
|
||||
raise
|
||||
|
||||
self._running = True
|
||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||
logger.info("Chat completion consumer started")
|
||||
logger.info(
|
||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||
)
|
||||
|
||||
async def _ensure_prisma(self) -> Prisma:
|
||||
"""Lazily initialize Prisma client on first use."""
|
||||
@@ -110,10 +110,6 @@ class ChatCompletionConsumer:
|
||||
pass
|
||||
self._consumer_task = None
|
||||
|
||||
if self._rabbitmq:
|
||||
await self._rabbitmq.disconnect()
|
||||
self._rabbitmq = None
|
||||
|
||||
if self._prisma:
|
||||
await self._prisma.disconnect()
|
||||
self._prisma = None
|
||||
@@ -126,33 +122,54 @@ class ChatCompletionConsumer:
|
||||
max_retries = 10
|
||||
retry_delay = 5 # seconds
|
||||
retry_count = 0
|
||||
block_timeout = 5000 # milliseconds
|
||||
|
||||
while self._running and retry_count < max_retries:
|
||||
if not self._rabbitmq:
|
||||
logger.error("RabbitMQ not initialized")
|
||||
return
|
||||
|
||||
try:
|
||||
channel = await self._rabbitmq.get_channel()
|
||||
queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name)
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Reset retry count on successful connection
|
||||
retry_count = 0
|
||||
|
||||
async with queue.iterator() as queue_iter:
|
||||
async for message in queue_iter:
|
||||
if not self._running:
|
||||
return
|
||||
while self._running:
|
||||
# Read new messages from the stream
|
||||
messages = await redis.xreadgroup(
|
||||
groupname=CONSUMER_GROUP,
|
||||
consumername=self._consumer_name,
|
||||
streams={COMPLETION_STREAM: ">"},
|
||||
block=block_timeout,
|
||||
count=10,
|
||||
)
|
||||
|
||||
try:
|
||||
async with message.process():
|
||||
await self._handle_message(message.body)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message will be requeued due to exception
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
for stream_name, entries in messages:
|
||||
for entry_id, data in entries:
|
||||
if not self._running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Handle the message
|
||||
message_data = data.get("data")
|
||||
if message_data:
|
||||
await self._handle_message(
|
||||
message_data.encode()
|
||||
if isinstance(message_data, str)
|
||||
else message_data
|
||||
)
|
||||
|
||||
# Acknowledge the message
|
||||
await redis.xack(
|
||||
COMPLETION_STREAM, CONSUMER_GROUP, entry_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message {entry_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message will be redelivered to another consumer
|
||||
# or can be claimed after timeout
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Consumer cancelled")
|
||||
@@ -363,7 +380,7 @@ async def publish_operation_complete(
|
||||
result: dict | str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish an operation completion message.
|
||||
"""Publish an operation completion message to Redis Streams.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID that completed.
|
||||
@@ -380,14 +397,10 @@ async def publish_operation_complete(
|
||||
error=error,
|
||||
)
|
||||
|
||||
rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
|
||||
try:
|
||||
await rabbitmq.connect()
|
||||
await rabbitmq.publish_message(
|
||||
routing_key="operation.complete",
|
||||
message=message.model_dump_json(),
|
||||
exchange=OPERATION_COMPLETE_EXCHANGE,
|
||||
)
|
||||
logger.info(f"Published completion for operation {operation_id}")
|
||||
finally:
|
||||
await rabbitmq.disconnect()
|
||||
redis = await get_redis_async()
|
||||
await redis.xadd(
|
||||
COMPLETION_STREAM,
|
||||
{"data": message.model_dump_json()},
|
||||
maxlen=STREAM_MAX_LENGTH,
|
||||
)
|
||||
logger.info(f"Published completion for operation {operation_id}")
|
||||
|
||||
@@ -209,14 +209,28 @@ async def get_session(
|
||||
session_id, user_id
|
||||
)
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
# stream replay instead, preventing duplicate content.
|
||||
if messages and messages[-1].get("role") == "assistant":
|
||||
original_count = len(messages)
|
||||
messages = messages[:-1]
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Filtered out in-progress assistant message "
|
||||
f"(was {original_count} messages, now {len(messages)})"
|
||||
)
|
||||
|
||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||
# Since we filtered out the cached assistant message, the client needs
|
||||
# the full stream to reconstruct the response.
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
task_id=active_task.task_id,
|
||||
last_message_id=last_message_id,
|
||||
last_message_id="0-0",
|
||||
)
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Session {session_id} HAS active stream: "
|
||||
f"task_id={active_task.task_id}, status={active_task.status}, "
|
||||
f"last_message_id={last_message_id}"
|
||||
f"last_message_id=0-0 (replay from start)"
|
||||
)
|
||||
else:
|
||||
logger.info(f"[SSE-RECONNECT] Session {session_id} has NO active stream")
|
||||
|
||||
@@ -211,9 +211,7 @@ async def subscribe_to_task(
|
||||
task_status = meta.get("status", "")
|
||||
task_user_id = meta.get("user_id", "") or None
|
||||
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Subscribing to task {task_id}: status={task_status}"
|
||||
)
|
||||
logger.info(f"[SSE-RECONNECT] Subscribing to task {task_id}: status={task_status}")
|
||||
|
||||
# Validate ownership
|
||||
if user_id and task_user_id and task_user_id != user_id:
|
||||
@@ -256,9 +254,7 @@ async def subscribe_to_task(
|
||||
logger.info(
|
||||
f"[SSE-RECONNECT] Task {task_id} is running, starting stream listener"
|
||||
)
|
||||
asyncio.create_task(
|
||||
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
||||
)
|
||||
asyncio.create_task(_stream_listener(task_id, subscriber_queue, replay_last_id))
|
||||
else:
|
||||
# Task is completed/failed - add finish marker
|
||||
logger.info(
|
||||
@@ -470,9 +466,7 @@ async def get_active_task_for_session(
|
||||
tasks_checked = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(
|
||||
cursor, match=f"{TASK_META_PREFIX}*", count=100
|
||||
)
|
||||
cursor, keys = await redis.scan(cursor, match=f"{TASK_META_PREFIX}*", count=100)
|
||||
|
||||
for key in keys:
|
||||
tasks_checked += 1
|
||||
|
||||
@@ -842,5 +842,9 @@ async def generate_agent_patch(
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(
|
||||
update_request, current_agent, _to_dict_list(library_agents), operation_id, task_id
|
||||
update_request,
|
||||
current_agent,
|
||||
_to_dict_list(library_agents),
|
||||
operation_id,
|
||||
task_id,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user