diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py index f53b4673f3..83ca543c0e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py +++ b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py @@ -3,12 +3,16 @@ This module provides a consumer that listens for completion notifications from external services (like Agent Generator) and triggers the appropriate stream registry and chat service updates. + +The consumer initializes its own Prisma client to avoid async context issues. """ import asyncio import logging +import os import orjson +from prisma import Prisma from pydantic import BaseModel from backend.data.rabbitmq import ( @@ -57,12 +61,17 @@ class OperationCompleteMessage(BaseModel): class ChatCompletionConsumer: - """Consumer for chat operation completion messages from RabbitMQ.""" + """Consumer for chat operation completion messages from RabbitMQ. + + This consumer initializes its own Prisma client in start() to ensure + database operations work correctly within this async context. + """ def __init__(self): self._rabbitmq: AsyncRabbitMQ | None = None self._consumer_task: asyncio.Task | None = None self._running = False + self._prisma: Prisma | None = None async def start(self) -> None: """Start the completion consumer.""" @@ -70,6 +79,9 @@ class ChatCompletionConsumer: logger.warning("Completion consumer already running") return + # Don't initialize Prisma here - do it lazily on first message + # to ensure it's in the same async context as the message handler + self._rabbitmq = AsyncRabbitMQ(RABBITMQ_CONFIG) await self._rabbitmq.connect() @@ -77,6 +89,15 @@ class ChatCompletionConsumer: self._consumer_task = asyncio.create_task(self._consume_messages()) logger.info("Chat completion consumer started") + async def _ensure_prisma(self) -> Prisma: + """Lazily initialize Prisma client on first use.""" + if self._prisma is None: + database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432") + self._prisma = Prisma(datasource={"url": database_url}) + await self._prisma.connect() + logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)") + return self._prisma + async def stop(self) -> None: """Stop the completion consumer.""" self._running = False @@ -93,6 +114,11 @@ class ChatCompletionConsumer: await self._rabbitmq.disconnect() self._rabbitmq = None + if self._prisma: + await self._prisma.disconnect() + self._prisma = None + logger.info("[COMPLETION] Consumer Prisma client disconnected") + logger.info("Chat completion consumer stopped") async def _consume_messages(self) -> None: @@ -144,7 +170,7 @@ class ChatCompletionConsumer: return async def _handle_message(self, body: bytes) -> None: - """Handle a single completion message.""" + """Handle a completion message using our own Prisma client.""" try: data = orjson.loads(body) message = OperationCompleteMessage(**data) @@ -153,23 +179,36 @@ class ChatCompletionConsumer: return logger.info( - f"Received completion for operation {message.operation_id} " + f"[COMPLETION] Received completion for operation {message.operation_id} " f"(task_id={message.task_id}, success={message.success})" ) # Find task in registry task = await stream_registry.find_task_by_operation_id(message.operation_id) if task is None: - # Try to look up by task_id directly task = await stream_registry.get_task(message.task_id) if task is None: logger.warning( - f"Task not found for operation {message.operation_id} " + f"[COMPLETION] Task not found for operation {message.operation_id} " f"(task_id={message.task_id})" ) return + logger.info( + f"[COMPLETION] Found task: task_id={task.task_id}, " + f"session_id={task.session_id}, tool_call_id={task.tool_call_id}" + ) + + # Guard against empty task fields + if not task.task_id or not task.session_id or not task.tool_call_id: + logger.error( + f"[COMPLETION] Task has empty critical fields! " + f"task_id={task.task_id!r}, session_id={task.session_id!r}, " + f"tool_call_id={task.tool_call_id!r}" + ) + return + if message.success: await self._handle_success(task, message) else: @@ -197,7 +236,7 @@ class ChatCompletionConsumer: ), ) - # Update pending operation in database + # Update pending operation in database using our Prisma client result_str = ( message.result if isinstance(message.result, str) @@ -207,26 +246,45 @@ class ChatCompletionConsumer: else '{"status": "completed"}' ) ) - await chat_service._update_pending_operation( - session_id=task.session_id, - tool_call_id=task.tool_call_id, - result=result_str, - ) + try: + prisma = await self._ensure_prisma() + await prisma.chatmessage.update_many( + where={ + "sessionId": task.session_id, + "toolCallId": task.tool_call_id, + }, + data={"content": result_str}, + ) + logger.info( + f"[COMPLETION] Updated tool message for session {task.session_id}" + ) + except Exception as e: + logger.error( + f"[COMPLETION] Failed to update tool message: {e}", exc_info=True + ) # Generate LLM continuation with streaming - await chat_service._generate_llm_continuation_with_streaming( - session_id=task.session_id, - user_id=task.user_id, - task_id=task.task_id, - ) + try: + await chat_service._generate_llm_continuation_with_streaming( + session_id=task.session_id, + user_id=task.user_id, + task_id=task.task_id, + ) + except Exception as e: + logger.error( + f"[COMPLETION] Failed to generate LLM continuation: {e}", + exc_info=True, + ) # Mark task as completed and release Redis lock await stream_registry.mark_task_completed(task.task_id, status="completed") - await chat_service._mark_operation_completed(task.tool_call_id) + try: + await chat_service._mark_operation_completed(task.tool_call_id) + except Exception as e: + logger.error(f"[COMPLETION] Failed to mark operation completed: {e}") logger.info( - f"Successfully processed completion for task {task.task_id} " - f"(operation {message.operation_id})" + f"[COMPLETION] Successfully processed completion for task {task.task_id}" ) async def _handle_failure( @@ -237,31 +295,44 @@ class ChatCompletionConsumer: """Handle failed operation completion.""" error_msg = message.error or "Operation failed" - # Publish error to stream registry followed by finish event + # Publish error to stream registry await stream_registry.publish_chunk( task.task_id, StreamError(errorText=error_msg), ) await stream_registry.publish_chunk(task.task_id, StreamFinish()) - # Update pending operation with error + # Update pending operation with error using our Prisma client error_response = ErrorResponse( message=error_msg, error=message.error, ) - await chat_service._update_pending_operation( - session_id=task.session_id, - tool_call_id=task.tool_call_id, - result=error_response.model_dump_json(), - ) + try: + prisma = await self._ensure_prisma() + await prisma.chatmessage.update_many( + where={ + "sessionId": task.session_id, + "toolCallId": task.tool_call_id, + }, + data={"content": error_response.model_dump_json()}, + ) + logger.info( + f"[COMPLETION] Updated tool message with error for session {task.session_id}" + ) + except Exception as e: + logger.error( + f"[COMPLETION] Failed to update tool message: {e}", exc_info=True + ) # Mark task as failed and release Redis lock await stream_registry.mark_task_completed(task.task_id, status="failed") - await chat_service._mark_operation_completed(task.tool_call_id) + try: + await chat_service._mark_operation_completed(task.tool_call_id) + except Exception as e: + logger.error(f"[COMPLETION] Failed to mark operation completed: {e}") logger.info( - f"Processed failure for task {task.task_id} " - f"(operation {message.operation_id}): {error_msg}" + f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}" ) @@ -294,9 +365,6 @@ async def publish_operation_complete( ) -> None: """Publish an operation completion message. - This is a helper function for testing or for services that want to - publish completion messages directly. - Args: operation_id: The operation ID that completed. task_id: The task ID associated with the operation. diff --git a/autogpt_platform/backend/backend/api/features/chat/response_model.py b/autogpt_platform/backend/backend/api/features/chat/response_model.py index 53a8cf3a1f..f627a42fcc 100644 --- a/autogpt_platform/backend/backend/api/features/chat/response_model.py +++ b/autogpt_platform/backend/backend/api/features/chat/response_model.py @@ -52,6 +52,10 @@ class StreamStart(StreamBaseResponse): type: ResponseType = ResponseType.START messageId: str = Field(..., description="Unique message ID") + taskId: str | None = Field( + default=None, + description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream", + ) class StreamFinish(StreamBaseResponse): diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 14d5a41482..4ddc1b33ef 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -1,6 +1,7 @@ """Chat API routes for chat session management and streaming via SSE.""" import logging +import uuid as uuid_module from collections.abc import AsyncGenerator from typing import Annotated @@ -16,7 +17,7 @@ from . import service as chat_service from . import stream_registry from .config import ChatConfig from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions -from .response_model import StreamFinish, StreamHeartbeat +from .response_model import StreamFinish, StreamHeartbeat, StreamStart config = ChatConfig() @@ -58,6 +59,13 @@ class CreateSessionResponse(BaseModel): user_id: str | None +class ActiveStreamInfo(BaseModel): + """Information about an active stream for reconnection.""" + + task_id: str + last_message_id: str # Redis Stream message ID for resumption + + class SessionDetailResponse(BaseModel): """Response model providing complete details for a chat session, including messages.""" @@ -66,6 +74,7 @@ class SessionDetailResponse(BaseModel): updated_at: str user_id: str | None messages: list[dict] + active_stream: ActiveStreamInfo | None = None # Present if stream is still active class SessionSummaryResponse(BaseModel): @@ -177,13 +186,14 @@ async def get_session( Retrieve the details of a specific chat session. Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages. + If there's an active stream for this session, returns the task_id for reconnection. Args: session_id: The unique identifier for the desired chat session. user_id: The optional authenticated user ID, or None for anonymous access. Returns: - SessionDetailResponse: Details for the requested session, or None if not found. + SessionDetailResponse: Details for the requested session, including active_stream info if applicable. """ session = await get_chat_session(session_id, user_id) @@ -191,10 +201,31 @@ async def get_session( raise NotFoundError(f"Session {session_id} not found.") messages = [message.model_dump() for message in session.messages] + + # Check if there's an active stream for this session + active_stream_info = None + logger.info(f"[SSE-RECONNECT] Checking for active stream in session {session_id}") + active_task, last_message_id = await stream_registry.get_active_task_for_session( + session_id, user_id + ) + if active_task: + active_stream_info = ActiveStreamInfo( + task_id=active_task.task_id, + last_message_id=last_message_id, + ) + logger.info( + f"[SSE-RECONNECT] Session {session_id} HAS active stream: " + f"task_id={active_task.task_id}, status={active_task.status}, " + f"last_message_id={last_message_id}" + ) + else: + logger.info(f"[SSE-RECONNECT] Session {session_id} has NO active stream") + logger.info( f"Returning session {session_id}: " f"message_count={len(messages)}, " - f"roles={[m.get('role') for m in messages]}" + f"roles={[m.get('role') for m in messages]}, " + f"has_active_stream={active_stream_info is not None}" ) return SessionDetailResponse( @@ -203,6 +234,7 @@ async def get_session( updated_at=session.updated_at.isoformat(), user_id=session.user_id or None, messages=messages, + active_stream=active_stream_info, ) @@ -222,49 +254,136 @@ async def stream_chat_post( - Tool call UI elements (if invoked) - Tool execution results + The AI generation runs in a background task that continues even if the client disconnects. + All chunks are written to Redis for reconnection support. If the client disconnects, + they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off. + Args: session_id: The chat session identifier to associate with the streamed messages. request: Request body containing message, is_user_message, and optional context. user_id: Optional authenticated user ID. Returns: - StreamingResponse: SSE-formatted response chunks. + StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event + containing the task_id for reconnection. """ + import asyncio + session = await _validate_and_get_session(session_id, user_id) - async def event_generator() -> AsyncGenerator[str, None]: + # Create a task in the stream registry for reconnection support + task_id = str(uuid_module.uuid4()) + operation_id = str(uuid_module.uuid4()) + await stream_registry.create_task( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id="chat_stream", # Not a tool call, but needed for the model + tool_name="chat", + operation_id=operation_id, + ) + logger.info( + f"[SSE-RECONNECT] Created stream task for reconnection support: " + f"task_id={task_id}, session_id={session_id}" + ) + + # Background task that runs the AI generation independently of SSE connection + async def run_ai_generation(): chunk_count = 0 first_chunk_type: str | None = None - async for chunk in chat_service.stream_chat_completion( - session_id, - request.message, - is_user_message=request.is_user_message, - user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch - context=request.context, - ): - if chunk_count < 3: - logger.info( - "Chat stream chunk", - extra={ - "session_id": session_id, - "chunk_type": str(chunk.type), - }, - ) - if not first_chunk_type: - first_chunk_type = str(chunk.type) - chunk_count += 1 - yield chunk.to_sse() - logger.info( - "Chat stream completed", - extra={ - "session_id": session_id, - "chunk_count": chunk_count, - "first_chunk_type": first_chunk_type, - }, - ) - # AI SDK protocol termination - yield "data: [DONE]\n\n" + try: + # Emit a start event with task_id for reconnection + start_chunk = StreamStart(messageId=task_id, taskId=task_id) + await stream_registry.publish_chunk(task_id, start_chunk) + + async for chunk in chat_service.stream_chat_completion( + session_id, + request.message, + is_user_message=request.is_user_message, + user_id=user_id, + session=session, # Pass pre-fetched session to avoid double-fetch + context=request.context, + ): + if chunk_count < 3: + logger.info( + "Chat stream chunk", + extra={ + "session_id": session_id, + "chunk_type": str(chunk.type), + }, + ) + if not first_chunk_type: + first_chunk_type = str(chunk.type) + chunk_count += 1 + # Write to Redis (subscribers will receive via pub/sub or polling) + await stream_registry.publish_chunk(task_id, chunk) + + # Mark task as completed + await stream_registry.mark_task_completed(task_id, "completed") + logger.info( + "[SSE-RECONNECT] Background AI generation completed", + extra={ + "session_id": session_id, + "task_id": task_id, + "chunk_count": chunk_count, + "first_chunk_type": first_chunk_type, + }, + ) + except Exception as e: + logger.error( + f"[SSE-RECONNECT] Error in background AI generation for session " + f"{session_id}: {e}" + ) + await stream_registry.mark_task_completed(task_id, "failed") + + # Start the AI generation in a background task + bg_task = asyncio.create_task(run_ai_generation()) + await stream_registry.set_task_asyncio_task(task_id, bg_task) + logger.info(f"[SSE-RECONNECT] Started background AI generation task for {task_id}") + + # SSE endpoint that subscribes to the task's stream + async def event_generator() -> AsyncGenerator[str, None]: + try: + # Subscribe to the task stream (this replays existing messages + live updates) + subscriber_queue = await stream_registry.subscribe_to_task( + task_id=task_id, + user_id=user_id, + last_message_id="0-0", # Get all messages from the beginning + ) + + if subscriber_queue is None: + logger.error(f"Failed to subscribe to task {task_id}") + yield StreamFinish().to_sse() + yield "data: [DONE]\n\n" + return + + # Read from the subscriber queue and yield to SSE + while True: + try: + chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0) + yield chunk.to_sse() + + # Check for finish signal + if isinstance(chunk, StreamFinish): + logger.info( + f"[SSE-RECONNECT] SSE subscriber received finish for task {task_id}" + ) + break + except asyncio.TimeoutError: + # Send heartbeat to keep connection alive + yield StreamHeartbeat().to_sse() + + except GeneratorExit: + # Client disconnected - that's fine, background task continues + logger.info( + f"[SSE-RECONNECT] SSE client disconnected for task {task_id}, " + f"background generation continues" + ) + except Exception as e: + logger.error(f"Error in SSE stream for task {task_id}: {e}") + finally: + # AI SDK protocol termination + yield "data: [DONE]\n\n" return StreamingResponse( event_generator(), @@ -409,6 +528,11 @@ async def stream_task( Raises: NotFoundError: If task_id is not found or user doesn't have access. """ + logger.info( + f"[SSE-RECONNECT] Client reconnecting to task stream: " + f"task_id={task_id}, last_message_id={last_message_id}" + ) + # Get subscriber queue from stream registry subscriber_queue = await stream_registry.subscribe_to_task( task_id=task_id, @@ -417,8 +541,15 @@ async def stream_task( ) if subscriber_queue is None: + logger.warning( + f"[SSE-RECONNECT] Task not found or access denied: task_id={task_id}" + ) raise NotFoundError(f"Task {task_id} not found or access denied.") + logger.info( + f"[SSE-RECONNECT] Successfully subscribed to task stream: task_id={task_id}" + ) + async def event_generator() -> AsyncGenerator[str, None]: import asyncio diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py index b39861bb7f..45f3fc233c 100644 --- a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -1,12 +1,18 @@ """Stream registry for managing reconnectable SSE streams. This module provides a registry for tracking active streaming tasks and their -messages. It supports: -- Creating tasks with unique IDs for long-running operations -- Publishing stream messages to both Redis Streams and in-memory queues -- Subscribing to tasks with replay of missed messages -- Looking up tasks by operation_id for webhook callbacks -- Cross-pod real-time delivery via Redis pub/sub +messages. It uses Redis for all state management (no in-memory state), making +pods stateless and horizontally scalable. + +Architecture: +- Redis Stream: Persists all messages for replay +- Redis Pub/Sub: Real-time delivery to subscribers +- Redis Hash: Task metadata (status, session_id, etc.) + +Subscribers: +1. Replay missed messages from Redis Stream +2. Subscribe to pub/sub channel for live updates +3. No in-memory state required on the subscribing pod """ import asyncio @@ -25,13 +31,10 @@ from .response_model import StreamBaseResponse, StreamFinish logger = logging.getLogger(__name__) config = ChatConfig() -# Track active pub/sub listeners for cross-pod delivery -_pubsub_listeners: dict[str, asyncio.Task] = {} - @dataclass class ActiveTask: - """Represents an active streaming task.""" + """Represents an active streaming task (metadata only, no in-memory queues).""" task_id: str session_id: str @@ -41,22 +44,17 @@ class ActiveTask: operation_id: str status: Literal["running", "completed", "failed"] = "running" created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) - queue: asyncio.Queue[StreamBaseResponse] = field(default_factory=asyncio.Queue) asyncio_task: asyncio.Task | None = None - # Lock for atomic status checks and subscriber management - lock: asyncio.Lock = field(default_factory=asyncio.Lock) - # Set of subscriber queues for fan-out - subscribers: set[asyncio.Queue[StreamBaseResponse]] = field(default_factory=set) -# Module-level registry for active tasks -_active_tasks: dict[str, ActiveTask] = {} - # Redis key patterns TASK_META_PREFIX = "chat:task:meta:" # Hash for task metadata TASK_STREAM_PREFIX = "chat:stream:" # Redis Stream for messages TASK_OP_PREFIX = "chat:task:op:" # Operation ID -> task_id mapping -TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for cross-pod delivery +TASK_PUBSUB_PREFIX = "chat:task:pubsub:" # Pub/sub channel for real-time delivery + +# Track background tasks for this pod (just the asyncio.Task reference, not subscribers) +_local_tasks: dict[str, asyncio.Task] = {} def _get_task_meta_key(task_id: str) -> str: @@ -75,7 +73,7 @@ def _get_operation_mapping_key(operation_id: str) -> str: def _get_task_pubsub_channel(task_id: str) -> str: - """Get Redis pub/sub channel for task cross-pod delivery.""" + """Get Redis pub/sub channel for task real-time delivery.""" return f"{TASK_PUBSUB_PREFIX}{task_id}" @@ -87,7 +85,7 @@ async def create_task( tool_name: str, operation_id: str, ) -> ActiveTask: - """Create a new streaming task in memory and Redis. + """Create a new streaming task in Redis. Args: task_id: Unique identifier for the task @@ -98,7 +96,7 @@ async def create_task( operation_id: Operation ID for webhook callbacks Returns: - The created ActiveTask instance + The created ActiveTask instance (metadata only) """ task = ActiveTask( task_id=task_id, @@ -109,10 +107,7 @@ async def create_task( operation_id=operation_id, ) - # Store in memory registry - _active_tasks[task_id] = task - - # Store metadata in Redis for durability + # Store metadata in Redis redis = await get_redis_async() meta_key = _get_task_meta_key(task_id) op_key = _get_operation_mapping_key(operation_id) @@ -136,8 +131,7 @@ async def create_task( await redis.set(op_key, task_id, ex=config.stream_ttl) logger.info( - f"Created streaming task {task_id} for operation {operation_id} " - f"in session {session_id}" + f"[SSE-RECONNECT] Created task {task_id} for session {session_id} in Redis" ) return task @@ -147,41 +141,26 @@ async def publish_chunk( task_id: str, chunk: StreamBaseResponse, ) -> str: - """Publish a chunk to the task's stream. + """Publish a chunk to Redis Stream and pub/sub channel. - Delivers to in-memory subscribers first (for real-time), then persists to - Redis Stream (for replay). This order ensures live subscribers get messages - even if Redis temporarily fails. + All delivery is via Redis - no in-memory state. Args: task_id: Task ID to publish to chunk: The stream response chunk to publish Returns: - The Redis Stream message ID (format: "timestamp-sequence"), or "0-0" if - Redis persistence failed + The Redis Stream message ID """ - # Deliver to in-memory subscribers FIRST for real-time updates - task = _active_tasks.get(task_id) - if task: - async with task.lock: - for subscriber_queue in task.subscribers: - try: - subscriber_queue.put_nowait(chunk) - except asyncio.QueueFull: - logger.warning( - f"Subscriber queue full for task {task_id}, dropping chunk" - ) - - # Then persist to Redis Stream for replay (with error handling) - message_id = "0-0" chunk_json = chunk.model_dump_json() + message_id = "0-0" + try: redis = await get_redis_async() stream_key = _get_task_stream_key(task_id) + pubsub_channel = _get_task_pubsub_channel(task_id) - # Add to Redis Stream with auto-generated ID - # The ID format is "timestamp-sequence" which gives us ordering + # Write to Redis Stream for persistence/replay raw_id = await redis.xadd( stream_key, {"data": chunk_json}, @@ -189,14 +168,13 @@ async def publish_chunk( ) message_id = raw_id if isinstance(raw_id, str) else raw_id.decode() - # Publish to pub/sub for cross-pod real-time delivery - pubsub_channel = _get_task_pubsub_channel(task_id) + # Publish to pub/sub for real-time delivery await redis.publish(pubsub_channel, chunk_json) logger.debug(f"Published chunk to task {task_id}, message_id={message_id}") except Exception as e: logger.error( - f"Failed to persist chunk to Redis for task {task_id}: {e}", + f"Failed to publish chunk for task {task_id}: {e}", exc_info=True, ) @@ -210,6 +188,8 @@ async def subscribe_to_task( ) -> asyncio.Queue[StreamBaseResponse] | None: """Subscribe to a task's stream with replay of missed messages. + This is fully stateless - uses Redis Stream for replay and pub/sub for live updates. + Args: task_id: Task ID to subscribe to user_id: User ID for ownership validation @@ -219,102 +199,23 @@ async def subscribe_to_task( An asyncio Queue that will receive stream chunks, or None if task not found or user doesn't have access """ - # Check in-memory first - task = _active_tasks.get(task_id) - - if task: - # Validate ownership - if user_id and task.user_id and task.user_id != user_id: - logger.warning( - f"User {user_id} attempted to subscribe to task {task_id} " - f"owned by {task.user_id}" - ) - return None - - # Create a new queue for this subscriber - subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue() - - # Replay from Redis Stream - redis = await get_redis_async() - stream_key = _get_task_stream_key(task_id) - - # Track the last message ID we've seen for gap detection - replay_last_id = last_message_id - - # Read all messages from stream starting after last_message_id - # xread returns messages with ID > last_message_id - messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) - - if messages: - # messages format: [[stream_name, [(id, {data: json}), ...]]] - for _stream_name, stream_messages in messages: - for msg_id, msg_data in stream_messages: - # Track the last message ID we've processed - replay_last_id = ( - msg_id if isinstance(msg_id, str) else msg_id.decode() - ) - if b"data" in msg_data: - try: - chunk_data = orjson.loads(msg_data[b"data"]) - # Reconstruct the appropriate response type - chunk = _reconstruct_chunk(chunk_data) - if chunk: - await subscriber_queue.put(chunk) - except Exception as e: - logger.warning(f"Failed to replay message: {e}") - - # Atomically check status and register subscriber under lock - # This prevents race condition where task completes between check and subscribe - should_start_pubsub = False - async with task.lock: - if task.status == "running": - # Register this subscriber for live updates - task.subscribers.add(subscriber_queue) - # Start pub/sub listener if this is the first subscriber - should_start_pubsub = len(task.subscribers) == 1 - logger.debug( - f"Registered subscriber for task {task_id}, " - f"total subscribers: {len(task.subscribers)}" - ) - else: - # Task is done, add finish marker - await subscriber_queue.put(StreamFinish()) - - # After registering, do a second read to catch any messages published - # between the first read and registration (closes the race window) - if task.status == "running": - gap_messages = await redis.xread( - {stream_key: replay_last_id}, block=0, count=1000 - ) - if gap_messages: - for _stream_name, stream_messages in gap_messages: - for _msg_id, msg_data in stream_messages: - if b"data" in msg_data: - try: - chunk_data = orjson.loads(msg_data[b"data"]) - chunk = _reconstruct_chunk(chunk_data) - if chunk: - await subscriber_queue.put(chunk) - except Exception as e: - logger.warning(f"Failed to replay gap message: {e}") - - # Start pub/sub listener outside the lock to avoid deadlocks - if should_start_pubsub: - await start_pubsub_listener(task_id) - - return subscriber_queue - - # Try to load from Redis if not in memory redis = await get_redis_async() meta_key = _get_task_meta_key(task_id) meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] if not meta: - logger.warning(f"Task {task_id} not found in memory or Redis") + logger.warning(f"[SSE-RECONNECT] Task {task_id} not found in Redis") return None + # Note: Redis client uses decode_responses=True, so keys are strings + task_status = meta.get("status", "") + task_user_id = meta.get("user_id", "") or None + + logger.info( + f"[SSE-RECONNECT] Subscribing to task {task_id}: status={task_status}" + ) + # Validate ownership - task_user_id = meta.get(b"user_id", b"").decode() or None if user_id and task_user_id and task_user_id != user_id: logger.warning( f"User {user_id} attempted to subscribe to task {task_id} " @@ -322,79 +223,158 @@ async def subscribe_to_task( ) return None - # Replay from Redis Stream only (task is not in memory, so it's completed/crashed) - subscriber_queue = asyncio.Queue() + subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue() stream_key = _get_task_stream_key(task_id) - # Read all messages starting after last_message_id + # Step 1: Replay messages from Redis Stream messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) + replayed_count = 0 + replay_last_id = last_message_id if messages: for _stream_name, stream_messages in messages: - for _msg_id, msg_data in stream_messages: - if b"data" in msg_data: + for msg_id, msg_data in stream_messages: + replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + # Note: Redis client uses decode_responses=True, so keys are strings + if "data" in msg_data: try: - chunk_data = orjson.loads(msg_data[b"data"]) + chunk_data = orjson.loads(msg_data["data"]) chunk = _reconstruct_chunk(chunk_data) if chunk: await subscriber_queue.put(chunk) + replayed_count += 1 except Exception as e: logger.warning(f"Failed to replay message: {e}") - # Add finish marker since task is not active - await subscriber_queue.put(StreamFinish()) + logger.info( + f"[SSE-RECONNECT] Task {task_id}: replayed {replayed_count} messages " + f"(last_id={replay_last_id})" + ) + + # Step 2: If task is still running, start stream listener for live updates + if task_status == "running": + logger.info( + f"[SSE-RECONNECT] Task {task_id} is running, starting stream listener" + ) + asyncio.create_task( + _stream_listener(task_id, subscriber_queue, replay_last_id) + ) + else: + # Task is completed/failed - add finish marker + logger.info( + f"[SSE-RECONNECT] Task {task_id} is {task_status}, adding finish marker" + ) + await subscriber_queue.put(StreamFinish()) return subscriber_queue +async def _stream_listener( + task_id: str, + subscriber_queue: asyncio.Queue[StreamBaseResponse], + last_replayed_id: str, +) -> None: + """Listen to Redis Stream for new messages using blocking XREAD. + + This approach avoids the duplicate message issue that can occur with pub/sub + when messages are published during the gap between replay and subscription. + + Args: + task_id: Task ID to listen for + subscriber_queue: Queue to deliver messages to + last_replayed_id: Last message ID from replay (continue from here) + """ + try: + redis = await get_redis_async() + stream_key = _get_task_stream_key(task_id) + current_id = last_replayed_id + + logger.debug( + f"[SSE-RECONNECT] Stream listener started for task {task_id}, " + f"from ID {current_id}" + ) + + while True: + # Block for up to 30 seconds waiting for new messages + # This allows periodic checking if task is still running + messages = await redis.xread( + {stream_key: current_id}, block=30000, count=100 + ) + + if not messages: + # Timeout - check if task is still running + meta_key = _get_task_meta_key(task_id) + status = await redis.hget(meta_key, "status") # type: ignore[misc] + if status and status != "running": + logger.info( + f"[SSE-RECONNECT] Task {task_id} no longer running " + f"(status={status}), stopping listener" + ) + subscriber_queue.put_nowait(StreamFinish()) + break + continue + + for _stream_name, stream_messages in messages: + for msg_id, msg_data in stream_messages: + current_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + + if "data" not in msg_data: + continue + + try: + chunk_data = orjson.loads(msg_data["data"]) + chunk = _reconstruct_chunk(chunk_data) + if chunk: + try: + subscriber_queue.put_nowait(chunk) + except asyncio.QueueFull: + logger.warning( + f"Subscriber queue full for task {task_id}" + ) + + # Stop listening on finish + if isinstance(chunk, StreamFinish): + logger.info( + f"[SSE-RECONNECT] Task {task_id} finished " + "via stream" + ) + return + except Exception as e: + logger.warning(f"Error processing stream message: {e}") + + except asyncio.CancelledError: + logger.debug(f"[SSE-RECONNECT] Stream listener cancelled for task {task_id}") + except Exception as e: + logger.error(f"Stream listener error for task {task_id}: {e}") + # On error, send finish to unblock subscriber + try: + subscriber_queue.put_nowait(StreamFinish()) + except asyncio.QueueFull: + pass + + async def mark_task_completed( task_id: str, status: Literal["completed", "failed"] = "completed", ) -> None: - """Mark a task as completed and publish final event. + """Mark a task as completed and publish finish event. Args: task_id: Task ID to mark as completed status: Final status ("completed" or "failed") """ - task = _active_tasks.get(task_id) - - if task: - # Acquire lock to prevent new subscribers during completion - async with task.lock: - task.status = status - # Send finish event directly to all current subscribers - finish_event = StreamFinish() - for subscriber_queue in task.subscribers: - try: - subscriber_queue.put_nowait(finish_event) - except asyncio.QueueFull: - logger.warning( - f"Subscriber queue full for task {task_id} during completion" - ) - # Clear subscribers since task is done - task.subscribers.clear() - - # Stop pub/sub listener since task is done - await stop_pubsub_listener(task_id) - - # Also publish to Redis Stream for replay (and pub/sub for cross-pod) - await publish_chunk(task_id, StreamFinish()) - - # Remove from active tasks after a short delay to allow subscribers to finish - async def _cleanup(): - await asyncio.sleep(5) - _active_tasks.pop(task_id, None) - logger.info(f"Cleaned up task {task_id} from memory") - - asyncio.create_task(_cleanup()) + # Publish finish event (goes to Redis Stream + pub/sub) + await publish_chunk(task_id, StreamFinish()) # Update Redis metadata redis = await get_redis_async() meta_key = _get_task_meta_key(task_id) await redis.hset(meta_key, "status", status) # type: ignore[misc] - logger.info(f"Marked task {task_id} as {status}") + # Clean up local task reference if exists + _local_tasks.pop(task_id, None) + + logger.info(f"[SSE-RECONNECT] Marked task {task_id} as {status}") async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None: @@ -408,43 +388,26 @@ async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None: Returns: ActiveTask if found, None otherwise """ - # Check in-memory first - for task in _active_tasks.values(): - if task.operation_id == operation_id: - return task - - # Try Redis lookup redis = await get_redis_async() op_key = _get_operation_mapping_key(operation_id) task_id = await redis.get(op_key) - if task_id: - task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id - # Check if task is in memory - if task_id_str in _active_tasks: - return _active_tasks[task_id_str] + logger.info( + f"[SSE-RECONNECT] find_task_by_operation_id: " + f"op_key={op_key}, task_id_from_redis={task_id!r}" + ) - # Load metadata from Redis - meta_key = _get_task_meta_key(task_id_str) - meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + if not task_id: + logger.info(f"[SSE-RECONNECT] No task_id found for operation {operation_id}") + return None - if meta: - # Reconstruct task object (not fully active, but has metadata) - return ActiveTask( - task_id=meta.get(b"task_id", b"").decode(), - session_id=meta.get(b"session_id", b"").decode(), - user_id=meta.get(b"user_id", b"").decode() or None, - tool_call_id=meta.get(b"tool_call_id", b"").decode(), - tool_name=meta.get(b"tool_name", b"").decode(), - operation_id=operation_id, - status=meta.get(b"status", b"running").decode(), # type: ignore - ) - - return None + task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id + logger.info(f"[SSE-RECONNECT] Looking up task by task_id={task_id_str}") + return await get_task(task_id_str) async def get_task(task_id: str) -> ActiveTask | None: - """Get a task by its ID. + """Get a task by its ID from Redis. Args: task_id: Task ID to look up @@ -452,27 +415,127 @@ async def get_task(task_id: str) -> ActiveTask | None: Returns: ActiveTask if found, None otherwise """ - # Check in-memory first - if task_id in _active_tasks: - return _active_tasks[task_id] - - # Try Redis lookup redis = await get_redis_async() meta_key = _get_task_meta_key(task_id) meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] - if meta: - return ActiveTask( - task_id=meta.get(b"task_id", b"").decode(), - session_id=meta.get(b"session_id", b"").decode(), - user_id=meta.get(b"user_id", b"").decode() or None, - tool_call_id=meta.get(b"tool_call_id", b"").decode(), - tool_name=meta.get(b"tool_name", b"").decode(), - operation_id=meta.get(b"operation_id", b"").decode(), - status=meta.get(b"status", b"running").decode(), # type: ignore[arg-type] + logger.info( + f"[SSE-RECONNECT] get_task: meta_key={meta_key}, " + f"meta_keys={list(meta.keys()) if meta else 'empty'}, " + f"meta={meta}" + ) + + if not meta: + logger.info(f"[SSE-RECONNECT] No metadata found for task {task_id}") + return None + + # Note: Redis client uses decode_responses=True, so keys/values are strings + task = ActiveTask( + task_id=meta.get("task_id", ""), + session_id=meta.get("session_id", ""), + user_id=meta.get("user_id", "") or None, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status=meta.get("status", "running"), # type: ignore[arg-type] + ) + logger.info( + f"[SSE-RECONNECT] get_task returning: task_id={task.task_id}, " + f"session_id={task.session_id}, operation_id={task.operation_id}" + ) + return task + + +async def get_active_task_for_session( + session_id: str, + user_id: str | None = None, +) -> tuple[ActiveTask | None, str]: + """Get the active (running) task for a session, if any. + + Scans Redis for tasks matching the session_id with status="running". + + Args: + session_id: Session ID to look up + user_id: User ID for ownership validation (optional) + + Returns: + Tuple of (ActiveTask if found and running, last_message_id from Redis Stream) + """ + logger.info(f"[SSE-RECONNECT] Looking for active task for session {session_id}") + + redis = await get_redis_async() + + # Scan Redis for task metadata keys + cursor = 0 + tasks_checked = 0 + + while True: + cursor, keys = await redis.scan( + cursor, match=f"{TASK_META_PREFIX}*", count=100 ) - return None + for key in keys: + tasks_checked += 1 + meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc] + if not meta: + continue + + # Note: Redis client uses decode_responses=True, so keys/values are strings + task_session_id = meta.get("session_id", "") + task_status = meta.get("status", "") + task_user_id = meta.get("user_id", "") or None + task_id = meta.get("task_id", "") + + # Log tasks found for this session + if task_session_id == session_id: + logger.info( + f"[SSE-RECONNECT] Found task for session: " + f"task_id={task_id}, status={task_status}" + ) + + if task_session_id == session_id and task_status == "running": + # Validate ownership + if user_id and task_user_id and task_user_id != user_id: + logger.info(f"[SSE-RECONNECT] Task {task_id} ownership mismatch") + continue + + # Get the last message ID from Redis Stream + stream_key = _get_task_stream_key(task_id) + last_id = "0-0" + try: + messages = await redis.xrevrange(stream_key, count=1) + if messages: + msg_id = messages[0][0] + last_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + except Exception as e: + logger.warning(f"Failed to get last message ID: {e}") + + logger.info( + f"[SSE-RECONNECT] Found active task: task_id={task_id}, " + f"last_message_id={last_id}" + ) + + return ( + ActiveTask( + task_id=task_id, + session_id=task_session_id, + user_id=task_user_id, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status="running", + ), + last_id, + ) + + if cursor == 0: + break + + logger.info( + f"[SSE-RECONNECT] No active task found for session {session_id} " + f"(checked {tasks_checked} tasks)" + ) + return None, "0-0" def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None: @@ -533,116 +596,30 @@ def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None: async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None: - """Associate an asyncio.Task with an ActiveTask. + """Track the asyncio.Task for a task (local reference only). + + This is just for cleanup purposes - the task state is in Redis. Args: task_id: Task ID - asyncio_task: The asyncio Task to associate + asyncio_task: The asyncio Task to track """ - task = _active_tasks.get(task_id) - if task: - task.asyncio_task = asyncio_task + _local_tasks[task_id] = asyncio_task async def unsubscribe_from_task( task_id: str, subscriber_queue: asyncio.Queue[StreamBaseResponse], ) -> None: - """Unsubscribe a queue from a task's stream. + """Clean up when a subscriber disconnects. - Should be called when a client disconnects to clean up resources. - Also stops the pub/sub listener if there are no more local subscribers. + With Redis-based pub/sub, there's no explicit unsubscription needed. + The pub/sub listener task will be garbage collected when the subscriber + stops reading from the queue. Args: - task_id: Task ID to unsubscribe from - subscriber_queue: The queue to remove from subscribers + task_id: Task ID + subscriber_queue: The subscriber's queue (unused, kept for API compat) """ - task = _active_tasks.get(task_id) - if task: - async with task.lock: - task.subscribers.discard(subscriber_queue) - remaining = len(task.subscribers) - logger.debug( - f"Unsubscribed from task {task_id}, " - f"remaining subscribers: {remaining}" - ) - # Stop pub/sub listener if no more local subscribers - if remaining == 0: - await stop_pubsub_listener(task_id) - - -async def start_pubsub_listener(task_id: str) -> None: - """Start listening to Redis pub/sub for cross-pod delivery. - - This enables real-time updates when another pod publishes chunks for a task - that has local subscribers on this pod. - - Args: - task_id: Task ID to listen for - """ - if task_id in _pubsub_listeners: - return # Already listening - - task = _active_tasks.get(task_id) - if not task: - return - - async def _listener(): - try: - redis = await get_redis_async() - pubsub = redis.pubsub() - channel = _get_task_pubsub_channel(task_id) - await pubsub.subscribe(channel) - logger.debug(f"Started pub/sub listener for task {task_id}") - - async for message in pubsub.listen(): - if message["type"] != "message": - continue - - try: - chunk_data = orjson.loads(message["data"]) - chunk = _reconstruct_chunk(chunk_data) - if chunk: - # Deliver to local subscribers - local_task = _active_tasks.get(task_id) - if local_task: - async with local_task.lock: - for queue in local_task.subscribers: - try: - queue.put_nowait(chunk) - except asyncio.QueueFull: - pass - # Stop listening if this was a finish event - if isinstance(chunk, StreamFinish): - break - except Exception as e: - logger.warning(f"Error processing pub/sub message: {e}") - - await pubsub.unsubscribe(channel) - await pubsub.close() - except asyncio.CancelledError: - pass - except Exception as e: - logger.error(f"Pub/sub listener error for task {task_id}: {e}") - finally: - _pubsub_listeners.pop(task_id, None) - logger.debug(f"Stopped pub/sub listener for task {task_id}") - - listener_task = asyncio.create_task(_listener()) - _pubsub_listeners[task_id] = listener_task - - -async def stop_pubsub_listener(task_id: str) -> None: - """Stop the pub/sub listener for a task. - - Args: - task_id: Task ID to stop listening for - """ - listener = _pubsub_listeners.pop(task_id, None) - if listener and not listener.done(): - listener.cancel() - try: - await listener - except asyncio.CancelledError: - pass - logger.debug(f"Cancelled pub/sub listener for task {task_id}") + # No-op - pub/sub listener cleans up automatically + logger.debug(f"[SSE-RECONNECT] Subscriber disconnected from task {task_id}") diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index e88a8c2d18..43ae0eb180 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -1079,7 +1079,7 @@ "get": { "tags": ["v2", "chat", "chat"], "summary": "Get Session", - "description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.", + "description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.", "operationId": "getV2GetSession", "security": [{ "HTTPBearerJWT": [] }], "parameters": [ @@ -1214,7 +1214,7 @@ "post": { "tags": ["v2", "chat", "chat"], "summary": "Stream Chat Post", - "description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.", + "description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.", "operationId": "postV2StreamChatPost", "security": [{ "HTTPBearerJWT": [] }], "parameters": [ @@ -6313,6 +6313,16 @@ "title": "AccuracyTrendsResponse", "description": "Response model for accuracy trends and alerts." }, + "ActiveStreamInfo": { + "properties": { + "task_id": { "type": "string", "title": "Task Id" }, + "last_message_id": { "type": "string", "title": "Last Message Id" } + }, + "type": "object", + "required": ["task_id", "last_message_id"], + "title": "ActiveStreamInfo", + "description": "Information about an active stream for reconnection." + }, "AddUserCreditsResponse": { "properties": { "new_balance": { "type": "integer", "title": "New Balance" }, @@ -9808,6 +9818,12 @@ "items": { "additionalProperties": true, "type": "object" }, "type": "array", "title": "Messages" + }, + "active_stream": { + "anyOf": [ + { "$ref": "#/components/schemas/ActiveStreamInfo" }, + { "type": "null" } + ] } }, "type": "object", diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx index 5a2eef33b5..e0eee93d13 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx @@ -25,6 +25,7 @@ export function Chat({ const { urlSessionId } = useCopilotSessionId(); const hasHandledNotFoundRef = useRef(false); const { + session, messages, isLoading, isCreating, @@ -36,6 +37,21 @@ export function Chat({ startPollingForOperation, } = useChat({ urlSessionId }); + // Extract active stream info for reconnection + const activeStream = (session as { active_stream?: { task_id: string; last_message_id: string } })?.active_stream; + + // Debug logging for SSE reconnection + if (session) { + console.info("[SSE-RECONNECT] Session loaded:", { + sessionId, + hasActiveStream: !!activeStream, + activeStream: activeStream ? { + taskId: activeStream.task_id, + lastMessageId: activeStream.last_message_id, + } : null, + }); + } + useEffect(() => { if (!onSessionNotFound) return; if (!urlSessionId) return; @@ -83,6 +99,10 @@ export function Chat({ className="flex-1" onStreamingChange={onStreamingChange} onOperationStarted={startPollingForOperation} + activeStream={activeStream ? { + taskId: activeStream.task_id, + lastMessageId: activeStream.last_message_id, + } : undefined} /> )} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts index 9149ca5d04..ed1e6dcd99 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts @@ -391,6 +391,12 @@ export const useChatStore = create((set, get) => ({ lastMessageId = "0-0", // Redis Stream ID format onChunk, ) { + console.info("[SSE-RECONNECT] reconnectToTask called:", { + sessionId, + taskId, + lastMessageId, + }); + const state = get(); const newActiveStreams = new Map(state.activeStreams); let newCompletedStreams = new Map(state.completedStreams); @@ -435,8 +441,14 @@ export const useChatStore = create((set, get) => ({ completedStreams: newCompletedStreams, }); + console.info("[SSE-RECONNECT] Starting executeTaskReconnect..."); try { await executeTaskReconnect(stream, taskId, lastMessageId); + console.info("[SSE-RECONNECT] executeTaskReconnect completed:", { + sessionId, + taskId, + streamStatus: stream.status, + }); } finally { if (onChunk) stream.onChunkCallbacks.delete(onChunk); if (stream.status !== "streaming") { @@ -468,9 +480,16 @@ export const useChatStore = create((set, get) => ({ // Clear active task on completion const taskState = get(); const newActiveTasks = new Map(taskState.activeTasks); + const hadActiveTask = newActiveTasks.has(sessionId); newActiveTasks.delete(sessionId); set({ activeTasks: newActiveTasks }); persistTasks(newActiveTasks); + if (hadActiveTask) { + console.info( + `[ChatStore] Cleared active task for session ${sessionId} ` + + `(stream status: ${stream.status})`, + ); + } } } } diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts index 8c8aa7b704..e04e69b32b 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts @@ -4,6 +4,7 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error"; export interface StreamChunk { type: + | "stream_start" | "text_chunk" | "text_ended" | "tool_call" @@ -15,6 +16,7 @@ export interface StreamChunk { | "error" | "usage" | "stream_end"; + taskId?: string; // Task ID for SSE reconnection timestamp?: string; content?: string; message?: string; @@ -41,7 +43,7 @@ export interface StreamChunk { } export type VercelStreamChunk = - | { type: "start"; messageId: string } + | { type: "start"; messageId: string; taskId?: string } | { type: "finish" } | { type: "text-start"; id: string } | { type: "text-delta"; id: string; delta: string } diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx index dec221338a..d49d01c11f 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx @@ -17,6 +17,11 @@ export interface ChatContainerProps { className?: string; onStreamingChange?: (isStreaming: boolean) => void; onOperationStarted?: () => void; + /** Active stream info from the server for reconnection */ + activeStream?: { + taskId: string; + lastMessageId: string; + }; } export function ChatContainer({ @@ -26,6 +31,7 @@ export function ChatContainer({ className, onStreamingChange, onOperationStarted, + activeStream, }: ChatContainerProps) { const { messages, @@ -41,6 +47,7 @@ export function ChatContainer({ initialMessages, initialPrompt, onOperationStarted, + activeStream, }); useEffect(() => { diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts index 82e9b05e88..ddab29e1f1 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts @@ -34,6 +34,22 @@ export function createStreamEventDispatcher( } switch (chunk.type) { + case "stream_start": + // Store task ID for SSE reconnection + if (chunk.taskId && deps.onActiveTaskStarted) { + console.info("[ChatStream] Stream started with task ID:", { + sessionId: deps.sessionId, + taskId: chunk.taskId, + }); + deps.onActiveTaskStarted({ + taskId: chunk.taskId, + operationId: chunk.taskId, // Use taskId as operationId for chat streams + toolName: "chat", + toolCallId: "chat_stream", + }); + } + break; + case "text_chunk": handleTextChunk(chunk, deps); break; @@ -56,7 +72,8 @@ export function createStreamEventDispatcher( break; case "stream_end": - console.info("[ChatStream] Stream ended:", { + // Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk + console.info("[SSE-RECONNECT] Stream ended:", { sessionId: deps.sessionId, hasResponse: deps.hasResponseRef.current, chunkCount: deps.streamingChunksRef.current.length, diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts index 34941ed6a7..cd260abab9 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts @@ -221,8 +221,10 @@ export function handleStreamEnd( _chunk: StreamChunk, deps: HandlerDependencies, ) { + console.info("[SSE-RECONNECT] handleStreamEnd called, resetting streaming state"); const completedContent = deps.streamingChunksRef.current.join(""); if (!completedContent.trim() && !deps.hasResponseRef.current) { + console.info("[SSE-RECONNECT] No content received, adding placeholder message"); deps.setMessages((prev) => [ ...prev, { @@ -261,10 +263,14 @@ export function handleStreamEnd( export function handleError(chunk: StreamChunk, deps: HandlerDependencies) { const errorMessage = chunk.message || chunk.content || "An error occurred"; - console.error("Stream error:", errorMessage); + console.error("[ChatStream] Stream error:", errorMessage, { + sessionId: deps.sessionId, + chunk, + }); if (isRegionBlockedError(chunk)) { deps.setIsRegionBlockedModalOpen(true); } + console.info("[ChatStream] Resetting streaming state due to error"); deps.setIsStreamingInitiated(false); deps.setHasTextChunks(false); deps.setStreamingChunks([]); diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts index f72b64aae2..6d18295044 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts @@ -41,6 +41,11 @@ interface Args { initialMessages: SessionDetailResponse["messages"]; initialPrompt?: string; onOperationStarted?: () => void; + /** Active stream info from the server for reconnection */ + activeStream?: { + taskId: string; + lastMessageId: string; + }; } export function useChatContainer({ @@ -48,6 +53,7 @@ export function useChatContainer({ initialMessages, initialPrompt, onOperationStarted, + activeStream, }: Args) { const [messages, setMessages] = useState([]); const [streamingChunks, setStreamingChunks] = useState([]); @@ -69,6 +75,8 @@ export function useChatContainer({ const getActiveTask = useChatStore((s) => s.getActiveTask); const reconnectToTask = useChatStore((s) => s.reconnectToTask); const isStreaming = isStreamingInitiated || hasTextChunks; + // Track whether we've already connected to this activeStream to avoid duplicate connections + const connectedActiveStreamRef = useRef(null); // Callback to store active task info for SSE reconnection function handleActiveTaskStarted(taskInfo: { @@ -88,25 +96,131 @@ export function useChatContainer({ useEffect( function handleSessionChange() { - if (sessionId === previousSessionIdRef.current) return; + const isSessionChange = sessionId !== previousSessionIdRef.current; - const prevSession = previousSessionIdRef.current; - if (prevSession) { - stopStreaming(prevSession); + console.info("[SSE-RECONNECT] handleSessionChange effect running:", { + sessionId, + previousSessionId: previousSessionIdRef.current, + isSessionChange, + hasActiveStream: !!activeStream, + activeStreamTaskId: activeStream?.taskId, + connectedActiveStream: connectedActiveStreamRef.current, + }); + + // Handle session change - reset state + if (isSessionChange) { + console.info("[SSE-RECONNECT] Session changed, resetting state"); + const prevSession = previousSessionIdRef.current; + if (prevSession) { + stopStreaming(prevSession); + } + previousSessionIdRef.current = sessionId; + connectedActiveStreamRef.current = null; // Reset connected stream tracker + setMessages([]); + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + setIsStreamingInitiated(false); + hasResponseRef.current = false; } - previousSessionIdRef.current = sessionId; - setMessages([]); - setStreamingChunks([]); - streamingChunksRef.current = []; - setHasTextChunks(false); - setIsStreamingInitiated(false); - hasResponseRef.current = false; - if (!sessionId) return; + if (!sessionId) { + console.info("[SSE-RECONNECT] No sessionId, skipping reconnection check"); + return; + } - // Check if there's an active task for this session that we should reconnect to + // Priority 1: Check if server told us there's an active stream (most authoritative) + // Also handles the case where activeStream arrives after initial session load + if (activeStream) { + // Skip if we've already connected to this exact stream + // Check and set immediately to prevent race conditions from effect re-runs + const streamKey = `${sessionId}:${activeStream.taskId}`; + if (connectedActiveStreamRef.current === streamKey) { + console.info( + "[SSE-RECONNECT] Already connected to this stream, skipping:", + { streamKey }, + ); + return; + } + + // Also skip if there's already an active stream for this session in the store + // (handles case where effect re-runs due to activeStreams state change) + const existingStream = activeStreams.get(sessionId); + if (existingStream && existingStream.status === "streaming") { + console.info( + "[SSE-RECONNECT] Active stream already exists in store, skipping:", + { sessionId, status: existingStream.status }, + ); + connectedActiveStreamRef.current = streamKey; + return; + } + + // Set immediately after check to prevent race conditions + connectedActiveStreamRef.current = streamKey; + + console.info( + "[SSE-RECONNECT] Server reports active stream, initiating reconnection:", + { + sessionId, + taskId: activeStream.taskId, + lastMessageId: activeStream.lastMessageId, + streamKey, + }, + ); + + const dispatcher = createStreamEventDispatcher({ + setHasTextChunks, + setStreamingChunks, + streamingChunksRef, + hasResponseRef, + setMessages, + setIsRegionBlockedModalOpen, + sessionId, + setIsStreamingInitiated, + onOperationStarted, + onActiveTaskStarted: handleActiveTaskStarted, + }); + + setIsStreamingInitiated(true); + // Store this as the active task for future reconnects + setActiveTask(sessionId, { + taskId: activeStream.taskId, + operationId: activeStream.taskId, + toolName: "chat", + lastMessageId: activeStream.lastMessageId, + }); + // Reconnect to the task stream + console.info("[SSE-RECONNECT] Calling reconnectToTask..."); + reconnectToTask( + sessionId, + activeStream.taskId, + activeStream.lastMessageId, + dispatcher, + ); + return; + } + + // Only check localStorage/in-memory on session change, not on every render + if (!isSessionChange) { + console.info( + "[SSE-RECONNECT] No active stream and not a session change, skipping fallbacks", + ); + return; + } + + // Priority 2: Check localStorage for active task (client-side state) + console.info("[SSE-RECONNECT] Checking localStorage for active task..."); const activeTask = getActiveTask(sessionId); if (activeTask) { + console.info( + "[SSE-RECONNECT] Found active task in localStorage, attempting reconnect:", + { + sessionId, + taskId: activeTask.taskId, + lastMessageId: activeTask.lastMessageId, + }, + ); + const dispatcher = createStreamEventDispatcher({ setHasTextChunks, setStreamingChunks, @@ -122,6 +236,7 @@ export function useChatContainer({ setIsStreamingInitiated(true); // Reconnect to the task stream + console.info("[SSE-RECONNECT] Calling reconnectToTask from localStorage..."); reconnectToTask( sessionId, activeTask.taskId, @@ -129,11 +244,20 @@ export function useChatContainer({ dispatcher, ); return; + } else { + console.info("[SSE-RECONNECT] No active task in localStorage"); } - // Otherwise check for an in-memory active stream - const activeStream = activeStreams.get(sessionId); - if (!activeStream || activeStream.status !== "streaming") return; + // Priority 3: Check for an in-memory active stream (same-tab scenario) + console.info("[SSE-RECONNECT] Checking in-memory active streams..."); + const inMemoryStream = activeStreams.get(sessionId); + if (!inMemoryStream || inMemoryStream.status !== "streaming") { + console.info("[SSE-RECONNECT] No in-memory active stream found:", { + hasStream: !!inMemoryStream, + status: inMemoryStream?.status, + }); + return; + } const dispatcher = createStreamEventDispatcher({ setHasTextChunks, @@ -160,6 +284,8 @@ export function useChatContainer({ onOperationStarted, getActiveTask, reconnectToTask, + activeStream, + setActiveTask, ], ); diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/stream-executor.ts b/autogpt_platform/frontend/src/components/contextual/Chat/stream-executor.ts index 97781ebf06..b491a14024 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/stream-executor.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/stream-executor.ts @@ -167,8 +167,15 @@ export async function executeTaskReconnect( ): Promise { const { abortController } = stream; + console.info("[SSE-RECONNECT] executeTaskReconnect starting:", { + taskId, + lastMessageId, + retryCount, + }); + try { const url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`; + console.info("[SSE-RECONNECT] Fetching task stream:", { url }); const response = await fetch(url, { method: "GET", @@ -178,15 +185,33 @@ export async function executeTaskReconnect( signal: abortController.signal, }); + console.info("[SSE-RECONNECT] Task stream response:", { + status: response.status, + ok: response.ok, + }); + if (!response.ok) { const errorText = await response.text(); - throw new Error(errorText || `HTTP ${response.status}`); + console.error("[SSE-RECONNECT] Task stream error response:", { + status: response.status, + errorText, + }); + // Don't retry on 404 (task not found) or 403 (access denied) - these are permanent errors + const isPermanentError = + response.status === 404 || response.status === 403; + const error = new Error(errorText || `HTTP ${response.status}`); + (error as Error & { status?: number }).status = response.status; + (error as Error & { isPermanent?: boolean }).isPermanent = + isPermanentError; + throw error; } if (!response.body) { throw new Error("Response body is null"); } + console.info("[SSE-RECONNECT] Task stream connected, reading chunks..."); + const reader = response.body.getReader(); const decoder = new TextDecoder(); let buffer = ""; @@ -195,6 +220,7 @@ export async function executeTaskReconnect( const { done, value } = await reader.read(); if (done) { + console.info("[SSE-RECONNECT] Task stream reader done (connection closed)"); notifySubscribers(stream, { type: "stream_end" }); stream.status = "completed"; return; @@ -208,6 +234,7 @@ export async function executeTaskReconnect( const data = parseSSELine(line); if (data !== null) { if (data === "[DONE]") { + console.info("[SSE-RECONNECT] Task stream received [DONE] signal"); notifySubscribers(stream, { type: "stream_end" }); stream.status = "completed"; return; @@ -220,14 +247,24 @@ export async function executeTaskReconnect( const chunk = normalizeStreamChunk(rawChunk); if (!chunk) continue; + // Log first few chunks for debugging + if (stream.chunks.length < 3) { + console.info("[SSE-RECONNECT] Task stream chunk received:", { + type: chunk.type, + chunkIndex: stream.chunks.length, + }); + } + notifySubscribers(stream, chunk); if (chunk.type === "stream_end") { + console.info("[SSE-RECONNECT] Task stream completed via stream_end chunk"); stream.status = "completed"; return; } if (chunk.type === "error") { + console.error("[SSE-RECONNECT] Task stream error chunk:", chunk); stream.status = "error"; stream.error = new Error( chunk.message || chunk.content || "Stream error", @@ -250,17 +287,35 @@ export async function executeTaskReconnect( return; } - if (retryCount < MAX_RETRIES) { + // Check if this is a permanent error (404/403) that shouldn't be retried + const isPermanentError = + err instanceof Error && + (err as Error & { isPermanent?: boolean }).isPermanent; + + if (!isPermanentError && retryCount < MAX_RETRIES) { const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount); console.log( `[StreamExecutor] Task reconnect retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`, ); await new Promise((resolve) => setTimeout(resolve, retryDelay)); - return executeTaskReconnect(stream, taskId, lastMessageId, retryCount + 1); + return executeTaskReconnect( + stream, + taskId, + lastMessageId, + retryCount + 1, + ); + } + + // Log permanent errors differently for debugging + if (isPermanentError) { + console.log( + `[StreamExecutor] Task reconnect failed permanently (task not found or access denied): ${(err as Error).message}`, + ); } stream.status = "error"; - stream.error = err instanceof Error ? err : new Error("Task reconnect failed"); + stream.error = + err instanceof Error ? err : new Error("Task reconnect failed"); notifySubscribers(stream, { type: "error", message: stream.error.message, diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/stream-utils.ts b/autogpt_platform/frontend/src/components/contextual/Chat/stream-utils.ts index 4100926e79..d1953ef0e6 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/stream-utils.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/stream-utils.ts @@ -28,7 +28,8 @@ export function normalizeStreamChunk( switch (chunk.type) { case "text-delta": - return { type: "text_chunk", content: chunk.delta }; + // Backend sends "content", Vercel AI SDK sends "delta" + return { type: "text_chunk", content: chunk.delta || chunk.content }; case "text-end": return { type: "text_ended" }; case "tool-input-available": @@ -63,6 +64,10 @@ export function normalizeStreamChunk( case "finish": return { type: "stream_end" }; case "start": + // Start event with optional taskId for reconnection + return chunk.taskId + ? { type: "stream_start", taskId: chunk.taskId } + : null; case "text-start": return null; case "tool-input-start":