diff --git a/.gitignore b/.gitignore index 1a2291b516..012a0b5227 100644 --- a/.gitignore +++ b/.gitignore @@ -180,3 +180,4 @@ autogpt_platform/backend/settings.py .claude/settings.local.json CLAUDE.local.md /autogpt_platform/backend/logs +.next \ No newline at end of file diff --git a/README.md b/README.md index 3572fe318b..349d8818ef 100644 --- a/README.md +++ b/README.md @@ -54,7 +54,7 @@ Before proceeding with the installation, ensure your system meets the following ### Updated Setup Instructions: We've moved to a fully maintained and regularly updated documentation site. -👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/) +👉 [Follow the official self-hosting guide here](https://agpt.co/docs/platform/getting-started/getting-started) This tutorial assumes you have Docker, VSCode, git and npm installed. diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py new file mode 100644 index 0000000000..f447d46bd7 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py @@ -0,0 +1,368 @@ +"""Redis Streams consumer for operation completion messages. + +This module provides a consumer (ChatCompletionConsumer) that listens for +completion notifications (OperationCompleteMessage) from external services +(like Agent Generator) and triggers the appropriate stream registry and +chat service updates via process_operation_success/process_operation_failure. + +Why Redis Streams instead of RabbitMQ? +-------------------------------------- +While the project typically uses RabbitMQ for async task queues (e.g., execution +queue), Redis Streams was chosen for chat completion notifications because: + +1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis + Streams (via stream_registry) for message persistence and replay. Using Redis + Streams for completion notifications keeps all chat streaming infrastructure + in one system, simplifying operations and reducing cross-system coordination. + +2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs, + allowing consumers to replay missed messages after reconnection. This aligns + with the SSE reconnection pattern where clients can resume from last_message_id. + +3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic + load balancing across pods with explicit message claiming (XAUTOCLAIM) for + recovering from dead consumers - ideal for the completion callback pattern. + +4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for + stream_registry) provides lower latency than an additional RabbitMQ hop. + +5. **Atomicity with Task State**: Completion processing often needs to update + task metadata stored in Redis. Keeping both in Redis enables simpler + transactional semantics without distributed coordination. + +The consumer uses Redis Streams with consumer groups for reliable message +processing across multiple platform pods, with XAUTOCLAIM for reclaiming +stale pending messages from dead consumers. +""" + +import asyncio +import logging +import os +import uuid +from typing import Any + +import orjson +from prisma import Prisma +from pydantic import BaseModel +from redis.exceptions import ResponseError + +from backend.data.redis_client import get_redis_async + +from . import stream_registry +from .completion_handler import process_operation_failure, process_operation_success +from .config import ChatConfig + +logger = logging.getLogger(__name__) +config = ChatConfig() + + +class OperationCompleteMessage(BaseModel): + """Message format for operation completion notifications.""" + + operation_id: str + task_id: str + success: bool + result: dict | str | None = None + error: str | None = None + + +class ChatCompletionConsumer: + """Consumer for chat operation completion messages from Redis Streams. + + This consumer initializes its own Prisma client in start() to ensure + database operations work correctly within this async context. + + Uses Redis consumer groups to allow multiple platform pods to consume + messages reliably with automatic redelivery on failure. + """ + + def __init__(self): + self._consumer_task: asyncio.Task | None = None + self._running = False + self._prisma: Prisma | None = None + self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}" + + async def start(self) -> None: + """Start the completion consumer.""" + if self._running: + logger.warning("Completion consumer already running") + return + + # Create consumer group if it doesn't exist + try: + redis = await get_redis_async() + await redis.xgroup_create( + config.stream_completion_name, + config.stream_consumer_group, + id="0", + mkstream=True, + ) + logger.info( + f"Created consumer group '{config.stream_consumer_group}' " + f"on stream '{config.stream_completion_name}'" + ) + except ResponseError as e: + if "BUSYGROUP" in str(e): + logger.debug( + f"Consumer group '{config.stream_consumer_group}' already exists" + ) + else: + raise + + self._running = True + self._consumer_task = asyncio.create_task(self._consume_messages()) + logger.info( + f"Chat completion consumer started (consumer: {self._consumer_name})" + ) + + async def _ensure_prisma(self) -> Prisma: + """Lazily initialize Prisma client on first use.""" + if self._prisma is None: + database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432") + self._prisma = Prisma(datasource={"url": database_url}) + await self._prisma.connect() + logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)") + return self._prisma + + async def stop(self) -> None: + """Stop the completion consumer.""" + self._running = False + + if self._consumer_task: + self._consumer_task.cancel() + try: + await self._consumer_task + except asyncio.CancelledError: + pass + self._consumer_task = None + + if self._prisma: + await self._prisma.disconnect() + self._prisma = None + logger.info("[COMPLETION] Consumer Prisma client disconnected") + + logger.info("Chat completion consumer stopped") + + async def _consume_messages(self) -> None: + """Main message consumption loop with retry logic.""" + max_retries = 10 + retry_delay = 5 # seconds + retry_count = 0 + block_timeout = 5000 # milliseconds + + while self._running and retry_count < max_retries: + try: + redis = await get_redis_async() + + # Reset retry count on successful connection + retry_count = 0 + + while self._running: + # First, claim any stale pending messages from dead consumers + # Redis does NOT auto-redeliver pending messages; we must explicitly + # claim them using XAUTOCLAIM + try: + claimed_result = await redis.xautoclaim( + name=config.stream_completion_name, + groupname=config.stream_consumer_group, + consumername=self._consumer_name, + min_idle_time=config.stream_claim_min_idle_ms, + start_id="0-0", + count=10, + ) + # xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids]) + if claimed_result and len(claimed_result) >= 2: + claimed_entries = claimed_result[1] + if claimed_entries: + logger.info( + f"Claimed {len(claimed_entries)} stale pending messages" + ) + for entry_id, data in claimed_entries: + if not self._running: + return + await self._process_entry(redis, entry_id, data) + except Exception as e: + logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}") + + # Read new messages from the stream + messages = await redis.xreadgroup( + groupname=config.stream_consumer_group, + consumername=self._consumer_name, + streams={config.stream_completion_name: ">"}, + block=block_timeout, + count=10, + ) + + if not messages: + continue + + for stream_name, entries in messages: + for entry_id, data in entries: + if not self._running: + return + await self._process_entry(redis, entry_id, data) + + except asyncio.CancelledError: + logger.info("Consumer cancelled") + return + except Exception as e: + retry_count += 1 + logger.error( + f"Consumer error (retry {retry_count}/{max_retries}): {e}", + exc_info=True, + ) + if self._running and retry_count < max_retries: + await asyncio.sleep(retry_delay) + else: + logger.error("Max retries reached, stopping consumer") + return + + async def _process_entry( + self, redis: Any, entry_id: str, data: dict[str, Any] + ) -> None: + """Process a single stream entry and acknowledge it on success. + + Args: + redis: Redis client connection + entry_id: The stream entry ID + data: The entry data dict + """ + try: + # Handle the message + message_data = data.get("data") + if message_data: + await self._handle_message( + message_data.encode() + if isinstance(message_data, str) + else message_data + ) + + # Acknowledge the message after successful processing + await redis.xack( + config.stream_completion_name, + config.stream_consumer_group, + entry_id, + ) + except Exception as e: + logger.error( + f"Error processing completion message {entry_id}: {e}", + exc_info=True, + ) + # Message remains in pending state and will be claimed by + # XAUTOCLAIM after min_idle_time expires + + async def _handle_message(self, body: bytes) -> None: + """Handle a completion message using our own Prisma client.""" + try: + data = orjson.loads(body) + message = OperationCompleteMessage(**data) + except Exception as e: + logger.error(f"Failed to parse completion message: {e}") + return + + logger.info( + f"[COMPLETION] Received completion for operation {message.operation_id} " + f"(task_id={message.task_id}, success={message.success})" + ) + + # Find task in registry + task = await stream_registry.find_task_by_operation_id(message.operation_id) + if task is None: + task = await stream_registry.get_task(message.task_id) + + if task is None: + logger.warning( + f"[COMPLETION] Task not found for operation {message.operation_id} " + f"(task_id={message.task_id})" + ) + return + + logger.info( + f"[COMPLETION] Found task: task_id={task.task_id}, " + f"session_id={task.session_id}, tool_call_id={task.tool_call_id}" + ) + + # Guard against empty task fields + if not task.task_id or not task.session_id or not task.tool_call_id: + logger.error( + f"[COMPLETION] Task has empty critical fields! " + f"task_id={task.task_id!r}, session_id={task.session_id!r}, " + f"tool_call_id={task.tool_call_id!r}" + ) + return + + if message.success: + await self._handle_success(task, message) + else: + await self._handle_failure(task, message) + + async def _handle_success( + self, + task: stream_registry.ActiveTask, + message: OperationCompleteMessage, + ) -> None: + """Handle successful operation completion.""" + prisma = await self._ensure_prisma() + await process_operation_success(task, message.result, prisma) + + async def _handle_failure( + self, + task: stream_registry.ActiveTask, + message: OperationCompleteMessage, + ) -> None: + """Handle failed operation completion.""" + prisma = await self._ensure_prisma() + await process_operation_failure(task, message.error, prisma) + + +# Module-level consumer instance +_consumer: ChatCompletionConsumer | None = None + + +async def start_completion_consumer() -> None: + """Start the global completion consumer.""" + global _consumer + if _consumer is None: + _consumer = ChatCompletionConsumer() + await _consumer.start() + + +async def stop_completion_consumer() -> None: + """Stop the global completion consumer.""" + global _consumer + if _consumer: + await _consumer.stop() + _consumer = None + + +async def publish_operation_complete( + operation_id: str, + task_id: str, + success: bool, + result: dict | str | None = None, + error: str | None = None, +) -> None: + """Publish an operation completion message to Redis Streams. + + Args: + operation_id: The operation ID that completed. + task_id: The task ID associated with the operation. + success: Whether the operation succeeded. + result: The result data (for success). + error: The error message (for failure). + """ + message = OperationCompleteMessage( + operation_id=operation_id, + task_id=task_id, + success=success, + result=result, + error=error, + ) + + redis = await get_redis_async() + await redis.xadd( + config.stream_completion_name, + {"data": message.model_dump_json()}, + maxlen=config.stream_max_length, + ) + logger.info(f"Published completion for operation {operation_id}") diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_handler.py b/autogpt_platform/backend/backend/api/features/chat/completion_handler.py new file mode 100644 index 0000000000..905fa2ddba --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/completion_handler.py @@ -0,0 +1,344 @@ +"""Shared completion handling for operation success and failure. + +This module provides common logic for handling operation completion from both: +- The Redis Streams consumer (completion_consumer.py) +- The HTTP webhook endpoint (routes.py) +""" + +import logging +from typing import Any + +import orjson +from prisma import Prisma + +from . import service as chat_service +from . import stream_registry +from .response_model import StreamError, StreamToolOutputAvailable +from .tools.models import ErrorResponse + +logger = logging.getLogger(__name__) + +# Tools that produce agent_json that needs to be saved to library +AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"} + +# Keys that should be stripped from agent_json when returning in error responses +SENSITIVE_KEYS = frozenset( + { + "api_key", + "apikey", + "api_secret", + "password", + "secret", + "credentials", + "credential", + "token", + "access_token", + "refresh_token", + "private_key", + "privatekey", + "auth", + "authorization", + } +) + + +def _sanitize_agent_json(obj: Any) -> Any: + """Recursively sanitize agent_json by removing sensitive keys. + + Args: + obj: The object to sanitize (dict, list, or primitive) + + Returns: + Sanitized copy with sensitive keys removed/redacted + """ + if isinstance(obj, dict): + return { + k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v) + for k, v in obj.items() + } + elif isinstance(obj, list): + return [_sanitize_agent_json(item) for item in obj] + else: + return obj + + +class ToolMessageUpdateError(Exception): + """Raised when updating a tool message in the database fails.""" + + pass + + +async def _update_tool_message( + session_id: str, + tool_call_id: str, + content: str, + prisma_client: Prisma | None, +) -> None: + """Update tool message in database. + + Args: + session_id: The session ID + tool_call_id: The tool call ID to update + content: The new content for the message + prisma_client: Optional Prisma client. If None, uses chat_service. + + Raises: + ToolMessageUpdateError: If the database update fails. The caller should + handle this to avoid marking the task as completed with inconsistent state. + """ + try: + if prisma_client: + # Use provided Prisma client (for consumer with its own connection) + updated_count = await prisma_client.chatmessage.update_many( + where={ + "sessionId": session_id, + "toolCallId": tool_call_id, + }, + data={"content": content}, + ) + # Check if any rows were updated - 0 means message not found + if updated_count == 0: + raise ToolMessageUpdateError( + f"No message found with tool_call_id={tool_call_id} in session {session_id}" + ) + else: + # Use service function (for webhook endpoint) + await chat_service._update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=content, + ) + except ToolMessageUpdateError: + raise + except Exception as e: + logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True) + raise ToolMessageUpdateError( + f"Failed to update tool message for tool_call_id={tool_call_id}: {e}" + ) from e + + +def serialize_result(result: dict | list | str | int | float | bool | None) -> str: + """Serialize result to JSON string with sensible defaults. + + Args: + result: The result to serialize. Can be a dict, list, string, + number, boolean, or None. + + Returns: + JSON string representation of the result. Returns '{"status": "completed"}' + only when result is explicitly None. + """ + if isinstance(result, str): + return result + if result is None: + return '{"status": "completed"}' + return orjson.dumps(result).decode("utf-8") + + +async def _save_agent_from_result( + result: dict[str, Any], + user_id: str | None, + tool_name: str, +) -> dict[str, Any]: + """Save agent to library if result contains agent_json. + + Args: + result: The result dict that may contain agent_json + user_id: The user ID to save the agent for + tool_name: The tool name (create_agent or edit_agent) + + Returns: + Updated result dict with saved agent details, or original result if no agent_json + """ + if not user_id: + logger.warning("[COMPLETION] Cannot save agent: no user_id in task") + return result + + agent_json = result.get("agent_json") + if not agent_json: + logger.warning( + f"[COMPLETION] {tool_name} completed but no agent_json in result" + ) + return result + + try: + from .tools.agent_generator import save_agent_to_library + + is_update = tool_name == "edit_agent" + created_graph, library_agent = await save_agent_to_library( + agent_json, user_id, is_update=is_update + ) + + logger.info( + f"[COMPLETION] Saved agent '{created_graph.name}' to library " + f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})" + ) + + # Return a response similar to AgentSavedResponse + return { + "type": "agent_saved", + "message": f"Agent '{created_graph.name}' has been saved to your library!", + "agent_id": created_graph.id, + "agent_name": created_graph.name, + "library_agent_id": library_agent.id, + "library_agent_link": f"/library/agents/{library_agent.id}", + "agent_page_link": f"/build?flowID={created_graph.id}", + } + except Exception as e: + logger.error( + f"[COMPLETION] Failed to save agent to library: {e}", + exc_info=True, + ) + # Return error but don't fail the whole operation + # Sanitize agent_json to remove sensitive keys before returning + return { + "type": "error", + "message": f"Agent was generated but failed to save: {str(e)}", + "error": str(e), + "agent_json": _sanitize_agent_json(agent_json), + } + + +async def process_operation_success( + task: stream_registry.ActiveTask, + result: dict | str | None, + prisma_client: Prisma | None = None, +) -> None: + """Handle successful operation completion. + + Publishes the result to the stream registry, updates the database, + generates LLM continuation, and marks the task as completed. + + Args: + task: The active task that completed + result: The result data from the operation + prisma_client: Optional Prisma client for database operations. + If None, uses chat_service._update_pending_operation instead. + + Raises: + ToolMessageUpdateError: If the database update fails. The task will be + marked as failed instead of completed to avoid inconsistent state. + """ + # For agent generation tools, save the agent to library + if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict): + result = await _save_agent_from_result(result, task.user_id, task.tool_name) + + # Serialize result for output (only substitute default when result is exactly None) + result_output = result if result is not None else {"status": "completed"} + output_str = ( + result_output + if isinstance(result_output, str) + else orjson.dumps(result_output).decode("utf-8") + ) + + # Publish result to stream registry + await stream_registry.publish_chunk( + task.task_id, + StreamToolOutputAvailable( + toolCallId=task.tool_call_id, + toolName=task.tool_name, + output=output_str, + success=True, + ), + ) + + # Update pending operation in database + # If this fails, we must not continue to mark the task as completed + result_str = serialize_result(result) + try: + await _update_tool_message( + session_id=task.session_id, + tool_call_id=task.tool_call_id, + content=result_str, + prisma_client=prisma_client, + ) + except ToolMessageUpdateError: + # DB update failed - mark task as failed to avoid inconsistent state + logger.error( + f"[COMPLETION] DB update failed for task {task.task_id}, " + "marking as failed instead of completed" + ) + await stream_registry.publish_chunk( + task.task_id, + StreamError(errorText="Failed to save operation result to database"), + ) + await stream_registry.mark_task_completed(task.task_id, status="failed") + raise + + # Generate LLM continuation with streaming + try: + await chat_service._generate_llm_continuation_with_streaming( + session_id=task.session_id, + user_id=task.user_id, + task_id=task.task_id, + ) + except Exception as e: + logger.error( + f"[COMPLETION] Failed to generate LLM continuation: {e}", + exc_info=True, + ) + + # Mark task as completed and release Redis lock + await stream_registry.mark_task_completed(task.task_id, status="completed") + try: + await chat_service._mark_operation_completed(task.tool_call_id) + except Exception as e: + logger.error(f"[COMPLETION] Failed to mark operation completed: {e}") + + logger.info( + f"[COMPLETION] Successfully processed completion for task {task.task_id}" + ) + + +async def process_operation_failure( + task: stream_registry.ActiveTask, + error: str | None, + prisma_client: Prisma | None = None, +) -> None: + """Handle failed operation completion. + + Publishes the error to the stream registry, updates the database with + the error response, and marks the task as failed. + + Args: + task: The active task that failed + error: The error message from the operation + prisma_client: Optional Prisma client for database operations. + If None, uses chat_service._update_pending_operation instead. + """ + error_msg = error or "Operation failed" + + # Publish error to stream registry + await stream_registry.publish_chunk( + task.task_id, + StreamError(errorText=error_msg), + ) + + # Update pending operation with error + # If this fails, we still continue to mark the task as failed + error_response = ErrorResponse( + message=error_msg, + error=error, + ) + try: + await _update_tool_message( + session_id=task.session_id, + tool_call_id=task.tool_call_id, + content=error_response.model_dump_json(), + prisma_client=prisma_client, + ) + except ToolMessageUpdateError: + # DB update failed - log but continue with cleanup + logger.error( + f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, " + "continuing with cleanup" + ) + + # Mark task as failed and release Redis lock + await stream_registry.mark_task_completed(task.task_id, status="failed") + try: + await chat_service._mark_operation_completed(task.tool_call_id) + except Exception as e: + logger.error(f"[COMPLETION] Failed to mark operation completed: {e}") + + logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}") diff --git a/autogpt_platform/backend/backend/api/features/chat/config.py b/autogpt_platform/backend/backend/api/features/chat/config.py index dba7934877..2e8dbf5413 100644 --- a/autogpt_platform/backend/backend/api/features/chat/config.py +++ b/autogpt_platform/backend/backend/api/features/chat/config.py @@ -44,6 +44,48 @@ class ChatConfig(BaseSettings): description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)", ) + # Stream registry configuration for SSE reconnection + stream_ttl: int = Field( + default=3600, + description="TTL in seconds for stream data in Redis (1 hour)", + ) + stream_max_length: int = Field( + default=10000, + description="Maximum number of messages to store per stream", + ) + + # Redis Streams configuration for completion consumer + stream_completion_name: str = Field( + default="chat:completions", + description="Redis Stream name for operation completions", + ) + stream_consumer_group: str = Field( + default="chat_consumers", + description="Consumer group name for completion stream", + ) + stream_claim_min_idle_ms: int = Field( + default=60000, + description="Minimum idle time in milliseconds before claiming pending messages from dead consumers", + ) + + # Redis key prefixes for stream registry + task_meta_prefix: str = Field( + default="chat:task:meta:", + description="Prefix for task metadata hash keys", + ) + task_stream_prefix: str = Field( + default="chat:stream:", + description="Prefix for task message stream keys", + ) + task_op_prefix: str = Field( + default="chat:task:op:", + description="Prefix for operation ID to task ID mapping keys", + ) + internal_api_key: str | None = Field( + default=None, + description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)", + ) + # Langfuse Prompt Management Configuration # Note: Langfuse credentials are in Settings().secrets (settings.py) langfuse_prompt_name: str = Field( @@ -82,6 +124,14 @@ class ChatConfig(BaseSettings): v = "https://openrouter.ai/api/v1" return v + @field_validator("internal_api_key", mode="before") + @classmethod + def get_internal_api_key(cls, v): + """Get internal API key from environment if not provided.""" + if v is None: + v = os.getenv("CHAT_INTERNAL_API_KEY") + return v + # Prompt paths for different contexts PROMPT_PATHS: dict[str, str] = { "default": "prompts/chat_system.md", diff --git a/autogpt_platform/backend/backend/api/features/chat/response_model.py b/autogpt_platform/backend/backend/api/features/chat/response_model.py index 53a8cf3a1f..f627a42fcc 100644 --- a/autogpt_platform/backend/backend/api/features/chat/response_model.py +++ b/autogpt_platform/backend/backend/api/features/chat/response_model.py @@ -52,6 +52,10 @@ class StreamStart(StreamBaseResponse): type: ResponseType = ResponseType.START messageId: str = Field(..., description="Unique message ID") + taskId: str | None = Field( + default=None, + description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream", + ) class StreamFinish(StreamBaseResponse): diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index cab51543b1..3e731d86ac 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -1,19 +1,23 @@ """Chat API routes for chat session management and streaming via SSE.""" import logging +import uuid as uuid_module from collections.abc import AsyncGenerator from typing import Annotated from autogpt_libs import auth -from fastapi import APIRouter, Depends, Query, Security +from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security from fastapi.responses import StreamingResponse from pydantic import BaseModel from backend.util.exceptions import NotFoundError from . import service as chat_service +from . import stream_registry +from .completion_handler import process_operation_failure, process_operation_success from .config import ChatConfig from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions +from .response_model import StreamFinish, StreamHeartbeat, StreamStart config = ChatConfig() @@ -55,6 +59,15 @@ class CreateSessionResponse(BaseModel): user_id: str | None +class ActiveStreamInfo(BaseModel): + """Information about an active stream for reconnection.""" + + task_id: str + last_message_id: str # Redis Stream message ID for resumption + operation_id: str # Operation ID for completion tracking + tool_name: str # Name of the tool being executed + + class SessionDetailResponse(BaseModel): """Response model providing complete details for a chat session, including messages.""" @@ -63,6 +76,7 @@ class SessionDetailResponse(BaseModel): updated_at: str user_id: str | None messages: list[dict] + active_stream: ActiveStreamInfo | None = None # Present if stream is still active class SessionSummaryResponse(BaseModel): @@ -81,6 +95,14 @@ class ListSessionsResponse(BaseModel): total: int +class OperationCompleteRequest(BaseModel): + """Request model for external completion webhook.""" + + success: bool + result: dict | str | None = None + error: str | None = None + + # ========== Routes ========== @@ -166,13 +188,14 @@ async def get_session( Retrieve the details of a specific chat session. Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages. + If there's an active stream for this session, returns the task_id for reconnection. Args: session_id: The unique identifier for the desired chat session. user_id: The optional authenticated user ID, or None for anonymous access. Returns: - SessionDetailResponse: Details for the requested session, or None if not found. + SessionDetailResponse: Details for the requested session, including active_stream info if applicable. """ session = await get_chat_session(session_id, user_id) @@ -180,11 +203,28 @@ async def get_session( raise NotFoundError(f"Session {session_id} not found.") messages = [message.model_dump() for message in session.messages] - logger.info( - f"Returning session {session_id}: " - f"message_count={len(messages)}, " - f"roles={[m.get('role') for m in messages]}" + + # Check if there's an active stream for this session + active_stream_info = None + active_task, last_message_id = await stream_registry.get_active_task_for_session( + session_id, user_id ) + if active_task: + # Filter out the in-progress assistant message from the session response. + # The client will receive the complete assistant response through the SSE + # stream replay instead, preventing duplicate content. + if messages and messages[-1].get("role") == "assistant": + messages = messages[:-1] + + # Use "0-0" as last_message_id to replay the stream from the beginning. + # Since we filtered out the cached assistant message, the client needs + # the full stream to reconstruct the response. + active_stream_info = ActiveStreamInfo( + task_id=active_task.task_id, + last_message_id="0-0", + operation_id=active_task.operation_id, + tool_name=active_task.tool_name, + ) return SessionDetailResponse( id=session.session_id, @@ -192,6 +232,7 @@ async def get_session( updated_at=session.updated_at.isoformat(), user_id=session.user_id or None, messages=messages, + active_stream=active_stream_info, ) @@ -211,49 +252,112 @@ async def stream_chat_post( - Tool call UI elements (if invoked) - Tool execution results + The AI generation runs in a background task that continues even if the client disconnects. + All chunks are written to Redis for reconnection support. If the client disconnects, + they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off. + Args: session_id: The chat session identifier to associate with the streamed messages. request: Request body containing message, is_user_message, and optional context. user_id: Optional authenticated user ID. Returns: - StreamingResponse: SSE-formatted response chunks. + StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event + containing the task_id for reconnection. """ + import asyncio + session = await _validate_and_get_session(session_id, user_id) + # Create a task in the stream registry for reconnection support + task_id = str(uuid_module.uuid4()) + operation_id = str(uuid_module.uuid4()) + await stream_registry.create_task( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id="chat_stream", # Not a tool call, but needed for the model + tool_name="chat", + operation_id=operation_id, + ) + + # Background task that runs the AI generation independently of SSE connection + async def run_ai_generation(): + try: + # Emit a start event with task_id for reconnection + start_chunk = StreamStart(messageId=task_id, taskId=task_id) + await stream_registry.publish_chunk(task_id, start_chunk) + + async for chunk in chat_service.stream_chat_completion( + session_id, + request.message, + is_user_message=request.is_user_message, + user_id=user_id, + session=session, # Pass pre-fetched session to avoid double-fetch + context=request.context, + ): + # Write to Redis (subscribers will receive via XREAD) + await stream_registry.publish_chunk(task_id, chunk) + + # Mark task as completed + await stream_registry.mark_task_completed(task_id, "completed") + except Exception as e: + logger.error( + f"Error in background AI generation for session {session_id}: {e}" + ) + await stream_registry.mark_task_completed(task_id, "failed") + + # Start the AI generation in a background task + bg_task = asyncio.create_task(run_ai_generation()) + await stream_registry.set_task_asyncio_task(task_id, bg_task) + + # SSE endpoint that subscribes to the task's stream async def event_generator() -> AsyncGenerator[str, None]: - chunk_count = 0 - first_chunk_type: str | None = None - async for chunk in chat_service.stream_chat_completion( - session_id, - request.message, - is_user_message=request.is_user_message, - user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch - context=request.context, - ): - if chunk_count < 3: - logger.info( - "Chat stream chunk", - extra={ - "session_id": session_id, - "chunk_type": str(chunk.type), - }, - ) - if not first_chunk_type: - first_chunk_type = str(chunk.type) - chunk_count += 1 - yield chunk.to_sse() - logger.info( - "Chat stream completed", - extra={ - "session_id": session_id, - "chunk_count": chunk_count, - "first_chunk_type": first_chunk_type, - }, - ) - # AI SDK protocol termination - yield "data: [DONE]\n\n" + subscriber_queue = None + try: + # Subscribe to the task stream (this replays existing messages + live updates) + subscriber_queue = await stream_registry.subscribe_to_task( + task_id=task_id, + user_id=user_id, + last_message_id="0-0", # Get all messages from the beginning + ) + + if subscriber_queue is None: + yield StreamFinish().to_sse() + yield "data: [DONE]\n\n" + return + + # Read from the subscriber queue and yield to SSE + while True: + try: + chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0) + yield chunk.to_sse() + + # Check for finish signal + if isinstance(chunk, StreamFinish): + break + except asyncio.TimeoutError: + # Send heartbeat to keep connection alive + yield StreamHeartbeat().to_sse() + + except GeneratorExit: + pass # Client disconnected - background task continues + except Exception as e: + logger.error(f"Error in SSE stream for task {task_id}: {e}") + finally: + # Unsubscribe when client disconnects or stream ends to prevent resource leak + if subscriber_queue is not None: + try: + await stream_registry.unsubscribe_from_task( + task_id, subscriber_queue + ) + except Exception as unsub_err: + logger.error( + f"Error unsubscribing from task {task_id}: {unsub_err}", + exc_info=True, + ) + # AI SDK protocol termination - always yield even if unsubscribe fails + yield "data: [DONE]\n\n" return StreamingResponse( event_generator(), @@ -366,6 +470,251 @@ async def session_assign_user( return {"status": "ok"} +# ========== Task Streaming (SSE Reconnection) ========== + + +@router.get( + "/tasks/{task_id}/stream", +) +async def stream_task( + task_id: str, + user_id: str | None = Depends(auth.get_user_id), + last_message_id: str = Query( + default="0-0", + description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.", + ), +): + """ + Reconnect to a long-running task's SSE stream. + + When a long-running operation (like agent generation) starts, the client + receives a task_id. If the connection drops, the client can reconnect + using this endpoint to resume receiving updates. + + Args: + task_id: The task ID from the operation_started response. + user_id: Authenticated user ID for ownership validation. + last_message_id: Last Redis Stream message ID received ("0-0" for full replay). + + Returns: + StreamingResponse: SSE-formatted response chunks starting after last_message_id. + + Raises: + HTTPException: 404 if task not found, 410 if task expired, 403 if access denied. + """ + # Check task existence and expiry before subscribing + task, error_code = await stream_registry.get_task_with_expiry_info(task_id) + + if error_code == "TASK_EXPIRED": + raise HTTPException( + status_code=410, + detail={ + "code": "TASK_EXPIRED", + "message": "This operation has expired. Please try again.", + }, + ) + + if error_code == "TASK_NOT_FOUND": + raise HTTPException( + status_code=404, + detail={ + "code": "TASK_NOT_FOUND", + "message": f"Task {task_id} not found.", + }, + ) + + # Validate ownership if task has an owner + if task and task.user_id and user_id != task.user_id: + raise HTTPException( + status_code=403, + detail={ + "code": "ACCESS_DENIED", + "message": "You do not have access to this task.", + }, + ) + + # Get subscriber queue from stream registry + subscriber_queue = await stream_registry.subscribe_to_task( + task_id=task_id, + user_id=user_id, + last_message_id=last_message_id, + ) + + if subscriber_queue is None: + raise HTTPException( + status_code=404, + detail={ + "code": "TASK_NOT_FOUND", + "message": f"Task {task_id} not found or access denied.", + }, + ) + + async def event_generator() -> AsyncGenerator[str, None]: + import asyncio + + heartbeat_interval = 15.0 # Send heartbeat every 15 seconds + try: + while True: + try: + # Wait for next chunk with timeout for heartbeats + chunk = await asyncio.wait_for( + subscriber_queue.get(), timeout=heartbeat_interval + ) + yield chunk.to_sse() + + # Check for finish signal + if isinstance(chunk, StreamFinish): + break + except asyncio.TimeoutError: + # Send heartbeat to keep connection alive + yield StreamHeartbeat().to_sse() + except Exception as e: + logger.error(f"Error in task stream {task_id}: {e}", exc_info=True) + finally: + # Unsubscribe when client disconnects or stream ends + try: + await stream_registry.unsubscribe_from_task(task_id, subscriber_queue) + except Exception as unsub_err: + logger.error( + f"Error unsubscribing from task {task_id}: {unsub_err}", + exc_info=True, + ) + # AI SDK protocol termination - always yield even if unsubscribe fails + yield "data: [DONE]\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "x-vercel-ai-ui-message-stream": "v1", + }, + ) + + +@router.get( + "/tasks/{task_id}", +) +async def get_task_status( + task_id: str, + user_id: str | None = Depends(auth.get_user_id), +) -> dict: + """ + Get the status of a long-running task. + + Args: + task_id: The task ID to check. + user_id: Authenticated user ID for ownership validation. + + Returns: + dict: Task status including task_id, status, tool_name, and operation_id. + + Raises: + NotFoundError: If task_id is not found or user doesn't have access. + """ + task = await stream_registry.get_task(task_id) + + if task is None: + raise NotFoundError(f"Task {task_id} not found.") + + # Validate ownership - if task has an owner, requester must match + if task.user_id and user_id != task.user_id: + raise NotFoundError(f"Task {task_id} not found.") + + return { + "task_id": task.task_id, + "session_id": task.session_id, + "status": task.status, + "tool_name": task.tool_name, + "operation_id": task.operation_id, + "created_at": task.created_at.isoformat(), + } + + +# ========== External Completion Webhook ========== + + +@router.post( + "/operations/{operation_id}/complete", + status_code=200, +) +async def complete_operation( + operation_id: str, + request: OperationCompleteRequest, + x_api_key: str | None = Header(default=None), +) -> dict: + """ + External completion webhook for long-running operations. + + Called by Agent Generator (or other services) when an operation completes. + This triggers the stream registry to publish completion and continue LLM generation. + + Args: + operation_id: The operation ID to complete. + request: Completion payload with success status and result/error. + x_api_key: Internal API key for authentication. + + Returns: + dict: Status of the completion. + + Raises: + HTTPException: If API key is invalid or operation not found. + """ + # Validate internal API key - reject if not configured or invalid + if not config.internal_api_key: + logger.error( + "Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured" + ) + raise HTTPException( + status_code=503, + detail="Webhook not available: internal API key not configured", + ) + if x_api_key != config.internal_api_key: + raise HTTPException(status_code=401, detail="Invalid API key") + + # Find task by operation_id + task = await stream_registry.find_task_by_operation_id(operation_id) + if task is None: + raise HTTPException( + status_code=404, + detail=f"Operation {operation_id} not found", + ) + + logger.info( + f"Received completion webhook for operation {operation_id} " + f"(task_id={task.task_id}, success={request.success})" + ) + + if request.success: + await process_operation_success(task, request.result) + else: + await process_operation_failure(task, request.error) + + return {"status": "ok", "task_id": task.task_id} + + +# ========== Configuration ========== + + +@router.get("/config/ttl", status_code=200) +async def get_ttl_config() -> dict: + """ + Get the stream TTL configuration. + + Returns the Time-To-Live settings for chat streams, which determines + how long clients can reconnect to an active stream. + + Returns: + dict: TTL configuration with seconds and milliseconds values. + """ + return { + "stream_ttl_seconds": config.stream_ttl, + "stream_ttl_ms": config.stream_ttl * 1000, + } + + # ========== Health Check ========== diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index 20216162b5..218575085b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -3,9 +3,13 @@ import logging import time from asyncio import CancelledError from collections.abc import AsyncGenerator -from typing import Any +from typing import TYPE_CHECKING, Any, cast import openai + +if TYPE_CHECKING: + from backend.util.prompt import CompressResult + import orjson from langfuse import get_client from openai import ( @@ -15,7 +19,13 @@ from openai import ( PermissionDeniedError, RateLimitError, ) -from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam +from openai.types.chat import ( + ChatCompletionChunk, + ChatCompletionMessageParam, + ChatCompletionStreamOptionsParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolParam, +) from backend.data.redis_client import get_redis_async from backend.data.understanding import ( @@ -26,6 +36,7 @@ from backend.util.exceptions import NotFoundError from backend.util.settings import Settings from . import db as chat_db +from . import stream_registry from .config import ChatConfig from .model import ( ChatMessage, @@ -794,207 +805,58 @@ def _is_region_blocked_error(error: Exception) -> bool: return "not available in your region" in str(error).lower() -async def _summarize_messages( +async def _manage_context_window( messages: list, model: str, api_key: str | None = None, base_url: str | None = None, - timeout: float = 30.0, -) -> str: - """Summarize a list of messages into concise context. +) -> "CompressResult": + """ + Manage context window using the unified compress_context function. - Uses the same model as the chat for higher quality summaries. + This is a thin wrapper that creates an OpenAI client for summarization + and delegates to the shared compression logic in prompt.py. Args: - messages: List of message dicts to summarize - model: Model to use for summarization (same as chat model) - api_key: API key for OpenAI client - base_url: Base URL for OpenAI client - timeout: Request timeout in seconds (default: 30.0) + messages: List of messages in OpenAI format + model: Model name for token counting and summarization + api_key: API key for summarization calls + base_url: Base URL for summarization calls Returns: - Summarized text + CompressResult with compacted messages and metadata """ - # Format messages for summarization - conversation = [] - for msg in messages: - role = msg.get("role", "") - content = msg.get("content", "") - # Include user, assistant, and tool messages (tool outputs are important context) - if content and role in ("user", "assistant", "tool"): - conversation.append(f"{role.upper()}: {content}") - - conversation_text = "\n\n".join(conversation) - - # Handle empty conversation - if not conversation_text: - return "No conversation history available." - - # Truncate conversation to fit within summarization model's context - # gpt-4o-mini has 128k context, but we limit to ~25k tokens (~100k chars) for safety - MAX_CHARS = 100_000 - if len(conversation_text) > MAX_CHARS: - conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]" - - # Call LLM to summarize import openai - summarization_client = openai.AsyncOpenAI( - api_key=api_key, base_url=base_url, timeout=timeout - ) + from backend.util.prompt import compress_context - response = await summarization_client.chat.completions.create( - model=model, - messages=[ - { - "role": "system", - "content": ( - "Create a detailed summary of the conversation so far. " - "This summary will be used as context when continuing the conversation.\n\n" - "Before writing the summary, analyze each message chronologically to identify:\n" - "- User requests and their explicit goals\n" - "- Your approach and key decisions made\n" - "- Technical specifics (file names, tool outputs, function signatures)\n" - "- Errors encountered and resolutions applied\n\n" - "You MUST include ALL of the following sections:\n\n" - "## 1. Primary Request and Intent\n" - "The user's explicit goals and what they are trying to accomplish.\n\n" - "## 2. Key Technical Concepts\n" - "Technologies, frameworks, tools, and patterns being used or discussed.\n\n" - "## 3. Files and Resources Involved\n" - "Specific files examined or modified, with relevant snippets and identifiers.\n\n" - "## 4. Errors and Fixes\n" - "Problems encountered, error messages, and their resolutions. " - "Include any user feedback on fixes.\n\n" - "## 5. Problem Solving\n" - "Issues that have been resolved and how they were addressed.\n\n" - "## 6. All User Messages\n" - "A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n" - "## 7. Pending Tasks\n" - "Work items the user explicitly requested that have not yet been completed.\n\n" - "## 8. Current Work\n" - "Precise description of what was being worked on most recently, including relevant context.\n\n" - "## 9. Next Steps\n" - "What should happen next, aligned with the user's most recent requests. " - "Include verbatim quotes of recent instructions if relevant." - ), - }, - {"role": "user", "content": f"Summarize:\n\n{conversation_text}"}, - ], - max_tokens=1500, - temperature=0.3, - ) + # Convert messages to dict format + messages_dict = [] + for msg in messages: + if isinstance(msg, dict): + msg_dict = {k: v for k, v in msg.items() if v is not None} + else: + msg_dict = dict(msg) + messages_dict.append(msg_dict) - summary = response.choices[0].message.content - return summary or "No summary available." - - -def _ensure_tool_pairs_intact( - recent_messages: list[dict], - all_messages: list[dict], - start_index: int, -) -> list[dict]: - """ - Ensure tool_call/tool_response pairs stay together after slicing. - - When slicing messages for context compaction, a naive slice can separate - an assistant message containing tool_calls from its corresponding tool - response messages. This causes API validation errors (e.g., Anthropic's - "unexpected tool_use_id found in tool_result blocks"). - - This function checks for orphan tool responses in the slice and extends - backwards to include their corresponding assistant messages. - - Args: - recent_messages: The sliced messages to validate - all_messages: The complete message list (for looking up missing assistants) - start_index: The index in all_messages where recent_messages begins - - Returns: - A potentially extended list of messages with tool pairs intact - """ - if not recent_messages: - return recent_messages - - # Collect all tool_call_ids from assistant messages in the slice - available_tool_call_ids: set[str] = set() - for msg in recent_messages: - if msg.get("role") == "assistant" and msg.get("tool_calls"): - for tc in msg["tool_calls"]: - tc_id = tc.get("id") - if tc_id: - available_tool_call_ids.add(tc_id) - - # Find orphan tool responses (tool messages whose tool_call_id is missing) - orphan_tool_call_ids: set[str] = set() - for msg in recent_messages: - if msg.get("role") == "tool": - tc_id = msg.get("tool_call_id") - if tc_id and tc_id not in available_tool_call_ids: - orphan_tool_call_ids.add(tc_id) - - if not orphan_tool_call_ids: - # No orphans, slice is valid - return recent_messages - - # Find the assistant messages that contain the orphan tool_call_ids - # Search backwards from start_index in all_messages - messages_to_prepend: list[dict] = [] - for i in range(start_index - 1, -1, -1): - msg = all_messages[i] - if msg.get("role") == "assistant" and msg.get("tool_calls"): - msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")} - if msg_tool_ids & orphan_tool_call_ids: - # This assistant message has tool_calls we need - # Also collect its contiguous tool responses that follow it - assistant_and_responses: list[dict] = [msg] - - # Scan forward from this assistant to collect tool responses - for j in range(i + 1, start_index): - following_msg = all_messages[j] - if following_msg.get("role") == "tool": - tool_id = following_msg.get("tool_call_id") - if tool_id and tool_id in msg_tool_ids: - assistant_and_responses.append(following_msg) - else: - # Stop at first non-tool message - break - - # Prepend the assistant and its tool responses (maintain order) - messages_to_prepend = assistant_and_responses + messages_to_prepend - # Mark these as found - orphan_tool_call_ids -= msg_tool_ids - # Also add this assistant's tool_call_ids to available set - available_tool_call_ids |= msg_tool_ids - - if not orphan_tool_call_ids: - # Found all missing assistants - break - - if orphan_tool_call_ids: - # Some tool_call_ids couldn't be resolved - remove those tool responses - # This shouldn't happen in normal operation but handles edge cases - logger.warning( - f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. " - "Removing orphan tool responses." - ) - recent_messages = [ - msg - for msg in recent_messages - if not ( - msg.get("role") == "tool" - and msg.get("tool_call_id") in orphan_tool_call_ids + # Only create client if api_key is provided (enables summarization) + # Use context manager to avoid socket leaks + if api_key: + async with openai.AsyncOpenAI( + api_key=api_key, base_url=base_url, timeout=30.0 + ) as client: + return await compress_context( + messages=messages_dict, + model=model, + client=client, ) - ] - - if messages_to_prepend: - logger.info( - f"Extended recent messages by {len(messages_to_prepend)} to preserve " - f"tool_call/tool_response pairs" + else: + # No API key - use truncation-only mode + return await compress_context( + messages=messages_dict, + model=model, + client=None, ) - return messages_to_prepend + recent_messages - - return recent_messages async def _stream_chat_chunks( @@ -1022,11 +884,8 @@ async def _stream_chat_chunks( logger.info("Starting pure chat stream") - # Build messages with system prompt prepended messages = session.to_openai_messages() if system_prompt: - from openai.types.chat import ChatCompletionSystemMessageParam - system_message = ChatCompletionSystemMessageParam( role="system", content=system_prompt, @@ -1034,314 +893,38 @@ async def _stream_chat_chunks( messages = [system_message] + messages # Apply context window management - token_count = 0 # Initialize for exception handler - try: - from backend.util.prompt import estimate_token_count + context_result = await _manage_context_window( + messages=messages, + model=model, + api_key=config.api_key, + base_url=config.base_url, + ) - # Convert to dict for token counting - # OpenAI message types are TypedDicts, so they're already dict-like - messages_dict = [] - for msg in messages: - # TypedDict objects are already dicts, just filter None values - if isinstance(msg, dict): - msg_dict = {k: v for k, v in msg.items() if v is not None} - else: - # Fallback for unexpected types - msg_dict = dict(msg) - messages_dict.append(msg_dict) - - # Estimate tokens using appropriate tokenizer - # Normalize model name for token counting (tiktoken only supports OpenAI models) - token_count_model = model - if "/" in model: - # Strip provider prefix (e.g., "anthropic/claude-opus-4.5" -> "claude-opus-4.5") - token_count_model = model.split("/")[-1] - - # For Claude and other non-OpenAI models, approximate with gpt-4o tokenizer - # Most modern LLMs have similar tokenization (~1 token per 4 chars) - if "claude" in token_count_model.lower() or not any( - known in token_count_model.lower() - for known in ["gpt", "o1", "chatgpt", "text-"] - ): - token_count_model = "gpt-4o" - - # Attempt token counting with error handling - try: - token_count = estimate_token_count(messages_dict, model=token_count_model) - except Exception as token_error: - # If token counting fails, use gpt-4o as fallback approximation - logger.warning( - f"Token counting failed for model {token_count_model}: {token_error}. " - "Using gpt-4o approximation." - ) - token_count = estimate_token_count(messages_dict, model="gpt-4o") - - # If over threshold, summarize old messages - if token_count > 120_000: - KEEP_RECENT = 15 - - # Check if we have a system prompt at the start - has_system_prompt = ( - len(messages) > 0 and messages[0].get("role") == "system" - ) - - # Always attempt mitigation when over limit, even with few messages - if messages: - # Split messages based on whether system prompt exists - # Calculate start index for the slice - slice_start = max(0, len(messages_dict) - KEEP_RECENT) - recent_messages = messages_dict[-KEEP_RECENT:] - - # Ensure tool_call/tool_response pairs stay together - # This prevents API errors from orphan tool responses - recent_messages = _ensure_tool_pairs_intact( - recent_messages, messages_dict, slice_start - ) - - if has_system_prompt: - # Keep system prompt separate, summarize everything between system and recent - system_msg = messages[0] - old_messages_dict = messages_dict[1:-KEEP_RECENT] - else: - # No system prompt, summarize everything except recent - system_msg = None - old_messages_dict = messages_dict[:-KEEP_RECENT] - - # Summarize any non-empty old messages (no minimum threshold) - # If we're over the token limit, we need to compress whatever we can - if old_messages_dict: - # Summarize old messages using the same model as chat - summary_text = await _summarize_messages( - old_messages_dict, - model=model, - api_key=config.api_key, - base_url=config.base_url, - ) - - # Build new message list - # Use assistant role (not system) to prevent privilege escalation - # of user-influenced content to instruction-level authority - from openai.types.chat import ChatCompletionAssistantMessageParam - - summary_msg = ChatCompletionAssistantMessageParam( - role="assistant", - content=( - "[Previous conversation summary — for context only]: " - f"{summary_text}" - ), - ) - - # Rebuild messages based on whether we have a system prompt - if has_system_prompt: - # system_prompt + summary + recent_messages - messages = [system_msg, summary_msg] + recent_messages - else: - # summary + recent_messages (no original system prompt) - messages = [summary_msg] + recent_messages - - logger.info( - f"Context summarized: {token_count} tokens, " - f"summarized {len(old_messages_dict)} old messages, " - f"kept last {KEEP_RECENT} messages" - ) - - # Fallback: If still over limit after summarization, progressively drop recent messages - # This handles edge cases where recent messages are extremely large - new_messages_dict = [] - for msg in messages: - if isinstance(msg, dict): - msg_dict = {k: v for k, v in msg.items() if v is not None} - else: - msg_dict = dict(msg) - new_messages_dict.append(msg_dict) - - new_token_count = estimate_token_count( - new_messages_dict, model=token_count_model - ) - - if new_token_count > 120_000: - # Still over limit - progressively reduce KEEP_RECENT - logger.warning( - f"Still over limit after summarization: {new_token_count} tokens. " - "Reducing number of recent messages kept." - ) - - for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]: - if keep_count == 0: - # Try with just system prompt + summary (no recent messages) - if has_system_prompt: - messages = [system_msg, summary_msg] - else: - messages = [summary_msg] - logger.info( - "Trying with 0 recent messages (system + summary only)" - ) - else: - # Slice from ORIGINAL recent_messages to avoid duplicating summary - reduced_recent = ( - recent_messages[-keep_count:] - if len(recent_messages) >= keep_count - else recent_messages - ) - # Ensure tool pairs stay intact in the reduced slice - reduced_slice_start = max( - 0, len(recent_messages) - keep_count - ) - reduced_recent = _ensure_tool_pairs_intact( - reduced_recent, recent_messages, reduced_slice_start - ) - if has_system_prompt: - messages = [ - system_msg, - summary_msg, - ] + reduced_recent - else: - messages = [summary_msg] + reduced_recent - - new_messages_dict = [] - for msg in messages: - if isinstance(msg, dict): - msg_dict = { - k: v for k, v in msg.items() if v is not None - } - else: - msg_dict = dict(msg) - new_messages_dict.append(msg_dict) - - new_token_count = estimate_token_count( - new_messages_dict, model=token_count_model - ) - - if new_token_count <= 120_000: - logger.info( - f"Reduced to {keep_count} recent messages, " - f"now {new_token_count} tokens" - ) - break - else: - logger.error( - f"Unable to reduce token count below threshold even with 0 messages. " - f"Final count: {new_token_count} tokens" - ) - # ABSOLUTE LAST RESORT: Drop system prompt - # This should only happen if summary itself is massive - if has_system_prompt and len(messages) > 1: - messages = messages[1:] # Drop system prompt - logger.critical( - "CRITICAL: Dropped system prompt as absolute last resort. " - "Behavioral consistency may be affected." - ) - # Yield error to user - yield StreamError( - errorText=( - "Warning: System prompt dropped due to size constraints. " - "Assistant behavior may be affected." - ) - ) - else: - # No old messages to summarize - all messages are "recent" - # Apply progressive truncation to reduce token count - logger.warning( - f"Token count {token_count} exceeds threshold but no old messages to summarize. " - f"Applying progressive truncation to recent messages." - ) - - # Create a base list excluding system prompt to avoid duplication - # This is the pool of messages we'll slice from in the loop - # Use messages_dict for type consistency with _ensure_tool_pairs_intact - base_msgs = ( - messages_dict[1:] if has_system_prompt else messages_dict - ) - - # Try progressively smaller keep counts - new_token_count = token_count # Initialize with current count - for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]: - if keep_count == 0: - # Try with just system prompt (no recent messages) - if has_system_prompt: - messages = [system_msg] - logger.info( - "Trying with 0 recent messages (system prompt only)" - ) - else: - # No system prompt and no recent messages = empty messages list - # This is invalid, skip this iteration - continue - else: - if len(base_msgs) < keep_count: - continue # Skip if we don't have enough messages - - # Slice from base_msgs to get recent messages (without system prompt) - recent_messages = base_msgs[-keep_count:] - - # Ensure tool pairs stay intact in the reduced slice - reduced_slice_start = max(0, len(base_msgs) - keep_count) - recent_messages = _ensure_tool_pairs_intact( - recent_messages, base_msgs, reduced_slice_start - ) - - if has_system_prompt: - messages = [system_msg] + recent_messages - else: - messages = recent_messages - - new_messages_dict = [] - for msg in messages: - if msg is None: - continue # Skip None messages (type safety) - if isinstance(msg, dict): - msg_dict = { - k: v for k, v in msg.items() if v is not None - } - else: - msg_dict = dict(msg) - new_messages_dict.append(msg_dict) - - new_token_count = estimate_token_count( - new_messages_dict, model=token_count_model - ) - - if new_token_count <= 120_000: - logger.info( - f"Reduced to {keep_count} recent messages, " - f"now {new_token_count} tokens" - ) - break - else: - # Even with 0 messages still over limit - logger.error( - f"Unable to reduce token count below threshold even with 0 messages. " - f"Final count: {new_token_count} tokens. Messages may be extremely large." - ) - # ABSOLUTE LAST RESORT: Drop system prompt - if has_system_prompt and len(messages) > 1: - messages = messages[1:] # Drop system prompt - logger.critical( - "CRITICAL: Dropped system prompt as absolute last resort. " - "Behavioral consistency may be affected." - ) - # Yield error to user - yield StreamError( - errorText=( - "Warning: System prompt dropped due to size constraints. " - "Assistant behavior may be affected." - ) - ) - - except Exception as e: - logger.error(f"Context summarization failed: {e}", exc_info=True) - # If we were over the token limit, yield error to user - # Don't silently continue with oversized messages that will fail - if token_count > 120_000: + if context_result.error: + if "System prompt dropped" in context_result.error: + # Warning only - continue with reduced context yield StreamError( errorText=( - f"Unable to manage context window (token limit exceeded: {token_count} tokens). " - "Context summarization failed. Please start a new conversation." + "Warning: System prompt dropped due to size constraints. " + "Assistant behavior may be affected." + ) + ) + else: + # Any other error - abort to prevent failed LLM calls + yield StreamError( + errorText=( + f"Context window management failed: {context_result.error}. " + "Please start a new conversation." ) ) yield StreamFinish() return - # Otherwise, continue with original messages (under limit) + + messages = context_result.messages + if context_result.was_compacted: + logger.info( + f"Context compacted for streaming: {context_result.token_count} tokens" + ) # Loop to handle tool calls and continue conversation while True: @@ -1369,14 +952,6 @@ async def _stream_chat_chunks( :128 ] # OpenRouter limit - # Create the stream with proper types - from typing import cast - - from openai.types.chat import ( - ChatCompletionMessageParam, - ChatCompletionStreamOptionsParam, - ) - stream = await client.chat.completions.create( model=model, messages=cast(list[ChatCompletionMessageParam], messages), @@ -1610,8 +1185,9 @@ async def _yield_tool_call( ) return - # Generate operation ID + # Generate operation ID and task ID operation_id = str(uuid_module.uuid4()) + task_id = str(uuid_module.uuid4()) # Build a user-friendly message based on tool and arguments if tool_name == "create_agent": @@ -1654,6 +1230,16 @@ async def _yield_tool_call( # Wrap session save and task creation in try-except to release lock on failure try: + # Create task in stream registry for SSE reconnection support + await stream_registry.create_task( + task_id=task_id, + session_id=session.session_id, + user_id=session.user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + # Save assistant message with tool_call FIRST (required by LLM) assistant_message = ChatMessage( role="assistant", @@ -1675,23 +1261,27 @@ async def _yield_tool_call( session.messages.append(pending_message) await upsert_chat_session(session) logger.info( - f"Saved pending operation {operation_id} for tool {tool_name} " - f"in session {session.session_id}" + f"Saved pending operation {operation_id} (task_id={task_id}) " + f"for tool {tool_name} in session {session.session_id}" ) # Store task reference in module-level set to prevent GC before completion - task = asyncio.create_task( - _execute_long_running_tool( + bg_task = asyncio.create_task( + _execute_long_running_tool_with_streaming( tool_name=tool_name, parameters=arguments, tool_call_id=tool_call_id, operation_id=operation_id, + task_id=task_id, session_id=session.session_id, user_id=session.user_id, ) ) - _background_tasks.add(task) - task.add_done_callback(_background_tasks.discard) + _background_tasks.add(bg_task) + bg_task.add_done_callback(_background_tasks.discard) + + # Associate the asyncio task with the stream registry task + await stream_registry.set_task_asyncio_task(task_id, bg_task) except Exception as e: # Roll back appended messages to prevent data corruption on subsequent saves if ( @@ -1709,6 +1299,11 @@ async def _yield_tool_call( # Release the Redis lock since the background task won't be spawned await _mark_operation_completed(tool_call_id) + # Mark stream registry task as failed if it was created + try: + await stream_registry.mark_task_completed(task_id, status="failed") + except Exception: + pass logger.error( f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True ) @@ -1722,6 +1317,7 @@ async def _yield_tool_call( message=started_msg, operation_id=operation_id, tool_name=tool_name, + task_id=task_id, # Include task_id for SSE reconnection ).model_dump_json(), success=True, ) @@ -1791,6 +1387,9 @@ async def _execute_long_running_tool( This function runs independently of the SSE connection, so the operation survives if the user closes their browser tab. + + NOTE: This is the legacy function without stream registry support. + Use _execute_long_running_tool_with_streaming for new implementations. """ try: # Load fresh session (not stale reference) @@ -1834,10 +1433,142 @@ async def _execute_long_running_tool( tool_call_id=tool_call_id, result=error_response.model_dump_json(), ) + # Generate LLM continuation so user sees explanation even for errors + try: + await _generate_llm_continuation(session_id=session_id, user_id=user_id) + except Exception as llm_err: + logger.warning(f"Failed to generate LLM continuation for error: {llm_err}") finally: await _mark_operation_completed(tool_call_id) +async def _execute_long_running_tool_with_streaming( + tool_name: str, + parameters: dict[str, Any], + tool_call_id: str, + operation_id: str, + task_id: str, + session_id: str, + user_id: str | None, +) -> None: + """Execute a long-running tool with stream registry support for SSE reconnection. + + This function runs independently of the SSE connection, publishes progress + to the stream registry, and survives if the user closes their browser tab. + Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming. + + If the external service returns a 202 Accepted (async), this function exits + early and lets the Redis Streams completion consumer handle the rest. + """ + # Track whether we delegated to async processing - if so, the Redis Streams + # completion consumer (stream_registry / completion_consumer) will handle cleanup, not us + delegated_to_async = False + + try: + # Load fresh session (not stale reference) + session = await get_chat_session(session_id, user_id) + if not session: + logger.error(f"Session {session_id} not found for background tool") + await stream_registry.mark_task_completed(task_id, status="failed") + return + + # Pass operation_id and task_id to the tool for async processing + enriched_parameters = { + **parameters, + "_operation_id": operation_id, + "_task_id": task_id, + } + + # Execute the actual tool + result = await execute_tool( + tool_name=tool_name, + parameters=enriched_parameters, + tool_call_id=tool_call_id, + user_id=user_id, + session=session, + ) + + # Check if the tool result indicates async processing + # (e.g., Agent Generator returned 202 Accepted) + try: + if isinstance(result.output, dict): + result_data = result.output + elif result.output: + result_data = orjson.loads(result.output) + else: + result_data = {} + if result_data.get("status") == "accepted": + logger.info( + f"Tool {tool_name} delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id}). " + f"Redis Streams completion consumer will handle the rest." + ) + # Don't publish result, don't continue with LLM, and don't cleanup + # The Redis Streams consumer (completion_consumer) will handle + # everything when the external service completes via webhook + delegated_to_async = True + return + except (orjson.JSONDecodeError, TypeError): + pass # Not JSON or not async - continue normally + + # Publish tool result to stream registry + await stream_registry.publish_chunk(task_id, result) + + # Update the pending message with result + result_str = ( + result.output + if isinstance(result.output, str) + else orjson.dumps(result.output).decode("utf-8") + ) + await _update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=result_str, + ) + + logger.info( + f"Background tool {tool_name} completed for session {session_id} " + f"(task_id={task_id})" + ) + + # Generate LLM continuation and stream chunks to registry + await _generate_llm_continuation_with_streaming( + session_id=session_id, + user_id=user_id, + task_id=task_id, + ) + + # Mark task as completed in stream registry + await stream_registry.mark_task_completed(task_id, status="completed") + + except Exception as e: + logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True) + error_response = ErrorResponse( + message=f"Tool {tool_name} failed: {str(e)}", + ) + + # Publish error to stream registry followed by finish event + await stream_registry.publish_chunk( + task_id, + StreamError(errorText=str(e)), + ) + await stream_registry.publish_chunk(task_id, StreamFinish()) + + await _update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=error_response.model_dump_json(), + ) + + # Mark task as failed in stream registry + await stream_registry.mark_task_completed(task_id, status="failed") + finally: + # Only cleanup if we didn't delegate to async processing + # For async path, the Redis Streams completion consumer handles cleanup + if not delegated_to_async: + await _mark_operation_completed(tool_call_id) + + async def _update_pending_operation( session_id: str, tool_call_id: str, @@ -1895,17 +1626,36 @@ async def _generate_llm_continuation( # Build system prompt system_prompt, _ = await _build_system_prompt(user_id) - # Build messages in OpenAI format messages = session.to_openai_messages() if system_prompt: - from openai.types.chat import ChatCompletionSystemMessageParam - system_message = ChatCompletionSystemMessageParam( role="system", content=system_prompt, ) messages = [system_message] + messages + # Apply context window management to prevent oversized requests + context_result = await _manage_context_window( + messages=messages, + model=config.model, + api_key=config.api_key, + base_url=config.base_url, + ) + + if context_result.error and "System prompt dropped" not in context_result.error: + logger.error( + f"Context window management failed for session {session_id}: " + f"{context_result.error} (tokens={context_result.token_count})" + ) + return + + messages = context_result.messages + if context_result.was_compacted: + logger.info( + f"Context compacted for LLM continuation: " + f"{context_result.token_count} tokens" + ) + # Build extra_body for tracing extra_body: dict[str, Any] = { "posthogProperties": { @@ -1918,19 +1668,54 @@ async def _generate_llm_continuation( if session_id: extra_body["session_id"] = session_id[:128] - # Make non-streaming LLM call (no tools - just text response) - from typing import cast + retry_count = 0 + last_error: Exception | None = None + response = None - from openai.types.chat import ChatCompletionMessageParam + while retry_count <= MAX_RETRIES: + try: + logger.info( + f"Generating LLM continuation for session {session_id}" + f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}" + ) - # No tools parameter = text-only response (no tool calls) - response = await client.chat.completions.create( - model=config.model, - messages=cast(list[ChatCompletionMessageParam], messages), - extra_body=extra_body, - ) + response = await client.chat.completions.create( + model=config.model, + messages=cast(list[ChatCompletionMessageParam], messages), + extra_body=extra_body, + ) + last_error = None # Clear any previous error on success + break # Success, exit retry loop + except Exception as e: + last_error = e + if _is_retryable_error(e) and retry_count < MAX_RETRIES: + retry_count += 1 + delay = min( + BASE_DELAY_SECONDS * (2 ** (retry_count - 1)), + MAX_DELAY_SECONDS, + ) + logger.warning( + f"Retryable error in LLM continuation: {e!s}. " + f"Retrying in {delay:.1f}s (attempt {retry_count}/{MAX_RETRIES})" + ) + await asyncio.sleep(delay) + continue + else: + # Non-retryable error - log and exit gracefully + logger.error( + f"Non-retryable error in LLM continuation: {e!s}", + exc_info=True, + ) + return - if response.choices and response.choices[0].message.content: + if last_error: + logger.error( + f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. " + f"Last error: {last_error!s}" + ) + return + + if response and response.choices and response.choices[0].message.content: assistant_content = response.choices[0].message.content # Reload session from DB to avoid race condition with user messages @@ -1964,3 +1749,128 @@ async def _generate_llm_continuation( except Exception as e: logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True) + + +async def _generate_llm_continuation_with_streaming( + session_id: str, + user_id: str | None, + task_id: str, +) -> None: + """Generate an LLM response with streaming to the stream registry. + + This is called by background tasks to continue the conversation + after a tool result is saved. Chunks are published to the stream registry + so reconnecting clients can receive them. + """ + import uuid as uuid_module + + try: + # Load fresh session from DB (bypass cache to get the updated tool result) + await invalidate_session_cache(session_id) + session = await get_chat_session(session_id, user_id) + if not session: + logger.error(f"Session {session_id} not found for LLM continuation") + return + + # Build system prompt + system_prompt, _ = await _build_system_prompt(user_id) + + # Build messages in OpenAI format + messages = session.to_openai_messages() + if system_prompt: + from openai.types.chat import ChatCompletionSystemMessageParam + + system_message = ChatCompletionSystemMessageParam( + role="system", + content=system_prompt, + ) + messages = [system_message] + messages + + # Build extra_body for tracing + extra_body: dict[str, Any] = { + "posthogProperties": { + "environment": settings.config.app_env.value, + }, + } + if user_id: + extra_body["user"] = user_id[:128] + extra_body["posthogDistinctId"] = user_id + if session_id: + extra_body["session_id"] = session_id[:128] + + # Make streaming LLM call (no tools - just text response) + from typing import cast + + from openai.types.chat import ChatCompletionMessageParam + + # Generate unique IDs for AI SDK protocol + message_id = str(uuid_module.uuid4()) + text_block_id = str(uuid_module.uuid4()) + + # Publish start event + await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id)) + await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id)) + + # Stream the response + stream = await client.chat.completions.create( + model=config.model, + messages=cast(list[ChatCompletionMessageParam], messages), + extra_body=extra_body, + stream=True, + ) + + assistant_content = "" + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + delta = chunk.choices[0].delta.content + assistant_content += delta + # Publish delta to stream registry + await stream_registry.publish_chunk( + task_id, + StreamTextDelta(id=text_block_id, delta=delta), + ) + + # Publish end events + await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id)) + + if assistant_content: + # Reload session from DB to avoid race condition with user messages + fresh_session = await get_chat_session(session_id, user_id) + if not fresh_session: + logger.error( + f"Session {session_id} disappeared during LLM continuation" + ) + return + + # Save assistant message to database + assistant_message = ChatMessage( + role="assistant", + content=assistant_content, + ) + fresh_session.messages.append(assistant_message) + + # Save to database (not cache) to persist the response + await upsert_chat_session(fresh_session) + + # Invalidate cache so next poll/refresh gets fresh data + await invalidate_session_cache(session_id) + + logger.info( + f"Generated streaming LLM continuation for session {session_id} " + f"(task_id={task_id}), response length: {len(assistant_content)}" + ) + else: + logger.warning( + f"Streaming LLM continuation returned empty response for {session_id}" + ) + + except Exception as e: + logger.error( + f"Failed to generate streaming LLM continuation: {e}", exc_info=True + ) + # Publish error to stream registry followed by finish event + await stream_registry.publish_chunk( + task_id, + StreamError(errorText=f"Failed to generate response: {e}"), + ) + await stream_registry.publish_chunk(task_id, StreamFinish()) diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py new file mode 100644 index 0000000000..88a5023e2b --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -0,0 +1,704 @@ +"""Stream registry for managing reconnectable SSE streams. + +This module provides a registry for tracking active streaming tasks and their +messages. It uses Redis for all state management (no in-memory state), making +pods stateless and horizontally scalable. + +Architecture: +- Redis Stream: Persists all messages for replay and real-time delivery +- Redis Hash: Task metadata (status, session_id, etc.) + +Subscribers: +1. Replay missed messages from Redis Stream (XREAD) +2. Listen for live updates via blocking XREAD +3. No in-memory state required on the subscribing pod +""" + +import asyncio +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Literal + +import orjson + +from backend.data.redis_client import get_redis_async + +from .config import ChatConfig +from .response_model import StreamBaseResponse, StreamError, StreamFinish + +logger = logging.getLogger(__name__) +config = ChatConfig() + +# Track background tasks for this pod (just the asyncio.Task reference, not subscribers) +_local_tasks: dict[str, asyncio.Task] = {} + +# Track listener tasks per subscriber queue for cleanup +# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe +_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {} + +# Timeout for putting chunks into subscriber queues (seconds) +# If the queue is full and doesn't drain within this time, send an overflow error +QUEUE_PUT_TIMEOUT = 5.0 + +# Lua script for atomic compare-and-swap status update (idempotent completion) +# Returns 1 if status was updated, 0 if already completed/failed +COMPLETE_TASK_SCRIPT = """ +local current = redis.call("HGET", KEYS[1], "status") +if current == "running" then + redis.call("HSET", KEYS[1], "status", ARGV[1]) + return 1 +end +return 0 +""" + + +@dataclass +class ActiveTask: + """Represents an active streaming task (metadata only, no in-memory queues).""" + + task_id: str + session_id: str + user_id: str | None + tool_call_id: str + tool_name: str + operation_id: str + status: Literal["running", "completed", "failed"] = "running" + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + asyncio_task: asyncio.Task | None = None + + +def _get_task_meta_key(task_id: str) -> str: + """Get Redis key for task metadata.""" + return f"{config.task_meta_prefix}{task_id}" + + +def _get_task_stream_key(task_id: str) -> str: + """Get Redis key for task message stream.""" + return f"{config.task_stream_prefix}{task_id}" + + +def _get_operation_mapping_key(operation_id: str) -> str: + """Get Redis key for operation_id to task_id mapping.""" + return f"{config.task_op_prefix}{operation_id}" + + +async def create_task( + task_id: str, + session_id: str, + user_id: str | None, + tool_call_id: str, + tool_name: str, + operation_id: str, +) -> ActiveTask: + """Create a new streaming task in Redis. + + Args: + task_id: Unique identifier for the task + session_id: Chat session ID + user_id: User ID (may be None for anonymous) + tool_call_id: Tool call ID from the LLM + tool_name: Name of the tool being executed + operation_id: Operation ID for webhook callbacks + + Returns: + The created ActiveTask instance (metadata only) + """ + task = ActiveTask( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + + # Store metadata in Redis + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + op_key = _get_operation_mapping_key(operation_id) + + await redis.hset( # type: ignore[misc] + meta_key, + mapping={ + "task_id": task_id, + "session_id": session_id, + "user_id": user_id or "", + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "operation_id": operation_id, + "status": task.status, + "created_at": task.created_at.isoformat(), + }, + ) + await redis.expire(meta_key, config.stream_ttl) + + # Create operation_id -> task_id mapping for webhook lookups + await redis.set(op_key, task_id, ex=config.stream_ttl) + + logger.debug(f"Created task {task_id} for session {session_id}") + + return task + + +async def publish_chunk( + task_id: str, + chunk: StreamBaseResponse, +) -> str: + """Publish a chunk to Redis Stream. + + All delivery is via Redis Streams - no in-memory state. + + Args: + task_id: Task ID to publish to + chunk: The stream response chunk to publish + + Returns: + The Redis Stream message ID + """ + chunk_json = chunk.model_dump_json() + message_id = "0-0" + + try: + redis = await get_redis_async() + stream_key = _get_task_stream_key(task_id) + + # Write to Redis Stream for persistence and real-time delivery + raw_id = await redis.xadd( + stream_key, + {"data": chunk_json}, + maxlen=config.stream_max_length, + ) + message_id = raw_id if isinstance(raw_id, str) else raw_id.decode() + + # Set TTL on stream to match task metadata TTL + await redis.expire(stream_key, config.stream_ttl) + except Exception as e: + logger.error( + f"Failed to publish chunk for task {task_id}: {e}", + exc_info=True, + ) + + return message_id + + +async def subscribe_to_task( + task_id: str, + user_id: str | None, + last_message_id: str = "0-0", +) -> asyncio.Queue[StreamBaseResponse] | None: + """Subscribe to a task's stream with replay of missed messages. + + This is fully stateless - uses Redis Stream for replay and pub/sub for live updates. + + Args: + task_id: Task ID to subscribe to + user_id: User ID for ownership validation + last_message_id: Last Redis Stream message ID received ("0-0" for full replay) + + Returns: + An asyncio Queue that will receive stream chunks, or None if task not found + or user doesn't have access + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + logger.debug(f"Task {task_id} not found in Redis") + return None + + # Note: Redis client uses decode_responses=True, so keys are strings + task_status = meta.get("status", "") + task_user_id = meta.get("user_id", "") or None + + # Validate ownership - if task has an owner, requester must match + if task_user_id: + if user_id != task_user_id: + logger.warning( + f"User {user_id} denied access to task {task_id} " + f"owned by {task_user_id}" + ) + return None + + subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue() + stream_key = _get_task_stream_key(task_id) + + # Step 1: Replay messages from Redis Stream + messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) + + replayed_count = 0 + replay_last_id = last_message_id + if messages: + for _stream_name, stream_messages in messages: + for msg_id, msg_data in stream_messages: + replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + # Note: Redis client uses decode_responses=True, so keys are strings + if "data" in msg_data: + try: + chunk_data = orjson.loads(msg_data["data"]) + chunk = _reconstruct_chunk(chunk_data) + if chunk: + await subscriber_queue.put(chunk) + replayed_count += 1 + except Exception as e: + logger.warning(f"Failed to replay message: {e}") + + logger.debug(f"Task {task_id}: replayed {replayed_count} messages") + + # Step 2: If task is still running, start stream listener for live updates + if task_status == "running": + listener_task = asyncio.create_task( + _stream_listener(task_id, subscriber_queue, replay_last_id) + ) + # Track listener task for cleanup on unsubscribe + _listener_tasks[id(subscriber_queue)] = (task_id, listener_task) + else: + # Task is completed/failed - add finish marker + await subscriber_queue.put(StreamFinish()) + + return subscriber_queue + + +async def _stream_listener( + task_id: str, + subscriber_queue: asyncio.Queue[StreamBaseResponse], + last_replayed_id: str, +) -> None: + """Listen to Redis Stream for new messages using blocking XREAD. + + This approach avoids the duplicate message issue that can occur with pub/sub + when messages are published during the gap between replay and subscription. + + Args: + task_id: Task ID to listen for + subscriber_queue: Queue to deliver messages to + last_replayed_id: Last message ID from replay (continue from here) + """ + queue_id = id(subscriber_queue) + # Track the last successfully delivered message ID for recovery hints + last_delivered_id = last_replayed_id + + try: + redis = await get_redis_async() + stream_key = _get_task_stream_key(task_id) + current_id = last_replayed_id + + while True: + # Block for up to 30 seconds waiting for new messages + # This allows periodic checking if task is still running + messages = await redis.xread( + {stream_key: current_id}, block=30000, count=100 + ) + + if not messages: + # Timeout - check if task is still running + meta_key = _get_task_meta_key(task_id) + status = await redis.hget(meta_key, "status") # type: ignore[misc] + if status and status != "running": + try: + await asyncio.wait_for( + subscriber_queue.put(StreamFinish()), + timeout=QUEUE_PUT_TIMEOUT, + ) + except asyncio.TimeoutError: + logger.warning( + f"Timeout delivering finish event for task {task_id}" + ) + break + continue + + for _stream_name, stream_messages in messages: + for msg_id, msg_data in stream_messages: + current_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + + if "data" not in msg_data: + continue + + try: + chunk_data = orjson.loads(msg_data["data"]) + chunk = _reconstruct_chunk(chunk_data) + if chunk: + try: + await asyncio.wait_for( + subscriber_queue.put(chunk), + timeout=QUEUE_PUT_TIMEOUT, + ) + # Update last delivered ID on successful delivery + last_delivered_id = current_id + except asyncio.TimeoutError: + logger.warning( + f"Subscriber queue full for task {task_id}, " + f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s" + ) + # Send overflow error with recovery info + try: + overflow_error = StreamError( + errorText="Message delivery timeout - some messages may have been missed", + code="QUEUE_OVERFLOW", + details={ + "last_delivered_id": last_delivered_id, + "recovery_hint": f"Reconnect with last_message_id={last_delivered_id}", + }, + ) + subscriber_queue.put_nowait(overflow_error) + except asyncio.QueueFull: + # Queue is completely stuck, nothing more we can do + logger.error( + f"Cannot deliver overflow error for task {task_id}, " + "queue completely blocked" + ) + + # Stop listening on finish + if isinstance(chunk, StreamFinish): + return + except Exception as e: + logger.warning(f"Error processing stream message: {e}") + + except asyncio.CancelledError: + logger.debug(f"Stream listener cancelled for task {task_id}") + raise # Re-raise to propagate cancellation + except Exception as e: + logger.error(f"Stream listener error for task {task_id}: {e}") + # On error, send finish to unblock subscriber + try: + await asyncio.wait_for( + subscriber_queue.put(StreamFinish()), + timeout=QUEUE_PUT_TIMEOUT, + ) + except (asyncio.TimeoutError, asyncio.QueueFull): + logger.warning( + f"Could not deliver finish event for task {task_id} after error" + ) + finally: + # Clean up listener task mapping on exit + _listener_tasks.pop(queue_id, None) + + +async def mark_task_completed( + task_id: str, + status: Literal["completed", "failed"] = "completed", +) -> bool: + """Mark a task as completed and publish finish event. + + This is idempotent - calling multiple times with the same task_id is safe. + Uses atomic compare-and-swap via Lua script to prevent race conditions. + Status is updated first (source of truth), then finish event is published (best-effort). + + Args: + task_id: Task ID to mark as completed + status: Final status ("completed" or "failed") + + Returns: + True if task was newly marked completed, False if already completed/failed + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + + # Atomic compare-and-swap: only update if status is "running" + # This prevents race conditions when multiple callers try to complete simultaneously + result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc] + + if result == 0: + logger.debug(f"Task {task_id} already completed/failed, skipping") + return False + + # THEN publish finish event (best-effort - listeners can detect via status polling) + try: + await publish_chunk(task_id, StreamFinish()) + except Exception as e: + logger.error( + f"Failed to publish finish event for task {task_id}: {e}. " + "Listeners will detect completion via status polling." + ) + + # Clean up local task reference if exists + _local_tasks.pop(task_id, None) + return True + + +async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None: + """Find a task by its operation ID. + + Used by webhook callbacks to locate the task to update. + + Args: + operation_id: Operation ID to search for + + Returns: + ActiveTask if found, None otherwise + """ + redis = await get_redis_async() + op_key = _get_operation_mapping_key(operation_id) + task_id = await redis.get(op_key) + + if not task_id: + return None + + task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id + return await get_task(task_id_str) + + +async def get_task(task_id: str) -> ActiveTask | None: + """Get a task by its ID from Redis. + + Args: + task_id: Task ID to look up + + Returns: + ActiveTask if found, None otherwise + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + return None + + # Note: Redis client uses decode_responses=True, so keys/values are strings + return ActiveTask( + task_id=meta.get("task_id", ""), + session_id=meta.get("session_id", ""), + user_id=meta.get("user_id", "") or None, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status=meta.get("status", "running"), # type: ignore[arg-type] + ) + + +async def get_task_with_expiry_info( + task_id: str, +) -> tuple[ActiveTask | None, str | None]: + """Get a task by its ID with expiration detection. + + Returns (task, error_code) where error_code is: + - None if task found + - "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired) + - "TASK_NOT_FOUND" if neither exists + + Args: + task_id: Task ID to look up + + Returns: + Tuple of (ActiveTask or None, error_code or None) + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + stream_key = _get_task_stream_key(task_id) + + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + # Check if stream still has data (metadata expired but stream hasn't) + stream_len = await redis.xlen(stream_key) + if stream_len > 0: + return None, "TASK_EXPIRED" + return None, "TASK_NOT_FOUND" + + # Note: Redis client uses decode_responses=True, so keys/values are strings + return ( + ActiveTask( + task_id=meta.get("task_id", ""), + session_id=meta.get("session_id", ""), + user_id=meta.get("user_id", "") or None, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status=meta.get("status", "running"), # type: ignore[arg-type] + ), + None, + ) + + +async def get_active_task_for_session( + session_id: str, + user_id: str | None = None, +) -> tuple[ActiveTask | None, str]: + """Get the active (running) task for a session, if any. + + Scans Redis for tasks matching the session_id with status="running". + + Args: + session_id: Session ID to look up + user_id: User ID for ownership validation (optional) + + Returns: + Tuple of (ActiveTask if found and running, last_message_id from Redis Stream) + """ + + redis = await get_redis_async() + + # Scan Redis for task metadata keys + cursor = 0 + tasks_checked = 0 + + while True: + cursor, keys = await redis.scan( + cursor, match=f"{config.task_meta_prefix}*", count=100 + ) + + for key in keys: + tasks_checked += 1 + meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc] + if not meta: + continue + + # Note: Redis client uses decode_responses=True, so keys/values are strings + task_session_id = meta.get("session_id", "") + task_status = meta.get("status", "") + task_user_id = meta.get("user_id", "") or None + task_id = meta.get("task_id", "") + + if task_session_id == session_id and task_status == "running": + # Validate ownership - if task has an owner, requester must match + if task_user_id and user_id != task_user_id: + continue + + # Get the last message ID from Redis Stream + stream_key = _get_task_stream_key(task_id) + last_id = "0-0" + try: + messages = await redis.xrevrange(stream_key, count=1) + if messages: + msg_id = messages[0][0] + last_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + except Exception as e: + logger.warning(f"Failed to get last message ID: {e}") + + return ( + ActiveTask( + task_id=task_id, + session_id=task_session_id, + user_id=task_user_id, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status="running", + ), + last_id, + ) + + if cursor == 0: + break + + return None, "0-0" + + +def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None: + """Reconstruct a StreamBaseResponse from JSON data. + + Args: + chunk_data: Parsed JSON data from Redis + + Returns: + Reconstructed response object, or None if unknown type + """ + from .response_model import ( + ResponseType, + StreamError, + StreamFinish, + StreamHeartbeat, + StreamStart, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, + StreamUsage, + ) + + # Map response types to their corresponding classes + type_to_class: dict[str, type[StreamBaseResponse]] = { + ResponseType.START.value: StreamStart, + ResponseType.FINISH.value: StreamFinish, + ResponseType.TEXT_START.value: StreamTextStart, + ResponseType.TEXT_DELTA.value: StreamTextDelta, + ResponseType.TEXT_END.value: StreamTextEnd, + ResponseType.TOOL_INPUT_START.value: StreamToolInputStart, + ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable, + ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable, + ResponseType.ERROR.value: StreamError, + ResponseType.USAGE.value: StreamUsage, + ResponseType.HEARTBEAT.value: StreamHeartbeat, + } + + chunk_type = chunk_data.get("type") + chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type] + + if chunk_class is None: + logger.warning(f"Unknown chunk type: {chunk_type}") + return None + + try: + return chunk_class(**chunk_data) + except Exception as e: + logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}") + return None + + +async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None: + """Track the asyncio.Task for a task (local reference only). + + This is just for cleanup purposes - the task state is in Redis. + + Args: + task_id: Task ID + asyncio_task: The asyncio Task to track + """ + _local_tasks[task_id] = asyncio_task + + +async def unsubscribe_from_task( + task_id: str, + subscriber_queue: asyncio.Queue[StreamBaseResponse], +) -> None: + """Clean up when a subscriber disconnects. + + Cancels the XREAD-based listener task associated with this subscriber queue + to prevent resource leaks. + + Args: + task_id: Task ID + subscriber_queue: The subscriber's queue used to look up the listener task + """ + queue_id = id(subscriber_queue) + listener_entry = _listener_tasks.pop(queue_id, None) + + if listener_entry is None: + logger.debug( + f"No listener task found for task {task_id} queue {queue_id} " + "(may have already completed)" + ) + return + + stored_task_id, listener_task = listener_entry + + if stored_task_id != task_id: + logger.warning( + f"Task ID mismatch in unsubscribe: expected {task_id}, " + f"found {stored_task_id}" + ) + + if listener_task.done(): + logger.debug(f"Listener task for task {task_id} already completed") + return + + # Cancel the listener task + listener_task.cancel() + + try: + # Wait for the task to be cancelled with a timeout + await asyncio.wait_for(listener_task, timeout=5.0) + except asyncio.CancelledError: + # Expected - the task was successfully cancelled + pass + except asyncio.TimeoutError: + logger.warning( + f"Timeout waiting for listener task cancellation for task {task_id}" + ) + except Exception as e: + logger.error(f"Error during listener task cancellation for task {task_id}: {e}") + + logger.debug(f"Successfully unsubscribed from task {task_id}") diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py index d078860c3a..dcbc35ef37 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py @@ -10,6 +10,7 @@ from .add_understanding import AddUnderstandingTool from .agent_output import AgentOutputTool from .base import BaseTool from .create_agent import CreateAgentTool +from .customize_agent import CustomizeAgentTool from .edit_agent import EditAgentTool from .find_agent import FindAgentTool from .find_block import FindBlockTool @@ -34,6 +35,7 @@ logger = logging.getLogger(__name__) TOOL_REGISTRY: dict[str, BaseTool] = { "add_understanding": AddUnderstandingTool(), "create_agent": CreateAgentTool(), + "customize_agent": CustomizeAgentTool(), "edit_agent": EditAgentTool(), "find_agent": FindAgentTool(), "find_block": FindBlockTool(), diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py index 499025b7dc..4266834220 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py @@ -2,30 +2,58 @@ from .core import ( AgentGeneratorNotConfiguredError, + AgentJsonValidationError, + AgentSummary, + DecompositionResult, + DecompositionStep, + LibraryAgentSummary, + MarketplaceAgentSummary, + customize_template, decompose_goal, + enrich_library_agents_from_steps, + extract_search_terms_from_steps, + extract_uuids_from_text, generate_agent, generate_agent_patch, get_agent_as_json, + get_all_relevant_agents_for_generation, + get_library_agent_by_graph_id, + get_library_agent_by_id, + get_library_agents_for_generation, + graph_to_json, json_to_graph, save_agent_to_library, + search_marketplace_agents_for_generation, ) from .errors import get_user_message_for_error from .service import health_check as check_external_service_health from .service import is_external_service_configured __all__ = [ - # Core functions + "AgentGeneratorNotConfiguredError", + "AgentJsonValidationError", + "AgentSummary", + "DecompositionResult", + "DecompositionStep", + "LibraryAgentSummary", + "MarketplaceAgentSummary", + "check_external_service_health", + "customize_template", "decompose_goal", + "enrich_library_agents_from_steps", + "extract_search_terms_from_steps", + "extract_uuids_from_text", "generate_agent", "generate_agent_patch", - "save_agent_to_library", "get_agent_as_json", - "json_to_graph", - # Exceptions - "AgentGeneratorNotConfiguredError", - # Service - "is_external_service_configured", - "check_external_service_health", - # Error handling + "get_all_relevant_agents_for_generation", + "get_library_agent_by_graph_id", + "get_library_agent_by_id", + "get_library_agents_for_generation", "get_user_message_for_error", + "graph_to_json", + "is_external_service_configured", + "json_to_graph", + "save_agent_to_library", + "search_marketplace_agents_for_generation", ] diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py index d56e33cbb0..b88b9b2924 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py @@ -1,13 +1,25 @@ """Core agent generation functions.""" import logging +import re import uuid -from typing import Any +from typing import Any, NotRequired, TypedDict from backend.api.features.library import db as library_db -from backend.data.graph import Graph, Link, Node, create_graph +from backend.api.features.store import db as store_db +from backend.data.graph import ( + Graph, + Link, + Node, + create_graph, + get_graph, + get_graph_all_versions, + get_store_listed_graphs, +) +from backend.util.exceptions import DatabaseError, NotFoundError from .service import ( + customize_template_external, decompose_goal_external, generate_agent_external, generate_agent_patch_external, @@ -16,6 +28,74 @@ from .service import ( logger = logging.getLogger(__name__) +AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565" + + +class ExecutionSummary(TypedDict): + """Summary of a single execution for quality assessment.""" + + status: str + correctness_score: NotRequired[float] + activity_summary: NotRequired[str] + + +class LibraryAgentSummary(TypedDict): + """Summary of a library agent for sub-agent composition. + + Includes recent executions to help the LLM decide whether to use this agent. + Each execution shows status, correctness_score (0-1), and activity_summary. + """ + + graph_id: str + graph_version: int + name: str + description: str + input_schema: dict[str, Any] + output_schema: dict[str, Any] + recent_executions: NotRequired[list[ExecutionSummary]] + + +class MarketplaceAgentSummary(TypedDict): + """Summary of a marketplace agent for sub-agent composition.""" + + name: str + description: str + sub_heading: str + creator: str + is_marketplace_agent: bool + + +class DecompositionStep(TypedDict, total=False): + """A single step in decomposed instructions.""" + + description: str + action: str + block_name: str + tool: str + name: str + + +class DecompositionResult(TypedDict, total=False): + """Result from decompose_goal - can be instructions, questions, or error.""" + + type: str + steps: list[DecompositionStep] + questions: list[dict[str, Any]] + error: str + error_type: str + + +AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any] + + +def _to_dict_list( + agents: list[AgentSummary] | list[dict[str, Any]] | None, +) -> list[dict[str, Any]] | None: + """Convert typed agent summaries to plain dicts for external service calls.""" + if agents is None: + return None + return [dict(a) for a in agents] + class AgentGeneratorNotConfiguredError(Exception): """Raised when the external Agent Generator service is not configured.""" @@ -36,15 +116,422 @@ def _check_service_configured() -> None: ) -async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None: +_UUID_PATTERN = re.compile( + r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}", + re.IGNORECASE, +) + + +def extract_uuids_from_text(text: str) -> list[str]: + """Extract all UUID v4 strings from text. + + Args: + text: Text that may contain UUIDs (e.g., user's goal description) + + Returns: + List of unique UUIDs found in the text (lowercase) + """ + matches = _UUID_PATTERN.findall(text) + return list({m.lower() for m in matches}) + + +async def get_library_agent_by_id( + user_id: str, agent_id: str +) -> LibraryAgentSummary | None: + """Fetch a specific library agent by its ID (library agent ID or graph_id). + + This function tries multiple lookup strategies: + 1. First tries to find by graph_id (AgentGraph primary key) + 2. If not found, tries to find by library agent ID (LibraryAgent primary key) + + This handles both cases: + - User provides graph_id (e.g., from AgentExecutorBlock) + - User provides library agent ID (e.g., from library URL) + + Args: + user_id: The user ID + agent_id: The ID to look up (can be graph_id or library agent ID) + + Returns: + LibraryAgentSummary if found, None otherwise + """ + try: + agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id) + if agent: + logger.debug(f"Found library agent by graph_id: {agent.name}") + return LibraryAgentSummary( + graph_id=agent.graph_id, + graph_version=agent.graph_version, + name=agent.name, + description=agent.description, + input_schema=agent.input_schema, + output_schema=agent.output_schema, + ) + except DatabaseError: + raise + except Exception as e: + logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}") + + try: + agent = await library_db.get_library_agent(agent_id, user_id) + if agent: + logger.debug(f"Found library agent by library_id: {agent.name}") + return LibraryAgentSummary( + graph_id=agent.graph_id, + graph_version=agent.graph_version, + name=agent.name, + description=agent.description, + input_schema=agent.input_schema, + output_schema=agent.output_schema, + ) + except NotFoundError: + logger.debug(f"Library agent not found by library_id: {agent_id}") + except DatabaseError: + raise + except Exception as e: + logger.warning( + f"Could not fetch library agent by library_id {agent_id}: {e}", + exc_info=True, + ) + + return None + + +get_library_agent_by_graph_id = get_library_agent_by_id + + +async def get_library_agents_for_generation( + user_id: str, + search_query: str | None = None, + exclude_graph_id: str | None = None, + max_results: int = 15, +) -> list[LibraryAgentSummary]: + """Fetch user's library agents formatted for Agent Generator. + + Uses search-based fetching to return relevant agents instead of all agents. + This is more scalable for users with large libraries. + + Includes recent_executions list to help the LLM assess agent quality: + - Each execution has status, correctness_score (0-1), and activity_summary + - This gives the LLM concrete examples of recent performance + + Args: + user_id: The user ID + search_query: Optional search term to find relevant agents (user's goal/description) + exclude_graph_id: Optional graph ID to exclude (prevents circular references) + max_results: Maximum number of agents to return (default 15) + + Returns: + List of LibraryAgentSummary with schemas and recent executions for sub-agent composition + """ + try: + response = await library_db.list_library_agents( + user_id=user_id, + search_term=search_query, + page=1, + page_size=max_results, + include_executions=True, + ) + + results: list[LibraryAgentSummary] = [] + for agent in response.agents: + if exclude_graph_id is not None and agent.graph_id == exclude_graph_id: + continue + + summary = LibraryAgentSummary( + graph_id=agent.graph_id, + graph_version=agent.graph_version, + name=agent.name, + description=agent.description, + input_schema=agent.input_schema, + output_schema=agent.output_schema, + ) + if agent.recent_executions: + exec_summaries: list[ExecutionSummary] = [] + for ex in agent.recent_executions: + exec_sum = ExecutionSummary(status=ex.status) + if ex.correctness_score is not None: + exec_sum["correctness_score"] = ex.correctness_score + if ex.activity_summary: + exec_sum["activity_summary"] = ex.activity_summary + exec_summaries.append(exec_sum) + summary["recent_executions"] = exec_summaries + results.append(summary) + return results + except DatabaseError: + raise + except Exception as e: + logger.warning(f"Failed to fetch library agents: {e}") + return [] + + +async def search_marketplace_agents_for_generation( + search_query: str, + max_results: int = 10, +) -> list[LibraryAgentSummary]: + """Search marketplace agents formatted for Agent Generator. + + Fetches marketplace agents and their full schemas so they can be used + as sub-agents in generated workflows. + + Args: + search_query: Search term to find relevant public agents + max_results: Maximum number of agents to return (default 10) + + Returns: + List of LibraryAgentSummary with full input/output schemas + """ + try: + response = await store_db.get_store_agents( + search_query=search_query, + page=1, + page_size=max_results, + ) + + agents_with_graphs = [ + agent for agent in response.agents if agent.agent_graph_id + ] + + if not agents_with_graphs: + return [] + + graph_ids = [agent.agent_graph_id for agent in agents_with_graphs] + graphs = await get_store_listed_graphs(*graph_ids) + + results: list[LibraryAgentSummary] = [] + for agent in agents_with_graphs: + graph_id = agent.agent_graph_id + if graph_id and graph_id in graphs: + graph = graphs[graph_id] + results.append( + LibraryAgentSummary( + graph_id=graph.id, + graph_version=graph.version, + name=agent.agent_name, + description=agent.description, + input_schema=graph.input_schema, + output_schema=graph.output_schema, + ) + ) + return results + except Exception as e: + logger.warning(f"Failed to search marketplace agents: {e}") + return [] + + +async def get_all_relevant_agents_for_generation( + user_id: str, + search_query: str | None = None, + exclude_graph_id: str | None = None, + include_library: bool = True, + include_marketplace: bool = True, + max_library_results: int = 15, + max_marketplace_results: int = 10, +) -> list[AgentSummary]: + """Fetch relevant agents from library and/or marketplace. + + Searches both user's library and marketplace by default. + Explicitly mentioned UUIDs in the search query are always looked up. + + Args: + user_id: The user ID + search_query: Search term to find relevant agents (user's goal/description) + exclude_graph_id: Optional graph ID to exclude (prevents circular references) + include_library: Whether to search user's library (default True) + include_marketplace: Whether to also search marketplace (default True) + max_library_results: Max library agents to return (default 15) + max_marketplace_results: Max marketplace agents to return (default 10) + + Returns: + List of AgentSummary with full schemas (both library and marketplace agents) + """ + agents: list[AgentSummary] = [] + seen_graph_ids: set[str] = set() + + if search_query: + mentioned_uuids = extract_uuids_from_text(search_query) + for graph_id in mentioned_uuids: + if graph_id == exclude_graph_id: + continue + agent = await get_library_agent_by_graph_id(user_id, graph_id) + agent_graph_id = agent.get("graph_id") if agent else None + if agent and agent_graph_id and agent_graph_id not in seen_graph_ids: + agents.append(agent) + seen_graph_ids.add(agent_graph_id) + logger.debug( + f"Found explicitly mentioned agent: {agent.get('name') or 'Unknown'}" + ) + + if include_library: + library_agents = await get_library_agents_for_generation( + user_id=user_id, + search_query=search_query, + exclude_graph_id=exclude_graph_id, + max_results=max_library_results, + ) + for agent in library_agents: + graph_id = agent.get("graph_id") + if graph_id and graph_id not in seen_graph_ids: + agents.append(agent) + seen_graph_ids.add(graph_id) + + if include_marketplace and search_query: + marketplace_agents = await search_marketplace_agents_for_generation( + search_query=search_query, + max_results=max_marketplace_results, + ) + for agent in marketplace_agents: + graph_id = agent.get("graph_id") + if graph_id and graph_id not in seen_graph_ids: + agents.append(agent) + seen_graph_ids.add(graph_id) + + return agents + + +def extract_search_terms_from_steps( + decomposition_result: DecompositionResult | dict[str, Any], +) -> list[str]: + """Extract search terms from decomposed instruction steps. + + Analyzes the decomposition result to extract relevant keywords + for additional library agent searches. + + Args: + decomposition_result: Result from decompose_goal containing steps + + Returns: + List of unique search terms extracted from steps + """ + search_terms: list[str] = [] + + if decomposition_result.get("type") != "instructions": + return search_terms + + steps = decomposition_result.get("steps", []) + if not steps: + return search_terms + + step_keys: list[str] = ["description", "action", "block_name", "tool", "name"] + + for step in steps: + for key in step_keys: + value = step.get(key) # type: ignore[union-attr] + if isinstance(value, str) and len(value) > 3: + search_terms.append(value) + + seen: set[str] = set() + unique_terms: list[str] = [] + for term in search_terms: + term_lower = term.lower() + if term_lower not in seen: + seen.add(term_lower) + unique_terms.append(term) + + return unique_terms + + +async def enrich_library_agents_from_steps( + user_id: str, + decomposition_result: DecompositionResult | dict[str, Any], + existing_agents: list[AgentSummary] | list[dict[str, Any]], + exclude_graph_id: str | None = None, + include_marketplace: bool = True, + max_additional_results: int = 10, +) -> list[AgentSummary] | list[dict[str, Any]]: + """Enrich library agents list with additional searches based on decomposed steps. + + This implements two-phase search: after decomposition, we search for additional + relevant agents based on the specific steps identified. + + Args: + user_id: The user ID + decomposition_result: Result from decompose_goal containing steps + existing_agents: Already fetched library agents from initial search + exclude_graph_id: Optional graph ID to exclude + include_marketplace: Whether to also search marketplace + max_additional_results: Max additional agents per search term (default 10) + + Returns: + Combined list of library agents (existing + newly discovered) + """ + search_terms = extract_search_terms_from_steps(decomposition_result) + + if not search_terms: + return existing_agents + + existing_ids: set[str] = set() + existing_names: set[str] = set() + + for agent in existing_agents: + agent_name = agent.get("name") + if agent_name and isinstance(agent_name, str): + existing_names.add(agent_name.lower()) + graph_id = agent.get("graph_id") # type: ignore[call-overload] + if graph_id and isinstance(graph_id, str): + existing_ids.add(graph_id) + + all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents) + + for term in search_terms[:3]: + try: + additional_agents = await get_all_relevant_agents_for_generation( + user_id=user_id, + search_query=term, + exclude_graph_id=exclude_graph_id, + include_marketplace=include_marketplace, + max_library_results=max_additional_results, + max_marketplace_results=5, + ) + + for agent in additional_agents: + agent_name = agent.get("name") + if not agent_name or not isinstance(agent_name, str): + continue + agent_name_lower = agent_name.lower() + + if agent_name_lower in existing_names: + continue + + graph_id = agent.get("graph_id") # type: ignore[call-overload] + if graph_id and graph_id in existing_ids: + continue + + all_agents.append(agent) + existing_names.add(agent_name_lower) + if graph_id and isinstance(graph_id, str): + existing_ids.add(graph_id) + + except DatabaseError: + logger.error(f"Database error searching for agents with term '{term}'") + raise + except Exception as e: + logger.warning( + f"Failed to search for additional agents with term '{term}': {e}" + ) + + logger.debug( + f"Enriched library agents: {len(existing_agents)} initial + " + f"{len(all_agents) - len(existing_agents)} additional = {len(all_agents)} total" + ) + + return all_agents + + +async def decompose_goal( + description: str, + context: str = "", + library_agents: list[AgentSummary] | None = None, +) -> DecompositionResult | None: """Break down a goal into steps or return clarifying questions. Args: description: Natural language goal description context: Additional context (e.g., answers to previous questions) + library_agents: User's library agents available for sub-agent composition Returns: - Dict with either: + DecompositionResult with either: - {"type": "clarifying_questions", "questions": [...]} - {"type": "instructions", "steps": [...]} Or None on error @@ -54,29 +541,47 @@ async def decompose_goal(description: str, context: str = "") -> dict[str, Any] """ _check_service_configured() logger.info("Calling external Agent Generator service for decompose_goal") - return await decompose_goal_external(description, context) + result = await decompose_goal_external( + description, context, _to_dict_list(library_agents) + ) + return result # type: ignore[return-value] -async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None: +async def generate_agent( + instructions: DecompositionResult | dict[str, Any], + library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, +) -> dict[str, Any] | None: """Generate agent JSON from instructions. Args: instructions: Structured instructions from decompose_goal + library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams + completion notification) + task_id: Task ID for async processing (enables Redis Streams persistence + and SSE delivery) Returns: - Agent JSON dict, error dict {"type": "error", ...}, or None on error + Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error Raises: AgentGeneratorNotConfiguredError: If the external service is not configured. """ _check_service_configured() logger.info("Calling external Agent Generator service for generate_agent") - result = await generate_agent_external(instructions) + result = await generate_agent_external( + dict(instructions), _to_dict_list(library_agents), operation_id, task_id + ) + + # Don't modify async response + if result and result.get("status") == "accepted": + return result + if result: - # Check if it's an error response - pass through as-is if isinstance(result, dict) and result.get("type") == "error": return result - # Ensure required fields for successful agent generation if "id" not in result: result["id"] = str(uuid.uuid4()) if "version" not in result: @@ -86,6 +591,12 @@ async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None: return result +class AgentJsonValidationError(Exception): + """Raised when agent JSON is invalid or missing required fields.""" + + pass + + def json_to_graph(agent_json: dict[str, Any]) -> Graph: """Convert agent JSON dict to Graph model. @@ -94,25 +605,55 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph: Returns: Graph ready for saving + + Raises: + AgentJsonValidationError: If required fields are missing from nodes or links """ nodes = [] - for n in agent_json.get("nodes", []): + for idx, n in enumerate(agent_json.get("nodes", [])): + block_id = n.get("block_id") + if not block_id: + node_id = n.get("id", f"index_{idx}") + raise AgentJsonValidationError( + f"Node '{node_id}' is missing required field 'block_id'" + ) node = Node( id=n.get("id", str(uuid.uuid4())), - block_id=n["block_id"], + block_id=block_id, input_default=n.get("input_default", {}), metadata=n.get("metadata", {}), ) nodes.append(node) links = [] - for link_data in agent_json.get("links", []): + for idx, link_data in enumerate(agent_json.get("links", [])): + source_id = link_data.get("source_id") + sink_id = link_data.get("sink_id") + source_name = link_data.get("source_name") + sink_name = link_data.get("sink_name") + + missing_fields = [] + if not source_id: + missing_fields.append("source_id") + if not sink_id: + missing_fields.append("sink_id") + if not source_name: + missing_fields.append("source_name") + if not sink_name: + missing_fields.append("sink_name") + + if missing_fields: + link_id = link_data.get("id", f"index_{idx}") + raise AgentJsonValidationError( + f"Link '{link_id}' is missing required fields: {', '.join(missing_fields)}" + ) + link = Link( id=link_data.get("id", str(uuid.uuid4())), - source_id=link_data["source_id"], - sink_id=link_data["sink_id"], - source_name=link_data["source_name"], - sink_name=link_data["sink_name"], + source_id=source_id, + sink_id=sink_id, + source_name=source_name, + sink_name=sink_name, is_static=link_data.get("is_static", False), ) links.append(link) @@ -133,22 +674,40 @@ def _reassign_node_ids(graph: Graph) -> None: This is needed when creating a new version to avoid unique constraint violations. """ - # Create mapping from old node IDs to new UUIDs id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes} - # Reassign node IDs for node in graph.nodes: node.id = id_map[node.id] - # Update link references to use new node IDs for link in graph.links: - link.id = str(uuid.uuid4()) # Also give links new IDs + link.id = str(uuid.uuid4()) if link.source_id in id_map: link.source_id = id_map[link.source_id] if link.sink_id in id_map: link.sink_id = id_map[link.sink_id] +def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None: + """Populate user_id in AgentExecutorBlock nodes. + + The external agent generator creates AgentExecutorBlock nodes with empty user_id. + This function fills in the actual user_id so sub-agents run with correct permissions. + + Args: + agent_json: Agent JSON dict (modified in place) + user_id: User ID to set + """ + for node in agent_json.get("nodes", []): + if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID: + input_default = node.get("input_default") or {} + if not input_default.get("user_id"): + input_default["user_id"] = user_id + node["input_default"] = input_default + logger.debug( + f"Set user_id for AgentExecutorBlock node {node.get('id')}" + ) + + async def save_agent_to_library( agent_json: dict[str, Any], user_id: str, is_update: bool = False ) -> tuple[Graph, Any]: @@ -162,33 +721,27 @@ async def save_agent_to_library( Returns: Tuple of (created Graph, LibraryAgent) """ - from backend.data.graph import get_graph_all_versions + # Populate user_id in AgentExecutorBlock nodes before conversion + _populate_agent_executor_user_ids(agent_json, user_id) graph = json_to_graph(agent_json) if is_update: - # For updates, keep the same graph ID but increment version - # and reassign node/link IDs to avoid conflicts if graph.id: existing_versions = await get_graph_all_versions(graph.id, user_id) if existing_versions: latest_version = max(v.version for v in existing_versions) graph.version = latest_version + 1 - # Reassign node IDs (but keep graph ID the same) _reassign_node_ids(graph) logger.info(f"Updating agent {graph.id} to version {graph.version}") else: - # For new agents, always generate a fresh UUID to avoid collisions graph.id = str(uuid.uuid4()) graph.version = 1 - # Reassign all node IDs as well _reassign_node_ids(graph) logger.info(f"Creating new agent with ID {graph.id}") - # Save to database created_graph = await create_graph(graph, user_id) - # Add to user's library (or update existing library agent) library_agents = await library_db.create_library_agent( graph=created_graph, user_id=user_id, @@ -199,26 +752,15 @@ async def save_agent_to_library( return created_graph, library_agents[0] -async def get_agent_as_json( - graph_id: str, user_id: str | None -) -> dict[str, Any] | None: - """Fetch an agent and convert to JSON format for editing. +def graph_to_json(graph: Graph) -> dict[str, Any]: + """Convert a Graph object to JSON format for the agent generator. Args: - graph_id: Graph ID or library agent ID - user_id: User ID + graph: Graph object to convert Returns: - Agent as JSON dict or None if not found + Agent as JSON dict """ - from backend.data.graph import get_graph - - # Try to get the graph (version=None gets the active version) - graph = await get_graph(graph_id, version=None, user_id=user_id) - if not graph: - return None - - # Convert to JSON format nodes = [] for node in graph.nodes: nodes.append( @@ -255,8 +797,41 @@ async def get_agent_as_json( } +async def get_agent_as_json( + agent_id: str, user_id: str | None +) -> dict[str, Any] | None: + """Fetch an agent and convert to JSON format for editing. + + Args: + agent_id: Graph ID or library agent ID + user_id: User ID + + Returns: + Agent as JSON dict or None if not found + """ + graph = await get_graph(agent_id, version=None, user_id=user_id) + + if not graph and user_id: + try: + library_agent = await library_db.get_library_agent(agent_id, user_id) + graph = await get_graph( + library_agent.graph_id, version=None, user_id=user_id + ) + except NotFoundError: + pass + + if not graph: + return None + + return graph_to_json(graph) + + async def generate_agent_patch( - update_request: str, current_agent: dict[str, Any] + update_request: str, + current_agent: dict[str, Any], + library_agents: list[AgentSummary] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Update an existing agent using natural language. @@ -268,14 +843,57 @@ async def generate_agent_patch( Args: update_request: Natural language description of changes current_agent: Current agent JSON + library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...}, - error dict {"type": "error", ...}, or None on unexpected error + {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error Raises: AgentGeneratorNotConfiguredError: If the external service is not configured. """ _check_service_configured() logger.info("Calling external Agent Generator service for generate_agent_patch") - return await generate_agent_patch_external(update_request, current_agent) + return await generate_agent_patch_external( + update_request, + current_agent, + _to_dict_list(library_agents), + operation_id, + task_id, + ) + + +async def customize_template( + template_agent: dict[str, Any], + modification_request: str, + context: str = "", +) -> dict[str, Any] | None: + """Customize a template/marketplace agent using natural language. + + This is used when users want to modify a template or marketplace agent + to fit their specific needs before adding it to their library. + + The external Agent Generator service handles: + - Understanding the modification request + - Applying changes to the template + - Fixing and validating the result + + Args: + template_agent: The template agent JSON to customize + modification_request: Natural language description of customizations + context: Additional context (e.g., answers to previous questions) + + Returns: + Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...}, + error dict {"type": "error", ...}, or None on unexpected error + + Raises: + AgentGeneratorNotConfiguredError: If the external service is not configured. + """ + _check_service_configured() + logger.info("Calling external Agent Generator service for customize_template") + return await customize_template_external( + template_agent, modification_request, context + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/errors.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/errors.py index bf71a95df9..282d8cf9aa 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/errors.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/errors.py @@ -1,11 +1,43 @@ """Error handling utilities for agent generator.""" +import re + + +def _sanitize_error_details(details: str) -> str: + """Sanitize error details to remove sensitive information. + + Strips common patterns that could expose internal system info: + - File paths (Unix and Windows) + - Database connection strings + - URLs with credentials + - Stack trace internals + + Args: + details: Raw error details string + + Returns: + Sanitized error details safe for user display + """ + sanitized = re.sub( + r"/[a-zA-Z0-9_./\-]+\.(py|js|ts|json|yaml|yml)", "[path]", details + ) + sanitized = re.sub(r"[A-Z]:\\[a-zA-Z0-9_\\.\\-]+", "[path]", sanitized) + sanitized = re.sub( + r"(postgres|mysql|mongodb|redis)://[^\s]+", "[database_url]", sanitized + ) + sanitized = re.sub(r"https?://[^:]+:[^@]+@[^\s]+", "[url]", sanitized) + sanitized = re.sub(r", line \d+", "", sanitized) + sanitized = re.sub(r'File "[^"]+",?', "", sanitized) + + return sanitized.strip() + def get_user_message_for_error( error_type: str, operation: str = "process the request", llm_parse_message: str | None = None, validation_message: str | None = None, + error_details: str | None = None, ) -> str: """Get a user-friendly error message based on error type. @@ -19,25 +51,45 @@ def get_user_message_for_error( message (e.g., "analyze the goal", "generate the agent") llm_parse_message: Custom message for llm_parse_error type validation_message: Custom message for validation_error type + error_details: Optional additional details about the error Returns: User-friendly error message suitable for display to the user """ + base_message = "" + if error_type == "llm_parse_error": - return ( + base_message = ( llm_parse_message or "The AI had trouble processing this request. Please try again." ) elif error_type == "validation_error": - return ( + base_message = ( validation_message - or "The request failed validation. Please try rephrasing." + or "The generated agent failed validation. " + "This usually happens when the agent structure doesn't match " + "what the platform expects. Please try simplifying your goal " + "or breaking it into smaller parts." ) elif error_type == "patch_error": - return "Failed to apply the changes. Please try a different approach." + base_message = ( + "Failed to apply the changes. The modification couldn't be " + "validated. Please try a different approach or simplify the change." + ) elif error_type in ("timeout", "llm_timeout"): - return "The request took too long. Please try again." + base_message = ( + "The request took too long to process. This can happen with " + "complex agents. Please try again or simplify your goal." + ) elif error_type in ("rate_limit", "llm_rate_limit"): - return "The service is currently busy. Please try again in a moment." + base_message = "The service is currently busy. Please try again in a moment." else: - return f"Failed to {operation}. Please try again." + base_message = f"Failed to {operation}. Please try again." + + if error_details: + details = _sanitize_error_details(error_details) + if len(details) > 200: + details = details[:200] + "..." + base_message += f"\n\nTechnical details: {details}" + + return base_message diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py index 1df1faaaef..62411b4e1b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py @@ -117,13 +117,16 @@ def _get_client() -> httpx.AsyncClient: async def decompose_goal_external( - description: str, context: str = "" + description: str, + context: str = "", + library_agents: list[dict[str, Any]] | None = None, ) -> dict[str, Any] | None: """Call the external service to decompose a goal. Args: description: Natural language goal description context: Additional context (e.g., answers to previous questions) + library_agents: User's library agents available for sub-agent composition Returns: Dict with either: @@ -136,11 +139,12 @@ async def decompose_goal_external( """ client = _get_client() - # Build the request payload - payload: dict[str, Any] = {"description": description} if context: - # The external service uses user_instruction for additional context - payload["user_instruction"] = context + description = f"{description}\n\nAdditional context from user:\n{context}" + + payload: dict[str, Any] = {"description": description} + if library_agents: + payload["library_agents"] = library_agents try: response = await client.post("/api/decompose-description", json=payload) @@ -207,21 +211,46 @@ async def decompose_goal_external( async def generate_agent_external( instructions: dict[str, Any], + library_agents: list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Call the external service to generate an agent from instructions. Args: instructions: Structured instructions from decompose_goal + library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: - Agent JSON dict on success, or error dict {"type": "error", ...} on error + Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error """ client = _get_client() + # Build request payload + payload: dict[str, Any] = {"instructions": instructions} + if library_agents: + payload["library_agents"] = library_agents + if operation_id and task_id: + payload["operation_id"] = operation_id + payload["task_id"] = task_id + try: - response = await client.post( - "/api/generate-agent", json={"instructions": instructions} - ) + response = await client.post("/api/generate-agent", json=payload) + + # Handle 202 Accepted for async processing + if response.status_code == 202: + logger.info( + f"Agent Generator accepted async request " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return { + "status": "accepted", + "operation_id": operation_id, + "task_id": task_id, + } + response.raise_for_status() data = response.json() @@ -229,8 +258,7 @@ async def generate_agent_external( error_msg = data.get("error", "Unknown error from Agent Generator") error_type = data.get("error_type", "unknown") logger.error( - f"Agent Generator generation failed: {error_msg} " - f"(type: {error_type})" + f"Agent Generator generation failed: {error_msg} (type: {error_type})" ) return _create_error_response(error_msg, error_type) @@ -251,27 +279,52 @@ async def generate_agent_external( async def generate_agent_patch_external( - update_request: str, current_agent: dict[str, Any] + update_request: str, + current_agent: dict[str, Any], + library_agents: list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Call the external service to generate a patch for an existing agent. Args: update_request: Natural language description of changes current_agent: Current agent JSON + library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: - Updated agent JSON, clarifying questions dict, or error dict on error + Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error """ client = _get_client() + # Build request payload + payload: dict[str, Any] = { + "update_request": update_request, + "current_agent_json": current_agent, + } + if library_agents: + payload["library_agents"] = library_agents + if operation_id and task_id: + payload["operation_id"] = operation_id + payload["task_id"] = task_id + try: - response = await client.post( - "/api/update-agent", - json={ - "update_request": update_request, - "current_agent_json": current_agent, - }, - ) + response = await client.post("/api/update-agent", json=payload) + + # Handle 202 Accepted for async processing + if response.status_code == 202: + logger.info( + f"Agent Generator accepted async update request " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return { + "status": "accepted", + "operation_id": operation_id, + "task_id": task_id, + } + response.raise_for_status() data = response.json() @@ -315,6 +368,77 @@ async def generate_agent_patch_external( return _create_error_response(error_msg, "unexpected_error") +async def customize_template_external( + template_agent: dict[str, Any], + modification_request: str, + context: str = "", +) -> dict[str, Any] | None: + """Call the external service to customize a template/marketplace agent. + + Args: + template_agent: The template agent JSON to customize + modification_request: Natural language description of customizations + context: Additional context (e.g., answers to previous questions) + + Returns: + Customized agent JSON, clarifying questions dict, or error dict on error + """ + client = _get_client() + + request = modification_request + if context: + request = f"{modification_request}\n\nAdditional context from user:\n{context}" + + payload: dict[str, Any] = { + "template_agent_json": template_agent, + "modification_request": request, + } + + try: + response = await client.post("/api/template-modification", json=payload) + response.raise_for_status() + data = response.json() + + if not data.get("success"): + error_msg = data.get("error", "Unknown error from Agent Generator") + error_type = data.get("error_type", "unknown") + logger.error( + f"Agent Generator template customization failed: {error_msg} " + f"(type: {error_type})" + ) + return _create_error_response(error_msg, error_type) + + # Check if it's clarifying questions + if data.get("type") == "clarifying_questions": + return { + "type": "clarifying_questions", + "questions": data.get("questions", []), + } + + # Check if it's an error passed through + if data.get("type") == "error": + return _create_error_response( + data.get("error", "Unknown error"), + data.get("error_type", "unknown"), + ) + + # Otherwise return the customized agent JSON + return data.get("agent_json") + + except httpx.HTTPStatusError as e: + error_type, error_msg = _classify_http_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except httpx.RequestError as e: + error_type, error_msg = _classify_request_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except Exception as e: + error_msg = f"Unexpected error calling Agent Generator: {e}" + logger.error(error_msg) + return _create_error_response(error_msg, "unexpected_error") + + async def get_blocks_external() -> list[dict[str, Any]] | None: """Get available blocks from the external service. diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py index 5fa74ba04e..62d59c470e 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py @@ -1,6 +1,7 @@ """Shared agent search functionality for find_agent and find_library_agent tools.""" import logging +import re from typing import Literal from backend.api.features.library import db as library_db @@ -19,6 +20,85 @@ logger = logging.getLogger(__name__) SearchSource = Literal["marketplace", "library"] +_UUID_PATTERN = re.compile( + r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$", + re.IGNORECASE, +) + + +def _is_uuid(text: str) -> bool: + """Check if text is a valid UUID v4.""" + return bool(_UUID_PATTERN.match(text.strip())) + + +async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None: + """Fetch a library agent by ID (library agent ID or graph_id). + + Tries multiple lookup strategies: + 1. First by graph_id (AgentGraph primary key) + 2. Then by library agent ID (LibraryAgent primary key) + + Args: + user_id: The user ID + agent_id: The ID to look up (can be graph_id or library agent ID) + + Returns: + AgentInfo if found, None otherwise + """ + try: + agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id) + if agent: + logger.debug(f"Found library agent by graph_id: {agent.name}") + return AgentInfo( + id=agent.id, + name=agent.name, + description=agent.description or "", + source="library", + in_library=True, + creator=agent.creator_name, + status=agent.status.value, + can_access_graph=agent.can_access_graph, + has_external_trigger=agent.has_external_trigger, + new_output=agent.new_output, + graph_id=agent.graph_id, + ) + except DatabaseError: + raise + except Exception as e: + logger.warning( + f"Could not fetch library agent by graph_id {agent_id}: {e}", + exc_info=True, + ) + + try: + agent = await library_db.get_library_agent(agent_id, user_id) + if agent: + logger.debug(f"Found library agent by library_id: {agent.name}") + return AgentInfo( + id=agent.id, + name=agent.name, + description=agent.description or "", + source="library", + in_library=True, + creator=agent.creator_name, + status=agent.status.value, + can_access_graph=agent.can_access_graph, + has_external_trigger=agent.has_external_trigger, + new_output=agent.new_output, + graph_id=agent.graph_id, + ) + except NotFoundError: + logger.debug(f"Library agent not found by library_id: {agent_id}") + except DatabaseError: + raise + except Exception as e: + logger.warning( + f"Could not fetch library agent by library_id {agent_id}: {e}", + exc_info=True, + ) + + return None + async def search_agents( query: str, @@ -69,29 +149,37 @@ async def search_agents( is_featured=False, ) ) - else: # library - logger.info(f"Searching user library for: {query}") - results = await library_db.list_library_agents( - user_id=user_id, # type: ignore[arg-type] - search_term=query, - page_size=10, - ) - for agent in results.agents: - agents.append( - AgentInfo( - id=agent.id, - name=agent.name, - description=agent.description or "", - source="library", - in_library=True, - creator=agent.creator_name, - status=agent.status.value, - can_access_graph=agent.can_access_graph, - has_external_trigger=agent.has_external_trigger, - new_output=agent.new_output, - graph_id=agent.graph_id, - ) + else: + if _is_uuid(query): + logger.info(f"Query looks like UUID, trying direct lookup: {query}") + agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type] + if agent: + agents.append(agent) + logger.info(f"Found agent by direct ID lookup: {agent.name}") + + if not agents: + logger.info(f"Searching user library for: {query}") + results = await library_db.list_library_agents( + user_id=user_id, # type: ignore[arg-type] + search_term=query, + page_size=10, ) + for agent in results.agents: + agents.append( + AgentInfo( + id=agent.id, + name=agent.name, + description=agent.description or "", + source="library", + in_library=True, + creator=agent.creator_name, + status=agent.status.value, + can_access_graph=agent.can_access_graph, + has_external_trigger=agent.has_external_trigger, + new_output=agent.new_output, + graph_id=agent.graph_id, + ) + ) logger.info(f"Found {len(agents)} agents in {source}") except NotFoundError: pass diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py index 74011c7e95..7333851a5b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py @@ -8,7 +8,9 @@ from backend.api.features.chat.model import ChatSession from .agent_generator import ( AgentGeneratorNotConfiguredError, decompose_goal, + enrich_library_agents_from_steps, generate_agent, + get_all_relevant_agents_for_generation, get_user_message_for_error, save_agent_to_library, ) @@ -16,6 +18,7 @@ from .base import BaseTool from .models import ( AgentPreviewResponse, AgentSavedResponse, + AsyncProcessingResponse, ClarificationNeededResponse, ClarifyingQuestion, ErrorResponse, @@ -96,6 +99,10 @@ class CreateAgentTool(BaseTool): save = kwargs.get("save", True) session_id = session.session_id if session else None + # Extract async processing params (passed by long-running tool handler) + operation_id = kwargs.get("_operation_id") + task_id = kwargs.get("_task_id") + if not description: return ErrorResponse( message="Please provide a description of what the agent should do.", @@ -103,9 +110,24 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) - # Step 1: Decompose goal into steps + library_agents = None + if user_id: + try: + library_agents = await get_all_relevant_agents_for_generation( + user_id=user_id, + search_query=description, + include_marketplace=True, + ) + logger.debug( + f"Found {len(library_agents)} relevant agents for sub-agent composition" + ) + except Exception as e: + logger.warning(f"Failed to fetch library agents: {e}") + try: - decomposition_result = await decompose_goal(description, context) + decomposition_result = await decompose_goal( + description, context, library_agents + ) except AgentGeneratorNotConfiguredError: return ErrorResponse( message=( @@ -124,7 +146,6 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) - # Check if the result is an error from the external service if decomposition_result.get("type") == "error": error_msg = decomposition_result.get("error", "Unknown error") error_type = decomposition_result.get("error_type", "unknown") @@ -144,7 +165,6 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) - # Check if LLM returned clarifying questions if decomposition_result.get("type") == "clarifying_questions": questions = decomposition_result.get("questions", []) return ClarificationNeededResponse( @@ -163,7 +183,6 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) - # Check for unachievable/vague goals if decomposition_result.get("type") == "unachievable_goal": suggested = decomposition_result.get("suggested_goal", "") reason = decomposition_result.get("reason", "") @@ -190,9 +209,27 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) - # Step 2: Generate agent JSON (external service handles fixing and validation) + if user_id and library_agents is not None: + try: + library_agents = await enrich_library_agents_from_steps( + user_id=user_id, + decomposition_result=decomposition_result, + existing_agents=library_agents, + include_marketplace=True, + ) + logger.debug( + f"After enrichment: {len(library_agents)} total agents for sub-agent composition" + ) + except Exception as e: + logger.warning(f"Failed to enrich library agents from steps: {e}") + try: - agent_json = await generate_agent(decomposition_result) + agent_json = await generate_agent( + decomposition_result, + library_agents, + operation_id=operation_id, + task_id=task_id, + ) except AgentGeneratorNotConfiguredError: return ErrorResponse( message=( @@ -211,7 +248,6 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) - # Check if the result is an error from the external service if isinstance(agent_json, dict) and agent_json.get("type") == "error": error_msg = agent_json.get("error", "Unknown error") error_type = agent_json.get("error_type", "unknown") @@ -219,7 +255,12 @@ class CreateAgentTool(BaseTool): error_type, operation="generate the agent", llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.", - validation_message="The generated agent failed validation. Please try rephrasing your goal.", + validation_message=( + "I wasn't able to create a valid agent for this request. " + "The generated workflow had some structural issues. " + "Please try simplifying your goal or breaking it into smaller steps." + ), + error_details=error_msg, ) return ErrorResponse( message=user_message, @@ -232,12 +273,24 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) + # Check if Agent Generator accepted for async processing + if agent_json.get("status") == "accepted": + logger.info( + f"Agent generation delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return AsyncProcessingResponse( + message="Agent generation started. You'll be notified when it's complete.", + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + ) + agent_name = agent_json.get("name", "Generated Agent") agent_description = agent_json.get("description", "") node_count = len(agent_json.get("nodes", [])) link_count = len(agent_json.get("links", [])) - # Step 3: Preview or save if not save: return AgentPreviewResponse( message=( @@ -252,7 +305,6 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) - # Save to library if not user_id: return ErrorResponse( message="You must be logged in to save agents.", @@ -270,7 +322,7 @@ class CreateAgentTool(BaseTool): agent_id=created_graph.id, agent_name=created_graph.name, library_agent_id=library_agent.id, - library_agent_link=f"/library/{library_agent.id}", + library_agent_link=f"/library/agents/{library_agent.id}", agent_page_link=f"/build?flowID={created_graph.id}", session_id=session_id, ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py new file mode 100644 index 0000000000..c0568bd936 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py @@ -0,0 +1,337 @@ +"""CustomizeAgentTool - Customizes marketplace/template agents using natural language.""" + +import logging +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.store import db as store_db +from backend.api.features.store.exceptions import AgentNotFoundError + +from .agent_generator import ( + AgentGeneratorNotConfiguredError, + customize_template, + get_user_message_for_error, + graph_to_json, + save_agent_to_library, +) +from .base import BaseTool +from .models import ( + AgentPreviewResponse, + AgentSavedResponse, + ClarificationNeededResponse, + ClarifyingQuestion, + ErrorResponse, + ToolResponseBase, +) + +logger = logging.getLogger(__name__) + + +class CustomizeAgentTool(BaseTool): + """Tool for customizing marketplace/template agents using natural language.""" + + @property + def name(self) -> str: + return "customize_agent" + + @property + def description(self) -> str: + return ( + "Customize a marketplace or template agent using natural language. " + "Takes an existing agent from the marketplace and modifies it based on " + "the user's requirements before adding to their library." + ) + + @property + def requires_auth(self) -> bool: + return True + + @property + def is_long_running(self) -> bool: + return True + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "agent_id": { + "type": "string", + "description": ( + "The marketplace agent ID in format 'creator/slug' " + "(e.g., 'autogpt/newsletter-writer'). " + "Get this from find_agent results." + ), + }, + "modifications": { + "type": "string", + "description": ( + "Natural language description of how to customize the agent. " + "Be specific about what changes you want to make." + ), + }, + "context": { + "type": "string", + "description": ( + "Additional context or answers to previous clarifying questions." + ), + }, + "save": { + "type": "boolean", + "description": ( + "Whether to save the customized agent to the user's library. " + "Default is true. Set to false for preview only." + ), + "default": True, + }, + }, + "required": ["agent_id", "modifications"], + } + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + """Execute the customize_agent tool. + + Flow: + 1. Parse the agent ID to get creator/slug + 2. Fetch the template agent from the marketplace + 3. Call customize_template with the modification request + 4. Preview or save based on the save parameter + """ + agent_id = kwargs.get("agent_id", "").strip() + modifications = kwargs.get("modifications", "").strip() + context = kwargs.get("context", "") + save = kwargs.get("save", True) + session_id = session.session_id if session else None + + if not agent_id: + return ErrorResponse( + message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').", + error="missing_agent_id", + session_id=session_id, + ) + + if not modifications: + return ErrorResponse( + message="Please describe how you want to customize this agent.", + error="missing_modifications", + session_id=session_id, + ) + + # Parse agent_id in format "creator/slug" + parts = [p.strip() for p in agent_id.split("/")] + if len(parts) != 2 or not parts[0] or not parts[1]: + return ErrorResponse( + message=( + f"Invalid agent ID format: '{agent_id}'. " + "Expected format is 'creator/agent-name' " + "(e.g., 'autogpt/newsletter-writer')." + ), + error="invalid_agent_id_format", + session_id=session_id, + ) + + creator_username, agent_slug = parts + + # Fetch the marketplace agent details + try: + agent_details = await store_db.get_store_agent_details( + username=creator_username, agent_name=agent_slug + ) + except AgentNotFoundError: + return ErrorResponse( + message=( + f"Could not find marketplace agent '{agent_id}'. " + "Please check the agent ID and try again." + ), + error="agent_not_found", + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error fetching marketplace agent {agent_id}: {e}") + return ErrorResponse( + message="Failed to fetch the marketplace agent. Please try again.", + error="fetch_error", + session_id=session_id, + ) + + if not agent_details.store_listing_version_id: + return ErrorResponse( + message=( + f"The agent '{agent_id}' does not have an available version. " + "Please try a different agent." + ), + error="no_version_available", + session_id=session_id, + ) + + # Get the full agent graph + try: + graph = await store_db.get_agent(agent_details.store_listing_version_id) + template_agent = graph_to_json(graph) + except Exception as e: + logger.error(f"Error fetching agent graph for {agent_id}: {e}") + return ErrorResponse( + message="Failed to fetch the agent configuration. Please try again.", + error="graph_fetch_error", + session_id=session_id, + ) + + # Call customize_template + try: + result = await customize_template( + template_agent=template_agent, + modification_request=modifications, + context=context, + ) + except AgentGeneratorNotConfiguredError: + return ErrorResponse( + message=( + "Agent customization is not available. " + "The Agent Generator service is not configured." + ), + error="service_not_configured", + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error calling customize_template for {agent_id}: {e}") + return ErrorResponse( + message=( + "Failed to customize the agent due to a service error. " + "Please try again." + ), + error="customization_service_error", + session_id=session_id, + ) + + if result is None: + return ErrorResponse( + message=( + "Failed to customize the agent. " + "The agent generation service may be unavailable or timed out. " + "Please try again." + ), + error="customization_failed", + session_id=session_id, + ) + + # Handle error response + if isinstance(result, dict) and result.get("type") == "error": + error_msg = result.get("error", "Unknown error") + error_type = result.get("error_type", "unknown") + user_message = get_user_message_for_error( + error_type, + operation="customize the agent", + llm_parse_message=( + "The AI had trouble customizing the agent. " + "Please try again or simplify your request." + ), + validation_message=( + "The customized agent failed validation. " + "Please try rephrasing your request." + ), + error_details=error_msg, + ) + return ErrorResponse( + message=user_message, + error=f"customization_failed:{error_type}", + session_id=session_id, + ) + + # Handle clarifying questions + if isinstance(result, dict) and result.get("type") == "clarifying_questions": + questions = result.get("questions") or [] + if not isinstance(questions, list): + logger.error( + f"Unexpected clarifying questions format: {type(questions)}" + ) + questions = [] + return ClarificationNeededResponse( + message=( + "I need some more information to customize this agent. " + "Please answer the following questions:" + ), + questions=[ + ClarifyingQuestion( + question=q.get("question", ""), + keyword=q.get("keyword", ""), + example=q.get("example"), + ) + for q in questions + if isinstance(q, dict) + ], + session_id=session_id, + ) + + # Result should be the customized agent JSON + if not isinstance(result, dict): + logger.error(f"Unexpected customize_template response type: {type(result)}") + return ErrorResponse( + message="Failed to customize the agent due to an unexpected response.", + error="unexpected_response_type", + session_id=session_id, + ) + + customized_agent = result + + agent_name = customized_agent.get( + "name", f"Customized {agent_details.agent_name}" + ) + agent_description = customized_agent.get("description", "") + nodes = customized_agent.get("nodes") + links = customized_agent.get("links") + node_count = len(nodes) if isinstance(nodes, list) else 0 + link_count = len(links) if isinstance(links, list) else 0 + + if not save: + return AgentPreviewResponse( + message=( + f"I've customized the agent '{agent_details.agent_name}'. " + f"The customized agent has {node_count} blocks. " + f"Review it and call customize_agent with save=true to save it." + ), + agent_json=customized_agent, + agent_name=agent_name, + description=agent_description, + node_count=node_count, + link_count=link_count, + session_id=session_id, + ) + + if not user_id: + return ErrorResponse( + message="You must be logged in to save agents.", + error="auth_required", + session_id=session_id, + ) + + # Save to user's library + try: + created_graph, library_agent = await save_agent_to_library( + customized_agent, user_id, is_update=False + ) + + return AgentSavedResponse( + message=( + f"Customized agent '{created_graph.name}' " + f"(based on '{agent_details.agent_name}') " + f"has been saved to your library!" + ), + agent_id=created_graph.id, + agent_name=created_graph.name, + library_agent_id=library_agent.id, + library_agent_link=f"/library/agents/{library_agent.id}", + agent_page_link=f"/build?flowID={created_graph.id}", + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error saving customized agent: {e}") + return ErrorResponse( + message="Failed to save the customized agent. Please try again.", + error="save_failed", + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py index ee8eee53ce..3ae56407a7 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py @@ -9,6 +9,7 @@ from .agent_generator import ( AgentGeneratorNotConfiguredError, generate_agent_patch, get_agent_as_json, + get_all_relevant_agents_for_generation, get_user_message_for_error, save_agent_to_library, ) @@ -16,6 +17,7 @@ from .base import BaseTool from .models import ( AgentPreviewResponse, AgentSavedResponse, + AsyncProcessingResponse, ClarificationNeededResponse, ClarifyingQuestion, ErrorResponse, @@ -103,6 +105,10 @@ class EditAgentTool(BaseTool): save = kwargs.get("save", True) session_id = session.session_id if session else None + # Extract async processing params (passed by long-running tool handler) + operation_id = kwargs.get("_operation_id") + task_id = kwargs.get("_task_id") + if not agent_id: return ErrorResponse( message="Please provide the agent ID to edit.", @@ -117,7 +123,6 @@ class EditAgentTool(BaseTool): session_id=session_id, ) - # Step 1: Fetch current agent current_agent = await get_agent_as_json(agent_id, user_id) if current_agent is None: @@ -127,14 +132,34 @@ class EditAgentTool(BaseTool): session_id=session_id, ) - # Build the update request with context + library_agents = None + if user_id: + try: + graph_id = current_agent.get("id") + library_agents = await get_all_relevant_agents_for_generation( + user_id=user_id, + search_query=changes, + exclude_graph_id=graph_id, + include_marketplace=True, + ) + logger.debug( + f"Found {len(library_agents)} relevant agents for sub-agent composition" + ) + except Exception as e: + logger.warning(f"Failed to fetch library agents: {e}") + update_request = changes if context: update_request = f"{changes}\n\nAdditional context:\n{context}" - # Step 2: Generate updated agent (external service handles fixing and validation) try: - result = await generate_agent_patch(update_request, current_agent) + result = await generate_agent_patch( + update_request, + current_agent, + library_agents, + operation_id=operation_id, + task_id=task_id, + ) except AgentGeneratorNotConfiguredError: return ErrorResponse( message=( @@ -153,6 +178,19 @@ class EditAgentTool(BaseTool): session_id=session_id, ) + # Check if Agent Generator accepted for async processing + if result.get("status") == "accepted": + logger.info( + f"Agent edit delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return AsyncProcessingResponse( + message="Agent edit started. You'll be notified when it's complete.", + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + ) + # Check if the result is an error from the external service if isinstance(result, dict) and result.get("type") == "error": error_msg = result.get("error", "Unknown error") @@ -162,6 +200,7 @@ class EditAgentTool(BaseTool): operation="generate the changes", llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.", validation_message="The generated changes failed validation. Please try rephrasing your request.", + error_details=error_msg, ) return ErrorResponse( message=user_message, @@ -175,7 +214,6 @@ class EditAgentTool(BaseTool): session_id=session_id, ) - # Check if LLM returned clarifying questions if result.get("type") == "clarifying_questions": questions = result.get("questions", []) return ClarificationNeededResponse( @@ -194,7 +232,6 @@ class EditAgentTool(BaseTool): session_id=session_id, ) - # Result is the updated agent JSON updated_agent = result agent_name = updated_agent.get("name", "Updated Agent") @@ -202,7 +239,6 @@ class EditAgentTool(BaseTool): node_count = len(updated_agent.get("nodes", [])) link_count = len(updated_agent.get("links", [])) - # Step 3: Preview or save if not save: return AgentPreviewResponse( message=( @@ -218,7 +254,6 @@ class EditAgentTool(BaseTool): session_id=session_id, ) - # Save to library (creates a new version) if not user_id: return ErrorResponse( message="You must be logged in to save agents.", @@ -236,7 +271,7 @@ class EditAgentTool(BaseTool): agent_id=created_graph.id, agent_name=created_graph.name, library_agent_id=library_agent.id, - library_agent_link=f"/library/{library_agent.id}", + library_agent_link=f"/library/agents/{library_agent.id}", agent_page_link=f"/build?flowID={created_graph.id}", session_id=session_id, ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/models.py b/autogpt_platform/backend/backend/api/features/chat/tools/models.py index 49b233784e..69c8c6c684 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/models.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/models.py @@ -38,6 +38,8 @@ class ResponseType(str, Enum): OPERATION_STARTED = "operation_started" OPERATION_PENDING = "operation_pending" OPERATION_IN_PROGRESS = "operation_in_progress" + # Input validation + INPUT_VALIDATION_ERROR = "input_validation_error" # Base response model @@ -68,6 +70,10 @@ class AgentInfo(BaseModel): has_external_trigger: bool | None = None new_output: bool | None = None graph_id: str | None = None + inputs: dict[str, Any] | None = Field( + default=None, + description="Input schema for the agent, including field names, types, and defaults", + ) class AgentsFoundResponse(ToolResponseBase): @@ -194,6 +200,20 @@ class ErrorResponse(ToolResponseBase): details: dict[str, Any] | None = None +class InputValidationErrorResponse(ToolResponseBase): + """Response when run_agent receives unknown input fields.""" + + type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR + unrecognized_fields: list[str] = Field( + description="List of input field names that were not recognized" + ) + inputs: dict[str, Any] = Field( + description="The agent's valid input schema for reference" + ) + graph_id: str | None = None + graph_version: int | None = None + + # Agent output models class ExecutionOutputInfo(BaseModel): """Summary of a single execution's outputs.""" @@ -352,11 +372,15 @@ class OperationStartedResponse(ToolResponseBase): This is returned immediately to the client while the operation continues to execute. The user can close the tab and check back later. + + The task_id can be used to reconnect to the SSE stream via + GET /chat/tasks/{task_id}/stream?last_idx=0 """ type: ResponseType = ResponseType.OPERATION_STARTED operation_id: str tool_name: str + task_id: str | None = None # For SSE reconnection class OperationPendingResponse(ToolResponseBase): @@ -380,3 +404,20 @@ class OperationInProgressResponse(ToolResponseBase): type: ResponseType = ResponseType.OPERATION_IN_PROGRESS tool_call_id: str + + +class AsyncProcessingResponse(ToolResponseBase): + """Response when an operation has been delegated to async processing. + + This is returned by tools when the external service accepts the request + for async processing (HTTP 202 Accepted). The Redis Streams completion + consumer will handle the result when the external service completes. + + The status field is specifically "accepted" to allow the long-running tool + handler to detect this response and skip LLM continuation. + """ + + type: ResponseType = ResponseType.OPERATION_STARTED + status: str = "accepted" # Must be "accepted" for detection + operation_id: str | None = None + task_id: str | None = None diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py index a7fa65348a..73d4cf81f2 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py @@ -30,6 +30,7 @@ from .models import ( ErrorResponse, ExecutionOptions, ExecutionStartedResponse, + InputValidationErrorResponse, SetupInfo, SetupRequirementsResponse, ToolResponseBase, @@ -273,6 +274,22 @@ class RunAgentTool(BaseTool): input_properties = graph.input_schema.get("properties", {}) required_fields = set(graph.input_schema.get("required", [])) provided_inputs = set(params.inputs.keys()) + valid_fields = set(input_properties.keys()) + + # Check for unknown input fields + unrecognized_fields = provided_inputs - valid_fields + if unrecognized_fields: + return InputValidationErrorResponse( + message=( + f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. " + f"Agent was not executed. Please use the correct field names from the schema." + ), + session_id=session_id, + unrecognized_fields=sorted(unrecognized_fields), + inputs=graph.input_schema, + graph_id=graph.id, + graph_version=graph.version, + ) # If agent has inputs but none were provided AND use_defaults is not set, # always show what's available first so user can decide diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py index 404df2adb6..d5da394fa6 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py @@ -402,3 +402,42 @@ async def test_run_agent_schedule_without_name(setup_test_data): # Should return error about missing schedule_name assert result_data.get("type") == "error" assert "schedule_name" in result_data["message"].lower() + + +@pytest.mark.asyncio(loop_scope="session") +async def test_run_agent_rejects_unknown_input_fields(setup_test_data): + """Test that run_agent returns input_validation_error for unknown input fields.""" + user = setup_test_data["user"] + store_submission = setup_test_data["store_submission"] + + tool = RunAgentTool() + agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}" + session = make_session(user_id=user.id) + + # Execute with unknown input field names + response = await tool.execute( + user_id=user.id, + session_id=str(uuid.uuid4()), + tool_call_id=str(uuid.uuid4()), + username_agent_slug=agent_marketplace_id, + inputs={ + "unknown_field": "some value", + "another_unknown": "another value", + }, + session=session, + ) + + assert response is not None + assert hasattr(response, "output") + assert isinstance(response.output, str) + result_data = orjson.loads(response.output) + + # Should return input_validation_error type with unrecognized fields + assert result_data.get("type") == "input_validation_error" + assert "unrecognized_fields" in result_data + assert set(result_data["unrecognized_fields"]) == { + "another_unknown", + "unknown_field", + } + assert "inputs" in result_data # Contains the valid schema + assert "Agent was not executed" in result_data["message"] diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py index a59082b399..51bb2c0575 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py @@ -5,6 +5,8 @@ import uuid from collections import defaultdict from typing import Any +from pydantic_core import PydanticUndefined + from backend.api.features.chat.model import ChatSession from backend.data.block import get_block from backend.data.execution import ExecutionContext @@ -75,15 +77,22 @@ class RunBlockTool(BaseTool): self, user_id: str, block: Any, + input_data: dict[str, Any] | None = None, ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: """ Check if user has required credentials for a block. + Args: + user_id: User ID + block: Block to check credentials for + input_data: Input data for the block (used to determine provider via discriminator) + Returns: tuple[matched_credentials, missing_credentials] """ matched_credentials: dict[str, CredentialsMetaInput] = {} missing_credentials: list[CredentialsMetaInput] = [] + input_data = input_data or {} # Get credential field info from block's input schema credentials_fields_info = block.input_schema.get_credentials_fields_info() @@ -96,14 +105,33 @@ class RunBlockTool(BaseTool): available_creds = await creds_manager.store.get_all_creds(user_id) for field_name, field_info in credentials_fields_info.items(): - # field_info.provider is a frozenset of acceptable providers - # field_info.supported_types is a frozenset of acceptable types + effective_field_info = field_info + if field_info.discriminator and field_info.discriminator_mapping: + # Get discriminator from input, falling back to schema default + discriminator_value = input_data.get(field_info.discriminator) + if discriminator_value is None: + field = block.input_schema.model_fields.get( + field_info.discriminator + ) + if field and field.default is not PydanticUndefined: + discriminator_value = field.default + + if ( + discriminator_value + and discriminator_value in field_info.discriminator_mapping + ): + effective_field_info = field_info.discriminate(discriminator_value) + logger.debug( + f"Discriminated provider for {field_name}: " + f"{discriminator_value} -> {effective_field_info.provider}" + ) + matching_cred = next( ( cred for cred in available_creds - if cred.provider in field_info.provider - and cred.type in field_info.supported_types + if cred.provider in effective_field_info.provider + and cred.type in effective_field_info.supported_types ), None, ) @@ -117,8 +145,8 @@ class RunBlockTool(BaseTool): ) else: # Create a placeholder for the missing credential - provider = next(iter(field_info.provider), "unknown") - cred_type = next(iter(field_info.supported_types), "api_key") + provider = next(iter(effective_field_info.provider), "unknown") + cred_type = next(iter(effective_field_info.supported_types), "api_key") missing_credentials.append( CredentialsMetaInput( id=field_name, @@ -186,10 +214,9 @@ class RunBlockTool(BaseTool): logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}") - # Check credentials creds_manager = IntegrationCredentialsManager() matched_credentials, missing_credentials = await self._check_block_credentials( - user_id, block + user_id, block, input_data ) if missing_credentials: diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py index a2ac91dc65..bd25594b8a 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py @@ -8,7 +8,12 @@ from backend.api.features.library import model as library_model from backend.api.features.store import db as store_db from backend.data import graph as graph_db from backend.data.graph import GraphModel -from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput +from backend.data.model import ( + CredentialsFieldInfo, + CredentialsMetaInput, + HostScopedCredentials, + OAuth2Credentials, +) from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util.exceptions import NotFoundError @@ -266,13 +271,21 @@ async def match_user_credentials_to_graph( credential_requirements, _node_fields, ) in aggregated_creds.items(): - # Find first matching credential by provider and type + # Find first matching credential by provider, type, and scopes matching_cred = next( ( cred for cred in available_creds if cred.provider in credential_requirements.provider and cred.type in credential_requirements.supported_types + and ( + cred.type != "oauth2" + or _credential_has_required_scopes(cred, credential_requirements) + ) + and ( + cred.type != "host_scoped" + or _credential_is_for_host(cred, credential_requirements) + ) ), None, ) @@ -296,10 +309,17 @@ async def match_user_credentials_to_graph( f"{credential_field_name} (validation failed: {e})" ) else: + # Build a helpful error message including scope requirements + error_parts = [ + f"provider in {list(credential_requirements.provider)}", + f"type in {list(credential_requirements.supported_types)}", + ] + if credential_requirements.required_scopes: + error_parts.append( + f"scopes including {list(credential_requirements.required_scopes)}" + ) missing_creds.append( - f"{credential_field_name} " - f"(requires provider in {list(credential_requirements.provider)}, " - f"type in {list(credential_requirements.supported_types)})" + f"{credential_field_name} (requires {', '.join(error_parts)})" ) logger.info( @@ -309,6 +329,35 @@ async def match_user_credentials_to_graph( return graph_credentials_inputs, missing_creds +def _credential_has_required_scopes( + credential: OAuth2Credentials, + requirements: CredentialsFieldInfo, +) -> bool: + """Check if an OAuth2 credential has all the scopes required by the input.""" + # If no scopes are required, any credential matches + if not requirements.required_scopes: + return True + + # Check that credential scopes are a superset of required scopes + return set(credential.scopes).issuperset(requirements.required_scopes) + + +def _credential_is_for_host( + credential: HostScopedCredentials, + requirements: CredentialsFieldInfo, +) -> bool: + """Check if a host-scoped credential matches the host required by the input.""" + # We need to know the host to match host-scoped credentials to. + # Graph.aggregate_credentials_inputs() adds the node's set URL value (if any) + # to discriminator_values. No discriminator_values -> no host to match against. + if not requirements.discriminator_values: + return True + + # Check that credential host matches required host. + # Host-scoped credential inputs are grouped by host, so any item from the set works. + return credential.matches_url(list(requirements.discriminator_values)[0]) + + async def check_user_has_required_credentials( user_id: str, required_credentials: list[CredentialsMetaInput], diff --git a/autogpt_platform/backend/backend/api/features/library/db.py b/autogpt_platform/backend/backend/api/features/library/db.py index 872fe66b28..394f959953 100644 --- a/autogpt_platform/backend/backend/api/features/library/db.py +++ b/autogpt_platform/backend/backend/api/features/library/db.py @@ -39,6 +39,7 @@ async def list_library_agents( sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT, page: int = 1, page_size: int = 50, + include_executions: bool = False, ) -> library_model.LibraryAgentResponse: """ Retrieves a paginated list of LibraryAgent records for a given user. @@ -49,6 +50,9 @@ async def list_library_agents( sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser). page: Current page (1-indexed). page_size: Number of items per page. + include_executions: Whether to include execution data for status calculation. + Defaults to False for performance (UI fetches status separately). + Set to True when accurate status/metrics are needed (e.g., agent generator). Returns: A LibraryAgentResponse containing the list of agents and pagination details. @@ -76,7 +80,6 @@ async def list_library_agents( "isArchived": False, } - # Build search filter if applicable if search_term: where_clause["OR"] = [ { @@ -93,7 +96,6 @@ async def list_library_agents( }, ] - # Determine sorting order_by: prisma.types.LibraryAgentOrderByInput | None = None if sort_by == library_model.LibraryAgentSort.CREATED_AT: @@ -105,7 +107,7 @@ async def list_library_agents( library_agents = await prisma.models.LibraryAgent.prisma().find_many( where=where_clause, include=library_agent_include( - user_id, include_nodes=False, include_executions=False + user_id, include_nodes=False, include_executions=include_executions ), order=order_by, skip=(page - 1) * page_size, diff --git a/autogpt_platform/backend/backend/api/features/library/model.py b/autogpt_platform/backend/backend/api/features/library/model.py index 14d7c7be81..c6bc0e0427 100644 --- a/autogpt_platform/backend/backend/api/features/library/model.py +++ b/autogpt_platform/backend/backend/api/features/library/model.py @@ -9,6 +9,7 @@ import pydantic from backend.data.block import BlockInput from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo from backend.data.model import CredentialsMetaInput, is_credentials_field_name +from backend.util.json import loads as json_loads from backend.util.models import Pagination if TYPE_CHECKING: @@ -16,10 +17,10 @@ if TYPE_CHECKING: class LibraryAgentStatus(str, Enum): - COMPLETED = "COMPLETED" # All runs completed - HEALTHY = "HEALTHY" # Agent is running (not all runs have completed) - WAITING = "WAITING" # Agent is queued or waiting to start - ERROR = "ERROR" # Agent is in an error state + COMPLETED = "COMPLETED" + HEALTHY = "HEALTHY" + WAITING = "WAITING" + ERROR = "ERROR" class MarketplaceListingCreator(pydantic.BaseModel): @@ -39,6 +40,30 @@ class MarketplaceListing(pydantic.BaseModel): creator: MarketplaceListingCreator +class RecentExecution(pydantic.BaseModel): + """Summary of a recent execution for quality assessment. + + Used by the LLM to understand the agent's recent performance with specific examples + rather than just aggregate statistics. + """ + + status: str + correctness_score: float | None = None + activity_summary: str | None = None + + +def _parse_settings(settings: dict | str | None) -> GraphSettings: + """Parse settings from database, handling both dict and string formats.""" + if settings is None: + return GraphSettings() + try: + if isinstance(settings, str): + settings = json_loads(settings) + return GraphSettings.model_validate(settings) + except Exception: + return GraphSettings() + + class LibraryAgent(pydantic.BaseModel): """ Represents an agent in the library, including metadata for display and @@ -48,7 +73,7 @@ class LibraryAgent(pydantic.BaseModel): id: str graph_id: str graph_version: int - owner_user_id: str # ID of user who owns/created this agent graph + owner_user_id: str image_url: str | None @@ -64,7 +89,7 @@ class LibraryAgent(pydantic.BaseModel): description: str instructions: str | None = None - input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend + input_schema: dict[str, Any] output_schema: dict[str, Any] credentials_input_schema: dict[str, Any] | None = pydantic.Field( description="Input schema for credentials required by the agent", @@ -81,25 +106,19 @@ class LibraryAgent(pydantic.BaseModel): ) trigger_setup_info: Optional[GraphTriggerInfo] = None - # Indicates whether there's a new output (based on recent runs) new_output: bool - - # Whether the user can access the underlying graph + execution_count: int = 0 + success_rate: float | None = None + avg_correctness_score: float | None = None + recent_executions: list[RecentExecution] = pydantic.Field( + default_factory=list, + description="List of recent executions with status, score, and summary", + ) can_access_graph: bool - - # Indicates if this agent is the latest version is_latest_version: bool - - # Whether the agent is marked as favorite by the user is_favorite: bool - - # Recommended schedule cron (from marketplace agents) recommended_schedule_cron: str | None = None - - # User-specific settings for this library agent settings: GraphSettings = pydantic.Field(default_factory=GraphSettings) - - # Marketplace listing information if the agent has been published marketplace_listing: Optional["MarketplaceListing"] = None @staticmethod @@ -123,7 +142,6 @@ class LibraryAgent(pydantic.BaseModel): agent_updated_at = agent.AgentGraph.updatedAt lib_agent_updated_at = agent.updatedAt - # Compute updated_at as the latest between library agent and graph updated_at = ( max(agent_updated_at, lib_agent_updated_at) if agent_updated_at @@ -136,7 +154,6 @@ class LibraryAgent(pydantic.BaseModel): creator_name = agent.Creator.name or "Unknown" creator_image_url = agent.Creator.avatarUrl or "" - # Logic to calculate status and new_output week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta( days=7 ) @@ -145,13 +162,55 @@ class LibraryAgent(pydantic.BaseModel): status = status_result.status new_output = status_result.new_output - # Check if user can access the graph - can_access_graph = agent.AgentGraph.userId == agent.userId + execution_count = len(executions) + success_rate: float | None = None + avg_correctness_score: float | None = None + if execution_count > 0: + success_count = sum( + 1 + for e in executions + if e.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED + ) + success_rate = (success_count / execution_count) * 100 - # Hard-coded to True until a method to check is implemented + correctness_scores = [] + for e in executions: + if e.stats and isinstance(e.stats, dict): + score = e.stats.get("correctness_score") + if score is not None and isinstance(score, (int, float)): + correctness_scores.append(float(score)) + if correctness_scores: + avg_correctness_score = sum(correctness_scores) / len( + correctness_scores + ) + + recent_executions: list[RecentExecution] = [] + for e in executions: + exec_score: float | None = None + exec_summary: str | None = None + if e.stats and isinstance(e.stats, dict): + score = e.stats.get("correctness_score") + if score is not None and isinstance(score, (int, float)): + exec_score = float(score) + summary = e.stats.get("activity_status") + if summary is not None and isinstance(summary, str): + exec_summary = summary + exec_status = ( + e.executionStatus.value + if hasattr(e.executionStatus, "value") + else str(e.executionStatus) + ) + recent_executions.append( + RecentExecution( + status=exec_status, + correctness_score=exec_score, + activity_summary=exec_summary, + ) + ) + + can_access_graph = agent.AgentGraph.userId == agent.userId is_latest_version = True - # Build marketplace_listing if available marketplace_listing_data = None if store_listing and store_listing.ActiveVersion and profile: creator_data = MarketplaceListingCreator( @@ -190,11 +249,15 @@ class LibraryAgent(pydantic.BaseModel): has_sensitive_action=graph.has_sensitive_action, trigger_setup_info=graph.trigger_setup_info, new_output=new_output, + execution_count=execution_count, + success_rate=success_rate, + avg_correctness_score=avg_correctness_score, + recent_executions=recent_executions, can_access_graph=can_access_graph, is_latest_version=is_latest_version, is_favorite=agent.isFavorite, recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron, - settings=GraphSettings.model_validate(agent.settings), + settings=_parse_settings(agent.settings), marketplace_listing=marketplace_listing_data, ) @@ -220,18 +283,15 @@ def _calculate_agent_status( if not executions: return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False) - # Track how many times each execution status appears status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus} new_output = False for execution in executions: - # Check if there's a completed run more recent than `recent_threshold` if execution.createdAt >= recent_threshold: if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED: new_output = True status_counts[execution.executionStatus] += 1 - # Determine the final status based on counts if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0: return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output) elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0: diff --git a/autogpt_platform/backend/backend/api/features/store/db.py b/autogpt_platform/backend/backend/api/features/store/db.py index 956fdfa7da..850a2bc3e9 100644 --- a/autogpt_platform/backend/backend/api/features/store/db.py +++ b/autogpt_platform/backend/backend/api/features/store/db.py @@ -112,6 +112,7 @@ async def get_store_agents( description=agent["description"], runs=agent["runs"], rating=agent["rating"], + agent_graph_id=agent.get("agentGraphId", ""), ) store_agents.append(store_agent) except Exception as e: @@ -170,6 +171,7 @@ async def get_store_agents( description=agent.description, runs=agent.runs, rating=agent.rating, + agent_graph_id=agent.agentGraphId, ) # Add to the list only if creation was successful store_agents.append(store_agent) diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py b/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py index bae5b97cd6..86af457f50 100644 --- a/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py +++ b/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py @@ -454,6 +454,9 @@ async def test_unified_hybrid_search_pagination( cleanup_embeddings: list, ): """Test unified search pagination works correctly.""" + # Use a unique search term to avoid matching other test data + unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}" + # Create multiple items content_ids = [] for i in range(5): @@ -465,14 +468,14 @@ async def test_unified_hybrid_search_pagination( content_type=ContentType.BLOCK, content_id=content_id, embedding=mock_embedding, - searchable_text=f"pagination test item number {i}", + searchable_text=f"{unique_term} item number {i}", metadata={"index": i}, user_id=None, ) # Get first page page1_results, total1 = await unified_hybrid_search( - query="pagination test", + query=unique_term, content_types=[ContentType.BLOCK], page=1, page_size=2, @@ -480,7 +483,7 @@ async def test_unified_hybrid_search_pagination( # Get second page page2_results, total2 = await unified_hybrid_search( - query="pagination test", + query=unique_term, content_types=[ContentType.BLOCK], page=2, page_size=2, diff --git a/autogpt_platform/backend/backend/api/features/store/hybrid_search.py b/autogpt_platform/backend/backend/api/features/store/hybrid_search.py index 8b0884bb24..e1b8f402c8 100644 --- a/autogpt_platform/backend/backend/api/features/store/hybrid_search.py +++ b/autogpt_platform/backend/backend/api/features/store/hybrid_search.py @@ -600,6 +600,7 @@ async def hybrid_search( sa.featured, sa.is_available, sa.updated_at, + sa."agentGraphId", -- Searchable text for BM25 reranking COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text, -- Semantic score @@ -659,6 +660,7 @@ async def hybrid_search( featured, is_available, updated_at, + "agentGraphId", searchable_text, semantic_score, lexical_score, diff --git a/autogpt_platform/backend/backend/api/features/store/model.py b/autogpt_platform/backend/backend/api/features/store/model.py index a3310b96fc..d66b91807d 100644 --- a/autogpt_platform/backend/backend/api/features/store/model.py +++ b/autogpt_platform/backend/backend/api/features/store/model.py @@ -38,6 +38,7 @@ class StoreAgent(pydantic.BaseModel): description: str runs: int rating: float + agent_graph_id: str class StoreAgentsResponse(pydantic.BaseModel): diff --git a/autogpt_platform/backend/backend/api/features/store/model_test.py b/autogpt_platform/backend/backend/api/features/store/model_test.py index fd09a0cf77..c4109f4603 100644 --- a/autogpt_platform/backend/backend/api/features/store/model_test.py +++ b/autogpt_platform/backend/backend/api/features/store/model_test.py @@ -26,11 +26,13 @@ def test_store_agent(): description="Test description", runs=50, rating=4.5, + agent_graph_id="test-graph-id", ) assert agent.slug == "test-agent" assert agent.agent_name == "Test Agent" assert agent.runs == 50 assert agent.rating == 4.5 + assert agent.agent_graph_id == "test-graph-id" def test_store_agents_response(): @@ -46,6 +48,7 @@ def test_store_agents_response(): description="Test description", runs=50, rating=4.5, + agent_graph_id="test-graph-id", ) ], pagination=store_model.Pagination( diff --git a/autogpt_platform/backend/backend/api/features/store/routes_test.py b/autogpt_platform/backend/backend/api/features/store/routes_test.py index 36431c20ec..fcef3f845a 100644 --- a/autogpt_platform/backend/backend/api/features/store/routes_test.py +++ b/autogpt_platform/backend/backend/api/features/store/routes_test.py @@ -82,6 +82,7 @@ def test_get_agents_featured( description="Featured agent description", runs=100, rating=4.5, + agent_graph_id="test-graph-1", ) ], pagination=store_model.Pagination( @@ -127,6 +128,7 @@ def test_get_agents_by_creator( description="Creator agent description", runs=50, rating=4.0, + agent_graph_id="test-graph-2", ) ], pagination=store_model.Pagination( @@ -172,6 +174,7 @@ def test_get_agents_sorted( description="Top agent description", runs=1000, rating=5.0, + agent_graph_id="test-graph-3", ) ], pagination=store_model.Pagination( @@ -217,6 +220,7 @@ def test_get_agents_search( description="Specific search term description", runs=75, rating=4.2, + agent_graph_id="test-graph-search", ) ], pagination=store_model.Pagination( @@ -262,6 +266,7 @@ def test_get_agents_category( description="Category agent description", runs=60, rating=4.1, + agent_graph_id="test-graph-category", ) ], pagination=store_model.Pagination( @@ -306,6 +311,7 @@ def test_get_agents_pagination( description=f"Agent {i} description", runs=i * 10, rating=4.0, + agent_graph_id="test-graph-2", ) for i in range(5) ], diff --git a/autogpt_platform/backend/backend/api/features/store/test_cache_delete.py b/autogpt_platform/backend/backend/api/features/store/test_cache_delete.py index dd9be1f4ab..298c51d47c 100644 --- a/autogpt_platform/backend/backend/api/features/store/test_cache_delete.py +++ b/autogpt_platform/backend/backend/api/features/store/test_cache_delete.py @@ -33,6 +33,7 @@ class TestCacheDeletion: description="Test description", runs=100, rating=4.5, + agent_graph_id="test-graph-id", ) ], pagination=Pagination( diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index b936312ce1..0eef76193e 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -40,6 +40,10 @@ import backend.data.user import backend.integrations.webhooks.utils import backend.util.service import backend.util.settings +from backend.api.features.chat.completion_consumer import ( + start_completion_consumer, + stop_completion_consumer, +) from backend.blocks.llm import DEFAULT_LLM_MODEL from backend.data.model import Credentials from backend.integrations.providers import ProviderName @@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI): await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL) await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs() + # Start chat completion consumer for Redis Streams notifications + try: + await start_completion_consumer() + except Exception as e: + logger.warning(f"Could not start chat completion consumer: {e}") + with launch_darkly_context(): yield + # Stop chat completion consumer + try: + await stop_completion_consumer() + except Exception as e: + logger.warning(f"Error stopping chat completion consumer: {e}") + try: await shutdown_cloud_storage_handler() except Exception as e: diff --git a/autogpt_platform/backend/backend/api/ws_api.py b/autogpt_platform/backend/backend/api/ws_api.py index b71fdb3526..e254d4b4db 100644 --- a/autogpt_platform/backend/backend/api/ws_api.py +++ b/autogpt_platform/backend/backend/api/ws_api.py @@ -66,18 +66,24 @@ async def event_broadcaster(manager: ConnectionManager): execution_bus = AsyncRedisExecutionEventBus() notification_bus = AsyncRedisNotificationEventBus() - async def execution_worker(): - async for event in execution_bus.listen("*"): - await manager.send_execution_update(event) + try: - async def notification_worker(): - async for notification in notification_bus.listen("*"): - await manager.send_notification( - user_id=notification.user_id, - payload=notification.payload, - ) + async def execution_worker(): + async for event in execution_bus.listen("*"): + await manager.send_execution_update(event) - await asyncio.gather(execution_worker(), notification_worker()) + async def notification_worker(): + async for notification in notification_bus.listen("*"): + await manager.send_notification( + user_id=notification.user_id, + payload=notification.payload, + ) + + await asyncio.gather(execution_worker(), notification_worker()) + finally: + # Ensure PubSub connections are closed on any exit to prevent leaks + await execution_bus.close() + await notification_bus.close() async def authenticate_websocket(websocket: WebSocket) -> str: diff --git a/autogpt_platform/backend/backend/blocks/linear/_api.py b/autogpt_platform/backend/backend/blocks/linear/_api.py index 477b8a209c..ea609d515a 100644 --- a/autogpt_platform/backend/backend/blocks/linear/_api.py +++ b/autogpt_platform/backend/backend/blocks/linear/_api.py @@ -162,8 +162,16 @@ class LinearClient: "searchTerm": team_name, } - team_id = await self.query(query, variables) - return team_id["teams"]["nodes"][0]["id"] + result = await self.query(query, variables) + nodes = result["teams"]["nodes"] + + if not nodes: + raise LinearAPIException( + f"Team '{team_name}' not found. Check the team name or key and try again.", + status_code=404, + ) + + return nodes[0]["id"] except LinearAPIException as e: raise e @@ -240,17 +248,44 @@ class LinearClient: except LinearAPIException as e: raise e - async def try_search_issues(self, term: str) -> list[Issue]: + async def try_search_issues( + self, + term: str, + max_results: int = 10, + team_id: str | None = None, + ) -> list[Issue]: try: query = """ - query SearchIssues($term: String!, $includeComments: Boolean!) { - searchIssues(term: $term, includeComments: $includeComments) { + query SearchIssues( + $term: String!, + $first: Int, + $teamId: String + ) { + searchIssues( + term: $term, + first: $first, + teamId: $teamId + ) { nodes { id identifier title description priority + createdAt + state { + id + name + type + } + project { + id + name + } + assignee { + id + name + } } } } @@ -258,7 +293,8 @@ class LinearClient: variables: dict[str, Any] = { "term": term, - "includeComments": True, + "first": max_results, + "teamId": team_id, } issues = await self.query(query, variables) diff --git a/autogpt_platform/backend/backend/blocks/linear/issues.py b/autogpt_platform/backend/backend/blocks/linear/issues.py index baac01214c..165178f8ee 100644 --- a/autogpt_platform/backend/backend/blocks/linear/issues.py +++ b/autogpt_platform/backend/backend/blocks/linear/issues.py @@ -17,7 +17,7 @@ from ._config import ( LinearScope, linear, ) -from .models import CreateIssueResponse, Issue +from .models import CreateIssueResponse, Issue, State class LinearCreateIssueBlock(Block): @@ -135,9 +135,20 @@ class LinearSearchIssuesBlock(Block): description="Linear credentials with read permissions", required_scopes={LinearScope.READ}, ) + max_results: int = SchemaField( + description="Maximum number of results to return", + default=10, + ge=1, + le=100, + ) + team_name: str | None = SchemaField( + description="Optional team name to filter results (e.g., 'Internal', 'Open Source')", + default=None, + ) class Output(BlockSchemaOutput): issues: list[Issue] = SchemaField(description="List of issues") + error: str = SchemaField(description="Error message if the search failed") def __init__(self): super().__init__( @@ -145,8 +156,11 @@ class LinearSearchIssuesBlock(Block): description="Searches for issues on Linear", input_schema=self.Input, output_schema=self.Output, + categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING}, test_input={ "term": "Test issue", + "max_results": 10, + "team_name": None, "credentials": TEST_CREDENTIALS_INPUT_OAUTH, }, test_credentials=TEST_CREDENTIALS_OAUTH, @@ -156,10 +170,14 @@ class LinearSearchIssuesBlock(Block): [ Issue( id="abc123", - identifier="abc123", + identifier="TST-123", title="Test issue", description="Test description", priority=1, + state=State( + id="state1", name="In Progress", type="started" + ), + createdAt="2026-01-15T10:00:00.000Z", ) ], ) @@ -168,10 +186,12 @@ class LinearSearchIssuesBlock(Block): "search_issues": lambda *args, **kwargs: [ Issue( id="abc123", - identifier="abc123", + identifier="TST-123", title="Test issue", description="Test description", priority=1, + state=State(id="state1", name="In Progress", type="started"), + createdAt="2026-01-15T10:00:00.000Z", ) ] }, @@ -181,10 +201,22 @@ class LinearSearchIssuesBlock(Block): async def search_issues( credentials: OAuth2Credentials | APIKeyCredentials, term: str, + max_results: int = 10, + team_name: str | None = None, ) -> list[Issue]: client = LinearClient(credentials=credentials) - response: list[Issue] = await client.try_search_issues(term=term) - return response + + # Resolve team name to ID if provided + # Raises LinearAPIException with descriptive message if team not found + team_id: str | None = None + if team_name: + team_id = await client.try_get_team_by_name(team_name=team_name) + + return await client.try_search_issues( + term=term, + max_results=max_results, + team_id=team_id, + ) async def run( self, @@ -196,7 +228,10 @@ class LinearSearchIssuesBlock(Block): """Execute the issue search""" try: issues = await self.search_issues( - credentials=credentials, term=input_data.term + credentials=credentials, + term=input_data.term, + max_results=input_data.max_results, + team_name=input_data.team_name, ) yield "issues", issues except LinearAPIException as e: diff --git a/autogpt_platform/backend/backend/blocks/linear/models.py b/autogpt_platform/backend/backend/blocks/linear/models.py index bfeaa13656..dd1f603459 100644 --- a/autogpt_platform/backend/backend/blocks/linear/models.py +++ b/autogpt_platform/backend/backend/blocks/linear/models.py @@ -36,12 +36,21 @@ class Project(BaseModel): content: str | None = None +class State(BaseModel): + id: str + name: str + type: str | None = ( + None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled") + ) + + class Issue(BaseModel): id: str identifier: str title: str description: str | None priority: int + state: State | None = None project: Project | None = None createdAt: str | None = None comments: list[Comment] | None = None diff --git a/autogpt_platform/backend/backend/blocks/llm.py b/autogpt_platform/backend/backend/blocks/llm.py index fdcd7f3568..54295da1f1 100644 --- a/autogpt_platform/backend/backend/blocks/llm.py +++ b/autogpt_platform/backend/backend/blocks/llm.py @@ -32,7 +32,7 @@ from backend.data.model import ( from backend.integrations.providers import ProviderName from backend.util import json from backend.util.logging import TruncatedLogger -from backend.util.prompt import compress_prompt, estimate_token_count +from backend.util.prompt import compress_context, estimate_token_count from backend.util.text import TextFormatter logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]") @@ -115,7 +115,6 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta): CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101" CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001" - CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219" CLAUDE_3_HAIKU = "claude-3-haiku-20240307" # AI/ML API models AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo" @@ -280,9 +279,6 @@ MODEL_METADATA = { LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata( "anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2 ), # claude-haiku-4-5-20251001 - LlmModel.CLAUDE_3_7_SONNET: ModelMetadata( - "anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2 - ), # claude-3-7-sonnet-20250219 LlmModel.CLAUDE_3_HAIKU: ModelMetadata( "anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1 ), # claude-3-haiku-20240307 @@ -638,11 +634,18 @@ async def llm_call( context_window = llm_model.context_window if compress_prompt_to_fit: - prompt = compress_prompt( + result = await compress_context( messages=prompt, target_tokens=llm_model.context_window // 2, - lossy_ok=True, + client=None, # Truncation-only, no LLM summarization + reserve=0, # Caller handles response token budget separately ) + if result.error: + logger.warning( + f"Prompt compression did not meet target: {result.error}. " + f"Proceeding with {result.token_count} tokens." + ) + prompt = result.messages # Calculate available tokens based on context window and input length estimated_input_tokens = estimate_token_count(prompt) diff --git a/autogpt_platform/backend/backend/blocks/stagehand/blocks.py b/autogpt_platform/backend/backend/blocks/stagehand/blocks.py index be1d736962..91c096ffe4 100644 --- a/autogpt_platform/backend/backend/blocks/stagehand/blocks.py +++ b/autogpt_platform/backend/backend/blocks/stagehand/blocks.py @@ -83,7 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum): GPT41_MINI = "gpt-4.1-mini-2025-04-14" # Anthropic - CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219" + CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" @property def provider_name(self) -> str: @@ -137,7 +137,7 @@ class StagehandObserveBlock(Block): model: StagehandRecommendedLlmModel = SchemaField( title="LLM Model", description="LLM to use for Stagehand (provider is inferred)", - default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET, + default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET, advanced=False, ) model_credentials: AICredentials = AICredentialsField() @@ -182,10 +182,7 @@ class StagehandObserveBlock(Block): **kwargs, ) -> BlockOutput: - logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}") - logger.info( - f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}" - ) + logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}") with disable_signal_handling(): stagehand = Stagehand( @@ -230,7 +227,7 @@ class StagehandActBlock(Block): model: StagehandRecommendedLlmModel = SchemaField( title="LLM Model", description="LLM to use for Stagehand (provider is inferred)", - default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET, + default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET, advanced=False, ) model_credentials: AICredentials = AICredentialsField() @@ -282,10 +279,7 @@ class StagehandActBlock(Block): **kwargs, ) -> BlockOutput: - logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}") - logger.info( - f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}" - ) + logger.debug(f"ACT: Using model provider {model_credentials.provider}") with disable_signal_handling(): stagehand = Stagehand( @@ -330,7 +324,7 @@ class StagehandExtractBlock(Block): model: StagehandRecommendedLlmModel = SchemaField( title="LLM Model", description="LLM to use for Stagehand (provider is inferred)", - default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET, + default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET, advanced=False, ) model_credentials: AICredentials = AICredentialsField() @@ -370,10 +364,7 @@ class StagehandExtractBlock(Block): **kwargs, ) -> BlockOutput: - logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}") - logger.info( - f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}" - ) + logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}") with disable_signal_handling(): stagehand = Stagehand( diff --git a/autogpt_platform/backend/backend/data/block.py b/autogpt_platform/backend/backend/data/block.py index 8d9ecfff4c..eb9360b037 100644 --- a/autogpt_platform/backend/backend/data/block.py +++ b/autogpt_platform/backend/backend/data/block.py @@ -873,14 +873,13 @@ def is_block_auth_configured( async def initialize_blocks() -> None: - # First, sync all provider costs to blocks - # Imported here to avoid circular import from backend.sdk.cost_integration import sync_all_provider_costs + from backend.util.retry import func_retry sync_all_provider_costs() - for cls in get_blocks().values(): - block = cls() + @func_retry + async def sync_block_to_db(block: Block) -> None: existing_block = await AgentBlock.prisma().find_first( where={"OR": [{"id": block.id}, {"name": block.name}]} ) @@ -893,7 +892,7 @@ async def initialize_blocks() -> None: outputSchema=json.dumps(block.output_schema.jsonschema()), ) ) - continue + return input_schema = json.dumps(block.input_schema.jsonschema()) output_schema = json.dumps(block.output_schema.jsonschema()) @@ -913,6 +912,25 @@ async def initialize_blocks() -> None: }, ) + failed_blocks: list[str] = [] + for cls in get_blocks().values(): + block = cls() + try: + await sync_block_to_db(block) + except Exception as e: + logger.warning( + f"Failed to sync block {block.name} to database: {e}. " + "Block is still available in memory.", + exc_info=True, + ) + failed_blocks.append(block.name) + + if failed_blocks: + logger.error( + f"Failed to sync {len(failed_blocks)} block(s) to database: " + f"{', '.join(failed_blocks)}. These blocks are still available in memory." + ) + # Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281 def get_block(block_id: str) -> AnyBlockSchema | None: diff --git a/autogpt_platform/backend/backend/data/block_cost_config.py b/autogpt_platform/backend/backend/data/block_cost_config.py index 96b9435c4f..c568fa2319 100644 --- a/autogpt_platform/backend/backend/data/block_cost_config.py +++ b/autogpt_platform/backend/backend/data/block_cost_config.py @@ -83,7 +83,6 @@ MODEL_COST: dict[LlmModel, int] = { LlmModel.CLAUDE_4_5_HAIKU: 4, LlmModel.CLAUDE_4_5_OPUS: 14, LlmModel.CLAUDE_4_5_SONNET: 9, - LlmModel.CLAUDE_3_7_SONNET: 5, LlmModel.CLAUDE_3_HAIKU: 1, LlmModel.AIML_API_QWEN2_5_72B: 1, LlmModel.AIML_API_LLAMA3_1_70B: 1, diff --git a/autogpt_platform/backend/backend/data/event_bus.py b/autogpt_platform/backend/backend/data/event_bus.py index d8a1c5b729..614fb158b2 100644 --- a/autogpt_platform/backend/backend/data/event_bus.py +++ b/autogpt_platform/backend/backend/data/event_bus.py @@ -133,10 +133,23 @@ class RedisEventBus(BaseRedisEventBus[M], ABC): class AsyncRedisEventBus(BaseRedisEventBus[M], ABC): + def __init__(self): + self._pubsub: AsyncPubSub | None = None + @property async def connection(self) -> redis.AsyncRedis: return await redis.get_redis_async() + async def close(self) -> None: + """Close the PubSub connection if it exists.""" + if self._pubsub is not None: + try: + await self._pubsub.close() + except Exception: + logger.warning("Failed to close PubSub connection", exc_info=True) + finally: + self._pubsub = None + async def publish_event(self, event: M, channel_key: str): """ Publish an event to Redis. Gracefully handles connection failures @@ -157,6 +170,7 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC): await self.connection, channel_key ) assert isinstance(pubsub, AsyncPubSub) + self._pubsub = pubsub if "*" in channel_key: await pubsub.psubscribe(full_channel_name) diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index c1f38f81d5..ee6cd2e4b0 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -1028,6 +1028,39 @@ async def get_graph( return GraphModel.from_db(graph, for_export) +async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]: + """Batch-fetch multiple store-listed graphs by their IDs. + + Only returns graphs that have approved store listings (publicly available). + Does not require permission checks since store-listed graphs are public. + + Args: + *graph_ids: Variable number of graph IDs to fetch + + Returns: + Dict mapping graph_id to GraphModel for graphs with approved store listings + """ + if not graph_ids: + return {} + + store_listings = await StoreListingVersion.prisma().find_many( + where={ + "agentGraphId": {"in": list(graph_ids)}, + "submissionStatus": SubmissionStatus.APPROVED, + "isDeleted": False, + }, + include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}}, + distinct=["agentGraphId"], + order={"agentGraphVersion": "desc"}, + ) + + return { + listing.agentGraphId: GraphModel.from_db(listing.AgentGraph) + for listing in store_listings + if listing.AgentGraph + } + + async def get_graph_as_admin( graph_id: str, version: int | None = None, diff --git a/autogpt_platform/backend/backend/data/model.py b/autogpt_platform/backend/backend/data/model.py index 2cc73f6b7b..5a09c591c9 100644 --- a/autogpt_platform/backend/backend/data/model.py +++ b/autogpt_platform/backend/backend/data/model.py @@ -19,7 +19,6 @@ from typing import ( cast, get_args, ) -from urllib.parse import urlparse from uuid import uuid4 from prisma.enums import CreditTransactionType, OnboardingStep @@ -42,6 +41,7 @@ from typing_extensions import TypedDict from backend.integrations.providers import ProviderName from backend.util.json import loads as json_loads +from backend.util.request import parse_url from backend.util.settings import Secrets # Type alias for any provider name (including custom ones) @@ -397,19 +397,25 @@ class HostScopedCredentials(_BaseCredentials): def matches_url(self, url: str) -> bool: """Check if this credential should be applied to the given URL.""" - parsed_url = urlparse(url) - # Extract hostname without port - request_host = parsed_url.hostname + request_host, request_port = _extract_host_from_url(url) + cred_scope_host, cred_scope_port = _extract_host_from_url(self.host) if not request_host: return False - # Simple host matching - exact match or wildcard subdomain match - if self.host == request_host: + # If a port is specified in credential host, the request host port must match + if cred_scope_port is not None and request_port != cred_scope_port: + return False + # Non-standard ports are only allowed if explicitly specified in credential host + elif cred_scope_port is None and request_port not in (80, 443, None): + return False + + # Simple host matching + if cred_scope_host == request_host: return True # Support wildcard matching (e.g., "*.example.com" matches "api.example.com") - if self.host.startswith("*."): - domain = self.host[2:] # Remove "*." + if cred_scope_host.startswith("*."): + domain = cred_scope_host[2:] # Remove "*." return request_host.endswith(f".{domain}") or request_host == domain return False @@ -551,13 +557,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]): ) -def _extract_host_from_url(url: str) -> str: - """Extract host from URL for grouping host-scoped credentials.""" +def _extract_host_from_url(url: str) -> tuple[str, int | None]: + """Extract host and port from URL for grouping host-scoped credentials.""" try: - parsed = urlparse(url) - return parsed.hostname or url + parsed = parse_url(url) + return parsed.hostname or url, parsed.port except Exception: - return "" + return "", None class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): @@ -606,7 +612,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): providers = frozenset( [cast(CP, "http")] + [ - cast(CP, _extract_host_from_url(str(value))) + cast(CP, parse_url(str(value)).netloc) for value in field.discriminator_values ] ) @@ -666,10 +672,16 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): if not (self.discriminator and self.discriminator_mapping): return self + try: + provider = self.discriminator_mapping[discriminator_value] + except KeyError: + raise ValueError( + f"Model '{discriminator_value}' is not supported. " + "It may have been deprecated. Please update your agent configuration." + ) + return CredentialsFieldInfo( - credentials_provider=frozenset( - [self.discriminator_mapping[discriminator_value]] - ), + credentials_provider=frozenset([provider]), credentials_types=self.supported_types, credentials_scopes=self.required_scopes, discriminator=self.discriminator, diff --git a/autogpt_platform/backend/backend/data/model_test.py b/autogpt_platform/backend/backend/data/model_test.py index 37ec6be82f..e8e2ddfa35 100644 --- a/autogpt_platform/backend/backend/data/model_test.py +++ b/autogpt_platform/backend/backend/data/model_test.py @@ -79,10 +79,23 @@ class TestHostScopedCredentials: headers={"Authorization": SecretStr("Bearer token")}, ) - assert creds.matches_url("http://localhost:8080/api/v1") + # Non-standard ports require explicit port in credential host + assert not creds.matches_url("http://localhost:8080/api/v1") assert creds.matches_url("https://localhost:443/secure/endpoint") assert creds.matches_url("http://localhost/simple") + def test_matches_url_with_explicit_port(self): + """Test URL matching with explicit port in credential host.""" + creds = HostScopedCredentials( + provider="custom", + host="localhost:8080", + headers={"Authorization": SecretStr("Bearer token")}, + ) + + assert creds.matches_url("http://localhost:8080/api/v1") + assert not creds.matches_url("http://localhost:3000/api/v1") + assert not creds.matches_url("http://localhost/simple") + def test_empty_headers_dict(self): """Test HostScopedCredentials with empty headers.""" creds = HostScopedCredentials( @@ -128,8 +141,20 @@ class TestHostScopedCredentials: ("*.example.com", "https://sub.api.example.com/test", True), ("*.example.com", "https://example.com/test", True), ("*.example.com", "https://example.org/test", False), - ("localhost", "http://localhost:3000/test", True), + # Non-standard ports require explicit port in credential host + ("localhost", "http://localhost:3000/test", False), + ("localhost:3000", "http://localhost:3000/test", True), ("localhost", "http://127.0.0.1:3000/test", False), + # IPv6 addresses (frontend stores with brackets via URL.hostname) + ("[::1]", "http://[::1]/test", True), + ("[::1]", "http://[::1]:80/test", True), + ("[::1]", "https://[::1]:443/test", True), + ("[::1]", "http://[::1]:8080/test", False), # Non-standard port + ("[::1]:8080", "http://[::1]:8080/test", True), + ("[::1]:8080", "http://[::1]:9090/test", False), + ("[2001:db8::1]", "http://[2001:db8::1]/path", True), + ("[2001:db8::1]", "https://[2001:db8::1]:443/path", True), + ("[2001:db8::1]", "http://[2001:db8::ff]/path", False), ], ) def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool): diff --git a/autogpt_platform/backend/backend/executor/database.py b/autogpt_platform/backend/backend/executor/database.py index ae7474fc1d..d44439d51c 100644 --- a/autogpt_platform/backend/backend/executor/database.py +++ b/autogpt_platform/backend/backend/executor/database.py @@ -17,6 +17,7 @@ from backend.data.analytics import ( get_accuracy_trends_and_alerts, get_marketplace_graphs_for_monitoring, ) +from backend.data.auth.oauth import cleanup_expired_oauth_tokens from backend.data.credit import UsageTransactionMetadata, get_user_credit_model from backend.data.execution import ( create_graph_execution, @@ -219,6 +220,9 @@ class DatabaseManager(AppService): # Onboarding increment_onboarding_runs = _(increment_onboarding_runs) + # OAuth + cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens) + # Store get_store_agents = _(get_store_agents) get_store_agent_details = _(get_store_agent_details) @@ -349,6 +353,9 @@ class DatabaseManagerAsyncClient(AppServiceClient): # Onboarding increment_onboarding_runs = d.increment_onboarding_runs + # OAuth + cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens + # Store get_store_agents = d.get_store_agents get_store_agent_details = d.get_store_agent_details diff --git a/autogpt_platform/backend/backend/executor/scheduler.py b/autogpt_platform/backend/backend/executor/scheduler.py index 44b77fc018..cbdc441718 100644 --- a/autogpt_platform/backend/backend/executor/scheduler.py +++ b/autogpt_platform/backend/backend/executor/scheduler.py @@ -24,11 +24,9 @@ from dotenv import load_dotenv from pydantic import BaseModel, Field, ValidationError from sqlalchemy import MetaData, create_engine -from backend.data.auth.oauth import cleanup_expired_oauth_tokens from backend.data.block import BlockInput from backend.data.execution import GraphExecutionWithNodes from backend.data.model import CredentialsMetaInput -from backend.data.onboarding import increment_onboarding_runs from backend.executor import utils as execution_utils from backend.monitoring import ( NotificationJobArgs, @@ -38,7 +36,11 @@ from backend.monitoring import ( report_execution_accuracy_alerts, report_late_executions, ) -from backend.util.clients import get_database_manager_client, get_scheduler_client +from backend.util.clients import ( + get_database_manager_async_client, + get_database_manager_client, + get_scheduler_client, +) from backend.util.cloud_storage import cleanup_expired_files_async from backend.util.exceptions import ( GraphNotFoundError, @@ -148,6 +150,7 @@ def execute_graph(**kwargs): async def _execute_graph(**kwargs): args = GraphExecutionJobArgs(**kwargs) start_time = asyncio.get_event_loop().time() + db = get_database_manager_async_client() try: logger.info(f"Executing recurring job for graph #{args.graph_id}") graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution( @@ -157,7 +160,7 @@ async def _execute_graph(**kwargs): inputs=args.input_data, graph_credentials_inputs=args.input_credentials, ) - await increment_onboarding_runs(args.user_id) + await db.increment_onboarding_runs(args.user_id) elapsed = asyncio.get_event_loop().time() - start_time logger.info( f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} " @@ -246,8 +249,13 @@ def cleanup_expired_files(): def cleanup_oauth_tokens(): """Clean up expired OAuth tokens from the database.""" + # Wait for completion - run_async(cleanup_expired_oauth_tokens()) + async def _cleanup(): + db = get_database_manager_async_client() + return await db.cleanup_expired_oauth_tokens() + + run_async(_cleanup()) def execution_accuracy_alerts(): diff --git a/autogpt_platform/backend/backend/integrations/webhooks/utils_test.py b/autogpt_platform/backend/backend/integrations/webhooks/utils_test.py new file mode 100644 index 0000000000..bc502a8e44 --- /dev/null +++ b/autogpt_platform/backend/backend/integrations/webhooks/utils_test.py @@ -0,0 +1,39 @@ +from urllib.parse import urlparse + +import fastapi +from fastapi.routing import APIRoute + +from backend.api.features.integrations.router import router as integrations_router +from backend.integrations.providers import ProviderName +from backend.integrations.webhooks import utils as webhooks_utils + + +def test_webhook_ingress_url_matches_route(monkeypatch) -> None: + app = fastapi.FastAPI() + app.include_router(integrations_router, prefix="/api/integrations") + + provider = ProviderName.GITHUB + webhook_id = "webhook_123" + base_url = "https://example.com" + + monkeypatch.setattr(webhooks_utils.app_config, "platform_base_url", base_url) + + route = next( + route + for route in integrations_router.routes + if isinstance(route, APIRoute) + and route.path == "/{provider}/webhooks/{webhook_id}/ingress" + and "POST" in route.methods + ) + expected_path = f"/api/integrations{route.path}".format( + provider=provider.value, + webhook_id=webhook_id, + ) + actual_url = urlparse(webhooks_utils.webhook_ingress_url(provider, webhook_id)) + expected_base = urlparse(base_url) + + assert (actual_url.scheme, actual_url.netloc) == ( + expected_base.scheme, + expected_base.netloc, + ) + assert actual_url.path == expected_path diff --git a/autogpt_platform/backend/backend/util/prompt.py b/autogpt_platform/backend/backend/util/prompt.py index 775d1c932b..5f904bbc8a 100644 --- a/autogpt_platform/backend/backend/util/prompt.py +++ b/autogpt_platform/backend/backend/util/prompt.py @@ -1,10 +1,19 @@ +from __future__ import annotations + +import logging from copy import deepcopy -from typing import Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any from tiktoken import encoding_for_model from backend.util import json +if TYPE_CHECKING: + from openai import AsyncOpenAI + +logger = logging.getLogger(__name__) + # ---------------------------------------------------------------------------# # CONSTANTS # # ---------------------------------------------------------------------------# @@ -100,9 +109,17 @@ def _is_objective_message(msg: dict) -> bool: def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None: """ Carefully truncate tool message content while preserving tool structure. - Only truncates tool_result content, leaves tool_use intact. + Handles both Anthropic-style (list content) and OpenAI-style (string content) tool messages. """ content = msg.get("content") + + # OpenAI-style tool message: role="tool" with string content + if msg.get("role") == "tool" and isinstance(content, str): + if _tok_len(content, enc) > max_tokens: + msg["content"] = _truncate_middle_tokens(content, enc, max_tokens) + return + + # Anthropic-style: list content with tool_result items if not isinstance(content, list): return @@ -140,141 +157,6 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str: # ---------------------------------------------------------------------------# -def compress_prompt( - messages: list[dict], - target_tokens: int, - *, - model: str = "gpt-4o", - reserve: int = 2_048, - start_cap: int = 8_192, - floor_cap: int = 128, - lossy_ok: bool = True, -) -> list[dict]: - """ - Shrink *messages* so that:: - - token_count(prompt) + reserve ≤ target_tokens - - Strategy - -------- - 1. **Token-aware truncation** – progressively halve a per-message cap - (`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the - *content* of every message except the first and last. Tool shells - are included: we keep the envelope but shorten huge payloads. - 2. **Middle-out deletion** – if still over the limit, delete whole - messages working outward from the centre, **skipping** any message - that contains ``tool_calls`` or has ``role == "tool"``. - 3. **Last-chance trim** – if still too big, truncate the *first* and - *last* message bodies down to `floor_cap` tokens. - 4. If the prompt is *still* too large: - • raise ``ValueError`` when ``lossy_ok == False`` (default) - • return the partially-trimmed prompt when ``lossy_ok == True`` - - Parameters - ---------- - messages Complete chat history (will be deep-copied). - model Model name; passed to tiktoken to pick the right - tokenizer (gpt-4o → 'o200k_base', others fallback). - target_tokens Hard ceiling for prompt size **excluding** the model's - forthcoming answer. - reserve How many tokens you want to leave available for that - answer (`max_tokens` in your subsequent completion call). - start_cap Initial per-message truncation ceiling (tokens). - floor_cap Lowest cap we'll accept before moving to deletions. - lossy_ok If *True* return best-effort prompt instead of raising - after all trim passes have been exhausted. - - Returns - ------- - list[dict] – A *new* messages list that abides by the rules above. - """ - enc = encoding_for_model(model) # best-match tokenizer - msgs = deepcopy(messages) # never mutate caller - - def total_tokens() -> int: - """Current size of *msgs* in tokens.""" - return sum(_msg_tokens(m, enc) for m in msgs) - - original_token_count = total_tokens() - - if original_token_count + reserve <= target_tokens: - return msgs - - # ---- STEP 0 : normalise content -------------------------------------- - # Convert non-string payloads to strings so token counting is coherent. - for i, m in enumerate(msgs): - if not isinstance(m.get("content"), str) and m.get("content") is not None: - if _is_tool_message(m): - continue - - # Keep first and last messages intact (unless they're tool messages) - if i == 0 or i == len(msgs) - 1: - continue - - # Reasonable 20k-char ceiling prevents pathological blobs - content_str = json.dumps(m["content"], separators=(",", ":")) - if len(content_str) > 20_000: - content_str = _truncate_middle_tokens(content_str, enc, 20_000) - m["content"] = content_str - - # ---- STEP 1 : token-aware truncation --------------------------------- - cap = start_cap - while total_tokens() + reserve > target_tokens and cap >= floor_cap: - for m in msgs[1:-1]: # keep first & last intact - if _is_tool_message(m): - # For tool messages, only truncate tool result content, preserve structure - _truncate_tool_message_content(m, enc, cap) - continue - - if _is_objective_message(m): - # Never truncate objective messages - they contain the core task - continue - - content = m.get("content") or "" - if _tok_len(content, enc) > cap: - m["content"] = _truncate_middle_tokens(content, enc, cap) - cap //= 2 # tighten the screw - - # ---- STEP 2 : middle-out deletion ----------------------------------- - while total_tokens() + reserve > target_tokens and len(msgs) > 2: - # Identify all deletable messages (not first/last, not tool messages, not objective messages) - deletable_indices = [] - for i in range(1, len(msgs) - 1): # Skip first and last - if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]): - deletable_indices.append(i) - - if not deletable_indices: - break # nothing more we can drop - - # Delete from center outward - find the index closest to center - centre = len(msgs) // 2 - to_delete = min(deletable_indices, key=lambda i: abs(i - centre)) - del msgs[to_delete] - - # ---- STEP 3 : final safety-net trim on first & last ------------------ - cap = start_cap - while total_tokens() + reserve > target_tokens and cap >= floor_cap: - for idx in (0, -1): # first and last - if _is_tool_message(msgs[idx]): - # For tool messages at first/last position, truncate tool result content only - _truncate_tool_message_content(msgs[idx], enc, cap) - continue - - text = msgs[idx].get("content") or "" - if _tok_len(text, enc) > cap: - msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap) - cap //= 2 # tighten the screw - - # ---- STEP 4 : success or fail-gracefully ----------------------------- - if total_tokens() + reserve > target_tokens and not lossy_ok: - raise ValueError( - "compress_prompt: prompt still exceeds budget " - f"({total_tokens() + reserve} > {target_tokens})." - ) - - return msgs - - def estimate_token_count( messages: list[dict], *, @@ -293,7 +175,8 @@ def estimate_token_count( ------- int – Token count. """ - enc = encoding_for_model(model) # best-match tokenizer + token_model = _normalize_model_for_tokenizer(model) + enc = encoding_for_model(token_model) return sum(_msg_tokens(m, enc) for m in messages) @@ -315,6 +198,543 @@ def estimate_token_count_str( ------- int – Token count. """ - enc = encoding_for_model(model) # best-match tokenizer + token_model = _normalize_model_for_tokenizer(model) + enc = encoding_for_model(token_model) text = json.dumps(text) if not isinstance(text, str) else text return _tok_len(text, enc) + + +# ---------------------------------------------------------------------------# +# UNIFIED CONTEXT COMPRESSION # +# ---------------------------------------------------------------------------# + +# Default thresholds +DEFAULT_TOKEN_THRESHOLD = 120_000 +DEFAULT_KEEP_RECENT = 15 + + +@dataclass +class CompressResult: + """Result of context compression.""" + + messages: list[dict] + token_count: int + was_compacted: bool + error: str | None = None + original_token_count: int = 0 + messages_summarized: int = 0 + messages_dropped: int = 0 + + +def _normalize_model_for_tokenizer(model: str) -> str: + """Normalize model name for tiktoken tokenizer selection.""" + if "/" in model: + model = model.split("/")[-1] + if "claude" in model.lower() or not any( + known in model.lower() for known in ["gpt", "o1", "chatgpt", "text-"] + ): + return "gpt-4o" + return model + + +def _extract_tool_call_ids_from_message(msg: dict) -> set[str]: + """ + Extract tool_call IDs from an assistant message. + + Supports both formats: + - OpenAI: {"role": "assistant", "tool_calls": [{"id": "..."}]} + - Anthropic: {"role": "assistant", "content": [{"type": "tool_use", "id": "..."}]} + + Returns: + Set of tool_call IDs found in the message. + """ + ids: set[str] = set() + if msg.get("role") != "assistant": + return ids + + # OpenAI format: tool_calls array + if msg.get("tool_calls"): + for tc in msg["tool_calls"]: + tc_id = tc.get("id") + if tc_id: + ids.add(tc_id) + + # Anthropic format: content list with tool_use blocks + content = msg.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_use": + tc_id = block.get("id") + if tc_id: + ids.add(tc_id) + + return ids + + +def _extract_tool_response_ids_from_message(msg: dict) -> set[str]: + """ + Extract tool_call IDs that this message is responding to. + + Supports both formats: + - OpenAI: {"role": "tool", "tool_call_id": "..."} + - Anthropic: {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "..."}]} + + Returns: + Set of tool_call IDs this message responds to. + """ + ids: set[str] = set() + + # OpenAI format: role=tool with tool_call_id + if msg.get("role") == "tool": + tc_id = msg.get("tool_call_id") + if tc_id: + ids.add(tc_id) + + # Anthropic format: content list with tool_result blocks + content = msg.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_result": + tc_id = block.get("tool_use_id") + if tc_id: + ids.add(tc_id) + + return ids + + +def _is_tool_response_message(msg: dict) -> bool: + """Check if message is a tool response (OpenAI or Anthropic format).""" + # OpenAI format + if msg.get("role") == "tool": + return True + # Anthropic format + content = msg.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_result": + return True + return False + + +def _remove_orphan_tool_responses( + messages: list[dict], orphan_ids: set[str] +) -> list[dict]: + """ + Remove tool response messages/blocks that reference orphan tool_call IDs. + + Supports both OpenAI and Anthropic formats. + For Anthropic messages with mixed valid/orphan tool_result blocks, + filters out only the orphan blocks instead of dropping the entire message. + """ + result = [] + for msg in messages: + # OpenAI format: role=tool - drop entire message if orphan + if msg.get("role") == "tool": + tc_id = msg.get("tool_call_id") + if tc_id and tc_id in orphan_ids: + continue + result.append(msg) + continue + + # Anthropic format: content list may have mixed tool_result blocks + content = msg.get("content") + if isinstance(content, list): + has_tool_results = any( + isinstance(b, dict) and b.get("type") == "tool_result" for b in content + ) + if has_tool_results: + # Filter out orphan tool_result blocks, keep valid ones + filtered_content = [ + block + for block in content + if not ( + isinstance(block, dict) + and block.get("type") == "tool_result" + and block.get("tool_use_id") in orphan_ids + ) + ] + # Only keep message if it has remaining content + if filtered_content: + msg = msg.copy() + msg["content"] = filtered_content + result.append(msg) + continue + + result.append(msg) + return result + + +def _ensure_tool_pairs_intact( + recent_messages: list[dict], + all_messages: list[dict], + start_index: int, +) -> list[dict]: + """ + Ensure tool_call/tool_response pairs stay together after slicing. + + When slicing messages for context compaction, a naive slice can separate + an assistant message containing tool_calls from its corresponding tool + response messages. This causes API validation errors (e.g., Anthropic's + "unexpected tool_use_id found in tool_result blocks"). + + This function checks for orphan tool responses in the slice and extends + backwards to include their corresponding assistant messages. + + Supports both formats: + - OpenAI: tool_calls array + role="tool" responses + - Anthropic: tool_use blocks + tool_result blocks + + Args: + recent_messages: The sliced messages to validate + all_messages: The complete message list (for looking up missing assistants) + start_index: The index in all_messages where recent_messages begins + + Returns: + A potentially extended list of messages with tool pairs intact + """ + if not recent_messages: + return recent_messages + + # Collect all tool_call_ids from assistant messages in the slice + available_tool_call_ids: set[str] = set() + for msg in recent_messages: + available_tool_call_ids |= _extract_tool_call_ids_from_message(msg) + + # Find orphan tool responses (responses whose tool_call_id is missing) + orphan_tool_call_ids: set[str] = set() + for msg in recent_messages: + response_ids = _extract_tool_response_ids_from_message(msg) + for tc_id in response_ids: + if tc_id not in available_tool_call_ids: + orphan_tool_call_ids.add(tc_id) + + if not orphan_tool_call_ids: + # No orphans, slice is valid + return recent_messages + + # Find the assistant messages that contain the orphan tool_call_ids + # Search backwards from start_index in all_messages + messages_to_prepend: list[dict] = [] + for i in range(start_index - 1, -1, -1): + msg = all_messages[i] + msg_tool_ids = _extract_tool_call_ids_from_message(msg) + if msg_tool_ids & orphan_tool_call_ids: + # This assistant message has tool_calls we need + # Also collect its contiguous tool responses that follow it + assistant_and_responses: list[dict] = [msg] + + # Scan forward from this assistant to collect tool responses + for j in range(i + 1, start_index): + following_msg = all_messages[j] + following_response_ids = _extract_tool_response_ids_from_message( + following_msg + ) + if following_response_ids and following_response_ids & msg_tool_ids: + assistant_and_responses.append(following_msg) + elif not _is_tool_response_message(following_msg): + # Stop at first non-tool-response message + break + + # Prepend the assistant and its tool responses (maintain order) + messages_to_prepend = assistant_and_responses + messages_to_prepend + # Mark these as found + orphan_tool_call_ids -= msg_tool_ids + # Also add this assistant's tool_call_ids to available set + available_tool_call_ids |= msg_tool_ids + + if not orphan_tool_call_ids: + # Found all missing assistants + break + + if orphan_tool_call_ids: + # Some tool_call_ids couldn't be resolved - remove those tool responses + # This shouldn't happen in normal operation but handles edge cases + logger.warning( + f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. " + "Removing orphan tool responses." + ) + recent_messages = _remove_orphan_tool_responses( + recent_messages, orphan_tool_call_ids + ) + + if messages_to_prepend: + logger.info( + f"Extended recent messages by {len(messages_to_prepend)} to preserve " + f"tool_call/tool_response pairs" + ) + return messages_to_prepend + recent_messages + + return recent_messages + + +async def _summarize_messages_llm( + messages: list[dict], + client: AsyncOpenAI, + model: str, + timeout: float = 30.0, +) -> str: + """Summarize messages using an LLM.""" + conversation = [] + for msg in messages: + role = msg.get("role", "") + content = msg.get("content", "") + if content and role in ("user", "assistant", "tool"): + conversation.append(f"{role.upper()}: {content}") + + conversation_text = "\n\n".join(conversation) + + if not conversation_text: + return "No conversation history available." + + # Limit to ~100k chars for safety + MAX_CHARS = 100_000 + if len(conversation_text) > MAX_CHARS: + conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]" + + response = await client.with_options(timeout=timeout).chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": ( + "Create a detailed summary of the conversation so far. " + "This summary will be used as context when continuing the conversation.\n\n" + "Before writing the summary, analyze each message chronologically to identify:\n" + "- User requests and their explicit goals\n" + "- Your approach and key decisions made\n" + "- Technical specifics (file names, tool outputs, function signatures)\n" + "- Errors encountered and resolutions applied\n\n" + "You MUST include ALL of the following sections:\n\n" + "## 1. Primary Request and Intent\n" + "The user's explicit goals and what they are trying to accomplish.\n\n" + "## 2. Key Technical Concepts\n" + "Technologies, frameworks, tools, and patterns being used or discussed.\n\n" + "## 3. Files and Resources Involved\n" + "Specific files examined or modified, with relevant snippets and identifiers.\n\n" + "## 4. Errors and Fixes\n" + "Problems encountered, error messages, and their resolutions. " + "Include any user feedback on fixes.\n\n" + "## 5. Problem Solving\n" + "Issues that have been resolved and how they were addressed.\n\n" + "## 6. All User Messages\n" + "A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n" + "## 7. Pending Tasks\n" + "Work items the user explicitly requested that have not yet been completed.\n\n" + "## 8. Current Work\n" + "Precise description of what was being worked on most recently, including relevant context.\n\n" + "## 9. Next Steps\n" + "What should happen next, aligned with the user's most recent requests. " + "Include verbatim quotes of recent instructions if relevant." + ), + }, + {"role": "user", "content": f"Summarize:\n\n{conversation_text}"}, + ], + max_tokens=1500, + temperature=0.3, + ) + + return response.choices[0].message.content or "No summary available." + + +async def compress_context( + messages: list[dict], + target_tokens: int = DEFAULT_TOKEN_THRESHOLD, + *, + model: str = "gpt-4o", + client: AsyncOpenAI | None = None, + keep_recent: int = DEFAULT_KEEP_RECENT, + reserve: int = 2_048, + start_cap: int = 8_192, + floor_cap: int = 128, +) -> CompressResult: + """ + Unified context compression that combines summarization and truncation strategies. + + Strategy (in order): + 1. **LLM summarization** – If client provided, summarize old messages into a + single context message while keeping recent messages intact. This is the + primary strategy for chat service. + 2. **Content truncation** – Progressively halve a per-message cap and truncate + bloated message content (tool outputs, large pastes). Preserves all messages + but shortens their content. Primary strategy when client=None (LLM blocks). + 3. **Middle-out deletion** – Delete whole messages one at a time from the center + outward, skipping tool messages and objective messages. + 4. **First/last trim** – Truncate first and last message content as last resort. + + Parameters + ---------- + messages Complete chat history (will be deep-copied). + target_tokens Hard ceiling for prompt size. + model Model name for tokenization and summarization. + client AsyncOpenAI client. If provided, enables LLM summarization + as the first strategy. If None, skips to truncation strategies. + keep_recent Number of recent messages to preserve during summarization. + reserve Tokens to reserve for model response. + start_cap Initial per-message truncation ceiling (tokens). + floor_cap Lowest cap before moving to deletions. + + Returns + ------- + CompressResult with compressed messages and metadata. + """ + # Guard clause for empty messages + if not messages: + return CompressResult( + messages=[], + token_count=0, + was_compacted=False, + original_token_count=0, + ) + + token_model = _normalize_model_for_tokenizer(model) + enc = encoding_for_model(token_model) + msgs = deepcopy(messages) + + def total_tokens() -> int: + return sum(_msg_tokens(m, enc) for m in msgs) + + original_count = total_tokens() + + # Already under limit + if original_count + reserve <= target_tokens: + return CompressResult( + messages=msgs, + token_count=original_count, + was_compacted=False, + original_token_count=original_count, + ) + + messages_summarized = 0 + messages_dropped = 0 + + # ---- STEP 1: LLM summarization (if client provided) ------------------- + # This is the primary compression strategy for chat service. + # Summarize old messages while keeping recent ones intact. + if client is not None: + has_system = len(msgs) > 0 and msgs[0].get("role") == "system" + system_msg = msgs[0] if has_system else None + + # Calculate old vs recent messages + if has_system: + if len(msgs) > keep_recent + 1: + old_msgs = msgs[1:-keep_recent] + recent_msgs = msgs[-keep_recent:] + else: + old_msgs = [] + recent_msgs = msgs[1:] if len(msgs) > 1 else [] + else: + if len(msgs) > keep_recent: + old_msgs = msgs[:-keep_recent] + recent_msgs = msgs[-keep_recent:] + else: + old_msgs = [] + recent_msgs = msgs + + # Ensure tool pairs stay intact + slice_start = max(0, len(msgs) - keep_recent) + recent_msgs = _ensure_tool_pairs_intact(recent_msgs, msgs, slice_start) + + if old_msgs: + try: + summary_text = await _summarize_messages_llm(old_msgs, client, model) + summary_msg = { + "role": "assistant", + "content": f"[Previous conversation summary — for context only]: {summary_text}", + } + messages_summarized = len(old_msgs) + + if has_system: + msgs = [system_msg, summary_msg] + recent_msgs + else: + msgs = [summary_msg] + recent_msgs + + logger.info( + f"Context summarized: {original_count} -> {total_tokens()} tokens, " + f"summarized {messages_summarized} messages" + ) + except Exception as e: + logger.warning(f"Summarization failed, continuing with truncation: {e}") + # Fall through to content truncation + + # ---- STEP 2: Normalize content ---------------------------------------- + # Convert non-string payloads to strings so token counting is coherent. + # Always run this before truncation to ensure consistent token counting. + for i, m in enumerate(msgs): + if not isinstance(m.get("content"), str) and m.get("content") is not None: + if _is_tool_message(m): + continue + if i == 0 or i == len(msgs) - 1: + continue + content_str = json.dumps(m["content"], separators=(",", ":")) + if len(content_str) > 20_000: + content_str = _truncate_middle_tokens(content_str, enc, 20_000) + m["content"] = content_str + + # ---- STEP 3: Token-aware content truncation --------------------------- + # Progressively halve per-message cap and truncate bloated content. + # This preserves all messages but shortens their content. + cap = start_cap + while total_tokens() + reserve > target_tokens and cap >= floor_cap: + for m in msgs[1:-1]: + if _is_tool_message(m): + _truncate_tool_message_content(m, enc, cap) + continue + if _is_objective_message(m): + continue + content = m.get("content") or "" + if _tok_len(content, enc) > cap: + m["content"] = _truncate_middle_tokens(content, enc, cap) + cap //= 2 + + # ---- STEP 4: Middle-out deletion -------------------------------------- + # Delete messages one at a time from the center outward. + # This is more granular than dropping all old messages at once. + while total_tokens() + reserve > target_tokens and len(msgs) > 2: + deletable: list[int] = [] + for i in range(1, len(msgs) - 1): + msg = msgs[i] + if ( + msg is not None + and not _is_tool_message(msg) + and not _is_objective_message(msg) + ): + deletable.append(i) + if not deletable: + break + centre = len(msgs) // 2 + to_delete = min(deletable, key=lambda i: abs(i - centre)) + del msgs[to_delete] + messages_dropped += 1 + + # ---- STEP 5: Final trim on first/last --------------------------------- + cap = start_cap + while total_tokens() + reserve > target_tokens and cap >= floor_cap: + for idx in (0, -1): + msg = msgs[idx] + if msg is None: + continue + if _is_tool_message(msg): + _truncate_tool_message_content(msg, enc, cap) + continue + text = msg.get("content") or "" + if _tok_len(text, enc) > cap: + msg["content"] = _truncate_middle_tokens(text, enc, cap) + cap //= 2 + + # Filter out any None values that may have been introduced + final_msgs: list[dict] = [m for m in msgs if m is not None] + final_count = sum(_msg_tokens(m, enc) for m in final_msgs) + error = None + if final_count + reserve > target_tokens: + error = f"Could not compress below target ({final_count + reserve} > {target_tokens})" + logger.warning(error) + + return CompressResult( + messages=final_msgs, + token_count=final_count, + was_compacted=True, + error=error, + original_token_count=original_count, + messages_summarized=messages_summarized, + messages_dropped=messages_dropped, + ) diff --git a/autogpt_platform/backend/backend/util/prompt_test.py b/autogpt_platform/backend/backend/util/prompt_test.py index af6b230f8f..2d4bf090b3 100644 --- a/autogpt_platform/backend/backend/util/prompt_test.py +++ b/autogpt_platform/backend/backend/util/prompt_test.py @@ -1,10 +1,21 @@ """Tests for prompt utility functions, especially tool call token counting.""" +from unittest.mock import AsyncMock, MagicMock + import pytest from tiktoken import encoding_for_model from backend.util import json -from backend.util.prompt import _msg_tokens, estimate_token_count +from backend.util.prompt import ( + CompressResult, + _ensure_tool_pairs_intact, + _msg_tokens, + _normalize_model_for_tokenizer, + _truncate_middle_tokens, + _truncate_tool_message_content, + compress_context, + estimate_token_count, +) class TestMsgTokens: @@ -276,3 +287,690 @@ class TestEstimateTokenCount: assert total_tokens == expected_total assert total_tokens > 20 # Should be substantial + + +class TestNormalizeModelForTokenizer: + """Test model name normalization for tiktoken.""" + + def test_openai_models_unchanged(self): + """Test that OpenAI models are returned as-is.""" + assert _normalize_model_for_tokenizer("gpt-4o") == "gpt-4o" + assert _normalize_model_for_tokenizer("gpt-4") == "gpt-4" + assert _normalize_model_for_tokenizer("gpt-3.5-turbo") == "gpt-3.5-turbo" + + def test_claude_models_normalized(self): + """Test that Claude models are normalized to gpt-4o.""" + assert _normalize_model_for_tokenizer("claude-3-opus") == "gpt-4o" + assert _normalize_model_for_tokenizer("claude-3-sonnet") == "gpt-4o" + assert _normalize_model_for_tokenizer("anthropic/claude-3-haiku") == "gpt-4o" + + def test_openrouter_paths_extracted(self): + """Test that OpenRouter model paths are handled.""" + assert _normalize_model_for_tokenizer("openai/gpt-4o") == "gpt-4o" + assert _normalize_model_for_tokenizer("anthropic/claude-3-opus") == "gpt-4o" + + def test_unknown_models_default_to_gpt4o(self): + """Test that unknown models default to gpt-4o.""" + assert _normalize_model_for_tokenizer("some-random-model") == "gpt-4o" + assert _normalize_model_for_tokenizer("llama-3-70b") == "gpt-4o" + + +class TestTruncateToolMessageContent: + """Test tool message content truncation.""" + + @pytest.fixture + def enc(self): + return encoding_for_model("gpt-4o") + + def test_truncate_openai_tool_message(self, enc): + """Test truncation of OpenAI-style tool message with string content.""" + long_content = "x" * 10000 + msg = {"role": "tool", "tool_call_id": "call_123", "content": long_content} + + _truncate_tool_message_content(msg, enc, max_tokens=100) + + # Content should be truncated + assert len(msg["content"]) < len(long_content) + assert "…" in msg["content"] # Has ellipsis marker + + def test_truncate_anthropic_tool_result(self, enc): + """Test truncation of Anthropic-style tool_result.""" + long_content = "y" * 10000 + msg = { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_123", + "content": long_content, + } + ], + } + + _truncate_tool_message_content(msg, enc, max_tokens=100) + + # Content should be truncated + result_content = msg["content"][0]["content"] + assert len(result_content) < len(long_content) + assert "…" in result_content + + def test_preserve_tool_use_blocks(self, enc): + """Test that tool_use blocks are not truncated.""" + msg = { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_123", + "name": "some_function", + "input": {"key": "value" * 1000}, # Large input + } + ], + } + + original = json.dumps(msg["content"][0]["input"]) + _truncate_tool_message_content(msg, enc, max_tokens=10) + + # tool_use should be unchanged + assert json.dumps(msg["content"][0]["input"]) == original + + def test_no_truncation_when_under_limit(self, enc): + """Test that short content is not modified.""" + msg = {"role": "tool", "tool_call_id": "call_123", "content": "Short content"} + + original = msg["content"] + _truncate_tool_message_content(msg, enc, max_tokens=1000) + + assert msg["content"] == original + + +class TestTruncateMiddleTokens: + """Test middle truncation of text.""" + + @pytest.fixture + def enc(self): + return encoding_for_model("gpt-4o") + + def test_truncates_long_text(self, enc): + """Test that long text is truncated with ellipsis in middle.""" + long_text = "word " * 1000 + result = _truncate_middle_tokens(long_text, enc, max_tok=50) + + assert len(enc.encode(result)) <= 52 # Allow some slack for ellipsis + assert "…" in result + assert result.startswith("word") # Head preserved + assert result.endswith("word ") # Tail preserved + + def test_preserves_short_text(self, enc): + """Test that short text is not modified.""" + short_text = "Hello world" + result = _truncate_middle_tokens(short_text, enc, max_tok=100) + + assert result == short_text + + +class TestEnsureToolPairsIntact: + """Test tool call/response pair preservation for both OpenAI and Anthropic formats.""" + + # ---- OpenAI Format Tests ---- + + def test_openai_adds_missing_tool_call(self): + """Test that orphaned OpenAI tool_response gets its tool_call prepended.""" + all_msgs = [ + {"role": "system", "content": "You are helpful."}, + { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "f1"}} + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result"}, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (the tool response) + recent = [all_msgs[2], all_msgs[3]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the tool_call message + assert len(result) == 3 + assert result[0]["role"] == "assistant" + assert "tool_calls" in result[0] + + def test_openai_keeps_complete_pairs(self): + """Test that complete OpenAI pairs are unchanged.""" + all_msgs = [ + {"role": "system", "content": "System"}, + { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "f1"}} + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result"}, + ] + recent = all_msgs[1:] # Include both tool_call and response + start_index = 1 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + assert len(result) == 2 # No messages added + + def test_openai_multiple_tool_calls(self): + """Test multiple OpenAI tool calls in one assistant message.""" + all_msgs = [ + {"role": "system", "content": "System"}, + { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "f1"}}, + {"id": "call_2", "type": "function", "function": {"name": "f2"}}, + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result1"}, + {"role": "tool", "tool_call_id": "call_2", "content": "result2"}, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (first tool response) + recent = [all_msgs[2], all_msgs[3], all_msgs[4]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the assistant message with both tool_calls + assert len(result) == 4 + assert result[0]["role"] == "assistant" + assert len(result[0]["tool_calls"]) == 2 + + # ---- Anthropic Format Tests ---- + + def test_anthropic_adds_missing_tool_use(self): + """Test that orphaned Anthropic tool_result gets its tool_use prepended.""" + all_msgs = [ + {"role": "system", "content": "You are helpful."}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_123", + "name": "get_weather", + "input": {"location": "SF"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_123", + "content": "22°C and sunny", + } + ], + }, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (the tool_result) + recent = [all_msgs[2], all_msgs[3]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the tool_use message + assert len(result) == 3 + assert result[0]["role"] == "assistant" + assert result[0]["content"][0]["type"] == "tool_use" + + def test_anthropic_keeps_complete_pairs(self): + """Test that complete Anthropic pairs are unchanged.""" + all_msgs = [ + {"role": "system", "content": "System"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_456", + "name": "calculator", + "input": {"expr": "2+2"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_456", + "content": "4", + } + ], + }, + ] + recent = all_msgs[1:] # Include both tool_use and result + start_index = 1 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + assert len(result) == 2 # No messages added + + def test_anthropic_multiple_tool_uses(self): + """Test multiple Anthropic tool_use blocks in one message.""" + all_msgs = [ + {"role": "system", "content": "System"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me check both..."}, + { + "type": "tool_use", + "id": "toolu_1", + "name": "get_weather", + "input": {"city": "NYC"}, + }, + { + "type": "tool_use", + "id": "toolu_2", + "name": "get_weather", + "input": {"city": "LA"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_1", + "content": "Cold", + }, + { + "type": "tool_result", + "tool_use_id": "toolu_2", + "content": "Warm", + }, + ], + }, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (tool_result) + recent = [all_msgs[2], all_msgs[3]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the assistant message with both tool_uses + assert len(result) == 3 + assert result[0]["role"] == "assistant" + tool_use_count = sum( + 1 for b in result[0]["content"] if b.get("type") == "tool_use" + ) + assert tool_use_count == 2 + + # ---- Mixed/Edge Case Tests ---- + + def test_anthropic_with_type_message_field(self): + """Test Anthropic format with 'type': 'message' field (smart_decision_maker style).""" + all_msgs = [ + {"role": "system", "content": "You are helpful."}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_abc", + "name": "search", + "input": {"q": "test"}, + } + ], + }, + { + "role": "user", + "type": "message", # Extra field from smart_decision_maker + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_abc", + "content": "Found results", + } + ], + }, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (the tool_result with 'type': 'message') + recent = [all_msgs[2], all_msgs[3]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the tool_use message + assert len(result) == 3 + assert result[0]["role"] == "assistant" + assert result[0]["content"][0]["type"] == "tool_use" + + def test_handles_no_tool_messages(self): + """Test messages without tool calls.""" + all_msgs = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + recent = all_msgs + start_index = 0 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + assert result == all_msgs + + def test_handles_empty_messages(self): + """Test empty message list.""" + result = _ensure_tool_pairs_intact([], [], 0) + assert result == [] + + def test_mixed_text_and_tool_content(self): + """Test Anthropic message with mixed text and tool_use content.""" + all_msgs = [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll help you with that."}, + { + "type": "tool_use", + "id": "toolu_mixed", + "name": "search", + "input": {"q": "test"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_mixed", + "content": "Found results", + } + ], + }, + {"role": "assistant", "content": "Here are the results..."}, + ] + # Start from tool_result + recent = [all_msgs[1], all_msgs[2]] + start_index = 1 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the assistant message with tool_use + assert len(result) == 3 + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][1]["type"] == "tool_use" + + +class TestCompressContext: + """Test the async compress_context function.""" + + @pytest.mark.asyncio + async def test_no_compression_needed(self): + """Test messages under limit return without compression.""" + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello!"}, + ] + + result = await compress_context(messages, target_tokens=100000) + + assert isinstance(result, CompressResult) + assert result.was_compacted is False + assert len(result.messages) == 2 + assert result.error is None + + @pytest.mark.asyncio + async def test_truncation_without_client(self): + """Test that truncation works without LLM client.""" + long_content = "x" * 50000 + messages = [ + {"role": "system", "content": "System"}, + {"role": "user", "content": long_content}, + {"role": "assistant", "content": "Response"}, + ] + + result = await compress_context( + messages, target_tokens=1000, client=None, reserve=100 + ) + + assert result.was_compacted is True + # Should have truncated without summarization + assert result.messages_summarized == 0 + + @pytest.mark.asyncio + async def test_with_mocked_llm_client(self): + """Test summarization with mocked LLM client.""" + # Create many messages to trigger summarization + messages = [{"role": "system", "content": "System prompt"}] + for i in range(30): + messages.append({"role": "user", "content": f"User message {i} " * 100}) + messages.append( + {"role": "assistant", "content": f"Assistant response {i} " * 100} + ) + + # Mock the AsyncOpenAI client + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Summary of conversation" + mock_client.with_options.return_value.chat.completions.create = AsyncMock( + return_value=mock_response + ) + + result = await compress_context( + messages, + target_tokens=5000, + client=mock_client, + keep_recent=5, + reserve=500, + ) + + assert result.was_compacted is True + # Should have attempted summarization + assert mock_client.with_options.called or result.messages_summarized > 0 + + @pytest.mark.asyncio + async def test_preserves_tool_pairs(self): + """Test that tool call/response pairs stay together.""" + messages = [ + {"role": "system", "content": "System"}, + {"role": "user", "content": "Do something"}, + { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "func"}} + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "Result " * 1000}, + {"role": "assistant", "content": "Done!"}, + ] + + result = await compress_context( + messages, target_tokens=500, client=None, reserve=50 + ) + + # Check that if tool response exists, its call exists too + tool_call_ids = set() + tool_response_ids = set() + for msg in result.messages: + if "tool_calls" in msg: + for tc in msg["tool_calls"]: + tool_call_ids.add(tc["id"]) + if msg.get("role") == "tool": + tool_response_ids.add(msg.get("tool_call_id")) + + # All tool responses should have their calls + assert tool_response_ids <= tool_call_ids + + @pytest.mark.asyncio + async def test_returns_error_when_cannot_compress(self): + """Test that error is returned when compression fails.""" + # Single huge message that can't be compressed enough + messages = [ + {"role": "user", "content": "x" * 100000}, + ] + + result = await compress_context( + messages, target_tokens=100, client=None, reserve=50 + ) + + # Should have an error since we can't get below 100 tokens + assert result.error is not None + assert result.was_compacted is True + + @pytest.mark.asyncio + async def test_empty_messages(self): + """Test that empty messages list returns early without error.""" + result = await compress_context([], target_tokens=1000) + + assert result.messages == [] + assert result.token_count == 0 + assert result.was_compacted is False + assert result.error is None + + +class TestRemoveOrphanToolResponses: + """Test _remove_orphan_tool_responses helper function.""" + + def test_removes_openai_orphan(self): + """Test removal of orphan OpenAI tool response.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + {"role": "tool", "tool_call_id": "call_orphan", "content": "result"}, + {"role": "user", "content": "Hello"}, + ] + orphan_ids = {"call_orphan"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + assert len(result) == 1 + assert result[0]["role"] == "user" + + def test_keeps_valid_openai_tool(self): + """Test that valid OpenAI tool responses are kept.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + {"role": "tool", "tool_call_id": "call_valid", "content": "result"}, + ] + orphan_ids = {"call_other"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + assert len(result) == 1 + assert result[0]["tool_call_id"] == "call_valid" + + def test_filters_anthropic_mixed_blocks(self): + """Test filtering individual orphan blocks from Anthropic message with mixed valid/orphan.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_valid", + "content": "valid result", + }, + { + "type": "tool_result", + "tool_use_id": "toolu_orphan", + "content": "orphan result", + }, + ], + }, + ] + orphan_ids = {"toolu_orphan"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + assert len(result) == 1 + # Should only have the valid tool_result, orphan filtered out + assert len(result[0]["content"]) == 1 + assert result[0]["content"][0]["tool_use_id"] == "toolu_valid" + + def test_removes_anthropic_all_orphan(self): + """Test removal of Anthropic message when all tool_results are orphans.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_orphan1", + "content": "result1", + }, + { + "type": "tool_result", + "tool_use_id": "toolu_orphan2", + "content": "result2", + }, + ], + }, + ] + orphan_ids = {"toolu_orphan1", "toolu_orphan2"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + # Message should be completely removed since no content left + assert len(result) == 0 + + def test_preserves_non_tool_messages(self): + """Test that non-tool messages are preserved.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + orphan_ids = {"some_id"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + assert result == messages + + +class TestCompressResultDataclass: + """Test CompressResult dataclass.""" + + def test_default_values(self): + """Test default values are set correctly.""" + result = CompressResult( + messages=[{"role": "user", "content": "test"}], + token_count=10, + was_compacted=False, + ) + + assert result.error is None + assert result.original_token_count == 0 # Defaults to 0, not None + assert result.messages_summarized == 0 + assert result.messages_dropped == 0 + + def test_all_fields(self): + """Test all fields can be set.""" + result = CompressResult( + messages=[{"role": "user", "content": "test"}], + token_count=100, + was_compacted=True, + error="Some error", + original_token_count=500, + messages_summarized=10, + messages_dropped=5, + ) + + assert result.token_count == 100 + assert result.was_compacted is True + assert result.error == "Some error" + assert result.original_token_count == 500 + assert result.messages_summarized == 10 + assert result.messages_dropped == 5 diff --git a/autogpt_platform/backend/backend/util/request.py b/autogpt_platform/backend/backend/util/request.py index 9744372b15..95e5ee32f7 100644 --- a/autogpt_platform/backend/backend/util/request.py +++ b/autogpt_platform/backend/backend/util/request.py @@ -157,12 +157,7 @@ async def validate_url( is_trusted: Boolean indicating if the hostname is in trusted_origins ip_addresses: List of IP addresses for the host; empty if the host is trusted """ - # Canonicalize URL - url = url.strip("/ ").replace("\\", "/") - parsed = urlparse(url) - if not parsed.scheme: - url = f"http://{url}" - parsed = urlparse(url) + parsed = parse_url(url) # Check scheme if parsed.scheme not in ALLOWED_SCHEMES: @@ -220,6 +215,17 @@ async def validate_url( ) +def parse_url(url: str) -> URL: + """Canonicalizes and parses a URL string.""" + url = url.strip("/ ").replace("\\", "/") + + # Ensure scheme is present for proper parsing + if not re.match(r"[a-z0-9+.\-]+://", url): + url = f"http://{url}" + + return urlparse(url) + + def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL: """ Pins a URL to a specific IP address to prevent DNS rebinding attacks. diff --git a/autogpt_platform/backend/migrations/20260126120000_migrate_claude_3_7_to_4_5_sonnet/migration.sql b/autogpt_platform/backend/migrations/20260126120000_migrate_claude_3_7_to_4_5_sonnet/migration.sql new file mode 100644 index 0000000000..5746c80820 --- /dev/null +++ b/autogpt_platform/backend/migrations/20260126120000_migrate_claude_3_7_to_4_5_sonnet/migration.sql @@ -0,0 +1,22 @@ +-- Migrate Claude 3.7 Sonnet to Claude 4.5 Sonnet +-- This updates all AgentNode blocks that use the deprecated Claude 3.7 Sonnet model +-- Anthropic is retiring claude-3-7-sonnet-20250219 on February 19, 2026 + +-- Update AgentNode constant inputs +UPDATE "AgentNode" +SET "constantInput" = JSONB_SET( + "constantInput"::jsonb, + '{model}', + '"claude-sonnet-4-5-20250929"'::jsonb + ) +WHERE "constantInput"::jsonb->>'model' = 'claude-3-7-sonnet-20250219'; + +-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput) +UPDATE "AgentNodeExecutionInputOutput" +SET "data" = JSONB_SET( + "data"::jsonb, + '{model}', + '"claude-sonnet-4-5-20250929"'::jsonb + ) +WHERE "agentPresetId" IS NOT NULL + AND "data"::jsonb->>'model' = 'claude-3-7-sonnet-20250219'; diff --git a/autogpt_platform/backend/snapshots/agts_by_creator b/autogpt_platform/backend/snapshots/agts_by_creator index 4d6dd12920..3f2e128a0d 100644 --- a/autogpt_platform/backend/snapshots/agts_by_creator +++ b/autogpt_platform/backend/snapshots/agts_by_creator @@ -9,7 +9,8 @@ "sub_heading": "Creator agent subheading", "description": "Creator agent description", "runs": 50, - "rating": 4.0 + "rating": 4.0, + "agent_graph_id": "test-graph-2" } ], "pagination": { diff --git a/autogpt_platform/backend/snapshots/agts_category b/autogpt_platform/backend/snapshots/agts_category index f65925ead3..4d0531763c 100644 --- a/autogpt_platform/backend/snapshots/agts_category +++ b/autogpt_platform/backend/snapshots/agts_category @@ -9,7 +9,8 @@ "sub_heading": "Category agent subheading", "description": "Category agent description", "runs": 60, - "rating": 4.1 + "rating": 4.1, + "agent_graph_id": "test-graph-category" } ], "pagination": { diff --git a/autogpt_platform/backend/snapshots/agts_pagination b/autogpt_platform/backend/snapshots/agts_pagination index 82e7f5f9bf..7b946157fb 100644 --- a/autogpt_platform/backend/snapshots/agts_pagination +++ b/autogpt_platform/backend/snapshots/agts_pagination @@ -9,7 +9,8 @@ "sub_heading": "Agent 0 subheading", "description": "Agent 0 description", "runs": 0, - "rating": 4.0 + "rating": 4.0, + "agent_graph_id": "test-graph-2" }, { "slug": "agent-1", @@ -20,7 +21,8 @@ "sub_heading": "Agent 1 subheading", "description": "Agent 1 description", "runs": 10, - "rating": 4.0 + "rating": 4.0, + "agent_graph_id": "test-graph-2" }, { "slug": "agent-2", @@ -31,7 +33,8 @@ "sub_heading": "Agent 2 subheading", "description": "Agent 2 description", "runs": 20, - "rating": 4.0 + "rating": 4.0, + "agent_graph_id": "test-graph-2" }, { "slug": "agent-3", @@ -42,7 +45,8 @@ "sub_heading": "Agent 3 subheading", "description": "Agent 3 description", "runs": 30, - "rating": 4.0 + "rating": 4.0, + "agent_graph_id": "test-graph-2" }, { "slug": "agent-4", @@ -53,7 +57,8 @@ "sub_heading": "Agent 4 subheading", "description": "Agent 4 description", "runs": 40, - "rating": 4.0 + "rating": 4.0, + "agent_graph_id": "test-graph-2" } ], "pagination": { diff --git a/autogpt_platform/backend/snapshots/agts_search b/autogpt_platform/backend/snapshots/agts_search index ca3f504584..ae9cc116bc 100644 --- a/autogpt_platform/backend/snapshots/agts_search +++ b/autogpt_platform/backend/snapshots/agts_search @@ -9,7 +9,8 @@ "sub_heading": "Search agent subheading", "description": "Specific search term description", "runs": 75, - "rating": 4.2 + "rating": 4.2, + "agent_graph_id": "test-graph-search" } ], "pagination": { diff --git a/autogpt_platform/backend/snapshots/agts_sorted b/autogpt_platform/backend/snapshots/agts_sorted index cddead76a5..b182256b2c 100644 --- a/autogpt_platform/backend/snapshots/agts_sorted +++ b/autogpt_platform/backend/snapshots/agts_sorted @@ -9,7 +9,8 @@ "sub_heading": "Top agent subheading", "description": "Top agent description", "runs": 1000, - "rating": 5.0 + "rating": 5.0, + "agent_graph_id": "test-graph-3" } ], "pagination": { diff --git a/autogpt_platform/backend/snapshots/feat_agts b/autogpt_platform/backend/snapshots/feat_agts index d57996a768..4f85786434 100644 --- a/autogpt_platform/backend/snapshots/feat_agts +++ b/autogpt_platform/backend/snapshots/feat_agts @@ -9,7 +9,8 @@ "sub_heading": "Featured agent subheading", "description": "Featured agent description", "runs": 100, - "rating": 4.5 + "rating": 4.5, + "agent_graph_id": "test-graph-1" } ], "pagination": { diff --git a/autogpt_platform/backend/snapshots/lib_agts_search b/autogpt_platform/backend/snapshots/lib_agts_search index 67c307b09e..3ce8402b63 100644 --- a/autogpt_platform/backend/snapshots/lib_agts_search +++ b/autogpt_platform/backend/snapshots/lib_agts_search @@ -31,6 +31,10 @@ "has_sensitive_action": false, "trigger_setup_info": null, "new_output": false, + "execution_count": 0, + "success_rate": null, + "avg_correctness_score": null, + "recent_executions": [], "can_access_graph": true, "is_latest_version": true, "is_favorite": false, @@ -72,6 +76,10 @@ "has_sensitive_action": false, "trigger_setup_info": null, "new_output": false, + "execution_count": 0, + "success_rate": null, + "avg_correctness_score": null, + "recent_executions": [], "can_access_graph": false, "is_latest_version": true, "is_favorite": false, diff --git a/autogpt_platform/backend/test/agent_generator/test_core_integration.py b/autogpt_platform/backend/test/agent_generator/test_core_integration.py index bdcc24ba79..528763e751 100644 --- a/autogpt_platform/backend/test/agent_generator/test_core_integration.py +++ b/autogpt_platform/backend/test/agent_generator/test_core_integration.py @@ -57,7 +57,8 @@ class TestDecomposeGoal: result = await core.decompose_goal("Build a chatbot") - mock_external.assert_called_once_with("Build a chatbot", "") + # library_agents defaults to None + mock_external.assert_called_once_with("Build a chatbot", "", None) assert result == expected_result @pytest.mark.asyncio @@ -74,7 +75,8 @@ class TestDecomposeGoal: await core.decompose_goal("Build a chatbot", "Use Python") - mock_external.assert_called_once_with("Build a chatbot", "Use Python") + # library_agents defaults to None + mock_external.assert_called_once_with("Build a chatbot", "Use Python", None) @pytest.mark.asyncio async def test_returns_none_on_service_failure(self): @@ -109,8 +111,7 @@ class TestGenerateAgent: instructions = {"type": "instructions", "steps": ["Step 1"]} result = await core.generate_agent(instructions) - mock_external.assert_called_once_with(instructions) - # Result should have id, version, is_active added if not present + mock_external.assert_called_once_with(instructions, None, None, None) assert result is not None assert result["name"] == "Test Agent" assert "id" in result @@ -174,7 +175,9 @@ class TestGenerateAgentPatch: current_agent = {"nodes": [], "links": []} result = await core.generate_agent_patch("Add a node", current_agent) - mock_external.assert_called_once_with("Add a node", current_agent) + mock_external.assert_called_once_with( + "Add a node", current_agent, None, None, None + ) assert result == expected_result @pytest.mark.asyncio diff --git a/autogpt_platform/backend/test/agent_generator/test_library_agents.py b/autogpt_platform/backend/test/agent_generator/test_library_agents.py new file mode 100644 index 0000000000..8387339582 --- /dev/null +++ b/autogpt_platform/backend/test/agent_generator/test_library_agents.py @@ -0,0 +1,857 @@ +""" +Tests for library agent fetching functionality in agent generator. + +This test suite verifies the search-based library agent fetching, +including the combination of library and marketplace agents. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from backend.api.features.chat.tools.agent_generator import core + + +class TestGetLibraryAgentsForGeneration: + """Test get_library_agents_for_generation function.""" + + @pytest.mark.asyncio + async def test_fetches_agents_with_search_term(self): + """Test that search_term is passed to the library db.""" + # Create a mock agent with proper attribute values + mock_agent = MagicMock() + mock_agent.graph_id = "agent-123" + mock_agent.graph_version = 1 + mock_agent.name = "Email Agent" + mock_agent.description = "Sends emails" + mock_agent.input_schema = {"properties": {}} + mock_agent.output_schema = {"properties": {}} + mock_agent.recent_executions = [] + + mock_response = MagicMock() + mock_response.agents = [mock_agent] + + with patch.object( + core.library_db, + "list_library_agents", + new_callable=AsyncMock, + return_value=mock_response, + ) as mock_list: + result = await core.get_library_agents_for_generation( + user_id="user-123", + search_query="send email", + ) + + mock_list.assert_called_once_with( + user_id="user-123", + search_term="send email", + page=1, + page_size=15, + include_executions=True, + ) + + # Verify result format + assert len(result) == 1 + assert result[0]["graph_id"] == "agent-123" + assert result[0]["name"] == "Email Agent" + + @pytest.mark.asyncio + async def test_excludes_specified_graph_id(self): + """Test that agents with excluded graph_id are filtered out.""" + mock_response = MagicMock() + mock_response.agents = [ + MagicMock( + graph_id="agent-123", + graph_version=1, + name="Agent 1", + description="First agent", + input_schema={}, + output_schema={}, + recent_executions=[], + ), + MagicMock( + graph_id="agent-456", + graph_version=1, + name="Agent 2", + description="Second agent", + input_schema={}, + output_schema={}, + recent_executions=[], + ), + ] + + with patch.object( + core.library_db, + "list_library_agents", + new_callable=AsyncMock, + return_value=mock_response, + ): + result = await core.get_library_agents_for_generation( + user_id="user-123", + exclude_graph_id="agent-123", + ) + + # Verify the excluded agent is not in results + assert len(result) == 1 + assert result[0]["graph_id"] == "agent-456" + + @pytest.mark.asyncio + async def test_respects_max_results(self): + """Test that max_results parameter limits the page_size.""" + mock_response = MagicMock() + mock_response.agents = [] + + with patch.object( + core.library_db, + "list_library_agents", + new_callable=AsyncMock, + return_value=mock_response, + ) as mock_list: + await core.get_library_agents_for_generation( + user_id="user-123", + max_results=5, + ) + + mock_list.assert_called_once_with( + user_id="user-123", + search_term=None, + page=1, + page_size=5, + include_executions=True, + ) + + +class TestSearchMarketplaceAgentsForGeneration: + """Test search_marketplace_agents_for_generation function.""" + + @pytest.mark.asyncio + async def test_searches_marketplace_with_query(self): + """Test that marketplace is searched with the query.""" + mock_response = MagicMock() + mock_response.agents = [ + MagicMock( + agent_name="Public Agent", + description="A public agent", + sub_heading="Does something useful", + creator="creator-1", + agent_graph_id="graph-123", + ) + ] + + mock_graph = MagicMock() + mock_graph.id = "graph-123" + mock_graph.version = 1 + mock_graph.input_schema = {"type": "object"} + mock_graph.output_schema = {"type": "object"} + + with ( + patch( + "backend.api.features.store.db.get_store_agents", + new_callable=AsyncMock, + return_value=mock_response, + ) as mock_search, + patch( + "backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs", + new_callable=AsyncMock, + return_value={"graph-123": mock_graph}, + ), + ): + result = await core.search_marketplace_agents_for_generation( + search_query="automation", + max_results=10, + ) + + mock_search.assert_called_once_with( + search_query="automation", + page=1, + page_size=10, + ) + + assert len(result) == 1 + assert result[0]["name"] == "Public Agent" + assert result[0]["graph_id"] == "graph-123" + + @pytest.mark.asyncio + async def test_handles_marketplace_error_gracefully(self): + """Test that marketplace errors don't crash the function.""" + with patch( + "backend.api.features.store.db.get_store_agents", + new_callable=AsyncMock, + side_effect=Exception("Marketplace unavailable"), + ): + result = await core.search_marketplace_agents_for_generation( + search_query="test" + ) + + # Should return empty list, not raise exception + assert result == [] + + +class TestGetAllRelevantAgentsForGeneration: + """Test get_all_relevant_agents_for_generation function.""" + + @pytest.mark.asyncio + async def test_combines_library_and_marketplace_agents(self): + """Test that agents from both sources are combined.""" + library_agents = [ + { + "graph_id": "lib-123", + "graph_version": 1, + "name": "Library Agent", + "description": "From library", + "input_schema": {}, + "output_schema": {}, + } + ] + + marketplace_agents = [ + { + "graph_id": "market-456", + "graph_version": 1, + "name": "Market Agent", + "description": "From marketplace", + "input_schema": {}, + "output_schema": {}, + } + ] + + with patch.object( + core, + "get_library_agents_for_generation", + new_callable=AsyncMock, + return_value=library_agents, + ): + with patch.object( + core, + "search_marketplace_agents_for_generation", + new_callable=AsyncMock, + return_value=marketplace_agents, + ): + result = await core.get_all_relevant_agents_for_generation( + user_id="user-123", + search_query="test query", + include_marketplace=True, + ) + + # Library agents should come first + assert len(result) == 2 + assert result[0]["name"] == "Library Agent" + assert result[1]["name"] == "Market Agent" + + @pytest.mark.asyncio + async def test_deduplicates_by_graph_id(self): + """Test that marketplace agents with same graph_id as library are excluded.""" + library_agents = [ + { + "graph_id": "shared-123", + "graph_version": 1, + "name": "Shared Agent", + "description": "From library", + "input_schema": {}, + "output_schema": {}, + } + ] + + marketplace_agents = [ + { + "graph_id": "shared-123", # Same graph_id, should be deduplicated + "graph_version": 1, + "name": "Shared Agent", + "description": "From marketplace", + "input_schema": {}, + "output_schema": {}, + }, + { + "graph_id": "unique-456", + "graph_version": 1, + "name": "Unique Agent", + "description": "Only in marketplace", + "input_schema": {}, + "output_schema": {}, + }, + ] + + with patch.object( + core, + "get_library_agents_for_generation", + new_callable=AsyncMock, + return_value=library_agents, + ): + with patch.object( + core, + "search_marketplace_agents_for_generation", + new_callable=AsyncMock, + return_value=marketplace_agents, + ): + result = await core.get_all_relevant_agents_for_generation( + user_id="user-123", + search_query="test", + include_marketplace=True, + ) + + # Shared Agent from marketplace should be excluded by graph_id + assert len(result) == 2 + names = [a["name"] for a in result] + assert "Shared Agent" in names + assert "Unique Agent" in names + + @pytest.mark.asyncio + async def test_skips_marketplace_when_disabled(self): + """Test that marketplace is not searched when include_marketplace=False.""" + library_agents = [ + { + "graph_id": "lib-123", + "graph_version": 1, + "name": "Library Agent", + "description": "From library", + "input_schema": {}, + "output_schema": {}, + } + ] + + with patch.object( + core, + "get_library_agents_for_generation", + new_callable=AsyncMock, + return_value=library_agents, + ): + with patch.object( + core, + "search_marketplace_agents_for_generation", + new_callable=AsyncMock, + ) as mock_marketplace: + result = await core.get_all_relevant_agents_for_generation( + user_id="user-123", + search_query="test", + include_marketplace=False, + ) + + # Marketplace should not be called + mock_marketplace.assert_not_called() + assert len(result) == 1 + + @pytest.mark.asyncio + async def test_skips_marketplace_when_no_search_query(self): + """Test that marketplace is not searched without a search query.""" + library_agents = [ + { + "graph_id": "lib-123", + "graph_version": 1, + "name": "Library Agent", + "description": "From library", + "input_schema": {}, + "output_schema": {}, + } + ] + + with patch.object( + core, + "get_library_agents_for_generation", + new_callable=AsyncMock, + return_value=library_agents, + ): + with patch.object( + core, + "search_marketplace_agents_for_generation", + new_callable=AsyncMock, + ) as mock_marketplace: + result = await core.get_all_relevant_agents_for_generation( + user_id="user-123", + search_query=None, # No search query + include_marketplace=True, + ) + + # Marketplace should not be called without search query + mock_marketplace.assert_not_called() + assert len(result) == 1 + + +class TestExtractSearchTermsFromSteps: + """Test extract_search_terms_from_steps function.""" + + def test_extracts_terms_from_instructions_type(self): + """Test extraction from valid instructions decomposition result.""" + decomposition_result = { + "type": "instructions", + "steps": [ + { + "description": "Send an email notification", + "block_name": "GmailSendBlock", + }, + {"description": "Fetch weather data", "action": "Get weather API"}, + ], + } + + result = core.extract_search_terms_from_steps(decomposition_result) + + assert "Send an email notification" in result + assert "GmailSendBlock" in result + assert "Fetch weather data" in result + assert "Get weather API" in result + + def test_returns_empty_for_non_instructions_type(self): + """Test that non-instructions types return empty list.""" + decomposition_result = { + "type": "clarifying_questions", + "questions": [{"question": "What email?"}], + } + + result = core.extract_search_terms_from_steps(decomposition_result) + + assert result == [] + + def test_deduplicates_terms_case_insensitively(self): + """Test that duplicate terms are removed (case-insensitive).""" + decomposition_result = { + "type": "instructions", + "steps": [ + {"description": "Send Email", "name": "send email"}, + {"description": "Other task"}, + ], + } + + result = core.extract_search_terms_from_steps(decomposition_result) + + # Should only have one "send email" variant + email_terms = [t for t in result if "email" in t.lower()] + assert len(email_terms) == 1 + + def test_filters_short_terms(self): + """Test that terms with 3 or fewer characters are filtered out.""" + decomposition_result = { + "type": "instructions", + "steps": [ + {"description": "ab", "action": "xyz"}, # Both too short + {"description": "Valid term here"}, + ], + } + + result = core.extract_search_terms_from_steps(decomposition_result) + + assert "ab" not in result + assert "xyz" not in result + assert "Valid term here" in result + + def test_handles_empty_steps(self): + """Test handling of empty steps list.""" + decomposition_result = { + "type": "instructions", + "steps": [], + } + + result = core.extract_search_terms_from_steps(decomposition_result) + + assert result == [] + + +class TestEnrichLibraryAgentsFromSteps: + """Test enrich_library_agents_from_steps function.""" + + @pytest.mark.asyncio + async def test_enriches_with_additional_agents(self): + """Test that additional agents are found based on steps.""" + existing_agents = [ + { + "graph_id": "existing-123", + "graph_version": 1, + "name": "Existing Agent", + "description": "Already fetched", + "input_schema": {}, + "output_schema": {}, + } + ] + + additional_agents = [ + { + "graph_id": "new-456", + "graph_version": 1, + "name": "Email Agent", + "description": "For sending emails", + "input_schema": {}, + "output_schema": {}, + } + ] + + decomposition_result = { + "type": "instructions", + "steps": [ + {"description": "Send email notification"}, + ], + } + + with patch.object( + core, + "get_all_relevant_agents_for_generation", + new_callable=AsyncMock, + return_value=additional_agents, + ): + result = await core.enrich_library_agents_from_steps( + user_id="user-123", + decomposition_result=decomposition_result, + existing_agents=existing_agents, + ) + + # Should have both existing and new agents + assert len(result) == 2 + names = [a["name"] for a in result] + assert "Existing Agent" in names + assert "Email Agent" in names + + @pytest.mark.asyncio + async def test_deduplicates_by_graph_id(self): + """Test that agents with same graph_id are not duplicated.""" + existing_agents = [ + { + "graph_id": "agent-123", + "graph_version": 1, + "name": "Existing Agent", + "description": "Already fetched", + "input_schema": {}, + "output_schema": {}, + } + ] + + # Additional search returns same agent + additional_agents = [ + { + "graph_id": "agent-123", # Same ID + "graph_version": 1, + "name": "Existing Agent Copy", + "description": "Same agent different name", + "input_schema": {}, + "output_schema": {}, + } + ] + + decomposition_result = { + "type": "instructions", + "steps": [{"description": "Some action"}], + } + + with patch.object( + core, + "get_all_relevant_agents_for_generation", + new_callable=AsyncMock, + return_value=additional_agents, + ): + result = await core.enrich_library_agents_from_steps( + user_id="user-123", + decomposition_result=decomposition_result, + existing_agents=existing_agents, + ) + + # Should not duplicate + assert len(result) == 1 + + @pytest.mark.asyncio + async def test_deduplicates_by_name(self): + """Test that agents with same name are not duplicated.""" + existing_agents = [ + { + "graph_id": "agent-123", + "graph_version": 1, + "name": "Email Agent", + "description": "Already fetched", + "input_schema": {}, + "output_schema": {}, + } + ] + + # Additional search returns agent with same name but different ID + additional_agents = [ + { + "graph_id": "agent-456", # Different ID + "graph_version": 1, + "name": "Email Agent", # Same name + "description": "Different agent same name", + "input_schema": {}, + "output_schema": {}, + } + ] + + decomposition_result = { + "type": "instructions", + "steps": [{"description": "Send email"}], + } + + with patch.object( + core, + "get_all_relevant_agents_for_generation", + new_callable=AsyncMock, + return_value=additional_agents, + ): + result = await core.enrich_library_agents_from_steps( + user_id="user-123", + decomposition_result=decomposition_result, + existing_agents=existing_agents, + ) + + # Should not duplicate by name + assert len(result) == 1 + assert result[0].get("graph_id") == "agent-123" # Original kept + + @pytest.mark.asyncio + async def test_returns_existing_when_no_steps(self): + """Test that existing agents are returned when no search terms extracted.""" + existing_agents = [ + { + "graph_id": "existing-123", + "graph_version": 1, + "name": "Existing Agent", + "description": "Already fetched", + "input_schema": {}, + "output_schema": {}, + } + ] + + decomposition_result = { + "type": "clarifying_questions", # Not instructions type + "questions": [], + } + + result = await core.enrich_library_agents_from_steps( + user_id="user-123", + decomposition_result=decomposition_result, + existing_agents=existing_agents, + ) + + # Should return existing unchanged + assert result == existing_agents + + @pytest.mark.asyncio + async def test_limits_search_terms_to_three(self): + """Test that only first 3 search terms are used.""" + existing_agents = [] + + decomposition_result = { + "type": "instructions", + "steps": [ + {"description": "First action"}, + {"description": "Second action"}, + {"description": "Third action"}, + {"description": "Fourth action"}, + {"description": "Fifth action"}, + ], + } + + call_count = 0 + + async def mock_get_agents(*args, **kwargs): + nonlocal call_count + call_count += 1 + return [] + + with patch.object( + core, + "get_all_relevant_agents_for_generation", + side_effect=mock_get_agents, + ): + await core.enrich_library_agents_from_steps( + user_id="user-123", + decomposition_result=decomposition_result, + existing_agents=existing_agents, + ) + + # Should only make 3 calls (limited to first 3 terms) + assert call_count == 3 + + +class TestExtractUuidsFromText: + """Test extract_uuids_from_text function.""" + + def test_extracts_single_uuid(self): + """Test extraction of a single UUID from text.""" + text = "Use my agent 46631191-e8a8-486f-ad90-84f89738321d for this task" + result = core.extract_uuids_from_text(text) + assert len(result) == 1 + assert "46631191-e8a8-486f-ad90-84f89738321d" in result + + def test_extracts_multiple_uuids(self): + """Test extraction of multiple UUIDs from text.""" + text = ( + "Combine agents 11111111-1111-4111-8111-111111111111 " + "and 22222222-2222-4222-9222-222222222222" + ) + result = core.extract_uuids_from_text(text) + assert len(result) == 2 + assert "11111111-1111-4111-8111-111111111111" in result + assert "22222222-2222-4222-9222-222222222222" in result + + def test_deduplicates_uuids(self): + """Test that duplicate UUIDs are deduplicated.""" + text = ( + "Use 46631191-e8a8-486f-ad90-84f89738321d twice: " + "46631191-e8a8-486f-ad90-84f89738321d" + ) + result = core.extract_uuids_from_text(text) + assert len(result) == 1 + + def test_normalizes_to_lowercase(self): + """Test that UUIDs are normalized to lowercase.""" + text = "Use 46631191-E8A8-486F-AD90-84F89738321D" + result = core.extract_uuids_from_text(text) + assert result[0] == "46631191-e8a8-486f-ad90-84f89738321d" + + def test_returns_empty_for_no_uuids(self): + """Test that empty list is returned when no UUIDs found.""" + text = "Create an email agent that sends notifications" + result = core.extract_uuids_from_text(text) + assert result == [] + + def test_ignores_invalid_uuids(self): + """Test that invalid UUID-like strings are ignored.""" + text = "Not a valid UUID: 12345678-1234-1234-1234-123456789abc" + result = core.extract_uuids_from_text(text) + # UUID v4 requires specific patterns (4 in third group, 8/9/a/b in fourth) + assert len(result) == 0 + + +class TestGetLibraryAgentById: + """Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id).""" + + @pytest.mark.asyncio + async def test_returns_agent_when_found_by_graph_id(self): + """Test that agent is returned when found by graph_id.""" + mock_agent = MagicMock() + mock_agent.graph_id = "agent-123" + mock_agent.graph_version = 1 + mock_agent.name = "Test Agent" + mock_agent.description = "Test description" + mock_agent.input_schema = {"properties": {}} + mock_agent.output_schema = {"properties": {}} + + with patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + return_value=mock_agent, + ): + result = await core.get_library_agent_by_id("user-123", "agent-123") + + assert result is not None + assert result["graph_id"] == "agent-123" + assert result["name"] == "Test Agent" + + @pytest.mark.asyncio + async def test_falls_back_to_library_agent_id(self): + """Test that lookup falls back to library agent ID when graph_id not found.""" + mock_agent = MagicMock() + mock_agent.graph_id = "graph-456" # Different from the lookup ID + mock_agent.graph_version = 1 + mock_agent.name = "Library Agent" + mock_agent.description = "Found by library ID" + mock_agent.input_schema = {"properties": {}} + mock_agent.output_schema = {"properties": {}} + + with ( + patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + return_value=None, # Not found by graph_id + ), + patch.object( + core.library_db, + "get_library_agent", + new_callable=AsyncMock, + return_value=mock_agent, # Found by library ID + ), + ): + result = await core.get_library_agent_by_id("user-123", "library-id-123") + + assert result is not None + assert result["graph_id"] == "graph-456" + assert result["name"] == "Library Agent" + + @pytest.mark.asyncio + async def test_returns_none_when_not_found_by_either_method(self): + """Test that None is returned when agent not found by either method.""" + with ( + patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + return_value=None, + ), + patch.object( + core.library_db, + "get_library_agent", + new_callable=AsyncMock, + side_effect=core.NotFoundError("Not found"), + ), + ): + result = await core.get_library_agent_by_id("user-123", "nonexistent") + + assert result is None + + @pytest.mark.asyncio + async def test_returns_none_on_exception(self): + """Test that None is returned when exception occurs in both lookups.""" + with ( + patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + side_effect=Exception("Database error"), + ), + patch.object( + core.library_db, + "get_library_agent", + new_callable=AsyncMock, + side_effect=Exception("Database error"), + ), + ): + result = await core.get_library_agent_by_id("user-123", "agent-123") + + assert result is None + + @pytest.mark.asyncio + async def test_alias_works(self): + """Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id.""" + assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id + + +class TestGetAllRelevantAgentsWithUuids: + """Test UUID extraction in get_all_relevant_agents_for_generation.""" + + @pytest.mark.asyncio + async def test_fetches_explicitly_mentioned_agents(self): + """Test that agents mentioned by UUID are fetched directly.""" + mock_agent = MagicMock() + mock_agent.graph_id = "46631191-e8a8-486f-ad90-84f89738321d" + mock_agent.graph_version = 1 + mock_agent.name = "Mentioned Agent" + mock_agent.description = "Explicitly mentioned" + mock_agent.input_schema = {} + mock_agent.output_schema = {} + + mock_response = MagicMock() + mock_response.agents = [] + + with ( + patch.object( + core.library_db, + "get_library_agent_by_graph_id", + new_callable=AsyncMock, + return_value=mock_agent, + ), + patch.object( + core.library_db, + "list_library_agents", + new_callable=AsyncMock, + return_value=mock_response, + ), + ): + result = await core.get_all_relevant_agents_for_generation( + user_id="user-123", + search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d", + include_marketplace=False, + ) + + assert len(result) == 1 + assert result[0].get("graph_id") == "46631191-e8a8-486f-ad90-84f89738321d" + assert result[0].get("name") == "Mentioned Agent" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/autogpt_platform/backend/test/agent_generator/test_service.py b/autogpt_platform/backend/test/agent_generator/test_service.py index fe7a1a7fdd..cc37c428c0 100644 --- a/autogpt_platform/backend/test/agent_generator/test_service.py +++ b/autogpt_platform/backend/test/agent_generator/test_service.py @@ -102,7 +102,7 @@ class TestDecomposeGoalExternal: @pytest.mark.asyncio async def test_decompose_goal_with_context(self): - """Test decomposition with additional context.""" + """Test decomposition with additional context enriched into description.""" mock_response = MagicMock() mock_response.json.return_value = { "success": True, @@ -119,9 +119,12 @@ class TestDecomposeGoalExternal: "Build a chatbot", context="Use Python" ) + expected_description = ( + "Build a chatbot\n\nAdditional context from user:\nUse Python" + ) mock_client.post.assert_called_once_with( "/api/decompose-description", - json={"description": "Build a chatbot", "user_instruction": "Use Python"}, + json={"description": expected_description}, ) @pytest.mark.asyncio @@ -433,5 +436,139 @@ class TestGetBlocksExternal: assert result is None +class TestLibraryAgentsPassthrough: + """Test that library_agents are passed correctly in all requests.""" + + def setup_method(self): + """Reset client singleton before each test.""" + service._settings = None + service._client = None + + @pytest.mark.asyncio + async def test_decompose_goal_passes_library_agents(self): + """Test that library_agents are included in decompose goal payload.""" + library_agents = [ + { + "graph_id": "agent-123", + "graph_version": 1, + "name": "Email Sender", + "description": "Sends emails", + "input_schema": {"properties": {"to": {"type": "string"}}}, + "output_schema": {"properties": {"sent": {"type": "boolean"}}}, + }, + ] + + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "type": "instructions", + "steps": ["Step 1"], + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + await service.decompose_goal_external( + "Send an email", + library_agents=library_agents, + ) + + # Verify library_agents was passed in the payload + call_args = mock_client.post.call_args + assert call_args[1]["json"]["library_agents"] == library_agents + + @pytest.mark.asyncio + async def test_generate_agent_passes_library_agents(self): + """Test that library_agents are included in generate agent payload.""" + library_agents = [ + { + "graph_id": "agent-456", + "graph_version": 2, + "name": "Data Fetcher", + "description": "Fetches data from API", + "input_schema": {"properties": {"url": {"type": "string"}}}, + "output_schema": {"properties": {"data": {"type": "object"}}}, + }, + ] + + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "agent_json": {"name": "Test Agent", "nodes": []}, + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + await service.generate_agent_external( + {"steps": ["Step 1"]}, + library_agents=library_agents, + ) + + # Verify library_agents was passed in the payload + call_args = mock_client.post.call_args + assert call_args[1]["json"]["library_agents"] == library_agents + + @pytest.mark.asyncio + async def test_generate_agent_patch_passes_library_agents(self): + """Test that library_agents are included in patch generation payload.""" + library_agents = [ + { + "graph_id": "agent-789", + "graph_version": 1, + "name": "Slack Notifier", + "description": "Sends Slack messages", + "input_schema": {"properties": {"message": {"type": "string"}}}, + "output_schema": {"properties": {"success": {"type": "boolean"}}}, + }, + ] + + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "agent_json": {"name": "Updated Agent", "nodes": []}, + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + await service.generate_agent_patch_external( + "Add error handling", + {"name": "Original Agent", "nodes": []}, + library_agents=library_agents, + ) + + # Verify library_agents was passed in the payload + call_args = mock_client.post.call_args + assert call_args[1]["json"]["library_agents"] == library_agents + + @pytest.mark.asyncio + async def test_decompose_goal_without_library_agents(self): + """Test that decompose goal works without library_agents.""" + mock_response = MagicMock() + mock_response.json.return_value = { + "success": True, + "type": "instructions", + "steps": ["Step 1"], + } + mock_response.raise_for_status = MagicMock() + + mock_client = AsyncMock() + mock_client.post.return_value = mock_response + + with patch.object(service, "_get_client", return_value=mock_client): + await service.decompose_goal_external("Build a workflow") + + # Verify library_agents was NOT passed when not provided + call_args = mock_client.post.call_args + assert "library_agents" not in call_args[1]["json"] + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/autogpt_platform/backend/test/e2e_test_data.py b/autogpt_platform/backend/test/e2e_test_data.py index d7576cdad3..7288197a90 100644 --- a/autogpt_platform/backend/test/e2e_test_data.py +++ b/autogpt_platform/backend/test/e2e_test_data.py @@ -43,19 +43,24 @@ faker = Faker() # Constants for data generation limits (reduced for E2E tests) NUM_USERS = 15 NUM_AGENT_BLOCKS = 30 -MIN_GRAPHS_PER_USER = 15 -MAX_GRAPHS_PER_USER = 15 +MIN_GRAPHS_PER_USER = 25 +MAX_GRAPHS_PER_USER = 25 MIN_NODES_PER_GRAPH = 3 MAX_NODES_PER_GRAPH = 6 MIN_PRESETS_PER_USER = 2 MAX_PRESETS_PER_USER = 3 -MIN_AGENTS_PER_USER = 15 -MAX_AGENTS_PER_USER = 15 +MIN_AGENTS_PER_USER = 25 +MAX_AGENTS_PER_USER = 25 MIN_EXECUTIONS_PER_GRAPH = 2 MAX_EXECUTIONS_PER_GRAPH = 8 MIN_REVIEWS_PER_VERSION = 2 MAX_REVIEWS_PER_VERSION = 5 +# Guaranteed minimums for marketplace tests (deterministic) +GUARANTEED_FEATURED_AGENTS = 8 +GUARANTEED_FEATURED_CREATORS = 5 +GUARANTEED_TOP_AGENTS = 10 + def get_image(): """Generate a consistent image URL using picsum.photos service.""" @@ -385,7 +390,7 @@ class TestDataCreator: library_agents = [] for user in self.users: - num_agents = 10 # Create exactly 10 agents per user + num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER) # Get available graphs for this user user_graphs = [ @@ -507,14 +512,17 @@ class TestDataCreator: existing_profiles, min(num_creators, len(existing_profiles)) ) - # Mark about 50% of creators as featured (more for testing) - num_featured = max(2, int(num_creators * 0.5)) + # Guarantee at least GUARANTEED_FEATURED_CREATORS featured creators + num_featured = max(GUARANTEED_FEATURED_CREATORS, int(num_creators * 0.5)) num_featured = min( num_featured, len(selected_profiles) ) # Don't exceed available profiles featured_profile_ids = set( random.sample([p.id for p in selected_profiles], num_featured) ) + print( + f"🎯 Creating {num_featured} featured creators (min: {GUARANTEED_FEATURED_CREATORS})" + ) for profile in selected_profiles: try: @@ -545,21 +553,25 @@ class TestDataCreator: return profiles async def create_test_store_submissions(self) -> List[Dict[str, Any]]: - """Create test store submissions using the API function.""" + """Create test store submissions using the API function. + + DETERMINISTIC: Guarantees minimum featured agents for E2E tests. + """ print("Creating test store submissions...") submissions = [] approved_submissions = [] + featured_count = 0 + submission_counter = 0 - # Create a special test submission for test123@gmail.com + # Create a special test submission for test123@gmail.com (ALWAYS approved + featured) test_user = next( (user for user in self.users if user["email"] == "test123@gmail.com"), None ) - if test_user: - # Special test data for consistent testing + if test_user and self.agent_graphs: test_submission_data = { "user_id": test_user["id"], - "agent_id": self.agent_graphs[0]["id"], # Use first available graph + "agent_id": self.agent_graphs[0]["id"], "agent_version": 1, "slug": "test-agent-submission", "name": "Test Agent Submission", @@ -580,37 +592,24 @@ class TestDataCreator: submissions.append(test_submission.model_dump()) print("✅ Created special test store submission for test123@gmail.com") - # Randomly approve, reject, or leave pending the test submission + # ALWAYS approve and feature the test submission if test_submission.store_listing_version_id: - random_value = random.random() - if random_value < 0.4: # 40% chance to approve - approved_submission = await review_store_submission( - store_listing_version_id=test_submission.store_listing_version_id, - is_approved=True, - external_comments="Test submission approved", - internal_comments="Auto-approved test submission", - reviewer_id=test_user["id"], - ) - approved_submissions.append(approved_submission.model_dump()) - print("✅ Approved test store submission") + approved_submission = await review_store_submission( + store_listing_version_id=test_submission.store_listing_version_id, + is_approved=True, + external_comments="Test submission approved", + internal_comments="Auto-approved test submission", + reviewer_id=test_user["id"], + ) + approved_submissions.append(approved_submission.model_dump()) + print("✅ Approved test store submission") - # Mark approved submission as featured - await prisma.storelistingversion.update( - where={"id": test_submission.store_listing_version_id}, - data={"isFeatured": True}, - ) - print("🌟 Marked test agent as FEATURED") - elif random_value < 0.7: # 30% chance to reject (40% to 70%) - await review_store_submission( - store_listing_version_id=test_submission.store_listing_version_id, - is_approved=False, - external_comments="Test submission rejected - needs improvements", - internal_comments="Auto-rejected test submission for E2E testing", - reviewer_id=test_user["id"], - ) - print("❌ Rejected test store submission") - else: # 30% chance to leave pending (70% to 100%) - print("⏳ Left test submission pending for review") + await prisma.storelistingversion.update( + where={"id": test_submission.store_listing_version_id}, + data={"isFeatured": True}, + ) + featured_count += 1 + print("🌟 Marked test agent as FEATURED") except Exception as e: print(f"Error creating test store submission: {e}") @@ -620,7 +619,6 @@ class TestDataCreator: # Create regular submissions for all users for user in self.users: - # Get available graphs for this specific user user_graphs = [ g for g in self.agent_graphs if g.get("userId") == user["id"] ] @@ -631,18 +629,17 @@ class TestDataCreator: ) continue - # Create exactly 4 store submissions per user for submission_index in range(4): graph = random.choice(user_graphs) + submission_counter += 1 try: print( - f"Creating store submission for user {user['id']} with graph {graph['id']} (owner: {graph.get('userId')})" + f"Creating store submission for user {user['id']} with graph {graph['id']}" ) - # Use the API function to create store submission with correct parameters submission = await create_store_submission( - user_id=user["id"], # Must match graph's userId + user_id=user["id"], agent_id=graph["id"], agent_version=graph.get("version", 1), slug=faker.slug(), @@ -651,22 +648,24 @@ class TestDataCreator: video_url=get_video_url() if random.random() < 0.3 else None, image_urls=[get_image() for _ in range(3)], description=faker.text(), - categories=[ - get_category() - ], # Single category from predefined list + categories=[get_category()], changes_summary="Initial E2E test submission", ) submissions.append(submission.model_dump()) print(f"✅ Created store submission: {submission.name}") - # Randomly approve, reject, or leave pending the submission if submission.store_listing_version_id: - random_value = random.random() - if random_value < 0.4: # 40% chance to approve - try: - # Pick a random user as the reviewer (admin) - reviewer_id = random.choice(self.users)["id"] + # DETERMINISTIC: First N submissions are always approved + # First GUARANTEED_FEATURED_AGENTS of those are always featured + should_approve = ( + submission_counter <= GUARANTEED_TOP_AGENTS + or random.random() < 0.4 + ) + should_feature = featured_count < GUARANTEED_FEATURED_AGENTS + if should_approve: + try: + reviewer_id = random.choice(self.users)["id"] approved_submission = await review_store_submission( store_listing_version_id=submission.store_listing_version_id, is_approved=True, @@ -681,16 +680,7 @@ class TestDataCreator: f"✅ Approved store submission: {submission.name}" ) - # Mark some agents as featured during creation (30% chance) - # More likely for creators and first submissions - is_creator = user["id"] in [ - p.get("userId") for p in self.profiles - ] - feature_chance = ( - 0.5 if is_creator else 0.2 - ) # 50% for creators, 20% for others - - if random.random() < feature_chance: + if should_feature: try: await prisma.storelistingversion.update( where={ @@ -698,8 +688,25 @@ class TestDataCreator: }, data={"isFeatured": True}, ) + featured_count += 1 print( - f"🌟 Marked agent as FEATURED: {submission.name}" + f"🌟 Marked agent as FEATURED ({featured_count}/{GUARANTEED_FEATURED_AGENTS}): {submission.name}" + ) + except Exception as e: + print( + f"Warning: Could not mark submission as featured: {e}" + ) + elif random.random() < 0.2: + try: + await prisma.storelistingversion.update( + where={ + "id": submission.store_listing_version_id + }, + data={"isFeatured": True}, + ) + featured_count += 1 + print( + f"🌟 Marked agent as FEATURED (bonus): {submission.name}" ) except Exception as e: print( @@ -710,11 +717,9 @@ class TestDataCreator: print( f"Warning: Could not approve submission {submission.name}: {e}" ) - elif random_value < 0.7: # 30% chance to reject (40% to 70%) + elif random.random() < 0.5: try: - # Pick a random user as the reviewer (admin) reviewer_id = random.choice(self.users)["id"] - await review_store_submission( store_listing_version_id=submission.store_listing_version_id, is_approved=False, @@ -729,7 +734,7 @@ class TestDataCreator: print( f"Warning: Could not reject submission {submission.name}: {e}" ) - else: # 30% chance to leave pending (70% to 100%) + else: print( f"⏳ Left submission pending for review: {submission.name}" ) @@ -743,9 +748,13 @@ class TestDataCreator: traceback.print_exc() continue + print("\n📊 Store Submissions Summary:") + print(f" Created: {len(submissions)}") + print(f" Approved: {len(approved_submissions)}") print( - f"Created {len(submissions)} store submissions, approved {len(approved_submissions)}" + f" Featured: {featured_count} (guaranteed min: {GUARANTEED_FEATURED_AGENTS})" ) + self.store_submissions = submissions return submissions @@ -825,12 +834,15 @@ class TestDataCreator: print(f"✅ Agent blocks available: {len(self.agent_blocks)}") print(f"✅ Agent graphs created: {len(self.agent_graphs)}") print(f"✅ Library agents created: {len(self.library_agents)}") - print(f"✅ Creator profiles updated: {len(self.profiles)} (some featured)") - print( - f"✅ Store submissions created: {len(self.store_submissions)} (some marked as featured during creation)" - ) + print(f"✅ Creator profiles updated: {len(self.profiles)}") + print(f"✅ Store submissions created: {len(self.store_submissions)}") print(f"✅ API keys created: {len(self.api_keys)}") print(f"✅ Presets created: {len(self.presets)}") + print("\n🎯 Deterministic Guarantees:") + print(f" • Featured agents: >= {GUARANTEED_FEATURED_AGENTS}") + print(f" • Featured creators: >= {GUARANTEED_FEATURED_CREATORS}") + print(f" • Top agents (approved): >= {GUARANTEED_TOP_AGENTS}") + print(f" • Library agents per user: >= {MIN_AGENTS_PER_USER}") print("\n🚀 Your E2E test database is ready to use!") diff --git a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx index 70d9783ccd..246fe52826 100644 --- a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx +++ b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx @@ -1,10 +1,9 @@ "use client"; +import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding"; +import { getOnboardingStatus, resolveResponse } from "@/app/api/helpers"; import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { useRouter } from "next/navigation"; import { useEffect } from "react"; -import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers"; -import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding"; -import { getHomepageRoute } from "@/lib/constants"; export default function OnboardingPage() { const router = useRouter(); @@ -13,12 +12,10 @@ export default function OnboardingPage() { async function redirectToStep() { try { // Check if onboarding is enabled (also gets chat flag for redirect) - const { shouldShowOnboarding, isChatEnabled } = - await getOnboardingStatus(); - const homepageRoute = getHomepageRoute(isChatEnabled); + const { shouldShowOnboarding } = await getOnboardingStatus(); if (!shouldShowOnboarding) { - router.replace(homepageRoute); + router.replace("/"); return; } @@ -26,7 +23,7 @@ export default function OnboardingPage() { // Handle completed onboarding if (onboarding.completedSteps.includes("GET_RESULTS")) { - router.replace(homepageRoute); + router.replace("/"); return; } diff --git a/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts b/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts index 15be137f63..e7e2997d0d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts +++ b/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts @@ -1,9 +1,8 @@ -import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; -import { getHomepageRoute } from "@/lib/constants"; -import BackendAPI from "@/lib/autogpt-server-api"; -import { NextResponse } from "next/server"; -import { revalidatePath } from "next/cache"; import { getOnboardingStatus } from "@/app/api/helpers"; +import BackendAPI from "@/lib/autogpt-server-api"; +import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; +import { revalidatePath } from "next/cache"; +import { NextResponse } from "next/server"; // Handle the callback to complete the user session login export async function GET(request: Request) { @@ -27,13 +26,12 @@ export async function GET(request: Request) { await api.createUser(); // Get onboarding status from backend (includes chat flag evaluated for this user) - const { shouldShowOnboarding, isChatEnabled } = - await getOnboardingStatus(); + const { shouldShowOnboarding } = await getOnboardingStatus(); if (shouldShowOnboarding) { next = "/onboarding"; revalidatePath("/onboarding", "layout"); } else { - next = getHomepageRoute(isChatEnabled); + next = "/"; revalidatePath(next, "layout"); } } catch (createUserError) { diff --git a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts index 41d05a9afb..fd67519957 100644 --- a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts +++ b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts @@ -1,6 +1,17 @@ import { OAuthPopupResultMessage } from "./types"; import { NextResponse } from "next/server"; +/** + * Safely encode a value as JSON for embedding in a script tag. + * Escapes characters that could break out of the script context to prevent XSS. + */ +function safeJsonStringify(value: unknown): string { + return JSON.stringify(value) + .replace(//g, "\\u003e") + .replace(/&/g, "\\u0026"); +} + // This route is intended to be used as the callback for integration OAuth flows, // controlled by the CredentialsInput component. The CredentialsInput opens the login // page in a pop-up window, which then redirects to this route to close the loop. @@ -23,12 +34,13 @@ export async function GET(request: Request) { console.debug("Sending message to opener:", message); // Return a response with the message as JSON and a script to close the window + // Use safeJsonStringify to prevent XSS by escaping <, >, and & characters return new NextResponse( ` diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode.tsx index 94e917a4ac..834603cc4a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/CustomNode/CustomNode.tsx @@ -857,7 +857,7 @@ export const CustomNode = React.memo( })(); const hasAdvancedFields = - data.inputSchema && + data.inputSchema?.properties && Object.entries(data.inputSchema.properties).some(([key, value]) => { return ( value.advanced === true && !data.inputSchema.required?.includes(key) diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts index 74fd663ab2..913c4d7ded 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts @@ -11,7 +11,6 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useQueryClient } from "@tanstack/react-query"; import { usePathname, useSearchParams } from "next/navigation"; -import { useRef } from "react"; import { useCopilotStore } from "../../copilot-page-store"; import { useCopilotSessionId } from "../../useCopilotSessionId"; import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer"; @@ -70,41 +69,16 @@ export function useCopilotShell() { }); const stopStream = useChatStore((s) => s.stopStream); - const onStreamComplete = useChatStore((s) => s.onStreamComplete); - const isStreaming = useCopilotStore((s) => s.isStreaming); const isCreatingSession = useCopilotStore((s) => s.isCreatingSession); - const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession); - const openInterruptModal = useCopilotStore((s) => s.openInterruptModal); - const pendingActionRef = useRef<(() => void) | null>(null); - - async function stopCurrentStream() { - if (!currentSessionId) return; - - setIsSwitchingSession(true); - await new Promise((resolve) => { - const unsubscribe = onStreamComplete((completedId) => { - if (completedId === currentSessionId) { - clearTimeout(timeout); - unsubscribe(); - resolve(); - } - }); - const timeout = setTimeout(() => { - unsubscribe(); - resolve(); - }, 3000); - stopStream(currentSessionId); - }); - - queryClient.invalidateQueries({ - queryKey: getGetV2GetSessionQueryKey(currentSessionId), - }); - setIsSwitchingSession(false); - } - - function selectSession(sessionId: string) { + function handleSessionClick(sessionId: string) { if (sessionId === currentSessionId) return; + + // Stop current stream - SSE reconnection allows resuming later + if (currentSessionId) { + stopStream(currentSessionId); + } + if (recentlyCreatedSessionsRef.current.has(sessionId)) { queryClient.invalidateQueries({ queryKey: getGetV2GetSessionQueryKey(sessionId), @@ -114,7 +88,12 @@ export function useCopilotShell() { if (isMobile) handleCloseDrawer(); } - function startNewChat() { + function handleNewChatClick() { + // Stop current stream - SSE reconnection allows resuming later + if (currentSessionId) { + stopStream(currentSessionId); + } + resetPagination(); queryClient.invalidateQueries({ queryKey: getGetV2ListSessionsQueryKey(), @@ -123,32 +102,6 @@ export function useCopilotShell() { if (isMobile) handleCloseDrawer(); } - function handleSessionClick(sessionId: string) { - if (sessionId === currentSessionId) return; - - if (isStreaming) { - pendingActionRef.current = async () => { - await stopCurrentStream(); - selectSession(sessionId); - }; - openInterruptModal(pendingActionRef.current); - } else { - selectSession(sessionId); - } - } - - function handleNewChatClick() { - if (isStreaming) { - pendingActionRef.current = async () => { - await stopCurrentStream(); - startNewChat(); - }; - openInterruptModal(pendingActionRef.current); - } else { - startNewChat(); - } - } - return { isMobile, isDrawerOpen, diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts index 692a5741f4..c6e479f896 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts @@ -26,8 +26,20 @@ export function buildCopilotChatUrl(prompt: string): string { export function getQuickActions(): string[] { return [ - "Show me what I can automate", - "Design a custom workflow", - "Help me with content creation", + "I don't know where to start, just ask me stuff", + "I do the same thing every week and it's killing me", + "Help me find where I'm wasting my time", ]; } + +export function getInputPlaceholder(width?: number) { + if (!width) return "What's your role and what eats up most of your day?"; + + if (width < 500) { + return "I'm a chef and I hate..."; + } + if (width <= 1080) { + return "What's your role and what eats up most of your day?"; + } + return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'"; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/layout.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/layout.tsx index 89cf72e2ba..876e5accfb 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/layout.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/layout.tsx @@ -1,6 +1,13 @@ -import type { ReactNode } from "react"; +"use client"; +import { FeatureFlagPage } from "@/services/feature-flags/FeatureFlagPage"; +import { Flag } from "@/services/feature-flags/use-get-flag"; +import { type ReactNode } from "react"; import { CopilotShell } from "./components/CopilotShell/CopilotShell"; export default function CopilotLayout({ children }: { children: ReactNode }) { - return {children}; + return ( + + {children} + + ); } diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx index 104b238895..542173a99c 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx @@ -6,7 +6,9 @@ import { Text } from "@/components/atoms/Text/Text"; import { Chat } from "@/components/contextual/Chat/Chat"; import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; +import { useEffect, useState } from "react"; import { useCopilotStore } from "./copilot-page-store"; +import { getInputPlaceholder } from "./helpers"; import { useCopilotPage } from "./useCopilotPage"; export default function CopilotPage() { @@ -14,14 +16,25 @@ export default function CopilotPage() { const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen); const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt); const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt); - const { - greetingName, - quickActions, - isLoading, - hasSession, - initialPrompt, - isReady, - } = state; + + const [inputPlaceholder, setInputPlaceholder] = useState( + getInputPlaceholder(), + ); + + useEffect(() => { + const handleResize = () => { + setInputPlaceholder(getInputPlaceholder(window.innerWidth)); + }; + + handleResize(); + + window.addEventListener("resize", handleResize); + return () => window.removeEventListener("resize", handleResize); + }, []); + + const { greetingName, quickActions, isLoading, hasSession, initialPrompt } = + state; + const { handleQuickAction, startChatWithPrompt, @@ -29,8 +42,6 @@ export default function CopilotPage() { handleStreamingChange, } = handlers; - if (!isReady) return null; - if (hasSession) { return (
@@ -81,7 +92,7 @@ export default function CopilotPage() { } return ( -
+
{isLoading ? (
@@ -98,25 +109,25 @@ export default function CopilotPage() {
) : ( <> -
+
Hey, {greetingName} - What do you want to automate? + Tell me about your work — I'll find what to automate.
-
+
{quickActions.map((action) => ( diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts index e4713cd24a..9d99f8e7bd 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotPage.ts @@ -3,18 +3,11 @@ import { postV2CreateSession, } from "@/app/api/__generated__/endpoints/chat/chat"; import { useToast } from "@/components/molecules/Toast/use-toast"; -import { getHomepageRoute } from "@/lib/constants"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useOnboarding } from "@/providers/onboarding/onboarding-provider"; -import { - Flag, - type FlagValues, - useGetFlag, -} from "@/services/feature-flags/use-get-flag"; import { SessionKey, sessionStorage } from "@/services/storage/session-storage"; import * as Sentry from "@sentry/nextjs"; import { useQueryClient } from "@tanstack/react-query"; -import { useFlags } from "launchdarkly-react-client-sdk"; import { useRouter } from "next/navigation"; import { useEffect } from "react"; import { useCopilotStore } from "./copilot-page-store"; @@ -33,22 +26,6 @@ export function useCopilotPage() { const isCreating = useCopilotStore((s) => s.isCreatingSession); const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession); - // Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus - useEffect(() => { - if (isLoggedIn) { - completeStep("VISIT_COPILOT"); - } - }, [completeStep, isLoggedIn]); - - const isChatEnabled = useGetFlag(Flag.CHAT); - const flags = useFlags(); - const homepageRoute = getHomepageRoute(isChatEnabled); - const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true"; - const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID; - const isLaunchDarklyConfigured = envEnabled && Boolean(clientId); - const isFlagReady = - !isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined; - const greetingName = getGreetingName(user); const quickActions = getQuickActions(); @@ -58,11 +35,8 @@ export function useCopilotPage() { : undefined; useEffect(() => { - if (!isFlagReady) return; - if (isChatEnabled === false) { - router.replace(homepageRoute); - } - }, [homepageRoute, isChatEnabled, isFlagReady, router]); + if (isLoggedIn) completeStep("VISIT_COPILOT"); + }, [completeStep, isLoggedIn]); async function startChatWithPrompt(prompt: string) { if (!prompt?.trim()) return; @@ -116,7 +90,6 @@ export function useCopilotPage() { isLoading: isUserLoading, hasSession, initialPrompt, - isReady: isFlagReady && isChatEnabled !== false && isLoggedIn, }, handlers: { handleQuickAction, diff --git a/autogpt_platform/frontend/src/app/(platform)/error/page.tsx b/autogpt_platform/frontend/src/app/(platform)/error/page.tsx index b26ca4559b..3cf68178ad 100644 --- a/autogpt_platform/frontend/src/app/(platform)/error/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/error/page.tsx @@ -1,8 +1,6 @@ "use client"; import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard"; -import { getHomepageRoute } from "@/lib/constants"; -import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; import { useSearchParams } from "next/navigation"; import { Suspense } from "react"; import { getErrorDetails } from "./helpers"; @@ -11,8 +9,6 @@ function ErrorPageContent() { const searchParams = useSearchParams(); const errorMessage = searchParams.get("message"); const errorDetails = getErrorDetails(errorMessage); - const isChatEnabled = useGetFlag(Flag.CHAT); - const homepageRoute = getHomepageRoute(isChatEnabled); function handleRetry() { // Auth-related errors should redirect to login @@ -30,7 +26,7 @@ function ErrorPageContent() { }, 2000); } else { // For server/network errors, go to home - window.location.href = homepageRoute; + window.location.href = "/"; } } diff --git a/autogpt_platform/frontend/src/app/(platform)/login/actions.ts b/autogpt_platform/frontend/src/app/(platform)/login/actions.ts index 447a25a41d..c4867dd123 100644 --- a/autogpt_platform/frontend/src/app/(platform)/login/actions.ts +++ b/autogpt_platform/frontend/src/app/(platform)/login/actions.ts @@ -1,6 +1,5 @@ "use server"; -import { getHomepageRoute } from "@/lib/constants"; import BackendAPI from "@/lib/autogpt-server-api"; import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; import { loginFormSchema } from "@/types/auth"; @@ -38,10 +37,8 @@ export async function login(email: string, password: string) { await api.createUser(); // Get onboarding status from backend (includes chat flag evaluated for this user) - const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus(); - const next = shouldShowOnboarding - ? "/onboarding" - : getHomepageRoute(isChatEnabled); + const { shouldShowOnboarding } = await getOnboardingStatus(); + const next = shouldShowOnboarding ? "/onboarding" : "/"; return { success: true, diff --git a/autogpt_platform/frontend/src/app/(platform)/login/useLoginPage.ts b/autogpt_platform/frontend/src/app/(platform)/login/useLoginPage.ts index e64cc1858d..9b81965c31 100644 --- a/autogpt_platform/frontend/src/app/(platform)/login/useLoginPage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/login/useLoginPage.ts @@ -1,8 +1,6 @@ import { useToast } from "@/components/molecules/Toast/use-toast"; -import { getHomepageRoute } from "@/lib/constants"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { environment } from "@/services/environment"; -import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; import { loginFormSchema, LoginProvider } from "@/types/auth"; import { zodResolver } from "@hookform/resolvers/zod"; import { useRouter, useSearchParams } from "next/navigation"; @@ -22,17 +20,15 @@ export function useLoginPage() { const [isGoogleLoading, setIsGoogleLoading] = useState(false); const [showNotAllowedModal, setShowNotAllowedModal] = useState(false); const isCloudEnv = environment.isCloud(); - const isChatEnabled = useGetFlag(Flag.CHAT); - const homepageRoute = getHomepageRoute(isChatEnabled); // Get redirect destination from 'next' query parameter const nextUrl = searchParams.get("next"); useEffect(() => { if (isLoggedIn && !isLoggingIn) { - router.push(nextUrl || homepageRoute); + router.push(nextUrl || "/"); } - }, [homepageRoute, isLoggedIn, isLoggingIn, nextUrl, router]); + }, [isLoggedIn, isLoggingIn, nextUrl, router]); const form = useForm>({ resolver: zodResolver(loginFormSchema), @@ -98,7 +94,7 @@ export function useLoginPage() { } // Prefer URL's next parameter, then use backend-determined route - router.replace(nextUrl || result.next || homepageRoute); + router.replace(nextUrl || result.next || "/"); } catch (error) { toast({ title: diff --git a/autogpt_platform/frontend/src/app/(platform)/signup/actions.ts b/autogpt_platform/frontend/src/app/(platform)/signup/actions.ts index 0fbba54b8e..204482dbe9 100644 --- a/autogpt_platform/frontend/src/app/(platform)/signup/actions.ts +++ b/autogpt_platform/frontend/src/app/(platform)/signup/actions.ts @@ -1,6 +1,5 @@ "use server"; -import { getHomepageRoute } from "@/lib/constants"; import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; import { signupFormSchema } from "@/types/auth"; import * as Sentry from "@sentry/nextjs"; @@ -59,10 +58,8 @@ export async function signup( } // Get onboarding status from backend (includes chat flag evaluated for this user) - const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus(); - const next = shouldShowOnboarding - ? "/onboarding" - : getHomepageRoute(isChatEnabled); + const { shouldShowOnboarding } = await getOnboardingStatus(); + const next = shouldShowOnboarding ? "/onboarding" : "/"; return { success: true, next }; } catch (err) { diff --git a/autogpt_platform/frontend/src/app/(platform)/signup/useSignupPage.ts b/autogpt_platform/frontend/src/app/(platform)/signup/useSignupPage.ts index 5fa8c2c159..fd78b48735 100644 --- a/autogpt_platform/frontend/src/app/(platform)/signup/useSignupPage.ts +++ b/autogpt_platform/frontend/src/app/(platform)/signup/useSignupPage.ts @@ -1,8 +1,6 @@ import { useToast } from "@/components/molecules/Toast/use-toast"; -import { getHomepageRoute } from "@/lib/constants"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { environment } from "@/services/environment"; -import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; import { LoginProvider, signupFormSchema } from "@/types/auth"; import { zodResolver } from "@hookform/resolvers/zod"; import { useRouter, useSearchParams } from "next/navigation"; @@ -22,17 +20,15 @@ export function useSignupPage() { const [isGoogleLoading, setIsGoogleLoading] = useState(false); const [showNotAllowedModal, setShowNotAllowedModal] = useState(false); const isCloudEnv = environment.isCloud(); - const isChatEnabled = useGetFlag(Flag.CHAT); - const homepageRoute = getHomepageRoute(isChatEnabled); // Get redirect destination from 'next' query parameter const nextUrl = searchParams.get("next"); useEffect(() => { if (isLoggedIn && !isSigningUp) { - router.push(nextUrl || homepageRoute); + router.push(nextUrl || "/"); } - }, [homepageRoute, isLoggedIn, isSigningUp, nextUrl, router]); + }, [isLoggedIn, isSigningUp, nextUrl, router]); const form = useForm>({ resolver: zodResolver(signupFormSchema), @@ -133,7 +129,7 @@ export function useSignupPage() { } // Prefer the URL's next parameter, then result.next (for onboarding), then default - const redirectTo = nextUrl || result.next || homepageRoute; + const redirectTo = nextUrl || result.next || "/"; router.replace(redirectTo); } catch (error) { setIsLoading(false); diff --git a/autogpt_platform/frontend/src/app/api/chat/tasks/[taskId]/stream/route.ts b/autogpt_platform/frontend/src/app/api/chat/tasks/[taskId]/stream/route.ts new file mode 100644 index 0000000000..336786bfdb --- /dev/null +++ b/autogpt_platform/frontend/src/app/api/chat/tasks/[taskId]/stream/route.ts @@ -0,0 +1,81 @@ +import { environment } from "@/services/environment"; +import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers"; +import { NextRequest } from "next/server"; + +/** + * SSE Proxy for task stream reconnection. + * + * This endpoint allows clients to reconnect to an ongoing or recently completed + * background task's stream. It replays missed messages from Redis Streams and + * subscribes to live updates if the task is still running. + * + * Client contract: + * 1. When receiving an operation_started event, store the task_id + * 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx} + * 3. Messages are replayed from the last_message_id position + * 4. Stream ends when "finish" event is received + */ +export async function GET( + request: NextRequest, + { params }: { params: Promise<{ taskId: string }> }, +) { + const { taskId } = await params; + const searchParams = request.nextUrl.searchParams; + const lastMessageId = searchParams.get("last_message_id") || "0-0"; + + try { + // Get auth token from server-side session + const token = await getServerAuthToken(); + + // Build backend URL + const backendUrl = environment.getAGPTServerBaseUrl(); + const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl); + streamUrl.searchParams.set("last_message_id", lastMessageId); + + // Forward request to backend with auth header + const headers: Record = { + Accept: "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }; + + if (token) { + headers["Authorization"] = `Bearer ${token}`; + } + + const response = await fetch(streamUrl.toString(), { + method: "GET", + headers, + }); + + if (!response.ok) { + const error = await response.text(); + return new Response(error, { + status: response.status, + headers: { "Content-Type": "application/json" }, + }); + } + + // Return the SSE stream directly + return new Response(response.body, { + headers: { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache, no-transform", + Connection: "keep-alive", + "X-Accel-Buffering": "no", + }, + }); + } catch (error) { + console.error("Task stream proxy error:", error); + return new Response( + JSON.stringify({ + error: "Failed to connect to task stream", + detail: error instanceof Error ? error.message : String(error), + }), + { + status: 500, + headers: { "Content-Type": "application/json" }, + }, + ); + } +} diff --git a/autogpt_platform/frontend/src/app/api/helpers.ts b/autogpt_platform/frontend/src/app/api/helpers.ts index c2104d231a..226f5fa786 100644 --- a/autogpt_platform/frontend/src/app/api/helpers.ts +++ b/autogpt_platform/frontend/src/app/api/helpers.ts @@ -181,6 +181,5 @@ export async function getOnboardingStatus() { const isCompleted = onboarding.completedSteps.includes("CONGRATS"); return { shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted, - isChatEnabled: status.is_chat_enabled, }; } diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index 6692c30e72..5ed449829d 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -917,6 +917,28 @@ "security": [{ "HTTPBearerJWT": [] }] } }, + "/api/chat/config/ttl": { + "get": { + "tags": ["v2", "chat", "chat"], + "summary": "Get Ttl Config", + "description": "Get the stream TTL configuration.\n\nReturns the Time-To-Live settings for chat streams, which determines\nhow long clients can reconnect to an active stream.\n\nReturns:\n dict: TTL configuration with seconds and milliseconds values.", + "operationId": "getV2GetTtlConfig", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "additionalProperties": true, + "type": "object", + "title": "Response Getv2Getttlconfig" + } + } + } + } + } + } + }, "/api/chat/health": { "get": { "tags": ["v2", "chat", "chat"], @@ -939,6 +961,63 @@ } } }, + "/api/chat/operations/{operation_id}/complete": { + "post": { + "tags": ["v2", "chat", "chat"], + "summary": "Complete Operation", + "description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.", + "operationId": "postV2CompleteOperation", + "parameters": [ + { + "name": "operation_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Operation Id" } + }, + { + "name": "x-api-key", + "in": "header", + "required": false, + "schema": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "X-Api-Key" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OperationCompleteRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "additionalProperties": true, + "title": "Response Postv2Completeoperation" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, "/api/chat/sessions": { "get": { "tags": ["v2", "chat", "chat"], @@ -1022,7 +1101,7 @@ "get": { "tags": ["v2", "chat", "chat"], "summary": "Get Session", - "description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.", + "description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.", "operationId": "getV2GetSession", "security": [{ "HTTPBearerJWT": [] }], "parameters": [ @@ -1157,7 +1236,7 @@ "post": { "tags": ["v2", "chat", "chat"], "summary": "Stream Chat Post", - "description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.", + "description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.", "operationId": "postV2StreamChatPost", "security": [{ "HTTPBearerJWT": [] }], "parameters": [ @@ -1195,6 +1274,94 @@ } } }, + "/api/chat/tasks/{task_id}": { + "get": { + "tags": ["v2", "chat", "chat"], + "summary": "Get Task Status", + "description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.", + "operationId": "getV2GetTaskStatus", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Task Id" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "additionalProperties": true, + "title": "Response Getv2Gettaskstatus" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/chat/tasks/{task_id}/stream": { + "get": { + "tags": ["v2", "chat", "chat"], + "summary": "Stream Task", + "description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.", + "operationId": "getV2StreamTask", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Task Id" } + }, + { + "name": "last_message_id", + "in": "query", + "required": false, + "schema": { + "type": "string", + "description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.", + "default": "0-0", + "title": "Last Message Id" + }, + "description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay." + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { "application/json": { "schema": {} } } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, "/api/credits": { "get": { "tags": ["v1", "credits"], @@ -6168,6 +6335,18 @@ "title": "AccuracyTrendsResponse", "description": "Response model for accuracy trends and alerts." }, + "ActiveStreamInfo": { + "properties": { + "task_id": { "type": "string", "title": "Task Id" }, + "last_message_id": { "type": "string", "title": "Last Message Id" }, + "operation_id": { "type": "string", "title": "Operation Id" }, + "tool_name": { "type": "string", "title": "Tool Name" } + }, + "type": "object", + "required": ["task_id", "last_message_id", "operation_id", "tool_name"], + "title": "ActiveStreamInfo", + "description": "Information about an active stream for reconnection." + }, "AddUserCreditsResponse": { "properties": { "new_balance": { "type": "integer", "title": "New Balance" }, @@ -7981,6 +8160,25 @@ ] }, "new_output": { "type": "boolean", "title": "New Output" }, + "execution_count": { + "type": "integer", + "title": "Execution Count", + "default": 0 + }, + "success_rate": { + "anyOf": [{ "type": "number" }, { "type": "null" }], + "title": "Success Rate" + }, + "avg_correctness_score": { + "anyOf": [{ "type": "number" }, { "type": "null" }], + "title": "Avg Correctness Score" + }, + "recent_executions": { + "items": { "$ref": "#/components/schemas/RecentExecution" }, + "type": "array", + "title": "Recent Executions", + "description": "List of recent executions with status, score, and summary" + }, "can_access_graph": { "type": "boolean", "title": "Can Access Graph" @@ -8804,6 +9002,27 @@ ], "title": "OnboardingStep" }, + "OperationCompleteRequest": { + "properties": { + "success": { "type": "boolean", "title": "Success" }, + "result": { + "anyOf": [ + { "additionalProperties": true, "type": "object" }, + { "type": "string" }, + { "type": "null" } + ], + "title": "Result" + }, + "error": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Error" + } + }, + "type": "object", + "required": ["success"], + "title": "OperationCompleteRequest", + "description": "Request model for external completion webhook." + }, "Pagination": { "properties": { "total_items": { @@ -9374,6 +9593,23 @@ "required": ["providers", "pagination"], "title": "ProviderResponse" }, + "RecentExecution": { + "properties": { + "status": { "type": "string", "title": "Status" }, + "correctness_score": { + "anyOf": [{ "type": "number" }, { "type": "null" }], + "title": "Correctness Score" + }, + "activity_summary": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Activity Summary" + } + }, + "type": "object", + "required": ["status"], + "title": "RecentExecution", + "description": "Summary of a recent execution for quality assessment.\n\nUsed by the LLM to understand the agent's recent performance with specific examples\nrather than just aggregate statistics." + }, "RefundRequest": { "properties": { "id": { "type": "string", "title": "Id" }, @@ -9642,6 +9878,12 @@ "items": { "additionalProperties": true, "type": "object" }, "type": "array", "title": "Messages" + }, + "active_stream": { + "anyOf": [ + { "$ref": "#/components/schemas/ActiveStreamInfo" }, + { "type": "null" } + ] } }, "type": "object", @@ -9797,7 +10039,8 @@ "sub_heading": { "type": "string", "title": "Sub Heading" }, "description": { "type": "string", "title": "Description" }, "runs": { "type": "integer", "title": "Runs" }, - "rating": { "type": "number", "title": "Rating" } + "rating": { "type": "number", "title": "Rating" }, + "agent_graph_id": { "type": "string", "title": "Agent Graph Id" } }, "type": "object", "required": [ @@ -9809,7 +10052,8 @@ "sub_heading", "description", "runs", - "rating" + "rating", + "agent_graph_id" ], "title": "StoreAgent" }, diff --git a/autogpt_platform/frontend/src/app/page.tsx b/autogpt_platform/frontend/src/app/page.tsx index dbfab49469..ce67760eda 100644 --- a/autogpt_platform/frontend/src/app/page.tsx +++ b/autogpt_platform/frontend/src/app/page.tsx @@ -1,27 +1,15 @@ "use client"; -import { getHomepageRoute } from "@/lib/constants"; -import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; +import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { useRouter } from "next/navigation"; import { useEffect } from "react"; export default function Page() { - const isChatEnabled = useGetFlag(Flag.CHAT); const router = useRouter(); - const homepageRoute = getHomepageRoute(isChatEnabled); - const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true"; - const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID; - const isLaunchDarklyConfigured = envEnabled && Boolean(clientId); - const isFlagReady = - !isLaunchDarklyConfigured || typeof isChatEnabled === "boolean"; - useEffect( - function redirectToHomepage() { - if (!isFlagReady) return; - router.replace(homepageRoute); - }, - [homepageRoute, isFlagReady, router], - ); + useEffect(() => { + router.replace("/copilot"); + }, [router]); - return null; + return ; } diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx index ada8c26231..da454150bf 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx @@ -1,7 +1,6 @@ "use client"; import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId"; -import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store"; import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { Text } from "@/components/atoms/Text/Text"; import { cn } from "@/lib/utils"; @@ -25,8 +24,8 @@ export function Chat({ }: ChatProps) { const { urlSessionId } = useCopilotSessionId(); const hasHandledNotFoundRef = useRef(false); - const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession); const { + session, messages, isLoading, isCreating, @@ -38,6 +37,18 @@ export function Chat({ startPollingForOperation, } = useChat({ urlSessionId }); + // Extract active stream info for reconnection + const activeStream = ( + session as { + active_stream?: { + task_id: string; + last_message_id: string; + operation_id: string; + tool_name: string; + }; + } + )?.active_stream; + useEffect(() => { if (!onSessionNotFound) return; if (!urlSessionId) return; @@ -53,8 +64,7 @@ export function Chat({ isCreating, ]); - const shouldShowLoader = - (showLoader && (isLoading || isCreating)) || isSwitchingSession; + const shouldShowLoader = showLoader && (isLoading || isCreating); return (
@@ -66,21 +76,19 @@ export function Chat({
- {isSwitchingSession - ? "Switching chat..." - : "Loading your chat..."} + Loading your chat...
)} {/* Error State */} - {error && !isLoading && !isSwitchingSession && ( + {error && !isLoading && ( )} {/* Session Content */} - {sessionId && !isLoading && !error && !isSwitchingSession && ( + {sessionId && !isLoading && !error && ( )} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/SSE_RECONNECTION.md b/autogpt_platform/frontend/src/components/contextual/Chat/SSE_RECONNECTION.md new file mode 100644 index 0000000000..9e78679f4e --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/SSE_RECONNECTION.md @@ -0,0 +1,159 @@ +# SSE Reconnection Contract for Long-Running Operations + +This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks. + +## Overview + +When a user triggers a long-running operation (like agent generation), the backend: + +1. Spawns a background task that survives SSE disconnections +2. Returns an `operation_started` response with a `task_id` +3. Stores stream messages in Redis Streams for replay + +Clients can reconnect to the task stream at any time to receive missed messages. + +## Client-Side Flow + +### 1. Receiving Operation Started + +When you receive an `operation_started` tool response: + +```typescript +// The response includes a task_id for reconnection +{ + type: "operation_started", + tool_name: "generate_agent", + operation_id: "uuid-...", + task_id: "task-uuid-...", // <-- Store this for reconnection + message: "Operation started. You can close this tab." +} +``` + +### 2. Storing Task Info + +Use the chat store to track the active task: + +```typescript +import { useChatStore } from "./chat-store"; + +// When operation_started is received: +useChatStore.getState().setActiveTask(sessionId, { + taskId: response.task_id, + operationId: response.operation_id, + toolName: response.tool_name, + lastMessageId: "0", +}); +``` + +### 3. Reconnecting to a Task + +To reconnect (e.g., after page refresh or tab reopen): + +```typescript +const { reconnectToTask, getActiveTask } = useChatStore.getState(); + +// Check if there's an active task for this session +const activeTask = getActiveTask(sessionId); + +if (activeTask) { + // Reconnect to the task stream + await reconnectToTask( + sessionId, + activeTask.taskId, + activeTask.lastMessageId, // Resume from last position + (chunk) => { + // Handle incoming chunks + console.log("Received chunk:", chunk); + }, + ); +} +``` + +### 4. Tracking Message Position + +To enable precise replay, update the last message ID as chunks arrive: + +```typescript +const { updateTaskLastMessageId } = useChatStore.getState(); + +function handleChunk(chunk: StreamChunk) { + // If chunk has an index/id, track it + if (chunk.idx !== undefined) { + updateTaskLastMessageId(sessionId, String(chunk.idx)); + } +} +``` + +## API Endpoints + +### Task Stream Reconnection + +``` +GET /api/chat/tasks/{taskId}/stream?last_message_id={idx} +``` + +- `taskId`: The task ID from `operation_started` +- `last_message_id`: Last received message index (default: "0" for full replay) + +Returns: SSE stream of missed messages + live updates + +## Chunk Types + +The reconnected stream follows the same Vercel AI SDK protocol: + +| Type | Description | +| ----------------------- | ----------------------- | +| `start` | Message lifecycle start | +| `text-delta` | Streaming text content | +| `text-end` | Text block completed | +| `tool-output-available` | Tool result available | +| `finish` | Stream completed | +| `error` | Error occurred | + +## Error Handling + +If reconnection fails: + +1. Check if task still exists (may have expired - default TTL: 1 hour) +2. Fall back to polling the session for final state +3. Show appropriate UI message to user + +## Persistence Considerations + +For robust reconnection across browser restarts: + +```typescript +// Store in localStorage/sessionStorage +const ACTIVE_TASKS_KEY = "chat_active_tasks"; + +function persistActiveTask(sessionId: string, task: ActiveTaskInfo) { + const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}"); + tasks[sessionId] = task; + localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks)); +} + +function loadPersistedTasks(): Record { + return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}"); +} +``` + +## Backend Configuration + +The following backend settings affect reconnection behavior: + +| Setting | Default | Description | +| ------------------- | ------- | ---------------------------------- | +| `stream_ttl` | 3600s | How long streams are kept in Redis | +| `stream_max_length` | 1000 | Max messages per stream | + +## Testing + +To test reconnection locally: + +1. Start a long-running operation (e.g., agent generation) +2. Note the `task_id` from the `operation_started` response +3. Close the browser tab +4. Reopen and call `reconnectToTask` with the saved `task_id` +5. Verify that missed messages are replayed + +See the main README for full local development setup. diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-constants.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-constants.ts new file mode 100644 index 0000000000..8802de2155 --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-constants.ts @@ -0,0 +1,16 @@ +/** + * Constants for the chat system. + * + * Centralizes magic strings and values used across chat components. + */ + +// LocalStorage keys +export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks"; + +// Redis Stream IDs +export const INITIAL_MESSAGE_ID = "0"; +export const INITIAL_STREAM_ID = "0-0"; + +// TTL values (in milliseconds) +export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes +export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts index 8229630e5d..3083f65d2c 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts @@ -1,6 +1,12 @@ "use client"; import { create } from "zustand"; +import { + ACTIVE_TASK_TTL_MS, + COMPLETED_STREAM_TTL_MS, + INITIAL_STREAM_ID, + STORAGE_KEY_ACTIVE_TASKS, +} from "./chat-constants"; import type { ActiveStream, StreamChunk, @@ -8,15 +14,59 @@ import type { StreamResult, StreamStatus, } from "./chat-types"; -import { executeStream } from "./stream-executor"; +import { executeStream, executeTaskReconnect } from "./stream-executor"; -const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes +export interface ActiveTaskInfo { + taskId: string; + sessionId: string; + operationId: string; + toolName: string; + lastMessageId: string; + startedAt: number; +} + +/** Load active tasks from localStorage */ +function loadPersistedTasks(): Map { + if (typeof window === "undefined") return new Map(); + try { + const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS); + if (!stored) return new Map(); + const parsed = JSON.parse(stored) as Record; + const now = Date.now(); + const tasks = new Map(); + // Filter out expired tasks + for (const [sessionId, task] of Object.entries(parsed)) { + if (now - task.startedAt < ACTIVE_TASK_TTL_MS) { + tasks.set(sessionId, task); + } + } + return tasks; + } catch { + return new Map(); + } +} + +/** Save active tasks to localStorage */ +function persistTasks(tasks: Map): void { + if (typeof window === "undefined") return; + try { + const obj: Record = {}; + for (const [sessionId, task] of tasks) { + obj[sessionId] = task; + } + localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj)); + } catch { + // Ignore storage errors + } +} interface ChatStoreState { activeStreams: Map; completedStreams: Map; activeSessions: Set; streamCompleteCallbacks: Set; + /** Active tasks for SSE reconnection - keyed by sessionId */ + activeTasks: Map; } interface ChatStoreActions { @@ -41,6 +91,24 @@ interface ChatStoreActions { unregisterActiveSession: (sessionId: string) => void; isSessionActive: (sessionId: string) => boolean; onStreamComplete: (callback: StreamCompleteCallback) => () => void; + /** Track active task for SSE reconnection */ + setActiveTask: ( + sessionId: string, + taskInfo: Omit, + ) => void; + /** Get active task for a session */ + getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined; + /** Clear active task when operation completes */ + clearActiveTask: (sessionId: string) => void; + /** Reconnect to an existing task stream */ + reconnectToTask: ( + sessionId: string, + taskId: string, + lastMessageId?: string, + onChunk?: (chunk: StreamChunk) => void, + ) => Promise; + /** Update last message ID for a task (for tracking replay position) */ + updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void; } type ChatStore = ChatStoreState & ChatStoreActions; @@ -64,18 +132,126 @@ function cleanupExpiredStreams( const now = Date.now(); const cleaned = new Map(completedStreams); for (const [sessionId, result] of cleaned) { - if (now - result.completedAt > COMPLETED_STREAM_TTL) { + if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) { cleaned.delete(sessionId); } } return cleaned; } +/** + * Finalize a stream by moving it from activeStreams to completedStreams. + * Also handles cleanup and notifications. + */ +function finalizeStream( + sessionId: string, + stream: ActiveStream, + onChunk: ((chunk: StreamChunk) => void) | undefined, + get: () => ChatStoreState & ChatStoreActions, + set: (state: Partial) => void, +): void { + if (onChunk) stream.onChunkCallbacks.delete(onChunk); + + if (stream.status !== "streaming") { + const currentState = get(); + const finalActiveStreams = new Map(currentState.activeStreams); + let finalCompletedStreams = new Map(currentState.completedStreams); + + const storedStream = finalActiveStreams.get(sessionId); + if (storedStream === stream) { + const result: StreamResult = { + sessionId, + status: stream.status, + chunks: stream.chunks, + completedAt: Date.now(), + error: stream.error, + }; + finalCompletedStreams.set(sessionId, result); + finalActiveStreams.delete(sessionId); + finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams); + set({ + activeStreams: finalActiveStreams, + completedStreams: finalCompletedStreams, + }); + + if (stream.status === "completed" || stream.status === "error") { + notifyStreamComplete(currentState.streamCompleteCallbacks, sessionId); + } + } + } +} + +/** + * Clean up an existing stream for a session and move it to completed streams. + * Returns updated maps for both active and completed streams. + */ +function cleanupExistingStream( + sessionId: string, + activeStreams: Map, + completedStreams: Map, + callbacks: Set, +): { + activeStreams: Map; + completedStreams: Map; +} { + const newActiveStreams = new Map(activeStreams); + let newCompletedStreams = new Map(completedStreams); + + const existingStream = newActiveStreams.get(sessionId); + if (existingStream) { + existingStream.abortController.abort(); + const normalizedStatus = + existingStream.status === "streaming" + ? "completed" + : existingStream.status; + const result: StreamResult = { + sessionId, + status: normalizedStatus, + chunks: existingStream.chunks, + completedAt: Date.now(), + error: existingStream.error, + }; + newCompletedStreams.set(sessionId, result); + newActiveStreams.delete(sessionId); + newCompletedStreams = cleanupExpiredStreams(newCompletedStreams); + if (normalizedStatus === "completed" || normalizedStatus === "error") { + notifyStreamComplete(callbacks, sessionId); + } + } + + return { + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + }; +} + +/** + * Create a new active stream with initial state. + */ +function createActiveStream( + sessionId: string, + onChunk?: (chunk: StreamChunk) => void, +): ActiveStream { + const abortController = new AbortController(); + const initialCallbacks = new Set<(chunk: StreamChunk) => void>(); + if (onChunk) initialCallbacks.add(onChunk); + + return { + sessionId, + abortController, + status: "streaming", + startedAt: Date.now(), + chunks: [], + onChunkCallbacks: initialCallbacks, + }; +} + export const useChatStore = create((set, get) => ({ activeStreams: new Map(), completedStreams: new Map(), activeSessions: new Set(), streamCompleteCallbacks: new Set(), + activeTasks: loadPersistedTasks(), startStream: async function startStream( sessionId, @@ -85,45 +261,21 @@ export const useChatStore = create((set, get) => ({ onChunk, ) { const state = get(); - const newActiveStreams = new Map(state.activeStreams); - let newCompletedStreams = new Map(state.completedStreams); const callbacks = state.streamCompleteCallbacks; - const existingStream = newActiveStreams.get(sessionId); - if (existingStream) { - existingStream.abortController.abort(); - const normalizedStatus = - existingStream.status === "streaming" - ? "completed" - : existingStream.status; - const result: StreamResult = { - sessionId, - status: normalizedStatus, - chunks: existingStream.chunks, - completedAt: Date.now(), - error: existingStream.error, - }; - newCompletedStreams.set(sessionId, result); - newActiveStreams.delete(sessionId); - newCompletedStreams = cleanupExpiredStreams(newCompletedStreams); - if (normalizedStatus === "completed" || normalizedStatus === "error") { - notifyStreamComplete(callbacks, sessionId); - } - } - - const abortController = new AbortController(); - const initialCallbacks = new Set<(chunk: StreamChunk) => void>(); - if (onChunk) initialCallbacks.add(onChunk); - - const stream: ActiveStream = { + // Clean up any existing stream for this session + const { + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + } = cleanupExistingStream( sessionId, - abortController, - status: "streaming", - startedAt: Date.now(), - chunks: [], - onChunkCallbacks: initialCallbacks, - }; + state.activeStreams, + state.completedStreams, + callbacks, + ); + // Create new stream + const stream = createActiveStream(sessionId, onChunk); newActiveStreams.set(sessionId, stream); set({ activeStreams: newActiveStreams, @@ -133,36 +285,7 @@ export const useChatStore = create((set, get) => ({ try { await executeStream(stream, message, isUserMessage, context); } finally { - if (onChunk) stream.onChunkCallbacks.delete(onChunk); - if (stream.status !== "streaming") { - const currentState = get(); - const finalActiveStreams = new Map(currentState.activeStreams); - let finalCompletedStreams = new Map(currentState.completedStreams); - - const storedStream = finalActiveStreams.get(sessionId); - if (storedStream === stream) { - const result: StreamResult = { - sessionId, - status: stream.status, - chunks: stream.chunks, - completedAt: Date.now(), - error: stream.error, - }; - finalCompletedStreams.set(sessionId, result); - finalActiveStreams.delete(sessionId); - finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams); - set({ - activeStreams: finalActiveStreams, - completedStreams: finalCompletedStreams, - }); - if (stream.status === "completed" || stream.status === "error") { - notifyStreamComplete( - currentState.streamCompleteCallbacks, - sessionId, - ); - } - } - } + finalizeStream(sessionId, stream, onChunk, get, set); } }, @@ -286,4 +409,93 @@ export const useChatStore = create((set, get) => ({ set({ streamCompleteCallbacks: cleanedCallbacks }); }; }, + + setActiveTask: function setActiveTask(sessionId, taskInfo) { + const state = get(); + const newActiveTasks = new Map(state.activeTasks); + newActiveTasks.set(sessionId, { + ...taskInfo, + sessionId, + startedAt: Date.now(), + }); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + }, + + getActiveTask: function getActiveTask(sessionId) { + return get().activeTasks.get(sessionId); + }, + + clearActiveTask: function clearActiveTask(sessionId) { + const state = get(); + if (!state.activeTasks.has(sessionId)) return; + + const newActiveTasks = new Map(state.activeTasks); + newActiveTasks.delete(sessionId); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + }, + + reconnectToTask: async function reconnectToTask( + sessionId, + taskId, + lastMessageId = INITIAL_STREAM_ID, + onChunk, + ) { + const state = get(); + const callbacks = state.streamCompleteCallbacks; + + // Clean up any existing stream for this session + const { + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + } = cleanupExistingStream( + sessionId, + state.activeStreams, + state.completedStreams, + callbacks, + ); + + // Create new stream for reconnection + const stream = createActiveStream(sessionId, onChunk); + newActiveStreams.set(sessionId, stream); + set({ + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + }); + + try { + await executeTaskReconnect(stream, taskId, lastMessageId); + } finally { + finalizeStream(sessionId, stream, onChunk, get, set); + + // Clear active task on completion + if (stream.status === "completed" || stream.status === "error") { + const taskState = get(); + if (taskState.activeTasks.has(sessionId)) { + const newActiveTasks = new Map(taskState.activeTasks); + newActiveTasks.delete(sessionId); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + } + } + } + }, + + updateTaskLastMessageId: function updateTaskLastMessageId( + sessionId, + lastMessageId, + ) { + const state = get(); + const task = state.activeTasks.get(sessionId); + if (!task) return; + + const newActiveTasks = new Map(state.activeTasks); + newActiveTasks.set(sessionId, { + ...task, + lastMessageId, + }); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + }, })); diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts index 8c8aa7b704..34813e17fe 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts @@ -4,6 +4,7 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error"; export interface StreamChunk { type: + | "stream_start" | "text_chunk" | "text_ended" | "tool_call" @@ -15,6 +16,7 @@ export interface StreamChunk { | "error" | "usage" | "stream_end"; + taskId?: string; timestamp?: string; content?: string; message?: string; @@ -41,7 +43,7 @@ export interface StreamChunk { } export type VercelStreamChunk = - | { type: "start"; messageId: string } + | { type: "start"; messageId: string; taskId?: string } | { type: "finish" } | { type: "text-start"; id: string } | { type: "text-delta"; id: string; delta: string } @@ -92,3 +94,70 @@ export interface StreamResult { } export type StreamCompleteCallback = (sessionId: string) => void; + +// Type guards for message types + +/** + * Check if a message has a toolId property. + */ +export function hasToolId( + msg: T, +): msg is T & { toolId: string } { + return ( + "toolId" in msg && + typeof (msg as Record).toolId === "string" + ); +} + +/** + * Check if a message has an operationId property. + */ +export function hasOperationId( + msg: T, +): msg is T & { operationId: string } { + return ( + "operationId" in msg && + typeof (msg as Record).operationId === "string" + ); +} + +/** + * Check if a message has a toolCallId property. + */ +export function hasToolCallId( + msg: T, +): msg is T & { toolCallId: string } { + return ( + "toolCallId" in msg && + typeof (msg as Record).toolCallId === "string" + ); +} + +/** + * Check if a message is an operation message type. + */ +export function isOperationMessage( + msg: T, +): msg is T & { + type: "operation_started" | "operation_pending" | "operation_in_progress"; +} { + return ( + msg.type === "operation_started" || + msg.type === "operation_pending" || + msg.type === "operation_in_progress" + ); +} + +/** + * Get the tool ID from a message if available. + * Checks toolId, operationId, and toolCallId properties. + */ +export function getToolIdFromMessage( + msg: T, +): string | undefined { + const record = msg as Record; + if (typeof record.toolId === "string") return record.toolId; + if (typeof record.operationId === "string") return record.operationId; + if (typeof record.toolCallId === "string") return record.toolCallId; + return undefined; +} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx index dec221338a..fbf2d5d143 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx @@ -2,7 +2,6 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessi import { Button } from "@/components/atoms/Button/Button"; import { Text } from "@/components/atoms/Text/Text"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; -import { useBreakpoint } from "@/lib/hooks/useBreakpoint"; import { cn } from "@/lib/utils"; import { GlobeHemisphereEastIcon } from "@phosphor-icons/react"; import { useEffect } from "react"; @@ -17,6 +16,13 @@ export interface ChatContainerProps { className?: string; onStreamingChange?: (isStreaming: boolean) => void; onOperationStarted?: () => void; + /** Active stream info from the server for reconnection */ + activeStream?: { + taskId: string; + lastMessageId: string; + operationId: string; + toolName: string; + }; } export function ChatContainer({ @@ -26,6 +32,7 @@ export function ChatContainer({ className, onStreamingChange, onOperationStarted, + activeStream, }: ChatContainerProps) { const { messages, @@ -41,16 +48,13 @@ export function ChatContainer({ initialMessages, initialPrompt, onOperationStarted, + activeStream, }); useEffect(() => { onStreamingChange?.(isStreaming); }, [isStreaming, onStreamingChange]); - const breakpoint = useBreakpoint(); - const isMobile = - breakpoint === "base" || breakpoint === "sm" || breakpoint === "md"; - return (
diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts index 82e9b05e88..af3b3329b7 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts @@ -2,6 +2,7 @@ import { toast } from "sonner"; import type { StreamChunk } from "../../chat-types"; import type { HandlerDependencies } from "./handlers"; import { + getErrorDisplayMessage, handleError, handleLoginNeeded, handleStreamEnd, @@ -24,16 +25,22 @@ export function createStreamEventDispatcher( chunk.type === "need_login" || chunk.type === "error" ) { - if (!deps.hasResponseRef.current) { - console.info("[ChatStream] First response chunk:", { - type: chunk.type, - sessionId: deps.sessionId, - }); - } deps.hasResponseRef.current = true; } switch (chunk.type) { + case "stream_start": + // Store task ID for SSE reconnection + if (chunk.taskId && deps.onActiveTaskStarted) { + deps.onActiveTaskStarted({ + taskId: chunk.taskId, + operationId: chunk.taskId, + toolName: "chat", + toolCallId: "chat_stream", + }); + } + break; + case "text_chunk": handleTextChunk(chunk, deps); break; @@ -56,11 +63,7 @@ export function createStreamEventDispatcher( break; case "stream_end": - console.info("[ChatStream] Stream ended:", { - sessionId: deps.sessionId, - hasResponse: deps.hasResponseRef.current, - chunkCount: deps.streamingChunksRef.current.length, - }); + // Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk handleStreamEnd(chunk, deps); break; @@ -70,7 +73,7 @@ export function createStreamEventDispatcher( // Show toast at dispatcher level to avoid circular dependencies if (!isRegionBlocked) { toast.error("Chat Error", { - description: chunk.message || chunk.content || "An error occurred", + description: getErrorDisplayMessage(chunk), }); } break; diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts index f3cac01f96..5aec5b9818 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts @@ -18,11 +18,19 @@ export interface HandlerDependencies { setStreamingChunks: Dispatch>; streamingChunksRef: MutableRefObject; hasResponseRef: MutableRefObject; + textFinalizedRef: MutableRefObject; + streamEndedRef: MutableRefObject; setMessages: Dispatch>; setIsStreamingInitiated: Dispatch>; setIsRegionBlockedModalOpen: Dispatch>; sessionId: string; onOperationStarted?: () => void; + onActiveTaskStarted?: (taskInfo: { + taskId: string; + operationId: string; + toolName: string; + toolCallId: string; + }) => void; } export function isRegionBlockedError(chunk: StreamChunk): boolean { @@ -32,6 +40,25 @@ export function isRegionBlockedError(chunk: StreamChunk): boolean { return message.toLowerCase().includes("not available in your region"); } +export function getUserFriendlyErrorMessage( + code: string | undefined, +): string | undefined { + switch (code) { + case "TASK_EXPIRED": + return "This operation has expired. Please try again."; + case "TASK_NOT_FOUND": + return "Could not find the requested operation."; + case "ACCESS_DENIED": + return "You do not have access to this operation."; + case "QUEUE_OVERFLOW": + return "Connection was interrupted. Please refresh to continue."; + case "MODEL_NOT_AVAILABLE_REGION": + return "This model is not available in your region."; + default: + return undefined; + } +} + export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) { if (!chunk.content) return; deps.setHasTextChunks(true); @@ -46,10 +73,15 @@ export function handleTextEnded( _chunk: StreamChunk, deps: HandlerDependencies, ) { + if (deps.textFinalizedRef.current) { + return; + } + const completedText = deps.streamingChunksRef.current.join(""); if (completedText.trim()) { + deps.textFinalizedRef.current = true; + deps.setMessages((prev) => { - // Check if this exact message already exists to prevent duplicates const exists = prev.some( (msg) => msg.type === "message" && @@ -76,9 +108,14 @@ export function handleToolCallStart( chunk: StreamChunk, deps: HandlerDependencies, ) { + // Use deterministic fallback instead of Date.now() to ensure same ID on replay + const toolId = + chunk.tool_id || + `tool-${deps.sessionId}-${chunk.idx ?? "unknown"}-${chunk.tool_name || "unknown"}`; + const toolCallMessage: Extract = { type: "tool_call", - toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`, + toolId, toolName: chunk.tool_name || "Executing", arguments: chunk.arguments || {}, timestamp: new Date(), @@ -111,6 +148,29 @@ export function handleToolCallStart( deps.setMessages(updateToolCallMessages); } +const TOOL_RESPONSE_TYPES = new Set([ + "tool_response", + "operation_started", + "operation_pending", + "operation_in_progress", + "execution_started", + "agent_carousel", + "clarification_needed", +]); + +function hasResponseForTool( + messages: ChatMessageData[], + toolId: string, +): boolean { + return messages.some((msg) => { + if (!TOOL_RESPONSE_TYPES.has(msg.type)) return false; + const msgToolId = + (msg as { toolId?: string }).toolId || + (msg as { toolCallId?: string }).toolCallId; + return msgToolId === toolId; + }); +} + export function handleToolResponse( chunk: StreamChunk, deps: HandlerDependencies, @@ -152,31 +212,49 @@ export function handleToolResponse( ) { const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name); if (inputsMessage) { - deps.setMessages((prev) => [...prev, inputsMessage]); + deps.setMessages((prev) => { + // Check for duplicate inputs_needed message + const exists = prev.some((msg) => msg.type === "inputs_needed"); + if (exists) return prev; + return [...prev, inputsMessage]; + }); } const credentialsMessage = extractCredentialsNeeded( parsedResult, chunk.tool_name, ); if (credentialsMessage) { - deps.setMessages((prev) => [...prev, credentialsMessage]); + deps.setMessages((prev) => { + // Check for duplicate credentials_needed message + const exists = prev.some((msg) => msg.type === "credentials_needed"); + if (exists) return prev; + return [...prev, credentialsMessage]; + }); } } return; } - // Trigger polling when operation_started is received if (responseMessage.type === "operation_started") { deps.onOperationStarted?.(); + const taskId = (responseMessage as { taskId?: string }).taskId; + if (taskId && deps.onActiveTaskStarted) { + deps.onActiveTaskStarted({ + taskId, + operationId: + (responseMessage as { operationId?: string }).operationId || "", + toolName: (responseMessage as { toolName?: string }).toolName || "", + toolCallId: (responseMessage as { toolId?: string }).toolId || "", + }); + } } deps.setMessages((prev) => { const toolCallIndex = prev.findIndex( (msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id, ); - const hasResponse = prev.some( - (msg) => msg.type === "tool_response" && msg.toolId === chunk.tool_id, - ); - if (hasResponse) return prev; + if (hasResponseForTool(prev, chunk.tool_id!)) { + return prev; + } if (toolCallIndex !== -1) { const newMessages = [...prev]; newMessages.splice(toolCallIndex + 1, 0, responseMessage); @@ -198,28 +276,48 @@ export function handleLoginNeeded( agentInfo: chunk.agent_info, timestamp: new Date(), }; - deps.setMessages((prev) => [...prev, loginNeededMessage]); + deps.setMessages((prev) => { + // Check for duplicate login_needed message + const exists = prev.some((msg) => msg.type === "login_needed"); + if (exists) return prev; + return [...prev, loginNeededMessage]; + }); } export function handleStreamEnd( _chunk: StreamChunk, deps: HandlerDependencies, ) { + if (deps.streamEndedRef.current) { + return; + } + deps.streamEndedRef.current = true; + const completedContent = deps.streamingChunksRef.current.join(""); if (!completedContent.trim() && !deps.hasResponseRef.current) { - deps.setMessages((prev) => [ - ...prev, - { - type: "message", - role: "assistant", - content: "No response received. Please try again.", - timestamp: new Date(), - }, - ]); - } - if (completedContent.trim()) { deps.setMessages((prev) => { - // Check if this exact message already exists to prevent duplicates + const exists = prev.some( + (msg) => + msg.type === "message" && + msg.role === "assistant" && + msg.content === "No response received. Please try again.", + ); + if (exists) return prev; + return [ + ...prev, + { + type: "message", + role: "assistant", + content: "No response received. Please try again.", + timestamp: new Date(), + }, + ]; + }); + } + if (completedContent.trim() && !deps.textFinalizedRef.current) { + deps.textFinalizedRef.current = true; + + deps.setMessages((prev) => { const exists = prev.some( (msg) => msg.type === "message" && @@ -244,8 +342,6 @@ export function handleStreamEnd( } export function handleError(chunk: StreamChunk, deps: HandlerDependencies) { - const errorMessage = chunk.message || chunk.content || "An error occurred"; - console.error("Stream error:", errorMessage); if (isRegionBlockedError(chunk)) { deps.setIsRegionBlockedModalOpen(true); } @@ -253,4 +349,14 @@ export function handleError(chunk: StreamChunk, deps: HandlerDependencies) { deps.setHasTextChunks(false); deps.setStreamingChunks([]); deps.streamingChunksRef.current = []; + deps.textFinalizedRef.current = false; + deps.streamEndedRef.current = true; +} + +export function getErrorDisplayMessage(chunk: StreamChunk): string { + const friendlyMessage = getUserFriendlyErrorMessage(chunk.code); + if (friendlyMessage) { + return friendlyMessage; + } + return chunk.message || chunk.content || "An error occurred"; } diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts index e744c9bc34..f1e94cea17 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts @@ -349,6 +349,7 @@ export function parseToolResponse( toolName: (parsedResult.tool_name as string) || toolName, toolId, operationId: (parsedResult.operation_id as string) || "", + taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection message: (parsedResult.message as string) || "Operation started. You can close this tab.", diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts index 46f384d055..248383df42 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts @@ -1,10 +1,17 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; import { useEffect, useMemo, useRef, useState } from "react"; +import { INITIAL_STREAM_ID } from "../../chat-constants"; import { useChatStore } from "../../chat-store"; import { toast } from "sonner"; import { useChatStream } from "../../useChatStream"; import { usePageContext } from "../../usePageContext"; import type { ChatMessageData } from "../ChatMessage/useChatMessage"; +import { + getToolIdFromMessage, + hasToolId, + isOperationMessage, + type StreamChunk, +} from "../../chat-types"; import { createStreamEventDispatcher } from "./createStreamEventDispatcher"; import { createUserMessage, @@ -14,6 +21,13 @@ import { processInitialMessages, } from "./helpers"; +const TOOL_RESULT_TYPES = new Set([ + "tool_response", + "agent_carousel", + "execution_started", + "clarification_needed", +]); + // Helper to generate deduplication key for a message function getMessageKey(msg: ChatMessageData): string { if (msg.type === "message") { @@ -23,14 +37,18 @@ function getMessageKey(msg: ChatMessageData): string { return `msg:${msg.role}:${msg.content}`; } else if (msg.type === "tool_call") { return `toolcall:${msg.toolId}`; - } else if (msg.type === "tool_response") { - return `toolresponse:${(msg as any).toolId}`; - } else if ( - msg.type === "operation_started" || - msg.type === "operation_pending" || - msg.type === "operation_in_progress" - ) { - return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`; + } else if (TOOL_RESULT_TYPES.has(msg.type)) { + // Unified key for all tool result types - same toolId with different types + // (tool_response vs agent_carousel) should deduplicate to the same key + const toolId = getToolIdFromMessage(msg); + // If no toolId, fall back to content-based key to avoid empty key collisions + if (!toolId) { + return `toolresult:content:${JSON.stringify(msg).slice(0, 200)}`; + } + return `toolresult:${toolId}`; + } else if (isOperationMessage(msg)) { + const toolId = getToolIdFromMessage(msg) || ""; + return `op:${toolId}:${msg.toolName}`; } else { return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`; } @@ -41,6 +59,13 @@ interface Args { initialMessages: SessionDetailResponse["messages"]; initialPrompt?: string; onOperationStarted?: () => void; + /** Active stream info from the server for reconnection */ + activeStream?: { + taskId: string; + lastMessageId: string; + operationId: string; + toolName: string; + }; } export function useChatContainer({ @@ -48,6 +73,7 @@ export function useChatContainer({ initialMessages, initialPrompt, onOperationStarted, + activeStream, }: Args) { const [messages, setMessages] = useState([]); const [streamingChunks, setStreamingChunks] = useState([]); @@ -57,6 +83,8 @@ export function useChatContainer({ useState(false); const hasResponseRef = useRef(false); const streamingChunksRef = useRef([]); + const textFinalizedRef = useRef(false); + const streamEndedRef = useRef(false); const previousSessionIdRef = useRef(null); const { error, @@ -65,44 +93,182 @@ export function useChatContainer({ } = useChatStream(); const activeStreams = useChatStore((s) => s.activeStreams); const subscribeToStream = useChatStore((s) => s.subscribeToStream); + const setActiveTask = useChatStore((s) => s.setActiveTask); + const getActiveTask = useChatStore((s) => s.getActiveTask); + const reconnectToTask = useChatStore((s) => s.reconnectToTask); const isStreaming = isStreamingInitiated || hasTextChunks; + // Track whether we've already connected to this activeStream to avoid duplicate connections + const connectedActiveStreamRef = useRef(null); + // Track if component is mounted to prevent state updates after unmount + const isMountedRef = useRef(true); + // Track current dispatcher to prevent multiple dispatchers from adding messages + const currentDispatcherIdRef = useRef(0); + + // Set mounted flag - reset on every mount, cleanup on unmount + useEffect(function trackMountedState() { + isMountedRef.current = true; + return function cleanup() { + isMountedRef.current = false; + }; + }, []); + + // Callback to store active task info for SSE reconnection + function handleActiveTaskStarted(taskInfo: { + taskId: string; + operationId: string; + toolName: string; + toolCallId: string; + }) { + if (!sessionId) return; + setActiveTask(sessionId, { + taskId: taskInfo.taskId, + operationId: taskInfo.operationId, + toolName: taskInfo.toolName, + lastMessageId: INITIAL_STREAM_ID, + }); + } + + // Create dispatcher for stream events - stable reference for current sessionId + // Each dispatcher gets a unique ID to prevent stale dispatchers from updating state + function createDispatcher() { + if (!sessionId) return () => {}; + // Increment dispatcher ID - only the most recent dispatcher should update state + const dispatcherId = ++currentDispatcherIdRef.current; + + const baseDispatcher = createStreamEventDispatcher({ + setHasTextChunks, + setStreamingChunks, + streamingChunksRef, + hasResponseRef, + textFinalizedRef, + streamEndedRef, + setMessages, + setIsRegionBlockedModalOpen, + sessionId, + setIsStreamingInitiated, + onOperationStarted, + onActiveTaskStarted: handleActiveTaskStarted, + }); + + // Wrap dispatcher to check if it's still the current one + return function guardedDispatcher(chunk: StreamChunk) { + // Skip if component unmounted or this is a stale dispatcher + if (!isMountedRef.current) { + return; + } + if (dispatcherId !== currentDispatcherIdRef.current) { + return; + } + baseDispatcher(chunk); + }; + } useEffect( function handleSessionChange() { - if (sessionId === previousSessionIdRef.current) return; + const isSessionChange = sessionId !== previousSessionIdRef.current; - const prevSession = previousSessionIdRef.current; - if (prevSession) { - stopStreaming(prevSession); + // Handle session change - reset state + if (isSessionChange) { + const prevSession = previousSessionIdRef.current; + if (prevSession) { + stopStreaming(prevSession); + } + previousSessionIdRef.current = sessionId; + connectedActiveStreamRef.current = null; + setMessages([]); + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + setIsStreamingInitiated(false); + hasResponseRef.current = false; + textFinalizedRef.current = false; + streamEndedRef.current = false; } - previousSessionIdRef.current = sessionId; - setMessages([]); - setStreamingChunks([]); - streamingChunksRef.current = []; - setHasTextChunks(false); - setIsStreamingInitiated(false); - hasResponseRef.current = false; if (!sessionId) return; - const activeStream = activeStreams.get(sessionId); - if (!activeStream || activeStream.status !== "streaming") return; + // Priority 1: Check if server told us there's an active stream (most authoritative) + if (activeStream) { + const streamKey = `${sessionId}:${activeStream.taskId}`; - const dispatcher = createStreamEventDispatcher({ - setHasTextChunks, - setStreamingChunks, - streamingChunksRef, - hasResponseRef, - setMessages, - setIsRegionBlockedModalOpen, - sessionId, - setIsStreamingInitiated, - onOperationStarted, - }); + if (connectedActiveStreamRef.current === streamKey) { + return; + } + + // Skip if there's already an active stream for this session in the store + const existingStream = activeStreams.get(sessionId); + if (existingStream && existingStream.status === "streaming") { + connectedActiveStreamRef.current = streamKey; + return; + } + + connectedActiveStreamRef.current = streamKey; + + // Clear all state before reconnection to prevent duplicates + // Server's initialMessages is authoritative; local state will be rebuilt from SSE replay + setMessages([]); + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + textFinalizedRef.current = false; + streamEndedRef.current = false; + hasResponseRef.current = false; + + setIsStreamingInitiated(true); + setActiveTask(sessionId, { + taskId: activeStream.taskId, + operationId: activeStream.operationId, + toolName: activeStream.toolName, + lastMessageId: activeStream.lastMessageId, + }); + reconnectToTask( + sessionId, + activeStream.taskId, + activeStream.lastMessageId, + createDispatcher(), + ); + // Don't return cleanup here - the guarded dispatcher handles stale events + // and the stream will complete naturally. Cleanup would prematurely stop + // the stream when effect re-runs due to activeStreams changing. + return; + } + + // Only check localStorage/in-memory on session change + if (!isSessionChange) return; + + // Priority 2: Check localStorage for active task + const activeTask = getActiveTask(sessionId); + if (activeTask) { + // Clear all state before reconnection to prevent duplicates + // Server's initialMessages is authoritative; local state will be rebuilt from SSE replay + setMessages([]); + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + textFinalizedRef.current = false; + streamEndedRef.current = false; + hasResponseRef.current = false; + + setIsStreamingInitiated(true); + reconnectToTask( + sessionId, + activeTask.taskId, + activeTask.lastMessageId, + createDispatcher(), + ); + // Don't return cleanup here - the guarded dispatcher handles stale events + return; + } + + // Priority 3: Check for an in-memory active stream (same-tab scenario) + const inMemoryStream = activeStreams.get(sessionId); + if (!inMemoryStream || inMemoryStream.status !== "streaming") { + return; + } setIsStreamingInitiated(true); const skipReplay = initialMessages.length > 0; - return subscribeToStream(sessionId, dispatcher, skipReplay); + return subscribeToStream(sessionId, createDispatcher(), skipReplay); }, [ sessionId, @@ -110,6 +276,10 @@ export function useChatContainer({ activeStreams, subscribeToStream, onOperationStarted, + getActiveTask, + reconnectToTask, + activeStream, + setActiveTask, ], ); @@ -124,7 +294,7 @@ export function useChatContainer({ msg.type === "agent_carousel" || msg.type === "execution_started" ) { - const toolId = (msg as any).toolId; + const toolId = hasToolId(msg) ? msg.toolId : undefined; if (toolId) { ids.add(toolId); } @@ -141,12 +311,8 @@ export function useChatContainer({ setMessages((prev) => { const filtered = prev.filter((msg) => { - if ( - msg.type === "operation_started" || - msg.type === "operation_pending" || - msg.type === "operation_in_progress" - ) { - const toolId = (msg as any).toolId || (msg as any).toolCallId; + if (isOperationMessage(msg)) { + const toolId = getToolIdFromMessage(msg); if (toolId && completedToolIds.has(toolId)) { return false; // Remove - operation completed } @@ -174,12 +340,8 @@ export function useChatContainer({ // Filter local messages: remove duplicates and completed operation messages const newLocalMessages = messages.filter((msg) => { // Remove operation messages for completed tools - if ( - msg.type === "operation_started" || - msg.type === "operation_pending" || - msg.type === "operation_in_progress" - ) { - const toolId = (msg as any).toolId || (msg as any).toolCallId; + if (isOperationMessage(msg)) { + const toolId = getToolIdFromMessage(msg); if (toolId && completedToolIds.has(toolId)) { return false; } @@ -190,7 +352,70 @@ export function useChatContainer({ }); // Server messages first (correct order), then new local messages - return [...processedInitial, ...newLocalMessages]; + const combined = [...processedInitial, ...newLocalMessages]; + + // Post-processing: Remove duplicate assistant messages that can occur during + // race conditions (e.g., rapid screen switching during SSE reconnection). + // Two assistant messages are considered duplicates if: + // - They are both text messages with role "assistant" + // - One message's content starts with the other's content (partial vs complete) + // - Or they have very similar content (>80% overlap at the start) + const deduplicated: ChatMessageData[] = []; + for (let i = 0; i < combined.length; i++) { + const current = combined[i]; + + // Check if this is an assistant text message + if (current.type !== "message" || current.role !== "assistant") { + deduplicated.push(current); + continue; + } + + // Look for duplicate assistant messages in the rest of the array + let dominated = false; + for (let j = 0; j < combined.length; j++) { + if (i === j) continue; + const other = combined[j]; + if (other.type !== "message" || other.role !== "assistant") continue; + + const currentContent = current.content || ""; + const otherContent = other.content || ""; + + // Skip empty messages + if (!currentContent.trim() || !otherContent.trim()) continue; + + // Check if current is a prefix of other (current is incomplete version) + if ( + otherContent.length > currentContent.length && + otherContent.startsWith(currentContent.slice(0, 100)) + ) { + // Current is a shorter/incomplete version of other - skip it + dominated = true; + break; + } + + // Check if messages are nearly identical (within a small difference) + // This catches cases where content differs only slightly + const minLen = Math.min(currentContent.length, otherContent.length); + const compareLen = Math.min(minLen, 200); // Compare first 200 chars + if ( + compareLen > 50 && + currentContent.slice(0, compareLen) === + otherContent.slice(0, compareLen) + ) { + // Same prefix - keep the longer one + if (otherContent.length > currentContent.length) { + dominated = true; + break; + } + } + } + + if (!dominated) { + deduplicated.push(current); + } + } + + return deduplicated; }, [initialMessages, messages, completedToolIds]); async function sendMessage( @@ -198,10 +423,8 @@ export function useChatContainer({ isUserMessage: boolean = true, context?: { url: string; content: string }, ) { - if (!sessionId) { - console.error("[useChatContainer] Cannot send message: no session ID"); - return; - } + if (!sessionId) return; + setIsRegionBlockedModalOpen(false); if (isUserMessage) { const userMessage = createUserMessage(content); @@ -214,31 +437,19 @@ export function useChatContainer({ setHasTextChunks(false); setIsStreamingInitiated(true); hasResponseRef.current = false; - - const dispatcher = createStreamEventDispatcher({ - setHasTextChunks, - setStreamingChunks, - streamingChunksRef, - hasResponseRef, - setMessages, - setIsRegionBlockedModalOpen, - sessionId, - setIsStreamingInitiated, - onOperationStarted, - }); + textFinalizedRef.current = false; + streamEndedRef.current = false; try { await sendStreamMessage( sessionId, content, - dispatcher, + createDispatcher(), isUserMessage, context, ); } catch (err) { - console.error("[useChatContainer] Failed to send message:", err); setIsStreamingInitiated(false); - if (err instanceof Error && err.name === "AbortError") return; const errorMessage = diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx index 521f6f6320..bac004f6ed 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx @@ -57,6 +57,7 @@ export function ChatInput({ isStreaming, value, baseHandleKeyDown, + inputId, }); return ( @@ -73,19 +74,20 @@ export function ChatInput({ hasMultipleLines ? "rounded-xlarge" : "rounded-full", )} > + {!value && !isRecording && ( + + )}