diff --git a/.github/workflows/platform-frontend-ci.yml b/.github/workflows/platform-frontend-ci.yml index 499bb03170..14676a6a1f 100644 --- a/.github/workflows/platform-frontend-ci.yml +++ b/.github/workflows/platform-frontend-ci.yml @@ -27,11 +27,20 @@ jobs: runs-on: ubuntu-latest outputs: cache-key: ${{ steps.cache-key.outputs.key }} + components-changed: ${{ steps.filter.outputs.components }} steps: - name: Checkout repository uses: actions/checkout@v4 + - name: Check for component changes + uses: dorny/paths-filter@v3 + id: filter + with: + filters: | + components: + - 'autogpt_platform/frontend/src/components/**' + - name: Set up Node.js uses: actions/setup-node@v4 with: @@ -90,8 +99,11 @@ jobs: chromatic: runs-on: ubuntu-latest needs: setup - # Only run on dev branch pushes or PRs targeting dev - if: github.ref == 'refs/heads/dev' || github.base_ref == 'dev' + # Disabled: to re-enable, remove 'false &&' from the condition below + if: >- + false + && (github.ref == 'refs/heads/dev' || github.base_ref == 'dev') + && needs.setup.outputs.components-changed == 'true' steps: - name: Checkout repository diff --git a/autogpt_platform/backend/.env.default b/autogpt_platform/backend/.env.default index b393f13017..fa52ba812a 100644 --- a/autogpt_platform/backend/.env.default +++ b/autogpt_platform/backend/.env.default @@ -152,6 +152,7 @@ REPLICATE_API_KEY= REVID_API_KEY= SCREENSHOTONE_API_KEY= UNREAL_SPEECH_API_KEY= +ELEVENLABS_API_KEY= # Data & Search Services E2B_API_KEY= diff --git a/autogpt_platform/backend/.gitignore b/autogpt_platform/backend/.gitignore index 9224c07d9e..6e688311a6 100644 --- a/autogpt_platform/backend/.gitignore +++ b/autogpt_platform/backend/.gitignore @@ -19,3 +19,6 @@ load-tests/*.json load-tests/*.log load-tests/node_modules/* migrations/*/rollback*.sql + +# Workspace files +workspaces/ diff --git a/autogpt_platform/backend/Dockerfile b/autogpt_platform/backend/Dockerfile index 103226d079..9bd455e490 100644 --- a/autogpt_platform/backend/Dockerfile +++ b/autogpt_platform/backend/Dockerfile @@ -62,10 +62,12 @@ ENV POETRY_HOME=/opt/poetry \ DEBIAN_FRONTEND=noninteractive ENV PATH=/opt/poetry/bin:$PATH -# Install Python without upgrading system-managed packages +# Install Python, FFmpeg, and ImageMagick (required for video processing blocks) RUN apt-get update && apt-get install -y \ python3.13 \ python3-pip \ + ffmpeg \ + imagemagick \ && rm -rf /var/lib/apt/lists/* # Copy only necessary files from builder diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py new file mode 100644 index 0000000000..f447d46bd7 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/completion_consumer.py @@ -0,0 +1,368 @@ +"""Redis Streams consumer for operation completion messages. + +This module provides a consumer (ChatCompletionConsumer) that listens for +completion notifications (OperationCompleteMessage) from external services +(like Agent Generator) and triggers the appropriate stream registry and +chat service updates via process_operation_success/process_operation_failure. + +Why Redis Streams instead of RabbitMQ? +-------------------------------------- +While the project typically uses RabbitMQ for async task queues (e.g., execution +queue), Redis Streams was chosen for chat completion notifications because: + +1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis + Streams (via stream_registry) for message persistence and replay. Using Redis + Streams for completion notifications keeps all chat streaming infrastructure + in one system, simplifying operations and reducing cross-system coordination. + +2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs, + allowing consumers to replay missed messages after reconnection. This aligns + with the SSE reconnection pattern where clients can resume from last_message_id. + +3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic + load balancing across pods with explicit message claiming (XAUTOCLAIM) for + recovering from dead consumers - ideal for the completion callback pattern. + +4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for + stream_registry) provides lower latency than an additional RabbitMQ hop. + +5. **Atomicity with Task State**: Completion processing often needs to update + task metadata stored in Redis. Keeping both in Redis enables simpler + transactional semantics without distributed coordination. + +The consumer uses Redis Streams with consumer groups for reliable message +processing across multiple platform pods, with XAUTOCLAIM for reclaiming +stale pending messages from dead consumers. +""" + +import asyncio +import logging +import os +import uuid +from typing import Any + +import orjson +from prisma import Prisma +from pydantic import BaseModel +from redis.exceptions import ResponseError + +from backend.data.redis_client import get_redis_async + +from . import stream_registry +from .completion_handler import process_operation_failure, process_operation_success +from .config import ChatConfig + +logger = logging.getLogger(__name__) +config = ChatConfig() + + +class OperationCompleteMessage(BaseModel): + """Message format for operation completion notifications.""" + + operation_id: str + task_id: str + success: bool + result: dict | str | None = None + error: str | None = None + + +class ChatCompletionConsumer: + """Consumer for chat operation completion messages from Redis Streams. + + This consumer initializes its own Prisma client in start() to ensure + database operations work correctly within this async context. + + Uses Redis consumer groups to allow multiple platform pods to consume + messages reliably with automatic redelivery on failure. + """ + + def __init__(self): + self._consumer_task: asyncio.Task | None = None + self._running = False + self._prisma: Prisma | None = None + self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}" + + async def start(self) -> None: + """Start the completion consumer.""" + if self._running: + logger.warning("Completion consumer already running") + return + + # Create consumer group if it doesn't exist + try: + redis = await get_redis_async() + await redis.xgroup_create( + config.stream_completion_name, + config.stream_consumer_group, + id="0", + mkstream=True, + ) + logger.info( + f"Created consumer group '{config.stream_consumer_group}' " + f"on stream '{config.stream_completion_name}'" + ) + except ResponseError as e: + if "BUSYGROUP" in str(e): + logger.debug( + f"Consumer group '{config.stream_consumer_group}' already exists" + ) + else: + raise + + self._running = True + self._consumer_task = asyncio.create_task(self._consume_messages()) + logger.info( + f"Chat completion consumer started (consumer: {self._consumer_name})" + ) + + async def _ensure_prisma(self) -> Prisma: + """Lazily initialize Prisma client on first use.""" + if self._prisma is None: + database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432") + self._prisma = Prisma(datasource={"url": database_url}) + await self._prisma.connect() + logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)") + return self._prisma + + async def stop(self) -> None: + """Stop the completion consumer.""" + self._running = False + + if self._consumer_task: + self._consumer_task.cancel() + try: + await self._consumer_task + except asyncio.CancelledError: + pass + self._consumer_task = None + + if self._prisma: + await self._prisma.disconnect() + self._prisma = None + logger.info("[COMPLETION] Consumer Prisma client disconnected") + + logger.info("Chat completion consumer stopped") + + async def _consume_messages(self) -> None: + """Main message consumption loop with retry logic.""" + max_retries = 10 + retry_delay = 5 # seconds + retry_count = 0 + block_timeout = 5000 # milliseconds + + while self._running and retry_count < max_retries: + try: + redis = await get_redis_async() + + # Reset retry count on successful connection + retry_count = 0 + + while self._running: + # First, claim any stale pending messages from dead consumers + # Redis does NOT auto-redeliver pending messages; we must explicitly + # claim them using XAUTOCLAIM + try: + claimed_result = await redis.xautoclaim( + name=config.stream_completion_name, + groupname=config.stream_consumer_group, + consumername=self._consumer_name, + min_idle_time=config.stream_claim_min_idle_ms, + start_id="0-0", + count=10, + ) + # xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids]) + if claimed_result and len(claimed_result) >= 2: + claimed_entries = claimed_result[1] + if claimed_entries: + logger.info( + f"Claimed {len(claimed_entries)} stale pending messages" + ) + for entry_id, data in claimed_entries: + if not self._running: + return + await self._process_entry(redis, entry_id, data) + except Exception as e: + logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}") + + # Read new messages from the stream + messages = await redis.xreadgroup( + groupname=config.stream_consumer_group, + consumername=self._consumer_name, + streams={config.stream_completion_name: ">"}, + block=block_timeout, + count=10, + ) + + if not messages: + continue + + for stream_name, entries in messages: + for entry_id, data in entries: + if not self._running: + return + await self._process_entry(redis, entry_id, data) + + except asyncio.CancelledError: + logger.info("Consumer cancelled") + return + except Exception as e: + retry_count += 1 + logger.error( + f"Consumer error (retry {retry_count}/{max_retries}): {e}", + exc_info=True, + ) + if self._running and retry_count < max_retries: + await asyncio.sleep(retry_delay) + else: + logger.error("Max retries reached, stopping consumer") + return + + async def _process_entry( + self, redis: Any, entry_id: str, data: dict[str, Any] + ) -> None: + """Process a single stream entry and acknowledge it on success. + + Args: + redis: Redis client connection + entry_id: The stream entry ID + data: The entry data dict + """ + try: + # Handle the message + message_data = data.get("data") + if message_data: + await self._handle_message( + message_data.encode() + if isinstance(message_data, str) + else message_data + ) + + # Acknowledge the message after successful processing + await redis.xack( + config.stream_completion_name, + config.stream_consumer_group, + entry_id, + ) + except Exception as e: + logger.error( + f"Error processing completion message {entry_id}: {e}", + exc_info=True, + ) + # Message remains in pending state and will be claimed by + # XAUTOCLAIM after min_idle_time expires + + async def _handle_message(self, body: bytes) -> None: + """Handle a completion message using our own Prisma client.""" + try: + data = orjson.loads(body) + message = OperationCompleteMessage(**data) + except Exception as e: + logger.error(f"Failed to parse completion message: {e}") + return + + logger.info( + f"[COMPLETION] Received completion for operation {message.operation_id} " + f"(task_id={message.task_id}, success={message.success})" + ) + + # Find task in registry + task = await stream_registry.find_task_by_operation_id(message.operation_id) + if task is None: + task = await stream_registry.get_task(message.task_id) + + if task is None: + logger.warning( + f"[COMPLETION] Task not found for operation {message.operation_id} " + f"(task_id={message.task_id})" + ) + return + + logger.info( + f"[COMPLETION] Found task: task_id={task.task_id}, " + f"session_id={task.session_id}, tool_call_id={task.tool_call_id}" + ) + + # Guard against empty task fields + if not task.task_id or not task.session_id or not task.tool_call_id: + logger.error( + f"[COMPLETION] Task has empty critical fields! " + f"task_id={task.task_id!r}, session_id={task.session_id!r}, " + f"tool_call_id={task.tool_call_id!r}" + ) + return + + if message.success: + await self._handle_success(task, message) + else: + await self._handle_failure(task, message) + + async def _handle_success( + self, + task: stream_registry.ActiveTask, + message: OperationCompleteMessage, + ) -> None: + """Handle successful operation completion.""" + prisma = await self._ensure_prisma() + await process_operation_success(task, message.result, prisma) + + async def _handle_failure( + self, + task: stream_registry.ActiveTask, + message: OperationCompleteMessage, + ) -> None: + """Handle failed operation completion.""" + prisma = await self._ensure_prisma() + await process_operation_failure(task, message.error, prisma) + + +# Module-level consumer instance +_consumer: ChatCompletionConsumer | None = None + + +async def start_completion_consumer() -> None: + """Start the global completion consumer.""" + global _consumer + if _consumer is None: + _consumer = ChatCompletionConsumer() + await _consumer.start() + + +async def stop_completion_consumer() -> None: + """Stop the global completion consumer.""" + global _consumer + if _consumer: + await _consumer.stop() + _consumer = None + + +async def publish_operation_complete( + operation_id: str, + task_id: str, + success: bool, + result: dict | str | None = None, + error: str | None = None, +) -> None: + """Publish an operation completion message to Redis Streams. + + Args: + operation_id: The operation ID that completed. + task_id: The task ID associated with the operation. + success: Whether the operation succeeded. + result: The result data (for success). + error: The error message (for failure). + """ + message = OperationCompleteMessage( + operation_id=operation_id, + task_id=task_id, + success=success, + result=result, + error=error, + ) + + redis = await get_redis_async() + await redis.xadd( + config.stream_completion_name, + {"data": message.model_dump_json()}, + maxlen=config.stream_max_length, + ) + logger.info(f"Published completion for operation {operation_id}") diff --git a/autogpt_platform/backend/backend/api/features/chat/completion_handler.py b/autogpt_platform/backend/backend/api/features/chat/completion_handler.py new file mode 100644 index 0000000000..905fa2ddba --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/completion_handler.py @@ -0,0 +1,344 @@ +"""Shared completion handling for operation success and failure. + +This module provides common logic for handling operation completion from both: +- The Redis Streams consumer (completion_consumer.py) +- The HTTP webhook endpoint (routes.py) +""" + +import logging +from typing import Any + +import orjson +from prisma import Prisma + +from . import service as chat_service +from . import stream_registry +from .response_model import StreamError, StreamToolOutputAvailable +from .tools.models import ErrorResponse + +logger = logging.getLogger(__name__) + +# Tools that produce agent_json that needs to be saved to library +AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"} + +# Keys that should be stripped from agent_json when returning in error responses +SENSITIVE_KEYS = frozenset( + { + "api_key", + "apikey", + "api_secret", + "password", + "secret", + "credentials", + "credential", + "token", + "access_token", + "refresh_token", + "private_key", + "privatekey", + "auth", + "authorization", + } +) + + +def _sanitize_agent_json(obj: Any) -> Any: + """Recursively sanitize agent_json by removing sensitive keys. + + Args: + obj: The object to sanitize (dict, list, or primitive) + + Returns: + Sanitized copy with sensitive keys removed/redacted + """ + if isinstance(obj, dict): + return { + k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v) + for k, v in obj.items() + } + elif isinstance(obj, list): + return [_sanitize_agent_json(item) for item in obj] + else: + return obj + + +class ToolMessageUpdateError(Exception): + """Raised when updating a tool message in the database fails.""" + + pass + + +async def _update_tool_message( + session_id: str, + tool_call_id: str, + content: str, + prisma_client: Prisma | None, +) -> None: + """Update tool message in database. + + Args: + session_id: The session ID + tool_call_id: The tool call ID to update + content: The new content for the message + prisma_client: Optional Prisma client. If None, uses chat_service. + + Raises: + ToolMessageUpdateError: If the database update fails. The caller should + handle this to avoid marking the task as completed with inconsistent state. + """ + try: + if prisma_client: + # Use provided Prisma client (for consumer with its own connection) + updated_count = await prisma_client.chatmessage.update_many( + where={ + "sessionId": session_id, + "toolCallId": tool_call_id, + }, + data={"content": content}, + ) + # Check if any rows were updated - 0 means message not found + if updated_count == 0: + raise ToolMessageUpdateError( + f"No message found with tool_call_id={tool_call_id} in session {session_id}" + ) + else: + # Use service function (for webhook endpoint) + await chat_service._update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=content, + ) + except ToolMessageUpdateError: + raise + except Exception as e: + logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True) + raise ToolMessageUpdateError( + f"Failed to update tool message for tool_call_id={tool_call_id}: {e}" + ) from e + + +def serialize_result(result: dict | list | str | int | float | bool | None) -> str: + """Serialize result to JSON string with sensible defaults. + + Args: + result: The result to serialize. Can be a dict, list, string, + number, boolean, or None. + + Returns: + JSON string representation of the result. Returns '{"status": "completed"}' + only when result is explicitly None. + """ + if isinstance(result, str): + return result + if result is None: + return '{"status": "completed"}' + return orjson.dumps(result).decode("utf-8") + + +async def _save_agent_from_result( + result: dict[str, Any], + user_id: str | None, + tool_name: str, +) -> dict[str, Any]: + """Save agent to library if result contains agent_json. + + Args: + result: The result dict that may contain agent_json + user_id: The user ID to save the agent for + tool_name: The tool name (create_agent or edit_agent) + + Returns: + Updated result dict with saved agent details, or original result if no agent_json + """ + if not user_id: + logger.warning("[COMPLETION] Cannot save agent: no user_id in task") + return result + + agent_json = result.get("agent_json") + if not agent_json: + logger.warning( + f"[COMPLETION] {tool_name} completed but no agent_json in result" + ) + return result + + try: + from .tools.agent_generator import save_agent_to_library + + is_update = tool_name == "edit_agent" + created_graph, library_agent = await save_agent_to_library( + agent_json, user_id, is_update=is_update + ) + + logger.info( + f"[COMPLETION] Saved agent '{created_graph.name}' to library " + f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})" + ) + + # Return a response similar to AgentSavedResponse + return { + "type": "agent_saved", + "message": f"Agent '{created_graph.name}' has been saved to your library!", + "agent_id": created_graph.id, + "agent_name": created_graph.name, + "library_agent_id": library_agent.id, + "library_agent_link": f"/library/agents/{library_agent.id}", + "agent_page_link": f"/build?flowID={created_graph.id}", + } + except Exception as e: + logger.error( + f"[COMPLETION] Failed to save agent to library: {e}", + exc_info=True, + ) + # Return error but don't fail the whole operation + # Sanitize agent_json to remove sensitive keys before returning + return { + "type": "error", + "message": f"Agent was generated but failed to save: {str(e)}", + "error": str(e), + "agent_json": _sanitize_agent_json(agent_json), + } + + +async def process_operation_success( + task: stream_registry.ActiveTask, + result: dict | str | None, + prisma_client: Prisma | None = None, +) -> None: + """Handle successful operation completion. + + Publishes the result to the stream registry, updates the database, + generates LLM continuation, and marks the task as completed. + + Args: + task: The active task that completed + result: The result data from the operation + prisma_client: Optional Prisma client for database operations. + If None, uses chat_service._update_pending_operation instead. + + Raises: + ToolMessageUpdateError: If the database update fails. The task will be + marked as failed instead of completed to avoid inconsistent state. + """ + # For agent generation tools, save the agent to library + if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict): + result = await _save_agent_from_result(result, task.user_id, task.tool_name) + + # Serialize result for output (only substitute default when result is exactly None) + result_output = result if result is not None else {"status": "completed"} + output_str = ( + result_output + if isinstance(result_output, str) + else orjson.dumps(result_output).decode("utf-8") + ) + + # Publish result to stream registry + await stream_registry.publish_chunk( + task.task_id, + StreamToolOutputAvailable( + toolCallId=task.tool_call_id, + toolName=task.tool_name, + output=output_str, + success=True, + ), + ) + + # Update pending operation in database + # If this fails, we must not continue to mark the task as completed + result_str = serialize_result(result) + try: + await _update_tool_message( + session_id=task.session_id, + tool_call_id=task.tool_call_id, + content=result_str, + prisma_client=prisma_client, + ) + except ToolMessageUpdateError: + # DB update failed - mark task as failed to avoid inconsistent state + logger.error( + f"[COMPLETION] DB update failed for task {task.task_id}, " + "marking as failed instead of completed" + ) + await stream_registry.publish_chunk( + task.task_id, + StreamError(errorText="Failed to save operation result to database"), + ) + await stream_registry.mark_task_completed(task.task_id, status="failed") + raise + + # Generate LLM continuation with streaming + try: + await chat_service._generate_llm_continuation_with_streaming( + session_id=task.session_id, + user_id=task.user_id, + task_id=task.task_id, + ) + except Exception as e: + logger.error( + f"[COMPLETION] Failed to generate LLM continuation: {e}", + exc_info=True, + ) + + # Mark task as completed and release Redis lock + await stream_registry.mark_task_completed(task.task_id, status="completed") + try: + await chat_service._mark_operation_completed(task.tool_call_id) + except Exception as e: + logger.error(f"[COMPLETION] Failed to mark operation completed: {e}") + + logger.info( + f"[COMPLETION] Successfully processed completion for task {task.task_id}" + ) + + +async def process_operation_failure( + task: stream_registry.ActiveTask, + error: str | None, + prisma_client: Prisma | None = None, +) -> None: + """Handle failed operation completion. + + Publishes the error to the stream registry, updates the database with + the error response, and marks the task as failed. + + Args: + task: The active task that failed + error: The error message from the operation + prisma_client: Optional Prisma client for database operations. + If None, uses chat_service._update_pending_operation instead. + """ + error_msg = error or "Operation failed" + + # Publish error to stream registry + await stream_registry.publish_chunk( + task.task_id, + StreamError(errorText=error_msg), + ) + + # Update pending operation with error + # If this fails, we still continue to mark the task as failed + error_response = ErrorResponse( + message=error_msg, + error=error, + ) + try: + await _update_tool_message( + session_id=task.session_id, + tool_call_id=task.tool_call_id, + content=error_response.model_dump_json(), + prisma_client=prisma_client, + ) + except ToolMessageUpdateError: + # DB update failed - log but continue with cleanup + logger.error( + f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, " + "continuing with cleanup" + ) + + # Mark task as failed and release Redis lock + await stream_registry.mark_task_completed(task.task_id, status="failed") + try: + await chat_service._mark_operation_completed(task.tool_call_id) + except Exception as e: + logger.error(f"[COMPLETION] Failed to mark operation completed: {e}") + + logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}") diff --git a/autogpt_platform/backend/backend/api/features/chat/config.py b/autogpt_platform/backend/backend/api/features/chat/config.py index dba7934877..0b37e42df8 100644 --- a/autogpt_platform/backend/backend/api/features/chat/config.py +++ b/autogpt_platform/backend/backend/api/features/chat/config.py @@ -11,7 +11,7 @@ class ChatConfig(BaseSettings): # OpenAI API Configuration model: str = Field( - default="anthropic/claude-opus-4.5", description="Default model to use" + default="anthropic/claude-opus-4.6", description="Default model to use" ) title_model: str = Field( default="openai/gpt-4o-mini", @@ -44,6 +44,48 @@ class ChatConfig(BaseSettings): description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)", ) + # Stream registry configuration for SSE reconnection + stream_ttl: int = Field( + default=3600, + description="TTL in seconds for stream data in Redis (1 hour)", + ) + stream_max_length: int = Field( + default=10000, + description="Maximum number of messages to store per stream", + ) + + # Redis Streams configuration for completion consumer + stream_completion_name: str = Field( + default="chat:completions", + description="Redis Stream name for operation completions", + ) + stream_consumer_group: str = Field( + default="chat_consumers", + description="Consumer group name for completion stream", + ) + stream_claim_min_idle_ms: int = Field( + default=60000, + description="Minimum idle time in milliseconds before claiming pending messages from dead consumers", + ) + + # Redis key prefixes for stream registry + task_meta_prefix: str = Field( + default="chat:task:meta:", + description="Prefix for task metadata hash keys", + ) + task_stream_prefix: str = Field( + default="chat:stream:", + description="Prefix for task message stream keys", + ) + task_op_prefix: str = Field( + default="chat:task:op:", + description="Prefix for operation ID to task ID mapping keys", + ) + internal_api_key: str | None = Field( + default=None, + description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)", + ) + # Langfuse Prompt Management Configuration # Note: Langfuse credentials are in Settings().secrets (settings.py) langfuse_prompt_name: str = Field( @@ -82,6 +124,14 @@ class ChatConfig(BaseSettings): v = "https://openrouter.ai/api/v1" return v + @field_validator("internal_api_key", mode="before") + @classmethod + def get_internal_api_key(cls, v): + """Get internal API key from environment if not provided.""" + if v is None: + v = os.getenv("CHAT_INTERNAL_API_KEY") + return v + # Prompt paths for different contexts PROMPT_PATHS: dict[str, str] = { "default": "prompts/chat_system.md", diff --git a/autogpt_platform/backend/backend/api/features/chat/response_model.py b/autogpt_platform/backend/backend/api/features/chat/response_model.py index 53a8cf3a1f..f627a42fcc 100644 --- a/autogpt_platform/backend/backend/api/features/chat/response_model.py +++ b/autogpt_platform/backend/backend/api/features/chat/response_model.py @@ -52,6 +52,10 @@ class StreamStart(StreamBaseResponse): type: ResponseType = ResponseType.START messageId: str = Field(..., description="Unique message ID") + taskId: str | None = Field( + default=None, + description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream", + ) class StreamFinish(StreamBaseResponse): diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index cab51543b1..3e731d86ac 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -1,19 +1,23 @@ """Chat API routes for chat session management and streaming via SSE.""" import logging +import uuid as uuid_module from collections.abc import AsyncGenerator from typing import Annotated from autogpt_libs import auth -from fastapi import APIRouter, Depends, Query, Security +from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security from fastapi.responses import StreamingResponse from pydantic import BaseModel from backend.util.exceptions import NotFoundError from . import service as chat_service +from . import stream_registry +from .completion_handler import process_operation_failure, process_operation_success from .config import ChatConfig from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions +from .response_model import StreamFinish, StreamHeartbeat, StreamStart config = ChatConfig() @@ -55,6 +59,15 @@ class CreateSessionResponse(BaseModel): user_id: str | None +class ActiveStreamInfo(BaseModel): + """Information about an active stream for reconnection.""" + + task_id: str + last_message_id: str # Redis Stream message ID for resumption + operation_id: str # Operation ID for completion tracking + tool_name: str # Name of the tool being executed + + class SessionDetailResponse(BaseModel): """Response model providing complete details for a chat session, including messages.""" @@ -63,6 +76,7 @@ class SessionDetailResponse(BaseModel): updated_at: str user_id: str | None messages: list[dict] + active_stream: ActiveStreamInfo | None = None # Present if stream is still active class SessionSummaryResponse(BaseModel): @@ -81,6 +95,14 @@ class ListSessionsResponse(BaseModel): total: int +class OperationCompleteRequest(BaseModel): + """Request model for external completion webhook.""" + + success: bool + result: dict | str | None = None + error: str | None = None + + # ========== Routes ========== @@ -166,13 +188,14 @@ async def get_session( Retrieve the details of a specific chat session. Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages. + If there's an active stream for this session, returns the task_id for reconnection. Args: session_id: The unique identifier for the desired chat session. user_id: The optional authenticated user ID, or None for anonymous access. Returns: - SessionDetailResponse: Details for the requested session, or None if not found. + SessionDetailResponse: Details for the requested session, including active_stream info if applicable. """ session = await get_chat_session(session_id, user_id) @@ -180,11 +203,28 @@ async def get_session( raise NotFoundError(f"Session {session_id} not found.") messages = [message.model_dump() for message in session.messages] - logger.info( - f"Returning session {session_id}: " - f"message_count={len(messages)}, " - f"roles={[m.get('role') for m in messages]}" + + # Check if there's an active stream for this session + active_stream_info = None + active_task, last_message_id = await stream_registry.get_active_task_for_session( + session_id, user_id ) + if active_task: + # Filter out the in-progress assistant message from the session response. + # The client will receive the complete assistant response through the SSE + # stream replay instead, preventing duplicate content. + if messages and messages[-1].get("role") == "assistant": + messages = messages[:-1] + + # Use "0-0" as last_message_id to replay the stream from the beginning. + # Since we filtered out the cached assistant message, the client needs + # the full stream to reconstruct the response. + active_stream_info = ActiveStreamInfo( + task_id=active_task.task_id, + last_message_id="0-0", + operation_id=active_task.operation_id, + tool_name=active_task.tool_name, + ) return SessionDetailResponse( id=session.session_id, @@ -192,6 +232,7 @@ async def get_session( updated_at=session.updated_at.isoformat(), user_id=session.user_id or None, messages=messages, + active_stream=active_stream_info, ) @@ -211,49 +252,112 @@ async def stream_chat_post( - Tool call UI elements (if invoked) - Tool execution results + The AI generation runs in a background task that continues even if the client disconnects. + All chunks are written to Redis for reconnection support. If the client disconnects, + they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off. + Args: session_id: The chat session identifier to associate with the streamed messages. request: Request body containing message, is_user_message, and optional context. user_id: Optional authenticated user ID. Returns: - StreamingResponse: SSE-formatted response chunks. + StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event + containing the task_id for reconnection. """ + import asyncio + session = await _validate_and_get_session(session_id, user_id) + # Create a task in the stream registry for reconnection support + task_id = str(uuid_module.uuid4()) + operation_id = str(uuid_module.uuid4()) + await stream_registry.create_task( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id="chat_stream", # Not a tool call, but needed for the model + tool_name="chat", + operation_id=operation_id, + ) + + # Background task that runs the AI generation independently of SSE connection + async def run_ai_generation(): + try: + # Emit a start event with task_id for reconnection + start_chunk = StreamStart(messageId=task_id, taskId=task_id) + await stream_registry.publish_chunk(task_id, start_chunk) + + async for chunk in chat_service.stream_chat_completion( + session_id, + request.message, + is_user_message=request.is_user_message, + user_id=user_id, + session=session, # Pass pre-fetched session to avoid double-fetch + context=request.context, + ): + # Write to Redis (subscribers will receive via XREAD) + await stream_registry.publish_chunk(task_id, chunk) + + # Mark task as completed + await stream_registry.mark_task_completed(task_id, "completed") + except Exception as e: + logger.error( + f"Error in background AI generation for session {session_id}: {e}" + ) + await stream_registry.mark_task_completed(task_id, "failed") + + # Start the AI generation in a background task + bg_task = asyncio.create_task(run_ai_generation()) + await stream_registry.set_task_asyncio_task(task_id, bg_task) + + # SSE endpoint that subscribes to the task's stream async def event_generator() -> AsyncGenerator[str, None]: - chunk_count = 0 - first_chunk_type: str | None = None - async for chunk in chat_service.stream_chat_completion( - session_id, - request.message, - is_user_message=request.is_user_message, - user_id=user_id, - session=session, # Pass pre-fetched session to avoid double-fetch - context=request.context, - ): - if chunk_count < 3: - logger.info( - "Chat stream chunk", - extra={ - "session_id": session_id, - "chunk_type": str(chunk.type), - }, - ) - if not first_chunk_type: - first_chunk_type = str(chunk.type) - chunk_count += 1 - yield chunk.to_sse() - logger.info( - "Chat stream completed", - extra={ - "session_id": session_id, - "chunk_count": chunk_count, - "first_chunk_type": first_chunk_type, - }, - ) - # AI SDK protocol termination - yield "data: [DONE]\n\n" + subscriber_queue = None + try: + # Subscribe to the task stream (this replays existing messages + live updates) + subscriber_queue = await stream_registry.subscribe_to_task( + task_id=task_id, + user_id=user_id, + last_message_id="0-0", # Get all messages from the beginning + ) + + if subscriber_queue is None: + yield StreamFinish().to_sse() + yield "data: [DONE]\n\n" + return + + # Read from the subscriber queue and yield to SSE + while True: + try: + chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0) + yield chunk.to_sse() + + # Check for finish signal + if isinstance(chunk, StreamFinish): + break + except asyncio.TimeoutError: + # Send heartbeat to keep connection alive + yield StreamHeartbeat().to_sse() + + except GeneratorExit: + pass # Client disconnected - background task continues + except Exception as e: + logger.error(f"Error in SSE stream for task {task_id}: {e}") + finally: + # Unsubscribe when client disconnects or stream ends to prevent resource leak + if subscriber_queue is not None: + try: + await stream_registry.unsubscribe_from_task( + task_id, subscriber_queue + ) + except Exception as unsub_err: + logger.error( + f"Error unsubscribing from task {task_id}: {unsub_err}", + exc_info=True, + ) + # AI SDK protocol termination - always yield even if unsubscribe fails + yield "data: [DONE]\n\n" return StreamingResponse( event_generator(), @@ -366,6 +470,251 @@ async def session_assign_user( return {"status": "ok"} +# ========== Task Streaming (SSE Reconnection) ========== + + +@router.get( + "/tasks/{task_id}/stream", +) +async def stream_task( + task_id: str, + user_id: str | None = Depends(auth.get_user_id), + last_message_id: str = Query( + default="0-0", + description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.", + ), +): + """ + Reconnect to a long-running task's SSE stream. + + When a long-running operation (like agent generation) starts, the client + receives a task_id. If the connection drops, the client can reconnect + using this endpoint to resume receiving updates. + + Args: + task_id: The task ID from the operation_started response. + user_id: Authenticated user ID for ownership validation. + last_message_id: Last Redis Stream message ID received ("0-0" for full replay). + + Returns: + StreamingResponse: SSE-formatted response chunks starting after last_message_id. + + Raises: + HTTPException: 404 if task not found, 410 if task expired, 403 if access denied. + """ + # Check task existence and expiry before subscribing + task, error_code = await stream_registry.get_task_with_expiry_info(task_id) + + if error_code == "TASK_EXPIRED": + raise HTTPException( + status_code=410, + detail={ + "code": "TASK_EXPIRED", + "message": "This operation has expired. Please try again.", + }, + ) + + if error_code == "TASK_NOT_FOUND": + raise HTTPException( + status_code=404, + detail={ + "code": "TASK_NOT_FOUND", + "message": f"Task {task_id} not found.", + }, + ) + + # Validate ownership if task has an owner + if task and task.user_id and user_id != task.user_id: + raise HTTPException( + status_code=403, + detail={ + "code": "ACCESS_DENIED", + "message": "You do not have access to this task.", + }, + ) + + # Get subscriber queue from stream registry + subscriber_queue = await stream_registry.subscribe_to_task( + task_id=task_id, + user_id=user_id, + last_message_id=last_message_id, + ) + + if subscriber_queue is None: + raise HTTPException( + status_code=404, + detail={ + "code": "TASK_NOT_FOUND", + "message": f"Task {task_id} not found or access denied.", + }, + ) + + async def event_generator() -> AsyncGenerator[str, None]: + import asyncio + + heartbeat_interval = 15.0 # Send heartbeat every 15 seconds + try: + while True: + try: + # Wait for next chunk with timeout for heartbeats + chunk = await asyncio.wait_for( + subscriber_queue.get(), timeout=heartbeat_interval + ) + yield chunk.to_sse() + + # Check for finish signal + if isinstance(chunk, StreamFinish): + break + except asyncio.TimeoutError: + # Send heartbeat to keep connection alive + yield StreamHeartbeat().to_sse() + except Exception as e: + logger.error(f"Error in task stream {task_id}: {e}", exc_info=True) + finally: + # Unsubscribe when client disconnects or stream ends + try: + await stream_registry.unsubscribe_from_task(task_id, subscriber_queue) + except Exception as unsub_err: + logger.error( + f"Error unsubscribing from task {task_id}: {unsub_err}", + exc_info=True, + ) + # AI SDK protocol termination - always yield even if unsubscribe fails + yield "data: [DONE]\n\n" + + return StreamingResponse( + event_generator(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "X-Accel-Buffering": "no", + "x-vercel-ai-ui-message-stream": "v1", + }, + ) + + +@router.get( + "/tasks/{task_id}", +) +async def get_task_status( + task_id: str, + user_id: str | None = Depends(auth.get_user_id), +) -> dict: + """ + Get the status of a long-running task. + + Args: + task_id: The task ID to check. + user_id: Authenticated user ID for ownership validation. + + Returns: + dict: Task status including task_id, status, tool_name, and operation_id. + + Raises: + NotFoundError: If task_id is not found or user doesn't have access. + """ + task = await stream_registry.get_task(task_id) + + if task is None: + raise NotFoundError(f"Task {task_id} not found.") + + # Validate ownership - if task has an owner, requester must match + if task.user_id and user_id != task.user_id: + raise NotFoundError(f"Task {task_id} not found.") + + return { + "task_id": task.task_id, + "session_id": task.session_id, + "status": task.status, + "tool_name": task.tool_name, + "operation_id": task.operation_id, + "created_at": task.created_at.isoformat(), + } + + +# ========== External Completion Webhook ========== + + +@router.post( + "/operations/{operation_id}/complete", + status_code=200, +) +async def complete_operation( + operation_id: str, + request: OperationCompleteRequest, + x_api_key: str | None = Header(default=None), +) -> dict: + """ + External completion webhook for long-running operations. + + Called by Agent Generator (or other services) when an operation completes. + This triggers the stream registry to publish completion and continue LLM generation. + + Args: + operation_id: The operation ID to complete. + request: Completion payload with success status and result/error. + x_api_key: Internal API key for authentication. + + Returns: + dict: Status of the completion. + + Raises: + HTTPException: If API key is invalid or operation not found. + """ + # Validate internal API key - reject if not configured or invalid + if not config.internal_api_key: + logger.error( + "Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured" + ) + raise HTTPException( + status_code=503, + detail="Webhook not available: internal API key not configured", + ) + if x_api_key != config.internal_api_key: + raise HTTPException(status_code=401, detail="Invalid API key") + + # Find task by operation_id + task = await stream_registry.find_task_by_operation_id(operation_id) + if task is None: + raise HTTPException( + status_code=404, + detail=f"Operation {operation_id} not found", + ) + + logger.info( + f"Received completion webhook for operation {operation_id} " + f"(task_id={task.task_id}, success={request.success})" + ) + + if request.success: + await process_operation_success(task, request.result) + else: + await process_operation_failure(task, request.error) + + return {"status": "ok", "task_id": task.task_id} + + +# ========== Configuration ========== + + +@router.get("/config/ttl", status_code=200) +async def get_ttl_config() -> dict: + """ + Get the stream TTL configuration. + + Returns the Time-To-Live settings for chat streams, which determines + how long clients can reconnect to an active stream. + + Returns: + dict: TTL configuration with seconds and milliseconds values. + """ + return { + "stream_ttl_seconds": config.stream_ttl, + "stream_ttl_ms": config.stream_ttl * 1000, + } + + # ========== Health Check ========== diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index 6336d1c5af..06da6bdf2b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -33,9 +33,10 @@ from backend.data.understanding import ( get_business_understanding, ) from backend.util.exceptions import NotFoundError -from backend.util.settings import Settings +from backend.util.settings import AppEnvironment, Settings from . import db as chat_db +from . import stream_registry from .config import ChatConfig from .model import ( ChatMessage, @@ -221,8 +222,18 @@ async def _get_system_prompt_template(context: str) -> str: try: # cache_ttl_seconds=0 disables SDK caching to always get the latest prompt # Use asyncio.to_thread to avoid blocking the event loop + # In non-production environments, fetch the latest prompt version + # instead of the production-labeled version for easier testing + label = ( + None + if settings.config.app_env == AppEnvironment.PRODUCTION + else "latest" + ) prompt = await asyncio.to_thread( - langfuse.get_prompt, config.langfuse_prompt_name, cache_ttl_seconds=0 + langfuse.get_prompt, + config.langfuse_prompt_name, + label=label, + cache_ttl_seconds=0, ) return prompt.compile(users_information=context) except Exception as e: @@ -617,6 +628,9 @@ async def stream_chat_completion( total_tokens=chunk.totalTokens, ) ) + elif isinstance(chunk, StreamHeartbeat): + # Pass through heartbeat to keep SSE connection alive + yield chunk else: logger.error(f"Unknown chunk type: {type(chunk)}", exc_info=True) @@ -1184,8 +1198,9 @@ async def _yield_tool_call( ) return - # Generate operation ID + # Generate operation ID and task ID operation_id = str(uuid_module.uuid4()) + task_id = str(uuid_module.uuid4()) # Build a user-friendly message based on tool and arguments if tool_name == "create_agent": @@ -1228,6 +1243,16 @@ async def _yield_tool_call( # Wrap session save and task creation in try-except to release lock on failure try: + # Create task in stream registry for SSE reconnection support + await stream_registry.create_task( + task_id=task_id, + session_id=session.session_id, + user_id=session.user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + # Save assistant message with tool_call FIRST (required by LLM) assistant_message = ChatMessage( role="assistant", @@ -1249,23 +1274,27 @@ async def _yield_tool_call( session.messages.append(pending_message) await upsert_chat_session(session) logger.info( - f"Saved pending operation {operation_id} for tool {tool_name} " - f"in session {session.session_id}" + f"Saved pending operation {operation_id} (task_id={task_id}) " + f"for tool {tool_name} in session {session.session_id}" ) # Store task reference in module-level set to prevent GC before completion - task = asyncio.create_task( - _execute_long_running_tool( + bg_task = asyncio.create_task( + _execute_long_running_tool_with_streaming( tool_name=tool_name, parameters=arguments, tool_call_id=tool_call_id, operation_id=operation_id, + task_id=task_id, session_id=session.session_id, user_id=session.user_id, ) ) - _background_tasks.add(task) - task.add_done_callback(_background_tasks.discard) + _background_tasks.add(bg_task) + bg_task.add_done_callback(_background_tasks.discard) + + # Associate the asyncio task with the stream registry task + await stream_registry.set_task_asyncio_task(task_id, bg_task) except Exception as e: # Roll back appended messages to prevent data corruption on subsequent saves if ( @@ -1283,6 +1312,11 @@ async def _yield_tool_call( # Release the Redis lock since the background task won't be spawned await _mark_operation_completed(tool_call_id) + # Mark stream registry task as failed if it was created + try: + await stream_registry.mark_task_completed(task_id, status="failed") + except Exception: + pass logger.error( f"Failed to setup long-running tool {tool_name}: {e}", exc_info=True ) @@ -1296,6 +1330,7 @@ async def _yield_tool_call( message=started_msg, operation_id=operation_id, tool_name=tool_name, + task_id=task_id, # Include task_id for SSE reconnection ).model_dump_json(), success=True, ) @@ -1365,6 +1400,9 @@ async def _execute_long_running_tool( This function runs independently of the SSE connection, so the operation survives if the user closes their browser tab. + + NOTE: This is the legacy function without stream registry support. + Use _execute_long_running_tool_with_streaming for new implementations. """ try: # Load fresh session (not stale reference) @@ -1417,6 +1455,133 @@ async def _execute_long_running_tool( await _mark_operation_completed(tool_call_id) +async def _execute_long_running_tool_with_streaming( + tool_name: str, + parameters: dict[str, Any], + tool_call_id: str, + operation_id: str, + task_id: str, + session_id: str, + user_id: str | None, +) -> None: + """Execute a long-running tool with stream registry support for SSE reconnection. + + This function runs independently of the SSE connection, publishes progress + to the stream registry, and survives if the user closes their browser tab. + Clients can reconnect via GET /chat/tasks/{task_id}/stream to resume streaming. + + If the external service returns a 202 Accepted (async), this function exits + early and lets the Redis Streams completion consumer handle the rest. + """ + # Track whether we delegated to async processing - if so, the Redis Streams + # completion consumer (stream_registry / completion_consumer) will handle cleanup, not us + delegated_to_async = False + + try: + # Load fresh session (not stale reference) + session = await get_chat_session(session_id, user_id) + if not session: + logger.error(f"Session {session_id} not found for background tool") + await stream_registry.mark_task_completed(task_id, status="failed") + return + + # Pass operation_id and task_id to the tool for async processing + enriched_parameters = { + **parameters, + "_operation_id": operation_id, + "_task_id": task_id, + } + + # Execute the actual tool + result = await execute_tool( + tool_name=tool_name, + parameters=enriched_parameters, + tool_call_id=tool_call_id, + user_id=user_id, + session=session, + ) + + # Check if the tool result indicates async processing + # (e.g., Agent Generator returned 202 Accepted) + try: + if isinstance(result.output, dict): + result_data = result.output + elif result.output: + result_data = orjson.loads(result.output) + else: + result_data = {} + if result_data.get("status") == "accepted": + logger.info( + f"Tool {tool_name} delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id}). " + f"Redis Streams completion consumer will handle the rest." + ) + # Don't publish result, don't continue with LLM, and don't cleanup + # The Redis Streams consumer (completion_consumer) will handle + # everything when the external service completes via webhook + delegated_to_async = True + return + except (orjson.JSONDecodeError, TypeError): + pass # Not JSON or not async - continue normally + + # Publish tool result to stream registry + await stream_registry.publish_chunk(task_id, result) + + # Update the pending message with result + result_str = ( + result.output + if isinstance(result.output, str) + else orjson.dumps(result.output).decode("utf-8") + ) + await _update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=result_str, + ) + + logger.info( + f"Background tool {tool_name} completed for session {session_id} " + f"(task_id={task_id})" + ) + + # Generate LLM continuation and stream chunks to registry + await _generate_llm_continuation_with_streaming( + session_id=session_id, + user_id=user_id, + task_id=task_id, + ) + + # Mark task as completed in stream registry + await stream_registry.mark_task_completed(task_id, status="completed") + + except Exception as e: + logger.error(f"Background tool {tool_name} failed: {e}", exc_info=True) + error_response = ErrorResponse( + message=f"Tool {tool_name} failed: {str(e)}", + ) + + # Publish error to stream registry followed by finish event + await stream_registry.publish_chunk( + task_id, + StreamError(errorText=str(e)), + ) + await stream_registry.publish_chunk(task_id, StreamFinish()) + + await _update_pending_operation( + session_id=session_id, + tool_call_id=tool_call_id, + result=error_response.model_dump_json(), + ) + + # Mark task as failed in stream registry + await stream_registry.mark_task_completed(task_id, status="failed") + finally: + # Only cleanup if we didn't delegate to async processing + # For async path, the Redis Streams completion consumer handles cleanup + if not delegated_to_async: + await _mark_operation_completed(tool_call_id) + + async def _update_pending_operation( session_id: str, tool_call_id: str, @@ -1597,3 +1762,128 @@ async def _generate_llm_continuation( except Exception as e: logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True) + + +async def _generate_llm_continuation_with_streaming( + session_id: str, + user_id: str | None, + task_id: str, +) -> None: + """Generate an LLM response with streaming to the stream registry. + + This is called by background tasks to continue the conversation + after a tool result is saved. Chunks are published to the stream registry + so reconnecting clients can receive them. + """ + import uuid as uuid_module + + try: + # Load fresh session from DB (bypass cache to get the updated tool result) + await invalidate_session_cache(session_id) + session = await get_chat_session(session_id, user_id) + if not session: + logger.error(f"Session {session_id} not found for LLM continuation") + return + + # Build system prompt + system_prompt, _ = await _build_system_prompt(user_id) + + # Build messages in OpenAI format + messages = session.to_openai_messages() + if system_prompt: + from openai.types.chat import ChatCompletionSystemMessageParam + + system_message = ChatCompletionSystemMessageParam( + role="system", + content=system_prompt, + ) + messages = [system_message] + messages + + # Build extra_body for tracing + extra_body: dict[str, Any] = { + "posthogProperties": { + "environment": settings.config.app_env.value, + }, + } + if user_id: + extra_body["user"] = user_id[:128] + extra_body["posthogDistinctId"] = user_id + if session_id: + extra_body["session_id"] = session_id[:128] + + # Make streaming LLM call (no tools - just text response) + from typing import cast + + from openai.types.chat import ChatCompletionMessageParam + + # Generate unique IDs for AI SDK protocol + message_id = str(uuid_module.uuid4()) + text_block_id = str(uuid_module.uuid4()) + + # Publish start event + await stream_registry.publish_chunk(task_id, StreamStart(messageId=message_id)) + await stream_registry.publish_chunk(task_id, StreamTextStart(id=text_block_id)) + + # Stream the response + stream = await client.chat.completions.create( + model=config.model, + messages=cast(list[ChatCompletionMessageParam], messages), + extra_body=extra_body, + stream=True, + ) + + assistant_content = "" + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + delta = chunk.choices[0].delta.content + assistant_content += delta + # Publish delta to stream registry + await stream_registry.publish_chunk( + task_id, + StreamTextDelta(id=text_block_id, delta=delta), + ) + + # Publish end events + await stream_registry.publish_chunk(task_id, StreamTextEnd(id=text_block_id)) + + if assistant_content: + # Reload session from DB to avoid race condition with user messages + fresh_session = await get_chat_session(session_id, user_id) + if not fresh_session: + logger.error( + f"Session {session_id} disappeared during LLM continuation" + ) + return + + # Save assistant message to database + assistant_message = ChatMessage( + role="assistant", + content=assistant_content, + ) + fresh_session.messages.append(assistant_message) + + # Save to database (not cache) to persist the response + await upsert_chat_session(fresh_session) + + # Invalidate cache so next poll/refresh gets fresh data + await invalidate_session_cache(session_id) + + logger.info( + f"Generated streaming LLM continuation for session {session_id} " + f"(task_id={task_id}), response length: {len(assistant_content)}" + ) + else: + logger.warning( + f"Streaming LLM continuation returned empty response for {session_id}" + ) + + except Exception as e: + logger.error( + f"Failed to generate streaming LLM continuation: {e}", exc_info=True + ) + # Publish error to stream registry followed by finish event + await stream_registry.publish_chunk( + task_id, + StreamError(errorText=f"Failed to generate response: {e}"), + ) + await stream_registry.publish_chunk(task_id, StreamFinish()) diff --git a/autogpt_platform/backend/backend/api/features/chat/stream_registry.py b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py new file mode 100644 index 0000000000..88a5023e2b --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/stream_registry.py @@ -0,0 +1,704 @@ +"""Stream registry for managing reconnectable SSE streams. + +This module provides a registry for tracking active streaming tasks and their +messages. It uses Redis for all state management (no in-memory state), making +pods stateless and horizontally scalable. + +Architecture: +- Redis Stream: Persists all messages for replay and real-time delivery +- Redis Hash: Task metadata (status, session_id, etc.) + +Subscribers: +1. Replay missed messages from Redis Stream (XREAD) +2. Listen for live updates via blocking XREAD +3. No in-memory state required on the subscribing pod +""" + +import asyncio +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Literal + +import orjson + +from backend.data.redis_client import get_redis_async + +from .config import ChatConfig +from .response_model import StreamBaseResponse, StreamError, StreamFinish + +logger = logging.getLogger(__name__) +config = ChatConfig() + +# Track background tasks for this pod (just the asyncio.Task reference, not subscribers) +_local_tasks: dict[str, asyncio.Task] = {} + +# Track listener tasks per subscriber queue for cleanup +# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe +_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {} + +# Timeout for putting chunks into subscriber queues (seconds) +# If the queue is full and doesn't drain within this time, send an overflow error +QUEUE_PUT_TIMEOUT = 5.0 + +# Lua script for atomic compare-and-swap status update (idempotent completion) +# Returns 1 if status was updated, 0 if already completed/failed +COMPLETE_TASK_SCRIPT = """ +local current = redis.call("HGET", KEYS[1], "status") +if current == "running" then + redis.call("HSET", KEYS[1], "status", ARGV[1]) + return 1 +end +return 0 +""" + + +@dataclass +class ActiveTask: + """Represents an active streaming task (metadata only, no in-memory queues).""" + + task_id: str + session_id: str + user_id: str | None + tool_call_id: str + tool_name: str + operation_id: str + status: Literal["running", "completed", "failed"] = "running" + created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) + asyncio_task: asyncio.Task | None = None + + +def _get_task_meta_key(task_id: str) -> str: + """Get Redis key for task metadata.""" + return f"{config.task_meta_prefix}{task_id}" + + +def _get_task_stream_key(task_id: str) -> str: + """Get Redis key for task message stream.""" + return f"{config.task_stream_prefix}{task_id}" + + +def _get_operation_mapping_key(operation_id: str) -> str: + """Get Redis key for operation_id to task_id mapping.""" + return f"{config.task_op_prefix}{operation_id}" + + +async def create_task( + task_id: str, + session_id: str, + user_id: str | None, + tool_call_id: str, + tool_name: str, + operation_id: str, +) -> ActiveTask: + """Create a new streaming task in Redis. + + Args: + task_id: Unique identifier for the task + session_id: Chat session ID + user_id: User ID (may be None for anonymous) + tool_call_id: Tool call ID from the LLM + tool_name: Name of the tool being executed + operation_id: Operation ID for webhook callbacks + + Returns: + The created ActiveTask instance (metadata only) + """ + task = ActiveTask( + task_id=task_id, + session_id=session_id, + user_id=user_id, + tool_call_id=tool_call_id, + tool_name=tool_name, + operation_id=operation_id, + ) + + # Store metadata in Redis + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + op_key = _get_operation_mapping_key(operation_id) + + await redis.hset( # type: ignore[misc] + meta_key, + mapping={ + "task_id": task_id, + "session_id": session_id, + "user_id": user_id or "", + "tool_call_id": tool_call_id, + "tool_name": tool_name, + "operation_id": operation_id, + "status": task.status, + "created_at": task.created_at.isoformat(), + }, + ) + await redis.expire(meta_key, config.stream_ttl) + + # Create operation_id -> task_id mapping for webhook lookups + await redis.set(op_key, task_id, ex=config.stream_ttl) + + logger.debug(f"Created task {task_id} for session {session_id}") + + return task + + +async def publish_chunk( + task_id: str, + chunk: StreamBaseResponse, +) -> str: + """Publish a chunk to Redis Stream. + + All delivery is via Redis Streams - no in-memory state. + + Args: + task_id: Task ID to publish to + chunk: The stream response chunk to publish + + Returns: + The Redis Stream message ID + """ + chunk_json = chunk.model_dump_json() + message_id = "0-0" + + try: + redis = await get_redis_async() + stream_key = _get_task_stream_key(task_id) + + # Write to Redis Stream for persistence and real-time delivery + raw_id = await redis.xadd( + stream_key, + {"data": chunk_json}, + maxlen=config.stream_max_length, + ) + message_id = raw_id if isinstance(raw_id, str) else raw_id.decode() + + # Set TTL on stream to match task metadata TTL + await redis.expire(stream_key, config.stream_ttl) + except Exception as e: + logger.error( + f"Failed to publish chunk for task {task_id}: {e}", + exc_info=True, + ) + + return message_id + + +async def subscribe_to_task( + task_id: str, + user_id: str | None, + last_message_id: str = "0-0", +) -> asyncio.Queue[StreamBaseResponse] | None: + """Subscribe to a task's stream with replay of missed messages. + + This is fully stateless - uses Redis Stream for replay and pub/sub for live updates. + + Args: + task_id: Task ID to subscribe to + user_id: User ID for ownership validation + last_message_id: Last Redis Stream message ID received ("0-0" for full replay) + + Returns: + An asyncio Queue that will receive stream chunks, or None if task not found + or user doesn't have access + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + logger.debug(f"Task {task_id} not found in Redis") + return None + + # Note: Redis client uses decode_responses=True, so keys are strings + task_status = meta.get("status", "") + task_user_id = meta.get("user_id", "") or None + + # Validate ownership - if task has an owner, requester must match + if task_user_id: + if user_id != task_user_id: + logger.warning( + f"User {user_id} denied access to task {task_id} " + f"owned by {task_user_id}" + ) + return None + + subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue() + stream_key = _get_task_stream_key(task_id) + + # Step 1: Replay messages from Redis Stream + messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000) + + replayed_count = 0 + replay_last_id = last_message_id + if messages: + for _stream_name, stream_messages in messages: + for msg_id, msg_data in stream_messages: + replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + # Note: Redis client uses decode_responses=True, so keys are strings + if "data" in msg_data: + try: + chunk_data = orjson.loads(msg_data["data"]) + chunk = _reconstruct_chunk(chunk_data) + if chunk: + await subscriber_queue.put(chunk) + replayed_count += 1 + except Exception as e: + logger.warning(f"Failed to replay message: {e}") + + logger.debug(f"Task {task_id}: replayed {replayed_count} messages") + + # Step 2: If task is still running, start stream listener for live updates + if task_status == "running": + listener_task = asyncio.create_task( + _stream_listener(task_id, subscriber_queue, replay_last_id) + ) + # Track listener task for cleanup on unsubscribe + _listener_tasks[id(subscriber_queue)] = (task_id, listener_task) + else: + # Task is completed/failed - add finish marker + await subscriber_queue.put(StreamFinish()) + + return subscriber_queue + + +async def _stream_listener( + task_id: str, + subscriber_queue: asyncio.Queue[StreamBaseResponse], + last_replayed_id: str, +) -> None: + """Listen to Redis Stream for new messages using blocking XREAD. + + This approach avoids the duplicate message issue that can occur with pub/sub + when messages are published during the gap between replay and subscription. + + Args: + task_id: Task ID to listen for + subscriber_queue: Queue to deliver messages to + last_replayed_id: Last message ID from replay (continue from here) + """ + queue_id = id(subscriber_queue) + # Track the last successfully delivered message ID for recovery hints + last_delivered_id = last_replayed_id + + try: + redis = await get_redis_async() + stream_key = _get_task_stream_key(task_id) + current_id = last_replayed_id + + while True: + # Block for up to 30 seconds waiting for new messages + # This allows periodic checking if task is still running + messages = await redis.xread( + {stream_key: current_id}, block=30000, count=100 + ) + + if not messages: + # Timeout - check if task is still running + meta_key = _get_task_meta_key(task_id) + status = await redis.hget(meta_key, "status") # type: ignore[misc] + if status and status != "running": + try: + await asyncio.wait_for( + subscriber_queue.put(StreamFinish()), + timeout=QUEUE_PUT_TIMEOUT, + ) + except asyncio.TimeoutError: + logger.warning( + f"Timeout delivering finish event for task {task_id}" + ) + break + continue + + for _stream_name, stream_messages in messages: + for msg_id, msg_data in stream_messages: + current_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + + if "data" not in msg_data: + continue + + try: + chunk_data = orjson.loads(msg_data["data"]) + chunk = _reconstruct_chunk(chunk_data) + if chunk: + try: + await asyncio.wait_for( + subscriber_queue.put(chunk), + timeout=QUEUE_PUT_TIMEOUT, + ) + # Update last delivered ID on successful delivery + last_delivered_id = current_id + except asyncio.TimeoutError: + logger.warning( + f"Subscriber queue full for task {task_id}, " + f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s" + ) + # Send overflow error with recovery info + try: + overflow_error = StreamError( + errorText="Message delivery timeout - some messages may have been missed", + code="QUEUE_OVERFLOW", + details={ + "last_delivered_id": last_delivered_id, + "recovery_hint": f"Reconnect with last_message_id={last_delivered_id}", + }, + ) + subscriber_queue.put_nowait(overflow_error) + except asyncio.QueueFull: + # Queue is completely stuck, nothing more we can do + logger.error( + f"Cannot deliver overflow error for task {task_id}, " + "queue completely blocked" + ) + + # Stop listening on finish + if isinstance(chunk, StreamFinish): + return + except Exception as e: + logger.warning(f"Error processing stream message: {e}") + + except asyncio.CancelledError: + logger.debug(f"Stream listener cancelled for task {task_id}") + raise # Re-raise to propagate cancellation + except Exception as e: + logger.error(f"Stream listener error for task {task_id}: {e}") + # On error, send finish to unblock subscriber + try: + await asyncio.wait_for( + subscriber_queue.put(StreamFinish()), + timeout=QUEUE_PUT_TIMEOUT, + ) + except (asyncio.TimeoutError, asyncio.QueueFull): + logger.warning( + f"Could not deliver finish event for task {task_id} after error" + ) + finally: + # Clean up listener task mapping on exit + _listener_tasks.pop(queue_id, None) + + +async def mark_task_completed( + task_id: str, + status: Literal["completed", "failed"] = "completed", +) -> bool: + """Mark a task as completed and publish finish event. + + This is idempotent - calling multiple times with the same task_id is safe. + Uses atomic compare-and-swap via Lua script to prevent race conditions. + Status is updated first (source of truth), then finish event is published (best-effort). + + Args: + task_id: Task ID to mark as completed + status: Final status ("completed" or "failed") + + Returns: + True if task was newly marked completed, False if already completed/failed + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + + # Atomic compare-and-swap: only update if status is "running" + # This prevents race conditions when multiple callers try to complete simultaneously + result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc] + + if result == 0: + logger.debug(f"Task {task_id} already completed/failed, skipping") + return False + + # THEN publish finish event (best-effort - listeners can detect via status polling) + try: + await publish_chunk(task_id, StreamFinish()) + except Exception as e: + logger.error( + f"Failed to publish finish event for task {task_id}: {e}. " + "Listeners will detect completion via status polling." + ) + + # Clean up local task reference if exists + _local_tasks.pop(task_id, None) + return True + + +async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None: + """Find a task by its operation ID. + + Used by webhook callbacks to locate the task to update. + + Args: + operation_id: Operation ID to search for + + Returns: + ActiveTask if found, None otherwise + """ + redis = await get_redis_async() + op_key = _get_operation_mapping_key(operation_id) + task_id = await redis.get(op_key) + + if not task_id: + return None + + task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id + return await get_task(task_id_str) + + +async def get_task(task_id: str) -> ActiveTask | None: + """Get a task by its ID from Redis. + + Args: + task_id: Task ID to look up + + Returns: + ActiveTask if found, None otherwise + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + return None + + # Note: Redis client uses decode_responses=True, so keys/values are strings + return ActiveTask( + task_id=meta.get("task_id", ""), + session_id=meta.get("session_id", ""), + user_id=meta.get("user_id", "") or None, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status=meta.get("status", "running"), # type: ignore[arg-type] + ) + + +async def get_task_with_expiry_info( + task_id: str, +) -> tuple[ActiveTask | None, str | None]: + """Get a task by its ID with expiration detection. + + Returns (task, error_code) where error_code is: + - None if task found + - "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired) + - "TASK_NOT_FOUND" if neither exists + + Args: + task_id: Task ID to look up + + Returns: + Tuple of (ActiveTask or None, error_code or None) + """ + redis = await get_redis_async() + meta_key = _get_task_meta_key(task_id) + stream_key = _get_task_stream_key(task_id) + + meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc] + + if not meta: + # Check if stream still has data (metadata expired but stream hasn't) + stream_len = await redis.xlen(stream_key) + if stream_len > 0: + return None, "TASK_EXPIRED" + return None, "TASK_NOT_FOUND" + + # Note: Redis client uses decode_responses=True, so keys/values are strings + return ( + ActiveTask( + task_id=meta.get("task_id", ""), + session_id=meta.get("session_id", ""), + user_id=meta.get("user_id", "") or None, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status=meta.get("status", "running"), # type: ignore[arg-type] + ), + None, + ) + + +async def get_active_task_for_session( + session_id: str, + user_id: str | None = None, +) -> tuple[ActiveTask | None, str]: + """Get the active (running) task for a session, if any. + + Scans Redis for tasks matching the session_id with status="running". + + Args: + session_id: Session ID to look up + user_id: User ID for ownership validation (optional) + + Returns: + Tuple of (ActiveTask if found and running, last_message_id from Redis Stream) + """ + + redis = await get_redis_async() + + # Scan Redis for task metadata keys + cursor = 0 + tasks_checked = 0 + + while True: + cursor, keys = await redis.scan( + cursor, match=f"{config.task_meta_prefix}*", count=100 + ) + + for key in keys: + tasks_checked += 1 + meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc] + if not meta: + continue + + # Note: Redis client uses decode_responses=True, so keys/values are strings + task_session_id = meta.get("session_id", "") + task_status = meta.get("status", "") + task_user_id = meta.get("user_id", "") or None + task_id = meta.get("task_id", "") + + if task_session_id == session_id and task_status == "running": + # Validate ownership - if task has an owner, requester must match + if task_user_id and user_id != task_user_id: + continue + + # Get the last message ID from Redis Stream + stream_key = _get_task_stream_key(task_id) + last_id = "0-0" + try: + messages = await redis.xrevrange(stream_key, count=1) + if messages: + msg_id = messages[0][0] + last_id = msg_id if isinstance(msg_id, str) else msg_id.decode() + except Exception as e: + logger.warning(f"Failed to get last message ID: {e}") + + return ( + ActiveTask( + task_id=task_id, + session_id=task_session_id, + user_id=task_user_id, + tool_call_id=meta.get("tool_call_id", ""), + tool_name=meta.get("tool_name", ""), + operation_id=meta.get("operation_id", ""), + status="running", + ), + last_id, + ) + + if cursor == 0: + break + + return None, "0-0" + + +def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None: + """Reconstruct a StreamBaseResponse from JSON data. + + Args: + chunk_data: Parsed JSON data from Redis + + Returns: + Reconstructed response object, or None if unknown type + """ + from .response_model import ( + ResponseType, + StreamError, + StreamFinish, + StreamHeartbeat, + StreamStart, + StreamTextDelta, + StreamTextEnd, + StreamTextStart, + StreamToolInputAvailable, + StreamToolInputStart, + StreamToolOutputAvailable, + StreamUsage, + ) + + # Map response types to their corresponding classes + type_to_class: dict[str, type[StreamBaseResponse]] = { + ResponseType.START.value: StreamStart, + ResponseType.FINISH.value: StreamFinish, + ResponseType.TEXT_START.value: StreamTextStart, + ResponseType.TEXT_DELTA.value: StreamTextDelta, + ResponseType.TEXT_END.value: StreamTextEnd, + ResponseType.TOOL_INPUT_START.value: StreamToolInputStart, + ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable, + ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable, + ResponseType.ERROR.value: StreamError, + ResponseType.USAGE.value: StreamUsage, + ResponseType.HEARTBEAT.value: StreamHeartbeat, + } + + chunk_type = chunk_data.get("type") + chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type] + + if chunk_class is None: + logger.warning(f"Unknown chunk type: {chunk_type}") + return None + + try: + return chunk_class(**chunk_data) + except Exception as e: + logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}") + return None + + +async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None: + """Track the asyncio.Task for a task (local reference only). + + This is just for cleanup purposes - the task state is in Redis. + + Args: + task_id: Task ID + asyncio_task: The asyncio Task to track + """ + _local_tasks[task_id] = asyncio_task + + +async def unsubscribe_from_task( + task_id: str, + subscriber_queue: asyncio.Queue[StreamBaseResponse], +) -> None: + """Clean up when a subscriber disconnects. + + Cancels the XREAD-based listener task associated with this subscriber queue + to prevent resource leaks. + + Args: + task_id: Task ID + subscriber_queue: The subscriber's queue used to look up the listener task + """ + queue_id = id(subscriber_queue) + listener_entry = _listener_tasks.pop(queue_id, None) + + if listener_entry is None: + logger.debug( + f"No listener task found for task {task_id} queue {queue_id} " + "(may have already completed)" + ) + return + + stored_task_id, listener_task = listener_entry + + if stored_task_id != task_id: + logger.warning( + f"Task ID mismatch in unsubscribe: expected {task_id}, " + f"found {stored_task_id}" + ) + + if listener_task.done(): + logger.debug(f"Listener task for task {task_id} already completed") + return + + # Cancel the listener task + listener_task.cancel() + + try: + # Wait for the task to be cancelled with a timeout + await asyncio.wait_for(listener_task, timeout=5.0) + except asyncio.CancelledError: + # Expected - the task was successfully cancelled + pass + except asyncio.TimeoutError: + logger.warning( + f"Timeout waiting for listener task cancellation for task {task_id}" + ) + except Exception as e: + logger.error(f"Error during listener task cancellation for task {task_id}: {e}") + + logger.debug(f"Successfully unsubscribed from task {task_id}") diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py index d078860c3a..dcbc35ef37 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/__init__.py @@ -10,6 +10,7 @@ from .add_understanding import AddUnderstandingTool from .agent_output import AgentOutputTool from .base import BaseTool from .create_agent import CreateAgentTool +from .customize_agent import CustomizeAgentTool from .edit_agent import EditAgentTool from .find_agent import FindAgentTool from .find_block import FindBlockTool @@ -34,6 +35,7 @@ logger = logging.getLogger(__name__) TOOL_REGISTRY: dict[str, BaseTool] = { "add_understanding": AddUnderstandingTool(), "create_agent": CreateAgentTool(), + "customize_agent": CustomizeAgentTool(), "edit_agent": EditAgentTool(), "find_agent": FindAgentTool(), "find_block": FindBlockTool(), diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py index b7650b3cbd..4266834220 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/__init__.py @@ -8,6 +8,7 @@ from .core import ( DecompositionStep, LibraryAgentSummary, MarketplaceAgentSummary, + customize_template, decompose_goal, enrich_library_agents_from_steps, extract_search_terms_from_steps, @@ -19,6 +20,7 @@ from .core import ( get_library_agent_by_graph_id, get_library_agent_by_id, get_library_agents_for_generation, + graph_to_json, json_to_graph, save_agent_to_library, search_marketplace_agents_for_generation, @@ -36,6 +38,7 @@ __all__ = [ "LibraryAgentSummary", "MarketplaceAgentSummary", "check_external_service_health", + "customize_template", "decompose_goal", "enrich_library_agents_from_steps", "extract_search_terms_from_steps", @@ -48,6 +51,7 @@ __all__ = [ "get_library_agent_by_id", "get_library_agents_for_generation", "get_user_message_for_error", + "graph_to_json", "is_external_service_configured", "json_to_graph", "save_agent_to_library", diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py index 0ddd2aa86b..f83ca30b5c 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/core.py @@ -7,18 +7,11 @@ from typing import Any, NotRequired, TypedDict from backend.api.features.library import db as library_db from backend.api.features.store import db as store_db -from backend.data.graph import ( - Graph, - Link, - Node, - create_graph, - get_graph, - get_graph_all_versions, - get_store_listed_graphs, -) +from backend.data.graph import Graph, Link, Node, get_graph, get_store_listed_graphs from backend.util.exceptions import DatabaseError, NotFoundError from .service import ( + customize_template_external, decompose_goal_external, generate_agent_external, generate_agent_patch_external, @@ -27,8 +20,6 @@ from .service import ( logger = logging.getLogger(__name__) -AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565" - class ExecutionSummary(TypedDict): """Summary of a single execution for quality assessment.""" @@ -549,15 +540,21 @@ async def decompose_goal( async def generate_agent( instructions: DecompositionResult | dict[str, Any], library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Generate agent JSON from instructions. Args: instructions: Structured instructions from decompose_goal library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams + completion notification) + task_id: Task ID for async processing (enables Redis Streams persistence + and SSE delivery) Returns: - Agent JSON dict, error dict {"type": "error", ...}, or None on error + Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error Raises: AgentGeneratorNotConfiguredError: If the external service is not configured. @@ -565,8 +562,13 @@ async def generate_agent( _check_service_configured() logger.info("Calling external Agent Generator service for generate_agent") result = await generate_agent_external( - dict(instructions), _to_dict_list(library_agents) + dict(instructions), _to_dict_list(library_agents), operation_id, task_id ) + + # Don't modify async response + if result and result.get("status") == "accepted": + return result + if result: if isinstance(result, dict) and result.get("type") == "error": return result @@ -657,45 +659,6 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph: ) -def _reassign_node_ids(graph: Graph) -> None: - """Reassign all node and link IDs to new UUIDs. - - This is needed when creating a new version to avoid unique constraint violations. - """ - id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes} - - for node in graph.nodes: - node.id = id_map[node.id] - - for link in graph.links: - link.id = str(uuid.uuid4()) - if link.source_id in id_map: - link.source_id = id_map[link.source_id] - if link.sink_id in id_map: - link.sink_id = id_map[link.sink_id] - - -def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None: - """Populate user_id in AgentExecutorBlock nodes. - - The external agent generator creates AgentExecutorBlock nodes with empty user_id. - This function fills in the actual user_id so sub-agents run with correct permissions. - - Args: - agent_json: Agent JSON dict (modified in place) - user_id: User ID to set - """ - for node in agent_json.get("nodes", []): - if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID: - input_default = node.get("input_default") or {} - if not input_default.get("user_id"): - input_default["user_id"] = user_id - node["input_default"] = input_default - logger.debug( - f"Set user_id for AgentExecutorBlock node {node.get('id')}" - ) - - async def save_agent_to_library( agent_json: dict[str, Any], user_id: str, is_update: bool = False ) -> tuple[Graph, Any]: @@ -709,63 +672,21 @@ async def save_agent_to_library( Returns: Tuple of (created Graph, LibraryAgent) """ - # Populate user_id in AgentExecutorBlock nodes before conversion - _populate_agent_executor_user_ids(agent_json, user_id) - graph = json_to_graph(agent_json) - if is_update: - if graph.id: - existing_versions = await get_graph_all_versions(graph.id, user_id) - if existing_versions: - latest_version = max(v.version for v in existing_versions) - graph.version = latest_version + 1 - _reassign_node_ids(graph) - logger.info(f"Updating agent {graph.id} to version {graph.version}") - else: - graph.id = str(uuid.uuid4()) - graph.version = 1 - _reassign_node_ids(graph) - logger.info(f"Creating new agent with ID {graph.id}") - - created_graph = await create_graph(graph, user_id) - - library_agents = await library_db.create_library_agent( - graph=created_graph, - user_id=user_id, - sensitive_action_safe_mode=True, - create_library_agents_for_sub_graphs=False, - ) - - return created_graph, library_agents[0] + return await library_db.update_graph_in_library(graph, user_id) + return await library_db.create_graph_in_library(graph, user_id) -async def get_agent_as_json( - agent_id: str, user_id: str | None -) -> dict[str, Any] | None: - """Fetch an agent and convert to JSON format for editing. +def graph_to_json(graph: Graph) -> dict[str, Any]: + """Convert a Graph object to JSON format for the agent generator. Args: - agent_id: Graph ID or library agent ID - user_id: User ID + graph: Graph object to convert Returns: - Agent as JSON dict or None if not found + Agent as JSON dict """ - graph = await get_graph(agent_id, version=None, user_id=user_id) - - if not graph and user_id: - try: - library_agent = await library_db.get_library_agent(agent_id, user_id) - graph = await get_graph( - library_agent.graph_id, version=None, user_id=user_id - ) - except NotFoundError: - pass - - if not graph: - return None - nodes = [] for node in graph.nodes: nodes.append( @@ -802,10 +723,41 @@ async def get_agent_as_json( } +async def get_agent_as_json( + agent_id: str, user_id: str | None +) -> dict[str, Any] | None: + """Fetch an agent and convert to JSON format for editing. + + Args: + agent_id: Graph ID or library agent ID + user_id: User ID + + Returns: + Agent as JSON dict or None if not found + """ + graph = await get_graph(agent_id, version=None, user_id=user_id) + + if not graph and user_id: + try: + library_agent = await library_db.get_library_agent(agent_id, user_id) + graph = await get_graph( + library_agent.graph_id, version=None, user_id=user_id + ) + except NotFoundError: + pass + + if not graph: + return None + + return graph_to_json(graph) + + async def generate_agent_patch( update_request: str, current_agent: dict[str, Any], library_agents: list[AgentSummary] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Update an existing agent using natural language. @@ -818,10 +770,12 @@ async def generate_agent_patch( update_request: Natural language description of changes current_agent: Current agent JSON library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...}, - error dict {"type": "error", ...}, or None on unexpected error + {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error Raises: AgentGeneratorNotConfiguredError: If the external service is not configured. @@ -829,5 +783,43 @@ async def generate_agent_patch( _check_service_configured() logger.info("Calling external Agent Generator service for generate_agent_patch") return await generate_agent_patch_external( - update_request, current_agent, _to_dict_list(library_agents) + update_request, + current_agent, + _to_dict_list(library_agents), + operation_id, + task_id, + ) + + +async def customize_template( + template_agent: dict[str, Any], + modification_request: str, + context: str = "", +) -> dict[str, Any] | None: + """Customize a template/marketplace agent using natural language. + + This is used when users want to modify a template or marketplace agent + to fit their specific needs before adding it to their library. + + The external Agent Generator service handles: + - Understanding the modification request + - Applying changes to the template + - Fixing and validating the result + + Args: + template_agent: The template agent JSON to customize + modification_request: Natural language description of customizations + context: Additional context (e.g., answers to previous questions) + + Returns: + Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...}, + error dict {"type": "error", ...}, or None on unexpected error + + Raises: + AgentGeneratorNotConfiguredError: If the external service is not configured. + """ + _check_service_configured() + logger.info("Calling external Agent Generator service for customize_template") + return await customize_template_external( + template_agent, modification_request, context ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py index c9c960d1ae..62411b4e1b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_generator/service.py @@ -212,24 +212,45 @@ async def decompose_goal_external( async def generate_agent_external( instructions: dict[str, Any], library_agents: list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Call the external service to generate an agent from instructions. Args: instructions: Structured instructions from decompose_goal library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: - Agent JSON dict on success, or error dict {"type": "error", ...} on error + Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error """ client = _get_client() + # Build request payload payload: dict[str, Any] = {"instructions": instructions} if library_agents: payload["library_agents"] = library_agents + if operation_id and task_id: + payload["operation_id"] = operation_id + payload["task_id"] = task_id try: response = await client.post("/api/generate-agent", json=payload) + + # Handle 202 Accepted for async processing + if response.status_code == 202: + logger.info( + f"Agent Generator accepted async request " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return { + "status": "accepted", + "operation_id": operation_id, + "task_id": task_id, + } + response.raise_for_status() data = response.json() @@ -261,6 +282,8 @@ async def generate_agent_patch_external( update_request: str, current_agent: dict[str, Any], library_agents: list[dict[str, Any]] | None = None, + operation_id: str | None = None, + task_id: str | None = None, ) -> dict[str, Any] | None: """Call the external service to generate a patch for an existing agent. @@ -268,21 +291,40 @@ async def generate_agent_patch_external( update_request: Natural language description of changes current_agent: Current agent JSON library_agents: User's library agents available for sub-agent composition + operation_id: Operation ID for async processing (enables Redis Streams callback) + task_id: Task ID for async processing (enables Redis Streams callback) Returns: - Updated agent JSON, clarifying questions dict, or error dict on error + Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error """ client = _get_client() + # Build request payload payload: dict[str, Any] = { "update_request": update_request, "current_agent_json": current_agent, } if library_agents: payload["library_agents"] = library_agents + if operation_id and task_id: + payload["operation_id"] = operation_id + payload["task_id"] = task_id try: response = await client.post("/api/update-agent", json=payload) + + # Handle 202 Accepted for async processing + if response.status_code == 202: + logger.info( + f"Agent Generator accepted async update request " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return { + "status": "accepted", + "operation_id": operation_id, + "task_id": task_id, + } + response.raise_for_status() data = response.json() @@ -326,6 +368,77 @@ async def generate_agent_patch_external( return _create_error_response(error_msg, "unexpected_error") +async def customize_template_external( + template_agent: dict[str, Any], + modification_request: str, + context: str = "", +) -> dict[str, Any] | None: + """Call the external service to customize a template/marketplace agent. + + Args: + template_agent: The template agent JSON to customize + modification_request: Natural language description of customizations + context: Additional context (e.g., answers to previous questions) + + Returns: + Customized agent JSON, clarifying questions dict, or error dict on error + """ + client = _get_client() + + request = modification_request + if context: + request = f"{modification_request}\n\nAdditional context from user:\n{context}" + + payload: dict[str, Any] = { + "template_agent_json": template_agent, + "modification_request": request, + } + + try: + response = await client.post("/api/template-modification", json=payload) + response.raise_for_status() + data = response.json() + + if not data.get("success"): + error_msg = data.get("error", "Unknown error from Agent Generator") + error_type = data.get("error_type", "unknown") + logger.error( + f"Agent Generator template customization failed: {error_msg} " + f"(type: {error_type})" + ) + return _create_error_response(error_msg, error_type) + + # Check if it's clarifying questions + if data.get("type") == "clarifying_questions": + return { + "type": "clarifying_questions", + "questions": data.get("questions", []), + } + + # Check if it's an error passed through + if data.get("type") == "error": + return _create_error_response( + data.get("error", "Unknown error"), + data.get("error_type", "unknown"), + ) + + # Otherwise return the customized agent JSON + return data.get("agent_json") + + except httpx.HTTPStatusError as e: + error_type, error_msg = _classify_http_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except httpx.RequestError as e: + error_type, error_msg = _classify_request_error(e) + logger.error(error_msg) + return _create_error_response(error_msg, error_type) + except Exception as e: + error_msg = f"Unexpected error calling Agent Generator: {e}" + logger.error(error_msg) + return _create_error_response(error_msg, "unexpected_error") + + async def get_blocks_external() -> list[dict[str, Any]] | None: """Get available blocks from the external service. diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py b/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py index 62d59c470e..61cdba1ef9 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/agent_search.py @@ -206,9 +206,9 @@ async def search_agents( ] ) no_results_msg = ( - f"No agents found matching '{query}'. Try different keywords or browse the marketplace." + f"No agents found matching '{query}'. Let the user know they can try different keywords or browse the marketplace. Also let them know you can create a custom agent for them based on their needs." if source == "marketplace" - else f"No agents matching '{query}' found in your library." + else f"No agents matching '{query}' found in your library. Let the user know you can create a custom agent for them based on their needs." ) return NoResultsResponse( message=no_results_msg, session_id=session_id, suggestions=suggestions @@ -224,10 +224,10 @@ async def search_agents( message = ( "Now you have found some options for the user to choose from. " "You can add a link to a recommended agent at: /marketplace/agent/agent_id " - "Please ask the user if they would like to use any of these agents." + "Please ask the user if they would like to use any of these agents. Let the user know we can create a custom agent for them based on their needs." if source == "marketplace" else "Found agents in the user's library. You can provide a link to view an agent at: " - "/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute." + "/library/agents/{agent_id}. Use agent_output to get execution results, or run_agent to execute. Let the user know we can create a custom agent for them based on their needs." ) return AgentsFoundResponse( diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py index adb2c78fce..7333851a5b 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/create_agent.py @@ -18,6 +18,7 @@ from .base import BaseTool from .models import ( AgentPreviewResponse, AgentSavedResponse, + AsyncProcessingResponse, ClarificationNeededResponse, ClarifyingQuestion, ErrorResponse, @@ -98,6 +99,10 @@ class CreateAgentTool(BaseTool): save = kwargs.get("save", True) session_id = session.session_id if session else None + # Extract async processing params (passed by long-running tool handler) + operation_id = kwargs.get("_operation_id") + task_id = kwargs.get("_task_id") + if not description: return ErrorResponse( message="Please provide a description of what the agent should do.", @@ -219,7 +224,12 @@ class CreateAgentTool(BaseTool): logger.warning(f"Failed to enrich library agents from steps: {e}") try: - agent_json = await generate_agent(decomposition_result, library_agents) + agent_json = await generate_agent( + decomposition_result, + library_agents, + operation_id=operation_id, + task_id=task_id, + ) except AgentGeneratorNotConfiguredError: return ErrorResponse( message=( @@ -263,6 +273,19 @@ class CreateAgentTool(BaseTool): session_id=session_id, ) + # Check if Agent Generator accepted for async processing + if agent_json.get("status") == "accepted": + logger.info( + f"Agent generation delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return AsyncProcessingResponse( + message="Agent generation started. You'll be notified when it's complete.", + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + ) + agent_name = agent_json.get("name", "Generated Agent") agent_description = agent_json.get("description", "") node_count = len(agent_json.get("nodes", [])) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py new file mode 100644 index 0000000000..c0568bd936 --- /dev/null +++ b/autogpt_platform/backend/backend/api/features/chat/tools/customize_agent.py @@ -0,0 +1,337 @@ +"""CustomizeAgentTool - Customizes marketplace/template agents using natural language.""" + +import logging +from typing import Any + +from backend.api.features.chat.model import ChatSession +from backend.api.features.store import db as store_db +from backend.api.features.store.exceptions import AgentNotFoundError + +from .agent_generator import ( + AgentGeneratorNotConfiguredError, + customize_template, + get_user_message_for_error, + graph_to_json, + save_agent_to_library, +) +from .base import BaseTool +from .models import ( + AgentPreviewResponse, + AgentSavedResponse, + ClarificationNeededResponse, + ClarifyingQuestion, + ErrorResponse, + ToolResponseBase, +) + +logger = logging.getLogger(__name__) + + +class CustomizeAgentTool(BaseTool): + """Tool for customizing marketplace/template agents using natural language.""" + + @property + def name(self) -> str: + return "customize_agent" + + @property + def description(self) -> str: + return ( + "Customize a marketplace or template agent using natural language. " + "Takes an existing agent from the marketplace and modifies it based on " + "the user's requirements before adding to their library." + ) + + @property + def requires_auth(self) -> bool: + return True + + @property + def is_long_running(self) -> bool: + return True + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "agent_id": { + "type": "string", + "description": ( + "The marketplace agent ID in format 'creator/slug' " + "(e.g., 'autogpt/newsletter-writer'). " + "Get this from find_agent results." + ), + }, + "modifications": { + "type": "string", + "description": ( + "Natural language description of how to customize the agent. " + "Be specific about what changes you want to make." + ), + }, + "context": { + "type": "string", + "description": ( + "Additional context or answers to previous clarifying questions." + ), + }, + "save": { + "type": "boolean", + "description": ( + "Whether to save the customized agent to the user's library. " + "Default is true. Set to false for preview only." + ), + "default": True, + }, + }, + "required": ["agent_id", "modifications"], + } + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs, + ) -> ToolResponseBase: + """Execute the customize_agent tool. + + Flow: + 1. Parse the agent ID to get creator/slug + 2. Fetch the template agent from the marketplace + 3. Call customize_template with the modification request + 4. Preview or save based on the save parameter + """ + agent_id = kwargs.get("agent_id", "").strip() + modifications = kwargs.get("modifications", "").strip() + context = kwargs.get("context", "") + save = kwargs.get("save", True) + session_id = session.session_id if session else None + + if not agent_id: + return ErrorResponse( + message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').", + error="missing_agent_id", + session_id=session_id, + ) + + if not modifications: + return ErrorResponse( + message="Please describe how you want to customize this agent.", + error="missing_modifications", + session_id=session_id, + ) + + # Parse agent_id in format "creator/slug" + parts = [p.strip() for p in agent_id.split("/")] + if len(parts) != 2 or not parts[0] or not parts[1]: + return ErrorResponse( + message=( + f"Invalid agent ID format: '{agent_id}'. " + "Expected format is 'creator/agent-name' " + "(e.g., 'autogpt/newsletter-writer')." + ), + error="invalid_agent_id_format", + session_id=session_id, + ) + + creator_username, agent_slug = parts + + # Fetch the marketplace agent details + try: + agent_details = await store_db.get_store_agent_details( + username=creator_username, agent_name=agent_slug + ) + except AgentNotFoundError: + return ErrorResponse( + message=( + f"Could not find marketplace agent '{agent_id}'. " + "Please check the agent ID and try again." + ), + error="agent_not_found", + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error fetching marketplace agent {agent_id}: {e}") + return ErrorResponse( + message="Failed to fetch the marketplace agent. Please try again.", + error="fetch_error", + session_id=session_id, + ) + + if not agent_details.store_listing_version_id: + return ErrorResponse( + message=( + f"The agent '{agent_id}' does not have an available version. " + "Please try a different agent." + ), + error="no_version_available", + session_id=session_id, + ) + + # Get the full agent graph + try: + graph = await store_db.get_agent(agent_details.store_listing_version_id) + template_agent = graph_to_json(graph) + except Exception as e: + logger.error(f"Error fetching agent graph for {agent_id}: {e}") + return ErrorResponse( + message="Failed to fetch the agent configuration. Please try again.", + error="graph_fetch_error", + session_id=session_id, + ) + + # Call customize_template + try: + result = await customize_template( + template_agent=template_agent, + modification_request=modifications, + context=context, + ) + except AgentGeneratorNotConfiguredError: + return ErrorResponse( + message=( + "Agent customization is not available. " + "The Agent Generator service is not configured." + ), + error="service_not_configured", + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error calling customize_template for {agent_id}: {e}") + return ErrorResponse( + message=( + "Failed to customize the agent due to a service error. " + "Please try again." + ), + error="customization_service_error", + session_id=session_id, + ) + + if result is None: + return ErrorResponse( + message=( + "Failed to customize the agent. " + "The agent generation service may be unavailable or timed out. " + "Please try again." + ), + error="customization_failed", + session_id=session_id, + ) + + # Handle error response + if isinstance(result, dict) and result.get("type") == "error": + error_msg = result.get("error", "Unknown error") + error_type = result.get("error_type", "unknown") + user_message = get_user_message_for_error( + error_type, + operation="customize the agent", + llm_parse_message=( + "The AI had trouble customizing the agent. " + "Please try again or simplify your request." + ), + validation_message=( + "The customized agent failed validation. " + "Please try rephrasing your request." + ), + error_details=error_msg, + ) + return ErrorResponse( + message=user_message, + error=f"customization_failed:{error_type}", + session_id=session_id, + ) + + # Handle clarifying questions + if isinstance(result, dict) and result.get("type") == "clarifying_questions": + questions = result.get("questions") or [] + if not isinstance(questions, list): + logger.error( + f"Unexpected clarifying questions format: {type(questions)}" + ) + questions = [] + return ClarificationNeededResponse( + message=( + "I need some more information to customize this agent. " + "Please answer the following questions:" + ), + questions=[ + ClarifyingQuestion( + question=q.get("question", ""), + keyword=q.get("keyword", ""), + example=q.get("example"), + ) + for q in questions + if isinstance(q, dict) + ], + session_id=session_id, + ) + + # Result should be the customized agent JSON + if not isinstance(result, dict): + logger.error(f"Unexpected customize_template response type: {type(result)}") + return ErrorResponse( + message="Failed to customize the agent due to an unexpected response.", + error="unexpected_response_type", + session_id=session_id, + ) + + customized_agent = result + + agent_name = customized_agent.get( + "name", f"Customized {agent_details.agent_name}" + ) + agent_description = customized_agent.get("description", "") + nodes = customized_agent.get("nodes") + links = customized_agent.get("links") + node_count = len(nodes) if isinstance(nodes, list) else 0 + link_count = len(links) if isinstance(links, list) else 0 + + if not save: + return AgentPreviewResponse( + message=( + f"I've customized the agent '{agent_details.agent_name}'. " + f"The customized agent has {node_count} blocks. " + f"Review it and call customize_agent with save=true to save it." + ), + agent_json=customized_agent, + agent_name=agent_name, + description=agent_description, + node_count=node_count, + link_count=link_count, + session_id=session_id, + ) + + if not user_id: + return ErrorResponse( + message="You must be logged in to save agents.", + error="auth_required", + session_id=session_id, + ) + + # Save to user's library + try: + created_graph, library_agent = await save_agent_to_library( + customized_agent, user_id, is_update=False + ) + + return AgentSavedResponse( + message=( + f"Customized agent '{created_graph.name}' " + f"(based on '{agent_details.agent_name}') " + f"has been saved to your library!" + ), + agent_id=created_graph.id, + agent_name=created_graph.name, + library_agent_id=library_agent.id, + library_agent_link=f"/library/agents/{library_agent.id}", + agent_page_link=f"/build?flowID={created_graph.id}", + session_id=session_id, + ) + except Exception as e: + logger.error(f"Error saving customized agent: {e}") + return ErrorResponse( + message="Failed to save the customized agent. Please try again.", + error="save_failed", + session_id=session_id, + ) diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py index 2c2c48226b..3ae56407a7 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/edit_agent.py @@ -17,6 +17,7 @@ from .base import BaseTool from .models import ( AgentPreviewResponse, AgentSavedResponse, + AsyncProcessingResponse, ClarificationNeededResponse, ClarifyingQuestion, ErrorResponse, @@ -104,6 +105,10 @@ class EditAgentTool(BaseTool): save = kwargs.get("save", True) session_id = session.session_id if session else None + # Extract async processing params (passed by long-running tool handler) + operation_id = kwargs.get("_operation_id") + task_id = kwargs.get("_task_id") + if not agent_id: return ErrorResponse( message="Please provide the agent ID to edit.", @@ -149,7 +154,11 @@ class EditAgentTool(BaseTool): try: result = await generate_agent_patch( - update_request, current_agent, library_agents + update_request, + current_agent, + library_agents, + operation_id=operation_id, + task_id=task_id, ) except AgentGeneratorNotConfiguredError: return ErrorResponse( @@ -169,6 +178,20 @@ class EditAgentTool(BaseTool): session_id=session_id, ) + # Check if Agent Generator accepted for async processing + if result.get("status") == "accepted": + logger.info( + f"Agent edit delegated to async processing " + f"(operation_id={operation_id}, task_id={task_id})" + ) + return AsyncProcessingResponse( + message="Agent edit started. You'll be notified when it's complete.", + operation_id=operation_id, + task_id=task_id, + session_id=session_id, + ) + + # Check if the result is an error from the external service if isinstance(result, dict) and result.get("type") == "error": error_msg = result.get("error", "Unknown error") error_type = result.get("error_type", "unknown") diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/models.py b/autogpt_platform/backend/backend/api/features/chat/tools/models.py index 5ff8190c31..69c8c6c684 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/models.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/models.py @@ -372,11 +372,15 @@ class OperationStartedResponse(ToolResponseBase): This is returned immediately to the client while the operation continues to execute. The user can close the tab and check back later. + + The task_id can be used to reconnect to the SSE stream via + GET /chat/tasks/{task_id}/stream?last_idx=0 """ type: ResponseType = ResponseType.OPERATION_STARTED operation_id: str tool_name: str + task_id: str | None = None # For SSE reconnection class OperationPendingResponse(ToolResponseBase): @@ -400,3 +404,20 @@ class OperationInProgressResponse(ToolResponseBase): type: ResponseType = ResponseType.OPERATION_IN_PROGRESS tool_call_id: str + + +class AsyncProcessingResponse(ToolResponseBase): + """Response when an operation has been delegated to async processing. + + This is returned by tools when the external service accepts the request + for async processing (HTTP 202 Accepted). The Redis Streams completion + consumer will handle the result when the external service completes. + + The status field is specifically "accepted" to allow the long-running tool + handler to detect this response and skip LLM continuation. + """ + + type: ResponseType = ResponseType.OPERATION_STARTED + status: str = "accepted" # Must be "accepted" for detection + operation_id: str | None = None + task_id: str | None = None diff --git a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py index 0046d0b249..cda0914809 100644 --- a/autogpt_platform/backend/backend/api/features/chat/tools/utils.py +++ b/autogpt_platform/backend/backend/api/features/chat/tools/utils.py @@ -6,9 +6,13 @@ from typing import Any from backend.api.features.library import db as library_db from backend.api.features.library import model as library_model from backend.api.features.store import db as store_db -from backend.data import graph as graph_db from backend.data.graph import GraphModel -from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput +from backend.data.model import ( + CredentialsFieldInfo, + CredentialsMetaInput, + HostScopedCredentials, + OAuth2Credentials, +) from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.util.exceptions import NotFoundError @@ -39,14 +43,8 @@ async def fetch_graph_from_store_slug( return None, None # Get the graph from store listing version - graph_meta = await store_db.get_available_graph( - store_agent.store_listing_version_id - ) - graph = await graph_db.get_graph( - graph_id=graph_meta.id, - version=graph_meta.version, - user_id=None, # Public access - include_subgraphs=True, + graph = await store_db.get_available_graph( + store_agent.store_listing_version_id, hide_nodes=False ) return graph, store_agent @@ -123,7 +121,7 @@ def build_missing_credentials_from_graph( return { field_key: _serialize_missing_credential(field_key, field_info) - for field_key, (field_info, _node_fields) in aggregated_fields.items() + for field_key, (field_info, _, _) in aggregated_fields.items() if field_key not in matched_keys } @@ -264,7 +262,8 @@ async def match_user_credentials_to_graph( # provider is in the set of acceptable providers. for credential_field_name, ( credential_requirements, - _node_fields, + _, + _, ) in aggregated_creds.items(): # Find first matching credential by provider, type, and scopes matching_cred = next( @@ -273,7 +272,14 @@ async def match_user_credentials_to_graph( for cred in available_creds if cred.provider in credential_requirements.provider and cred.type in credential_requirements.supported_types - and _credential_has_required_scopes(cred, credential_requirements) + and ( + cred.type != "oauth2" + or _credential_has_required_scopes(cred, credential_requirements) + ) + and ( + cred.type != "host_scoped" + or _credential_is_for_host(cred, credential_requirements) + ) ), None, ) @@ -318,19 +324,10 @@ async def match_user_credentials_to_graph( def _credential_has_required_scopes( - credential: Credentials, + credential: OAuth2Credentials, requirements: CredentialsFieldInfo, ) -> bool: - """ - Check if a credential has all the scopes required by the block. - - For OAuth2 credentials, verifies that the credential's scopes are a superset - of the required scopes. For other credential types, returns True (no scope check). - """ - # Only OAuth2 credentials have scopes to check - if credential.type != "oauth2": - return True - + """Check if an OAuth2 credential has all the scopes required by the input.""" # If no scopes are required, any credential matches if not requirements.required_scopes: return True @@ -339,6 +336,22 @@ def _credential_has_required_scopes( return set(credential.scopes).issuperset(requirements.required_scopes) +def _credential_is_for_host( + credential: HostScopedCredentials, + requirements: CredentialsFieldInfo, +) -> bool: + """Check if a host-scoped credential matches the host required by the input.""" + # We need to know the host to match host-scoped credentials to. + # Graph.aggregate_credentials_inputs() adds the node's set URL value (if any) + # to discriminator_values. No discriminator_values -> no host to match against. + if not requirements.discriminator_values: + return True + + # Check that credential host matches required host. + # Host-scoped credential inputs are grouped by host, so any item from the set works. + return credential.matches_url(list(requirements.discriminator_values)[0]) + + async def check_user_has_required_credentials( user_id: str, required_credentials: list[CredentialsMetaInput], diff --git a/autogpt_platform/backend/backend/api/features/library/db.py b/autogpt_platform/backend/backend/api/features/library/db.py index 394f959953..32479c18a3 100644 --- a/autogpt_platform/backend/backend/api/features/library/db.py +++ b/autogpt_platform/backend/backend/api/features/library/db.py @@ -19,7 +19,10 @@ from backend.data.graph import GraphSettings from backend.data.includes import AGENT_PRESET_INCLUDE, library_agent_include from backend.data.model import CredentialsMetaInput from backend.integrations.creds_manager import IntegrationCredentialsManager -from backend.integrations.webhooks.graph_lifecycle_hooks import on_graph_activate +from backend.integrations.webhooks.graph_lifecycle_hooks import ( + on_graph_activate, + on_graph_deactivate, +) from backend.util.clients import get_scheduler_client from backend.util.exceptions import DatabaseError, InvalidInputError, NotFoundError from backend.util.json import SafeJson @@ -371,7 +374,7 @@ async def get_library_agent_by_graph_id( async def add_generated_agent_image( - graph: graph_db.BaseGraph, + graph: graph_db.GraphBaseMeta, user_id: str, library_agent_id: str, ) -> Optional[prisma.models.LibraryAgent]: @@ -537,6 +540,92 @@ async def update_agent_version_in_library( return library_model.LibraryAgent.from_db(lib) +async def create_graph_in_library( + graph: graph_db.Graph, + user_id: str, +) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]: + """Create a new graph and add it to the user's library.""" + graph.version = 1 + graph_model = graph_db.make_graph_model(graph, user_id) + graph_model.reassign_ids(user_id=user_id, reassign_graph_id=True) + + created_graph = await graph_db.create_graph(graph_model, user_id) + + library_agents = await create_library_agent( + graph=created_graph, + user_id=user_id, + sensitive_action_safe_mode=True, + create_library_agents_for_sub_graphs=False, + ) + + if created_graph.is_active: + created_graph = await on_graph_activate(created_graph, user_id=user_id) + + return created_graph, library_agents[0] + + +async def update_graph_in_library( + graph: graph_db.Graph, + user_id: str, +) -> tuple[graph_db.GraphModel, library_model.LibraryAgent]: + """Create a new version of an existing graph and update the library entry.""" + existing_versions = await graph_db.get_graph_all_versions(graph.id, user_id) + current_active_version = ( + next((v for v in existing_versions if v.is_active), None) + if existing_versions + else None + ) + graph.version = ( + max(v.version for v in existing_versions) + 1 if existing_versions else 1 + ) + + graph_model = graph_db.make_graph_model(graph, user_id) + graph_model.reassign_ids(user_id=user_id, reassign_graph_id=False) + + created_graph = await graph_db.create_graph(graph_model, user_id) + + library_agent = await get_library_agent_by_graph_id(user_id, created_graph.id) + if not library_agent: + raise NotFoundError(f"Library agent not found for graph {created_graph.id}") + + library_agent = await update_library_agent_version_and_settings( + user_id, created_graph + ) + + if created_graph.is_active: + created_graph = await on_graph_activate(created_graph, user_id=user_id) + await graph_db.set_graph_active_version( + graph_id=created_graph.id, + version=created_graph.version, + user_id=user_id, + ) + if current_active_version: + await on_graph_deactivate(current_active_version, user_id=user_id) + + return created_graph, library_agent + + +async def update_library_agent_version_and_settings( + user_id: str, agent_graph: graph_db.GraphModel +) -> library_model.LibraryAgent: + """Update library agent to point to new graph version and sync settings.""" + library = await update_agent_version_in_library( + user_id, agent_graph.id, agent_graph.version + ) + updated_settings = GraphSettings.from_graph( + graph=agent_graph, + hitl_safe_mode=library.settings.human_in_the_loop_safe_mode, + sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode, + ) + if updated_settings != library.settings: + library = await update_library_agent( + library_agent_id=library.id, + user_id=user_id, + settings=updated_settings, + ) + return library + + async def update_library_agent( library_agent_id: str, user_id: str, diff --git a/autogpt_platform/backend/backend/api/features/store/db.py b/autogpt_platform/backend/backend/api/features/store/db.py index 850a2bc3e9..87b72d6a9c 100644 --- a/autogpt_platform/backend/backend/api/features/store/db.py +++ b/autogpt_platform/backend/backend/api/features/store/db.py @@ -1,7 +1,7 @@ import asyncio import logging from datetime import datetime, timezone -from typing import Any, Literal +from typing import Any, Literal, overload import fastapi import prisma.enums @@ -11,8 +11,8 @@ import prisma.types from backend.data.db import transaction from backend.data.graph import ( - GraphMeta, GraphModel, + GraphModelWithoutNodes, get_graph, get_graph_as_admin, get_sub_graphs, @@ -334,7 +334,22 @@ async def get_store_agent_details( raise DatabaseError("Failed to fetch agent details") from e -async def get_available_graph(store_listing_version_id: str) -> GraphMeta: +@overload +async def get_available_graph( + store_listing_version_id: str, hide_nodes: Literal[False] +) -> GraphModel: ... + + +@overload +async def get_available_graph( + store_listing_version_id: str, hide_nodes: Literal[True] = True +) -> GraphModelWithoutNodes: ... + + +async def get_available_graph( + store_listing_version_id: str, + hide_nodes: bool = True, +) -> GraphModelWithoutNodes | GraphModel: try: # Get avaialble, non-deleted store listing version store_listing_version = ( @@ -344,7 +359,7 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta: "isAvailable": True, "isDeleted": False, }, - include={"AgentGraph": {"include": {"Nodes": True}}}, + include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}}, ) ) @@ -354,7 +369,9 @@ async def get_available_graph(store_listing_version_id: str) -> GraphMeta: detail=f"Store listing version {store_listing_version_id} not found", ) - return GraphModel.from_db(store_listing_version.AgentGraph).meta() + return (GraphModelWithoutNodes if hide_nodes else GraphModel).from_db( + store_listing_version.AgentGraph + ) except Exception as e: logger.error(f"Error getting agent: {e}") diff --git a/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py b/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py index bae5b97cd6..86af457f50 100644 --- a/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py +++ b/autogpt_platform/backend/backend/api/features/store/embeddings_e2e_test.py @@ -454,6 +454,9 @@ async def test_unified_hybrid_search_pagination( cleanup_embeddings: list, ): """Test unified search pagination works correctly.""" + # Use a unique search term to avoid matching other test data + unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}" + # Create multiple items content_ids = [] for i in range(5): @@ -465,14 +468,14 @@ async def test_unified_hybrid_search_pagination( content_type=ContentType.BLOCK, content_id=content_id, embedding=mock_embedding, - searchable_text=f"pagination test item number {i}", + searchable_text=f"{unique_term} item number {i}", metadata={"index": i}, user_id=None, ) # Get first page page1_results, total1 = await unified_hybrid_search( - query="pagination test", + query=unique_term, content_types=[ContentType.BLOCK], page=1, page_size=2, @@ -480,7 +483,7 @@ async def test_unified_hybrid_search_pagination( # Get second page page2_results, total2 = await unified_hybrid_search( - query="pagination test", + query=unique_term, content_types=[ContentType.BLOCK], page=2, page_size=2, diff --git a/autogpt_platform/backend/backend/api/features/store/image_gen.py b/autogpt_platform/backend/backend/api/features/store/image_gen.py index 87b7b601df..087a7895ba 100644 --- a/autogpt_platform/backend/backend/api/features/store/image_gen.py +++ b/autogpt_platform/backend/backend/api/features/store/image_gen.py @@ -16,7 +16,7 @@ from backend.blocks.ideogram import ( StyleType, UpscaleOption, ) -from backend.data.graph import BaseGraph +from backend.data.graph import GraphBaseMeta from backend.data.model import CredentialsMetaInput, ProviderName from backend.integrations.credentials_store import ideogram_credentials from backend.util.request import Requests @@ -34,14 +34,14 @@ class ImageStyle(str, Enum): DIGITAL_ART = "digital art" -async def generate_agent_image(agent: BaseGraph | AgentGraph) -> io.BytesIO: +async def generate_agent_image(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO: if settings.config.use_agent_image_generation_v2: return await generate_agent_image_v2(graph=agent) else: return await generate_agent_image_v1(agent=agent) -async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO: +async def generate_agent_image_v2(graph: GraphBaseMeta | AgentGraph) -> io.BytesIO: """ Generate an image for an agent using Ideogram model. Returns: @@ -54,14 +54,17 @@ async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO: description = f"{name} ({graph.description})" if graph.description else name prompt = ( - f"Create a visually striking retro-futuristic vector pop art illustration prominently featuring " - f'"{name}" in bold typography. The image clearly and literally depicts a {description}, ' - f"along with recognizable objects directly associated with the primary function of a {name}. " - f"Ensure the imagery is concrete, intuitive, and immediately understandable, clearly conveying the " - f"purpose of a {name}. Maintain vibrant, limited-palette colors, sharp vector lines, geometric " - f"shapes, flat illustration techniques, and solid colors without gradients or shading. Preserve a " - f"retro-futuristic aesthetic influenced by mid-century futurism and 1960s psychedelia, " - f"prioritizing clear visual storytelling and thematic clarity above all else." + "Create a visually striking retro-futuristic vector pop art illustration " + f'prominently featuring "{name}" in bold typography. The image clearly and ' + f"literally depicts a {description}, along with recognizable objects directly " + f"associated with the primary function of a {name}. " + f"Ensure the imagery is concrete, intuitive, and immediately understandable, " + f"clearly conveying the purpose of a {name}. " + "Maintain vibrant, limited-palette colors, sharp vector lines, " + "geometric shapes, flat illustration techniques, and solid colors " + "without gradients or shading. Preserve a retro-futuristic aesthetic " + "influenced by mid-century futurism and 1960s psychedelia, " + "prioritizing clear visual storytelling and thematic clarity above all else." ) custom_colors = [ @@ -99,12 +102,12 @@ async def generate_agent_image_v2(graph: BaseGraph | AgentGraph) -> io.BytesIO: return io.BytesIO(response.content) -async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO: +async def generate_agent_image_v1(agent: GraphBaseMeta | AgentGraph) -> io.BytesIO: """ Generate an image for an agent using Flux model via Replicate API. Args: - agent (Graph): The agent to generate an image for + agent (GraphBaseMeta | AgentGraph): The agent to generate an image for Returns: io.BytesIO: The generated image as bytes @@ -114,7 +117,13 @@ async def generate_agent_image_v1(agent: BaseGraph | AgentGraph) -> io.BytesIO: raise ValueError("Missing Replicate API key in settings") # Construct prompt from agent details - prompt = f"Create a visually engaging app store thumbnail for the AI agent that highlights what it does in a clear and captivating way:\n- **Name**: {agent.name}\n- **Description**: {agent.description}\nFocus on showcasing its core functionality with an appealing design." + prompt = ( + "Create a visually engaging app store thumbnail for the AI agent " + "that highlights what it does in a clear and captivating way:\n" + f"- **Name**: {agent.name}\n" + f"- **Description**: {agent.description}\n" + f"Focus on showcasing its core functionality with an appealing design." + ) # Set up Replicate client client = ReplicateClient(api_token=settings.secrets.replicate_api_key) diff --git a/autogpt_platform/backend/backend/api/features/store/routes.py b/autogpt_platform/backend/backend/api/features/store/routes.py index 2f3c7bfb04..d93fe60f15 100644 --- a/autogpt_platform/backend/backend/api/features/store/routes.py +++ b/autogpt_platform/backend/backend/api/features/store/routes.py @@ -278,7 +278,7 @@ async def get_agent( ) async def get_graph_meta_by_store_listing_version_id( store_listing_version_id: str, -) -> backend.data.graph.GraphMeta: +) -> backend.data.graph.GraphModelWithoutNodes: """ Get Agent Graph from Store Listing Version ID. """ diff --git a/autogpt_platform/backend/backend/api/features/v1.py b/autogpt_platform/backend/backend/api/features/v1.py index 09d3759a65..a8610702cc 100644 --- a/autogpt_platform/backend/backend/api/features/v1.py +++ b/autogpt_platform/backend/backend/api/features/v1.py @@ -101,7 +101,6 @@ from backend.util.timezone_utils import ( from backend.util.virus_scanner import scan_content_safe from .library import db as library_db -from .library import model as library_model from .store.model import StoreAgentDetails @@ -823,18 +822,16 @@ async def update_graph( graph: graph_db.Graph, user_id: Annotated[str, Security(get_user_id)], ) -> graph_db.GraphModel: - # Sanity check if graph.id and graph.id != graph_id: raise HTTPException(400, detail="Graph ID does not match ID in URI") - # Determine new version existing_versions = await graph_db.get_graph_all_versions(graph_id, user_id=user_id) if not existing_versions: raise HTTPException(404, detail=f"Graph #{graph_id} not found") - latest_version_number = max(g.version for g in existing_versions) - graph.version = latest_version_number + 1 + graph.version = max(g.version for g in existing_versions) + 1 current_active_version = next((v for v in existing_versions if v.is_active), None) + graph = graph_db.make_graph_model(graph, user_id) graph.reassign_ids(user_id=user_id, reassign_graph_id=False) graph.validate_graph(for_run=False) @@ -842,27 +839,23 @@ async def update_graph( new_graph_version = await graph_db.create_graph(graph, user_id=user_id) if new_graph_version.is_active: - # Keep the library agent up to date with the new active version - await _update_library_agent_version_and_settings(user_id, new_graph_version) - - # Handle activation of the new graph first to ensure continuity + await library_db.update_library_agent_version_and_settings( + user_id, new_graph_version + ) new_graph_version = await on_graph_activate(new_graph_version, user_id=user_id) - # Ensure new version is the only active version await graph_db.set_graph_active_version( graph_id=graph_id, version=new_graph_version.version, user_id=user_id ) if current_active_version: - # Handle deactivation of the previously active version await on_graph_deactivate(current_active_version, user_id=user_id) - # Fetch new graph version *with sub-graphs* (needed for credentials input schema) new_graph_version_with_subgraphs = await graph_db.get_graph( graph_id, new_graph_version.version, user_id=user_id, include_subgraphs=True, ) - assert new_graph_version_with_subgraphs # make type checker happy + assert new_graph_version_with_subgraphs return new_graph_version_with_subgraphs @@ -900,33 +893,15 @@ async def set_graph_active_version( ) # Keep the library agent up to date with the new active version - await _update_library_agent_version_and_settings(user_id, new_active_graph) + await library_db.update_library_agent_version_and_settings( + user_id, new_active_graph + ) if current_active_graph and current_active_graph.version != new_active_version: # Handle deactivation of the previously active version await on_graph_deactivate(current_active_graph, user_id=user_id) -async def _update_library_agent_version_and_settings( - user_id: str, agent_graph: graph_db.GraphModel -) -> library_model.LibraryAgent: - library = await library_db.update_agent_version_in_library( - user_id, agent_graph.id, agent_graph.version - ) - updated_settings = GraphSettings.from_graph( - graph=agent_graph, - hitl_safe_mode=library.settings.human_in_the_loop_safe_mode, - sensitive_action_safe_mode=library.settings.sensitive_action_safe_mode, - ) - if updated_settings != library.settings: - library = await library_db.update_library_agent( - library_agent_id=library.id, - user_id=user_id, - settings=updated_settings, - ) - return library - - @v1_router.patch( path="/graphs/{graph_id}/settings", summary="Update graph settings", diff --git a/autogpt_platform/backend/backend/api/rest_api.py b/autogpt_platform/backend/backend/api/rest_api.py index b936312ce1..0eef76193e 100644 --- a/autogpt_platform/backend/backend/api/rest_api.py +++ b/autogpt_platform/backend/backend/api/rest_api.py @@ -40,6 +40,10 @@ import backend.data.user import backend.integrations.webhooks.utils import backend.util.service import backend.util.settings +from backend.api.features.chat.completion_consumer import ( + start_completion_consumer, + stop_completion_consumer, +) from backend.blocks.llm import DEFAULT_LLM_MODEL from backend.data.model import Credentials from backend.integrations.providers import ProviderName @@ -118,9 +122,21 @@ async def lifespan_context(app: fastapi.FastAPI): await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL) await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs() + # Start chat completion consumer for Redis Streams notifications + try: + await start_completion_consumer() + except Exception as e: + logger.warning(f"Could not start chat completion consumer: {e}") + with launch_darkly_context(): yield + # Stop chat completion consumer + try: + await stop_completion_consumer() + except Exception as e: + logger.warning(f"Error stopping chat completion consumer: {e}") + try: await shutdown_cloud_storage_handler() except Exception as e: diff --git a/autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py b/autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py new file mode 100644 index 0000000000..b823627b43 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py @@ -0,0 +1,28 @@ +"""ElevenLabs integration blocks - test credentials and shared utilities.""" + +from typing import Literal + +from pydantic import SecretStr + +from backend.data.model import APIKeyCredentials, CredentialsMetaInput +from backend.integrations.providers import ProviderName + +TEST_CREDENTIALS = APIKeyCredentials( + id="01234567-89ab-cdef-0123-456789abcdef", + provider="elevenlabs", + api_key=SecretStr("mock-elevenlabs-api-key"), + title="Mock ElevenLabs API key", + expires_at=None, +) + +TEST_CREDENTIALS_INPUT = { + "provider": TEST_CREDENTIALS.provider, + "id": TEST_CREDENTIALS.id, + "type": TEST_CREDENTIALS.type, + "title": TEST_CREDENTIALS.title, +} + +ElevenLabsCredentials = APIKeyCredentials +ElevenLabsCredentialsInput = CredentialsMetaInput[ + Literal[ProviderName.ELEVENLABS], Literal["api_key"] +] diff --git a/autogpt_platform/backend/backend/blocks/encoder_block.py b/autogpt_platform/backend/backend/blocks/encoder_block.py new file mode 100644 index 0000000000..b60a4ae828 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/encoder_block.py @@ -0,0 +1,77 @@ +"""Text encoding block for converting special characters to escape sequences.""" + +import codecs + +from backend.data.block import ( + Block, + BlockCategory, + BlockOutput, + BlockSchemaInput, + BlockSchemaOutput, +) +from backend.data.model import SchemaField + + +class TextEncoderBlock(Block): + """ + Encodes a string by converting special characters into escape sequences. + + This block is the inverse of TextDecoderBlock. It takes text containing + special characters (like newlines, tabs, etc.) and converts them into + their escape sequence representations (e.g., newline becomes \\n). + """ + + class Input(BlockSchemaInput): + """Input schema for TextEncoderBlock.""" + + text: str = SchemaField( + description="A string containing special characters to be encoded", + placeholder="Your text with newlines and quotes to encode", + ) + + class Output(BlockSchemaOutput): + """Output schema for TextEncoderBlock.""" + + encoded_text: str = SchemaField( + description="The encoded text with special characters converted to escape sequences" + ) + error: str = SchemaField(description="Error message if encoding fails") + + def __init__(self): + super().__init__( + id="5185f32e-4b65-4ecf-8fbb-873f003f09d6", + description="Encodes a string by converting special characters into escape sequences", + categories={BlockCategory.TEXT}, + input_schema=TextEncoderBlock.Input, + output_schema=TextEncoderBlock.Output, + test_input={ + "text": """Hello +World! +This is a "quoted" string.""" + }, + test_output=[ + ( + "encoded_text", + """Hello\\nWorld!\\nThis is a "quoted" string.""", + ) + ], + ) + + async def run(self, input_data: Input, **kwargs) -> BlockOutput: + """ + Encode the input text by converting special characters to escape sequences. + + Args: + input_data: The input containing the text to encode. + **kwargs: Additional keyword arguments (unused). + + Yields: + The encoded text with escape sequences, or an error message if encoding fails. + """ + try: + encoded_text = codecs.encode(input_data.text, "unicode_escape").decode( + "utf-8" + ) + yield "encoded_text", encoded_text + except Exception as e: + yield "error", f"Encoding error: {str(e)}" diff --git a/autogpt_platform/backend/backend/blocks/linear/_api.py b/autogpt_platform/backend/backend/blocks/linear/_api.py index 477b8a209c..ea609d515a 100644 --- a/autogpt_platform/backend/backend/blocks/linear/_api.py +++ b/autogpt_platform/backend/backend/blocks/linear/_api.py @@ -162,8 +162,16 @@ class LinearClient: "searchTerm": team_name, } - team_id = await self.query(query, variables) - return team_id["teams"]["nodes"][0]["id"] + result = await self.query(query, variables) + nodes = result["teams"]["nodes"] + + if not nodes: + raise LinearAPIException( + f"Team '{team_name}' not found. Check the team name or key and try again.", + status_code=404, + ) + + return nodes[0]["id"] except LinearAPIException as e: raise e @@ -240,17 +248,44 @@ class LinearClient: except LinearAPIException as e: raise e - async def try_search_issues(self, term: str) -> list[Issue]: + async def try_search_issues( + self, + term: str, + max_results: int = 10, + team_id: str | None = None, + ) -> list[Issue]: try: query = """ - query SearchIssues($term: String!, $includeComments: Boolean!) { - searchIssues(term: $term, includeComments: $includeComments) { + query SearchIssues( + $term: String!, + $first: Int, + $teamId: String + ) { + searchIssues( + term: $term, + first: $first, + teamId: $teamId + ) { nodes { id identifier title description priority + createdAt + state { + id + name + type + } + project { + id + name + } + assignee { + id + name + } } } } @@ -258,7 +293,8 @@ class LinearClient: variables: dict[str, Any] = { "term": term, - "includeComments": True, + "first": max_results, + "teamId": team_id, } issues = await self.query(query, variables) diff --git a/autogpt_platform/backend/backend/blocks/linear/issues.py b/autogpt_platform/backend/backend/blocks/linear/issues.py index baac01214c..165178f8ee 100644 --- a/autogpt_platform/backend/backend/blocks/linear/issues.py +++ b/autogpt_platform/backend/backend/blocks/linear/issues.py @@ -17,7 +17,7 @@ from ._config import ( LinearScope, linear, ) -from .models import CreateIssueResponse, Issue +from .models import CreateIssueResponse, Issue, State class LinearCreateIssueBlock(Block): @@ -135,9 +135,20 @@ class LinearSearchIssuesBlock(Block): description="Linear credentials with read permissions", required_scopes={LinearScope.READ}, ) + max_results: int = SchemaField( + description="Maximum number of results to return", + default=10, + ge=1, + le=100, + ) + team_name: str | None = SchemaField( + description="Optional team name to filter results (e.g., 'Internal', 'Open Source')", + default=None, + ) class Output(BlockSchemaOutput): issues: list[Issue] = SchemaField(description="List of issues") + error: str = SchemaField(description="Error message if the search failed") def __init__(self): super().__init__( @@ -145,8 +156,11 @@ class LinearSearchIssuesBlock(Block): description="Searches for issues on Linear", input_schema=self.Input, output_schema=self.Output, + categories={BlockCategory.PRODUCTIVITY, BlockCategory.ISSUE_TRACKING}, test_input={ "term": "Test issue", + "max_results": 10, + "team_name": None, "credentials": TEST_CREDENTIALS_INPUT_OAUTH, }, test_credentials=TEST_CREDENTIALS_OAUTH, @@ -156,10 +170,14 @@ class LinearSearchIssuesBlock(Block): [ Issue( id="abc123", - identifier="abc123", + identifier="TST-123", title="Test issue", description="Test description", priority=1, + state=State( + id="state1", name="In Progress", type="started" + ), + createdAt="2026-01-15T10:00:00.000Z", ) ], ) @@ -168,10 +186,12 @@ class LinearSearchIssuesBlock(Block): "search_issues": lambda *args, **kwargs: [ Issue( id="abc123", - identifier="abc123", + identifier="TST-123", title="Test issue", description="Test description", priority=1, + state=State(id="state1", name="In Progress", type="started"), + createdAt="2026-01-15T10:00:00.000Z", ) ] }, @@ -181,10 +201,22 @@ class LinearSearchIssuesBlock(Block): async def search_issues( credentials: OAuth2Credentials | APIKeyCredentials, term: str, + max_results: int = 10, + team_name: str | None = None, ) -> list[Issue]: client = LinearClient(credentials=credentials) - response: list[Issue] = await client.try_search_issues(term=term) - return response + + # Resolve team name to ID if provided + # Raises LinearAPIException with descriptive message if team not found + team_id: str | None = None + if team_name: + team_id = await client.try_get_team_by_name(team_name=team_name) + + return await client.try_search_issues( + term=term, + max_results=max_results, + team_id=team_id, + ) async def run( self, @@ -196,7 +228,10 @@ class LinearSearchIssuesBlock(Block): """Execute the issue search""" try: issues = await self.search_issues( - credentials=credentials, term=input_data.term + credentials=credentials, + term=input_data.term, + max_results=input_data.max_results, + team_name=input_data.team_name, ) yield "issues", issues except LinearAPIException as e: diff --git a/autogpt_platform/backend/backend/blocks/linear/models.py b/autogpt_platform/backend/backend/blocks/linear/models.py index bfeaa13656..dd1f603459 100644 --- a/autogpt_platform/backend/backend/blocks/linear/models.py +++ b/autogpt_platform/backend/backend/blocks/linear/models.py @@ -36,12 +36,21 @@ class Project(BaseModel): content: str | None = None +class State(BaseModel): + id: str + name: str + type: str | None = ( + None # Workflow state type (e.g., "triage", "backlog", "started", "completed", "canceled") + ) + + class Issue(BaseModel): id: str identifier: str title: str description: str | None priority: int + state: State | None = None project: Project | None = None createdAt: str | None = None comments: list[Comment] | None = None diff --git a/autogpt_platform/backend/backend/blocks/llm.py b/autogpt_platform/backend/backend/blocks/llm.py index 54295da1f1..be2b85949e 100644 --- a/autogpt_platform/backend/backend/blocks/llm.py +++ b/autogpt_platform/backend/backend/blocks/llm.py @@ -115,6 +115,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta): CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101" CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929" CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001" + CLAUDE_4_6_OPUS = "claude-opus-4-6" CLAUDE_3_HAIKU = "claude-3-haiku-20240307" # AI/ML API models AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo" @@ -270,6 +271,9 @@ MODEL_METADATA = { LlmModel.CLAUDE_4_SONNET: ModelMetadata( "anthropic", 200000, 64000, "Claude Sonnet 4", "Anthropic", "Anthropic", 2 ), # claude-4-sonnet-20250514 + LlmModel.CLAUDE_4_6_OPUS: ModelMetadata( + "anthropic", 200000, 128000, "Claude Opus 4.6", "Anthropic", "Anthropic", 3 + ), # claude-opus-4-6 LlmModel.CLAUDE_4_5_OPUS: ModelMetadata( "anthropic", 200000, 64000, "Claude Opus 4.5", "Anthropic", "Anthropic", 3 ), # claude-opus-4-5-20251101 diff --git a/autogpt_platform/backend/backend/blocks/media.py b/autogpt_platform/backend/backend/blocks/media.py deleted file mode 100644 index a8d145bc64..0000000000 --- a/autogpt_platform/backend/backend/blocks/media.py +++ /dev/null @@ -1,246 +0,0 @@ -import os -import tempfile -from typing import Optional - -from moviepy.audio.io.AudioFileClip import AudioFileClip -from moviepy.video.fx.Loop import Loop -from moviepy.video.io.VideoFileClip import VideoFileClip - -from backend.data.block import ( - Block, - BlockCategory, - BlockOutput, - BlockSchemaInput, - BlockSchemaOutput, -) -from backend.data.execution import ExecutionContext -from backend.data.model import SchemaField -from backend.util.file import MediaFileType, get_exec_file_path, store_media_file - - -class MediaDurationBlock(Block): - - class Input(BlockSchemaInput): - media_in: MediaFileType = SchemaField( - description="Media input (URL, data URI, or local path)." - ) - is_video: bool = SchemaField( - description="Whether the media is a video (True) or audio (False).", - default=True, - ) - - class Output(BlockSchemaOutput): - duration: float = SchemaField( - description="Duration of the media file (in seconds)." - ) - - def __init__(self): - super().__init__( - id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6", - description="Block to get the duration of a media file.", - categories={BlockCategory.MULTIMEDIA}, - input_schema=MediaDurationBlock.Input, - output_schema=MediaDurationBlock.Output, - ) - - async def run( - self, - input_data: Input, - *, - execution_context: ExecutionContext, - **kwargs, - ) -> BlockOutput: - # 1) Store the input media locally - local_media_path = await store_media_file( - file=input_data.media_in, - execution_context=execution_context, - return_format="for_local_processing", - ) - assert execution_context.graph_exec_id is not None - media_abspath = get_exec_file_path( - execution_context.graph_exec_id, local_media_path - ) - - # 2) Load the clip - if input_data.is_video: - clip = VideoFileClip(media_abspath) - else: - clip = AudioFileClip(media_abspath) - - yield "duration", clip.duration - - -class LoopVideoBlock(Block): - """ - Block for looping (repeating) a video clip until a given duration or number of loops. - """ - - class Input(BlockSchemaInput): - video_in: MediaFileType = SchemaField( - description="The input video (can be a URL, data URI, or local path)." - ) - # Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`. - duration: Optional[float] = SchemaField( - description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.", - default=None, - ge=0.0, - ) - n_loops: Optional[int] = SchemaField( - description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).", - default=None, - ge=1, - ) - - class Output(BlockSchemaOutput): - video_out: str = SchemaField( - description="Looped video returned either as a relative path or a data URI." - ) - - def __init__(self): - super().__init__( - id="8bf9eef6-5451-4213-b265-25306446e94b", - description="Block to loop a video to a given duration or number of repeats.", - categories={BlockCategory.MULTIMEDIA}, - input_schema=LoopVideoBlock.Input, - output_schema=LoopVideoBlock.Output, - ) - - async def run( - self, - input_data: Input, - *, - execution_context: ExecutionContext, - **kwargs, - ) -> BlockOutput: - assert execution_context.graph_exec_id is not None - assert execution_context.node_exec_id is not None - graph_exec_id = execution_context.graph_exec_id - node_exec_id = execution_context.node_exec_id - - # 1) Store the input video locally - local_video_path = await store_media_file( - file=input_data.video_in, - execution_context=execution_context, - return_format="for_local_processing", - ) - input_abspath = get_exec_file_path(graph_exec_id, local_video_path) - - # 2) Load the clip - clip = VideoFileClip(input_abspath) - - # 3) Apply the loop effect - looped_clip = clip - if input_data.duration: - # Loop until we reach the specified duration - looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)]) - elif input_data.n_loops: - looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)]) - else: - raise ValueError("Either 'duration' or 'n_loops' must be provided.") - - assert isinstance(looped_clip, VideoFileClip) - - # 4) Save the looped output - output_filename = MediaFileType( - f"{node_exec_id}_looped_{os.path.basename(local_video_path)}" - ) - output_abspath = get_exec_file_path(graph_exec_id, output_filename) - - looped_clip = looped_clip.with_audio(clip.audio) - looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac") - - # Return output - for_block_output returns workspace:// if available, else data URI - video_out = await store_media_file( - file=output_filename, - execution_context=execution_context, - return_format="for_block_output", - ) - - yield "video_out", video_out - - -class AddAudioToVideoBlock(Block): - """ - Block that adds (attaches) an audio track to an existing video. - Optionally scale the volume of the new track. - """ - - class Input(BlockSchemaInput): - video_in: MediaFileType = SchemaField( - description="Video input (URL, data URI, or local path)." - ) - audio_in: MediaFileType = SchemaField( - description="Audio input (URL, data URI, or local path)." - ) - volume: float = SchemaField( - description="Volume scale for the newly attached audio track (1.0 = original).", - default=1.0, - ) - - class Output(BlockSchemaOutput): - video_out: MediaFileType = SchemaField( - description="Final video (with attached audio), as a path or data URI." - ) - - def __init__(self): - super().__init__( - id="3503748d-62b6-4425-91d6-725b064af509", - description="Block to attach an audio file to a video file using moviepy.", - categories={BlockCategory.MULTIMEDIA}, - input_schema=AddAudioToVideoBlock.Input, - output_schema=AddAudioToVideoBlock.Output, - ) - - async def run( - self, - input_data: Input, - *, - execution_context: ExecutionContext, - **kwargs, - ) -> BlockOutput: - assert execution_context.graph_exec_id is not None - assert execution_context.node_exec_id is not None - graph_exec_id = execution_context.graph_exec_id - node_exec_id = execution_context.node_exec_id - - # 1) Store the inputs locally - local_video_path = await store_media_file( - file=input_data.video_in, - execution_context=execution_context, - return_format="for_local_processing", - ) - local_audio_path = await store_media_file( - file=input_data.audio_in, - execution_context=execution_context, - return_format="for_local_processing", - ) - - abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id) - video_abspath = os.path.join(abs_temp_dir, local_video_path) - audio_abspath = os.path.join(abs_temp_dir, local_audio_path) - - # 2) Load video + audio with moviepy - video_clip = VideoFileClip(video_abspath) - audio_clip = AudioFileClip(audio_abspath) - # Optionally scale volume - if input_data.volume != 1.0: - audio_clip = audio_clip.with_volume_scaled(input_data.volume) - - # 3) Attach the new audio track - final_clip = video_clip.with_audio(audio_clip) - - # 4) Write to output file - output_filename = MediaFileType( - f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}" - ) - output_abspath = os.path.join(abs_temp_dir, output_filename) - final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac") - - # 5) Return output - for_block_output returns workspace:// if available, else data URI - video_out = await store_media_file( - file=output_filename, - execution_context=execution_context, - return_format="for_block_output", - ) - - yield "video_out", video_out diff --git a/autogpt_platform/backend/backend/blocks/stagehand/blocks.py b/autogpt_platform/backend/backend/blocks/stagehand/blocks.py index 4d5d6bf4f3..91c096ffe4 100644 --- a/autogpt_platform/backend/backend/blocks/stagehand/blocks.py +++ b/autogpt_platform/backend/backend/blocks/stagehand/blocks.py @@ -182,10 +182,7 @@ class StagehandObserveBlock(Block): **kwargs, ) -> BlockOutput: - logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}") - logger.info( - f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}" - ) + logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}") with disable_signal_handling(): stagehand = Stagehand( @@ -282,10 +279,7 @@ class StagehandActBlock(Block): **kwargs, ) -> BlockOutput: - logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}") - logger.info( - f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}" - ) + logger.debug(f"ACT: Using model provider {model_credentials.provider}") with disable_signal_handling(): stagehand = Stagehand( @@ -370,10 +364,7 @@ class StagehandExtractBlock(Block): **kwargs, ) -> BlockOutput: - logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}") - logger.info( - f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}" - ) + logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}") with disable_signal_handling(): stagehand = Stagehand( diff --git a/autogpt_platform/backend/backend/blocks/test/test_text_encoder.py b/autogpt_platform/backend/backend/blocks/test/test_text_encoder.py new file mode 100644 index 0000000000..1e9b9fed4f --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/test/test_text_encoder.py @@ -0,0 +1,77 @@ +import pytest + +from backend.blocks.encoder_block import TextEncoderBlock + + +@pytest.mark.asyncio +async def test_text_encoder_basic(): + """Test basic encoding of newlines and special characters.""" + block = TextEncoderBlock() + result = [] + async for output in block.run(TextEncoderBlock.Input(text="Hello\nWorld")): + result.append(output) + + assert len(result) == 1 + assert result[0][0] == "encoded_text" + assert result[0][1] == "Hello\\nWorld" + + +@pytest.mark.asyncio +async def test_text_encoder_multiple_escapes(): + """Test encoding of multiple escape sequences.""" + block = TextEncoderBlock() + result = [] + async for output in block.run( + TextEncoderBlock.Input(text="Line1\nLine2\tTabbed\rCarriage") + ): + result.append(output) + + assert len(result) == 1 + assert result[0][0] == "encoded_text" + assert "\\n" in result[0][1] + assert "\\t" in result[0][1] + assert "\\r" in result[0][1] + + +@pytest.mark.asyncio +async def test_text_encoder_unicode(): + """Test that unicode characters are handled correctly.""" + block = TextEncoderBlock() + result = [] + async for output in block.run(TextEncoderBlock.Input(text="Hello 世界\n")): + result.append(output) + + assert len(result) == 1 + assert result[0][0] == "encoded_text" + # Unicode characters should be escaped as \uXXXX sequences + assert "\\n" in result[0][1] + + +@pytest.mark.asyncio +async def test_text_encoder_empty_string(): + """Test encoding of an empty string.""" + block = TextEncoderBlock() + result = [] + async for output in block.run(TextEncoderBlock.Input(text="")): + result.append(output) + + assert len(result) == 1 + assert result[0][0] == "encoded_text" + assert result[0][1] == "" + + +@pytest.mark.asyncio +async def test_text_encoder_error_handling(): + """Test that encoding errors are handled gracefully.""" + from unittest.mock import patch + + block = TextEncoderBlock() + result = [] + + with patch("codecs.encode", side_effect=Exception("Mocked encoding error")): + async for output in block.run(TextEncoderBlock.Input(text="test")): + result.append(output) + + assert len(result) == 1 + assert result[0][0] == "error" + assert "Mocked encoding error" in result[0][1] diff --git a/autogpt_platform/backend/backend/blocks/video/__init__.py b/autogpt_platform/backend/backend/blocks/video/__init__.py new file mode 100644 index 0000000000..4974ae8a87 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/video/__init__.py @@ -0,0 +1,37 @@ +"""Video editing blocks for AutoGPT Platform. + +This module provides blocks for: +- Downloading videos from URLs (YouTube, Vimeo, news sites, direct links) +- Clipping/trimming video segments +- Concatenating multiple videos +- Adding text overlays +- Adding AI-generated narration +- Getting media duration +- Looping videos +- Adding audio to videos + +Dependencies: +- yt-dlp: For video downloading +- moviepy: For video editing operations +- elevenlabs: For AI narration (optional) +""" + +from backend.blocks.video.add_audio import AddAudioToVideoBlock +from backend.blocks.video.clip import VideoClipBlock +from backend.blocks.video.concat import VideoConcatBlock +from backend.blocks.video.download import VideoDownloadBlock +from backend.blocks.video.duration import MediaDurationBlock +from backend.blocks.video.loop import LoopVideoBlock +from backend.blocks.video.narration import VideoNarrationBlock +from backend.blocks.video.text_overlay import VideoTextOverlayBlock + +__all__ = [ + "AddAudioToVideoBlock", + "LoopVideoBlock", + "MediaDurationBlock", + "VideoClipBlock", + "VideoConcatBlock", + "VideoDownloadBlock", + "VideoNarrationBlock", + "VideoTextOverlayBlock", +] diff --git a/autogpt_platform/backend/backend/blocks/video/_utils.py b/autogpt_platform/backend/backend/blocks/video/_utils.py new file mode 100644 index 0000000000..9ebf195078 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/video/_utils.py @@ -0,0 +1,131 @@ +"""Shared utilities for video blocks.""" + +from __future__ import annotations + +import logging +import os +import re +import subprocess +from pathlib import Path + +logger = logging.getLogger(__name__) + +# Known operation tags added by video blocks +_VIDEO_OPS = ( + r"(?:clip|overlay|narrated|looped|concat|audio_attached|with_audio|narration)" +) + +# Matches: {node_exec_id}_{operation}_ where node_exec_id contains a UUID +_BLOCK_PREFIX_RE = re.compile( + r"^[a-zA-Z0-9_-]*" + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" + r"[a-zA-Z0-9_-]*" + r"_" + _VIDEO_OPS + r"_" +) + +# Matches: a lone {node_exec_id}_ prefix (no operation keyword, e.g. download output) +_UUID_PREFIX_RE = re.compile( + r"^[a-zA-Z0-9_-]*" + r"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}" + r"[a-zA-Z0-9_-]*_" +) + + +def extract_source_name(input_path: str, max_length: int = 50) -> str: + """Extract the original source filename by stripping block-generated prefixes. + + Iteratively removes {node_exec_id}_{operation}_ prefixes that accumulate + when chaining video blocks, recovering the original human-readable name. + + Safe for plain filenames (no UUID -> no stripping). + Falls back to "video" if everything is stripped. + """ + stem = Path(input_path).stem + + # Pass 1: strip {node_exec_id}_{operation}_ prefixes iteratively + while _BLOCK_PREFIX_RE.match(stem): + stem = _BLOCK_PREFIX_RE.sub("", stem, count=1) + + # Pass 2: strip a lone {node_exec_id}_ prefix (e.g. from download block) + if _UUID_PREFIX_RE.match(stem): + stem = _UUID_PREFIX_RE.sub("", stem, count=1) + + if not stem: + return "video" + + return stem[:max_length] + + +def get_video_codecs(output_path: str) -> tuple[str, str]: + """Get appropriate video and audio codecs based on output file extension. + + Args: + output_path: Path to the output file (used to determine extension) + + Returns: + Tuple of (video_codec, audio_codec) + + Codec mappings: + - .mp4: H.264 + AAC (universal compatibility) + - .webm: VP8 + Vorbis (web streaming) + - .mkv: H.264 + AAC (container supports many codecs) + - .mov: H.264 + AAC (Apple QuickTime, widely compatible) + - .m4v: H.264 + AAC (Apple iTunes/devices) + - .avi: MPEG-4 + MP3 (legacy Windows) + """ + ext = os.path.splitext(output_path)[1].lower() + + codec_map: dict[str, tuple[str, str]] = { + ".mp4": ("libx264", "aac"), + ".webm": ("libvpx", "libvorbis"), + ".mkv": ("libx264", "aac"), + ".mov": ("libx264", "aac"), + ".m4v": ("libx264", "aac"), + ".avi": ("mpeg4", "libmp3lame"), + } + + return codec_map.get(ext, ("libx264", "aac")) + + +def strip_chapters_inplace(video_path: str) -> None: + """Strip chapter metadata from a media file in-place using ffmpeg. + + MoviePy 2.x crashes with IndexError when parsing files with embedded + chapter metadata (https://github.com/Zulko/moviepy/issues/2419). + This strips chapters without re-encoding. + + Args: + video_path: Absolute path to the media file to strip chapters from. + """ + base, ext = os.path.splitext(video_path) + tmp_path = base + ".tmp" + ext + try: + result = subprocess.run( + [ + "ffmpeg", + "-y", + "-i", + video_path, + "-map_chapters", + "-1", + "-codec", + "copy", + tmp_path, + ], + capture_output=True, + text=True, + timeout=300, + ) + if result.returncode != 0: + logger.warning( + "ffmpeg chapter strip failed (rc=%d): %s", + result.returncode, + result.stderr, + ) + return + os.replace(tmp_path, video_path) + except FileNotFoundError: + logger.warning("ffmpeg not found; skipping chapter strip") + finally: + if os.path.exists(tmp_path): + os.unlink(tmp_path) diff --git a/autogpt_platform/backend/backend/blocks/video/add_audio.py b/autogpt_platform/backend/backend/blocks/video/add_audio.py new file mode 100644 index 0000000000..ebd4ab94f2 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/video/add_audio.py @@ -0,0 +1,113 @@ +"""AddAudioToVideoBlock - Attach an audio track to a video file.""" + +from moviepy.audio.io.AudioFileClip import AudioFileClip +from moviepy.video.io.VideoFileClip import VideoFileClip + +from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace +from backend.data.block import ( + Block, + BlockCategory, + BlockOutput, + BlockSchemaInput, + BlockSchemaOutput, +) +from backend.data.execution import ExecutionContext +from backend.data.model import SchemaField +from backend.util.file import MediaFileType, get_exec_file_path, store_media_file + + +class AddAudioToVideoBlock(Block): + """Add (attach) an audio track to an existing video.""" + + class Input(BlockSchemaInput): + video_in: MediaFileType = SchemaField( + description="Video input (URL, data URI, or local path)." + ) + audio_in: MediaFileType = SchemaField( + description="Audio input (URL, data URI, or local path)." + ) + volume: float = SchemaField( + description="Volume scale for the newly attached audio track (1.0 = original).", + default=1.0, + ) + + class Output(BlockSchemaOutput): + video_out: MediaFileType = SchemaField( + description="Final video (with attached audio), as a path or data URI." + ) + + def __init__(self): + super().__init__( + id="3503748d-62b6-4425-91d6-725b064af509", + description="Block to attach an audio file to a video file using moviepy.", + categories={BlockCategory.MULTIMEDIA}, + input_schema=AddAudioToVideoBlock.Input, + output_schema=AddAudioToVideoBlock.Output, + ) + + async def run( + self, + input_data: Input, + *, + execution_context: ExecutionContext, + **kwargs, + ) -> BlockOutput: + assert execution_context.graph_exec_id is not None + assert execution_context.node_exec_id is not None + graph_exec_id = execution_context.graph_exec_id + node_exec_id = execution_context.node_exec_id + + # 1) Store the inputs locally + local_video_path = await store_media_file( + file=input_data.video_in, + execution_context=execution_context, + return_format="for_local_processing", + ) + local_audio_path = await store_media_file( + file=input_data.audio_in, + execution_context=execution_context, + return_format="for_local_processing", + ) + + video_abspath = get_exec_file_path(graph_exec_id, local_video_path) + audio_abspath = get_exec_file_path(graph_exec_id, local_audio_path) + + # 2) Load video + audio with moviepy + strip_chapters_inplace(video_abspath) + strip_chapters_inplace(audio_abspath) + video_clip = None + audio_clip = None + final_clip = None + try: + video_clip = VideoFileClip(video_abspath) + audio_clip = AudioFileClip(audio_abspath) + # Optionally scale volume + if input_data.volume != 1.0: + audio_clip = audio_clip.with_volume_scaled(input_data.volume) + + # 3) Attach the new audio track + final_clip = video_clip.with_audio(audio_clip) + + # 4) Write to output file + source = extract_source_name(local_video_path) + output_filename = MediaFileType(f"{node_exec_id}_with_audio_{source}.mp4") + output_abspath = get_exec_file_path(graph_exec_id, output_filename) + final_clip.write_videofile( + output_abspath, codec="libx264", audio_codec="aac" + ) + finally: + if final_clip: + final_clip.close() + if audio_clip: + audio_clip.close() + if video_clip: + video_clip.close() + + # 5) Return output - for_block_output returns workspace:// if available, else data URI + video_out = await store_media_file( + file=output_filename, + execution_context=execution_context, + return_format="for_block_output", + ) + + yield "video_out", video_out diff --git a/autogpt_platform/backend/backend/blocks/video/clip.py b/autogpt_platform/backend/backend/blocks/video/clip.py new file mode 100644 index 0000000000..05deea6530 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/video/clip.py @@ -0,0 +1,167 @@ +"""VideoClipBlock - Extract a segment from a video file.""" + +from typing import Literal + +from moviepy.video.io.VideoFileClip import VideoFileClip + +from backend.blocks.video._utils import ( + extract_source_name, + get_video_codecs, + strip_chapters_inplace, +) +from backend.data.block import ( + Block, + BlockCategory, + BlockOutput, + BlockSchemaInput, + BlockSchemaOutput, +) +from backend.data.execution import ExecutionContext +from backend.data.model import SchemaField +from backend.util.exceptions import BlockExecutionError +from backend.util.file import MediaFileType, get_exec_file_path, store_media_file + + +class VideoClipBlock(Block): + """Extract a time segment from a video.""" + + class Input(BlockSchemaInput): + video_in: MediaFileType = SchemaField( + description="Input video (URL, data URI, or local path)" + ) + start_time: float = SchemaField(description="Start time in seconds", ge=0.0) + end_time: float = SchemaField(description="End time in seconds", ge=0.0) + output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField( + description="Output format", default="mp4", advanced=True + ) + + class Output(BlockSchemaOutput): + video_out: MediaFileType = SchemaField( + description="Clipped video file (path or data URI)" + ) + duration: float = SchemaField(description="Clip duration in seconds") + + def __init__(self): + super().__init__( + id="8f539119-e580-4d86-ad41-86fbcb22abb1", + description="Extract a time segment from a video", + categories={BlockCategory.MULTIMEDIA}, + input_schema=self.Input, + output_schema=self.Output, + test_input={ + "video_in": "/tmp/test.mp4", + "start_time": 0.0, + "end_time": 10.0, + }, + test_output=[("video_out", str), ("duration", float)], + test_mock={ + "_clip_video": lambda *args: 10.0, + "_store_input_video": lambda *args, **kwargs: "test.mp4", + "_store_output_video": lambda *args, **kwargs: "clip_test.mp4", + }, + ) + + async def _store_input_video( + self, execution_context: ExecutionContext, file: MediaFileType + ) -> MediaFileType: + """Store input video. Extracted for testability.""" + return await store_media_file( + file=file, + execution_context=execution_context, + return_format="for_local_processing", + ) + + async def _store_output_video( + self, execution_context: ExecutionContext, file: MediaFileType + ) -> MediaFileType: + """Store output video. Extracted for testability.""" + return await store_media_file( + file=file, + execution_context=execution_context, + return_format="for_block_output", + ) + + def _clip_video( + self, + video_abspath: str, + output_abspath: str, + start_time: float, + end_time: float, + ) -> float: + """Extract a clip from a video. Extracted for testability.""" + clip = None + subclip = None + try: + strip_chapters_inplace(video_abspath) + clip = VideoFileClip(video_abspath) + subclip = clip.subclipped(start_time, end_time) + video_codec, audio_codec = get_video_codecs(output_abspath) + subclip.write_videofile( + output_abspath, codec=video_codec, audio_codec=audio_codec + ) + return subclip.duration + finally: + if subclip: + subclip.close() + if clip: + clip.close() + + async def run( + self, + input_data: Input, + *, + execution_context: ExecutionContext, + node_exec_id: str, + **kwargs, + ) -> BlockOutput: + # Validate time range + if input_data.end_time <= input_data.start_time: + raise BlockExecutionError( + message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})", + block_name=self.name, + block_id=str(self.id), + ) + + try: + assert execution_context.graph_exec_id is not None + + # Store the input video locally + local_video_path = await self._store_input_video( + execution_context, input_data.video_in + ) + video_abspath = get_exec_file_path( + execution_context.graph_exec_id, local_video_path + ) + + # Build output path + source = extract_source_name(local_video_path) + output_filename = MediaFileType( + f"{node_exec_id}_clip_{source}.{input_data.output_format}" + ) + output_abspath = get_exec_file_path( + execution_context.graph_exec_id, output_filename + ) + + duration = self._clip_video( + video_abspath, + output_abspath, + input_data.start_time, + input_data.end_time, + ) + + # Return as workspace path or data URI based on context + video_out = await self._store_output_video( + execution_context, output_filename + ) + + yield "video_out", video_out + yield "duration", duration + + except BlockExecutionError: + raise + except Exception as e: + raise BlockExecutionError( + message=f"Failed to clip video: {e}", + block_name=self.name, + block_id=str(self.id), + ) from e diff --git a/autogpt_platform/backend/backend/blocks/video/concat.py b/autogpt_platform/backend/backend/blocks/video/concat.py new file mode 100644 index 0000000000..b49854fb40 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/video/concat.py @@ -0,0 +1,227 @@ +"""VideoConcatBlock - Concatenate multiple video clips into one.""" + +from typing import Literal + +from moviepy import concatenate_videoclips +from moviepy.video.fx import CrossFadeIn, CrossFadeOut, FadeIn, FadeOut +from moviepy.video.io.VideoFileClip import VideoFileClip + +from backend.blocks.video._utils import ( + extract_source_name, + get_video_codecs, + strip_chapters_inplace, +) +from backend.data.block import ( + Block, + BlockCategory, + BlockOutput, + BlockSchemaInput, + BlockSchemaOutput, +) +from backend.data.execution import ExecutionContext +from backend.data.model import SchemaField +from backend.util.exceptions import BlockExecutionError +from backend.util.file import MediaFileType, get_exec_file_path, store_media_file + + +class VideoConcatBlock(Block): + """Merge multiple video clips into one continuous video.""" + + class Input(BlockSchemaInput): + videos: list[MediaFileType] = SchemaField( + description="List of video files to concatenate (in order)" + ) + transition: Literal["none", "crossfade", "fade_black"] = SchemaField( + description="Transition between clips", default="none" + ) + transition_duration: int = SchemaField( + description="Transition duration in seconds", + default=1, + ge=0, + advanced=True, + ) + output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField( + description="Output format", default="mp4", advanced=True + ) + + class Output(BlockSchemaOutput): + video_out: MediaFileType = SchemaField( + description="Concatenated video file (path or data URI)" + ) + total_duration: float = SchemaField(description="Total duration in seconds") + + def __init__(self): + super().__init__( + id="9b0f531a-1118-487f-aeec-3fa63ea8900a", + description="Merge multiple video clips into one continuous video", + categories={BlockCategory.MULTIMEDIA}, + input_schema=self.Input, + output_schema=self.Output, + test_input={ + "videos": ["/tmp/a.mp4", "/tmp/b.mp4"], + }, + test_output=[ + ("video_out", str), + ("total_duration", float), + ], + test_mock={ + "_concat_videos": lambda *args: 20.0, + "_store_input_video": lambda *args, **kwargs: "test.mp4", + "_store_output_video": lambda *args, **kwargs: "concat_test.mp4", + }, + ) + + async def _store_input_video( + self, execution_context: ExecutionContext, file: MediaFileType + ) -> MediaFileType: + """Store input video. Extracted for testability.""" + return await store_media_file( + file=file, + execution_context=execution_context, + return_format="for_local_processing", + ) + + async def _store_output_video( + self, execution_context: ExecutionContext, file: MediaFileType + ) -> MediaFileType: + """Store output video. Extracted for testability.""" + return await store_media_file( + file=file, + execution_context=execution_context, + return_format="for_block_output", + ) + + def _concat_videos( + self, + video_abspaths: list[str], + output_abspath: str, + transition: str, + transition_duration: int, + ) -> float: + """Concatenate videos. Extracted for testability. + + Returns: + Total duration of the concatenated video. + """ + clips = [] + faded_clips = [] + final = None + try: + # Load clips + for v in video_abspaths: + strip_chapters_inplace(v) + clips.append(VideoFileClip(v)) + + # Validate transition_duration against shortest clip + if transition in {"crossfade", "fade_black"} and transition_duration > 0: + min_duration = min(c.duration for c in clips) + if transition_duration >= min_duration: + raise BlockExecutionError( + message=( + f"transition_duration ({transition_duration}s) must be " + f"shorter than the shortest clip ({min_duration:.2f}s)" + ), + block_name=self.name, + block_id=str(self.id), + ) + + if transition == "crossfade": + for i, clip in enumerate(clips): + effects = [] + if i > 0: + effects.append(CrossFadeIn(transition_duration)) + if i < len(clips) - 1: + effects.append(CrossFadeOut(transition_duration)) + if effects: + clip = clip.with_effects(effects) + faded_clips.append(clip) + final = concatenate_videoclips( + faded_clips, + method="compose", + padding=-transition_duration, + ) + elif transition == "fade_black": + for clip in clips: + faded = clip.with_effects( + [FadeIn(transition_duration), FadeOut(transition_duration)] + ) + faded_clips.append(faded) + final = concatenate_videoclips(faded_clips) + else: + final = concatenate_videoclips(clips) + + video_codec, audio_codec = get_video_codecs(output_abspath) + final.write_videofile( + output_abspath, codec=video_codec, audio_codec=audio_codec + ) + + return final.duration + finally: + if final: + final.close() + for clip in faded_clips: + clip.close() + for clip in clips: + clip.close() + + async def run( + self, + input_data: Input, + *, + execution_context: ExecutionContext, + node_exec_id: str, + **kwargs, + ) -> BlockOutput: + # Validate minimum clips + if len(input_data.videos) < 2: + raise BlockExecutionError( + message="At least 2 videos are required for concatenation", + block_name=self.name, + block_id=str(self.id), + ) + + try: + assert execution_context.graph_exec_id is not None + + # Store all input videos locally + video_abspaths = [] + for video in input_data.videos: + local_path = await self._store_input_video(execution_context, video) + video_abspaths.append( + get_exec_file_path(execution_context.graph_exec_id, local_path) + ) + + # Build output path + source = ( + extract_source_name(video_abspaths[0]) if video_abspaths else "video" + ) + output_filename = MediaFileType( + f"{node_exec_id}_concat_{source}.{input_data.output_format}" + ) + output_abspath = get_exec_file_path( + execution_context.graph_exec_id, output_filename + ) + + total_duration = self._concat_videos( + video_abspaths, + output_abspath, + input_data.transition, + input_data.transition_duration, + ) + + # Return as workspace path or data URI based on context + video_out = await self._store_output_video( + execution_context, output_filename + ) + + yield "video_out", video_out + yield "total_duration", total_duration + + except BlockExecutionError: + raise + except Exception as e: + raise BlockExecutionError( + message=f"Failed to concatenate videos: {e}", + block_name=self.name, + block_id=str(self.id), + ) from e diff --git a/autogpt_platform/backend/backend/blocks/video/download.py b/autogpt_platform/backend/backend/blocks/video/download.py new file mode 100644 index 0000000000..4046d5df42 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/video/download.py @@ -0,0 +1,172 @@ +"""VideoDownloadBlock - Download video from URL (YouTube, Vimeo, news sites, direct links).""" + +import os +import typing +from typing import Literal + +import yt_dlp + +if typing.TYPE_CHECKING: + from yt_dlp import _Params + +from backend.data.block import ( + Block, + BlockCategory, + BlockOutput, + BlockSchemaInput, + BlockSchemaOutput, +) +from backend.data.execution import ExecutionContext +from backend.data.model import SchemaField +from backend.util.exceptions import BlockExecutionError +from backend.util.file import MediaFileType, get_exec_file_path, store_media_file + + +class VideoDownloadBlock(Block): + """Download video from URL using yt-dlp.""" + + class Input(BlockSchemaInput): + url: str = SchemaField( + description="URL of the video to download (YouTube, Vimeo, direct link, etc.)", + placeholder="https://www.youtube.com/watch?v=...", + ) + quality: Literal["best", "1080p", "720p", "480p", "audio_only"] = SchemaField( + description="Video quality preference", default="720p" + ) + output_format: Literal["mp4", "webm", "mkv"] = SchemaField( + description="Output video format", default="mp4", advanced=True + ) + + class Output(BlockSchemaOutput): + video_file: MediaFileType = SchemaField( + description="Downloaded video (path or data URI)" + ) + duration: float = SchemaField(description="Video duration in seconds") + title: str = SchemaField(description="Video title from source") + source_url: str = SchemaField(description="Original source URL") + + def __init__(self): + super().__init__( + id="c35daabb-cd60-493b-b9ad-51f1fe4b50c4", + description="Download video from URL (YouTube, Vimeo, news sites, direct links)", + categories={BlockCategory.MULTIMEDIA}, + input_schema=self.Input, + output_schema=self.Output, + disabled=True, # Disable until we can sandbox yt-dlp and handle security implications + test_input={ + "url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ", + "quality": "480p", + }, + test_output=[ + ("video_file", str), + ("duration", float), + ("title", str), + ("source_url", str), + ], + test_mock={ + "_download_video": lambda *args: ( + "video.mp4", + 212.0, + "Test Video", + ), + "_store_output_video": lambda *args, **kwargs: "video.mp4", + }, + ) + + async def _store_output_video( + self, execution_context: ExecutionContext, file: MediaFileType + ) -> MediaFileType: + """Store output video. Extracted for testability.""" + return await store_media_file( + file=file, + execution_context=execution_context, + return_format="for_block_output", + ) + + def _get_format_string(self, quality: str) -> str: + formats = { + "best": "bestvideo+bestaudio/best", + "1080p": "bestvideo[height<=1080]+bestaudio/best[height<=1080]", + "720p": "bestvideo[height<=720]+bestaudio/best[height<=720]", + "480p": "bestvideo[height<=480]+bestaudio/best[height<=480]", + "audio_only": "bestaudio/best", + } + return formats.get(quality, formats["720p"]) + + def _download_video( + self, + url: str, + quality: str, + output_format: str, + output_dir: str, + node_exec_id: str, + ) -> tuple[str, float, str]: + """Download video. Extracted for testability.""" + output_template = os.path.join( + output_dir, f"{node_exec_id}_%(title).50s.%(ext)s" + ) + + ydl_opts: "_Params" = { + "format": f"{self._get_format_string(quality)}/best", + "outtmpl": output_template, + "merge_output_format": output_format, + "quiet": True, + "no_warnings": True, + } + + with yt_dlp.YoutubeDL(ydl_opts) as ydl: + info = ydl.extract_info(url, download=True) + video_path = ydl.prepare_filename(info) + + # Handle format conversion in filename + if not video_path.endswith(f".{output_format}"): + video_path = video_path.rsplit(".", 1)[0] + f".{output_format}" + + # Return just the filename, not the full path + filename = os.path.basename(video_path) + + return ( + filename, + info.get("duration") or 0.0, + info.get("title") or "Unknown", + ) + + async def run( + self, + input_data: Input, + *, + execution_context: ExecutionContext, + node_exec_id: str, + **kwargs, + ) -> BlockOutput: + try: + assert execution_context.graph_exec_id is not None + + # Get the exec file directory + output_dir = get_exec_file_path(execution_context.graph_exec_id, "") + os.makedirs(output_dir, exist_ok=True) + + filename, duration, title = self._download_video( + input_data.url, + input_data.quality, + input_data.output_format, + output_dir, + node_exec_id, + ) + + # Return as workspace path or data URI based on context + video_out = await self._store_output_video( + execution_context, MediaFileType(filename) + ) + + yield "video_file", video_out + yield "duration", duration + yield "title", title + yield "source_url", input_data.url + + except Exception as e: + raise BlockExecutionError( + message=f"Failed to download video: {e}", + block_name=self.name, + block_id=str(self.id), + ) from e diff --git a/autogpt_platform/backend/backend/blocks/video/duration.py b/autogpt_platform/backend/backend/blocks/video/duration.py new file mode 100644 index 0000000000..9e05d35b00 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/video/duration.py @@ -0,0 +1,77 @@ +"""MediaDurationBlock - Get the duration of a media file.""" + +from moviepy.audio.io.AudioFileClip import AudioFileClip +from moviepy.video.io.VideoFileClip import VideoFileClip + +from backend.blocks.video._utils import strip_chapters_inplace +from backend.data.block import ( + Block, + BlockCategory, + BlockOutput, + BlockSchemaInput, + BlockSchemaOutput, +) +from backend.data.execution import ExecutionContext +from backend.data.model import SchemaField +from backend.util.file import MediaFileType, get_exec_file_path, store_media_file + + +class MediaDurationBlock(Block): + """Get the duration of a media file (video or audio).""" + + class Input(BlockSchemaInput): + media_in: MediaFileType = SchemaField( + description="Media input (URL, data URI, or local path)." + ) + is_video: bool = SchemaField( + description="Whether the media is a video (True) or audio (False).", + default=True, + ) + + class Output(BlockSchemaOutput): + duration: float = SchemaField( + description="Duration of the media file (in seconds)." + ) + + def __init__(self): + super().__init__( + id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6", + description="Block to get the duration of a media file.", + categories={BlockCategory.MULTIMEDIA}, + input_schema=MediaDurationBlock.Input, + output_schema=MediaDurationBlock.Output, + ) + + async def run( + self, + input_data: Input, + *, + execution_context: ExecutionContext, + **kwargs, + ) -> BlockOutput: + # 1) Store the input media locally + local_media_path = await store_media_file( + file=input_data.media_in, + execution_context=execution_context, + return_format="for_local_processing", + ) + assert execution_context.graph_exec_id is not None + media_abspath = get_exec_file_path( + execution_context.graph_exec_id, local_media_path + ) + + # 2) Strip chapters to avoid MoviePy crash, then load the clip + strip_chapters_inplace(media_abspath) + clip = None + try: + if input_data.is_video: + clip = VideoFileClip(media_abspath) + else: + clip = AudioFileClip(media_abspath) + + duration = clip.duration + finally: + if clip: + clip.close() + + yield "duration", duration diff --git a/autogpt_platform/backend/backend/blocks/video/loop.py b/autogpt_platform/backend/backend/blocks/video/loop.py new file mode 100644 index 0000000000..461610f713 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/video/loop.py @@ -0,0 +1,115 @@ +"""LoopVideoBlock - Loop a video to a given duration or number of repeats.""" + +from typing import Optional + +from moviepy.video.fx.Loop import Loop +from moviepy.video.io.VideoFileClip import VideoFileClip + +from backend.blocks.video._utils import extract_source_name, strip_chapters_inplace +from backend.data.block import ( + Block, + BlockCategory, + BlockOutput, + BlockSchemaInput, + BlockSchemaOutput, +) +from backend.data.execution import ExecutionContext +from backend.data.model import SchemaField +from backend.util.file import MediaFileType, get_exec_file_path, store_media_file + + +class LoopVideoBlock(Block): + """Loop (repeat) a video clip until a given duration or number of loops.""" + + class Input(BlockSchemaInput): + video_in: MediaFileType = SchemaField( + description="The input video (can be a URL, data URI, or local path)." + ) + duration: Optional[float] = SchemaField( + description="Target duration (in seconds) to loop the video to. Either duration or n_loops must be provided.", + default=None, + ge=0.0, + le=3600.0, # Max 1 hour to prevent disk exhaustion + ) + n_loops: Optional[int] = SchemaField( + description="Number of times to repeat the video. Either n_loops or duration must be provided.", + default=None, + ge=1, + le=10, # Max 10 loops to prevent disk exhaustion + ) + + class Output(BlockSchemaOutput): + video_out: MediaFileType = SchemaField( + description="Looped video returned either as a relative path or a data URI." + ) + + def __init__(self): + super().__init__( + id="8bf9eef6-5451-4213-b265-25306446e94b", + description="Block to loop a video to a given duration or number of repeats.", + categories={BlockCategory.MULTIMEDIA}, + input_schema=LoopVideoBlock.Input, + output_schema=LoopVideoBlock.Output, + ) + + async def run( + self, + input_data: Input, + *, + execution_context: ExecutionContext, + **kwargs, + ) -> BlockOutput: + assert execution_context.graph_exec_id is not None + assert execution_context.node_exec_id is not None + graph_exec_id = execution_context.graph_exec_id + node_exec_id = execution_context.node_exec_id + + # 1) Store the input video locally + local_video_path = await store_media_file( + file=input_data.video_in, + execution_context=execution_context, + return_format="for_local_processing", + ) + input_abspath = get_exec_file_path(graph_exec_id, local_video_path) + + # 2) Load the clip + strip_chapters_inplace(input_abspath) + clip = None + looped_clip = None + try: + clip = VideoFileClip(input_abspath) + + # 3) Apply the loop effect + if input_data.duration: + # Loop until we reach the specified duration + looped_clip = clip.with_effects([Loop(duration=input_data.duration)]) + elif input_data.n_loops: + looped_clip = clip.with_effects([Loop(n=input_data.n_loops)]) + else: + raise ValueError("Either 'duration' or 'n_loops' must be provided.") + + assert isinstance(looped_clip, VideoFileClip) + + # 4) Save the looped output + source = extract_source_name(local_video_path) + output_filename = MediaFileType(f"{node_exec_id}_looped_{source}.mp4") + output_abspath = get_exec_file_path(graph_exec_id, output_filename) + + looped_clip = looped_clip.with_audio(clip.audio) + looped_clip.write_videofile( + output_abspath, codec="libx264", audio_codec="aac" + ) + finally: + if looped_clip: + looped_clip.close() + if clip: + clip.close() + + # Return output - for_block_output returns workspace:// if available, else data URI + video_out = await store_media_file( + file=output_filename, + execution_context=execution_context, + return_format="for_block_output", + ) + + yield "video_out", video_out diff --git a/autogpt_platform/backend/backend/blocks/video/narration.py b/autogpt_platform/backend/backend/blocks/video/narration.py new file mode 100644 index 0000000000..adf41753c8 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/video/narration.py @@ -0,0 +1,267 @@ +"""VideoNarrationBlock - Generate AI voice narration and add to video.""" + +import os +from typing import Literal + +from elevenlabs import ElevenLabs +from moviepy import CompositeAudioClip +from moviepy.audio.io.AudioFileClip import AudioFileClip +from moviepy.video.io.VideoFileClip import VideoFileClip + +from backend.blocks.elevenlabs._auth import ( + TEST_CREDENTIALS, + TEST_CREDENTIALS_INPUT, + ElevenLabsCredentials, + ElevenLabsCredentialsInput, +) +from backend.blocks.video._utils import ( + extract_source_name, + get_video_codecs, + strip_chapters_inplace, +) +from backend.data.block import ( + Block, + BlockCategory, + BlockOutput, + BlockSchemaInput, + BlockSchemaOutput, +) +from backend.data.execution import ExecutionContext +from backend.data.model import CredentialsField, SchemaField +from backend.util.exceptions import BlockExecutionError +from backend.util.file import MediaFileType, get_exec_file_path, store_media_file + + +class VideoNarrationBlock(Block): + """Generate AI narration and add to video.""" + + class Input(BlockSchemaInput): + credentials: ElevenLabsCredentialsInput = CredentialsField( + description="ElevenLabs API key for voice synthesis" + ) + video_in: MediaFileType = SchemaField( + description="Input video (URL, data URI, or local path)" + ) + script: str = SchemaField(description="Narration script text") + voice_id: str = SchemaField( + description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel + ) + model_id: Literal[ + "eleven_multilingual_v2", + "eleven_flash_v2_5", + "eleven_turbo_v2_5", + "eleven_turbo_v2", + ] = SchemaField( + description="ElevenLabs TTS model", + default="eleven_multilingual_v2", + ) + mix_mode: Literal["replace", "mix", "ducking"] = SchemaField( + description="How to combine with original audio. 'ducking' applies stronger attenuation than 'mix'.", + default="ducking", + ) + narration_volume: float = SchemaField( + description="Narration volume (0.0 to 2.0)", + default=1.0, + ge=0.0, + le=2.0, + advanced=True, + ) + original_volume: float = SchemaField( + description="Original audio volume when mixing (0.0 to 1.0)", + default=0.3, + ge=0.0, + le=1.0, + advanced=True, + ) + + class Output(BlockSchemaOutput): + video_out: MediaFileType = SchemaField( + description="Video with narration (path or data URI)" + ) + audio_file: MediaFileType = SchemaField( + description="Generated audio file (path or data URI)" + ) + + def __init__(self): + super().__init__( + id="3d036b53-859c-4b17-9826-ca340f736e0e", + description="Generate AI narration and add to video", + categories={BlockCategory.MULTIMEDIA, BlockCategory.AI}, + input_schema=self.Input, + output_schema=self.Output, + test_input={ + "video_in": "/tmp/test.mp4", + "script": "Hello world", + "credentials": TEST_CREDENTIALS_INPUT, + }, + test_credentials=TEST_CREDENTIALS, + test_output=[("video_out", str), ("audio_file", str)], + test_mock={ + "_generate_narration_audio": lambda *args: b"mock audio content", + "_add_narration_to_video": lambda *args: None, + "_store_input_video": lambda *args, **kwargs: "test.mp4", + "_store_output_video": lambda *args, **kwargs: "narrated_test.mp4", + }, + ) + + async def _store_input_video( + self, execution_context: ExecutionContext, file: MediaFileType + ) -> MediaFileType: + """Store input video. Extracted for testability.""" + return await store_media_file( + file=file, + execution_context=execution_context, + return_format="for_local_processing", + ) + + async def _store_output_video( + self, execution_context: ExecutionContext, file: MediaFileType + ) -> MediaFileType: + """Store output video. Extracted for testability.""" + return await store_media_file( + file=file, + execution_context=execution_context, + return_format="for_block_output", + ) + + def _generate_narration_audio( + self, api_key: str, script: str, voice_id: str, model_id: str + ) -> bytes: + """Generate narration audio via ElevenLabs API.""" + client = ElevenLabs(api_key=api_key) + audio_generator = client.text_to_speech.convert( + voice_id=voice_id, + text=script, + model_id=model_id, + ) + # The SDK returns a generator, collect all chunks + return b"".join(audio_generator) + + def _add_narration_to_video( + self, + video_abspath: str, + audio_abspath: str, + output_abspath: str, + mix_mode: str, + narration_volume: float, + original_volume: float, + ) -> None: + """Add narration audio to video. Extracted for testability.""" + video = None + final = None + narration_original = None + narration_scaled = None + original = None + + try: + strip_chapters_inplace(video_abspath) + video = VideoFileClip(video_abspath) + narration_original = AudioFileClip(audio_abspath) + narration_scaled = narration_original.with_volume_scaled(narration_volume) + narration = narration_scaled + + if mix_mode == "replace": + final_audio = narration + elif mix_mode == "mix": + if video.audio: + original = video.audio.with_volume_scaled(original_volume) + final_audio = CompositeAudioClip([original, narration]) + else: + final_audio = narration + else: # ducking - apply stronger attenuation + if video.audio: + # Ducking uses a much lower volume for original audio + ducking_volume = original_volume * 0.3 + original = video.audio.with_volume_scaled(ducking_volume) + final_audio = CompositeAudioClip([original, narration]) + else: + final_audio = narration + + final = video.with_audio(final_audio) + video_codec, audio_codec = get_video_codecs(output_abspath) + final.write_videofile( + output_abspath, codec=video_codec, audio_codec=audio_codec + ) + + finally: + if original: + original.close() + if narration_scaled: + narration_scaled.close() + if narration_original: + narration_original.close() + if final: + final.close() + if video: + video.close() + + async def run( + self, + input_data: Input, + *, + credentials: ElevenLabsCredentials, + execution_context: ExecutionContext, + node_exec_id: str, + **kwargs, + ) -> BlockOutput: + try: + assert execution_context.graph_exec_id is not None + + # Store the input video locally + local_video_path = await self._store_input_video( + execution_context, input_data.video_in + ) + video_abspath = get_exec_file_path( + execution_context.graph_exec_id, local_video_path + ) + + # Generate narration audio via ElevenLabs + audio_content = self._generate_narration_audio( + credentials.api_key.get_secret_value(), + input_data.script, + input_data.voice_id, + input_data.model_id, + ) + + # Save audio to exec file path + audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3") + audio_abspath = get_exec_file_path( + execution_context.graph_exec_id, audio_filename + ) + os.makedirs(os.path.dirname(audio_abspath), exist_ok=True) + with open(audio_abspath, "wb") as f: + f.write(audio_content) + + # Add narration to video + source = extract_source_name(local_video_path) + output_filename = MediaFileType(f"{node_exec_id}_narrated_{source}.mp4") + output_abspath = get_exec_file_path( + execution_context.graph_exec_id, output_filename + ) + + self._add_narration_to_video( + video_abspath, + audio_abspath, + output_abspath, + input_data.mix_mode, + input_data.narration_volume, + input_data.original_volume, + ) + + # Return as workspace path or data URI based on context + video_out = await self._store_output_video( + execution_context, output_filename + ) + audio_out = await self._store_output_video( + execution_context, audio_filename + ) + + yield "video_out", video_out + yield "audio_file", audio_out + + except Exception as e: + raise BlockExecutionError( + message=f"Failed to add narration: {e}", + block_name=self.name, + block_id=str(self.id), + ) from e diff --git a/autogpt_platform/backend/backend/blocks/video/text_overlay.py b/autogpt_platform/backend/backend/blocks/video/text_overlay.py new file mode 100644 index 0000000000..cb7cfe0420 --- /dev/null +++ b/autogpt_platform/backend/backend/blocks/video/text_overlay.py @@ -0,0 +1,231 @@ +"""VideoTextOverlayBlock - Add text overlay to video.""" + +from typing import Literal + +from moviepy import CompositeVideoClip, TextClip +from moviepy.video.io.VideoFileClip import VideoFileClip + +from backend.blocks.video._utils import ( + extract_source_name, + get_video_codecs, + strip_chapters_inplace, +) +from backend.data.block import ( + Block, + BlockCategory, + BlockOutput, + BlockSchemaInput, + BlockSchemaOutput, +) +from backend.data.execution import ExecutionContext +from backend.data.model import SchemaField +from backend.util.exceptions import BlockExecutionError +from backend.util.file import MediaFileType, get_exec_file_path, store_media_file + + +class VideoTextOverlayBlock(Block): + """Add text overlay/caption to video.""" + + class Input(BlockSchemaInput): + video_in: MediaFileType = SchemaField( + description="Input video (URL, data URI, or local path)" + ) + text: str = SchemaField(description="Text to overlay on video") + position: Literal[ + "top", + "center", + "bottom", + "top-left", + "top-right", + "bottom-left", + "bottom-right", + ] = SchemaField(description="Position of text on screen", default="bottom") + start_time: float | None = SchemaField( + description="When to show text (seconds). None = entire video", + default=None, + advanced=True, + ) + end_time: float | None = SchemaField( + description="When to hide text (seconds). None = until end", + default=None, + advanced=True, + ) + font_size: int = SchemaField( + description="Font size", default=48, ge=12, le=200, advanced=True + ) + font_color: str = SchemaField( + description="Font color (hex or name)", default="white", advanced=True + ) + bg_color: str | None = SchemaField( + description="Background color behind text (None for transparent)", + default=None, + advanced=True, + ) + + class Output(BlockSchemaOutput): + video_out: MediaFileType = SchemaField( + description="Video with text overlay (path or data URI)" + ) + + def __init__(self): + super().__init__( + id="8ef14de6-cc90-430a-8cfa-3a003be92454", + description="Add text overlay/caption to video", + categories={BlockCategory.MULTIMEDIA}, + input_schema=self.Input, + output_schema=self.Output, + disabled=True, # Disable until we can lockdown imagemagick security policy + test_input={"video_in": "/tmp/test.mp4", "text": "Hello World"}, + test_output=[("video_out", str)], + test_mock={ + "_add_text_overlay": lambda *args: None, + "_store_input_video": lambda *args, **kwargs: "test.mp4", + "_store_output_video": lambda *args, **kwargs: "overlay_test.mp4", + }, + ) + + async def _store_input_video( + self, execution_context: ExecutionContext, file: MediaFileType + ) -> MediaFileType: + """Store input video. Extracted for testability.""" + return await store_media_file( + file=file, + execution_context=execution_context, + return_format="for_local_processing", + ) + + async def _store_output_video( + self, execution_context: ExecutionContext, file: MediaFileType + ) -> MediaFileType: + """Store output video. Extracted for testability.""" + return await store_media_file( + file=file, + execution_context=execution_context, + return_format="for_block_output", + ) + + def _add_text_overlay( + self, + video_abspath: str, + output_abspath: str, + text: str, + position: str, + start_time: float | None, + end_time: float | None, + font_size: int, + font_color: str, + bg_color: str | None, + ) -> None: + """Add text overlay to video. Extracted for testability.""" + video = None + final = None + txt_clip = None + try: + strip_chapters_inplace(video_abspath) + video = VideoFileClip(video_abspath) + + txt_clip = TextClip( + text=text, + font_size=font_size, + color=font_color, + bg_color=bg_color, + ) + + # Position mapping + pos_map = { + "top": ("center", "top"), + "center": ("center", "center"), + "bottom": ("center", "bottom"), + "top-left": ("left", "top"), + "top-right": ("right", "top"), + "bottom-left": ("left", "bottom"), + "bottom-right": ("right", "bottom"), + } + + txt_clip = txt_clip.with_position(pos_map[position]) + + # Set timing + start = start_time or 0 + end = end_time or video.duration + duration = max(0, end - start) + txt_clip = txt_clip.with_start(start).with_end(end).with_duration(duration) + + final = CompositeVideoClip([video, txt_clip]) + video_codec, audio_codec = get_video_codecs(output_abspath) + final.write_videofile( + output_abspath, codec=video_codec, audio_codec=audio_codec + ) + + finally: + if txt_clip: + txt_clip.close() + if final: + final.close() + if video: + video.close() + + async def run( + self, + input_data: Input, + *, + execution_context: ExecutionContext, + node_exec_id: str, + **kwargs, + ) -> BlockOutput: + # Validate time range if both are provided + if ( + input_data.start_time is not None + and input_data.end_time is not None + and input_data.end_time <= input_data.start_time + ): + raise BlockExecutionError( + message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})", + block_name=self.name, + block_id=str(self.id), + ) + + try: + assert execution_context.graph_exec_id is not None + + # Store the input video locally + local_video_path = await self._store_input_video( + execution_context, input_data.video_in + ) + video_abspath = get_exec_file_path( + execution_context.graph_exec_id, local_video_path + ) + + # Build output path + source = extract_source_name(local_video_path) + output_filename = MediaFileType(f"{node_exec_id}_overlay_{source}.mp4") + output_abspath = get_exec_file_path( + execution_context.graph_exec_id, output_filename + ) + + self._add_text_overlay( + video_abspath, + output_abspath, + input_data.text, + input_data.position, + input_data.start_time, + input_data.end_time, + input_data.font_size, + input_data.font_color, + input_data.bg_color, + ) + + # Return as workspace path or data URI based on context + video_out = await self._store_output_video( + execution_context, output_filename + ) + + yield "video_out", video_out + + except BlockExecutionError: + raise + except Exception as e: + raise BlockExecutionError( + message=f"Failed to add text overlay: {e}", + block_name=self.name, + block_id=str(self.id), + ) from e diff --git a/autogpt_platform/backend/backend/blocks/youtube.py b/autogpt_platform/backend/backend/blocks/youtube.py index e79be3e99b..6d81a86b4c 100644 --- a/autogpt_platform/backend/backend/blocks/youtube.py +++ b/autogpt_platform/backend/backend/blocks/youtube.py @@ -165,10 +165,13 @@ class TranscribeYoutubeVideoBlock(Block): credentials: WebshareProxyCredentials, **kwargs, ) -> BlockOutput: - video_id = self.extract_video_id(input_data.youtube_url) - yield "video_id", video_id + try: + video_id = self.extract_video_id(input_data.youtube_url) + transcript = self.get_transcript(video_id, credentials) + transcript_text = self.format_transcript(transcript=transcript) - transcript = self.get_transcript(video_id, credentials) - transcript_text = self.format_transcript(transcript=transcript) - - yield "transcript", transcript_text + # Only yield after all operations succeed + yield "video_id", video_id + yield "transcript", transcript_text + except Exception as e: + yield "error", str(e) diff --git a/autogpt_platform/backend/backend/data/block.py b/autogpt_platform/backend/backend/data/block.py index 8d9ecfff4c..f67134ceb3 100644 --- a/autogpt_platform/backend/backend/data/block.py +++ b/autogpt_platform/backend/backend/data/block.py @@ -246,7 +246,9 @@ class BlockSchema(BaseModel): f"is not of type {CredentialsMetaInput.__name__}" ) - credentials_fields[field_name].validate_credentials_field_schema(cls) + CredentialsMetaInput.validate_credentials_field_schema( + cls.get_field_schema(field_name), field_name + ) elif field_name in credentials_fields: raise KeyError( @@ -873,14 +875,13 @@ def is_block_auth_configured( async def initialize_blocks() -> None: - # First, sync all provider costs to blocks - # Imported here to avoid circular import from backend.sdk.cost_integration import sync_all_provider_costs + from backend.util.retry import func_retry sync_all_provider_costs() - for cls in get_blocks().values(): - block = cls() + @func_retry + async def sync_block_to_db(block: Block) -> None: existing_block = await AgentBlock.prisma().find_first( where={"OR": [{"id": block.id}, {"name": block.name}]} ) @@ -893,7 +894,7 @@ async def initialize_blocks() -> None: outputSchema=json.dumps(block.output_schema.jsonschema()), ) ) - continue + return input_schema = json.dumps(block.input_schema.jsonschema()) output_schema = json.dumps(block.output_schema.jsonschema()) @@ -913,6 +914,25 @@ async def initialize_blocks() -> None: }, ) + failed_blocks: list[str] = [] + for cls in get_blocks().values(): + block = cls() + try: + await sync_block_to_db(block) + except Exception as e: + logger.warning( + f"Failed to sync block {block.name} to database: {e}. " + "Block is still available in memory.", + exc_info=True, + ) + failed_blocks.append(block.name) + + if failed_blocks: + logger.error( + f"Failed to sync {len(failed_blocks)} block(s) to database: " + f"{', '.join(failed_blocks)}. These blocks are still available in memory." + ) + # Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281 def get_block(block_id: str) -> AnyBlockSchema | None: diff --git a/autogpt_platform/backend/backend/data/block_cost_config.py b/autogpt_platform/backend/backend/data/block_cost_config.py index f46cc726f0..ec35afa401 100644 --- a/autogpt_platform/backend/backend/data/block_cost_config.py +++ b/autogpt_platform/backend/backend/data/block_cost_config.py @@ -36,12 +36,14 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock +from backend.blocks.video.narration import VideoNarrationBlock from backend.data.block import Block, BlockCost, BlockCostType from backend.integrations.credentials_store import ( aiml_api_credentials, anthropic_credentials, apollo_credentials, did_credentials, + elevenlabs_credentials, enrichlayer_credentials, groq_credentials, ideogram_credentials, @@ -78,6 +80,7 @@ MODEL_COST: dict[LlmModel, int] = { LlmModel.CLAUDE_4_1_OPUS: 21, LlmModel.CLAUDE_4_OPUS: 21, LlmModel.CLAUDE_4_SONNET: 5, + LlmModel.CLAUDE_4_6_OPUS: 14, LlmModel.CLAUDE_4_5_HAIKU: 4, LlmModel.CLAUDE_4_5_OPUS: 14, LlmModel.CLAUDE_4_5_SONNET: 9, @@ -639,4 +642,16 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = { }, ), ], + VideoNarrationBlock: [ + BlockCost( + cost_amount=5, # ElevenLabs TTS cost + cost_filter={ + "credentials": { + "id": elevenlabs_credentials.id, + "provider": elevenlabs_credentials.provider, + "type": elevenlabs_credentials.type, + } + }, + ) + ], } diff --git a/autogpt_platform/backend/backend/data/credit_test.py b/autogpt_platform/backend/backend/data/credit_test.py index 391a373b86..2b10c62882 100644 --- a/autogpt_platform/backend/backend/data/credit_test.py +++ b/autogpt_platform/backend/backend/data/credit_test.py @@ -134,6 +134,16 @@ async def test_block_credit_reset(server: SpinTestServer): month1 = datetime.now(timezone.utc).replace(month=1, day=1) user_credit.time_now = lambda: month1 + # IMPORTANT: Set updatedAt to December of previous year to ensure it's + # in a different month than month1 (January). This fixes a timing bug + # where if the test runs in early February, 35 days ago would be January, + # matching the mocked month1 and preventing the refill from triggering. + dec_previous_year = month1.replace(year=month1.year - 1, month=12, day=15) + await UserBalance.prisma().update( + where={"userId": DEFAULT_USER_ID}, + data={"updatedAt": dec_previous_year}, + ) + # First call in month 1 should trigger refill balance = await user_credit.get_credits(DEFAULT_USER_ID) assert balance == REFILL_VALUE # Should get 1000 credits diff --git a/autogpt_platform/backend/backend/data/graph.py b/autogpt_platform/backend/backend/data/graph.py index ee6cd2e4b0..0dc3eea887 100644 --- a/autogpt_platform/backend/backend/data/graph.py +++ b/autogpt_platform/backend/backend/data/graph.py @@ -3,7 +3,7 @@ import logging import uuid from collections import defaultdict from datetime import datetime, timezone -from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, cast +from typing import TYPE_CHECKING, Annotated, Any, Literal, Optional, Self, cast from prisma.enums import SubmissionStatus from prisma.models import ( @@ -20,7 +20,7 @@ from prisma.types import ( AgentNodeLinkCreateInput, StoreListingVersionWhereInput, ) -from pydantic import BaseModel, BeforeValidator, Field, create_model +from pydantic import BaseModel, BeforeValidator, Field from pydantic.fields import computed_field from backend.blocks.agent import AgentExecutorBlock @@ -30,7 +30,6 @@ from backend.data.db import prisma as db from backend.data.dynamic_fields import is_tool_pin, sanitize_pin_name from backend.data.includes import MAX_GRAPH_VERSIONS_FETCH from backend.data.model import ( - CredentialsField, CredentialsFieldInfo, CredentialsMetaInput, is_credentials_field_name, @@ -45,7 +44,6 @@ from .block import ( AnyBlockSchema, Block, BlockInput, - BlockSchema, BlockType, EmptySchema, get_block, @@ -113,10 +111,12 @@ class Link(BaseDbModel): class Node(BaseDbModel): block_id: str - input_default: BlockInput = {} # dict[input_name, default_value] - metadata: dict[str, Any] = {} - input_links: list[Link] = [] - output_links: list[Link] = [] + input_default: BlockInput = Field( # dict[input_name, default_value] + default_factory=dict + ) + metadata: dict[str, Any] = Field(default_factory=dict) + input_links: list[Link] = Field(default_factory=list) + output_links: list[Link] = Field(default_factory=list) @property def credentials_optional(self) -> bool: @@ -221,18 +221,33 @@ class NodeModel(Node): return result -class BaseGraph(BaseDbModel): +class GraphBaseMeta(BaseDbModel): + """ + Shared base for `GraphMeta` and `BaseGraph`, with core graph metadata fields. + """ + version: int = 1 is_active: bool = True name: str description: str instructions: str | None = None recommended_schedule_cron: str | None = None - nodes: list[Node] = [] - links: list[Link] = [] forked_from_id: str | None = None forked_from_version: int | None = None + +class BaseGraph(GraphBaseMeta): + """ + Graph with nodes, links, and computed I/O schema fields. + + Used to represent sub-graphs within a `Graph`. Contains the full graph + structure including nodes and links, plus computed fields for schemas + and trigger info. Does NOT include user_id or created_at (see GraphModel). + """ + + nodes: list[Node] = Field(default_factory=list) + links: list[Link] = Field(default_factory=list) + @computed_field @property def input_schema(self) -> dict[str, Any]: @@ -361,44 +376,79 @@ class GraphTriggerInfo(BaseModel): class Graph(BaseGraph): - sub_graphs: list[BaseGraph] = [] # Flattened sub-graphs + """Creatable graph model used in API create/update endpoints.""" + + sub_graphs: list[BaseGraph] = Field(default_factory=list) # Flattened sub-graphs + + +class GraphMeta(GraphBaseMeta): + """ + Lightweight graph metadata model representing an existing graph from the database, + for use in listings and summaries. + + Lacks `GraphModel`'s nodes, links, and expensive computed fields. + Use for list endpoints where full graph data is not needed and performance matters. + """ + + id: str # type: ignore + version: int # type: ignore + user_id: str + created_at: datetime + + @classmethod + def from_db(cls, graph: "AgentGraph") -> Self: + return cls( + id=graph.id, + version=graph.version, + is_active=graph.isActive, + name=graph.name or "", + description=graph.description or "", + instructions=graph.instructions, + recommended_schedule_cron=graph.recommendedScheduleCron, + forked_from_id=graph.forkedFromId, + forked_from_version=graph.forkedFromVersion, + user_id=graph.userId, + created_at=graph.createdAt, + ) + + +class GraphModel(Graph, GraphMeta): + """ + Full graph model representing an existing graph from the database. + + This is the primary model for working with persisted graphs. Includes all + graph data (nodes, links, sub_graphs) plus user ownership and timestamps. + Provides computed fields (input_schema, output_schema, etc.) used during + set-up (frontend) and execution (backend). + + Inherits from: + - `Graph`: provides structure (nodes, links, sub_graphs) and computed schemas + - `GraphMeta`: provides user_id, created_at for database records + """ + + nodes: list[NodeModel] = Field(default_factory=list) # type: ignore + + @property + def starting_nodes(self) -> list[NodeModel]: + outbound_nodes = {link.sink_id for link in self.links} + input_nodes = { + node.id for node in self.nodes if node.block.block_type == BlockType.INPUT + } + return [ + node + for node in self.nodes + if node.id not in outbound_nodes or node.id in input_nodes + ] + + @property + def webhook_input_node(self) -> NodeModel | None: # type: ignore + return cast(NodeModel, super().webhook_input_node) @computed_field @property def credentials_input_schema(self) -> dict[str, Any]: - schema = self._credentials_input_schema.jsonschema() - - # Determine which credential fields are required based on credentials_optional metadata graph_credentials_inputs = self.aggregate_credentials_inputs() - required_fields = [] - # Build a map of node_id -> node for quick lookup - all_nodes = {node.id: node for node in self.nodes} - for sub_graph in self.sub_graphs: - for node in sub_graph.nodes: - all_nodes[node.id] = node - - for field_key, ( - _field_info, - node_field_pairs, - ) in graph_credentials_inputs.items(): - # A field is required if ANY node using it has credentials_optional=False - is_required = False - for node_id, _field_name in node_field_pairs: - node = all_nodes.get(node_id) - if node and not node.credentials_optional: - is_required = True - break - - if is_required: - required_fields.append(field_key) - - schema["required"] = required_fields - return schema - - @property - def _credentials_input_schema(self) -> type[BlockSchema]: - graph_credentials_inputs = self.aggregate_credentials_inputs() logger.debug( f"Combined credentials input fields for graph #{self.id} ({self.name}): " f"{graph_credentials_inputs}" @@ -406,8 +456,8 @@ class Graph(BaseGraph): # Warn if same-provider credentials inputs can't be combined (= bad UX) graph_cred_fields = list(graph_credentials_inputs.values()) - for i, (field, keys) in enumerate(graph_cred_fields): - for other_field, other_keys in list(graph_cred_fields)[i + 1 :]: + for i, (field, keys, _) in enumerate(graph_cred_fields): + for other_field, other_keys, _ in list(graph_cred_fields)[i + 1 :]: if field.provider != other_field.provider: continue if ProviderName.HTTP in field.provider: @@ -423,31 +473,78 @@ class Graph(BaseGraph): f"keys: {keys} <> {other_keys}." ) - fields: dict[str, tuple[type[CredentialsMetaInput], CredentialsMetaInput]] = { - agg_field_key: ( - CredentialsMetaInput[ - Literal[tuple(field_info.provider)], # type: ignore - Literal[tuple(field_info.supported_types)], # type: ignore - ], - CredentialsField( - required_scopes=set(field_info.required_scopes or []), - discriminator=field_info.discriminator, - discriminator_mapping=field_info.discriminator_mapping, - discriminator_values=field_info.discriminator_values, - ), - ) - for agg_field_key, (field_info, _) in graph_credentials_inputs.items() - } + # Build JSON schema directly to avoid expensive create_model + validation overhead + properties = {} + required_fields = [] - return create_model( - self.name.replace(" ", "") + "CredentialsInputSchema", - __base__=BlockSchema, - **fields, # type: ignore - ) + for agg_field_key, ( + field_info, + _, + is_required, + ) in graph_credentials_inputs.items(): + providers = list(field_info.provider) + cred_types = list(field_info.supported_types) + + field_schema: dict[str, Any] = { + "credentials_provider": providers, + "credentials_types": cred_types, + "type": "object", + "properties": { + "id": {"title": "Id", "type": "string"}, + "title": { + "anyOf": [{"type": "string"}, {"type": "null"}], + "default": None, + "title": "Title", + }, + "provider": { + "title": "Provider", + "type": "string", + **( + {"enum": providers} + if len(providers) > 1 + else {"const": providers[0]} + ), + }, + "type": { + "title": "Type", + "type": "string", + **( + {"enum": cred_types} + if len(cred_types) > 1 + else {"const": cred_types[0]} + ), + }, + }, + "required": ["id", "provider", "type"], + } + + # Add other (optional) field info items + field_schema.update( + field_info.model_dump( + by_alias=True, + exclude_defaults=True, + exclude={"provider", "supported_types"}, # already included above + ) + ) + + # Ensure field schema is well-formed + CredentialsMetaInput.validate_credentials_field_schema( + field_schema, agg_field_key + ) + + properties[agg_field_key] = field_schema + if is_required: + required_fields.append(agg_field_key) + + return { + "type": "object", + "properties": properties, + "required": required_fields, + } def aggregate_credentials_inputs( self, - ) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]]]]: + ) -> dict[str, tuple[CredentialsFieldInfo, set[tuple[str, str]], bool]]: """ Returns: dict[aggregated_field_key, tuple( @@ -455,13 +552,19 @@ class Graph(BaseGraph): (now includes discriminator_values from matching nodes) set[(node_id, field_name)]: Node credentials fields that are compatible with this aggregated field spec + bool: True if the field is required (any node has credentials_optional=False) )] """ # First collect all credential field data with input defaults - node_credential_data = [] + # Track (field_info, (node_id, field_name), is_required) for each credential field + node_credential_data: list[tuple[CredentialsFieldInfo, tuple[str, str]]] = [] + node_required_map: dict[str, bool] = {} # node_id -> is_required for graph in [self] + self.sub_graphs: for node in graph.nodes: + # Track if this node requires credentials (credentials_optional=False means required) + node_required_map[node.id] = not node.credentials_optional + for ( field_name, field_info, @@ -485,37 +588,21 @@ class Graph(BaseGraph): ) # Combine credential field info (this will merge discriminator_values automatically) - return CredentialsFieldInfo.combine(*node_credential_data) + combined = CredentialsFieldInfo.combine(*node_credential_data) - -class GraphModel(Graph): - user_id: str - nodes: list[NodeModel] = [] # type: ignore - - created_at: datetime - - @property - def starting_nodes(self) -> list[NodeModel]: - outbound_nodes = {link.sink_id for link in self.links} - input_nodes = { - node.id for node in self.nodes if node.block.block_type == BlockType.INPUT + # Add is_required flag to each aggregated field + # A field is required if ANY node using it has credentials_optional=False + return { + key: ( + field_info, + node_field_pairs, + any( + node_required_map.get(node_id, True) + for node_id, _ in node_field_pairs + ), + ) + for key, (field_info, node_field_pairs) in combined.items() } - return [ - node - for node in self.nodes - if node.id not in outbound_nodes or node.id in input_nodes - ] - - @property - def webhook_input_node(self) -> NodeModel | None: # type: ignore - return cast(NodeModel, super().webhook_input_node) - - def meta(self) -> "GraphMeta": - """ - Returns a GraphMeta object with metadata about the graph. - This is used to return metadata about the graph without exposing nodes and links. - """ - return GraphMeta.from_graph(self) def reassign_ids(self, user_id: str, reassign_graph_id: bool = False): """ @@ -799,13 +886,14 @@ class GraphModel(Graph): if is_static_output_block(link.source_id): link.is_static = True # Each value block output should be static. - @staticmethod - def from_db( + @classmethod + def from_db( # type: ignore[reportIncompatibleMethodOverride] + cls, graph: AgentGraph, for_export: bool = False, sub_graphs: list[AgentGraph] | None = None, - ) -> "GraphModel": - return GraphModel( + ) -> Self: + return cls( id=graph.id, user_id=graph.userId if not for_export else "", version=graph.version, @@ -831,17 +919,28 @@ class GraphModel(Graph): ], ) + def hide_nodes(self) -> "GraphModelWithoutNodes": + """ + Returns a copy of the `GraphModel` with nodes, links, and sub-graphs hidden + (excluded from serialization). They are still present in the model instance + so all computed fields (e.g. `credentials_input_schema`) still work. + """ + return GraphModelWithoutNodes.model_validate(self, from_attributes=True) -class GraphMeta(Graph): - user_id: str - # Easy work-around to prevent exposing nodes and links in the API response - nodes: list[NodeModel] = Field(default=[], exclude=True) # type: ignore - links: list[Link] = Field(default=[], exclude=True) +class GraphModelWithoutNodes(GraphModel): + """ + GraphModel variant that excludes nodes, links, and sub-graphs from serialization. - @staticmethod - def from_graph(graph: GraphModel) -> "GraphMeta": - return GraphMeta(**graph.model_dump()) + Used in contexts like the store where exposing internal graph structure + is not desired. Inherits all computed fields from GraphModel but marks + nodes and links as excluded from JSON output. + """ + + nodes: list[NodeModel] = Field(default_factory=list, exclude=True) + links: list[Link] = Field(default_factory=list, exclude=True) + + sub_graphs: list[BaseGraph] = Field(default_factory=list, exclude=True) class GraphsPaginated(BaseModel): @@ -912,21 +1011,11 @@ async def list_graphs_paginated( where=where_clause, distinct=["id"], order={"version": "desc"}, - include=AGENT_GRAPH_INCLUDE, skip=offset, take=page_size, ) - graph_models: list[GraphMeta] = [] - for graph in graphs: - try: - graph_meta = GraphModel.from_db(graph).meta() - # Trigger serialization to validate that the graph is well formed - graph_meta.model_dump() - graph_models.append(graph_meta) - except Exception as e: - logger.error(f"Error processing graph {graph.id}: {e}") - continue + graph_models = [GraphMeta.from_db(graph) for graph in graphs] return GraphsPaginated( graphs=graph_models, diff --git a/autogpt_platform/backend/backend/data/model.py b/autogpt_platform/backend/backend/data/model.py index 331126fbd6..7bdfef059b 100644 --- a/autogpt_platform/backend/backend/data/model.py +++ b/autogpt_platform/backend/backend/data/model.py @@ -19,7 +19,6 @@ from typing import ( cast, get_args, ) -from urllib.parse import urlparse from uuid import uuid4 from prisma.enums import CreditTransactionType, OnboardingStep @@ -42,6 +41,7 @@ from typing_extensions import TypedDict from backend.integrations.providers import ProviderName from backend.util.json import loads as json_loads +from backend.util.request import parse_url from backend.util.settings import Secrets # Type alias for any provider name (including custom ones) @@ -163,7 +163,6 @@ class User(BaseModel): if TYPE_CHECKING: from prisma.models import User as PrismaUser - from backend.data.block import BlockSchema T = TypeVar("T") logger = logging.getLogger(__name__) @@ -397,19 +396,25 @@ class HostScopedCredentials(_BaseCredentials): def matches_url(self, url: str) -> bool: """Check if this credential should be applied to the given URL.""" - parsed_url = urlparse(url) - # Extract hostname without port - request_host = parsed_url.hostname + request_host, request_port = _extract_host_from_url(url) + cred_scope_host, cred_scope_port = _extract_host_from_url(self.host) if not request_host: return False - # Simple host matching - exact match or wildcard subdomain match - if self.host == request_host: + # If a port is specified in credential host, the request host port must match + if cred_scope_port is not None and request_port != cred_scope_port: + return False + # Non-standard ports are only allowed if explicitly specified in credential host + elif cred_scope_port is None and request_port not in (80, 443, None): + return False + + # Simple host matching + if cred_scope_host == request_host: return True # Support wildcard matching (e.g., "*.example.com" matches "api.example.com") - if self.host.startswith("*."): - domain = self.host[2:] # Remove "*." + if cred_scope_host.startswith("*."): + domain = cred_scope_host[2:] # Remove "*." return request_host.endswith(f".{domain}") or request_host == domain return False @@ -502,15 +507,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]): def allowed_cred_types(cls) -> tuple[CredentialsType, ...]: return get_args(cls.model_fields["type"].annotation) - @classmethod - def validate_credentials_field_schema(cls, model: type["BlockSchema"]): + @staticmethod + def validate_credentials_field_schema( + field_schema: dict[str, Any], field_name: str + ): """Validates the schema of a credentials input field""" - field_name = next( - name for name, type in model.get_credentials_fields().items() if type is cls - ) - field_schema = model.jsonschema()["properties"][field_name] try: - schema_extra = CredentialsFieldInfo[CP, CT].model_validate(field_schema) + field_info = CredentialsFieldInfo[CP, CT].model_validate(field_schema) except ValidationError as e: if "Field required [type=missing" not in str(e): raise @@ -520,11 +523,11 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]): f"{field_schema}" ) from e - providers = cls.allowed_providers() + providers = field_info.provider if ( providers is not None and len(providers) > 1 - and not schema_extra.discriminator + and not field_info.discriminator ): raise TypeError( f"Multi-provider CredentialsField '{field_name}' " @@ -551,13 +554,13 @@ class CredentialsMetaInput(BaseModel, Generic[CP, CT]): ) -def _extract_host_from_url(url: str) -> str: - """Extract host from URL for grouping host-scoped credentials.""" +def _extract_host_from_url(url: str) -> tuple[str, int | None]: + """Extract host and port from URL for grouping host-scoped credentials.""" try: - parsed = urlparse(url) - return parsed.hostname or url + parsed = parse_url(url) + return parsed.hostname or url, parsed.port except Exception: - return "" + return "", None class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): @@ -606,7 +609,7 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]): providers = frozenset( [cast(CP, "http")] + [ - cast(CP, _extract_host_from_url(str(value))) + cast(CP, parse_url(str(value)).netloc) for value in field.discriminator_values ] ) diff --git a/autogpt_platform/backend/backend/data/model_test.py b/autogpt_platform/backend/backend/data/model_test.py index 37ec6be82f..e8e2ddfa35 100644 --- a/autogpt_platform/backend/backend/data/model_test.py +++ b/autogpt_platform/backend/backend/data/model_test.py @@ -79,10 +79,23 @@ class TestHostScopedCredentials: headers={"Authorization": SecretStr("Bearer token")}, ) - assert creds.matches_url("http://localhost:8080/api/v1") + # Non-standard ports require explicit port in credential host + assert not creds.matches_url("http://localhost:8080/api/v1") assert creds.matches_url("https://localhost:443/secure/endpoint") assert creds.matches_url("http://localhost/simple") + def test_matches_url_with_explicit_port(self): + """Test URL matching with explicit port in credential host.""" + creds = HostScopedCredentials( + provider="custom", + host="localhost:8080", + headers={"Authorization": SecretStr("Bearer token")}, + ) + + assert creds.matches_url("http://localhost:8080/api/v1") + assert not creds.matches_url("http://localhost:3000/api/v1") + assert not creds.matches_url("http://localhost/simple") + def test_empty_headers_dict(self): """Test HostScopedCredentials with empty headers.""" creds = HostScopedCredentials( @@ -128,8 +141,20 @@ class TestHostScopedCredentials: ("*.example.com", "https://sub.api.example.com/test", True), ("*.example.com", "https://example.com/test", True), ("*.example.com", "https://example.org/test", False), - ("localhost", "http://localhost:3000/test", True), + # Non-standard ports require explicit port in credential host + ("localhost", "http://localhost:3000/test", False), + ("localhost:3000", "http://localhost:3000/test", True), ("localhost", "http://127.0.0.1:3000/test", False), + # IPv6 addresses (frontend stores with brackets via URL.hostname) + ("[::1]", "http://[::1]/test", True), + ("[::1]", "http://[::1]:80/test", True), + ("[::1]", "https://[::1]:443/test", True), + ("[::1]", "http://[::1]:8080/test", False), # Non-standard port + ("[::1]:8080", "http://[::1]:8080/test", True), + ("[::1]:8080", "http://[::1]:9090/test", False), + ("[2001:db8::1]", "http://[2001:db8::1]/path", True), + ("[2001:db8::1]", "https://[2001:db8::1]:443/path", True), + ("[2001:db8::1]", "http://[2001:db8::ff]/path", False), ], ) def test_url_matching_parametrized(self, host: str, test_url: str, expected: bool): diff --git a/autogpt_platform/backend/backend/executor/utils.py b/autogpt_platform/backend/backend/executor/utils.py index fa264c30a7..d26424aefc 100644 --- a/autogpt_platform/backend/backend/executor/utils.py +++ b/autogpt_platform/backend/backend/executor/utils.py @@ -373,7 +373,7 @@ def make_node_credentials_input_map( # Get aggregated credentials fields for the graph graph_cred_inputs = graph.aggregate_credentials_inputs() - for graph_input_name, (_, compatible_node_fields) in graph_cred_inputs.items(): + for graph_input_name, (_, compatible_node_fields, _) in graph_cred_inputs.items(): # Best-effort map: skip missing items if graph_input_name not in graph_credentials_input: continue diff --git a/autogpt_platform/backend/backend/integrations/credentials_store.py b/autogpt_platform/backend/backend/integrations/credentials_store.py index 40a6f7269c..384405b0c7 100644 --- a/autogpt_platform/backend/backend/integrations/credentials_store.py +++ b/autogpt_platform/backend/backend/integrations/credentials_store.py @@ -224,6 +224,14 @@ openweathermap_credentials = APIKeyCredentials( expires_at=None, ) +elevenlabs_credentials = APIKeyCredentials( + id="f4a8b6c2-3d1e-4f5a-9b8c-7d6e5f4a3b2c", + provider="elevenlabs", + api_key=SecretStr(settings.secrets.elevenlabs_api_key), + title="Use Credits for ElevenLabs", + expires_at=None, +) + DEFAULT_CREDENTIALS = [ ollama_credentials, revid_credentials, @@ -252,6 +260,7 @@ DEFAULT_CREDENTIALS = [ v0_credentials, webshare_proxy_credentials, openweathermap_credentials, + elevenlabs_credentials, ] SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS} @@ -366,6 +375,8 @@ class IntegrationCredentialsStore: all_credentials.append(webshare_proxy_credentials) if settings.secrets.openweathermap_api_key: all_credentials.append(openweathermap_credentials) + if settings.secrets.elevenlabs_api_key: + all_credentials.append(elevenlabs_credentials) return all_credentials async def get_creds_by_id( diff --git a/autogpt_platform/backend/backend/integrations/providers.py b/autogpt_platform/backend/backend/integrations/providers.py index 3af5006ca4..8a0d6fd183 100644 --- a/autogpt_platform/backend/backend/integrations/providers.py +++ b/autogpt_platform/backend/backend/integrations/providers.py @@ -18,6 +18,7 @@ class ProviderName(str, Enum): DISCORD = "discord" D_ID = "d_id" E2B = "e2b" + ELEVENLABS = "elevenlabs" FAL = "fal" GITHUB = "github" GOOGLE = "google" diff --git a/autogpt_platform/backend/backend/util/file.py b/autogpt_platform/backend/backend/util/file.py index 0dfdb5bd29..70e354a29c 100644 --- a/autogpt_platform/backend/backend/util/file.py +++ b/autogpt_platform/backend/backend/util/file.py @@ -8,6 +8,8 @@ from pathlib import Path from typing import TYPE_CHECKING, Literal from urllib.parse import urlparse +from pydantic import BaseModel + from backend.util.cloud_storage import get_cloud_storage_handler from backend.util.request import Requests from backend.util.settings import Config @@ -17,6 +19,35 @@ from backend.util.virus_scanner import scan_content_safe if TYPE_CHECKING: from backend.data.execution import ExecutionContext + +class WorkspaceUri(BaseModel): + """Parsed workspace:// URI.""" + + file_ref: str # File ID or path (e.g. "abc123" or "/path/to/file.txt") + mime_type: str | None = None # MIME type from fragment (e.g. "video/mp4") + is_path: bool = False # True if file_ref is a path (starts with "/") + + +def parse_workspace_uri(uri: str) -> WorkspaceUri: + """Parse a workspace:// URI into its components. + + Examples: + "workspace://abc123" → WorkspaceUri(file_ref="abc123", mime_type=None, is_path=False) + "workspace://abc123#video/mp4" → WorkspaceUri(file_ref="abc123", mime_type="video/mp4", is_path=False) + "workspace:///path/to/file.txt" → WorkspaceUri(file_ref="/path/to/file.txt", mime_type=None, is_path=True) + """ + raw = uri.removeprefix("workspace://") + mime_type: str | None = None + if "#" in raw: + raw, fragment = raw.split("#", 1) + mime_type = fragment or None + return WorkspaceUri( + file_ref=raw, + mime_type=mime_type, + is_path=raw.startswith("/"), + ) + + # Return format options for store_media_file # - "for_local_processing": Returns local file path - use with ffmpeg, MoviePy, PIL, etc. # - "for_external_api": Returns data URI (base64) - use when sending content to external APIs @@ -183,22 +214,20 @@ async def store_media_file( "This file type is only available in CoPilot sessions." ) - # Parse workspace reference - # workspace://abc123 - by file ID - # workspace:///path/to/file.txt - by virtual path - file_ref = file[12:] # Remove "workspace://" + # Parse workspace reference (strips #mimeType fragment from file ID) + ws = parse_workspace_uri(file) - if file_ref.startswith("/"): - # Path reference - workspace_content = await workspace_manager.read_file(file_ref) - file_info = await workspace_manager.get_file_info_by_path(file_ref) + if ws.is_path: + # Path reference: workspace:///path/to/file.txt + workspace_content = await workspace_manager.read_file(ws.file_ref) + file_info = await workspace_manager.get_file_info_by_path(ws.file_ref) filename = sanitize_filename( file_info.name if file_info else f"{uuid.uuid4()}.bin" ) else: - # ID reference - workspace_content = await workspace_manager.read_file_by_id(file_ref) - file_info = await workspace_manager.get_file_info(file_ref) + # ID reference: workspace://abc123 or workspace://abc123#video/mp4 + workspace_content = await workspace_manager.read_file_by_id(ws.file_ref) + file_info = await workspace_manager.get_file_info(ws.file_ref) filename = sanitize_filename( file_info.name if file_info else f"{uuid.uuid4()}.bin" ) @@ -342,7 +371,21 @@ async def store_media_file( # Don't re-save if input was already from workspace if is_from_workspace: - # Return original workspace reference + # Return original workspace reference, ensuring MIME type fragment + ws = parse_workspace_uri(file) + if not ws.mime_type: + # Add MIME type fragment if missing (older refs without it) + try: + if ws.is_path: + info = await workspace_manager.get_file_info_by_path( + ws.file_ref + ) + else: + info = await workspace_manager.get_file_info(ws.file_ref) + if info: + return MediaFileType(f"{file}#{info.mimeType}") + except Exception: + pass return MediaFileType(file) # Save new content to workspace @@ -354,7 +397,7 @@ async def store_media_file( filename=filename, overwrite=True, ) - return MediaFileType(f"workspace://{file_record.id}") + return MediaFileType(f"workspace://{file_record.id}#{file_record.mimeType}") else: raise ValueError(f"Invalid return_format: {return_format}") diff --git a/autogpt_platform/backend/backend/util/request.py b/autogpt_platform/backend/backend/util/request.py index 9744372b15..95e5ee32f7 100644 --- a/autogpt_platform/backend/backend/util/request.py +++ b/autogpt_platform/backend/backend/util/request.py @@ -157,12 +157,7 @@ async def validate_url( is_trusted: Boolean indicating if the hostname is in trusted_origins ip_addresses: List of IP addresses for the host; empty if the host is trusted """ - # Canonicalize URL - url = url.strip("/ ").replace("\\", "/") - parsed = urlparse(url) - if not parsed.scheme: - url = f"http://{url}" - parsed = urlparse(url) + parsed = parse_url(url) # Check scheme if parsed.scheme not in ALLOWED_SCHEMES: @@ -220,6 +215,17 @@ async def validate_url( ) +def parse_url(url: str) -> URL: + """Canonicalizes and parses a URL string.""" + url = url.strip("/ ").replace("\\", "/") + + # Ensure scheme is present for proper parsing + if not re.match(r"[a-z0-9+.\-]+://", url): + url = f"http://{url}" + + return urlparse(url) + + def pin_url(url: URL, ip_addresses: Optional[list[str]] = None) -> URL: """ Pins a URL to a specific IP address to prevent DNS rebinding attacks. diff --git a/autogpt_platform/backend/backend/util/settings.py b/autogpt_platform/backend/backend/util/settings.py index aa28a4c9ac..50b7428160 100644 --- a/autogpt_platform/backend/backend/util/settings.py +++ b/autogpt_platform/backend/backend/util/settings.py @@ -656,6 +656,7 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings): e2b_api_key: str = Field(default="", description="E2B API key") nvidia_api_key: str = Field(default="", description="Nvidia API key") mem0_api_key: str = Field(default="", description="Mem0 API key") + elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key") linear_client_id: str = Field(default="", description="Linear client ID") linear_client_secret: str = Field(default="", description="Linear client secret") diff --git a/autogpt_platform/backend/backend/util/workspace.py b/autogpt_platform/backend/backend/util/workspace.py index a2f1a61b9e..86413b640a 100644 --- a/autogpt_platform/backend/backend/util/workspace.py +++ b/autogpt_platform/backend/backend/util/workspace.py @@ -22,6 +22,7 @@ from backend.data.workspace import ( soft_delete_workspace_file, ) from backend.util.settings import Config +from backend.util.virus_scanner import scan_content_safe from backend.util.workspace_storage import compute_file_checksum, get_workspace_storage logger = logging.getLogger(__name__) @@ -187,6 +188,9 @@ class WorkspaceManager: f"{Config().max_file_size_mb}MB limit" ) + # Virus scan content before persisting (defense in depth) + await scan_content_safe(content, filename=filename) + # Determine path with session scoping if path is None: path = f"/{filename}" diff --git a/autogpt_platform/backend/poetry.lock b/autogpt_platform/backend/poetry.lock index 91ac358ade..61da8c974f 100644 --- a/autogpt_platform/backend/poetry.lock +++ b/autogpt_platform/backend/poetry.lock @@ -1169,6 +1169,29 @@ attrs = ">=21.3.0" e2b = ">=1.5.4,<2.0.0" httpx = ">=0.20.0,<1.0.0" +[[package]] +name = "elevenlabs" +version = "1.59.0" +description = "" +optional = false +python-versions = "<4.0,>=3.8" +groups = ["main"] +files = [ + {file = "elevenlabs-1.59.0-py3-none-any.whl", hash = "sha256:468145db81a0bc867708b4a8619699f75583e9481b395ec1339d0b443da771ed"}, + {file = "elevenlabs-1.59.0.tar.gz", hash = "sha256:16e735bd594e86d415dd445d249c8cc28b09996cfd627fbc10102c0a84698859"}, +] + +[package.dependencies] +httpx = ">=0.21.2" +pydantic = ">=1.9.2" +pydantic-core = ">=2.18.2,<3.0.0" +requests = ">=2.20" +typing_extensions = ">=4.0.0" +websockets = ">=11.0" + +[package.extras] +pyaudio = ["pyaudio (>=0.2.14)"] + [[package]] name = "email-validator" version = "2.2.0" @@ -7361,6 +7384,28 @@ files = [ defusedxml = ">=0.7.1,<0.8.0" requests = "*" +[[package]] +name = "yt-dlp" +version = "2025.12.8" +description = "A feature-rich command-line audio/video downloader" +optional = false +python-versions = ">=3.10" +groups = ["main"] +files = [ + {file = "yt_dlp-2025.12.8-py3-none-any.whl", hash = "sha256:36e2584342e409cfbfa0b5e61448a1c5189e345cf4564294456ee509e7d3e065"}, + {file = "yt_dlp-2025.12.8.tar.gz", hash = "sha256:b773c81bb6b71cb2c111cfb859f453c7a71cf2ef44eff234ff155877184c3e4f"}, +] + +[package.extras] +build = ["build", "hatchling (>=1.27.0)", "pip", "setuptools (>=71.0.2)", "wheel"] +curl-cffi = ["curl-cffi (>=0.5.10,<0.6.dev0 || >=0.10.dev0,<0.14) ; implementation_name == \"cpython\""] +default = ["brotli ; implementation_name == \"cpython\"", "brotlicffi ; implementation_name != \"cpython\"", "certifi", "mutagen", "pycryptodomex", "requests (>=2.32.2,<3)", "urllib3 (>=2.0.2,<3)", "websockets (>=13.0)", "yt-dlp-ejs (==0.3.2)"] +dev = ["autopep8 (>=2.0,<3.0)", "pre-commit", "pytest (>=8.1,<9.0)", "pytest-rerunfailures (>=14.0,<15.0)", "ruff (>=0.14.0,<0.15.0)"] +pyinstaller = ["pyinstaller (>=6.17.0)"] +secretstorage = ["cffi", "secretstorage"] +static-analysis = ["autopep8 (>=2.0,<3.0)", "ruff (>=0.14.0,<0.15.0)"] +test = ["pytest (>=8.1,<9.0)", "pytest-rerunfailures (>=14.0,<15.0)"] + [[package]] name = "zerobouncesdk" version = "1.1.2" @@ -7512,4 +7557,4 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "ee5742dc1a9df50dfc06d4b26a1682cbb2b25cab6b79ce5625ec272f93e4f4bf" +content-hash = "8239323f9ae6713224dffd1fe8ba8b449fe88b6c3c7a90940294a74f43a0387a" diff --git a/autogpt_platform/backend/pyproject.toml b/autogpt_platform/backend/pyproject.toml index fe263e47c0..24aea39f33 100644 --- a/autogpt_platform/backend/pyproject.toml +++ b/autogpt_platform/backend/pyproject.toml @@ -20,6 +20,7 @@ click = "^8.2.0" cryptography = "^45.0" discord-py = "^2.5.2" e2b-code-interpreter = "^1.5.2" +elevenlabs = "^1.50.0" fastapi = "^0.116.1" feedparser = "^6.0.11" flake8 = "^7.3.0" @@ -71,6 +72,7 @@ tweepy = "^4.16.0" uvicorn = { extras = ["standard"], version = "^0.35.0" } websockets = "^15.0" youtube-transcript-api = "^1.2.1" +yt-dlp = "2025.12.08" zerobouncesdk = "^1.1.2" # NOTE: please insert new dependencies in their alphabetical location pytest-snapshot = "^0.9.0" diff --git a/autogpt_platform/backend/snapshots/grph_single b/autogpt_platform/backend/snapshots/grph_single index 1811a57ec8..7fa5783577 100644 --- a/autogpt_platform/backend/snapshots/grph_single +++ b/autogpt_platform/backend/snapshots/grph_single @@ -3,7 +3,6 @@ "credentials_input_schema": { "properties": {}, "required": [], - "title": "TestGraphCredentialsInputSchema", "type": "object" }, "description": "A test graph", diff --git a/autogpt_platform/backend/snapshots/grphs_all b/autogpt_platform/backend/snapshots/grphs_all index 0b314d96f9..9ccb4a6dc8 100644 --- a/autogpt_platform/backend/snapshots/grphs_all +++ b/autogpt_platform/backend/snapshots/grphs_all @@ -1,34 +1,14 @@ [ { - "credentials_input_schema": { - "properties": {}, - "required": [], - "title": "TestGraphCredentialsInputSchema", - "type": "object" - }, + "created_at": "2025-09-04T13:37:00", "description": "A test graph", "forked_from_id": null, "forked_from_version": null, - "has_external_trigger": false, - "has_human_in_the_loop": false, - "has_sensitive_action": false, "id": "graph-123", - "input_schema": { - "properties": {}, - "required": [], - "type": "object" - }, "instructions": null, "is_active": true, "name": "Test Graph", - "output_schema": { - "properties": {}, - "required": [], - "type": "object" - }, "recommended_schedule_cron": null, - "sub_graphs": [], - "trigger_setup_info": null, "user_id": "3e53486c-cf57-477e-ba2a-cb02dc828e1a", "version": 1 } diff --git a/autogpt_platform/backend/test/agent_generator/test_core_integration.py b/autogpt_platform/backend/test/agent_generator/test_core_integration.py index 05ce4a3aff..528763e751 100644 --- a/autogpt_platform/backend/test/agent_generator/test_core_integration.py +++ b/autogpt_platform/backend/test/agent_generator/test_core_integration.py @@ -111,9 +111,7 @@ class TestGenerateAgent: instructions = {"type": "instructions", "steps": ["Step 1"]} result = await core.generate_agent(instructions) - # library_agents defaults to None - mock_external.assert_called_once_with(instructions, None) - # Result should have id, version, is_active added if not present + mock_external.assert_called_once_with(instructions, None, None, None) assert result is not None assert result["name"] == "Test Agent" assert "id" in result @@ -177,8 +175,9 @@ class TestGenerateAgentPatch: current_agent = {"nodes": [], "links": []} result = await core.generate_agent_patch("Add a node", current_agent) - # library_agents defaults to None - mock_external.assert_called_once_with("Add a node", current_agent, None) + mock_external.assert_called_once_with( + "Add a node", current_agent, None, None, None + ) assert result == expected_result @pytest.mark.asyncio diff --git a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/components/AgentOnboardingCredentials/AgentOnboardingCredentials.tsx b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/components/AgentOnboardingCredentials/AgentOnboardingCredentials.tsx index f0bb652a06..a8efa344a2 100644 --- a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/components/AgentOnboardingCredentials/AgentOnboardingCredentials.tsx +++ b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/components/AgentOnboardingCredentials/AgentOnboardingCredentials.tsx @@ -1,5 +1,5 @@ import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput"; -import { GraphMeta } from "@/app/api/__generated__/models/graphMeta"; +import { GraphModel } from "@/app/api/__generated__/models/graphModel"; import { CredentialsInput } from "@/components/contextual/CredentialsInput/CredentialsInput"; import { useState } from "react"; import { getSchemaDefaultCredentials } from "../../helpers"; @@ -9,7 +9,7 @@ type Credential = CredentialsMetaInput | undefined; type Credentials = Record; type Props = { - agent: GraphMeta | null; + agent: GraphModel | null; siblingInputs?: Record; onCredentialsChange: ( credentials: Record, diff --git a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/components/AgentOnboardingCredentials/helpers.ts b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/components/AgentOnboardingCredentials/helpers.ts index 7a456d63e4..a4947015c4 100644 --- a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/components/AgentOnboardingCredentials/helpers.ts +++ b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/components/AgentOnboardingCredentials/helpers.ts @@ -1,9 +1,9 @@ import { CredentialsMetaInput } from "@/app/api/__generated__/models/credentialsMetaInput"; -import { GraphMeta } from "@/app/api/__generated__/models/graphMeta"; +import { GraphModel } from "@/app/api/__generated__/models/graphModel"; import { BlockIOCredentialsSubSchema } from "@/lib/autogpt-server-api/types"; export function getCredentialFields( - agent: GraphMeta | null, + agent: GraphModel | null, ): AgentCredentialsFields { if (!agent) return {}; diff --git a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/helpers.ts b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/helpers.ts index 62f5c564ff..ff1f8d452c 100644 --- a/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/helpers.ts +++ b/autogpt_platform/frontend/src/app/(no-navbar)/onboarding/5-run/helpers.ts @@ -3,10 +3,10 @@ import type { CredentialsMetaInput, } from "@/lib/autogpt-server-api/types"; import type { InputValues } from "./types"; -import { GraphMeta } from "@/app/api/__generated__/models/graphMeta"; +import { GraphModel } from "@/app/api/__generated__/models/graphModel"; export function computeInitialAgentInputs( - agent: GraphMeta | null, + agent: GraphModel | null, existingInputs?: InputValues | null, ): InputValues { const properties = agent?.input_schema?.properties || {}; @@ -29,7 +29,7 @@ export function computeInitialAgentInputs( } type IsRunDisabledParams = { - agent: GraphMeta | null; + agent: GraphModel | null; isRunning: boolean; agentInputs: InputValues | null | undefined; }; diff --git a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts index 41d05a9afb..fd67519957 100644 --- a/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts +++ b/autogpt_platform/frontend/src/app/(platform)/auth/integrations/oauth_callback/route.ts @@ -1,6 +1,17 @@ import { OAuthPopupResultMessage } from "./types"; import { NextResponse } from "next/server"; +/** + * Safely encode a value as JSON for embedding in a script tag. + * Escapes characters that could break out of the script context to prevent XSS. + */ +function safeJsonStringify(value: unknown): string { + return JSON.stringify(value) + .replace(//g, "\\u003e") + .replace(/&/g, "\\u0026"); +} + // This route is intended to be used as the callback for integration OAuth flows, // controlled by the CredentialsInput component. The CredentialsInput opens the login // page in a pop-up window, which then redirects to this route to close the loop. @@ -23,12 +34,13 @@ export async function GET(request: Request) { console.debug("Sending message to opener:", message); // Return a response with the message as JSON and a script to close the window + // Use safeJsonStringify to prevent XSS by escaping <, >, and & characters return new NextResponse( ` diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/BlocksControl.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/BlocksControl.tsx index f5451e6d4d..99b66fe1dc 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/BlocksControl.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/BlocksControl.tsx @@ -30,6 +30,8 @@ import { } from "@/components/atoms/Tooltip/BaseTooltip"; import { GraphMeta } from "@/lib/autogpt-server-api"; import jaro from "jaro-winkler"; +import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs"; +import { okData } from "@/app/api/helpers"; type _Block = Omit & { uiKey?: string; @@ -107,6 +109,8 @@ export function BlocksControl({ .filter((b) => b.uiType !== BlockUIType.AGENT) .sort((a, b) => a.name.localeCompare(b.name)); + // Agent blocks are created from GraphMeta which doesn't include schemas. + // Schemas will be fetched on-demand when the block is actually added. const agentBlockList = flows .map((flow): _Block => { return { @@ -116,8 +120,9 @@ export function BlocksControl({ `Ver.${flow.version}` + (flow.description ? ` | ${flow.description}` : ""), categories: [{ category: "AGENT", description: "" }], - inputSchema: flow.input_schema, - outputSchema: flow.output_schema, + // Empty schemas - will be populated when block is added + inputSchema: { type: "object", properties: {} }, + outputSchema: { type: "object", properties: {} }, staticOutput: false, uiType: BlockUIType.AGENT, costs: [], @@ -125,8 +130,7 @@ export function BlocksControl({ hardcodedValues: { graph_id: flow.id, graph_version: flow.version, - input_schema: flow.input_schema, - output_schema: flow.output_schema, + // Schemas will be fetched on-demand when block is added }, }; }) @@ -182,6 +186,37 @@ export function BlocksControl({ setSelectedCategory(null); }, []); + // Handler to add a block, fetching graph data on-demand for agent blocks + const handleAddBlock = useCallback( + async (block: _Block & { notAvailable: string | null }) => { + if (block.notAvailable) return; + + // For agent blocks, fetch the full graph to get schemas + if (block.uiType === BlockUIType.AGENT && block.hardcodedValues) { + const graphID = block.hardcodedValues.graph_id as string; + const graphVersion = block.hardcodedValues.graph_version as number; + const graphData = okData( + await getV1GetSpecificGraph(graphID, { version: graphVersion }), + ); + + if (graphData) { + addBlock(block.id, block.name, { + ...block.hardcodedValues, + input_schema: graphData.input_schema, + output_schema: graphData.output_schema, + }); + } else { + // Fallback: add without schemas (will be incomplete) + console.error("Failed to fetch graph data for agent block"); + addBlock(block.id, block.name, block.hardcodedValues || {}); + } + } else { + addBlock(block.id, block.name, block.hardcodedValues || {}); + } + }, + [addBlock], + ); + // Extract unique categories from blocks const categories = useMemo(() => { return Array.from( @@ -303,10 +338,7 @@ export function BlocksControl({ }), ); }} - onClick={() => - !block.notAvailable && - addBlock(block.id, block.name, block?.hardcodedValues || {}) - } + onClick={() => handleAddBlock(block)} title={block.notAvailable ?? undefined} >
(null); + // Prepare renderers for each item when enhanced mode is enabled + const getItemRenderer = useMemo(() => { + if (!enableEnhancedOutputHandling) return null; + return (item: unknown) => { + const metadata: OutputMetadata = {}; + return globalRegistry.getRenderer(item, metadata); + }; + }, [enableEnhancedOutputHandling]); + const copyData = (pin: string, data: string) => { navigator.clipboard.writeText(data).then(() => { toast({ @@ -102,15 +120,31 @@ export default function DataTable({
- {value.map((item, index) => ( - - - {index < value.length - 1 && ", "} - - ))} + {value.map((item, index) => { + const renderer = getItemRenderer?.(item); + if (enableEnhancedOutputHandling && renderer) { + const metadata: OutputMetadata = {}; + return ( + + + {index < value.length - 1 && ", "} + + ); + } + return ( + + + {index < value.length - 1 && ", "} + + ); + })} diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/Flow/Flow.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/Flow/Flow.tsx index a54a9ef386..67b3cad9af 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/Flow/Flow.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/Flow/Flow.tsx @@ -29,13 +29,17 @@ import "@xyflow/react/dist/style.css"; import { ConnectedEdge, CustomNode } from "../CustomNode/CustomNode"; import "./flow.css"; import { + BlockIORootSchema, BlockUIType, formatEdgeID, GraphExecutionID, GraphID, GraphMeta, LibraryAgent, + SpecialBlockID, } from "@/lib/autogpt-server-api"; +import { getV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs"; +import { okData } from "@/app/api/helpers"; import { IncompatibilityInfo } from "../../../hooks/useSubAgentUpdate/types"; import { Key, storage } from "@/services/storage/local-storage"; import { findNewlyAddedBlockCoordinates, getTypeColor } from "@/lib/utils"; @@ -687,8 +691,94 @@ const FlowEditor: React.FC<{ [getNode, updateNode, nodes], ); + /* Shared helper to create and add a node */ + const createAndAddNode = useCallback( + async ( + blockID: string, + blockName: string, + hardcodedValues: Record, + position: { x: number; y: number }, + ): Promise => { + const nodeSchema = availableBlocks.find((node) => node.id === blockID); + if (!nodeSchema) { + console.error(`Schema not found for block ID: ${blockID}`); + return null; + } + + // For agent blocks, fetch the full graph to get schemas + let inputSchema: BlockIORootSchema = nodeSchema.inputSchema; + let outputSchema: BlockIORootSchema = nodeSchema.outputSchema; + let finalHardcodedValues = hardcodedValues; + + if (blockID === SpecialBlockID.AGENT) { + const graphID = hardcodedValues.graph_id as string; + const graphVersion = hardcodedValues.graph_version as number; + const graphData = okData( + await getV1GetSpecificGraph(graphID, { version: graphVersion }), + ); + + if (graphData) { + inputSchema = graphData.input_schema as BlockIORootSchema; + outputSchema = graphData.output_schema as BlockIORootSchema; + finalHardcodedValues = { + ...hardcodedValues, + input_schema: graphData.input_schema, + output_schema: graphData.output_schema, + }; + } else { + console.error("Failed to fetch graph data for agent block"); + } + } + + const newNode: CustomNode = { + id: nodeId.toString(), + type: "custom", + position, + data: { + blockType: blockName, + blockCosts: nodeSchema.costs || [], + title: `${blockName} ${nodeId}`, + description: nodeSchema.description, + categories: nodeSchema.categories, + inputSchema: inputSchema, + outputSchema: outputSchema, + hardcodedValues: finalHardcodedValues, + connections: [], + isOutputOpen: false, + block_id: blockID, + isOutputStatic: nodeSchema.staticOutput, + uiType: nodeSchema.uiType, + }, + }; + + addNodes(newNode); + setNodeId((prevId) => prevId + 1); + clearNodesStatusAndOutput(); + + history.push({ + type: "ADD_NODE", + payload: { node: { ...newNode, ...newNode.data } }, + undo: () => deleteElements({ nodes: [{ id: newNode.id }] }), + redo: () => addNodes(newNode), + }); + + return newNode; + }, + [ + availableBlocks, + nodeId, + addNodes, + deleteElements, + clearNodesStatusAndOutput, + ], + ); + const addNode = useCallback( - (blockId: string, nodeType: string, hardcodedValues: any = {}) => { + async ( + blockId: string, + nodeType: string, + hardcodedValues: Record = {}, + ) => { const nodeSchema = availableBlocks.find((node) => node.id === blockId); if (!nodeSchema) { console.error(`Schema not found for block ID: ${blockId}`); @@ -707,73 +797,42 @@ const FlowEditor: React.FC<{ // Alternative: We could also use D3 force, Intersection for this (React flow Pro examples) const { x, y } = getViewport(); - const viewportCoordinates = + const position = nodeDimensions && Object.keys(nodeDimensions).length > 0 - ? // we will get all the dimension of nodes, then store - findNewlyAddedBlockCoordinates( + ? findNewlyAddedBlockCoordinates( nodeDimensions, nodeSchema.uiType == BlockUIType.NOTE ? 300 : 500, 60, 1.0, ) - : // we will get all the dimension of nodes, then store - { + : { x: window.innerWidth / 2 - x, y: window.innerHeight / 2 - y, }; - const newNode: CustomNode = { - id: nodeId.toString(), - type: "custom", - position: viewportCoordinates, // Set the position to the calculated viewport center - data: { - blockType: nodeType, - blockCosts: nodeSchema.costs, - title: `${nodeType} ${nodeId}`, - description: nodeSchema.description, - categories: nodeSchema.categories, - inputSchema: nodeSchema.inputSchema, - outputSchema: nodeSchema.outputSchema, - hardcodedValues: hardcodedValues, - connections: [], - isOutputOpen: false, - block_id: blockId, - isOutputStatic: nodeSchema.staticOutput, - uiType: nodeSchema.uiType, - }, - }; - - addNodes(newNode); - setNodeId((prevId) => prevId + 1); - clearNodesStatusAndOutput(); // Clear status and output when a new node is added + const newNode = await createAndAddNode( + blockId, + nodeType, + hardcodedValues, + position, + ); + if (!newNode) return; setViewport( { - // Rough estimate of the dimension of the node is: 500x400px. - // Though we skip shifting the X, considering the block menu side-bar. - x: -viewportCoordinates.x * 0.8 + (window.innerWidth - 0.0) / 2, - y: -viewportCoordinates.y * 0.8 + (window.innerHeight - 400) / 2, + x: -position.x * 0.8 + (window.innerWidth - 0.0) / 2, + y: -position.y * 0.8 + (window.innerHeight - 400) / 2, zoom: 0.8, }, { duration: 500 }, ); - - history.push({ - type: "ADD_NODE", - payload: { node: { ...newNode, ...newNode.data } }, - undo: () => deleteElements({ nodes: [{ id: newNode.id }] }), - redo: () => addNodes(newNode), - }); }, [ - nodeId, getViewport, setViewport, availableBlocks, - addNodes, nodeDimensions, - deleteElements, - clearNodesStatusAndOutput, + createAndAddNode, ], ); @@ -920,7 +979,7 @@ const FlowEditor: React.FC<{ }, []); const onDrop = useCallback( - (event: React.DragEvent) => { + async (event: React.DragEvent) => { event.preventDefault(); const blockData = event.dataTransfer.getData("application/reactflow"); @@ -935,62 +994,17 @@ const FlowEditor: React.FC<{ y: event.clientY, }); - // Find the block schema - const nodeSchema = availableBlocks.find((node) => node.id === blockId); - if (!nodeSchema) { - console.error(`Schema not found for block ID: ${blockId}`); - return; - } - - // Create the new node at the drop position - const newNode: CustomNode = { - id: nodeId.toString(), - type: "custom", + await createAndAddNode( + blockId, + blockName, + hardcodedValues || {}, position, - data: { - blockType: blockName, - blockCosts: nodeSchema.costs || [], - title: `${blockName} ${nodeId}`, - description: nodeSchema.description, - categories: nodeSchema.categories, - inputSchema: nodeSchema.inputSchema, - outputSchema: nodeSchema.outputSchema, - hardcodedValues: hardcodedValues, - connections: [], - isOutputOpen: false, - block_id: blockId, - uiType: nodeSchema.uiType, - }, - }; - - history.push({ - type: "ADD_NODE", - payload: { node: { ...newNode, ...newNode.data } }, - undo: () => { - deleteElements({ nodes: [{ id: newNode.id } as any], edges: [] }); - }, - redo: () => { - addNodes([newNode]); - }, - }); - addNodes([newNode]); - clearNodesStatusAndOutput(); - - setNodeId((prevId) => prevId + 1); + ); } catch (error) { console.error("Failed to drop block:", error); } }, - [ - nodeId, - availableBlocks, - nodes, - edges, - addNodes, - screenToFlowPosition, - deleteElements, - clearNodesStatusAndOutput, - ], + [screenToFlowPosition, createAndAddNode], ); const buildContextValue: BuilderContextType = useMemo( diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/NodeOutputs.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/NodeOutputs.tsx index d90b7d6a4c..2111db7d99 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/NodeOutputs.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/NodeOutputs.tsx @@ -1,8 +1,14 @@ -import React, { useContext, useState } from "react"; +import React, { useContext, useMemo, useState } from "react"; import { Button } from "@/components/__legacy__/ui/button"; import { Maximize2 } from "lucide-react"; import * as Separator from "@radix-ui/react-separator"; import { ContentRenderer } from "@/components/__legacy__/ui/render"; +import type { OutputMetadata } from "@/components/contextual/OutputRenderers"; +import { + globalRegistry, + OutputItem, +} from "@/components/contextual/OutputRenderers"; +import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag"; import { beautifyString } from "@/lib/utils"; @@ -21,6 +27,9 @@ export default function NodeOutputs({ data, }: NodeOutputsProps) { const builderContext = useContext(BuilderContext); + const enableEnhancedOutputHandling = useGetFlag( + Flag.ENABLE_ENHANCED_OUTPUT_HANDLING, + ); const [expandedDialog, setExpandedDialog] = useState<{ isOpen: boolean; @@ -37,6 +46,15 @@ export default function NodeOutputs({ const { getNodeTitle } = builderContext; + // Prepare renderers for each item when enhanced mode is enabled + const getItemRenderer = useMemo(() => { + if (!enableEnhancedOutputHandling) return null; + return (item: unknown) => { + const metadata: OutputMetadata = {}; + return globalRegistry.getRenderer(item, metadata); + }; + }, [enableEnhancedOutputHandling]); + const getBeautifiedPinName = (pin: string) => { if (!pin.startsWith("tools_^_")) { return beautifyString(pin); @@ -87,15 +105,31 @@ export default function NodeOutputs({
Data:
- {dataArray.slice(0, 10).map((item, index) => ( - - - {index < Math.min(dataArray.length, 10) - 1 && ", "} - - ))} + {dataArray.slice(0, 10).map((item, index) => { + const renderer = getItemRenderer?.(item); + if (enableEnhancedOutputHandling && renderer) { + const metadata: OutputMetadata = {}; + return ( + + + {index < Math.min(dataArray.length, 10) - 1 && ", "} + + ); + } + return ( + + + {index < Math.min(dataArray.length, 10) - 1 && ", "} + + ); + })} {dataArray.length > 10 && (
diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/RunnerInputUI.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/RunnerInputUI.tsx index 15983be9f5..cb06a79683 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/RunnerInputUI.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/RunnerInputUI.tsx @@ -4,13 +4,13 @@ import { AgentRunDraftView } from "@/app/(platform)/library/agents/[id]/componen import { Dialog } from "@/components/molecules/Dialog/Dialog"; import type { CredentialsMetaInput, - GraphMeta, + Graph, } from "@/lib/autogpt-server-api/types"; interface RunInputDialogProps { isOpen: boolean; doClose: () => void; - graph: GraphMeta; + graph: Graph; doRun?: ( inputs: Record, credentialsInputs: Record, diff --git a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/RunnerUIWrapper.tsx b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/RunnerUIWrapper.tsx index a9af065a5d..b1d40fb919 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/RunnerUIWrapper.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/build/components/legacy-builder/RunnerUIWrapper.tsx @@ -9,13 +9,13 @@ import { CustomNodeData } from "@/app/(platform)/build/components/legacy-builder import { BlockUIType, CredentialsMetaInput, - GraphMeta, + Graph, } from "@/lib/autogpt-server-api/types"; import RunnerOutputUI, { OutputNodeInfo } from "./RunnerOutputUI"; import { RunnerInputDialog } from "./RunnerInputUI"; interface RunnerUIWrapperProps { - graph: GraphMeta; + graph: Graph; nodes: Node[]; graphExecutionError?: string | null; saveAndRun: ( diff --git a/autogpt_platform/frontend/src/app/(platform)/build/hooks/useSubAgentUpdate/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/build/hooks/useSubAgentUpdate/helpers.ts index aece7e9811..69593a142b 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/hooks/useSubAgentUpdate/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/hooks/useSubAgentUpdate/helpers.ts @@ -1,5 +1,5 @@ import { GraphInputSchema } from "@/lib/autogpt-server-api"; -import { GraphMetaLike, IncompatibilityInfo } from "./types"; +import { GraphLike, IncompatibilityInfo } from "./types"; // Helper type for schema properties - the generated types are too loose type SchemaProperties = Record; @@ -36,7 +36,7 @@ export function getSchemaRequired(schema: unknown): SchemaRequired { */ export function createUpdatedAgentNodeInputs( currentInputs: Record, - latestSubGraphVersion: GraphMetaLike, + latestSubGraphVersion: GraphLike, ): Record { return { ...currentInputs, diff --git a/autogpt_platform/frontend/src/app/(platform)/build/hooks/useSubAgentUpdate/types.ts b/autogpt_platform/frontend/src/app/(platform)/build/hooks/useSubAgentUpdate/types.ts index 83f83155db..6c115f20a3 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/hooks/useSubAgentUpdate/types.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/hooks/useSubAgentUpdate/types.ts @@ -1,7 +1,11 @@ -import type { GraphMeta as LegacyGraphMeta } from "@/lib/autogpt-server-api"; +import type { + Graph as LegacyGraph, + GraphMeta as LegacyGraphMeta, +} from "@/lib/autogpt-server-api"; +import type { GraphModel as GeneratedGraph } from "@/app/api/__generated__/models/graphModel"; import type { GraphMeta as GeneratedGraphMeta } from "@/app/api/__generated__/models/graphMeta"; -export type SubAgentUpdateInfo = { +export type SubAgentUpdateInfo = { hasUpdate: boolean; currentVersion: number; latestVersion: number; @@ -10,7 +14,10 @@ export type SubAgentUpdateInfo = { incompatibilities: IncompatibilityInfo | null; }; -// Union type for GraphMeta that works with both legacy and new builder +// Union type for Graph (with schemas) that works with both legacy and new builder +export type GraphLike = LegacyGraph | GeneratedGraph; + +// Union type for GraphMeta (without schemas) for version detection export type GraphMetaLike = LegacyGraphMeta | GeneratedGraphMeta; export type IncompatibilityInfo = { diff --git a/autogpt_platform/frontend/src/app/(platform)/build/hooks/useSubAgentUpdate/useSubAgentUpdate.ts b/autogpt_platform/frontend/src/app/(platform)/build/hooks/useSubAgentUpdate/useSubAgentUpdate.ts index 315e337cd6..7ad10ea697 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/hooks/useSubAgentUpdate/useSubAgentUpdate.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/hooks/useSubAgentUpdate/useSubAgentUpdate.ts @@ -1,5 +1,11 @@ import { useMemo } from "react"; -import { GraphInputSchema, GraphOutputSchema } from "@/lib/autogpt-server-api"; +import type { + GraphInputSchema, + GraphOutputSchema, +} from "@/lib/autogpt-server-api"; +import type { GraphModel } from "@/app/api/__generated__/models/graphModel"; +import { useGetV1GetSpecificGraph } from "@/app/api/__generated__/endpoints/graphs/graphs"; +import { okData } from "@/app/api/helpers"; import { getEffectiveType } from "@/lib/utils"; import { EdgeLike, getSchemaProperties, getSchemaRequired } from "./helpers"; import { @@ -11,26 +17,38 @@ import { /** * Checks if a newer version of a sub-agent is available and determines compatibility */ -export function useSubAgentUpdate( +export function useSubAgentUpdate( nodeID: string, graphID: string | undefined, graphVersion: number | undefined, currentInputSchema: GraphInputSchema | undefined, currentOutputSchema: GraphOutputSchema | undefined, connections: EdgeLike[], - availableGraphs: T[], -): SubAgentUpdateInfo { + availableGraphs: GraphMetaLike[], +): SubAgentUpdateInfo { // Find the latest version of the same graph - const latestGraph = useMemo(() => { + const latestGraphInfo = useMemo(() => { if (!graphID) return null; return availableGraphs.find((graph) => graph.id === graphID) || null; }, [graphID, availableGraphs]); - // Check if there's an update available + // Check if there's a newer version available const hasUpdate = useMemo(() => { - if (!latestGraph || graphVersion === undefined) return false; - return latestGraph.version! > graphVersion; - }, [latestGraph, graphVersion]); + if (!latestGraphInfo || graphVersion === undefined) return false; + return latestGraphInfo.version! > graphVersion; + }, [latestGraphInfo, graphVersion]); + + // Fetch full graph IF an update is detected + const { data: latestGraph } = useGetV1GetSpecificGraph( + graphID ?? "", + { version: latestGraphInfo?.version }, + { + query: { + enabled: hasUpdate && !!graphID && !!latestGraphInfo?.version, + select: okData, + }, + }, + ); // Get connected input and output handles for this specific node const connectedHandles = useMemo(() => { @@ -152,8 +170,8 @@ export function useSubAgentUpdate( return { hasUpdate, currentVersion: graphVersion || 0, - latestVersion: latestGraph?.version || 0, - latestGraph, + latestVersion: latestGraphInfo?.version || 0, + latestGraph: latestGraph || null, isCompatible: compatibilityResult.isCompatible, incompatibilities: compatibilityResult.incompatibilities, }; diff --git a/autogpt_platform/frontend/src/app/(platform)/build/stores/graphStore.ts b/autogpt_platform/frontend/src/app/(platform)/build/stores/graphStore.ts index 6961884732..c1eba556d2 100644 --- a/autogpt_platform/frontend/src/app/(platform)/build/stores/graphStore.ts +++ b/autogpt_platform/frontend/src/app/(platform)/build/stores/graphStore.ts @@ -18,7 +18,7 @@ interface GraphStore { outputSchema: Record | null, ) => void; - // Available graphs; used for sub-graph updates + // Available graphs; used for sub-graph updated version detection availableSubGraphs: GraphMeta[]; setAvailableSubGraphs: (graphs: GraphMeta[]) => void; diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts index 74fd663ab2..913c4d7ded 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/components/CopilotShell/useCopilotShell.ts @@ -11,7 +11,6 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint"; import { useSupabase } from "@/lib/supabase/hooks/useSupabase"; import { useQueryClient } from "@tanstack/react-query"; import { usePathname, useSearchParams } from "next/navigation"; -import { useRef } from "react"; import { useCopilotStore } from "../../copilot-page-store"; import { useCopilotSessionId } from "../../useCopilotSessionId"; import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer"; @@ -70,41 +69,16 @@ export function useCopilotShell() { }); const stopStream = useChatStore((s) => s.stopStream); - const onStreamComplete = useChatStore((s) => s.onStreamComplete); - const isStreaming = useCopilotStore((s) => s.isStreaming); const isCreatingSession = useCopilotStore((s) => s.isCreatingSession); - const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession); - const openInterruptModal = useCopilotStore((s) => s.openInterruptModal); - const pendingActionRef = useRef<(() => void) | null>(null); - - async function stopCurrentStream() { - if (!currentSessionId) return; - - setIsSwitchingSession(true); - await new Promise((resolve) => { - const unsubscribe = onStreamComplete((completedId) => { - if (completedId === currentSessionId) { - clearTimeout(timeout); - unsubscribe(); - resolve(); - } - }); - const timeout = setTimeout(() => { - unsubscribe(); - resolve(); - }, 3000); - stopStream(currentSessionId); - }); - - queryClient.invalidateQueries({ - queryKey: getGetV2GetSessionQueryKey(currentSessionId), - }); - setIsSwitchingSession(false); - } - - function selectSession(sessionId: string) { + function handleSessionClick(sessionId: string) { if (sessionId === currentSessionId) return; + + // Stop current stream - SSE reconnection allows resuming later + if (currentSessionId) { + stopStream(currentSessionId); + } + if (recentlyCreatedSessionsRef.current.has(sessionId)) { queryClient.invalidateQueries({ queryKey: getGetV2GetSessionQueryKey(sessionId), @@ -114,7 +88,12 @@ export function useCopilotShell() { if (isMobile) handleCloseDrawer(); } - function startNewChat() { + function handleNewChatClick() { + // Stop current stream - SSE reconnection allows resuming later + if (currentSessionId) { + stopStream(currentSessionId); + } + resetPagination(); queryClient.invalidateQueries({ queryKey: getGetV2ListSessionsQueryKey(), @@ -123,32 +102,6 @@ export function useCopilotShell() { if (isMobile) handleCloseDrawer(); } - function handleSessionClick(sessionId: string) { - if (sessionId === currentSessionId) return; - - if (isStreaming) { - pendingActionRef.current = async () => { - await stopCurrentStream(); - selectSession(sessionId); - }; - openInterruptModal(pendingActionRef.current); - } else { - selectSession(sessionId); - } - } - - function handleNewChatClick() { - if (isStreaming) { - pendingActionRef.current = async () => { - await stopCurrentStream(); - startNewChat(); - }; - openInterruptModal(pendingActionRef.current); - } else { - startNewChat(); - } - } - return { isMobile, isDrawerOpen, diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts index 692a5741f4..c6e479f896 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts @@ -26,8 +26,20 @@ export function buildCopilotChatUrl(prompt: string): string { export function getQuickActions(): string[] { return [ - "Show me what I can automate", - "Design a custom workflow", - "Help me with content creation", + "I don't know where to start, just ask me stuff", + "I do the same thing every week and it's killing me", + "Help me find where I'm wasting my time", ]; } + +export function getInputPlaceholder(width?: number) { + if (!width) return "What's your role and what eats up most of your day?"; + + if (width < 500) { + return "I'm a chef and I hate..."; + } + if (width <= 1080) { + return "What's your role and what eats up most of your day?"; + } + return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'"; +} diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx index e9bc018c1b..542173a99c 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/page.tsx @@ -6,7 +6,9 @@ import { Text } from "@/components/atoms/Text/Text"; import { Chat } from "@/components/contextual/Chat/Chat"; import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; +import { useEffect, useState } from "react"; import { useCopilotStore } from "./copilot-page-store"; +import { getInputPlaceholder } from "./helpers"; import { useCopilotPage } from "./useCopilotPage"; export default function CopilotPage() { @@ -14,8 +16,25 @@ export default function CopilotPage() { const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen); const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt); const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt); + + const [inputPlaceholder, setInputPlaceholder] = useState( + getInputPlaceholder(), + ); + + useEffect(() => { + const handleResize = () => { + setInputPlaceholder(getInputPlaceholder(window.innerWidth)); + }; + + handleResize(); + + window.addEventListener("resize", handleResize); + return () => window.removeEventListener("resize", handleResize); + }, []); + const { greetingName, quickActions, isLoading, hasSession, initialPrompt } = state; + const { handleQuickAction, startChatWithPrompt, @@ -73,7 +92,7 @@ export default function CopilotPage() { } return ( -
+
{isLoading ? (
@@ -90,25 +109,25 @@ export default function CopilotPage() {
) : ( <> -
+
Hey, {greetingName} - What do you want to automate? + Tell me about your work — I'll find what to automate.
-
+
{quickActions.map((action) => ( diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-draft-view.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-draft-view.tsx index 0147c19a5c..b0c3a6ff7b 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-draft-view.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-run-draft-view.tsx @@ -10,8 +10,8 @@ import React, { import { CredentialsMetaInput, CredentialsType, + Graph, GraphExecutionID, - GraphMeta, LibraryAgentPreset, LibraryAgentPresetID, LibraryAgentPresetUpdatable, @@ -69,7 +69,7 @@ export function AgentRunDraftView({ className, recommendedScheduleCron, }: { - graph: GraphMeta; + graph: Graph; agentActions?: ButtonAction[]; recommendedScheduleCron?: string | null; doRun?: ( diff --git a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-schedule-details-view.tsx b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-schedule-details-view.tsx index 61161088fc..30b0a82e65 100644 --- a/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-schedule-details-view.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/library/agents/[id]/components/OldAgentLibraryView/components/agent-schedule-details-view.tsx @@ -2,8 +2,8 @@ import React, { useCallback, useMemo } from "react"; import { + Graph, GraphExecutionID, - GraphMeta, Schedule, ScheduleID, } from "@/lib/autogpt-server-api"; @@ -35,7 +35,7 @@ export function AgentScheduleDetailsView({ onForcedRun, doDeleteSchedule, }: { - graph: GraphMeta; + graph: Graph; schedule: Schedule; agentActions: ButtonAction[]; onForcedRun: (runID: GraphExecutionID) => void; diff --git a/autogpt_platform/frontend/src/app/api/chat/tasks/[taskId]/stream/route.ts b/autogpt_platform/frontend/src/app/api/chat/tasks/[taskId]/stream/route.ts new file mode 100644 index 0000000000..336786bfdb --- /dev/null +++ b/autogpt_platform/frontend/src/app/api/chat/tasks/[taskId]/stream/route.ts @@ -0,0 +1,81 @@ +import { environment } from "@/services/environment"; +import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers"; +import { NextRequest } from "next/server"; + +/** + * SSE Proxy for task stream reconnection. + * + * This endpoint allows clients to reconnect to an ongoing or recently completed + * background task's stream. It replays missed messages from Redis Streams and + * subscribes to live updates if the task is still running. + * + * Client contract: + * 1. When receiving an operation_started event, store the task_id + * 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx} + * 3. Messages are replayed from the last_message_id position + * 4. Stream ends when "finish" event is received + */ +export async function GET( + request: NextRequest, + { params }: { params: Promise<{ taskId: string }> }, +) { + const { taskId } = await params; + const searchParams = request.nextUrl.searchParams; + const lastMessageId = searchParams.get("last_message_id") || "0-0"; + + try { + // Get auth token from server-side session + const token = await getServerAuthToken(); + + // Build backend URL + const backendUrl = environment.getAGPTServerBaseUrl(); + const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl); + streamUrl.searchParams.set("last_message_id", lastMessageId); + + // Forward request to backend with auth header + const headers: Record = { + Accept: "text/event-stream", + "Cache-Control": "no-cache", + Connection: "keep-alive", + }; + + if (token) { + headers["Authorization"] = `Bearer ${token}`; + } + + const response = await fetch(streamUrl.toString(), { + method: "GET", + headers, + }); + + if (!response.ok) { + const error = await response.text(); + return new Response(error, { + status: response.status, + headers: { "Content-Type": "application/json" }, + }); + } + + // Return the SSE stream directly + return new Response(response.body, { + headers: { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache, no-transform", + Connection: "keep-alive", + "X-Accel-Buffering": "no", + }, + }); + } catch (error) { + console.error("Task stream proxy error:", error); + return new Response( + JSON.stringify({ + error: "Failed to connect to task stream", + detail: error instanceof Error ? error.message : String(error), + }), + { + status: 500, + headers: { "Content-Type": "application/json" }, + }, + ); + } +} diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index aa4c49b1a2..0e9020272d 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -917,6 +917,28 @@ "security": [{ "HTTPBearerJWT": [] }] } }, + "/api/chat/config/ttl": { + "get": { + "tags": ["v2", "chat", "chat"], + "summary": "Get Ttl Config", + "description": "Get the stream TTL configuration.\n\nReturns the Time-To-Live settings for chat streams, which determines\nhow long clients can reconnect to an active stream.\n\nReturns:\n dict: TTL configuration with seconds and milliseconds values.", + "operationId": "getV2GetTtlConfig", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "additionalProperties": true, + "type": "object", + "title": "Response Getv2Getttlconfig" + } + } + } + } + } + } + }, "/api/chat/health": { "get": { "tags": ["v2", "chat", "chat"], @@ -939,6 +961,63 @@ } } }, + "/api/chat/operations/{operation_id}/complete": { + "post": { + "tags": ["v2", "chat", "chat"], + "summary": "Complete Operation", + "description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.", + "operationId": "postV2CompleteOperation", + "parameters": [ + { + "name": "operation_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Operation Id" } + }, + { + "name": "x-api-key", + "in": "header", + "required": false, + "schema": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "X-Api-Key" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/OperationCompleteRequest" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "additionalProperties": true, + "title": "Response Postv2Completeoperation" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, "/api/chat/sessions": { "get": { "tags": ["v2", "chat", "chat"], @@ -1022,7 +1101,7 @@ "get": { "tags": ["v2", "chat", "chat"], "summary": "Get Session", - "description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.", + "description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.", "operationId": "getV2GetSession", "security": [{ "HTTPBearerJWT": [] }], "parameters": [ @@ -1157,7 +1236,7 @@ "post": { "tags": ["v2", "chat", "chat"], "summary": "Stream Chat Post", - "description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.", + "description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.", "operationId": "postV2StreamChatPost", "security": [{ "HTTPBearerJWT": [] }], "parameters": [ @@ -1195,6 +1274,94 @@ } } }, + "/api/chat/tasks/{task_id}": { + "get": { + "tags": ["v2", "chat", "chat"], + "summary": "Get Task Status", + "description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.", + "operationId": "getV2GetTaskStatus", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Task Id" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "additionalProperties": true, + "title": "Response Getv2Gettaskstatus" + } + } + } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/chat/tasks/{task_id}/stream": { + "get": { + "tags": ["v2", "chat", "chat"], + "summary": "Stream Task", + "description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.", + "operationId": "getV2StreamTask", + "security": [{ "HTTPBearerJWT": [] }], + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Task Id" } + }, + { + "name": "last_message_id", + "in": "query", + "required": false, + "schema": { + "type": "string", + "description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.", + "default": "0-0", + "title": "Last Message Id" + }, + "description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay." + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { "application/json": { "schema": {} } } + }, + "401": { + "$ref": "#/components/responses/HTTP401NotAuthenticatedError" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, "/api/credits": { "get": { "tags": ["v1", "credits"], @@ -5462,7 +5629,9 @@ "description": "Successful Response", "content": { "application/json": { - "schema": { "$ref": "#/components/schemas/GraphMeta" } + "schema": { + "$ref": "#/components/schemas/GraphModelWithoutNodes" + } } } }, @@ -6168,6 +6337,18 @@ "title": "AccuracyTrendsResponse", "description": "Response model for accuracy trends and alerts." }, + "ActiveStreamInfo": { + "properties": { + "task_id": { "type": "string", "title": "Task Id" }, + "last_message_id": { "type": "string", "title": "Last Message Id" }, + "operation_id": { "type": "string", "title": "Operation Id" }, + "tool_name": { "type": "string", "title": "Tool Name" } + }, + "type": "object", + "required": ["task_id", "last_message_id", "operation_id", "tool_name"], + "title": "ActiveStreamInfo", + "description": "Information about an active stream for reconnection." + }, "AddUserCreditsResponse": { "properties": { "new_balance": { "type": "integer", "title": "New Balance" }, @@ -6316,18 +6497,6 @@ "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Recommended Schedule Cron" }, - "nodes": { - "items": { "$ref": "#/components/schemas/Node" }, - "type": "array", - "title": "Nodes", - "default": [] - }, - "links": { - "items": { "$ref": "#/components/schemas/Link" }, - "type": "array", - "title": "Links", - "default": [] - }, "forked_from_id": { "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Forked From Id" @@ -6335,11 +6504,22 @@ "forked_from_version": { "anyOf": [{ "type": "integer" }, { "type": "null" }], "title": "Forked From Version" + }, + "nodes": { + "items": { "$ref": "#/components/schemas/Node" }, + "type": "array", + "title": "Nodes" + }, + "links": { + "items": { "$ref": "#/components/schemas/Link" }, + "type": "array", + "title": "Links" } }, "type": "object", "required": ["name", "description"], - "title": "BaseGraph" + "title": "BaseGraph", + "description": "Graph with nodes, links, and computed I/O schema fields.\n\nUsed to represent sub-graphs within a `Graph`. Contains the full graph\nstructure including nodes and links, plus computed fields for schemas\nand trigger info. Does NOT include user_id or created_at (see GraphModel)." }, "BaseGraph-Output": { "properties": { @@ -6360,18 +6540,6 @@ "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Recommended Schedule Cron" }, - "nodes": { - "items": { "$ref": "#/components/schemas/Node" }, - "type": "array", - "title": "Nodes", - "default": [] - }, - "links": { - "items": { "$ref": "#/components/schemas/Link" }, - "type": "array", - "title": "Links", - "default": [] - }, "forked_from_id": { "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Forked From Id" @@ -6380,6 +6548,16 @@ "anyOf": [{ "type": "integer" }, { "type": "null" }], "title": "Forked From Version" }, + "nodes": { + "items": { "$ref": "#/components/schemas/Node" }, + "type": "array", + "title": "Nodes" + }, + "links": { + "items": { "$ref": "#/components/schemas/Link" }, + "type": "array", + "title": "Links" + }, "input_schema": { "additionalProperties": true, "type": "object", @@ -6426,7 +6604,8 @@ "has_sensitive_action", "trigger_setup_info" ], - "title": "BaseGraph" + "title": "BaseGraph", + "description": "Graph with nodes, links, and computed I/O schema fields.\n\nUsed to represent sub-graphs within a `Graph`. Contains the full graph\nstructure including nodes and links, plus computed fields for schemas\nand trigger info. Does NOT include user_id or created_at (see GraphModel)." }, "BlockCategoryResponse": { "properties": { @@ -7220,18 +7399,6 @@ "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Recommended Schedule Cron" }, - "nodes": { - "items": { "$ref": "#/components/schemas/Node" }, - "type": "array", - "title": "Nodes", - "default": [] - }, - "links": { - "items": { "$ref": "#/components/schemas/Link" }, - "type": "array", - "title": "Links", - "default": [] - }, "forked_from_id": { "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Forked From Id" @@ -7240,16 +7407,26 @@ "anyOf": [{ "type": "integer" }, { "type": "null" }], "title": "Forked From Version" }, + "nodes": { + "items": { "$ref": "#/components/schemas/Node" }, + "type": "array", + "title": "Nodes" + }, + "links": { + "items": { "$ref": "#/components/schemas/Link" }, + "type": "array", + "title": "Links" + }, "sub_graphs": { "items": { "$ref": "#/components/schemas/BaseGraph-Input" }, "type": "array", - "title": "Sub Graphs", - "default": [] + "title": "Sub Graphs" } }, "type": "object", "required": ["name", "description"], - "title": "Graph" + "title": "Graph", + "description": "Creatable graph model used in API create/update endpoints." }, "GraphExecution": { "properties": { @@ -7599,6 +7776,52 @@ "description": "Response schema for paginated graph executions." }, "GraphMeta": { + "properties": { + "id": { "type": "string", "title": "Id" }, + "version": { "type": "integer", "title": "Version" }, + "is_active": { + "type": "boolean", + "title": "Is Active", + "default": true + }, + "name": { "type": "string", "title": "Name" }, + "description": { "type": "string", "title": "Description" }, + "instructions": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Instructions" + }, + "recommended_schedule_cron": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Recommended Schedule Cron" + }, + "forked_from_id": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Forked From Id" + }, + "forked_from_version": { + "anyOf": [{ "type": "integer" }, { "type": "null" }], + "title": "Forked From Version" + }, + "user_id": { "type": "string", "title": "User Id" }, + "created_at": { + "type": "string", + "format": "date-time", + "title": "Created At" + } + }, + "type": "object", + "required": [ + "id", + "version", + "name", + "description", + "user_id", + "created_at" + ], + "title": "GraphMeta", + "description": "Lightweight graph metadata model representing an existing graph from the database,\nfor use in listings and summaries.\n\nLacks `GraphModel`'s nodes, links, and expensive computed fields.\nUse for list endpoints where full graph data is not needed and performance matters." + }, + "GraphModel": { "properties": { "id": { "type": "string", "title": "Id" }, "version": { "type": "integer", "title": "Version", "default": 1 }, @@ -7625,13 +7848,27 @@ "anyOf": [{ "type": "integer" }, { "type": "null" }], "title": "Forked From Version" }, + "user_id": { "type": "string", "title": "User Id" }, + "created_at": { + "type": "string", + "format": "date-time", + "title": "Created At" + }, + "nodes": { + "items": { "$ref": "#/components/schemas/NodeModel" }, + "type": "array", + "title": "Nodes" + }, + "links": { + "items": { "$ref": "#/components/schemas/Link" }, + "type": "array", + "title": "Links" + }, "sub_graphs": { "items": { "$ref": "#/components/schemas/BaseGraph-Output" }, "type": "array", - "title": "Sub Graphs", - "default": [] + "title": "Sub Graphs" }, - "user_id": { "type": "string", "title": "User Id" }, "input_schema": { "additionalProperties": true, "type": "object", @@ -7678,6 +7915,7 @@ "name", "description", "user_id", + "created_at", "input_schema", "output_schema", "has_external_trigger", @@ -7686,9 +7924,10 @@ "trigger_setup_info", "credentials_input_schema" ], - "title": "GraphMeta" + "title": "GraphModel", + "description": "Full graph model representing an existing graph from the database.\n\nThis is the primary model for working with persisted graphs. Includes all\ngraph data (nodes, links, sub_graphs) plus user ownership and timestamps.\nProvides computed fields (input_schema, output_schema, etc.) used during\nset-up (frontend) and execution (backend).\n\nInherits from:\n- `Graph`: provides structure (nodes, links, sub_graphs) and computed schemas\n- `GraphMeta`: provides user_id, created_at for database records" }, - "GraphModel": { + "GraphModelWithoutNodes": { "properties": { "id": { "type": "string", "title": "Id" }, "version": { "type": "integer", "title": "Version", "default": 1 }, @@ -7707,18 +7946,6 @@ "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Recommended Schedule Cron" }, - "nodes": { - "items": { "$ref": "#/components/schemas/NodeModel" }, - "type": "array", - "title": "Nodes", - "default": [] - }, - "links": { - "items": { "$ref": "#/components/schemas/Link" }, - "type": "array", - "title": "Links", - "default": [] - }, "forked_from_id": { "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Forked From Id" @@ -7727,12 +7954,6 @@ "anyOf": [{ "type": "integer" }, { "type": "null" }], "title": "Forked From Version" }, - "sub_graphs": { - "items": { "$ref": "#/components/schemas/BaseGraph-Output" }, - "type": "array", - "title": "Sub Graphs", - "default": [] - }, "user_id": { "type": "string", "title": "User Id" }, "created_at": { "type": "string", @@ -7794,7 +8015,8 @@ "trigger_setup_info", "credentials_input_schema" ], - "title": "GraphModel" + "title": "GraphModelWithoutNodes", + "description": "GraphModel variant that excludes nodes, links, and sub-graphs from serialization.\n\nUsed in contexts like the store where exposing internal graph structure\nis not desired. Inherits all computed fields from GraphModel but marks\nnodes and links as excluded from JSON output." }, "GraphSettings": { "properties": { @@ -8434,26 +8656,22 @@ "input_default": { "additionalProperties": true, "type": "object", - "title": "Input Default", - "default": {} + "title": "Input Default" }, "metadata": { "additionalProperties": true, "type": "object", - "title": "Metadata", - "default": {} + "title": "Metadata" }, "input_links": { "items": { "$ref": "#/components/schemas/Link" }, "type": "array", - "title": "Input Links", - "default": [] + "title": "Input Links" }, "output_links": { "items": { "$ref": "#/components/schemas/Link" }, "type": "array", - "title": "Output Links", - "default": [] + "title": "Output Links" } }, "type": "object", @@ -8533,26 +8751,22 @@ "input_default": { "additionalProperties": true, "type": "object", - "title": "Input Default", - "default": {} + "title": "Input Default" }, "metadata": { "additionalProperties": true, "type": "object", - "title": "Metadata", - "default": {} + "title": "Metadata" }, "input_links": { "items": { "$ref": "#/components/schemas/Link" }, "type": "array", - "title": "Input Links", - "default": [] + "title": "Input Links" }, "output_links": { "items": { "$ref": "#/components/schemas/Link" }, "type": "array", - "title": "Output Links", - "default": [] + "title": "Output Links" }, "graph_id": { "type": "string", "title": "Graph Id" }, "graph_version": { "type": "integer", "title": "Graph Version" }, @@ -8823,6 +9037,27 @@ ], "title": "OnboardingStep" }, + "OperationCompleteRequest": { + "properties": { + "success": { "type": "boolean", "title": "Success" }, + "result": { + "anyOf": [ + { "additionalProperties": true, "type": "object" }, + { "type": "string" }, + { "type": "null" } + ], + "title": "Result" + }, + "error": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Error" + } + }, + "type": "object", + "required": ["success"], + "title": "OperationCompleteRequest", + "description": "Request model for external completion webhook." + }, "Pagination": { "properties": { "total_items": { @@ -9678,6 +9913,12 @@ "items": { "additionalProperties": true, "type": "object" }, "type": "array", "title": "Messages" + }, + "active_stream": { + "anyOf": [ + { "$ref": "#/components/schemas/ActiveStreamInfo" }, + { "type": "null" } + ] } }, "type": "object", diff --git a/autogpt_platform/frontend/src/components/__legacy__/ui/render.tsx b/autogpt_platform/frontend/src/components/__legacy__/ui/render.tsx index 5173326f23..b290c51809 100644 --- a/autogpt_platform/frontend/src/components/__legacy__/ui/render.tsx +++ b/autogpt_platform/frontend/src/components/__legacy__/ui/render.tsx @@ -22,7 +22,7 @@ const isValidVideoUrl = (url: string): boolean => { if (url.startsWith("data:video")) { return true; } - const videoExtensions = /\.(mp4|webm|ogg)$/i; + const videoExtensions = /\.(mp4|webm|ogg|mov|avi|mkv|m4v)$/i; const youtubeRegex = /^(https?:\/\/)?(www\.)?(youtube\.com|youtu\.?be)\/.+$/; const cleanedUrl = url.split("?")[0]; return ( @@ -44,11 +44,29 @@ const isValidAudioUrl = (url: string): boolean => { if (url.startsWith("data:audio")) { return true; } - const audioExtensions = /\.(mp3|wav)$/i; + const audioExtensions = /\.(mp3|wav|ogg|m4a|aac|flac)$/i; const cleanedUrl = url.split("?")[0]; return isValidMediaUri(url) && audioExtensions.test(cleanedUrl); }; +const getVideoMimeType = (url: string): string => { + if (url.startsWith("data:video/")) { + const match = url.match(/^data:(video\/[^;]+)/); + return match?.[1] || "video/mp4"; + } + const extension = url.split("?")[0].split(".").pop()?.toLowerCase(); + const mimeMap: Record = { + mp4: "video/mp4", + webm: "video/webm", + ogg: "video/ogg", + mov: "video/quicktime", + avi: "video/x-msvideo", + mkv: "video/x-matroska", + m4v: "video/mp4", + }; + return mimeMap[extension || ""] || "video/mp4"; +}; + const VideoRenderer: React.FC<{ videoUrl: string }> = ({ videoUrl }) => { const videoId = getYouTubeVideoId(videoUrl); return ( @@ -63,7 +81,7 @@ const VideoRenderer: React.FC<{ videoUrl: string }> = ({ videoUrl }) => { > ) : ( )} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx index ada8c26231..da454150bf 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/Chat.tsx @@ -1,7 +1,6 @@ "use client"; import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId"; -import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store"; import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner"; import { Text } from "@/components/atoms/Text/Text"; import { cn } from "@/lib/utils"; @@ -25,8 +24,8 @@ export function Chat({ }: ChatProps) { const { urlSessionId } = useCopilotSessionId(); const hasHandledNotFoundRef = useRef(false); - const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession); const { + session, messages, isLoading, isCreating, @@ -38,6 +37,18 @@ export function Chat({ startPollingForOperation, } = useChat({ urlSessionId }); + // Extract active stream info for reconnection + const activeStream = ( + session as { + active_stream?: { + task_id: string; + last_message_id: string; + operation_id: string; + tool_name: string; + }; + } + )?.active_stream; + useEffect(() => { if (!onSessionNotFound) return; if (!urlSessionId) return; @@ -53,8 +64,7 @@ export function Chat({ isCreating, ]); - const shouldShowLoader = - (showLoader && (isLoading || isCreating)) || isSwitchingSession; + const shouldShowLoader = showLoader && (isLoading || isCreating); return (
@@ -66,21 +76,19 @@ export function Chat({
- {isSwitchingSession - ? "Switching chat..." - : "Loading your chat..."} + Loading your chat...
)} {/* Error State */} - {error && !isLoading && !isSwitchingSession && ( + {error && !isLoading && ( )} {/* Session Content */} - {sessionId && !isLoading && !error && !isSwitchingSession && ( + {sessionId && !isLoading && !error && ( )} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/SSE_RECONNECTION.md b/autogpt_platform/frontend/src/components/contextual/Chat/SSE_RECONNECTION.md new file mode 100644 index 0000000000..9e78679f4e --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/SSE_RECONNECTION.md @@ -0,0 +1,159 @@ +# SSE Reconnection Contract for Long-Running Operations + +This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks. + +## Overview + +When a user triggers a long-running operation (like agent generation), the backend: + +1. Spawns a background task that survives SSE disconnections +2. Returns an `operation_started` response with a `task_id` +3. Stores stream messages in Redis Streams for replay + +Clients can reconnect to the task stream at any time to receive missed messages. + +## Client-Side Flow + +### 1. Receiving Operation Started + +When you receive an `operation_started` tool response: + +```typescript +// The response includes a task_id for reconnection +{ + type: "operation_started", + tool_name: "generate_agent", + operation_id: "uuid-...", + task_id: "task-uuid-...", // <-- Store this for reconnection + message: "Operation started. You can close this tab." +} +``` + +### 2. Storing Task Info + +Use the chat store to track the active task: + +```typescript +import { useChatStore } from "./chat-store"; + +// When operation_started is received: +useChatStore.getState().setActiveTask(sessionId, { + taskId: response.task_id, + operationId: response.operation_id, + toolName: response.tool_name, + lastMessageId: "0", +}); +``` + +### 3. Reconnecting to a Task + +To reconnect (e.g., after page refresh or tab reopen): + +```typescript +const { reconnectToTask, getActiveTask } = useChatStore.getState(); + +// Check if there's an active task for this session +const activeTask = getActiveTask(sessionId); + +if (activeTask) { + // Reconnect to the task stream + await reconnectToTask( + sessionId, + activeTask.taskId, + activeTask.lastMessageId, // Resume from last position + (chunk) => { + // Handle incoming chunks + console.log("Received chunk:", chunk); + }, + ); +} +``` + +### 4. Tracking Message Position + +To enable precise replay, update the last message ID as chunks arrive: + +```typescript +const { updateTaskLastMessageId } = useChatStore.getState(); + +function handleChunk(chunk: StreamChunk) { + // If chunk has an index/id, track it + if (chunk.idx !== undefined) { + updateTaskLastMessageId(sessionId, String(chunk.idx)); + } +} +``` + +## API Endpoints + +### Task Stream Reconnection + +``` +GET /api/chat/tasks/{taskId}/stream?last_message_id={idx} +``` + +- `taskId`: The task ID from `operation_started` +- `last_message_id`: Last received message index (default: "0" for full replay) + +Returns: SSE stream of missed messages + live updates + +## Chunk Types + +The reconnected stream follows the same Vercel AI SDK protocol: + +| Type | Description | +| ----------------------- | ----------------------- | +| `start` | Message lifecycle start | +| `text-delta` | Streaming text content | +| `text-end` | Text block completed | +| `tool-output-available` | Tool result available | +| `finish` | Stream completed | +| `error` | Error occurred | + +## Error Handling + +If reconnection fails: + +1. Check if task still exists (may have expired - default TTL: 1 hour) +2. Fall back to polling the session for final state +3. Show appropriate UI message to user + +## Persistence Considerations + +For robust reconnection across browser restarts: + +```typescript +// Store in localStorage/sessionStorage +const ACTIVE_TASKS_KEY = "chat_active_tasks"; + +function persistActiveTask(sessionId: string, task: ActiveTaskInfo) { + const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}"); + tasks[sessionId] = task; + localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks)); +} + +function loadPersistedTasks(): Record { + return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}"); +} +``` + +## Backend Configuration + +The following backend settings affect reconnection behavior: + +| Setting | Default | Description | +| ------------------- | ------- | ---------------------------------- | +| `stream_ttl` | 3600s | How long streams are kept in Redis | +| `stream_max_length` | 1000 | Max messages per stream | + +## Testing + +To test reconnection locally: + +1. Start a long-running operation (e.g., agent generation) +2. Note the `task_id` from the `operation_started` response +3. Close the browser tab +4. Reopen and call `reconnectToTask` with the saved `task_id` +5. Verify that missed messages are replayed + +See the main README for full local development setup. diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-constants.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-constants.ts new file mode 100644 index 0000000000..8802de2155 --- /dev/null +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-constants.ts @@ -0,0 +1,16 @@ +/** + * Constants for the chat system. + * + * Centralizes magic strings and values used across chat components. + */ + +// LocalStorage keys +export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks"; + +// Redis Stream IDs +export const INITIAL_MESSAGE_ID = "0"; +export const INITIAL_STREAM_ID = "0-0"; + +// TTL values (in milliseconds) +export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes +export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts index 8229630e5d..3083f65d2c 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-store.ts @@ -1,6 +1,12 @@ "use client"; import { create } from "zustand"; +import { + ACTIVE_TASK_TTL_MS, + COMPLETED_STREAM_TTL_MS, + INITIAL_STREAM_ID, + STORAGE_KEY_ACTIVE_TASKS, +} from "./chat-constants"; import type { ActiveStream, StreamChunk, @@ -8,15 +14,59 @@ import type { StreamResult, StreamStatus, } from "./chat-types"; -import { executeStream } from "./stream-executor"; +import { executeStream, executeTaskReconnect } from "./stream-executor"; -const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes +export interface ActiveTaskInfo { + taskId: string; + sessionId: string; + operationId: string; + toolName: string; + lastMessageId: string; + startedAt: number; +} + +/** Load active tasks from localStorage */ +function loadPersistedTasks(): Map { + if (typeof window === "undefined") return new Map(); + try { + const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS); + if (!stored) return new Map(); + const parsed = JSON.parse(stored) as Record; + const now = Date.now(); + const tasks = new Map(); + // Filter out expired tasks + for (const [sessionId, task] of Object.entries(parsed)) { + if (now - task.startedAt < ACTIVE_TASK_TTL_MS) { + tasks.set(sessionId, task); + } + } + return tasks; + } catch { + return new Map(); + } +} + +/** Save active tasks to localStorage */ +function persistTasks(tasks: Map): void { + if (typeof window === "undefined") return; + try { + const obj: Record = {}; + for (const [sessionId, task] of tasks) { + obj[sessionId] = task; + } + localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj)); + } catch { + // Ignore storage errors + } +} interface ChatStoreState { activeStreams: Map; completedStreams: Map; activeSessions: Set; streamCompleteCallbacks: Set; + /** Active tasks for SSE reconnection - keyed by sessionId */ + activeTasks: Map; } interface ChatStoreActions { @@ -41,6 +91,24 @@ interface ChatStoreActions { unregisterActiveSession: (sessionId: string) => void; isSessionActive: (sessionId: string) => boolean; onStreamComplete: (callback: StreamCompleteCallback) => () => void; + /** Track active task for SSE reconnection */ + setActiveTask: ( + sessionId: string, + taskInfo: Omit, + ) => void; + /** Get active task for a session */ + getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined; + /** Clear active task when operation completes */ + clearActiveTask: (sessionId: string) => void; + /** Reconnect to an existing task stream */ + reconnectToTask: ( + sessionId: string, + taskId: string, + lastMessageId?: string, + onChunk?: (chunk: StreamChunk) => void, + ) => Promise; + /** Update last message ID for a task (for tracking replay position) */ + updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void; } type ChatStore = ChatStoreState & ChatStoreActions; @@ -64,18 +132,126 @@ function cleanupExpiredStreams( const now = Date.now(); const cleaned = new Map(completedStreams); for (const [sessionId, result] of cleaned) { - if (now - result.completedAt > COMPLETED_STREAM_TTL) { + if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) { cleaned.delete(sessionId); } } return cleaned; } +/** + * Finalize a stream by moving it from activeStreams to completedStreams. + * Also handles cleanup and notifications. + */ +function finalizeStream( + sessionId: string, + stream: ActiveStream, + onChunk: ((chunk: StreamChunk) => void) | undefined, + get: () => ChatStoreState & ChatStoreActions, + set: (state: Partial) => void, +): void { + if (onChunk) stream.onChunkCallbacks.delete(onChunk); + + if (stream.status !== "streaming") { + const currentState = get(); + const finalActiveStreams = new Map(currentState.activeStreams); + let finalCompletedStreams = new Map(currentState.completedStreams); + + const storedStream = finalActiveStreams.get(sessionId); + if (storedStream === stream) { + const result: StreamResult = { + sessionId, + status: stream.status, + chunks: stream.chunks, + completedAt: Date.now(), + error: stream.error, + }; + finalCompletedStreams.set(sessionId, result); + finalActiveStreams.delete(sessionId); + finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams); + set({ + activeStreams: finalActiveStreams, + completedStreams: finalCompletedStreams, + }); + + if (stream.status === "completed" || stream.status === "error") { + notifyStreamComplete(currentState.streamCompleteCallbacks, sessionId); + } + } + } +} + +/** + * Clean up an existing stream for a session and move it to completed streams. + * Returns updated maps for both active and completed streams. + */ +function cleanupExistingStream( + sessionId: string, + activeStreams: Map, + completedStreams: Map, + callbacks: Set, +): { + activeStreams: Map; + completedStreams: Map; +} { + const newActiveStreams = new Map(activeStreams); + let newCompletedStreams = new Map(completedStreams); + + const existingStream = newActiveStreams.get(sessionId); + if (existingStream) { + existingStream.abortController.abort(); + const normalizedStatus = + existingStream.status === "streaming" + ? "completed" + : existingStream.status; + const result: StreamResult = { + sessionId, + status: normalizedStatus, + chunks: existingStream.chunks, + completedAt: Date.now(), + error: existingStream.error, + }; + newCompletedStreams.set(sessionId, result); + newActiveStreams.delete(sessionId); + newCompletedStreams = cleanupExpiredStreams(newCompletedStreams); + if (normalizedStatus === "completed" || normalizedStatus === "error") { + notifyStreamComplete(callbacks, sessionId); + } + } + + return { + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + }; +} + +/** + * Create a new active stream with initial state. + */ +function createActiveStream( + sessionId: string, + onChunk?: (chunk: StreamChunk) => void, +): ActiveStream { + const abortController = new AbortController(); + const initialCallbacks = new Set<(chunk: StreamChunk) => void>(); + if (onChunk) initialCallbacks.add(onChunk); + + return { + sessionId, + abortController, + status: "streaming", + startedAt: Date.now(), + chunks: [], + onChunkCallbacks: initialCallbacks, + }; +} + export const useChatStore = create((set, get) => ({ activeStreams: new Map(), completedStreams: new Map(), activeSessions: new Set(), streamCompleteCallbacks: new Set(), + activeTasks: loadPersistedTasks(), startStream: async function startStream( sessionId, @@ -85,45 +261,21 @@ export const useChatStore = create((set, get) => ({ onChunk, ) { const state = get(); - const newActiveStreams = new Map(state.activeStreams); - let newCompletedStreams = new Map(state.completedStreams); const callbacks = state.streamCompleteCallbacks; - const existingStream = newActiveStreams.get(sessionId); - if (existingStream) { - existingStream.abortController.abort(); - const normalizedStatus = - existingStream.status === "streaming" - ? "completed" - : existingStream.status; - const result: StreamResult = { - sessionId, - status: normalizedStatus, - chunks: existingStream.chunks, - completedAt: Date.now(), - error: existingStream.error, - }; - newCompletedStreams.set(sessionId, result); - newActiveStreams.delete(sessionId); - newCompletedStreams = cleanupExpiredStreams(newCompletedStreams); - if (normalizedStatus === "completed" || normalizedStatus === "error") { - notifyStreamComplete(callbacks, sessionId); - } - } - - const abortController = new AbortController(); - const initialCallbacks = new Set<(chunk: StreamChunk) => void>(); - if (onChunk) initialCallbacks.add(onChunk); - - const stream: ActiveStream = { + // Clean up any existing stream for this session + const { + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + } = cleanupExistingStream( sessionId, - abortController, - status: "streaming", - startedAt: Date.now(), - chunks: [], - onChunkCallbacks: initialCallbacks, - }; + state.activeStreams, + state.completedStreams, + callbacks, + ); + // Create new stream + const stream = createActiveStream(sessionId, onChunk); newActiveStreams.set(sessionId, stream); set({ activeStreams: newActiveStreams, @@ -133,36 +285,7 @@ export const useChatStore = create((set, get) => ({ try { await executeStream(stream, message, isUserMessage, context); } finally { - if (onChunk) stream.onChunkCallbacks.delete(onChunk); - if (stream.status !== "streaming") { - const currentState = get(); - const finalActiveStreams = new Map(currentState.activeStreams); - let finalCompletedStreams = new Map(currentState.completedStreams); - - const storedStream = finalActiveStreams.get(sessionId); - if (storedStream === stream) { - const result: StreamResult = { - sessionId, - status: stream.status, - chunks: stream.chunks, - completedAt: Date.now(), - error: stream.error, - }; - finalCompletedStreams.set(sessionId, result); - finalActiveStreams.delete(sessionId); - finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams); - set({ - activeStreams: finalActiveStreams, - completedStreams: finalCompletedStreams, - }); - if (stream.status === "completed" || stream.status === "error") { - notifyStreamComplete( - currentState.streamCompleteCallbacks, - sessionId, - ); - } - } - } + finalizeStream(sessionId, stream, onChunk, get, set); } }, @@ -286,4 +409,93 @@ export const useChatStore = create((set, get) => ({ set({ streamCompleteCallbacks: cleanedCallbacks }); }; }, + + setActiveTask: function setActiveTask(sessionId, taskInfo) { + const state = get(); + const newActiveTasks = new Map(state.activeTasks); + newActiveTasks.set(sessionId, { + ...taskInfo, + sessionId, + startedAt: Date.now(), + }); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + }, + + getActiveTask: function getActiveTask(sessionId) { + return get().activeTasks.get(sessionId); + }, + + clearActiveTask: function clearActiveTask(sessionId) { + const state = get(); + if (!state.activeTasks.has(sessionId)) return; + + const newActiveTasks = new Map(state.activeTasks); + newActiveTasks.delete(sessionId); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + }, + + reconnectToTask: async function reconnectToTask( + sessionId, + taskId, + lastMessageId = INITIAL_STREAM_ID, + onChunk, + ) { + const state = get(); + const callbacks = state.streamCompleteCallbacks; + + // Clean up any existing stream for this session + const { + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + } = cleanupExistingStream( + sessionId, + state.activeStreams, + state.completedStreams, + callbacks, + ); + + // Create new stream for reconnection + const stream = createActiveStream(sessionId, onChunk); + newActiveStreams.set(sessionId, stream); + set({ + activeStreams: newActiveStreams, + completedStreams: newCompletedStreams, + }); + + try { + await executeTaskReconnect(stream, taskId, lastMessageId); + } finally { + finalizeStream(sessionId, stream, onChunk, get, set); + + // Clear active task on completion + if (stream.status === "completed" || stream.status === "error") { + const taskState = get(); + if (taskState.activeTasks.has(sessionId)) { + const newActiveTasks = new Map(taskState.activeTasks); + newActiveTasks.delete(sessionId); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + } + } + } + }, + + updateTaskLastMessageId: function updateTaskLastMessageId( + sessionId, + lastMessageId, + ) { + const state = get(); + const task = state.activeTasks.get(sessionId); + if (!task) return; + + const newActiveTasks = new Map(state.activeTasks); + newActiveTasks.set(sessionId, { + ...task, + lastMessageId, + }); + set({ activeTasks: newActiveTasks }); + persistTasks(newActiveTasks); + }, })); diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts b/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts index 8c8aa7b704..34813e17fe 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/chat-types.ts @@ -4,6 +4,7 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error"; export interface StreamChunk { type: + | "stream_start" | "text_chunk" | "text_ended" | "tool_call" @@ -15,6 +16,7 @@ export interface StreamChunk { | "error" | "usage" | "stream_end"; + taskId?: string; timestamp?: string; content?: string; message?: string; @@ -41,7 +43,7 @@ export interface StreamChunk { } export type VercelStreamChunk = - | { type: "start"; messageId: string } + | { type: "start"; messageId: string; taskId?: string } | { type: "finish" } | { type: "text-start"; id: string } | { type: "text-delta"; id: string; delta: string } @@ -92,3 +94,70 @@ export interface StreamResult { } export type StreamCompleteCallback = (sessionId: string) => void; + +// Type guards for message types + +/** + * Check if a message has a toolId property. + */ +export function hasToolId( + msg: T, +): msg is T & { toolId: string } { + return ( + "toolId" in msg && + typeof (msg as Record).toolId === "string" + ); +} + +/** + * Check if a message has an operationId property. + */ +export function hasOperationId( + msg: T, +): msg is T & { operationId: string } { + return ( + "operationId" in msg && + typeof (msg as Record).operationId === "string" + ); +} + +/** + * Check if a message has a toolCallId property. + */ +export function hasToolCallId( + msg: T, +): msg is T & { toolCallId: string } { + return ( + "toolCallId" in msg && + typeof (msg as Record).toolCallId === "string" + ); +} + +/** + * Check if a message is an operation message type. + */ +export function isOperationMessage( + msg: T, +): msg is T & { + type: "operation_started" | "operation_pending" | "operation_in_progress"; +} { + return ( + msg.type === "operation_started" || + msg.type === "operation_pending" || + msg.type === "operation_in_progress" + ); +} + +/** + * Get the tool ID from a message if available. + * Checks toolId, operationId, and toolCallId properties. + */ +export function getToolIdFromMessage( + msg: T, +): string | undefined { + const record = msg as Record; + if (typeof record.toolId === "string") return record.toolId; + if (typeof record.operationId === "string") return record.operationId; + if (typeof record.toolCallId === "string") return record.toolCallId; + return undefined; +} diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx index dec221338a..fbf2d5d143 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/ChatContainer.tsx @@ -2,7 +2,6 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessi import { Button } from "@/components/atoms/Button/Button"; import { Text } from "@/components/atoms/Text/Text"; import { Dialog } from "@/components/molecules/Dialog/Dialog"; -import { useBreakpoint } from "@/lib/hooks/useBreakpoint"; import { cn } from "@/lib/utils"; import { GlobeHemisphereEastIcon } from "@phosphor-icons/react"; import { useEffect } from "react"; @@ -17,6 +16,13 @@ export interface ChatContainerProps { className?: string; onStreamingChange?: (isStreaming: boolean) => void; onOperationStarted?: () => void; + /** Active stream info from the server for reconnection */ + activeStream?: { + taskId: string; + lastMessageId: string; + operationId: string; + toolName: string; + }; } export function ChatContainer({ @@ -26,6 +32,7 @@ export function ChatContainer({ className, onStreamingChange, onOperationStarted, + activeStream, }: ChatContainerProps) { const { messages, @@ -41,16 +48,13 @@ export function ChatContainer({ initialMessages, initialPrompt, onOperationStarted, + activeStream, }); useEffect(() => { onStreamingChange?.(isStreaming); }, [isStreaming, onStreamingChange]); - const breakpoint = useBreakpoint(); - const isMobile = - breakpoint === "base" || breakpoint === "sm" || breakpoint === "md"; - return (
diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts index 82e9b05e88..af3b3329b7 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/createStreamEventDispatcher.ts @@ -2,6 +2,7 @@ import { toast } from "sonner"; import type { StreamChunk } from "../../chat-types"; import type { HandlerDependencies } from "./handlers"; import { + getErrorDisplayMessage, handleError, handleLoginNeeded, handleStreamEnd, @@ -24,16 +25,22 @@ export function createStreamEventDispatcher( chunk.type === "need_login" || chunk.type === "error" ) { - if (!deps.hasResponseRef.current) { - console.info("[ChatStream] First response chunk:", { - type: chunk.type, - sessionId: deps.sessionId, - }); - } deps.hasResponseRef.current = true; } switch (chunk.type) { + case "stream_start": + // Store task ID for SSE reconnection + if (chunk.taskId && deps.onActiveTaskStarted) { + deps.onActiveTaskStarted({ + taskId: chunk.taskId, + operationId: chunk.taskId, + toolName: "chat", + toolCallId: "chat_stream", + }); + } + break; + case "text_chunk": handleTextChunk(chunk, deps); break; @@ -56,11 +63,7 @@ export function createStreamEventDispatcher( break; case "stream_end": - console.info("[ChatStream] Stream ended:", { - sessionId: deps.sessionId, - hasResponse: deps.hasResponseRef.current, - chunkCount: deps.streamingChunksRef.current.length, - }); + // Note: "finish" type from backend gets normalized to "stream_end" by normalizeStreamChunk handleStreamEnd(chunk, deps); break; @@ -70,7 +73,7 @@ export function createStreamEventDispatcher( // Show toast at dispatcher level to avoid circular dependencies if (!isRegionBlocked) { toast.error("Chat Error", { - description: chunk.message || chunk.content || "An error occurred", + description: getErrorDisplayMessage(chunk), }); } break; diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts index f3cac01f96..5aec5b9818 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/handlers.ts @@ -18,11 +18,19 @@ export interface HandlerDependencies { setStreamingChunks: Dispatch>; streamingChunksRef: MutableRefObject; hasResponseRef: MutableRefObject; + textFinalizedRef: MutableRefObject; + streamEndedRef: MutableRefObject; setMessages: Dispatch>; setIsStreamingInitiated: Dispatch>; setIsRegionBlockedModalOpen: Dispatch>; sessionId: string; onOperationStarted?: () => void; + onActiveTaskStarted?: (taskInfo: { + taskId: string; + operationId: string; + toolName: string; + toolCallId: string; + }) => void; } export function isRegionBlockedError(chunk: StreamChunk): boolean { @@ -32,6 +40,25 @@ export function isRegionBlockedError(chunk: StreamChunk): boolean { return message.toLowerCase().includes("not available in your region"); } +export function getUserFriendlyErrorMessage( + code: string | undefined, +): string | undefined { + switch (code) { + case "TASK_EXPIRED": + return "This operation has expired. Please try again."; + case "TASK_NOT_FOUND": + return "Could not find the requested operation."; + case "ACCESS_DENIED": + return "You do not have access to this operation."; + case "QUEUE_OVERFLOW": + return "Connection was interrupted. Please refresh to continue."; + case "MODEL_NOT_AVAILABLE_REGION": + return "This model is not available in your region."; + default: + return undefined; + } +} + export function handleTextChunk(chunk: StreamChunk, deps: HandlerDependencies) { if (!chunk.content) return; deps.setHasTextChunks(true); @@ -46,10 +73,15 @@ export function handleTextEnded( _chunk: StreamChunk, deps: HandlerDependencies, ) { + if (deps.textFinalizedRef.current) { + return; + } + const completedText = deps.streamingChunksRef.current.join(""); if (completedText.trim()) { + deps.textFinalizedRef.current = true; + deps.setMessages((prev) => { - // Check if this exact message already exists to prevent duplicates const exists = prev.some( (msg) => msg.type === "message" && @@ -76,9 +108,14 @@ export function handleToolCallStart( chunk: StreamChunk, deps: HandlerDependencies, ) { + // Use deterministic fallback instead of Date.now() to ensure same ID on replay + const toolId = + chunk.tool_id || + `tool-${deps.sessionId}-${chunk.idx ?? "unknown"}-${chunk.tool_name || "unknown"}`; + const toolCallMessage: Extract = { type: "tool_call", - toolId: chunk.tool_id || `tool-${Date.now()}-${chunk.idx || 0}`, + toolId, toolName: chunk.tool_name || "Executing", arguments: chunk.arguments || {}, timestamp: new Date(), @@ -111,6 +148,29 @@ export function handleToolCallStart( deps.setMessages(updateToolCallMessages); } +const TOOL_RESPONSE_TYPES = new Set([ + "tool_response", + "operation_started", + "operation_pending", + "operation_in_progress", + "execution_started", + "agent_carousel", + "clarification_needed", +]); + +function hasResponseForTool( + messages: ChatMessageData[], + toolId: string, +): boolean { + return messages.some((msg) => { + if (!TOOL_RESPONSE_TYPES.has(msg.type)) return false; + const msgToolId = + (msg as { toolId?: string }).toolId || + (msg as { toolCallId?: string }).toolCallId; + return msgToolId === toolId; + }); +} + export function handleToolResponse( chunk: StreamChunk, deps: HandlerDependencies, @@ -152,31 +212,49 @@ export function handleToolResponse( ) { const inputsMessage = extractInputsNeeded(parsedResult, chunk.tool_name); if (inputsMessage) { - deps.setMessages((prev) => [...prev, inputsMessage]); + deps.setMessages((prev) => { + // Check for duplicate inputs_needed message + const exists = prev.some((msg) => msg.type === "inputs_needed"); + if (exists) return prev; + return [...prev, inputsMessage]; + }); } const credentialsMessage = extractCredentialsNeeded( parsedResult, chunk.tool_name, ); if (credentialsMessage) { - deps.setMessages((prev) => [...prev, credentialsMessage]); + deps.setMessages((prev) => { + // Check for duplicate credentials_needed message + const exists = prev.some((msg) => msg.type === "credentials_needed"); + if (exists) return prev; + return [...prev, credentialsMessage]; + }); } } return; } - // Trigger polling when operation_started is received if (responseMessage.type === "operation_started") { deps.onOperationStarted?.(); + const taskId = (responseMessage as { taskId?: string }).taskId; + if (taskId && deps.onActiveTaskStarted) { + deps.onActiveTaskStarted({ + taskId, + operationId: + (responseMessage as { operationId?: string }).operationId || "", + toolName: (responseMessage as { toolName?: string }).toolName || "", + toolCallId: (responseMessage as { toolId?: string }).toolId || "", + }); + } } deps.setMessages((prev) => { const toolCallIndex = prev.findIndex( (msg) => msg.type === "tool_call" && msg.toolId === chunk.tool_id, ); - const hasResponse = prev.some( - (msg) => msg.type === "tool_response" && msg.toolId === chunk.tool_id, - ); - if (hasResponse) return prev; + if (hasResponseForTool(prev, chunk.tool_id!)) { + return prev; + } if (toolCallIndex !== -1) { const newMessages = [...prev]; newMessages.splice(toolCallIndex + 1, 0, responseMessage); @@ -198,28 +276,48 @@ export function handleLoginNeeded( agentInfo: chunk.agent_info, timestamp: new Date(), }; - deps.setMessages((prev) => [...prev, loginNeededMessage]); + deps.setMessages((prev) => { + // Check for duplicate login_needed message + const exists = prev.some((msg) => msg.type === "login_needed"); + if (exists) return prev; + return [...prev, loginNeededMessage]; + }); } export function handleStreamEnd( _chunk: StreamChunk, deps: HandlerDependencies, ) { + if (deps.streamEndedRef.current) { + return; + } + deps.streamEndedRef.current = true; + const completedContent = deps.streamingChunksRef.current.join(""); if (!completedContent.trim() && !deps.hasResponseRef.current) { - deps.setMessages((prev) => [ - ...prev, - { - type: "message", - role: "assistant", - content: "No response received. Please try again.", - timestamp: new Date(), - }, - ]); - } - if (completedContent.trim()) { deps.setMessages((prev) => { - // Check if this exact message already exists to prevent duplicates + const exists = prev.some( + (msg) => + msg.type === "message" && + msg.role === "assistant" && + msg.content === "No response received. Please try again.", + ); + if (exists) return prev; + return [ + ...prev, + { + type: "message", + role: "assistant", + content: "No response received. Please try again.", + timestamp: new Date(), + }, + ]; + }); + } + if (completedContent.trim() && !deps.textFinalizedRef.current) { + deps.textFinalizedRef.current = true; + + deps.setMessages((prev) => { const exists = prev.some( (msg) => msg.type === "message" && @@ -244,8 +342,6 @@ export function handleStreamEnd( } export function handleError(chunk: StreamChunk, deps: HandlerDependencies) { - const errorMessage = chunk.message || chunk.content || "An error occurred"; - console.error("Stream error:", errorMessage); if (isRegionBlockedError(chunk)) { deps.setIsRegionBlockedModalOpen(true); } @@ -253,4 +349,14 @@ export function handleError(chunk: StreamChunk, deps: HandlerDependencies) { deps.setHasTextChunks(false); deps.setStreamingChunks([]); deps.streamingChunksRef.current = []; + deps.textFinalizedRef.current = false; + deps.streamEndedRef.current = true; +} + +export function getErrorDisplayMessage(chunk: StreamChunk): string { + const friendlyMessage = getUserFriendlyErrorMessage(chunk.code); + if (friendlyMessage) { + return friendlyMessage; + } + return chunk.message || chunk.content || "An error occurred"; } diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts index e744c9bc34..f1e94cea17 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/helpers.ts @@ -349,6 +349,7 @@ export function parseToolResponse( toolName: (parsedResult.tool_name as string) || toolName, toolId, operationId: (parsedResult.operation_id as string) || "", + taskId: (parsedResult.task_id as string) || undefined, // For SSE reconnection message: (parsedResult.message as string) || "Operation started. You can close this tab.", diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts index 46f384d055..248383df42 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatContainer/useChatContainer.ts @@ -1,10 +1,17 @@ import type { SessionDetailResponse } from "@/app/api/__generated__/models/sessionDetailResponse"; import { useEffect, useMemo, useRef, useState } from "react"; +import { INITIAL_STREAM_ID } from "../../chat-constants"; import { useChatStore } from "../../chat-store"; import { toast } from "sonner"; import { useChatStream } from "../../useChatStream"; import { usePageContext } from "../../usePageContext"; import type { ChatMessageData } from "../ChatMessage/useChatMessage"; +import { + getToolIdFromMessage, + hasToolId, + isOperationMessage, + type StreamChunk, +} from "../../chat-types"; import { createStreamEventDispatcher } from "./createStreamEventDispatcher"; import { createUserMessage, @@ -14,6 +21,13 @@ import { processInitialMessages, } from "./helpers"; +const TOOL_RESULT_TYPES = new Set([ + "tool_response", + "agent_carousel", + "execution_started", + "clarification_needed", +]); + // Helper to generate deduplication key for a message function getMessageKey(msg: ChatMessageData): string { if (msg.type === "message") { @@ -23,14 +37,18 @@ function getMessageKey(msg: ChatMessageData): string { return `msg:${msg.role}:${msg.content}`; } else if (msg.type === "tool_call") { return `toolcall:${msg.toolId}`; - } else if (msg.type === "tool_response") { - return `toolresponse:${(msg as any).toolId}`; - } else if ( - msg.type === "operation_started" || - msg.type === "operation_pending" || - msg.type === "operation_in_progress" - ) { - return `op:${(msg as any).toolId || (msg as any).operationId || (msg as any).toolCallId || ""}:${msg.toolName}`; + } else if (TOOL_RESULT_TYPES.has(msg.type)) { + // Unified key for all tool result types - same toolId with different types + // (tool_response vs agent_carousel) should deduplicate to the same key + const toolId = getToolIdFromMessage(msg); + // If no toolId, fall back to content-based key to avoid empty key collisions + if (!toolId) { + return `toolresult:content:${JSON.stringify(msg).slice(0, 200)}`; + } + return `toolresult:${toolId}`; + } else if (isOperationMessage(msg)) { + const toolId = getToolIdFromMessage(msg) || ""; + return `op:${toolId}:${msg.toolName}`; } else { return `${msg.type}:${JSON.stringify(msg).slice(0, 100)}`; } @@ -41,6 +59,13 @@ interface Args { initialMessages: SessionDetailResponse["messages"]; initialPrompt?: string; onOperationStarted?: () => void; + /** Active stream info from the server for reconnection */ + activeStream?: { + taskId: string; + lastMessageId: string; + operationId: string; + toolName: string; + }; } export function useChatContainer({ @@ -48,6 +73,7 @@ export function useChatContainer({ initialMessages, initialPrompt, onOperationStarted, + activeStream, }: Args) { const [messages, setMessages] = useState([]); const [streamingChunks, setStreamingChunks] = useState([]); @@ -57,6 +83,8 @@ export function useChatContainer({ useState(false); const hasResponseRef = useRef(false); const streamingChunksRef = useRef([]); + const textFinalizedRef = useRef(false); + const streamEndedRef = useRef(false); const previousSessionIdRef = useRef(null); const { error, @@ -65,44 +93,182 @@ export function useChatContainer({ } = useChatStream(); const activeStreams = useChatStore((s) => s.activeStreams); const subscribeToStream = useChatStore((s) => s.subscribeToStream); + const setActiveTask = useChatStore((s) => s.setActiveTask); + const getActiveTask = useChatStore((s) => s.getActiveTask); + const reconnectToTask = useChatStore((s) => s.reconnectToTask); const isStreaming = isStreamingInitiated || hasTextChunks; + // Track whether we've already connected to this activeStream to avoid duplicate connections + const connectedActiveStreamRef = useRef(null); + // Track if component is mounted to prevent state updates after unmount + const isMountedRef = useRef(true); + // Track current dispatcher to prevent multiple dispatchers from adding messages + const currentDispatcherIdRef = useRef(0); + + // Set mounted flag - reset on every mount, cleanup on unmount + useEffect(function trackMountedState() { + isMountedRef.current = true; + return function cleanup() { + isMountedRef.current = false; + }; + }, []); + + // Callback to store active task info for SSE reconnection + function handleActiveTaskStarted(taskInfo: { + taskId: string; + operationId: string; + toolName: string; + toolCallId: string; + }) { + if (!sessionId) return; + setActiveTask(sessionId, { + taskId: taskInfo.taskId, + operationId: taskInfo.operationId, + toolName: taskInfo.toolName, + lastMessageId: INITIAL_STREAM_ID, + }); + } + + // Create dispatcher for stream events - stable reference for current sessionId + // Each dispatcher gets a unique ID to prevent stale dispatchers from updating state + function createDispatcher() { + if (!sessionId) return () => {}; + // Increment dispatcher ID - only the most recent dispatcher should update state + const dispatcherId = ++currentDispatcherIdRef.current; + + const baseDispatcher = createStreamEventDispatcher({ + setHasTextChunks, + setStreamingChunks, + streamingChunksRef, + hasResponseRef, + textFinalizedRef, + streamEndedRef, + setMessages, + setIsRegionBlockedModalOpen, + sessionId, + setIsStreamingInitiated, + onOperationStarted, + onActiveTaskStarted: handleActiveTaskStarted, + }); + + // Wrap dispatcher to check if it's still the current one + return function guardedDispatcher(chunk: StreamChunk) { + // Skip if component unmounted or this is a stale dispatcher + if (!isMountedRef.current) { + return; + } + if (dispatcherId !== currentDispatcherIdRef.current) { + return; + } + baseDispatcher(chunk); + }; + } useEffect( function handleSessionChange() { - if (sessionId === previousSessionIdRef.current) return; + const isSessionChange = sessionId !== previousSessionIdRef.current; - const prevSession = previousSessionIdRef.current; - if (prevSession) { - stopStreaming(prevSession); + // Handle session change - reset state + if (isSessionChange) { + const prevSession = previousSessionIdRef.current; + if (prevSession) { + stopStreaming(prevSession); + } + previousSessionIdRef.current = sessionId; + connectedActiveStreamRef.current = null; + setMessages([]); + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + setIsStreamingInitiated(false); + hasResponseRef.current = false; + textFinalizedRef.current = false; + streamEndedRef.current = false; } - previousSessionIdRef.current = sessionId; - setMessages([]); - setStreamingChunks([]); - streamingChunksRef.current = []; - setHasTextChunks(false); - setIsStreamingInitiated(false); - hasResponseRef.current = false; if (!sessionId) return; - const activeStream = activeStreams.get(sessionId); - if (!activeStream || activeStream.status !== "streaming") return; + // Priority 1: Check if server told us there's an active stream (most authoritative) + if (activeStream) { + const streamKey = `${sessionId}:${activeStream.taskId}`; - const dispatcher = createStreamEventDispatcher({ - setHasTextChunks, - setStreamingChunks, - streamingChunksRef, - hasResponseRef, - setMessages, - setIsRegionBlockedModalOpen, - sessionId, - setIsStreamingInitiated, - onOperationStarted, - }); + if (connectedActiveStreamRef.current === streamKey) { + return; + } + + // Skip if there's already an active stream for this session in the store + const existingStream = activeStreams.get(sessionId); + if (existingStream && existingStream.status === "streaming") { + connectedActiveStreamRef.current = streamKey; + return; + } + + connectedActiveStreamRef.current = streamKey; + + // Clear all state before reconnection to prevent duplicates + // Server's initialMessages is authoritative; local state will be rebuilt from SSE replay + setMessages([]); + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + textFinalizedRef.current = false; + streamEndedRef.current = false; + hasResponseRef.current = false; + + setIsStreamingInitiated(true); + setActiveTask(sessionId, { + taskId: activeStream.taskId, + operationId: activeStream.operationId, + toolName: activeStream.toolName, + lastMessageId: activeStream.lastMessageId, + }); + reconnectToTask( + sessionId, + activeStream.taskId, + activeStream.lastMessageId, + createDispatcher(), + ); + // Don't return cleanup here - the guarded dispatcher handles stale events + // and the stream will complete naturally. Cleanup would prematurely stop + // the stream when effect re-runs due to activeStreams changing. + return; + } + + // Only check localStorage/in-memory on session change + if (!isSessionChange) return; + + // Priority 2: Check localStorage for active task + const activeTask = getActiveTask(sessionId); + if (activeTask) { + // Clear all state before reconnection to prevent duplicates + // Server's initialMessages is authoritative; local state will be rebuilt from SSE replay + setMessages([]); + setStreamingChunks([]); + streamingChunksRef.current = []; + setHasTextChunks(false); + textFinalizedRef.current = false; + streamEndedRef.current = false; + hasResponseRef.current = false; + + setIsStreamingInitiated(true); + reconnectToTask( + sessionId, + activeTask.taskId, + activeTask.lastMessageId, + createDispatcher(), + ); + // Don't return cleanup here - the guarded dispatcher handles stale events + return; + } + + // Priority 3: Check for an in-memory active stream (same-tab scenario) + const inMemoryStream = activeStreams.get(sessionId); + if (!inMemoryStream || inMemoryStream.status !== "streaming") { + return; + } setIsStreamingInitiated(true); const skipReplay = initialMessages.length > 0; - return subscribeToStream(sessionId, dispatcher, skipReplay); + return subscribeToStream(sessionId, createDispatcher(), skipReplay); }, [ sessionId, @@ -110,6 +276,10 @@ export function useChatContainer({ activeStreams, subscribeToStream, onOperationStarted, + getActiveTask, + reconnectToTask, + activeStream, + setActiveTask, ], ); @@ -124,7 +294,7 @@ export function useChatContainer({ msg.type === "agent_carousel" || msg.type === "execution_started" ) { - const toolId = (msg as any).toolId; + const toolId = hasToolId(msg) ? msg.toolId : undefined; if (toolId) { ids.add(toolId); } @@ -141,12 +311,8 @@ export function useChatContainer({ setMessages((prev) => { const filtered = prev.filter((msg) => { - if ( - msg.type === "operation_started" || - msg.type === "operation_pending" || - msg.type === "operation_in_progress" - ) { - const toolId = (msg as any).toolId || (msg as any).toolCallId; + if (isOperationMessage(msg)) { + const toolId = getToolIdFromMessage(msg); if (toolId && completedToolIds.has(toolId)) { return false; // Remove - operation completed } @@ -174,12 +340,8 @@ export function useChatContainer({ // Filter local messages: remove duplicates and completed operation messages const newLocalMessages = messages.filter((msg) => { // Remove operation messages for completed tools - if ( - msg.type === "operation_started" || - msg.type === "operation_pending" || - msg.type === "operation_in_progress" - ) { - const toolId = (msg as any).toolId || (msg as any).toolCallId; + if (isOperationMessage(msg)) { + const toolId = getToolIdFromMessage(msg); if (toolId && completedToolIds.has(toolId)) { return false; } @@ -190,7 +352,70 @@ export function useChatContainer({ }); // Server messages first (correct order), then new local messages - return [...processedInitial, ...newLocalMessages]; + const combined = [...processedInitial, ...newLocalMessages]; + + // Post-processing: Remove duplicate assistant messages that can occur during + // race conditions (e.g., rapid screen switching during SSE reconnection). + // Two assistant messages are considered duplicates if: + // - They are both text messages with role "assistant" + // - One message's content starts with the other's content (partial vs complete) + // - Or they have very similar content (>80% overlap at the start) + const deduplicated: ChatMessageData[] = []; + for (let i = 0; i < combined.length; i++) { + const current = combined[i]; + + // Check if this is an assistant text message + if (current.type !== "message" || current.role !== "assistant") { + deduplicated.push(current); + continue; + } + + // Look for duplicate assistant messages in the rest of the array + let dominated = false; + for (let j = 0; j < combined.length; j++) { + if (i === j) continue; + const other = combined[j]; + if (other.type !== "message" || other.role !== "assistant") continue; + + const currentContent = current.content || ""; + const otherContent = other.content || ""; + + // Skip empty messages + if (!currentContent.trim() || !otherContent.trim()) continue; + + // Check if current is a prefix of other (current is incomplete version) + if ( + otherContent.length > currentContent.length && + otherContent.startsWith(currentContent.slice(0, 100)) + ) { + // Current is a shorter/incomplete version of other - skip it + dominated = true; + break; + } + + // Check if messages are nearly identical (within a small difference) + // This catches cases where content differs only slightly + const minLen = Math.min(currentContent.length, otherContent.length); + const compareLen = Math.min(minLen, 200); // Compare first 200 chars + if ( + compareLen > 50 && + currentContent.slice(0, compareLen) === + otherContent.slice(0, compareLen) + ) { + // Same prefix - keep the longer one + if (otherContent.length > currentContent.length) { + dominated = true; + break; + } + } + } + + if (!dominated) { + deduplicated.push(current); + } + } + + return deduplicated; }, [initialMessages, messages, completedToolIds]); async function sendMessage( @@ -198,10 +423,8 @@ export function useChatContainer({ isUserMessage: boolean = true, context?: { url: string; content: string }, ) { - if (!sessionId) { - console.error("[useChatContainer] Cannot send message: no session ID"); - return; - } + if (!sessionId) return; + setIsRegionBlockedModalOpen(false); if (isUserMessage) { const userMessage = createUserMessage(content); @@ -214,31 +437,19 @@ export function useChatContainer({ setHasTextChunks(false); setIsStreamingInitiated(true); hasResponseRef.current = false; - - const dispatcher = createStreamEventDispatcher({ - setHasTextChunks, - setStreamingChunks, - streamingChunksRef, - hasResponseRef, - setMessages, - setIsRegionBlockedModalOpen, - sessionId, - setIsStreamingInitiated, - onOperationStarted, - }); + textFinalizedRef.current = false; + streamEndedRef.current = false; try { await sendStreamMessage( sessionId, content, - dispatcher, + createDispatcher(), isUserMessage, context, ); } catch (err) { - console.error("[useChatContainer] Failed to send message:", err); setIsStreamingInitiated(false); - if (err instanceof Error && err.name === "AbortError") return; const errorMessage = diff --git a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx index beb4678e73..bac004f6ed 100644 --- a/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx +++ b/autogpt_platform/frontend/src/components/contextual/Chat/components/ChatInput/ChatInput.tsx @@ -74,19 +74,20 @@ export function ChatInput({ hasMultipleLines ? "rounded-xlarge" : "rounded-full", )} > + {!value && !isRecording && ( + + )}