update to only use redis for integration

This commit is contained in:
Swifty
2026-02-02 15:33:36 +01:00
parent e812ee9265
commit ef3fab57fd
4 changed files with 107 additions and 82 deletions

View File

@@ -1,27 +1,24 @@
"""RabbitMQ consumer for operation completion messages.
"""Redis Streams consumer for operation completion messages.
This module provides a consumer that listens for completion notifications
from external services (like Agent Generator) and triggers the appropriate
stream registry and chat service updates.
The consumer initializes its own Prisma client to avoid async context issues.
The consumer uses Redis Streams with consumer groups for reliable message
processing across multiple platform pods.
"""
import asyncio
import logging
import os
import uuid
import orjson
from prisma import Prisma
from pydantic import BaseModel
from redis.exceptions import ResponseError
from backend.data.rabbitmq import (
AsyncRabbitMQ,
Exchange,
ExchangeType,
Queue,
RabbitMQConfig,
)
from backend.data.redis_client import get_redis_async
from . import service as chat_service
from . import stream_registry
@@ -30,24 +27,10 @@ from .tools.models import ErrorResponse
logger = logging.getLogger(__name__)
# Queue and exchange configuration
OPERATION_COMPLETE_EXCHANGE = Exchange(
name="chat_operations",
type=ExchangeType.DIRECT,
durable=True,
)
OPERATION_COMPLETE_QUEUE = Queue(
name="chat_operation_complete",
durable=True,
exchange=OPERATION_COMPLETE_EXCHANGE,
routing_key="operation.complete",
)
RABBITMQ_CONFIG = RabbitMQConfig(
exchanges=[OPERATION_COMPLETE_EXCHANGE],
queues=[OPERATION_COMPLETE_QUEUE],
)
# Stream configuration
COMPLETION_STREAM = "chat:completions"
CONSUMER_GROUP = "chat_consumers"
STREAM_MAX_LENGTH = 10000
class OperationCompleteMessage(BaseModel):
@@ -61,17 +44,20 @@ class OperationCompleteMessage(BaseModel):
class ChatCompletionConsumer:
"""Consumer for chat operation completion messages from RabbitMQ.
"""Consumer for chat operation completion messages from Redis Streams.
This consumer initializes its own Prisma client in start() to ensure
database operations work correctly within this async context.
Uses Redis consumer groups to allow multiple platform pods to consume
messages reliably with automatic redelivery on failure.
"""
def __init__(self):
self._rabbitmq: AsyncRabbitMQ | None = None
self._consumer_task: asyncio.Task | None = None
self._running = False
self._prisma: Prisma | None = None
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
async def start(self) -> None:
"""Start the completion consumer."""
@@ -79,15 +65,29 @@ class ChatCompletionConsumer:
logger.warning("Completion consumer already running")
return
# Don't initialize Prisma here - do it lazily on first message
# to ensure it's in the same async context as the message handler
self._rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
await self._rabbitmq.connect()
# Create consumer group if it doesn't exist
try:
redis = await get_redis_async()
await redis.xgroup_create(
COMPLETION_STREAM,
CONSUMER_GROUP,
id="0",
mkstream=True,
)
logger.info(
f"Created consumer group '{CONSUMER_GROUP}' on stream '{COMPLETION_STREAM}'"
)
except ResponseError as e:
if "BUSYGROUP" in str(e):
logger.debug(f"Consumer group '{CONSUMER_GROUP}' already exists")
else:
raise
self._running = True
self._consumer_task = asyncio.create_task(self._consume_messages())
logger.info("Chat completion consumer started")
logger.info(
f"Chat completion consumer started (consumer: {self._consumer_name})"
)
async def _ensure_prisma(self) -> Prisma:
"""Lazily initialize Prisma client on first use."""
@@ -110,10 +110,6 @@ class ChatCompletionConsumer:
pass
self._consumer_task = None
if self._rabbitmq:
await self._rabbitmq.disconnect()
self._rabbitmq = None
if self._prisma:
await self._prisma.disconnect()
self._prisma = None
@@ -126,33 +122,54 @@ class ChatCompletionConsumer:
max_retries = 10
retry_delay = 5 # seconds
retry_count = 0
block_timeout = 5000 # milliseconds
while self._running and retry_count < max_retries:
if not self._rabbitmq:
logger.error("RabbitMQ not initialized")
return
try:
channel = await self._rabbitmq.get_channel()
queue = await channel.get_queue(OPERATION_COMPLETE_QUEUE.name)
redis = await get_redis_async()
# Reset retry count on successful connection
retry_count = 0
async with queue.iterator() as queue_iter:
async for message in queue_iter:
if not self._running:
return
while self._running:
# Read new messages from the stream
messages = await redis.xreadgroup(
groupname=CONSUMER_GROUP,
consumername=self._consumer_name,
streams={COMPLETION_STREAM: ">"},
block=block_timeout,
count=10,
)
try:
async with message.process():
await self._handle_message(message.body)
except Exception as e:
logger.error(
f"Error processing completion message: {e}",
exc_info=True,
)
# Message will be requeued due to exception
if not messages:
continue
for stream_name, entries in messages:
for entry_id, data in entries:
if not self._running:
return
try:
# Handle the message
message_data = data.get("data")
if message_data:
await self._handle_message(
message_data.encode()
if isinstance(message_data, str)
else message_data
)
# Acknowledge the message
await redis.xack(
COMPLETION_STREAM, CONSUMER_GROUP, entry_id
)
except Exception as e:
logger.error(
f"Error processing completion message {entry_id}: {e}",
exc_info=True,
)
# Message will be redelivered to another consumer
# or can be claimed after timeout
except asyncio.CancelledError:
logger.info("Consumer cancelled")
@@ -363,7 +380,7 @@ async def publish_operation_complete(
result: dict | str | None = None,
error: str | None = None,
) -> None:
"""Publish an operation completion message.
"""Publish an operation completion message to Redis Streams.
Args:
operation_id: The operation ID that completed.
@@ -380,14 +397,10 @@ async def publish_operation_complete(
error=error,
)
rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG)
try:
await rabbitmq.connect()
await rabbitmq.publish_message(
routing_key="operation.complete",
message=message.model_dump_json(),
exchange=OPERATION_COMPLETE_EXCHANGE,
)
logger.info(f"Published completion for operation {operation_id}")
finally:
await rabbitmq.disconnect()
redis = await get_redis_async()
await redis.xadd(
COMPLETION_STREAM,
{"data": message.model_dump_json()},
maxlen=STREAM_MAX_LENGTH,
)
logger.info(f"Published completion for operation {operation_id}")

View File

@@ -209,14 +209,28 @@ async def get_session(
session_id, user_id
)
if active_task:
# Filter out the in-progress assistant message from the session response.
# The client will receive the complete assistant response through the SSE
# stream replay instead, preventing duplicate content.
if messages and messages[-1].get("role") == "assistant":
original_count = len(messages)
messages = messages[:-1]
logger.info(
f"[SSE-RECONNECT] Filtered out in-progress assistant message "
f"(was {original_count} messages, now {len(messages)})"
)
# Use "0-0" as last_message_id to replay the stream from the beginning.
# Since we filtered out the cached assistant message, the client needs
# the full stream to reconstruct the response.
active_stream_info = ActiveStreamInfo(
task_id=active_task.task_id,
last_message_id=last_message_id,
last_message_id="0-0",
)
logger.info(
f"[SSE-RECONNECT] Session {session_id} HAS active stream: "
f"task_id={active_task.task_id}, status={active_task.status}, "
f"last_message_id={last_message_id}"
f"last_message_id=0-0 (replay from start)"
)
else:
logger.info(f"[SSE-RECONNECT] Session {session_id} has NO active stream")

View File

@@ -211,9 +211,7 @@ async def subscribe_to_task(
task_status = meta.get("status", "")
task_user_id = meta.get("user_id", "") or None
logger.info(
f"[SSE-RECONNECT] Subscribing to task {task_id}: status={task_status}"
)
logger.info(f"[SSE-RECONNECT] Subscribing to task {task_id}: status={task_status}")
# Validate ownership
if user_id and task_user_id and task_user_id != user_id:
@@ -256,9 +254,7 @@ async def subscribe_to_task(
logger.info(
f"[SSE-RECONNECT] Task {task_id} is running, starting stream listener"
)
asyncio.create_task(
_stream_listener(task_id, subscriber_queue, replay_last_id)
)
asyncio.create_task(_stream_listener(task_id, subscriber_queue, replay_last_id))
else:
# Task is completed/failed - add finish marker
logger.info(
@@ -470,9 +466,7 @@ async def get_active_task_for_session(
tasks_checked = 0
while True:
cursor, keys = await redis.scan(
cursor, match=f"{TASK_META_PREFIX}*", count=100
)
cursor, keys = await redis.scan(cursor, match=f"{TASK_META_PREFIX}*", count=100)
for key in keys:
tasks_checked += 1

View File

@@ -842,5 +842,9 @@ async def generate_agent_patch(
_check_service_configured()
logger.info("Calling external Agent Generator service for generate_agent_patch")
return await generate_agent_patch_external(
update_request, current_agent, _to_dict_list(library_agents), operation_id, task_id
update_request,
current_agent,
_to_dict_list(library_agents),
operation_id,
task_id,
)