mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-04 11:55:11 -05:00
Compare commits
31 Commits
autogpt-pl
...
feature/vi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
059c94afac | ||
|
|
3ee7c9bfa8 | ||
|
|
0fde14bf23 | ||
|
|
e8b33f9dbe | ||
|
|
6d6d3b820e | ||
|
|
8b5c018032 | ||
|
|
b5611b00b3 | ||
|
|
6cd62c4d50 | ||
|
|
9f4c33a695 | ||
|
|
b0debe9488 | ||
|
|
b20767bde9 | ||
|
|
b9a9481381 | ||
|
|
d2d2a0c0c9 | ||
|
|
521f69220d | ||
|
|
368adc985d | ||
|
|
8c3216f0a2 | ||
|
|
94063616e5 | ||
|
|
2433a86cb1 | ||
|
|
0ede203f8e | ||
|
|
dc751316c5 | ||
|
|
e7fb54e6af | ||
|
|
7b76f4d1e4 | ||
|
|
3cc56de0fa | ||
|
|
d2bead0f7a | ||
|
|
f8d3893c16 | ||
|
|
1cfbc0dd08 | ||
|
|
ff84643b48 | ||
|
|
c19c3c834a | ||
|
|
d0f7ba8cfd | ||
|
|
2a855f4bd0 | ||
|
|
b93bb3b9f8 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -180,4 +180,3 @@ autogpt_platform/backend/settings.py
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
.next
|
||||
@@ -54,7 +54,7 @@ Before proceeding with the installation, ensure your system meets the following
|
||||
### Updated Setup Instructions:
|
||||
We've moved to a fully maintained and regularly updated documentation site.
|
||||
|
||||
👉 [Follow the official self-hosting guide here](https://agpt.co/docs/platform/getting-started/getting-started)
|
||||
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
|
||||
|
||||
|
||||
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
||||
|
||||
@@ -152,6 +152,7 @@ REPLICATE_API_KEY=
|
||||
REVID_API_KEY=
|
||||
SCREENSHOTONE_API_KEY=
|
||||
UNREAL_SPEECH_API_KEY=
|
||||
ELEVENLABS_API_KEY=
|
||||
|
||||
# Data & Search Services
|
||||
E2B_API_KEY=
|
||||
|
||||
@@ -62,10 +62,11 @@ ENV POETRY_HOME=/opt/poetry \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Install Python without upgrading system-managed packages
|
||||
# Install Python and FFmpeg (required for video processing blocks)
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy only necessary files from builder
|
||||
|
||||
@@ -1,368 +0,0 @@
|
||||
"""Redis Streams consumer for operation completion messages.
|
||||
|
||||
This module provides a consumer (ChatCompletionConsumer) that listens for
|
||||
completion notifications (OperationCompleteMessage) from external services
|
||||
(like Agent Generator) and triggers the appropriate stream registry and
|
||||
chat service updates via process_operation_success/process_operation_failure.
|
||||
|
||||
Why Redis Streams instead of RabbitMQ?
|
||||
--------------------------------------
|
||||
While the project typically uses RabbitMQ for async task queues (e.g., execution
|
||||
queue), Redis Streams was chosen for chat completion notifications because:
|
||||
|
||||
1. **Unified Infrastructure**: The SSE reconnection feature already uses Redis
|
||||
Streams (via stream_registry) for message persistence and replay. Using Redis
|
||||
Streams for completion notifications keeps all chat streaming infrastructure
|
||||
in one system, simplifying operations and reducing cross-system coordination.
|
||||
|
||||
2. **Message Replay**: Redis Streams support XREAD with arbitrary message IDs,
|
||||
allowing consumers to replay missed messages after reconnection. This aligns
|
||||
with the SSE reconnection pattern where clients can resume from last_message_id.
|
||||
|
||||
3. **Consumer Groups with XAUTOCLAIM**: Redis consumer groups provide automatic
|
||||
load balancing across pods with explicit message claiming (XAUTOCLAIM) for
|
||||
recovering from dead consumers - ideal for the completion callback pattern.
|
||||
|
||||
4. **Lower Latency**: For real-time SSE updates, Redis (already in-memory for
|
||||
stream_registry) provides lower latency than an additional RabbitMQ hop.
|
||||
|
||||
5. **Atomicity with Task State**: Completion processing often needs to update
|
||||
task metadata stored in Redis. Keeping both in Redis enables simpler
|
||||
transactional semantics without distributed coordination.
|
||||
|
||||
The consumer uses Redis Streams with consumer groups for reliable message
|
||||
processing across multiple platform pods, with XAUTOCLAIM for reclaiming
|
||||
stale pending messages from dead consumers.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
from pydantic import BaseModel
|
||||
from redis.exceptions import ResponseError
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
class OperationCompleteMessage(BaseModel):
|
||||
"""Message format for operation completion notifications."""
|
||||
|
||||
operation_id: str
|
||||
task_id: str
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
class ChatCompletionConsumer:
|
||||
"""Consumer for chat operation completion messages from Redis Streams.
|
||||
|
||||
This consumer initializes its own Prisma client in start() to ensure
|
||||
database operations work correctly within this async context.
|
||||
|
||||
Uses Redis consumer groups to allow multiple platform pods to consume
|
||||
messages reliably with automatic redelivery on failure.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._consumer_task: asyncio.Task | None = None
|
||||
self._running = False
|
||||
self._prisma: Prisma | None = None
|
||||
self._consumer_name = f"consumer-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the completion consumer."""
|
||||
if self._running:
|
||||
logger.warning("Completion consumer already running")
|
||||
return
|
||||
|
||||
# Create consumer group if it doesn't exist
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
await redis.xgroup_create(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
id="0",
|
||||
mkstream=True,
|
||||
)
|
||||
logger.info(
|
||||
f"Created consumer group '{config.stream_consumer_group}' "
|
||||
f"on stream '{config.stream_completion_name}'"
|
||||
)
|
||||
except ResponseError as e:
|
||||
if "BUSYGROUP" in str(e):
|
||||
logger.debug(
|
||||
f"Consumer group '{config.stream_consumer_group}' already exists"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
self._running = True
|
||||
self._consumer_task = asyncio.create_task(self._consume_messages())
|
||||
logger.info(
|
||||
f"Chat completion consumer started (consumer: {self._consumer_name})"
|
||||
)
|
||||
|
||||
async def _ensure_prisma(self) -> Prisma:
|
||||
"""Lazily initialize Prisma client on first use."""
|
||||
if self._prisma is None:
|
||||
database_url = os.getenv("DATABASE_URL", "postgresql://localhost:5432")
|
||||
self._prisma = Prisma(datasource={"url": database_url})
|
||||
await self._prisma.connect()
|
||||
logger.info("[COMPLETION] Consumer Prisma client connected (lazy init)")
|
||||
return self._prisma
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop the completion consumer."""
|
||||
self._running = False
|
||||
|
||||
if self._consumer_task:
|
||||
self._consumer_task.cancel()
|
||||
try:
|
||||
await self._consumer_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._consumer_task = None
|
||||
|
||||
if self._prisma:
|
||||
await self._prisma.disconnect()
|
||||
self._prisma = None
|
||||
logger.info("[COMPLETION] Consumer Prisma client disconnected")
|
||||
|
||||
logger.info("Chat completion consumer stopped")
|
||||
|
||||
async def _consume_messages(self) -> None:
|
||||
"""Main message consumption loop with retry logic."""
|
||||
max_retries = 10
|
||||
retry_delay = 5 # seconds
|
||||
retry_count = 0
|
||||
block_timeout = 5000 # milliseconds
|
||||
|
||||
while self._running and retry_count < max_retries:
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Reset retry count on successful connection
|
||||
retry_count = 0
|
||||
|
||||
while self._running:
|
||||
# First, claim any stale pending messages from dead consumers
|
||||
# Redis does NOT auto-redeliver pending messages; we must explicitly
|
||||
# claim them using XAUTOCLAIM
|
||||
try:
|
||||
claimed_result = await redis.xautoclaim(
|
||||
name=config.stream_completion_name,
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
min_idle_time=config.stream_claim_min_idle_ms,
|
||||
start_id="0-0",
|
||||
count=10,
|
||||
)
|
||||
# xautoclaim returns: (next_start_id, [(id, data), ...], [deleted_ids])
|
||||
if claimed_result and len(claimed_result) >= 2:
|
||||
claimed_entries = claimed_result[1]
|
||||
if claimed_entries:
|
||||
logger.info(
|
||||
f"Claimed {len(claimed_entries)} stale pending messages"
|
||||
)
|
||||
for entry_id, data in claimed_entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
except Exception as e:
|
||||
logger.warning(f"XAUTOCLAIM failed (non-fatal): {e}")
|
||||
|
||||
# Read new messages from the stream
|
||||
messages = await redis.xreadgroup(
|
||||
groupname=config.stream_consumer_group,
|
||||
consumername=self._consumer_name,
|
||||
streams={config.stream_completion_name: ">"},
|
||||
block=block_timeout,
|
||||
count=10,
|
||||
)
|
||||
|
||||
if not messages:
|
||||
continue
|
||||
|
||||
for stream_name, entries in messages:
|
||||
for entry_id, data in entries:
|
||||
if not self._running:
|
||||
return
|
||||
await self._process_entry(redis, entry_id, data)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Consumer cancelled")
|
||||
return
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
logger.error(
|
||||
f"Consumer error (retry {retry_count}/{max_retries}): {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
if self._running and retry_count < max_retries:
|
||||
await asyncio.sleep(retry_delay)
|
||||
else:
|
||||
logger.error("Max retries reached, stopping consumer")
|
||||
return
|
||||
|
||||
async def _process_entry(
|
||||
self, redis: Any, entry_id: str, data: dict[str, Any]
|
||||
) -> None:
|
||||
"""Process a single stream entry and acknowledge it on success.
|
||||
|
||||
Args:
|
||||
redis: Redis client connection
|
||||
entry_id: The stream entry ID
|
||||
data: The entry data dict
|
||||
"""
|
||||
try:
|
||||
# Handle the message
|
||||
message_data = data.get("data")
|
||||
if message_data:
|
||||
await self._handle_message(
|
||||
message_data.encode()
|
||||
if isinstance(message_data, str)
|
||||
else message_data
|
||||
)
|
||||
|
||||
# Acknowledge the message after successful processing
|
||||
await redis.xack(
|
||||
config.stream_completion_name,
|
||||
config.stream_consumer_group,
|
||||
entry_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing completion message {entry_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Message remains in pending state and will be claimed by
|
||||
# XAUTOCLAIM after min_idle_time expires
|
||||
|
||||
async def _handle_message(self, body: bytes) -> None:
|
||||
"""Handle a completion message using our own Prisma client."""
|
||||
try:
|
||||
data = orjson.loads(body)
|
||||
message = OperationCompleteMessage(**data)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse completion message: {e}")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Received completion for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id}, success={message.success})"
|
||||
)
|
||||
|
||||
# Find task in registry
|
||||
task = await stream_registry.find_task_by_operation_id(message.operation_id)
|
||||
if task is None:
|
||||
task = await stream_registry.get_task(message.task_id)
|
||||
|
||||
if task is None:
|
||||
logger.warning(
|
||||
f"[COMPLETION] Task not found for operation {message.operation_id} "
|
||||
f"(task_id={message.task_id})"
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Found task: task_id={task.task_id}, "
|
||||
f"session_id={task.session_id}, tool_call_id={task.tool_call_id}"
|
||||
)
|
||||
|
||||
# Guard against empty task fields
|
||||
if not task.task_id or not task.session_id or not task.tool_call_id:
|
||||
logger.error(
|
||||
f"[COMPLETION] Task has empty critical fields! "
|
||||
f"task_id={task.task_id!r}, session_id={task.session_id!r}, "
|
||||
f"tool_call_id={task.tool_call_id!r}"
|
||||
)
|
||||
return
|
||||
|
||||
if message.success:
|
||||
await self._handle_success(task, message)
|
||||
else:
|
||||
await self._handle_failure(task, message)
|
||||
|
||||
async def _handle_success(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle successful operation completion."""
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_success(task, message.result, prisma)
|
||||
|
||||
async def _handle_failure(
|
||||
self,
|
||||
task: stream_registry.ActiveTask,
|
||||
message: OperationCompleteMessage,
|
||||
) -> None:
|
||||
"""Handle failed operation completion."""
|
||||
prisma = await self._ensure_prisma()
|
||||
await process_operation_failure(task, message.error, prisma)
|
||||
|
||||
|
||||
# Module-level consumer instance
|
||||
_consumer: ChatCompletionConsumer | None = None
|
||||
|
||||
|
||||
async def start_completion_consumer() -> None:
|
||||
"""Start the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer is None:
|
||||
_consumer = ChatCompletionConsumer()
|
||||
await _consumer.start()
|
||||
|
||||
|
||||
async def stop_completion_consumer() -> None:
|
||||
"""Stop the global completion consumer."""
|
||||
global _consumer
|
||||
if _consumer:
|
||||
await _consumer.stop()
|
||||
_consumer = None
|
||||
|
||||
|
||||
async def publish_operation_complete(
|
||||
operation_id: str,
|
||||
task_id: str,
|
||||
success: bool,
|
||||
result: dict | str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Publish an operation completion message to Redis Streams.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID that completed.
|
||||
task_id: The task ID associated with the operation.
|
||||
success: Whether the operation succeeded.
|
||||
result: The result data (for success).
|
||||
error: The error message (for failure).
|
||||
"""
|
||||
message = OperationCompleteMessage(
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
success=success,
|
||||
result=result,
|
||||
error=error,
|
||||
)
|
||||
|
||||
redis = await get_redis_async()
|
||||
await redis.xadd(
|
||||
config.stream_completion_name,
|
||||
{"data": message.model_dump_json()},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
logger.info(f"Published completion for operation {operation_id}")
|
||||
@@ -1,344 +0,0 @@
|
||||
"""Shared completion handling for operation success and failure.
|
||||
|
||||
This module provides common logic for handling operation completion from both:
|
||||
- The Redis Streams consumer (completion_consumer.py)
|
||||
- The HTTP webhook endpoint (routes.py)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
import orjson
|
||||
from prisma import Prisma
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .response_model import StreamError, StreamToolOutputAvailable
|
||||
from .tools.models import ErrorResponse
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Tools that produce agent_json that needs to be saved to library
|
||||
AGENT_GENERATION_TOOLS = {"create_agent", "edit_agent"}
|
||||
|
||||
# Keys that should be stripped from agent_json when returning in error responses
|
||||
SENSITIVE_KEYS = frozenset(
|
||||
{
|
||||
"api_key",
|
||||
"apikey",
|
||||
"api_secret",
|
||||
"password",
|
||||
"secret",
|
||||
"credentials",
|
||||
"credential",
|
||||
"token",
|
||||
"access_token",
|
||||
"refresh_token",
|
||||
"private_key",
|
||||
"privatekey",
|
||||
"auth",
|
||||
"authorization",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_agent_json(obj: Any) -> Any:
|
||||
"""Recursively sanitize agent_json by removing sensitive keys.
|
||||
|
||||
Args:
|
||||
obj: The object to sanitize (dict, list, or primitive)
|
||||
|
||||
Returns:
|
||||
Sanitized copy with sensitive keys removed/redacted
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {
|
||||
k: "[REDACTED]" if k.lower() in SENSITIVE_KEYS else _sanitize_agent_json(v)
|
||||
for k, v in obj.items()
|
||||
}
|
||||
elif isinstance(obj, list):
|
||||
return [_sanitize_agent_json(item) for item in obj]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
class ToolMessageUpdateError(Exception):
|
||||
"""Raised when updating a tool message in the database fails."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
async def _update_tool_message(
|
||||
session_id: str,
|
||||
tool_call_id: str,
|
||||
content: str,
|
||||
prisma_client: Prisma | None,
|
||||
) -> None:
|
||||
"""Update tool message in database.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
tool_call_id: The tool call ID to update
|
||||
content: The new content for the message
|
||||
prisma_client: Optional Prisma client. If None, uses chat_service.
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails. The caller should
|
||||
handle this to avoid marking the task as completed with inconsistent state.
|
||||
"""
|
||||
try:
|
||||
if prisma_client:
|
||||
# Use provided Prisma client (for consumer with its own connection)
|
||||
updated_count = await prisma_client.chatmessage.update_many(
|
||||
where={
|
||||
"sessionId": session_id,
|
||||
"toolCallId": tool_call_id,
|
||||
},
|
||||
data={"content": content},
|
||||
)
|
||||
# Check if any rows were updated - 0 means message not found
|
||||
if updated_count == 0:
|
||||
raise ToolMessageUpdateError(
|
||||
f"No message found with tool_call_id={tool_call_id} in session {session_id}"
|
||||
)
|
||||
else:
|
||||
# Use service function (for webhook endpoint)
|
||||
await chat_service._update_pending_operation(
|
||||
session_id=session_id,
|
||||
tool_call_id=tool_call_id,
|
||||
result=content,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to update tool message: {e}", exc_info=True)
|
||||
raise ToolMessageUpdateError(
|
||||
f"Failed to update tool message for tool_call_id={tool_call_id}: {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def serialize_result(result: dict | list | str | int | float | bool | None) -> str:
|
||||
"""Serialize result to JSON string with sensible defaults.
|
||||
|
||||
Args:
|
||||
result: The result to serialize. Can be a dict, list, string,
|
||||
number, boolean, or None.
|
||||
|
||||
Returns:
|
||||
JSON string representation of the result. Returns '{"status": "completed"}'
|
||||
only when result is explicitly None.
|
||||
"""
|
||||
if isinstance(result, str):
|
||||
return result
|
||||
if result is None:
|
||||
return '{"status": "completed"}'
|
||||
return orjson.dumps(result).decode("utf-8")
|
||||
|
||||
|
||||
async def _save_agent_from_result(
|
||||
result: dict[str, Any],
|
||||
user_id: str | None,
|
||||
tool_name: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Save agent to library if result contains agent_json.
|
||||
|
||||
Args:
|
||||
result: The result dict that may contain agent_json
|
||||
user_id: The user ID to save the agent for
|
||||
tool_name: The tool name (create_agent or edit_agent)
|
||||
|
||||
Returns:
|
||||
Updated result dict with saved agent details, or original result if no agent_json
|
||||
"""
|
||||
if not user_id:
|
||||
logger.warning("[COMPLETION] Cannot save agent: no user_id in task")
|
||||
return result
|
||||
|
||||
agent_json = result.get("agent_json")
|
||||
if not agent_json:
|
||||
logger.warning(
|
||||
f"[COMPLETION] {tool_name} completed but no agent_json in result"
|
||||
)
|
||||
return result
|
||||
|
||||
try:
|
||||
from .tools.agent_generator import save_agent_to_library
|
||||
|
||||
is_update = tool_name == "edit_agent"
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
agent_json, user_id, is_update=is_update
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Saved agent '{created_graph.name}' to library "
|
||||
f"(graph_id={created_graph.id}, library_agent_id={library_agent.id})"
|
||||
)
|
||||
|
||||
# Return a response similar to AgentSavedResponse
|
||||
return {
|
||||
"type": "agent_saved",
|
||||
"message": f"Agent '{created_graph.name}' has been saved to your library!",
|
||||
"agent_id": created_graph.id,
|
||||
"agent_name": created_graph.name,
|
||||
"library_agent_id": library_agent.id,
|
||||
"library_agent_link": f"/library/agents/{library_agent.id}",
|
||||
"agent_page_link": f"/build?flowID={created_graph.id}",
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to save agent to library: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
# Return error but don't fail the whole operation
|
||||
# Sanitize agent_json to remove sensitive keys before returning
|
||||
return {
|
||||
"type": "error",
|
||||
"message": f"Agent was generated but failed to save: {str(e)}",
|
||||
"error": str(e),
|
||||
"agent_json": _sanitize_agent_json(agent_json),
|
||||
}
|
||||
|
||||
|
||||
async def process_operation_success(
|
||||
task: stream_registry.ActiveTask,
|
||||
result: dict | str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle successful operation completion.
|
||||
|
||||
Publishes the result to the stream registry, updates the database,
|
||||
generates LLM continuation, and marks the task as completed.
|
||||
|
||||
Args:
|
||||
task: The active task that completed
|
||||
result: The result data from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
|
||||
Raises:
|
||||
ToolMessageUpdateError: If the database update fails. The task will be
|
||||
marked as failed instead of completed to avoid inconsistent state.
|
||||
"""
|
||||
# For agent generation tools, save the agent to library
|
||||
if task.tool_name in AGENT_GENERATION_TOOLS and isinstance(result, dict):
|
||||
result = await _save_agent_from_result(result, task.user_id, task.tool_name)
|
||||
|
||||
# Serialize result for output (only substitute default when result is exactly None)
|
||||
result_output = result if result is not None else {"status": "completed"}
|
||||
output_str = (
|
||||
result_output
|
||||
if isinstance(result_output, str)
|
||||
else orjson.dumps(result_output).decode("utf-8")
|
||||
)
|
||||
|
||||
# Publish result to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=task.tool_call_id,
|
||||
toolName=task.tool_name,
|
||||
output=output_str,
|
||||
success=True,
|
||||
),
|
||||
)
|
||||
|
||||
# Update pending operation in database
|
||||
# If this fails, we must not continue to mark the task as completed
|
||||
result_str = serialize_result(result)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=result_str,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - mark task as failed to avoid inconsistent state
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed for task {task.task_id}, "
|
||||
"marking as failed instead of completed"
|
||||
)
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText="Failed to save operation result to database"),
|
||||
)
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
raise
|
||||
|
||||
# Generate LLM continuation with streaming
|
||||
try:
|
||||
await chat_service._generate_llm_continuation_with_streaming(
|
||||
session_id=task.session_id,
|
||||
user_id=task.user_id,
|
||||
task_id=task.task_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[COMPLETION] Failed to generate LLM continuation: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# Mark task as completed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="completed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(
|
||||
f"[COMPLETION] Successfully processed completion for task {task.task_id}"
|
||||
)
|
||||
|
||||
|
||||
async def process_operation_failure(
|
||||
task: stream_registry.ActiveTask,
|
||||
error: str | None,
|
||||
prisma_client: Prisma | None = None,
|
||||
) -> None:
|
||||
"""Handle failed operation completion.
|
||||
|
||||
Publishes the error to the stream registry, updates the database with
|
||||
the error response, and marks the task as failed.
|
||||
|
||||
Args:
|
||||
task: The active task that failed
|
||||
error: The error message from the operation
|
||||
prisma_client: Optional Prisma client for database operations.
|
||||
If None, uses chat_service._update_pending_operation instead.
|
||||
"""
|
||||
error_msg = error or "Operation failed"
|
||||
|
||||
# Publish error to stream registry
|
||||
await stream_registry.publish_chunk(
|
||||
task.task_id,
|
||||
StreamError(errorText=error_msg),
|
||||
)
|
||||
|
||||
# Update pending operation with error
|
||||
# If this fails, we still continue to mark the task as failed
|
||||
error_response = ErrorResponse(
|
||||
message=error_msg,
|
||||
error=error,
|
||||
)
|
||||
try:
|
||||
await _update_tool_message(
|
||||
session_id=task.session_id,
|
||||
tool_call_id=task.tool_call_id,
|
||||
content=error_response.model_dump_json(),
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
except ToolMessageUpdateError:
|
||||
# DB update failed - log but continue with cleanup
|
||||
logger.error(
|
||||
f"[COMPLETION] DB update failed while processing failure for task {task.task_id}, "
|
||||
"continuing with cleanup"
|
||||
)
|
||||
|
||||
# Mark task as failed and release Redis lock
|
||||
await stream_registry.mark_task_completed(task.task_id, status="failed")
|
||||
try:
|
||||
await chat_service._mark_operation_completed(task.tool_call_id)
|
||||
except Exception as e:
|
||||
logger.error(f"[COMPLETION] Failed to mark operation completed: {e}")
|
||||
|
||||
logger.info(f"[COMPLETION] Processed failure for task {task.task_id}: {error_msg}")
|
||||
@@ -44,48 +44,6 @@ class ChatConfig(BaseSettings):
|
||||
description="TTL in seconds for long-running operation tracking in Redis (safety net if pod dies)",
|
||||
)
|
||||
|
||||
# Stream registry configuration for SSE reconnection
|
||||
stream_ttl: int = Field(
|
||||
default=3600,
|
||||
description="TTL in seconds for stream data in Redis (1 hour)",
|
||||
)
|
||||
stream_max_length: int = Field(
|
||||
default=10000,
|
||||
description="Maximum number of messages to store per stream",
|
||||
)
|
||||
|
||||
# Redis Streams configuration for completion consumer
|
||||
stream_completion_name: str = Field(
|
||||
default="chat:completions",
|
||||
description="Redis Stream name for operation completions",
|
||||
)
|
||||
stream_consumer_group: str = Field(
|
||||
default="chat_consumers",
|
||||
description="Consumer group name for completion stream",
|
||||
)
|
||||
stream_claim_min_idle_ms: int = Field(
|
||||
default=60000,
|
||||
description="Minimum idle time in milliseconds before claiming pending messages from dead consumers",
|
||||
)
|
||||
|
||||
# Redis key prefixes for stream registry
|
||||
task_meta_prefix: str = Field(
|
||||
default="chat:task:meta:",
|
||||
description="Prefix for task metadata hash keys",
|
||||
)
|
||||
task_stream_prefix: str = Field(
|
||||
default="chat:stream:",
|
||||
description="Prefix for task message stream keys",
|
||||
)
|
||||
task_op_prefix: str = Field(
|
||||
default="chat:task:op:",
|
||||
description="Prefix for operation ID to task ID mapping keys",
|
||||
)
|
||||
internal_api_key: str | None = Field(
|
||||
default=None,
|
||||
description="API key for internal webhook callbacks (env: CHAT_INTERNAL_API_KEY)",
|
||||
)
|
||||
|
||||
# Langfuse Prompt Management Configuration
|
||||
# Note: Langfuse credentials are in Settings().secrets (settings.py)
|
||||
langfuse_prompt_name: str = Field(
|
||||
@@ -124,14 +82,6 @@ class ChatConfig(BaseSettings):
|
||||
v = "https://openrouter.ai/api/v1"
|
||||
return v
|
||||
|
||||
@field_validator("internal_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_internal_api_key(cls, v):
|
||||
"""Get internal API key from environment if not provided."""
|
||||
if v is None:
|
||||
v = os.getenv("CHAT_INTERNAL_API_KEY")
|
||||
return v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -52,10 +52,6 @@ class StreamStart(StreamBaseResponse):
|
||||
|
||||
type: ResponseType = ResponseType.START
|
||||
messageId: str = Field(..., description="Unique message ID")
|
||||
taskId: str | None = Field(
|
||||
default=None,
|
||||
description="Task ID for SSE reconnection. Clients can reconnect using GET /tasks/{taskId}/stream",
|
||||
)
|
||||
|
||||
|
||||
class StreamFinish(StreamBaseResponse):
|
||||
|
||||
@@ -1,23 +1,19 @@
|
||||
"""Chat API routes for chat session management and streaming via SSE."""
|
||||
|
||||
import logging
|
||||
import uuid as uuid_module
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated
|
||||
|
||||
from autogpt_libs import auth
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query, Security
|
||||
from fastapi import APIRouter, Depends, Query, Security
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
from . import service as chat_service
|
||||
from . import stream_registry
|
||||
from .completion_handler import process_operation_failure, process_operation_success
|
||||
from .config import ChatConfig
|
||||
from .model import ChatSession, create_chat_session, get_chat_session, get_user_sessions
|
||||
from .response_model import StreamFinish, StreamHeartbeat, StreamStart
|
||||
|
||||
config = ChatConfig()
|
||||
|
||||
@@ -59,15 +55,6 @@ class CreateSessionResponse(BaseModel):
|
||||
user_id: str | None
|
||||
|
||||
|
||||
class ActiveStreamInfo(BaseModel):
|
||||
"""Information about an active stream for reconnection."""
|
||||
|
||||
task_id: str
|
||||
last_message_id: str # Redis Stream message ID for resumption
|
||||
operation_id: str # Operation ID for completion tracking
|
||||
tool_name: str # Name of the tool being executed
|
||||
|
||||
|
||||
class SessionDetailResponse(BaseModel):
|
||||
"""Response model providing complete details for a chat session, including messages."""
|
||||
|
||||
@@ -76,7 +63,6 @@ class SessionDetailResponse(BaseModel):
|
||||
updated_at: str
|
||||
user_id: str | None
|
||||
messages: list[dict]
|
||||
active_stream: ActiveStreamInfo | None = None # Present if stream is still active
|
||||
|
||||
|
||||
class SessionSummaryResponse(BaseModel):
|
||||
@@ -95,14 +81,6 @@ class ListSessionsResponse(BaseModel):
|
||||
total: int
|
||||
|
||||
|
||||
class OperationCompleteRequest(BaseModel):
|
||||
"""Request model for external completion webhook."""
|
||||
|
||||
success: bool
|
||||
result: dict | str | None = None
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ========== Routes ==========
|
||||
|
||||
|
||||
@@ -188,14 +166,13 @@ async def get_session(
|
||||
Retrieve the details of a specific chat session.
|
||||
|
||||
Looks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.
|
||||
If there's an active stream for this session, returns the task_id for reconnection.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier for the desired chat session.
|
||||
user_id: The optional authenticated user ID, or None for anonymous access.
|
||||
|
||||
Returns:
|
||||
SessionDetailResponse: Details for the requested session, including active_stream info if applicable.
|
||||
SessionDetailResponse: Details for the requested session, or None if not found.
|
||||
|
||||
"""
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
@@ -203,28 +180,11 @@ async def get_session(
|
||||
raise NotFoundError(f"Session {session_id} not found.")
|
||||
|
||||
messages = [message.model_dump() for message in session.messages]
|
||||
|
||||
# Check if there's an active stream for this session
|
||||
active_stream_info = None
|
||||
active_task, last_message_id = await stream_registry.get_active_task_for_session(
|
||||
session_id, user_id
|
||||
logger.info(
|
||||
f"Returning session {session_id}: "
|
||||
f"message_count={len(messages)}, "
|
||||
f"roles={[m.get('role') for m in messages]}"
|
||||
)
|
||||
if active_task:
|
||||
# Filter out the in-progress assistant message from the session response.
|
||||
# The client will receive the complete assistant response through the SSE
|
||||
# stream replay instead, preventing duplicate content.
|
||||
if messages and messages[-1].get("role") == "assistant":
|
||||
messages = messages[:-1]
|
||||
|
||||
# Use "0-0" as last_message_id to replay the stream from the beginning.
|
||||
# Since we filtered out the cached assistant message, the client needs
|
||||
# the full stream to reconstruct the response.
|
||||
active_stream_info = ActiveStreamInfo(
|
||||
task_id=active_task.task_id,
|
||||
last_message_id="0-0",
|
||||
operation_id=active_task.operation_id,
|
||||
tool_name=active_task.tool_name,
|
||||
)
|
||||
|
||||
return SessionDetailResponse(
|
||||
id=session.session_id,
|
||||
@@ -232,7 +192,6 @@ async def get_session(
|
||||
updated_at=session.updated_at.isoformat(),
|
||||
user_id=session.user_id or None,
|
||||
messages=messages,
|
||||
active_stream=active_stream_info,
|
||||
)
|
||||
|
||||
|
||||
@@ -252,112 +211,49 @@ async def stream_chat_post(
|
||||
- Tool call UI elements (if invoked)
|
||||
- Tool execution results
|
||||
|
||||
The AI generation runs in a background task that continues even if the client disconnects.
|
||||
All chunks are written to Redis for reconnection support. If the client disconnects,
|
||||
they can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.
|
||||
|
||||
Args:
|
||||
session_id: The chat session identifier to associate with the streamed messages.
|
||||
request: Request body containing message, is_user_message, and optional context.
|
||||
user_id: Optional authenticated user ID.
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks. First chunk is a "start" event
|
||||
containing the task_id for reconnection.
|
||||
StreamingResponse: SSE-formatted response chunks.
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
session = await _validate_and_get_session(session_id, user_id)
|
||||
|
||||
# Create a task in the stream registry for reconnection support
|
||||
task_id = str(uuid_module.uuid4())
|
||||
operation_id = str(uuid_module.uuid4())
|
||||
await stream_registry.create_task(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id="chat_stream", # Not a tool call, but needed for the model
|
||||
tool_name="chat",
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Background task that runs the AI generation independently of SSE connection
|
||||
async def run_ai_generation():
|
||||
try:
|
||||
# Emit a start event with task_id for reconnection
|
||||
start_chunk = StreamStart(messageId=task_id, taskId=task_id)
|
||||
await stream_registry.publish_chunk(task_id, start_chunk)
|
||||
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
# Write to Redis (subscribers will receive via XREAD)
|
||||
await stream_registry.publish_chunk(task_id, chunk)
|
||||
|
||||
# Mark task as completed
|
||||
await stream_registry.mark_task_completed(task_id, "completed")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error in background AI generation for session {session_id}: {e}"
|
||||
)
|
||||
await stream_registry.mark_task_completed(task_id, "failed")
|
||||
|
||||
# Start the AI generation in a background task
|
||||
bg_task = asyncio.create_task(run_ai_generation())
|
||||
await stream_registry.set_task_asyncio_task(task_id, bg_task)
|
||||
|
||||
# SSE endpoint that subscribes to the task's stream
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
subscriber_queue = None
|
||||
try:
|
||||
# Subscribe to the task stream (this replays existing messages + live updates)
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id="0-0", # Get all messages from the beginning
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
yield StreamFinish().to_sse()
|
||||
yield "data: [DONE]\n\n"
|
||||
return
|
||||
|
||||
# Read from the subscriber queue and yield to SSE
|
||||
while True:
|
||||
try:
|
||||
chunk = await asyncio.wait_for(subscriber_queue.get(), timeout=30.0)
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
|
||||
except GeneratorExit:
|
||||
pass # Client disconnected - background task continues
|
||||
except Exception as e:
|
||||
logger.error(f"Error in SSE stream for task {task_id}: {e}")
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends to prevent resource leak
|
||||
if subscriber_queue is not None:
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(
|
||||
task_id, subscriber_queue
|
||||
)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
yield "data: [DONE]\n\n"
|
||||
chunk_count = 0
|
||||
first_chunk_type: str | None = None
|
||||
async for chunk in chat_service.stream_chat_completion(
|
||||
session_id,
|
||||
request.message,
|
||||
is_user_message=request.is_user_message,
|
||||
user_id=user_id,
|
||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
||||
context=request.context,
|
||||
):
|
||||
if chunk_count < 3:
|
||||
logger.info(
|
||||
"Chat stream chunk",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_type": str(chunk.type),
|
||||
},
|
||||
)
|
||||
if not first_chunk_type:
|
||||
first_chunk_type = str(chunk.type)
|
||||
chunk_count += 1
|
||||
yield chunk.to_sse()
|
||||
logger.info(
|
||||
"Chat stream completed",
|
||||
extra={
|
||||
"session_id": session_id,
|
||||
"chunk_count": chunk_count,
|
||||
"first_chunk_type": first_chunk_type,
|
||||
},
|
||||
)
|
||||
# AI SDK protocol termination
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
@@ -470,251 +366,6 @@ async def session_assign_user(
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
# ========== Task Streaming (SSE Reconnection) ==========
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}/stream",
|
||||
)
|
||||
async def stream_task(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
last_message_id: str = Query(
|
||||
default="0-0",
|
||||
description="Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Reconnect to a long-running task's SSE stream.
|
||||
|
||||
When a long-running operation (like agent generation) starts, the client
|
||||
receives a task_id. If the connection drops, the client can reconnect
|
||||
using this endpoint to resume receiving updates.
|
||||
|
||||
Args:
|
||||
task_id: The task ID from the operation_started response.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay).
|
||||
|
||||
Returns:
|
||||
StreamingResponse: SSE-formatted response chunks starting after last_message_id.
|
||||
|
||||
Raises:
|
||||
HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.
|
||||
"""
|
||||
# Check task existence and expiry before subscribing
|
||||
task, error_code = await stream_registry.get_task_with_expiry_info(task_id)
|
||||
|
||||
if error_code == "TASK_EXPIRED":
|
||||
raise HTTPException(
|
||||
status_code=410,
|
||||
detail={
|
||||
"code": "TASK_EXPIRED",
|
||||
"message": "This operation has expired. Please try again.",
|
||||
},
|
||||
)
|
||||
|
||||
if error_code == "TASK_NOT_FOUND":
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found.",
|
||||
},
|
||||
)
|
||||
|
||||
# Validate ownership if task has an owner
|
||||
if task and task.user_id and user_id != task.user_id:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"code": "ACCESS_DENIED",
|
||||
"message": "You do not have access to this task.",
|
||||
},
|
||||
)
|
||||
|
||||
# Get subscriber queue from stream registry
|
||||
subscriber_queue = await stream_registry.subscribe_to_task(
|
||||
task_id=task_id,
|
||||
user_id=user_id,
|
||||
last_message_id=last_message_id,
|
||||
)
|
||||
|
||||
if subscriber_queue is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "TASK_NOT_FOUND",
|
||||
"message": f"Task {task_id} not found or access denied.",
|
||||
},
|
||||
)
|
||||
|
||||
async def event_generator() -> AsyncGenerator[str, None]:
|
||||
import asyncio
|
||||
|
||||
heartbeat_interval = 15.0 # Send heartbeat every 15 seconds
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for next chunk with timeout for heartbeats
|
||||
chunk = await asyncio.wait_for(
|
||||
subscriber_queue.get(), timeout=heartbeat_interval
|
||||
)
|
||||
yield chunk.to_sse()
|
||||
|
||||
# Check for finish signal
|
||||
if isinstance(chunk, StreamFinish):
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
# Send heartbeat to keep connection alive
|
||||
yield StreamHeartbeat().to_sse()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in task stream {task_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# Unsubscribe when client disconnects or stream ends
|
||||
try:
|
||||
await stream_registry.unsubscribe_from_task(task_id, subscriber_queue)
|
||||
except Exception as unsub_err:
|
||||
logger.error(
|
||||
f"Error unsubscribing from task {task_id}: {unsub_err}",
|
||||
exc_info=True,
|
||||
)
|
||||
# AI SDK protocol termination - always yield even if unsubscribe fails
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
event_generator(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
"x-vercel-ai-ui-message-stream": "v1",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/tasks/{task_id}",
|
||||
)
|
||||
async def get_task_status(
|
||||
task_id: str,
|
||||
user_id: str | None = Depends(auth.get_user_id),
|
||||
) -> dict:
|
||||
"""
|
||||
Get the status of a long-running task.
|
||||
|
||||
Args:
|
||||
task_id: The task ID to check.
|
||||
user_id: Authenticated user ID for ownership validation.
|
||||
|
||||
Returns:
|
||||
dict: Task status including task_id, status, tool_name, and operation_id.
|
||||
|
||||
Raises:
|
||||
NotFoundError: If task_id is not found or user doesn't have access.
|
||||
"""
|
||||
task = await stream_registry.get_task(task_id)
|
||||
|
||||
if task is None:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task.user_id and user_id != task.user_id:
|
||||
raise NotFoundError(f"Task {task_id} not found.")
|
||||
|
||||
return {
|
||||
"task_id": task.task_id,
|
||||
"session_id": task.session_id,
|
||||
"status": task.status,
|
||||
"tool_name": task.tool_name,
|
||||
"operation_id": task.operation_id,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
# ========== External Completion Webhook ==========
|
||||
|
||||
|
||||
@router.post(
|
||||
"/operations/{operation_id}/complete",
|
||||
status_code=200,
|
||||
)
|
||||
async def complete_operation(
|
||||
operation_id: str,
|
||||
request: OperationCompleteRequest,
|
||||
x_api_key: str | None = Header(default=None),
|
||||
) -> dict:
|
||||
"""
|
||||
External completion webhook for long-running operations.
|
||||
|
||||
Called by Agent Generator (or other services) when an operation completes.
|
||||
This triggers the stream registry to publish completion and continue LLM generation.
|
||||
|
||||
Args:
|
||||
operation_id: The operation ID to complete.
|
||||
request: Completion payload with success status and result/error.
|
||||
x_api_key: Internal API key for authentication.
|
||||
|
||||
Returns:
|
||||
dict: Status of the completion.
|
||||
|
||||
Raises:
|
||||
HTTPException: If API key is invalid or operation not found.
|
||||
"""
|
||||
# Validate internal API key - reject if not configured or invalid
|
||||
if not config.internal_api_key:
|
||||
logger.error(
|
||||
"Operation complete webhook rejected: CHAT_INTERNAL_API_KEY not configured"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Webhook not available: internal API key not configured",
|
||||
)
|
||||
if x_api_key != config.internal_api_key:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
# Find task by operation_id
|
||||
task = await stream_registry.find_task_by_operation_id(operation_id)
|
||||
if task is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Operation {operation_id} not found",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Received completion webhook for operation {operation_id} "
|
||||
f"(task_id={task.task_id}, success={request.success})"
|
||||
)
|
||||
|
||||
if request.success:
|
||||
await process_operation_success(task, request.result)
|
||||
else:
|
||||
await process_operation_failure(task, request.error)
|
||||
|
||||
return {"status": "ok", "task_id": task.task_id}
|
||||
|
||||
|
||||
# ========== Configuration ==========
|
||||
|
||||
|
||||
@router.get("/config/ttl", status_code=200)
|
||||
async def get_ttl_config() -> dict:
|
||||
"""
|
||||
Get the stream TTL configuration.
|
||||
|
||||
Returns the Time-To-Live settings for chat streams, which determines
|
||||
how long clients can reconnect to an active stream.
|
||||
|
||||
Returns:
|
||||
dict: TTL configuration with seconds and milliseconds values.
|
||||
"""
|
||||
return {
|
||||
"stream_ttl_seconds": config.stream_ttl,
|
||||
"stream_ttl_ms": config.stream_ttl * 1000,
|
||||
}
|
||||
|
||||
|
||||
# ========== Health Check ==========
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,704 +0,0 @@
|
||||
"""Stream registry for managing reconnectable SSE streams.
|
||||
|
||||
This module provides a registry for tracking active streaming tasks and their
|
||||
messages. It uses Redis for all state management (no in-memory state), making
|
||||
pods stateless and horizontally scalable.
|
||||
|
||||
Architecture:
|
||||
- Redis Stream: Persists all messages for replay and real-time delivery
|
||||
- Redis Hash: Task metadata (status, session_id, etc.)
|
||||
|
||||
Subscribers:
|
||||
1. Replay missed messages from Redis Stream (XREAD)
|
||||
2. Listen for live updates via blocking XREAD
|
||||
3. No in-memory state required on the subscribing pod
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Literal
|
||||
|
||||
import orjson
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
|
||||
from .config import ChatConfig
|
||||
from .response_model import StreamBaseResponse, StreamError, StreamFinish
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
# Track background tasks for this pod (just the asyncio.Task reference, not subscribers)
|
||||
_local_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
# Track listener tasks per subscriber queue for cleanup
|
||||
# Maps queue id() to (task_id, asyncio.Task) for proper cleanup on unsubscribe
|
||||
_listener_tasks: dict[int, tuple[str, asyncio.Task]] = {}
|
||||
|
||||
# Timeout for putting chunks into subscriber queues (seconds)
|
||||
# If the queue is full and doesn't drain within this time, send an overflow error
|
||||
QUEUE_PUT_TIMEOUT = 5.0
|
||||
|
||||
# Lua script for atomic compare-and-swap status update (idempotent completion)
|
||||
# Returns 1 if status was updated, 0 if already completed/failed
|
||||
COMPLETE_TASK_SCRIPT = """
|
||||
local current = redis.call("HGET", KEYS[1], "status")
|
||||
if current == "running" then
|
||||
redis.call("HSET", KEYS[1], "status", ARGV[1])
|
||||
return 1
|
||||
end
|
||||
return 0
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ActiveTask:
|
||||
"""Represents an active streaming task (metadata only, no in-memory queues)."""
|
||||
|
||||
task_id: str
|
||||
session_id: str
|
||||
user_id: str | None
|
||||
tool_call_id: str
|
||||
tool_name: str
|
||||
operation_id: str
|
||||
status: Literal["running", "completed", "failed"] = "running"
|
||||
created_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
||||
asyncio_task: asyncio.Task | None = None
|
||||
|
||||
|
||||
def _get_task_meta_key(task_id: str) -> str:
|
||||
"""Get Redis key for task metadata."""
|
||||
return f"{config.task_meta_prefix}{task_id}"
|
||||
|
||||
|
||||
def _get_task_stream_key(task_id: str) -> str:
|
||||
"""Get Redis key for task message stream."""
|
||||
return f"{config.task_stream_prefix}{task_id}"
|
||||
|
||||
|
||||
def _get_operation_mapping_key(operation_id: str) -> str:
|
||||
"""Get Redis key for operation_id to task_id mapping."""
|
||||
return f"{config.task_op_prefix}{operation_id}"
|
||||
|
||||
|
||||
async def create_task(
|
||||
task_id: str,
|
||||
session_id: str,
|
||||
user_id: str | None,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
operation_id: str,
|
||||
) -> ActiveTask:
|
||||
"""Create a new streaming task in Redis.
|
||||
|
||||
Args:
|
||||
task_id: Unique identifier for the task
|
||||
session_id: Chat session ID
|
||||
user_id: User ID (may be None for anonymous)
|
||||
tool_call_id: Tool call ID from the LLM
|
||||
tool_name: Name of the tool being executed
|
||||
operation_id: Operation ID for webhook callbacks
|
||||
|
||||
Returns:
|
||||
The created ActiveTask instance (metadata only)
|
||||
"""
|
||||
task = ActiveTask(
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
operation_id=operation_id,
|
||||
)
|
||||
|
||||
# Store metadata in Redis
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
|
||||
await redis.hset( # type: ignore[misc]
|
||||
meta_key,
|
||||
mapping={
|
||||
"task_id": task_id,
|
||||
"session_id": session_id,
|
||||
"user_id": user_id or "",
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"operation_id": operation_id,
|
||||
"status": task.status,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
},
|
||||
)
|
||||
await redis.expire(meta_key, config.stream_ttl)
|
||||
|
||||
# Create operation_id -> task_id mapping for webhook lookups
|
||||
await redis.set(op_key, task_id, ex=config.stream_ttl)
|
||||
|
||||
logger.debug(f"Created task {task_id} for session {session_id}")
|
||||
|
||||
return task
|
||||
|
||||
|
||||
async def publish_chunk(
|
||||
task_id: str,
|
||||
chunk: StreamBaseResponse,
|
||||
) -> str:
|
||||
"""Publish a chunk to Redis Stream.
|
||||
|
||||
All delivery is via Redis Streams - no in-memory state.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to publish to
|
||||
chunk: The stream response chunk to publish
|
||||
|
||||
Returns:
|
||||
The Redis Stream message ID
|
||||
"""
|
||||
chunk_json = chunk.model_dump_json()
|
||||
message_id = "0-0"
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Write to Redis Stream for persistence and real-time delivery
|
||||
raw_id = await redis.xadd(
|
||||
stream_key,
|
||||
{"data": chunk_json},
|
||||
maxlen=config.stream_max_length,
|
||||
)
|
||||
message_id = raw_id if isinstance(raw_id, str) else raw_id.decode()
|
||||
|
||||
# Set TTL on stream to match task metadata TTL
|
||||
await redis.expire(stream_key, config.stream_ttl)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to publish chunk for task {task_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return message_id
|
||||
|
||||
|
||||
async def subscribe_to_task(
|
||||
task_id: str,
|
||||
user_id: str | None,
|
||||
last_message_id: str = "0-0",
|
||||
) -> asyncio.Queue[StreamBaseResponse] | None:
|
||||
"""Subscribe to a task's stream with replay of missed messages.
|
||||
|
||||
This is fully stateless - uses Redis Stream for replay and pub/sub for live updates.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to subscribe to
|
||||
user_id: User ID for ownership validation
|
||||
last_message_id: Last Redis Stream message ID received ("0-0" for full replay)
|
||||
|
||||
Returns:
|
||||
An asyncio Queue that will receive stream chunks, or None if task not found
|
||||
or user doesn't have access
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if not meta:
|
||||
logger.debug(f"Task {task_id} not found in Redis")
|
||||
return None
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
task_status = meta.get("status", "")
|
||||
task_user_id = meta.get("user_id", "") or None
|
||||
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task_user_id:
|
||||
if user_id != task_user_id:
|
||||
logger.warning(
|
||||
f"User {user_id} denied access to task {task_id} "
|
||||
f"owned by {task_user_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
# Step 1: Replay messages from Redis Stream
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=0, count=1000)
|
||||
|
||||
replayed_count = 0
|
||||
replay_last_id = last_message_id
|
||||
if messages:
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
replay_last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
# Note: Redis client uses decode_responses=True, so keys are strings
|
||||
if "data" in msg_data:
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
await subscriber_queue.put(chunk)
|
||||
replayed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to replay message: {e}")
|
||||
|
||||
logger.debug(f"Task {task_id}: replayed {replayed_count} messages")
|
||||
|
||||
# Step 2: If task is still running, start stream listener for live updates
|
||||
if task_status == "running":
|
||||
listener_task = asyncio.create_task(
|
||||
_stream_listener(task_id, subscriber_queue, replay_last_id)
|
||||
)
|
||||
# Track listener task for cleanup on unsubscribe
|
||||
_listener_tasks[id(subscriber_queue)] = (task_id, listener_task)
|
||||
else:
|
||||
# Task is completed/failed - add finish marker
|
||||
await subscriber_queue.put(StreamFinish())
|
||||
|
||||
return subscriber_queue
|
||||
|
||||
|
||||
async def _stream_listener(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
last_replayed_id: str,
|
||||
) -> None:
|
||||
"""Listen to Redis Stream for new messages using blocking XREAD.
|
||||
|
||||
This approach avoids the duplicate message issue that can occur with pub/sub
|
||||
when messages are published during the gap between replay and subscription.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to listen for
|
||||
subscriber_queue: Queue to deliver messages to
|
||||
last_replayed_id: Last message ID from replay (continue from here)
|
||||
"""
|
||||
queue_id = id(subscriber_queue)
|
||||
# Track the last successfully delivered message ID for recovery hints
|
||||
last_delivered_id = last_replayed_id
|
||||
|
||||
try:
|
||||
redis = await get_redis_async()
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
current_id = last_replayed_id
|
||||
|
||||
while True:
|
||||
# Block for up to 30 seconds waiting for new messages
|
||||
# This allows periodic checking if task is still running
|
||||
messages = await redis.xread(
|
||||
{stream_key: current_id}, block=30000, count=100
|
||||
)
|
||||
|
||||
if not messages:
|
||||
# Timeout - check if task is still running
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
status = await redis.hget(meta_key, "status") # type: ignore[misc]
|
||||
if status and status != "running":
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamFinish()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Timeout delivering finish event for task {task_id}"
|
||||
)
|
||||
break
|
||||
continue
|
||||
|
||||
for _stream_name, stream_messages in messages:
|
||||
for msg_id, msg_data in stream_messages:
|
||||
current_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
|
||||
if "data" not in msg_data:
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_data = orjson.loads(msg_data["data"])
|
||||
chunk = _reconstruct_chunk(chunk_data)
|
||||
if chunk:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(chunk),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
# Update last delivered ID on successful delivery
|
||||
last_delivered_id = current_id
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Subscriber queue full for task {task_id}, "
|
||||
f"message delivery timed out after {QUEUE_PUT_TIMEOUT}s"
|
||||
)
|
||||
# Send overflow error with recovery info
|
||||
try:
|
||||
overflow_error = StreamError(
|
||||
errorText="Message delivery timeout - some messages may have been missed",
|
||||
code="QUEUE_OVERFLOW",
|
||||
details={
|
||||
"last_delivered_id": last_delivered_id,
|
||||
"recovery_hint": f"Reconnect with last_message_id={last_delivered_id}",
|
||||
},
|
||||
)
|
||||
subscriber_queue.put_nowait(overflow_error)
|
||||
except asyncio.QueueFull:
|
||||
# Queue is completely stuck, nothing more we can do
|
||||
logger.error(
|
||||
f"Cannot deliver overflow error for task {task_id}, "
|
||||
"queue completely blocked"
|
||||
)
|
||||
|
||||
# Stop listening on finish
|
||||
if isinstance(chunk, StreamFinish):
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing stream message: {e}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(f"Stream listener cancelled for task {task_id}")
|
||||
raise # Re-raise to propagate cancellation
|
||||
except Exception as e:
|
||||
logger.error(f"Stream listener error for task {task_id}: {e}")
|
||||
# On error, send finish to unblock subscriber
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
subscriber_queue.put(StreamFinish()),
|
||||
timeout=QUEUE_PUT_TIMEOUT,
|
||||
)
|
||||
except (asyncio.TimeoutError, asyncio.QueueFull):
|
||||
logger.warning(
|
||||
f"Could not deliver finish event for task {task_id} after error"
|
||||
)
|
||||
finally:
|
||||
# Clean up listener task mapping on exit
|
||||
_listener_tasks.pop(queue_id, None)
|
||||
|
||||
|
||||
async def mark_task_completed(
|
||||
task_id: str,
|
||||
status: Literal["completed", "failed"] = "completed",
|
||||
) -> bool:
|
||||
"""Mark a task as completed and publish finish event.
|
||||
|
||||
This is idempotent - calling multiple times with the same task_id is safe.
|
||||
Uses atomic compare-and-swap via Lua script to prevent race conditions.
|
||||
Status is updated first (source of truth), then finish event is published (best-effort).
|
||||
|
||||
Args:
|
||||
task_id: Task ID to mark as completed
|
||||
status: Final status ("completed" or "failed")
|
||||
|
||||
Returns:
|
||||
True if task was newly marked completed, False if already completed/failed
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
|
||||
# Atomic compare-and-swap: only update if status is "running"
|
||||
# This prevents race conditions when multiple callers try to complete simultaneously
|
||||
result = await redis.eval(COMPLETE_TASK_SCRIPT, 1, meta_key, status) # type: ignore[misc]
|
||||
|
||||
if result == 0:
|
||||
logger.debug(f"Task {task_id} already completed/failed, skipping")
|
||||
return False
|
||||
|
||||
# THEN publish finish event (best-effort - listeners can detect via status polling)
|
||||
try:
|
||||
await publish_chunk(task_id, StreamFinish())
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to publish finish event for task {task_id}: {e}. "
|
||||
"Listeners will detect completion via status polling."
|
||||
)
|
||||
|
||||
# Clean up local task reference if exists
|
||||
_local_tasks.pop(task_id, None)
|
||||
return True
|
||||
|
||||
|
||||
async def find_task_by_operation_id(operation_id: str) -> ActiveTask | None:
|
||||
"""Find a task by its operation ID.
|
||||
|
||||
Used by webhook callbacks to locate the task to update.
|
||||
|
||||
Args:
|
||||
operation_id: Operation ID to search for
|
||||
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
op_key = _get_operation_mapping_key(operation_id)
|
||||
task_id = await redis.get(op_key)
|
||||
|
||||
if not task_id:
|
||||
return None
|
||||
|
||||
task_id_str = task_id.decode() if isinstance(task_id, bytes) else task_id
|
||||
return await get_task(task_id_str)
|
||||
|
||||
|
||||
async def get_task(task_id: str) -> ActiveTask | None:
|
||||
"""Get a task by its ID from Redis.
|
||||
|
||||
Args:
|
||||
task_id: Task ID to look up
|
||||
|
||||
Returns:
|
||||
ActiveTask if found, None otherwise
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if not meta:
|
||||
return None
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||
return ActiveTask(
|
||||
task_id=meta.get("task_id", ""),
|
||||
session_id=meta.get("session_id", ""),
|
||||
user_id=meta.get("user_id", "") or None,
|
||||
tool_call_id=meta.get("tool_call_id", ""),
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
operation_id=meta.get("operation_id", ""),
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
async def get_task_with_expiry_info(
|
||||
task_id: str,
|
||||
) -> tuple[ActiveTask | None, str | None]:
|
||||
"""Get a task by its ID with expiration detection.
|
||||
|
||||
Returns (task, error_code) where error_code is:
|
||||
- None if task found
|
||||
- "TASK_EXPIRED" if stream exists but metadata is gone (TTL expired)
|
||||
- "TASK_NOT_FOUND" if neither exists
|
||||
|
||||
Args:
|
||||
task_id: Task ID to look up
|
||||
|
||||
Returns:
|
||||
Tuple of (ActiveTask or None, error_code or None)
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
meta_key = _get_task_meta_key(task_id)
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
|
||||
meta: dict[Any, Any] = await redis.hgetall(meta_key) # type: ignore[misc]
|
||||
|
||||
if not meta:
|
||||
# Check if stream still has data (metadata expired but stream hasn't)
|
||||
stream_len = await redis.xlen(stream_key)
|
||||
if stream_len > 0:
|
||||
return None, "TASK_EXPIRED"
|
||||
return None, "TASK_NOT_FOUND"
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||
return (
|
||||
ActiveTask(
|
||||
task_id=meta.get("task_id", ""),
|
||||
session_id=meta.get("session_id", ""),
|
||||
user_id=meta.get("user_id", "") or None,
|
||||
tool_call_id=meta.get("tool_call_id", ""),
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
operation_id=meta.get("operation_id", ""),
|
||||
status=meta.get("status", "running"), # type: ignore[arg-type]
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
async def get_active_task_for_session(
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> tuple[ActiveTask | None, str]:
|
||||
"""Get the active (running) task for a session, if any.
|
||||
|
||||
Scans Redis for tasks matching the session_id with status="running".
|
||||
|
||||
Args:
|
||||
session_id: Session ID to look up
|
||||
user_id: User ID for ownership validation (optional)
|
||||
|
||||
Returns:
|
||||
Tuple of (ActiveTask if found and running, last_message_id from Redis Stream)
|
||||
"""
|
||||
|
||||
redis = await get_redis_async()
|
||||
|
||||
# Scan Redis for task metadata keys
|
||||
cursor = 0
|
||||
tasks_checked = 0
|
||||
|
||||
while True:
|
||||
cursor, keys = await redis.scan(
|
||||
cursor, match=f"{config.task_meta_prefix}*", count=100
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
tasks_checked += 1
|
||||
meta: dict[Any, Any] = await redis.hgetall(key) # type: ignore[misc]
|
||||
if not meta:
|
||||
continue
|
||||
|
||||
# Note: Redis client uses decode_responses=True, so keys/values are strings
|
||||
task_session_id = meta.get("session_id", "")
|
||||
task_status = meta.get("status", "")
|
||||
task_user_id = meta.get("user_id", "") or None
|
||||
task_id = meta.get("task_id", "")
|
||||
|
||||
if task_session_id == session_id and task_status == "running":
|
||||
# Validate ownership - if task has an owner, requester must match
|
||||
if task_user_id and user_id != task_user_id:
|
||||
continue
|
||||
|
||||
# Get the last message ID from Redis Stream
|
||||
stream_key = _get_task_stream_key(task_id)
|
||||
last_id = "0-0"
|
||||
try:
|
||||
messages = await redis.xrevrange(stream_key, count=1)
|
||||
if messages:
|
||||
msg_id = messages[0][0]
|
||||
last_id = msg_id if isinstance(msg_id, str) else msg_id.decode()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to get last message ID: {e}")
|
||||
|
||||
return (
|
||||
ActiveTask(
|
||||
task_id=task_id,
|
||||
session_id=task_session_id,
|
||||
user_id=task_user_id,
|
||||
tool_call_id=meta.get("tool_call_id", ""),
|
||||
tool_name=meta.get("tool_name", ""),
|
||||
operation_id=meta.get("operation_id", ""),
|
||||
status="running",
|
||||
),
|
||||
last_id,
|
||||
)
|
||||
|
||||
if cursor == 0:
|
||||
break
|
||||
|
||||
return None, "0-0"
|
||||
|
||||
|
||||
def _reconstruct_chunk(chunk_data: dict) -> StreamBaseResponse | None:
|
||||
"""Reconstruct a StreamBaseResponse from JSON data.
|
||||
|
||||
Args:
|
||||
chunk_data: Parsed JSON data from Redis
|
||||
|
||||
Returns:
|
||||
Reconstructed response object, or None if unknown type
|
||||
"""
|
||||
from .response_model import (
|
||||
ResponseType,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolInputStart,
|
||||
StreamToolOutputAvailable,
|
||||
StreamUsage,
|
||||
)
|
||||
|
||||
# Map response types to their corresponding classes
|
||||
type_to_class: dict[str, type[StreamBaseResponse]] = {
|
||||
ResponseType.START.value: StreamStart,
|
||||
ResponseType.FINISH.value: StreamFinish,
|
||||
ResponseType.TEXT_START.value: StreamTextStart,
|
||||
ResponseType.TEXT_DELTA.value: StreamTextDelta,
|
||||
ResponseType.TEXT_END.value: StreamTextEnd,
|
||||
ResponseType.TOOL_INPUT_START.value: StreamToolInputStart,
|
||||
ResponseType.TOOL_INPUT_AVAILABLE.value: StreamToolInputAvailable,
|
||||
ResponseType.TOOL_OUTPUT_AVAILABLE.value: StreamToolOutputAvailable,
|
||||
ResponseType.ERROR.value: StreamError,
|
||||
ResponseType.USAGE.value: StreamUsage,
|
||||
ResponseType.HEARTBEAT.value: StreamHeartbeat,
|
||||
}
|
||||
|
||||
chunk_type = chunk_data.get("type")
|
||||
chunk_class = type_to_class.get(chunk_type) # type: ignore[arg-type]
|
||||
|
||||
if chunk_class is None:
|
||||
logger.warning(f"Unknown chunk type: {chunk_type}")
|
||||
return None
|
||||
|
||||
try:
|
||||
return chunk_class(**chunk_data)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to reconstruct chunk of type {chunk_type}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def set_task_asyncio_task(task_id: str, asyncio_task: asyncio.Task) -> None:
|
||||
"""Track the asyncio.Task for a task (local reference only).
|
||||
|
||||
This is just for cleanup purposes - the task state is in Redis.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
asyncio_task: The asyncio Task to track
|
||||
"""
|
||||
_local_tasks[task_id] = asyncio_task
|
||||
|
||||
|
||||
async def unsubscribe_from_task(
|
||||
task_id: str,
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse],
|
||||
) -> None:
|
||||
"""Clean up when a subscriber disconnects.
|
||||
|
||||
Cancels the XREAD-based listener task associated with this subscriber queue
|
||||
to prevent resource leaks.
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
subscriber_queue: The subscriber's queue used to look up the listener task
|
||||
"""
|
||||
queue_id = id(subscriber_queue)
|
||||
listener_entry = _listener_tasks.pop(queue_id, None)
|
||||
|
||||
if listener_entry is None:
|
||||
logger.debug(
|
||||
f"No listener task found for task {task_id} queue {queue_id} "
|
||||
"(may have already completed)"
|
||||
)
|
||||
return
|
||||
|
||||
stored_task_id, listener_task = listener_entry
|
||||
|
||||
if stored_task_id != task_id:
|
||||
logger.warning(
|
||||
f"Task ID mismatch in unsubscribe: expected {task_id}, "
|
||||
f"found {stored_task_id}"
|
||||
)
|
||||
|
||||
if listener_task.done():
|
||||
logger.debug(f"Listener task for task {task_id} already completed")
|
||||
return
|
||||
|
||||
# Cancel the listener task
|
||||
listener_task.cancel()
|
||||
|
||||
try:
|
||||
# Wait for the task to be cancelled with a timeout
|
||||
await asyncio.wait_for(listener_task, timeout=5.0)
|
||||
except asyncio.CancelledError:
|
||||
# Expected - the task was successfully cancelled
|
||||
pass
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"Timeout waiting for listener task cancellation for task {task_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error during listener task cancellation for task {task_id}: {e}")
|
||||
|
||||
logger.debug(f"Successfully unsubscribed from task {task_id}")
|
||||
@@ -10,7 +10,6 @@ from .add_understanding import AddUnderstandingTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .base import BaseTool
|
||||
from .create_agent import CreateAgentTool
|
||||
from .customize_agent import CustomizeAgentTool
|
||||
from .edit_agent import EditAgentTool
|
||||
from .find_agent import FindAgentTool
|
||||
from .find_block import FindBlockTool
|
||||
@@ -35,7 +34,6 @@ logger = logging.getLogger(__name__)
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"create_agent": CreateAgentTool(),
|
||||
"customize_agent": CustomizeAgentTool(),
|
||||
"edit_agent": EditAgentTool(),
|
||||
"find_agent": FindAgentTool(),
|
||||
"find_block": FindBlockTool(),
|
||||
|
||||
@@ -2,58 +2,30 @@
|
||||
|
||||
from .core import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
AgentJsonValidationError,
|
||||
AgentSummary,
|
||||
DecompositionResult,
|
||||
DecompositionStep,
|
||||
LibraryAgentSummary,
|
||||
MarketplaceAgentSummary,
|
||||
customize_template,
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
extract_search_terms_from_steps,
|
||||
extract_uuids_from_text,
|
||||
generate_agent,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_library_agent_by_graph_id,
|
||||
get_library_agent_by_id,
|
||||
get_library_agents_for_generation,
|
||||
graph_to_json,
|
||||
json_to_graph,
|
||||
save_agent_to_library,
|
||||
search_marketplace_agents_for_generation,
|
||||
)
|
||||
from .errors import get_user_message_for_error
|
||||
from .service import health_check as check_external_service_health
|
||||
from .service import is_external_service_configured
|
||||
|
||||
__all__ = [
|
||||
"AgentGeneratorNotConfiguredError",
|
||||
"AgentJsonValidationError",
|
||||
"AgentSummary",
|
||||
"DecompositionResult",
|
||||
"DecompositionStep",
|
||||
"LibraryAgentSummary",
|
||||
"MarketplaceAgentSummary",
|
||||
"check_external_service_health",
|
||||
"customize_template",
|
||||
# Core functions
|
||||
"decompose_goal",
|
||||
"enrich_library_agents_from_steps",
|
||||
"extract_search_terms_from_steps",
|
||||
"extract_uuids_from_text",
|
||||
"generate_agent",
|
||||
"generate_agent_patch",
|
||||
"get_agent_as_json",
|
||||
"get_all_relevant_agents_for_generation",
|
||||
"get_library_agent_by_graph_id",
|
||||
"get_library_agent_by_id",
|
||||
"get_library_agents_for_generation",
|
||||
"get_user_message_for_error",
|
||||
"graph_to_json",
|
||||
"is_external_service_configured",
|
||||
"json_to_graph",
|
||||
"save_agent_to_library",
|
||||
"search_marketplace_agents_for_generation",
|
||||
"get_agent_as_json",
|
||||
"json_to_graph",
|
||||
# Exceptions
|
||||
"AgentGeneratorNotConfiguredError",
|
||||
# Service
|
||||
"is_external_service_configured",
|
||||
"check_external_service_health",
|
||||
# Error handling
|
||||
"get_user_message_for_error",
|
||||
]
|
||||
|
||||
@@ -1,25 +1,13 @@
|
||||
"""Core agent generation functions."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import (
|
||||
Graph,
|
||||
Link,
|
||||
Node,
|
||||
create_graph,
|
||||
get_graph,
|
||||
get_graph_all_versions,
|
||||
get_store_listed_graphs,
|
||||
)
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
|
||||
from .service import (
|
||||
customize_template_external,
|
||||
decompose_goal_external,
|
||||
generate_agent_external,
|
||||
generate_agent_patch_external,
|
||||
@@ -28,74 +16,6 @@ from .service import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||
|
||||
|
||||
class ExecutionSummary(TypedDict):
|
||||
"""Summary of a single execution for quality assessment."""
|
||||
|
||||
status: str
|
||||
correctness_score: NotRequired[float]
|
||||
activity_summary: NotRequired[str]
|
||||
|
||||
|
||||
class LibraryAgentSummary(TypedDict):
|
||||
"""Summary of a library agent for sub-agent composition.
|
||||
|
||||
Includes recent executions to help the LLM decide whether to use this agent.
|
||||
Each execution shows status, correctness_score (0-1), and activity_summary.
|
||||
"""
|
||||
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
recent_executions: NotRequired[list[ExecutionSummary]]
|
||||
|
||||
|
||||
class MarketplaceAgentSummary(TypedDict):
|
||||
"""Summary of a marketplace agent for sub-agent composition."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
sub_heading: str
|
||||
creator: str
|
||||
is_marketplace_agent: bool
|
||||
|
||||
|
||||
class DecompositionStep(TypedDict, total=False):
|
||||
"""A single step in decomposed instructions."""
|
||||
|
||||
description: str
|
||||
action: str
|
||||
block_name: str
|
||||
tool: str
|
||||
name: str
|
||||
|
||||
|
||||
class DecompositionResult(TypedDict, total=False):
|
||||
"""Result from decompose_goal - can be instructions, questions, or error."""
|
||||
|
||||
type: str
|
||||
steps: list[DecompositionStep]
|
||||
questions: list[dict[str, Any]]
|
||||
error: str
|
||||
error_type: str
|
||||
|
||||
|
||||
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||
|
||||
|
||||
def _to_dict_list(
|
||||
agents: list[AgentSummary] | list[dict[str, Any]] | None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Convert typed agent summaries to plain dicts for external service calls."""
|
||||
if agents is None:
|
||||
return None
|
||||
return [dict(a) for a in agents]
|
||||
|
||||
|
||||
class AgentGeneratorNotConfiguredError(Exception):
|
||||
"""Raised when the external Agent Generator service is not configured."""
|
||||
@@ -116,422 +36,15 @@ def _check_service_configured() -> None:
|
||||
)
|
||||
|
||||
|
||||
_UUID_PATTERN = re.compile(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def extract_uuids_from_text(text: str) -> list[str]:
|
||||
"""Extract all UUID v4 strings from text.
|
||||
|
||||
Args:
|
||||
text: Text that may contain UUIDs (e.g., user's goal description)
|
||||
|
||||
Returns:
|
||||
List of unique UUIDs found in the text (lowercase)
|
||||
"""
|
||||
matches = _UUID_PATTERN.findall(text)
|
||||
return list({m.lower() for m in matches})
|
||||
|
||||
|
||||
async def get_library_agent_by_id(
|
||||
user_id: str, agent_id: str
|
||||
) -> LibraryAgentSummary | None:
|
||||
"""Fetch a specific library agent by its ID (library agent ID or graph_id).
|
||||
|
||||
This function tries multiple lookup strategies:
|
||||
1. First tries to find by graph_id (AgentGraph primary key)
|
||||
2. If not found, tries to find by library agent ID (LibraryAgent primary key)
|
||||
|
||||
This handles both cases:
|
||||
- User provides graph_id (e.g., from AgentExecutorBlock)
|
||||
- User provides library agent ID (e.g., from library URL)
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||
|
||||
Returns:
|
||||
LibraryAgentSummary if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
graph_id=agent.graph_id,
|
||||
graph_version=agent.graph_version,
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
input_schema=agent.input_schema,
|
||||
output_schema=agent.output_schema,
|
||||
)
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
graph_id=agent.graph_id,
|
||||
graph_version=agent.graph_version,
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
input_schema=agent.input_schema,
|
||||
output_schema=agent.output_schema,
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
get_library_agent_by_graph_id = get_library_agent_by_id
|
||||
|
||||
|
||||
async def get_library_agents_for_generation(
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
exclude_graph_id: str | None = None,
|
||||
max_results: int = 15,
|
||||
) -> list[LibraryAgentSummary]:
|
||||
"""Fetch user's library agents formatted for Agent Generator.
|
||||
|
||||
Uses search-based fetching to return relevant agents instead of all agents.
|
||||
This is more scalable for users with large libraries.
|
||||
|
||||
Includes recent_executions list to help the LLM assess agent quality:
|
||||
- Each execution has status, correctness_score (0-1), and activity_summary
|
||||
- This gives the LLM concrete examples of recent performance
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
search_query: Optional search term to find relevant agents (user's goal/description)
|
||||
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||
max_results: Maximum number of agents to return (default 15)
|
||||
|
||||
Returns:
|
||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||
"""
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
include_executions=True,
|
||||
)
|
||||
|
||||
results: list[LibraryAgentSummary] = []
|
||||
for agent in response.agents:
|
||||
if exclude_graph_id is not None and agent.graph_id == exclude_graph_id:
|
||||
continue
|
||||
|
||||
summary = LibraryAgentSummary(
|
||||
graph_id=agent.graph_id,
|
||||
graph_version=agent.graph_version,
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
input_schema=agent.input_schema,
|
||||
output_schema=agent.output_schema,
|
||||
)
|
||||
if agent.recent_executions:
|
||||
exec_summaries: list[ExecutionSummary] = []
|
||||
for ex in agent.recent_executions:
|
||||
exec_sum = ExecutionSummary(status=ex.status)
|
||||
if ex.correctness_score is not None:
|
||||
exec_sum["correctness_score"] = ex.correctness_score
|
||||
if ex.activity_summary:
|
||||
exec_sum["activity_summary"] = ex.activity_summary
|
||||
exec_summaries.append(exec_sum)
|
||||
summary["recent_executions"] = exec_summaries
|
||||
results.append(summary)
|
||||
return results
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def search_marketplace_agents_for_generation(
|
||||
search_query: str,
|
||||
max_results: int = 10,
|
||||
) -> list[LibraryAgentSummary]:
|
||||
"""Search marketplace agents formatted for Agent Generator.
|
||||
|
||||
Fetches marketplace agents and their full schemas so they can be used
|
||||
as sub-agents in generated workflows.
|
||||
|
||||
Args:
|
||||
search_query: Search term to find relevant public agents
|
||||
max_results: Maximum number of agents to return (default 10)
|
||||
|
||||
Returns:
|
||||
List of LibraryAgentSummary with full input/output schemas
|
||||
"""
|
||||
try:
|
||||
response = await store_db.get_store_agents(
|
||||
search_query=search_query,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
)
|
||||
|
||||
agents_with_graphs = [
|
||||
agent for agent in response.agents if agent.agent_graph_id
|
||||
]
|
||||
|
||||
if not agents_with_graphs:
|
||||
return []
|
||||
|
||||
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
||||
graphs = await get_store_listed_graphs(*graph_ids)
|
||||
|
||||
results: list[LibraryAgentSummary] = []
|
||||
for agent in agents_with_graphs:
|
||||
graph_id = agent.agent_graph_id
|
||||
if graph_id and graph_id in graphs:
|
||||
graph = graphs[graph_id]
|
||||
results.append(
|
||||
LibraryAgentSummary(
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
name=agent.agent_name,
|
||||
description=agent.description,
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
)
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to search marketplace agents: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_all_relevant_agents_for_generation(
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
exclude_graph_id: str | None = None,
|
||||
include_library: bool = True,
|
||||
include_marketplace: bool = True,
|
||||
max_library_results: int = 15,
|
||||
max_marketplace_results: int = 10,
|
||||
) -> list[AgentSummary]:
|
||||
"""Fetch relevant agents from library and/or marketplace.
|
||||
|
||||
Searches both user's library and marketplace by default.
|
||||
Explicitly mentioned UUIDs in the search query are always looked up.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
search_query: Search term to find relevant agents (user's goal/description)
|
||||
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||
include_library: Whether to search user's library (default True)
|
||||
include_marketplace: Whether to also search marketplace (default True)
|
||||
max_library_results: Max library agents to return (default 15)
|
||||
max_marketplace_results: Max marketplace agents to return (default 10)
|
||||
|
||||
Returns:
|
||||
List of AgentSummary with full schemas (both library and marketplace agents)
|
||||
"""
|
||||
agents: list[AgentSummary] = []
|
||||
seen_graph_ids: set[str] = set()
|
||||
|
||||
if search_query:
|
||||
mentioned_uuids = extract_uuids_from_text(search_query)
|
||||
for graph_id in mentioned_uuids:
|
||||
if graph_id == exclude_graph_id:
|
||||
continue
|
||||
agent = await get_library_agent_by_graph_id(user_id, graph_id)
|
||||
agent_graph_id = agent.get("graph_id") if agent else None
|
||||
if agent and agent_graph_id and agent_graph_id not in seen_graph_ids:
|
||||
agents.append(agent)
|
||||
seen_graph_ids.add(agent_graph_id)
|
||||
logger.debug(
|
||||
f"Found explicitly mentioned agent: {agent.get('name') or 'Unknown'}"
|
||||
)
|
||||
|
||||
if include_library:
|
||||
library_agents = await get_library_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=search_query,
|
||||
exclude_graph_id=exclude_graph_id,
|
||||
max_results=max_library_results,
|
||||
)
|
||||
for agent in library_agents:
|
||||
graph_id = agent.get("graph_id")
|
||||
if graph_id and graph_id not in seen_graph_ids:
|
||||
agents.append(agent)
|
||||
seen_graph_ids.add(graph_id)
|
||||
|
||||
if include_marketplace and search_query:
|
||||
marketplace_agents = await search_marketplace_agents_for_generation(
|
||||
search_query=search_query,
|
||||
max_results=max_marketplace_results,
|
||||
)
|
||||
for agent in marketplace_agents:
|
||||
graph_id = agent.get("graph_id")
|
||||
if graph_id and graph_id not in seen_graph_ids:
|
||||
agents.append(agent)
|
||||
seen_graph_ids.add(graph_id)
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
def extract_search_terms_from_steps(
|
||||
decomposition_result: DecompositionResult | dict[str, Any],
|
||||
) -> list[str]:
|
||||
"""Extract search terms from decomposed instruction steps.
|
||||
|
||||
Analyzes the decomposition result to extract relevant keywords
|
||||
for additional library agent searches.
|
||||
|
||||
Args:
|
||||
decomposition_result: Result from decompose_goal containing steps
|
||||
|
||||
Returns:
|
||||
List of unique search terms extracted from steps
|
||||
"""
|
||||
search_terms: list[str] = []
|
||||
|
||||
if decomposition_result.get("type") != "instructions":
|
||||
return search_terms
|
||||
|
||||
steps = decomposition_result.get("steps", [])
|
||||
if not steps:
|
||||
return search_terms
|
||||
|
||||
step_keys: list[str] = ["description", "action", "block_name", "tool", "name"]
|
||||
|
||||
for step in steps:
|
||||
for key in step_keys:
|
||||
value = step.get(key) # type: ignore[union-attr]
|
||||
if isinstance(value, str) and len(value) > 3:
|
||||
search_terms.append(value)
|
||||
|
||||
seen: set[str] = set()
|
||||
unique_terms: list[str] = []
|
||||
for term in search_terms:
|
||||
term_lower = term.lower()
|
||||
if term_lower not in seen:
|
||||
seen.add(term_lower)
|
||||
unique_terms.append(term)
|
||||
|
||||
return unique_terms
|
||||
|
||||
|
||||
async def enrich_library_agents_from_steps(
|
||||
user_id: str,
|
||||
decomposition_result: DecompositionResult | dict[str, Any],
|
||||
existing_agents: list[AgentSummary] | list[dict[str, Any]],
|
||||
exclude_graph_id: str | None = None,
|
||||
include_marketplace: bool = True,
|
||||
max_additional_results: int = 10,
|
||||
) -> list[AgentSummary] | list[dict[str, Any]]:
|
||||
"""Enrich library agents list with additional searches based on decomposed steps.
|
||||
|
||||
This implements two-phase search: after decomposition, we search for additional
|
||||
relevant agents based on the specific steps identified.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
decomposition_result: Result from decompose_goal containing steps
|
||||
existing_agents: Already fetched library agents from initial search
|
||||
exclude_graph_id: Optional graph ID to exclude
|
||||
include_marketplace: Whether to also search marketplace
|
||||
max_additional_results: Max additional agents per search term (default 10)
|
||||
|
||||
Returns:
|
||||
Combined list of library agents (existing + newly discovered)
|
||||
"""
|
||||
search_terms = extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
if not search_terms:
|
||||
return existing_agents
|
||||
|
||||
existing_ids: set[str] = set()
|
||||
existing_names: set[str] = set()
|
||||
|
||||
for agent in existing_agents:
|
||||
agent_name = agent.get("name")
|
||||
if agent_name and isinstance(agent_name, str):
|
||||
existing_names.add(agent_name.lower())
|
||||
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||
if graph_id and isinstance(graph_id, str):
|
||||
existing_ids.add(graph_id)
|
||||
|
||||
all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents)
|
||||
|
||||
for term in search_terms[:3]:
|
||||
try:
|
||||
additional_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=term,
|
||||
exclude_graph_id=exclude_graph_id,
|
||||
include_marketplace=include_marketplace,
|
||||
max_library_results=max_additional_results,
|
||||
max_marketplace_results=5,
|
||||
)
|
||||
|
||||
for agent in additional_agents:
|
||||
agent_name = agent.get("name")
|
||||
if not agent_name or not isinstance(agent_name, str):
|
||||
continue
|
||||
agent_name_lower = agent_name.lower()
|
||||
|
||||
if agent_name_lower in existing_names:
|
||||
continue
|
||||
|
||||
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||
if graph_id and graph_id in existing_ids:
|
||||
continue
|
||||
|
||||
all_agents.append(agent)
|
||||
existing_names.add(agent_name_lower)
|
||||
if graph_id and isinstance(graph_id, str):
|
||||
existing_ids.add(graph_id)
|
||||
|
||||
except DatabaseError:
|
||||
logger.error(f"Database error searching for agents with term '{term}'")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to search for additional agents with term '{term}': {e}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Enriched library agents: {len(existing_agents)} initial + "
|
||||
f"{len(all_agents) - len(existing_agents)} additional = {len(all_agents)} total"
|
||||
)
|
||||
|
||||
return all_agents
|
||||
|
||||
|
||||
async def decompose_goal(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
) -> DecompositionResult | None:
|
||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
DecompositionResult with either:
|
||||
Dict with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
Or None on error
|
||||
@@ -541,47 +54,29 @@ async def decompose_goal(
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||
result = await decompose_goal_external(
|
||||
description, context, _to_dict_list(library_agents)
|
||||
)
|
||||
return result # type: ignore[return-value]
|
||||
return await decompose_goal_external(description, context)
|
||||
|
||||
|
||||
async def generate_agent(
|
||||
instructions: DecompositionResult | dict[str, Any],
|
||||
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams
|
||||
completion notification)
|
||||
task_id: Task ID for async processing (enables Redis Streams persistence
|
||||
and SSE delivery)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, {"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(
|
||||
dict(instructions), _to_dict_list(library_agents), operation_id, task_id
|
||||
)
|
||||
|
||||
# Don't modify async response
|
||||
if result and result.get("status") == "accepted":
|
||||
return result
|
||||
|
||||
result = await generate_agent_external(instructions)
|
||||
if result:
|
||||
# Check if it's an error response - pass through as-is
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
return result
|
||||
# Ensure required fields for successful agent generation
|
||||
if "id" not in result:
|
||||
result["id"] = str(uuid.uuid4())
|
||||
if "version" not in result:
|
||||
@@ -591,12 +86,6 @@ async def generate_agent(
|
||||
return result
|
||||
|
||||
|
||||
class AgentJsonValidationError(Exception):
|
||||
"""Raised when agent JSON is invalid or missing required fields."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
"""Convert agent JSON dict to Graph model.
|
||||
|
||||
@@ -605,55 +94,25 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
|
||||
Returns:
|
||||
Graph ready for saving
|
||||
|
||||
Raises:
|
||||
AgentJsonValidationError: If required fields are missing from nodes or links
|
||||
"""
|
||||
nodes = []
|
||||
for idx, n in enumerate(agent_json.get("nodes", [])):
|
||||
block_id = n.get("block_id")
|
||||
if not block_id:
|
||||
node_id = n.get("id", f"index_{idx}")
|
||||
raise AgentJsonValidationError(
|
||||
f"Node '{node_id}' is missing required field 'block_id'"
|
||||
)
|
||||
for n in agent_json.get("nodes", []):
|
||||
node = Node(
|
||||
id=n.get("id", str(uuid.uuid4())),
|
||||
block_id=block_id,
|
||||
block_id=n["block_id"],
|
||||
input_default=n.get("input_default", {}),
|
||||
metadata=n.get("metadata", {}),
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
links = []
|
||||
for idx, link_data in enumerate(agent_json.get("links", [])):
|
||||
source_id = link_data.get("source_id")
|
||||
sink_id = link_data.get("sink_id")
|
||||
source_name = link_data.get("source_name")
|
||||
sink_name = link_data.get("sink_name")
|
||||
|
||||
missing_fields = []
|
||||
if not source_id:
|
||||
missing_fields.append("source_id")
|
||||
if not sink_id:
|
||||
missing_fields.append("sink_id")
|
||||
if not source_name:
|
||||
missing_fields.append("source_name")
|
||||
if not sink_name:
|
||||
missing_fields.append("sink_name")
|
||||
|
||||
if missing_fields:
|
||||
link_id = link_data.get("id", f"index_{idx}")
|
||||
raise AgentJsonValidationError(
|
||||
f"Link '{link_id}' is missing required fields: {', '.join(missing_fields)}"
|
||||
)
|
||||
|
||||
for link_data in agent_json.get("links", []):
|
||||
link = Link(
|
||||
id=link_data.get("id", str(uuid.uuid4())),
|
||||
source_id=source_id,
|
||||
sink_id=sink_id,
|
||||
source_name=source_name,
|
||||
sink_name=sink_name,
|
||||
source_id=link_data["source_id"],
|
||||
sink_id=link_data["sink_id"],
|
||||
source_name=link_data["source_name"],
|
||||
sink_name=link_data["sink_name"],
|
||||
is_static=link_data.get("is_static", False),
|
||||
)
|
||||
links.append(link)
|
||||
@@ -674,40 +133,22 @@ def _reassign_node_ids(graph: Graph) -> None:
|
||||
|
||||
This is needed when creating a new version to avoid unique constraint violations.
|
||||
"""
|
||||
# Create mapping from old node IDs to new UUIDs
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
|
||||
# Reassign node IDs
|
||||
for node in graph.nodes:
|
||||
node.id = id_map[node.id]
|
||||
|
||||
# Update link references to use new node IDs
|
||||
for link in graph.links:
|
||||
link.id = str(uuid.uuid4())
|
||||
link.id = str(uuid.uuid4()) # Also give links new IDs
|
||||
if link.source_id in id_map:
|
||||
link.source_id = id_map[link.source_id]
|
||||
if link.sink_id in id_map:
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
|
||||
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
||||
"""Populate user_id in AgentExecutorBlock nodes.
|
||||
|
||||
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
||||
This function fills in the actual user_id so sub-agents run with correct permissions.
|
||||
|
||||
Args:
|
||||
agent_json: Agent JSON dict (modified in place)
|
||||
user_id: User ID to set
|
||||
"""
|
||||
for node in agent_json.get("nodes", []):
|
||||
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
||||
input_default = node.get("input_default") or {}
|
||||
if not input_default.get("user_id"):
|
||||
input_default["user_id"] = user_id
|
||||
node["input_default"] = input_default
|
||||
logger.debug(
|
||||
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
||||
)
|
||||
|
||||
|
||||
async def save_agent_to_library(
|
||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||
) -> tuple[Graph, Any]:
|
||||
@@ -721,27 +162,33 @@ async def save_agent_to_library(
|
||||
Returns:
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
# Populate user_id in AgentExecutorBlock nodes before conversion
|
||||
_populate_agent_executor_user_ids(agent_json, user_id)
|
||||
from backend.data.graph import get_graph_all_versions
|
||||
|
||||
graph = json_to_graph(agent_json)
|
||||
|
||||
if is_update:
|
||||
# For updates, keep the same graph ID but increment version
|
||||
# and reassign node/link IDs to avoid conflicts
|
||||
if graph.id:
|
||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||
if existing_versions:
|
||||
latest_version = max(v.version for v in existing_versions)
|
||||
graph.version = latest_version + 1
|
||||
# Reassign node IDs (but keep graph ID the same)
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||
else:
|
||||
# For new agents, always generate a fresh UUID to avoid collisions
|
||||
graph.id = str(uuid.uuid4())
|
||||
graph.version = 1
|
||||
# Reassign all node IDs as well
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Creating new agent with ID {graph.id}")
|
||||
|
||||
# Save to database
|
||||
created_graph = await create_graph(graph, user_id)
|
||||
|
||||
# Add to user's library (or update existing library agent)
|
||||
library_agents = await library_db.create_library_agent(
|
||||
graph=created_graph,
|
||||
user_id=user_id,
|
||||
@@ -752,15 +199,26 @@ async def save_agent_to_library(
|
||||
return created_graph, library_agents[0]
|
||||
|
||||
|
||||
def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||
"""Convert a Graph object to JSON format for the agent generator.
|
||||
async def get_agent_as_json(
|
||||
graph_id: str, user_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch an agent and convert to JSON format for editing.
|
||||
|
||||
Args:
|
||||
graph: Graph object to convert
|
||||
graph_id: Graph ID or library agent ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Agent as JSON dict
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
# Try to get the graph (version=None gets the active version)
|
||||
graph = await get_graph(graph_id, version=None, user_id=user_id)
|
||||
if not graph:
|
||||
return None
|
||||
|
||||
# Convert to JSON format
|
||||
nodes = []
|
||||
for node in graph.nodes:
|
||||
nodes.append(
|
||||
@@ -797,41 +255,8 @@ def graph_to_json(graph: Graph) -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
async def get_agent_as_json(
|
||||
agent_id: str, user_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch an agent and convert to JSON format for editing.
|
||||
|
||||
Args:
|
||||
agent_id: Graph ID or library agent ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||
|
||||
if not graph and user_id:
|
||||
try:
|
||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id, version=None, user_id=user_id
|
||||
)
|
||||
except NotFoundError:
|
||||
pass
|
||||
|
||||
if not graph:
|
||||
return None
|
||||
|
||||
return graph_to_json(graph)
|
||||
|
||||
|
||||
async def generate_agent_patch(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
@@ -843,57 +268,14 @@ async def generate_agent_patch(
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
{"status": "accepted"} for async, error dict {"type": "error", ...}, or None on error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(
|
||||
update_request,
|
||||
current_agent,
|
||||
_to_dict_list(library_agents),
|
||||
operation_id,
|
||||
task_id,
|
||||
)
|
||||
|
||||
|
||||
async def customize_template(
|
||||
template_agent: dict[str, Any],
|
||||
modification_request: str,
|
||||
context: str = "",
|
||||
) -> dict[str, Any] | None:
|
||||
"""Customize a template/marketplace agent using natural language.
|
||||
|
||||
This is used when users want to modify a template or marketplace agent
|
||||
to fit their specific needs before adding it to their library.
|
||||
|
||||
The external Agent Generator service handles:
|
||||
- Understanding the modification request
|
||||
- Applying changes to the template
|
||||
- Fixing and validating the result
|
||||
|
||||
Args:
|
||||
template_agent: The template agent JSON to customize
|
||||
modification_request: Natural language description of customizations
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
error dict {"type": "error", ...}, or None on unexpected error
|
||||
|
||||
Raises:
|
||||
AgentGeneratorNotConfiguredError: If the external service is not configured.
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for customize_template")
|
||||
return await customize_template_external(
|
||||
template_agent, modification_request, context
|
||||
)
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(update_request, current_agent)
|
||||
|
||||
@@ -1,43 +1,11 @@
|
||||
"""Error handling utilities for agent generator."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def _sanitize_error_details(details: str) -> str:
|
||||
"""Sanitize error details to remove sensitive information.
|
||||
|
||||
Strips common patterns that could expose internal system info:
|
||||
- File paths (Unix and Windows)
|
||||
- Database connection strings
|
||||
- URLs with credentials
|
||||
- Stack trace internals
|
||||
|
||||
Args:
|
||||
details: Raw error details string
|
||||
|
||||
Returns:
|
||||
Sanitized error details safe for user display
|
||||
"""
|
||||
sanitized = re.sub(
|
||||
r"/[a-zA-Z0-9_./\-]+\.(py|js|ts|json|yaml|yml)", "[path]", details
|
||||
)
|
||||
sanitized = re.sub(r"[A-Z]:\\[a-zA-Z0-9_\\.\\-]+", "[path]", sanitized)
|
||||
sanitized = re.sub(
|
||||
r"(postgres|mysql|mongodb|redis)://[^\s]+", "[database_url]", sanitized
|
||||
)
|
||||
sanitized = re.sub(r"https?://[^:]+:[^@]+@[^\s]+", "[url]", sanitized)
|
||||
sanitized = re.sub(r", line \d+", "", sanitized)
|
||||
sanitized = re.sub(r'File "[^"]+",?', "", sanitized)
|
||||
|
||||
return sanitized.strip()
|
||||
|
||||
|
||||
def get_user_message_for_error(
|
||||
error_type: str,
|
||||
operation: str = "process the request",
|
||||
llm_parse_message: str | None = None,
|
||||
validation_message: str | None = None,
|
||||
error_details: str | None = None,
|
||||
) -> str:
|
||||
"""Get a user-friendly error message based on error type.
|
||||
|
||||
@@ -51,45 +19,25 @@ def get_user_message_for_error(
|
||||
message (e.g., "analyze the goal", "generate the agent")
|
||||
llm_parse_message: Custom message for llm_parse_error type
|
||||
validation_message: Custom message for validation_error type
|
||||
error_details: Optional additional details about the error
|
||||
|
||||
Returns:
|
||||
User-friendly error message suitable for display to the user
|
||||
"""
|
||||
base_message = ""
|
||||
|
||||
if error_type == "llm_parse_error":
|
||||
base_message = (
|
||||
return (
|
||||
llm_parse_message
|
||||
or "The AI had trouble processing this request. Please try again."
|
||||
)
|
||||
elif error_type == "validation_error":
|
||||
base_message = (
|
||||
return (
|
||||
validation_message
|
||||
or "The generated agent failed validation. "
|
||||
"This usually happens when the agent structure doesn't match "
|
||||
"what the platform expects. Please try simplifying your goal "
|
||||
"or breaking it into smaller parts."
|
||||
or "The request failed validation. Please try rephrasing."
|
||||
)
|
||||
elif error_type == "patch_error":
|
||||
base_message = (
|
||||
"Failed to apply the changes. The modification couldn't be "
|
||||
"validated. Please try a different approach or simplify the change."
|
||||
)
|
||||
return "Failed to apply the changes. Please try a different approach."
|
||||
elif error_type in ("timeout", "llm_timeout"):
|
||||
base_message = (
|
||||
"The request took too long to process. This can happen with "
|
||||
"complex agents. Please try again or simplify your goal."
|
||||
)
|
||||
return "The request took too long. Please try again."
|
||||
elif error_type in ("rate_limit", "llm_rate_limit"):
|
||||
base_message = "The service is currently busy. Please try again in a moment."
|
||||
return "The service is currently busy. Please try again in a moment."
|
||||
else:
|
||||
base_message = f"Failed to {operation}. Please try again."
|
||||
|
||||
if error_details:
|
||||
details = _sanitize_error_details(error_details)
|
||||
if len(details) > 200:
|
||||
details = details[:200] + "..."
|
||||
base_message += f"\n\nTechnical details: {details}"
|
||||
|
||||
return base_message
|
||||
return f"Failed to {operation}. Please try again."
|
||||
|
||||
@@ -117,16 +117,13 @@ def _get_client() -> httpx.AsyncClient:
|
||||
|
||||
|
||||
async def decompose_goal_external(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
description: str, context: str = ""
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to decompose a goal.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
@@ -139,12 +136,11 @@ async def decompose_goal_external(
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
if context:
|
||||
description = f"{description}\n\nAdditional context from user:\n{context}"
|
||||
|
||||
# Build the request payload
|
||||
payload: dict[str, Any] = {"description": description}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if context:
|
||||
# The external service uses user_instruction for additional context
|
||||
payload["user_instruction"] = context
|
||||
|
||||
try:
|
||||
response = await client.post("/api/decompose-description", json=payload)
|
||||
@@ -211,46 +207,21 @@ async def decompose_goal_external(
|
||||
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, {"status": "accepted"} for async, or error dict {"type": "error", ...} on error
|
||||
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {"instructions": instructions}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post("/api/generate-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
"/api/generate-agent", json={"instructions": instructions}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -258,7 +229,8 @@ async def generate_agent_external(
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
||||
f"Agent Generator generation failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
@@ -279,52 +251,27 @@ async def generate_agent_external(
|
||||
|
||||
|
||||
async def generate_agent_patch_external(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
operation_id: str | None = None,
|
||||
task_id: str | None = None,
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
operation_id: Operation ID for async processing (enables Redis Streams callback)
|
||||
task_id: Task ID for async processing (enables Redis Streams callback)
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, {"status": "accepted"} for async, or error dict on error
|
||||
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build request payload
|
||||
payload: dict[str, Any] = {
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if operation_id and task_id:
|
||||
payload["operation_id"] = operation_id
|
||||
payload["task_id"] = task_id
|
||||
|
||||
try:
|
||||
response = await client.post("/api/update-agent", json=payload)
|
||||
|
||||
# Handle 202 Accepted for async processing
|
||||
if response.status_code == 202:
|
||||
logger.info(
|
||||
f"Agent Generator accepted async update request "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return {
|
||||
"status": "accepted",
|
||||
"operation_id": operation_id,
|
||||
"task_id": task_id,
|
||||
}
|
||||
|
||||
response = await client.post(
|
||||
"/api/update-agent",
|
||||
json={
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -368,77 +315,6 @@ async def generate_agent_patch_external(
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def customize_template_external(
|
||||
template_agent: dict[str, Any],
|
||||
modification_request: str,
|
||||
context: str = "",
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to customize a template/marketplace agent.
|
||||
|
||||
Args:
|
||||
template_agent: The template agent JSON to customize
|
||||
modification_request: Natural language description of customizations
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
|
||||
Returns:
|
||||
Customized agent JSON, clarifying questions dict, or error dict on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
request = modification_request
|
||||
if context:
|
||||
request = f"{modification_request}\n\nAdditional context from user:\n{context}"
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"template_agent_json": template_agent,
|
||||
"modification_request": request,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await client.post("/api/template-modification", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
if not data.get("success"):
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator template customization failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
# Check if it's clarifying questions
|
||||
if data.get("type") == "clarifying_questions":
|
||||
return {
|
||||
"type": "clarifying_questions",
|
||||
"questions": data.get("questions", []),
|
||||
}
|
||||
|
||||
# Check if it's an error passed through
|
||||
if data.get("type") == "error":
|
||||
return _create_error_response(
|
||||
data.get("error", "Unknown error"),
|
||||
data.get("error_type", "unknown"),
|
||||
)
|
||||
|
||||
# Otherwise return the customized agent JSON
|
||||
return data.get("agent_json")
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
error_type, error_msg = _classify_http_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except httpx.RequestError as e:
|
||||
error_type, error_msg = _classify_request_error(e)
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error calling Agent Generator: {e}"
|
||||
logger.error(error_msg)
|
||||
return _create_error_response(error_msg, "unexpected_error")
|
||||
|
||||
|
||||
async def get_blocks_external() -> list[dict[str, Any]] | None:
|
||||
"""Get available blocks from the external service.
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
@@ -20,85 +19,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
SearchSource = Literal["marketplace", "library"]
|
||||
|
||||
_UUID_PATTERN = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(text: str) -> bool:
|
||||
"""Check if text is a valid UUID v4."""
|
||||
return bool(_UUID_PATTERN.match(text.strip()))
|
||||
|
||||
|
||||
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
||||
"""Fetch a library agent by ID (library agent ID or graph_id).
|
||||
|
||||
Tries multiple lookup strategies:
|
||||
1. First by graph_id (AgentGraph primary key)
|
||||
2. Then by library agent ID (LibraryAgent primary key)
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||
|
||||
Returns:
|
||||
AgentInfo if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch library agent by graph_id {agent_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def search_agents(
|
||||
query: str,
|
||||
@@ -149,37 +69,29 @@ async def search_agents(
|
||||
is_featured=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if _is_uuid(query):
|
||||
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
|
||||
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
logger.info(f"Found agent by direct ID lookup: {agent.name}")
|
||||
|
||||
if not agents:
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
else: # library
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
)
|
||||
logger.info(f"Found {len(agents)} agents in {source}")
|
||||
except NotFoundError:
|
||||
pass
|
||||
|
||||
@@ -8,9 +8,7 @@ from backend.api.features.chat.model import ChatSession
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
generate_agent,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
@@ -18,7 +16,6 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -99,10 +96,6 @@ class CreateAgentTool(BaseTool):
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not description:
|
||||
return ErrorResponse(
|
||||
message="Please provide a description of what the agent should do.",
|
||||
@@ -110,24 +103,9 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
library_agents = None
|
||||
if user_id:
|
||||
try:
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=description,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
|
||||
# Step 1: Decompose goal into steps
|
||||
try:
|
||||
decomposition_result = await decompose_goal(
|
||||
description, context, library_agents
|
||||
)
|
||||
decomposition_result = await decompose_goal(description, context)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -146,6 +124,7 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if decomposition_result.get("type") == "error":
|
||||
error_msg = decomposition_result.get("error", "Unknown error")
|
||||
error_type = decomposition_result.get("error_type", "unknown")
|
||||
@@ -165,6 +144,7 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if decomposition_result.get("type") == "clarifying_questions":
|
||||
questions = decomposition_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
@@ -183,6 +163,7 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check for unachievable/vague goals
|
||||
if decomposition_result.get("type") == "unachievable_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get("reason", "")
|
||||
@@ -209,27 +190,9 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if user_id and library_agents is not None:
|
||||
try:
|
||||
library_agents = await enrich_library_agents_from_steps(
|
||||
user_id=user_id,
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=library_agents,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||
|
||||
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
||||
try:
|
||||
agent_json = await generate_agent(
|
||||
decomposition_result,
|
||||
library_agents,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
agent_json = await generate_agent(decomposition_result)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -248,6 +211,7 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
||||
error_msg = agent_json.get("error", "Unknown error")
|
||||
error_type = agent_json.get("error_type", "unknown")
|
||||
@@ -255,12 +219,7 @@ class CreateAgentTool(BaseTool):
|
||||
error_type,
|
||||
operation="generate the agent",
|
||||
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
||||
validation_message=(
|
||||
"I wasn't able to create a valid agent for this request. "
|
||||
"The generated workflow had some structural issues. "
|
||||
"Please try simplifying your goal or breaking it into smaller steps."
|
||||
),
|
||||
error_details=error_msg,
|
||||
validation_message="The generated agent failed validation. Please try rephrasing your goal.",
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
@@ -273,24 +232,12 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if agent_json.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent generation delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent generation started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
agent_name = agent_json.get("name", "Generated Agent")
|
||||
agent_description = agent_json.get("description", "")
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
# Step 3: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
@@ -305,6 +252,7 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
@@ -322,7 +270,7 @@ class CreateAgentTool(BaseTool):
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -1,337 +0,0 @@
|
||||
"""CustomizeAgentTool - Customizes marketplace/template agents using natural language."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.api.features.store.exceptions import AgentNotFoundError
|
||||
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
customize_template,
|
||||
get_user_message_for_error,
|
||||
graph_to_json,
|
||||
save_agent_to_library,
|
||||
)
|
||||
from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
ToolResponseBase,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CustomizeAgentTool(BaseTool):
|
||||
"""Tool for customizing marketplace/template agents using natural language."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "customize_agent"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Customize a marketplace or template agent using natural language. "
|
||||
"Takes an existing agent from the marketplace and modifies it based on "
|
||||
"the user's requirements before adding to their library."
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_auth(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_long_running(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"agent_id": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"The marketplace agent ID in format 'creator/slug' "
|
||||
"(e.g., 'autogpt/newsletter-writer'). "
|
||||
"Get this from find_agent results."
|
||||
),
|
||||
},
|
||||
"modifications": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Natural language description of how to customize the agent. "
|
||||
"Be specific about what changes you want to make."
|
||||
),
|
||||
},
|
||||
"context": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Additional context or answers to previous clarifying questions."
|
||||
),
|
||||
},
|
||||
"save": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Whether to save the customized agent to the user's library. "
|
||||
"Default is true. Set to false for preview only."
|
||||
),
|
||||
"default": True,
|
||||
},
|
||||
},
|
||||
"required": ["agent_id", "modifications"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs,
|
||||
) -> ToolResponseBase:
|
||||
"""Execute the customize_agent tool.
|
||||
|
||||
Flow:
|
||||
1. Parse the agent ID to get creator/slug
|
||||
2. Fetch the template agent from the marketplace
|
||||
3. Call customize_template with the modification request
|
||||
4. Preview or save based on the save parameter
|
||||
"""
|
||||
agent_id = kwargs.get("agent_id", "").strip()
|
||||
modifications = kwargs.get("modifications", "").strip()
|
||||
context = kwargs.get("context", "")
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the marketplace agent ID (e.g., 'creator/agent-name').",
|
||||
error="missing_agent_id",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not modifications:
|
||||
return ErrorResponse(
|
||||
message="Please describe how you want to customize this agent.",
|
||||
error="missing_modifications",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Parse agent_id in format "creator/slug"
|
||||
parts = [p.strip() for p in agent_id.split("/")]
|
||||
if len(parts) != 2 or not parts[0] or not parts[1]:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Invalid agent ID format: '{agent_id}'. "
|
||||
"Expected format is 'creator/agent-name' "
|
||||
"(e.g., 'autogpt/newsletter-writer')."
|
||||
),
|
||||
error="invalid_agent_id_format",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
creator_username, agent_slug = parts
|
||||
|
||||
# Fetch the marketplace agent details
|
||||
try:
|
||||
agent_details = await store_db.get_store_agent_details(
|
||||
username=creator_username, agent_name=agent_slug
|
||||
)
|
||||
except AgentNotFoundError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"Could not find marketplace agent '{agent_id}'. "
|
||||
"Please check the agent ID and try again."
|
||||
),
|
||||
error="agent_not_found",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching marketplace agent {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the marketplace agent. Please try again.",
|
||||
error="fetch_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not agent_details.store_listing_version_id:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
f"The agent '{agent_id}' does not have an available version. "
|
||||
"Please try a different agent."
|
||||
),
|
||||
error="no_version_available",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Get the full agent graph
|
||||
try:
|
||||
graph = await store_db.get_agent(agent_details.store_listing_version_id)
|
||||
template_agent = graph_to_json(graph)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching agent graph for {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to fetch the agent configuration. Please try again.",
|
||||
error="graph_fetch_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Call customize_template
|
||||
try:
|
||||
result = await customize_template(
|
||||
template_agent=template_agent,
|
||||
modification_request=modifications,
|
||||
context=context,
|
||||
)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Agent customization is not available. "
|
||||
"The Agent Generator service is not configured."
|
||||
),
|
||||
error="service_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error calling customize_template for {agent_id}: {e}")
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Failed to customize the agent due to a service error. "
|
||||
"Please try again."
|
||||
),
|
||||
error="customization_service_error",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if result is None:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Failed to customize the agent. "
|
||||
"The agent generation service may be unavailable or timed out. "
|
||||
"Please try again."
|
||||
),
|
||||
error="customization_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle error response
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
user_message = get_user_message_for_error(
|
||||
error_type,
|
||||
operation="customize the agent",
|
||||
llm_parse_message=(
|
||||
"The AI had trouble customizing the agent. "
|
||||
"Please try again or simplify your request."
|
||||
),
|
||||
validation_message=(
|
||||
"The customized agent failed validation. "
|
||||
"Please try rephrasing your request."
|
||||
),
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
error=f"customization_failed:{error_type}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Handle clarifying questions
|
||||
if isinstance(result, dict) and result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions") or []
|
||||
if not isinstance(questions, list):
|
||||
logger.error(
|
||||
f"Unexpected clarifying questions format: {type(questions)}"
|
||||
)
|
||||
questions = []
|
||||
return ClarificationNeededResponse(
|
||||
message=(
|
||||
"I need some more information to customize this agent. "
|
||||
"Please answer the following questions:"
|
||||
),
|
||||
questions=[
|
||||
ClarifyingQuestion(
|
||||
question=q.get("question", ""),
|
||||
keyword=q.get("keyword", ""),
|
||||
example=q.get("example"),
|
||||
)
|
||||
for q in questions
|
||||
if isinstance(q, dict)
|
||||
],
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Result should be the customized agent JSON
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"Unexpected customize_template response type: {type(result)}")
|
||||
return ErrorResponse(
|
||||
message="Failed to customize the agent due to an unexpected response.",
|
||||
error="unexpected_response_type",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
customized_agent = result
|
||||
|
||||
agent_name = customized_agent.get(
|
||||
"name", f"Customized {agent_details.agent_name}"
|
||||
)
|
||||
agent_description = customized_agent.get("description", "")
|
||||
nodes = customized_agent.get("nodes")
|
||||
links = customized_agent.get("links")
|
||||
node_count = len(nodes) if isinstance(nodes, list) else 0
|
||||
link_count = len(links) if isinstance(links, list) else 0
|
||||
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
f"I've customized the agent '{agent_details.agent_name}'. "
|
||||
f"The customized agent has {node_count} blocks. "
|
||||
f"Review it and call customize_agent with save=true to save it."
|
||||
),
|
||||
agent_json=customized_agent,
|
||||
agent_name=agent_name,
|
||||
description=agent_description,
|
||||
node_count=node_count,
|
||||
link_count=link_count,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
error="auth_required",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to user's library
|
||||
try:
|
||||
created_graph, library_agent = await save_agent_to_library(
|
||||
customized_agent, user_id, is_update=False
|
||||
)
|
||||
|
||||
return AgentSavedResponse(
|
||||
message=(
|
||||
f"Customized agent '{created_graph.name}' "
|
||||
f"(based on '{agent_details.agent_name}') "
|
||||
f"has been saved to your library!"
|
||||
),
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving customized agent: {e}")
|
||||
return ErrorResponse(
|
||||
message="Failed to save the customized agent. Please try again.",
|
||||
error="save_failed",
|
||||
session_id=session_id,
|
||||
)
|
||||
@@ -9,7 +9,6 @@ from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
@@ -17,7 +16,6 @@ from .base import BaseTool
|
||||
from .models import (
|
||||
AgentPreviewResponse,
|
||||
AgentSavedResponse,
|
||||
AsyncProcessingResponse,
|
||||
ClarificationNeededResponse,
|
||||
ClarifyingQuestion,
|
||||
ErrorResponse,
|
||||
@@ -105,10 +103,6 @@ class EditAgentTool(BaseTool):
|
||||
save = kwargs.get("save", True)
|
||||
session_id = session.session_id if session else None
|
||||
|
||||
# Extract async processing params (passed by long-running tool handler)
|
||||
operation_id = kwargs.get("_operation_id")
|
||||
task_id = kwargs.get("_task_id")
|
||||
|
||||
if not agent_id:
|
||||
return ErrorResponse(
|
||||
message="Please provide the agent ID to edit.",
|
||||
@@ -123,6 +117,7 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 1: Fetch current agent
|
||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||
|
||||
if current_agent is None:
|
||||
@@ -132,34 +127,14 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
library_agents = None
|
||||
if user_id:
|
||||
try:
|
||||
graph_id = current_agent.get("id")
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=changes,
|
||||
exclude_graph_id=graph_id,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
|
||||
# Build the update request with context
|
||||
update_request = changes
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
|
||||
# Step 2: Generate updated agent (external service handles fixing and validation)
|
||||
try:
|
||||
result = await generate_agent_patch(
|
||||
update_request,
|
||||
current_agent,
|
||||
library_agents,
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
result = await generate_agent_patch(update_request, current_agent)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -178,19 +153,6 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if Agent Generator accepted for async processing
|
||||
if result.get("status") == "accepted":
|
||||
logger.info(
|
||||
f"Agent edit delegated to async processing "
|
||||
f"(operation_id={operation_id}, task_id={task_id})"
|
||||
)
|
||||
return AsyncProcessingResponse(
|
||||
message="Agent edit started. You'll be notified when it's complete.",
|
||||
operation_id=operation_id,
|
||||
task_id=task_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
@@ -200,7 +162,6 @@ class EditAgentTool(BaseTool):
|
||||
operation="generate the changes",
|
||||
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
||||
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
@@ -214,6 +175,7 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
@@ -232,6 +194,7 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Result is the updated agent JSON
|
||||
updated_agent = result
|
||||
|
||||
agent_name = updated_agent.get("name", "Updated Agent")
|
||||
@@ -239,6 +202,7 @@ class EditAgentTool(BaseTool):
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
# Step 3: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
@@ -254,6 +218,7 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library (creates a new version)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
@@ -271,7 +236,7 @@ class EditAgentTool(BaseTool):
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -38,8 +38,6 @@ class ResponseType(str, Enum):
|
||||
OPERATION_STARTED = "operation_started"
|
||||
OPERATION_PENDING = "operation_pending"
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
# Input validation
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -70,10 +68,6 @@ class AgentInfo(BaseModel):
|
||||
has_external_trigger: bool | None = None
|
||||
new_output: bool | None = None
|
||||
graph_id: str | None = None
|
||||
inputs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Input schema for the agent, including field names, types, and defaults",
|
||||
)
|
||||
|
||||
|
||||
class AgentsFoundResponse(ToolResponseBase):
|
||||
@@ -200,20 +194,6 @@ class ErrorResponse(ToolResponseBase):
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class InputValidationErrorResponse(ToolResponseBase):
|
||||
"""Response when run_agent receives unknown input fields."""
|
||||
|
||||
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
|
||||
unrecognized_fields: list[str] = Field(
|
||||
description="List of input field names that were not recognized"
|
||||
)
|
||||
inputs: dict[str, Any] = Field(
|
||||
description="The agent's valid input schema for reference"
|
||||
)
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
# Agent output models
|
||||
class ExecutionOutputInfo(BaseModel):
|
||||
"""Summary of a single execution's outputs."""
|
||||
@@ -372,15 +352,11 @@ class OperationStartedResponse(ToolResponseBase):
|
||||
|
||||
This is returned immediately to the client while the operation continues
|
||||
to execute. The user can close the tab and check back later.
|
||||
|
||||
The task_id can be used to reconnect to the SSE stream via
|
||||
GET /chat/tasks/{task_id}/stream?last_idx=0
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
operation_id: str
|
||||
tool_name: str
|
||||
task_id: str | None = None # For SSE reconnection
|
||||
|
||||
|
||||
class OperationPendingResponse(ToolResponseBase):
|
||||
@@ -404,20 +380,3 @@ class OperationInProgressResponse(ToolResponseBase):
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_IN_PROGRESS
|
||||
tool_call_id: str
|
||||
|
||||
|
||||
class AsyncProcessingResponse(ToolResponseBase):
|
||||
"""Response when an operation has been delegated to async processing.
|
||||
|
||||
This is returned by tools when the external service accepts the request
|
||||
for async processing (HTTP 202 Accepted). The Redis Streams completion
|
||||
consumer will handle the result when the external service completes.
|
||||
|
||||
The status field is specifically "accepted" to allow the long-running tool
|
||||
handler to detect this response and skip LLM continuation.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.OPERATION_STARTED
|
||||
status: str = "accepted" # Must be "accepted" for detection
|
||||
operation_id: str | None = None
|
||||
task_id: str | None = None
|
||||
|
||||
@@ -30,7 +30,6 @@ from .models import (
|
||||
ErrorResponse,
|
||||
ExecutionOptions,
|
||||
ExecutionStartedResponse,
|
||||
InputValidationErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
@@ -274,22 +273,6 @@ class RunAgentTool(BaseTool):
|
||||
input_properties = graph.input_schema.get("properties", {})
|
||||
required_fields = set(graph.input_schema.get("required", []))
|
||||
provided_inputs = set(params.inputs.keys())
|
||||
valid_fields = set(input_properties.keys())
|
||||
|
||||
# Check for unknown input fields
|
||||
unrecognized_fields = provided_inputs - valid_fields
|
||||
if unrecognized_fields:
|
||||
return InputValidationErrorResponse(
|
||||
message=(
|
||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||
f"Agent was not executed. Please use the correct field names from the schema."
|
||||
),
|
||||
session_id=session_id,
|
||||
unrecognized_fields=sorted(unrecognized_fields),
|
||||
inputs=graph.input_schema,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# If agent has inputs but none were provided AND use_defaults is not set,
|
||||
# always show what's available first so user can decide
|
||||
|
||||
@@ -402,42 +402,3 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
||||
# Should return error about missing schedule_name
|
||||
assert result_data.get("type") == "error"
|
||||
assert "schedule_name" in result_data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
|
||||
"""Test that run_agent returns input_validation_error for unknown input fields."""
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
tool = RunAgentTool()
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute with unknown input field names
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={
|
||||
"unknown_field": "some value",
|
||||
"another_unknown": "another value",
|
||||
},
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return input_validation_error type with unrecognized fields
|
||||
assert result_data.get("type") == "input_validation_error"
|
||||
assert "unrecognized_fields" in result_data
|
||||
assert set(result_data["unrecognized_fields"]) == {
|
||||
"another_unknown",
|
||||
"unknown_field",
|
||||
}
|
||||
assert "inputs" in result_data # Contains the valid schema
|
||||
assert "Agent was not executed" in result_data["message"]
|
||||
|
||||
@@ -5,8 +5,6 @@ import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
@@ -77,22 +75,15 @@ class RunBlockTool(BaseTool):
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to check credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
input_data = input_data or {}
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
@@ -105,33 +96,14 @@ class RunBlockTool(BaseTool):
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
# Get discriminator from input, falling back to schema default
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(
|
||||
field_info.discriminator
|
||||
)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
# field_info.provider is a frozenset of acceptable providers
|
||||
# field_info.supported_types is a frozenset of acceptable types
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in effective_field_info.provider
|
||||
and cred.type in effective_field_info.supported_types
|
||||
if cred.provider in field_info.provider
|
||||
and cred.type in field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -145,8 +117,8 @@ class RunBlockTool(BaseTool):
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(effective_field_info.provider), "unknown")
|
||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
@@ -214,9 +186,10 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
# Check credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||
user_id, block, input_data
|
||||
user_id, block
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
|
||||
@@ -8,7 +8,7 @@ from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -266,14 +266,13 @@ async def match_user_credentials_to_graph(
|
||||
credential_requirements,
|
||||
_node_fields,
|
||||
) in aggregated_creds.items():
|
||||
# Find first matching credential by provider, type, and scopes
|
||||
# Find first matching credential by provider and type
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in credential_requirements.provider
|
||||
and cred.type in credential_requirements.supported_types
|
||||
and _credential_has_required_scopes(cred, credential_requirements)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -297,17 +296,10 @@ async def match_user_credentials_to_graph(
|
||||
f"{credential_field_name} (validation failed: {e})"
|
||||
)
|
||||
else:
|
||||
# Build a helpful error message including scope requirements
|
||||
error_parts = [
|
||||
f"provider in {list(credential_requirements.provider)}",
|
||||
f"type in {list(credential_requirements.supported_types)}",
|
||||
]
|
||||
if credential_requirements.required_scopes:
|
||||
error_parts.append(
|
||||
f"scopes including {list(credential_requirements.required_scopes)}"
|
||||
)
|
||||
missing_creds.append(
|
||||
f"{credential_field_name} (requires {', '.join(error_parts)})"
|
||||
f"{credential_field_name} "
|
||||
f"(requires provider in {list(credential_requirements.provider)}, "
|
||||
f"type in {list(credential_requirements.supported_types)})"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -317,28 +309,6 @@ async def match_user_credentials_to_graph(
|
||||
return graph_credentials_inputs, missing_creds
|
||||
|
||||
|
||||
def _credential_has_required_scopes(
|
||||
credential: Credentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a credential has all the scopes required by the block.
|
||||
|
||||
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||
of the required scopes. For other credential types, returns True (no scope check).
|
||||
"""
|
||||
# Only OAuth2 credentials have scopes to check
|
||||
if credential.type != "oauth2":
|
||||
return True
|
||||
|
||||
# If no scopes are required, any credential matches
|
||||
if not requirements.required_scopes:
|
||||
return True
|
||||
|
||||
# Check that credential scopes are a superset of required scopes
|
||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||
|
||||
|
||||
async def check_user_has_required_credentials(
|
||||
user_id: str,
|
||||
required_credentials: list[CredentialsMetaInput],
|
||||
|
||||
@@ -39,7 +39,6 @@ async def list_library_agents(
|
||||
sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
include_executions: bool = False,
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Retrieves a paginated list of LibraryAgent records for a given user.
|
||||
@@ -50,9 +49,6 @@ async def list_library_agents(
|
||||
sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser).
|
||||
page: Current page (1-indexed).
|
||||
page_size: Number of items per page.
|
||||
include_executions: Whether to include execution data for status calculation.
|
||||
Defaults to False for performance (UI fetches status separately).
|
||||
Set to True when accurate status/metrics are needed (e.g., agent generator).
|
||||
|
||||
Returns:
|
||||
A LibraryAgentResponse containing the list of agents and pagination details.
|
||||
@@ -80,6 +76,7 @@ async def list_library_agents(
|
||||
"isArchived": False,
|
||||
}
|
||||
|
||||
# Build search filter if applicable
|
||||
if search_term:
|
||||
where_clause["OR"] = [
|
||||
{
|
||||
@@ -96,6 +93,7 @@ async def list_library_agents(
|
||||
},
|
||||
]
|
||||
|
||||
# Determine sorting
|
||||
order_by: prisma.types.LibraryAgentOrderByInput | None = None
|
||||
|
||||
if sort_by == library_model.LibraryAgentSort.CREATED_AT:
|
||||
@@ -107,7 +105,7 @@ async def list_library_agents(
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=include_executions
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
|
||||
@@ -9,7 +9,6 @@ import pydantic
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
||||
from backend.util.json import loads as json_loads
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -17,10 +16,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class LibraryAgentStatus(str, Enum):
|
||||
COMPLETED = "COMPLETED"
|
||||
HEALTHY = "HEALTHY"
|
||||
WAITING = "WAITING"
|
||||
ERROR = "ERROR"
|
||||
COMPLETED = "COMPLETED" # All runs completed
|
||||
HEALTHY = "HEALTHY" # Agent is running (not all runs have completed)
|
||||
WAITING = "WAITING" # Agent is queued or waiting to start
|
||||
ERROR = "ERROR" # Agent is in an error state
|
||||
|
||||
|
||||
class MarketplaceListingCreator(pydantic.BaseModel):
|
||||
@@ -40,30 +39,6 @@ class MarketplaceListing(pydantic.BaseModel):
|
||||
creator: MarketplaceListingCreator
|
||||
|
||||
|
||||
class RecentExecution(pydantic.BaseModel):
|
||||
"""Summary of a recent execution for quality assessment.
|
||||
|
||||
Used by the LLM to understand the agent's recent performance with specific examples
|
||||
rather than just aggregate statistics.
|
||||
"""
|
||||
|
||||
status: str
|
||||
correctness_score: float | None = None
|
||||
activity_summary: str | None = None
|
||||
|
||||
|
||||
def _parse_settings(settings: dict | str | None) -> GraphSettings:
|
||||
"""Parse settings from database, handling both dict and string formats."""
|
||||
if settings is None:
|
||||
return GraphSettings()
|
||||
try:
|
||||
if isinstance(settings, str):
|
||||
settings = json_loads(settings)
|
||||
return GraphSettings.model_validate(settings)
|
||||
except Exception:
|
||||
return GraphSettings()
|
||||
|
||||
|
||||
class LibraryAgent(pydantic.BaseModel):
|
||||
"""
|
||||
Represents an agent in the library, including metadata for display and
|
||||
@@ -73,7 +48,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str
|
||||
owner_user_id: str # ID of user who owns/created this agent graph
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -89,7 +64,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
|
||||
input_schema: dict[str, Any]
|
||||
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
output_schema: dict[str, Any]
|
||||
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
|
||||
description="Input schema for credentials required by the agent",
|
||||
@@ -106,19 +81,25 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
)
|
||||
trigger_setup_info: Optional[GraphTriggerInfo] = None
|
||||
|
||||
# Indicates whether there's a new output (based on recent runs)
|
||||
new_output: bool
|
||||
execution_count: int = 0
|
||||
success_rate: float | None = None
|
||||
avg_correctness_score: float | None = None
|
||||
recent_executions: list[RecentExecution] = pydantic.Field(
|
||||
default_factory=list,
|
||||
description="List of recent executions with status, score, and summary",
|
||||
)
|
||||
|
||||
# Whether the user can access the underlying graph
|
||||
can_access_graph: bool
|
||||
|
||||
# Indicates if this agent is the latest version
|
||||
is_latest_version: bool
|
||||
|
||||
# Whether the agent is marked as favorite by the user
|
||||
is_favorite: bool
|
||||
|
||||
# Recommended schedule cron (from marketplace agents)
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
# User-specific settings for this library agent
|
||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||
|
||||
# Marketplace listing information if the agent has been published
|
||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||
|
||||
@staticmethod
|
||||
@@ -142,6 +123,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
agent_updated_at = agent.AgentGraph.updatedAt
|
||||
lib_agent_updated_at = agent.updatedAt
|
||||
|
||||
# Compute updated_at as the latest between library agent and graph
|
||||
updated_at = (
|
||||
max(agent_updated_at, lib_agent_updated_at)
|
||||
if agent_updated_at
|
||||
@@ -154,6 +136,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
creator_name = agent.Creator.name or "Unknown"
|
||||
creator_image_url = agent.Creator.avatarUrl or ""
|
||||
|
||||
# Logic to calculate status and new_output
|
||||
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
days=7
|
||||
)
|
||||
@@ -162,55 +145,13 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
status = status_result.status
|
||||
new_output = status_result.new_output
|
||||
|
||||
execution_count = len(executions)
|
||||
success_rate: float | None = None
|
||||
avg_correctness_score: float | None = None
|
||||
if execution_count > 0:
|
||||
success_count = sum(
|
||||
1
|
||||
for e in executions
|
||||
if e.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED
|
||||
)
|
||||
success_rate = (success_count / execution_count) * 100
|
||||
|
||||
correctness_scores = []
|
||||
for e in executions:
|
||||
if e.stats and isinstance(e.stats, dict):
|
||||
score = e.stats.get("correctness_score")
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
correctness_scores.append(float(score))
|
||||
if correctness_scores:
|
||||
avg_correctness_score = sum(correctness_scores) / len(
|
||||
correctness_scores
|
||||
)
|
||||
|
||||
recent_executions: list[RecentExecution] = []
|
||||
for e in executions:
|
||||
exec_score: float | None = None
|
||||
exec_summary: str | None = None
|
||||
if e.stats and isinstance(e.stats, dict):
|
||||
score = e.stats.get("correctness_score")
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
exec_score = float(score)
|
||||
summary = e.stats.get("activity_status")
|
||||
if summary is not None and isinstance(summary, str):
|
||||
exec_summary = summary
|
||||
exec_status = (
|
||||
e.executionStatus.value
|
||||
if hasattr(e.executionStatus, "value")
|
||||
else str(e.executionStatus)
|
||||
)
|
||||
recent_executions.append(
|
||||
RecentExecution(
|
||||
status=exec_status,
|
||||
correctness_score=exec_score,
|
||||
activity_summary=exec_summary,
|
||||
)
|
||||
)
|
||||
|
||||
# Check if user can access the graph
|
||||
can_access_graph = agent.AgentGraph.userId == agent.userId
|
||||
|
||||
# Hard-coded to True until a method to check is implemented
|
||||
is_latest_version = True
|
||||
|
||||
# Build marketplace_listing if available
|
||||
marketplace_listing_data = None
|
||||
if store_listing and store_listing.ActiveVersion and profile:
|
||||
creator_data = MarketplaceListingCreator(
|
||||
@@ -249,15 +190,11 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
has_sensitive_action=graph.has_sensitive_action,
|
||||
trigger_setup_info=graph.trigger_setup_info,
|
||||
new_output=new_output,
|
||||
execution_count=execution_count,
|
||||
success_rate=success_rate,
|
||||
avg_correctness_score=avg_correctness_score,
|
||||
recent_executions=recent_executions,
|
||||
can_access_graph=can_access_graph,
|
||||
is_latest_version=is_latest_version,
|
||||
is_favorite=agent.isFavorite,
|
||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||
settings=_parse_settings(agent.settings),
|
||||
settings=GraphSettings.model_validate(agent.settings),
|
||||
marketplace_listing=marketplace_listing_data,
|
||||
)
|
||||
|
||||
@@ -283,15 +220,18 @@ def _calculate_agent_status(
|
||||
if not executions:
|
||||
return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False)
|
||||
|
||||
# Track how many times each execution status appears
|
||||
status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus}
|
||||
new_output = False
|
||||
|
||||
for execution in executions:
|
||||
# Check if there's a completed run more recent than `recent_threshold`
|
||||
if execution.createdAt >= recent_threshold:
|
||||
if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED:
|
||||
new_output = True
|
||||
status_counts[execution.executionStatus] += 1
|
||||
|
||||
# Determine the final status based on counts
|
||||
if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0:
|
||||
return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output)
|
||||
elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0:
|
||||
|
||||
@@ -112,7 +112,6 @@ async def get_store_agents(
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
agent_graph_id=agent.get("agentGraphId", ""),
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
@@ -171,7 +170,6 @@ async def get_store_agents(
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
agent_graph_id=agent.agentGraphId,
|
||||
)
|
||||
# Add to the list only if creation was successful
|
||||
store_agents.append(store_agent)
|
||||
|
||||
@@ -454,9 +454,6 @@ async def test_unified_hybrid_search_pagination(
|
||||
cleanup_embeddings: list,
|
||||
):
|
||||
"""Test unified search pagination works correctly."""
|
||||
# Use a unique search term to avoid matching other test data
|
||||
unique_term = f"xyzpagtest{uuid.uuid4().hex[:8]}"
|
||||
|
||||
# Create multiple items
|
||||
content_ids = []
|
||||
for i in range(5):
|
||||
@@ -468,14 +465,14 @@ async def test_unified_hybrid_search_pagination(
|
||||
content_type=ContentType.BLOCK,
|
||||
content_id=content_id,
|
||||
embedding=mock_embedding,
|
||||
searchable_text=f"{unique_term} item number {i}",
|
||||
searchable_text=f"pagination test item number {i}",
|
||||
metadata={"index": i},
|
||||
user_id=None,
|
||||
)
|
||||
|
||||
# Get first page
|
||||
page1_results, total1 = await unified_hybrid_search(
|
||||
query=unique_term,
|
||||
query="pagination test",
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=1,
|
||||
page_size=2,
|
||||
@@ -483,7 +480,7 @@ async def test_unified_hybrid_search_pagination(
|
||||
|
||||
# Get second page
|
||||
page2_results, total2 = await unified_hybrid_search(
|
||||
query=unique_term,
|
||||
query="pagination test",
|
||||
content_types=[ContentType.BLOCK],
|
||||
page=2,
|
||||
page_size=2,
|
||||
|
||||
@@ -600,7 +600,6 @@ async def hybrid_search(
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
sa."agentGraphId",
|
||||
-- Searchable text for BM25 reranking
|
||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||
-- Semantic score
|
||||
@@ -660,7 +659,6 @@ async def hybrid_search(
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
"agentGraphId",
|
||||
searchable_text,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
|
||||
@@ -38,7 +38,6 @@ class StoreAgent(pydantic.BaseModel):
|
||||
description: str
|
||||
runs: int
|
||||
rating: float
|
||||
agent_graph_id: str
|
||||
|
||||
|
||||
class StoreAgentsResponse(pydantic.BaseModel):
|
||||
|
||||
@@ -26,13 +26,11 @@ def test_store_agent():
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
assert agent.slug == "test-agent"
|
||||
assert agent.agent_name == "Test Agent"
|
||||
assert agent.runs == 50
|
||||
assert agent.rating == 4.5
|
||||
assert agent.agent_graph_id == "test-graph-id"
|
||||
|
||||
|
||||
def test_store_agents_response():
|
||||
@@ -48,7 +46,6 @@ def test_store_agents_response():
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
|
||||
@@ -82,7 +82,6 @@ def test_get_agents_featured(
|
||||
description="Featured agent description",
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-1",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
@@ -128,7 +127,6 @@ def test_get_agents_by_creator(
|
||||
description="Creator agent description",
|
||||
runs=50,
|
||||
rating=4.0,
|
||||
agent_graph_id="test-graph-2",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
@@ -174,7 +172,6 @@ def test_get_agents_sorted(
|
||||
description="Top agent description",
|
||||
runs=1000,
|
||||
rating=5.0,
|
||||
agent_graph_id="test-graph-3",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
@@ -220,7 +217,6 @@ def test_get_agents_search(
|
||||
description="Specific search term description",
|
||||
runs=75,
|
||||
rating=4.2,
|
||||
agent_graph_id="test-graph-search",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
@@ -266,7 +262,6 @@ def test_get_agents_category(
|
||||
description="Category agent description",
|
||||
runs=60,
|
||||
rating=4.1,
|
||||
agent_graph_id="test-graph-category",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
@@ -311,7 +306,6 @@ def test_get_agents_pagination(
|
||||
description=f"Agent {i} description",
|
||||
runs=i * 10,
|
||||
rating=4.0,
|
||||
agent_graph_id="test-graph-2",
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
|
||||
@@ -33,7 +33,6 @@ class TestCacheDeletion:
|
||||
description="Test description",
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
],
|
||||
pagination=Pagination(
|
||||
|
||||
@@ -40,10 +40,6 @@ import backend.data.user
|
||||
import backend.integrations.webhooks.utils
|
||||
import backend.util.service
|
||||
import backend.util.settings
|
||||
from backend.api.features.chat.completion_consumer import (
|
||||
start_completion_consumer,
|
||||
stop_completion_consumer,
|
||||
)
|
||||
from backend.blocks.llm import DEFAULT_LLM_MODEL
|
||||
from backend.data.model import Credentials
|
||||
from backend.integrations.providers import ProviderName
|
||||
@@ -122,21 +118,9 @@ async def lifespan_context(app: fastapi.FastAPI):
|
||||
await backend.data.graph.migrate_llm_models(DEFAULT_LLM_MODEL)
|
||||
await backend.integrations.webhooks.utils.migrate_legacy_triggered_graphs()
|
||||
|
||||
# Start chat completion consumer for Redis Streams notifications
|
||||
try:
|
||||
await start_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not start chat completion consumer: {e}")
|
||||
|
||||
with launch_darkly_context():
|
||||
yield
|
||||
|
||||
# Stop chat completion consumer
|
||||
try:
|
||||
await stop_completion_consumer()
|
||||
except Exception as e:
|
||||
logger.warning(f"Error stopping chat completion consumer: {e}")
|
||||
|
||||
try:
|
||||
await shutdown_cloud_storage_handler()
|
||||
except Exception as e:
|
||||
|
||||
@@ -66,24 +66,18 @@ async def event_broadcaster(manager: ConnectionManager):
|
||||
execution_bus = AsyncRedisExecutionEventBus()
|
||||
notification_bus = AsyncRedisNotificationEventBus()
|
||||
|
||||
try:
|
||||
async def execution_worker():
|
||||
async for event in execution_bus.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
|
||||
async def execution_worker():
|
||||
async for event in execution_bus.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
async def notification_worker():
|
||||
async for notification in notification_bus.listen("*"):
|
||||
await manager.send_notification(
|
||||
user_id=notification.user_id,
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
async def notification_worker():
|
||||
async for notification in notification_bus.listen("*"):
|
||||
await manager.send_notification(
|
||||
user_id=notification.user_id,
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
finally:
|
||||
# Ensure PubSub connections are closed on any exit to prevent leaks
|
||||
await execution_bus.close()
|
||||
await notification_bus.close()
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""ElevenLabs integration blocks - test credentials and shared utilities."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="elevenlabs",
|
||||
api_key=SecretStr("mock-elevenlabs-api-key"),
|
||||
title="Mock ElevenLabs API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
ElevenLabsCredentials = APIKeyCredentials
|
||||
ElevenLabsCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.ELEVENLABS], Literal["api_key"]
|
||||
]
|
||||
@@ -32,7 +32,7 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.prompt import compress_context, estimate_token_count
|
||||
from backend.util.prompt import compress_prompt, estimate_token_count
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
@@ -115,6 +115,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
@@ -279,6 +280,9 @@ MODEL_METADATA = {
|
||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
||||
), # claude-haiku-4-5-20251001
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
||||
), # claude-3-haiku-20240307
|
||||
@@ -634,18 +638,11 @@ async def llm_call(
|
||||
context_window = llm_model.context_window
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
result = await compress_context(
|
||||
prompt = compress_prompt(
|
||||
messages=prompt,
|
||||
target_tokens=llm_model.context_window // 2,
|
||||
client=None, # Truncation-only, no LLM summarization
|
||||
reserve=0, # Caller handles response token budget separately
|
||||
lossy_ok=True,
|
||||
)
|
||||
if result.error:
|
||||
logger.warning(
|
||||
f"Prompt compression did not meet target: {result.error}. "
|
||||
f"Proceeding with {result.token_count} tokens."
|
||||
)
|
||||
prompt = result.messages
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
|
||||
@@ -1,246 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.fx.Loop import Loop
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class MediaDurationBlock(Block):
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
media_in: MediaFileType = SchemaField(
|
||||
description="Media input (URL, data URI, or local path)."
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video (True) or audio (False).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
duration: float = SchemaField(
|
||||
description="Duration of the media file (in seconds)."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||
description="Block to get the duration of a media file.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=MediaDurationBlock.Input,
|
||||
output_schema=MediaDurationBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the input media locally
|
||||
local_media_path = await store_media_file(
|
||||
file=input_data.media_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
assert execution_context.graph_exec_id is not None
|
||||
media_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_media_path
|
||||
)
|
||||
|
||||
# 2) Load the clip
|
||||
if input_data.is_video:
|
||||
clip = VideoFileClip(media_abspath)
|
||||
else:
|
||||
clip = AudioFileClip(media_abspath)
|
||||
|
||||
yield "duration", clip.duration
|
||||
|
||||
|
||||
class LoopVideoBlock(Block):
|
||||
"""
|
||||
Block for looping (repeating) a video clip until a given duration or number of loops.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="The input video (can be a URL, data URI, or local path)."
|
||||
)
|
||||
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
|
||||
duration: Optional[float] = SchemaField(
|
||||
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
|
||||
default=None,
|
||||
ge=0.0,
|
||||
)
|
||||
n_loops: Optional[int] = SchemaField(
|
||||
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
|
||||
default=None,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: str = SchemaField(
|
||||
description="Looped video returned either as a relative path or a data URI."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||
description="Block to loop a video to a given duration or number of repeats.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=LoopVideoBlock.Input,
|
||||
output_schema=LoopVideoBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the input video locally
|
||||
local_video_path = await store_media_file(
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||
|
||||
# 2) Load the clip
|
||||
clip = VideoFileClip(input_abspath)
|
||||
|
||||
# 3) Apply the loop effect
|
||||
looped_clip = clip
|
||||
if input_data.duration:
|
||||
# Loop until we reach the specified duration
|
||||
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
|
||||
elif input_data.n_loops:
|
||||
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
|
||||
else:
|
||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||
|
||||
assert isinstance(looped_clip, VideoFileClip)
|
||||
|
||||
# 4) Save the looped output
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||
|
||||
looped_clip = looped_clip.with_audio(clip.audio)
|
||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
|
||||
|
||||
class AddAudioToVideoBlock(Block):
|
||||
"""
|
||||
Block that adds (attaches) an audio track to an existing video.
|
||||
Optionally scale the volume of the new track.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Video input (URL, data URI, or local path)."
|
||||
)
|
||||
audio_in: MediaFileType = SchemaField(
|
||||
description="Audio input (URL, data URI, or local path)."
|
||||
)
|
||||
volume: float = SchemaField(
|
||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Final video (with attached audio), as a path or data URI."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||
description="Block to attach an audio file to a video file using moviepy.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=AddAudioToVideoBlock.Input,
|
||||
output_schema=AddAudioToVideoBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the inputs locally
|
||||
local_video_path = await store_media_file(
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
local_audio_path = await store_media_file(
|
||||
file=input_data.audio_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||
video_abspath = os.path.join(abs_temp_dir, local_video_path)
|
||||
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
|
||||
|
||||
# 2) Load video + audio with moviepy
|
||||
video_clip = VideoFileClip(video_abspath)
|
||||
audio_clip = AudioFileClip(audio_abspath)
|
||||
# Optionally scale volume
|
||||
if input_data.volume != 1.0:
|
||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||
|
||||
# 3) Attach the new audio track
|
||||
final_clip = video_clip.with_audio(audio_clip)
|
||||
|
||||
# 4) Write to output file
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
@@ -83,7 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
|
||||
# Anthropic
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
@@ -137,7 +137,7 @@ class StagehandObserveBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
@@ -182,7 +182,10 @@ class StagehandObserveBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.debug(f"OBSERVE: Using model provider {model_credentials.provider}")
|
||||
logger.info(f"OBSERVE: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"OBSERVE: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
@@ -227,7 +230,7 @@ class StagehandActBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
@@ -279,7 +282,10 @@ class StagehandActBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.debug(f"ACT: Using model provider {model_credentials.provider}")
|
||||
logger.info(f"ACT: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"ACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
@@ -324,7 +330,7 @@ class StagehandExtractBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
@@ -364,7 +370,10 @@ class StagehandExtractBlock(Block):
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
|
||||
logger.debug(f"EXTRACT: Using model provider {model_credentials.provider}")
|
||||
logger.info(f"EXTRACT: Stagehand credentials: {stagehand_credentials}")
|
||||
logger.info(
|
||||
f"EXTRACT: Model credentials: {model_credentials} for provider {model_credentials.provider} secret: {model_credentials.api_key.get_secret_value()}"
|
||||
)
|
||||
|
||||
with disable_signal_handling():
|
||||
stagehand = Stagehand(
|
||||
|
||||
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Video editing blocks for AutoGPT Platform.
|
||||
|
||||
This module provides blocks for:
|
||||
- Downloading videos from URLs (YouTube, Vimeo, news sites, direct links)
|
||||
- Clipping/trimming video segments
|
||||
- Concatenating multiple videos
|
||||
- Adding text overlays
|
||||
- Adding AI-generated narration
|
||||
- Getting media duration
|
||||
- Looping videos
|
||||
- Adding audio to videos
|
||||
|
||||
Dependencies:
|
||||
- yt-dlp: For video downloading
|
||||
- moviepy: For video editing operations
|
||||
- elevenlabs: For AI narration (optional)
|
||||
"""
|
||||
|
||||
from backend.blocks.video.add_audio import AddAudioToVideoBlock
|
||||
from backend.blocks.video.clip import VideoClipBlock
|
||||
from backend.blocks.video.concat import VideoConcatBlock
|
||||
from backend.blocks.video.download import VideoDownloadBlock
|
||||
from backend.blocks.video.duration import MediaDurationBlock
|
||||
from backend.blocks.video.loop import LoopVideoBlock
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
from backend.blocks.video.text_overlay import VideoTextOverlayBlock
|
||||
|
||||
__all__ = [
|
||||
"AddAudioToVideoBlock",
|
||||
"LoopVideoBlock",
|
||||
"MediaDurationBlock",
|
||||
"VideoClipBlock",
|
||||
"VideoConcatBlock",
|
||||
"VideoDownloadBlock",
|
||||
"VideoNarrationBlock",
|
||||
"VideoTextOverlayBlock",
|
||||
]
|
||||
34
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
34
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Shared utilities for video blocks."""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def get_video_codecs(output_path: str) -> tuple[str, str]:
|
||||
"""Get appropriate video and audio codecs based on output file extension.
|
||||
|
||||
Args:
|
||||
output_path: Path to the output file (used to determine extension)
|
||||
|
||||
Returns:
|
||||
Tuple of (video_codec, audio_codec)
|
||||
|
||||
Codec mappings:
|
||||
- .mp4: H.264 + AAC (universal compatibility)
|
||||
- .webm: VP8 + Vorbis (web streaming)
|
||||
- .mkv: H.264 + AAC (container supports many codecs)
|
||||
- .mov: H.264 + AAC (Apple QuickTime, widely compatible)
|
||||
- .m4v: H.264 + AAC (Apple iTunes/devices)
|
||||
- .avi: MPEG-4 + MP3 (legacy Windows)
|
||||
"""
|
||||
ext = os.path.splitext(output_path)[1].lower()
|
||||
|
||||
codec_map: dict[str, tuple[str, str]] = {
|
||||
".mp4": ("libx264", "aac"),
|
||||
".webm": ("libvpx", "libvorbis"),
|
||||
".mkv": ("libx264", "aac"),
|
||||
".mov": ("libx264", "aac"),
|
||||
".m4v": ("libx264", "aac"),
|
||||
".avi": ("mpeg4", "libmp3lame"),
|
||||
}
|
||||
|
||||
return codec_map.get(ext, ("libx264", "aac"))
|
||||
102
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
102
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""AddAudioToVideoBlock - Attach an audio track to a video file."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, store_media_file
|
||||
|
||||
|
||||
class AddAudioToVideoBlock(Block):
|
||||
"""Add (attach) an audio track to an existing video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Video input (URL, data URI, or local path)."
|
||||
)
|
||||
audio_in: MediaFileType = SchemaField(
|
||||
description="Audio input (URL, data URI, or local path)."
|
||||
)
|
||||
volume: float = SchemaField(
|
||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Final video (with attached audio), as a path or data URI."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||
description="Block to attach an audio file to a video file using moviepy.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=AddAudioToVideoBlock.Input,
|
||||
output_schema=AddAudioToVideoBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the inputs locally
|
||||
local_video_path = await store_media_file(
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
local_audio_path = await store_media_file(
|
||||
file=input_data.audio_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||
video_abspath = os.path.join(abs_temp_dir, local_video_path)
|
||||
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
|
||||
|
||||
# 2) Load video + audio with moviepy
|
||||
video_clip = VideoFileClip(video_abspath)
|
||||
audio_clip = AudioFileClip(audio_abspath)
|
||||
# Optionally scale volume
|
||||
if input_data.volume != 1.0:
|
||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||
|
||||
# 3) Attach the new audio track
|
||||
final_clip = video_clip.with_audio(audio_clip)
|
||||
|
||||
# 4) Write to output file
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
165
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
165
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""VideoClipBlock - Extract a segment from a video file."""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.video._utils import get_video_codecs
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoClipBlock(Block):
|
||||
"""Extract a time segment from a video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Input video (URL, data URI, or local path)"
|
||||
)
|
||||
start_time: float = SchemaField(description="Start time in seconds", ge=0.0)
|
||||
end_time: float = SchemaField(description="End time in seconds", ge=0.0)
|
||||
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||
description="Output format", default="mp4", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Clipped video file (path or data URI)"
|
||||
)
|
||||
duration: float = SchemaField(description="Clip duration in seconds")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8f539119-e580-4d86-ad41-86fbcb22abb1",
|
||||
description="Extract a time segment from a video",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"video_in": "/tmp/test.mp4",
|
||||
"start_time": 0.0,
|
||||
"end_time": 10.0,
|
||||
},
|
||||
test_output=[("video_out", str), ("duration", float)],
|
||||
test_mock={
|
||||
"_clip_video": lambda *args: 10.0,
|
||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||
"_store_output_video": lambda *args, **kwargs: "clip_test.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_input_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store input video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _clip_video(
|
||||
self,
|
||||
video_abspath: str,
|
||||
output_abspath: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> float:
|
||||
"""Extract a clip from a video. Extracted for testability."""
|
||||
clip = None
|
||||
subclip = None
|
||||
try:
|
||||
clip = VideoFileClip(video_abspath)
|
||||
subclip = clip.subclipped(start_time, end_time)
|
||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||
subclip.write_videofile(
|
||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||
)
|
||||
return subclip.duration
|
||||
finally:
|
||||
if subclip:
|
||||
subclip.close()
|
||||
if clip:
|
||||
clip.close()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Validate time range
|
||||
if input_data.end_time <= input_data.start_time:
|
||||
raise BlockExecutionError(
|
||||
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
)
|
||||
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Store the input video locally
|
||||
local_video_path = await self._store_input_video(
|
||||
execution_context, input_data.video_in
|
||||
)
|
||||
video_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_video_path
|
||||
)
|
||||
|
||||
# Build output path
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_clip_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
# Ensure correct extension
|
||||
base, _ = os.path.splitext(output_filename)
|
||||
output_filename = MediaFileType(f"{base}.{input_data.output_format}")
|
||||
output_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, output_filename
|
||||
)
|
||||
|
||||
duration = self._clip_video(
|
||||
video_abspath,
|
||||
output_abspath,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, output_filename
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
yield "duration", duration
|
||||
|
||||
except BlockExecutionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to clip video: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
197
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
197
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""VideoConcatBlock - Concatenate multiple video clips into one."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from moviepy import concatenate_videoclips
|
||||
from moviepy.video.fx import CrossFadeIn, CrossFadeOut, FadeIn, FadeOut
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.video._utils import get_video_codecs
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoConcatBlock(Block):
|
||||
"""Merge multiple video clips into one continuous video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
videos: list[MediaFileType] = SchemaField(
|
||||
description="List of video files to concatenate (in order)"
|
||||
)
|
||||
transition: Literal["none", "crossfade", "fade_black"] = SchemaField(
|
||||
description="Transition between clips", default="none"
|
||||
)
|
||||
transition_duration: int = SchemaField(
|
||||
description="Transition duration in seconds",
|
||||
default=1,
|
||||
ge=0,
|
||||
advanced=True,
|
||||
)
|
||||
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||
description="Output format", default="mp4", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Concatenated video file (path or data URI)"
|
||||
)
|
||||
total_duration: float = SchemaField(description="Total duration in seconds")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9b0f531a-1118-487f-aeec-3fa63ea8900a",
|
||||
description="Merge multiple video clips into one continuous video",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={"videos": ["/tmp/a.mp4", "/tmp/b.mp4"]},
|
||||
test_output=[("video_out", str), ("total_duration", float)],
|
||||
test_mock={
|
||||
"_concat_videos": lambda *args: 20.0,
|
||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||
"_store_output_video": lambda *args, **kwargs: "concat_test.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_input_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store input video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _concat_videos(
|
||||
self,
|
||||
video_abspaths: list[str],
|
||||
output_abspath: str,
|
||||
transition: str,
|
||||
transition_duration: int,
|
||||
) -> float:
|
||||
"""Concatenate videos. Extracted for testability."""
|
||||
clips = []
|
||||
faded_clips = []
|
||||
final = None
|
||||
try:
|
||||
# Load clips
|
||||
for v in video_abspaths:
|
||||
clips.append(VideoFileClip(v))
|
||||
|
||||
if transition == "crossfade":
|
||||
for i, clip in enumerate(clips):
|
||||
effects = []
|
||||
if i > 0:
|
||||
effects.append(CrossFadeIn(transition_duration))
|
||||
if i < len(clips) - 1:
|
||||
effects.append(CrossFadeOut(transition_duration))
|
||||
if effects:
|
||||
clip = clip.with_effects(effects)
|
||||
faded_clips.append(clip)
|
||||
final = concatenate_videoclips(
|
||||
faded_clips,
|
||||
method="compose",
|
||||
padding=-transition_duration,
|
||||
)
|
||||
elif transition == "fade_black":
|
||||
for clip in clips:
|
||||
faded = clip.with_effects(
|
||||
[FadeIn(transition_duration), FadeOut(transition_duration)]
|
||||
)
|
||||
faded_clips.append(faded)
|
||||
final = concatenate_videoclips(faded_clips)
|
||||
else:
|
||||
final = concatenate_videoclips(clips)
|
||||
|
||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||
final.write_videofile(
|
||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||
)
|
||||
|
||||
return final.duration
|
||||
finally:
|
||||
if final:
|
||||
final.close()
|
||||
for clip in faded_clips:
|
||||
clip.close()
|
||||
for clip in clips:
|
||||
clip.close()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Validate minimum clips
|
||||
if len(input_data.videos) < 2:
|
||||
raise BlockExecutionError(
|
||||
message="At least 2 videos are required for concatenation",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
)
|
||||
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Store all input videos locally
|
||||
video_abspaths = []
|
||||
for video in input_data.videos:
|
||||
local_path = await self._store_input_video(execution_context, video)
|
||||
video_abspaths.append(
|
||||
get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||
)
|
||||
|
||||
# Build output path
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_concat.{input_data.output_format}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, output_filename
|
||||
)
|
||||
|
||||
total_duration = self._concat_videos(
|
||||
video_abspaths,
|
||||
output_abspath,
|
||||
input_data.transition,
|
||||
input_data.transition_duration,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, output_filename
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
yield "total_duration", total_duration
|
||||
|
||||
except BlockExecutionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to concatenate videos: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
167
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
167
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""VideoDownloadBlock - Download video from URL (YouTube, Vimeo, news sites, direct links)."""
|
||||
|
||||
import os
|
||||
import typing
|
||||
from typing import Literal
|
||||
|
||||
import yt_dlp
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from yt_dlp import _Params
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoDownloadBlock(Block):
|
||||
"""Download video from URL using yt-dlp."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
url: str = SchemaField(
|
||||
description="URL of the video to download (YouTube, Vimeo, direct link, etc.)",
|
||||
placeholder="https://www.youtube.com/watch?v=...",
|
||||
)
|
||||
quality: Literal["best", "1080p", "720p", "480p", "audio_only"] = SchemaField(
|
||||
description="Video quality preference", default="720p"
|
||||
)
|
||||
output_format: Literal["mp4", "webm", "mkv"] = SchemaField(
|
||||
description="Output video format", default="mp4", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_file: MediaFileType = SchemaField(
|
||||
description="Downloaded video (path or data URI)"
|
||||
)
|
||||
duration: float = SchemaField(description="Video duration in seconds")
|
||||
title: str = SchemaField(description="Video title from source")
|
||||
source_url: str = SchemaField(description="Original source URL")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c35daabb-cd60-493b-b9ad-51f1fe4b50c4",
|
||||
description="Download video from URL (YouTube, Vimeo, news sites, direct links)",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
||||
"quality": "480p",
|
||||
},
|
||||
test_output=[
|
||||
("video_file", str),
|
||||
("duration", float),
|
||||
("title", str),
|
||||
("source_url", str),
|
||||
],
|
||||
test_mock={
|
||||
"_download_video": lambda *args: ("video.mp4", 212.0, "Test Video"),
|
||||
"_store_output_video": lambda *args, **kwargs: "video.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _get_format_string(self, quality: str) -> str:
|
||||
formats = {
|
||||
"best": "bestvideo+bestaudio/best",
|
||||
"1080p": "bestvideo[height<=1080]+bestaudio/best[height<=1080]",
|
||||
"720p": "bestvideo[height<=720]+bestaudio/best[height<=720]",
|
||||
"480p": "bestvideo[height<=480]+bestaudio/best[height<=480]",
|
||||
"audio_only": "bestaudio/best",
|
||||
}
|
||||
return formats.get(quality, formats["720p"])
|
||||
|
||||
def _download_video(
|
||||
self,
|
||||
url: str,
|
||||
quality: str,
|
||||
output_format: str,
|
||||
output_dir: str,
|
||||
node_exec_id: str,
|
||||
) -> tuple[str, float, str]:
|
||||
"""Download video. Extracted for testability."""
|
||||
output_template = os.path.join(
|
||||
output_dir, f"{node_exec_id}_%(title).50s.%(ext)s"
|
||||
)
|
||||
|
||||
ydl_opts: "_Params" = {
|
||||
"format": self._get_format_string(quality),
|
||||
"outtmpl": output_template,
|
||||
"merge_output_format": output_format,
|
||||
"quiet": True,
|
||||
"no_warnings": True,
|
||||
}
|
||||
|
||||
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||
info = ydl.extract_info(url, download=True)
|
||||
video_path = ydl.prepare_filename(info)
|
||||
|
||||
# Handle format conversion in filename
|
||||
if not video_path.endswith(f".{output_format}"):
|
||||
video_path = video_path.rsplit(".", 1)[0] + f".{output_format}"
|
||||
|
||||
# Return just the filename, not the full path
|
||||
filename = os.path.basename(video_path)
|
||||
|
||||
return (
|
||||
filename,
|
||||
info.get("duration") or 0.0,
|
||||
info.get("title") or "Unknown",
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Get the exec file directory
|
||||
output_dir = get_exec_file_path(execution_context.graph_exec_id, "")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
filename, duration, title = self._download_video(
|
||||
input_data.url,
|
||||
input_data.quality,
|
||||
input_data.output_format,
|
||||
output_dir,
|
||||
node_exec_id,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, MediaFileType(filename)
|
||||
)
|
||||
|
||||
yield "video_file", video_out
|
||||
yield "duration", duration
|
||||
yield "title", title
|
||||
yield "source_url", input_data.url
|
||||
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to download video: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
68
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
68
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""MediaDurationBlock - Get the duration of a media file."""
|
||||
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class MediaDurationBlock(Block):
|
||||
"""Get the duration of a media file (video or audio)."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
media_in: MediaFileType = SchemaField(
|
||||
description="Media input (URL, data URI, or local path)."
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video (True) or audio (False).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
duration: float = SchemaField(
|
||||
description="Duration of the media file (in seconds)."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||
description="Block to get the duration of a media file.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=MediaDurationBlock.Input,
|
||||
output_schema=MediaDurationBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the input media locally
|
||||
local_media_path = await store_media_file(
|
||||
file=input_data.media_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
assert execution_context.graph_exec_id is not None
|
||||
media_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_media_path
|
||||
)
|
||||
|
||||
# 2) Load the clip
|
||||
if input_data.is_video:
|
||||
clip = VideoFileClip(media_abspath)
|
||||
else:
|
||||
clip = AudioFileClip(media_abspath)
|
||||
|
||||
yield "duration", clip.duration
|
||||
104
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
104
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""LoopVideoBlock - Loop a video to a given duration or number of repeats."""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from moviepy.video.fx.Loop import Loop
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class LoopVideoBlock(Block):
|
||||
"""Loop (repeat) a video clip until a given duration or number of loops."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="The input video (can be a URL, data URI, or local path)."
|
||||
)
|
||||
duration: Optional[float] = SchemaField(
|
||||
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
|
||||
default=None,
|
||||
ge=0.0,
|
||||
)
|
||||
n_loops: Optional[int] = SchemaField(
|
||||
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
|
||||
default=None,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: str = SchemaField(
|
||||
description="Looped video returned either as a relative path or a data URI."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||
description="Block to loop a video to a given duration or number of repeats.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=LoopVideoBlock.Input,
|
||||
output_schema=LoopVideoBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the input video locally
|
||||
local_video_path = await store_media_file(
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||
|
||||
# 2) Load the clip
|
||||
clip = VideoFileClip(input_abspath)
|
||||
|
||||
# 3) Apply the loop effect
|
||||
looped_clip = clip
|
||||
if input_data.duration:
|
||||
# Loop until we reach the specified duration
|
||||
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
|
||||
elif input_data.n_loops:
|
||||
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
|
||||
else:
|
||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||
|
||||
assert isinstance(looped_clip, VideoFileClip)
|
||||
|
||||
# 4) Save the looped output
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||
|
||||
looped_clip = looped_clip.with_audio(clip.audio)
|
||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
263
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
263
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""VideoNarrationBlock - Generate AI voice narration and add to video."""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from elevenlabs import ElevenLabs
|
||||
from moviepy import CompositeAudioClip
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.elevenlabs._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ElevenLabsCredentials,
|
||||
ElevenLabsCredentialsInput,
|
||||
)
|
||||
from backend.blocks.video._utils import get_video_codecs
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoNarrationBlock(Block):
|
||||
"""Generate AI narration and add to video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: ElevenLabsCredentialsInput = CredentialsField(
|
||||
description="ElevenLabs API key for voice synthesis"
|
||||
)
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Input video (URL, data URI, or local path)"
|
||||
)
|
||||
script: str = SchemaField(description="Narration script text")
|
||||
voice_id: str = SchemaField(
|
||||
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
)
|
||||
model_id: Literal[
|
||||
"eleven_multilingual_v2",
|
||||
"eleven_flash_v2_5",
|
||||
"eleven_turbo_v2_5",
|
||||
"eleven_turbo_v2",
|
||||
] = SchemaField(
|
||||
description="ElevenLabs TTS model",
|
||||
default="eleven_multilingual_v2",
|
||||
)
|
||||
mix_mode: Literal["replace", "mix", "ducking"] = SchemaField(
|
||||
description="How to combine with original audio. 'ducking' applies stronger attenuation than 'mix'.",
|
||||
default="ducking",
|
||||
)
|
||||
narration_volume: float = SchemaField(
|
||||
description="Narration volume (0.0 to 2.0)",
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
advanced=True,
|
||||
)
|
||||
original_volume: float = SchemaField(
|
||||
description="Original audio volume when mixing (0.0 to 1.0)",
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Video with narration (path or data URI)"
|
||||
)
|
||||
audio_file: MediaFileType = SchemaField(
|
||||
description="Generated audio file (path or data URI)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3d036b53-859c-4b17-9826-ca340f736e0e",
|
||||
description="Generate AI narration and add to video",
|
||||
categories={BlockCategory.MULTIMEDIA, BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"video_in": "/tmp/test.mp4",
|
||||
"script": "Hello world",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("video_out", str), ("audio_file", str)],
|
||||
test_mock={
|
||||
"_generate_narration_audio": lambda *args: b"mock audio content",
|
||||
"_add_narration_to_video": lambda *args: None,
|
||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||
"_store_output_video": lambda *args, **kwargs: "narrated_test.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_input_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store input video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _generate_narration_audio(
|
||||
self, api_key: str, script: str, voice_id: str, model_id: str
|
||||
) -> bytes:
|
||||
"""Generate narration audio via ElevenLabs API."""
|
||||
client = ElevenLabs(api_key=api_key)
|
||||
audio_generator = client.text_to_speech.convert(
|
||||
voice_id=voice_id,
|
||||
text=script,
|
||||
model_id=model_id,
|
||||
)
|
||||
# The SDK returns a generator, collect all chunks
|
||||
return b"".join(audio_generator)
|
||||
|
||||
def _add_narration_to_video(
|
||||
self,
|
||||
video_abspath: str,
|
||||
audio_abspath: str,
|
||||
output_abspath: str,
|
||||
mix_mode: str,
|
||||
narration_volume: float,
|
||||
original_volume: float,
|
||||
) -> None:
|
||||
"""Add narration audio to video. Extracted for testability."""
|
||||
video = None
|
||||
final = None
|
||||
narration_original = None
|
||||
narration_scaled = None
|
||||
original = None
|
||||
|
||||
try:
|
||||
video = VideoFileClip(video_abspath)
|
||||
narration_original = AudioFileClip(audio_abspath)
|
||||
narration_scaled = narration_original.with_volume_scaled(narration_volume)
|
||||
narration = narration_scaled
|
||||
|
||||
if mix_mode == "replace":
|
||||
final_audio = narration
|
||||
elif mix_mode == "mix":
|
||||
if video.audio:
|
||||
original = video.audio.with_volume_scaled(original_volume)
|
||||
final_audio = CompositeAudioClip([original, narration])
|
||||
else:
|
||||
final_audio = narration
|
||||
else: # ducking - apply stronger attenuation
|
||||
if video.audio:
|
||||
# Ducking uses a much lower volume for original audio
|
||||
ducking_volume = original_volume * 0.3
|
||||
original = video.audio.with_volume_scaled(ducking_volume)
|
||||
final_audio = CompositeAudioClip([original, narration])
|
||||
else:
|
||||
final_audio = narration
|
||||
|
||||
final = video.with_audio(final_audio)
|
||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||
final.write_videofile(
|
||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||
)
|
||||
|
||||
finally:
|
||||
if original:
|
||||
original.close()
|
||||
if narration_scaled:
|
||||
narration_scaled.close()
|
||||
if narration_original:
|
||||
narration_original.close()
|
||||
if final:
|
||||
final.close()
|
||||
if video:
|
||||
video.close()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: ElevenLabsCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Store the input video locally
|
||||
local_video_path = await self._store_input_video(
|
||||
execution_context, input_data.video_in
|
||||
)
|
||||
video_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_video_path
|
||||
)
|
||||
|
||||
# Generate narration audio via ElevenLabs
|
||||
audio_content = self._generate_narration_audio(
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.script,
|
||||
input_data.voice_id,
|
||||
input_data.model_id,
|
||||
)
|
||||
|
||||
# Save audio to exec file path
|
||||
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
|
||||
audio_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, audio_filename
|
||||
)
|
||||
os.makedirs(os.path.dirname(audio_abspath), exist_ok=True)
|
||||
with open(audio_abspath, "wb") as f:
|
||||
f.write(audio_content)
|
||||
|
||||
# Add narration to video
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_narrated_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, output_filename
|
||||
)
|
||||
|
||||
self._add_narration_to_video(
|
||||
video_abspath,
|
||||
audio_abspath,
|
||||
output_abspath,
|
||||
input_data.mix_mode,
|
||||
input_data.narration_volume,
|
||||
input_data.original_volume,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, output_filename
|
||||
)
|
||||
audio_out = await self._store_output_video(
|
||||
execution_context, audio_filename
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
yield "audio_file", audio_out
|
||||
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to add narration: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
227
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
227
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""VideoTextOverlayBlock - Add text overlay to video."""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from moviepy import CompositeVideoClip, TextClip
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.video._utils import get_video_codecs
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoTextOverlayBlock(Block):
|
||||
"""Add text overlay/caption to video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Input video (URL, data URI, or local path)"
|
||||
)
|
||||
text: str = SchemaField(description="Text to overlay on video")
|
||||
position: Literal[
|
||||
"top",
|
||||
"center",
|
||||
"bottom",
|
||||
"top-left",
|
||||
"top-right",
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
] = SchemaField(description="Position of text on screen", default="bottom")
|
||||
start_time: float | None = SchemaField(
|
||||
description="When to show text (seconds). None = entire video",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
end_time: float | None = SchemaField(
|
||||
description="When to hide text (seconds). None = until end",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
font_size: int = SchemaField(
|
||||
description="Font size", default=48, ge=12, le=200, advanced=True
|
||||
)
|
||||
font_color: str = SchemaField(
|
||||
description="Font color (hex or name)", default="white", advanced=True
|
||||
)
|
||||
bg_color: str | None = SchemaField(
|
||||
description="Background color behind text (None for transparent)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Video with text overlay (path or data URI)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8ef14de6-cc90-430a-8cfa-3a003be92454",
|
||||
description="Add text overlay/caption to video",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={"video_in": "/tmp/test.mp4", "text": "Hello World"},
|
||||
test_output=[("video_out", str)],
|
||||
test_mock={
|
||||
"_add_text_overlay": lambda *args: None,
|
||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||
"_store_output_video": lambda *args, **kwargs: "overlay_test.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_input_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store input video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _add_text_overlay(
|
||||
self,
|
||||
video_abspath: str,
|
||||
output_abspath: str,
|
||||
text: str,
|
||||
position: str,
|
||||
start_time: float | None,
|
||||
end_time: float | None,
|
||||
font_size: int,
|
||||
font_color: str,
|
||||
bg_color: str | None,
|
||||
) -> None:
|
||||
"""Add text overlay to video. Extracted for testability."""
|
||||
video = None
|
||||
final = None
|
||||
txt_clip = None
|
||||
try:
|
||||
video = VideoFileClip(video_abspath)
|
||||
|
||||
txt_clip = TextClip(
|
||||
text=text,
|
||||
font_size=font_size,
|
||||
color=font_color,
|
||||
bg_color=bg_color,
|
||||
)
|
||||
|
||||
# Position mapping
|
||||
pos_map = {
|
||||
"top": ("center", "top"),
|
||||
"center": ("center", "center"),
|
||||
"bottom": ("center", "bottom"),
|
||||
"top-left": ("left", "top"),
|
||||
"top-right": ("right", "top"),
|
||||
"bottom-left": ("left", "bottom"),
|
||||
"bottom-right": ("right", "bottom"),
|
||||
}
|
||||
|
||||
txt_clip = txt_clip.with_position(pos_map[position])
|
||||
|
||||
# Set timing
|
||||
start = start_time or 0
|
||||
end = end_time or video.duration
|
||||
duration = max(0, end - start)
|
||||
txt_clip = txt_clip.with_start(start).with_end(end).with_duration(duration)
|
||||
|
||||
final = CompositeVideoClip([video, txt_clip])
|
||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||
final.write_videofile(
|
||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||
)
|
||||
|
||||
finally:
|
||||
if txt_clip:
|
||||
txt_clip.close()
|
||||
if final:
|
||||
final.close()
|
||||
if video:
|
||||
video.close()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Validate time range if both are provided
|
||||
if (
|
||||
input_data.start_time is not None
|
||||
and input_data.end_time is not None
|
||||
and input_data.end_time <= input_data.start_time
|
||||
):
|
||||
raise BlockExecutionError(
|
||||
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
)
|
||||
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Store the input video locally
|
||||
local_video_path = await self._store_input_video(
|
||||
execution_context, input_data.video_in
|
||||
)
|
||||
video_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_video_path
|
||||
)
|
||||
|
||||
# Build output path
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_overlay_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, output_filename
|
||||
)
|
||||
|
||||
self._add_text_overlay(
|
||||
video_abspath,
|
||||
output_abspath,
|
||||
input_data.text,
|
||||
input_data.position,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
input_data.font_size,
|
||||
input_data.font_color,
|
||||
input_data.bg_color,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, output_filename
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
|
||||
except BlockExecutionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to add text overlay: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
@@ -873,13 +873,14 @@ def is_block_auth_configured(
|
||||
|
||||
|
||||
async def initialize_blocks() -> None:
|
||||
# First, sync all provider costs to blocks
|
||||
# Imported here to avoid circular import
|
||||
from backend.sdk.cost_integration import sync_all_provider_costs
|
||||
from backend.util.retry import func_retry
|
||||
|
||||
sync_all_provider_costs()
|
||||
|
||||
@func_retry
|
||||
async def sync_block_to_db(block: Block) -> None:
|
||||
for cls in get_blocks().values():
|
||||
block = cls()
|
||||
existing_block = await AgentBlock.prisma().find_first(
|
||||
where={"OR": [{"id": block.id}, {"name": block.name}]}
|
||||
)
|
||||
@@ -892,7 +893,7 @@ async def initialize_blocks() -> None:
|
||||
outputSchema=json.dumps(block.output_schema.jsonschema()),
|
||||
)
|
||||
)
|
||||
return
|
||||
continue
|
||||
|
||||
input_schema = json.dumps(block.input_schema.jsonschema())
|
||||
output_schema = json.dumps(block.output_schema.jsonschema())
|
||||
@@ -912,25 +913,6 @@ async def initialize_blocks() -> None:
|
||||
},
|
||||
)
|
||||
|
||||
failed_blocks: list[str] = []
|
||||
for cls in get_blocks().values():
|
||||
block = cls()
|
||||
try:
|
||||
await sync_block_to_db(block)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to sync block {block.name} to database: {e}. "
|
||||
"Block is still available in memory.",
|
||||
exc_info=True,
|
||||
)
|
||||
failed_blocks.append(block.name)
|
||||
|
||||
if failed_blocks:
|
||||
logger.error(
|
||||
f"Failed to sync {len(failed_blocks)} block(s) to database: "
|
||||
f"{', '.join(failed_blocks)}. These blocks are still available in memory."
|
||||
)
|
||||
|
||||
|
||||
# Note on the return type annotation: https://github.com/microsoft/pyright/issues/10281
|
||||
def get_block(block_id: str) -> AnyBlockSchema | None:
|
||||
|
||||
@@ -36,12 +36,14 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
from backend.data.block import Block, BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
anthropic_credentials,
|
||||
apollo_credentials,
|
||||
did_credentials,
|
||||
elevenlabs_credentials,
|
||||
enrichlayer_credentials,
|
||||
groq_credentials,
|
||||
ideogram_credentials,
|
||||
@@ -81,6 +83,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||
@@ -639,4 +642,16 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
),
|
||||
],
|
||||
VideoNarrationBlock: [
|
||||
BlockCost(
|
||||
cost_amount=5, # ElevenLabs TTS cost
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": elevenlabs_credentials.id,
|
||||
"provider": elevenlabs_credentials.provider,
|
||||
"type": elevenlabs_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
@@ -133,23 +133,10 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
|
||||
|
||||
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
def __init__(self):
|
||||
self._pubsub: AsyncPubSub | None = None
|
||||
|
||||
@property
|
||||
async def connection(self) -> redis.AsyncRedis:
|
||||
return await redis.get_redis_async()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the PubSub connection if it exists."""
|
||||
if self._pubsub is not None:
|
||||
try:
|
||||
await self._pubsub.close()
|
||||
except Exception:
|
||||
logger.warning("Failed to close PubSub connection", exc_info=True)
|
||||
finally:
|
||||
self._pubsub = None
|
||||
|
||||
async def publish_event(self, event: M, channel_key: str):
|
||||
"""
|
||||
Publish an event to Redis. Gracefully handles connection failures
|
||||
@@ -170,7 +157,6 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
await self.connection, channel_key
|
||||
)
|
||||
assert isinstance(pubsub, AsyncPubSub)
|
||||
self._pubsub = pubsub
|
||||
|
||||
if "*" in channel_key:
|
||||
await pubsub.psubscribe(full_channel_name)
|
||||
|
||||
@@ -1028,39 +1028,6 @@ async def get_graph(
|
||||
return GraphModel.from_db(graph, for_export)
|
||||
|
||||
|
||||
async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]:
|
||||
"""Batch-fetch multiple store-listed graphs by their IDs.
|
||||
|
||||
Only returns graphs that have approved store listings (publicly available).
|
||||
Does not require permission checks since store-listed graphs are public.
|
||||
|
||||
Args:
|
||||
*graph_ids: Variable number of graph IDs to fetch
|
||||
|
||||
Returns:
|
||||
Dict mapping graph_id to GraphModel for graphs with approved store listings
|
||||
"""
|
||||
if not graph_ids:
|
||||
return {}
|
||||
|
||||
store_listings = await StoreListingVersion.prisma().find_many(
|
||||
where={
|
||||
"agentGraphId": {"in": list(graph_ids)},
|
||||
"submissionStatus": SubmissionStatus.APPROVED,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
||||
distinct=["agentGraphId"],
|
||||
order={"agentGraphVersion": "desc"},
|
||||
)
|
||||
|
||||
return {
|
||||
listing.agentGraphId: GraphModel.from_db(listing.AgentGraph)
|
||||
for listing in store_listings
|
||||
if listing.AgentGraph
|
||||
}
|
||||
|
||||
|
||||
async def get_graph_as_admin(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
|
||||
@@ -666,16 +666,10 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
if not (self.discriminator and self.discriminator_mapping):
|
||||
return self
|
||||
|
||||
try:
|
||||
provider = self.discriminator_mapping[discriminator_value]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Model '{discriminator_value}' is not supported. "
|
||||
"It may have been deprecated. Please update your agent configuration."
|
||||
)
|
||||
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_provider=frozenset(
|
||||
[self.discriminator_mapping[discriminator_value]]
|
||||
),
|
||||
credentials_types=self.supported_types,
|
||||
credentials_scopes=self.required_scopes,
|
||||
discriminator=self.discriminator,
|
||||
|
||||
@@ -17,7 +17,6 @@ from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
get_marketplace_graphs_for_monitoring,
|
||||
)
|
||||
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
create_graph_execution,
|
||||
@@ -220,9 +219,6 @@ class DatabaseManager(AppService):
|
||||
# Onboarding
|
||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||
|
||||
# OAuth
|
||||
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
@@ -353,9 +349,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
# Onboarding
|
||||
increment_onboarding_runs = d.increment_onboarding_runs
|
||||
|
||||
# OAuth
|
||||
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
||||
|
||||
# Store
|
||||
get_store_agents = d.get_store_agents
|
||||
get_store_agent_details = d.get_store_agent_details
|
||||
|
||||
@@ -24,9 +24,11 @@ from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.monitoring import (
|
||||
NotificationJobArgs,
|
||||
@@ -36,11 +38,7 @@ from backend.monitoring import (
|
||||
report_execution_accuracy_alerts,
|
||||
report_late_executions,
|
||||
)
|
||||
from backend.util.clients import (
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_scheduler_client,
|
||||
)
|
||||
from backend.util.clients import get_database_manager_client, get_scheduler_client
|
||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||
from backend.util.exceptions import (
|
||||
GraphNotFoundError,
|
||||
@@ -150,7 +148,6 @@ def execute_graph(**kwargs):
|
||||
async def _execute_graph(**kwargs):
|
||||
args = GraphExecutionJobArgs(**kwargs)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
db = get_database_manager_async_client()
|
||||
try:
|
||||
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
||||
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
||||
@@ -160,7 +157,7 @@ async def _execute_graph(**kwargs):
|
||||
inputs=args.input_data,
|
||||
graph_credentials_inputs=args.input_credentials,
|
||||
)
|
||||
await db.increment_onboarding_runs(args.user_id)
|
||||
await increment_onboarding_runs(args.user_id)
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
logger.info(
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
||||
@@ -249,13 +246,8 @@ def cleanup_expired_files():
|
||||
|
||||
def cleanup_oauth_tokens():
|
||||
"""Clean up expired OAuth tokens from the database."""
|
||||
|
||||
# Wait for completion
|
||||
async def _cleanup():
|
||||
db = get_database_manager_async_client()
|
||||
return await db.cleanup_expired_oauth_tokens()
|
||||
|
||||
run_async(_cleanup())
|
||||
run_async(cleanup_expired_oauth_tokens())
|
||||
|
||||
|
||||
def execution_accuracy_alerts():
|
||||
|
||||
@@ -224,6 +224,14 @@ openweathermap_credentials = APIKeyCredentials(
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
elevenlabs_credentials = APIKeyCredentials(
|
||||
id="f4a8b6c2-3d1e-4f5a-9b8c-7d6e5f4a3b2c",
|
||||
provider="elevenlabs",
|
||||
api_key=SecretStr(settings.secrets.elevenlabs_api_key),
|
||||
title="Use Credits for ElevenLabs",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
DEFAULT_CREDENTIALS = [
|
||||
ollama_credentials,
|
||||
revid_credentials,
|
||||
@@ -252,6 +260,7 @@ DEFAULT_CREDENTIALS = [
|
||||
v0_credentials,
|
||||
webshare_proxy_credentials,
|
||||
openweathermap_credentials,
|
||||
elevenlabs_credentials,
|
||||
]
|
||||
|
||||
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
||||
@@ -366,6 +375,8 @@ class IntegrationCredentialsStore:
|
||||
all_credentials.append(webshare_proxy_credentials)
|
||||
if settings.secrets.openweathermap_api_key:
|
||||
all_credentials.append(openweathermap_credentials)
|
||||
if settings.secrets.elevenlabs_api_key:
|
||||
all_credentials.append(elevenlabs_credentials)
|
||||
return all_credentials
|
||||
|
||||
async def get_creds_by_id(
|
||||
|
||||
@@ -18,6 +18,7 @@ class ProviderName(str, Enum):
|
||||
DISCORD = "discord"
|
||||
D_ID = "d_id"
|
||||
E2B = "e2b"
|
||||
ELEVENLABS = "elevenlabs"
|
||||
FAL = "fal"
|
||||
GITHUB = "github"
|
||||
GOOGLE = "google"
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fastapi
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from backend.api.features.integrations.router import router as integrations_router
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import utils as webhooks_utils
|
||||
|
||||
|
||||
def test_webhook_ingress_url_matches_route(monkeypatch) -> None:
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(integrations_router, prefix="/api/integrations")
|
||||
|
||||
provider = ProviderName.GITHUB
|
||||
webhook_id = "webhook_123"
|
||||
base_url = "https://example.com"
|
||||
|
||||
monkeypatch.setattr(webhooks_utils.app_config, "platform_base_url", base_url)
|
||||
|
||||
route = next(
|
||||
route
|
||||
for route in integrations_router.routes
|
||||
if isinstance(route, APIRoute)
|
||||
and route.path == "/{provider}/webhooks/{webhook_id}/ingress"
|
||||
and "POST" in route.methods
|
||||
)
|
||||
expected_path = f"/api/integrations{route.path}".format(
|
||||
provider=provider.value,
|
||||
webhook_id=webhook_id,
|
||||
)
|
||||
actual_url = urlparse(webhooks_utils.webhook_ingress_url(provider, webhook_id))
|
||||
expected_base = urlparse(base_url)
|
||||
|
||||
assert (actual_url.scheme, actual_url.netloc) == (
|
||||
expected_base.scheme,
|
||||
expected_base.netloc,
|
||||
)
|
||||
assert actual_url.path == expected_path
|
||||
@@ -1,19 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.util import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------#
|
||||
# CONSTANTS #
|
||||
# ---------------------------------------------------------------------------#
|
||||
@@ -109,17 +100,9 @@ def _is_objective_message(msg: dict) -> bool:
|
||||
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
|
||||
"""
|
||||
Carefully truncate tool message content while preserving tool structure.
|
||||
Handles both Anthropic-style (list content) and OpenAI-style (string content) tool messages.
|
||||
Only truncates tool_result content, leaves tool_use intact.
|
||||
"""
|
||||
content = msg.get("content")
|
||||
|
||||
# OpenAI-style tool message: role="tool" with string content
|
||||
if msg.get("role") == "tool" and isinstance(content, str):
|
||||
if _tok_len(content, enc) > max_tokens:
|
||||
msg["content"] = _truncate_middle_tokens(content, enc, max_tokens)
|
||||
return
|
||||
|
||||
# Anthropic-style: list content with tool_result items
|
||||
if not isinstance(content, list):
|
||||
return
|
||||
|
||||
@@ -157,6 +140,141 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
# ---------------------------------------------------------------------------#
|
||||
|
||||
|
||||
def compress_prompt(
|
||||
messages: list[dict],
|
||||
target_tokens: int,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
reserve: int = 2_048,
|
||||
start_cap: int = 8_192,
|
||||
floor_cap: int = 128,
|
||||
lossy_ok: bool = True,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Shrink *messages* so that::
|
||||
|
||||
token_count(prompt) + reserve ≤ target_tokens
|
||||
|
||||
Strategy
|
||||
--------
|
||||
1. **Token-aware truncation** – progressively halve a per-message cap
|
||||
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
|
||||
*content* of every message except the first and last. Tool shells
|
||||
are included: we keep the envelope but shorten huge payloads.
|
||||
2. **Middle-out deletion** – if still over the limit, delete whole
|
||||
messages working outward from the centre, **skipping** any message
|
||||
that contains ``tool_calls`` or has ``role == "tool"``.
|
||||
3. **Last-chance trim** – if still too big, truncate the *first* and
|
||||
*last* message bodies down to `floor_cap` tokens.
|
||||
4. If the prompt is *still* too large:
|
||||
• raise ``ValueError`` when ``lossy_ok == False`` (default)
|
||||
• return the partially-trimmed prompt when ``lossy_ok == True``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
messages Complete chat history (will be deep-copied).
|
||||
model Model name; passed to tiktoken to pick the right
|
||||
tokenizer (gpt-4o → 'o200k_base', others fallback).
|
||||
target_tokens Hard ceiling for prompt size **excluding** the model's
|
||||
forthcoming answer.
|
||||
reserve How many tokens you want to leave available for that
|
||||
answer (`max_tokens` in your subsequent completion call).
|
||||
start_cap Initial per-message truncation ceiling (tokens).
|
||||
floor_cap Lowest cap we'll accept before moving to deletions.
|
||||
lossy_ok If *True* return best-effort prompt instead of raising
|
||||
after all trim passes have been exhausted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[dict] – A *new* messages list that abides by the rules above.
|
||||
"""
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
msgs = deepcopy(messages) # never mutate caller
|
||||
|
||||
def total_tokens() -> int:
|
||||
"""Current size of *msgs* in tokens."""
|
||||
return sum(_msg_tokens(m, enc) for m in msgs)
|
||||
|
||||
original_token_count = total_tokens()
|
||||
|
||||
if original_token_count + reserve <= target_tokens:
|
||||
return msgs
|
||||
|
||||
# ---- STEP 0 : normalise content --------------------------------------
|
||||
# Convert non-string payloads to strings so token counting is coherent.
|
||||
for i, m in enumerate(msgs):
|
||||
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
||||
if _is_tool_message(m):
|
||||
continue
|
||||
|
||||
# Keep first and last messages intact (unless they're tool messages)
|
||||
if i == 0 or i == len(msgs) - 1:
|
||||
continue
|
||||
|
||||
# Reasonable 20k-char ceiling prevents pathological blobs
|
||||
content_str = json.dumps(m["content"], separators=(",", ":"))
|
||||
if len(content_str) > 20_000:
|
||||
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
||||
m["content"] = content_str
|
||||
|
||||
# ---- STEP 1 : token-aware truncation ---------------------------------
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for m in msgs[1:-1]: # keep first & last intact
|
||||
if _is_tool_message(m):
|
||||
# For tool messages, only truncate tool result content, preserve structure
|
||||
_truncate_tool_message_content(m, enc, cap)
|
||||
continue
|
||||
|
||||
if _is_objective_message(m):
|
||||
# Never truncate objective messages - they contain the core task
|
||||
continue
|
||||
|
||||
content = m.get("content") or ""
|
||||
if _tok_len(content, enc) > cap:
|
||||
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
||||
cap //= 2 # tighten the screw
|
||||
|
||||
# ---- STEP 2 : middle-out deletion -----------------------------------
|
||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||
# Identify all deletable messages (not first/last, not tool messages, not objective messages)
|
||||
deletable_indices = []
|
||||
for i in range(1, len(msgs) - 1): # Skip first and last
|
||||
if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]):
|
||||
deletable_indices.append(i)
|
||||
|
||||
if not deletable_indices:
|
||||
break # nothing more we can drop
|
||||
|
||||
# Delete from center outward - find the index closest to center
|
||||
centre = len(msgs) // 2
|
||||
to_delete = min(deletable_indices, key=lambda i: abs(i - centre))
|
||||
del msgs[to_delete]
|
||||
|
||||
# ---- STEP 3 : final safety-net trim on first & last ------------------
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for idx in (0, -1): # first and last
|
||||
if _is_tool_message(msgs[idx]):
|
||||
# For tool messages at first/last position, truncate tool result content only
|
||||
_truncate_tool_message_content(msgs[idx], enc, cap)
|
||||
continue
|
||||
|
||||
text = msgs[idx].get("content") or ""
|
||||
if _tok_len(text, enc) > cap:
|
||||
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
|
||||
cap //= 2 # tighten the screw
|
||||
|
||||
# ---- STEP 4 : success or fail-gracefully -----------------------------
|
||||
if total_tokens() + reserve > target_tokens and not lossy_ok:
|
||||
raise ValueError(
|
||||
"compress_prompt: prompt still exceeds budget "
|
||||
f"({total_tokens() + reserve} > {target_tokens})."
|
||||
)
|
||||
|
||||
return msgs
|
||||
|
||||
|
||||
def estimate_token_count(
|
||||
messages: list[dict],
|
||||
*,
|
||||
@@ -175,8 +293,7 @@ def estimate_token_count(
|
||||
-------
|
||||
int – Token count.
|
||||
"""
|
||||
token_model = _normalize_model_for_tokenizer(model)
|
||||
enc = encoding_for_model(token_model)
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
return sum(_msg_tokens(m, enc) for m in messages)
|
||||
|
||||
|
||||
@@ -198,543 +315,6 @@ def estimate_token_count_str(
|
||||
-------
|
||||
int – Token count.
|
||||
"""
|
||||
token_model = _normalize_model_for_tokenizer(model)
|
||||
enc = encoding_for_model(token_model)
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
text = json.dumps(text) if not isinstance(text, str) else text
|
||||
return _tok_len(text, enc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------#
|
||||
# UNIFIED CONTEXT COMPRESSION #
|
||||
# ---------------------------------------------------------------------------#
|
||||
|
||||
# Default thresholds
|
||||
DEFAULT_TOKEN_THRESHOLD = 120_000
|
||||
DEFAULT_KEEP_RECENT = 15
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressResult:
|
||||
"""Result of context compression."""
|
||||
|
||||
messages: list[dict]
|
||||
token_count: int
|
||||
was_compacted: bool
|
||||
error: str | None = None
|
||||
original_token_count: int = 0
|
||||
messages_summarized: int = 0
|
||||
messages_dropped: int = 0
|
||||
|
||||
|
||||
def _normalize_model_for_tokenizer(model: str) -> str:
|
||||
"""Normalize model name for tiktoken tokenizer selection."""
|
||||
if "/" in model:
|
||||
model = model.split("/")[-1]
|
||||
if "claude" in model.lower() or not any(
|
||||
known in model.lower() for known in ["gpt", "o1", "chatgpt", "text-"]
|
||||
):
|
||||
return "gpt-4o"
|
||||
return model
|
||||
|
||||
|
||||
def _extract_tool_call_ids_from_message(msg: dict) -> set[str]:
|
||||
"""
|
||||
Extract tool_call IDs from an assistant message.
|
||||
|
||||
Supports both formats:
|
||||
- OpenAI: {"role": "assistant", "tool_calls": [{"id": "..."}]}
|
||||
- Anthropic: {"role": "assistant", "content": [{"type": "tool_use", "id": "..."}]}
|
||||
|
||||
Returns:
|
||||
Set of tool_call IDs found in the message.
|
||||
"""
|
||||
ids: set[str] = set()
|
||||
if msg.get("role") != "assistant":
|
||||
return ids
|
||||
|
||||
# OpenAI format: tool_calls array
|
||||
if msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id:
|
||||
ids.add(tc_id)
|
||||
|
||||
# Anthropic format: content list with tool_use blocks
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
tc_id = block.get("id")
|
||||
if tc_id:
|
||||
ids.add(tc_id)
|
||||
|
||||
return ids
|
||||
|
||||
|
||||
def _extract_tool_response_ids_from_message(msg: dict) -> set[str]:
|
||||
"""
|
||||
Extract tool_call IDs that this message is responding to.
|
||||
|
||||
Supports both formats:
|
||||
- OpenAI: {"role": "tool", "tool_call_id": "..."}
|
||||
- Anthropic: {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "..."}]}
|
||||
|
||||
Returns:
|
||||
Set of tool_call IDs this message responds to.
|
||||
"""
|
||||
ids: set[str] = set()
|
||||
|
||||
# OpenAI format: role=tool with tool_call_id
|
||||
if msg.get("role") == "tool":
|
||||
tc_id = msg.get("tool_call_id")
|
||||
if tc_id:
|
||||
ids.add(tc_id)
|
||||
|
||||
# Anthropic format: content list with tool_result blocks
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
tc_id = block.get("tool_use_id")
|
||||
if tc_id:
|
||||
ids.add(tc_id)
|
||||
|
||||
return ids
|
||||
|
||||
|
||||
def _is_tool_response_message(msg: dict) -> bool:
|
||||
"""Check if message is a tool response (OpenAI or Anthropic format)."""
|
||||
# OpenAI format
|
||||
if msg.get("role") == "tool":
|
||||
return True
|
||||
# Anthropic format
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _remove_orphan_tool_responses(
|
||||
messages: list[dict], orphan_ids: set[str]
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Remove tool response messages/blocks that reference orphan tool_call IDs.
|
||||
|
||||
Supports both OpenAI and Anthropic formats.
|
||||
For Anthropic messages with mixed valid/orphan tool_result blocks,
|
||||
filters out only the orphan blocks instead of dropping the entire message.
|
||||
"""
|
||||
result = []
|
||||
for msg in messages:
|
||||
# OpenAI format: role=tool - drop entire message if orphan
|
||||
if msg.get("role") == "tool":
|
||||
tc_id = msg.get("tool_call_id")
|
||||
if tc_id and tc_id in orphan_ids:
|
||||
continue
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
# Anthropic format: content list may have mixed tool_result blocks
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
has_tool_results = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result" for b in content
|
||||
)
|
||||
if has_tool_results:
|
||||
# Filter out orphan tool_result blocks, keep valid ones
|
||||
filtered_content = [
|
||||
block
|
||||
for block in content
|
||||
if not (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "tool_result"
|
||||
and block.get("tool_use_id") in orphan_ids
|
||||
)
|
||||
]
|
||||
# Only keep message if it has remaining content
|
||||
if filtered_content:
|
||||
msg = msg.copy()
|
||||
msg["content"] = filtered_content
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_tool_pairs_intact(
|
||||
recent_messages: list[dict],
|
||||
all_messages: list[dict],
|
||||
start_index: int,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Ensure tool_call/tool_response pairs stay together after slicing.
|
||||
|
||||
When slicing messages for context compaction, a naive slice can separate
|
||||
an assistant message containing tool_calls from its corresponding tool
|
||||
response messages. This causes API validation errors (e.g., Anthropic's
|
||||
"unexpected tool_use_id found in tool_result blocks").
|
||||
|
||||
This function checks for orphan tool responses in the slice and extends
|
||||
backwards to include their corresponding assistant messages.
|
||||
|
||||
Supports both formats:
|
||||
- OpenAI: tool_calls array + role="tool" responses
|
||||
- Anthropic: tool_use blocks + tool_result blocks
|
||||
|
||||
Args:
|
||||
recent_messages: The sliced messages to validate
|
||||
all_messages: The complete message list (for looking up missing assistants)
|
||||
start_index: The index in all_messages where recent_messages begins
|
||||
|
||||
Returns:
|
||||
A potentially extended list of messages with tool pairs intact
|
||||
"""
|
||||
if not recent_messages:
|
||||
return recent_messages
|
||||
|
||||
# Collect all tool_call_ids from assistant messages in the slice
|
||||
available_tool_call_ids: set[str] = set()
|
||||
for msg in recent_messages:
|
||||
available_tool_call_ids |= _extract_tool_call_ids_from_message(msg)
|
||||
|
||||
# Find orphan tool responses (responses whose tool_call_id is missing)
|
||||
orphan_tool_call_ids: set[str] = set()
|
||||
for msg in recent_messages:
|
||||
response_ids = _extract_tool_response_ids_from_message(msg)
|
||||
for tc_id in response_ids:
|
||||
if tc_id not in available_tool_call_ids:
|
||||
orphan_tool_call_ids.add(tc_id)
|
||||
|
||||
if not orphan_tool_call_ids:
|
||||
# No orphans, slice is valid
|
||||
return recent_messages
|
||||
|
||||
# Find the assistant messages that contain the orphan tool_call_ids
|
||||
# Search backwards from start_index in all_messages
|
||||
messages_to_prepend: list[dict] = []
|
||||
for i in range(start_index - 1, -1, -1):
|
||||
msg = all_messages[i]
|
||||
msg_tool_ids = _extract_tool_call_ids_from_message(msg)
|
||||
if msg_tool_ids & orphan_tool_call_ids:
|
||||
# This assistant message has tool_calls we need
|
||||
# Also collect its contiguous tool responses that follow it
|
||||
assistant_and_responses: list[dict] = [msg]
|
||||
|
||||
# Scan forward from this assistant to collect tool responses
|
||||
for j in range(i + 1, start_index):
|
||||
following_msg = all_messages[j]
|
||||
following_response_ids = _extract_tool_response_ids_from_message(
|
||||
following_msg
|
||||
)
|
||||
if following_response_ids and following_response_ids & msg_tool_ids:
|
||||
assistant_and_responses.append(following_msg)
|
||||
elif not _is_tool_response_message(following_msg):
|
||||
# Stop at first non-tool-response message
|
||||
break
|
||||
|
||||
# Prepend the assistant and its tool responses (maintain order)
|
||||
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
||||
# Mark these as found
|
||||
orphan_tool_call_ids -= msg_tool_ids
|
||||
# Also add this assistant's tool_call_ids to available set
|
||||
available_tool_call_ids |= msg_tool_ids
|
||||
|
||||
if not orphan_tool_call_ids:
|
||||
# Found all missing assistants
|
||||
break
|
||||
|
||||
if orphan_tool_call_ids:
|
||||
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||
# This shouldn't happen in normal operation but handles edge cases
|
||||
logger.warning(
|
||||
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||
"Removing orphan tool responses."
|
||||
)
|
||||
recent_messages = _remove_orphan_tool_responses(
|
||||
recent_messages, orphan_tool_call_ids
|
||||
)
|
||||
|
||||
if messages_to_prepend:
|
||||
logger.info(
|
||||
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||
f"tool_call/tool_response pairs"
|
||||
)
|
||||
return messages_to_prepend + recent_messages
|
||||
|
||||
return recent_messages
|
||||
|
||||
|
||||
async def _summarize_messages_llm(
|
||||
messages: list[dict],
|
||||
client: AsyncOpenAI,
|
||||
model: str,
|
||||
timeout: float = 30.0,
|
||||
) -> str:
|
||||
"""Summarize messages using an LLM."""
|
||||
conversation = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
if content and role in ("user", "assistant", "tool"):
|
||||
conversation.append(f"{role.upper()}: {content}")
|
||||
|
||||
conversation_text = "\n\n".join(conversation)
|
||||
|
||||
if not conversation_text:
|
||||
return "No conversation history available."
|
||||
|
||||
# Limit to ~100k chars for safety
|
||||
MAX_CHARS = 100_000
|
||||
if len(conversation_text) > MAX_CHARS:
|
||||
conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]"
|
||||
|
||||
response = await client.with_options(timeout=timeout).chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Create a detailed summary of the conversation so far. "
|
||||
"This summary will be used as context when continuing the conversation.\n\n"
|
||||
"Before writing the summary, analyze each message chronologically to identify:\n"
|
||||
"- User requests and their explicit goals\n"
|
||||
"- Your approach and key decisions made\n"
|
||||
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
||||
"- Errors encountered and resolutions applied\n\n"
|
||||
"You MUST include ALL of the following sections:\n\n"
|
||||
"## 1. Primary Request and Intent\n"
|
||||
"The user's explicit goals and what they are trying to accomplish.\n\n"
|
||||
"## 2. Key Technical Concepts\n"
|
||||
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
|
||||
"## 3. Files and Resources Involved\n"
|
||||
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
||||
"## 4. Errors and Fixes\n"
|
||||
"Problems encountered, error messages, and their resolutions. "
|
||||
"Include any user feedback on fixes.\n\n"
|
||||
"## 5. Problem Solving\n"
|
||||
"Issues that have been resolved and how they were addressed.\n\n"
|
||||
"## 6. All User Messages\n"
|
||||
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
|
||||
"## 7. Pending Tasks\n"
|
||||
"Work items the user explicitly requested that have not yet been completed.\n\n"
|
||||
"## 8. Current Work\n"
|
||||
"Precise description of what was being worked on most recently, including relevant context.\n\n"
|
||||
"## 9. Next Steps\n"
|
||||
"What should happen next, aligned with the user's most recent requests. "
|
||||
"Include verbatim quotes of recent instructions if relevant."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
||||
],
|
||||
max_tokens=1500,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content or "No summary available."
|
||||
|
||||
|
||||
async def compress_context(
|
||||
messages: list[dict],
|
||||
target_tokens: int = DEFAULT_TOKEN_THRESHOLD,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
client: AsyncOpenAI | None = None,
|
||||
keep_recent: int = DEFAULT_KEEP_RECENT,
|
||||
reserve: int = 2_048,
|
||||
start_cap: int = 8_192,
|
||||
floor_cap: int = 128,
|
||||
) -> CompressResult:
|
||||
"""
|
||||
Unified context compression that combines summarization and truncation strategies.
|
||||
|
||||
Strategy (in order):
|
||||
1. **LLM summarization** – If client provided, summarize old messages into a
|
||||
single context message while keeping recent messages intact. This is the
|
||||
primary strategy for chat service.
|
||||
2. **Content truncation** – Progressively halve a per-message cap and truncate
|
||||
bloated message content (tool outputs, large pastes). Preserves all messages
|
||||
but shortens their content. Primary strategy when client=None (LLM blocks).
|
||||
3. **Middle-out deletion** – Delete whole messages one at a time from the center
|
||||
outward, skipping tool messages and objective messages.
|
||||
4. **First/last trim** – Truncate first and last message content as last resort.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
messages Complete chat history (will be deep-copied).
|
||||
target_tokens Hard ceiling for prompt size.
|
||||
model Model name for tokenization and summarization.
|
||||
client AsyncOpenAI client. If provided, enables LLM summarization
|
||||
as the first strategy. If None, skips to truncation strategies.
|
||||
keep_recent Number of recent messages to preserve during summarization.
|
||||
reserve Tokens to reserve for model response.
|
||||
start_cap Initial per-message truncation ceiling (tokens).
|
||||
floor_cap Lowest cap before moving to deletions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
CompressResult with compressed messages and metadata.
|
||||
"""
|
||||
# Guard clause for empty messages
|
||||
if not messages:
|
||||
return CompressResult(
|
||||
messages=[],
|
||||
token_count=0,
|
||||
was_compacted=False,
|
||||
original_token_count=0,
|
||||
)
|
||||
|
||||
token_model = _normalize_model_for_tokenizer(model)
|
||||
enc = encoding_for_model(token_model)
|
||||
msgs = deepcopy(messages)
|
||||
|
||||
def total_tokens() -> int:
|
||||
return sum(_msg_tokens(m, enc) for m in msgs)
|
||||
|
||||
original_count = total_tokens()
|
||||
|
||||
# Already under limit
|
||||
if original_count + reserve <= target_tokens:
|
||||
return CompressResult(
|
||||
messages=msgs,
|
||||
token_count=original_count,
|
||||
was_compacted=False,
|
||||
original_token_count=original_count,
|
||||
)
|
||||
|
||||
messages_summarized = 0
|
||||
messages_dropped = 0
|
||||
|
||||
# ---- STEP 1: LLM summarization (if client provided) -------------------
|
||||
# This is the primary compression strategy for chat service.
|
||||
# Summarize old messages while keeping recent ones intact.
|
||||
if client is not None:
|
||||
has_system = len(msgs) > 0 and msgs[0].get("role") == "system"
|
||||
system_msg = msgs[0] if has_system else None
|
||||
|
||||
# Calculate old vs recent messages
|
||||
if has_system:
|
||||
if len(msgs) > keep_recent + 1:
|
||||
old_msgs = msgs[1:-keep_recent]
|
||||
recent_msgs = msgs[-keep_recent:]
|
||||
else:
|
||||
old_msgs = []
|
||||
recent_msgs = msgs[1:] if len(msgs) > 1 else []
|
||||
else:
|
||||
if len(msgs) > keep_recent:
|
||||
old_msgs = msgs[:-keep_recent]
|
||||
recent_msgs = msgs[-keep_recent:]
|
||||
else:
|
||||
old_msgs = []
|
||||
recent_msgs = msgs
|
||||
|
||||
# Ensure tool pairs stay intact
|
||||
slice_start = max(0, len(msgs) - keep_recent)
|
||||
recent_msgs = _ensure_tool_pairs_intact(recent_msgs, msgs, slice_start)
|
||||
|
||||
if old_msgs:
|
||||
try:
|
||||
summary_text = await _summarize_messages_llm(old_msgs, client, model)
|
||||
summary_msg = {
|
||||
"role": "assistant",
|
||||
"content": f"[Previous conversation summary — for context only]: {summary_text}",
|
||||
}
|
||||
messages_summarized = len(old_msgs)
|
||||
|
||||
if has_system:
|
||||
msgs = [system_msg, summary_msg] + recent_msgs
|
||||
else:
|
||||
msgs = [summary_msg] + recent_msgs
|
||||
|
||||
logger.info(
|
||||
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
||||
f"summarized {messages_summarized} messages"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
||||
# Fall through to content truncation
|
||||
|
||||
# ---- STEP 2: Normalize content ----------------------------------------
|
||||
# Convert non-string payloads to strings so token counting is coherent.
|
||||
# Always run this before truncation to ensure consistent token counting.
|
||||
for i, m in enumerate(msgs):
|
||||
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
||||
if _is_tool_message(m):
|
||||
continue
|
||||
if i == 0 or i == len(msgs) - 1:
|
||||
continue
|
||||
content_str = json.dumps(m["content"], separators=(",", ":"))
|
||||
if len(content_str) > 20_000:
|
||||
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
||||
m["content"] = content_str
|
||||
|
||||
# ---- STEP 3: Token-aware content truncation ---------------------------
|
||||
# Progressively halve per-message cap and truncate bloated content.
|
||||
# This preserves all messages but shortens their content.
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for m in msgs[1:-1]:
|
||||
if _is_tool_message(m):
|
||||
_truncate_tool_message_content(m, enc, cap)
|
||||
continue
|
||||
if _is_objective_message(m):
|
||||
continue
|
||||
content = m.get("content") or ""
|
||||
if _tok_len(content, enc) > cap:
|
||||
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
||||
cap //= 2
|
||||
|
||||
# ---- STEP 4: Middle-out deletion --------------------------------------
|
||||
# Delete messages one at a time from the center outward.
|
||||
# This is more granular than dropping all old messages at once.
|
||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||
deletable: list[int] = []
|
||||
for i in range(1, len(msgs) - 1):
|
||||
msg = msgs[i]
|
||||
if (
|
||||
msg is not None
|
||||
and not _is_tool_message(msg)
|
||||
and not _is_objective_message(msg)
|
||||
):
|
||||
deletable.append(i)
|
||||
if not deletable:
|
||||
break
|
||||
centre = len(msgs) // 2
|
||||
to_delete = min(deletable, key=lambda i: abs(i - centre))
|
||||
del msgs[to_delete]
|
||||
messages_dropped += 1
|
||||
|
||||
# ---- STEP 5: Final trim on first/last ---------------------------------
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for idx in (0, -1):
|
||||
msg = msgs[idx]
|
||||
if msg is None:
|
||||
continue
|
||||
if _is_tool_message(msg):
|
||||
_truncate_tool_message_content(msg, enc, cap)
|
||||
continue
|
||||
text = msg.get("content") or ""
|
||||
if _tok_len(text, enc) > cap:
|
||||
msg["content"] = _truncate_middle_tokens(text, enc, cap)
|
||||
cap //= 2
|
||||
|
||||
# Filter out any None values that may have been introduced
|
||||
final_msgs: list[dict] = [m for m in msgs if m is not None]
|
||||
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
|
||||
error = None
|
||||
if final_count + reserve > target_tokens:
|
||||
error = f"Could not compress below target ({final_count + reserve} > {target_tokens})"
|
||||
logger.warning(error)
|
||||
|
||||
return CompressResult(
|
||||
messages=final_msgs,
|
||||
token_count=final_count,
|
||||
was_compacted=True,
|
||||
error=error,
|
||||
original_token_count=original_count,
|
||||
messages_summarized=messages_summarized,
|
||||
messages_dropped=messages_dropped,
|
||||
)
|
||||
|
||||
@@ -1,21 +1,10 @@
|
||||
"""Tests for prompt utility functions, especially tool call token counting."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import (
|
||||
CompressResult,
|
||||
_ensure_tool_pairs_intact,
|
||||
_msg_tokens,
|
||||
_normalize_model_for_tokenizer,
|
||||
_truncate_middle_tokens,
|
||||
_truncate_tool_message_content,
|
||||
compress_context,
|
||||
estimate_token_count,
|
||||
)
|
||||
from backend.util.prompt import _msg_tokens, estimate_token_count
|
||||
|
||||
|
||||
class TestMsgTokens:
|
||||
@@ -287,690 +276,3 @@ class TestEstimateTokenCount:
|
||||
|
||||
assert total_tokens == expected_total
|
||||
assert total_tokens > 20 # Should be substantial
|
||||
|
||||
|
||||
class TestNormalizeModelForTokenizer:
|
||||
"""Test model name normalization for tiktoken."""
|
||||
|
||||
def test_openai_models_unchanged(self):
|
||||
"""Test that OpenAI models are returned as-is."""
|
||||
assert _normalize_model_for_tokenizer("gpt-4o") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("gpt-4") == "gpt-4"
|
||||
assert _normalize_model_for_tokenizer("gpt-3.5-turbo") == "gpt-3.5-turbo"
|
||||
|
||||
def test_claude_models_normalized(self):
|
||||
"""Test that Claude models are normalized to gpt-4o."""
|
||||
assert _normalize_model_for_tokenizer("claude-3-opus") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("claude-3-sonnet") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("anthropic/claude-3-haiku") == "gpt-4o"
|
||||
|
||||
def test_openrouter_paths_extracted(self):
|
||||
"""Test that OpenRouter model paths are handled."""
|
||||
assert _normalize_model_for_tokenizer("openai/gpt-4o") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("anthropic/claude-3-opus") == "gpt-4o"
|
||||
|
||||
def test_unknown_models_default_to_gpt4o(self):
|
||||
"""Test that unknown models default to gpt-4o."""
|
||||
assert _normalize_model_for_tokenizer("some-random-model") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("llama-3-70b") == "gpt-4o"
|
||||
|
||||
|
||||
class TestTruncateToolMessageContent:
|
||||
"""Test tool message content truncation."""
|
||||
|
||||
@pytest.fixture
|
||||
def enc(self):
|
||||
return encoding_for_model("gpt-4o")
|
||||
|
||||
def test_truncate_openai_tool_message(self, enc):
|
||||
"""Test truncation of OpenAI-style tool message with string content."""
|
||||
long_content = "x" * 10000
|
||||
msg = {"role": "tool", "tool_call_id": "call_123", "content": long_content}
|
||||
|
||||
_truncate_tool_message_content(msg, enc, max_tokens=100)
|
||||
|
||||
# Content should be truncated
|
||||
assert len(msg["content"]) < len(long_content)
|
||||
assert "…" in msg["content"] # Has ellipsis marker
|
||||
|
||||
def test_truncate_anthropic_tool_result(self, enc):
|
||||
"""Test truncation of Anthropic-style tool_result."""
|
||||
long_content = "y" * 10000
|
||||
msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_123",
|
||||
"content": long_content,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
_truncate_tool_message_content(msg, enc, max_tokens=100)
|
||||
|
||||
# Content should be truncated
|
||||
result_content = msg["content"][0]["content"]
|
||||
assert len(result_content) < len(long_content)
|
||||
assert "…" in result_content
|
||||
|
||||
def test_preserve_tool_use_blocks(self, enc):
|
||||
"""Test that tool_use blocks are not truncated."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "some_function",
|
||||
"input": {"key": "value" * 1000}, # Large input
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
original = json.dumps(msg["content"][0]["input"])
|
||||
_truncate_tool_message_content(msg, enc, max_tokens=10)
|
||||
|
||||
# tool_use should be unchanged
|
||||
assert json.dumps(msg["content"][0]["input"]) == original
|
||||
|
||||
def test_no_truncation_when_under_limit(self, enc):
|
||||
"""Test that short content is not modified."""
|
||||
msg = {"role": "tool", "tool_call_id": "call_123", "content": "Short content"}
|
||||
|
||||
original = msg["content"]
|
||||
_truncate_tool_message_content(msg, enc, max_tokens=1000)
|
||||
|
||||
assert msg["content"] == original
|
||||
|
||||
|
||||
class TestTruncateMiddleTokens:
|
||||
"""Test middle truncation of text."""
|
||||
|
||||
@pytest.fixture
|
||||
def enc(self):
|
||||
return encoding_for_model("gpt-4o")
|
||||
|
||||
def test_truncates_long_text(self, enc):
|
||||
"""Test that long text is truncated with ellipsis in middle."""
|
||||
long_text = "word " * 1000
|
||||
result = _truncate_middle_tokens(long_text, enc, max_tok=50)
|
||||
|
||||
assert len(enc.encode(result)) <= 52 # Allow some slack for ellipsis
|
||||
assert "…" in result
|
||||
assert result.startswith("word") # Head preserved
|
||||
assert result.endswith("word ") # Tail preserved
|
||||
|
||||
def test_preserves_short_text(self, enc):
|
||||
"""Test that short text is not modified."""
|
||||
short_text = "Hello world"
|
||||
result = _truncate_middle_tokens(short_text, enc, max_tok=100)
|
||||
|
||||
assert result == short_text
|
||||
|
||||
|
||||
class TestEnsureToolPairsIntact:
|
||||
"""Test tool call/response pair preservation for both OpenAI and Anthropic formats."""
|
||||
|
||||
# ---- OpenAI Format Tests ----
|
||||
|
||||
def test_openai_adds_missing_tool_call(self):
|
||||
"""Test that orphaned OpenAI tool_response gets its tool_call prepended."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (the tool response)
|
||||
recent = [all_msgs[2], all_msgs[3]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the tool_call message
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert "tool_calls" in result[0]
|
||||
|
||||
def test_openai_keeps_complete_pairs(self):
|
||||
"""Test that complete OpenAI pairs are unchanged."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "System"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
|
||||
]
|
||||
recent = all_msgs[1:] # Include both tool_call and response
|
||||
start_index = 1
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
assert len(result) == 2 # No messages added
|
||||
|
||||
def test_openai_multiple_tool_calls(self):
|
||||
"""Test multiple OpenAI tool calls in one assistant message."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "System"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "f1"}},
|
||||
{"id": "call_2", "type": "function", "function": {"name": "f2"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "result1"},
|
||||
{"role": "tool", "tool_call_id": "call_2", "content": "result2"},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (first tool response)
|
||||
recent = [all_msgs[2], all_msgs[3], all_msgs[4]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the assistant message with both tool_calls
|
||||
assert len(result) == 4
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert len(result[0]["tool_calls"]) == 2
|
||||
|
||||
# ---- Anthropic Format Tests ----
|
||||
|
||||
def test_anthropic_adds_missing_tool_use(self):
|
||||
"""Test that orphaned Anthropic tool_result gets its tool_use prepended."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "SF"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_123",
|
||||
"content": "22°C and sunny",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (the tool_result)
|
||||
recent = [all_msgs[2], all_msgs[3]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the tool_use message
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert result[0]["content"][0]["type"] == "tool_use"
|
||||
|
||||
def test_anthropic_keeps_complete_pairs(self):
|
||||
"""Test that complete Anthropic pairs are unchanged."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "System"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_456",
|
||||
"name": "calculator",
|
||||
"input": {"expr": "2+2"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_456",
|
||||
"content": "4",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
recent = all_msgs[1:] # Include both tool_use and result
|
||||
start_index = 1
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
assert len(result) == 2 # No messages added
|
||||
|
||||
def test_anthropic_multiple_tool_uses(self):
|
||||
"""Test multiple Anthropic tool_use blocks in one message."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "System"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me check both..."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_1",
|
||||
"name": "get_weather",
|
||||
"input": {"city": "NYC"},
|
||||
},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_2",
|
||||
"name": "get_weather",
|
||||
"input": {"city": "LA"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_1",
|
||||
"content": "Cold",
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_2",
|
||||
"content": "Warm",
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (tool_result)
|
||||
recent = [all_msgs[2], all_msgs[3]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the assistant message with both tool_uses
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "assistant"
|
||||
tool_use_count = sum(
|
||||
1 for b in result[0]["content"] if b.get("type") == "tool_use"
|
||||
)
|
||||
assert tool_use_count == 2
|
||||
|
||||
# ---- Mixed/Edge Case Tests ----
|
||||
|
||||
def test_anthropic_with_type_message_field(self):
|
||||
"""Test Anthropic format with 'type': 'message' field (smart_decision_maker style)."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_abc",
|
||||
"name": "search",
|
||||
"input": {"q": "test"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message", # Extra field from smart_decision_maker
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_abc",
|
||||
"content": "Found results",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (the tool_result with 'type': 'message')
|
||||
recent = [all_msgs[2], all_msgs[3]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the tool_use message
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert result[0]["content"][0]["type"] == "tool_use"
|
||||
|
||||
def test_handles_no_tool_messages(self):
|
||||
"""Test messages without tool calls."""
|
||||
all_msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
recent = all_msgs
|
||||
start_index = 0
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
assert result == all_msgs
|
||||
|
||||
def test_handles_empty_messages(self):
|
||||
"""Test empty message list."""
|
||||
result = _ensure_tool_pairs_intact([], [], 0)
|
||||
assert result == []
|
||||
|
||||
def test_mixed_text_and_tool_content(self):
|
||||
"""Test Anthropic message with mixed text and tool_use content."""
|
||||
all_msgs = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "I'll help you with that."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_mixed",
|
||||
"name": "search",
|
||||
"input": {"q": "test"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_mixed",
|
||||
"content": "Found results",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": "Here are the results..."},
|
||||
]
|
||||
# Start from tool_result
|
||||
recent = [all_msgs[1], all_msgs[2]]
|
||||
start_index = 1
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the assistant message with tool_use
|
||||
assert len(result) == 3
|
||||
assert result[0]["content"][0]["type"] == "text"
|
||||
assert result[0]["content"][1]["type"] == "tool_use"
|
||||
|
||||
|
||||
class TestCompressContext:
|
||||
"""Test the async compress_context function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_compression_needed(self):
|
||||
"""Test messages under limit return without compression."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
]
|
||||
|
||||
result = await compress_context(messages, target_tokens=100000)
|
||||
|
||||
assert isinstance(result, CompressResult)
|
||||
assert result.was_compacted is False
|
||||
assert len(result.messages) == 2
|
||||
assert result.error is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncation_without_client(self):
|
||||
"""Test that truncation works without LLM client."""
|
||||
long_content = "x" * 50000
|
||||
messages = [
|
||||
{"role": "system", "content": "System"},
|
||||
{"role": "user", "content": long_content},
|
||||
{"role": "assistant", "content": "Response"},
|
||||
]
|
||||
|
||||
result = await compress_context(
|
||||
messages, target_tokens=1000, client=None, reserve=100
|
||||
)
|
||||
|
||||
assert result.was_compacted is True
|
||||
# Should have truncated without summarization
|
||||
assert result.messages_summarized == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_mocked_llm_client(self):
|
||||
"""Test summarization with mocked LLM client."""
|
||||
# Create many messages to trigger summarization
|
||||
messages = [{"role": "system", "content": "System prompt"}]
|
||||
for i in range(30):
|
||||
messages.append({"role": "user", "content": f"User message {i} " * 100})
|
||||
messages.append(
|
||||
{"role": "assistant", "content": f"Assistant response {i} " * 100}
|
||||
)
|
||||
|
||||
# Mock the AsyncOpenAI client
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Summary of conversation"
|
||||
mock_client.with_options.return_value.chat.completions.create = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
result = await compress_context(
|
||||
messages,
|
||||
target_tokens=5000,
|
||||
client=mock_client,
|
||||
keep_recent=5,
|
||||
reserve=500,
|
||||
)
|
||||
|
||||
assert result.was_compacted is True
|
||||
# Should have attempted summarization
|
||||
assert mock_client.with_options.called or result.messages_summarized > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_tool_pairs(self):
|
||||
"""Test that tool call/response pairs stay together."""
|
||||
messages = [
|
||||
{"role": "system", "content": "System"},
|
||||
{"role": "user", "content": "Do something"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "func"}}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "Result " * 1000},
|
||||
{"role": "assistant", "content": "Done!"},
|
||||
]
|
||||
|
||||
result = await compress_context(
|
||||
messages, target_tokens=500, client=None, reserve=50
|
||||
)
|
||||
|
||||
# Check that if tool response exists, its call exists too
|
||||
tool_call_ids = set()
|
||||
tool_response_ids = set()
|
||||
for msg in result.messages:
|
||||
if "tool_calls" in msg:
|
||||
for tc in msg["tool_calls"]:
|
||||
tool_call_ids.add(tc["id"])
|
||||
if msg.get("role") == "tool":
|
||||
tool_response_ids.add(msg.get("tool_call_id"))
|
||||
|
||||
# All tool responses should have their calls
|
||||
assert tool_response_ids <= tool_call_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_error_when_cannot_compress(self):
|
||||
"""Test that error is returned when compression fails."""
|
||||
# Single huge message that can't be compressed enough
|
||||
messages = [
|
||||
{"role": "user", "content": "x" * 100000},
|
||||
]
|
||||
|
||||
result = await compress_context(
|
||||
messages, target_tokens=100, client=None, reserve=50
|
||||
)
|
||||
|
||||
# Should have an error since we can't get below 100 tokens
|
||||
assert result.error is not None
|
||||
assert result.was_compacted is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_messages(self):
|
||||
"""Test that empty messages list returns early without error."""
|
||||
result = await compress_context([], target_tokens=1000)
|
||||
|
||||
assert result.messages == []
|
||||
assert result.token_count == 0
|
||||
assert result.was_compacted is False
|
||||
assert result.error is None
|
||||
|
||||
|
||||
class TestRemoveOrphanToolResponses:
|
||||
"""Test _remove_orphan_tool_responses helper function."""
|
||||
|
||||
def test_removes_openai_orphan(self):
|
||||
"""Test removal of orphan OpenAI tool response."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "call_orphan", "content": "result"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
orphan_ids = {"call_orphan"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_keeps_valid_openai_tool(self):
|
||||
"""Test that valid OpenAI tool responses are kept."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "call_valid", "content": "result"},
|
||||
]
|
||||
orphan_ids = {"call_other"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["tool_call_id"] == "call_valid"
|
||||
|
||||
def test_filters_anthropic_mixed_blocks(self):
|
||||
"""Test filtering individual orphan blocks from Anthropic message with mixed valid/orphan."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_valid",
|
||||
"content": "valid result",
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_orphan",
|
||||
"content": "orphan result",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
orphan_ids = {"toolu_orphan"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
assert len(result) == 1
|
||||
# Should only have the valid tool_result, orphan filtered out
|
||||
assert len(result[0]["content"]) == 1
|
||||
assert result[0]["content"][0]["tool_use_id"] == "toolu_valid"
|
||||
|
||||
def test_removes_anthropic_all_orphan(self):
|
||||
"""Test removal of Anthropic message when all tool_results are orphans."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_orphan1",
|
||||
"content": "result1",
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_orphan2",
|
||||
"content": "result2",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
orphan_ids = {"toolu_orphan1", "toolu_orphan2"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
# Message should be completely removed since no content left
|
||||
assert len(result) == 0
|
||||
|
||||
def test_preserves_non_tool_messages(self):
|
||||
"""Test that non-tool messages are preserved."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
orphan_ids = {"some_id"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
assert result == messages
|
||||
|
||||
|
||||
class TestCompressResultDataclass:
|
||||
"""Test CompressResult dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly."""
|
||||
result = CompressResult(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
token_count=10,
|
||||
was_compacted=False,
|
||||
)
|
||||
|
||||
assert result.error is None
|
||||
assert result.original_token_count == 0 # Defaults to 0, not None
|
||||
assert result.messages_summarized == 0
|
||||
assert result.messages_dropped == 0
|
||||
|
||||
def test_all_fields(self):
|
||||
"""Test all fields can be set."""
|
||||
result = CompressResult(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
token_count=100,
|
||||
was_compacted=True,
|
||||
error="Some error",
|
||||
original_token_count=500,
|
||||
messages_summarized=10,
|
||||
messages_dropped=5,
|
||||
)
|
||||
|
||||
assert result.token_count == 100
|
||||
assert result.was_compacted is True
|
||||
assert result.error == "Some error"
|
||||
assert result.original_token_count == 500
|
||||
assert result.messages_summarized == 10
|
||||
assert result.messages_dropped == 5
|
||||
|
||||
@@ -656,6 +656,7 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
e2b_api_key: str = Field(default="", description="E2B API key")
|
||||
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
||||
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
||||
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
|
||||
|
||||
linear_client_id: str = Field(default="", description="Linear client ID")
|
||||
linear_client_secret: str = Field(default="", description="Linear client secret")
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
-- Migrate Claude 3.7 Sonnet to Claude 4.5 Sonnet
|
||||
-- This updates all AgentNode blocks that use the deprecated Claude 3.7 Sonnet model
|
||||
-- Anthropic is retiring claude-3-7-sonnet-20250219 on February 19, 2026
|
||||
|
||||
-- Update AgentNode constant inputs
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"claude-sonnet-4-5-20250929"'::jsonb
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = 'claude-3-7-sonnet-20250219';
|
||||
|
||||
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput)
|
||||
UPDATE "AgentNodeExecutionInputOutput"
|
||||
SET "data" = JSONB_SET(
|
||||
"data"::jsonb,
|
||||
'{model}',
|
||||
'"claude-sonnet-4-5-20250929"'::jsonb
|
||||
)
|
||||
WHERE "agentPresetId" IS NOT NULL
|
||||
AND "data"::jsonb->>'model' = 'claude-3-7-sonnet-20250219';
|
||||
47
autogpt_platform/backend/poetry.lock
generated
47
autogpt_platform/backend/poetry.lock
generated
@@ -1169,6 +1169,29 @@ attrs = ">=21.3.0"
|
||||
e2b = ">=1.5.4,<2.0.0"
|
||||
httpx = ">=0.20.0,<1.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "elevenlabs"
|
||||
version = "1.59.0"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "elevenlabs-1.59.0-py3-none-any.whl", hash = "sha256:468145db81a0bc867708b4a8619699f75583e9481b395ec1339d0b443da771ed"},
|
||||
{file = "elevenlabs-1.59.0.tar.gz", hash = "sha256:16e735bd594e86d415dd445d249c8cc28b09996cfd627fbc10102c0a84698859"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.21.2"
|
||||
pydantic = ">=1.9.2"
|
||||
pydantic-core = ">=2.18.2,<3.0.0"
|
||||
requests = ">=2.20"
|
||||
typing_extensions = ">=4.0.0"
|
||||
websockets = ">=11.0"
|
||||
|
||||
[package.extras]
|
||||
pyaudio = ["pyaudio (>=0.2.14)"]
|
||||
|
||||
[[package]]
|
||||
name = "email-validator"
|
||||
version = "2.2.0"
|
||||
@@ -7361,6 +7384,28 @@ files = [
|
||||
defusedxml = ">=0.7.1,<0.8.0"
|
||||
requests = "*"
|
||||
|
||||
[[package]]
|
||||
name = "yt-dlp"
|
||||
version = "2025.12.8"
|
||||
description = "A feature-rich command-line audio/video downloader"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "yt_dlp-2025.12.8-py3-none-any.whl", hash = "sha256:36e2584342e409cfbfa0b5e61448a1c5189e345cf4564294456ee509e7d3e065"},
|
||||
{file = "yt_dlp-2025.12.8.tar.gz", hash = "sha256:b773c81bb6b71cb2c111cfb859f453c7a71cf2ef44eff234ff155877184c3e4f"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
build = ["build", "hatchling (>=1.27.0)", "pip", "setuptools (>=71.0.2)", "wheel"]
|
||||
curl-cffi = ["curl-cffi (>=0.5.10,<0.6.dev0 || >=0.10.dev0,<0.14) ; implementation_name == \"cpython\""]
|
||||
default = ["brotli ; implementation_name == \"cpython\"", "brotlicffi ; implementation_name != \"cpython\"", "certifi", "mutagen", "pycryptodomex", "requests (>=2.32.2,<3)", "urllib3 (>=2.0.2,<3)", "websockets (>=13.0)", "yt-dlp-ejs (==0.3.2)"]
|
||||
dev = ["autopep8 (>=2.0,<3.0)", "pre-commit", "pytest (>=8.1,<9.0)", "pytest-rerunfailures (>=14.0,<15.0)", "ruff (>=0.14.0,<0.15.0)"]
|
||||
pyinstaller = ["pyinstaller (>=6.17.0)"]
|
||||
secretstorage = ["cffi", "secretstorage"]
|
||||
static-analysis = ["autopep8 (>=2.0,<3.0)", "ruff (>=0.14.0,<0.15.0)"]
|
||||
test = ["pytest (>=8.1,<9.0)", "pytest-rerunfailures (>=14.0,<15.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "zerobouncesdk"
|
||||
version = "1.1.2"
|
||||
@@ -7512,4 +7557,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "ee5742dc1a9df50dfc06d4b26a1682cbb2b25cab6b79ce5625ec272f93e4f4bf"
|
||||
content-hash = "8239323f9ae6713224dffd1fe8ba8b449fe88b6c3c7a90940294a74f43a0387a"
|
||||
|
||||
@@ -20,6 +20,7 @@ click = "^8.2.0"
|
||||
cryptography = "^45.0"
|
||||
discord-py = "^2.5.2"
|
||||
e2b-code-interpreter = "^1.5.2"
|
||||
elevenlabs = "^1.50.0"
|
||||
fastapi = "^0.116.1"
|
||||
feedparser = "^6.0.11"
|
||||
flake8 = "^7.3.0"
|
||||
@@ -71,6 +72,7 @@ tweepy = "^4.16.0"
|
||||
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
||||
websockets = "^15.0"
|
||||
youtube-transcript-api = "^1.2.1"
|
||||
yt-dlp = "2025.12.08"
|
||||
zerobouncesdk = "^1.1.2"
|
||||
# NOTE: please insert new dependencies in their alphabetical location
|
||||
pytest-snapshot = "^0.9.0"
|
||||
|
||||
@@ -9,8 +9,7 @@
|
||||
"sub_heading": "Creator agent subheading",
|
||||
"description": "Creator agent description",
|
||||
"runs": 50,
|
||||
"rating": 4.0,
|
||||
"agent_graph_id": "test-graph-2"
|
||||
"rating": 4.0
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -9,8 +9,7 @@
|
||||
"sub_heading": "Category agent subheading",
|
||||
"description": "Category agent description",
|
||||
"runs": 60,
|
||||
"rating": 4.1,
|
||||
"agent_graph_id": "test-graph-category"
|
||||
"rating": 4.1
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -9,8 +9,7 @@
|
||||
"sub_heading": "Agent 0 subheading",
|
||||
"description": "Agent 0 description",
|
||||
"runs": 0,
|
||||
"rating": 4.0,
|
||||
"agent_graph_id": "test-graph-2"
|
||||
"rating": 4.0
|
||||
},
|
||||
{
|
||||
"slug": "agent-1",
|
||||
@@ -21,8 +20,7 @@
|
||||
"sub_heading": "Agent 1 subheading",
|
||||
"description": "Agent 1 description",
|
||||
"runs": 10,
|
||||
"rating": 4.0,
|
||||
"agent_graph_id": "test-graph-2"
|
||||
"rating": 4.0
|
||||
},
|
||||
{
|
||||
"slug": "agent-2",
|
||||
@@ -33,8 +31,7 @@
|
||||
"sub_heading": "Agent 2 subheading",
|
||||
"description": "Agent 2 description",
|
||||
"runs": 20,
|
||||
"rating": 4.0,
|
||||
"agent_graph_id": "test-graph-2"
|
||||
"rating": 4.0
|
||||
},
|
||||
{
|
||||
"slug": "agent-3",
|
||||
@@ -45,8 +42,7 @@
|
||||
"sub_heading": "Agent 3 subheading",
|
||||
"description": "Agent 3 description",
|
||||
"runs": 30,
|
||||
"rating": 4.0,
|
||||
"agent_graph_id": "test-graph-2"
|
||||
"rating": 4.0
|
||||
},
|
||||
{
|
||||
"slug": "agent-4",
|
||||
@@ -57,8 +53,7 @@
|
||||
"sub_heading": "Agent 4 subheading",
|
||||
"description": "Agent 4 description",
|
||||
"runs": 40,
|
||||
"rating": 4.0,
|
||||
"agent_graph_id": "test-graph-2"
|
||||
"rating": 4.0
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -9,8 +9,7 @@
|
||||
"sub_heading": "Search agent subheading",
|
||||
"description": "Specific search term description",
|
||||
"runs": 75,
|
||||
"rating": 4.2,
|
||||
"agent_graph_id": "test-graph-search"
|
||||
"rating": 4.2
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -9,8 +9,7 @@
|
||||
"sub_heading": "Top agent subheading",
|
||||
"description": "Top agent description",
|
||||
"runs": 1000,
|
||||
"rating": 5.0,
|
||||
"agent_graph_id": "test-graph-3"
|
||||
"rating": 5.0
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -9,8 +9,7 @@
|
||||
"sub_heading": "Featured agent subheading",
|
||||
"description": "Featured agent description",
|
||||
"runs": 100,
|
||||
"rating": 4.5,
|
||||
"agent_graph_id": "test-graph-1"
|
||||
"rating": 4.5
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -31,10 +31,6 @@
|
||||
"has_sensitive_action": false,
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"execution_count": 0,
|
||||
"success_rate": null,
|
||||
"avg_correctness_score": null,
|
||||
"recent_executions": [],
|
||||
"can_access_graph": true,
|
||||
"is_latest_version": true,
|
||||
"is_favorite": false,
|
||||
@@ -76,10 +72,6 @@
|
||||
"has_sensitive_action": false,
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"execution_count": 0,
|
||||
"success_rate": null,
|
||||
"avg_correctness_score": null,
|
||||
"recent_executions": [],
|
||||
"can_access_graph": false,
|
||||
"is_latest_version": true,
|
||||
"is_favorite": false,
|
||||
|
||||
@@ -57,8 +57,7 @@ class TestDecomposeGoal:
|
||||
|
||||
result = await core.decompose_goal("Build a chatbot")
|
||||
|
||||
# library_agents defaults to None
|
||||
mock_external.assert_called_once_with("Build a chatbot", "", None)
|
||||
mock_external.assert_called_once_with("Build a chatbot", "")
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -75,8 +74,7 @@ class TestDecomposeGoal:
|
||||
|
||||
await core.decompose_goal("Build a chatbot", "Use Python")
|
||||
|
||||
# library_agents defaults to None
|
||||
mock_external.assert_called_once_with("Build a chatbot", "Use Python", None)
|
||||
mock_external.assert_called_once_with("Build a chatbot", "Use Python")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_service_failure(self):
|
||||
@@ -111,7 +109,8 @@ class TestGenerateAgent:
|
||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||
result = await core.generate_agent(instructions)
|
||||
|
||||
mock_external.assert_called_once_with(instructions, None, None, None)
|
||||
mock_external.assert_called_once_with(instructions)
|
||||
# Result should have id, version, is_active added if not present
|
||||
assert result is not None
|
||||
assert result["name"] == "Test Agent"
|
||||
assert "id" in result
|
||||
@@ -175,9 +174,7 @@ class TestGenerateAgentPatch:
|
||||
current_agent = {"nodes": [], "links": []}
|
||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||
|
||||
mock_external.assert_called_once_with(
|
||||
"Add a node", current_agent, None, None, None
|
||||
)
|
||||
mock_external.assert_called_once_with("Add a node", current_agent)
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,857 +0,0 @@
|
||||
"""
|
||||
Tests for library agent fetching functionality in agent generator.
|
||||
|
||||
This test suite verifies the search-based library agent fetching,
|
||||
including the combination of library and marketplace agents.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import core
|
||||
|
||||
|
||||
class TestGetLibraryAgentsForGeneration:
|
||||
"""Test get_library_agents_for_generation function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetches_agents_with_search_term(self):
|
||||
"""Test that search_term is passed to the library db."""
|
||||
# Create a mock agent with proper attribute values
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.graph_id = "agent-123"
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.name = "Email Agent"
|
||||
mock_agent.description = "Sends emails"
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
mock_agent.recent_executions = []
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = [mock_agent]
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="send email",
|
||||
)
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term="send email",
|
||||
page=1,
|
||||
page_size=15,
|
||||
include_executions=True,
|
||||
)
|
||||
|
||||
# Verify result format
|
||||
assert len(result) == 1
|
||||
assert result[0]["graph_id"] == "agent-123"
|
||||
assert result[0]["name"] == "Email Agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_excludes_specified_graph_id(self):
|
||||
"""Test that agents with excluded graph_id are filtered out."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = [
|
||||
MagicMock(
|
||||
graph_id="agent-123",
|
||||
graph_version=1,
|
||||
name="Agent 1",
|
||||
description="First agent",
|
||||
input_schema={},
|
||||
output_schema={},
|
||||
recent_executions=[],
|
||||
),
|
||||
MagicMock(
|
||||
graph_id="agent-456",
|
||||
graph_version=1,
|
||||
name="Agent 2",
|
||||
description="Second agent",
|
||||
input_schema={},
|
||||
output_schema={},
|
||||
recent_executions=[],
|
||||
),
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
exclude_graph_id="agent-123",
|
||||
)
|
||||
|
||||
# Verify the excluded agent is not in results
|
||||
assert len(result) == 1
|
||||
assert result[0]["graph_id"] == "agent-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respects_max_results(self):
|
||||
"""Test that max_results parameter limits the page_size."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
max_results=5,
|
||||
)
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term=None,
|
||||
page=1,
|
||||
page_size=5,
|
||||
include_executions=True,
|
||||
)
|
||||
|
||||
|
||||
class TestSearchMarketplaceAgentsForGeneration:
|
||||
"""Test search_marketplace_agents_for_generation function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_searches_marketplace_with_query(self):
|
||||
"""Test that marketplace is searched with the query."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = [
|
||||
MagicMock(
|
||||
agent_name="Public Agent",
|
||||
description="A public agent",
|
||||
sub_heading="Does something useful",
|
||||
creator="creator-1",
|
||||
agent_graph_id="graph-123",
|
||||
)
|
||||
]
|
||||
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.id = "graph-123"
|
||||
mock_graph.version = 1
|
||||
mock_graph.input_schema = {"type": "object"}
|
||||
mock_graph.output_schema = {"type": "object"}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_search,
|
||||
patch(
|
||||
"backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"graph-123": mock_graph},
|
||||
),
|
||||
):
|
||||
result = await core.search_marketplace_agents_for_generation(
|
||||
search_query="automation",
|
||||
max_results=10,
|
||||
)
|
||||
|
||||
mock_search.assert_called_once_with(
|
||||
search_query="automation",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Public Agent"
|
||||
assert result[0]["graph_id"] == "graph-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_marketplace_error_gracefully(self):
|
||||
"""Test that marketplace errors don't crash the function."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Marketplace unavailable"),
|
||||
):
|
||||
result = await core.search_marketplace_agents_for_generation(
|
||||
search_query="test"
|
||||
)
|
||||
|
||||
# Should return empty list, not raise exception
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestGetAllRelevantAgentsForGeneration:
|
||||
"""Test get_all_relevant_agents_for_generation function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combines_library_and_marketplace_agents(self):
|
||||
"""Test that agents from both sources are combined."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "lib-123",
|
||||
"graph_version": 1,
|
||||
"name": "Library Agent",
|
||||
"description": "From library",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
marketplace_agents = [
|
||||
{
|
||||
"graph_id": "market-456",
|
||||
"graph_version": 1,
|
||||
"name": "Market Agent",
|
||||
"description": "From marketplace",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_library_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=library_agents,
|
||||
):
|
||||
with patch.object(
|
||||
core,
|
||||
"search_marketplace_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=marketplace_agents,
|
||||
):
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="test query",
|
||||
include_marketplace=True,
|
||||
)
|
||||
|
||||
# Library agents should come first
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "Library Agent"
|
||||
assert result[1]["name"] == "Market Agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_by_graph_id(self):
|
||||
"""Test that marketplace agents with same graph_id as library are excluded."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "shared-123",
|
||||
"graph_version": 1,
|
||||
"name": "Shared Agent",
|
||||
"description": "From library",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
marketplace_agents = [
|
||||
{
|
||||
"graph_id": "shared-123", # Same graph_id, should be deduplicated
|
||||
"graph_version": 1,
|
||||
"name": "Shared Agent",
|
||||
"description": "From marketplace",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
},
|
||||
{
|
||||
"graph_id": "unique-456",
|
||||
"graph_version": 1,
|
||||
"name": "Unique Agent",
|
||||
"description": "Only in marketplace",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
},
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_library_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=library_agents,
|
||||
):
|
||||
with patch.object(
|
||||
core,
|
||||
"search_marketplace_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=marketplace_agents,
|
||||
):
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="test",
|
||||
include_marketplace=True,
|
||||
)
|
||||
|
||||
# Shared Agent from marketplace should be excluded by graph_id
|
||||
assert len(result) == 2
|
||||
names = [a["name"] for a in result]
|
||||
assert "Shared Agent" in names
|
||||
assert "Unique Agent" in names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_marketplace_when_disabled(self):
|
||||
"""Test that marketplace is not searched when include_marketplace=False."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "lib-123",
|
||||
"graph_version": 1,
|
||||
"name": "Library Agent",
|
||||
"description": "From library",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_library_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=library_agents,
|
||||
):
|
||||
with patch.object(
|
||||
core,
|
||||
"search_marketplace_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_marketplace:
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="test",
|
||||
include_marketplace=False,
|
||||
)
|
||||
|
||||
# Marketplace should not be called
|
||||
mock_marketplace.assert_not_called()
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_marketplace_when_no_search_query(self):
|
||||
"""Test that marketplace is not searched without a search query."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "lib-123",
|
||||
"graph_version": 1,
|
||||
"name": "Library Agent",
|
||||
"description": "From library",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_library_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=library_agents,
|
||||
):
|
||||
with patch.object(
|
||||
core,
|
||||
"search_marketplace_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_marketplace:
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query=None, # No search query
|
||||
include_marketplace=True,
|
||||
)
|
||||
|
||||
# Marketplace should not be called without search query
|
||||
mock_marketplace.assert_not_called()
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestExtractSearchTermsFromSteps:
|
||||
"""Test extract_search_terms_from_steps function."""
|
||||
|
||||
def test_extracts_terms_from_instructions_type(self):
|
||||
"""Test extraction from valid instructions decomposition result."""
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{
|
||||
"description": "Send an email notification",
|
||||
"block_name": "GmailSendBlock",
|
||||
},
|
||||
{"description": "Fetch weather data", "action": "Get weather API"},
|
||||
],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
assert "Send an email notification" in result
|
||||
assert "GmailSendBlock" in result
|
||||
assert "Fetch weather data" in result
|
||||
assert "Get weather API" in result
|
||||
|
||||
def test_returns_empty_for_non_instructions_type(self):
|
||||
"""Test that non-instructions types return empty list."""
|
||||
decomposition_result = {
|
||||
"type": "clarifying_questions",
|
||||
"questions": [{"question": "What email?"}],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_deduplicates_terms_case_insensitively(self):
|
||||
"""Test that duplicate terms are removed (case-insensitive)."""
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{"description": "Send Email", "name": "send email"},
|
||||
{"description": "Other task"},
|
||||
],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
# Should only have one "send email" variant
|
||||
email_terms = [t for t in result if "email" in t.lower()]
|
||||
assert len(email_terms) == 1
|
||||
|
||||
def test_filters_short_terms(self):
|
||||
"""Test that terms with 3 or fewer characters are filtered out."""
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{"description": "ab", "action": "xyz"}, # Both too short
|
||||
{"description": "Valid term here"},
|
||||
],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
assert "ab" not in result
|
||||
assert "xyz" not in result
|
||||
assert "Valid term here" in result
|
||||
|
||||
def test_handles_empty_steps(self):
|
||||
"""Test handling of empty steps list."""
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestEnrichLibraryAgentsFromSteps:
|
||||
"""Test enrich_library_agents_from_steps function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enriches_with_additional_agents(self):
|
||||
"""Test that additional agents are found based on steps."""
|
||||
existing_agents = [
|
||||
{
|
||||
"graph_id": "existing-123",
|
||||
"graph_version": 1,
|
||||
"name": "Existing Agent",
|
||||
"description": "Already fetched",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
additional_agents = [
|
||||
{
|
||||
"graph_id": "new-456",
|
||||
"graph_version": 1,
|
||||
"name": "Email Agent",
|
||||
"description": "For sending emails",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{"description": "Send email notification"},
|
||||
],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=additional_agents,
|
||||
):
|
||||
result = await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should have both existing and new agents
|
||||
assert len(result) == 2
|
||||
names = [a["name"] for a in result]
|
||||
assert "Existing Agent" in names
|
||||
assert "Email Agent" in names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_by_graph_id(self):
|
||||
"""Test that agents with same graph_id are not duplicated."""
|
||||
existing_agents = [
|
||||
{
|
||||
"graph_id": "agent-123",
|
||||
"graph_version": 1,
|
||||
"name": "Existing Agent",
|
||||
"description": "Already fetched",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
# Additional search returns same agent
|
||||
additional_agents = [
|
||||
{
|
||||
"graph_id": "agent-123", # Same ID
|
||||
"graph_version": 1,
|
||||
"name": "Existing Agent Copy",
|
||||
"description": "Same agent different name",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [{"description": "Some action"}],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=additional_agents,
|
||||
):
|
||||
result = await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should not duplicate
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_by_name(self):
|
||||
"""Test that agents with same name are not duplicated."""
|
||||
existing_agents = [
|
||||
{
|
||||
"graph_id": "agent-123",
|
||||
"graph_version": 1,
|
||||
"name": "Email Agent",
|
||||
"description": "Already fetched",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
# Additional search returns agent with same name but different ID
|
||||
additional_agents = [
|
||||
{
|
||||
"graph_id": "agent-456", # Different ID
|
||||
"graph_version": 1,
|
||||
"name": "Email Agent", # Same name
|
||||
"description": "Different agent same name",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [{"description": "Send email"}],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=additional_agents,
|
||||
):
|
||||
result = await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should not duplicate by name
|
||||
assert len(result) == 1
|
||||
assert result[0].get("graph_id") == "agent-123" # Original kept
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_existing_when_no_steps(self):
|
||||
"""Test that existing agents are returned when no search terms extracted."""
|
||||
existing_agents = [
|
||||
{
|
||||
"graph_id": "existing-123",
|
||||
"graph_version": 1,
|
||||
"name": "Existing Agent",
|
||||
"description": "Already fetched",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
decomposition_result = {
|
||||
"type": "clarifying_questions", # Not instructions type
|
||||
"questions": [],
|
||||
}
|
||||
|
||||
result = await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should return existing unchanged
|
||||
assert result == existing_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_limits_search_terms_to_three(self):
|
||||
"""Test that only first 3 search terms are used."""
|
||||
existing_agents = []
|
||||
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{"description": "First action"},
|
||||
{"description": "Second action"},
|
||||
{"description": "Third action"},
|
||||
{"description": "Fourth action"},
|
||||
{"description": "Fifth action"},
|
||||
],
|
||||
}
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_get_agents(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return []
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_all_relevant_agents_for_generation",
|
||||
side_effect=mock_get_agents,
|
||||
):
|
||||
await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should only make 3 calls (limited to first 3 terms)
|
||||
assert call_count == 3
|
||||
|
||||
|
||||
class TestExtractUuidsFromText:
|
||||
"""Test extract_uuids_from_text function."""
|
||||
|
||||
def test_extracts_single_uuid(self):
|
||||
"""Test extraction of a single UUID from text."""
|
||||
text = "Use my agent 46631191-e8a8-486f-ad90-84f89738321d for this task"
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert len(result) == 1
|
||||
assert "46631191-e8a8-486f-ad90-84f89738321d" in result
|
||||
|
||||
def test_extracts_multiple_uuids(self):
|
||||
"""Test extraction of multiple UUIDs from text."""
|
||||
text = (
|
||||
"Combine agents 11111111-1111-4111-8111-111111111111 "
|
||||
"and 22222222-2222-4222-9222-222222222222"
|
||||
)
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert len(result) == 2
|
||||
assert "11111111-1111-4111-8111-111111111111" in result
|
||||
assert "22222222-2222-4222-9222-222222222222" in result
|
||||
|
||||
def test_deduplicates_uuids(self):
|
||||
"""Test that duplicate UUIDs are deduplicated."""
|
||||
text = (
|
||||
"Use 46631191-e8a8-486f-ad90-84f89738321d twice: "
|
||||
"46631191-e8a8-486f-ad90-84f89738321d"
|
||||
)
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_normalizes_to_lowercase(self):
|
||||
"""Test that UUIDs are normalized to lowercase."""
|
||||
text = "Use 46631191-E8A8-486F-AD90-84F89738321D"
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert result[0] == "46631191-e8a8-486f-ad90-84f89738321d"
|
||||
|
||||
def test_returns_empty_for_no_uuids(self):
|
||||
"""Test that empty list is returned when no UUIDs found."""
|
||||
text = "Create an email agent that sends notifications"
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert result == []
|
||||
|
||||
def test_ignores_invalid_uuids(self):
|
||||
"""Test that invalid UUID-like strings are ignored."""
|
||||
text = "Not a valid UUID: 12345678-1234-1234-1234-123456789abc"
|
||||
result = core.extract_uuids_from_text(text)
|
||||
# UUID v4 requires specific patterns (4 in third group, 8/9/a/b in fourth)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestGetLibraryAgentById:
|
||||
"""Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_agent_when_found_by_graph_id(self):
|
||||
"""Test that agent is returned when found by graph_id."""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.graph_id = "agent-123"
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.name = "Test Agent"
|
||||
mock_agent.description = "Test description"
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is not None
|
||||
assert result["graph_id"] == "agent-123"
|
||||
assert result["name"] == "Test Agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_library_agent_id(self):
|
||||
"""Test that lookup falls back to library agent ID when graph_id not found."""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.graph_id = "graph-456" # Different from the lookup ID
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.name = "Library Agent"
|
||||
mock_agent.description = "Found by library ID"
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # Not found by graph_id
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent, # Found by library ID
|
||||
),
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "library-id-123")
|
||||
|
||||
assert result is not None
|
||||
assert result["graph_id"] == "graph-456"
|
||||
assert result["name"] == "Library Agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_found_by_either_method(self):
|
||||
"""Test that None is returned when agent not found by either method."""
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=core.NotFoundError("Not found"),
|
||||
),
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_exception(self):
|
||||
"""Test that None is returned when exception occurs in both lookups."""
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alias_works(self):
|
||||
"""Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id."""
|
||||
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
|
||||
|
||||
|
||||
class TestGetAllRelevantAgentsWithUuids:
|
||||
"""Test UUID extraction in get_all_relevant_agents_for_generation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetches_explicitly_mentioned_agents(self):
|
||||
"""Test that agents mentioned by UUID are fetched directly."""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.graph_id = "46631191-e8a8-486f-ad90-84f89738321d"
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.name = "Mentioned Agent"
|
||||
mock_agent.description = "Explicitly mentioned"
|
||||
mock_agent.input_schema = {}
|
||||
mock_agent.output_schema = {}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",
|
||||
include_marketplace=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].get("graph_id") == "46631191-e8a8-486f-ad90-84f89738321d"
|
||||
assert result[0].get("name") == "Mentioned Agent"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -102,7 +102,7 @@ class TestDecomposeGoalExternal:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_with_context(self):
|
||||
"""Test decomposition with additional context enriched into description."""
|
||||
"""Test decomposition with additional context."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
@@ -119,12 +119,9 @@ class TestDecomposeGoalExternal:
|
||||
"Build a chatbot", context="Use Python"
|
||||
)
|
||||
|
||||
expected_description = (
|
||||
"Build a chatbot\n\nAdditional context from user:\nUse Python"
|
||||
)
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/decompose-description",
|
||||
json={"description": expected_description},
|
||||
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -436,139 +433,5 @@ class TestGetBlocksExternal:
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLibraryAgentsPassthrough:
|
||||
"""Test that library_agents are passed correctly in all requests."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset client singleton before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_passes_library_agents(self):
|
||||
"""Test that library_agents are included in decompose goal payload."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "agent-123",
|
||||
"graph_version": 1,
|
||||
"name": "Email Sender",
|
||||
"description": "Sends emails",
|
||||
"input_schema": {"properties": {"to": {"type": "string"}}},
|
||||
"output_schema": {"properties": {"sent": {"type": "boolean"}}},
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.decompose_goal_external(
|
||||
"Send an email",
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_passes_library_agents(self):
|
||||
"""Test that library_agents are included in generate agent payload."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "agent-456",
|
||||
"graph_version": 2,
|
||||
"name": "Data Fetcher",
|
||||
"description": "Fetches data from API",
|
||||
"input_schema": {"properties": {"url": {"type": "string"}}},
|
||||
"output_schema": {"properties": {"data": {"type": "object"}}},
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": {"name": "Test Agent", "nodes": []},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.generate_agent_external(
|
||||
{"steps": ["Step 1"]},
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_patch_passes_library_agents(self):
|
||||
"""Test that library_agents are included in patch generation payload."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "agent-789",
|
||||
"graph_version": 1,
|
||||
"name": "Slack Notifier",
|
||||
"description": "Sends Slack messages",
|
||||
"input_schema": {"properties": {"message": {"type": "string"}}},
|
||||
"output_schema": {"properties": {"success": {"type": "boolean"}}},
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": {"name": "Updated Agent", "nodes": []},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.generate_agent_patch_external(
|
||||
"Add error handling",
|
||||
{"name": "Original Agent", "nodes": []},
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_without_library_agents(self):
|
||||
"""Test that decompose goal works without library_agents."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.decompose_goal_external("Build a workflow")
|
||||
|
||||
# Verify library_agents was NOT passed when not provided
|
||||
call_args = mock_client.post.call_args
|
||||
assert "library_agents" not in call_args[1]["json"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -43,24 +43,19 @@ faker = Faker()
|
||||
# Constants for data generation limits (reduced for E2E tests)
|
||||
NUM_USERS = 15
|
||||
NUM_AGENT_BLOCKS = 30
|
||||
MIN_GRAPHS_PER_USER = 25
|
||||
MAX_GRAPHS_PER_USER = 25
|
||||
MIN_GRAPHS_PER_USER = 15
|
||||
MAX_GRAPHS_PER_USER = 15
|
||||
MIN_NODES_PER_GRAPH = 3
|
||||
MAX_NODES_PER_GRAPH = 6
|
||||
MIN_PRESETS_PER_USER = 2
|
||||
MAX_PRESETS_PER_USER = 3
|
||||
MIN_AGENTS_PER_USER = 25
|
||||
MAX_AGENTS_PER_USER = 25
|
||||
MIN_AGENTS_PER_USER = 15
|
||||
MAX_AGENTS_PER_USER = 15
|
||||
MIN_EXECUTIONS_PER_GRAPH = 2
|
||||
MAX_EXECUTIONS_PER_GRAPH = 8
|
||||
MIN_REVIEWS_PER_VERSION = 2
|
||||
MAX_REVIEWS_PER_VERSION = 5
|
||||
|
||||
# Guaranteed minimums for marketplace tests (deterministic)
|
||||
GUARANTEED_FEATURED_AGENTS = 8
|
||||
GUARANTEED_FEATURED_CREATORS = 5
|
||||
GUARANTEED_TOP_AGENTS = 10
|
||||
|
||||
|
||||
def get_image():
|
||||
"""Generate a consistent image URL using picsum.photos service."""
|
||||
@@ -390,7 +385,7 @@ class TestDataCreator:
|
||||
|
||||
library_agents = []
|
||||
for user in self.users:
|
||||
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
|
||||
num_agents = 10 # Create exactly 10 agents per user
|
||||
|
||||
# Get available graphs for this user
|
||||
user_graphs = [
|
||||
@@ -512,17 +507,14 @@ class TestDataCreator:
|
||||
existing_profiles, min(num_creators, len(existing_profiles))
|
||||
)
|
||||
|
||||
# Guarantee at least GUARANTEED_FEATURED_CREATORS featured creators
|
||||
num_featured = max(GUARANTEED_FEATURED_CREATORS, int(num_creators * 0.5))
|
||||
# Mark about 50% of creators as featured (more for testing)
|
||||
num_featured = max(2, int(num_creators * 0.5))
|
||||
num_featured = min(
|
||||
num_featured, len(selected_profiles)
|
||||
) # Don't exceed available profiles
|
||||
featured_profile_ids = set(
|
||||
random.sample([p.id for p in selected_profiles], num_featured)
|
||||
)
|
||||
print(
|
||||
f"🎯 Creating {num_featured} featured creators (min: {GUARANTEED_FEATURED_CREATORS})"
|
||||
)
|
||||
|
||||
for profile in selected_profiles:
|
||||
try:
|
||||
@@ -553,25 +545,21 @@ class TestDataCreator:
|
||||
return profiles
|
||||
|
||||
async def create_test_store_submissions(self) -> List[Dict[str, Any]]:
|
||||
"""Create test store submissions using the API function.
|
||||
|
||||
DETERMINISTIC: Guarantees minimum featured agents for E2E tests.
|
||||
"""
|
||||
"""Create test store submissions using the API function."""
|
||||
print("Creating test store submissions...")
|
||||
|
||||
submissions = []
|
||||
approved_submissions = []
|
||||
featured_count = 0
|
||||
submission_counter = 0
|
||||
|
||||
# Create a special test submission for test123@gmail.com (ALWAYS approved + featured)
|
||||
# Create a special test submission for test123@gmail.com
|
||||
test_user = next(
|
||||
(user for user in self.users if user["email"] == "test123@gmail.com"), None
|
||||
)
|
||||
if test_user and self.agent_graphs:
|
||||
if test_user:
|
||||
# Special test data for consistent testing
|
||||
test_submission_data = {
|
||||
"user_id": test_user["id"],
|
||||
"agent_id": self.agent_graphs[0]["id"],
|
||||
"agent_id": self.agent_graphs[0]["id"], # Use first available graph
|
||||
"agent_version": 1,
|
||||
"slug": "test-agent-submission",
|
||||
"name": "Test Agent Submission",
|
||||
@@ -592,24 +580,37 @@ class TestDataCreator:
|
||||
submissions.append(test_submission.model_dump())
|
||||
print("✅ Created special test store submission for test123@gmail.com")
|
||||
|
||||
# ALWAYS approve and feature the test submission
|
||||
# Randomly approve, reject, or leave pending the test submission
|
||||
if test_submission.store_listing_version_id:
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=test_submission.store_listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Test submission approved",
|
||||
internal_comments="Auto-approved test submission",
|
||||
reviewer_id=test_user["id"],
|
||||
)
|
||||
approved_submissions.append(approved_submission.model_dump())
|
||||
print("✅ Approved test store submission")
|
||||
random_value = random.random()
|
||||
if random_value < 0.4: # 40% chance to approve
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=test_submission.store_listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Test submission approved",
|
||||
internal_comments="Auto-approved test submission",
|
||||
reviewer_id=test_user["id"],
|
||||
)
|
||||
approved_submissions.append(approved_submission.model_dump())
|
||||
print("✅ Approved test store submission")
|
||||
|
||||
await prisma.storelistingversion.update(
|
||||
where={"id": test_submission.store_listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
print("🌟 Marked test agent as FEATURED")
|
||||
# Mark approved submission as featured
|
||||
await prisma.storelistingversion.update(
|
||||
where={"id": test_submission.store_listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
print("🌟 Marked test agent as FEATURED")
|
||||
elif random_value < 0.7: # 30% chance to reject (40% to 70%)
|
||||
await review_store_submission(
|
||||
store_listing_version_id=test_submission.store_listing_version_id,
|
||||
is_approved=False,
|
||||
external_comments="Test submission rejected - needs improvements",
|
||||
internal_comments="Auto-rejected test submission for E2E testing",
|
||||
reviewer_id=test_user["id"],
|
||||
)
|
||||
print("❌ Rejected test store submission")
|
||||
else: # 30% chance to leave pending (70% to 100%)
|
||||
print("⏳ Left test submission pending for review")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating test store submission: {e}")
|
||||
@@ -619,6 +620,7 @@ class TestDataCreator:
|
||||
|
||||
# Create regular submissions for all users
|
||||
for user in self.users:
|
||||
# Get available graphs for this specific user
|
||||
user_graphs = [
|
||||
g for g in self.agent_graphs if g.get("userId") == user["id"]
|
||||
]
|
||||
@@ -629,17 +631,18 @@ class TestDataCreator:
|
||||
)
|
||||
continue
|
||||
|
||||
# Create exactly 4 store submissions per user
|
||||
for submission_index in range(4):
|
||||
graph = random.choice(user_graphs)
|
||||
submission_counter += 1
|
||||
|
||||
try:
|
||||
print(
|
||||
f"Creating store submission for user {user['id']} with graph {graph['id']}"
|
||||
f"Creating store submission for user {user['id']} with graph {graph['id']} (owner: {graph.get('userId')})"
|
||||
)
|
||||
|
||||
# Use the API function to create store submission with correct parameters
|
||||
submission = await create_store_submission(
|
||||
user_id=user["id"],
|
||||
user_id=user["id"], # Must match graph's userId
|
||||
agent_id=graph["id"],
|
||||
agent_version=graph.get("version", 1),
|
||||
slug=faker.slug(),
|
||||
@@ -648,24 +651,22 @@ class TestDataCreator:
|
||||
video_url=get_video_url() if random.random() < 0.3 else None,
|
||||
image_urls=[get_image() for _ in range(3)],
|
||||
description=faker.text(),
|
||||
categories=[get_category()],
|
||||
categories=[
|
||||
get_category()
|
||||
], # Single category from predefined list
|
||||
changes_summary="Initial E2E test submission",
|
||||
)
|
||||
submissions.append(submission.model_dump())
|
||||
print(f"✅ Created store submission: {submission.name}")
|
||||
|
||||
# Randomly approve, reject, or leave pending the submission
|
||||
if submission.store_listing_version_id:
|
||||
# DETERMINISTIC: First N submissions are always approved
|
||||
# First GUARANTEED_FEATURED_AGENTS of those are always featured
|
||||
should_approve = (
|
||||
submission_counter <= GUARANTEED_TOP_AGENTS
|
||||
or random.random() < 0.4
|
||||
)
|
||||
should_feature = featured_count < GUARANTEED_FEATURED_AGENTS
|
||||
|
||||
if should_approve:
|
||||
random_value = random.random()
|
||||
if random_value < 0.4: # 40% chance to approve
|
||||
try:
|
||||
# Pick a random user as the reviewer (admin)
|
||||
reviewer_id = random.choice(self.users)["id"]
|
||||
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=submission.store_listing_version_id,
|
||||
is_approved=True,
|
||||
@@ -680,7 +681,16 @@ class TestDataCreator:
|
||||
f"✅ Approved store submission: {submission.name}"
|
||||
)
|
||||
|
||||
if should_feature:
|
||||
# Mark some agents as featured during creation (30% chance)
|
||||
# More likely for creators and first submissions
|
||||
is_creator = user["id"] in [
|
||||
p.get("userId") for p in self.profiles
|
||||
]
|
||||
feature_chance = (
|
||||
0.5 if is_creator else 0.2
|
||||
) # 50% for creators, 20% for others
|
||||
|
||||
if random.random() < feature_chance:
|
||||
try:
|
||||
await prisma.storelistingversion.update(
|
||||
where={
|
||||
@@ -688,25 +698,8 @@ class TestDataCreator:
|
||||
},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
print(
|
||||
f"🌟 Marked agent as FEATURED ({featured_count}/{GUARANTEED_FEATURED_AGENTS}): {submission.name}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: Could not mark submission as featured: {e}"
|
||||
)
|
||||
elif random.random() < 0.2:
|
||||
try:
|
||||
await prisma.storelistingversion.update(
|
||||
where={
|
||||
"id": submission.store_listing_version_id
|
||||
},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
print(
|
||||
f"🌟 Marked agent as FEATURED (bonus): {submission.name}"
|
||||
f"🌟 Marked agent as FEATURED: {submission.name}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
@@ -717,9 +710,11 @@ class TestDataCreator:
|
||||
print(
|
||||
f"Warning: Could not approve submission {submission.name}: {e}"
|
||||
)
|
||||
elif random.random() < 0.5:
|
||||
elif random_value < 0.7: # 30% chance to reject (40% to 70%)
|
||||
try:
|
||||
# Pick a random user as the reviewer (admin)
|
||||
reviewer_id = random.choice(self.users)["id"]
|
||||
|
||||
await review_store_submission(
|
||||
store_listing_version_id=submission.store_listing_version_id,
|
||||
is_approved=False,
|
||||
@@ -734,7 +729,7 @@ class TestDataCreator:
|
||||
print(
|
||||
f"Warning: Could not reject submission {submission.name}: {e}"
|
||||
)
|
||||
else:
|
||||
else: # 30% chance to leave pending (70% to 100%)
|
||||
print(
|
||||
f"⏳ Left submission pending for review: {submission.name}"
|
||||
)
|
||||
@@ -748,13 +743,9 @@ class TestDataCreator:
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
print("\n📊 Store Submissions Summary:")
|
||||
print(f" Created: {len(submissions)}")
|
||||
print(f" Approved: {len(approved_submissions)}")
|
||||
print(
|
||||
f" Featured: {featured_count} (guaranteed min: {GUARANTEED_FEATURED_AGENTS})"
|
||||
f"Created {len(submissions)} store submissions, approved {len(approved_submissions)}"
|
||||
)
|
||||
|
||||
self.store_submissions = submissions
|
||||
return submissions
|
||||
|
||||
@@ -834,15 +825,12 @@ class TestDataCreator:
|
||||
print(f"✅ Agent blocks available: {len(self.agent_blocks)}")
|
||||
print(f"✅ Agent graphs created: {len(self.agent_graphs)}")
|
||||
print(f"✅ Library agents created: {len(self.library_agents)}")
|
||||
print(f"✅ Creator profiles updated: {len(self.profiles)}")
|
||||
print(f"✅ Store submissions created: {len(self.store_submissions)}")
|
||||
print(f"✅ Creator profiles updated: {len(self.profiles)} (some featured)")
|
||||
print(
|
||||
f"✅ Store submissions created: {len(self.store_submissions)} (some marked as featured during creation)"
|
||||
)
|
||||
print(f"✅ API keys created: {len(self.api_keys)}")
|
||||
print(f"✅ Presets created: {len(self.presets)}")
|
||||
print("\n🎯 Deterministic Guarantees:")
|
||||
print(f" • Featured agents: >= {GUARANTEED_FEATURED_AGENTS}")
|
||||
print(f" • Featured creators: >= {GUARANTEED_FEATURED_CREATORS}")
|
||||
print(f" • Top agents (approved): >= {GUARANTEED_TOP_AGENTS}")
|
||||
print(f" • Library agents per user: >= {MIN_AGENTS_PER_USER}")
|
||||
print("\n🚀 Your E2E test database is ready to use!")
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"use client";
|
||||
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
||||
import { getOnboardingStatus, resolveResponse } from "@/app/api/helpers";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers";
|
||||
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
|
||||
export default function OnboardingPage() {
|
||||
const router = useRouter();
|
||||
@@ -12,10 +13,12 @@ export default function OnboardingPage() {
|
||||
async function redirectToStep() {
|
||||
try {
|
||||
// Check if onboarding is enabled (also gets chat flag for redirect)
|
||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
||||
const { shouldShowOnboarding, isChatEnabled } =
|
||||
await getOnboardingStatus();
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
if (!shouldShowOnboarding) {
|
||||
router.replace("/");
|
||||
router.replace(homepageRoute);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -23,7 +26,7 @@ export default function OnboardingPage() {
|
||||
|
||||
// Handle completed onboarding
|
||||
if (onboarding.completedSteps.includes("GET_RESULTS")) {
|
||||
router.replace("/");
|
||||
router.replace(homepageRoute);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { getOnboardingStatus } from "@/app/api/helpers";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { NextResponse } from "next/server";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { getOnboardingStatus } from "@/app/api/helpers";
|
||||
|
||||
// Handle the callback to complete the user session login
|
||||
export async function GET(request: Request) {
|
||||
@@ -26,12 +27,13 @@ export async function GET(request: Request) {
|
||||
await api.createUser();
|
||||
|
||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
||||
const { shouldShowOnboarding, isChatEnabled } =
|
||||
await getOnboardingStatus();
|
||||
if (shouldShowOnboarding) {
|
||||
next = "/onboarding";
|
||||
revalidatePath("/onboarding", "layout");
|
||||
} else {
|
||||
next = "/";
|
||||
next = getHomepageRoute(isChatEnabled);
|
||||
revalidatePath(next, "layout");
|
||||
}
|
||||
} catch (createUserError) {
|
||||
|
||||
@@ -1,17 +1,6 @@
|
||||
import { OAuthPopupResultMessage } from "./types";
|
||||
import { NextResponse } from "next/server";
|
||||
|
||||
/**
|
||||
* Safely encode a value as JSON for embedding in a script tag.
|
||||
* Escapes characters that could break out of the script context to prevent XSS.
|
||||
*/
|
||||
function safeJsonStringify(value: unknown): string {
|
||||
return JSON.stringify(value)
|
||||
.replace(/</g, "\\u003c")
|
||||
.replace(/>/g, "\\u003e")
|
||||
.replace(/&/g, "\\u0026");
|
||||
}
|
||||
|
||||
// This route is intended to be used as the callback for integration OAuth flows,
|
||||
// controlled by the CredentialsInput component. The CredentialsInput opens the login
|
||||
// page in a pop-up window, which then redirects to this route to close the loop.
|
||||
@@ -34,13 +23,12 @@ export async function GET(request: Request) {
|
||||
console.debug("Sending message to opener:", message);
|
||||
|
||||
// Return a response with the message as JSON and a script to close the window
|
||||
// Use safeJsonStringify to prevent XSS by escaping <, >, and & characters
|
||||
return new NextResponse(
|
||||
`
|
||||
<html>
|
||||
<body>
|
||||
<script>
|
||||
window.opener.postMessage(${safeJsonStringify(message)});
|
||||
window.opener.postMessage(${JSON.stringify(message)});
|
||||
window.close();
|
||||
</script>
|
||||
</body>
|
||||
|
||||
@@ -857,7 +857,7 @@ export const CustomNode = React.memo(
|
||||
})();
|
||||
|
||||
const hasAdvancedFields =
|
||||
data.inputSchema?.properties &&
|
||||
data.inputSchema &&
|
||||
Object.entries(data.inputSchema.properties).some(([key, value]) => {
|
||||
return (
|
||||
value.advanced === true && !data.inputSchema.required?.includes(key)
|
||||
|
||||
@@ -11,6 +11,7 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { usePathname, useSearchParams } from "next/navigation";
|
||||
import { useRef } from "react";
|
||||
import { useCopilotStore } from "../../copilot-page-store";
|
||||
import { useCopilotSessionId } from "../../useCopilotSessionId";
|
||||
import { useMobileDrawer } from "./components/MobileDrawer/useMobileDrawer";
|
||||
@@ -69,16 +70,41 @@ export function useCopilotShell() {
|
||||
});
|
||||
|
||||
const stopStream = useChatStore((s) => s.stopStream);
|
||||
const onStreamComplete = useChatStore((s) => s.onStreamComplete);
|
||||
const isStreaming = useCopilotStore((s) => s.isStreaming);
|
||||
const isCreatingSession = useCopilotStore((s) => s.isCreatingSession);
|
||||
const setIsSwitchingSession = useCopilotStore((s) => s.setIsSwitchingSession);
|
||||
const openInterruptModal = useCopilotStore((s) => s.openInterruptModal);
|
||||
|
||||
function handleSessionClick(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
const pendingActionRef = useRef<(() => void) | null>(null);
|
||||
|
||||
// Stop current stream - SSE reconnection allows resuming later
|
||||
if (currentSessionId) {
|
||||
async function stopCurrentStream() {
|
||||
if (!currentSessionId) return;
|
||||
|
||||
setIsSwitchingSession(true);
|
||||
await new Promise<void>((resolve) => {
|
||||
const unsubscribe = onStreamComplete((completedId) => {
|
||||
if (completedId === currentSessionId) {
|
||||
clearTimeout(timeout);
|
||||
unsubscribe();
|
||||
resolve();
|
||||
}
|
||||
});
|
||||
const timeout = setTimeout(() => {
|
||||
unsubscribe();
|
||||
resolve();
|
||||
}, 3000);
|
||||
stopStream(currentSessionId);
|
||||
}
|
||||
});
|
||||
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(currentSessionId),
|
||||
});
|
||||
setIsSwitchingSession(false);
|
||||
}
|
||||
|
||||
function selectSession(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
if (recentlyCreatedSessionsRef.current.has(sessionId)) {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2GetSessionQueryKey(sessionId),
|
||||
@@ -88,12 +114,7 @@ export function useCopilotShell() {
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function handleNewChatClick() {
|
||||
// Stop current stream - SSE reconnection allows resuming later
|
||||
if (currentSessionId) {
|
||||
stopStream(currentSessionId);
|
||||
}
|
||||
|
||||
function startNewChat() {
|
||||
resetPagination();
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: getGetV2ListSessionsQueryKey(),
|
||||
@@ -102,6 +123,32 @@ export function useCopilotShell() {
|
||||
if (isMobile) handleCloseDrawer();
|
||||
}
|
||||
|
||||
function handleSessionClick(sessionId: string) {
|
||||
if (sessionId === currentSessionId) return;
|
||||
|
||||
if (isStreaming) {
|
||||
pendingActionRef.current = async () => {
|
||||
await stopCurrentStream();
|
||||
selectSession(sessionId);
|
||||
};
|
||||
openInterruptModal(pendingActionRef.current);
|
||||
} else {
|
||||
selectSession(sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
function handleNewChatClick() {
|
||||
if (isStreaming) {
|
||||
pendingActionRef.current = async () => {
|
||||
await stopCurrentStream();
|
||||
startNewChat();
|
||||
};
|
||||
openInterruptModal(pendingActionRef.current);
|
||||
} else {
|
||||
startNewChat();
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isMobile,
|
||||
isDrawerOpen,
|
||||
|
||||
@@ -26,20 +26,8 @@ export function buildCopilotChatUrl(prompt: string): string {
|
||||
|
||||
export function getQuickActions(): string[] {
|
||||
return [
|
||||
"I don't know where to start, just ask me stuff",
|
||||
"I do the same thing every week and it's killing me",
|
||||
"Help me find where I'm wasting my time",
|
||||
"Show me what I can automate",
|
||||
"Design a custom workflow",
|
||||
"Help me with content creation",
|
||||
];
|
||||
}
|
||||
|
||||
export function getInputPlaceholder(width?: number) {
|
||||
if (!width) return "What's your role and what eats up most of your day?";
|
||||
|
||||
if (width < 500) {
|
||||
return "I'm a chef and I hate...";
|
||||
}
|
||||
if (width <= 1080) {
|
||||
return "What's your role and what eats up most of your day?";
|
||||
}
|
||||
return "What's your role and what eats up most of your day? e.g. 'I'm a recruiter and I hate...'";
|
||||
}
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
"use client";
|
||||
import { FeatureFlagPage } from "@/services/feature-flags/FeatureFlagPage";
|
||||
import { Flag } from "@/services/feature-flags/use-get-flag";
|
||||
import { type ReactNode } from "react";
|
||||
import type { ReactNode } from "react";
|
||||
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
||||
|
||||
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<FeatureFlagPage flag={Flag.CHAT} whenDisabled="/library">
|
||||
<CopilotShell>{children}</CopilotShell>
|
||||
</FeatureFlagPage>
|
||||
);
|
||||
return <CopilotShell>{children}</CopilotShell>;
|
||||
}
|
||||
|
||||
@@ -6,9 +6,7 @@ import { Text } from "@/components/atoms/Text/Text";
|
||||
import { Chat } from "@/components/contextual/Chat/Chat";
|
||||
import { ChatInput } from "@/components/contextual/Chat/components/ChatInput/ChatInput";
|
||||
import { Dialog } from "@/components/molecules/Dialog/Dialog";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useCopilotStore } from "./copilot-page-store";
|
||||
import { getInputPlaceholder } from "./helpers";
|
||||
import { useCopilotPage } from "./useCopilotPage";
|
||||
|
||||
export default function CopilotPage() {
|
||||
@@ -16,25 +14,14 @@ export default function CopilotPage() {
|
||||
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||
|
||||
const [inputPlaceholder, setInputPlaceholder] = useState(
|
||||
getInputPlaceholder(),
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const handleResize = () => {
|
||||
setInputPlaceholder(getInputPlaceholder(window.innerWidth));
|
||||
};
|
||||
|
||||
handleResize();
|
||||
|
||||
window.addEventListener("resize", handleResize);
|
||||
return () => window.removeEventListener("resize", handleResize);
|
||||
}, []);
|
||||
|
||||
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
|
||||
state;
|
||||
|
||||
const {
|
||||
greetingName,
|
||||
quickActions,
|
||||
isLoading,
|
||||
hasSession,
|
||||
initialPrompt,
|
||||
isReady,
|
||||
} = state;
|
||||
const {
|
||||
handleQuickAction,
|
||||
startChatWithPrompt,
|
||||
@@ -42,6 +29,8 @@ export default function CopilotPage() {
|
||||
handleStreamingChange,
|
||||
} = handlers;
|
||||
|
||||
if (!isReady) return null;
|
||||
|
||||
if (hasSession) {
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
@@ -92,7 +81,7 @@ export default function CopilotPage() {
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-3 py-5 md:px-6 md:py-10">
|
||||
<div className="flex h-full flex-1 items-center justify-center overflow-y-auto bg-[#f8f8f9] px-6 py-10">
|
||||
<div className="w-full text-center">
|
||||
{isLoading ? (
|
||||
<div className="mx-auto max-w-2xl">
|
||||
@@ -109,25 +98,25 @@ export default function CopilotPage() {
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="mx-auto max-w-3xl">
|
||||
<div className="mx-auto max-w-2xl">
|
||||
<Text
|
||||
variant="h3"
|
||||
className="mb-1 !text-[1.375rem] text-zinc-700"
|
||||
className="mb-3 !text-[1.375rem] text-zinc-700"
|
||||
>
|
||||
Hey, <span className="text-violet-600">{greetingName}</span>
|
||||
</Text>
|
||||
<Text variant="h3" className="mb-8 !font-normal">
|
||||
Tell me about your work — I'll find what to automate.
|
||||
What do you want to automate?
|
||||
</Text>
|
||||
|
||||
<div className="mb-6">
|
||||
<ChatInput
|
||||
onSend={startChatWithPrompt}
|
||||
placeholder={inputPlaceholder}
|
||||
placeholder='You can search or just ask - e.g. "create a blog post outline"'
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-wrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
<div className="flex flex-nowrap items-center justify-center gap-3 overflow-x-auto [-ms-overflow-style:none] [scrollbar-width:none] [&::-webkit-scrollbar]:hidden">
|
||||
{quickActions.map((action) => (
|
||||
<Button
|
||||
key={action}
|
||||
@@ -135,7 +124,7 @@ export default function CopilotPage() {
|
||||
variant="outline"
|
||||
size="small"
|
||||
onClick={() => handleQuickAction(action)}
|
||||
className="h-auto shrink-0 border-zinc-300 px-3 py-2 text-[.9rem] text-zinc-600"
|
||||
className="h-auto shrink-0 border-zinc-600 !px-4 !py-2 text-[1rem] text-zinc-600"
|
||||
>
|
||||
{action}
|
||||
</Button>
|
||||
|
||||
@@ -3,11 +3,18 @@ import {
|
||||
postV2CreateSession,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||
import {
|
||||
Flag,
|
||||
type FlagValues,
|
||||
useGetFlag,
|
||||
} from "@/services/feature-flags/use-get-flag";
|
||||
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { useCopilotStore } from "./copilot-page-store";
|
||||
@@ -26,6 +33,22 @@ export function useCopilotPage() {
|
||||
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
||||
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
||||
|
||||
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
|
||||
useEffect(() => {
|
||||
if (isLoggedIn) {
|
||||
completeStep("VISIT_COPILOT");
|
||||
}
|
||||
}, [completeStep, isLoggedIn]);
|
||||
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const flags = useFlags<FlagValues>();
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||
const isFlagReady =
|
||||
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
||||
|
||||
const greetingName = getGreetingName(user);
|
||||
const quickActions = getQuickActions();
|
||||
|
||||
@@ -35,8 +58,11 @@ export function useCopilotPage() {
|
||||
: undefined;
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoggedIn) completeStep("VISIT_COPILOT");
|
||||
}, [completeStep, isLoggedIn]);
|
||||
if (!isFlagReady) return;
|
||||
if (isChatEnabled === false) {
|
||||
router.replace(homepageRoute);
|
||||
}
|
||||
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
|
||||
|
||||
async function startChatWithPrompt(prompt: string) {
|
||||
if (!prompt?.trim()) return;
|
||||
@@ -90,6 +116,7 @@ export function useCopilotPage() {
|
||||
isLoading: isUserLoading,
|
||||
hasSession,
|
||||
initialPrompt,
|
||||
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
||||
},
|
||||
handlers: {
|
||||
handleQuickAction,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { Suspense } from "react";
|
||||
import { getErrorDetails } from "./helpers";
|
||||
@@ -9,6 +11,8 @@ function ErrorPageContent() {
|
||||
const searchParams = useSearchParams();
|
||||
const errorMessage = searchParams.get("message");
|
||||
const errorDetails = getErrorDetails(errorMessage);
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
function handleRetry() {
|
||||
// Auth-related errors should redirect to login
|
||||
@@ -26,7 +30,7 @@ function ErrorPageContent() {
|
||||
}, 2000);
|
||||
} else {
|
||||
// For server/network errors, go to home
|
||||
window.location.href = "/";
|
||||
window.location.href = homepageRoute;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"use server";
|
||||
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { loginFormSchema } from "@/types/auth";
|
||||
@@ -37,8 +38,10 @@ export async function login(email: string, password: string) {
|
||||
await api.createUser();
|
||||
|
||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
||||
const next = shouldShowOnboarding ? "/onboarding" : "/";
|
||||
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
||||
const next = shouldShowOnboarding
|
||||
? "/onboarding"
|
||||
: getHomepageRoute(isChatEnabled);
|
||||
|
||||
return {
|
||||
success: true,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { environment } from "@/services/environment";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
@@ -20,15 +22,17 @@ export function useLoginPage() {
|
||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||
const isCloudEnv = environment.isCloud();
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
// Get redirect destination from 'next' query parameter
|
||||
const nextUrl = searchParams.get("next");
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoggedIn && !isLoggingIn) {
|
||||
router.push(nextUrl || "/");
|
||||
router.push(nextUrl || homepageRoute);
|
||||
}
|
||||
}, [isLoggedIn, isLoggingIn, nextUrl, router]);
|
||||
}, [homepageRoute, isLoggedIn, isLoggingIn, nextUrl, router]);
|
||||
|
||||
const form = useForm<z.infer<typeof loginFormSchema>>({
|
||||
resolver: zodResolver(loginFormSchema),
|
||||
@@ -94,7 +98,7 @@ export function useLoginPage() {
|
||||
}
|
||||
|
||||
// Prefer URL's next parameter, then use backend-determined route
|
||||
router.replace(nextUrl || result.next || "/");
|
||||
router.replace(nextUrl || result.next || homepageRoute);
|
||||
} catch (error) {
|
||||
toast({
|
||||
title:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"use server";
|
||||
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { signupFormSchema } from "@/types/auth";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
@@ -58,8 +59,10 @@ export async function signup(
|
||||
}
|
||||
|
||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
||||
const next = shouldShowOnboarding ? "/onboarding" : "/";
|
||||
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
||||
const next = shouldShowOnboarding
|
||||
? "/onboarding"
|
||||
: getHomepageRoute(isChatEnabled);
|
||||
|
||||
return { success: true, next };
|
||||
} catch (err) {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { environment } from "@/services/environment";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { LoginProvider, signupFormSchema } from "@/types/auth";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
@@ -20,15 +22,17 @@ export function useSignupPage() {
|
||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||
const isCloudEnv = environment.isCloud();
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
// Get redirect destination from 'next' query parameter
|
||||
const nextUrl = searchParams.get("next");
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoggedIn && !isSigningUp) {
|
||||
router.push(nextUrl || "/");
|
||||
router.push(nextUrl || homepageRoute);
|
||||
}
|
||||
}, [isLoggedIn, isSigningUp, nextUrl, router]);
|
||||
}, [homepageRoute, isLoggedIn, isSigningUp, nextUrl, router]);
|
||||
|
||||
const form = useForm<z.infer<typeof signupFormSchema>>({
|
||||
resolver: zodResolver(signupFormSchema),
|
||||
@@ -129,7 +133,7 @@ export function useSignupPage() {
|
||||
}
|
||||
|
||||
// Prefer the URL's next parameter, then result.next (for onboarding), then default
|
||||
const redirectTo = nextUrl || result.next || "/";
|
||||
const redirectTo = nextUrl || result.next || homepageRoute;
|
||||
router.replace(redirectTo);
|
||||
} catch (error) {
|
||||
setIsLoading(false);
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import { environment } from "@/services/environment";
|
||||
import { getServerAuthToken } from "@/lib/autogpt-server-api/helpers";
|
||||
import { NextRequest } from "next/server";
|
||||
|
||||
/**
|
||||
* SSE Proxy for task stream reconnection.
|
||||
*
|
||||
* This endpoint allows clients to reconnect to an ongoing or recently completed
|
||||
* background task's stream. It replays missed messages from Redis Streams and
|
||||
* subscribes to live updates if the task is still running.
|
||||
*
|
||||
* Client contract:
|
||||
* 1. When receiving an operation_started event, store the task_id
|
||||
* 2. To reconnect: GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
||||
* 3. Messages are replayed from the last_message_id position
|
||||
* 4. Stream ends when "finish" event is received
|
||||
*/
|
||||
export async function GET(
|
||||
request: NextRequest,
|
||||
{ params }: { params: Promise<{ taskId: string }> },
|
||||
) {
|
||||
const { taskId } = await params;
|
||||
const searchParams = request.nextUrl.searchParams;
|
||||
const lastMessageId = searchParams.get("last_message_id") || "0-0";
|
||||
|
||||
try {
|
||||
// Get auth token from server-side session
|
||||
const token = await getServerAuthToken();
|
||||
|
||||
// Build backend URL
|
||||
const backendUrl = environment.getAGPTServerBaseUrl();
|
||||
const streamUrl = new URL(`/api/chat/tasks/${taskId}/stream`, backendUrl);
|
||||
streamUrl.searchParams.set("last_message_id", lastMessageId);
|
||||
|
||||
// Forward request to backend with auth header
|
||||
const headers: Record<string, string> = {
|
||||
Accept: "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
Connection: "keep-alive",
|
||||
};
|
||||
|
||||
if (token) {
|
||||
headers["Authorization"] = `Bearer ${token}`;
|
||||
}
|
||||
|
||||
const response = await fetch(streamUrl.toString(), {
|
||||
method: "GET",
|
||||
headers,
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const error = await response.text();
|
||||
return new Response(error, {
|
||||
status: response.status,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
}
|
||||
|
||||
// Return the SSE stream directly
|
||||
return new Response(response.body, {
|
||||
headers: {
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache, no-transform",
|
||||
Connection: "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Task stream proxy error:", error);
|
||||
return new Response(
|
||||
JSON.stringify({
|
||||
error: "Failed to connect to task stream",
|
||||
detail: error instanceof Error ? error.message : String(error),
|
||||
}),
|
||||
{
|
||||
status: 500,
|
||||
headers: { "Content-Type": "application/json" },
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -181,5 +181,6 @@ export async function getOnboardingStatus() {
|
||||
const isCompleted = onboarding.completedSteps.includes("CONGRATS");
|
||||
return {
|
||||
shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted,
|
||||
isChatEnabled: status.is_chat_enabled,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -917,28 +917,6 @@
|
||||
"security": [{ "HTTPBearerJWT": [] }]
|
||||
}
|
||||
},
|
||||
"/api/chat/config/ttl": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Ttl Config",
|
||||
"description": "Get the stream TTL configuration.\n\nReturns the Time-To-Live settings for chat streams, which determines\nhow long clients can reconnect to an active stream.\n\nReturns:\n dict: TTL configuration with seconds and milliseconds values.",
|
||||
"operationId": "getV2GetTtlConfig",
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"additionalProperties": true,
|
||||
"type": "object",
|
||||
"title": "Response Getv2Getttlconfig"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/health": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
@@ -961,63 +939,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/operations/{operation_id}/complete": {
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Complete Operation",
|
||||
"description": "External completion webhook for long-running operations.\n\nCalled by Agent Generator (or other services) when an operation completes.\nThis triggers the stream registry to publish completion and continue LLM generation.\n\nArgs:\n operation_id: The operation ID to complete.\n request: Completion payload with success status and result/error.\n x_api_key: Internal API key for authentication.\n\nReturns:\n dict: Status of the completion.\n\nRaises:\n HTTPException: If API key is invalid or operation not found.",
|
||||
"operationId": "postV2CompleteOperation",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "operation_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Operation Id" }
|
||||
},
|
||||
{
|
||||
"name": "x-api-key",
|
||||
"in": "header",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "X-Api-Key"
|
||||
}
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"required": true,
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/OperationCompleteRequest"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Response Postv2Completeoperation"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/sessions": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
@@ -1101,7 +1022,7 @@
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Session",
|
||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\nIf there's an active stream for this session, returns the task_id for reconnection.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, including active_stream info if applicable.",
|
||||
"description": "Retrieve the details of a specific chat session.\n\nLooks up a chat session by ID for the given user (if authenticated) and returns all session data including messages.\n\nArgs:\n session_id: The unique identifier for the desired chat session.\n user_id: The optional authenticated user ID, or None for anonymous access.\n\nReturns:\n SessionDetailResponse: Details for the requested session, or None if not found.",
|
||||
"operationId": "getV2GetSession",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
@@ -1236,7 +1157,7 @@
|
||||
"post": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Stream Chat Post",
|
||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nThe AI generation runs in a background task that continues even if the client disconnects.\nAll chunks are written to Redis for reconnection support. If the client disconnects,\nthey can reconnect using GET /tasks/{task_id}/stream to resume from where they left off.\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks. First chunk is a \"start\" event\n containing the task_id for reconnection.",
|
||||
"description": "Stream chat responses for a session (POST with context support).\n\nStreams the AI/completion responses in real time over Server-Sent Events (SSE), including:\n - Text fragments as they are generated\n - Tool call UI elements (if invoked)\n - Tool execution results\n\nArgs:\n session_id: The chat session identifier to associate with the streamed messages.\n request: Request body containing message, is_user_message, and optional context.\n user_id: Optional authenticated user ID.\nReturns:\n StreamingResponse: SSE-formatted response chunks.",
|
||||
"operationId": "postV2StreamChatPost",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
@@ -1274,94 +1195,6 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/tasks/{task_id}": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Get Task Status",
|
||||
"description": "Get the status of a long-running task.\n\nArgs:\n task_id: The task ID to check.\n user_id: Authenticated user ID for ownership validation.\n\nReturns:\n dict: Task status including task_id, status, tool_name, and operation_id.\n\nRaises:\n NotFoundError: If task_id is not found or user doesn't have access.",
|
||||
"operationId": "getV2GetTaskStatus",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "task_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Task Id" }
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Response Getv2Gettaskstatus"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/chat/tasks/{task_id}/stream": {
|
||||
"get": {
|
||||
"tags": ["v2", "chat", "chat"],
|
||||
"summary": "Stream Task",
|
||||
"description": "Reconnect to a long-running task's SSE stream.\n\nWhen a long-running operation (like agent generation) starts, the client\nreceives a task_id. If the connection drops, the client can reconnect\nusing this endpoint to resume receiving updates.\n\nArgs:\n task_id: The task ID from the operation_started response.\n user_id: Authenticated user ID for ownership validation.\n last_message_id: Last Redis Stream message ID received (\"0-0\" for full replay).\n\nReturns:\n StreamingResponse: SSE-formatted response chunks starting after last_message_id.\n\nRaises:\n HTTPException: 404 if task not found, 410 if task expired, 403 if access denied.",
|
||||
"operationId": "getV2StreamTask",
|
||||
"security": [{ "HTTPBearerJWT": [] }],
|
||||
"parameters": [
|
||||
{
|
||||
"name": "task_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": { "type": "string", "title": "Task Id" }
|
||||
},
|
||||
{
|
||||
"name": "last_message_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay.",
|
||||
"default": "0-0",
|
||||
"title": "Last Message Id"
|
||||
},
|
||||
"description": "Last Redis Stream message ID received (e.g., '1706540123456-0'). Use '0-0' for full replay."
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": { "application/json": { "schema": {} } }
|
||||
},
|
||||
"401": {
|
||||
"$ref": "#/components/responses/HTTP401NotAuthenticatedError"
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": { "$ref": "#/components/schemas/HTTPValidationError" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/api/credits": {
|
||||
"get": {
|
||||
"tags": ["v1", "credits"],
|
||||
@@ -6335,18 +6168,6 @@
|
||||
"title": "AccuracyTrendsResponse",
|
||||
"description": "Response model for accuracy trends and alerts."
|
||||
},
|
||||
"ActiveStreamInfo": {
|
||||
"properties": {
|
||||
"task_id": { "type": "string", "title": "Task Id" },
|
||||
"last_message_id": { "type": "string", "title": "Last Message Id" },
|
||||
"operation_id": { "type": "string", "title": "Operation Id" },
|
||||
"tool_name": { "type": "string", "title": "Tool Name" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["task_id", "last_message_id", "operation_id", "tool_name"],
|
||||
"title": "ActiveStreamInfo",
|
||||
"description": "Information about an active stream for reconnection."
|
||||
},
|
||||
"AddUserCreditsResponse": {
|
||||
"properties": {
|
||||
"new_balance": { "type": "integer", "title": "New Balance" },
|
||||
@@ -8160,25 +7981,6 @@
|
||||
]
|
||||
},
|
||||
"new_output": { "type": "boolean", "title": "New Output" },
|
||||
"execution_count": {
|
||||
"type": "integer",
|
||||
"title": "Execution Count",
|
||||
"default": 0
|
||||
},
|
||||
"success_rate": {
|
||||
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||
"title": "Success Rate"
|
||||
},
|
||||
"avg_correctness_score": {
|
||||
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||
"title": "Avg Correctness Score"
|
||||
},
|
||||
"recent_executions": {
|
||||
"items": { "$ref": "#/components/schemas/RecentExecution" },
|
||||
"type": "array",
|
||||
"title": "Recent Executions",
|
||||
"description": "List of recent executions with status, score, and summary"
|
||||
},
|
||||
"can_access_graph": {
|
||||
"type": "boolean",
|
||||
"title": "Can Access Graph"
|
||||
@@ -9002,27 +8804,6 @@
|
||||
],
|
||||
"title": "OnboardingStep"
|
||||
},
|
||||
"OperationCompleteRequest": {
|
||||
"properties": {
|
||||
"success": { "type": "boolean", "title": "Success" },
|
||||
"result": {
|
||||
"anyOf": [
|
||||
{ "additionalProperties": true, "type": "object" },
|
||||
{ "type": "string" },
|
||||
{ "type": "null" }
|
||||
],
|
||||
"title": "Result"
|
||||
},
|
||||
"error": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Error"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["success"],
|
||||
"title": "OperationCompleteRequest",
|
||||
"description": "Request model for external completion webhook."
|
||||
},
|
||||
"Pagination": {
|
||||
"properties": {
|
||||
"total_items": {
|
||||
@@ -9593,23 +9374,6 @@
|
||||
"required": ["providers", "pagination"],
|
||||
"title": "ProviderResponse"
|
||||
},
|
||||
"RecentExecution": {
|
||||
"properties": {
|
||||
"status": { "type": "string", "title": "Status" },
|
||||
"correctness_score": {
|
||||
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||
"title": "Correctness Score"
|
||||
},
|
||||
"activity_summary": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Activity Summary"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["status"],
|
||||
"title": "RecentExecution",
|
||||
"description": "Summary of a recent execution for quality assessment.\n\nUsed by the LLM to understand the agent's recent performance with specific examples\nrather than just aggregate statistics."
|
||||
},
|
||||
"RefundRequest": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
@@ -9878,12 +9642,6 @@
|
||||
"items": { "additionalProperties": true, "type": "object" },
|
||||
"type": "array",
|
||||
"title": "Messages"
|
||||
},
|
||||
"active_stream": {
|
||||
"anyOf": [
|
||||
{ "$ref": "#/components/schemas/ActiveStreamInfo" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
@@ -10039,8 +9797,7 @@
|
||||
"sub_heading": { "type": "string", "title": "Sub Heading" },
|
||||
"description": { "type": "string", "title": "Description" },
|
||||
"runs": { "type": "integer", "title": "Runs" },
|
||||
"rating": { "type": "number", "title": "Rating" },
|
||||
"agent_graph_id": { "type": "string", "title": "Agent Graph Id" }
|
||||
"rating": { "type": "number", "title": "Rating" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
@@ -10052,8 +9809,7 @@
|
||||
"sub_heading",
|
||||
"description",
|
||||
"runs",
|
||||
"rating",
|
||||
"agent_graph_id"
|
||||
"rating"
|
||||
],
|
||||
"title": "StoreAgent"
|
||||
},
|
||||
|
||||
@@ -1,15 +1,27 @@
|
||||
"use client";
|
||||
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
|
||||
export default function Page() {
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const router = useRouter();
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||
const isFlagReady =
|
||||
!isLaunchDarklyConfigured || typeof isChatEnabled === "boolean";
|
||||
|
||||
useEffect(() => {
|
||||
router.replace("/copilot");
|
||||
}, [router]);
|
||||
useEffect(
|
||||
function redirectToHomepage() {
|
||||
if (!isFlagReady) return;
|
||||
router.replace(homepageRoute);
|
||||
},
|
||||
[homepageRoute, isFlagReady, router],
|
||||
);
|
||||
|
||||
return <LoadingSpinner size="large" cover />;
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useCopilotSessionId } from "@/app/(platform)/copilot/useCopilotSessionId";
|
||||
import { useCopilotStore } from "@/app/(platform)/copilot/copilot-page-store";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
@@ -24,8 +25,8 @@ export function Chat({
|
||||
}: ChatProps) {
|
||||
const { urlSessionId } = useCopilotSessionId();
|
||||
const hasHandledNotFoundRef = useRef(false);
|
||||
const isSwitchingSession = useCopilotStore((s) => s.isSwitchingSession);
|
||||
const {
|
||||
session,
|
||||
messages,
|
||||
isLoading,
|
||||
isCreating,
|
||||
@@ -37,18 +38,6 @@ export function Chat({
|
||||
startPollingForOperation,
|
||||
} = useChat({ urlSessionId });
|
||||
|
||||
// Extract active stream info for reconnection
|
||||
const activeStream = (
|
||||
session as {
|
||||
active_stream?: {
|
||||
task_id: string;
|
||||
last_message_id: string;
|
||||
operation_id: string;
|
||||
tool_name: string;
|
||||
};
|
||||
}
|
||||
)?.active_stream;
|
||||
|
||||
useEffect(() => {
|
||||
if (!onSessionNotFound) return;
|
||||
if (!urlSessionId) return;
|
||||
@@ -64,7 +53,8 @@ export function Chat({
|
||||
isCreating,
|
||||
]);
|
||||
|
||||
const shouldShowLoader = showLoader && (isLoading || isCreating);
|
||||
const shouldShowLoader =
|
||||
(showLoader && (isLoading || isCreating)) || isSwitchingSession;
|
||||
|
||||
return (
|
||||
<div className={cn("flex h-full flex-col", className)}>
|
||||
@@ -76,19 +66,21 @@ export function Chat({
|
||||
<div className="flex flex-col items-center gap-3">
|
||||
<LoadingSpinner size="large" className="text-neutral-400" />
|
||||
<Text variant="body" className="text-zinc-500">
|
||||
Loading your chat...
|
||||
{isSwitchingSession
|
||||
? "Switching chat..."
|
||||
: "Loading your chat..."}
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Error State */}
|
||||
{error && !isLoading && (
|
||||
{error && !isLoading && !isSwitchingSession && (
|
||||
<ChatErrorState error={error} onRetry={createSession} />
|
||||
)}
|
||||
|
||||
{/* Session Content */}
|
||||
{sessionId && !isLoading && !error && (
|
||||
{sessionId && !isLoading && !error && !isSwitchingSession && (
|
||||
<ChatContainer
|
||||
sessionId={sessionId}
|
||||
initialMessages={messages}
|
||||
@@ -96,16 +88,6 @@ export function Chat({
|
||||
className="flex-1"
|
||||
onStreamingChange={onStreamingChange}
|
||||
onOperationStarted={startPollingForOperation}
|
||||
activeStream={
|
||||
activeStream
|
||||
? {
|
||||
taskId: activeStream.task_id,
|
||||
lastMessageId: activeStream.last_message_id,
|
||||
operationId: activeStream.operation_id,
|
||||
toolName: activeStream.tool_name,
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</main>
|
||||
|
||||
@@ -1,159 +0,0 @@
|
||||
# SSE Reconnection Contract for Long-Running Operations
|
||||
|
||||
This document describes the client-side contract for handling SSE (Server-Sent Events) disconnections and reconnecting to long-running background tasks.
|
||||
|
||||
## Overview
|
||||
|
||||
When a user triggers a long-running operation (like agent generation), the backend:
|
||||
|
||||
1. Spawns a background task that survives SSE disconnections
|
||||
2. Returns an `operation_started` response with a `task_id`
|
||||
3. Stores stream messages in Redis Streams for replay
|
||||
|
||||
Clients can reconnect to the task stream at any time to receive missed messages.
|
||||
|
||||
## Client-Side Flow
|
||||
|
||||
### 1. Receiving Operation Started
|
||||
|
||||
When you receive an `operation_started` tool response:
|
||||
|
||||
```typescript
|
||||
// The response includes a task_id for reconnection
|
||||
{
|
||||
type: "operation_started",
|
||||
tool_name: "generate_agent",
|
||||
operation_id: "uuid-...",
|
||||
task_id: "task-uuid-...", // <-- Store this for reconnection
|
||||
message: "Operation started. You can close this tab."
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Storing Task Info
|
||||
|
||||
Use the chat store to track the active task:
|
||||
|
||||
```typescript
|
||||
import { useChatStore } from "./chat-store";
|
||||
|
||||
// When operation_started is received:
|
||||
useChatStore.getState().setActiveTask(sessionId, {
|
||||
taskId: response.task_id,
|
||||
operationId: response.operation_id,
|
||||
toolName: response.tool_name,
|
||||
lastMessageId: "0",
|
||||
});
|
||||
```
|
||||
|
||||
### 3. Reconnecting to a Task
|
||||
|
||||
To reconnect (e.g., after page refresh or tab reopen):
|
||||
|
||||
```typescript
|
||||
const { reconnectToTask, getActiveTask } = useChatStore.getState();
|
||||
|
||||
// Check if there's an active task for this session
|
||||
const activeTask = getActiveTask(sessionId);
|
||||
|
||||
if (activeTask) {
|
||||
// Reconnect to the task stream
|
||||
await reconnectToTask(
|
||||
sessionId,
|
||||
activeTask.taskId,
|
||||
activeTask.lastMessageId, // Resume from last position
|
||||
(chunk) => {
|
||||
// Handle incoming chunks
|
||||
console.log("Received chunk:", chunk);
|
||||
},
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Tracking Message Position
|
||||
|
||||
To enable precise replay, update the last message ID as chunks arrive:
|
||||
|
||||
```typescript
|
||||
const { updateTaskLastMessageId } = useChatStore.getState();
|
||||
|
||||
function handleChunk(chunk: StreamChunk) {
|
||||
// If chunk has an index/id, track it
|
||||
if (chunk.idx !== undefined) {
|
||||
updateTaskLastMessageId(sessionId, String(chunk.idx));
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Task Stream Reconnection
|
||||
|
||||
```
|
||||
GET /api/chat/tasks/{taskId}/stream?last_message_id={idx}
|
||||
```
|
||||
|
||||
- `taskId`: The task ID from `operation_started`
|
||||
- `last_message_id`: Last received message index (default: "0" for full replay)
|
||||
|
||||
Returns: SSE stream of missed messages + live updates
|
||||
|
||||
## Chunk Types
|
||||
|
||||
The reconnected stream follows the same Vercel AI SDK protocol:
|
||||
|
||||
| Type | Description |
|
||||
| ----------------------- | ----------------------- |
|
||||
| `start` | Message lifecycle start |
|
||||
| `text-delta` | Streaming text content |
|
||||
| `text-end` | Text block completed |
|
||||
| `tool-output-available` | Tool result available |
|
||||
| `finish` | Stream completed |
|
||||
| `error` | Error occurred |
|
||||
|
||||
## Error Handling
|
||||
|
||||
If reconnection fails:
|
||||
|
||||
1. Check if task still exists (may have expired - default TTL: 1 hour)
|
||||
2. Fall back to polling the session for final state
|
||||
3. Show appropriate UI message to user
|
||||
|
||||
## Persistence Considerations
|
||||
|
||||
For robust reconnection across browser restarts:
|
||||
|
||||
```typescript
|
||||
// Store in localStorage/sessionStorage
|
||||
const ACTIVE_TASKS_KEY = "chat_active_tasks";
|
||||
|
||||
function persistActiveTask(sessionId: string, task: ActiveTaskInfo) {
|
||||
const tasks = JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
||||
tasks[sessionId] = task;
|
||||
localStorage.setItem(ACTIVE_TASKS_KEY, JSON.stringify(tasks));
|
||||
}
|
||||
|
||||
function loadPersistedTasks(): Record<string, ActiveTaskInfo> {
|
||||
return JSON.parse(localStorage.getItem(ACTIVE_TASKS_KEY) || "{}");
|
||||
}
|
||||
```
|
||||
|
||||
## Backend Configuration
|
||||
|
||||
The following backend settings affect reconnection behavior:
|
||||
|
||||
| Setting | Default | Description |
|
||||
| ------------------- | ------- | ---------------------------------- |
|
||||
| `stream_ttl` | 3600s | How long streams are kept in Redis |
|
||||
| `stream_max_length` | 1000 | Max messages per stream |
|
||||
|
||||
## Testing
|
||||
|
||||
To test reconnection locally:
|
||||
|
||||
1. Start a long-running operation (e.g., agent generation)
|
||||
2. Note the `task_id` from the `operation_started` response
|
||||
3. Close the browser tab
|
||||
4. Reopen and call `reconnectToTask` with the saved `task_id`
|
||||
5. Verify that missed messages are replayed
|
||||
|
||||
See the main README for full local development setup.
|
||||
@@ -1,16 +0,0 @@
|
||||
/**
|
||||
* Constants for the chat system.
|
||||
*
|
||||
* Centralizes magic strings and values used across chat components.
|
||||
*/
|
||||
|
||||
// LocalStorage keys
|
||||
export const STORAGE_KEY_ACTIVE_TASKS = "chat_active_tasks";
|
||||
|
||||
// Redis Stream IDs
|
||||
export const INITIAL_MESSAGE_ID = "0";
|
||||
export const INITIAL_STREAM_ID = "0-0";
|
||||
|
||||
// TTL values (in milliseconds)
|
||||
export const COMPLETED_STREAM_TTL_MS = 5 * 60 * 1000; // 5 minutes
|
||||
export const ACTIVE_TASK_TTL_MS = 60 * 60 * 1000; // 1 hour
|
||||
@@ -1,12 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { create } from "zustand";
|
||||
import {
|
||||
ACTIVE_TASK_TTL_MS,
|
||||
COMPLETED_STREAM_TTL_MS,
|
||||
INITIAL_STREAM_ID,
|
||||
STORAGE_KEY_ACTIVE_TASKS,
|
||||
} from "./chat-constants";
|
||||
import type {
|
||||
ActiveStream,
|
||||
StreamChunk,
|
||||
@@ -14,59 +8,15 @@ import type {
|
||||
StreamResult,
|
||||
StreamStatus,
|
||||
} from "./chat-types";
|
||||
import { executeStream, executeTaskReconnect } from "./stream-executor";
|
||||
import { executeStream } from "./stream-executor";
|
||||
|
||||
export interface ActiveTaskInfo {
|
||||
taskId: string;
|
||||
sessionId: string;
|
||||
operationId: string;
|
||||
toolName: string;
|
||||
lastMessageId: string;
|
||||
startedAt: number;
|
||||
}
|
||||
|
||||
/** Load active tasks from localStorage */
|
||||
function loadPersistedTasks(): Map<string, ActiveTaskInfo> {
|
||||
if (typeof window === "undefined") return new Map();
|
||||
try {
|
||||
const stored = localStorage.getItem(STORAGE_KEY_ACTIVE_TASKS);
|
||||
if (!stored) return new Map();
|
||||
const parsed = JSON.parse(stored) as Record<string, ActiveTaskInfo>;
|
||||
const now = Date.now();
|
||||
const tasks = new Map<string, ActiveTaskInfo>();
|
||||
// Filter out expired tasks
|
||||
for (const [sessionId, task] of Object.entries(parsed)) {
|
||||
if (now - task.startedAt < ACTIVE_TASK_TTL_MS) {
|
||||
tasks.set(sessionId, task);
|
||||
}
|
||||
}
|
||||
return tasks;
|
||||
} catch {
|
||||
return new Map();
|
||||
}
|
||||
}
|
||||
|
||||
/** Save active tasks to localStorage */
|
||||
function persistTasks(tasks: Map<string, ActiveTaskInfo>): void {
|
||||
if (typeof window === "undefined") return;
|
||||
try {
|
||||
const obj: Record<string, ActiveTaskInfo> = {};
|
||||
for (const [sessionId, task] of tasks) {
|
||||
obj[sessionId] = task;
|
||||
}
|
||||
localStorage.setItem(STORAGE_KEY_ACTIVE_TASKS, JSON.stringify(obj));
|
||||
} catch {
|
||||
// Ignore storage errors
|
||||
}
|
||||
}
|
||||
const COMPLETED_STREAM_TTL = 5 * 60 * 1000; // 5 minutes
|
||||
|
||||
interface ChatStoreState {
|
||||
activeStreams: Map<string, ActiveStream>;
|
||||
completedStreams: Map<string, StreamResult>;
|
||||
activeSessions: Set<string>;
|
||||
streamCompleteCallbacks: Set<StreamCompleteCallback>;
|
||||
/** Active tasks for SSE reconnection - keyed by sessionId */
|
||||
activeTasks: Map<string, ActiveTaskInfo>;
|
||||
}
|
||||
|
||||
interface ChatStoreActions {
|
||||
@@ -91,24 +41,6 @@ interface ChatStoreActions {
|
||||
unregisterActiveSession: (sessionId: string) => void;
|
||||
isSessionActive: (sessionId: string) => boolean;
|
||||
onStreamComplete: (callback: StreamCompleteCallback) => () => void;
|
||||
/** Track active task for SSE reconnection */
|
||||
setActiveTask: (
|
||||
sessionId: string,
|
||||
taskInfo: Omit<ActiveTaskInfo, "sessionId" | "startedAt">,
|
||||
) => void;
|
||||
/** Get active task for a session */
|
||||
getActiveTask: (sessionId: string) => ActiveTaskInfo | undefined;
|
||||
/** Clear active task when operation completes */
|
||||
clearActiveTask: (sessionId: string) => void;
|
||||
/** Reconnect to an existing task stream */
|
||||
reconnectToTask: (
|
||||
sessionId: string,
|
||||
taskId: string,
|
||||
lastMessageId?: string,
|
||||
onChunk?: (chunk: StreamChunk) => void,
|
||||
) => Promise<void>;
|
||||
/** Update last message ID for a task (for tracking replay position) */
|
||||
updateTaskLastMessageId: (sessionId: string, lastMessageId: string) => void;
|
||||
}
|
||||
|
||||
type ChatStore = ChatStoreState & ChatStoreActions;
|
||||
@@ -132,126 +64,18 @@ function cleanupExpiredStreams(
|
||||
const now = Date.now();
|
||||
const cleaned = new Map(completedStreams);
|
||||
for (const [sessionId, result] of cleaned) {
|
||||
if (now - result.completedAt > COMPLETED_STREAM_TTL_MS) {
|
||||
if (now - result.completedAt > COMPLETED_STREAM_TTL) {
|
||||
cleaned.delete(sessionId);
|
||||
}
|
||||
}
|
||||
return cleaned;
|
||||
}
|
||||
|
||||
/**
|
||||
* Finalize a stream by moving it from activeStreams to completedStreams.
|
||||
* Also handles cleanup and notifications.
|
||||
*/
|
||||
function finalizeStream(
|
||||
sessionId: string,
|
||||
stream: ActiveStream,
|
||||
onChunk: ((chunk: StreamChunk) => void) | undefined,
|
||||
get: () => ChatStoreState & ChatStoreActions,
|
||||
set: (state: Partial<ChatStoreState>) => void,
|
||||
): void {
|
||||
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
||||
|
||||
if (stream.status !== "streaming") {
|
||||
const currentState = get();
|
||||
const finalActiveStreams = new Map(currentState.activeStreams);
|
||||
let finalCompletedStreams = new Map(currentState.completedStreams);
|
||||
|
||||
const storedStream = finalActiveStreams.get(sessionId);
|
||||
if (storedStream === stream) {
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: stream.status,
|
||||
chunks: stream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: stream.error,
|
||||
};
|
||||
finalCompletedStreams.set(sessionId, result);
|
||||
finalActiveStreams.delete(sessionId);
|
||||
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
||||
set({
|
||||
activeStreams: finalActiveStreams,
|
||||
completedStreams: finalCompletedStreams,
|
||||
});
|
||||
|
||||
if (stream.status === "completed" || stream.status === "error") {
|
||||
notifyStreamComplete(currentState.streamCompleteCallbacks, sessionId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up an existing stream for a session and move it to completed streams.
|
||||
* Returns updated maps for both active and completed streams.
|
||||
*/
|
||||
function cleanupExistingStream(
|
||||
sessionId: string,
|
||||
activeStreams: Map<string, ActiveStream>,
|
||||
completedStreams: Map<string, StreamResult>,
|
||||
callbacks: Set<StreamCompleteCallback>,
|
||||
): {
|
||||
activeStreams: Map<string, ActiveStream>;
|
||||
completedStreams: Map<string, StreamResult>;
|
||||
} {
|
||||
const newActiveStreams = new Map(activeStreams);
|
||||
let newCompletedStreams = new Map(completedStreams);
|
||||
|
||||
const existingStream = newActiveStreams.get(sessionId);
|
||||
if (existingStream) {
|
||||
existingStream.abortController.abort();
|
||||
const normalizedStatus =
|
||||
existingStream.status === "streaming"
|
||||
? "completed"
|
||||
: existingStream.status;
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: normalizedStatus,
|
||||
chunks: existingStream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: existingStream.error,
|
||||
};
|
||||
newCompletedStreams.set(sessionId, result);
|
||||
newActiveStreams.delete(sessionId);
|
||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
||||
notifyStreamComplete(callbacks, sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
activeStreams: newActiveStreams,
|
||||
completedStreams: newCompletedStreams,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a new active stream with initial state.
|
||||
*/
|
||||
function createActiveStream(
|
||||
sessionId: string,
|
||||
onChunk?: (chunk: StreamChunk) => void,
|
||||
): ActiveStream {
|
||||
const abortController = new AbortController();
|
||||
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
||||
if (onChunk) initialCallbacks.add(onChunk);
|
||||
|
||||
return {
|
||||
sessionId,
|
||||
abortController,
|
||||
status: "streaming",
|
||||
startedAt: Date.now(),
|
||||
chunks: [],
|
||||
onChunkCallbacks: initialCallbacks,
|
||||
};
|
||||
}
|
||||
|
||||
export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
activeStreams: new Map(),
|
||||
completedStreams: new Map(),
|
||||
activeSessions: new Set(),
|
||||
streamCompleteCallbacks: new Set(),
|
||||
activeTasks: loadPersistedTasks(),
|
||||
|
||||
startStream: async function startStream(
|
||||
sessionId,
|
||||
@@ -261,21 +85,45 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
onChunk,
|
||||
) {
|
||||
const state = get();
|
||||
const newActiveStreams = new Map(state.activeStreams);
|
||||
let newCompletedStreams = new Map(state.completedStreams);
|
||||
const callbacks = state.streamCompleteCallbacks;
|
||||
|
||||
// Clean up any existing stream for this session
|
||||
const {
|
||||
activeStreams: newActiveStreams,
|
||||
completedStreams: newCompletedStreams,
|
||||
} = cleanupExistingStream(
|
||||
sessionId,
|
||||
state.activeStreams,
|
||||
state.completedStreams,
|
||||
callbacks,
|
||||
);
|
||||
const existingStream = newActiveStreams.get(sessionId);
|
||||
if (existingStream) {
|
||||
existingStream.abortController.abort();
|
||||
const normalizedStatus =
|
||||
existingStream.status === "streaming"
|
||||
? "completed"
|
||||
: existingStream.status;
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: normalizedStatus,
|
||||
chunks: existingStream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: existingStream.error,
|
||||
};
|
||||
newCompletedStreams.set(sessionId, result);
|
||||
newActiveStreams.delete(sessionId);
|
||||
newCompletedStreams = cleanupExpiredStreams(newCompletedStreams);
|
||||
if (normalizedStatus === "completed" || normalizedStatus === "error") {
|
||||
notifyStreamComplete(callbacks, sessionId);
|
||||
}
|
||||
}
|
||||
|
||||
const abortController = new AbortController();
|
||||
const initialCallbacks = new Set<(chunk: StreamChunk) => void>();
|
||||
if (onChunk) initialCallbacks.add(onChunk);
|
||||
|
||||
const stream: ActiveStream = {
|
||||
sessionId,
|
||||
abortController,
|
||||
status: "streaming",
|
||||
startedAt: Date.now(),
|
||||
chunks: [],
|
||||
onChunkCallbacks: initialCallbacks,
|
||||
};
|
||||
|
||||
// Create new stream
|
||||
const stream = createActiveStream(sessionId, onChunk);
|
||||
newActiveStreams.set(sessionId, stream);
|
||||
set({
|
||||
activeStreams: newActiveStreams,
|
||||
@@ -285,7 +133,36 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
try {
|
||||
await executeStream(stream, message, isUserMessage, context);
|
||||
} finally {
|
||||
finalizeStream(sessionId, stream, onChunk, get, set);
|
||||
if (onChunk) stream.onChunkCallbacks.delete(onChunk);
|
||||
if (stream.status !== "streaming") {
|
||||
const currentState = get();
|
||||
const finalActiveStreams = new Map(currentState.activeStreams);
|
||||
let finalCompletedStreams = new Map(currentState.completedStreams);
|
||||
|
||||
const storedStream = finalActiveStreams.get(sessionId);
|
||||
if (storedStream === stream) {
|
||||
const result: StreamResult = {
|
||||
sessionId,
|
||||
status: stream.status,
|
||||
chunks: stream.chunks,
|
||||
completedAt: Date.now(),
|
||||
error: stream.error,
|
||||
};
|
||||
finalCompletedStreams.set(sessionId, result);
|
||||
finalActiveStreams.delete(sessionId);
|
||||
finalCompletedStreams = cleanupExpiredStreams(finalCompletedStreams);
|
||||
set({
|
||||
activeStreams: finalActiveStreams,
|
||||
completedStreams: finalCompletedStreams,
|
||||
});
|
||||
if (stream.status === "completed" || stream.status === "error") {
|
||||
notifyStreamComplete(
|
||||
currentState.streamCompleteCallbacks,
|
||||
sessionId,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -409,93 +286,4 @@ export const useChatStore = create<ChatStore>((set, get) => ({
|
||||
set({ streamCompleteCallbacks: cleanedCallbacks });
|
||||
};
|
||||
},
|
||||
|
||||
setActiveTask: function setActiveTask(sessionId, taskInfo) {
|
||||
const state = get();
|
||||
const newActiveTasks = new Map(state.activeTasks);
|
||||
newActiveTasks.set(sessionId, {
|
||||
...taskInfo,
|
||||
sessionId,
|
||||
startedAt: Date.now(),
|
||||
});
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
},
|
||||
|
||||
getActiveTask: function getActiveTask(sessionId) {
|
||||
return get().activeTasks.get(sessionId);
|
||||
},
|
||||
|
||||
clearActiveTask: function clearActiveTask(sessionId) {
|
||||
const state = get();
|
||||
if (!state.activeTasks.has(sessionId)) return;
|
||||
|
||||
const newActiveTasks = new Map(state.activeTasks);
|
||||
newActiveTasks.delete(sessionId);
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
},
|
||||
|
||||
reconnectToTask: async function reconnectToTask(
|
||||
sessionId,
|
||||
taskId,
|
||||
lastMessageId = INITIAL_STREAM_ID,
|
||||
onChunk,
|
||||
) {
|
||||
const state = get();
|
||||
const callbacks = state.streamCompleteCallbacks;
|
||||
|
||||
// Clean up any existing stream for this session
|
||||
const {
|
||||
activeStreams: newActiveStreams,
|
||||
completedStreams: newCompletedStreams,
|
||||
} = cleanupExistingStream(
|
||||
sessionId,
|
||||
state.activeStreams,
|
||||
state.completedStreams,
|
||||
callbacks,
|
||||
);
|
||||
|
||||
// Create new stream for reconnection
|
||||
const stream = createActiveStream(sessionId, onChunk);
|
||||
newActiveStreams.set(sessionId, stream);
|
||||
set({
|
||||
activeStreams: newActiveStreams,
|
||||
completedStreams: newCompletedStreams,
|
||||
});
|
||||
|
||||
try {
|
||||
await executeTaskReconnect(stream, taskId, lastMessageId);
|
||||
} finally {
|
||||
finalizeStream(sessionId, stream, onChunk, get, set);
|
||||
|
||||
// Clear active task on completion
|
||||
if (stream.status === "completed" || stream.status === "error") {
|
||||
const taskState = get();
|
||||
if (taskState.activeTasks.has(sessionId)) {
|
||||
const newActiveTasks = new Map(taskState.activeTasks);
|
||||
newActiveTasks.delete(sessionId);
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
updateTaskLastMessageId: function updateTaskLastMessageId(
|
||||
sessionId,
|
||||
lastMessageId,
|
||||
) {
|
||||
const state = get();
|
||||
const task = state.activeTasks.get(sessionId);
|
||||
if (!task) return;
|
||||
|
||||
const newActiveTasks = new Map(state.activeTasks);
|
||||
newActiveTasks.set(sessionId, {
|
||||
...task,
|
||||
lastMessageId,
|
||||
});
|
||||
set({ activeTasks: newActiveTasks });
|
||||
persistTasks(newActiveTasks);
|
||||
},
|
||||
}));
|
||||
|
||||
@@ -4,7 +4,6 @@ export type StreamStatus = "idle" | "streaming" | "completed" | "error";
|
||||
|
||||
export interface StreamChunk {
|
||||
type:
|
||||
| "stream_start"
|
||||
| "text_chunk"
|
||||
| "text_ended"
|
||||
| "tool_call"
|
||||
@@ -16,7 +15,6 @@ export interface StreamChunk {
|
||||
| "error"
|
||||
| "usage"
|
||||
| "stream_end";
|
||||
taskId?: string;
|
||||
timestamp?: string;
|
||||
content?: string;
|
||||
message?: string;
|
||||
@@ -43,7 +41,7 @@ export interface StreamChunk {
|
||||
}
|
||||
|
||||
export type VercelStreamChunk =
|
||||
| { type: "start"; messageId: string; taskId?: string }
|
||||
| { type: "start"; messageId: string }
|
||||
| { type: "finish" }
|
||||
| { type: "text-start"; id: string }
|
||||
| { type: "text-delta"; id: string; delta: string }
|
||||
@@ -94,70 +92,3 @@ export interface StreamResult {
|
||||
}
|
||||
|
||||
export type StreamCompleteCallback = (sessionId: string) => void;
|
||||
|
||||
// Type guards for message types
|
||||
|
||||
/**
|
||||
* Check if a message has a toolId property.
|
||||
*/
|
||||
export function hasToolId<T extends { type: string }>(
|
||||
msg: T,
|
||||
): msg is T & { toolId: string } {
|
||||
return (
|
||||
"toolId" in msg &&
|
||||
typeof (msg as Record<string, unknown>).toolId === "string"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message has an operationId property.
|
||||
*/
|
||||
export function hasOperationId<T extends { type: string }>(
|
||||
msg: T,
|
||||
): msg is T & { operationId: string } {
|
||||
return (
|
||||
"operationId" in msg &&
|
||||
typeof (msg as Record<string, unknown>).operationId === "string"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message has a toolCallId property.
|
||||
*/
|
||||
export function hasToolCallId<T extends { type: string }>(
|
||||
msg: T,
|
||||
): msg is T & { toolCallId: string } {
|
||||
return (
|
||||
"toolCallId" in msg &&
|
||||
typeof (msg as Record<string, unknown>).toolCallId === "string"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if a message is an operation message type.
|
||||
*/
|
||||
export function isOperationMessage<T extends { type: string }>(
|
||||
msg: T,
|
||||
): msg is T & {
|
||||
type: "operation_started" | "operation_pending" | "operation_in_progress";
|
||||
} {
|
||||
return (
|
||||
msg.type === "operation_started" ||
|
||||
msg.type === "operation_pending" ||
|
||||
msg.type === "operation_in_progress"
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the tool ID from a message if available.
|
||||
* Checks toolId, operationId, and toolCallId properties.
|
||||
*/
|
||||
export function getToolIdFromMessage<T extends { type: string }>(
|
||||
msg: T,
|
||||
): string | undefined {
|
||||
const record = msg as Record<string, unknown>;
|
||||
if (typeof record.toolId === "string") return record.toolId;
|
||||
if (typeof record.operationId === "string") return record.operationId;
|
||||
if (typeof record.toolCallId === "string") return record.toolCallId;
|
||||
return undefined;
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user