diff --git a/.gitignore b/.gitignore index 1a2291b516..012a0b5227 100644 --- a/.gitignore +++ b/.gitignore @@ -180,3 +180,4 @@ autogpt_platform/backend/settings.py .claude/settings.local.json CLAUDE.local.md /autogpt_platform/backend/logs +.next \ No newline at end of file diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py new file mode 100644 index 0000000000..f447d46bd7 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py @@ -0,0 +1,368 @@ +"""Redis Streams consumer for operation completion messages. + +This module provides a consumer (ChatCompletionConsumer) that listens for +completion notifications (OperationCompleteMessage) from external services +(like Agent Generator) and triggers the appropriate stream registry and +chat service updates via process_operation_success/process_operation_failure. + +Why Redis Streams instead of RabbitMQ? +-------------------------------------- +While the project typically uses RabbitMQ for async task queues (e.g., execution +queue), Redis Streams was chosen for chat completion notifications because: + +1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis + Streams (via stream_registry) for message persistence and replay. Using Redis + Streams for completion notifications keeps all chat streaming infrastructure + in one system, simplifying operations and reducing cross-system coordination. + +2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs, + allowing consumers to replay missed messages after reconnection. This aligns + with the SSE reconnection pattern where clients can resume from last_message_id. + +3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic + load balancing across pods with explicit message claiming (XAUTOCLAIM) for + recovering from dead consumers - ideal for the completion callback pattern. + +4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for + stream_registry) provides lower latency than an additional RabbitMQ hop. + +5. **Atomicity with Task State**: Completion processing often needs to update + task metadata stored in Redis. Keeping both in Redis enables simpler + transactional semantics without distributed coordination. + +The consumer uses Redis Streams with consumer groups for reliable message +processing across multiple platform pods, with XAUTOCLAIM for reclaiming +stale pending messages from dead consumers. +""" + +import asyncio +import logging +import os +import uuid +from typing import Any + +import orjson +from prisma import Prisma +from pydantic import BaseModel +from redis.exceptions import ResponseError + +from backend.data.redis_client import get_redis_async + +from . import stream_registry +from .completion_handler import process_operation_failure, process_operation_success +from .config import ChatConfig + +logger = logging.getLogger(__name__) +config = ChatConfig() + + +class OperationCompleteMessage(BaseModel): + """Message format for operation completion notifications.""" + + operation_id: str + task_id: str + success: bool + result: dict | str | None = None + error: str | None = None + + +class ChatCompletionConsumer: + """Consumer for chat operation completion messages from Redis Streams. + + This consumer initializes its own Prisma client in start() to ensure + database operations work correctly within this async context. + + Uses Redis consumer groups to allow multiple platform pods to consume + messages reliably with automatic redelivery on failure. + """ + + def __init__(self): + self._consumer_task: asyncio.Task | None = None + self._running = False + self._prisma: Prisma | None = None + self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}" + + async def start(self) -> None: + """Start the completion consumer.""" + if self._running: + logger.warning("Completion consumer already running") + return + + # Create consumer group if it doesn't exist + try: + redis = await get_redis_async() + await redis.xgroup_create( + config.stream_completion_name, + config.stream_consumer_group, + id="0", + mkstream=True, + ) + logger.info( + f"Created consumer group '{config.stream_consumer_group}' " + f"on stream '{config.stream_completion_name}'" + ) + except ResponseError as e: + if "BUSYGROUP" in str(e): + logger.debug( + f"Consumer group '{config.stream_consumer_group}' already exists" + ) + else: + raise + + self._running = True + self._consumer_task = asyncio.create_task(self._consume_messages()) + logger.info( + f"Chat completion consumer started (consumer: {self._consumer_name})" + ) + + async def _ensure_prisma(self) -> Prisma: + """Lazily initialize Prisma client on first use.""" + if self._prisma is None: + database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432") + self._prisma = Prisma(datasource={"url": database_url}) + await self._prisma.connect() + logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)") + return self._prisma + + async def stop(self) -> None: + """Stop the completion consumer.""" + self._running = False + + if self._consumer_task: + self._consumer_task.cancel() + try: + await self._consumer_task + except asyncio.CancelledError: + pass + self._consumer_task = None + + if self._prisma: + await self._prisma.disconnect() + self._prisma = None + logger.info("[COMPLETION] Consumer Prisma client disconnected") + + logger.info("Chat completion consumer stopped") + + async def _consume_messages(self) -> None: + """Main message consumption loop with retry logic.""" + max_retries = 10 + retry_delay = 5 # seconds + retry_count = 0 + block_timeout = 5000 # milliseconds + + while self._running and retry_count < max_retries: + try: + redis = await get_redis_async() + + # Reset retry count on successful connection + retry_count = 0 + + while self._running: + # First, claim any stale pending messages from dead consumers + # Redis does NOT auto-redeliver pending messages; we must explicitly + # claim them using XAUTOCLAIM + try: + claimed_result = await redis.xautoclaim( + name=config.stream_completion_name, + groupname=config.stream_consumer_group, + consumername=self._consumer_name, + min_idle_time=config.stream_claim_min_idle_ms, + start_id="0-0", + count=10, + ) + # xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids]) + if claimed_result and len(claimed_result) >= 2: + claimed_entries = claimed_result[1] + if claimed_entries: + logger.info( + f"Claimed {len(claimed_entries)} stale pending messages" + ) + for entry_id, data in claimed_entries: + if not self._running: + return + await self._process_entry(redis, entry_id, data) + except Exception as e: + logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}") + + # Read new messages from the stream + messages = await redis.xreadgroup( + groupname=config.stream_consumer_group, + consumername=self._consumer_name, + streams={config.stream_completion_name: ">"}, + block=block_timeout, + count=10, + ) + + if not messages: + continue + + for stream_name, entries in messages: + for entry_id, data in entries: + if not self._running: + return + await self._process_entry(redis, entry_id, data) + + except asyncio.CancelledError: + logger.info("Consumer cancelled") + return + except Exception as e: + retry_count += 1 + logger.error( + f"Consumer error (retry {retry_count}/{max_retries}): {e}", + exc_info=True, + ) + if self._running and retry_count < max_retries: + await asyncio.sleep(retry_delay) + else: + logger.error("Max retries reached, stopping consumer") + return + + async def _process_entry( + self, redis: Any, entry_id: str, data: dict[str, Any] + ) -> None: + """Process a single stream entry and acknowledge it on success. + + Args: + redis: Redis client connection + entry_id: The stream entry ID + data: The entry data dict + """ + try: + # Handle the message + message_data = data.get("data") + if message_data: + await self._handle_message( + message_data.encode() + if isinstance(message_data, str) + else message_data + ) + + # Acknowledge the message after successful processing + await redis.xack( + config.stream_completion_name, + config.stream_consumer_group, + entry_id, + ) + except Exception as e: + logger.error( + f"Error processing completion message {entry_id}: {e}", + exc_info=True, + ) + # Message remains in pending state and will be claimed by + # XAUTOCLAIM after min_idle_time expires + + async def _handle_message(self, body: bytes) -> None: + """Handle a completion message using our own Prisma client.""" + try: + data = orjson.loads(body) + message = OperationCompleteMessage(**data) + except Exception as e: + logger.error(f"Failed to parse completion message: {e}") + return + + logger.info( + f"[COMPLETION] Received completion for operation {message.operation_id} " + f"(task_id={message.task_id}, success={message.success})" + ) + + # Find task in registry + task = await stream_registry.find_task_by_operation_id(message.operation_id) + if task is None: + task = await stream_registry.get_task(message.task_id) + + if task is None: + logger.warning( + f"[COMPLETION] Task not found for operation {message.operation_id} " + f"(task_id={message.task_id})" + ) + return + + logger.info( + f"[COMPLETION] Found task: task_id={task.task_id}, " + f"session_id={task.session_id}, tool_call_id={task.tool_call_id}" + ) + + # Guard against empty task fields + if not task.task_id or not task.session_id or not task.tool_call_id: + logger.error( + f"[COMPLETION] Task has empty critical fields! " + f"task_id={task.task_id!r}, session_id={task.session_id!r}, " + f"tool_call_id={task.tool_call_id!r}" + ) + return + + if message.success: + await self._handle_success(task, message) + else: + await self._handle_failure(task, message) + + async def _handle_success( + self, + task: stream_registry.ActiveTask, + message: OperationCompleteMessage, + ) -> None: + """Handle successful operation completion.""" + prisma = await self._ensure_prisma() + await process_operation_success(task, message.result, prisma) + + async def _handle_failure( + self, + task: stream_registry.ActiveTask, + message: OperationCompleteMessage, + ) -> None: + """Handle failed operation completion.""" + prisma = await self._ensure_prisma() + await process_operation_failure(task, message.error, prisma) + + +# Module-level consumer instance +_consumer: ChatCompletionConsumer | None = None + + +async def start_completion_consumer() -> None: + """Start the global completion consumer.""" + global _consumer + if _consumer is None: + _consumer = ChatCompletionConsumer() + await _consumer.start() + + +async def stop_completion_consumer() -> None: + """Stop the global completion consumer.""" + global _consumer + if _consumer: + await _consumer.stop() + _consumer = None + + +async def publish_operation_complete( + operation_id: str, + task_id: str, + success: bool, + result: dict | str | None = None, + error: str | None = None, +) -> None: + """Publish an operation completion message to Redis Streams. + + Args: + operation_id: The operation ID that completed. + task_id: The task ID associated with the operation. + success: Whether the operation succeeded. + result: The result data (for success). + error: The error message (for failure). + """ + message = OperationCompleteMessage( + operation_id=operation_id, + task_id=task_id, + success=success, + result=result, + error=error, + ) + + redis = await get_redis_async() + await redis.xadd( + config.stream_completion_name, + {"data": message.model_dump_json()}, + maxlen=config.stream_max_length, + ) + logger.info(f"Published completion for operation {operation_id}") diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_handler.py b/autogpt_platform/backend/backend/api/features/chat/completion_handler.py new file mode 100644 index 0000000000..905fa2ddba --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/completion_handler.py @@ -0,0 +1,344 @@ +"""Shared completion handling for operation success and failure. + +This module provides common logic for handling operation completion from both: +- The Redis Streams consumer (completion_consumer.py) +- The HTTP webhook endpoint (routes.py) +""" + +import logging +from typing import Any + +import orjson +from prisma import Prisma + +from . import service as chat_service +from . import stream_registry +from .response_model import StreamError, StreamToolOutputAvailable +from .tools.models import ErrorResponse + +logger = logging.getLogger(__name__) + +# Tools that produce agent_json that needs to be saved to library +AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"} + +# Keys that should be stripped from agent_json when returning in error responses +SENSITIVE_KEYS = frozenset( + { + "api_key", + "apikey", + "api_secret", + "password", + "secret", + "credentials", + "credential", + "token", + "access_token", + "refresh_token", + "private_key", + "privatekey", + "auth", + "authorization", + } +) + + +def _sanitize_agent_json(obj: Any) -> Any: + """Recursively sanitize agent_json by removing sensitive keys. + + Args: + obj: The object to sanitize (dict, list, or primitive) + + Returns: + Sanitized copy with sensitive keys removed/redacted + """ + if isinstance(obj, dict): + return { + k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v) + for k, v in obj.items() + } + elif isinstance(obj, list): + return [_sanitize_agent_json(item) for item in obj] + else: + return obj + + +class ToolMessageUpdateError(Exception): + """Raised when updating a tool message in the database fails.""" + + pass + + +async def _update_tool_message( + session_id: str, + tool_call_id: str, + content: str, + prisma_client: Prisma | None, +) -> None: + """Update tool message in database. + + Args: + session_id: The session ID + tool_call_id: The tool call ID to update + content: The new content for the message + prisma_client: Optional Prisma client. If None, uses chat_service. + + Raises: + ToolMessageUpdateError: If the database update fails. The caller should + handle this to avoid marking the task as completed with inconsistent state. + """ + try: + if prisma_client: + # Use provided Prisma client (for consumer with its own connection) + updated_count = await prisma_client.chatmessage.update_many( + where={ + "sessionId": session_id, + "toolCallId": tool_call_id, + }, + data={"content": content}, + ) + # Check if any rows were updated - 0 means message not found + if updated_count == 0: + raise ToolMessageUpdateError( + f"No message found with tool_call_id={tool_call_id} in session {session_id}" + ) + else: + # Use service function (for webhook endpoint) + await chat_service._update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=content, + ) + except ToolMessageUpdateError: + raise + except Exception as e: + logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True) + raise ToolMessageUpdateError( + f"Failed to update tool message for tool_call_id={tool_call_id}: {e}" + ) from e + + +def serialize_result(result: dict | list | str | int | float | bool | None) -> str: + """Serialize result to JSON string with sensible defaults. + + Args: + result: The result to serialize. Can be a dict, list, string, + number, boolean, or None. + + Returns: + JSON string representation of the result. Returns '{"status": "completed"}' + only when result is explicitly None. + """ + if isinstance(result, str): + return result + if result is None: + return '{"status": "completed"}' + return orjson.dumps(result).decode("utf-8") + + +async def _save_agent_from_result( + result: dict[str, Any], + user_id: str | None, + tool_name: str, +) -> dict[str, Any]: + """Save agent to library if result contains agent_json. + + Args: + result: The result dict that may contain agent_json + user_id: The user ID to save the agent for + tool_name: The tool name (create_agent or edit_agent) + + Returns: + Updated result dict with saved agent details, or original result if no agent_json + """ + if not user_id: + logger.warning("[COMPLETION] Cannot save agent: no user_id in task") + return result + + agent_json = result.get("agent_json") + if not agent_json: + logger.warning( + f"[COMPLETION] {tool_name} completed but no agent_json in result" + ) + return result + + try: + from .tools.agent_generator import save_agent_to_library + + is_update = tool_name == "edit_agent" + created_graph, library_agent = await save_agent_to_library( + agent_json, user_id, is_update=is_update + ) + + logger.info( + f"[COMPLETION] Saved agent '{created_graph.name}' to library " + f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})" + ) + + # Return a response similar to AgentSavedResponse + return { + "type": "agent_saved", + "message": f"Agent '{created_graph.name}' has been saved to your library!", + "agent_id": created_graph.id, + "agent_name": created_graph.name, + "library_agent_id": library_agent.id, + "library_agent_link": f"/library/agents/{library_agent.id}", + "agent_page_link": f"/build?flowID={created_graph.id}", + } + except Exception as e: + logger.error( + f"[COMPLETION] Failed to save agent to library: {e}", + exc_info=True, + ) + # Return error but don't fail the whole operation + # Sanitize agent_json to remove sensitive keys before returning + return { + "type": "error", + "message": f"Agent was generated but failed to save: {str(e)}", + "error": str(e), + "agent_json": _sanitize_agent_json(agent_json), + } + + +async def process_operation_success( + task: stream_registry.ActiveTask, + result: dict | str | None, + prisma_client: Prisma | None = None, +) -> None: + """Handle successful operation completion. + + Publishes the result to the stream registry, updates the database, + generates LLM continuation, and marks the task as completed. + + Args: + task: The active task that completed + result: The result data from the operation + prisma_client: Optional Prisma client for database operations. + If None, uses chat_service._update_pending_operation instead. + + Raises: + ToolMessageUpdateError: If the database update fails. The task will be + marked as failed instead of completed to avoid inconsistent state. + """ + # For agent generation tools, save the agent to library + if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict): + result = await _save_agent_from_result(result, task.user_id, task.tool_name) + + # Serialize result for output (only substitute default when result is exactly None) + result_output = result if result is not None else {"status": "completed"} + output_str = ( + result_output + if isinstance(result_output, str) + else orjson.dumps(result_output).decode("utf-8") + ) + + # Publish result to stream registry + await stream_registry.publish_chunk( + task.task_id, + StreamToolOutputAvailable( + toolCallId=task.tool_call_id, + toolName=task.tool_name, + output=output_str, + success=True, + ), + ) + + # Update pending operation in database + # If this fails, we must not continue to mark the task as completed + result_str = serialize_result(result) + try: + await _update_tool_message( + session_id=task.session_id, + tool_call_id=task.tool_call_id, + content=result_str, + prisma_client=prisma_client, + ) + except ToolMessageUpdateError: + # DB update failed - mark task as failed to avoid inconsistent state + logger.error( + f"[COMPLETION] DB update failed for task {task.task_id}, " + "marking as failed instead of completed" + ) + await stream_registry.publish_chunk( + task.task_id, + StreamError(errorText="Failed to save operation result to database"), + ) + await stream_registry.mark_task_completed(task.task_id, status="failed") + raise + + # Generate LLM continuation with streaming + try: + await chat_service._generate_llm_continuation_with_streaming( + session_id=task.session_id, + user_id=task.user_id, + task_id=task.task_id, + ) + except Exception as e: + logger.error( + f"[COMPLETION] Failed to generate LLM continuation: {e}", + exc_info=True, + ) + + # Mark task as completed and release Redis lock + await stream_registry.mark_task_completed(task.task_id, status="completed") + try: + await chat_service._mark_operation_completed(task.tool_call_id) + except Exception as e: + logger.error(f"[COMPLETION] Failed to mark operation completed: {e}") + + logger.info( + f"[COMPLETION] Successfully processed completion for task {task.task_id}" + ) + + +async def process_operation_failure( + task: stream_registry.ActiveTask, + error: str | None, + prisma_client: Prisma | None = None, +) -> None: + """Handle failed operation completion. + + Publishes the error to the stream registry, updates the database with + the error response, and marks the task as failed. + + Args: + task: The active task that failed + error: The error message from the operation + prisma_client: Optional Prisma client for database operations. + If None, uses chat_service._update_pending_operation instead. + """ + error_msg = error or "Operation failed" + + # Publish error to stream registry + await stream_registry.publish_chunk( + task.task_id, + StreamError(errorText=error_msg), + ) + + # Update pending operation with error + # If this fails, we still continue to mark the task as failed + error_response = ErrorResponse( + message=error_msg, + error=error, + ) + try: + await _update_tool_message( + session_id=task.session_id, + tool_call_id=task.tool_call_id, + content=error_response.model_dump_json(), + prisma_client=prisma_client, + ) + except ToolMessageUpdateError: + # DB update failed - log but continue with cleanup + logger.error( + f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, " + "continuing with cleanup" + ) + + # Mark task as failed and release Redis lock + await stream_registry.mark_task_completed(task.task_id, status="failed") + try: + await chat_service._mark_operation_completed(task.tool_call_id) + except Exception as e: + logger.error(f"[COMPLETION] Failed to mark operation completed: {e}") + + logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}") diff --git a/autogpt_platform/backend/backend/api/features/chat/config.py b/autogpt_platform/backend/backend/api/features/chat/config.py index dba7934877..2e8dbf5413 100644 --- a/autogpt_platform/backend/backend/api/features/chat/config.py +++ b/autogpt_platform/backend/backend/api/features/chat/config.py @@ -44,6 +44,48 @@ class ChatConfig(BaseSettings): description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)", ) + # Stream registry configuration for SSE reconnection + stream_ttl: int = Field( + default=3600, + description="TTL in seconds for stream data in Redis (1 hour)", + ) + stream_max_length: int = Field( + default=10000, + description="Maximum number of messages to store per stream", + ) + + # Redis Streams configuration for completion consumer + stream_completion_name: str = Field( + default="chat:completions", + description="Redis Stream name for operation completions", + ) + stream_consumer_group: str = Field( + default="chat_consumers", + description="Consumer group name for completion stream", + ) + stream_claim_min_idle_ms: int = Field( + default=60000, + description="Minimum idle time in milliseconds before claiming pending messages from dead consumers", + ) + + # Redis key prefixes for stream registry + task_meta_prefix: str = Field( + default="chat:task:meta:", + description="Prefix for task metadata hash keys", + ) + task_stream_prefix: str = Field( + default="chat:stream:", + description="Prefix for task message stream keys", + ) + task_op_prefix: str = Field( + default="chat:task:op:", + description="Prefix for operation ID to task ID mapping keys", + ) + internal_api_key: str | None = Field( + default=None, + description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)", + ) + # Langfuse Prompt Management Configuration # Note: Langfuse credentials are in Settings().secrets (settings.py) langfuse_prompt_name: str = Field( @@ -82,6 +124,14 @@ class ChatConfig(BaseSettings): v = "https://openrouter.ai/api/v1" return v + @field_validator("internal_api_key", mode="before") + @classmethod + def get_internal_api_key(cls, v): + """Get internal API key from environment if not provided.""" + if v is None: + v = os.getenv("CHAT_INTERNAL_API_KEY") + return v + # Prompt paths for different contexts PROMPT_PATHS: dict[str, str] = { "default": "prompts/chat_system.md", diff --git a/autogpt_platform/backend/backend/api/features/chat/response_model.py b/autogpt_platform/backend/backend/api/features/chat/response_model.py index 53a8cf3a1f..f627a42fcc 100644 --- a/autogpt_platform/backend/backend/api/features/chat/response_model.py +++ b/autogpt_platform/backend/backend/api/features/chat/response_model.py @@ -52,6 +52,10 @@ class StreamStart(StreamBaseResponse): type: ResponseType = ResponseType.START messageId: str = Field(..., description="Unique message ID") + taskId: str | None = Field( + default=None, + description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream", + ) class StreamFinish(StreamBaseResponse): diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index cab51543b1..3e731d86ac 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -1,19 +1,23 @@ """Chat API routes for chat session management and streaming via SSE.""" import logging +import uuid as uuid_module from collections.abc import AsyncGenerator from typing import Annotated from autogpt_libs import auth -from fastapi import APIRouter, Depends, Query, Security +from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security from fastapi.responses import StreamingResponse from pydantic import BaseModel from backend.util.exceptions import NotFoundError from . import service as chat_service +from . import stream_registry +from .completion_handler import process_operation_failure, process_operation_success from .config import ChatConfig from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions +from .response_model import StreamFinish, StreamHeartbeat, StreamStart config = ChatConfig() @@ -55,6 +59,15 @@ class CreateSessionResponse(BaseModel): user_id: str | None +class ActiveStreamInfo(BaseModel): + """Information about an active stream for reconnection.""" + + task_id: str + last_message_id: str # Redis Stream message ID for resumption + operation_id: str # Operation ID for completion tracking + tool_name: str # Name of the tool being executed + + class SessionDetailResponse(BaseModel): """Response model providing complete details for a chat session, including messages.""" @@ -63,6 +76,7 @@ class SessionDetailResponse(BaseModel): updated_at: str user_id: str | None messages: list[dict] + active_stream: ActiveStreamInfo | None = None # Present if stream is still active class SessionSummaryResponse(BaseModel): @@ -81,6 +95,14 @@ class ListSessionsResponse(BaseModel): total: int +class OperationCompleteRequest(BaseModel): + """Request model for external completion webhook.""" + + success: bool + result: dict | str | None = None + error: str | None = None + + # ========== Routes ========== @@ -166,13 +188,14 @@ async def get_session( Retrieve the details of a specific chat session. Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages. + If there's an active stream for this session, returns the task_id for reconnection. Args: session_id: The unique identifier for the desired chat session. user_id: The optional authenticated user ID, or None for anonymous access. Returns: - SessionDetailResponse: Details for the requested session, or None if not found. + SessionDetailResponse: Details for the requested session, including active_stream info if applicable. """ session = await get_chat_session(session_id, user_id) @@ -180,11 +203,28 @@ async def get_session( raise NotFoundError(f"Session {session_id} not found.") messages = [message.model_dump() for message in session.messages] - logger.info( - f"Returning session {session_id}: " - f"message_count={len(messages)}, " - f"roles={[m.get('role') for m in messages]}" + + # Check if there's an active stream for this session + active_stream_info = None + active_task, last_message_id = await stream_registry.get_active_task_for_session( + session_id, user_id ) + if active_task: + # Filter out the in-progress assistant message from the session response. + # The client will receive the complete assistant response through the SSE + # stream replay instead, preventing duplicate content. + if messages and messages[-1].get("role") == "assistant": + messages = messages[:-1] + + # Use "0-0" as last_message_id to replay the stream from the beginning. + # Since we filtered out the cached assistant message, the client needs + # the full stream to reconstruct the response. + active_stream_info = ActiveStreamInfo( + task_id=active_task.task_id, + last_message_id="0-0", + operation_id=active_task.operation_id, + tool_name=active_task.tool_name, + ) return SessionDetailResponse( id=session.session_id, @@ -192,6 +232,7 @@ async def get_session( updated_at=session.updated_at.isoformat(), user_id=session.user_id or None, messages=messages, + active_stream=active_stream_info, ) @@ -211,49 +252,112 @@ async def stream_chat_post( - Tool call UI elements (if invoked) - Tool execution results + The AI generation runs in a background task that continues even if the client disconnects. + All chunks are written to Redis for reconnection support. If the client disconnects, + they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off. + Args: session_id: The chat session identifier to associate with the streamed messages. request: Request body containing message, is_user_message, and optional context. user_id: Optional authenticated user ID. Returns: - StreamingResponse: SSE-formatted response chunks. + StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event + containing the task_id for reconnection. """ + import asyncio + session = await _validate_and_get_session(session_id, user_id) + # Create a task in the stream registry for reconnection support + task_id = str(uuid_module.uuid4()) + operation_id = str(uuid_module.uuid4()) + await stream_registry.create_task( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id="chat_stream", # Not a tool call, but needed for the model + tool_name="chat", + operation_id=operation_id, + ) + + # Background task that runs the AI generation independently of SSE connection + async def run_ai_generation(): + try: + # Emit a start event with task_id for reconnection + start_chunk = StreamStart(messageId=task_id, taskId=task_id) + await stream_registry.publish_chunk(task_id, start_chunk) + + async for chunk in chat_service.stream_chat_completion( + session_id, + request.message, + is_user_message=request.is_user_message, + user_id=user_id, + session=session, # Pass pre-fetched session to avoid double-fetch + context=request.context, + ): + # Write to Redis (subscribers will receive via XREAD) + await stream_registry.publish_chunk(task_id, chunk) + + # Mark task as completed + await stream_registry.mark_task_completed(task_id, "completed") + except Exception as e: + logger.error( + f"Error in background AI generation for session {session_id}: {e}" + ) + await stream_registry.mark_task_completed(task_id, "failed") + + # Start the AI generation in a background task + bg_task = asyncio.create_task(run_ai_generation()) + await stream_registry.set_task_asyncio_task(task_id, bg_task) + + # SSE endpoint that subscribes to the task's stream async def event_generator() -> AsyncGenerator[str, None]: - chunk_count = 0 - first_chunk_type: str | None = None - async for chunk in chat_service.stream_chat_completion( - session_id, - request.message, - is_user_message=request.is_user_message, - user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch - context=request.context, - ): - if chunk_count < 3: - logger.info( - "Chat stream chunk", - extra={ - "session_id": session_id, - "chunk_type": str(chunk.type), - }, - ) - if not first_chunk_type: - first_chunk_type = str(chunk.type) - chunk_count += 1 - yield chunk.to_sse() - logger.info( - "Chat stream completed", - extra={ - "session_id": session_id, - "chunk_count": chunk_count, - "first_chunk_type": first_chunk_type, - }, - ) - # AI SDK protocol termination - yield "data: [DONE]\n\n" + subscriber_queue = None + try: + # Subscribe to the task stream (this replays existing messages + live updates) + subscriber_queue = await stream_registry.subscribe_to_task( + task_id=task_id, + user_id=user_id, + last_message_id="0-0", # Get all messages from the beginning + ) + + if subscriber_queue is None: + yield StreamFinish().to_sse() + yield "data: [DONE]\n\n" + return + + # Read from the subscriber queue and yield to SSE + while True: + try: + chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0) + yield chunk.to_sse() + + # Check for finish signal + if isinstance(chunk, StreamFinish): + break + except asyncio.TimeoutError: + # Send heartbeat to keep connection alive + yield StreamHeartbeat().to_sse() + + except GeneratorExit: + pass # Client disconnected - background task continues + except Exception as e: + logger.error(f"Error in SSE stream for task {task_id}: {e}") + finally: + # Unsubscribe when client disconnects or stream ends to prevent resource leak + if subscriber_queue is not None: + try: + await stream_registry.unsubscribe_from_task( + task_id, subscriber_queue + ) + except Exception as unsub_err: + logger.error( + f"Error unsubscribing from task {task_id}: {unsub_err}", + exc_info=True, + ) + # AI SDK protocol termination - always yield even if unsubscribe fails + yield "data: [DONE]\n\n" return StreamingResponse( event_generator(), @@ -366,6 +470,251 @@ async def session_assign_user( return {"status": "ok"} +# ========== Task Streaming (SSE Reconnection) ========== + + +@router.get( + "/tasks/{task_id}/stream", +) +async def stream_task( + task_id: str, + user_id: str | None = Depends(auth.get_user_id), + last_message_id: str = Query( + default="0-0", + description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.", + ), +): + """ + Reconnect to a long-running task's SSE stream. + + When a long-running operation (like agent generation) starts, the client + receives a task_id. If the connection drops, the client can reconnect + using this endpoint to resume receiving updates. + + Args: + task_id: The task ID from the operation_started response. + user_id: Authenticated user ID for ownership validation. + last_message_id: Last Redis Stream message ID received ("0-0" for full replay). + + Returns: + StreamingResponse: SSE-formatted response chunks starting after last_message_id. + + Raises: + HTTPException: 404 if task not found, 410 if task expired, 403 if access denied. + """ + # Check task existence and expiry before subscribing + task, error_code = await stream_registry.get_task_with_expiry_info(task_id) + + if error_code == "TASK_EXPIRED": + raise HTTPException( + status_code=410, + detail={ + "code": "TASK_EXPIRED", + "message": "This operation has expired. Please try again.", + }, + ) + + if error_code == "TASK_NOT_FOUND": + raise HTTPException( + status_code=404, + detail={ + "code": "TASK_NOT_FOUND", + "message": f"Task {task_id} not found.", + }, + ) + + # Validate ownership if task has an owner + if task and task.user_id and user_id != task.user_id: + raise HTTPException( + status_code=403, + detail={ + "code": "ACCESS_DENIED", + "message": "You do not have access to this task.", + }, + ) + + # Get subscriber queue from stream registry + subscriber_queue = await stream_registry.subscribe_to_task( + task_id=task_id, + user_id=user_id, + last_message_id=last_message_id, + ) + + if subscriber_queue is None: + raise HTTPException( + status_code=404, + detail={ + "code": "TASK_NOT_FOUND", + "message": f"Task {task_id} not found or access denied.", + }, + ) + + async def event_generator() -> AsyncGenerator[str, None]: + import asyncio + + heartbeat_interval = 15.0 # Send heartbeat every 15 seconds + try: + while True: + try: + # Wait for next chunk with timeout for heartbeats + chunk = await asyncio.wait_for( + subscriber_queue.get(), timeout=heartbeat_interval + ) + yield chunk.to_sse() + + # Check for finish signal + if isinstance(chunk, StreamFinish): + break + except asyncio.TimeoutError: + # Send heartbeat to keep connection alive + yield StreamHeartbeat().to_sse() + except Exception as e: + logger.error(f"Error in task stream {task_id}: {e}", exc_info=True) + finally: + # Unsubscribe when client disconnects or stream ends + try: + await stream_registry.unsubscribe_from_task(task_id, subscriber_queue) + except Exception as unsub_err: + logger.error( + f"Error unsubscribing from task {task_id}: {unsub_err}", + exc_info=True, + ) + # AI SDK protocol termination - always yield even if unsubscribe fails + yield "data: [DONE]\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "x-vercel-ai-ui-message-stream": "v1", + }, + ) + + +@router.get( + "/tasks/{task_id}", +) +async def get_task_status( + task_id: str, + user_id: str | None = Depends(auth.get_user_id), +) -> dict: + """ + Get the status of a long-running task. + + Args: + task_id: The task ID to check. + user_id: Authenticated user ID for ownership validation. + + Returns: + dict: Task status including task_id, status, tool_name, and operation_id. + + Raises: + NotFoundError: If task_id is not found or user doesn't have access. + """ + task = await stream_registry.get_task(task_id) + + if task is None: + raise NotFoundError(f"Task {task_id} not found.") + + # Validate ownership - if task has an owner, requester must match + if task.user_id and user_id != task.user_id: + raise NotFoundError(f"Task {task_id} not found.") + + return { + "task_id": task.task_id, + "session_id": task.session_id, + "status": task.status, + "tool_name": task.tool_name, + "operation_id": task.operation_id, + "created_at": task.created_at.isoformat(), + } + + +# ========== External Completion Webhook ========== + + +@router.post( + "/operations/{operation_id}/complete", + status_code=200, +) +async def complete_operation( + operation_id: str, + request: OperationCompleteRequest, + x_api_key: str | None = Header(default=None), +) -> dict: + """ + External completion webhook for long-running operations. + + Called by Agent Generator (or other services) when an operation completes. + This triggers the stream registry to publish completion and continue LLM generation. + + Args: + operation_id: The operation ID to complete. + request: Completion payload with success status and result/error. + x_api_key: Internal API key for authentication. + + Returns: + dict: Status of the completion. + + Raises: + HTTPException: If API key is invalid or operation not found. + """ + # Validate internal API key - reject if not configured or invalid + if not config.internal_api_key: + logger.error( + "Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured" + ) + raise HTTPException( + status_code=503, + detail="Webhook not available: internal API key not configured", + ) + if x_api_key != config.internal_api_key: + raise HTTPException(status_code=401, detail="Invalid API key") + + # Find task by operation_id + task = await stream_registry.find_task_by_operation_id(operation_id) + if task is None: + raise HTTPException( + status_code=404, + detail=f"Operation {operation_id} not found", + ) + + logger.info( + f"Received completion webhook for operation {operation_id} " + f"(task_id={task.task_id}, success={request.success})" + ) + + if request.success: + await process_operation_success(task, request.result) + else: + await process_operation_failure(task, request.error) + + return {"status": "ok", "task_id": task.task_id} + + +# ========== Configuration ========== + + +@router.get("/config/ttl", status_code=200) +async def get_ttl_config() -> dict: + """ + Get the stream TTL configuration. + + Returns the Time-To-Live settings for chat streams, which determines + how long clients can reconnect to an active stream. + + Returns: + dict: TTL configuration with seconds and milliseconds values. + """ + return { + "stream_ttl_seconds": config.stream_ttl, + "stream_ttl_ms": config.stream_ttl * 1000, + } + + # ========== Health Check ========== diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index f1f3156713..218575085b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -3,10 +3,13 @@ import logging import time from asyncio import CancelledError from collections.abc import AsyncGenerator -from dataclasses import dataclass -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast import openai + +if TYPE_CHECKING: + from backend.util.prompt import CompressResult + import orjson from langfuse import get_client from openai import ( @@ -17,7 +20,6 @@ from openai import ( RateLimitError, ) from openai.types.chat import ( - ChatCompletionAssistantMessageParam, ChatCompletionChunk, ChatCompletionMessageParam, ChatCompletionStreamOptionsParam, @@ -31,10 +33,10 @@ from backend.data.understanding import ( get_business_understanding, ) from backend.util.exceptions import NotFoundError -from backend.util.prompt import estimate_token_count from backend.util.settings import Settings from . import db as chat_db +from . import stream_registry from .config import ChatConfig from .model import ( ChatMessage, @@ -803,402 +805,58 @@ def _is_region_blocked_error(error: Exception) -> bool: return "not available in your region" in str(error).lower() -# Context window management constants -TOKEN_THRESHOLD = 120_000 -KEEP_RECENT_MESSAGES = 15 - - -@dataclass -class ContextWindowResult: - """Result of context window management.""" - - messages: list[dict[str, Any]] - token_count: int - was_compacted: bool - error: str | None = None - - -def _messages_to_dicts(messages: list) -> list[dict[str, Any]]: - """Convert message objects to dicts, filtering None values. - - Handles both TypedDict (dict-like) and other message formats. - """ - result = [] - for msg in messages: - if msg is None: - continue - if isinstance(msg, dict): - msg_dict = {k: v for k, v in msg.items() if v is not None} - else: - msg_dict = dict(msg) - result.append(msg_dict) - return result - - async def _manage_context_window( messages: list, model: str, api_key: str | None = None, base_url: str | None = None, -) -> ContextWindowResult: +) -> "CompressResult": """ - Manage context window by summarizing old messages if token count exceeds threshold. + Manage context window using the unified compress_context function. - This function handles context compaction for LLM calls by: - 1. Counting tokens in the message list - 2. If over threshold, summarizing old messages while keeping recent ones - 3. Ensuring tool_call/tool_response pairs stay intact - 4. Progressively reducing message count if still over limit + This is a thin wrapper that creates an OpenAI client for summarization + and delegates to the shared compression logic in prompt.py. Args: - messages: List of messages in OpenAI format (with system prompt if present) - model: Model name for token counting + messages: List of messages in OpenAI format + model: Model name for token counting and summarization api_key: API key for summarization calls base_url: Base URL for summarization calls Returns: - ContextWindowResult with compacted messages and metadata + CompressResult with compacted messages and metadata """ - if not messages: - return ContextWindowResult([], 0, False, "No messages to compact") - - messages_dict = _messages_to_dicts(messages) - - # Normalize model name for token counting (tiktoken only supports OpenAI models) - token_count_model = model.split("/")[-1] if "/" in model else model - if "claude" in token_count_model.lower() or not any( - known in token_count_model.lower() - for known in ["gpt", "o1", "chatgpt", "text-"] - ): - token_count_model = "gpt-4o" - - try: - token_count = estimate_token_count(messages_dict, model=token_count_model) - except Exception as e: - logger.warning(f"Token counting failed: {e}. Using gpt-4o approximation.") - token_count_model = "gpt-4o" - token_count = estimate_token_count(messages_dict, model=token_count_model) - - if token_count <= TOKEN_THRESHOLD: - return ContextWindowResult(messages, token_count, False) - - has_system_prompt = messages[0].get("role") == "system" - slice_start = max(0, len(messages_dict) - KEEP_RECENT_MESSAGES) - recent_messages = _ensure_tool_pairs_intact( - messages_dict[-KEEP_RECENT_MESSAGES:], messages_dict, slice_start - ) - - # Determine old messages to summarize (explicit bounds to avoid slice edge cases) - system_msg = messages[0] if has_system_prompt else None - if has_system_prompt: - old_messages_dict = ( - messages_dict[1:-KEEP_RECENT_MESSAGES] - if len(messages_dict) > KEEP_RECENT_MESSAGES + 1 - else [] - ) - else: - old_messages_dict = ( - messages_dict[:-KEEP_RECENT_MESSAGES] - if len(messages_dict) > KEEP_RECENT_MESSAGES - else [] - ) - - # Try to summarize old messages, fall back to truncation on failure - summary_msg = None - if old_messages_dict: - try: - summary_text = await _summarize_messages( - old_messages_dict, model=model, api_key=api_key, base_url=base_url - ) - summary_msg = ChatCompletionAssistantMessageParam( - role="assistant", - content=f"[Previous conversation summary — for context only]: {summary_text}", - ) - base = [system_msg, summary_msg] if has_system_prompt else [summary_msg] - messages = base + recent_messages - logger.info( - f"Context summarized: {token_count} tokens, " - f"summarized {len(old_messages_dict)} msgs, kept {KEEP_RECENT_MESSAGES}" - ) - except Exception as e: - logger.warning(f"Summarization failed, falling back to truncation: {e}") - messages = ( - [system_msg] + recent_messages if has_system_prompt else recent_messages - ) - else: - logger.warning( - f"Token count {token_count} exceeds threshold but no old messages to summarize" - ) - - new_token_count = estimate_token_count( - _messages_to_dicts(messages), model=token_count_model - ) - - # Progressive truncation if still over limit - if new_token_count > TOKEN_THRESHOLD: - logger.warning( - f"Still over limit: {new_token_count} tokens. Reducing messages." - ) - base_msgs = ( - recent_messages - if old_messages_dict - else (messages_dict[1:] if has_system_prompt else messages_dict) - ) - - def build_messages(recent: list) -> list: - """Build message list with optional system prompt and summary.""" - prefix = [] - if has_system_prompt and system_msg: - prefix.append(system_msg) - if summary_msg: - prefix.append(summary_msg) - return prefix + recent - - for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]: - if keep_count == 0: - messages = build_messages([]) - if not messages: - continue - elif len(base_msgs) < keep_count: - continue - else: - reduced = _ensure_tool_pairs_intact( - base_msgs[-keep_count:], - base_msgs, - max(0, len(base_msgs) - keep_count), - ) - messages = build_messages(reduced) - - new_token_count = estimate_token_count( - _messages_to_dicts(messages), model=token_count_model - ) - if new_token_count <= TOKEN_THRESHOLD: - logger.info( - f"Reduced to {keep_count} messages, {new_token_count} tokens" - ) - break - else: - logger.error( - f"Cannot reduce below threshold. Final: {new_token_count} tokens" - ) - if has_system_prompt and len(messages) > 1: - messages = messages[1:] - logger.critical("Dropped system prompt as last resort") - return ContextWindowResult( - messages, new_token_count, True, "System prompt dropped" - ) - # No system prompt to drop - return error so callers don't proceed with oversized context - return ContextWindowResult( - messages, - new_token_count, - True, - "Unable to reduce context below token limit", - ) - - return ContextWindowResult(messages, new_token_count, True) - - -async def _summarize_messages( - messages: list, - model: str, - api_key: str | None = None, - base_url: str | None = None, - timeout: float = 30.0, -) -> str: - """Summarize a list of messages into concise context. - - Uses the same model as the chat for higher quality summaries. - - Args: - messages: List of message dicts to summarize - model: Model to use for summarization (same as chat model) - api_key: API key for OpenAI client - base_url: Base URL for OpenAI client - timeout: Request timeout in seconds (default: 30.0) - - Returns: - Summarized text - """ - # Format messages for summarization - conversation = [] - for msg in messages: - role = msg.get("role", "") - content = msg.get("content", "") - # Include user, assistant, and tool messages (tool outputs are important context) - if content and role in ("user", "assistant", "tool"): - conversation.append(f"{role.upper()}: {content}") - - conversation_text = "\n\n".join(conversation) - - # Handle empty conversation - if not conversation_text: - return "No conversation history available." - - # Truncate conversation to fit within summarization model's context - # gpt-4o-mini has 128k context, but we limit to ~25k tokens (~100k chars) for safety - MAX_CHARS = 100_000 - if len(conversation_text) > MAX_CHARS: - conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]" - - # Call LLM to summarize import openai - summarization_client = openai.AsyncOpenAI( - api_key=api_key, base_url=base_url, timeout=timeout - ) + from backend.util.prompt import compress_context - response = await summarization_client.chat.completions.create( - model=model, - messages=[ - { - "role": "system", - "content": ( - "Create a detailed summary of the conversation so far. " - "This summary will be used as context when continuing the conversation.\n\n" - "Before writing the summary, analyze each message chronologically to identify:\n" - "- User requests and their explicit goals\n" - "- Your approach and key decisions made\n" - "- Technical specifics (file names, tool outputs, function signatures)\n" - "- Errors encountered and resolutions applied\n\n" - "You MUST include ALL of the following sections:\n\n" - "## 1. Primary Request and Intent\n" - "The user's explicit goals and what they are trying to accomplish.\n\n" - "## 2. Key Technical Concepts\n" - "Technologies, frameworks, tools, and patterns being used or discussed.\n\n" - "## 3. Files and Resources Involved\n" - "Specific files examined or modified, with relevant snippets and identifiers.\n\n" - "## 4. Errors and Fixes\n" - "Problems encountered, error messages, and their resolutions. " - "Include any user feedback on fixes.\n\n" - "## 5. Problem Solving\n" - "Issues that have been resolved and how they were addressed.\n\n" - "## 6. All User Messages\n" - "A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n" - "## 7. Pending Tasks\n" - "Work items the user explicitly requested that have not yet been completed.\n\n" - "## 8. Current Work\n" - "Precise description of what was being worked on most recently, including relevant context.\n\n" - "## 9. Next Steps\n" - "What should happen next, aligned with the user's most recent requests. " - "Include verbatim quotes of recent instructions if relevant." - ), - }, - {"role": "user", "content": f"Summarize:\n\n{conversation_text}"}, - ], - max_tokens=1500, - temperature=0.3, - ) + # Convert messages to dict format + messages_dict = [] + for msg in messages: + if isinstance(msg, dict): + msg_dict = {k: v for k, v in msg.items() if v is not None} + else: + msg_dict = dict(msg) + messages_dict.append(msg_dict) - summary = response.choices[0].message.content - return summary or "No summary available." - - -def _ensure_tool_pairs_intact( - recent_messages: list[dict], - all_messages: list[dict], - start_index: int, -) -> list[dict]: - """ - Ensure tool_call/tool_response pairs stay together after slicing. - - When slicing messages for context compaction, a naive slice can separate - an assistant message containing tool_calls from its corresponding tool - response messages. This causes API validation errors (e.g., Anthropic's - "unexpected tool_use_id found in tool_result blocks"). - - This function checks for orphan tool responses in the slice and extends - backwards to include their corresponding assistant messages. - - Args: - recent_messages: The sliced messages to validate - all_messages: The complete message list (for looking up missing assistants) - start_index: The index in all_messages where recent_messages begins - - Returns: - A potentially extended list of messages with tool pairs intact - """ - if not recent_messages: - return recent_messages - - # Collect all tool_call_ids from assistant messages in the slice - available_tool_call_ids: set[str] = set() - for msg in recent_messages: - if msg.get("role") == "assistant" and msg.get("tool_calls"): - for tc in msg["tool_calls"]: - tc_id = tc.get("id") - if tc_id: - available_tool_call_ids.add(tc_id) - - # Find orphan tool responses (tool messages whose tool_call_id is missing) - orphan_tool_call_ids: set[str] = set() - for msg in recent_messages: - if msg.get("role") == "tool": - tc_id = msg.get("tool_call_id") - if tc_id and tc_id not in available_tool_call_ids: - orphan_tool_call_ids.add(tc_id) - - if not orphan_tool_call_ids: - # No orphans, slice is valid - return recent_messages - - # Find the assistant messages that contain the orphan tool_call_ids - # Search backwards from start_index in all_messages - messages_to_prepend: list[dict] = [] - for i in range(start_index - 1, -1, -1): - msg = all_messages[i] - if msg.get("role") == "assistant" and msg.get("tool_calls"): - msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")} - if msg_tool_ids & orphan_tool_call_ids: - # This assistant message has tool_calls we need - # Also collect its contiguous tool responses that follow it - assistant_and_responses: list[dict] = [msg] - - # Scan forward from this assistant to collect tool responses - for j in range(i + 1, start_index): - following_msg = all_messages[j] - if following_msg.get("role") == "tool": - tool_id = following_msg.get("tool_call_id") - if tool_id and tool_id in msg_tool_ids: - assistant_and_responses.append(following_msg) - else: - # Stop at first non-tool message - break - - # Prepend the assistant and its tool responses (maintain order) - messages_to_prepend = assistant_and_responses + messages_to_prepend - # Mark these as found - orphan_tool_call_ids -= msg_tool_ids - # Also add this assistant's tool_call_ids to available set - available_tool_call_ids |= msg_tool_ids - - if not orphan_tool_call_ids: - # Found all missing assistants - break - - if orphan_tool_call_ids: - # Some tool_call_ids couldn't be resolved - remove those tool responses - # This shouldn't happen in normal operation but handles edge cases - logger.warning( - f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. " - "Removing orphan tool responses." - ) - recent_messages = [ - msg - for msg in recent_messages - if not ( - msg.get("role") == "tool" - and msg.get("tool_call_id") in orphan_tool_call_ids + # Only create client if api_key is provided (enables summarization) + # Use context manager to avoid socket leaks + if api_key: + async with openai.AsyncOpenAI( + api_key=api_key, base_url=base_url, timeout=30.0 + ) as client: + return await compress_context( + messages=messages_dict, + model=model, + client=client, ) - ] - - if messages_to_prepend: - logger.info( - f"Extended recent messages by {len(messages_to_prepend)} to preserve " - f"tool_call/tool_response pairs" + else: + # No API key - use truncation-only mode + return await compress_context( + messages=messages_dict, + model=model, + client=None, ) - return messages_to_prepend + recent_messages - - return recent_messages async def _stream_chat_chunks( @@ -1527,8 +1185,9 @@ async def _yield_tool_call( ) return - # Generate operation ID + # Generate operation ID and task ID operation_id = str(uuid_module.uuid4()) + task_id = str(uuid_module.uuid4()) # Build a user-friendly message based on tool and arguments if tool_name == "create_agent": @@ -1571,6 +1230,16 @@ async def _yield_tool_call( # Wrap session save and task creation in try-except to release lock on failure try: + # Create task in stream registry for SSE reconnection support + await stream_registry.create_task( + task_id=task_id, + session_id=session.session_id, + user_id=session.user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + # Save assistant message with tool_call FIRST (required by LLM) assistant_message = ChatMessage( role="assistant", @@ -1592,23 +1261,27 @@ async def _yield_tool_call( session.messages.append(pending_message) await upsert_chat_session(session) logger.info( - f"Saved pending operation {operation_id} for tool {tool_name} " - f"in session {session.session_id}" + f"Saved pending operation {operation_id} (task_id={task_id}) " + f"for tool {tool_name} in session {session.session_id}" ) # Store task reference in module-level set to prevent GC before completion - task = asyncio.create_task( - _execute_long_running_tool( + bg_task = asyncio.create_task( + _execute_long_running_tool_with_streaming( tool_name=tool_name, parameters=arguments, tool_call_id=tool_call_id, operation_id=operation_id, + task_id=task_id, session_id=session.session_id, user_id=session.user_id, ) ) - _background_tasks.add(task) - task.add_done_callback(_background_tasks.discard) + _background_tasks.add(bg_task) + bg_task.add_done_callback(_background_tasks.discard) + + # Associate the asyncio task with the stream registry task + await stream_registry.set_task_asyncio_task(task_id, bg_task) except Exception as e: # Roll back appended messages to prevent data corruption on subsequent saves if ( @@ -1626,6 +1299,11 @@ async def _yield_tool_call( # Release the Redis lock since the background task won't be spawned await _mark_operation_completed(tool_call_id) + # Mark stream registry task as failed if it was created + try: + await stream_registry.mark_task_completed(task_id, status="failed") + except Exception: + pass logger.error( f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True ) @@ -1639,6 +1317,7 @@ async def _yield_tool_call( message=started_msg, operation_id=operation_id, tool_name=tool_name, + task_id=task_id, # Include task_id for SSE reconnection ).model_dump_json(), success=True, ) @@ -1708,6 +1387,9 @@ async def _execute_long_running_tool( This function runs independently of the SSE connection, so the operation survives if the user closes their browser tab. + + NOTE: This is the legacy function without stream registry support. + Use _execute_long_running_tool_with_streaming for new implementations. """ try: # Load fresh session (not stale reference) @@ -1760,6 +1442,133 @@ async def _execute_long_running_tool( await _mark_operation_completed(tool_call_id) +async def _execute_long_running_tool_with_streaming( + tool_name: str, + parameters: dict[str, Any], + tool_call_id: str, + operation_id: str, + task_id: str, + session_id: str, + user_id: str | None, +) -> None: + """Execute a long-running tool with stream registry support for SSE reconnection. + + This function runs independently of the SSE connection, publishes progress + to the stream registry, and survives if the user closes their browser tab. + Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming. + + If the external service returns a 202 Accepted (async), this function exits + early and lets the Redis Streams completion consumer handle the rest. + """ + # Track whether we delegated to async processing - if so, the Redis Streams + # completion consumer (stream_registry / completion_consumer) will handle cleanup, not us + delegated_to_async = False + + try: + # Load fresh session (not stale reference) + session = await get_chat_session(session_id, user_id) + if not session: + logger.error(f"Session {session_id} not found for background tool") + await stream_registry.mark_task_completed(task_id, status="failed") + return + + # Pass operation_id and task_id to the tool for async processing + enriched_parameters = { + **parameters, + "_operation_id": operation_id, + "_task_id": task_id, + } + + # Execute the actual tool + result = await execute_tool( + tool_name=tool_name, + parameters=enriched_parameters, + tool_call_id=tool_call_id, + user_id=user_id, + session=session, + ) + + # Check if the tool result indicates async processing + # (e.g., Agent Generator returned 202 Accepted) + try: + if isinstance(result.output, dict): + result_data = result.output + elif result.output: + result_data = orjson.loads(result.output) + else: + result_data = {} + if result_data.get("status") == "accepted": + logger.info( + f"Tool {tool_name} delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id}). " + f"Redis Streams completion consumer will handle the rest." + ) + # Don't publish result, don't continue with LLM, and don't cleanup + # The Redis Streams consumer (completion_consumer) will handle + # everything when the external service completes via webhook + delegated_to_async = True + return + except (orjson.JSONDecodeError, TypeError): + pass # Not JSON or not async - continue normally + + # Publish tool result to stream registry + await stream_registry.publish_chunk(task_id, result) + + # Update the pending message with result + result_str = ( + result.output + if isinstance(result.output, str) + else orjson.dumps(result.output).decode("utf-8") + ) + await _update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=result_str, + ) + + logger.info( + f"Background tool {tool_name} completed for session {session_id} " + f"(task_id={task_id})" + ) + + # Generate LLM continuation and stream chunks to registry + await _generate_llm_continuation_with_streaming( + session_id=session_id, + user_id=user_id, + task_id=task_id, + ) + + # Mark task as completed in stream registry + await stream_registry.mark_task_completed(task_id, status="completed") + + except Exception as e: + logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True) + error_response = ErrorResponse( + message=f"Tool {tool_name} failed: {str(e)}", + ) + + # Publish error to stream registry followed by finish event + await stream_registry.publish_chunk( + task_id, + StreamError(errorText=str(e)), + ) + await stream_registry.publish_chunk(task_id, StreamFinish()) + + await _update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=error_response.model_dump_json(), + ) + + # Mark task as failed in stream registry + await stream_registry.mark_task_completed(task_id, status="failed") + finally: + # Only cleanup if we didn't delegate to async processing + # For async path, the Redis Streams completion consumer handles cleanup + if not delegated_to_async: + await _mark_operation_completed(tool_call_id) + + async def _update_pending_operation( session_id: str, tool_call_id: str, @@ -1940,3 +1749,128 @@ async def _generate_llm_continuation( except Exception as e: logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True) + + +async def _generate_llm_continuation_with_streaming( + session_id: str, + user_id: str | None, + task_id: str, +) -> None: + """Generate an LLM response with streaming to the stream registry. + + This is called by background tasks to continue the conversation + after a tool result is saved. Chunks are published to the stream registry + so reconnecting clients can receive them. + """ + import uuid as uuid_module + + try: + # Load fresh session from DB (bypass cache to get the updated tool result) + await invalidate_session_cache(session_id) + session = await get_chat_session(session_id, user_id) + if not session: + logger.error(f"Session {session_id} not found for LLM continuation") + return + + # Build system prompt + system_prompt, _ = await _build_system_prompt(user_id) + + # Build messages in OpenAI format + messages = session.to_openai_messages() + if system_prompt: + from openai.types.chat import ChatCompletionSystemMessageParam + + system_message = ChatCompletionSystemMessageParam( + role="system", + content=system_prompt, + ) + messages = [system_message] + messages + + # Build extra_body for tracing + extra_body: dict[str, Any] = { + "posthogProperties": { + "environment": settings.config.app_env.value, + }, + } + if user_id: + extra_body["user"] = user_id[:128] + extra_body["posthogDistinctId"] = user_id + if session_id: + extra_body["session_id"] = session_id[:128] + + # Make streaming LLM call (no tools - just text response) + from typing import cast + + from openai.types.chat import ChatCompletionMessageParam + + # Generate unique IDs for AI SDK protocol + message_id = str(uuid_module.uuid4()) + text_block_id = str(uuid_module.uuid4()) + + # Publish start event + await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id)) + await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id)) + + # Stream the response + stream = await client.chat.completions.create( + model=config.model, + messages=cast(list[ChatCompletionMessageParam], messages), + extra_body=extra_body, + stream=True, + ) + + assistant_content = "" + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + delta = chunk.choices[0].delta.content + assistant_content += delta + # Publish delta to stream registry + await stream_registry.publish_chunk( + task_id, + StreamTextDelta(id=text_block_id, delta=delta), + ) + + # Publish end events + await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id)) + + if assistant_content: + # Reload session from DB to avoid race condition with user messages + fresh_session = await get_chat_session(session_id, user_id) + if not fresh_session: + logger.error( + f"Session {session_id} disappeared during LLM continuation" + ) + return + + # Save assistant message to database + assistant_message = ChatMessage( + role="assistant", + content=assistant_content, + ) + fresh_session.messages.append(assistant_message) + + # Save to database (not cache) to persist the response + await upsert_chat_session(fresh_session) + + # Invalidate cache so next poll/refresh gets fresh data + await invalidate_session_cache(session_id) + + logger.info( + f"Generated streaming LLM continuation for session {session_id} " + f"(task_id={task_id}), response length: {len(assistant_content)}" + ) + else: + logger.warning( + f"Streaming LLM continuation returned empty response for {session_id}" + ) + + except Exception as e: + logger.error( + f"Failed to generate streaming LLM continuation: {e}", exc_info=True + ) + # Publish error to stream registry followed by finish event + await stream_registry.publish_chunk( + task_id, + StreamError(errorText=f"Failed to generate response: {e}"), + ) + await stream_registry.publish_chunk(task_id, StreamFinish()) diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py new file mode 100644 index 0000000000..88a5023e2b --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -0,0 +1,704 @@ +"""Stream registry for managing reconnectable SSE streams. + +This module provides a registry for tracking active streaming tasks and their +messages. It uses Redis for all state management (no in-memory state), making +pods stateless and horizontally scalable. + +Architecture: +- Redis Stream: Persists all messages for replay and real-time delivery +- Redis Hash: Task metadata (status, session_id, etc.) + +Subscribers: +1. Replay missed messages from Redis Stream (XREAD) +2. Listen for live updates via blocking XREAD +3. No in-memory state required on the subscribing pod +""" + +import asyncio +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Literal + +import orjson + +from backend.data.redis_client import get_redis_async + +from .config import ChatConfig +from .response_model import StreamBaseResponse, StreamError, StreamFinish + +logger = logging.getLogger(__name__) +config = ChatConfig() + +# Track background tasks for this pod (just the asyncio.Task reference, not subscribers) +_local_tasks: dict[str, asyncio.Task] = {} + +# Track listener tasks per subscriber queue for cleanup +# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe +_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {} + +# Timeout for putting chunks into subscriber queues (seconds) +# If the queue is full and doesn't drain within this time, send an overflow error +QUEUE_PUT_TIMEOUT = 5.0 + +# Lua script for atomic compare-and-swap status update (idempotent completion) +# Returns 1 if status was updated, 0 if already completed/failed +COMPLETE_TASK_SCRIPT = """ +local current = redis.call("HGET", KEYS[1], "status") +if current == "running" then + redis.call("HSET", KEYS[1], "status", ARGV[1]) + return 1 +end +return 0 +""" + + +@dataclass +class ActiveTask: + """Represents an active streaming task (metadata only, no in-memory queues).""" + + task_id: str + session_id: str + user_id: str | None + tool_call_id: str + tool_name: str + operation_id: str + status: Literal["running", "completed", "failed"] = "running" + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + asyncio_task: asyncio.Task | None = None + + +def _get_task_meta_key(task_id: str) -> str: + """Get Redis key for task metadata.""" + return f"{config.task_meta_prefix}{task_id}" + + +def _get_task_stream_key(task_id: str) -> str: + """Get Redis key for task message stream.""" + return f"{config.task_stream_prefix}{task_id}" + + +def _get_operation_mapping_key(operation_id: str) -> str: + """Get Redis key for operation_id to task_id mapping.""" + return f"{config.task_op_prefix}{operation_id}" + + +async def create_task( + task_id: str, + session_id: str, + user_id: str | None, + tool_call_id: str, + tool_name: str, + operation_id: str, +) -> ActiveTask: + """Create a new streaming task in Redis. + + Args: + task_id: Unique identifier for the task + session_id: Chat session ID + user_id: User ID (may be None for anonymous) + tool_call_id: Tool call ID from the LLM + tool_name: Name of the tool being executed + operation_id: Operation ID for webhook callbacks + + Returns: + The created ActiveTask instance (metadata only) + """ + task = ActiveTask( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + + # Store metadata in Redis + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + op_key = _get_operation_mapping_key(operation_id) + + await redis.hset( # type: ignore[misc] + meta_key, + mapping={ + "task_id": task_id, + "session_id": session_id, + "user_id": user_id or "", + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "operation_id": operation_id, + "status": task.status, + "created_at": task.created_at.isoformat(), + }, + ) + await redis.expire(meta_key, config.stream_ttl) + + # Create operation_id -> task_id mapping for webhook lookups + await redis.set(op_key, task_id, ex=config.stream_ttl) + + logger.debug(f"Created task {task_id} for session {session_id}") + + return task + + +async def publish_chunk( + task_id: str, + chunk: StreamBaseResponse, +) -> str: + """Publish a chunk to Redis Stream. + + All delivery is via Redis Streams - no in-memory state. + + Args: + task_id: Task ID to publish to + chunk: The stream response chunk to publish + + Returns: + The Redis Stream message ID + """ + chunk_json = chunk.model_dump_json() + message_id = "0-0" + + try: + redis = await get_redis_async() + stream_key = _get_task_stream_key(task_id) + + # Write to Redis Stream for persistence and real-time delivery + raw_id = await redis.xadd( + stream_key, + {"data": chunk_json}, + maxlen=config.stream_max_length, + ) + message_id = raw_id if isinstance(raw_id, str) else raw_id.decode() + + # Set TTL on stream to match task metadata TTL + await redis.expire(stream_key, config.stream_ttl) + except Exception as e: + logger.error( + f"Failed to publish chunk for task {task_id}: {e}", + exc_info=True, + ) + + return message_id + + +async def subscribe_to_task( + task_id: str, + user_id: str | None, + last_message_id: str = "0-0", +) -> asyncio.Queue[StreamBaseResponse] | None: + """Subscribe to a task's stream with replay of missed messages. + + This is fully stateless - uses Redis Stream for replay and pub/sub for live updates. + + Args: + task_id: Task ID to subscribe to + user_id: User ID for ownership validation + last_message_id: Last Redis Stream message ID received ("0-0" for full replay) + + Returns: + An asyncio Queue that will receive stream chunks, or None if task not found + or user doesn't have access + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + logger.debug(f"Task {task_id} not found in Redis") + return None + + # Note: Redis client uses decode_responses=True, so keys are strings + task_status = meta.get("status", "") + task_user_id = meta.get("user_id", "") or None + + # Validate ownership - if task has an owner, requester must match + if task_user_id: + if user_id != task_user_id: + logger.warning( + f"User {user_id} denied access to task {task_id} " + f"owned by {task_user_id}" + ) + return None + + subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue() + stream_key = _get_task_stream_key(task_id) + + # Step 1: Replay messages from Redis Stream + messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) + + replayed_count = 0 + replay_last_id = last_message_id + if messages: + for _stream_name, stream_messages in messages: + for msg_id, msg_data in stream_messages: + replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + # Note: Redis client uses decode_responses=True, so keys are strings + if "data" in msg_data: + try: + chunk_data = orjson.loads(msg_data["data"]) + chunk = _reconstruct_chunk(chunk_data) + if chunk: + await subscriber_queue.put(chunk) + replayed_count += 1 + except Exception as e: + logger.warning(f"Failed to replay message: {e}") + + logger.debug(f"Task {task_id}: replayed {replayed_count} messages") + + # Step 2: If task is still running, start stream listener for live updates + if task_status == "running": + listener_task = asyncio.create_task( + _stream_listener(task_id, subscriber_queue, replay_last_id) + ) + # Track listener task for cleanup on unsubscribe + _listener_tasks[id(subscriber_queue)] = (task_id, listener_task) + else: + # Task is completed/failed - add finish marker + await subscriber_queue.put(StreamFinish()) + + return subscriber_queue + + +async def _stream_listener( + task_id: str, + subscriber_queue: asyncio.Queue[StreamBaseResponse], + last_replayed_id: str, +) -> None: + """Listen to Redis Stream for new messages using blocking XREAD. + + This approach avoids the duplicate message issue that can occur with pub/sub + when messages are published during the gap between replay and subscription. + + Args: + task_id: Task ID to listen for + subscriber_queue: Queue to deliver messages to + last_replayed_id: Last message ID from replay (continue from here) + """ + queue_id = id(subscriber_queue) + # Track the last successfully delivered message ID for recovery hints + last_delivered_id = last_replayed_id + + try: + redis = await get_redis_async() + stream_key = _get_task_stream_key(task_id) + current_id = last_replayed_id + + while True: + # Block for up to 30 seconds waiting for new messages + # This allows periodic checking if task is still running + messages = await redis.xread( + {stream_key: current_id}, block=30000, count=100 + ) + + if not messages: + # Timeout - check if task is still running + meta_key = _get_task_meta_key(task_id) + status = await redis.hget(meta_key, "status") # type: ignore[misc] + if status and status != "running": + try: + await asyncio.wait_for( + subscriber_queue.put(StreamFinish()), + timeout=QUEUE_PUT_TIMEOUT, + ) + except asyncio.TimeoutError: + logger.warning( + f"Timeout delivering finish event for task {task_id}" + ) + break + continue + + for _stream_name, stream_messages in messages: + for msg_id, msg_data in stream_messages: + current_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + + if "data" not in msg_data: + continue + + try: + chunk_data = orjson.loads(msg_data["data"]) + chunk = _reconstruct_chunk(chunk_data) + if chunk: + try: + await asyncio.wait_for( + subscriber_queue.put(chunk), + timeout=QUEUE_PUT_TIMEOUT, + ) + # Update last delivered ID on successful delivery + last_delivered_id = current_id + except asyncio.TimeoutError: + logger.warning( + f"Subscriber queue full for task {task_id}, " + f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s" + ) + # Send overflow error with recovery info + try: + overflow_error = StreamError( + errorText="Message delivery timeout - some messages may have been missed", + code="QUEUE_OVERFLOW", + details={ + "last_delivered_id": last_delivered_id, + "recovery_hint": f"Reconnect with last_message_id={last_delivered_id}", + }, + ) + subscriber_queue.put_nowait(overflow_error) + except asyncio.QueueFull: + # Queue is completely stuck, nothing more we can do + logger.error( + f"Cannot deliver overflow error for task {task_id}, " + "queue completely blocked" + ) + + # Stop listening on finish + if isinstance(chunk, StreamFinish): + return + except Exception as e: + logger.warning(f"Error processing stream message: {e}") + + except asyncio.CancelledError: + logger.debug(f"Stream listener cancelled for task {task_id}") + raise # Re-raise to propagate cancellation + except Exception as e: + logger.error(f"Stream listener error for task {task_id}: {e}") + # On error, send finish to unblock subscriber + try: + await asyncio.wait_for( + subscriber_queue.put(StreamFinish()), + timeout=QUEUE_PUT_TIMEOUT, + ) + except (asyncio.TimeoutError, asyncio.QueueFull): + logger.warning( + f"Could not deliver finish event for task {task_id} after error" + ) + finally: + # Clean up listener task mapping on exit + _listener_tasks.pop(queue_id, None) + + +async def mark_task_completed( + task_id: str, + status: Literal["completed", "failed"] = "completed", +) -> bool: + """Mark a task as completed and publish finish event. + + This is idempotent - calling multiple times with the same task_id is safe. + Uses atomic compare-and-swap via Lua script to prevent race conditions. + Status is updated first (source of truth), then finish event is published (best-effort). + + Args: + task_id: Task ID to mark as completed + status: Final status ("completed" or "failed") + + Returns: + True if task was newly marked completed, False if already completed/failed + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + + # Atomic compare-and-swap: only update if status is "running" + # This prevents race conditions when multiple callers try to complete simultaneously + result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc] + + if result == 0: + logger.debug(f"Task {task_id} already completed/failed, skipping") + return False + + # THEN publish finish event (best-effort - listeners can detect via status polling) + try: + await publish_chunk(task_id, StreamFinish()) + except Exception as e: + logger.error( + f"Failed to publish finish event for task {task_id}: {e}. " + "Listeners will detect completion via status polling." + ) + + # Clean up local task reference if exists + _local_tasks.pop(task_id, None) + return True + + +async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None: + """Find a task by its operation ID. + + Used by webhook callbacks to locate the task to update. + + Args: + operation_id: Operation ID to search for + + Returns: + ActiveTask if found, None otherwise + """ + redis = await get_redis_async() + op_key = _get_operation_mapping_key(operation_id) + task_id = await redis.get(op_key) + + if not task_id: + return None + + task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id + return await get_task(task_id_str) + + +async def get_task(task_id: str) -> ActiveTask | None: + """Get a task by its ID from Redis. + + Args: + task_id: Task ID to look up + + Returns: + ActiveTask if found, None otherwise + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + return None + + # Note: Redis client uses decode_responses=True, so keys/values are strings + return ActiveTask( + task_id=meta.get("task_id", ""), + session_id=meta.get("session_id", ""), + user_id=meta.get("user_id", "") or None, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status=meta.get("status", "running"), # type: ignore[arg-type] + ) + + +async def get_task_with_expiry_info( + task_id: str, +) -> tuple[ActiveTask | None, str | None]: + """Get a task by its ID with expiration detection. + + Returns (task, error_code) where error_code is: + - None if task found + - "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired) + - "TASK_NOT_FOUND" if neither exists + + Args: + task_id: Task ID to look up + + Returns: + Tuple of (ActiveTask or None, error_code or None) + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + stream_key = _get_task_stream_key(task_id) + + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + # Check if stream still has data (metadata expired but stream hasn't) + stream_len = await redis.xlen(stream_key) + if stream_len > 0: + return None, "TASK_EXPIRED" + return None, "TASK_NOT_FOUND" + + # Note: Redis client uses decode_responses=True, so keys/values are strings + return ( + ActiveTask( + task_id=meta.get("task_id", ""), + session_id=meta.get("session_id", ""), + user_id=meta.get("user_id", "") or None, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status=meta.get("status", "running"), # type: ignore[arg-type] + ), + None, + ) + + +async def get_active_task_for_session( + session_id: str, + user_id: str | None = None, +) -> tuple[ActiveTask | None, str]: + """Get the active (running) task for a session, if any. + + Scans Redis for tasks matching the session_id with status="running". + + Args: + session_id: Session ID to look up + user_id: User ID for ownership validation (optional) + + Returns: + Tuple of (ActiveTask if found and running, last_message_id from Redis Stream) + """ + + redis = await get_redis_async() + + # Scan Redis for task metadata keys + cursor = 0 + tasks_checked = 0 + + while True: + cursor, keys = await redis.scan( + cursor, match=f"{config.task_meta_prefix}*", count=100 + ) + + for key in keys: + tasks_checked += 1 + meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc] + if not meta: + continue + + # Note: Redis client uses decode_responses=True, so keys/values are strings + task_session_id = meta.get("session_id", "") + task_status = meta.get("status", "") + task_user_id = meta.get("user_id", "") or None + task_id = meta.get("task_id", "") + + if task_session_id == session_id and task_status == "running": + # Validate ownership - if task has an owner, requester must match + if task_user_id and user_id != task_user_id: + continue + + # Get the last message ID from Redis Stream + stream_key = _get_task_stream_key(task_id) + last_id = "0-0" + try: + messages = await redis.xrevrange(stream_key, count=1) + if messages: + msg_id = messages[0][0] + last_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + except Exception as e: + logger.warning(f"Failed to get last message ID: {e}") + + return ( + ActiveTask( + task_id=task_id, + session_id=task_session_id, + user_id=task_user_id, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status="running", + ), + last_id, + ) + + if cursor == 0: + break + + return None, "0-0" + + +def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None: + """Reconstruct a StreamBaseResponse from JSON data. + + Args: + chunk_data: Parsed JSON data from Redis + + Returns: + Reconstructed response object, or None if unknown type + """ + from .response_model import ( + ResponseType, + StreamError, + StreamFinish, + StreamHeartbeat, + StreamStart, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, + StreamUsage, + ) + + # Map response types to their corresponding classes + type_to_class: dict[str, type[StreamBaseResponse]] = { + ResponseType.START.value: StreamStart, + ResponseType.FINISH.value: StreamFinish, + ResponseType.TEXT_START.value: StreamTextStart, + ResponseType.TEXT_DELTA.value: StreamTextDelta, + ResponseType.TEXT_END.value: StreamTextEnd, + ResponseType.TOOL_INPUT_START.value: StreamToolInputStart, + ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable, + ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable, + ResponseType.ERROR.value: StreamError, + ResponseType.USAGE.value: StreamUsage, + ResponseType.HEARTBEAT.value: StreamHeartbeat, + } + + chunk_type = chunk_data.get("type") + chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type] + + if chunk_class is None: + logger.warning(f"Unknown chunk type: {chunk_type}") + return None + + try: + return chunk_class(**chunk_data) + except Exception as e: + logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}") + return None + + +async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None: + """Track the asyncio.Task for a task (local reference only). + + This is just for cleanup purposes - the task state is in Redis. + + Args: + task_id: Task ID + asyncio_task: The asyncio Task to track + """ + _local_tasks[task_id] = asyncio_task + + +async def unsubscribe_from_task( + task_id: str, + subscriber_queue: asyncio.Queue[StreamBaseResponse], +) -> None: + """Clean up when a subscriber disconnects. + + Cancels the XREAD-based listener task associated with this subscriber queue + to prevent resource leaks. + + Args: + task_id: Task ID + subscriber_queue: The subscriber's queue used to look up the listener task + """ + queue_id = id(subscriber_queue) + listener_entry = _listener_tasks.pop(queue_id, None) + + if listener_entry is None: + logger.debug( + f"No listener task found for task {task_id} queue {queue_id} " + "(may have already completed)" + ) + return + + stored_task_id, listener_task = listener_entry + + if stored_task_id != task_id: + logger.warning( + f"Task ID mismatch in unsubscribe: expected {task_id}, " + f"found {stored_task_id}" + ) + + if listener_task.done(): + logger.debug(f"Listener task for task {task_id} already completed") + return + + # Cancel the listener task + listener_task.cancel() + + try: + # Wait for the task to be cancelled with a timeout + await asyncio.wait_for(listener_task, timeout=5.0) + except asyncio.CancelledError: + # Expected - the task was successfully cancelled + pass + except asyncio.TimeoutError: + logger.warning( + f"Timeout waiting for listener task cancellation for task {task_id}" + ) + except Exception as e: + logger.error(f"Error during listener task cancellation for task {task_id}: {e}") + + logger.debug(f"Successfully unsubscribed from task {task_id}") diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py index d078860c3a..dcbc35ef37 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py @@ -10,6 +10,7 @@ from .add_understanding import AddUnderstandingTool from .agent_output import AgentOutputTool from .base import BaseTool from .create_agent import CreateAgentTool +from .customize_agent import CustomizeAgentTool from .edit_agent import EditAgentTool from .find_agent import FindAgentTool from .find_block import FindBlockTool @@ -34,6 +35,7 @@ logger = logging.getLogger(__name__) TOOL_REGISTRY: dict[str, BaseTool] = { "add_understanding": AddUnderstandingTool(), "create_agent": CreateAgentTool(), + "customize_agent": CustomizeAgentTool(), "edit_agent": EditAgentTool(), "find_agent": FindAgentTool(), "find_block": FindBlockTool(), diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py index b7650b3cbd..4266834220 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py @@ -8,6 +8,7 @@ from .core import ( DecompositionStep, LibraryAgentSummary, MarketplaceAgentSummary, + customize_template, decompose_goal, enrich_library_agents_from_steps, extract_search_terms_from_steps, @@ -19,6 +20,7 @@ from .core import ( get_library_agent_by_graph_id, get_library_agent_by_id, get_library_agents_for_generation, + graph_to_json, json_to_graph, save_agent_to_library, search_marketplace_agents_for_generation, @@ -36,6 +38,7 @@ __all__ = [ "LibraryAgentSummary", "MarketplaceAgentSummary", "check_external_service_health", + "customize_template", "decompose_goal", "enrich_library_agents_from_steps", "extract_search_terms_from_steps", @@ -48,6 +51,7 @@ __all__ = [ "get_library_agent_by_id", "get_library_agents_for_generation", "get_user_message_for_error", + "graph_to_json", "is_external_service_configured", "json_to_graph", "save_agent_to_library", diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py index 0ddd2aa86b..b88b9b2924 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py @@ -19,6 +19,7 @@ from backend.data.graph import ( from backend.util.exceptions import DatabaseError, NotFoundError from .service import ( + customize_template_external, decompose_goal_external, generate_agent_external, generate_agent_patch_external, @@ -549,15 +550,21 @@ async def decompose_goal( async def generate_agent( instructions: DecompositionResult | dict[str, Any], library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Generate agent JSON from instructions. Args: instructions: Structured instructions from decompose_goal library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams + completion notification) + task_id: Task ID for async processing (enables Redis Streams persistence + and SSE delivery) Returns: - Agent JSON dict, error dict {"type": "error", ...}, or None on error + Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error Raises: AgentGeneratorNotConfiguredError: If the external service is not configured. @@ -565,8 +572,13 @@ async def generate_agent( _check_service_configured() logger.info("Calling external Agent Generator service for generate_agent") result = await generate_agent_external( - dict(instructions), _to_dict_list(library_agents) + dict(instructions), _to_dict_list(library_agents), operation_id, task_id ) + + # Don't modify async response + if result and result.get("status") == "accepted": + return result + if result: if isinstance(result, dict) and result.get("type") == "error": return result @@ -740,32 +752,15 @@ async def save_agent_to_library( return created_graph, library_agents[0] -async def get_agent_as_json( - agent_id: str, user_id: str | None -) -> dict[str, Any] | None: - """Fetch an agent and convert to JSON format for editing. +def graph_to_json(graph: Graph) -> dict[str, Any]: + """Convert a Graph object to JSON format for the agent generator. Args: - agent_id: Graph ID or library agent ID - user_id: User ID + graph: Graph object to convert Returns: - Agent as JSON dict or None if not found + Agent as JSON dict """ - graph = await get_graph(agent_id, version=None, user_id=user_id) - - if not graph and user_id: - try: - library_agent = await library_db.get_library_agent(agent_id, user_id) - graph = await get_graph( - library_agent.graph_id, version=None, user_id=user_id - ) - except NotFoundError: - pass - - if not graph: - return None - nodes = [] for node in graph.nodes: nodes.append( @@ -802,10 +797,41 @@ async def get_agent_as_json( } +async def get_agent_as_json( + agent_id: str, user_id: str | None +) -> dict[str, Any] | None: + """Fetch an agent and convert to JSON format for editing. + + Args: + agent_id: Graph ID or library agent ID + user_id: User ID + + Returns: + Agent as JSON dict or None if not found + """ + graph = await get_graph(agent_id, version=None, user_id=user_id) + + if not graph and user_id: + try: + library_agent = await library_db.get_library_agent(agent_id, user_id) + graph = await get_graph( + library_agent.graph_id, version=None, user_id=user_id + ) + except NotFoundError: + pass + + if not graph: + return None + + return graph_to_json(graph) + + async def generate_agent_patch( update_request: str, current_agent: dict[str, Any], library_agents: list[AgentSummary] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Update an existing agent using natural language. @@ -818,10 +844,12 @@ async def generate_agent_patch( update_request: Natural language description of changes current_agent: Current agent JSON library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...}, - error dict {"type": "error", ...}, or None on unexpected error + {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error Raises: AgentGeneratorNotConfiguredError: If the external service is not configured. @@ -829,5 +857,43 @@ async def generate_agent_patch( _check_service_configured() logger.info("Calling external Agent Generator service for generate_agent_patch") return await generate_agent_patch_external( - update_request, current_agent, _to_dict_list(library_agents) + update_request, + current_agent, + _to_dict_list(library_agents), + operation_id, + task_id, + ) + + +async def customize_template( + template_agent: dict[str, Any], + modification_request: str, + context: str = "", +) -> dict[str, Any] | None: + """Customize a template/marketplace agent using natural language. + + This is used when users want to modify a template or marketplace agent + to fit their specific needs before adding it to their library. + + The external Agent Generator service handles: + - Understanding the modification request + - Applying changes to the template + - Fixing and validating the result + + Args: + template_agent: The template agent JSON to customize + modification_request: Natural language description of customizations + context: Additional context (e.g., answers to previous questions) + + Returns: + Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...}, + error dict {"type": "error", ...}, or None on unexpected error + + Raises: + AgentGeneratorNotConfiguredError: If the external service is not configured. + """ + _check_service_configured() + logger.info("Calling external Agent Generator service for customize_template") + return await customize_template_external( + template_agent, modification_request, context ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py index c9c960d1ae..62411b4e1b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py @@ -212,24 +212,45 @@ async def decompose_goal_external( async def generate_agent_external( instructions: dict[str, Any], library_agents: list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Call the external service to generate an agent from instructions. Args: instructions: Structured instructions from decompose_goal library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: - Agent JSON dict on success, or error dict {"type": "error", ...} on error + Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error """ client = _get_client() + # Build request payload payload: dict[str, Any] = {"instructions": instructions} if library_agents: payload["library_agents"] = library_agents + if operation_id and task_id: + payload["operation_id"] = operation_id + payload["task_id"] = task_id try: response = await client.post("/api/generate-agent", json=payload) + + # Handle 202 Accepted for async processing + if response.status_code == 202: + logger.info( + f"Agent Generator accepted async request " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return { + "status": "accepted", + "operation_id": operation_id, + "task_id": task_id, + } + response.raise_for_status() data = response.json() @@ -261,6 +282,8 @@ async def generate_agent_patch_external( update_request: str, current_agent: dict[str, Any], library_agents: list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Call the external service to generate a patch for an existing agent. @@ -268,21 +291,40 @@ async def generate_agent_patch_external( update_request: Natural language description of changes current_agent: Current agent JSON library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: - Updated agent JSON, clarifying questions dict, or error dict on error + Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error """ client = _get_client() + # Build request payload payload: dict[str, Any] = { "update_request": update_request, "current_agent_json": current_agent, } if library_agents: payload["library_agents"] = library_agents + if operation_id and task_id: + payload["operation_id"] = operation_id + payload["task_id"] = task_id try: response = await client.post("/api/update-agent", json=payload) + + # Handle 202 Accepted for async processing + if response.status_code == 202: + logger.info( + f"Agent Generator accepted async update request " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return { + "status": "accepted", + "operation_id": operation_id, + "task_id": task_id, + } + response.raise_for_status() data = response.json() @@ -326,6 +368,77 @@ async def generate_agent_patch_external( return _create_error_response(error_msg, "unexpected_error") +async def customize_template_external( + template_agent: dict[str, Any], + modification_request: str, + context: str = "", +) -> dict[str, Any] | None: + """Call the external service to customize a template/marketplace agent. + + Args: + template_agent: The template agent JSON to customize + modification_request: Natural language description of customizations + context: Additional context (e.g., answers to previous questions) + + Returns: + Customized agent JSON, clarifying questions dict, or error dict on error + """ + client = _get_client() + + request = modification_request + if context: + request = f"{modification_request}\n\nAdditional context from user:\n{context}" + + payload: dict[str, Any] = { + "template_agent_json": template_agent, + "modification_request": request, + } + + try: + response = await client.post("/api/template-modification", json=payload) + response.raise_for_status() + data = response.json() + + if not data.get("success"): + error_msg = data.get("error", "Unknown error from Agent Generator") + error_type = data.get("error_type", "unknown") + logger.error( + f"Agent Generator template customization failed: {error_msg} " + f"(type: {error_type})" + ) + return _create_error_response(error_msg, error_type) + + # Check if it's clarifying questions + if data.get("type") == "clarifying_questions": + return { + "type": "clarifying_questions", + "questions": data.get("questions", []), + } + + # Check if it's an error passed through + if data.get("type") == "error": + return _create_error_response( + data.get("error", "Unknown error"), + data.get("error_type", "unknown"), + ) + + # Otherwise return the customized agent JSON + return data.get("agent_json") + + except httpx.HTTPStatusError as e: + error_type, error_msg = _classify_http_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except httpx.RequestError as e: + error_type, error_msg = _classify_request_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except Exception as e: + error_msg = f"Unexpected error calling Agent Generator: {e}" + logger.error(error_msg) + return _create_error_response(error_msg, "unexpected_error") + + async def get_blocks_external() -> list[dict[str, Any]] | None: """Get available blocks from the external service. diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py index adb2c78fce..7333851a5b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py @@ -18,6 +18,7 @@ from .base import BaseTool from .models import ( AgentPreviewResponse, AgentSavedResponse, + AsyncProcessingResponse, ClarificationNeededResponse, ClarifyingQuestion, ErrorResponse, @@ -98,6 +99,10 @@ class CreateAgentTool(BaseTool): save = kwargs.get("save", True) session_id = session.session_id if session else None + # Extract async processing params (passed by long-running tool handler) + operation_id = kwargs.get("_operation_id") + task_id = kwargs.get("_task_id") + if not description: return ErrorResponse( message="Please provide a description of what the agent should do.", @@ -219,7 +224,12 @@ class CreateAgentTool(BaseTool): logger.warning(f"Failed to enrich library agents from steps: {e}") try: - agent_json = await generate_agent(decomposition_result, library_agents) + agent_json = await generate_agent( + decomposition_result, + library_agents, + operation_id=operation_id, + task_id=task_id, + ) except AgentGeneratorNotConfiguredError: return ErrorResponse( message=( @@ -263,6 +273,19 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) + # Check if Agent Generator accepted for async processing + if agent_json.get("status") == "accepted": + logger.info( + f"Agent generation delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return AsyncProcessingResponse( + message="Agent generation started. You'll be notified when it's complete.", + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + ) + agent_name = agent_json.get("name", "Generated Agent") agent_description = agent_json.get("description", "") node_count = len(agent_json.get("nodes", [])) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py new file mode 100644 index 0000000000..c0568bd936 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py @@ -0,0 +1,337 @@ +"""CustomizeAgentTool - Customizes marketplace/template agents using natural language.""" + +import logging +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.store import db as store_db +from backend.api.features.store.exceptions import AgentNotFoundError + +from .agent_generator import ( + AgentGeneratorNotConfiguredError, + customize_template, + get_user_message_for_error, + graph_to_json, + save_agent_to_library, +) +from .base import BaseTool +from .models import ( + AgentPreviewResponse, + AgentSavedResponse, + ClarificationNeededResponse, + ClarifyingQuestion, + ErrorResponse, + ToolResponseBase, +) + +logger = logging.getLogger(__name__) + + +class CustomizeAgentTool(BaseTool): + """Tool for customizing marketplace/template agents using natural language.""" + + @property + def name(self) -> str: + return "customize_agent" + + @property + def description(self) -> str: + return ( + "Customize a marketplace or template agent using natural language. " + "Takes an existing agent from the marketplace and modifies it based on " + "the user's requirements before adding to their library." + ) + + @property + def requires_auth(self) -> bool: + return True + + @property + def is_long_running(self) -> bool: + return True + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "agent_id": { + "type": "string", + "description": ( + "The marketplace agent ID in format 'creator/slug' " + "(e.g., 'autogpt/newsletter-writer'). " + "Get this from find_agent results." + ), + }, + "modifications": { + "type": "string", + "description": ( + "Natural language description of how to customize the agent. " + "Be specific about what changes you want to make." + ), + }, + "context": { + "type": "string", + "description": ( + "Additional context or answers to previous clarifying questions." + ), + }, + "save": { + "type": "boolean", + "description": ( + "Whether to save the customized agent to the user's library. " + "Default is true. Set to false for preview only." + ), + "default": True, + }, + }, + "required": ["agent_id", "modifications"], + } + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + """Execute the customize_agent tool. + + Flow: + 1. Parse the agent ID to get creator/slug + 2. Fetch the template agent from the marketplace + 3. Call customize_template with the modification request + 4. Preview or save based on the save parameter + """ + agent_id = kwargs.get("agent_id", "").strip() + modifications = kwargs.get("modifications", "").strip() + context = kwargs.get("context", "") + save = kwargs.get("save", True) + session_id = session.session_id if session else None + + if not agent_id: + return ErrorResponse( + message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').", + error="missing_agent_id", + session_id=session_id, + ) + + if not modifications: + return ErrorResponse( + message="Please describe how you want to customize this agent.", + error="missing_modifications", + session_id=session_id, + ) + + # Parse agent_id in format "creator/slug" + parts = [p.strip() for p in agent_id.split("/")] + if len(parts) != 2 or not parts[0] or not parts[1]: + return ErrorResponse( + message=( + f"Invalid agent ID format: '{agent_id}'. " + "Expected format is 'creator/agent-name' " + "(e.g., 'autogpt/newsletter-writer')." + ), + error="invalid_agent_id_format", + session_id=session_id, + ) + + creator_username, agent_slug = parts + + # Fetch the marketplace agent details + try: + agent_details = await store_db.get_store_agent_details( + username=creator_username, agent_name=agent_slug + ) + except AgentNotFoundError: + return ErrorResponse( + message=( + f"Could not find marketplace agent '{agent_id}'. " + "Please check the agent ID and try again." + ), + error="agent_not_found", + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error fetching marketplace agent {agent_id}: {e}") + return ErrorResponse( + message="Failed to fetch the marketplace agent. Please try again.", + error="fetch_error", + session_id=session_id, + ) + + if not agent_details.store_listing_version_id: + return ErrorResponse( + message=( + f"The agent '{agent_id}' does not have an available version. " + "Please try a different agent." + ), + error="no_version_available", + session_id=session_id, + ) + + # Get the full agent graph + try: + graph = await store_db.get_agent(agent_details.store_listing_version_id) + template_agent = graph_to_json(graph) + except Exception as e: + logger.error(f"Error fetching agent graph for {agent_id}: {e}") + return ErrorResponse( + message="Failed to fetch the agent configuration. Please try again.", + error="graph_fetch_error", + session_id=session_id, + ) + + # Call customize_template + try: + result = await customize_template( + template_agent=template_agent, + modification_request=modifications, + context=context, + ) + except AgentGeneratorNotConfiguredError: + return ErrorResponse( + message=( + "Agent customization is not available. " + "The Agent Generator service is not configured." + ), + error="service_not_configured", + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error calling customize_template for {agent_id}: {e}") + return ErrorResponse( + message=( + "Failed to customize the agent due to a service error. " + "Please try again." + ), + error="customization_service_error", + session_id=session_id, + ) + + if result is None: + return ErrorResponse( + message=( + "Failed to customize the agent. " + "The agent generation service may be unavailable or timed out. " + "Please try again." + ), + error="customization_failed", + session_id=session_id, + ) + + # Handle error response + if isinstance(result, dict) and result.get("type") == "error": + error_msg = result.get("error", "Unknown error") + error_type = result.get("error_type", "unknown") + user_message = get_user_message_for_error( + error_type, + operation="customize the agent", + llm_parse_message=( + "The AI had trouble customizing the agent. " + "Please try again or simplify your request." + ), + validation_message=( + "The customized agent failed validation. " + "Please try rephrasing your request." + ), + error_details=error_msg, + ) + return ErrorResponse( + message=user_message, + error=f"customization_failed:{error_type}", + session_id=session_id, + ) + + # Handle clarifying questions + if isinstance(result, dict) and result.get("type") == "clarifying_questions": + questions = result.get("questions") or [] + if not isinstance(questions, list): + logger.error( + f"Unexpected clarifying questions format: {type(questions)}" + ) + questions = [] + return ClarificationNeededResponse( + message=( + "I need some more information to customize this agent. " + "Please answer the following questions:" + ), + questions=[ + ClarifyingQuestion( + question=q.get("question", ""), + keyword=q.get("keyword", ""), + example=q.get("example"), + ) + for q in questions + if isinstance(q, dict) + ], + session_id=session_id, + ) + + # Result should be the customized agent JSON + if not isinstance(result, dict): + logger.error(f"Unexpected customize_template response type: {type(result)}") + return ErrorResponse( + message="Failed to customize the agent due to an unexpected response.", + error="unexpected_response_type", + session_id=session_id, + ) + + customized_agent = result + + agent_name = customized_agent.get( + "name", f"Customized {agent_details.agent_name}" + ) + agent_description = customized_agent.get("description", "") + nodes = customized_agent.get("nodes") + links = customized_agent.get("links") + node_count = len(nodes) if isinstance(nodes, list) else 0 + link_count = len(links) if isinstance(links, list) else 0 + + if not save: + return AgentPreviewResponse( + message=( + f"I've customized the agent '{agent_details.agent_name}'. " + f"The customized agent has {node_count} blocks. " + f"Review it and call customize_agent with save=true to save it." + ), + agent_json=customized_agent, + agent_name=agent_name, + description=agent_description, + node_count=node_count, + link_count=link_count, + session_id=session_id, + ) + + if not user_id: + return ErrorResponse( + message="You must be logged in to save agents.", + error="auth_required", + session_id=session_id, + ) + + # Save to user's library + try: + created_graph, library_agent = await save_agent_to_library( + customized_agent, user_id, is_update=False + ) + + return AgentSavedResponse( + message=( + f"Customized agent '{created_graph.name}' " + f"(based on '{agent_details.agent_name}') " + f"has been saved to your library!" + ), + agent_id=created_graph.id, + agent_name=created_graph.name, + library_agent_id=library_agent.id, + library_agent_link=f"/library/agents/{library_agent.id}", + agent_page_link=f"/build?flowID={created_graph.id}", + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error saving customized agent: {e}") + return ErrorResponse( + message="Failed to save the customized agent. Please try again.", + error="save_failed", + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py index 2c2c48226b..3ae56407a7 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py @@ -17,6 +17,7 @@ from .base import BaseTool from .models import ( AgentPreviewResponse, AgentSavedResponse, + AsyncProcessingResponse, ClarificationNeededResponse, ClarifyingQuestion, ErrorResponse, @@ -104,6 +105,10 @@ class EditAgentTool(BaseTool): save = kwargs.get("save", True) session_id = session.session_id if session else None + # Extract async processing params (passed by long-running tool handler) + operation_id = kwargs.get("_operation_id") + task_id = kwargs.get("_task_id") + if not agent_id: return ErrorResponse( message="Please provide the agent ID to edit.", @@ -149,7 +154,11 @@ class EditAgentTool(BaseTool): try: result = await generate_agent_patch( - update_request, current_agent, library_agents + update_request, + current_agent, + library_agents, + operation_id=operation_id, + task_id=task_id, ) except AgentGeneratorNotConfiguredError: return ErrorResponse( @@ -169,6 +178,20 @@ class EditAgentTool(BaseTool): session_id=session_id, ) + # Check if Agent Generator accepted for async processing + if result.get("status") == "accepted": + logger.info( + f"Agent edit delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return AsyncProcessingResponse( + message="Agent edit started. You'll be notified when it's complete.", + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + ) + + # Check if the result is an error from the external service if isinstance(result, dict) and result.get("type") == "error": error_msg = result.get("error", "Unknown error") error_type = result.get("error_type", "unknown") diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/models.py b/autogpt_platform/backend/backend/api/features/chat/tools/models.py index 49b233784e..69c8c6c684 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/models.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/models.py @@ -38,6 +38,8 @@ class ResponseType(str, Enum): OPERATION_STARTED = "operation_started" OPERATION_PENDING = "operation_pending" OPERATION_IN_PROGRESS = "operation_in_progress" + # Input validation + INPUT_VALIDATION_ERROR = "input_validation_error" # Base response model @@ -68,6 +70,10 @@ class AgentInfo(BaseModel): has_external_trigger: bool | None = None new_output: bool | None = None graph_id: str | None = None + inputs: dict[str, Any] | None = Field( + default=None, + description="Input schema for the agent, including field names, types, and defaults", + ) class AgentsFoundResponse(ToolResponseBase): @@ -194,6 +200,20 @@ class ErrorResponse(ToolResponseBase): details: dict[str, Any] | None = None +class InputValidationErrorResponse(ToolResponseBase): + """Response when run_agent receives unknown input fields.""" + + type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR + unrecognized_fields: list[str] = Field( + description="List of input field names that were not recognized" + ) + inputs: dict[str, Any] = Field( + description="The agent's valid input schema for reference" + ) + graph_id: str | None = None + graph_version: int | None = None + + # Agent output models class ExecutionOutputInfo(BaseModel): """Summary of a single execution's outputs.""" @@ -352,11 +372,15 @@ class OperationStartedResponse(ToolResponseBase): This is returned immediately to the client while the operation continues to execute. The user can close the tab and check back later. + + The task_id can be used to reconnect to the SSE stream via + GET /chat/tasks/{task_id}/stream?last_idx=0 """ type: ResponseType = ResponseType.OPERATION_STARTED operation_id: str tool_name: str + task_id: str | None = None # For SSE reconnection class OperationPendingResponse(ToolResponseBase): @@ -380,3 +404,20 @@ class OperationInProgressResponse(ToolResponseBase): type: ResponseType = ResponseType.OPERATION_IN_PROGRESS tool_call_id: str + + +class AsyncProcessingResponse(ToolResponseBase): + """Response when an operation has been delegated to async processing. + + This is returned by tools when the external service accepts the request + for async processing (HTTP 202 Accepted). The Redis Streams completion + consumer will handle the result when the external service completes. + + The status field is specifically "accepted" to allow the long-running tool + handler to detect this response and skip LLM continuation. + """ + + type: ResponseType = ResponseType.OPERATION_STARTED + status: str = "accepted" # Must be "accepted" for detection + operation_id: str | None = None + task_id: str | None = None diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py index a7fa65348a..73d4cf81f2 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent.py @@ -30,6 +30,7 @@ from .models import ( ErrorResponse, ExecutionOptions, ExecutionStartedResponse, + InputValidationErrorResponse, SetupInfo, SetupRequirementsResponse, ToolResponseBase, @@ -273,6 +274,22 @@ class RunAgentTool(BaseTool): input_properties = graph.input_schema.get("properties", {}) required_fields = set(graph.input_schema.get("required", [])) provided_inputs = set(params.inputs.keys()) + valid_fields = set(input_properties.keys()) + + # Check for unknown input fields + unrecognized_fields = provided_inputs - valid_fields + if unrecognized_fields: + return InputValidationErrorResponse( + message=( + f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. " + f"Agent was not executed. Please use the correct field names from the schema." + ), + session_id=session_id, + unrecognized_fields=sorted(unrecognized_fields), + inputs=graph.input_schema, + graph_id=graph.id, + graph_version=graph.version, + ) # If agent has inputs but none were provided AND use_defaults is not set, # always show what's available first so user can decide diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py index 404df2adb6..d5da394fa6 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_agent_test.py @@ -402,3 +402,42 @@ async def test_run_agent_schedule_without_name(setup_test_data): # Should return error about missing schedule_name assert result_data.get("type") == "error" assert "schedule_name" in result_data["message"].lower() + + +@pytest.mark.asyncio(loop_scope="session") +async def test_run_agent_rejects_unknown_input_fields(setup_test_data): + """Test that run_agent returns input_validation_error for unknown input fields.""" + user = setup_test_data["user"] + store_submission = setup_test_data["store_submission"] + + tool = RunAgentTool() + agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}" + session = make_session(user_id=user.id) + + # Execute with unknown input field names + response = await tool.execute( + user_id=user.id, + session_id=str(uuid.uuid4()), + tool_call_id=str(uuid.uuid4()), + username_agent_slug=agent_marketplace_id, + inputs={ + "unknown_field": "some value", + "another_unknown": "another value", + }, + session=session, + ) + + assert response is not None + assert hasattr(response, "output") + assert isinstance(response.output, str) + result_data = orjson.loads(response.output) + + # Should return input_validation_error type with unrecognized fields + assert result_data.get("type") == "input_validation_error" + assert "unrecognized_fields" in result_data + assert set(result_data["unrecognized_fields"]) == { + "another_unknown", + "unknown_field", + } + assert "inputs" in result_data # Contains the valid schema + assert "Agent was not executed" in result_data["message"] diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py index a59082b399..51bb2c0575 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/run_block.py @@ -5,6 +5,8 @@ import uuid from collections import defaultdict from typing import Any +from pydantic_core import PydanticUndefined + from backend.api.features.chat.model import ChatSession from backend.data.block import get_block from backend.data.execution import ExecutionContext @@ -75,15 +77,22 @@ class RunBlockTool(BaseTool): self, user_id: str, block: Any, + input_data: dict[str, Any] | None = None, ) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]: """ Check if user has required credentials for a block. + Args: + user_id: User ID + block: Block to check credentials for + input_data: Input data for the block (used to determine provider via discriminator) + Returns: tuple[matched_credentials, missing_credentials] """ matched_credentials: dict[str, CredentialsMetaInput] = {} missing_credentials: list[CredentialsMetaInput] = [] + input_data = input_data or {} # Get credential field info from block's input schema credentials_fields_info = block.input_schema.get_credentials_fields_info() @@ -96,14 +105,33 @@ class RunBlockTool(BaseTool): available_creds = await creds_manager.store.get_all_creds(user_id) for field_name, field_info in credentials_fields_info.items(): - # field_info.provider is a frozenset of acceptable providers - # field_info.supported_types is a frozenset of acceptable types + effective_field_info = field_info + if field_info.discriminator and field_info.discriminator_mapping: + # Get discriminator from input, falling back to schema default + discriminator_value = input_data.get(field_info.discriminator) + if discriminator_value is None: + field = block.input_schema.model_fields.get( + field_info.discriminator + ) + if field and field.default is not PydanticUndefined: + discriminator_value = field.default + + if ( + discriminator_value + and discriminator_value in field_info.discriminator_mapping + ): + effective_field_info = field_info.discriminate(discriminator_value) + logger.debug( + f"Discriminated provider for {field_name}: " + f"{discriminator_value} -> {effective_field_info.provider}" + ) + matching_cred = next( ( cred for cred in available_creds - if cred.provider in field_info.provider - and cred.type in field_info.supported_types + if cred.provider in effective_field_info.provider + and cred.type in effective_field_info.supported_types ), None, ) @@ -117,8 +145,8 @@ class RunBlockTool(BaseTool): ) else: # Create a placeholder for the missing credential - provider = next(iter(field_info.provider), "unknown") - cred_type = next(iter(field_info.supported_types), "api_key") + provider = next(iter(effective_field_info.provider), "unknown") + cred_type = next(iter(effective_field_info.supported_types), "api_key") missing_credentials.append( CredentialsMetaInput( id=field_name, @@ -186,10 +214,9 @@ class RunBlockTool(BaseTool): logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}") - # Check credentials creds_manager = IntegrationCredentialsManager() matched_credentials, missing_credentials = await self._check_block_credentials( - user_id, block + user_id, block, input_data ) if missing_credentials: diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py b/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py index bae5b97cd6..86af457f50 100644 --- a/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py +++ b/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py @@ -454,6 +454,9 @@ async def test_unified_hybrid_search_pagination( cleanup_embeddings: list, ): """Test unified search pagination works correctly.""" + # Use a unique search term to avoid matching other test data + unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}" + # Create multiple items content_ids = [] for i in range(5): @@ -465,14 +468,14 @@ async def test_unified_hybrid_search_pagination( content_type=ContentType.BLOCK, content_id=content_id, embedding=mock_embedding, - searchable_text=f"pagination test item number {i}", + searchable_text=f"{unique_term} item number {i}", metadata={"index": i}, user_id=None, ) # Get first page page1_results, total1 = await unified_hybrid_search( - query="pagination test", + query=unique_term, content_types=[ContentType.BLOCK], page=1, page_size=2, @@ -480,7 +483,7 @@ async def test_unified_hybrid_search_pagination( # Get second page page2_results, total2 = await unified_hybrid_search( - query="pagination test", + query=unique_term, content_types=[ContentType.BLOCK], page=2, page_size=2, diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index b936312ce1..0eef76193e 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -40,6 +40,10 @@ import backend.data.user import backend.integrations.webhooks.utils import backend.util.service import backend.util.settings +from backend.api.features.chat.completion_consumer import ( + start_completion_consumer, + stop_completion_consumer, +) from backend.blocks.llm import DEFAULT_LLM_MODEL from backend.data.model import Credentials from backend.integrations.providers import ProviderName @@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI): await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL) await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs() + # Start chat completion consumer for Redis Streams notifications + try: + await start_completion_consumer() + except Exception as e: + logger.warning(f"Could not start chat completion consumer: {e}") + with launch_darkly_context(): yield + # Stop chat completion consumer + try: + await stop_completion_consumer() + except Exception as e: + logger.warning(f"Error stopping chat completion consumer: {e}") + try: await shutdown_cloud_storage_handler() except Exception as e: diff --git a/autogpt_platform/backend/backend/blocks/llm.py b/autogpt_platform/backend/backend/blocks/llm.py index 732fb1354c..54295da1f1 100644 --- a/autogpt_platform/backend/backend/blocks/llm.py +++ b/autogpt_platform/backend/backend/blocks/llm.py @@ -32,7 +32,7 @@ from backend.data.model import ( from backend.integrations.providers import ProviderName from backend.util import json from backend.util.logging import TruncatedLogger -from backend.util.prompt import compress_prompt, estimate_token_count +from backend.util.prompt import compress_context, estimate_token_count from backend.util.text import TextFormatter logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]") @@ -634,11 +634,18 @@ async def llm_call( context_window = llm_model.context_window if compress_prompt_to_fit: - prompt = compress_prompt( + result = await compress_context( messages=prompt, target_tokens=llm_model.context_window // 2, - lossy_ok=True, + client=None, # Truncation-only, no LLM summarization + reserve=0, # Caller handles response token budget separately ) + if result.error: + logger.warning( + f"Prompt compression did not meet target: {result.error}. " + f"Proceeding with {result.token_count} tokens." + ) + prompt = result.messages # Calculate available tokens based on context window and input length estimated_input_tokens = estimate_token_count(prompt) diff --git a/autogpt_platform/backend/backend/blocks/stagehand/blocks.py b/autogpt_platform/backend/backend/blocks/stagehand/blocks.py index 4d5d6bf4f3..91c096ffe4 100644 --- a/autogpt_platform/backend/backend/blocks/stagehand/blocks.py +++ b/autogpt_platform/backend/backend/blocks/stagehand/blocks.py @@ -182,10 +182,7 @@ class StagehandObserveBlock(Block): **kwargs, ) -> BlockOutput: - logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}") - logger.info( - f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}" - ) + logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}") with disable_signal_handling(): stagehand = Stagehand( @@ -282,10 +279,7 @@ class StagehandActBlock(Block): **kwargs, ) -> BlockOutput: - logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}") - logger.info( - f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}" - ) + logger.debug(f"ACT: Using model provider {model_credentials.provider}") with disable_signal_handling(): stagehand = Stagehand( @@ -370,10 +364,7 @@ class StagehandExtractBlock(Block): **kwargs, ) -> BlockOutput: - logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}") - logger.info( - f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}" - ) + logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}") with disable_signal_handling(): stagehand = Stagehand( diff --git a/autogpt_platform/backend/backend/data/block.py b/autogpt_platform/backend/backend/data/block.py index 8d9ecfff4c..eb9360b037 100644 --- a/autogpt_platform/backend/backend/data/block.py +++ b/autogpt_platform/backend/backend/data/block.py @@ -873,14 +873,13 @@ def is_block_auth_configured( async def initialize_blocks() -> None: - # First, sync all provider costs to blocks - # Imported here to avoid circular import from backend.sdk.cost_integration import sync_all_provider_costs + from backend.util.retry import func_retry sync_all_provider_costs() - for cls in get_blocks().values(): - block = cls() + @func_retry + async def sync_block_to_db(block: Block) -> None: existing_block = await AgentBlock.prisma().find_first( where={"OR": [{"id": block.id}, {"name": block.name}]} ) @@ -893,7 +892,7 @@ async def initialize_blocks() -> None: outputSchema=json.dumps(block.output_schema.jsonschema()), ) ) - continue + return input_schema = json.dumps(block.input_schema.jsonschema()) output_schema = json.dumps(block.output_schema.jsonschema()) @@ -913,6 +912,25 @@ async def initialize_blocks() -> None: }, ) + failed_blocks: list[str] = [] + for cls in get_blocks().values(): + block = cls() + try: + await sync_block_to_db(block) + except Exception as e: + logger.warning( + f"Failed to sync block {block.name} to database: {e}. " + "Block is still available in memory.", + exc_info=True, + ) + failed_blocks.append(block.name) + + if failed_blocks: + logger.error( + f"Failed to sync {len(failed_blocks)} block(s) to database: " + f"{', '.join(failed_blocks)}. These blocks are still available in memory." + ) + # Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281 def get_block(block_id: str) -> AnyBlockSchema | None: diff --git a/autogpt_platform/backend/backend/executor/database.py b/autogpt_platform/backend/backend/executor/database.py index ae7474fc1d..d44439d51c 100644 --- a/autogpt_platform/backend/backend/executor/database.py +++ b/autogpt_platform/backend/backend/executor/database.py @@ -17,6 +17,7 @@ from backend.data.analytics import ( get_accuracy_trends_and_alerts, get_marketplace_graphs_for_monitoring, ) +from backend.data.auth.oauth import cleanup_expired_oauth_tokens from backend.data.credit import UsageTransactionMetadata, get_user_credit_model from backend.data.execution import ( create_graph_execution, @@ -219,6 +220,9 @@ class DatabaseManager(AppService): # Onboarding increment_onboarding_runs = _(increment_onboarding_runs) + # OAuth + cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens) + # Store get_store_agents = _(get_store_agents) get_store_agent_details = _(get_store_agent_details) @@ -349,6 +353,9 @@ class DatabaseManagerAsyncClient(AppServiceClient): # Onboarding increment_onboarding_runs = d.increment_onboarding_runs + # OAuth + cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens + # Store get_store_agents = d.get_store_agents get_store_agent_details = d.get_store_agent_details diff --git a/autogpt_platform/backend/backend/executor/scheduler.py b/autogpt_platform/backend/backend/executor/scheduler.py index 44b77fc018..cbdc441718 100644 --- a/autogpt_platform/backend/backend/executor/scheduler.py +++ b/autogpt_platform/backend/backend/executor/scheduler.py @@ -24,11 +24,9 @@ from dotenv import load_dotenv from pydantic import BaseModel, Field, ValidationError from sqlalchemy import MetaData, create_engine -from backend.data.auth.oauth import cleanup_expired_oauth_tokens from backend.data.block import BlockInput from backend.data.execution import GraphExecutionWithNodes from backend.data.model import CredentialsMetaInput -from backend.data.onboarding import increment_onboarding_runs from backend.executor import utils as execution_utils from backend.monitoring import ( NotificationJobArgs, @@ -38,7 +36,11 @@ from backend.monitoring import ( report_execution_accuracy_alerts, report_late_executions, ) -from backend.util.clients import get_database_manager_client, get_scheduler_client +from backend.util.clients import ( + get_database_manager_async_client, + get_database_manager_client, + get_scheduler_client, +) from backend.util.cloud_storage import cleanup_expired_files_async from backend.util.exceptions import ( GraphNotFoundError, @@ -148,6 +150,7 @@ def execute_graph(**kwargs): async def _execute_graph(**kwargs): args = GraphExecutionJobArgs(**kwargs) start_time = asyncio.get_event_loop().time() + db = get_database_manager_async_client() try: logger.info(f"Executing recurring job for graph #{args.graph_id}") graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution( @@ -157,7 +160,7 @@ async def _execute_graph(**kwargs): inputs=args.input_data, graph_credentials_inputs=args.input_credentials, ) - await increment_onboarding_runs(args.user_id) + await db.increment_onboarding_runs(args.user_id) elapsed = asyncio.get_event_loop().time() - start_time logger.info( f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} " @@ -246,8 +249,13 @@ def cleanup_expired_files(): def cleanup_oauth_tokens(): """Clean up expired OAuth tokens from the database.""" + # Wait for completion - run_async(cleanup_expired_oauth_tokens()) + async def _cleanup(): + db = get_database_manager_async_client() + return await db.cleanup_expired_oauth_tokens() + + run_async(_cleanup()) def execution_accuracy_alerts(): diff --git a/autogpt_platform/backend/backend/util/prompt.py b/autogpt_platform/backend/backend/util/prompt.py index 775d1c932b..5f904bbc8a 100644 --- a/autogpt_platform/backend/backend/util/prompt.py +++ b/autogpt_platform/backend/backend/util/prompt.py @@ -1,10 +1,19 @@ +from __future__ import annotations + +import logging from copy import deepcopy -from typing import Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any from tiktoken import encoding_for_model from backend.util import json +if TYPE_CHECKING: + from openai import AsyncOpenAI + +logger = logging.getLogger(__name__) + # ---------------------------------------------------------------------------# # CONSTANTS # # ---------------------------------------------------------------------------# @@ -100,9 +109,17 @@ def _is_objective_message(msg: dict) -> bool: def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None: """ Carefully truncate tool message content while preserving tool structure. - Only truncates tool_result content, leaves tool_use intact. + Handles both Anthropic-style (list content) and OpenAI-style (string content) tool messages. """ content = msg.get("content") + + # OpenAI-style tool message: role="tool" with string content + if msg.get("role") == "tool" and isinstance(content, str): + if _tok_len(content, enc) > max_tokens: + msg["content"] = _truncate_middle_tokens(content, enc, max_tokens) + return + + # Anthropic-style: list content with tool_result items if not isinstance(content, list): return @@ -140,141 +157,6 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str: # ---------------------------------------------------------------------------# -def compress_prompt( - messages: list[dict], - target_tokens: int, - *, - model: str = "gpt-4o", - reserve: int = 2_048, - start_cap: int = 8_192, - floor_cap: int = 128, - lossy_ok: bool = True, -) -> list[dict]: - """ - Shrink *messages* so that:: - - token_count(prompt) + reserve ≤ target_tokens - - Strategy - -------- - 1. **Token-aware truncation** – progressively halve a per-message cap - (`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the - *content* of every message except the first and last. Tool shells - are included: we keep the envelope but shorten huge payloads. - 2. **Middle-out deletion** – if still over the limit, delete whole - messages working outward from the centre, **skipping** any message - that contains ``tool_calls`` or has ``role == "tool"``. - 3. **Last-chance trim** – if still too big, truncate the *first* and - *last* message bodies down to `floor_cap` tokens. - 4. If the prompt is *still* too large: - • raise ``ValueError`` when ``lossy_ok == False`` (default) - • return the partially-trimmed prompt when ``lossy_ok == True`` - - Parameters - ---------- - messages Complete chat history (will be deep-copied). - model Model name; passed to tiktoken to pick the right - tokenizer (gpt-4o → 'o200k_base', others fallback). - target_tokens Hard ceiling for prompt size **excluding** the model's - forthcoming answer. - reserve How many tokens you want to leave available for that - answer (`max_tokens` in your subsequent completion call). - start_cap Initial per-message truncation ceiling (tokens). - floor_cap Lowest cap we'll accept before moving to deletions. - lossy_ok If *True* return best-effort prompt instead of raising - after all trim passes have been exhausted. - - Returns - ------- - list[dict] – A *new* messages list that abides by the rules above. - """ - enc = encoding_for_model(model) # best-match tokenizer - msgs = deepcopy(messages) # never mutate caller - - def total_tokens() -> int: - """Current size of *msgs* in tokens.""" - return sum(_msg_tokens(m, enc) for m in msgs) - - original_token_count = total_tokens() - - if original_token_count + reserve <= target_tokens: - return msgs - - # ---- STEP 0 : normalise content -------------------------------------- - # Convert non-string payloads to strings so token counting is coherent. - for i, m in enumerate(msgs): - if not isinstance(m.get("content"), str) and m.get("content") is not None: - if _is_tool_message(m): - continue - - # Keep first and last messages intact (unless they're tool messages) - if i == 0 or i == len(msgs) - 1: - continue - - # Reasonable 20k-char ceiling prevents pathological blobs - content_str = json.dumps(m["content"], separators=(",", ":")) - if len(content_str) > 20_000: - content_str = _truncate_middle_tokens(content_str, enc, 20_000) - m["content"] = content_str - - # ---- STEP 1 : token-aware truncation --------------------------------- - cap = start_cap - while total_tokens() + reserve > target_tokens and cap >= floor_cap: - for m in msgs[1:-1]: # keep first & last intact - if _is_tool_message(m): - # For tool messages, only truncate tool result content, preserve structure - _truncate_tool_message_content(m, enc, cap) - continue - - if _is_objective_message(m): - # Never truncate objective messages - they contain the core task - continue - - content = m.get("content") or "" - if _tok_len(content, enc) > cap: - m["content"] = _truncate_middle_tokens(content, enc, cap) - cap //= 2 # tighten the screw - - # ---- STEP 2 : middle-out deletion ----------------------------------- - while total_tokens() + reserve > target_tokens and len(msgs) > 2: - # Identify all deletable messages (not first/last, not tool messages, not objective messages) - deletable_indices = [] - for i in range(1, len(msgs) - 1): # Skip first and last - if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]): - deletable_indices.append(i) - - if not deletable_indices: - break # nothing more we can drop - - # Delete from center outward - find the index closest to center - centre = len(msgs) // 2 - to_delete = min(deletable_indices, key=lambda i: abs(i - centre)) - del msgs[to_delete] - - # ---- STEP 3 : final safety-net trim on first & last ------------------ - cap = start_cap - while total_tokens() + reserve > target_tokens and cap >= floor_cap: - for idx in (0, -1): # first and last - if _is_tool_message(msgs[idx]): - # For tool messages at first/last position, truncate tool result content only - _truncate_tool_message_content(msgs[idx], enc, cap) - continue - - text = msgs[idx].get("content") or "" - if _tok_len(text, enc) > cap: - msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap) - cap //= 2 # tighten the screw - - # ---- STEP 4 : success or fail-gracefully ----------------------------- - if total_tokens() + reserve > target_tokens and not lossy_ok: - raise ValueError( - "compress_prompt: prompt still exceeds budget " - f"({total_tokens() + reserve} > {target_tokens})." - ) - - return msgs - - def estimate_token_count( messages: list[dict], *, @@ -293,7 +175,8 @@ def estimate_token_count( ------- int – Token count. """ - enc = encoding_for_model(model) # best-match tokenizer + token_model = _normalize_model_for_tokenizer(model) + enc = encoding_for_model(token_model) return sum(_msg_tokens(m, enc) for m in messages) @@ -315,6 +198,543 @@ def estimate_token_count_str( ------- int – Token count. """ - enc = encoding_for_model(model) # best-match tokenizer + token_model = _normalize_model_for_tokenizer(model) + enc = encoding_for_model(token_model) text = json.dumps(text) if not isinstance(text, str) else text return _tok_len(text, enc) + + +# ---------------------------------------------------------------------------# +# UNIFIED CONTEXT COMPRESSION # +# ---------------------------------------------------------------------------# + +# Default thresholds +DEFAULT_TOKEN_THRESHOLD = 120_000 +DEFAULT_KEEP_RECENT = 15 + + +@dataclass +class CompressResult: + """Result of context compression.""" + + messages: list[dict] + token_count: int + was_compacted: bool + error: str | None = None + original_token_count: int = 0 + messages_summarized: int = 0 + messages_dropped: int = 0 + + +def _normalize_model_for_tokenizer(model: str) -> str: + """Normalize model name for tiktoken tokenizer selection.""" + if "/" in model: + model = model.split("/")[-1] + if "claude" in model.lower() or not any( + known in model.lower() for known in ["gpt", "o1", "chatgpt", "text-"] + ): + return "gpt-4o" + return model + + +def _extract_tool_call_ids_from_message(msg: dict) -> set[str]: + """ + Extract tool_call IDs from an assistant message. + + Supports both formats: + - OpenAI: {"role": "assistant", "tool_calls": [{"id": "..."}]} + - Anthropic: {"role": "assistant", "content": [{"type": "tool_use", "id": "..."}]} + + Returns: + Set of tool_call IDs found in the message. + """ + ids: set[str] = set() + if msg.get("role") != "assistant": + return ids + + # OpenAI format: tool_calls array + if msg.get("tool_calls"): + for tc in msg["tool_calls"]: + tc_id = tc.get("id") + if tc_id: + ids.add(tc_id) + + # Anthropic format: content list with tool_use blocks + content = msg.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_use": + tc_id = block.get("id") + if tc_id: + ids.add(tc_id) + + return ids + + +def _extract_tool_response_ids_from_message(msg: dict) -> set[str]: + """ + Extract tool_call IDs that this message is responding to. + + Supports both formats: + - OpenAI: {"role": "tool", "tool_call_id": "..."} + - Anthropic: {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "..."}]} + + Returns: + Set of tool_call IDs this message responds to. + """ + ids: set[str] = set() + + # OpenAI format: role=tool with tool_call_id + if msg.get("role") == "tool": + tc_id = msg.get("tool_call_id") + if tc_id: + ids.add(tc_id) + + # Anthropic format: content list with tool_result blocks + content = msg.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_result": + tc_id = block.get("tool_use_id") + if tc_id: + ids.add(tc_id) + + return ids + + +def _is_tool_response_message(msg: dict) -> bool: + """Check if message is a tool response (OpenAI or Anthropic format).""" + # OpenAI format + if msg.get("role") == "tool": + return True + # Anthropic format + content = msg.get("content") + if isinstance(content, list): + for block in content: + if isinstance(block, dict) and block.get("type") == "tool_result": + return True + return False + + +def _remove_orphan_tool_responses( + messages: list[dict], orphan_ids: set[str] +) -> list[dict]: + """ + Remove tool response messages/blocks that reference orphan tool_call IDs. + + Supports both OpenAI and Anthropic formats. + For Anthropic messages with mixed valid/orphan tool_result blocks, + filters out only the orphan blocks instead of dropping the entire message. + """ + result = [] + for msg in messages: + # OpenAI format: role=tool - drop entire message if orphan + if msg.get("role") == "tool": + tc_id = msg.get("tool_call_id") + if tc_id and tc_id in orphan_ids: + continue + result.append(msg) + continue + + # Anthropic format: content list may have mixed tool_result blocks + content = msg.get("content") + if isinstance(content, list): + has_tool_results = any( + isinstance(b, dict) and b.get("type") == "tool_result" for b in content + ) + if has_tool_results: + # Filter out orphan tool_result blocks, keep valid ones + filtered_content = [ + block + for block in content + if not ( + isinstance(block, dict) + and block.get("type") == "tool_result" + and block.get("tool_use_id") in orphan_ids + ) + ] + # Only keep message if it has remaining content + if filtered_content: + msg = msg.copy() + msg["content"] = filtered_content + result.append(msg) + continue + + result.append(msg) + return result + + +def _ensure_tool_pairs_intact( + recent_messages: list[dict], + all_messages: list[dict], + start_index: int, +) -> list[dict]: + """ + Ensure tool_call/tool_response pairs stay together after slicing. + + When slicing messages for context compaction, a naive slice can separate + an assistant message containing tool_calls from its corresponding tool + response messages. This causes API validation errors (e.g., Anthropic's + "unexpected tool_use_id found in tool_result blocks"). + + This function checks for orphan tool responses in the slice and extends + backwards to include their corresponding assistant messages. + + Supports both formats: + - OpenAI: tool_calls array + role="tool" responses + - Anthropic: tool_use blocks + tool_result blocks + + Args: + recent_messages: The sliced messages to validate + all_messages: The complete message list (for looking up missing assistants) + start_index: The index in all_messages where recent_messages begins + + Returns: + A potentially extended list of messages with tool pairs intact + """ + if not recent_messages: + return recent_messages + + # Collect all tool_call_ids from assistant messages in the slice + available_tool_call_ids: set[str] = set() + for msg in recent_messages: + available_tool_call_ids |= _extract_tool_call_ids_from_message(msg) + + # Find orphan tool responses (responses whose tool_call_id is missing) + orphan_tool_call_ids: set[str] = set() + for msg in recent_messages: + response_ids = _extract_tool_response_ids_from_message(msg) + for tc_id in response_ids: + if tc_id not in available_tool_call_ids: + orphan_tool_call_ids.add(tc_id) + + if not orphan_tool_call_ids: + # No orphans, slice is valid + return recent_messages + + # Find the assistant messages that contain the orphan tool_call_ids + # Search backwards from start_index in all_messages + messages_to_prepend: list[dict] = [] + for i in range(start_index - 1, -1, -1): + msg = all_messages[i] + msg_tool_ids = _extract_tool_call_ids_from_message(msg) + if msg_tool_ids & orphan_tool_call_ids: + # This assistant message has tool_calls we need + # Also collect its contiguous tool responses that follow it + assistant_and_responses: list[dict] = [msg] + + # Scan forward from this assistant to collect tool responses + for j in range(i + 1, start_index): + following_msg = all_messages[j] + following_response_ids = _extract_tool_response_ids_from_message( + following_msg + ) + if following_response_ids and following_response_ids & msg_tool_ids: + assistant_and_responses.append(following_msg) + elif not _is_tool_response_message(following_msg): + # Stop at first non-tool-response message + break + + # Prepend the assistant and its tool responses (maintain order) + messages_to_prepend = assistant_and_responses + messages_to_prepend + # Mark these as found + orphan_tool_call_ids -= msg_tool_ids + # Also add this assistant's tool_call_ids to available set + available_tool_call_ids |= msg_tool_ids + + if not orphan_tool_call_ids: + # Found all missing assistants + break + + if orphan_tool_call_ids: + # Some tool_call_ids couldn't be resolved - remove those tool responses + # This shouldn't happen in normal operation but handles edge cases + logger.warning( + f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. " + "Removing orphan tool responses." + ) + recent_messages = _remove_orphan_tool_responses( + recent_messages, orphan_tool_call_ids + ) + + if messages_to_prepend: + logger.info( + f"Extended recent messages by {len(messages_to_prepend)} to preserve " + f"tool_call/tool_response pairs" + ) + return messages_to_prepend + recent_messages + + return recent_messages + + +async def _summarize_messages_llm( + messages: list[dict], + client: AsyncOpenAI, + model: str, + timeout: float = 30.0, +) -> str: + """Summarize messages using an LLM.""" + conversation = [] + for msg in messages: + role = msg.get("role", "") + content = msg.get("content", "") + if content and role in ("user", "assistant", "tool"): + conversation.append(f"{role.upper()}: {content}") + + conversation_text = "\n\n".join(conversation) + + if not conversation_text: + return "No conversation history available." + + # Limit to ~100k chars for safety + MAX_CHARS = 100_000 + if len(conversation_text) > MAX_CHARS: + conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]" + + response = await client.with_options(timeout=timeout).chat.completions.create( + model=model, + messages=[ + { + "role": "system", + "content": ( + "Create a detailed summary of the conversation so far. " + "This summary will be used as context when continuing the conversation.\n\n" + "Before writing the summary, analyze each message chronologically to identify:\n" + "- User requests and their explicit goals\n" + "- Your approach and key decisions made\n" + "- Technical specifics (file names, tool outputs, function signatures)\n" + "- Errors encountered and resolutions applied\n\n" + "You MUST include ALL of the following sections:\n\n" + "## 1. Primary Request and Intent\n" + "The user's explicit goals and what they are trying to accomplish.\n\n" + "## 2. Key Technical Concepts\n" + "Technologies, frameworks, tools, and patterns being used or discussed.\n\n" + "## 3. Files and Resources Involved\n" + "Specific files examined or modified, with relevant snippets and identifiers.\n\n" + "## 4. Errors and Fixes\n" + "Problems encountered, error messages, and their resolutions. " + "Include any user feedback on fixes.\n\n" + "## 5. Problem Solving\n" + "Issues that have been resolved and how they were addressed.\n\n" + "## 6. All User Messages\n" + "A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n" + "## 7. Pending Tasks\n" + "Work items the user explicitly requested that have not yet been completed.\n\n" + "## 8. Current Work\n" + "Precise description of what was being worked on most recently, including relevant context.\n\n" + "## 9. Next Steps\n" + "What should happen next, aligned with the user's most recent requests. " + "Include verbatim quotes of recent instructions if relevant." + ), + }, + {"role": "user", "content": f"Summarize:\n\n{conversation_text}"}, + ], + max_tokens=1500, + temperature=0.3, + ) + + return response.choices[0].message.content or "No summary available." + + +async def compress_context( + messages: list[dict], + target_tokens: int = DEFAULT_TOKEN_THRESHOLD, + *, + model: str = "gpt-4o", + client: AsyncOpenAI | None = None, + keep_recent: int = DEFAULT_KEEP_RECENT, + reserve: int = 2_048, + start_cap: int = 8_192, + floor_cap: int = 128, +) -> CompressResult: + """ + Unified context compression that combines summarization and truncation strategies. + + Strategy (in order): + 1. **LLM summarization** – If client provided, summarize old messages into a + single context message while keeping recent messages intact. This is the + primary strategy for chat service. + 2. **Content truncation** – Progressively halve a per-message cap and truncate + bloated message content (tool outputs, large pastes). Preserves all messages + but shortens their content. Primary strategy when client=None (LLM blocks). + 3. **Middle-out deletion** – Delete whole messages one at a time from the center + outward, skipping tool messages and objective messages. + 4. **First/last trim** – Truncate first and last message content as last resort. + + Parameters + ---------- + messages Complete chat history (will be deep-copied). + target_tokens Hard ceiling for prompt size. + model Model name for tokenization and summarization. + client AsyncOpenAI client. If provided, enables LLM summarization + as the first strategy. If None, skips to truncation strategies. + keep_recent Number of recent messages to preserve during summarization. + reserve Tokens to reserve for model response. + start_cap Initial per-message truncation ceiling (tokens). + floor_cap Lowest cap before moving to deletions. + + Returns + ------- + CompressResult with compressed messages and metadata. + """ + # Guard clause for empty messages + if not messages: + return CompressResult( + messages=[], + token_count=0, + was_compacted=False, + original_token_count=0, + ) + + token_model = _normalize_model_for_tokenizer(model) + enc = encoding_for_model(token_model) + msgs = deepcopy(messages) + + def total_tokens() -> int: + return sum(_msg_tokens(m, enc) for m in msgs) + + original_count = total_tokens() + + # Already under limit + if original_count + reserve <= target_tokens: + return CompressResult( + messages=msgs, + token_count=original_count, + was_compacted=False, + original_token_count=original_count, + ) + + messages_summarized = 0 + messages_dropped = 0 + + # ---- STEP 1: LLM summarization (if client provided) ------------------- + # This is the primary compression strategy for chat service. + # Summarize old messages while keeping recent ones intact. + if client is not None: + has_system = len(msgs) > 0 and msgs[0].get("role") == "system" + system_msg = msgs[0] if has_system else None + + # Calculate old vs recent messages + if has_system: + if len(msgs) > keep_recent + 1: + old_msgs = msgs[1:-keep_recent] + recent_msgs = msgs[-keep_recent:] + else: + old_msgs = [] + recent_msgs = msgs[1:] if len(msgs) > 1 else [] + else: + if len(msgs) > keep_recent: + old_msgs = msgs[:-keep_recent] + recent_msgs = msgs[-keep_recent:] + else: + old_msgs = [] + recent_msgs = msgs + + # Ensure tool pairs stay intact + slice_start = max(0, len(msgs) - keep_recent) + recent_msgs = _ensure_tool_pairs_intact(recent_msgs, msgs, slice_start) + + if old_msgs: + try: + summary_text = await _summarize_messages_llm(old_msgs, client, model) + summary_msg = { + "role": "assistant", + "content": f"[Previous conversation summary — for context only]: {summary_text}", + } + messages_summarized = len(old_msgs) + + if has_system: + msgs = [system_msg, summary_msg] + recent_msgs + else: + msgs = [summary_msg] + recent_msgs + + logger.info( + f"Context summarized: {original_count} -> {total_tokens()} tokens, " + f"summarized {messages_summarized} messages" + ) + except Exception as e: + logger.warning(f"Summarization failed, continuing with truncation: {e}") + # Fall through to content truncation + + # ---- STEP 2: Normalize content ---------------------------------------- + # Convert non-string payloads to strings so token counting is coherent. + # Always run this before truncation to ensure consistent token counting. + for i, m in enumerate(msgs): + if not isinstance(m.get("content"), str) and m.get("content") is not None: + if _is_tool_message(m): + continue + if i == 0 or i == len(msgs) - 1: + continue + content_str = json.dumps(m["content"], separators=(",", ":")) + if len(content_str) > 20_000: + content_str = _truncate_middle_tokens(content_str, enc, 20_000) + m["content"] = content_str + + # ---- STEP 3: Token-aware content truncation --------------------------- + # Progressively halve per-message cap and truncate bloated content. + # This preserves all messages but shortens their content. + cap = start_cap + while total_tokens() + reserve > target_tokens and cap >= floor_cap: + for m in msgs[1:-1]: + if _is_tool_message(m): + _truncate_tool_message_content(m, enc, cap) + continue + if _is_objective_message(m): + continue + content = m.get("content") or "" + if _tok_len(content, enc) > cap: + m["content"] = _truncate_middle_tokens(content, enc, cap) + cap //= 2 + + # ---- STEP 4: Middle-out deletion -------------------------------------- + # Delete messages one at a time from the center outward. + # This is more granular than dropping all old messages at once. + while total_tokens() + reserve > target_tokens and len(msgs) > 2: + deletable: list[int] = [] + for i in range(1, len(msgs) - 1): + msg = msgs[i] + if ( + msg is not None + and not _is_tool_message(msg) + and not _is_objective_message(msg) + ): + deletable.append(i) + if not deletable: + break + centre = len(msgs) // 2 + to_delete = min(deletable, key=lambda i: abs(i - centre)) + del msgs[to_delete] + messages_dropped += 1 + + # ---- STEP 5: Final trim on first/last --------------------------------- + cap = start_cap + while total_tokens() + reserve > target_tokens and cap >= floor_cap: + for idx in (0, -1): + msg = msgs[idx] + if msg is None: + continue + if _is_tool_message(msg): + _truncate_tool_message_content(msg, enc, cap) + continue + text = msg.get("content") or "" + if _tok_len(text, enc) > cap: + msg["content"] = _truncate_middle_tokens(text, enc, cap) + cap //= 2 + + # Filter out any None values that may have been introduced + final_msgs: list[dict] = [m for m in msgs if m is not None] + final_count = sum(_msg_tokens(m, enc) for m in final_msgs) + error = None + if final_count + reserve > target_tokens: + error = f"Could not compress below target ({final_count + reserve} > {target_tokens})" + logger.warning(error) + + return CompressResult( + messages=final_msgs, + token_count=final_count, + was_compacted=True, + error=error, + original_token_count=original_count, + messages_summarized=messages_summarized, + messages_dropped=messages_dropped, + ) diff --git a/autogpt_platform/backend/backend/util/prompt_test.py b/autogpt_platform/backend/backend/util/prompt_test.py index af6b230f8f..2d4bf090b3 100644 --- a/autogpt_platform/backend/backend/util/prompt_test.py +++ b/autogpt_platform/backend/backend/util/prompt_test.py @@ -1,10 +1,21 @@ """Tests for prompt utility functions, especially tool call token counting.""" +from unittest.mock import AsyncMock, MagicMock + import pytest from tiktoken import encoding_for_model from backend.util import json -from backend.util.prompt import _msg_tokens, estimate_token_count +from backend.util.prompt import ( + CompressResult, + _ensure_tool_pairs_intact, + _msg_tokens, + _normalize_model_for_tokenizer, + _truncate_middle_tokens, + _truncate_tool_message_content, + compress_context, + estimate_token_count, +) class TestMsgTokens: @@ -276,3 +287,690 @@ class TestEstimateTokenCount: assert total_tokens == expected_total assert total_tokens > 20 # Should be substantial + + +class TestNormalizeModelForTokenizer: + """Test model name normalization for tiktoken.""" + + def test_openai_models_unchanged(self): + """Test that OpenAI models are returned as-is.""" + assert _normalize_model_for_tokenizer("gpt-4o") == "gpt-4o" + assert _normalize_model_for_tokenizer("gpt-4") == "gpt-4" + assert _normalize_model_for_tokenizer("gpt-3.5-turbo") == "gpt-3.5-turbo" + + def test_claude_models_normalized(self): + """Test that Claude models are normalized to gpt-4o.""" + assert _normalize_model_for_tokenizer("claude-3-opus") == "gpt-4o" + assert _normalize_model_for_tokenizer("claude-3-sonnet") == "gpt-4o" + assert _normalize_model_for_tokenizer("anthropic/claude-3-haiku") == "gpt-4o" + + def test_openrouter_paths_extracted(self): + """Test that OpenRouter model paths are handled.""" + assert _normalize_model_for_tokenizer("openai/gpt-4o") == "gpt-4o" + assert _normalize_model_for_tokenizer("anthropic/claude-3-opus") == "gpt-4o" + + def test_unknown_models_default_to_gpt4o(self): + """Test that unknown models default to gpt-4o.""" + assert _normalize_model_for_tokenizer("some-random-model") == "gpt-4o" + assert _normalize_model_for_tokenizer("llama-3-70b") == "gpt-4o" + + +class TestTruncateToolMessageContent: + """Test tool message content truncation.""" + + @pytest.fixture + def enc(self): + return encoding_for_model("gpt-4o") + + def test_truncate_openai_tool_message(self, enc): + """Test truncation of OpenAI-style tool message with string content.""" + long_content = "x" * 10000 + msg = {"role": "tool", "tool_call_id": "call_123", "content": long_content} + + _truncate_tool_message_content(msg, enc, max_tokens=100) + + # Content should be truncated + assert len(msg["content"]) < len(long_content) + assert "…" in msg["content"] # Has ellipsis marker + + def test_truncate_anthropic_tool_result(self, enc): + """Test truncation of Anthropic-style tool_result.""" + long_content = "y" * 10000 + msg = { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_123", + "content": long_content, + } + ], + } + + _truncate_tool_message_content(msg, enc, max_tokens=100) + + # Content should be truncated + result_content = msg["content"][0]["content"] + assert len(result_content) < len(long_content) + assert "…" in result_content + + def test_preserve_tool_use_blocks(self, enc): + """Test that tool_use blocks are not truncated.""" + msg = { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_123", + "name": "some_function", + "input": {"key": "value" * 1000}, # Large input + } + ], + } + + original = json.dumps(msg["content"][0]["input"]) + _truncate_tool_message_content(msg, enc, max_tokens=10) + + # tool_use should be unchanged + assert json.dumps(msg["content"][0]["input"]) == original + + def test_no_truncation_when_under_limit(self, enc): + """Test that short content is not modified.""" + msg = {"role": "tool", "tool_call_id": "call_123", "content": "Short content"} + + original = msg["content"] + _truncate_tool_message_content(msg, enc, max_tokens=1000) + + assert msg["content"] == original + + +class TestTruncateMiddleTokens: + """Test middle truncation of text.""" + + @pytest.fixture + def enc(self): + return encoding_for_model("gpt-4o") + + def test_truncates_long_text(self, enc): + """Test that long text is truncated with ellipsis in middle.""" + long_text = "word " * 1000 + result = _truncate_middle_tokens(long_text, enc, max_tok=50) + + assert len(enc.encode(result)) <= 52 # Allow some slack for ellipsis + assert "…" in result + assert result.startswith("word") # Head preserved + assert result.endswith("word ") # Tail preserved + + def test_preserves_short_text(self, enc): + """Test that short text is not modified.""" + short_text = "Hello world" + result = _truncate_middle_tokens(short_text, enc, max_tok=100) + + assert result == short_text + + +class TestEnsureToolPairsIntact: + """Test tool call/response pair preservation for both OpenAI and Anthropic formats.""" + + # ---- OpenAI Format Tests ---- + + def test_openai_adds_missing_tool_call(self): + """Test that orphaned OpenAI tool_response gets its tool_call prepended.""" + all_msgs = [ + {"role": "system", "content": "You are helpful."}, + { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "f1"}} + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result"}, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (the tool response) + recent = [all_msgs[2], all_msgs[3]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the tool_call message + assert len(result) == 3 + assert result[0]["role"] == "assistant" + assert "tool_calls" in result[0] + + def test_openai_keeps_complete_pairs(self): + """Test that complete OpenAI pairs are unchanged.""" + all_msgs = [ + {"role": "system", "content": "System"}, + { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "f1"}} + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result"}, + ] + recent = all_msgs[1:] # Include both tool_call and response + start_index = 1 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + assert len(result) == 2 # No messages added + + def test_openai_multiple_tool_calls(self): + """Test multiple OpenAI tool calls in one assistant message.""" + all_msgs = [ + {"role": "system", "content": "System"}, + { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "f1"}}, + {"id": "call_2", "type": "function", "function": {"name": "f2"}}, + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "result1"}, + {"role": "tool", "tool_call_id": "call_2", "content": "result2"}, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (first tool response) + recent = [all_msgs[2], all_msgs[3], all_msgs[4]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the assistant message with both tool_calls + assert len(result) == 4 + assert result[0]["role"] == "assistant" + assert len(result[0]["tool_calls"]) == 2 + + # ---- Anthropic Format Tests ---- + + def test_anthropic_adds_missing_tool_use(self): + """Test that orphaned Anthropic tool_result gets its tool_use prepended.""" + all_msgs = [ + {"role": "system", "content": "You are helpful."}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_123", + "name": "get_weather", + "input": {"location": "SF"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_123", + "content": "22°C and sunny", + } + ], + }, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (the tool_result) + recent = [all_msgs[2], all_msgs[3]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the tool_use message + assert len(result) == 3 + assert result[0]["role"] == "assistant" + assert result[0]["content"][0]["type"] == "tool_use" + + def test_anthropic_keeps_complete_pairs(self): + """Test that complete Anthropic pairs are unchanged.""" + all_msgs = [ + {"role": "system", "content": "System"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_456", + "name": "calculator", + "input": {"expr": "2+2"}, + } + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_456", + "content": "4", + } + ], + }, + ] + recent = all_msgs[1:] # Include both tool_use and result + start_index = 1 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + assert len(result) == 2 # No messages added + + def test_anthropic_multiple_tool_uses(self): + """Test multiple Anthropic tool_use blocks in one message.""" + all_msgs = [ + {"role": "system", "content": "System"}, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Let me check both..."}, + { + "type": "tool_use", + "id": "toolu_1", + "name": "get_weather", + "input": {"city": "NYC"}, + }, + { + "type": "tool_use", + "id": "toolu_2", + "name": "get_weather", + "input": {"city": "LA"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_1", + "content": "Cold", + }, + { + "type": "tool_result", + "tool_use_id": "toolu_2", + "content": "Warm", + }, + ], + }, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (tool_result) + recent = [all_msgs[2], all_msgs[3]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the assistant message with both tool_uses + assert len(result) == 3 + assert result[0]["role"] == "assistant" + tool_use_count = sum( + 1 for b in result[0]["content"] if b.get("type") == "tool_use" + ) + assert tool_use_count == 2 + + # ---- Mixed/Edge Case Tests ---- + + def test_anthropic_with_type_message_field(self): + """Test Anthropic format with 'type': 'message' field (smart_decision_maker style).""" + all_msgs = [ + {"role": "system", "content": "You are helpful."}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "toolu_abc", + "name": "search", + "input": {"q": "test"}, + } + ], + }, + { + "role": "user", + "type": "message", # Extra field from smart_decision_maker + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_abc", + "content": "Found results", + } + ], + }, + {"role": "user", "content": "Thanks!"}, + ] + # Recent messages start at index 2 (the tool_result with 'type': 'message') + recent = [all_msgs[2], all_msgs[3]] + start_index = 2 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the tool_use message + assert len(result) == 3 + assert result[0]["role"] == "assistant" + assert result[0]["content"][0]["type"] == "tool_use" + + def test_handles_no_tool_messages(self): + """Test messages without tool calls.""" + all_msgs = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + recent = all_msgs + start_index = 0 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + assert result == all_msgs + + def test_handles_empty_messages(self): + """Test empty message list.""" + result = _ensure_tool_pairs_intact([], [], 0) + assert result == [] + + def test_mixed_text_and_tool_content(self): + """Test Anthropic message with mixed text and tool_use content.""" + all_msgs = [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "I'll help you with that."}, + { + "type": "tool_use", + "id": "toolu_mixed", + "name": "search", + "input": {"q": "test"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_mixed", + "content": "Found results", + } + ], + }, + {"role": "assistant", "content": "Here are the results..."}, + ] + # Start from tool_result + recent = [all_msgs[1], all_msgs[2]] + start_index = 1 + + result = _ensure_tool_pairs_intact(recent, all_msgs, start_index) + + # Should prepend the assistant message with tool_use + assert len(result) == 3 + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][1]["type"] == "tool_use" + + +class TestCompressContext: + """Test the async compress_context function.""" + + @pytest.mark.asyncio + async def test_no_compression_needed(self): + """Test messages under limit return without compression.""" + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello!"}, + ] + + result = await compress_context(messages, target_tokens=100000) + + assert isinstance(result, CompressResult) + assert result.was_compacted is False + assert len(result.messages) == 2 + assert result.error is None + + @pytest.mark.asyncio + async def test_truncation_without_client(self): + """Test that truncation works without LLM client.""" + long_content = "x" * 50000 + messages = [ + {"role": "system", "content": "System"}, + {"role": "user", "content": long_content}, + {"role": "assistant", "content": "Response"}, + ] + + result = await compress_context( + messages, target_tokens=1000, client=None, reserve=100 + ) + + assert result.was_compacted is True + # Should have truncated without summarization + assert result.messages_summarized == 0 + + @pytest.mark.asyncio + async def test_with_mocked_llm_client(self): + """Test summarization with mocked LLM client.""" + # Create many messages to trigger summarization + messages = [{"role": "system", "content": "System prompt"}] + for i in range(30): + messages.append({"role": "user", "content": f"User message {i} " * 100}) + messages.append( + {"role": "assistant", "content": f"Assistant response {i} " * 100} + ) + + # Mock the AsyncOpenAI client + mock_client = AsyncMock() + mock_response = MagicMock() + mock_response.choices = [MagicMock()] + mock_response.choices[0].message.content = "Summary of conversation" + mock_client.with_options.return_value.chat.completions.create = AsyncMock( + return_value=mock_response + ) + + result = await compress_context( + messages, + target_tokens=5000, + client=mock_client, + keep_recent=5, + reserve=500, + ) + + assert result.was_compacted is True + # Should have attempted summarization + assert mock_client.with_options.called or result.messages_summarized > 0 + + @pytest.mark.asyncio + async def test_preserves_tool_pairs(self): + """Test that tool call/response pairs stay together.""" + messages = [ + {"role": "system", "content": "System"}, + {"role": "user", "content": "Do something"}, + { + "role": "assistant", + "tool_calls": [ + {"id": "call_1", "type": "function", "function": {"name": "func"}} + ], + }, + {"role": "tool", "tool_call_id": "call_1", "content": "Result " * 1000}, + {"role": "assistant", "content": "Done!"}, + ] + + result = await compress_context( + messages, target_tokens=500, client=None, reserve=50 + ) + + # Check that if tool response exists, its call exists too + tool_call_ids = set() + tool_response_ids = set() + for msg in result.messages: + if "tool_calls" in msg: + for tc in msg["tool_calls"]: + tool_call_ids.add(tc["id"]) + if msg.get("role") == "tool": + tool_response_ids.add(msg.get("tool_call_id")) + + # All tool responses should have their calls + assert tool_response_ids <= tool_call_ids + + @pytest.mark.asyncio + async def test_returns_error_when_cannot_compress(self): + """Test that error is returned when compression fails.""" + # Single huge message that can't be compressed enough + messages = [ + {"role": "user", "content": "x" * 100000}, + ] + + result = await compress_context( + messages, target_tokens=100, client=None, reserve=50 + ) + + # Should have an error since we can't get below 100 tokens + assert result.error is not None + assert result.was_compacted is True + + @pytest.mark.asyncio + async def test_empty_messages(self): + """Test that empty messages list returns early without error.""" + result = await compress_context([], target_tokens=1000) + + assert result.messages == [] + assert result.token_count == 0 + assert result.was_compacted is False + assert result.error is None + + +class TestRemoveOrphanToolResponses: + """Test _remove_orphan_tool_responses helper function.""" + + def test_removes_openai_orphan(self): + """Test removal of orphan OpenAI tool response.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + {"role": "tool", "tool_call_id": "call_orphan", "content": "result"}, + {"role": "user", "content": "Hello"}, + ] + orphan_ids = {"call_orphan"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + assert len(result) == 1 + assert result[0]["role"] == "user" + + def test_keeps_valid_openai_tool(self): + """Test that valid OpenAI tool responses are kept.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + {"role": "tool", "tool_call_id": "call_valid", "content": "result"}, + ] + orphan_ids = {"call_other"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + assert len(result) == 1 + assert result[0]["tool_call_id"] == "call_valid" + + def test_filters_anthropic_mixed_blocks(self): + """Test filtering individual orphan blocks from Anthropic message with mixed valid/orphan.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_valid", + "content": "valid result", + }, + { + "type": "tool_result", + "tool_use_id": "toolu_orphan", + "content": "orphan result", + }, + ], + }, + ] + orphan_ids = {"toolu_orphan"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + assert len(result) == 1 + # Should only have the valid tool_result, orphan filtered out + assert len(result[0]["content"]) == 1 + assert result[0]["content"][0]["tool_use_id"] == "toolu_valid" + + def test_removes_anthropic_all_orphan(self): + """Test removal of Anthropic message when all tool_results are orphans.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "toolu_orphan1", + "content": "result1", + }, + { + "type": "tool_result", + "tool_use_id": "toolu_orphan2", + "content": "result2", + }, + ], + }, + ] + orphan_ids = {"toolu_orphan1", "toolu_orphan2"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + # Message should be completely removed since no content left + assert len(result) == 0 + + def test_preserves_non_tool_messages(self): + """Test that non-tool messages are preserved.""" + from backend.util.prompt import _remove_orphan_tool_responses + + messages = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + ] + orphan_ids = {"some_id"} + + result = _remove_orphan_tool_responses(messages, orphan_ids) + + assert result == messages + + +class TestCompressResultDataclass: + """Test CompressResult dataclass.""" + + def test_default_values(self): + """Test default values are set correctly.""" + result = CompressResult( + messages=[{"role": "user", "content": "test"}], + token_count=10, + was_compacted=False, + ) + + assert result.error is None + assert result.original_token_count == 0 # Defaults to 0, not None + assert result.messages_summarized == 0 + assert result.messages_dropped == 0 + + def test_all_fields(self): + """Test all fields can be set.""" + result = CompressResult( + messages=[{"role": "user", "content": "test"}], + token_count=100, + was_compacted=True, + error="Some error", + original_token_count=500, + messages_summarized=10, + messages_dropped=5, + ) + + assert result.token_count == 100 + assert result.was_compacted is True + assert result.error == "Some error" + assert result.original_token_count == 500 + assert result.messages_summarized == 10 + assert result.messages_dropped == 5 diff --git a/autogpt_platform/backend/test/agent_generator/test_core_integration.py b/autogpt_platform/backend/test/agent_generator/test_core_integration.py index 05ce4a3aff..528763e751 100644 --- a/autogpt_platform/backend/test/agent_generator/test_core_integration.py +++ b/autogpt_platform/backend/test/agent_generator/test_core_integration.py @@ -111,9 +111,7 @@ class TestGenerateAgent: instructions = {"type": "instructions", "steps": ["Step 1"]} result = await core.generate_agent(instructions) - # library_agents defaults to None - mock_external.assert_called_once_with(instructions, None) - # Result should have id, version, is_active added if not present + mock_external.assert_called_once_with(instructions, None, None, None) assert result is not None assert result["name"] == "Test Agent" assert "id" in result @@ -177,8 +175,9 @@ class TestGenerateAgentPatch: current_agent = {"nodes": [], "links": []} result = await core.generate_agent_patch("Add a node", current_agent) - # library_agents defaults to None - mock_external.assert_called_once_with("Add a node", current_agent, None) + mock_external.assert_called_once_with( + "Add a node", current_agent, None, None, None + ) assert result == expected_result @pytest.mark.asyncio diff --git a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx index 70d9783ccd..246fe52826 100644 --- a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx +++ b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/page.tsx @@ -1,10 +1,9 @@ "use client"; +import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding"; +import { getOnboardingStatus, resolveResponse } from "@/app/api/helpers"; import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { useRouter } from "next/navigation"; import { useEffect } from "react"; -import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers"; -import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding"; -import { getHomepageRoute } from "@/lib/constants"; export default function OnboardingPage() { const router = useRouter(); @@ -13,12 +12,10 @@ export default function OnboardingPage() { async function redirectToStep() { try { // Check if onboarding is enabled (also gets chat flag for redirect) - const { shouldShowOnboarding, isChatEnabled } = - await getOnboardingStatus(); - const homepageRoute = getHomepageRoute(isChatEnabled); + const { shouldShowOnboarding } = await getOnboardingStatus(); if (!shouldShowOnboarding) { - router.replace(homepageRoute); + router.replace("/"); return; } @@ -26,7 +23,7 @@ export default function OnboardingPage() { // Handle completed onboarding if (onboarding.completedSteps.includes("GET_RESULTS")) { - router.replace(homepageRoute); + router.replace("/"); return; } diff --git a/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts b/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts index 15be137f63..e7e2997d0d 100644 --- a/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts +++ b/autogpt_platform/frontend/src/app/(platform)/auth/callback/route.ts @@ -1,9 +1,8 @@ -import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; -import { getHomepageRoute } from "@/lib/constants"; -import BackendAPI from "@/lib/autogpt-server-api"; -import { NextResponse } from "next/server"; -import { revalidatePath } from "next/cache"; import { getOnboardingStatus } from "@/app/api/helpers"; +import BackendAPI from "@/lib/autogpt-server-api"; +import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase"; +import { revalidatePath } from "next/cache"; +import { NextResponse } from "next/server"; // Handle the callback to complete the user session login export async function GET(request: Request) { @@ -27,13 +26,12 @@ export async function GET(request: Request) { await api.createUser(); // Get onboarding status from backend (includes chat flag evaluated for this user) - const { shouldShowOnboarding, isChatEnabled } = - await getOnboardingStatus(); + const { shouldShowOnboarding } = await getOnboardingStatus(); if (shouldShowOnboarding) { next = "/onboarding"; revalidatePath("/onboarding", "layout"); } else { - next = getHomepageRoute(isChatEnabled); + next = "/"; revalidatePath(next, "layout"); } } catch (createUserError) { diff --git a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts index 41d05a9afb..fd67519957 100644 --- a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts +++ b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts @@ -1,6 +1,17 @@ import { OAuthPopupResultMessage } from "./types"; import { NextResponse } from "next/server"; +/** + * Safely encode a value as JSON for embedding in a script tag. + * Escapes characters that could break out of the script context to prevent XSS. + */ +function safeJsonStringify(value: unknown): string { + return JSON.stringify(value) + .replace(//g, "\\u003e") + .replace(/&/g, "\\u0026"); +} + // This route is intended to be used as the callback for integration OAuth flows, // controlled by the CredentialsInput component. The CredentialsInput opens the login // page in a pop-up window, which then redirects to this route to close the loop. @@ -23,12 +34,13 @@ export async function GET(request: Request) { console.debug("Sending message to opener:", message); // Return a response with the message as JSON and a script to close the window + // Use safeJsonStringify to prevent XSS by escaping <, >, and & characters return new NextResponse( `
diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts index 74fd663ab2..913c4d7ded 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts @@ -11,7 +11,6 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useQueryClient } from "@tanstack/react-query"; import { usePathname, useSearchParams } from "next/navigation"; -import { useRef } from "react"; import { useCopilotStore } from "../../copilot-page-store"; import { useCopilotSessionId } from "../../useCopilotSessionId"; import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer"; @@ -70,41 +69,16 @@ export function useCopilotShell() { }); const stopStream = useChatStore((s) => s.stopStream); - const onStreamComplete = useChatStore((s) => s.onStreamComplete); - const isStreaming = useCopilotStore((s) => s.isStreaming); const isCreatingSession = useCopilotStore((s) => s.isCreatingSession); - const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession); - const openInterruptModal = useCopilotStore((s) => s.openInterruptModal); - const pendingActionRef = useRef<(() => void) | null>(null); - - async function stopCurrentStream() { - if (!currentSessionId) return; - - setIsSwitchingSession(true); - await new Promise