From f4bf492f24da4790dbf9f9df78061144d34034ac Mon Sep 17 00:00:00 2001 From: Swifty Date: Tue, 3 Feb 2026 16:52:06 +0100 Subject: [PATCH] feat(platform): Add Redis-based SSE reconnection for long-running CoPilot operations (#11877) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Changes 🏗️ Adds Redis-based SSE reconnection support for long-running CoPilot operations (like Agent Generator), enabling clients to reconnect and resume receiving updates after disconnection. ### What this does: - **Stream Registry** - Redis-backed task tracking with message persistence via Redis Streams - **SSE Reconnection** - Clients can reconnect to active tasks using `task_id` and `last_message_id` - **Duplicate Message Fix** - Filters out in-progress assistant messages from session response when active stream exists - **Completion Consumer** - Handles background task completion notifications via Redis Streams ### Architecture: ``` 1. User sends message → Backend creates task in Redis 2. SSE chunks written to Redis Stream for persistence 3. Client receives chunks via SSE subscription 4. If client disconnects → Task continues in background 5. Client reconnects → GET /sessions/{id} returns active_stream info 6. Client subscribes to /tasks/{task_id}/stream with last_message_id 7. Missed messages replayed from Redis Stream ``` ### Key endpoints: - `GET /sessions/{session_id}` - Returns `active_stream` info if task is running - `GET /tasks/{task_id}/stream?last_message_id=X` - SSE endpoint for reconnection - `GET /tasks/{task_id}` - Get task status - `POST /operations/{op_id}/complete` - Webhook for external service completion ### Duplicate message fix: When `GET /sessions/{id}` detects an active stream: 1. Filters out the in-progress assistant message from response 2. Returns `last_message_id="0-0"` so client replays stream from beginning 3. Client receives complete response only through SSE (single source of truth) ### Frontend changes: - Task persistence in localStorage for cross-tab reconnection - Stream event dispatcher handles reconnection flow - Deduplication logic prevents duplicate messages ### Testing: - Manual testing of reconnection scenarios - Verified duplicate message fix works correctly ## Related - Resolves SSE timeout issues for Agent Generator - Fixes duplicate message bug on reconnection --- .../api/features/chat/completion_consumer.py | 368 +++++++++ .../api/features/chat/completion_handler.py | 344 +++++++++ .../backend/api/features/chat/config.py | 50 ++ .../api/features/chat/response_model.py | 4 + .../backend/api/features/chat/routes.py | 427 ++++++++++- .../backend/api/features/chat/service.py | 291 +++++++- .../api/features/chat/stream_registry.py | 704 ++++++++++++++++++ .../chat/tools/agent_generator/core.py | 27 +- .../chat/tools/agent_generator/service.py | 46 +- .../api/features/chat/tools/create_agent.py | 25 +- .../api/features/chat/tools/edit_agent.py | 25 +- .../backend/api/features/chat/tools/models.py | 21 + .../api/features/store/embeddings_e2e_test.py | 9 +- .../backend/backend/api/rest_api.py | 16 + .../agent_generator/test_core_integration.py | 9 +- .../CopilotShell/useCopilotShell.ts | 73 +- .../api/chat/tasks/[taskId]/stream/route.ts | 81 ++ .../frontend/src/app/api/openapi.json | 210 +++++- .../src/components/contextual/Chat/Chat.tsx | 36 +- .../contextual/Chat/SSE_RECONNECTION.md | 159 ++++ .../contextual/Chat/chat-constants.ts | 16 + .../components/contextual/Chat/chat-store.ts | 348 +++++++-- .../components/contextual/Chat/chat-types.ts | 71 +- .../ChatContainer/ChatContainer.tsx | 9 + .../createStreamEventDispatcher.ts | 27 +- .../Chat/components/ChatContainer/handlers.ts | 154 +++- .../Chat/components/ChatContainer/helpers.ts | 1 + .../ChatContainer/useChatContainer.ts | 343 +++++++-- .../components/ChatMessage/useChatMessage.ts | 1 + .../components/MessageList/MessageList.tsx | 5 - .../contextual/Chat/stream-executor.ts | 189 ++++- .../contextual/Chat/stream-utils.ts | 5 + 32 files changed, 3747 insertions(+), 347 deletions(-) create mode 100644 autogpt_platform/backend/backend/api/features/chat/completion_consumer.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/completion_handler.py create mode 100644 autogpt_platform/backend/backend/api/features/chat/stream_registry.py create mode 100644 autogpt_platform/frontend/src/app/api/chat/tasks/[taskId]/stream/route.ts create mode 100644 autogpt_platform/frontend/src/components/contextual/Chat/SSE_RECONNECTION.md create mode 100644 autogpt_platform/frontend/src/components/contextual/Chat/chat-constants.ts diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py new file mode 100644 index 0000000000..f447d46bd7 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py @@ -0,0 +1,368 @@ +"""Redis Streams consumer for operation completion messages. + +This module provides a consumer (ChatCompletionConsumer) that listens for +completion notifications (OperationCompleteMessage) from external services +(like Agent Generator) and triggers the appropriate stream registry and +chat service updates via process_operation_success/process_operation_failure. + +Why Redis Streams instead of RabbitMQ? +-------------------------------------- +While the project typically uses RabbitMQ for async task queues (e.g., execution +queue), Redis Streams was chosen for chat completion notifications because: + +1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis + Streams (via stream_registry) for message persistence and replay. Using Redis + Streams for completion notifications keeps all chat streaming infrastructure + in one system, simplifying operations and reducing cross-system coordination. + +2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs, + allowing consumers to replay missed messages after reconnection. This aligns + with the SSE reconnection pattern where clients can resume from last_message_id. + +3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic + load balancing across pods with explicit message claiming (XAUTOCLAIM) for + recovering from dead consumers - ideal for the completion callback pattern. + +4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for + stream_registry) provides lower latency than an additional RabbitMQ hop. + +5. **Atomicity with Task State**: Completion processing often needs to update + task metadata stored in Redis. Keeping both in Redis enables simpler + transactional semantics without distributed coordination. + +The consumer uses Redis Streams with consumer groups for reliable message +processing across multiple platform pods, with XAUTOCLAIM for reclaiming +stale pending messages from dead consumers. +""" + +import asyncio +import logging +import os +import uuid +from typing import Any + +import orjson +from prisma import Prisma +from pydantic import BaseModel +from redis.exceptions import ResponseError + +from backend.data.redis_client import get_redis_async + +from . import stream_registry +from .completion_handler import process_operation_failure, process_operation_success +from .config import ChatConfig + +logger = logging.getLogger(__name__) +config = ChatConfig() + + +class OperationCompleteMessage(BaseModel): + """Message format for operation completion notifications.""" + + operation_id: str + task_id: str + success: bool + result: dict | str | None = None + error: str | None = None + + +class ChatCompletionConsumer: + """Consumer for chat operation completion messages from Redis Streams. + + This consumer initializes its own Prisma client in start() to ensure + database operations work correctly within this async context. + + Uses Redis consumer groups to allow multiple platform pods to consume + messages reliably with automatic redelivery on failure. + """ + + def __init__(self): + self._consumer_task: asyncio.Task | None = None + self._running = False + self._prisma: Prisma | None = None + self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}" + + async def start(self) -> None: + """Start the completion consumer.""" + if self._running: + logger.warning("Completion consumer already running") + return + + # Create consumer group if it doesn't exist + try: + redis = await get_redis_async() + await redis.xgroup_create( + config.stream_completion_name, + config.stream_consumer_group, + id="0", + mkstream=True, + ) + logger.info( + f"Created consumer group '{config.stream_consumer_group}' " + f"on stream '{config.stream_completion_name}'" + ) + except ResponseError as e: + if "BUSYGROUP" in str(e): + logger.debug( + f"Consumer group '{config.stream_consumer_group}' already exists" + ) + else: + raise + + self._running = True + self._consumer_task = asyncio.create_task(self._consume_messages()) + logger.info( + f"Chat completion consumer started (consumer: {self._consumer_name})" + ) + + async def _ensure_prisma(self) -> Prisma: + """Lazily initialize Prisma client on first use.""" + if self._prisma is None: + database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432") + self._prisma = Prisma(datasource={"url": database_url}) + await self._prisma.connect() + logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)") + return self._prisma + + async def stop(self) -> None: + """Stop the completion consumer.""" + self._running = False + + if self._consumer_task: + self._consumer_task.cancel() + try: + await self._consumer_task + except asyncio.CancelledError: + pass + self._consumer_task = None + + if self._prisma: + await self._prisma.disconnect() + self._prisma = None + logger.info("[COMPLETION] Consumer Prisma client disconnected") + + logger.info("Chat completion consumer stopped") + + async def _consume_messages(self) -> None: + """Main message consumption loop with retry logic.""" + max_retries = 10 + retry_delay = 5 # seconds + retry_count = 0 + block_timeout = 5000 # milliseconds + + while self._running and retry_count < max_retries: + try: + redis = await get_redis_async() + + # Reset retry count on successful connection + retry_count = 0 + + while self._running: + # First, claim any stale pending messages from dead consumers + # Redis does NOT auto-redeliver pending messages; we must explicitly + # claim them using XAUTOCLAIM + try: + claimed_result = await redis.xautoclaim( + name=config.stream_completion_name, + groupname=config.stream_consumer_group, + consumername=self._consumer_name, + min_idle_time=config.stream_claim_min_idle_ms, + start_id="0-0", + count=10, + ) + # xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids]) + if claimed_result and len(claimed_result) >= 2: + claimed_entries = claimed_result[1] + if claimed_entries: + logger.info( + f"Claimed {len(claimed_entries)} stale pending messages" + ) + for entry_id, data in claimed_entries: + if not self._running: + return + await self._process_entry(redis, entry_id, data) + except Exception as e: + logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}") + + # Read new messages from the stream + messages = await redis.xreadgroup( + groupname=config.stream_consumer_group, + consumername=self._consumer_name, + streams={config.stream_completion_name: ">"}, + block=block_timeout, + count=10, + ) + + if not messages: + continue + + for stream_name, entries in messages: + for entry_id, data in entries: + if not self._running: + return + await self._process_entry(redis, entry_id, data) + + except asyncio.CancelledError: + logger.info("Consumer cancelled") + return + except Exception as e: + retry_count += 1 + logger.error( + f"Consumer error (retry {retry_count}/{max_retries}): {e}", + exc_info=True, + ) + if self._running and retry_count < max_retries: + await asyncio.sleep(retry_delay) + else: + logger.error("Max retries reached, stopping consumer") + return + + async def _process_entry( + self, redis: Any, entry_id: str, data: dict[str, Any] + ) -> None: + """Process a single stream entry and acknowledge it on success. + + Args: + redis: Redis client connection + entry_id: The stream entry ID + data: The entry data dict + """ + try: + # Handle the message + message_data = data.get("data") + if message_data: + await self._handle_message( + message_data.encode() + if isinstance(message_data, str) + else message_data + ) + + # Acknowledge the message after successful processing + await redis.xack( + config.stream_completion_name, + config.stream_consumer_group, + entry_id, + ) + except Exception as e: + logger.error( + f"Error processing completion message {entry_id}: {e}", + exc_info=True, + ) + # Message remains in pending state and will be claimed by + # XAUTOCLAIM after min_idle_time expires + + async def _handle_message(self, body: bytes) -> None: + """Handle a completion message using our own Prisma client.""" + try: + data = orjson.loads(body) + message = OperationCompleteMessage(**data) + except Exception as e: + logger.error(f"Failed to parse completion message: {e}") + return + + logger.info( + f"[COMPLETION] Received completion for operation {message.operation_id} " + f"(task_id={message.task_id}, success={message.success})" + ) + + # Find task in registry + task = await stream_registry.find_task_by_operation_id(message.operation_id) + if task is None: + task = await stream_registry.get_task(message.task_id) + + if task is None: + logger.warning( + f"[COMPLETION] Task not found for operation {message.operation_id} " + f"(task_id={message.task_id})" + ) + return + + logger.info( + f"[COMPLETION] Found task: task_id={task.task_id}, " + f"session_id={task.session_id}, tool_call_id={task.tool_call_id}" + ) + + # Guard against empty task fields + if not task.task_id or not task.session_id or not task.tool_call_id: + logger.error( + f"[COMPLETION] Task has empty critical fields! " + f"task_id={task.task_id!r}, session_id={task.session_id!r}, " + f"tool_call_id={task.tool_call_id!r}" + ) + return + + if message.success: + await self._handle_success(task, message) + else: + await self._handle_failure(task, message) + + async def _handle_success( + self, + task: stream_registry.ActiveTask, + message: OperationCompleteMessage, + ) -> None: + """Handle successful operation completion.""" + prisma = await self._ensure_prisma() + await process_operation_success(task, message.result, prisma) + + async def _handle_failure( + self, + task: stream_registry.ActiveTask, + message: OperationCompleteMessage, + ) -> None: + """Handle failed operation completion.""" + prisma = await self._ensure_prisma() + await process_operation_failure(task, message.error, prisma) + + +# Module-level consumer instance +_consumer: ChatCompletionConsumer | None = None + + +async def start_completion_consumer() -> None: + """Start the global completion consumer.""" + global _consumer + if _consumer is None: + _consumer = ChatCompletionConsumer() + await _consumer.start() + + +async def stop_completion_consumer() -> None: + """Stop the global completion consumer.""" + global _consumer + if _consumer: + await _consumer.stop() + _consumer = None + + +async def publish_operation_complete( + operation_id: str, + task_id: str, + success: bool, + result: dict | str | None = None, + error: str | None = None, +) -> None: + """Publish an operation completion message to Redis Streams. + + Args: + operation_id: The operation ID that completed. + task_id: The task ID associated with the operation. + success: Whether the operation succeeded. + result: The result data (for success). + error: The error message (for failure). + """ + message = OperationCompleteMessage( + operation_id=operation_id, + task_id=task_id, + success=success, + result=result, + error=error, + ) + + redis = await get_redis_async() + await redis.xadd( + config.stream_completion_name, + {"data": message.model_dump_json()}, + maxlen=config.stream_max_length, + ) + logger.info(f"Published completion for operation {operation_id}") diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_handler.py b/autogpt_platform/backend/backend/api/features/chat/completion_handler.py new file mode 100644 index 0000000000..905fa2ddba --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/completion_handler.py @@ -0,0 +1,344 @@ +"""Shared completion handling for operation success and failure. + +This module provides common logic for handling operation completion from both: +- The Redis Streams consumer (completion_consumer.py) +- The HTTP webhook endpoint (routes.py) +""" + +import logging +from typing import Any + +import orjson +from prisma import Prisma + +from . import service as chat_service +from . import stream_registry +from .response_model import StreamError, StreamToolOutputAvailable +from .tools.models import ErrorResponse + +logger = logging.getLogger(__name__) + +# Tools that produce agent_json that needs to be saved to library +AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"} + +# Keys that should be stripped from agent_json when returning in error responses +SENSITIVE_KEYS = frozenset( + { + "api_key", + "apikey", + "api_secret", + "password", + "secret", + "credentials", + "credential", + "token", + "access_token", + "refresh_token", + "private_key", + "privatekey", + "auth", + "authorization", + } +) + + +def _sanitize_agent_json(obj: Any) -> Any: + """Recursively sanitize agent_json by removing sensitive keys. + + Args: + obj: The object to sanitize (dict, list, or primitive) + + Returns: + Sanitized copy with sensitive keys removed/redacted + """ + if isinstance(obj, dict): + return { + k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v) + for k, v in obj.items() + } + elif isinstance(obj, list): + return [_sanitize_agent_json(item) for item in obj] + else: + return obj + + +class ToolMessageUpdateError(Exception): + """Raised when updating a tool message in the database fails.""" + + pass + + +async def _update_tool_message( + session_id: str, + tool_call_id: str, + content: str, + prisma_client: Prisma | None, +) -> None: + """Update tool message in database. + + Args: + session_id: The session ID + tool_call_id: The tool call ID to update + content: The new content for the message + prisma_client: Optional Prisma client. If None, uses chat_service. + + Raises: + ToolMessageUpdateError: If the database update fails. The caller should + handle this to avoid marking the task as completed with inconsistent state. + """ + try: + if prisma_client: + # Use provided Prisma client (for consumer with its own connection) + updated_count = await prisma_client.chatmessage.update_many( + where={ + "sessionId": session_id, + "toolCallId": tool_call_id, + }, + data={"content": content}, + ) + # Check if any rows were updated - 0 means message not found + if updated_count == 0: + raise ToolMessageUpdateError( + f"No message found with tool_call_id={tool_call_id} in session {session_id}" + ) + else: + # Use service function (for webhook endpoint) + await chat_service._update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=content, + ) + except ToolMessageUpdateError: + raise + except Exception as e: + logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True) + raise ToolMessageUpdateError( + f"Failed to update tool message for tool_call_id={tool_call_id}: {e}" + ) from e + + +def serialize_result(result: dict | list | str | int | float | bool | None) -> str: + """Serialize result to JSON string with sensible defaults. + + Args: + result: The result to serialize. Can be a dict, list, string, + number, boolean, or None. + + Returns: + JSON string representation of the result. Returns '{"status": "completed"}' + only when result is explicitly None. + """ + if isinstance(result, str): + return result + if result is None: + return '{"status": "completed"}' + return orjson.dumps(result).decode("utf-8") + + +async def _save_agent_from_result( + result: dict[str, Any], + user_id: str | None, + tool_name: str, +) -> dict[str, Any]: + """Save agent to library if result contains agent_json. + + Args: + result: The result dict that may contain agent_json + user_id: The user ID to save the agent for + tool_name: The tool name (create_agent or edit_agent) + + Returns: + Updated result dict with saved agent details, or original result if no agent_json + """ + if not user_id: + logger.warning("[COMPLETION] Cannot save agent: no user_id in task") + return result + + agent_json = result.get("agent_json") + if not agent_json: + logger.warning( + f"[COMPLETION] {tool_name} completed but no agent_json in result" + ) + return result + + try: + from .tools.agent_generator import save_agent_to_library + + is_update = tool_name == "edit_agent" + created_graph, library_agent = await save_agent_to_library( + agent_json, user_id, is_update=is_update + ) + + logger.info( + f"[COMPLETION] Saved agent '{created_graph.name}' to library " + f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})" + ) + + # Return a response similar to AgentSavedResponse + return { + "type": "agent_saved", + "message": f"Agent '{created_graph.name}' has been saved to your library!", + "agent_id": created_graph.id, + "agent_name": created_graph.name, + "library_agent_id": library_agent.id, + "library_agent_link": f"/library/agents/{library_agent.id}", + "agent_page_link": f"/build?flowID={created_graph.id}", + } + except Exception as e: + logger.error( + f"[COMPLETION] Failed to save agent to library: {e}", + exc_info=True, + ) + # Return error but don't fail the whole operation + # Sanitize agent_json to remove sensitive keys before returning + return { + "type": "error", + "message": f"Agent was generated but failed to save: {str(e)}", + "error": str(e), + "agent_json": _sanitize_agent_json(agent_json), + } + + +async def process_operation_success( + task: stream_registry.ActiveTask, + result: dict | str | None, + prisma_client: Prisma | None = None, +) -> None: + """Handle successful operation completion. + + Publishes the result to the stream registry, updates the database, + generates LLM continuation, and marks the task as completed. + + Args: + task: The active task that completed + result: The result data from the operation + prisma_client: Optional Prisma client for database operations. + If None, uses chat_service._update_pending_operation instead. + + Raises: + ToolMessageUpdateError: If the database update fails. The task will be + marked as failed instead of completed to avoid inconsistent state. + """ + # For agent generation tools, save the agent to library + if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict): + result = await _save_agent_from_result(result, task.user_id, task.tool_name) + + # Serialize result for output (only substitute default when result is exactly None) + result_output = result if result is not None else {"status": "completed"} + output_str = ( + result_output + if isinstance(result_output, str) + else orjson.dumps(result_output).decode("utf-8") + ) + + # Publish result to stream registry + await stream_registry.publish_chunk( + task.task_id, + StreamToolOutputAvailable( + toolCallId=task.tool_call_id, + toolName=task.tool_name, + output=output_str, + success=True, + ), + ) + + # Update pending operation in database + # If this fails, we must not continue to mark the task as completed + result_str = serialize_result(result) + try: + await _update_tool_message( + session_id=task.session_id, + tool_call_id=task.tool_call_id, + content=result_str, + prisma_client=prisma_client, + ) + except ToolMessageUpdateError: + # DB update failed - mark task as failed to avoid inconsistent state + logger.error( + f"[COMPLETION] DB update failed for task {task.task_id}, " + "marking as failed instead of completed" + ) + await stream_registry.publish_chunk( + task.task_id, + StreamError(errorText="Failed to save operation result to database"), + ) + await stream_registry.mark_task_completed(task.task_id, status="failed") + raise + + # Generate LLM continuation with streaming + try: + await chat_service._generate_llm_continuation_with_streaming( + session_id=task.session_id, + user_id=task.user_id, + task_id=task.task_id, + ) + except Exception as e: + logger.error( + f"[COMPLETION] Failed to generate LLM continuation: {e}", + exc_info=True, + ) + + # Mark task as completed and release Redis lock + await stream_registry.mark_task_completed(task.task_id, status="completed") + try: + await chat_service._mark_operation_completed(task.tool_call_id) + except Exception as e: + logger.error(f"[COMPLETION] Failed to mark operation completed: {e}") + + logger.info( + f"[COMPLETION] Successfully processed completion for task {task.task_id}" + ) + + +async def process_operation_failure( + task: stream_registry.ActiveTask, + error: str | None, + prisma_client: Prisma | None = None, +) -> None: + """Handle failed operation completion. + + Publishes the error to the stream registry, updates the database with + the error response, and marks the task as failed. + + Args: + task: The active task that failed + error: The error message from the operation + prisma_client: Optional Prisma client for database operations. + If None, uses chat_service._update_pending_operation instead. + """ + error_msg = error or "Operation failed" + + # Publish error to stream registry + await stream_registry.publish_chunk( + task.task_id, + StreamError(errorText=error_msg), + ) + + # Update pending operation with error + # If this fails, we still continue to mark the task as failed + error_response = ErrorResponse( + message=error_msg, + error=error, + ) + try: + await _update_tool_message( + session_id=task.session_id, + tool_call_id=task.tool_call_id, + content=error_response.model_dump_json(), + prisma_client=prisma_client, + ) + except ToolMessageUpdateError: + # DB update failed - log but continue with cleanup + logger.error( + f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, " + "continuing with cleanup" + ) + + # Mark task as failed and release Redis lock + await stream_registry.mark_task_completed(task.task_id, status="failed") + try: + await chat_service._mark_operation_completed(task.tool_call_id) + except Exception as e: + logger.error(f"[COMPLETION] Failed to mark operation completed: {e}") + + logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}") diff --git a/autogpt_platform/backend/backend/api/features/chat/config.py b/autogpt_platform/backend/backend/api/features/chat/config.py index dba7934877..2e8dbf5413 100644 --- a/autogpt_platform/backend/backend/api/features/chat/config.py +++ b/autogpt_platform/backend/backend/api/features/chat/config.py @@ -44,6 +44,48 @@ class ChatConfig(BaseSettings): description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)", ) + # Stream registry configuration for SSE reconnection + stream_ttl: int = Field( + default=3600, + description="TTL in seconds for stream data in Redis (1 hour)", + ) + stream_max_length: int = Field( + default=10000, + description="Maximum number of messages to store per stream", + ) + + # Redis Streams configuration for completion consumer + stream_completion_name: str = Field( + default="chat:completions", + description="Redis Stream name for operation completions", + ) + stream_consumer_group: str = Field( + default="chat_consumers", + description="Consumer group name for completion stream", + ) + stream_claim_min_idle_ms: int = Field( + default=60000, + description="Minimum idle time in milliseconds before claiming pending messages from dead consumers", + ) + + # Redis key prefixes for stream registry + task_meta_prefix: str = Field( + default="chat:task:meta:", + description="Prefix for task metadata hash keys", + ) + task_stream_prefix: str = Field( + default="chat:stream:", + description="Prefix for task message stream keys", + ) + task_op_prefix: str = Field( + default="chat:task:op:", + description="Prefix for operation ID to task ID mapping keys", + ) + internal_api_key: str | None = Field( + default=None, + description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)", + ) + # Langfuse Prompt Management Configuration # Note: Langfuse credentials are in Settings().secrets (settings.py) langfuse_prompt_name: str = Field( @@ -82,6 +124,14 @@ class ChatConfig(BaseSettings): v = "https://openrouter.ai/api/v1" return v + @field_validator("internal_api_key", mode="before") + @classmethod + def get_internal_api_key(cls, v): + """Get internal API key from environment if not provided.""" + if v is None: + v = os.getenv("CHAT_INTERNAL_API_KEY") + return v + # Prompt paths for different contexts PROMPT_PATHS: dict[str, str] = { "default": "prompts/chat_system.md", diff --git a/autogpt_platform/backend/backend/api/features/chat/response_model.py b/autogpt_platform/backend/backend/api/features/chat/response_model.py index 53a8cf3a1f..f627a42fcc 100644 --- a/autogpt_platform/backend/backend/api/features/chat/response_model.py +++ b/autogpt_platform/backend/backend/api/features/chat/response_model.py @@ -52,6 +52,10 @@ class StreamStart(StreamBaseResponse): type: ResponseType = ResponseType.START messageId: str = Field(..., description="Unique message ID") + taskId: str | None = Field( + default=None, + description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream", + ) class StreamFinish(StreamBaseResponse): diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index cab51543b1..3e731d86ac 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -1,19 +1,23 @@ """Chat API routes for chat session management and streaming via SSE.""" import logging +import uuid as uuid_module from collections.abc import AsyncGenerator from typing import Annotated from autogpt_libs import auth -from fastapi import APIRouter, Depends, Query, Security +from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security from fastapi.responses import StreamingResponse from pydantic import BaseModel from backend.util.exceptions import NotFoundError from . import service as chat_service +from . import stream_registry +from .completion_handler import process_operation_failure, process_operation_success from .config import ChatConfig from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions +from .response_model import StreamFinish, StreamHeartbeat, StreamStart config = ChatConfig() @@ -55,6 +59,15 @@ class CreateSessionResponse(BaseModel): user_id: str | None +class ActiveStreamInfo(BaseModel): + """Information about an active stream for reconnection.""" + + task_id: str + last_message_id: str # Redis Stream message ID for resumption + operation_id: str # Operation ID for completion tracking + tool_name: str # Name of the tool being executed + + class SessionDetailResponse(BaseModel): """Response model providing complete details for a chat session, including messages.""" @@ -63,6 +76,7 @@ class SessionDetailResponse(BaseModel): updated_at: str user_id: str | None messages: list[dict] + active_stream: ActiveStreamInfo | None = None # Present if stream is still active class SessionSummaryResponse(BaseModel): @@ -81,6 +95,14 @@ class ListSessionsResponse(BaseModel): total: int +class OperationCompleteRequest(BaseModel): + """Request model for external completion webhook.""" + + success: bool + result: dict | str | None = None + error: str | None = None + + # ========== Routes ========== @@ -166,13 +188,14 @@ async def get_session( Retrieve the details of a specific chat session. Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages. + If there's an active stream for this session, returns the task_id for reconnection. Args: session_id: The unique identifier for the desired chat session. user_id: The optional authenticated user ID, or None for anonymous access. Returns: - SessionDetailResponse: Details for the requested session, or None if not found. + SessionDetailResponse: Details for the requested session, including active_stream info if applicable. """ session = await get_chat_session(session_id, user_id) @@ -180,11 +203,28 @@ async def get_session( raise NotFoundError(f"Session {session_id} not found.") messages = [message.model_dump() for message in session.messages] - logger.info( - f"Returning session {session_id}: " - f"message_count={len(messages)}, " - f"roles={[m.get('role') for m in messages]}" + + # Check if there's an active stream for this session + active_stream_info = None + active_task, last_message_id = await stream_registry.get_active_task_for_session( + session_id, user_id ) + if active_task: + # Filter out the in-progress assistant message from the session response. + # The client will receive the complete assistant response through the SSE + # stream replay instead, preventing duplicate content. + if messages and messages[-1].get("role") == "assistant": + messages = messages[:-1] + + # Use "0-0" as last_message_id to replay the stream from the beginning. + # Since we filtered out the cached assistant message, the client needs + # the full stream to reconstruct the response. + active_stream_info = ActiveStreamInfo( + task_id=active_task.task_id, + last_message_id="0-0", + operation_id=active_task.operation_id, + tool_name=active_task.tool_name, + ) return SessionDetailResponse( id=session.session_id, @@ -192,6 +232,7 @@ async def get_session( updated_at=session.updated_at.isoformat(), user_id=session.user_id or None, messages=messages, + active_stream=active_stream_info, ) @@ -211,49 +252,112 @@ async def stream_chat_post( - Tool call UI elements (if invoked) - Tool execution results + The AI generation runs in a background task that continues even if the client disconnects. + All chunks are written to Redis for reconnection support. If the client disconnects, + they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off. + Args: session_id: The chat session identifier to associate with the streamed messages. request: Request body containing message, is_user_message, and optional context. user_id: Optional authenticated user ID. Returns: - StreamingResponse: SSE-formatted response chunks. + StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event + containing the task_id for reconnection. """ + import asyncio + session = await _validate_and_get_session(session_id, user_id) + # Create a task in the stream registry for reconnection support + task_id = str(uuid_module.uuid4()) + operation_id = str(uuid_module.uuid4()) + await stream_registry.create_task( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id="chat_stream", # Not a tool call, but needed for the model + tool_name="chat", + operation_id=operation_id, + ) + + # Background task that runs the AI generation independently of SSE connection + async def run_ai_generation(): + try: + # Emit a start event with task_id for reconnection + start_chunk = StreamStart(messageId=task_id, taskId=task_id) + await stream_registry.publish_chunk(task_id, start_chunk) + + async for chunk in chat_service.stream_chat_completion( + session_id, + request.message, + is_user_message=request.is_user_message, + user_id=user_id, + session=session, # Pass pre-fetched session to avoid double-fetch + context=request.context, + ): + # Write to Redis (subscribers will receive via XREAD) + await stream_registry.publish_chunk(task_id, chunk) + + # Mark task as completed + await stream_registry.mark_task_completed(task_id, "completed") + except Exception as e: + logger.error( + f"Error in background AI generation for session {session_id}: {e}" + ) + await stream_registry.mark_task_completed(task_id, "failed") + + # Start the AI generation in a background task + bg_task = asyncio.create_task(run_ai_generation()) + await stream_registry.set_task_asyncio_task(task_id, bg_task) + + # SSE endpoint that subscribes to the task's stream async def event_generator() -> AsyncGenerator[str, None]: - chunk_count = 0 - first_chunk_type: str | None = None - async for chunk in chat_service.stream_chat_completion( - session_id, - request.message, - is_user_message=request.is_user_message, - user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch - context=request.context, - ): - if chunk_count < 3: - logger.info( - "Chat stream chunk", - extra={ - "session_id": session_id, - "chunk_type": str(chunk.type), - }, - ) - if not first_chunk_type: - first_chunk_type = str(chunk.type) - chunk_count += 1 - yield chunk.to_sse() - logger.info( - "Chat stream completed", - extra={ - "session_id": session_id, - "chunk_count": chunk_count, - "first_chunk_type": first_chunk_type, - }, - ) - # AI SDK protocol termination - yield "data: [DONE]\n\n" + subscriber_queue = None + try: + # Subscribe to the task stream (this replays existing messages + live updates) + subscriber_queue = await stream_registry.subscribe_to_task( + task_id=task_id, + user_id=user_id, + last_message_id="0-0", # Get all messages from the beginning + ) + + if subscriber_queue is None: + yield StreamFinish().to_sse() + yield "data: [DONE]\n\n" + return + + # Read from the subscriber queue and yield to SSE + while True: + try: + chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0) + yield chunk.to_sse() + + # Check for finish signal + if isinstance(chunk, StreamFinish): + break + except asyncio.TimeoutError: + # Send heartbeat to keep connection alive + yield StreamHeartbeat().to_sse() + + except GeneratorExit: + pass # Client disconnected - background task continues + except Exception as e: + logger.error(f"Error in SSE stream for task {task_id}: {e}") + finally: + # Unsubscribe when client disconnects or stream ends to prevent resource leak + if subscriber_queue is not None: + try: + await stream_registry.unsubscribe_from_task( + task_id, subscriber_queue + ) + except Exception as unsub_err: + logger.error( + f"Error unsubscribing from task {task_id}: {unsub_err}", + exc_info=True, + ) + # AI SDK protocol termination - always yield even if unsubscribe fails + yield "data: [DONE]\n\n" return StreamingResponse( event_generator(), @@ -366,6 +470,251 @@ async def session_assign_user( return {"status": "ok"} +# ========== Task Streaming (SSE Reconnection) ========== + + +@router.get( + "/tasks/{task_id}/stream", +) +async def stream_task( + task_id: str, + user_id: str | None = Depends(auth.get_user_id), + last_message_id: str = Query( + default="0-0", + description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.", + ), +): + """ + Reconnect to a long-running task's SSE stream. + + When a long-running operation (like agent generation) starts, the client + receives a task_id. If the connection drops, the client can reconnect + using this endpoint to resume receiving updates. + + Args: + task_id: The task ID from the operation_started response. + user_id: Authenticated user ID for ownership validation. + last_message_id: Last Redis Stream message ID received ("0-0" for full replay). + + Returns: + StreamingResponse: SSE-formatted response chunks starting after last_message_id. + + Raises: + HTTPException: 404 if task not found, 410 if task expired, 403 if access denied. + """ + # Check task existence and expiry before subscribing + task, error_code = await stream_registry.get_task_with_expiry_info(task_id) + + if error_code == "TASK_EXPIRED": + raise HTTPException( + status_code=410, + detail={ + "code": "TASK_EXPIRED", + "message": "This operation has expired. Please try again.", + }, + ) + + if error_code == "TASK_NOT_FOUND": + raise HTTPException( + status_code=404, + detail={ + "code": "TASK_NOT_FOUND", + "message": f"Task {task_id} not found.", + }, + ) + + # Validate ownership if task has an owner + if task and task.user_id and user_id != task.user_id: + raise HTTPException( + status_code=403, + detail={ + "code": "ACCESS_DENIED", + "message": "You do not have access to this task.", + }, + ) + + # Get subscriber queue from stream registry + subscriber_queue = await stream_registry.subscribe_to_task( + task_id=task_id, + user_id=user_id, + last_message_id=last_message_id, + ) + + if subscriber_queue is None: + raise HTTPException( + status_code=404, + detail={ + "code": "TASK_NOT_FOUND", + "message": f"Task {task_id} not found or access denied.", + }, + ) + + async def event_generator() -> AsyncGenerator[str, None]: + import asyncio + + heartbeat_interval = 15.0 # Send heartbeat every 15 seconds + try: + while True: + try: + # Wait for next chunk with timeout for heartbeats + chunk = await asyncio.wait_for( + subscriber_queue.get(), timeout=heartbeat_interval + ) + yield chunk.to_sse() + + # Check for finish signal + if isinstance(chunk, StreamFinish): + break + except asyncio.TimeoutError: + # Send heartbeat to keep connection alive + yield StreamHeartbeat().to_sse() + except Exception as e: + logger.error(f"Error in task stream {task_id}: {e}", exc_info=True) + finally: + # Unsubscribe when client disconnects or stream ends + try: + await stream_registry.unsubscribe_from_task(task_id, subscriber_queue) + except Exception as unsub_err: + logger.error( + f"Error unsubscribing from task {task_id}: {unsub_err}", + exc_info=True, + ) + # AI SDK protocol termination - always yield even if unsubscribe fails + yield "data: [DONE]\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "x-vercel-ai-ui-message-stream": "v1", + }, + ) + + +@router.get( + "/tasks/{task_id}", +) +async def get_task_status( + task_id: str, + user_id: str | None = Depends(auth.get_user_id), +) -> dict: + """ + Get the status of a long-running task. + + Args: + task_id: The task ID to check. + user_id: Authenticated user ID for ownership validation. + + Returns: + dict: Task status including task_id, status, tool_name, and operation_id. + + Raises: + NotFoundError: If task_id is not found or user doesn't have access. + """ + task = await stream_registry.get_task(task_id) + + if task is None: + raise NotFoundError(f"Task {task_id} not found.") + + # Validate ownership - if task has an owner, requester must match + if task.user_id and user_id != task.user_id: + raise NotFoundError(f"Task {task_id} not found.") + + return { + "task_id": task.task_id, + "session_id": task.session_id, + "status": task.status, + "tool_name": task.tool_name, + "operation_id": task.operation_id, + "created_at": task.created_at.isoformat(), + } + + +# ========== External Completion Webhook ========== + + +@router.post( + "/operations/{operation_id}/complete", + status_code=200, +) +async def complete_operation( + operation_id: str, + request: OperationCompleteRequest, + x_api_key: str | None = Header(default=None), +) -> dict: + """ + External completion webhook for long-running operations. + + Called by Agent Generator (or other services) when an operation completes. + This triggers the stream registry to publish completion and continue LLM generation. + + Args: + operation_id: The operation ID to complete. + request: Completion payload with success status and result/error. + x_api_key: Internal API key for authentication. + + Returns: + dict: Status of the completion. + + Raises: + HTTPException: If API key is invalid or operation not found. + """ + # Validate internal API key - reject if not configured or invalid + if not config.internal_api_key: + logger.error( + "Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured" + ) + raise HTTPException( + status_code=503, + detail="Webhook not available: internal API key not configured", + ) + if x_api_key != config.internal_api_key: + raise HTTPException(status_code=401, detail="Invalid API key") + + # Find task by operation_id + task = await stream_registry.find_task_by_operation_id(operation_id) + if task is None: + raise HTTPException( + status_code=404, + detail=f"Operation {operation_id} not found", + ) + + logger.info( + f"Received completion webhook for operation {operation_id} " + f"(task_id={task.task_id}, success={request.success})" + ) + + if request.success: + await process_operation_success(task, request.result) + else: + await process_operation_failure(task, request.error) + + return {"status": "ok", "task_id": task.task_id} + + +# ========== Configuration ========== + + +@router.get("/config/ttl", status_code=200) +async def get_ttl_config() -> dict: + """ + Get the stream TTL configuration. + + Returns the Time-To-Live settings for chat streams, which determines + how long clients can reconnect to an active stream. + + Returns: + dict: TTL configuration with seconds and milliseconds values. + """ + return { + "stream_ttl_seconds": config.stream_ttl, + "stream_ttl_ms": config.stream_ttl * 1000, + } + + # ========== Health Check ========== diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index 6336d1c5af..218575085b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -36,6 +36,7 @@ from backend.util.exceptions import NotFoundError from backend.util.settings import Settings from . import db as chat_db +from . import stream_registry from .config import ChatConfig from .model import ( ChatMessage, @@ -1184,8 +1185,9 @@ async def _yield_tool_call( ) return - # Generate operation ID + # Generate operation ID and task ID operation_id = str(uuid_module.uuid4()) + task_id = str(uuid_module.uuid4()) # Build a user-friendly message based on tool and arguments if tool_name == "create_agent": @@ -1228,6 +1230,16 @@ async def _yield_tool_call( # Wrap session save and task creation in try-except to release lock on failure try: + # Create task in stream registry for SSE reconnection support + await stream_registry.create_task( + task_id=task_id, + session_id=session.session_id, + user_id=session.user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + # Save assistant message with tool_call FIRST (required by LLM) assistant_message = ChatMessage( role="assistant", @@ -1249,23 +1261,27 @@ async def _yield_tool_call( session.messages.append(pending_message) await upsert_chat_session(session) logger.info( - f"Saved pending operation {operation_id} for tool {tool_name} " - f"in session {session.session_id}" + f"Saved pending operation {operation_id} (task_id={task_id}) " + f"for tool {tool_name} in session {session.session_id}" ) # Store task reference in module-level set to prevent GC before completion - task = asyncio.create_task( - _execute_long_running_tool( + bg_task = asyncio.create_task( + _execute_long_running_tool_with_streaming( tool_name=tool_name, parameters=arguments, tool_call_id=tool_call_id, operation_id=operation_id, + task_id=task_id, session_id=session.session_id, user_id=session.user_id, ) ) - _background_tasks.add(task) - task.add_done_callback(_background_tasks.discard) + _background_tasks.add(bg_task) + bg_task.add_done_callback(_background_tasks.discard) + + # Associate the asyncio task with the stream registry task + await stream_registry.set_task_asyncio_task(task_id, bg_task) except Exception as e: # Roll back appended messages to prevent data corruption on subsequent saves if ( @@ -1283,6 +1299,11 @@ async def _yield_tool_call( # Release the Redis lock since the background task won't be spawned await _mark_operation_completed(tool_call_id) + # Mark stream registry task as failed if it was created + try: + await stream_registry.mark_task_completed(task_id, status="failed") + except Exception: + pass logger.error( f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True ) @@ -1296,6 +1317,7 @@ async def _yield_tool_call( message=started_msg, operation_id=operation_id, tool_name=tool_name, + task_id=task_id, # Include task_id for SSE reconnection ).model_dump_json(), success=True, ) @@ -1365,6 +1387,9 @@ async def _execute_long_running_tool( This function runs independently of the SSE connection, so the operation survives if the user closes their browser tab. + + NOTE: This is the legacy function without stream registry support. + Use _execute_long_running_tool_with_streaming for new implementations. """ try: # Load fresh session (not stale reference) @@ -1417,6 +1442,133 @@ async def _execute_long_running_tool( await _mark_operation_completed(tool_call_id) +async def _execute_long_running_tool_with_streaming( + tool_name: str, + parameters: dict[str, Any], + tool_call_id: str, + operation_id: str, + task_id: str, + session_id: str, + user_id: str | None, +) -> None: + """Execute a long-running tool with stream registry support for SSE reconnection. + + This function runs independently of the SSE connection, publishes progress + to the stream registry, and survives if the user closes their browser tab. + Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming. + + If the external service returns a 202 Accepted (async), this function exits + early and lets the Redis Streams completion consumer handle the rest. + """ + # Track whether we delegated to async processing - if so, the Redis Streams + # completion consumer (stream_registry / completion_consumer) will handle cleanup, not us + delegated_to_async = False + + try: + # Load fresh session (not stale reference) + session = await get_chat_session(session_id, user_id) + if not session: + logger.error(f"Session {session_id} not found for background tool") + await stream_registry.mark_task_completed(task_id, status="failed") + return + + # Pass operation_id and task_id to the tool for async processing + enriched_parameters = { + **parameters, + "_operation_id": operation_id, + "_task_id": task_id, + } + + # Execute the actual tool + result = await execute_tool( + tool_name=tool_name, + parameters=enriched_parameters, + tool_call_id=tool_call_id, + user_id=user_id, + session=session, + ) + + # Check if the tool result indicates async processing + # (e.g., Agent Generator returned 202 Accepted) + try: + if isinstance(result.output, dict): + result_data = result.output + elif result.output: + result_data = orjson.loads(result.output) + else: + result_data = {} + if result_data.get("status") == "accepted": + logger.info( + f"Tool {tool_name} delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id}). " + f"Redis Streams completion consumer will handle the rest." + ) + # Don't publish result, don't continue with LLM, and don't cleanup + # The Redis Streams consumer (completion_consumer) will handle + # everything when the external service completes via webhook + delegated_to_async = True + return + except (orjson.JSONDecodeError, TypeError): + pass # Not JSON or not async - continue normally + + # Publish tool result to stream registry + await stream_registry.publish_chunk(task_id, result) + + # Update the pending message with result + result_str = ( + result.output + if isinstance(result.output, str) + else orjson.dumps(result.output).decode("utf-8") + ) + await _update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=result_str, + ) + + logger.info( + f"Background tool {tool_name} completed for session {session_id} " + f"(task_id={task_id})" + ) + + # Generate LLM continuation and stream chunks to registry + await _generate_llm_continuation_with_streaming( + session_id=session_id, + user_id=user_id, + task_id=task_id, + ) + + # Mark task as completed in stream registry + await stream_registry.mark_task_completed(task_id, status="completed") + + except Exception as e: + logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True) + error_response = ErrorResponse( + message=f"Tool {tool_name} failed: {str(e)}", + ) + + # Publish error to stream registry followed by finish event + await stream_registry.publish_chunk( + task_id, + StreamError(errorText=str(e)), + ) + await stream_registry.publish_chunk(task_id, StreamFinish()) + + await _update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=error_response.model_dump_json(), + ) + + # Mark task as failed in stream registry + await stream_registry.mark_task_completed(task_id, status="failed") + finally: + # Only cleanup if we didn't delegate to async processing + # For async path, the Redis Streams completion consumer handles cleanup + if not delegated_to_async: + await _mark_operation_completed(tool_call_id) + + async def _update_pending_operation( session_id: str, tool_call_id: str, @@ -1597,3 +1749,128 @@ async def _generate_llm_continuation( except Exception as e: logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True) + + +async def _generate_llm_continuation_with_streaming( + session_id: str, + user_id: str | None, + task_id: str, +) -> None: + """Generate an LLM response with streaming to the stream registry. + + This is called by background tasks to continue the conversation + after a tool result is saved. Chunks are published to the stream registry + so reconnecting clients can receive them. + """ + import uuid as uuid_module + + try: + # Load fresh session from DB (bypass cache to get the updated tool result) + await invalidate_session_cache(session_id) + session = await get_chat_session(session_id, user_id) + if not session: + logger.error(f"Session {session_id} not found for LLM continuation") + return + + # Build system prompt + system_prompt, _ = await _build_system_prompt(user_id) + + # Build messages in OpenAI format + messages = session.to_openai_messages() + if system_prompt: + from openai.types.chat import ChatCompletionSystemMessageParam + + system_message = ChatCompletionSystemMessageParam( + role="system", + content=system_prompt, + ) + messages = [system_message] + messages + + # Build extra_body for tracing + extra_body: dict[str, Any] = { + "posthogProperties": { + "environment": settings.config.app_env.value, + }, + } + if user_id: + extra_body["user"] = user_id[:128] + extra_body["posthogDistinctId"] = user_id + if session_id: + extra_body["session_id"] = session_id[:128] + + # Make streaming LLM call (no tools - just text response) + from typing import cast + + from openai.types.chat import ChatCompletionMessageParam + + # Generate unique IDs for AI SDK protocol + message_id = str(uuid_module.uuid4()) + text_block_id = str(uuid_module.uuid4()) + + # Publish start event + await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id)) + await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id)) + + # Stream the response + stream = await client.chat.completions.create( + model=config.model, + messages=cast(list[ChatCompletionMessageParam], messages), + extra_body=extra_body, + stream=True, + ) + + assistant_content = "" + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + delta = chunk.choices[0].delta.content + assistant_content += delta + # Publish delta to stream registry + await stream_registry.publish_chunk( + task_id, + StreamTextDelta(id=text_block_id, delta=delta), + ) + + # Publish end events + await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id)) + + if assistant_content: + # Reload session from DB to avoid race condition with user messages + fresh_session = await get_chat_session(session_id, user_id) + if not fresh_session: + logger.error( + f"Session {session_id} disappeared during LLM continuation" + ) + return + + # Save assistant message to database + assistant_message = ChatMessage( + role="assistant", + content=assistant_content, + ) + fresh_session.messages.append(assistant_message) + + # Save to database (not cache) to persist the response + await upsert_chat_session(fresh_session) + + # Invalidate cache so next poll/refresh gets fresh data + await invalidate_session_cache(session_id) + + logger.info( + f"Generated streaming LLM continuation for session {session_id} " + f"(task_id={task_id}), response length: {len(assistant_content)}" + ) + else: + logger.warning( + f"Streaming LLM continuation returned empty response for {session_id}" + ) + + except Exception as e: + logger.error( + f"Failed to generate streaming LLM continuation: {e}", exc_info=True + ) + # Publish error to stream registry followed by finish event + await stream_registry.publish_chunk( + task_id, + StreamError(errorText=f"Failed to generate response: {e}"), + ) + await stream_registry.publish_chunk(task_id, StreamFinish()) diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py new file mode 100644 index 0000000000..88a5023e2b --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -0,0 +1,704 @@ +"""Stream registry for managing reconnectable SSE streams. + +This module provides a registry for tracking active streaming tasks and their +messages. It uses Redis for all state management (no in-memory state), making +pods stateless and horizontally scalable. + +Architecture: +- Redis Stream: Persists all messages for replay and real-time delivery +- Redis Hash: Task metadata (status, session_id, etc.) + +Subscribers: +1. Replay missed messages from Redis Stream (XREAD) +2. Listen for live updates via blocking XREAD +3. No in-memory state required on the subscribing pod +""" + +import asyncio +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Literal + +import orjson + +from backend.data.redis_client import get_redis_async + +from .config import ChatConfig +from .response_model import StreamBaseResponse, StreamError, StreamFinish + +logger = logging.getLogger(__name__) +config = ChatConfig() + +# Track background tasks for this pod (just the asyncio.Task reference, not subscribers) +_local_tasks: dict[str, asyncio.Task] = {} + +# Track listener tasks per subscriber queue for cleanup +# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe +_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {} + +# Timeout for putting chunks into subscriber queues (seconds) +# If the queue is full and doesn't drain within this time, send an overflow error +QUEUE_PUT_TIMEOUT = 5.0 + +# Lua script for atomic compare-and-swap status update (idempotent completion) +# Returns 1 if status was updated, 0 if already completed/failed +COMPLETE_TASK_SCRIPT = """ +local current = redis.call("HGET", KEYS[1], "status") +if current == "running" then + redis.call("HSET", KEYS[1], "status", ARGV[1]) + return 1 +end +return 0 +""" + + +@dataclass +class ActiveTask: + """Represents an active streaming task (metadata only, no in-memory queues).""" + + task_id: str + session_id: str + user_id: str | None + tool_call_id: str + tool_name: str + operation_id: str + status: Literal["running", "completed", "failed"] = "running" + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + asyncio_task: asyncio.Task | None = None + + +def _get_task_meta_key(task_id: str) -> str: + """Get Redis key for task metadata.""" + return f"{config.task_meta_prefix}{task_id}" + + +def _get_task_stream_key(task_id: str) -> str: + """Get Redis key for task message stream.""" + return f"{config.task_stream_prefix}{task_id}" + + +def _get_operation_mapping_key(operation_id: str) -> str: + """Get Redis key for operation_id to task_id mapping.""" + return f"{config.task_op_prefix}{operation_id}" + + +async def create_task( + task_id: str, + session_id: str, + user_id: str | None, + tool_call_id: str, + tool_name: str, + operation_id: str, +) -> ActiveTask: + """Create a new streaming task in Redis. + + Args: + task_id: Unique identifier for the task + session_id: Chat session ID + user_id: User ID (may be None for anonymous) + tool_call_id: Tool call ID from the LLM + tool_name: Name of the tool being executed + operation_id: Operation ID for webhook callbacks + + Returns: + The created ActiveTask instance (metadata only) + """ + task = ActiveTask( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + + # Store metadata in Redis + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + op_key = _get_operation_mapping_key(operation_id) + + await redis.hset( # type: ignore[misc] + meta_key, + mapping={ + "task_id": task_id, + "session_id": session_id, + "user_id": user_id or "", + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "operation_id": operation_id, + "status": task.status, + "created_at": task.created_at.isoformat(), + }, + ) + await redis.expire(meta_key, config.stream_ttl) + + # Create operation_id -> task_id mapping for webhook lookups + await redis.set(op_key, task_id, ex=config.stream_ttl) + + logger.debug(f"Created task {task_id} for session {session_id}") + + return task + + +async def publish_chunk( + task_id: str, + chunk: StreamBaseResponse, +) -> str: + """Publish a chunk to Redis Stream. + + All delivery is via Redis Streams - no in-memory state. + + Args: + task_id: Task ID to publish to + chunk: The stream response chunk to publish + + Returns: + The Redis Stream message ID + """ + chunk_json = chunk.model_dump_json() + message_id = "0-0" + + try: + redis = await get_redis_async() + stream_key = _get_task_stream_key(task_id) + + # Write to Redis Stream for persistence and real-time delivery + raw_id = await redis.xadd( + stream_key, + {"data": chunk_json}, + maxlen=config.stream_max_length, + ) + message_id = raw_id if isinstance(raw_id, str) else raw_id.decode() + + # Set TTL on stream to match task metadata TTL + await redis.expire(stream_key, config.stream_ttl) + except Exception as e: + logger.error( + f"Failed to publish chunk for task {task_id}: {e}", + exc_info=True, + ) + + return message_id + + +async def subscribe_to_task( + task_id: str, + user_id: str | None, + last_message_id: str = "0-0", +) -> asyncio.Queue[StreamBaseResponse] | None: + """Subscribe to a task's stream with replay of missed messages. + + This is fully stateless - uses Redis Stream for replay and pub/sub for live updates. + + Args: + task_id: Task ID to subscribe to + user_id: User ID for ownership validation + last_message_id: Last Redis Stream message ID received ("0-0" for full replay) + + Returns: + An asyncio Queue that will receive stream chunks, or None if task not found + or user doesn't have access + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + logger.debug(f"Task {task_id} not found in Redis") + return None + + # Note: Redis client uses decode_responses=True, so keys are strings + task_status = meta.get("status", "") + task_user_id = meta.get("user_id", "") or None + + # Validate ownership - if task has an owner, requester must match + if task_user_id: + if user_id != task_user_id: + logger.warning( + f"User {user_id} denied access to task {task_id} " + f"owned by {task_user_id}" + ) + return None + + subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue() + stream_key = _get_task_stream_key(task_id) + + # Step 1: Replay messages from Redis Stream + messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) + + replayed_count = 0 + replay_last_id = last_message_id + if messages: + for _stream_name, stream_messages in messages: + for msg_id, msg_data in stream_messages: + replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + # Note: Redis client uses decode_responses=True, so keys are strings + if "data" in msg_data: + try: + chunk_data = orjson.loads(msg_data["data"]) + chunk = _reconstruct_chunk(chunk_data) + if chunk: + await subscriber_queue.put(chunk) + replayed_count += 1 + except Exception as e: + logger.warning(f"Failed to replay message: {e}") + + logger.debug(f"Task {task_id}: replayed {replayed_count} messages") + + # Step 2: If task is still running, start stream listener for live updates + if task_status == "running": + listener_task = asyncio.create_task( + _stream_listener(task_id, subscriber_queue, replay_last_id) + ) + # Track listener task for cleanup on unsubscribe + _listener_tasks[id(subscriber_queue)] = (task_id, listener_task) + else: + # Task is completed/failed - add finish marker + await subscriber_queue.put(StreamFinish()) + + return subscriber_queue + + +async def _stream_listener( + task_id: str, + subscriber_queue: asyncio.Queue[StreamBaseResponse], + last_replayed_id: str, +) -> None: + """Listen to Redis Stream for new messages using blocking XREAD. + + This approach avoids the duplicate message issue that can occur with pub/sub + when messages are published during the gap between replay and subscription. + + Args: + task_id: Task ID to listen for + subscriber_queue: Queue to deliver messages to + last_replayed_id: Last message ID from replay (continue from here) + """ + queue_id = id(subscriber_queue) + # Track the last successfully delivered message ID for recovery hints + last_delivered_id = last_replayed_id + + try: + redis = await get_redis_async() + stream_key = _get_task_stream_key(task_id) + current_id = last_replayed_id + + while True: + # Block for up to 30 seconds waiting for new messages + # This allows periodic checking if task is still running + messages = await redis.xread( + {stream_key: current_id}, block=30000, count=100 + ) + + if not messages: + # Timeout - check if task is still running + meta_key = _get_task_meta_key(task_id) + status = await redis.hget(meta_key, "status") # type: ignore[misc] + if status and status != "running": + try: + await asyncio.wait_for( + subscriber_queue.put(StreamFinish()), + timeout=QUEUE_PUT_TIMEOUT, + ) + except asyncio.TimeoutError: + logger.warning( + f"Timeout delivering finish event for task {task_id}" + ) + break + continue + + for _stream_name, stream_messages in messages: + for msg_id, msg_data in stream_messages: + current_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + + if "data" not in msg_data: + continue + + try: + chunk_data = orjson.loads(msg_data["data"]) + chunk = _reconstruct_chunk(chunk_data) + if chunk: + try: + await asyncio.wait_for( + subscriber_queue.put(chunk), + timeout=QUEUE_PUT_TIMEOUT, + ) + # Update last delivered ID on successful delivery + last_delivered_id = current_id + except asyncio.TimeoutError: + logger.warning( + f"Subscriber queue full for task {task_id}, " + f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s" + ) + # Send overflow error with recovery info + try: + overflow_error = StreamError( + errorText="Message delivery timeout - some messages may have been missed", + code="QUEUE_OVERFLOW", + details={ + "last_delivered_id": last_delivered_id, + "recovery_hint": f"Reconnect with last_message_id={last_delivered_id}", + }, + ) + subscriber_queue.put_nowait(overflow_error) + except asyncio.QueueFull: + # Queue is completely stuck, nothing more we can do + logger.error( + f"Cannot deliver overflow error for task {task_id}, " + "queue completely blocked" + ) + + # Stop listening on finish + if isinstance(chunk, StreamFinish): + return + except Exception as e: + logger.warning(f"Error processing stream message: {e}") + + except asyncio.CancelledError: + logger.debug(f"Stream listener cancelled for task {task_id}") + raise # Re-raise to propagate cancellation + except Exception as e: + logger.error(f"Stream listener error for task {task_id}: {e}") + # On error, send finish to unblock subscriber + try: + await asyncio.wait_for( + subscriber_queue.put(StreamFinish()), + timeout=QUEUE_PUT_TIMEOUT, + ) + except (asyncio.TimeoutError, asyncio.QueueFull): + logger.warning( + f"Could not deliver finish event for task {task_id} after error" + ) + finally: + # Clean up listener task mapping on exit + _listener_tasks.pop(queue_id, None) + + +async def mark_task_completed( + task_id: str, + status: Literal["completed", "failed"] = "completed", +) -> bool: + """Mark a task as completed and publish finish event. + + This is idempotent - calling multiple times with the same task_id is safe. + Uses atomic compare-and-swap via Lua script to prevent race conditions. + Status is updated first (source of truth), then finish event is published (best-effort). + + Args: + task_id: Task ID to mark as completed + status: Final status ("completed" or "failed") + + Returns: + True if task was newly marked completed, False if already completed/failed + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + + # Atomic compare-and-swap: only update if status is "running" + # This prevents race conditions when multiple callers try to complete simultaneously + result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc] + + if result == 0: + logger.debug(f"Task {task_id} already completed/failed, skipping") + return False + + # THEN publish finish event (best-effort - listeners can detect via status polling) + try: + await publish_chunk(task_id, StreamFinish()) + except Exception as e: + logger.error( + f"Failed to publish finish event for task {task_id}: {e}. " + "Listeners will detect completion via status polling." + ) + + # Clean up local task reference if exists + _local_tasks.pop(task_id, None) + return True + + +async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None: + """Find a task by its operation ID. + + Used by webhook callbacks to locate the task to update. + + Args: + operation_id: Operation ID to search for + + Returns: + ActiveTask if found, None otherwise + """ + redis = await get_redis_async() + op_key = _get_operation_mapping_key(operation_id) + task_id = await redis.get(op_key) + + if not task_id: + return None + + task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id + return await get_task(task_id_str) + + +async def get_task(task_id: str) -> ActiveTask | None: + """Get a task by its ID from Redis. + + Args: + task_id: Task ID to look up + + Returns: + ActiveTask if found, None otherwise + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + return None + + # Note: Redis client uses decode_responses=True, so keys/values are strings + return ActiveTask( + task_id=meta.get("task_id", ""), + session_id=meta.get("session_id", ""), + user_id=meta.get("user_id", "") or None, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status=meta.get("status", "running"), # type: ignore[arg-type] + ) + + +async def get_task_with_expiry_info( + task_id: str, +) -> tuple[ActiveTask | None, str | None]: + """Get a task by its ID with expiration detection. + + Returns (task, error_code) where error_code is: + - None if task found + - "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired) + - "TASK_NOT_FOUND" if neither exists + + Args: + task_id: Task ID to look up + + Returns: + Tuple of (ActiveTask or None, error_code or None) + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + stream_key = _get_task_stream_key(task_id) + + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + # Check if stream still has data (metadata expired but stream hasn't) + stream_len = await redis.xlen(stream_key) + if stream_len > 0: + return None, "TASK_EXPIRED" + return None, "TASK_NOT_FOUND" + + # Note: Redis client uses decode_responses=True, so keys/values are strings + return ( + ActiveTask( + task_id=meta.get("task_id", ""), + session_id=meta.get("session_id", ""), + user_id=meta.get("user_id", "") or None, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status=meta.get("status", "running"), # type: ignore[arg-type] + ), + None, + ) + + +async def get_active_task_for_session( + session_id: str, + user_id: str | None = None, +) -> tuple[ActiveTask | None, str]: + """Get the active (running) task for a session, if any. + + Scans Redis for tasks matching the session_id with status="running". + + Args: + session_id: Session ID to look up + user_id: User ID for ownership validation (optional) + + Returns: + Tuple of (ActiveTask if found and running, last_message_id from Redis Stream) + """ + + redis = await get_redis_async() + + # Scan Redis for task metadata keys + cursor = 0 + tasks_checked = 0 + + while True: + cursor, keys = await redis.scan( + cursor, match=f"{config.task_meta_prefix}*", count=100 + ) + + for key in keys: + tasks_checked += 1 + meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc] + if not meta: + continue + + # Note: Redis client uses decode_responses=True, so keys/values are strings + task_session_id = meta.get("session_id", "") + task_status = meta.get("status", "") + task_user_id = meta.get("user_id", "") or None + task_id = meta.get("task_id", "") + + if task_session_id == session_id and task_status == "running": + # Validate ownership - if task has an owner, requester must match + if task_user_id and user_id != task_user_id: + continue + + # Get the last message ID from Redis Stream + stream_key = _get_task_stream_key(task_id) + last_id = "0-0" + try: + messages = await redis.xrevrange(stream_key, count=1) + if messages: + msg_id = messages[0][0] + last_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + except Exception as e: + logger.warning(f"Failed to get last message ID: {e}") + + return ( + ActiveTask( + task_id=task_id, + session_id=task_session_id, + user_id=task_user_id, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status="running", + ), + last_id, + ) + + if cursor == 0: + break + + return None, "0-0" + + +def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None: + """Reconstruct a StreamBaseResponse from JSON data. + + Args: + chunk_data: Parsed JSON data from Redis + + Returns: + Reconstructed response object, or None if unknown type + """ + from .response_model import ( + ResponseType, + StreamError, + StreamFinish, + StreamHeartbeat, + StreamStart, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, + StreamUsage, + ) + + # Map response types to their corresponding classes + type_to_class: dict[str, type[StreamBaseResponse]] = { + ResponseType.START.value: StreamStart, + ResponseType.FINISH.value: StreamFinish, + ResponseType.TEXT_START.value: StreamTextStart, + ResponseType.TEXT_DELTA.value: StreamTextDelta, + ResponseType.TEXT_END.value: StreamTextEnd, + ResponseType.TOOL_INPUT_START.value: StreamToolInputStart, + ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable, + ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable, + ResponseType.ERROR.value: StreamError, + ResponseType.USAGE.value: StreamUsage, + ResponseType.HEARTBEAT.value: StreamHeartbeat, + } + + chunk_type = chunk_data.get("type") + chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type] + + if chunk_class is None: + logger.warning(f"Unknown chunk type: {chunk_type}") + return None + + try: + return chunk_class(**chunk_data) + except Exception as e: + logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}") + return None + + +async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None: + """Track the asyncio.Task for a task (local reference only). + + This is just for cleanup purposes - the task state is in Redis. + + Args: + task_id: Task ID + asyncio_task: The asyncio Task to track + """ + _local_tasks[task_id] = asyncio_task + + +async def unsubscribe_from_task( + task_id: str, + subscriber_queue: asyncio.Queue[StreamBaseResponse], +) -> None: + """Clean up when a subscriber disconnects. + + Cancels the XREAD-based listener task associated with this subscriber queue + to prevent resource leaks. + + Args: + task_id: Task ID + subscriber_queue: The subscriber's queue used to look up the listener task + """ + queue_id = id(subscriber_queue) + listener_entry = _listener_tasks.pop(queue_id, None) + + if listener_entry is None: + logger.debug( + f"No listener task found for task {task_id} queue {queue_id} " + "(may have already completed)" + ) + return + + stored_task_id, listener_task = listener_entry + + if stored_task_id != task_id: + logger.warning( + f"Task ID mismatch in unsubscribe: expected {task_id}, " + f"found {stored_task_id}" + ) + + if listener_task.done(): + logger.debug(f"Listener task for task {task_id} already completed") + return + + # Cancel the listener task + listener_task.cancel() + + try: + # Wait for the task to be cancelled with a timeout + await asyncio.wait_for(listener_task, timeout=5.0) + except asyncio.CancelledError: + # Expected - the task was successfully cancelled + pass + except asyncio.TimeoutError: + logger.warning( + f"Timeout waiting for listener task cancellation for task {task_id}" + ) + except Exception as e: + logger.error(f"Error during listener task cancellation for task {task_id}: {e}") + + logger.debug(f"Successfully unsubscribed from task {task_id}") diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py index 5b40091bbb..b88b9b2924 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py @@ -550,15 +550,21 @@ async def decompose_goal( async def generate_agent( instructions: DecompositionResult | dict[str, Any], library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Generate agent JSON from instructions. Args: instructions: Structured instructions from decompose_goal library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams + completion notification) + task_id: Task ID for async processing (enables Redis Streams persistence + and SSE delivery) Returns: - Agent JSON dict, error dict {"type": "error", ...}, or None on error + Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error Raises: AgentGeneratorNotConfiguredError: If the external service is not configured. @@ -566,8 +572,13 @@ async def generate_agent( _check_service_configured() logger.info("Calling external Agent Generator service for generate_agent") result = await generate_agent_external( - dict(instructions), _to_dict_list(library_agents) + dict(instructions), _to_dict_list(library_agents), operation_id, task_id ) + + # Don't modify async response + if result and result.get("status") == "accepted": + return result + if result: if isinstance(result, dict) and result.get("type") == "error": return result @@ -819,6 +830,8 @@ async def generate_agent_patch( update_request: str, current_agent: dict[str, Any], library_agents: list[AgentSummary] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Update an existing agent using natural language. @@ -831,10 +844,12 @@ async def generate_agent_patch( update_request: Natural language description of changes current_agent: Current agent JSON library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...}, - error dict {"type": "error", ...}, or None on unexpected error + {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error Raises: AgentGeneratorNotConfiguredError: If the external service is not configured. @@ -842,7 +857,11 @@ async def generate_agent_patch( _check_service_configured() logger.info("Calling external Agent Generator service for generate_agent_patch") return await generate_agent_patch_external( - update_request, current_agent, _to_dict_list(library_agents) + update_request, + current_agent, + _to_dict_list(library_agents), + operation_id, + task_id, ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py index 780247a776..62411b4e1b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py @@ -212,24 +212,45 @@ async def decompose_goal_external( async def generate_agent_external( instructions: dict[str, Any], library_agents: list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Call the external service to generate an agent from instructions. Args: instructions: Structured instructions from decompose_goal library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: - Agent JSON dict on success, or error dict {"type": "error", ...} on error + Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error """ client = _get_client() + # Build request payload payload: dict[str, Any] = {"instructions": instructions} if library_agents: payload["library_agents"] = library_agents + if operation_id and task_id: + payload["operation_id"] = operation_id + payload["task_id"] = task_id try: response = await client.post("/api/generate-agent", json=payload) + + # Handle 202 Accepted for async processing + if response.status_code == 202: + logger.info( + f"Agent Generator accepted async request " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return { + "status": "accepted", + "operation_id": operation_id, + "task_id": task_id, + } + response.raise_for_status() data = response.json() @@ -261,6 +282,8 @@ async def generate_agent_patch_external( update_request: str, current_agent: dict[str, Any], library_agents: list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Call the external service to generate a patch for an existing agent. @@ -268,21 +291,40 @@ async def generate_agent_patch_external( update_request: Natural language description of changes current_agent: Current agent JSON library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: - Updated agent JSON, clarifying questions dict, or error dict on error + Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error """ client = _get_client() + # Build request payload payload: dict[str, Any] = { "update_request": update_request, "current_agent_json": current_agent, } if library_agents: payload["library_agents"] = library_agents + if operation_id and task_id: + payload["operation_id"] = operation_id + payload["task_id"] = task_id try: response = await client.post("/api/update-agent", json=payload) + + # Handle 202 Accepted for async processing + if response.status_code == 202: + logger.info( + f"Agent Generator accepted async update request " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return { + "status": "accepted", + "operation_id": operation_id, + "task_id": task_id, + } + response.raise_for_status() data = response.json() diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py index adb2c78fce..7333851a5b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py @@ -18,6 +18,7 @@ from .base import BaseTool from .models import ( AgentPreviewResponse, AgentSavedResponse, + AsyncProcessingResponse, ClarificationNeededResponse, ClarifyingQuestion, ErrorResponse, @@ -98,6 +99,10 @@ class CreateAgentTool(BaseTool): save = kwargs.get("save", True) session_id = session.session_id if session else None + # Extract async processing params (passed by long-running tool handler) + operation_id = kwargs.get("_operation_id") + task_id = kwargs.get("_task_id") + if not description: return ErrorResponse( message="Please provide a description of what the agent should do.", @@ -219,7 +224,12 @@ class CreateAgentTool(BaseTool): logger.warning(f"Failed to enrich library agents from steps: {e}") try: - agent_json = await generate_agent(decomposition_result, library_agents) + agent_json = await generate_agent( + decomposition_result, + library_agents, + operation_id=operation_id, + task_id=task_id, + ) except AgentGeneratorNotConfiguredError: return ErrorResponse( message=( @@ -263,6 +273,19 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) + # Check if Agent Generator accepted for async processing + if agent_json.get("status") == "accepted": + logger.info( + f"Agent generation delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return AsyncProcessingResponse( + message="Agent generation started. You'll be notified when it's complete.", + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + ) + agent_name = agent_json.get("name", "Generated Agent") agent_description = agent_json.get("description", "") node_count = len(agent_json.get("nodes", [])) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py index 2c2c48226b..3ae56407a7 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py @@ -17,6 +17,7 @@ from .base import BaseTool from .models import ( AgentPreviewResponse, AgentSavedResponse, + AsyncProcessingResponse, ClarificationNeededResponse, ClarifyingQuestion, ErrorResponse, @@ -104,6 +105,10 @@ class EditAgentTool(BaseTool): save = kwargs.get("save", True) session_id = session.session_id if session else None + # Extract async processing params (passed by long-running tool handler) + operation_id = kwargs.get("_operation_id") + task_id = kwargs.get("_task_id") + if not agent_id: return ErrorResponse( message="Please provide the agent ID to edit.", @@ -149,7 +154,11 @@ class EditAgentTool(BaseTool): try: result = await generate_agent_patch( - update_request, current_agent, library_agents + update_request, + current_agent, + library_agents, + operation_id=operation_id, + task_id=task_id, ) except AgentGeneratorNotConfiguredError: return ErrorResponse( @@ -169,6 +178,20 @@ class EditAgentTool(BaseTool): session_id=session_id, ) + # Check if Agent Generator accepted for async processing + if result.get("status") == "accepted": + logger.info( + f"Agent edit delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return AsyncProcessingResponse( + message="Agent edit started. You'll be notified when it's complete.", + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + ) + + # Check if the result is an error from the external service if isinstance(result, dict) and result.get("type") == "error": error_msg = result.get("error", "Unknown error") error_type = result.get("error_type", "unknown") diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/models.py b/autogpt_platform/backend/backend/api/features/chat/tools/models.py index 5ff8190c31..69c8c6c684 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/models.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/models.py @@ -372,11 +372,15 @@ class OperationStartedResponse(ToolResponseBase): This is returned immediately to the client while the operation continues to execute. The user can close the tab and check back later. + + The task_id can be used to reconnect to the SSE stream via + GET /chat/tasks/{task_id}/stream?last_idx=0 """ type: ResponseType = ResponseType.OPERATION_STARTED operation_id: str tool_name: str + task_id: str | None = None # For SSE reconnection class OperationPendingResponse(ToolResponseBase): @@ -400,3 +404,20 @@ class OperationInProgressResponse(ToolResponseBase): type: ResponseType = ResponseType.OPERATION_IN_PROGRESS tool_call_id: str + + +class AsyncProcessingResponse(ToolResponseBase): + """Response when an operation has been delegated to async processing. + + This is returned by tools when the external service accepts the request + for async processing (HTTP 202 Accepted). The Redis Streams completion + consumer will handle the result when the external service completes. + + The status field is specifically "accepted" to allow the long-running tool + handler to detect this response and skip LLM continuation. + """ + + type: ResponseType = ResponseType.OPERATION_STARTED + status: str = "accepted" # Must be "accepted" for detection + operation_id: str | None = None + task_id: str | None = None diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py b/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py index bae5b97cd6..86af457f50 100644 --- a/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py +++ b/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py @@ -454,6 +454,9 @@ async def test_unified_hybrid_search_pagination( cleanup_embeddings: list, ): """Test unified search pagination works correctly.""" + # Use a unique search term to avoid matching other test data + unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}" + # Create multiple items content_ids = [] for i in range(5): @@ -465,14 +468,14 @@ async def test_unified_hybrid_search_pagination( content_type=ContentType.BLOCK, content_id=content_id, embedding=mock_embedding, - searchable_text=f"pagination test item number {i}", + searchable_text=f"{unique_term} item number {i}", metadata={"index": i}, user_id=None, ) # Get first page page1_results, total1 = await unified_hybrid_search( - query="pagination test", + query=unique_term, content_types=[ContentType.BLOCK], page=1, page_size=2, @@ -480,7 +483,7 @@ async def test_unified_hybrid_search_pagination( # Get second page page2_results, total2 = await unified_hybrid_search( - query="pagination test", + query=unique_term, content_types=[ContentType.BLOCK], page=2, page_size=2, diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index b936312ce1..0eef76193e 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -40,6 +40,10 @@ import backend.data.user import backend.integrations.webhooks.utils import backend.util.service import backend.util.settings +from backend.api.features.chat.completion_consumer import ( + start_completion_consumer, + stop_completion_consumer, +) from backend.blocks.llm import DEFAULT_LLM_MODEL from backend.data.model import Credentials from backend.integrations.providers import ProviderName @@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI): await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL) await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs() + # Start chat completion consumer for Redis Streams notifications + try: + await start_completion_consumer() + except Exception as e: + logger.warning(f"Could not start chat completion consumer: {e}") + with launch_darkly_context(): yield + # Stop chat completion consumer + try: + await stop_completion_consumer() + except Exception as e: + logger.warning(f"Error stopping chat completion consumer: {e}") + try: await shutdown_cloud_storage_handler() except Exception as e: diff --git a/autogpt_platform/backend/test/agent_generator/test_core_integration.py b/autogpt_platform/backend/test/agent_generator/test_core_integration.py index 05ce4a3aff..528763e751 100644 --- a/autogpt_platform/backend/test/agent_generator/test_core_integration.py +++ b/autogpt_platform/backend/test/agent_generator/test_core_integration.py @@ -111,9 +111,7 @@ class TestGenerateAgent: instructions = {"type": "instructions", "steps": ["Step 1"]} result = await core.generate_agent(instructions) - # library_agents defaults to None - mock_external.assert_called_once_with(instructions, None) - # Result should have id, version, is_active added if not present + mock_external.assert_called_once_with(instructions, None, None, None) assert result is not None assert result["name"] == "Test Agent" assert "id" in result @@ -177,8 +175,9 @@ class TestGenerateAgentPatch: current_agent = {"nodes": [], "links": []} result = await core.generate_agent_patch("Add a node", current_agent) - # library_agents defaults to None - mock_external.assert_called_once_with("Add a node", current_agent, None) + mock_external.assert_called_once_with( + "Add a node", current_agent, None, None, None + ) assert result == expected_result @pytest.mark.asyncio diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts index 74fd663ab2..913c4d7ded 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts @@ -11,7 +11,6 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useQueryClient } from "@tanstack/react-query"; import { usePathname, useSearchParams } from "next/navigation"; -import { useRef } from "react"; import { useCopilotStore } from "../../copilot-page-store"; import { useCopilotSessionId } from "../../useCopilotSessionId"; import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer"; @@ -70,41 +69,16 @@ export function useCopilotShell() { }); const stopStream = useChatStore((s) => s.stopStream); - const onStreamComplete = useChatStore((s) => s.onStreamComplete); - const isStreaming = useCopilotStore((s) => s.isStreaming); const isCreatingSession = useCopilotStore((s) => s.isCreatingSession); - const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession); - const openInterruptModal = useCopilotStore((s) => s.openInterruptModal); - const pendingActionRef = useRef<(() => void) | null>(null); - - async function stopCurrentStream() { - if (!currentSessionId) return; - - setIsSwitchingSession(true); - await new Promise((resolve) => { - const unsubscribe = onStreamComplete((completedId) => { - if (completedId === currentSessionId) { - clearTimeout(timeout); - unsubscribe(); - resolve(); - } - }); - const timeout = setTimeout(() => { - unsubscribe(); - resolve(); - }, 3000); - stopStream(currentSessionId); - }); - - queryClient.invalidateQueries({ - queryKey: getGetV2GetSessionQueryKey(currentSessionId), - }); - setIsSwitchingSession(false); - } - - function selectSession(sessionId: string) { + function handleSessionClick(sessionId: string) { if (sessionId === currentSessionId) return; + + // Stop current stream - SSE reconnection allows resuming later + if (currentSessionId) { + stopStream(currentSessionId); + } + if (recentlyCreatedSessionsRef.current.has(sessionId)) { queryClient.invalidateQueries({ queryKey: getGetV2GetSessionQueryKey(sessionId), @@ -114,7 +88,12 @@ export function useCopilotShell() { if (isMobile) handleCloseDrawer(); } - function startNewChat() { + function handleNewChatClick() { + // Stop current stream - SSE reconnection allows resuming later + if (currentSessionId) { + stopStream(currentSessionId); + } + resetPagination(); queryClient.invalidateQueries({ queryKey: getGetV2ListSessionsQueryKey(), @@ -123,32 +102,6 @@ export function useCopilotShell() { if (isMobile) handleCloseDrawer(); } - function handleSessionClick(sessionId: string) { - if (sessionId === currentSessionId) return; - - if (isStreaming) { - pendingActionRef.current = async () => { - await stopCurrentStream(); - selectSession(sessionId); - }; - openInterruptModal(pendingActionRef.current); - } else { - selectSession(sessionId); - } - } - - function handleNewChatClick() { - if (isStreaming) { - pendingActionRef.current = async () => { - await stopCurrentStream(); - startNewChat(); - }; - openInterruptModal(pendingActionRef.current); - } else { - startNewChat(); - } - } - return { isMobile, isDrawerOpen, diff --git a/autogpt_platform/frontend/src/app/api/chat/tasks/[taskId]/stream/route.ts b/autogpt_platform/frontend/src/app/api/chat/tasks/[taskId]/stream/route.ts new file mode 100644 index 0000000000..336786bfdb --- /dev/null +++ b/autogpt_platform/frontend/src/app/api/chat/tasks/[taskId]/stream/route.ts @@ -0,0 +1,81 @@ +import { environment } from "@/services/environment"; +import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers"; +import { NextRequest } from "next/server"; + +/** + * SSE Proxy for task stream reconnection. + * + * This endpoint allows clients to reconnect to an ongoing or recently completed + * background task's stream. It replays missed messages from Redis Streams and + * subscribes to live updates if the task is still running. + * + * Client contract: + * 1. When receiving an operation_started event, store the task_id + * 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx} + * 3. Messages are replayed from the last_message_id position + * 4. Stream ends when "finish" event is received + */ +export async function GET( + request: NextRequest, + { params }: { params: Promise<{ taskId: string }> }, +) { + const { taskId } = await params; + const searchParams = request.nextUrl.searchParams; + const lastMessageId = searchParams.get("last_message_id") || "0-0"; + + try { + // Get auth token from server-side session + const token = await getServerAuthToken(); + + // Build backend URL + const backendUrl = environment.getAGPTServerBaseUrl(); + const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl); + streamUrl.searchParams.set("last_message_id", lastMessageId); + + // Forward request to backend with auth header + const headers: Record = { + Accept: "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }; + + if (token) { + headers["Authorization"] = `Bearer ${token}`; + } + + const response = await fetch(streamUrl.toString(), { + method: "GET", + headers, + }); + + if (!response.ok) { + const error = await response.text(); + return new Response(error, { + status: response.status, + headers: { "Content-Type": "application/json" }, + }); + } + + // Return the SSE stream directly + return new Response(response.body, { + headers: { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache, no-transform", + Connection: "keep-alive", + "X-Accel-Buffering": "no", + }, + }); + } catch (error) { + console.error("Task stream proxy error:", error); + return new Response( + JSON.stringify({ + error: "Failed to connect to task stream", + detail: error instanceof Error ? error.message : String(error), + }), + { + status: 500, + headers: { "Content-Type": "application/json" }, + }, + ); + } +} diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index aa4c49b1a2..5ed449829d 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -917,6 +917,28 @@ "security": [{ "HTTPBearerJWT": [] }] } }, + "/api/chat/config/ttl": { + "get": { + "tags": ["v2", "chat", "chat"], + "summary": "Get Ttl Config", + "description": "Get the stream TTL configuration.\n\nReturns the Time-To-Live settings for chat streams, which determines\nhow long clients can reconnect to an active stream.\n\nReturns:\n dict: TTL configuration with seconds and milliseconds values.", + "operationId": "getV2GetTtlConfig", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "additionalProperties": true, + "type": "object", + "title": "Response Getv2Getttlconfig" + } + } + } + } + } + } + }, "/api/chat/health": { "get": { "tags": ["v2", "chat", "chat"], @@ -939,6 +961,63 @@ } } }, + "/api/chat/operations/{operation_id}/complete": { + "post": { + "tags": ["v2", "chat", "chat"], + "summary": "Complete Operation", + "description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.", + "operationId": "postV2CompleteOperation", + "parameters": [ + { + "name": "operation_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Operation Id" } + }, + { + "name": "x-api-key", + "in": "header", + "required": false, + "schema": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "X-Api-Key" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OperationCompleteRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "additionalProperties": true, + "title": "Response Postv2Completeoperation" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, "/api/chat/sessions": { "get": { "tags": ["v2", "chat", "chat"], @@ -1022,7 +1101,7 @@ "get": { "tags": ["v2", "chat", "chat"], "summary": "Get Session", - "description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.", + "description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.", "operationId": "getV2GetSession", "security": [{ "HTTPBearerJWT": [] }], "parameters": [ @@ -1157,7 +1236,7 @@ "post": { "tags": ["v2", "chat", "chat"], "summary": "Stream Chat Post", - "description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.", + "description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.", "operationId": "postV2StreamChatPost", "security": [{ "HTTPBearerJWT": [] }], "parameters": [ @@ -1195,6 +1274,94 @@ } } }, + "/api/chat/tasks/{task_id}": { + "get": { + "tags": ["v2", "chat", "chat"], + "summary": "Get Task Status", + "description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.", + "operationId": "getV2GetTaskStatus", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Task Id" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "additionalProperties": true, + "title": "Response Getv2Gettaskstatus" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/chat/tasks/{task_id}/stream": { + "get": { + "tags": ["v2", "chat", "chat"], + "summary": "Stream Task", + "description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.", + "operationId": "getV2StreamTask", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Task Id" } + }, + { + "name": "last_message_id", + "in": "query", + "required": false, + "schema": { + "type": "string", + "description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.", + "default": "0-0", + "title": "Last Message Id" + }, + "description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay." + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { "application/json": { "schema": {} } } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, "/api/credits": { "get": { "tags": ["v1", "credits"], @@ -6168,6 +6335,18 @@ "title": "AccuracyTrendsResponse", "description": "Response model for accuracy trends and alerts." }, + "ActiveStreamInfo": { + "properties": { + "task_id": { "type": "string", "title": "Task Id" }, + "last_message_id": { "type": "string", "title": "Last Message Id" }, + "operation_id": { "type": "string", "title": "Operation Id" }, + "tool_name": { "type": "string", "title": "Tool Name" } + }, + "type": "object", + "required": ["task_id", "last_message_id", "operation_id", "tool_name"], + "title": "ActiveStreamInfo", + "description": "Information about an active stream for reconnection." + }, "AddUserCreditsResponse": { "properties": { "new_balance": { "type": "integer", "title": "New Balance" }, @@ -8823,6 +9002,27 @@ ], "title": "OnboardingStep" }, + "OperationCompleteRequest": { + "properties": { + "success": { "type": "boolean", "title": "Success" }, + "result": { + "anyOf": [ + { "additionalProperties": true, "type": "object" }, + { "type": "string" }, + { "type": "null" } + ], + "title": "Result" + }, + "error": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Error" + } + }, + "type": "object", + "required": ["success"], + "title": "OperationCompleteRequest", + "description": "Request model for external completion webhook." + }, "Pagination": { "properties": { "total_items": { @@ -9678,6 +9878,12 @@ "items": { "additionalProperties": true, "type": "object" }, "type": "array", "title": "Messages" + }, + "active_stream": { + "anyOf": [ + { "$ref": "#/components/schemas/ActiveStreamInfo" }, + { "type": "null" } + ] } }, "type": "object", diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx index ada8c26231..da454150bf 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx @@ -1,7 +1,6 @@ "use client"; import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId"; -import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store"; import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { Text } from "@/components/atoms/Text/Text"; import { cn } from "@/lib/utils"; @@ -25,8 +24,8 @@ export function Chat({ }: ChatProps) { const { urlSessionId } = useCopilotSessionId(); const hasHandledNotFoundRef = useRef(false); - const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession); const { + session, messages, isLoading, isCreating, @@ -38,6 +37,18 @@ export function Chat({ startPollingForOperation, } = useChat({ urlSessionId }); + // Extract active stream info for reconnection + const activeStream = ( + session as { + active_stream?: { + task_id: string; + last_message_id: string; + operation_id: string; + tool_name: string; + }; + } + )?.active_stream; + useEffect(() => { if (!onSessionNotFound) return; if (!urlSessionId) return; @@ -53,8 +64,7 @@ export function Chat({ isCreating, ]); - const shouldShowLoader = - (showLoader && (isLoading || isCreating)) || isSwitchingSession; + const shouldShowLoader = showLoader && (isLoading || isCreating); return (
@@ -66,21 +76,19 @@ export function Chat({
- {isSwitchingSession - ? "Switching chat..." - : "Loading your chat..."} + Loading your chat...
)} {/* Error State */} - {error && !isLoading && !isSwitchingSession && ( + {error && !isLoading && ( )} {/* Session Content */} - {sessionId && !isLoading && !error && !isSwitchingSession && ( + {sessionId && !isLoading && !error && ( )} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/SSE_RECONNECTION.md b/autogpt_platform/frontend/src/components/contextual/Chat/SSE_RECONNECTION.md new file mode 100644 index 0000000000..9e78679f4e --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/SSE_RECONNECTION.md @@ -0,0 +1,159 @@ +# SSE Reconnection Contract for Long-Running Operations + +This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks. + +## Overview + +When a user triggers a long-running operation (like agent generation), the backend: + +1. Spawns a background task that survives SSE disconnections +2. Returns an `operation_started` response with a `task_id` +3. Stores stream messages in Redis Streams for replay + +Clients can reconnect to the task stream at any time to receive missed messages. + +## Client-Side Flow + +### 1. Receiving Operation Started + +When you receive an `operation_started` tool response: + +```typescript +// The response includes a task_id for reconnection +{ + type: "operation_started", + tool_name: "generate_agent", + operation_id: "uuid-...", + task_id: "task-uuid-...", // <-- Store this for reconnection + message: "Operation started. You can close this tab." +} +``` + +### 2. Storing Task Info + +Use the chat store to track the active task: + +```typescript +import { useChatStore } from "./chat-store"; + +// When operation_started is received: +useChatStore.getState().setActiveTask(sessionId, { + taskId: response.task_id, + operationId: response.operation_id, + toolName: response.tool_name, + lastMessageId: "0", +}); +``` + +### 3. Reconnecting to a Task + +To reconnect (e.g., after page refresh or tab reopen): + +```typescript +const { reconnectToTask, getActiveTask } = useChatStore.getState(); + +// Check if there's an active task for this session +const activeTask = getActiveTask(sessionId); + +if (activeTask) { + // Reconnect to the task stream + await reconnectToTask( + sessionId, + activeTask.taskId, + activeTask.lastMessageId, // Resume from last position + (chunk) => { + // Handle incoming chunks + console.log("Received chunk:", chunk); + }, + ); +} +``` + +### 4. Tracking Message Position + +To enable precise replay, update the last message ID as chunks arrive: + +```typescript +const { updateTaskLastMessageId } = useChatStore.getState(); + +function handleChunk(chunk: StreamChunk) { + // If chunk has an index/id, track it + if (chunk.idx !== undefined) { + updateTaskLastMessageId(sessionId, String(chunk.idx)); + } +} +``` + +## API Endpoints + +### Task Stream Reconnection + +``` +GET /api/chat/tasks/{taskId}/stream?last_message_id={idx} +``` + +- `taskId`: The task ID from `operation_started` +- `last_message_id`: Last received message index (default: "0" for full replay) + +Returns: SSE stream of missed messages + live updates + +## Chunk Types + +The reconnected stream follows the same Vercel AI SDK protocol: + +| Type | Description | +| ----------------------- | ----------------------- | +| `start` | Message lifecycle start | +| `text-delta` | Streaming text content | +| `text-end` | Text block completed | +| `tool-output-available` | Tool result available | +| `finish` | Stream completed | +| `error` | Error occurred | + +## Error Handling + +If reconnection fails: + +1. Check if task still exists (may have expired - default TTL: 1 hour) +2. Fall back to polling the session for final state +3. Show appropriate UI message to user + +## Persistence Considerations + +For robust reconnection across browser restarts: + +```typescript +// Store in localStorage/sessionStorage +const ACTIVE_TASKS_KEY = "chat_active_tasks"; + +function persistActiveTask(sessionId: string, task: ActiveTaskInfo) { + const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}"); + tasks[sessionId] = task; + localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks)); +} + +function loadPersistedTasks(): Record { + return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}"); +} +``` + +## Backend Configuration + +The following backend settings affect reconnection behavior: + +| Setting | Default | Description | +| ------------------- | ------- | ---------------------------------- | +| `stream_ttl` | 3600s | How long streams are kept in Redis | +| `stream_max_length` | 1000 | Max messages per stream | + +## Testing + +To test reconnection locally: + +1. Start a long-running operation (e.g., agent generation) +2. Note the `task_id` from the `operation_started` response +3. Close the browser tab +4. Reopen and call `reconnectToTask` with the saved `task_id` +5. Verify that missed messages are replayed + +See the main README for full local development setup. diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-constants.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-constants.ts new file mode 100644 index 0000000000..8802de2155 --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-constants.ts @@ -0,0 +1,16 @@ +/** + * Constants for the chat system. + * + * Centralizes magic strings and values used across chat components. + */ + +// LocalStorage keys +export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks"; + +// Redis Stream IDs +export const INITIAL_MESSAGE_ID = "0"; +export const INITIAL_STREAM_ID = "0-0"; + +// TTL values (in milliseconds) +export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes +export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts index 8229630e5d..3083f65d2c 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts @@ -1,6 +1,12 @@ "use client"; import { create } from "zustand"; +import { + ACTIVE_TASK_TTL_MS, + COMPLETED_STREAM_TTL_MS, + INITIAL_STREAM_ID, + STORAGE_KEY_ACTIVE_TASKS, +} from "./chat-constants"; import type { ActiveStream, StreamChunk, @@ -8,15 +14,59 @@ import type { StreamResult, StreamStatus, } from "./chat-types"; -import { executeStream } from "./stream-executor"; +import { executeStream, executeTaskReconnect } from "./stream-executor"; -const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes +export interface ActiveTaskInfo { + taskId: string; + sessionId: string; + operationId: string; + toolName: string; + lastMessageId: string; + startedAt: number; +} + +/** Load active tasks from localStorage */ +function loadPersistedTasks(): Map { + if (typeof window === "undefined") return new Map(); + try { + const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS); + if (!stored) return new Map(); + const parsed = JSON.parse(stored) as Record; + const now = Date.now(); + const tasks = new Map(); + // Filter out expired tasks + for (const [sessionId, task] of Object.entries(parsed)) { + if (now - task.startedAt < ACTIVE_TASK_TTL_MS) { + tasks.set(sessionId, task); + } + } + return tasks; + } catch { + return new Map(); + } +} + +/** Save active tasks to localStorage */ +function persistTasks(tasks: Map): void { + if (typeof window === "undefined") return; + try { + const obj: Record = {}; + for (const [sessionId, task] of tasks) { + obj[sessionId] = task; + } + localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj)); + } catch { + // Ignore storage errors + } +} interface ChatStoreState { activeStreams: Map; completedStreams: Map; activeSessions: Set; streamCompleteCallbacks: Set; + /** Active tasks for SSE reconnection - keyed by sessionId */ + activeTasks: Map; } interface ChatStoreActions { @@ -41,6 +91,24 @@ interface ChatStoreActions { unregisterActiveSession: (sessionId: string) => void; isSessionActive: (sessionId: string) => boolean; onStreamComplete: (callback: StreamCompleteCallback) => () => void; + /** Track active task for SSE reconnection */ + setActiveTask: ( + sessionId: string, + taskInfo: Omit, + ) => void; + /** Get active task for a session */ + getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined; + /** Clear active task when operation completes */ + clearActiveTask: (sessionId: string) => void; + /** Reconnect to an existing task stream */ + reconnectToTask: ( + sessionId: string, + taskId: string, + lastMessageId?: string, + onChunk?: (chunk: StreamChunk) => void, + ) => Promise; + /** Update last message ID for a task (for tracking replay position) */ + updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void; } type ChatStore = ChatStoreState & ChatStoreActions; @@ -64,18 +132,126 @@ function cleanupExpiredStreams( const now = Date.now(); const cleaned = new Map(completedStreams); for (const [sessionId, result] of cleaned) { - if (now - result.completedAt > COMPLETED_STREAM_TTL) { + if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) { cleaned.delete(sessionId); } } return cleaned; } +/** + * Finalize a stream by moving it from activeStreams to completedStreams. + * Also handles cleanup and notifications. + */ +function finalizeStream( + sessionId: string, + stream: ActiveStream, + onChunk: ((chunk: StreamChunk) => void) | undefined, + get: () => ChatStoreState & ChatStoreActions, + set: (state: Partial) => void, +): void { + if (onChunk) stream.onChunkCallbacks.delete(onChunk); + + if (stream.status !== "streaming") { + const currentState = get(); + const finalActiveStreams = new Map(currentState.activeStreams); + let finalCompletedStreams = new Map(currentState.completedStreams); + + const storedStream = finalActiveStreams.get(sessionId); + if (storedStream === stream) { + const result: StreamResult = { + sessionId, + status: stream.status, + chunks: stream.chunks, + completedAt: Date.now(), + error: stream.error, + }; + finalCompletedStreams.set(sessionId, result); + finalActiveStreams.delete(sessionId); + finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams); + set({ + activeStreams: finalActiveStreams, + completedStreams: finalCompletedStreams, + }); + + if (stream.status === "completed" || stream.status === "error") { + notifyStreamComplete(currentState.streamCompleteCallbacks, sessionId); + } + } + } +} + +/** + * Clean up an existing stream for a session and move it to completed streams. + * Returns updated maps for both active and completed streams. + */ +function cleanupExistingStream( + sessionId: string, + activeStreams: Map, + completedStreams: Map, + callbacks: Set, +): { + activeStreams: Map; + completedStreams: Map; +} { + const newActiveStreams = new Map(activeStreams); + let newCompletedStreams = new Map(completedStreams); + + const existingStream = newActiveStreams.get(sessionId); + if (existingStream) { + existingStream.abortController.abort(); + const normalizedStatus = + existingStream.status === "streaming" + ? "completed" + : existingStream.status; + const result: StreamResult = { + sessionId, + status: normalizedStatus, + chunks: existingStream.chunks, + completedAt: Date.now(), + error: existingStream.error, + }; + newCompletedStreams.set(sessionId, result); + newActiveStreams.delete(sessionId); + newCompletedStreams = cleanupExpiredStreams(newCompletedStreams); + if (normalizedStatus === "completed" || normalizedStatus === "error") { + notifyStreamComplete(callbacks, sessionId); + } + } + + return { + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + }; +} + +/** + * Create a new active stream with initial state. + */ +function createActiveStream( + sessionId: string, + onChunk?: (chunk: StreamChunk) => void, +): ActiveStream { + const abortController = new AbortController(); + const initialCallbacks = new Set<(chunk: StreamChunk) => void>(); + if (onChunk) initialCallbacks.add(onChunk); + + return { + sessionId, + abortController, + status: "streaming", + startedAt: Date.now(), + chunks: [], + onChunkCallbacks: initialCallbacks, + }; +} + export const useChatStore = create((set, get) => ({ activeStreams: new Map(), completedStreams: new Map(), activeSessions: new Set(), streamCompleteCallbacks: new Set(), + activeTasks: loadPersistedTasks(), startStream: async function startStream( sessionId, @@ -85,45 +261,21 @@ export const useChatStore = create((set, get) => ({ onChunk, ) { const state = get(); - const newActiveStreams = new Map(state.activeStreams); - let newCompletedStreams = new Map(state.completedStreams); const callbacks = state.streamCompleteCallbacks; - const existingStream = newActiveStreams.get(sessionId); - if (existingStream) { - existingStream.abortController.abort(); - const normalizedStatus = - existingStream.status === "streaming" - ? "completed" - : existingStream.status; - const result: StreamResult = { - sessionId, - status: normalizedStatus, - chunks: existingStream.chunks, - completedAt: Date.now(), - error: existingStream.error, - }; - newCompletedStreams.set(sessionId, result); - newActiveStreams.delete(sessionId); - newCompletedStreams = cleanupExpiredStreams(newCompletedStreams); - if (normalizedStatus === "completed" || normalizedStatus === "error") { - notifyStreamComplete(callbacks, sessionId); - } - } - - const abortController = new AbortController(); - const initialCallbacks = new Set<(chunk: StreamChunk) => void>(); - if (onChunk) initialCallbacks.add(onChunk); - - const stream: ActiveStream = { + // Clean up any existing stream for this session + const { + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + } = cleanupExistingStream( sessionId, - abortController, - status: "streaming", - startedAt: Date.now(), - chunks: [], - onChunkCallbacks: initialCallbacks, - }; + state.activeStreams, + state.completedStreams, + callbacks, + ); + // Create new stream + const stream = createActiveStream(sessionId, onChunk); newActiveStreams.set(sessionId, stream); set({ activeStreams: newActiveStreams, @@ -133,36 +285,7 @@ export const useChatStore = create((set, get) => ({ try { await executeStream(stream, message, isUserMessage, context); } finally { - if (onChunk) stream.onChunkCallbacks.delete(onChunk); - if (stream.status !== "streaming") { - const currentState = get(); - const finalActiveStreams = new Map(currentState.activeStreams); - let finalCompletedStreams = new Map(currentState.completedStreams); - - const storedStream = finalActiveStreams.get(sessionId); - if (storedStream === stream) { - const result: StreamResult = { - sessionId, - status: stream.status, - chunks: stream.chunks, - completedAt: Date.now(), - error: stream.error, - }; - finalCompletedStreams.set(sessionId, result); - finalActiveStreams.delete(sessionId); - finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams); - set({ - activeStreams: finalActiveStreams, - completedStreams: finalCompletedStreams, - }); - if (stream.status === "completed" || stream.status === "error") { - notifyStreamComplete( - currentState.streamCompleteCallbacks, - sessionId, - ); - } - } - } + finalizeStream(sessionId, stream, onChunk, get, set); } }, @@ -286,4 +409,93 @@ export const useChatStore = create((set, get) => ({ set({ streamCompleteCallbacks: cleanedCallbacks }); }; }, + + setActiveTask: function setActiveTask(sessionId, taskInfo) { + const state = get(); + const newActiveTasks = new Map(state.activeTasks); + newActiveTasks.set(sessionId, { + ...taskInfo, + sessionId, + startedAt: Date.now(), + }); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + }, + + getActiveTask: function getActiveTask(sessionId) { + return get().activeTasks.get(sessionId); + }, + + clearActiveTask: function clearActiveTask(sessionId) { + const state = get(); + if (!state.activeTasks.has(sessionId)) return; + + const newActiveTasks = new Map(state.activeTasks); + newActiveTasks.delete(sessionId); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + }, + + reconnectToTask: async function reconnectToTask( + sessionId, + taskId, + lastMessageId = INITIAL_STREAM_ID, + onChunk, + ) { + const state = get(); + const callbacks = state.streamCompleteCallbacks; + + // Clean up any existing stream for this session + const { + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + } = cleanupExistingStream( + sessionId, + state.activeStreams, + state.completedStreams, + callbacks, + ); + + // Create new stream for reconnection + const stream = createActiveStream(sessionId, onChunk); + newActiveStreams.set(sessionId, stream); + set({ + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + }); + + try { + await executeTaskReconnect(stream, taskId, lastMessageId); + } finally { + finalizeStream(sessionId, stream, onChunk, get, set); + + // Clear active task on completion + if (stream.status === "completed" || stream.status === "error") { + const taskState = get(); + if (taskState.activeTasks.has(sessionId)) { + const newActiveTasks = new Map(taskState.activeTasks); + newActiveTasks.delete(sessionId); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + } + } + } + }, + + updateTaskLastMessageId: function updateTaskLastMessageId( + sessionId, + lastMessageId, + ) { + const state = get(); + const task = state.activeTasks.get(sessionId); + if (!task) return; + + const newActiveTasks = new Map(state.activeTasks); + newActiveTasks.set(sessionId, { + ...task, + lastMessageId, + }); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + }, })); diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts index 8c8aa7b704..34813e17fe 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts @@ -4,6 +4,7 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error"; export interface StreamChunk { type: + | "stream_start" | "text_chunk" | "text_ended" | "tool_call" @@ -15,6 +16,7 @@ export interface StreamChunk { | "error" | "usage" | "stream_end"; + taskId?: string; timestamp?: string; content?: string; message?: string; @@ -41,7 +43,7 @@ export interface StreamChunk { } export type VercelStreamChunk = - | { type: "start"; messageId: string } + | { type: "start"; messageId: string; taskId?: string } | { type: "finish" } | { type: "text-start"; id: string } | { type: "text-delta"; id: string; delta: string } @@ -92,3 +94,70 @@ export interface StreamResult { } export type StreamCompleteCallback = (sessionId: string) => void; + +// Type guards for message types + +/** + * Check if a message has a toolId property. + */ +export function hasToolId( + msg: T, +): msg is T & { toolId: string } { + return ( + "toolId" in msg && + typeof (msg as Record).toolId === "string" + ); +} + +/** + * Check if a message has an operationId property. + */ +export function hasOperationId( + msg: T, +): msg is T & { operationId: string } { + return ( + "operationId" in msg && + typeof (msg as Record).operationId === "string" + ); +} + +/** + * Check if a message has a toolCallId property. + */ +export function hasToolCallId( + msg: T, +): msg is T & { toolCallId: string } { + return ( + "toolCallId" in msg && + typeof (msg as Record).toolCallId === "string" + ); +} + +/** + * Check if a message is an operation message type. + */ +export function isOperationMessage( + msg: T, +): msg is T & { + type: "operation_started" | "operation_pending" | "operation_in_progress"; +} { + return ( + msg.type === "operation_started" || + msg.type === "operation_pending" || + msg.type === "operation_in_progress" + ); +} + +/** + * Get the tool ID from a message if available. + * Checks toolId, operationId, and toolCallId properties. + */ +export function getToolIdFromMessage( + msg: T, +): string | undefined { + const record = msg as Record; + if (typeof record.toolId === "string") return record.toolId; + if (typeof record.operationId === "string") return record.operationId; + if (typeof record.toolCallId === "string") return record.toolCallId; + return undefined; +} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx index dec221338a..5df9944f47 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx @@ -17,6 +17,13 @@ export interface ChatContainerProps { className?: string; onStreamingChange?: (isStreaming: boolean) => void; onOperationStarted?: () => void; + /** Active stream info from the server for reconnection */ + activeStream?: { + taskId: string; + lastMessageId: string; + operationId: string; + toolName: string; + }; } export function ChatContainer({ @@ -26,6 +33,7 @@ export function ChatContainer({ className, onStreamingChange, onOperationStarted, + activeStream, }: ChatContainerProps) { const { messages, @@ -41,6 +49,7 @@ export function ChatContainer({ initialMessages, initialPrompt, onOperationStarted, + activeStream, }); useEffect(() => { diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts index 82e9b05e88..af3b3329b7 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts @@ -2,6 +2,7 @@ import { toast } from "sonner"; import type { StreamChunk } from "../../chat-types"; import type { HandlerDependencies } from "./handlers"; import { + getErrorDisplayMessage, handleError, handleLoginNeeded, handleStreamEnd, @@ -24,16 +25,22 @@ export function createStreamEventDispatcher( chunk.type === "need_login" || chunk.type === "error" ) { - if (!deps.hasResponseRef.current) { - console.info("[ChatStream] First response chunk:", { - type: chunk.type, - sessionId: deps.sessionId, - }); - } deps.hasResponseRef.current = true; } switch (chunk.type) { + case "stream_start": + // Store task ID for SSE reconnection + if (chunk.taskId && deps.onActiveTaskStarted) { + deps.onActiveTaskStarted({ + taskId: chunk.taskId, + operationId: chunk.taskId, + toolName: "chat", + toolCallId: "chat_stream", + }); + } + break; + case "text_chunk": handleTextChunk(chunk, deps); break; @@ -56,11 +63,7 @@ export function createStreamEventDispatcher( break; case "stream_end": - console.info("[ChatStream] Stream ended:", { - sessionId: deps.sessionId, - hasResponse: deps.hasResponseRef.current, - chunkCount: deps.streamingChunksRef.current.length, - }); + // Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk handleStreamEnd(chunk, deps); break; @@ -70,7 +73,7 @@ export function createStreamEventDispatcher( // Show toast at dispatcher level to avoid circular dependencies if (!isRegionBlocked) { toast.error("Chat Error", { - description: chunk.message || chunk.content || "An error occurred", + description: getErrorDisplayMessage(chunk), }); } break; diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts index f3cac01f96..5aec5b9818 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts @@ -18,11 +18,19 @@ export interface HandlerDependencies { setStreamingChunks: Dispatch>; streamingChunksRef: MutableRefObject; hasResponseRef: MutableRefObject; + textFinalizedRef: MutableRefObject; + streamEndedRef: MutableRefObject; setMessages: Dispatch>; setIsStreamingInitiated: Dispatch>; setIsRegionBlockedModalOpen: Dispatch>; sessionId: string; onOperationStarted?: () => void; + onActiveTaskStarted?: (taskInfo: { + taskId: string; + operationId: string; + toolName: string; + toolCallId: string; + }) => void; } export function isRegionBlockedError(chunk: StreamChunk): boolean { @@ -32,6 +40,25 @@ export function isRegionBlockedError(chunk: StreamChunk): boolean { return message.toLowerCase().includes("not available in your region"); } +export function getUserFriendlyErrorMessage( + code: string | undefined, +): string | undefined { + switch (code) { + case "TASK_EXPIRED": + return "This operation has expired. Please try again."; + case "TASK_NOT_FOUND": + return "Could not find the requested operation."; + case "ACCESS_DENIED": + return "You do not have access to this operation."; + case "QUEUE_OVERFLOW": + return "Connection was interrupted. Please refresh to continue."; + case "MODEL_NOT_AVAILABLE_REGION": + return "This model is not available in your region."; + default: + return undefined; + } +} + export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) { if (!chunk.content) return; deps.setHasTextChunks(true); @@ -46,10 +73,15 @@ export function handleTextEnded( _chunk: StreamChunk, deps: HandlerDependencies, ) { + if (deps.textFinalizedRef.current) { + return; + } + const completedText = deps.streamingChunksRef.current.join(""); if (completedText.trim()) { + deps.textFinalizedRef.current = true; + deps.setMessages((prev) => { - // Check if this exact message already exists to prevent duplicates const exists = prev.some( (msg) => msg.type === "message" && @@ -76,9 +108,14 @@ export function handleToolCallStart( chunk: StreamChunk, deps: HandlerDependencies, ) { + // Use deterministic fallback instead of Date.now() to ensure same ID on replay + const toolId = + chunk.tool_id || + `tool-${deps.sessionId}-${chunk.idx ?? "unknown"}-${chunk.tool_name || "unknown"}`; + const toolCallMessage: Extract = { type: "tool_call", - toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`, + toolId, toolName: chunk.tool_name || "Executing", arguments: chunk.arguments || {}, timestamp: new Date(), @@ -111,6 +148,29 @@ export function handleToolCallStart( deps.setMessages(updateToolCallMessages); } +const TOOL_RESPONSE_TYPES = new Set([ + "tool_response", + "operation_started", + "operation_pending", + "operation_in_progress", + "execution_started", + "agent_carousel", + "clarification_needed", +]); + +function hasResponseForTool( + messages: ChatMessageData[], + toolId: string, +): boolean { + return messages.some((msg) => { + if (!TOOL_RESPONSE_TYPES.has(msg.type)) return false; + const msgToolId = + (msg as { toolId?: string }).toolId || + (msg as { toolCallId?: string }).toolCallId; + return msgToolId === toolId; + }); +} + export function handleToolResponse( chunk: StreamChunk, deps: HandlerDependencies, @@ -152,31 +212,49 @@ export function handleToolResponse( ) { const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name); if (inputsMessage) { - deps.setMessages((prev) => [...prev, inputsMessage]); + deps.setMessages((prev) => { + // Check for duplicate inputs_needed message + const exists = prev.some((msg) => msg.type === "inputs_needed"); + if (exists) return prev; + return [...prev, inputsMessage]; + }); } const credentialsMessage = extractCredentialsNeeded( parsedResult, chunk.tool_name, ); if (credentialsMessage) { - deps.setMessages((prev) => [...prev, credentialsMessage]); + deps.setMessages((prev) => { + // Check for duplicate credentials_needed message + const exists = prev.some((msg) => msg.type === "credentials_needed"); + if (exists) return prev; + return [...prev, credentialsMessage]; + }); } } return; } - // Trigger polling when operation_started is received if (responseMessage.type === "operation_started") { deps.onOperationStarted?.(); + const taskId = (responseMessage as { taskId?: string }).taskId; + if (taskId && deps.onActiveTaskStarted) { + deps.onActiveTaskStarted({ + taskId, + operationId: + (responseMessage as { operationId?: string }).operationId || "", + toolName: (responseMessage as { toolName?: string }).toolName || "", + toolCallId: (responseMessage as { toolId?: string }).toolId || "", + }); + } } deps.setMessages((prev) => { const toolCallIndex = prev.findIndex( (msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id, ); - const hasResponse = prev.some( - (msg) => msg.type === "tool_response" && msg.toolId === chunk.tool_id, - ); - if (hasResponse) return prev; + if (hasResponseForTool(prev, chunk.tool_id!)) { + return prev; + } if (toolCallIndex !== -1) { const newMessages = [...prev]; newMessages.splice(toolCallIndex + 1, 0, responseMessage); @@ -198,28 +276,48 @@ export function handleLoginNeeded( agentInfo: chunk.agent_info, timestamp: new Date(), }; - deps.setMessages((prev) => [...prev, loginNeededMessage]); + deps.setMessages((prev) => { + // Check for duplicate login_needed message + const exists = prev.some((msg) => msg.type === "login_needed"); + if (exists) return prev; + return [...prev, loginNeededMessage]; + }); } export function handleStreamEnd( _chunk: StreamChunk, deps: HandlerDependencies, ) { + if (deps.streamEndedRef.current) { + return; + } + deps.streamEndedRef.current = true; + const completedContent = deps.streamingChunksRef.current.join(""); if (!completedContent.trim() && !deps.hasResponseRef.current) { - deps.setMessages((prev) => [ - ...prev, - { - type: "message", - role: "assistant", - content: "No response received. Please try again.", - timestamp: new Date(), - }, - ]); - } - if (completedContent.trim()) { deps.setMessages((prev) => { - // Check if this exact message already exists to prevent duplicates + const exists = prev.some( + (msg) => + msg.type === "message" && + msg.role === "assistant" && + msg.content === "No response received. Please try again.", + ); + if (exists) return prev; + return [ + ...prev, + { + type: "message", + role: "assistant", + content: "No response received. Please try again.", + timestamp: new Date(), + }, + ]; + }); + } + if (completedContent.trim() && !deps.textFinalizedRef.current) { + deps.textFinalizedRef.current = true; + + deps.setMessages((prev) => { const exists = prev.some( (msg) => msg.type === "message" && @@ -244,8 +342,6 @@ export function handleStreamEnd( } export function handleError(chunk: StreamChunk, deps: HandlerDependencies) { - const errorMessage = chunk.message || chunk.content || "An error occurred"; - console.error("Stream error:", errorMessage); if (isRegionBlockedError(chunk)) { deps.setIsRegionBlockedModalOpen(true); } @@ -253,4 +349,14 @@ export function handleError(chunk: StreamChunk, deps: HandlerDependencies) { deps.setHasTextChunks(false); deps.setStreamingChunks([]); deps.streamingChunksRef.current = []; + deps.textFinalizedRef.current = false; + deps.streamEndedRef.current = true; +} + +export function getErrorDisplayMessage(chunk: StreamChunk): string { + const friendlyMessage = getUserFriendlyErrorMessage(chunk.code); + if (friendlyMessage) { + return friendlyMessage; + } + return chunk.message || chunk.content || "An error occurred"; } diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts index e744c9bc34..f1e94cea17 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts @@ -349,6 +349,7 @@ export function parseToolResponse( toolName: (parsedResult.tool_name as string) || toolName, toolId, operationId: (parsedResult.operation_id as string) || "", + taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection message: (parsedResult.message as string) || "Operation started. You can close this tab.", diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts index 46f384d055..248383df42 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts @@ -1,10 +1,17 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; import { useEffect, useMemo, useRef, useState } from "react"; +import { INITIAL_STREAM_ID } from "../../chat-constants"; import { useChatStore } from "../../chat-store"; import { toast } from "sonner"; import { useChatStream } from "../../useChatStream"; import { usePageContext } from "../../usePageContext"; import type { ChatMessageData } from "../ChatMessage/useChatMessage"; +import { + getToolIdFromMessage, + hasToolId, + isOperationMessage, + type StreamChunk, +} from "../../chat-types"; import { createStreamEventDispatcher } from "./createStreamEventDispatcher"; import { createUserMessage, @@ -14,6 +21,13 @@ import { processInitialMessages, } from "./helpers"; +const TOOL_RESULT_TYPES = new Set([ + "tool_response", + "agent_carousel", + "execution_started", + "clarification_needed", +]); + // Helper to generate deduplication key for a message function getMessageKey(msg: ChatMessageData): string { if (msg.type === "message") { @@ -23,14 +37,18 @@ function getMessageKey(msg: ChatMessageData): string { return `msg:${msg.role}:${msg.content}`; } else if (msg.type === "tool_call") { return `toolcall:${msg.toolId}`; - } else if (msg.type === "tool_response") { - return `toolresponse:${(msg as any).toolId}`; - } else if ( - msg.type === "operation_started" || - msg.type === "operation_pending" || - msg.type === "operation_in_progress" - ) { - return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`; + } else if (TOOL_RESULT_TYPES.has(msg.type)) { + // Unified key for all tool result types - same toolId with different types + // (tool_response vs agent_carousel) should deduplicate to the same key + const toolId = getToolIdFromMessage(msg); + // If no toolId, fall back to content-based key to avoid empty key collisions + if (!toolId) { + return `toolresult:content:${JSON.stringify(msg).slice(0, 200)}`; + } + return `toolresult:${toolId}`; + } else if (isOperationMessage(msg)) { + const toolId = getToolIdFromMessage(msg) || ""; + return `op:${toolId}:${msg.toolName}`; } else { return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`; } @@ -41,6 +59,13 @@ interface Args { initialMessages: SessionDetailResponse["messages"]; initialPrompt?: string; onOperationStarted?: () => void; + /** Active stream info from the server for reconnection */ + activeStream?: { + taskId: string; + lastMessageId: string; + operationId: string; + toolName: string; + }; } export function useChatContainer({ @@ -48,6 +73,7 @@ export function useChatContainer({ initialMessages, initialPrompt, onOperationStarted, + activeStream, }: Args) { const [messages, setMessages] = useState([]); const [streamingChunks, setStreamingChunks] = useState([]); @@ -57,6 +83,8 @@ export function useChatContainer({ useState(false); const hasResponseRef = useRef(false); const streamingChunksRef = useRef([]); + const textFinalizedRef = useRef(false); + const streamEndedRef = useRef(false); const previousSessionIdRef = useRef(null); const { error, @@ -65,44 +93,182 @@ export function useChatContainer({ } = useChatStream(); const activeStreams = useChatStore((s) => s.activeStreams); const subscribeToStream = useChatStore((s) => s.subscribeToStream); + const setActiveTask = useChatStore((s) => s.setActiveTask); + const getActiveTask = useChatStore((s) => s.getActiveTask); + const reconnectToTask = useChatStore((s) => s.reconnectToTask); const isStreaming = isStreamingInitiated || hasTextChunks; + // Track whether we've already connected to this activeStream to avoid duplicate connections + const connectedActiveStreamRef = useRef(null); + // Track if component is mounted to prevent state updates after unmount + const isMountedRef = useRef(true); + // Track current dispatcher to prevent multiple dispatchers from adding messages + const currentDispatcherIdRef = useRef(0); + + // Set mounted flag - reset on every mount, cleanup on unmount + useEffect(function trackMountedState() { + isMountedRef.current = true; + return function cleanup() { + isMountedRef.current = false; + }; + }, []); + + // Callback to store active task info for SSE reconnection + function handleActiveTaskStarted(taskInfo: { + taskId: string; + operationId: string; + toolName: string; + toolCallId: string; + }) { + if (!sessionId) return; + setActiveTask(sessionId, { + taskId: taskInfo.taskId, + operationId: taskInfo.operationId, + toolName: taskInfo.toolName, + lastMessageId: INITIAL_STREAM_ID, + }); + } + + // Create dispatcher for stream events - stable reference for current sessionId + // Each dispatcher gets a unique ID to prevent stale dispatchers from updating state + function createDispatcher() { + if (!sessionId) return () => {}; + // Increment dispatcher ID - only the most recent dispatcher should update state + const dispatcherId = ++currentDispatcherIdRef.current; + + const baseDispatcher = createStreamEventDispatcher({ + setHasTextChunks, + setStreamingChunks, + streamingChunksRef, + hasResponseRef, + textFinalizedRef, + streamEndedRef, + setMessages, + setIsRegionBlockedModalOpen, + sessionId, + setIsStreamingInitiated, + onOperationStarted, + onActiveTaskStarted: handleActiveTaskStarted, + }); + + // Wrap dispatcher to check if it's still the current one + return function guardedDispatcher(chunk: StreamChunk) { + // Skip if component unmounted or this is a stale dispatcher + if (!isMountedRef.current) { + return; + } + if (dispatcherId !== currentDispatcherIdRef.current) { + return; + } + baseDispatcher(chunk); + }; + } useEffect( function handleSessionChange() { - if (sessionId === previousSessionIdRef.current) return; + const isSessionChange = sessionId !== previousSessionIdRef.current; - const prevSession = previousSessionIdRef.current; - if (prevSession) { - stopStreaming(prevSession); + // Handle session change - reset state + if (isSessionChange) { + const prevSession = previousSessionIdRef.current; + if (prevSession) { + stopStreaming(prevSession); + } + previousSessionIdRef.current = sessionId; + connectedActiveStreamRef.current = null; + setMessages([]); + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + setIsStreamingInitiated(false); + hasResponseRef.current = false; + textFinalizedRef.current = false; + streamEndedRef.current = false; } - previousSessionIdRef.current = sessionId; - setMessages([]); - setStreamingChunks([]); - streamingChunksRef.current = []; - setHasTextChunks(false); - setIsStreamingInitiated(false); - hasResponseRef.current = false; if (!sessionId) return; - const activeStream = activeStreams.get(sessionId); - if (!activeStream || activeStream.status !== "streaming") return; + // Priority 1: Check if server told us there's an active stream (most authoritative) + if (activeStream) { + const streamKey = `${sessionId}:${activeStream.taskId}`; - const dispatcher = createStreamEventDispatcher({ - setHasTextChunks, - setStreamingChunks, - streamingChunksRef, - hasResponseRef, - setMessages, - setIsRegionBlockedModalOpen, - sessionId, - setIsStreamingInitiated, - onOperationStarted, - }); + if (connectedActiveStreamRef.current === streamKey) { + return; + } + + // Skip if there's already an active stream for this session in the store + const existingStream = activeStreams.get(sessionId); + if (existingStream && existingStream.status === "streaming") { + connectedActiveStreamRef.current = streamKey; + return; + } + + connectedActiveStreamRef.current = streamKey; + + // Clear all state before reconnection to prevent duplicates + // Server's initialMessages is authoritative; local state will be rebuilt from SSE replay + setMessages([]); + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + textFinalizedRef.current = false; + streamEndedRef.current = false; + hasResponseRef.current = false; + + setIsStreamingInitiated(true); + setActiveTask(sessionId, { + taskId: activeStream.taskId, + operationId: activeStream.operationId, + toolName: activeStream.toolName, + lastMessageId: activeStream.lastMessageId, + }); + reconnectToTask( + sessionId, + activeStream.taskId, + activeStream.lastMessageId, + createDispatcher(), + ); + // Don't return cleanup here - the guarded dispatcher handles stale events + // and the stream will complete naturally. Cleanup would prematurely stop + // the stream when effect re-runs due to activeStreams changing. + return; + } + + // Only check localStorage/in-memory on session change + if (!isSessionChange) return; + + // Priority 2: Check localStorage for active task + const activeTask = getActiveTask(sessionId); + if (activeTask) { + // Clear all state before reconnection to prevent duplicates + // Server's initialMessages is authoritative; local state will be rebuilt from SSE replay + setMessages([]); + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + textFinalizedRef.current = false; + streamEndedRef.current = false; + hasResponseRef.current = false; + + setIsStreamingInitiated(true); + reconnectToTask( + sessionId, + activeTask.taskId, + activeTask.lastMessageId, + createDispatcher(), + ); + // Don't return cleanup here - the guarded dispatcher handles stale events + return; + } + + // Priority 3: Check for an in-memory active stream (same-tab scenario) + const inMemoryStream = activeStreams.get(sessionId); + if (!inMemoryStream || inMemoryStream.status !== "streaming") { + return; + } setIsStreamingInitiated(true); const skipReplay = initialMessages.length > 0; - return subscribeToStream(sessionId, dispatcher, skipReplay); + return subscribeToStream(sessionId, createDispatcher(), skipReplay); }, [ sessionId, @@ -110,6 +276,10 @@ export function useChatContainer({ activeStreams, subscribeToStream, onOperationStarted, + getActiveTask, + reconnectToTask, + activeStream, + setActiveTask, ], ); @@ -124,7 +294,7 @@ export function useChatContainer({ msg.type === "agent_carousel" || msg.type === "execution_started" ) { - const toolId = (msg as any).toolId; + const toolId = hasToolId(msg) ? msg.toolId : undefined; if (toolId) { ids.add(toolId); } @@ -141,12 +311,8 @@ export function useChatContainer({ setMessages((prev) => { const filtered = prev.filter((msg) => { - if ( - msg.type === "operation_started" || - msg.type === "operation_pending" || - msg.type === "operation_in_progress" - ) { - const toolId = (msg as any).toolId || (msg as any).toolCallId; + if (isOperationMessage(msg)) { + const toolId = getToolIdFromMessage(msg); if (toolId && completedToolIds.has(toolId)) { return false; // Remove - operation completed } @@ -174,12 +340,8 @@ export function useChatContainer({ // Filter local messages: remove duplicates and completed operation messages const newLocalMessages = messages.filter((msg) => { // Remove operation messages for completed tools - if ( - msg.type === "operation_started" || - msg.type === "operation_pending" || - msg.type === "operation_in_progress" - ) { - const toolId = (msg as any).toolId || (msg as any).toolCallId; + if (isOperationMessage(msg)) { + const toolId = getToolIdFromMessage(msg); if (toolId && completedToolIds.has(toolId)) { return false; } @@ -190,7 +352,70 @@ export function useChatContainer({ }); // Server messages first (correct order), then new local messages - return [...processedInitial, ...newLocalMessages]; + const combined = [...processedInitial, ...newLocalMessages]; + + // Post-processing: Remove duplicate assistant messages that can occur during + // race conditions (e.g., rapid screen switching during SSE reconnection). + // Two assistant messages are considered duplicates if: + // - They are both text messages with role "assistant" + // - One message's content starts with the other's content (partial vs complete) + // - Or they have very similar content (>80% overlap at the start) + const deduplicated: ChatMessageData[] = []; + for (let i = 0; i < combined.length; i++) { + const current = combined[i]; + + // Check if this is an assistant text message + if (current.type !== "message" || current.role !== "assistant") { + deduplicated.push(current); + continue; + } + + // Look for duplicate assistant messages in the rest of the array + let dominated = false; + for (let j = 0; j < combined.length; j++) { + if (i === j) continue; + const other = combined[j]; + if (other.type !== "message" || other.role !== "assistant") continue; + + const currentContent = current.content || ""; + const otherContent = other.content || ""; + + // Skip empty messages + if (!currentContent.trim() || !otherContent.trim()) continue; + + // Check if current is a prefix of other (current is incomplete version) + if ( + otherContent.length > currentContent.length && + otherContent.startsWith(currentContent.slice(0, 100)) + ) { + // Current is a shorter/incomplete version of other - skip it + dominated = true; + break; + } + + // Check if messages are nearly identical (within a small difference) + // This catches cases where content differs only slightly + const minLen = Math.min(currentContent.length, otherContent.length); + const compareLen = Math.min(minLen, 200); // Compare first 200 chars + if ( + compareLen > 50 && + currentContent.slice(0, compareLen) === + otherContent.slice(0, compareLen) + ) { + // Same prefix - keep the longer one + if (otherContent.length > currentContent.length) { + dominated = true; + break; + } + } + } + + if (!dominated) { + deduplicated.push(current); + } + } + + return deduplicated; }, [initialMessages, messages, completedToolIds]); async function sendMessage( @@ -198,10 +423,8 @@ export function useChatContainer({ isUserMessage: boolean = true, context?: { url: string; content: string }, ) { - if (!sessionId) { - console.error("[useChatContainer] Cannot send message: no session ID"); - return; - } + if (!sessionId) return; + setIsRegionBlockedModalOpen(false); if (isUserMessage) { const userMessage = createUserMessage(content); @@ -214,31 +437,19 @@ export function useChatContainer({ setHasTextChunks(false); setIsStreamingInitiated(true); hasResponseRef.current = false; - - const dispatcher = createStreamEventDispatcher({ - setHasTextChunks, - setStreamingChunks, - streamingChunksRef, - hasResponseRef, - setMessages, - setIsRegionBlockedModalOpen, - sessionId, - setIsStreamingInitiated, - onOperationStarted, - }); + textFinalizedRef.current = false; + streamEndedRef.current = false; try { await sendStreamMessage( sessionId, content, - dispatcher, + createDispatcher(), isUserMessage, context, ); } catch (err) { - console.error("[useChatContainer] Failed to send message:", err); setIsStreamingInitiated(false); - if (err instanceof Error && err.name === "AbortError") return; const errorMessage = diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatMessage/useChatMessage.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatMessage/useChatMessage.ts index d6526c78ab..6809497a93 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatMessage/useChatMessage.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatMessage/useChatMessage.ts @@ -111,6 +111,7 @@ export type ChatMessageData = toolName: string; toolId: string; operationId: string; + taskId?: string; // For SSE reconnection message: string; timestamp?: string | Date; } diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/MessageList/MessageList.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/MessageList/MessageList.tsx index 84f31f9d20..01d107c64e 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/MessageList/MessageList.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/MessageList/MessageList.tsx @@ -31,11 +31,6 @@ export function MessageList({ isStreaming, }); - /** - * Keeps this for debugging purposes 💆🏽 - */ - console.log(messages); - return (
{/* Top fade shadow */} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/stream-executor.ts b/autogpt_platform/frontend/src/components/contextual/Chat/stream-executor.ts index b0d970c286..8f4c8f9fec 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/stream-executor.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/stream-executor.ts @@ -1,3 +1,4 @@ +import { INITIAL_STREAM_ID } from "./chat-constants"; import type { ActiveStream, StreamChunk, @@ -10,8 +11,14 @@ import { parseSSELine, } from "./stream-utils"; -function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) { - stream.chunks.push(chunk); +function notifySubscribers( + stream: ActiveStream, + chunk: StreamChunk, + skipStore = false, +) { + if (!skipStore) { + stream.chunks.push(chunk); + } for (const callback of stream.onChunkCallbacks) { try { callback(chunk); @@ -21,36 +28,114 @@ function notifySubscribers(stream: ActiveStream, chunk: StreamChunk) { } } -export async function executeStream( - stream: ActiveStream, - message: string, - isUserMessage: boolean, - context?: { url: string; content: string }, - retryCount: number = 0, +interface StreamExecutionOptions { + stream: ActiveStream; + mode: "new" | "reconnect"; + message?: string; + isUserMessage?: boolean; + context?: { url: string; content: string }; + taskId?: string; + lastMessageId?: string; + retryCount?: number; +} + +async function executeStreamInternal( + options: StreamExecutionOptions, ): Promise { + const { + stream, + mode, + message, + isUserMessage, + context, + taskId, + lastMessageId = INITIAL_STREAM_ID, + retryCount = 0, + } = options; + const { sessionId, abortController } = stream; + const isReconnect = mode === "reconnect"; + + if (isReconnect) { + if (!taskId) { + throw new Error("taskId is required for reconnect mode"); + } + if (lastMessageId === null || lastMessageId === undefined) { + throw new Error("lastMessageId is required for reconnect mode"); + } + } else { + if (!message) { + throw new Error("message is required for new stream mode"); + } + if (isUserMessage === undefined) { + throw new Error("isUserMessage is required for new stream mode"); + } + } try { - const url = `/api/chat/sessions/${sessionId}/stream`; - const body = JSON.stringify({ - message, - is_user_message: isUserMessage, - context: context || null, - }); + let url: string; + let fetchOptions: RequestInit; - const response = await fetch(url, { - method: "POST", - headers: { - "Content-Type": "application/json", - Accept: "text/event-stream", - }, - body, - signal: abortController.signal, - }); + if (isReconnect) { + url = `/api/chat/tasks/${taskId}/stream?last_message_id=${encodeURIComponent(lastMessageId)}`; + fetchOptions = { + method: "GET", + headers: { + Accept: "text/event-stream", + }, + signal: abortController.signal, + }; + } else { + url = `/api/chat/sessions/${sessionId}/stream`; + fetchOptions = { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + }, + body: JSON.stringify({ + message, + is_user_message: isUserMessage, + context: context || null, + }), + signal: abortController.signal, + }; + } + + const response = await fetch(url, fetchOptions); if (!response.ok) { const errorText = await response.text(); - throw new Error(errorText || `HTTP ${response.status}`); + let errorCode: string | undefined; + let errorMessage = errorText || `HTTP ${response.status}`; + try { + const parsed = JSON.parse(errorText); + if (parsed.detail) { + const detail = + typeof parsed.detail === "string" + ? parsed.detail + : parsed.detail.message || JSON.stringify(parsed.detail); + errorMessage = detail; + errorCode = + typeof parsed.detail === "object" ? parsed.detail.code : undefined; + } + } catch {} + + const isPermanentError = + isReconnect && + (response.status === 404 || + response.status === 403 || + response.status === 410); + + const error = new Error(errorMessage) as Error & { + status?: number; + isPermanent?: boolean; + taskErrorCode?: string; + }; + error.status = response.status; + error.isPermanent = isPermanentError; + error.taskErrorCode = errorCode; + throw error; } if (!response.body) { @@ -104,9 +189,7 @@ export async function executeStream( ); return; } - } catch (err) { - console.warn("[StreamExecutor] Failed to parse SSE chunk:", err); - } + } catch {} } } } @@ -117,19 +200,17 @@ export async function executeStream( return; } - if (retryCount < MAX_RETRIES) { + const isPermanentError = + err instanceof Error && + (err as Error & { isPermanent?: boolean }).isPermanent; + + if (!isPermanentError && retryCount < MAX_RETRIES) { const retryDelay = INITIAL_RETRY_DELAY * Math.pow(2, retryCount); - console.log( - `[StreamExecutor] Retrying in ${retryDelay}ms (attempt ${retryCount + 1}/${MAX_RETRIES})`, - ); await new Promise((resolve) => setTimeout(resolve, retryDelay)); - return executeStream( - stream, - message, - isUserMessage, - context, - retryCount + 1, - ); + return executeStreamInternal({ + ...options, + retryCount: retryCount + 1, + }); } stream.status = "error"; @@ -140,3 +221,35 @@ export async function executeStream( }); } } + +export async function executeStream( + stream: ActiveStream, + message: string, + isUserMessage: boolean, + context?: { url: string; content: string }, + retryCount: number = 0, +): Promise { + return executeStreamInternal({ + stream, + mode: "new", + message, + isUserMessage, + context, + retryCount, + }); +} + +export async function executeTaskReconnect( + stream: ActiveStream, + taskId: string, + lastMessageId: string = INITIAL_STREAM_ID, + retryCount: number = 0, +): Promise { + return executeStreamInternal({ + stream, + mode: "reconnect", + taskId, + lastMessageId, + retryCount, + }); +} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/stream-utils.ts b/autogpt_platform/frontend/src/components/contextual/Chat/stream-utils.ts index 4100926e79..253e47b874 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/stream-utils.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/stream-utils.ts @@ -28,6 +28,7 @@ export function normalizeStreamChunk( switch (chunk.type) { case "text-delta": + // Vercel AI SDK sends "delta" for text content return { type: "text_chunk", content: chunk.delta }; case "text-end": return { type: "text_ended" }; @@ -63,6 +64,10 @@ export function normalizeStreamChunk( case "finish": return { type: "stream_end" }; case "start": + // Start event with optional taskId for reconnection + return chunk.taskId + ? { type: "stream_start", taskId: chunk.taskId } + : null; case "text-start": return null; case "tool-input-start":