mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-03 11:24:57 -05:00
Compare commits
31 Commits
classic-fr
...
feature/vi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
059c94afac | ||
|
|
3ee7c9bfa8 | ||
|
|
0fde14bf23 | ||
|
|
e8b33f9dbe | ||
|
|
6d6d3b820e | ||
|
|
8b5c018032 | ||
|
|
b5611b00b3 | ||
|
|
6cd62c4d50 | ||
|
|
9f4c33a695 | ||
|
|
b0debe9488 | ||
|
|
b20767bde9 | ||
|
|
b9a9481381 | ||
|
|
d2d2a0c0c9 | ||
|
|
521f69220d | ||
|
|
368adc985d | ||
|
|
8c3216f0a2 | ||
|
|
94063616e5 | ||
|
|
2433a86cb1 | ||
|
|
0ede203f8e | ||
|
|
dc751316c5 | ||
|
|
e7fb54e6af | ||
|
|
7b76f4d1e4 | ||
|
|
3cc56de0fa | ||
|
|
d2bead0f7a | ||
|
|
f8d3893c16 | ||
|
|
1cfbc0dd08 | ||
|
|
ff84643b48 | ||
|
|
c19c3c834a | ||
|
|
d0f7ba8cfd | ||
|
|
2a855f4bd0 | ||
|
|
b93bb3b9f8 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -180,4 +180,3 @@ autogpt_platform/backend/settings.py
|
||||
.claude/settings.local.json
|
||||
CLAUDE.local.md
|
||||
/autogpt_platform/backend/logs
|
||||
.next
|
||||
@@ -54,7 +54,7 @@ Before proceeding with the installation, ensure your system meets the following
|
||||
### Updated Setup Instructions:
|
||||
We've moved to a fully maintained and regularly updated documentation site.
|
||||
|
||||
👉 [Follow the official self-hosting guide here](https://agpt.co/docs/platform/getting-started/getting-started)
|
||||
👉 [Follow the official self-hosting guide here](https://docs.agpt.co/platform/getting-started/)
|
||||
|
||||
|
||||
This tutorial assumes you have Docker, VSCode, git and npm installed.
|
||||
|
||||
@@ -152,6 +152,7 @@ REPLICATE_API_KEY=
|
||||
REVID_API_KEY=
|
||||
SCREENSHOTONE_API_KEY=
|
||||
UNREAL_SPEECH_API_KEY=
|
||||
ELEVENLABS_API_KEY=
|
||||
|
||||
# Data & Search Services
|
||||
E2B_API_KEY=
|
||||
|
||||
@@ -62,10 +62,11 @@ ENV POETRY_HOME=/opt/poetry \
|
||||
DEBIAN_FRONTEND=noninteractive
|
||||
ENV PATH=/opt/poetry/bin:$PATH
|
||||
|
||||
# Install Python without upgrading system-managed packages
|
||||
# Install Python and FFmpeg (required for video processing blocks)
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.13 \
|
||||
python3-pip \
|
||||
ffmpeg \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy only necessary files from builder
|
||||
|
||||
@@ -3,13 +3,9 @@ import logging
|
||||
import time
|
||||
from asyncio import CancelledError
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
import orjson
|
||||
from langfuse import get_client
|
||||
from openai import (
|
||||
@@ -19,13 +15,7 @@ from openai import (
|
||||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
)
|
||||
from openai.types.chat import (
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionStreamOptionsParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import (
|
||||
@@ -804,58 +794,207 @@ def _is_region_blocked_error(error: Exception) -> bool:
|
||||
return "not available in your region" in str(error).lower()
|
||||
|
||||
|
||||
async def _manage_context_window(
|
||||
async def _summarize_messages(
|
||||
messages: list,
|
||||
model: str,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> "CompressResult":
|
||||
"""
|
||||
Manage context window using the unified compress_context function.
|
||||
timeout: float = 30.0,
|
||||
) -> str:
|
||||
"""Summarize a list of messages into concise context.
|
||||
|
||||
This is a thin wrapper that creates an OpenAI client for summarization
|
||||
and delegates to the shared compression logic in prompt.py.
|
||||
Uses the same model as the chat for higher quality summaries.
|
||||
|
||||
Args:
|
||||
messages: List of messages in OpenAI format
|
||||
model: Model name for token counting and summarization
|
||||
api_key: API key for summarization calls
|
||||
base_url: Base URL for summarization calls
|
||||
messages: List of message dicts to summarize
|
||||
model: Model to use for summarization (same as chat model)
|
||||
api_key: API key for OpenAI client
|
||||
base_url: Base URL for OpenAI client
|
||||
timeout: Request timeout in seconds (default: 30.0)
|
||||
|
||||
Returns:
|
||||
CompressResult with compacted messages and metadata
|
||||
Summarized text
|
||||
"""
|
||||
# Format messages for summarization
|
||||
conversation = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
# Include user, assistant, and tool messages (tool outputs are important context)
|
||||
if content and role in ("user", "assistant", "tool"):
|
||||
conversation.append(f"{role.upper()}: {content}")
|
||||
|
||||
conversation_text = "\n\n".join(conversation)
|
||||
|
||||
# Handle empty conversation
|
||||
if not conversation_text:
|
||||
return "No conversation history available."
|
||||
|
||||
# Truncate conversation to fit within summarization model's context
|
||||
# gpt-4o-mini has 128k context, but we limit to ~25k tokens (~100k chars) for safety
|
||||
MAX_CHARS = 100_000
|
||||
if len(conversation_text) > MAX_CHARS:
|
||||
conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]"
|
||||
|
||||
# Call LLM to summarize
|
||||
import openai
|
||||
|
||||
from backend.util.prompt import compress_context
|
||||
summarization_client = openai.AsyncOpenAI(
|
||||
api_key=api_key, base_url=base_url, timeout=timeout
|
||||
)
|
||||
|
||||
# Convert messages to dict format
|
||||
messages_dict = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
||||
else:
|
||||
msg_dict = dict(msg)
|
||||
messages_dict.append(msg_dict)
|
||||
response = await summarization_client.chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Create a detailed summary of the conversation so far. "
|
||||
"This summary will be used as context when continuing the conversation.\n\n"
|
||||
"Before writing the summary, analyze each message chronologically to identify:\n"
|
||||
"- User requests and their explicit goals\n"
|
||||
"- Your approach and key decisions made\n"
|
||||
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
||||
"- Errors encountered and resolutions applied\n\n"
|
||||
"You MUST include ALL of the following sections:\n\n"
|
||||
"## 1. Primary Request and Intent\n"
|
||||
"The user's explicit goals and what they are trying to accomplish.\n\n"
|
||||
"## 2. Key Technical Concepts\n"
|
||||
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
|
||||
"## 3. Files and Resources Involved\n"
|
||||
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
||||
"## 4. Errors and Fixes\n"
|
||||
"Problems encountered, error messages, and their resolutions. "
|
||||
"Include any user feedback on fixes.\n\n"
|
||||
"## 5. Problem Solving\n"
|
||||
"Issues that have been resolved and how they were addressed.\n\n"
|
||||
"## 6. All User Messages\n"
|
||||
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
|
||||
"## 7. Pending Tasks\n"
|
||||
"Work items the user explicitly requested that have not yet been completed.\n\n"
|
||||
"## 8. Current Work\n"
|
||||
"Precise description of what was being worked on most recently, including relevant context.\n\n"
|
||||
"## 9. Next Steps\n"
|
||||
"What should happen next, aligned with the user's most recent requests. "
|
||||
"Include verbatim quotes of recent instructions if relevant."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
||||
],
|
||||
max_tokens=1500,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
# Only create client if api_key is provided (enables summarization)
|
||||
# Use context manager to avoid socket leaks
|
||||
if api_key:
|
||||
async with openai.AsyncOpenAI(
|
||||
api_key=api_key, base_url=base_url, timeout=30.0
|
||||
) as client:
|
||||
return await compress_context(
|
||||
messages=messages_dict,
|
||||
model=model,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
# No API key - use truncation-only mode
|
||||
return await compress_context(
|
||||
messages=messages_dict,
|
||||
model=model,
|
||||
client=None,
|
||||
summary = response.choices[0].message.content
|
||||
return summary or "No summary available."
|
||||
|
||||
|
||||
def _ensure_tool_pairs_intact(
|
||||
recent_messages: list[dict],
|
||||
all_messages: list[dict],
|
||||
start_index: int,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Ensure tool_call/tool_response pairs stay together after slicing.
|
||||
|
||||
When slicing messages for context compaction, a naive slice can separate
|
||||
an assistant message containing tool_calls from its corresponding tool
|
||||
response messages. This causes API validation errors (e.g., Anthropic's
|
||||
"unexpected tool_use_id found in tool_result blocks").
|
||||
|
||||
This function checks for orphan tool responses in the slice and extends
|
||||
backwards to include their corresponding assistant messages.
|
||||
|
||||
Args:
|
||||
recent_messages: The sliced messages to validate
|
||||
all_messages: The complete message list (for looking up missing assistants)
|
||||
start_index: The index in all_messages where recent_messages begins
|
||||
|
||||
Returns:
|
||||
A potentially extended list of messages with tool pairs intact
|
||||
"""
|
||||
if not recent_messages:
|
||||
return recent_messages
|
||||
|
||||
# Collect all tool_call_ids from assistant messages in the slice
|
||||
available_tool_call_ids: set[str] = set()
|
||||
for msg in recent_messages:
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id:
|
||||
available_tool_call_ids.add(tc_id)
|
||||
|
||||
# Find orphan tool responses (tool messages whose tool_call_id is missing)
|
||||
orphan_tool_call_ids: set[str] = set()
|
||||
for msg in recent_messages:
|
||||
if msg.get("role") == "tool":
|
||||
tc_id = msg.get("tool_call_id")
|
||||
if tc_id and tc_id not in available_tool_call_ids:
|
||||
orphan_tool_call_ids.add(tc_id)
|
||||
|
||||
if not orphan_tool_call_ids:
|
||||
# No orphans, slice is valid
|
||||
return recent_messages
|
||||
|
||||
# Find the assistant messages that contain the orphan tool_call_ids
|
||||
# Search backwards from start_index in all_messages
|
||||
messages_to_prepend: list[dict] = []
|
||||
for i in range(start_index - 1, -1, -1):
|
||||
msg = all_messages[i]
|
||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
||||
msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")}
|
||||
if msg_tool_ids & orphan_tool_call_ids:
|
||||
# This assistant message has tool_calls we need
|
||||
# Also collect its contiguous tool responses that follow it
|
||||
assistant_and_responses: list[dict] = [msg]
|
||||
|
||||
# Scan forward from this assistant to collect tool responses
|
||||
for j in range(i + 1, start_index):
|
||||
following_msg = all_messages[j]
|
||||
if following_msg.get("role") == "tool":
|
||||
tool_id = following_msg.get("tool_call_id")
|
||||
if tool_id and tool_id in msg_tool_ids:
|
||||
assistant_and_responses.append(following_msg)
|
||||
else:
|
||||
# Stop at first non-tool message
|
||||
break
|
||||
|
||||
# Prepend the assistant and its tool responses (maintain order)
|
||||
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
||||
# Mark these as found
|
||||
orphan_tool_call_ids -= msg_tool_ids
|
||||
# Also add this assistant's tool_call_ids to available set
|
||||
available_tool_call_ids |= msg_tool_ids
|
||||
|
||||
if not orphan_tool_call_ids:
|
||||
# Found all missing assistants
|
||||
break
|
||||
|
||||
if orphan_tool_call_ids:
|
||||
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||
# This shouldn't happen in normal operation but handles edge cases
|
||||
logger.warning(
|
||||
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||
"Removing orphan tool responses."
|
||||
)
|
||||
recent_messages = [
|
||||
msg
|
||||
for msg in recent_messages
|
||||
if not (
|
||||
msg.get("role") == "tool"
|
||||
and msg.get("tool_call_id") in orphan_tool_call_ids
|
||||
)
|
||||
]
|
||||
|
||||
if messages_to_prepend:
|
||||
logger.info(
|
||||
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||
f"tool_call/tool_response pairs"
|
||||
)
|
||||
return messages_to_prepend + recent_messages
|
||||
|
||||
return recent_messages
|
||||
|
||||
|
||||
async def _stream_chat_chunks(
|
||||
@@ -883,8 +1022,11 @@ async def _stream_chat_chunks(
|
||||
|
||||
logger.info("Starting pure chat stream")
|
||||
|
||||
# Build messages with system prompt prepended
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||
|
||||
system_message = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
@@ -892,38 +1034,314 @@ async def _stream_chat_chunks(
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Apply context window management
|
||||
context_result = await _manage_context_window(
|
||||
messages=messages,
|
||||
model=model,
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
)
|
||||
token_count = 0 # Initialize for exception handler
|
||||
try:
|
||||
from backend.util.prompt import estimate_token_count
|
||||
|
||||
if context_result.error:
|
||||
if "System prompt dropped" in context_result.error:
|
||||
# Warning only - continue with reduced context
|
||||
yield StreamError(
|
||||
errorText=(
|
||||
"Warning: System prompt dropped due to size constraints. "
|
||||
"Assistant behavior may be affected."
|
||||
)
|
||||
# Convert to dict for token counting
|
||||
# OpenAI message types are TypedDicts, so they're already dict-like
|
||||
messages_dict = []
|
||||
for msg in messages:
|
||||
# TypedDict objects are already dicts, just filter None values
|
||||
if isinstance(msg, dict):
|
||||
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
||||
else:
|
||||
# Fallback for unexpected types
|
||||
msg_dict = dict(msg)
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
# Estimate tokens using appropriate tokenizer
|
||||
# Normalize model name for token counting (tiktoken only supports OpenAI models)
|
||||
token_count_model = model
|
||||
if "/" in model:
|
||||
# Strip provider prefix (e.g., "anthropic/claude-opus-4.5" -> "claude-opus-4.5")
|
||||
token_count_model = model.split("/")[-1]
|
||||
|
||||
# For Claude and other non-OpenAI models, approximate with gpt-4o tokenizer
|
||||
# Most modern LLMs have similar tokenization (~1 token per 4 chars)
|
||||
if "claude" in token_count_model.lower() or not any(
|
||||
known in token_count_model.lower()
|
||||
for known in ["gpt", "o1", "chatgpt", "text-"]
|
||||
):
|
||||
token_count_model = "gpt-4o"
|
||||
|
||||
# Attempt token counting with error handling
|
||||
try:
|
||||
token_count = estimate_token_count(messages_dict, model=token_count_model)
|
||||
except Exception as token_error:
|
||||
# If token counting fails, use gpt-4o as fallback approximation
|
||||
logger.warning(
|
||||
f"Token counting failed for model {token_count_model}: {token_error}. "
|
||||
"Using gpt-4o approximation."
|
||||
)
|
||||
else:
|
||||
# Any other error - abort to prevent failed LLM calls
|
||||
token_count = estimate_token_count(messages_dict, model="gpt-4o")
|
||||
|
||||
# If over threshold, summarize old messages
|
||||
if token_count > 120_000:
|
||||
KEEP_RECENT = 15
|
||||
|
||||
# Check if we have a system prompt at the start
|
||||
has_system_prompt = (
|
||||
len(messages) > 0 and messages[0].get("role") == "system"
|
||||
)
|
||||
|
||||
# Always attempt mitigation when over limit, even with few messages
|
||||
if messages:
|
||||
# Split messages based on whether system prompt exists
|
||||
# Calculate start index for the slice
|
||||
slice_start = max(0, len(messages_dict) - KEEP_RECENT)
|
||||
recent_messages = messages_dict[-KEEP_RECENT:]
|
||||
|
||||
# Ensure tool_call/tool_response pairs stay together
|
||||
# This prevents API errors from orphan tool responses
|
||||
recent_messages = _ensure_tool_pairs_intact(
|
||||
recent_messages, messages_dict, slice_start
|
||||
)
|
||||
|
||||
if has_system_prompt:
|
||||
# Keep system prompt separate, summarize everything between system and recent
|
||||
system_msg = messages[0]
|
||||
old_messages_dict = messages_dict[1:-KEEP_RECENT]
|
||||
else:
|
||||
# No system prompt, summarize everything except recent
|
||||
system_msg = None
|
||||
old_messages_dict = messages_dict[:-KEEP_RECENT]
|
||||
|
||||
# Summarize any non-empty old messages (no minimum threshold)
|
||||
# If we're over the token limit, we need to compress whatever we can
|
||||
if old_messages_dict:
|
||||
# Summarize old messages using the same model as chat
|
||||
summary_text = await _summarize_messages(
|
||||
old_messages_dict,
|
||||
model=model,
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
)
|
||||
|
||||
# Build new message list
|
||||
# Use assistant role (not system) to prevent privilege escalation
|
||||
# of user-influenced content to instruction-level authority
|
||||
from openai.types.chat import ChatCompletionAssistantMessageParam
|
||||
|
||||
summary_msg = ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=(
|
||||
"[Previous conversation summary — for context only]: "
|
||||
f"{summary_text}"
|
||||
),
|
||||
)
|
||||
|
||||
# Rebuild messages based on whether we have a system prompt
|
||||
if has_system_prompt:
|
||||
# system_prompt + summary + recent_messages
|
||||
messages = [system_msg, summary_msg] + recent_messages
|
||||
else:
|
||||
# summary + recent_messages (no original system prompt)
|
||||
messages = [summary_msg] + recent_messages
|
||||
|
||||
logger.info(
|
||||
f"Context summarized: {token_count} tokens, "
|
||||
f"summarized {len(old_messages_dict)} old messages, "
|
||||
f"kept last {KEEP_RECENT} messages"
|
||||
)
|
||||
|
||||
# Fallback: If still over limit after summarization, progressively drop recent messages
|
||||
# This handles edge cases where recent messages are extremely large
|
||||
new_messages_dict = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
||||
else:
|
||||
msg_dict = dict(msg)
|
||||
new_messages_dict.append(msg_dict)
|
||||
|
||||
new_token_count = estimate_token_count(
|
||||
new_messages_dict, model=token_count_model
|
||||
)
|
||||
|
||||
if new_token_count > 120_000:
|
||||
# Still over limit - progressively reduce KEEP_RECENT
|
||||
logger.warning(
|
||||
f"Still over limit after summarization: {new_token_count} tokens. "
|
||||
"Reducing number of recent messages kept."
|
||||
)
|
||||
|
||||
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
||||
if keep_count == 0:
|
||||
# Try with just system prompt + summary (no recent messages)
|
||||
if has_system_prompt:
|
||||
messages = [system_msg, summary_msg]
|
||||
else:
|
||||
messages = [summary_msg]
|
||||
logger.info(
|
||||
"Trying with 0 recent messages (system + summary only)"
|
||||
)
|
||||
else:
|
||||
# Slice from ORIGINAL recent_messages to avoid duplicating summary
|
||||
reduced_recent = (
|
||||
recent_messages[-keep_count:]
|
||||
if len(recent_messages) >= keep_count
|
||||
else recent_messages
|
||||
)
|
||||
# Ensure tool pairs stay intact in the reduced slice
|
||||
reduced_slice_start = max(
|
||||
0, len(recent_messages) - keep_count
|
||||
)
|
||||
reduced_recent = _ensure_tool_pairs_intact(
|
||||
reduced_recent, recent_messages, reduced_slice_start
|
||||
)
|
||||
if has_system_prompt:
|
||||
messages = [
|
||||
system_msg,
|
||||
summary_msg,
|
||||
] + reduced_recent
|
||||
else:
|
||||
messages = [summary_msg] + reduced_recent
|
||||
|
||||
new_messages_dict = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
msg_dict = {
|
||||
k: v for k, v in msg.items() if v is not None
|
||||
}
|
||||
else:
|
||||
msg_dict = dict(msg)
|
||||
new_messages_dict.append(msg_dict)
|
||||
|
||||
new_token_count = estimate_token_count(
|
||||
new_messages_dict, model=token_count_model
|
||||
)
|
||||
|
||||
if new_token_count <= 120_000:
|
||||
logger.info(
|
||||
f"Reduced to {keep_count} recent messages, "
|
||||
f"now {new_token_count} tokens"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.error(
|
||||
f"Unable to reduce token count below threshold even with 0 messages. "
|
||||
f"Final count: {new_token_count} tokens"
|
||||
)
|
||||
# ABSOLUTE LAST RESORT: Drop system prompt
|
||||
# This should only happen if summary itself is massive
|
||||
if has_system_prompt and len(messages) > 1:
|
||||
messages = messages[1:] # Drop system prompt
|
||||
logger.critical(
|
||||
"CRITICAL: Dropped system prompt as absolute last resort. "
|
||||
"Behavioral consistency may be affected."
|
||||
)
|
||||
# Yield error to user
|
||||
yield StreamError(
|
||||
errorText=(
|
||||
"Warning: System prompt dropped due to size constraints. "
|
||||
"Assistant behavior may be affected."
|
||||
)
|
||||
)
|
||||
else:
|
||||
# No old messages to summarize - all messages are "recent"
|
||||
# Apply progressive truncation to reduce token count
|
||||
logger.warning(
|
||||
f"Token count {token_count} exceeds threshold but no old messages to summarize. "
|
||||
f"Applying progressive truncation to recent messages."
|
||||
)
|
||||
|
||||
# Create a base list excluding system prompt to avoid duplication
|
||||
# This is the pool of messages we'll slice from in the loop
|
||||
# Use messages_dict for type consistency with _ensure_tool_pairs_intact
|
||||
base_msgs = (
|
||||
messages_dict[1:] if has_system_prompt else messages_dict
|
||||
)
|
||||
|
||||
# Try progressively smaller keep counts
|
||||
new_token_count = token_count # Initialize with current count
|
||||
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
||||
if keep_count == 0:
|
||||
# Try with just system prompt (no recent messages)
|
||||
if has_system_prompt:
|
||||
messages = [system_msg]
|
||||
logger.info(
|
||||
"Trying with 0 recent messages (system prompt only)"
|
||||
)
|
||||
else:
|
||||
# No system prompt and no recent messages = empty messages list
|
||||
# This is invalid, skip this iteration
|
||||
continue
|
||||
else:
|
||||
if len(base_msgs) < keep_count:
|
||||
continue # Skip if we don't have enough messages
|
||||
|
||||
# Slice from base_msgs to get recent messages (without system prompt)
|
||||
recent_messages = base_msgs[-keep_count:]
|
||||
|
||||
# Ensure tool pairs stay intact in the reduced slice
|
||||
reduced_slice_start = max(0, len(base_msgs) - keep_count)
|
||||
recent_messages = _ensure_tool_pairs_intact(
|
||||
recent_messages, base_msgs, reduced_slice_start
|
||||
)
|
||||
|
||||
if has_system_prompt:
|
||||
messages = [system_msg] + recent_messages
|
||||
else:
|
||||
messages = recent_messages
|
||||
|
||||
new_messages_dict = []
|
||||
for msg in messages:
|
||||
if msg is None:
|
||||
continue # Skip None messages (type safety)
|
||||
if isinstance(msg, dict):
|
||||
msg_dict = {
|
||||
k: v for k, v in msg.items() if v is not None
|
||||
}
|
||||
else:
|
||||
msg_dict = dict(msg)
|
||||
new_messages_dict.append(msg_dict)
|
||||
|
||||
new_token_count = estimate_token_count(
|
||||
new_messages_dict, model=token_count_model
|
||||
)
|
||||
|
||||
if new_token_count <= 120_000:
|
||||
logger.info(
|
||||
f"Reduced to {keep_count} recent messages, "
|
||||
f"now {new_token_count} tokens"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Even with 0 messages still over limit
|
||||
logger.error(
|
||||
f"Unable to reduce token count below threshold even with 0 messages. "
|
||||
f"Final count: {new_token_count} tokens. Messages may be extremely large."
|
||||
)
|
||||
# ABSOLUTE LAST RESORT: Drop system prompt
|
||||
if has_system_prompt and len(messages) > 1:
|
||||
messages = messages[1:] # Drop system prompt
|
||||
logger.critical(
|
||||
"CRITICAL: Dropped system prompt as absolute last resort. "
|
||||
"Behavioral consistency may be affected."
|
||||
)
|
||||
# Yield error to user
|
||||
yield StreamError(
|
||||
errorText=(
|
||||
"Warning: System prompt dropped due to size constraints. "
|
||||
"Assistant behavior may be affected."
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Context summarization failed: {e}", exc_info=True)
|
||||
# If we were over the token limit, yield error to user
|
||||
# Don't silently continue with oversized messages that will fail
|
||||
if token_count > 120_000:
|
||||
yield StreamError(
|
||||
errorText=(
|
||||
f"Context window management failed: {context_result.error}. "
|
||||
"Please start a new conversation."
|
||||
f"Unable to manage context window (token limit exceeded: {token_count} tokens). "
|
||||
"Context summarization failed. Please start a new conversation."
|
||||
)
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
|
||||
messages = context_result.messages
|
||||
if context_result.was_compacted:
|
||||
logger.info(
|
||||
f"Context compacted for streaming: {context_result.token_count} tokens"
|
||||
)
|
||||
# Otherwise, continue with original messages (under limit)
|
||||
|
||||
# Loop to handle tool calls and continue conversation
|
||||
while True:
|
||||
@@ -951,6 +1369,14 @@ async def _stream_chat_chunks(
|
||||
:128
|
||||
] # OpenRouter limit
|
||||
|
||||
# Create the stream with proper types
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionStreamOptionsParam,
|
||||
)
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
@@ -1408,11 +1834,6 @@ async def _execute_long_running_tool(
|
||||
tool_call_id=tool_call_id,
|
||||
result=error_response.model_dump_json(),
|
||||
)
|
||||
# Generate LLM continuation so user sees explanation even for errors
|
||||
try:
|
||||
await _generate_llm_continuation(session_id=session_id, user_id=user_id)
|
||||
except Exception as llm_err:
|
||||
logger.warning(f"Failed to generate LLM continuation for error: {llm_err}")
|
||||
finally:
|
||||
await _mark_operation_completed(tool_call_id)
|
||||
|
||||
@@ -1474,36 +1895,17 @@ async def _generate_llm_continuation(
|
||||
# Build system prompt
|
||||
system_prompt, _ = await _build_system_prompt(user_id)
|
||||
|
||||
# Build messages in OpenAI format
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||
|
||||
system_message = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
)
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Apply context window management to prevent oversized requests
|
||||
context_result = await _manage_context_window(
|
||||
messages=messages,
|
||||
model=config.model,
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
)
|
||||
|
||||
if context_result.error and "System prompt dropped" not in context_result.error:
|
||||
logger.error(
|
||||
f"Context window management failed for session {session_id}: "
|
||||
f"{context_result.error} (tokens={context_result.token_count})"
|
||||
)
|
||||
return
|
||||
|
||||
messages = context_result.messages
|
||||
if context_result.was_compacted:
|
||||
logger.info(
|
||||
f"Context compacted for LLM continuation: "
|
||||
f"{context_result.token_count} tokens"
|
||||
)
|
||||
|
||||
# Build extra_body for tracing
|
||||
extra_body: dict[str, Any] = {
|
||||
"posthogProperties": {
|
||||
@@ -1516,54 +1918,19 @@ async def _generate_llm_continuation(
|
||||
if session_id:
|
||||
extra_body["session_id"] = session_id[:128]
|
||||
|
||||
retry_count = 0
|
||||
last_error: Exception | None = None
|
||||
response = None
|
||||
# Make non-streaming LLM call (no tools - just text response)
|
||||
from typing import cast
|
||||
|
||||
while retry_count <= MAX_RETRIES:
|
||||
try:
|
||||
logger.info(
|
||||
f"Generating LLM continuation for session {session_id}"
|
||||
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
||||
)
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
extra_body=extra_body,
|
||||
)
|
||||
last_error = None # Clear any previous error on success
|
||||
break # Success, exit retry loop
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
|
||||
retry_count += 1
|
||||
delay = min(
|
||||
BASE_DELAY_SECONDS * (2 ** (retry_count - 1)),
|
||||
MAX_DELAY_SECONDS,
|
||||
)
|
||||
logger.warning(
|
||||
f"Retryable error in LLM continuation: {e!s}. "
|
||||
f"Retrying in {delay:.1f}s (attempt {retry_count}/{MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
# Non-retryable error - log and exit gracefully
|
||||
logger.error(
|
||||
f"Non-retryable error in LLM continuation: {e!s}",
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
# No tools parameter = text-only response (no tool calls)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
if last_error:
|
||||
logger.error(
|
||||
f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. "
|
||||
f"Last error: {last_error!s}"
|
||||
)
|
||||
return
|
||||
|
||||
if response and response.choices and response.choices[0].message.content:
|
||||
if response.choices and response.choices[0].message.content:
|
||||
assistant_content = response.choices[0].message.content
|
||||
|
||||
# Reload session from DB to avoid race condition with user messages
|
||||
|
||||
@@ -2,54 +2,30 @@
|
||||
|
||||
from .core import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
AgentJsonValidationError,
|
||||
AgentSummary,
|
||||
DecompositionResult,
|
||||
DecompositionStep,
|
||||
LibraryAgentSummary,
|
||||
MarketplaceAgentSummary,
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
extract_search_terms_from_steps,
|
||||
extract_uuids_from_text,
|
||||
generate_agent,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_library_agent_by_graph_id,
|
||||
get_library_agent_by_id,
|
||||
get_library_agents_for_generation,
|
||||
json_to_graph,
|
||||
save_agent_to_library,
|
||||
search_marketplace_agents_for_generation,
|
||||
)
|
||||
from .errors import get_user_message_for_error
|
||||
from .service import health_check as check_external_service_health
|
||||
from .service import is_external_service_configured
|
||||
|
||||
__all__ = [
|
||||
"AgentGeneratorNotConfiguredError",
|
||||
"AgentJsonValidationError",
|
||||
"AgentSummary",
|
||||
"DecompositionResult",
|
||||
"DecompositionStep",
|
||||
"LibraryAgentSummary",
|
||||
"MarketplaceAgentSummary",
|
||||
"check_external_service_health",
|
||||
# Core functions
|
||||
"decompose_goal",
|
||||
"enrich_library_agents_from_steps",
|
||||
"extract_search_terms_from_steps",
|
||||
"extract_uuids_from_text",
|
||||
"generate_agent",
|
||||
"generate_agent_patch",
|
||||
"get_agent_as_json",
|
||||
"get_all_relevant_agents_for_generation",
|
||||
"get_library_agent_by_graph_id",
|
||||
"get_library_agent_by_id",
|
||||
"get_library_agents_for_generation",
|
||||
"get_user_message_for_error",
|
||||
"is_external_service_configured",
|
||||
"json_to_graph",
|
||||
"save_agent_to_library",
|
||||
"search_marketplace_agents_for_generation",
|
||||
"get_agent_as_json",
|
||||
"json_to_graph",
|
||||
# Exceptions
|
||||
"AgentGeneratorNotConfiguredError",
|
||||
# Service
|
||||
"is_external_service_configured",
|
||||
"check_external_service_health",
|
||||
# Error handling
|
||||
"get_user_message_for_error",
|
||||
]
|
||||
|
||||
@@ -1,22 +1,11 @@
|
||||
"""Core agent generation functions."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any, NotRequired, TypedDict
|
||||
from typing import Any
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data.graph import (
|
||||
Graph,
|
||||
Link,
|
||||
Node,
|
||||
create_graph,
|
||||
get_graph,
|
||||
get_graph_all_versions,
|
||||
get_store_listed_graphs,
|
||||
)
|
||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||
from backend.data.graph import Graph, Link, Node, create_graph
|
||||
|
||||
from .service import (
|
||||
decompose_goal_external,
|
||||
@@ -27,74 +16,6 @@ from .service import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
AGENT_EXECUTOR_BLOCK_ID = "e189baac-8c20-45a1-94a7-55177ea42565"
|
||||
|
||||
|
||||
class ExecutionSummary(TypedDict):
|
||||
"""Summary of a single execution for quality assessment."""
|
||||
|
||||
status: str
|
||||
correctness_score: NotRequired[float]
|
||||
activity_summary: NotRequired[str]
|
||||
|
||||
|
||||
class LibraryAgentSummary(TypedDict):
|
||||
"""Summary of a library agent for sub-agent composition.
|
||||
|
||||
Includes recent executions to help the LLM decide whether to use this agent.
|
||||
Each execution shows status, correctness_score (0-1), and activity_summary.
|
||||
"""
|
||||
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
name: str
|
||||
description: str
|
||||
input_schema: dict[str, Any]
|
||||
output_schema: dict[str, Any]
|
||||
recent_executions: NotRequired[list[ExecutionSummary]]
|
||||
|
||||
|
||||
class MarketplaceAgentSummary(TypedDict):
|
||||
"""Summary of a marketplace agent for sub-agent composition."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
sub_heading: str
|
||||
creator: str
|
||||
is_marketplace_agent: bool
|
||||
|
||||
|
||||
class DecompositionStep(TypedDict, total=False):
|
||||
"""A single step in decomposed instructions."""
|
||||
|
||||
description: str
|
||||
action: str
|
||||
block_name: str
|
||||
tool: str
|
||||
name: str
|
||||
|
||||
|
||||
class DecompositionResult(TypedDict, total=False):
|
||||
"""Result from decompose_goal - can be instructions, questions, or error."""
|
||||
|
||||
type: str
|
||||
steps: list[DecompositionStep]
|
||||
questions: list[dict[str, Any]]
|
||||
error: str
|
||||
error_type: str
|
||||
|
||||
|
||||
AgentSummary = LibraryAgentSummary | MarketplaceAgentSummary | dict[str, Any]
|
||||
|
||||
|
||||
def _to_dict_list(
|
||||
agents: list[AgentSummary] | list[dict[str, Any]] | None,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Convert typed agent summaries to plain dicts for external service calls."""
|
||||
if agents is None:
|
||||
return None
|
||||
return [dict(a) for a in agents]
|
||||
|
||||
|
||||
class AgentGeneratorNotConfiguredError(Exception):
|
||||
"""Raised when the external Agent Generator service is not configured."""
|
||||
@@ -115,422 +36,15 @@ def _check_service_configured() -> None:
|
||||
)
|
||||
|
||||
|
||||
_UUID_PATTERN = re.compile(
|
||||
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def extract_uuids_from_text(text: str) -> list[str]:
|
||||
"""Extract all UUID v4 strings from text.
|
||||
|
||||
Args:
|
||||
text: Text that may contain UUIDs (e.g., user's goal description)
|
||||
|
||||
Returns:
|
||||
List of unique UUIDs found in the text (lowercase)
|
||||
"""
|
||||
matches = _UUID_PATTERN.findall(text)
|
||||
return list({m.lower() for m in matches})
|
||||
|
||||
|
||||
async def get_library_agent_by_id(
|
||||
user_id: str, agent_id: str
|
||||
) -> LibraryAgentSummary | None:
|
||||
"""Fetch a specific library agent by its ID (library agent ID or graph_id).
|
||||
|
||||
This function tries multiple lookup strategies:
|
||||
1. First tries to find by graph_id (AgentGraph primary key)
|
||||
2. If not found, tries to find by library agent ID (LibraryAgent primary key)
|
||||
|
||||
This handles both cases:
|
||||
- User provides graph_id (e.g., from AgentExecutorBlock)
|
||||
- User provides library agent ID (e.g., from library URL)
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||
|
||||
Returns:
|
||||
LibraryAgentSummary if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
graph_id=agent.graph_id,
|
||||
graph_version=agent.graph_version,
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
input_schema=agent.input_schema,
|
||||
output_schema=agent.output_schema,
|
||||
)
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not fetch library agent by graph_id {agent_id}: {e}")
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return LibraryAgentSummary(
|
||||
graph_id=agent.graph_id,
|
||||
graph_version=agent.graph_version,
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
input_schema=agent.input_schema,
|
||||
output_schema=agent.output_schema,
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
get_library_agent_by_graph_id = get_library_agent_by_id
|
||||
|
||||
|
||||
async def get_library_agents_for_generation(
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
exclude_graph_id: str | None = None,
|
||||
max_results: int = 15,
|
||||
) -> list[LibraryAgentSummary]:
|
||||
"""Fetch user's library agents formatted for Agent Generator.
|
||||
|
||||
Uses search-based fetching to return relevant agents instead of all agents.
|
||||
This is more scalable for users with large libraries.
|
||||
|
||||
Includes recent_executions list to help the LLM assess agent quality:
|
||||
- Each execution has status, correctness_score (0-1), and activity_summary
|
||||
- This gives the LLM concrete examples of recent performance
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
search_query: Optional search term to find relevant agents (user's goal/description)
|
||||
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||
max_results: Maximum number of agents to return (default 15)
|
||||
|
||||
Returns:
|
||||
List of LibraryAgentSummary with schemas and recent executions for sub-agent composition
|
||||
"""
|
||||
try:
|
||||
response = await library_db.list_library_agents(
|
||||
user_id=user_id,
|
||||
search_term=search_query,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
include_executions=True,
|
||||
)
|
||||
|
||||
results: list[LibraryAgentSummary] = []
|
||||
for agent in response.agents:
|
||||
if exclude_graph_id is not None and agent.graph_id == exclude_graph_id:
|
||||
continue
|
||||
|
||||
summary = LibraryAgentSummary(
|
||||
graph_id=agent.graph_id,
|
||||
graph_version=agent.graph_version,
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
input_schema=agent.input_schema,
|
||||
output_schema=agent.output_schema,
|
||||
)
|
||||
if agent.recent_executions:
|
||||
exec_summaries: list[ExecutionSummary] = []
|
||||
for ex in agent.recent_executions:
|
||||
exec_sum = ExecutionSummary(status=ex.status)
|
||||
if ex.correctness_score is not None:
|
||||
exec_sum["correctness_score"] = ex.correctness_score
|
||||
if ex.activity_summary:
|
||||
exec_sum["activity_summary"] = ex.activity_summary
|
||||
exec_summaries.append(exec_sum)
|
||||
summary["recent_executions"] = exec_summaries
|
||||
results.append(summary)
|
||||
return results
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def search_marketplace_agents_for_generation(
|
||||
search_query: str,
|
||||
max_results: int = 10,
|
||||
) -> list[LibraryAgentSummary]:
|
||||
"""Search marketplace agents formatted for Agent Generator.
|
||||
|
||||
Fetches marketplace agents and their full schemas so they can be used
|
||||
as sub-agents in generated workflows.
|
||||
|
||||
Args:
|
||||
search_query: Search term to find relevant public agents
|
||||
max_results: Maximum number of agents to return (default 10)
|
||||
|
||||
Returns:
|
||||
List of LibraryAgentSummary with full input/output schemas
|
||||
"""
|
||||
try:
|
||||
response = await store_db.get_store_agents(
|
||||
search_query=search_query,
|
||||
page=1,
|
||||
page_size=max_results,
|
||||
)
|
||||
|
||||
agents_with_graphs = [
|
||||
agent for agent in response.agents if agent.agent_graph_id
|
||||
]
|
||||
|
||||
if not agents_with_graphs:
|
||||
return []
|
||||
|
||||
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
||||
graphs = await get_store_listed_graphs(*graph_ids)
|
||||
|
||||
results: list[LibraryAgentSummary] = []
|
||||
for agent in agents_with_graphs:
|
||||
graph_id = agent.agent_graph_id
|
||||
if graph_id and graph_id in graphs:
|
||||
graph = graphs[graph_id]
|
||||
results.append(
|
||||
LibraryAgentSummary(
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
name=agent.agent_name,
|
||||
description=agent.description,
|
||||
input_schema=graph.input_schema,
|
||||
output_schema=graph.output_schema,
|
||||
)
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to search marketplace agents: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_all_relevant_agents_for_generation(
|
||||
user_id: str,
|
||||
search_query: str | None = None,
|
||||
exclude_graph_id: str | None = None,
|
||||
include_library: bool = True,
|
||||
include_marketplace: bool = True,
|
||||
max_library_results: int = 15,
|
||||
max_marketplace_results: int = 10,
|
||||
) -> list[AgentSummary]:
|
||||
"""Fetch relevant agents from library and/or marketplace.
|
||||
|
||||
Searches both user's library and marketplace by default.
|
||||
Explicitly mentioned UUIDs in the search query are always looked up.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
search_query: Search term to find relevant agents (user's goal/description)
|
||||
exclude_graph_id: Optional graph ID to exclude (prevents circular references)
|
||||
include_library: Whether to search user's library (default True)
|
||||
include_marketplace: Whether to also search marketplace (default True)
|
||||
max_library_results: Max library agents to return (default 15)
|
||||
max_marketplace_results: Max marketplace agents to return (default 10)
|
||||
|
||||
Returns:
|
||||
List of AgentSummary with full schemas (both library and marketplace agents)
|
||||
"""
|
||||
agents: list[AgentSummary] = []
|
||||
seen_graph_ids: set[str] = set()
|
||||
|
||||
if search_query:
|
||||
mentioned_uuids = extract_uuids_from_text(search_query)
|
||||
for graph_id in mentioned_uuids:
|
||||
if graph_id == exclude_graph_id:
|
||||
continue
|
||||
agent = await get_library_agent_by_graph_id(user_id, graph_id)
|
||||
agent_graph_id = agent.get("graph_id") if agent else None
|
||||
if agent and agent_graph_id and agent_graph_id not in seen_graph_ids:
|
||||
agents.append(agent)
|
||||
seen_graph_ids.add(agent_graph_id)
|
||||
logger.debug(
|
||||
f"Found explicitly mentioned agent: {agent.get('name') or 'Unknown'}"
|
||||
)
|
||||
|
||||
if include_library:
|
||||
library_agents = await get_library_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=search_query,
|
||||
exclude_graph_id=exclude_graph_id,
|
||||
max_results=max_library_results,
|
||||
)
|
||||
for agent in library_agents:
|
||||
graph_id = agent.get("graph_id")
|
||||
if graph_id and graph_id not in seen_graph_ids:
|
||||
agents.append(agent)
|
||||
seen_graph_ids.add(graph_id)
|
||||
|
||||
if include_marketplace and search_query:
|
||||
marketplace_agents = await search_marketplace_agents_for_generation(
|
||||
search_query=search_query,
|
||||
max_results=max_marketplace_results,
|
||||
)
|
||||
for agent in marketplace_agents:
|
||||
graph_id = agent.get("graph_id")
|
||||
if graph_id and graph_id not in seen_graph_ids:
|
||||
agents.append(agent)
|
||||
seen_graph_ids.add(graph_id)
|
||||
|
||||
return agents
|
||||
|
||||
|
||||
def extract_search_terms_from_steps(
|
||||
decomposition_result: DecompositionResult | dict[str, Any],
|
||||
) -> list[str]:
|
||||
"""Extract search terms from decomposed instruction steps.
|
||||
|
||||
Analyzes the decomposition result to extract relevant keywords
|
||||
for additional library agent searches.
|
||||
|
||||
Args:
|
||||
decomposition_result: Result from decompose_goal containing steps
|
||||
|
||||
Returns:
|
||||
List of unique search terms extracted from steps
|
||||
"""
|
||||
search_terms: list[str] = []
|
||||
|
||||
if decomposition_result.get("type") != "instructions":
|
||||
return search_terms
|
||||
|
||||
steps = decomposition_result.get("steps", [])
|
||||
if not steps:
|
||||
return search_terms
|
||||
|
||||
step_keys: list[str] = ["description", "action", "block_name", "tool", "name"]
|
||||
|
||||
for step in steps:
|
||||
for key in step_keys:
|
||||
value = step.get(key) # type: ignore[union-attr]
|
||||
if isinstance(value, str) and len(value) > 3:
|
||||
search_terms.append(value)
|
||||
|
||||
seen: set[str] = set()
|
||||
unique_terms: list[str] = []
|
||||
for term in search_terms:
|
||||
term_lower = term.lower()
|
||||
if term_lower not in seen:
|
||||
seen.add(term_lower)
|
||||
unique_terms.append(term)
|
||||
|
||||
return unique_terms
|
||||
|
||||
|
||||
async def enrich_library_agents_from_steps(
|
||||
user_id: str,
|
||||
decomposition_result: DecompositionResult | dict[str, Any],
|
||||
existing_agents: list[AgentSummary] | list[dict[str, Any]],
|
||||
exclude_graph_id: str | None = None,
|
||||
include_marketplace: bool = True,
|
||||
max_additional_results: int = 10,
|
||||
) -> list[AgentSummary] | list[dict[str, Any]]:
|
||||
"""Enrich library agents list with additional searches based on decomposed steps.
|
||||
|
||||
This implements two-phase search: after decomposition, we search for additional
|
||||
relevant agents based on the specific steps identified.
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
decomposition_result: Result from decompose_goal containing steps
|
||||
existing_agents: Already fetched library agents from initial search
|
||||
exclude_graph_id: Optional graph ID to exclude
|
||||
include_marketplace: Whether to also search marketplace
|
||||
max_additional_results: Max additional agents per search term (default 10)
|
||||
|
||||
Returns:
|
||||
Combined list of library agents (existing + newly discovered)
|
||||
"""
|
||||
search_terms = extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
if not search_terms:
|
||||
return existing_agents
|
||||
|
||||
existing_ids: set[str] = set()
|
||||
existing_names: set[str] = set()
|
||||
|
||||
for agent in existing_agents:
|
||||
agent_name = agent.get("name")
|
||||
if agent_name and isinstance(agent_name, str):
|
||||
existing_names.add(agent_name.lower())
|
||||
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||
if graph_id and isinstance(graph_id, str):
|
||||
existing_ids.add(graph_id)
|
||||
|
||||
all_agents: list[AgentSummary] | list[dict[str, Any]] = list(existing_agents)
|
||||
|
||||
for term in search_terms[:3]:
|
||||
try:
|
||||
additional_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=term,
|
||||
exclude_graph_id=exclude_graph_id,
|
||||
include_marketplace=include_marketplace,
|
||||
max_library_results=max_additional_results,
|
||||
max_marketplace_results=5,
|
||||
)
|
||||
|
||||
for agent in additional_agents:
|
||||
agent_name = agent.get("name")
|
||||
if not agent_name or not isinstance(agent_name, str):
|
||||
continue
|
||||
agent_name_lower = agent_name.lower()
|
||||
|
||||
if agent_name_lower in existing_names:
|
||||
continue
|
||||
|
||||
graph_id = agent.get("graph_id") # type: ignore[call-overload]
|
||||
if graph_id and graph_id in existing_ids:
|
||||
continue
|
||||
|
||||
all_agents.append(agent)
|
||||
existing_names.add(agent_name_lower)
|
||||
if graph_id and isinstance(graph_id, str):
|
||||
existing_ids.add(graph_id)
|
||||
|
||||
except DatabaseError:
|
||||
logger.error(f"Database error searching for agents with term '{term}'")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to search for additional agents with term '{term}': {e}"
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Enriched library agents: {len(existing_agents)} initial + "
|
||||
f"{len(all_agents) - len(existing_agents)} additional = {len(all_agents)} total"
|
||||
)
|
||||
|
||||
return all_agents
|
||||
|
||||
|
||||
async def decompose_goal(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
) -> DecompositionResult | None:
|
||||
async def decompose_goal(description: str, context: str = "") -> dict[str, Any] | None:
|
||||
"""Break down a goal into steps or return clarifying questions.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
DecompositionResult with either:
|
||||
Dict with either:
|
||||
- {"type": "clarifying_questions", "questions": [...]}
|
||||
- {"type": "instructions", "steps": [...]}
|
||||
Or None on error
|
||||
@@ -540,21 +54,14 @@ async def decompose_goal(
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for decompose_goal")
|
||||
result = await decompose_goal_external(
|
||||
description, context, _to_dict_list(library_agents)
|
||||
)
|
||||
return result # type: ignore[return-value]
|
||||
return await decompose_goal_external(description, context)
|
||||
|
||||
|
||||
async def generate_agent(
|
||||
instructions: DecompositionResult | dict[str, Any],
|
||||
library_agents: list[AgentSummary] | list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
async def generate_agent(instructions: dict[str, Any]) -> dict[str, Any] | None:
|
||||
"""Generate agent JSON from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Agent JSON dict, error dict {"type": "error", ...}, or None on error
|
||||
@@ -564,12 +71,12 @@ async def generate_agent(
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent")
|
||||
result = await generate_agent_external(
|
||||
dict(instructions), _to_dict_list(library_agents)
|
||||
)
|
||||
result = await generate_agent_external(instructions)
|
||||
if result:
|
||||
# Check if it's an error response - pass through as-is
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
return result
|
||||
# Ensure required fields for successful agent generation
|
||||
if "id" not in result:
|
||||
result["id"] = str(uuid.uuid4())
|
||||
if "version" not in result:
|
||||
@@ -579,12 +86,6 @@ async def generate_agent(
|
||||
return result
|
||||
|
||||
|
||||
class AgentJsonValidationError(Exception):
|
||||
"""Raised when agent JSON is invalid or missing required fields."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
"""Convert agent JSON dict to Graph model.
|
||||
|
||||
@@ -593,55 +94,25 @@ def json_to_graph(agent_json: dict[str, Any]) -> Graph:
|
||||
|
||||
Returns:
|
||||
Graph ready for saving
|
||||
|
||||
Raises:
|
||||
AgentJsonValidationError: If required fields are missing from nodes or links
|
||||
"""
|
||||
nodes = []
|
||||
for idx, n in enumerate(agent_json.get("nodes", [])):
|
||||
block_id = n.get("block_id")
|
||||
if not block_id:
|
||||
node_id = n.get("id", f"index_{idx}")
|
||||
raise AgentJsonValidationError(
|
||||
f"Node '{node_id}' is missing required field 'block_id'"
|
||||
)
|
||||
for n in agent_json.get("nodes", []):
|
||||
node = Node(
|
||||
id=n.get("id", str(uuid.uuid4())),
|
||||
block_id=block_id,
|
||||
block_id=n["block_id"],
|
||||
input_default=n.get("input_default", {}),
|
||||
metadata=n.get("metadata", {}),
|
||||
)
|
||||
nodes.append(node)
|
||||
|
||||
links = []
|
||||
for idx, link_data in enumerate(agent_json.get("links", [])):
|
||||
source_id = link_data.get("source_id")
|
||||
sink_id = link_data.get("sink_id")
|
||||
source_name = link_data.get("source_name")
|
||||
sink_name = link_data.get("sink_name")
|
||||
|
||||
missing_fields = []
|
||||
if not source_id:
|
||||
missing_fields.append("source_id")
|
||||
if not sink_id:
|
||||
missing_fields.append("sink_id")
|
||||
if not source_name:
|
||||
missing_fields.append("source_name")
|
||||
if not sink_name:
|
||||
missing_fields.append("sink_name")
|
||||
|
||||
if missing_fields:
|
||||
link_id = link_data.get("id", f"index_{idx}")
|
||||
raise AgentJsonValidationError(
|
||||
f"Link '{link_id}' is missing required fields: {', '.join(missing_fields)}"
|
||||
)
|
||||
|
||||
for link_data in agent_json.get("links", []):
|
||||
link = Link(
|
||||
id=link_data.get("id", str(uuid.uuid4())),
|
||||
source_id=source_id,
|
||||
sink_id=sink_id,
|
||||
source_name=source_name,
|
||||
sink_name=sink_name,
|
||||
source_id=link_data["source_id"],
|
||||
sink_id=link_data["sink_id"],
|
||||
source_name=link_data["source_name"],
|
||||
sink_name=link_data["sink_name"],
|
||||
is_static=link_data.get("is_static", False),
|
||||
)
|
||||
links.append(link)
|
||||
@@ -662,40 +133,22 @@ def _reassign_node_ids(graph: Graph) -> None:
|
||||
|
||||
This is needed when creating a new version to avoid unique constraint violations.
|
||||
"""
|
||||
# Create mapping from old node IDs to new UUIDs
|
||||
id_map = {node.id: str(uuid.uuid4()) for node in graph.nodes}
|
||||
|
||||
# Reassign node IDs
|
||||
for node in graph.nodes:
|
||||
node.id = id_map[node.id]
|
||||
|
||||
# Update link references to use new node IDs
|
||||
for link in graph.links:
|
||||
link.id = str(uuid.uuid4())
|
||||
link.id = str(uuid.uuid4()) # Also give links new IDs
|
||||
if link.source_id in id_map:
|
||||
link.source_id = id_map[link.source_id]
|
||||
if link.sink_id in id_map:
|
||||
link.sink_id = id_map[link.sink_id]
|
||||
|
||||
|
||||
def _populate_agent_executor_user_ids(agent_json: dict[str, Any], user_id: str) -> None:
|
||||
"""Populate user_id in AgentExecutorBlock nodes.
|
||||
|
||||
The external agent generator creates AgentExecutorBlock nodes with empty user_id.
|
||||
This function fills in the actual user_id so sub-agents run with correct permissions.
|
||||
|
||||
Args:
|
||||
agent_json: Agent JSON dict (modified in place)
|
||||
user_id: User ID to set
|
||||
"""
|
||||
for node in agent_json.get("nodes", []):
|
||||
if node.get("block_id") == AGENT_EXECUTOR_BLOCK_ID:
|
||||
input_default = node.get("input_default") or {}
|
||||
if not input_default.get("user_id"):
|
||||
input_default["user_id"] = user_id
|
||||
node["input_default"] = input_default
|
||||
logger.debug(
|
||||
f"Set user_id for AgentExecutorBlock node {node.get('id')}"
|
||||
)
|
||||
|
||||
|
||||
async def save_agent_to_library(
|
||||
agent_json: dict[str, Any], user_id: str, is_update: bool = False
|
||||
) -> tuple[Graph, Any]:
|
||||
@@ -709,27 +162,33 @@ async def save_agent_to_library(
|
||||
Returns:
|
||||
Tuple of (created Graph, LibraryAgent)
|
||||
"""
|
||||
# Populate user_id in AgentExecutorBlock nodes before conversion
|
||||
_populate_agent_executor_user_ids(agent_json, user_id)
|
||||
from backend.data.graph import get_graph_all_versions
|
||||
|
||||
graph = json_to_graph(agent_json)
|
||||
|
||||
if is_update:
|
||||
# For updates, keep the same graph ID but increment version
|
||||
# and reassign node/link IDs to avoid conflicts
|
||||
if graph.id:
|
||||
existing_versions = await get_graph_all_versions(graph.id, user_id)
|
||||
if existing_versions:
|
||||
latest_version = max(v.version for v in existing_versions)
|
||||
graph.version = latest_version + 1
|
||||
# Reassign node IDs (but keep graph ID the same)
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Updating agent {graph.id} to version {graph.version}")
|
||||
else:
|
||||
# For new agents, always generate a fresh UUID to avoid collisions
|
||||
graph.id = str(uuid.uuid4())
|
||||
graph.version = 1
|
||||
# Reassign all node IDs as well
|
||||
_reassign_node_ids(graph)
|
||||
logger.info(f"Creating new agent with ID {graph.id}")
|
||||
|
||||
# Save to database
|
||||
created_graph = await create_graph(graph, user_id)
|
||||
|
||||
# Add to user's library (or update existing library agent)
|
||||
library_agents = await library_db.create_library_agent(
|
||||
graph=created_graph,
|
||||
user_id=user_id,
|
||||
@@ -741,31 +200,25 @@ async def save_agent_to_library(
|
||||
|
||||
|
||||
async def get_agent_as_json(
|
||||
agent_id: str, user_id: str | None
|
||||
graph_id: str, user_id: str | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Fetch an agent and convert to JSON format for editing.
|
||||
|
||||
Args:
|
||||
agent_id: Graph ID or library agent ID
|
||||
graph_id: Graph ID or library agent ID
|
||||
user_id: User ID
|
||||
|
||||
Returns:
|
||||
Agent as JSON dict or None if not found
|
||||
"""
|
||||
graph = await get_graph(agent_id, version=None, user_id=user_id)
|
||||
|
||||
if not graph and user_id:
|
||||
try:
|
||||
library_agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
graph = await get_graph(
|
||||
library_agent.graph_id, version=None, user_id=user_id
|
||||
)
|
||||
except NotFoundError:
|
||||
pass
|
||||
from backend.data.graph import get_graph
|
||||
|
||||
# Try to get the graph (version=None gets the active version)
|
||||
graph = await get_graph(graph_id, version=None, user_id=user_id)
|
||||
if not graph:
|
||||
return None
|
||||
|
||||
# Convert to JSON format
|
||||
nodes = []
|
||||
for node in graph.nodes:
|
||||
nodes.append(
|
||||
@@ -803,9 +256,7 @@ async def get_agent_as_json(
|
||||
|
||||
|
||||
async def generate_agent_patch(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[AgentSummary] | None = None,
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Update an existing agent using natural language.
|
||||
|
||||
@@ -817,7 +268,6 @@ async def generate_agent_patch(
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict {"type": "clarifying_questions", ...},
|
||||
@@ -828,6 +278,4 @@ async def generate_agent_patch(
|
||||
"""
|
||||
_check_service_configured()
|
||||
logger.info("Calling external Agent Generator service for generate_agent_patch")
|
||||
return await generate_agent_patch_external(
|
||||
update_request, current_agent, _to_dict_list(library_agents)
|
||||
)
|
||||
return await generate_agent_patch_external(update_request, current_agent)
|
||||
|
||||
@@ -1,43 +1,11 @@
|
||||
"""Error handling utilities for agent generator."""
|
||||
|
||||
import re
|
||||
|
||||
|
||||
def _sanitize_error_details(details: str) -> str:
|
||||
"""Sanitize error details to remove sensitive information.
|
||||
|
||||
Strips common patterns that could expose internal system info:
|
||||
- File paths (Unix and Windows)
|
||||
- Database connection strings
|
||||
- URLs with credentials
|
||||
- Stack trace internals
|
||||
|
||||
Args:
|
||||
details: Raw error details string
|
||||
|
||||
Returns:
|
||||
Sanitized error details safe for user display
|
||||
"""
|
||||
sanitized = re.sub(
|
||||
r"/[a-zA-Z0-9_./\-]+\.(py|js|ts|json|yaml|yml)", "[path]", details
|
||||
)
|
||||
sanitized = re.sub(r"[A-Z]:\\[a-zA-Z0-9_\\.\\-]+", "[path]", sanitized)
|
||||
sanitized = re.sub(
|
||||
r"(postgres|mysql|mongodb|redis)://[^\s]+", "[database_url]", sanitized
|
||||
)
|
||||
sanitized = re.sub(r"https?://[^:]+:[^@]+@[^\s]+", "[url]", sanitized)
|
||||
sanitized = re.sub(r", line \d+", "", sanitized)
|
||||
sanitized = re.sub(r'File "[^"]+",?', "", sanitized)
|
||||
|
||||
return sanitized.strip()
|
||||
|
||||
|
||||
def get_user_message_for_error(
|
||||
error_type: str,
|
||||
operation: str = "process the request",
|
||||
llm_parse_message: str | None = None,
|
||||
validation_message: str | None = None,
|
||||
error_details: str | None = None,
|
||||
) -> str:
|
||||
"""Get a user-friendly error message based on error type.
|
||||
|
||||
@@ -51,45 +19,25 @@ def get_user_message_for_error(
|
||||
message (e.g., "analyze the goal", "generate the agent")
|
||||
llm_parse_message: Custom message for llm_parse_error type
|
||||
validation_message: Custom message for validation_error type
|
||||
error_details: Optional additional details about the error
|
||||
|
||||
Returns:
|
||||
User-friendly error message suitable for display to the user
|
||||
"""
|
||||
base_message = ""
|
||||
|
||||
if error_type == "llm_parse_error":
|
||||
base_message = (
|
||||
return (
|
||||
llm_parse_message
|
||||
or "The AI had trouble processing this request. Please try again."
|
||||
)
|
||||
elif error_type == "validation_error":
|
||||
base_message = (
|
||||
return (
|
||||
validation_message
|
||||
or "The generated agent failed validation. "
|
||||
"This usually happens when the agent structure doesn't match "
|
||||
"what the platform expects. Please try simplifying your goal "
|
||||
"or breaking it into smaller parts."
|
||||
or "The request failed validation. Please try rephrasing."
|
||||
)
|
||||
elif error_type == "patch_error":
|
||||
base_message = (
|
||||
"Failed to apply the changes. The modification couldn't be "
|
||||
"validated. Please try a different approach or simplify the change."
|
||||
)
|
||||
return "Failed to apply the changes. Please try a different approach."
|
||||
elif error_type in ("timeout", "llm_timeout"):
|
||||
base_message = (
|
||||
"The request took too long to process. This can happen with "
|
||||
"complex agents. Please try again or simplify your goal."
|
||||
)
|
||||
return "The request took too long. Please try again."
|
||||
elif error_type in ("rate_limit", "llm_rate_limit"):
|
||||
base_message = "The service is currently busy. Please try again in a moment."
|
||||
return "The service is currently busy. Please try again in a moment."
|
||||
else:
|
||||
base_message = f"Failed to {operation}. Please try again."
|
||||
|
||||
if error_details:
|
||||
details = _sanitize_error_details(error_details)
|
||||
if len(details) > 200:
|
||||
details = details[:200] + "..."
|
||||
base_message += f"\n\nTechnical details: {details}"
|
||||
|
||||
return base_message
|
||||
return f"Failed to {operation}. Please try again."
|
||||
|
||||
@@ -117,16 +117,13 @@ def _get_client() -> httpx.AsyncClient:
|
||||
|
||||
|
||||
async def decompose_goal_external(
|
||||
description: str,
|
||||
context: str = "",
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
description: str, context: str = ""
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to decompose a goal.
|
||||
|
||||
Args:
|
||||
description: Natural language goal description
|
||||
context: Additional context (e.g., answers to previous questions)
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Dict with either:
|
||||
@@ -139,12 +136,11 @@ async def decompose_goal_external(
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
if context:
|
||||
description = f"{description}\n\nAdditional context from user:\n{context}"
|
||||
|
||||
# Build the request payload
|
||||
payload: dict[str, Any] = {"description": description}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
if context:
|
||||
# The external service uses user_instruction for additional context
|
||||
payload["user_instruction"] = context
|
||||
|
||||
try:
|
||||
response = await client.post("/api/decompose-description", json=payload)
|
||||
@@ -211,25 +207,21 @@ async def decompose_goal_external(
|
||||
|
||||
async def generate_agent_external(
|
||||
instructions: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate an agent from instructions.
|
||||
|
||||
Args:
|
||||
instructions: Structured instructions from decompose_goal
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Agent JSON dict on success, or error dict {"type": "error", ...} on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
payload: dict[str, Any] = {"instructions": instructions}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/generate-agent", json=payload)
|
||||
response = await client.post(
|
||||
"/api/generate-agent", json={"instructions": instructions}
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
@@ -237,7 +229,8 @@ async def generate_agent_external(
|
||||
error_msg = data.get("error", "Unknown error from Agent Generator")
|
||||
error_type = data.get("error_type", "unknown")
|
||||
logger.error(
|
||||
f"Agent Generator generation failed: {error_msg} (type: {error_type})"
|
||||
f"Agent Generator generation failed: {error_msg} "
|
||||
f"(type: {error_type})"
|
||||
)
|
||||
return _create_error_response(error_msg, error_type)
|
||||
|
||||
@@ -258,31 +251,27 @@ async def generate_agent_external(
|
||||
|
||||
|
||||
async def generate_agent_patch_external(
|
||||
update_request: str,
|
||||
current_agent: dict[str, Any],
|
||||
library_agents: list[dict[str, Any]] | None = None,
|
||||
update_request: str, current_agent: dict[str, Any]
|
||||
) -> dict[str, Any] | None:
|
||||
"""Call the external service to generate a patch for an existing agent.
|
||||
|
||||
Args:
|
||||
update_request: Natural language description of changes
|
||||
current_agent: Current agent JSON
|
||||
library_agents: User's library agents available for sub-agent composition
|
||||
|
||||
Returns:
|
||||
Updated agent JSON, clarifying questions dict, or error dict on error
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
try:
|
||||
response = await client.post("/api/update-agent", json=payload)
|
||||
response = await client.post(
|
||||
"/api/update-agent",
|
||||
json={
|
||||
"update_request": update_request,
|
||||
"current_agent_json": current_agent,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Shared agent search functionality for find_agent and find_library_agent tools."""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Literal
|
||||
|
||||
from backend.api.features.library import db as library_db
|
||||
@@ -20,85 +19,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
SearchSource = Literal["marketplace", "library"]
|
||||
|
||||
_UUID_PATTERN = re.compile(
|
||||
r"^[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}$",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(text: str) -> bool:
|
||||
"""Check if text is a valid UUID v4."""
|
||||
return bool(_UUID_PATTERN.match(text.strip()))
|
||||
|
||||
|
||||
async def _get_library_agent_by_id(user_id: str, agent_id: str) -> AgentInfo | None:
|
||||
"""Fetch a library agent by ID (library agent ID or graph_id).
|
||||
|
||||
Tries multiple lookup strategies:
|
||||
1. First by graph_id (AgentGraph primary key)
|
||||
2. Then by library agent ID (LibraryAgent primary key)
|
||||
|
||||
Args:
|
||||
user_id: The user ID
|
||||
agent_id: The ID to look up (can be graph_id or library agent ID)
|
||||
|
||||
Returns:
|
||||
AgentInfo if found, None otherwise
|
||||
"""
|
||||
try:
|
||||
agent = await library_db.get_library_agent_by_graph_id(user_id, agent_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by graph_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch library agent by graph_id {agent_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
try:
|
||||
agent = await library_db.get_library_agent(agent_id, user_id)
|
||||
if agent:
|
||||
logger.debug(f"Found library agent by library_id: {agent.name}")
|
||||
return AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
except NotFoundError:
|
||||
logger.debug(f"Library agent not found by library_id: {agent_id}")
|
||||
except DatabaseError:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Could not fetch library agent by library_id {agent_id}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def search_agents(
|
||||
query: str,
|
||||
@@ -149,37 +69,29 @@ async def search_agents(
|
||||
is_featured=False,
|
||||
)
|
||||
)
|
||||
else:
|
||||
if _is_uuid(query):
|
||||
logger.info(f"Query looks like UUID, trying direct lookup: {query}")
|
||||
agent = await _get_library_agent_by_id(user_id, query) # type: ignore[arg-type]
|
||||
if agent:
|
||||
agents.append(agent)
|
||||
logger.info(f"Found agent by direct ID lookup: {agent.name}")
|
||||
|
||||
if not agents:
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
else: # library
|
||||
logger.info(f"Searching user library for: {query}")
|
||||
results = await library_db.list_library_agents(
|
||||
user_id=user_id, # type: ignore[arg-type]
|
||||
search_term=query,
|
||||
page_size=10,
|
||||
)
|
||||
for agent in results.agents:
|
||||
agents.append(
|
||||
AgentInfo(
|
||||
id=agent.id,
|
||||
name=agent.name,
|
||||
description=agent.description or "",
|
||||
source="library",
|
||||
in_library=True,
|
||||
creator=agent.creator_name,
|
||||
status=agent.status.value,
|
||||
can_access_graph=agent.can_access_graph,
|
||||
has_external_trigger=agent.has_external_trigger,
|
||||
new_output=agent.new_output,
|
||||
graph_id=agent.graph_id,
|
||||
)
|
||||
)
|
||||
logger.info(f"Found {len(agents)} agents in {source}")
|
||||
except NotFoundError:
|
||||
pass
|
||||
|
||||
@@ -8,9 +8,7 @@ from backend.api.features.chat.model import ChatSession
|
||||
from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
decompose_goal,
|
||||
enrich_library_agents_from_steps,
|
||||
generate_agent,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
@@ -105,24 +103,9 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
library_agents = None
|
||||
if user_id:
|
||||
try:
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=description,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
|
||||
# Step 1: Decompose goal into steps
|
||||
try:
|
||||
decomposition_result = await decompose_goal(
|
||||
description, context, library_agents
|
||||
)
|
||||
decomposition_result = await decompose_goal(description, context)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -141,6 +124,7 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if decomposition_result.get("type") == "error":
|
||||
error_msg = decomposition_result.get("error", "Unknown error")
|
||||
error_type = decomposition_result.get("error_type", "unknown")
|
||||
@@ -160,6 +144,7 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if decomposition_result.get("type") == "clarifying_questions":
|
||||
questions = decomposition_result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
@@ -178,6 +163,7 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check for unachievable/vague goals
|
||||
if decomposition_result.get("type") == "unachievable_goal":
|
||||
suggested = decomposition_result.get("suggested_goal", "")
|
||||
reason = decomposition_result.get("reason", "")
|
||||
@@ -204,22 +190,9 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
if user_id and library_agents is not None:
|
||||
try:
|
||||
library_agents = await enrich_library_agents_from_steps(
|
||||
user_id=user_id,
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=library_agents,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"After enrichment: {len(library_agents)} total agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to enrich library agents from steps: {e}")
|
||||
|
||||
# Step 2: Generate agent JSON (external service handles fixing and validation)
|
||||
try:
|
||||
agent_json = await generate_agent(decomposition_result, library_agents)
|
||||
agent_json = await generate_agent(decomposition_result)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -238,6 +211,7 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(agent_json, dict) and agent_json.get("type") == "error":
|
||||
error_msg = agent_json.get("error", "Unknown error")
|
||||
error_type = agent_json.get("error_type", "unknown")
|
||||
@@ -245,12 +219,7 @@ class CreateAgentTool(BaseTool):
|
||||
error_type,
|
||||
operation="generate the agent",
|
||||
llm_parse_message="The AI had trouble generating the agent. Please try again or simplify your goal.",
|
||||
validation_message=(
|
||||
"I wasn't able to create a valid agent for this request. "
|
||||
"The generated workflow had some structural issues. "
|
||||
"Please try simplifying your goal or breaking it into smaller steps."
|
||||
),
|
||||
error_details=error_msg,
|
||||
validation_message="The generated agent failed validation. Please try rephrasing your goal.",
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
@@ -268,6 +237,7 @@ class CreateAgentTool(BaseTool):
|
||||
node_count = len(agent_json.get("nodes", []))
|
||||
link_count = len(agent_json.get("links", []))
|
||||
|
||||
# Step 3: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
@@ -282,6 +252,7 @@ class CreateAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
@@ -299,7 +270,7 @@ class CreateAgentTool(BaseTool):
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -9,7 +9,6 @@ from .agent_generator import (
|
||||
AgentGeneratorNotConfiguredError,
|
||||
generate_agent_patch,
|
||||
get_agent_as_json,
|
||||
get_all_relevant_agents_for_generation,
|
||||
get_user_message_for_error,
|
||||
save_agent_to_library,
|
||||
)
|
||||
@@ -118,6 +117,7 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Step 1: Fetch current agent
|
||||
current_agent = await get_agent_as_json(agent_id, user_id)
|
||||
|
||||
if current_agent is None:
|
||||
@@ -127,30 +127,14 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
library_agents = None
|
||||
if user_id:
|
||||
try:
|
||||
graph_id = current_agent.get("id")
|
||||
library_agents = await get_all_relevant_agents_for_generation(
|
||||
user_id=user_id,
|
||||
search_query=changes,
|
||||
exclude_graph_id=graph_id,
|
||||
include_marketplace=True,
|
||||
)
|
||||
logger.debug(
|
||||
f"Found {len(library_agents)} relevant agents for sub-agent composition"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch library agents: {e}")
|
||||
|
||||
# Build the update request with context
|
||||
update_request = changes
|
||||
if context:
|
||||
update_request = f"{changes}\n\nAdditional context:\n{context}"
|
||||
|
||||
# Step 2: Generate updated agent (external service handles fixing and validation)
|
||||
try:
|
||||
result = await generate_agent_patch(
|
||||
update_request, current_agent, library_agents
|
||||
)
|
||||
result = await generate_agent_patch(update_request, current_agent)
|
||||
except AgentGeneratorNotConfiguredError:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
@@ -169,6 +153,7 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if the result is an error from the external service
|
||||
if isinstance(result, dict) and result.get("type") == "error":
|
||||
error_msg = result.get("error", "Unknown error")
|
||||
error_type = result.get("error_type", "unknown")
|
||||
@@ -177,7 +162,6 @@ class EditAgentTool(BaseTool):
|
||||
operation="generate the changes",
|
||||
llm_parse_message="The AI had trouble generating the changes. Please try again or simplify your request.",
|
||||
validation_message="The generated changes failed validation. Please try rephrasing your request.",
|
||||
error_details=error_msg,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=user_message,
|
||||
@@ -191,6 +175,7 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Check if LLM returned clarifying questions
|
||||
if result.get("type") == "clarifying_questions":
|
||||
questions = result.get("questions", [])
|
||||
return ClarificationNeededResponse(
|
||||
@@ -209,6 +194,7 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Result is the updated agent JSON
|
||||
updated_agent = result
|
||||
|
||||
agent_name = updated_agent.get("name", "Updated Agent")
|
||||
@@ -216,6 +202,7 @@ class EditAgentTool(BaseTool):
|
||||
node_count = len(updated_agent.get("nodes", []))
|
||||
link_count = len(updated_agent.get("links", []))
|
||||
|
||||
# Step 3: Preview or save
|
||||
if not save:
|
||||
return AgentPreviewResponse(
|
||||
message=(
|
||||
@@ -231,6 +218,7 @@ class EditAgentTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
# Save to library (creates a new version)
|
||||
if not user_id:
|
||||
return ErrorResponse(
|
||||
message="You must be logged in to save agents.",
|
||||
@@ -248,7 +236,7 @@ class EditAgentTool(BaseTool):
|
||||
agent_id=created_graph.id,
|
||||
agent_name=created_graph.name,
|
||||
library_agent_id=library_agent.id,
|
||||
library_agent_link=f"/library/agents/{library_agent.id}",
|
||||
library_agent_link=f"/library/{library_agent.id}",
|
||||
agent_page_link=f"/build?flowID={created_graph.id}",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
@@ -38,8 +38,6 @@ class ResponseType(str, Enum):
|
||||
OPERATION_STARTED = "operation_started"
|
||||
OPERATION_PENDING = "operation_pending"
|
||||
OPERATION_IN_PROGRESS = "operation_in_progress"
|
||||
# Input validation
|
||||
INPUT_VALIDATION_ERROR = "input_validation_error"
|
||||
|
||||
|
||||
# Base response model
|
||||
@@ -70,10 +68,6 @@ class AgentInfo(BaseModel):
|
||||
has_external_trigger: bool | None = None
|
||||
new_output: bool | None = None
|
||||
graph_id: str | None = None
|
||||
inputs: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
description="Input schema for the agent, including field names, types, and defaults",
|
||||
)
|
||||
|
||||
|
||||
class AgentsFoundResponse(ToolResponseBase):
|
||||
@@ -200,20 +194,6 @@ class ErrorResponse(ToolResponseBase):
|
||||
details: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class InputValidationErrorResponse(ToolResponseBase):
|
||||
"""Response when run_agent receives unknown input fields."""
|
||||
|
||||
type: ResponseType = ResponseType.INPUT_VALIDATION_ERROR
|
||||
unrecognized_fields: list[str] = Field(
|
||||
description="List of input field names that were not recognized"
|
||||
)
|
||||
inputs: dict[str, Any] = Field(
|
||||
description="The agent's valid input schema for reference"
|
||||
)
|
||||
graph_id: str | None = None
|
||||
graph_version: int | None = None
|
||||
|
||||
|
||||
# Agent output models
|
||||
class ExecutionOutputInfo(BaseModel):
|
||||
"""Summary of a single execution's outputs."""
|
||||
|
||||
@@ -30,7 +30,6 @@ from .models import (
|
||||
ErrorResponse,
|
||||
ExecutionOptions,
|
||||
ExecutionStartedResponse,
|
||||
InputValidationErrorResponse,
|
||||
SetupInfo,
|
||||
SetupRequirementsResponse,
|
||||
ToolResponseBase,
|
||||
@@ -274,22 +273,6 @@ class RunAgentTool(BaseTool):
|
||||
input_properties = graph.input_schema.get("properties", {})
|
||||
required_fields = set(graph.input_schema.get("required", []))
|
||||
provided_inputs = set(params.inputs.keys())
|
||||
valid_fields = set(input_properties.keys())
|
||||
|
||||
# Check for unknown input fields
|
||||
unrecognized_fields = provided_inputs - valid_fields
|
||||
if unrecognized_fields:
|
||||
return InputValidationErrorResponse(
|
||||
message=(
|
||||
f"Unknown input field(s) provided: {', '.join(sorted(unrecognized_fields))}. "
|
||||
f"Agent was not executed. Please use the correct field names from the schema."
|
||||
),
|
||||
session_id=session_id,
|
||||
unrecognized_fields=sorted(unrecognized_fields),
|
||||
inputs=graph.input_schema,
|
||||
graph_id=graph.id,
|
||||
graph_version=graph.version,
|
||||
)
|
||||
|
||||
# If agent has inputs but none were provided AND use_defaults is not set,
|
||||
# always show what's available first so user can decide
|
||||
|
||||
@@ -402,42 +402,3 @@ async def test_run_agent_schedule_without_name(setup_test_data):
|
||||
# Should return error about missing schedule_name
|
||||
assert result_data.get("type") == "error"
|
||||
assert "schedule_name" in result_data["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="session")
|
||||
async def test_run_agent_rejects_unknown_input_fields(setup_test_data):
|
||||
"""Test that run_agent returns input_validation_error for unknown input fields."""
|
||||
user = setup_test_data["user"]
|
||||
store_submission = setup_test_data["store_submission"]
|
||||
|
||||
tool = RunAgentTool()
|
||||
agent_marketplace_id = f"{user.email.split('@')[0]}/{store_submission.slug}"
|
||||
session = make_session(user_id=user.id)
|
||||
|
||||
# Execute with unknown input field names
|
||||
response = await tool.execute(
|
||||
user_id=user.id,
|
||||
session_id=str(uuid.uuid4()),
|
||||
tool_call_id=str(uuid.uuid4()),
|
||||
username_agent_slug=agent_marketplace_id,
|
||||
inputs={
|
||||
"unknown_field": "some value",
|
||||
"another_unknown": "another value",
|
||||
},
|
||||
session=session,
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert hasattr(response, "output")
|
||||
assert isinstance(response.output, str)
|
||||
result_data = orjson.loads(response.output)
|
||||
|
||||
# Should return input_validation_error type with unrecognized fields
|
||||
assert result_data.get("type") == "input_validation_error"
|
||||
assert "unrecognized_fields" in result_data
|
||||
assert set(result_data["unrecognized_fields"]) == {
|
||||
"another_unknown",
|
||||
"unknown_field",
|
||||
}
|
||||
assert "inputs" in result_data # Contains the valid schema
|
||||
assert "Agent was not executed" in result_data["message"]
|
||||
|
||||
@@ -5,8 +5,6 @@ import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from backend.api.features.chat.model import ChatSession
|
||||
from backend.data.block import get_block
|
||||
from backend.data.execution import ExecutionContext
|
||||
@@ -77,22 +75,15 @@ class RunBlockTool(BaseTool):
|
||||
self,
|
||||
user_id: str,
|
||||
block: Any,
|
||||
input_data: dict[str, Any] | None = None,
|
||||
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||
"""
|
||||
Check if user has required credentials for a block.
|
||||
|
||||
Args:
|
||||
user_id: User ID
|
||||
block: Block to check credentials for
|
||||
input_data: Input data for the block (used to determine provider via discriminator)
|
||||
|
||||
Returns:
|
||||
tuple[matched_credentials, missing_credentials]
|
||||
"""
|
||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
||||
missing_credentials: list[CredentialsMetaInput] = []
|
||||
input_data = input_data or {}
|
||||
|
||||
# Get credential field info from block's input schema
|
||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||
@@ -105,33 +96,14 @@ class RunBlockTool(BaseTool):
|
||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
||||
|
||||
for field_name, field_info in credentials_fields_info.items():
|
||||
effective_field_info = field_info
|
||||
if field_info.discriminator and field_info.discriminator_mapping:
|
||||
# Get discriminator from input, falling back to schema default
|
||||
discriminator_value = input_data.get(field_info.discriminator)
|
||||
if discriminator_value is None:
|
||||
field = block.input_schema.model_fields.get(
|
||||
field_info.discriminator
|
||||
)
|
||||
if field and field.default is not PydanticUndefined:
|
||||
discriminator_value = field.default
|
||||
|
||||
if (
|
||||
discriminator_value
|
||||
and discriminator_value in field_info.discriminator_mapping
|
||||
):
|
||||
effective_field_info = field_info.discriminate(discriminator_value)
|
||||
logger.debug(
|
||||
f"Discriminated provider for {field_name}: "
|
||||
f"{discriminator_value} -> {effective_field_info.provider}"
|
||||
)
|
||||
|
||||
# field_info.provider is a frozenset of acceptable providers
|
||||
# field_info.supported_types is a frozenset of acceptable types
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in effective_field_info.provider
|
||||
and cred.type in effective_field_info.supported_types
|
||||
if cred.provider in field_info.provider
|
||||
and cred.type in field_info.supported_types
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -145,8 +117,8 @@ class RunBlockTool(BaseTool):
|
||||
)
|
||||
else:
|
||||
# Create a placeholder for the missing credential
|
||||
provider = next(iter(effective_field_info.provider), "unknown")
|
||||
cred_type = next(iter(effective_field_info.supported_types), "api_key")
|
||||
provider = next(iter(field_info.provider), "unknown")
|
||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||
missing_credentials.append(
|
||||
CredentialsMetaInput(
|
||||
id=field_name,
|
||||
@@ -214,9 +186,10 @@ class RunBlockTool(BaseTool):
|
||||
|
||||
logger.info(f"Executing block {block.name} ({block_id}) for user {user_id}")
|
||||
|
||||
# Check credentials
|
||||
creds_manager = IntegrationCredentialsManager()
|
||||
matched_credentials, missing_credentials = await self._check_block_credentials(
|
||||
user_id, block, input_data
|
||||
user_id, block
|
||||
)
|
||||
|
||||
if missing_credentials:
|
||||
|
||||
@@ -8,7 +8,7 @@ from backend.api.features.library import model as library_model
|
||||
from backend.api.features.store import db as store_db
|
||||
from backend.data import graph as graph_db
|
||||
from backend.data.graph import GraphModel
|
||||
from backend.data.model import Credentials, CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||
from backend.util.exceptions import NotFoundError
|
||||
|
||||
@@ -266,14 +266,13 @@ async def match_user_credentials_to_graph(
|
||||
credential_requirements,
|
||||
_node_fields,
|
||||
) in aggregated_creds.items():
|
||||
# Find first matching credential by provider, type, and scopes
|
||||
# Find first matching credential by provider and type
|
||||
matching_cred = next(
|
||||
(
|
||||
cred
|
||||
for cred in available_creds
|
||||
if cred.provider in credential_requirements.provider
|
||||
and cred.type in credential_requirements.supported_types
|
||||
and _credential_has_required_scopes(cred, credential_requirements)
|
||||
),
|
||||
None,
|
||||
)
|
||||
@@ -297,17 +296,10 @@ async def match_user_credentials_to_graph(
|
||||
f"{credential_field_name} (validation failed: {e})"
|
||||
)
|
||||
else:
|
||||
# Build a helpful error message including scope requirements
|
||||
error_parts = [
|
||||
f"provider in {list(credential_requirements.provider)}",
|
||||
f"type in {list(credential_requirements.supported_types)}",
|
||||
]
|
||||
if credential_requirements.required_scopes:
|
||||
error_parts.append(
|
||||
f"scopes including {list(credential_requirements.required_scopes)}"
|
||||
)
|
||||
missing_creds.append(
|
||||
f"{credential_field_name} (requires {', '.join(error_parts)})"
|
||||
f"{credential_field_name} "
|
||||
f"(requires provider in {list(credential_requirements.provider)}, "
|
||||
f"type in {list(credential_requirements.supported_types)})"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -317,28 +309,6 @@ async def match_user_credentials_to_graph(
|
||||
return graph_credentials_inputs, missing_creds
|
||||
|
||||
|
||||
def _credential_has_required_scopes(
|
||||
credential: Credentials,
|
||||
requirements: CredentialsFieldInfo,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a credential has all the scopes required by the block.
|
||||
|
||||
For OAuth2 credentials, verifies that the credential's scopes are a superset
|
||||
of the required scopes. For other credential types, returns True (no scope check).
|
||||
"""
|
||||
# Only OAuth2 credentials have scopes to check
|
||||
if credential.type != "oauth2":
|
||||
return True
|
||||
|
||||
# If no scopes are required, any credential matches
|
||||
if not requirements.required_scopes:
|
||||
return True
|
||||
|
||||
# Check that credential scopes are a superset of required scopes
|
||||
return set(credential.scopes).issuperset(requirements.required_scopes)
|
||||
|
||||
|
||||
async def check_user_has_required_credentials(
|
||||
user_id: str,
|
||||
required_credentials: list[CredentialsMetaInput],
|
||||
|
||||
@@ -39,7 +39,6 @@ async def list_library_agents(
|
||||
sort_by: library_model.LibraryAgentSort = library_model.LibraryAgentSort.UPDATED_AT,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
include_executions: bool = False,
|
||||
) -> library_model.LibraryAgentResponse:
|
||||
"""
|
||||
Retrieves a paginated list of LibraryAgent records for a given user.
|
||||
@@ -50,9 +49,6 @@ async def list_library_agents(
|
||||
sort_by: Sorting field (createdAt, updatedAt, isFavorite, isCreatedByUser).
|
||||
page: Current page (1-indexed).
|
||||
page_size: Number of items per page.
|
||||
include_executions: Whether to include execution data for status calculation.
|
||||
Defaults to False for performance (UI fetches status separately).
|
||||
Set to True when accurate status/metrics are needed (e.g., agent generator).
|
||||
|
||||
Returns:
|
||||
A LibraryAgentResponse containing the list of agents and pagination details.
|
||||
@@ -80,6 +76,7 @@ async def list_library_agents(
|
||||
"isArchived": False,
|
||||
}
|
||||
|
||||
# Build search filter if applicable
|
||||
if search_term:
|
||||
where_clause["OR"] = [
|
||||
{
|
||||
@@ -96,6 +93,7 @@ async def list_library_agents(
|
||||
},
|
||||
]
|
||||
|
||||
# Determine sorting
|
||||
order_by: prisma.types.LibraryAgentOrderByInput | None = None
|
||||
|
||||
if sort_by == library_model.LibraryAgentSort.CREATED_AT:
|
||||
@@ -107,7 +105,7 @@ async def list_library_agents(
|
||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||
where=where_clause,
|
||||
include=library_agent_include(
|
||||
user_id, include_nodes=False, include_executions=include_executions
|
||||
user_id, include_nodes=False, include_executions=False
|
||||
),
|
||||
order=order_by,
|
||||
skip=(page - 1) * page_size,
|
||||
|
||||
@@ -9,7 +9,6 @@ import pydantic
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.graph import GraphModel, GraphSettings, GraphTriggerInfo
|
||||
from backend.data.model import CredentialsMetaInput, is_credentials_field_name
|
||||
from backend.util.json import loads as json_loads
|
||||
from backend.util.models import Pagination
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -17,10 +16,10 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class LibraryAgentStatus(str, Enum):
|
||||
COMPLETED = "COMPLETED"
|
||||
HEALTHY = "HEALTHY"
|
||||
WAITING = "WAITING"
|
||||
ERROR = "ERROR"
|
||||
COMPLETED = "COMPLETED" # All runs completed
|
||||
HEALTHY = "HEALTHY" # Agent is running (not all runs have completed)
|
||||
WAITING = "WAITING" # Agent is queued or waiting to start
|
||||
ERROR = "ERROR" # Agent is in an error state
|
||||
|
||||
|
||||
class MarketplaceListingCreator(pydantic.BaseModel):
|
||||
@@ -40,30 +39,6 @@ class MarketplaceListing(pydantic.BaseModel):
|
||||
creator: MarketplaceListingCreator
|
||||
|
||||
|
||||
class RecentExecution(pydantic.BaseModel):
|
||||
"""Summary of a recent execution for quality assessment.
|
||||
|
||||
Used by the LLM to understand the agent's recent performance with specific examples
|
||||
rather than just aggregate statistics.
|
||||
"""
|
||||
|
||||
status: str
|
||||
correctness_score: float | None = None
|
||||
activity_summary: str | None = None
|
||||
|
||||
|
||||
def _parse_settings(settings: dict | str | None) -> GraphSettings:
|
||||
"""Parse settings from database, handling both dict and string formats."""
|
||||
if settings is None:
|
||||
return GraphSettings()
|
||||
try:
|
||||
if isinstance(settings, str):
|
||||
settings = json_loads(settings)
|
||||
return GraphSettings.model_validate(settings)
|
||||
except Exception:
|
||||
return GraphSettings()
|
||||
|
||||
|
||||
class LibraryAgent(pydantic.BaseModel):
|
||||
"""
|
||||
Represents an agent in the library, including metadata for display and
|
||||
@@ -73,7 +48,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
id: str
|
||||
graph_id: str
|
||||
graph_version: int
|
||||
owner_user_id: str
|
||||
owner_user_id: str # ID of user who owns/created this agent graph
|
||||
|
||||
image_url: str | None
|
||||
|
||||
@@ -89,7 +64,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
description: str
|
||||
instructions: str | None = None
|
||||
|
||||
input_schema: dict[str, Any]
|
||||
input_schema: dict[str, Any] # Should be BlockIOObjectSubSchema in frontend
|
||||
output_schema: dict[str, Any]
|
||||
credentials_input_schema: dict[str, Any] | None = pydantic.Field(
|
||||
description="Input schema for credentials required by the agent",
|
||||
@@ -106,19 +81,25 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
)
|
||||
trigger_setup_info: Optional[GraphTriggerInfo] = None
|
||||
|
||||
# Indicates whether there's a new output (based on recent runs)
|
||||
new_output: bool
|
||||
execution_count: int = 0
|
||||
success_rate: float | None = None
|
||||
avg_correctness_score: float | None = None
|
||||
recent_executions: list[RecentExecution] = pydantic.Field(
|
||||
default_factory=list,
|
||||
description="List of recent executions with status, score, and summary",
|
||||
)
|
||||
|
||||
# Whether the user can access the underlying graph
|
||||
can_access_graph: bool
|
||||
|
||||
# Indicates if this agent is the latest version
|
||||
is_latest_version: bool
|
||||
|
||||
# Whether the agent is marked as favorite by the user
|
||||
is_favorite: bool
|
||||
|
||||
# Recommended schedule cron (from marketplace agents)
|
||||
recommended_schedule_cron: str | None = None
|
||||
|
||||
# User-specific settings for this library agent
|
||||
settings: GraphSettings = pydantic.Field(default_factory=GraphSettings)
|
||||
|
||||
# Marketplace listing information if the agent has been published
|
||||
marketplace_listing: Optional["MarketplaceListing"] = None
|
||||
|
||||
@staticmethod
|
||||
@@ -142,6 +123,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
agent_updated_at = agent.AgentGraph.updatedAt
|
||||
lib_agent_updated_at = agent.updatedAt
|
||||
|
||||
# Compute updated_at as the latest between library agent and graph
|
||||
updated_at = (
|
||||
max(agent_updated_at, lib_agent_updated_at)
|
||||
if agent_updated_at
|
||||
@@ -154,6 +136,7 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
creator_name = agent.Creator.name or "Unknown"
|
||||
creator_image_url = agent.Creator.avatarUrl or ""
|
||||
|
||||
# Logic to calculate status and new_output
|
||||
week_ago = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(
|
||||
days=7
|
||||
)
|
||||
@@ -162,55 +145,13 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
status = status_result.status
|
||||
new_output = status_result.new_output
|
||||
|
||||
execution_count = len(executions)
|
||||
success_rate: float | None = None
|
||||
avg_correctness_score: float | None = None
|
||||
if execution_count > 0:
|
||||
success_count = sum(
|
||||
1
|
||||
for e in executions
|
||||
if e.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED
|
||||
)
|
||||
success_rate = (success_count / execution_count) * 100
|
||||
|
||||
correctness_scores = []
|
||||
for e in executions:
|
||||
if e.stats and isinstance(e.stats, dict):
|
||||
score = e.stats.get("correctness_score")
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
correctness_scores.append(float(score))
|
||||
if correctness_scores:
|
||||
avg_correctness_score = sum(correctness_scores) / len(
|
||||
correctness_scores
|
||||
)
|
||||
|
||||
recent_executions: list[RecentExecution] = []
|
||||
for e in executions:
|
||||
exec_score: float | None = None
|
||||
exec_summary: str | None = None
|
||||
if e.stats and isinstance(e.stats, dict):
|
||||
score = e.stats.get("correctness_score")
|
||||
if score is not None and isinstance(score, (int, float)):
|
||||
exec_score = float(score)
|
||||
summary = e.stats.get("activity_status")
|
||||
if summary is not None and isinstance(summary, str):
|
||||
exec_summary = summary
|
||||
exec_status = (
|
||||
e.executionStatus.value
|
||||
if hasattr(e.executionStatus, "value")
|
||||
else str(e.executionStatus)
|
||||
)
|
||||
recent_executions.append(
|
||||
RecentExecution(
|
||||
status=exec_status,
|
||||
correctness_score=exec_score,
|
||||
activity_summary=exec_summary,
|
||||
)
|
||||
)
|
||||
|
||||
# Check if user can access the graph
|
||||
can_access_graph = agent.AgentGraph.userId == agent.userId
|
||||
|
||||
# Hard-coded to True until a method to check is implemented
|
||||
is_latest_version = True
|
||||
|
||||
# Build marketplace_listing if available
|
||||
marketplace_listing_data = None
|
||||
if store_listing and store_listing.ActiveVersion and profile:
|
||||
creator_data = MarketplaceListingCreator(
|
||||
@@ -249,15 +190,11 @@ class LibraryAgent(pydantic.BaseModel):
|
||||
has_sensitive_action=graph.has_sensitive_action,
|
||||
trigger_setup_info=graph.trigger_setup_info,
|
||||
new_output=new_output,
|
||||
execution_count=execution_count,
|
||||
success_rate=success_rate,
|
||||
avg_correctness_score=avg_correctness_score,
|
||||
recent_executions=recent_executions,
|
||||
can_access_graph=can_access_graph,
|
||||
is_latest_version=is_latest_version,
|
||||
is_favorite=agent.isFavorite,
|
||||
recommended_schedule_cron=agent.AgentGraph.recommendedScheduleCron,
|
||||
settings=_parse_settings(agent.settings),
|
||||
settings=GraphSettings.model_validate(agent.settings),
|
||||
marketplace_listing=marketplace_listing_data,
|
||||
)
|
||||
|
||||
@@ -283,15 +220,18 @@ def _calculate_agent_status(
|
||||
if not executions:
|
||||
return AgentStatusResult(status=LibraryAgentStatus.COMPLETED, new_output=False)
|
||||
|
||||
# Track how many times each execution status appears
|
||||
status_counts = {status: 0 for status in prisma.enums.AgentExecutionStatus}
|
||||
new_output = False
|
||||
|
||||
for execution in executions:
|
||||
# Check if there's a completed run more recent than `recent_threshold`
|
||||
if execution.createdAt >= recent_threshold:
|
||||
if execution.executionStatus == prisma.enums.AgentExecutionStatus.COMPLETED:
|
||||
new_output = True
|
||||
status_counts[execution.executionStatus] += 1
|
||||
|
||||
# Determine the final status based on counts
|
||||
if status_counts[prisma.enums.AgentExecutionStatus.FAILED] > 0:
|
||||
return AgentStatusResult(status=LibraryAgentStatus.ERROR, new_output=new_output)
|
||||
elif status_counts[prisma.enums.AgentExecutionStatus.QUEUED] > 0:
|
||||
|
||||
@@ -112,7 +112,6 @@ async def get_store_agents(
|
||||
description=agent["description"],
|
||||
runs=agent["runs"],
|
||||
rating=agent["rating"],
|
||||
agent_graph_id=agent.get("agentGraphId", ""),
|
||||
)
|
||||
store_agents.append(store_agent)
|
||||
except Exception as e:
|
||||
@@ -171,7 +170,6 @@ async def get_store_agents(
|
||||
description=agent.description,
|
||||
runs=agent.runs,
|
||||
rating=agent.rating,
|
||||
agent_graph_id=agent.agentGraphId,
|
||||
)
|
||||
# Add to the list only if creation was successful
|
||||
store_agents.append(store_agent)
|
||||
|
||||
@@ -600,7 +600,6 @@ async def hybrid_search(
|
||||
sa.featured,
|
||||
sa.is_available,
|
||||
sa.updated_at,
|
||||
sa."agentGraphId",
|
||||
-- Searchable text for BM25 reranking
|
||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||
-- Semantic score
|
||||
@@ -660,7 +659,6 @@ async def hybrid_search(
|
||||
featured,
|
||||
is_available,
|
||||
updated_at,
|
||||
"agentGraphId",
|
||||
searchable_text,
|
||||
semantic_score,
|
||||
lexical_score,
|
||||
|
||||
@@ -38,7 +38,6 @@ class StoreAgent(pydantic.BaseModel):
|
||||
description: str
|
||||
runs: int
|
||||
rating: float
|
||||
agent_graph_id: str
|
||||
|
||||
|
||||
class StoreAgentsResponse(pydantic.BaseModel):
|
||||
|
||||
@@ -26,13 +26,11 @@ def test_store_agent():
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
assert agent.slug == "test-agent"
|
||||
assert agent.agent_name == "Test Agent"
|
||||
assert agent.runs == 50
|
||||
assert agent.rating == 4.5
|
||||
assert agent.agent_graph_id == "test-graph-id"
|
||||
|
||||
|
||||
def test_store_agents_response():
|
||||
@@ -48,7 +46,6 @@ def test_store_agents_response():
|
||||
description="Test description",
|
||||
runs=50,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
|
||||
@@ -82,7 +82,6 @@ def test_get_agents_featured(
|
||||
description="Featured agent description",
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-1",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
@@ -128,7 +127,6 @@ def test_get_agents_by_creator(
|
||||
description="Creator agent description",
|
||||
runs=50,
|
||||
rating=4.0,
|
||||
agent_graph_id="test-graph-2",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
@@ -174,7 +172,6 @@ def test_get_agents_sorted(
|
||||
description="Top agent description",
|
||||
runs=1000,
|
||||
rating=5.0,
|
||||
agent_graph_id="test-graph-3",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
@@ -220,7 +217,6 @@ def test_get_agents_search(
|
||||
description="Specific search term description",
|
||||
runs=75,
|
||||
rating=4.2,
|
||||
agent_graph_id="test-graph-search",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
@@ -266,7 +262,6 @@ def test_get_agents_category(
|
||||
description="Category agent description",
|
||||
runs=60,
|
||||
rating=4.1,
|
||||
agent_graph_id="test-graph-category",
|
||||
)
|
||||
],
|
||||
pagination=store_model.Pagination(
|
||||
@@ -311,7 +306,6 @@ def test_get_agents_pagination(
|
||||
description=f"Agent {i} description",
|
||||
runs=i * 10,
|
||||
rating=4.0,
|
||||
agent_graph_id="test-graph-2",
|
||||
)
|
||||
for i in range(5)
|
||||
],
|
||||
|
||||
@@ -33,7 +33,6 @@ class TestCacheDeletion:
|
||||
description="Test description",
|
||||
runs=100,
|
||||
rating=4.5,
|
||||
agent_graph_id="test-graph-id",
|
||||
)
|
||||
],
|
||||
pagination=Pagination(
|
||||
|
||||
@@ -66,24 +66,18 @@ async def event_broadcaster(manager: ConnectionManager):
|
||||
execution_bus = AsyncRedisExecutionEventBus()
|
||||
notification_bus = AsyncRedisNotificationEventBus()
|
||||
|
||||
try:
|
||||
async def execution_worker():
|
||||
async for event in execution_bus.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
|
||||
async def execution_worker():
|
||||
async for event in execution_bus.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
async def notification_worker():
|
||||
async for notification in notification_bus.listen("*"):
|
||||
await manager.send_notification(
|
||||
user_id=notification.user_id,
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
async def notification_worker():
|
||||
async for notification in notification_bus.listen("*"):
|
||||
await manager.send_notification(
|
||||
user_id=notification.user_id,
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
finally:
|
||||
# Ensure PubSub connections are closed on any exit to prevent leaks
|
||||
await execution_bus.close()
|
||||
await notification_bus.close()
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
28
autogpt_platform/backend/backend/blocks/elevenlabs/_auth.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""ElevenLabs integration blocks - test credentials and shared utilities."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import SecretStr
|
||||
|
||||
from backend.data.model import APIKeyCredentials, CredentialsMetaInput
|
||||
from backend.integrations.providers import ProviderName
|
||||
|
||||
TEST_CREDENTIALS = APIKeyCredentials(
|
||||
id="01234567-89ab-cdef-0123-456789abcdef",
|
||||
provider="elevenlabs",
|
||||
api_key=SecretStr("mock-elevenlabs-api-key"),
|
||||
title="Mock ElevenLabs API key",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
TEST_CREDENTIALS_INPUT = {
|
||||
"provider": TEST_CREDENTIALS.provider,
|
||||
"id": TEST_CREDENTIALS.id,
|
||||
"type": TEST_CREDENTIALS.type,
|
||||
"title": TEST_CREDENTIALS.title,
|
||||
}
|
||||
|
||||
ElevenLabsCredentials = APIKeyCredentials
|
||||
ElevenLabsCredentialsInput = CredentialsMetaInput[
|
||||
Literal[ProviderName.ELEVENLABS], Literal["api_key"]
|
||||
]
|
||||
@@ -32,7 +32,7 @@ from backend.data.model import (
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.util import json
|
||||
from backend.util.logging import TruncatedLogger
|
||||
from backend.util.prompt import compress_context, estimate_token_count
|
||||
from backend.util.prompt import compress_prompt, estimate_token_count
|
||||
from backend.util.text import TextFormatter
|
||||
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||
@@ -115,6 +115,7 @@ class LlmModel(str, Enum, metaclass=LlmModelMeta):
|
||||
CLAUDE_4_5_OPUS = "claude-opus-4-5-20251101"
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_4_5_HAIKU = "claude-haiku-4-5-20251001"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
CLAUDE_3_HAIKU = "claude-3-haiku-20240307"
|
||||
# AI/ML API models
|
||||
AIML_API_QWEN2_5_72B = "Qwen/Qwen2.5-72B-Instruct-Turbo"
|
||||
@@ -279,6 +280,9 @@ MODEL_METADATA = {
|
||||
LlmModel.CLAUDE_4_5_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude Haiku 4.5", "Anthropic", "Anthropic", 2
|
||||
), # claude-haiku-4-5-20251001
|
||||
LlmModel.CLAUDE_3_7_SONNET: ModelMetadata(
|
||||
"anthropic", 200000, 64000, "Claude 3.7 Sonnet", "Anthropic", "Anthropic", 2
|
||||
), # claude-3-7-sonnet-20250219
|
||||
LlmModel.CLAUDE_3_HAIKU: ModelMetadata(
|
||||
"anthropic", 200000, 4096, "Claude 3 Haiku", "Anthropic", "Anthropic", 1
|
||||
), # claude-3-haiku-20240307
|
||||
@@ -634,18 +638,11 @@ async def llm_call(
|
||||
context_window = llm_model.context_window
|
||||
|
||||
if compress_prompt_to_fit:
|
||||
result = await compress_context(
|
||||
prompt = compress_prompt(
|
||||
messages=prompt,
|
||||
target_tokens=llm_model.context_window // 2,
|
||||
client=None, # Truncation-only, no LLM summarization
|
||||
reserve=0, # Caller handles response token budget separately
|
||||
lossy_ok=True,
|
||||
)
|
||||
if result.error:
|
||||
logger.warning(
|
||||
f"Prompt compression did not meet target: {result.error}. "
|
||||
f"Proceeding with {result.token_count} tokens."
|
||||
)
|
||||
prompt = result.messages
|
||||
|
||||
# Calculate available tokens based on context window and input length
|
||||
estimated_input_tokens = estimate_token_count(prompt)
|
||||
|
||||
@@ -1,246 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
from typing import Optional
|
||||
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.fx.Loop import Loop
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class MediaDurationBlock(Block):
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
media_in: MediaFileType = SchemaField(
|
||||
description="Media input (URL, data URI, or local path)."
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video (True) or audio (False).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
duration: float = SchemaField(
|
||||
description="Duration of the media file (in seconds)."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||
description="Block to get the duration of a media file.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=MediaDurationBlock.Input,
|
||||
output_schema=MediaDurationBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the input media locally
|
||||
local_media_path = await store_media_file(
|
||||
file=input_data.media_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
assert execution_context.graph_exec_id is not None
|
||||
media_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_media_path
|
||||
)
|
||||
|
||||
# 2) Load the clip
|
||||
if input_data.is_video:
|
||||
clip = VideoFileClip(media_abspath)
|
||||
else:
|
||||
clip = AudioFileClip(media_abspath)
|
||||
|
||||
yield "duration", clip.duration
|
||||
|
||||
|
||||
class LoopVideoBlock(Block):
|
||||
"""
|
||||
Block for looping (repeating) a video clip until a given duration or number of loops.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="The input video (can be a URL, data URI, or local path)."
|
||||
)
|
||||
# Provide EITHER a `duration` or `n_loops` or both. We'll demonstrate `duration`.
|
||||
duration: Optional[float] = SchemaField(
|
||||
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
|
||||
default=None,
|
||||
ge=0.0,
|
||||
)
|
||||
n_loops: Optional[int] = SchemaField(
|
||||
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
|
||||
default=None,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: str = SchemaField(
|
||||
description="Looped video returned either as a relative path or a data URI."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||
description="Block to loop a video to a given duration or number of repeats.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=LoopVideoBlock.Input,
|
||||
output_schema=LoopVideoBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the input video locally
|
||||
local_video_path = await store_media_file(
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||
|
||||
# 2) Load the clip
|
||||
clip = VideoFileClip(input_abspath)
|
||||
|
||||
# 3) Apply the loop effect
|
||||
looped_clip = clip
|
||||
if input_data.duration:
|
||||
# Loop until we reach the specified duration
|
||||
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
|
||||
elif input_data.n_loops:
|
||||
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
|
||||
else:
|
||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||
|
||||
assert isinstance(looped_clip, VideoFileClip)
|
||||
|
||||
# 4) Save the looped output
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||
|
||||
looped_clip = looped_clip.with_audio(clip.audio)
|
||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
|
||||
|
||||
class AddAudioToVideoBlock(Block):
|
||||
"""
|
||||
Block that adds (attaches) an audio track to an existing video.
|
||||
Optionally scale the volume of the new track.
|
||||
"""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Video input (URL, data URI, or local path)."
|
||||
)
|
||||
audio_in: MediaFileType = SchemaField(
|
||||
description="Audio input (URL, data URI, or local path)."
|
||||
)
|
||||
volume: float = SchemaField(
|
||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Final video (with attached audio), as a path or data URI."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||
description="Block to attach an audio file to a video file using moviepy.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=AddAudioToVideoBlock.Input,
|
||||
output_schema=AddAudioToVideoBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the inputs locally
|
||||
local_video_path = await store_media_file(
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
local_audio_path = await store_media_file(
|
||||
file=input_data.audio_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||
video_abspath = os.path.join(abs_temp_dir, local_video_path)
|
||||
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
|
||||
|
||||
# 2) Load video + audio with moviepy
|
||||
video_clip = VideoFileClip(video_abspath)
|
||||
audio_clip = AudioFileClip(audio_abspath)
|
||||
# Optionally scale volume
|
||||
if input_data.volume != 1.0:
|
||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||
|
||||
# 3) Attach the new audio track
|
||||
final_clip = video_clip.with_audio(audio_clip)
|
||||
|
||||
# 4) Write to output file
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
@@ -83,7 +83,7 @@ class StagehandRecommendedLlmModel(str, Enum):
|
||||
GPT41_MINI = "gpt-4.1-mini-2025-04-14"
|
||||
|
||||
# Anthropic
|
||||
CLAUDE_4_5_SONNET = "claude-sonnet-4-5-20250929"
|
||||
CLAUDE_3_7_SONNET = "claude-3-7-sonnet-20250219"
|
||||
|
||||
@property
|
||||
def provider_name(self) -> str:
|
||||
@@ -137,7 +137,7 @@ class StagehandObserveBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
@@ -230,7 +230,7 @@ class StagehandActBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
@@ -330,7 +330,7 @@ class StagehandExtractBlock(Block):
|
||||
model: StagehandRecommendedLlmModel = SchemaField(
|
||||
title="LLM Model",
|
||||
description="LLM to use for Stagehand (provider is inferred)",
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_4_5_SONNET,
|
||||
default=StagehandRecommendedLlmModel.CLAUDE_3_7_SONNET,
|
||||
advanced=False,
|
||||
)
|
||||
model_credentials: AICredentials = AICredentialsField()
|
||||
|
||||
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
37
autogpt_platform/backend/backend/blocks/video/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Video editing blocks for AutoGPT Platform.
|
||||
|
||||
This module provides blocks for:
|
||||
- Downloading videos from URLs (YouTube, Vimeo, news sites, direct links)
|
||||
- Clipping/trimming video segments
|
||||
- Concatenating multiple videos
|
||||
- Adding text overlays
|
||||
- Adding AI-generated narration
|
||||
- Getting media duration
|
||||
- Looping videos
|
||||
- Adding audio to videos
|
||||
|
||||
Dependencies:
|
||||
- yt-dlp: For video downloading
|
||||
- moviepy: For video editing operations
|
||||
- elevenlabs: For AI narration (optional)
|
||||
"""
|
||||
|
||||
from backend.blocks.video.add_audio import AddAudioToVideoBlock
|
||||
from backend.blocks.video.clip import VideoClipBlock
|
||||
from backend.blocks.video.concat import VideoConcatBlock
|
||||
from backend.blocks.video.download import VideoDownloadBlock
|
||||
from backend.blocks.video.duration import MediaDurationBlock
|
||||
from backend.blocks.video.loop import LoopVideoBlock
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
from backend.blocks.video.text_overlay import VideoTextOverlayBlock
|
||||
|
||||
__all__ = [
|
||||
"AddAudioToVideoBlock",
|
||||
"LoopVideoBlock",
|
||||
"MediaDurationBlock",
|
||||
"VideoClipBlock",
|
||||
"VideoConcatBlock",
|
||||
"VideoDownloadBlock",
|
||||
"VideoNarrationBlock",
|
||||
"VideoTextOverlayBlock",
|
||||
]
|
||||
34
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
34
autogpt_platform/backend/backend/blocks/video/_utils.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Shared utilities for video blocks."""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def get_video_codecs(output_path: str) -> tuple[str, str]:
|
||||
"""Get appropriate video and audio codecs based on output file extension.
|
||||
|
||||
Args:
|
||||
output_path: Path to the output file (used to determine extension)
|
||||
|
||||
Returns:
|
||||
Tuple of (video_codec, audio_codec)
|
||||
|
||||
Codec mappings:
|
||||
- .mp4: H.264 + AAC (universal compatibility)
|
||||
- .webm: VP8 + Vorbis (web streaming)
|
||||
- .mkv: H.264 + AAC (container supports many codecs)
|
||||
- .mov: H.264 + AAC (Apple QuickTime, widely compatible)
|
||||
- .m4v: H.264 + AAC (Apple iTunes/devices)
|
||||
- .avi: MPEG-4 + MP3 (legacy Windows)
|
||||
"""
|
||||
ext = os.path.splitext(output_path)[1].lower()
|
||||
|
||||
codec_map: dict[str, tuple[str, str]] = {
|
||||
".mp4": ("libx264", "aac"),
|
||||
".webm": ("libvpx", "libvorbis"),
|
||||
".mkv": ("libx264", "aac"),
|
||||
".mov": ("libx264", "aac"),
|
||||
".m4v": ("libx264", "aac"),
|
||||
".avi": ("mpeg4", "libmp3lame"),
|
||||
}
|
||||
|
||||
return codec_map.get(ext, ("libx264", "aac"))
|
||||
102
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
102
autogpt_platform/backend/backend/blocks/video/add_audio.py
Normal file
@@ -0,0 +1,102 @@
|
||||
"""AddAudioToVideoBlock - Attach an audio track to a video file."""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, store_media_file
|
||||
|
||||
|
||||
class AddAudioToVideoBlock(Block):
|
||||
"""Add (attach) an audio track to an existing video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Video input (URL, data URI, or local path)."
|
||||
)
|
||||
audio_in: MediaFileType = SchemaField(
|
||||
description="Audio input (URL, data URI, or local path)."
|
||||
)
|
||||
volume: float = SchemaField(
|
||||
description="Volume scale for the newly attached audio track (1.0 = original).",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Final video (with attached audio), as a path or data URI."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3503748d-62b6-4425-91d6-725b064af509",
|
||||
description="Block to attach an audio file to a video file using moviepy.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=AddAudioToVideoBlock.Input,
|
||||
output_schema=AddAudioToVideoBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the inputs locally
|
||||
local_video_path = await store_media_file(
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
local_audio_path = await store_media_file(
|
||||
file=input_data.audio_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
abs_temp_dir = os.path.join(tempfile.gettempdir(), "exec_file", graph_exec_id)
|
||||
video_abspath = os.path.join(abs_temp_dir, local_video_path)
|
||||
audio_abspath = os.path.join(abs_temp_dir, local_audio_path)
|
||||
|
||||
# 2) Load video + audio with moviepy
|
||||
video_clip = VideoFileClip(video_abspath)
|
||||
audio_clip = AudioFileClip(audio_abspath)
|
||||
# Optionally scale volume
|
||||
if input_data.volume != 1.0:
|
||||
audio_clip = audio_clip.with_volume_scaled(input_data.volume)
|
||||
|
||||
# 3) Attach the new audio track
|
||||
final_clip = video_clip.with_audio(audio_clip)
|
||||
|
||||
# 4) Write to output file
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_audio_attached_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = os.path.join(abs_temp_dir, output_filename)
|
||||
final_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# 5) Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
165
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
165
autogpt_platform/backend/backend/blocks/video/clip.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""VideoClipBlock - Extract a segment from a video file."""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.video._utils import get_video_codecs
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoClipBlock(Block):
|
||||
"""Extract a time segment from a video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Input video (URL, data URI, or local path)"
|
||||
)
|
||||
start_time: float = SchemaField(description="Start time in seconds", ge=0.0)
|
||||
end_time: float = SchemaField(description="End time in seconds", ge=0.0)
|
||||
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||
description="Output format", default="mp4", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Clipped video file (path or data URI)"
|
||||
)
|
||||
duration: float = SchemaField(description="Clip duration in seconds")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8f539119-e580-4d86-ad41-86fbcb22abb1",
|
||||
description="Extract a time segment from a video",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"video_in": "/tmp/test.mp4",
|
||||
"start_time": 0.0,
|
||||
"end_time": 10.0,
|
||||
},
|
||||
test_output=[("video_out", str), ("duration", float)],
|
||||
test_mock={
|
||||
"_clip_video": lambda *args: 10.0,
|
||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||
"_store_output_video": lambda *args, **kwargs: "clip_test.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_input_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store input video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _clip_video(
|
||||
self,
|
||||
video_abspath: str,
|
||||
output_abspath: str,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
) -> float:
|
||||
"""Extract a clip from a video. Extracted for testability."""
|
||||
clip = None
|
||||
subclip = None
|
||||
try:
|
||||
clip = VideoFileClip(video_abspath)
|
||||
subclip = clip.subclipped(start_time, end_time)
|
||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||
subclip.write_videofile(
|
||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||
)
|
||||
return subclip.duration
|
||||
finally:
|
||||
if subclip:
|
||||
subclip.close()
|
||||
if clip:
|
||||
clip.close()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Validate time range
|
||||
if input_data.end_time <= input_data.start_time:
|
||||
raise BlockExecutionError(
|
||||
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
)
|
||||
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Store the input video locally
|
||||
local_video_path = await self._store_input_video(
|
||||
execution_context, input_data.video_in
|
||||
)
|
||||
video_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_video_path
|
||||
)
|
||||
|
||||
# Build output path
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_clip_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
# Ensure correct extension
|
||||
base, _ = os.path.splitext(output_filename)
|
||||
output_filename = MediaFileType(f"{base}.{input_data.output_format}")
|
||||
output_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, output_filename
|
||||
)
|
||||
|
||||
duration = self._clip_video(
|
||||
video_abspath,
|
||||
output_abspath,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, output_filename
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
yield "duration", duration
|
||||
|
||||
except BlockExecutionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to clip video: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
197
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
197
autogpt_platform/backend/backend/blocks/video/concat.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""VideoConcatBlock - Concatenate multiple video clips into one."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from moviepy import concatenate_videoclips
|
||||
from moviepy.video.fx import CrossFadeIn, CrossFadeOut, FadeIn, FadeOut
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.video._utils import get_video_codecs
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoConcatBlock(Block):
|
||||
"""Merge multiple video clips into one continuous video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
videos: list[MediaFileType] = SchemaField(
|
||||
description="List of video files to concatenate (in order)"
|
||||
)
|
||||
transition: Literal["none", "crossfade", "fade_black"] = SchemaField(
|
||||
description="Transition between clips", default="none"
|
||||
)
|
||||
transition_duration: int = SchemaField(
|
||||
description="Transition duration in seconds",
|
||||
default=1,
|
||||
ge=0,
|
||||
advanced=True,
|
||||
)
|
||||
output_format: Literal["mp4", "webm", "mkv", "mov"] = SchemaField(
|
||||
description="Output format", default="mp4", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Concatenated video file (path or data URI)"
|
||||
)
|
||||
total_duration: float = SchemaField(description="Total duration in seconds")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="9b0f531a-1118-487f-aeec-3fa63ea8900a",
|
||||
description="Merge multiple video clips into one continuous video",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={"videos": ["/tmp/a.mp4", "/tmp/b.mp4"]},
|
||||
test_output=[("video_out", str), ("total_duration", float)],
|
||||
test_mock={
|
||||
"_concat_videos": lambda *args: 20.0,
|
||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||
"_store_output_video": lambda *args, **kwargs: "concat_test.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_input_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store input video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _concat_videos(
|
||||
self,
|
||||
video_abspaths: list[str],
|
||||
output_abspath: str,
|
||||
transition: str,
|
||||
transition_duration: int,
|
||||
) -> float:
|
||||
"""Concatenate videos. Extracted for testability."""
|
||||
clips = []
|
||||
faded_clips = []
|
||||
final = None
|
||||
try:
|
||||
# Load clips
|
||||
for v in video_abspaths:
|
||||
clips.append(VideoFileClip(v))
|
||||
|
||||
if transition == "crossfade":
|
||||
for i, clip in enumerate(clips):
|
||||
effects = []
|
||||
if i > 0:
|
||||
effects.append(CrossFadeIn(transition_duration))
|
||||
if i < len(clips) - 1:
|
||||
effects.append(CrossFadeOut(transition_duration))
|
||||
if effects:
|
||||
clip = clip.with_effects(effects)
|
||||
faded_clips.append(clip)
|
||||
final = concatenate_videoclips(
|
||||
faded_clips,
|
||||
method="compose",
|
||||
padding=-transition_duration,
|
||||
)
|
||||
elif transition == "fade_black":
|
||||
for clip in clips:
|
||||
faded = clip.with_effects(
|
||||
[FadeIn(transition_duration), FadeOut(transition_duration)]
|
||||
)
|
||||
faded_clips.append(faded)
|
||||
final = concatenate_videoclips(faded_clips)
|
||||
else:
|
||||
final = concatenate_videoclips(clips)
|
||||
|
||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||
final.write_videofile(
|
||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||
)
|
||||
|
||||
return final.duration
|
||||
finally:
|
||||
if final:
|
||||
final.close()
|
||||
for clip in faded_clips:
|
||||
clip.close()
|
||||
for clip in clips:
|
||||
clip.close()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Validate minimum clips
|
||||
if len(input_data.videos) < 2:
|
||||
raise BlockExecutionError(
|
||||
message="At least 2 videos are required for concatenation",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
)
|
||||
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Store all input videos locally
|
||||
video_abspaths = []
|
||||
for video in input_data.videos:
|
||||
local_path = await self._store_input_video(execution_context, video)
|
||||
video_abspaths.append(
|
||||
get_exec_file_path(execution_context.graph_exec_id, local_path)
|
||||
)
|
||||
|
||||
# Build output path
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_concat.{input_data.output_format}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, output_filename
|
||||
)
|
||||
|
||||
total_duration = self._concat_videos(
|
||||
video_abspaths,
|
||||
output_abspath,
|
||||
input_data.transition,
|
||||
input_data.transition_duration,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, output_filename
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
yield "total_duration", total_duration
|
||||
|
||||
except BlockExecutionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to concatenate videos: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
167
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
167
autogpt_platform/backend/backend/blocks/video/download.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""VideoDownloadBlock - Download video from URL (YouTube, Vimeo, news sites, direct links)."""
|
||||
|
||||
import os
|
||||
import typing
|
||||
from typing import Literal
|
||||
|
||||
import yt_dlp
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from yt_dlp import _Params
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoDownloadBlock(Block):
|
||||
"""Download video from URL using yt-dlp."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
url: str = SchemaField(
|
||||
description="URL of the video to download (YouTube, Vimeo, direct link, etc.)",
|
||||
placeholder="https://www.youtube.com/watch?v=...",
|
||||
)
|
||||
quality: Literal["best", "1080p", "720p", "480p", "audio_only"] = SchemaField(
|
||||
description="Video quality preference", default="720p"
|
||||
)
|
||||
output_format: Literal["mp4", "webm", "mkv"] = SchemaField(
|
||||
description="Output video format", default="mp4", advanced=True
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_file: MediaFileType = SchemaField(
|
||||
description="Downloaded video (path or data URI)"
|
||||
)
|
||||
duration: float = SchemaField(description="Video duration in seconds")
|
||||
title: str = SchemaField(description="Video title from source")
|
||||
source_url: str = SchemaField(description="Original source URL")
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="c35daabb-cd60-493b-b9ad-51f1fe4b50c4",
|
||||
description="Download video from URL (YouTube, Vimeo, news sites, direct links)",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"url": "https://www.youtube.com/watch?v=dQw4w9WgXcQ",
|
||||
"quality": "480p",
|
||||
},
|
||||
test_output=[
|
||||
("video_file", str),
|
||||
("duration", float),
|
||||
("title", str),
|
||||
("source_url", str),
|
||||
],
|
||||
test_mock={
|
||||
"_download_video": lambda *args: ("video.mp4", 212.0, "Test Video"),
|
||||
"_store_output_video": lambda *args, **kwargs: "video.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _get_format_string(self, quality: str) -> str:
|
||||
formats = {
|
||||
"best": "bestvideo+bestaudio/best",
|
||||
"1080p": "bestvideo[height<=1080]+bestaudio/best[height<=1080]",
|
||||
"720p": "bestvideo[height<=720]+bestaudio/best[height<=720]",
|
||||
"480p": "bestvideo[height<=480]+bestaudio/best[height<=480]",
|
||||
"audio_only": "bestaudio/best",
|
||||
}
|
||||
return formats.get(quality, formats["720p"])
|
||||
|
||||
def _download_video(
|
||||
self,
|
||||
url: str,
|
||||
quality: str,
|
||||
output_format: str,
|
||||
output_dir: str,
|
||||
node_exec_id: str,
|
||||
) -> tuple[str, float, str]:
|
||||
"""Download video. Extracted for testability."""
|
||||
output_template = os.path.join(
|
||||
output_dir, f"{node_exec_id}_%(title).50s.%(ext)s"
|
||||
)
|
||||
|
||||
ydl_opts: "_Params" = {
|
||||
"format": self._get_format_string(quality),
|
||||
"outtmpl": output_template,
|
||||
"merge_output_format": output_format,
|
||||
"quiet": True,
|
||||
"no_warnings": True,
|
||||
}
|
||||
|
||||
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
||||
info = ydl.extract_info(url, download=True)
|
||||
video_path = ydl.prepare_filename(info)
|
||||
|
||||
# Handle format conversion in filename
|
||||
if not video_path.endswith(f".{output_format}"):
|
||||
video_path = video_path.rsplit(".", 1)[0] + f".{output_format}"
|
||||
|
||||
# Return just the filename, not the full path
|
||||
filename = os.path.basename(video_path)
|
||||
|
||||
return (
|
||||
filename,
|
||||
info.get("duration") or 0.0,
|
||||
info.get("title") or "Unknown",
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Get the exec file directory
|
||||
output_dir = get_exec_file_path(execution_context.graph_exec_id, "")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
filename, duration, title = self._download_video(
|
||||
input_data.url,
|
||||
input_data.quality,
|
||||
input_data.output_format,
|
||||
output_dir,
|
||||
node_exec_id,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, MediaFileType(filename)
|
||||
)
|
||||
|
||||
yield "video_file", video_out
|
||||
yield "duration", duration
|
||||
yield "title", title
|
||||
yield "source_url", input_data.url
|
||||
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to download video: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
68
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
68
autogpt_platform/backend/backend/blocks/video/duration.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""MediaDurationBlock - Get the duration of a media file."""
|
||||
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class MediaDurationBlock(Block):
|
||||
"""Get the duration of a media file (video or audio)."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
media_in: MediaFileType = SchemaField(
|
||||
description="Media input (URL, data URI, or local path)."
|
||||
)
|
||||
is_video: bool = SchemaField(
|
||||
description="Whether the media is a video (True) or audio (False).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
duration: float = SchemaField(
|
||||
description="Duration of the media file (in seconds)."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="d8b91fd4-da26-42d4-8ecb-8b196c6d84b6",
|
||||
description="Block to get the duration of a media file.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=MediaDurationBlock.Input,
|
||||
output_schema=MediaDurationBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# 1) Store the input media locally
|
||||
local_media_path = await store_media_file(
|
||||
file=input_data.media_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
assert execution_context.graph_exec_id is not None
|
||||
media_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_media_path
|
||||
)
|
||||
|
||||
# 2) Load the clip
|
||||
if input_data.is_video:
|
||||
clip = VideoFileClip(media_abspath)
|
||||
else:
|
||||
clip = AudioFileClip(media_abspath)
|
||||
|
||||
yield "duration", clip.duration
|
||||
104
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
104
autogpt_platform/backend/backend/blocks/video/loop.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""LoopVideoBlock - Loop a video to a given duration or number of repeats."""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from moviepy.video.fx.Loop import Loop
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class LoopVideoBlock(Block):
|
||||
"""Loop (repeat) a video clip until a given duration or number of loops."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="The input video (can be a URL, data URI, or local path)."
|
||||
)
|
||||
duration: Optional[float] = SchemaField(
|
||||
description="Target duration (in seconds) to loop the video to. If omitted, defaults to no looping.",
|
||||
default=None,
|
||||
ge=0.0,
|
||||
)
|
||||
n_loops: Optional[int] = SchemaField(
|
||||
description="Number of times to repeat the video. If omitted, defaults to 1 (no repeat).",
|
||||
default=None,
|
||||
ge=1,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: str = SchemaField(
|
||||
description="Looped video returned either as a relative path or a data URI."
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8bf9eef6-5451-4213-b265-25306446e94b",
|
||||
description="Block to loop a video to a given duration or number of repeats.",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=LoopVideoBlock.Input,
|
||||
output_schema=LoopVideoBlock.Output,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
assert execution_context.node_exec_id is not None
|
||||
graph_exec_id = execution_context.graph_exec_id
|
||||
node_exec_id = execution_context.node_exec_id
|
||||
|
||||
# 1) Store the input video locally
|
||||
local_video_path = await store_media_file(
|
||||
file=input_data.video_in,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
input_abspath = get_exec_file_path(graph_exec_id, local_video_path)
|
||||
|
||||
# 2) Load the clip
|
||||
clip = VideoFileClip(input_abspath)
|
||||
|
||||
# 3) Apply the loop effect
|
||||
looped_clip = clip
|
||||
if input_data.duration:
|
||||
# Loop until we reach the specified duration
|
||||
looped_clip = looped_clip.with_effects([Loop(duration=input_data.duration)])
|
||||
elif input_data.n_loops:
|
||||
looped_clip = looped_clip.with_effects([Loop(n=input_data.n_loops)])
|
||||
else:
|
||||
raise ValueError("Either 'duration' or 'n_loops' must be provided.")
|
||||
|
||||
assert isinstance(looped_clip, VideoFileClip)
|
||||
|
||||
# 4) Save the looped output
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_looped_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(graph_exec_id, output_filename)
|
||||
|
||||
looped_clip = looped_clip.with_audio(clip.audio)
|
||||
looped_clip.write_videofile(output_abspath, codec="libx264", audio_codec="aac")
|
||||
|
||||
# Return output - for_block_output returns workspace:// if available, else data URI
|
||||
video_out = await store_media_file(
|
||||
file=output_filename,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
263
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
263
autogpt_platform/backend/backend/blocks/video/narration.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""VideoNarrationBlock - Generate AI voice narration and add to video."""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from elevenlabs import ElevenLabs
|
||||
from moviepy import CompositeAudioClip
|
||||
from moviepy.audio.io.AudioFileClip import AudioFileClip
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.elevenlabs._auth import (
|
||||
TEST_CREDENTIALS,
|
||||
TEST_CREDENTIALS_INPUT,
|
||||
ElevenLabsCredentials,
|
||||
ElevenLabsCredentialsInput,
|
||||
)
|
||||
from backend.blocks.video._utils import get_video_codecs
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import CredentialsField, SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoNarrationBlock(Block):
|
||||
"""Generate AI narration and add to video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
credentials: ElevenLabsCredentialsInput = CredentialsField(
|
||||
description="ElevenLabs API key for voice synthesis"
|
||||
)
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Input video (URL, data URI, or local path)"
|
||||
)
|
||||
script: str = SchemaField(description="Narration script text")
|
||||
voice_id: str = SchemaField(
|
||||
description="ElevenLabs voice ID", default="21m00Tcm4TlvDq8ikWAM" # Rachel
|
||||
)
|
||||
model_id: Literal[
|
||||
"eleven_multilingual_v2",
|
||||
"eleven_flash_v2_5",
|
||||
"eleven_turbo_v2_5",
|
||||
"eleven_turbo_v2",
|
||||
] = SchemaField(
|
||||
description="ElevenLabs TTS model",
|
||||
default="eleven_multilingual_v2",
|
||||
)
|
||||
mix_mode: Literal["replace", "mix", "ducking"] = SchemaField(
|
||||
description="How to combine with original audio. 'ducking' applies stronger attenuation than 'mix'.",
|
||||
default="ducking",
|
||||
)
|
||||
narration_volume: float = SchemaField(
|
||||
description="Narration volume (0.0 to 2.0)",
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
le=2.0,
|
||||
advanced=True,
|
||||
)
|
||||
original_volume: float = SchemaField(
|
||||
description="Original audio volume when mixing (0.0 to 1.0)",
|
||||
default=0.3,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Video with narration (path or data URI)"
|
||||
)
|
||||
audio_file: MediaFileType = SchemaField(
|
||||
description="Generated audio file (path or data URI)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="3d036b53-859c-4b17-9826-ca340f736e0e",
|
||||
description="Generate AI narration and add to video",
|
||||
categories={BlockCategory.MULTIMEDIA, BlockCategory.AI},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={
|
||||
"video_in": "/tmp/test.mp4",
|
||||
"script": "Hello world",
|
||||
"credentials": TEST_CREDENTIALS_INPUT,
|
||||
},
|
||||
test_credentials=TEST_CREDENTIALS,
|
||||
test_output=[("video_out", str), ("audio_file", str)],
|
||||
test_mock={
|
||||
"_generate_narration_audio": lambda *args: b"mock audio content",
|
||||
"_add_narration_to_video": lambda *args: None,
|
||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||
"_store_output_video": lambda *args, **kwargs: "narrated_test.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_input_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store input video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _generate_narration_audio(
|
||||
self, api_key: str, script: str, voice_id: str, model_id: str
|
||||
) -> bytes:
|
||||
"""Generate narration audio via ElevenLabs API."""
|
||||
client = ElevenLabs(api_key=api_key)
|
||||
audio_generator = client.text_to_speech.convert(
|
||||
voice_id=voice_id,
|
||||
text=script,
|
||||
model_id=model_id,
|
||||
)
|
||||
# The SDK returns a generator, collect all chunks
|
||||
return b"".join(audio_generator)
|
||||
|
||||
def _add_narration_to_video(
|
||||
self,
|
||||
video_abspath: str,
|
||||
audio_abspath: str,
|
||||
output_abspath: str,
|
||||
mix_mode: str,
|
||||
narration_volume: float,
|
||||
original_volume: float,
|
||||
) -> None:
|
||||
"""Add narration audio to video. Extracted for testability."""
|
||||
video = None
|
||||
final = None
|
||||
narration_original = None
|
||||
narration_scaled = None
|
||||
original = None
|
||||
|
||||
try:
|
||||
video = VideoFileClip(video_abspath)
|
||||
narration_original = AudioFileClip(audio_abspath)
|
||||
narration_scaled = narration_original.with_volume_scaled(narration_volume)
|
||||
narration = narration_scaled
|
||||
|
||||
if mix_mode == "replace":
|
||||
final_audio = narration
|
||||
elif mix_mode == "mix":
|
||||
if video.audio:
|
||||
original = video.audio.with_volume_scaled(original_volume)
|
||||
final_audio = CompositeAudioClip([original, narration])
|
||||
else:
|
||||
final_audio = narration
|
||||
else: # ducking - apply stronger attenuation
|
||||
if video.audio:
|
||||
# Ducking uses a much lower volume for original audio
|
||||
ducking_volume = original_volume * 0.3
|
||||
original = video.audio.with_volume_scaled(ducking_volume)
|
||||
final_audio = CompositeAudioClip([original, narration])
|
||||
else:
|
||||
final_audio = narration
|
||||
|
||||
final = video.with_audio(final_audio)
|
||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||
final.write_videofile(
|
||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||
)
|
||||
|
||||
finally:
|
||||
if original:
|
||||
original.close()
|
||||
if narration_scaled:
|
||||
narration_scaled.close()
|
||||
if narration_original:
|
||||
narration_original.close()
|
||||
if final:
|
||||
final.close()
|
||||
if video:
|
||||
video.close()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
credentials: ElevenLabsCredentials,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Store the input video locally
|
||||
local_video_path = await self._store_input_video(
|
||||
execution_context, input_data.video_in
|
||||
)
|
||||
video_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_video_path
|
||||
)
|
||||
|
||||
# Generate narration audio via ElevenLabs
|
||||
audio_content = self._generate_narration_audio(
|
||||
credentials.api_key.get_secret_value(),
|
||||
input_data.script,
|
||||
input_data.voice_id,
|
||||
input_data.model_id,
|
||||
)
|
||||
|
||||
# Save audio to exec file path
|
||||
audio_filename = MediaFileType(f"{node_exec_id}_narration.mp3")
|
||||
audio_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, audio_filename
|
||||
)
|
||||
os.makedirs(os.path.dirname(audio_abspath), exist_ok=True)
|
||||
with open(audio_abspath, "wb") as f:
|
||||
f.write(audio_content)
|
||||
|
||||
# Add narration to video
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_narrated_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, output_filename
|
||||
)
|
||||
|
||||
self._add_narration_to_video(
|
||||
video_abspath,
|
||||
audio_abspath,
|
||||
output_abspath,
|
||||
input_data.mix_mode,
|
||||
input_data.narration_volume,
|
||||
input_data.original_volume,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, output_filename
|
||||
)
|
||||
audio_out = await self._store_output_video(
|
||||
execution_context, audio_filename
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
yield "audio_file", audio_out
|
||||
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to add narration: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
227
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
227
autogpt_platform/backend/backend/blocks/video/text_overlay.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""VideoTextOverlayBlock - Add text overlay to video."""
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from moviepy import CompositeVideoClip, TextClip
|
||||
from moviepy.video.io.VideoFileClip import VideoFileClip
|
||||
|
||||
from backend.blocks.video._utils import get_video_codecs
|
||||
from backend.data.block import (
|
||||
Block,
|
||||
BlockCategory,
|
||||
BlockOutput,
|
||||
BlockSchemaInput,
|
||||
BlockSchemaOutput,
|
||||
)
|
||||
from backend.data.execution import ExecutionContext
|
||||
from backend.data.model import SchemaField
|
||||
from backend.util.exceptions import BlockExecutionError
|
||||
from backend.util.file import MediaFileType, get_exec_file_path, store_media_file
|
||||
|
||||
|
||||
class VideoTextOverlayBlock(Block):
|
||||
"""Add text overlay/caption to video."""
|
||||
|
||||
class Input(BlockSchemaInput):
|
||||
video_in: MediaFileType = SchemaField(
|
||||
description="Input video (URL, data URI, or local path)"
|
||||
)
|
||||
text: str = SchemaField(description="Text to overlay on video")
|
||||
position: Literal[
|
||||
"top",
|
||||
"center",
|
||||
"bottom",
|
||||
"top-left",
|
||||
"top-right",
|
||||
"bottom-left",
|
||||
"bottom-right",
|
||||
] = SchemaField(description="Position of text on screen", default="bottom")
|
||||
start_time: float | None = SchemaField(
|
||||
description="When to show text (seconds). None = entire video",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
end_time: float | None = SchemaField(
|
||||
description="When to hide text (seconds). None = until end",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
font_size: int = SchemaField(
|
||||
description="Font size", default=48, ge=12, le=200, advanced=True
|
||||
)
|
||||
font_color: str = SchemaField(
|
||||
description="Font color (hex or name)", default="white", advanced=True
|
||||
)
|
||||
bg_color: str | None = SchemaField(
|
||||
description="Background color behind text (None for transparent)",
|
||||
default=None,
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
class Output(BlockSchemaOutput):
|
||||
video_out: MediaFileType = SchemaField(
|
||||
description="Video with text overlay (path or data URI)"
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(
|
||||
id="8ef14de6-cc90-430a-8cfa-3a003be92454",
|
||||
description="Add text overlay/caption to video",
|
||||
categories={BlockCategory.MULTIMEDIA},
|
||||
input_schema=self.Input,
|
||||
output_schema=self.Output,
|
||||
test_input={"video_in": "/tmp/test.mp4", "text": "Hello World"},
|
||||
test_output=[("video_out", str)],
|
||||
test_mock={
|
||||
"_add_text_overlay": lambda *args: None,
|
||||
"_store_input_video": lambda *args, **kwargs: "test.mp4",
|
||||
"_store_output_video": lambda *args, **kwargs: "overlay_test.mp4",
|
||||
},
|
||||
)
|
||||
|
||||
async def _store_input_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store input video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_local_processing",
|
||||
)
|
||||
|
||||
async def _store_output_video(
|
||||
self, execution_context: ExecutionContext, file: MediaFileType
|
||||
) -> MediaFileType:
|
||||
"""Store output video. Extracted for testability."""
|
||||
return await store_media_file(
|
||||
file=file,
|
||||
execution_context=execution_context,
|
||||
return_format="for_block_output",
|
||||
)
|
||||
|
||||
def _add_text_overlay(
|
||||
self,
|
||||
video_abspath: str,
|
||||
output_abspath: str,
|
||||
text: str,
|
||||
position: str,
|
||||
start_time: float | None,
|
||||
end_time: float | None,
|
||||
font_size: int,
|
||||
font_color: str,
|
||||
bg_color: str | None,
|
||||
) -> None:
|
||||
"""Add text overlay to video. Extracted for testability."""
|
||||
video = None
|
||||
final = None
|
||||
txt_clip = None
|
||||
try:
|
||||
video = VideoFileClip(video_abspath)
|
||||
|
||||
txt_clip = TextClip(
|
||||
text=text,
|
||||
font_size=font_size,
|
||||
color=font_color,
|
||||
bg_color=bg_color,
|
||||
)
|
||||
|
||||
# Position mapping
|
||||
pos_map = {
|
||||
"top": ("center", "top"),
|
||||
"center": ("center", "center"),
|
||||
"bottom": ("center", "bottom"),
|
||||
"top-left": ("left", "top"),
|
||||
"top-right": ("right", "top"),
|
||||
"bottom-left": ("left", "bottom"),
|
||||
"bottom-right": ("right", "bottom"),
|
||||
}
|
||||
|
||||
txt_clip = txt_clip.with_position(pos_map[position])
|
||||
|
||||
# Set timing
|
||||
start = start_time or 0
|
||||
end = end_time or video.duration
|
||||
duration = max(0, end - start)
|
||||
txt_clip = txt_clip.with_start(start).with_end(end).with_duration(duration)
|
||||
|
||||
final = CompositeVideoClip([video, txt_clip])
|
||||
video_codec, audio_codec = get_video_codecs(output_abspath)
|
||||
final.write_videofile(
|
||||
output_abspath, codec=video_codec, audio_codec=audio_codec
|
||||
)
|
||||
|
||||
finally:
|
||||
if txt_clip:
|
||||
txt_clip.close()
|
||||
if final:
|
||||
final.close()
|
||||
if video:
|
||||
video.close()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
input_data: Input,
|
||||
*,
|
||||
execution_context: ExecutionContext,
|
||||
node_exec_id: str,
|
||||
**kwargs,
|
||||
) -> BlockOutput:
|
||||
# Validate time range if both are provided
|
||||
if (
|
||||
input_data.start_time is not None
|
||||
and input_data.end_time is not None
|
||||
and input_data.end_time <= input_data.start_time
|
||||
):
|
||||
raise BlockExecutionError(
|
||||
message=f"end_time ({input_data.end_time}) must be greater than start_time ({input_data.start_time})",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
)
|
||||
|
||||
try:
|
||||
assert execution_context.graph_exec_id is not None
|
||||
|
||||
# Store the input video locally
|
||||
local_video_path = await self._store_input_video(
|
||||
execution_context, input_data.video_in
|
||||
)
|
||||
video_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, local_video_path
|
||||
)
|
||||
|
||||
# Build output path
|
||||
output_filename = MediaFileType(
|
||||
f"{node_exec_id}_overlay_{os.path.basename(local_video_path)}"
|
||||
)
|
||||
output_abspath = get_exec_file_path(
|
||||
execution_context.graph_exec_id, output_filename
|
||||
)
|
||||
|
||||
self._add_text_overlay(
|
||||
video_abspath,
|
||||
output_abspath,
|
||||
input_data.text,
|
||||
input_data.position,
|
||||
input_data.start_time,
|
||||
input_data.end_time,
|
||||
input_data.font_size,
|
||||
input_data.font_color,
|
||||
input_data.bg_color,
|
||||
)
|
||||
|
||||
# Return as workspace path or data URI based on context
|
||||
video_out = await self._store_output_video(
|
||||
execution_context, output_filename
|
||||
)
|
||||
|
||||
yield "video_out", video_out
|
||||
|
||||
except BlockExecutionError:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise BlockExecutionError(
|
||||
message=f"Failed to add text overlay: {e}",
|
||||
block_name=self.name,
|
||||
block_id=str(self.id),
|
||||
) from e
|
||||
@@ -36,12 +36,14 @@ from backend.blocks.replicate.replicate_block import ReplicateModelBlock
|
||||
from backend.blocks.smart_decision_maker import SmartDecisionMakerBlock
|
||||
from backend.blocks.talking_head import CreateTalkingAvatarVideoBlock
|
||||
from backend.blocks.text_to_speech_block import UnrealTextToSpeechBlock
|
||||
from backend.blocks.video.narration import VideoNarrationBlock
|
||||
from backend.data.block import Block, BlockCost, BlockCostType
|
||||
from backend.integrations.credentials_store import (
|
||||
aiml_api_credentials,
|
||||
anthropic_credentials,
|
||||
apollo_credentials,
|
||||
did_credentials,
|
||||
elevenlabs_credentials,
|
||||
enrichlayer_credentials,
|
||||
groq_credentials,
|
||||
ideogram_credentials,
|
||||
@@ -81,6 +83,7 @@ MODEL_COST: dict[LlmModel, int] = {
|
||||
LlmModel.CLAUDE_4_5_HAIKU: 4,
|
||||
LlmModel.CLAUDE_4_5_OPUS: 14,
|
||||
LlmModel.CLAUDE_4_5_SONNET: 9,
|
||||
LlmModel.CLAUDE_3_7_SONNET: 5,
|
||||
LlmModel.CLAUDE_3_HAIKU: 1,
|
||||
LlmModel.AIML_API_QWEN2_5_72B: 1,
|
||||
LlmModel.AIML_API_LLAMA3_1_70B: 1,
|
||||
@@ -639,4 +642,16 @@ BLOCK_COSTS: dict[Type[Block], list[BlockCost]] = {
|
||||
},
|
||||
),
|
||||
],
|
||||
VideoNarrationBlock: [
|
||||
BlockCost(
|
||||
cost_amount=5, # ElevenLabs TTS cost
|
||||
cost_filter={
|
||||
"credentials": {
|
||||
"id": elevenlabs_credentials.id,
|
||||
"provider": elevenlabs_credentials.provider,
|
||||
"type": elevenlabs_credentials.type,
|
||||
}
|
||||
},
|
||||
)
|
||||
],
|
||||
}
|
||||
|
||||
@@ -133,23 +133,10 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
|
||||
|
||||
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
def __init__(self):
|
||||
self._pubsub: AsyncPubSub | None = None
|
||||
|
||||
@property
|
||||
async def connection(self) -> redis.AsyncRedis:
|
||||
return await redis.get_redis_async()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the PubSub connection if it exists."""
|
||||
if self._pubsub is not None:
|
||||
try:
|
||||
await self._pubsub.close()
|
||||
except Exception:
|
||||
logger.warning("Failed to close PubSub connection", exc_info=True)
|
||||
finally:
|
||||
self._pubsub = None
|
||||
|
||||
async def publish_event(self, event: M, channel_key: str):
|
||||
"""
|
||||
Publish an event to Redis. Gracefully handles connection failures
|
||||
@@ -170,7 +157,6 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
await self.connection, channel_key
|
||||
)
|
||||
assert isinstance(pubsub, AsyncPubSub)
|
||||
self._pubsub = pubsub
|
||||
|
||||
if "*" in channel_key:
|
||||
await pubsub.psubscribe(full_channel_name)
|
||||
|
||||
@@ -1028,39 +1028,6 @@ async def get_graph(
|
||||
return GraphModel.from_db(graph, for_export)
|
||||
|
||||
|
||||
async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]:
|
||||
"""Batch-fetch multiple store-listed graphs by their IDs.
|
||||
|
||||
Only returns graphs that have approved store listings (publicly available).
|
||||
Does not require permission checks since store-listed graphs are public.
|
||||
|
||||
Args:
|
||||
*graph_ids: Variable number of graph IDs to fetch
|
||||
|
||||
Returns:
|
||||
Dict mapping graph_id to GraphModel for graphs with approved store listings
|
||||
"""
|
||||
if not graph_ids:
|
||||
return {}
|
||||
|
||||
store_listings = await StoreListingVersion.prisma().find_many(
|
||||
where={
|
||||
"agentGraphId": {"in": list(graph_ids)},
|
||||
"submissionStatus": SubmissionStatus.APPROVED,
|
||||
"isDeleted": False,
|
||||
},
|
||||
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
||||
distinct=["agentGraphId"],
|
||||
order={"agentGraphVersion": "desc"},
|
||||
)
|
||||
|
||||
return {
|
||||
listing.agentGraphId: GraphModel.from_db(listing.AgentGraph)
|
||||
for listing in store_listings
|
||||
if listing.AgentGraph
|
||||
}
|
||||
|
||||
|
||||
async def get_graph_as_admin(
|
||||
graph_id: str,
|
||||
version: int | None = None,
|
||||
|
||||
@@ -666,16 +666,10 @@ class CredentialsFieldInfo(BaseModel, Generic[CP, CT]):
|
||||
if not (self.discriminator and self.discriminator_mapping):
|
||||
return self
|
||||
|
||||
try:
|
||||
provider = self.discriminator_mapping[discriminator_value]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f"Model '{discriminator_value}' is not supported. "
|
||||
"It may have been deprecated. Please update your agent configuration."
|
||||
)
|
||||
|
||||
return CredentialsFieldInfo(
|
||||
credentials_provider=frozenset([provider]),
|
||||
credentials_provider=frozenset(
|
||||
[self.discriminator_mapping[discriminator_value]]
|
||||
),
|
||||
credentials_types=self.supported_types,
|
||||
credentials_scopes=self.required_scopes,
|
||||
discriminator=self.discriminator,
|
||||
|
||||
@@ -17,7 +17,6 @@ from backend.data.analytics import (
|
||||
get_accuracy_trends_and_alerts,
|
||||
get_marketplace_graphs_for_monitoring,
|
||||
)
|
||||
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||
from backend.data.execution import (
|
||||
create_graph_execution,
|
||||
@@ -220,9 +219,6 @@ class DatabaseManager(AppService):
|
||||
# Onboarding
|
||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||
|
||||
# OAuth
|
||||
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
||||
|
||||
# Store
|
||||
get_store_agents = _(get_store_agents)
|
||||
get_store_agent_details = _(get_store_agent_details)
|
||||
@@ -353,9 +349,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
||||
# Onboarding
|
||||
increment_onboarding_runs = d.increment_onboarding_runs
|
||||
|
||||
# OAuth
|
||||
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
||||
|
||||
# Store
|
||||
get_store_agents = d.get_store_agents
|
||||
get_store_agent_details = d.get_store_agent_details
|
||||
|
||||
@@ -24,9 +24,11 @@ from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from sqlalchemy import MetaData, create_engine
|
||||
|
||||
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||
from backend.data.block import BlockInput
|
||||
from backend.data.execution import GraphExecutionWithNodes
|
||||
from backend.data.model import CredentialsMetaInput
|
||||
from backend.data.onboarding import increment_onboarding_runs
|
||||
from backend.executor import utils as execution_utils
|
||||
from backend.monitoring import (
|
||||
NotificationJobArgs,
|
||||
@@ -36,11 +38,7 @@ from backend.monitoring import (
|
||||
report_execution_accuracy_alerts,
|
||||
report_late_executions,
|
||||
)
|
||||
from backend.util.clients import (
|
||||
get_database_manager_async_client,
|
||||
get_database_manager_client,
|
||||
get_scheduler_client,
|
||||
)
|
||||
from backend.util.clients import get_database_manager_client, get_scheduler_client
|
||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||
from backend.util.exceptions import (
|
||||
GraphNotFoundError,
|
||||
@@ -150,7 +148,6 @@ def execute_graph(**kwargs):
|
||||
async def _execute_graph(**kwargs):
|
||||
args = GraphExecutionJobArgs(**kwargs)
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
db = get_database_manager_async_client()
|
||||
try:
|
||||
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
||||
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
||||
@@ -160,7 +157,7 @@ async def _execute_graph(**kwargs):
|
||||
inputs=args.input_data,
|
||||
graph_credentials_inputs=args.input_credentials,
|
||||
)
|
||||
await db.increment_onboarding_runs(args.user_id)
|
||||
await increment_onboarding_runs(args.user_id)
|
||||
elapsed = asyncio.get_event_loop().time() - start_time
|
||||
logger.info(
|
||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
||||
@@ -249,13 +246,8 @@ def cleanup_expired_files():
|
||||
|
||||
def cleanup_oauth_tokens():
|
||||
"""Clean up expired OAuth tokens from the database."""
|
||||
|
||||
# Wait for completion
|
||||
async def _cleanup():
|
||||
db = get_database_manager_async_client()
|
||||
return await db.cleanup_expired_oauth_tokens()
|
||||
|
||||
run_async(_cleanup())
|
||||
run_async(cleanup_expired_oauth_tokens())
|
||||
|
||||
|
||||
def execution_accuracy_alerts():
|
||||
|
||||
@@ -224,6 +224,14 @@ openweathermap_credentials = APIKeyCredentials(
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
elevenlabs_credentials = APIKeyCredentials(
|
||||
id="f4a8b6c2-3d1e-4f5a-9b8c-7d6e5f4a3b2c",
|
||||
provider="elevenlabs",
|
||||
api_key=SecretStr(settings.secrets.elevenlabs_api_key),
|
||||
title="Use Credits for ElevenLabs",
|
||||
expires_at=None,
|
||||
)
|
||||
|
||||
DEFAULT_CREDENTIALS = [
|
||||
ollama_credentials,
|
||||
revid_credentials,
|
||||
@@ -252,6 +260,7 @@ DEFAULT_CREDENTIALS = [
|
||||
v0_credentials,
|
||||
webshare_proxy_credentials,
|
||||
openweathermap_credentials,
|
||||
elevenlabs_credentials,
|
||||
]
|
||||
|
||||
SYSTEM_CREDENTIAL_IDS = {cred.id for cred in DEFAULT_CREDENTIALS}
|
||||
@@ -366,6 +375,8 @@ class IntegrationCredentialsStore:
|
||||
all_credentials.append(webshare_proxy_credentials)
|
||||
if settings.secrets.openweathermap_api_key:
|
||||
all_credentials.append(openweathermap_credentials)
|
||||
if settings.secrets.elevenlabs_api_key:
|
||||
all_credentials.append(elevenlabs_credentials)
|
||||
return all_credentials
|
||||
|
||||
async def get_creds_by_id(
|
||||
|
||||
@@ -18,6 +18,7 @@ class ProviderName(str, Enum):
|
||||
DISCORD = "discord"
|
||||
D_ID = "d_id"
|
||||
E2B = "e2b"
|
||||
ELEVENLABS = "elevenlabs"
|
||||
FAL = "fal"
|
||||
GITHUB = "github"
|
||||
GOOGLE = "google"
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import fastapi
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from backend.api.features.integrations.router import router as integrations_router
|
||||
from backend.integrations.providers import ProviderName
|
||||
from backend.integrations.webhooks import utils as webhooks_utils
|
||||
|
||||
|
||||
def test_webhook_ingress_url_matches_route(monkeypatch) -> None:
|
||||
app = fastapi.FastAPI()
|
||||
app.include_router(integrations_router, prefix="/api/integrations")
|
||||
|
||||
provider = ProviderName.GITHUB
|
||||
webhook_id = "webhook_123"
|
||||
base_url = "https://example.com"
|
||||
|
||||
monkeypatch.setattr(webhooks_utils.app_config, "platform_base_url", base_url)
|
||||
|
||||
route = next(
|
||||
route
|
||||
for route in integrations_router.routes
|
||||
if isinstance(route, APIRoute)
|
||||
and route.path == "/{provider}/webhooks/{webhook_id}/ingress"
|
||||
and "POST" in route.methods
|
||||
)
|
||||
expected_path = f"/api/integrations{route.path}".format(
|
||||
provider=provider.value,
|
||||
webhook_id=webhook_id,
|
||||
)
|
||||
actual_url = urlparse(webhooks_utils.webhook_ingress_url(provider, webhook_id))
|
||||
expected_base = urlparse(base_url)
|
||||
|
||||
assert (actual_url.scheme, actual_url.netloc) == (
|
||||
expected_base.scheme,
|
||||
expected_base.netloc,
|
||||
)
|
||||
assert actual_url.path == expected_path
|
||||
@@ -1,19 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.util import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------#
|
||||
# CONSTANTS #
|
||||
# ---------------------------------------------------------------------------#
|
||||
@@ -109,17 +100,9 @@ def _is_objective_message(msg: dict) -> bool:
|
||||
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
|
||||
"""
|
||||
Carefully truncate tool message content while preserving tool structure.
|
||||
Handles both Anthropic-style (list content) and OpenAI-style (string content) tool messages.
|
||||
Only truncates tool_result content, leaves tool_use intact.
|
||||
"""
|
||||
content = msg.get("content")
|
||||
|
||||
# OpenAI-style tool message: role="tool" with string content
|
||||
if msg.get("role") == "tool" and isinstance(content, str):
|
||||
if _tok_len(content, enc) > max_tokens:
|
||||
msg["content"] = _truncate_middle_tokens(content, enc, max_tokens)
|
||||
return
|
||||
|
||||
# Anthropic-style: list content with tool_result items
|
||||
if not isinstance(content, list):
|
||||
return
|
||||
|
||||
@@ -157,6 +140,141 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
||||
# ---------------------------------------------------------------------------#
|
||||
|
||||
|
||||
def compress_prompt(
|
||||
messages: list[dict],
|
||||
target_tokens: int,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
reserve: int = 2_048,
|
||||
start_cap: int = 8_192,
|
||||
floor_cap: int = 128,
|
||||
lossy_ok: bool = True,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Shrink *messages* so that::
|
||||
|
||||
token_count(prompt) + reserve ≤ target_tokens
|
||||
|
||||
Strategy
|
||||
--------
|
||||
1. **Token-aware truncation** – progressively halve a per-message cap
|
||||
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
|
||||
*content* of every message except the first and last. Tool shells
|
||||
are included: we keep the envelope but shorten huge payloads.
|
||||
2. **Middle-out deletion** – if still over the limit, delete whole
|
||||
messages working outward from the centre, **skipping** any message
|
||||
that contains ``tool_calls`` or has ``role == "tool"``.
|
||||
3. **Last-chance trim** – if still too big, truncate the *first* and
|
||||
*last* message bodies down to `floor_cap` tokens.
|
||||
4. If the prompt is *still* too large:
|
||||
• raise ``ValueError`` when ``lossy_ok == False`` (default)
|
||||
• return the partially-trimmed prompt when ``lossy_ok == True``
|
||||
|
||||
Parameters
|
||||
----------
|
||||
messages Complete chat history (will be deep-copied).
|
||||
model Model name; passed to tiktoken to pick the right
|
||||
tokenizer (gpt-4o → 'o200k_base', others fallback).
|
||||
target_tokens Hard ceiling for prompt size **excluding** the model's
|
||||
forthcoming answer.
|
||||
reserve How many tokens you want to leave available for that
|
||||
answer (`max_tokens` in your subsequent completion call).
|
||||
start_cap Initial per-message truncation ceiling (tokens).
|
||||
floor_cap Lowest cap we'll accept before moving to deletions.
|
||||
lossy_ok If *True* return best-effort prompt instead of raising
|
||||
after all trim passes have been exhausted.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list[dict] – A *new* messages list that abides by the rules above.
|
||||
"""
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
msgs = deepcopy(messages) # never mutate caller
|
||||
|
||||
def total_tokens() -> int:
|
||||
"""Current size of *msgs* in tokens."""
|
||||
return sum(_msg_tokens(m, enc) for m in msgs)
|
||||
|
||||
original_token_count = total_tokens()
|
||||
|
||||
if original_token_count + reserve <= target_tokens:
|
||||
return msgs
|
||||
|
||||
# ---- STEP 0 : normalise content --------------------------------------
|
||||
# Convert non-string payloads to strings so token counting is coherent.
|
||||
for i, m in enumerate(msgs):
|
||||
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
||||
if _is_tool_message(m):
|
||||
continue
|
||||
|
||||
# Keep first and last messages intact (unless they're tool messages)
|
||||
if i == 0 or i == len(msgs) - 1:
|
||||
continue
|
||||
|
||||
# Reasonable 20k-char ceiling prevents pathological blobs
|
||||
content_str = json.dumps(m["content"], separators=(",", ":"))
|
||||
if len(content_str) > 20_000:
|
||||
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
||||
m["content"] = content_str
|
||||
|
||||
# ---- STEP 1 : token-aware truncation ---------------------------------
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for m in msgs[1:-1]: # keep first & last intact
|
||||
if _is_tool_message(m):
|
||||
# For tool messages, only truncate tool result content, preserve structure
|
||||
_truncate_tool_message_content(m, enc, cap)
|
||||
continue
|
||||
|
||||
if _is_objective_message(m):
|
||||
# Never truncate objective messages - they contain the core task
|
||||
continue
|
||||
|
||||
content = m.get("content") or ""
|
||||
if _tok_len(content, enc) > cap:
|
||||
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
||||
cap //= 2 # tighten the screw
|
||||
|
||||
# ---- STEP 2 : middle-out deletion -----------------------------------
|
||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||
# Identify all deletable messages (not first/last, not tool messages, not objective messages)
|
||||
deletable_indices = []
|
||||
for i in range(1, len(msgs) - 1): # Skip first and last
|
||||
if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]):
|
||||
deletable_indices.append(i)
|
||||
|
||||
if not deletable_indices:
|
||||
break # nothing more we can drop
|
||||
|
||||
# Delete from center outward - find the index closest to center
|
||||
centre = len(msgs) // 2
|
||||
to_delete = min(deletable_indices, key=lambda i: abs(i - centre))
|
||||
del msgs[to_delete]
|
||||
|
||||
# ---- STEP 3 : final safety-net trim on first & last ------------------
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for idx in (0, -1): # first and last
|
||||
if _is_tool_message(msgs[idx]):
|
||||
# For tool messages at first/last position, truncate tool result content only
|
||||
_truncate_tool_message_content(msgs[idx], enc, cap)
|
||||
continue
|
||||
|
||||
text = msgs[idx].get("content") or ""
|
||||
if _tok_len(text, enc) > cap:
|
||||
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
|
||||
cap //= 2 # tighten the screw
|
||||
|
||||
# ---- STEP 4 : success or fail-gracefully -----------------------------
|
||||
if total_tokens() + reserve > target_tokens and not lossy_ok:
|
||||
raise ValueError(
|
||||
"compress_prompt: prompt still exceeds budget "
|
||||
f"({total_tokens() + reserve} > {target_tokens})."
|
||||
)
|
||||
|
||||
return msgs
|
||||
|
||||
|
||||
def estimate_token_count(
|
||||
messages: list[dict],
|
||||
*,
|
||||
@@ -175,8 +293,7 @@ def estimate_token_count(
|
||||
-------
|
||||
int – Token count.
|
||||
"""
|
||||
token_model = _normalize_model_for_tokenizer(model)
|
||||
enc = encoding_for_model(token_model)
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
return sum(_msg_tokens(m, enc) for m in messages)
|
||||
|
||||
|
||||
@@ -198,543 +315,6 @@ def estimate_token_count_str(
|
||||
-------
|
||||
int – Token count.
|
||||
"""
|
||||
token_model = _normalize_model_for_tokenizer(model)
|
||||
enc = encoding_for_model(token_model)
|
||||
enc = encoding_for_model(model) # best-match tokenizer
|
||||
text = json.dumps(text) if not isinstance(text, str) else text
|
||||
return _tok_len(text, enc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------#
|
||||
# UNIFIED CONTEXT COMPRESSION #
|
||||
# ---------------------------------------------------------------------------#
|
||||
|
||||
# Default thresholds
|
||||
DEFAULT_TOKEN_THRESHOLD = 120_000
|
||||
DEFAULT_KEEP_RECENT = 15
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompressResult:
|
||||
"""Result of context compression."""
|
||||
|
||||
messages: list[dict]
|
||||
token_count: int
|
||||
was_compacted: bool
|
||||
error: str | None = None
|
||||
original_token_count: int = 0
|
||||
messages_summarized: int = 0
|
||||
messages_dropped: int = 0
|
||||
|
||||
|
||||
def _normalize_model_for_tokenizer(model: str) -> str:
|
||||
"""Normalize model name for tiktoken tokenizer selection."""
|
||||
if "/" in model:
|
||||
model = model.split("/")[-1]
|
||||
if "claude" in model.lower() or not any(
|
||||
known in model.lower() for known in ["gpt", "o1", "chatgpt", "text-"]
|
||||
):
|
||||
return "gpt-4o"
|
||||
return model
|
||||
|
||||
|
||||
def _extract_tool_call_ids_from_message(msg: dict) -> set[str]:
|
||||
"""
|
||||
Extract tool_call IDs from an assistant message.
|
||||
|
||||
Supports both formats:
|
||||
- OpenAI: {"role": "assistant", "tool_calls": [{"id": "..."}]}
|
||||
- Anthropic: {"role": "assistant", "content": [{"type": "tool_use", "id": "..."}]}
|
||||
|
||||
Returns:
|
||||
Set of tool_call IDs found in the message.
|
||||
"""
|
||||
ids: set[str] = set()
|
||||
if msg.get("role") != "assistant":
|
||||
return ids
|
||||
|
||||
# OpenAI format: tool_calls array
|
||||
if msg.get("tool_calls"):
|
||||
for tc in msg["tool_calls"]:
|
||||
tc_id = tc.get("id")
|
||||
if tc_id:
|
||||
ids.add(tc_id)
|
||||
|
||||
# Anthropic format: content list with tool_use blocks
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||
tc_id = block.get("id")
|
||||
if tc_id:
|
||||
ids.add(tc_id)
|
||||
|
||||
return ids
|
||||
|
||||
|
||||
def _extract_tool_response_ids_from_message(msg: dict) -> set[str]:
|
||||
"""
|
||||
Extract tool_call IDs that this message is responding to.
|
||||
|
||||
Supports both formats:
|
||||
- OpenAI: {"role": "tool", "tool_call_id": "..."}
|
||||
- Anthropic: {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "..."}]}
|
||||
|
||||
Returns:
|
||||
Set of tool_call IDs this message responds to.
|
||||
"""
|
||||
ids: set[str] = set()
|
||||
|
||||
# OpenAI format: role=tool with tool_call_id
|
||||
if msg.get("role") == "tool":
|
||||
tc_id = msg.get("tool_call_id")
|
||||
if tc_id:
|
||||
ids.add(tc_id)
|
||||
|
||||
# Anthropic format: content list with tool_result blocks
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
tc_id = block.get("tool_use_id")
|
||||
if tc_id:
|
||||
ids.add(tc_id)
|
||||
|
||||
return ids
|
||||
|
||||
|
||||
def _is_tool_response_message(msg: dict) -> bool:
|
||||
"""Check if message is a tool response (OpenAI or Anthropic format)."""
|
||||
# OpenAI format
|
||||
if msg.get("role") == "tool":
|
||||
return True
|
||||
# Anthropic format
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _remove_orphan_tool_responses(
|
||||
messages: list[dict], orphan_ids: set[str]
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Remove tool response messages/blocks that reference orphan tool_call IDs.
|
||||
|
||||
Supports both OpenAI and Anthropic formats.
|
||||
For Anthropic messages with mixed valid/orphan tool_result blocks,
|
||||
filters out only the orphan blocks instead of dropping the entire message.
|
||||
"""
|
||||
result = []
|
||||
for msg in messages:
|
||||
# OpenAI format: role=tool - drop entire message if orphan
|
||||
if msg.get("role") == "tool":
|
||||
tc_id = msg.get("tool_call_id")
|
||||
if tc_id and tc_id in orphan_ids:
|
||||
continue
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
# Anthropic format: content list may have mixed tool_result blocks
|
||||
content = msg.get("content")
|
||||
if isinstance(content, list):
|
||||
has_tool_results = any(
|
||||
isinstance(b, dict) and b.get("type") == "tool_result" for b in content
|
||||
)
|
||||
if has_tool_results:
|
||||
# Filter out orphan tool_result blocks, keep valid ones
|
||||
filtered_content = [
|
||||
block
|
||||
for block in content
|
||||
if not (
|
||||
isinstance(block, dict)
|
||||
and block.get("type") == "tool_result"
|
||||
and block.get("tool_use_id") in orphan_ids
|
||||
)
|
||||
]
|
||||
# Only keep message if it has remaining content
|
||||
if filtered_content:
|
||||
msg = msg.copy()
|
||||
msg["content"] = filtered_content
|
||||
result.append(msg)
|
||||
continue
|
||||
|
||||
result.append(msg)
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_tool_pairs_intact(
|
||||
recent_messages: list[dict],
|
||||
all_messages: list[dict],
|
||||
start_index: int,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Ensure tool_call/tool_response pairs stay together after slicing.
|
||||
|
||||
When slicing messages for context compaction, a naive slice can separate
|
||||
an assistant message containing tool_calls from its corresponding tool
|
||||
response messages. This causes API validation errors (e.g., Anthropic's
|
||||
"unexpected tool_use_id found in tool_result blocks").
|
||||
|
||||
This function checks for orphan tool responses in the slice and extends
|
||||
backwards to include their corresponding assistant messages.
|
||||
|
||||
Supports both formats:
|
||||
- OpenAI: tool_calls array + role="tool" responses
|
||||
- Anthropic: tool_use blocks + tool_result blocks
|
||||
|
||||
Args:
|
||||
recent_messages: The sliced messages to validate
|
||||
all_messages: The complete message list (for looking up missing assistants)
|
||||
start_index: The index in all_messages where recent_messages begins
|
||||
|
||||
Returns:
|
||||
A potentially extended list of messages with tool pairs intact
|
||||
"""
|
||||
if not recent_messages:
|
||||
return recent_messages
|
||||
|
||||
# Collect all tool_call_ids from assistant messages in the slice
|
||||
available_tool_call_ids: set[str] = set()
|
||||
for msg in recent_messages:
|
||||
available_tool_call_ids |= _extract_tool_call_ids_from_message(msg)
|
||||
|
||||
# Find orphan tool responses (responses whose tool_call_id is missing)
|
||||
orphan_tool_call_ids: set[str] = set()
|
||||
for msg in recent_messages:
|
||||
response_ids = _extract_tool_response_ids_from_message(msg)
|
||||
for tc_id in response_ids:
|
||||
if tc_id not in available_tool_call_ids:
|
||||
orphan_tool_call_ids.add(tc_id)
|
||||
|
||||
if not orphan_tool_call_ids:
|
||||
# No orphans, slice is valid
|
||||
return recent_messages
|
||||
|
||||
# Find the assistant messages that contain the orphan tool_call_ids
|
||||
# Search backwards from start_index in all_messages
|
||||
messages_to_prepend: list[dict] = []
|
||||
for i in range(start_index - 1, -1, -1):
|
||||
msg = all_messages[i]
|
||||
msg_tool_ids = _extract_tool_call_ids_from_message(msg)
|
||||
if msg_tool_ids & orphan_tool_call_ids:
|
||||
# This assistant message has tool_calls we need
|
||||
# Also collect its contiguous tool responses that follow it
|
||||
assistant_and_responses: list[dict] = [msg]
|
||||
|
||||
# Scan forward from this assistant to collect tool responses
|
||||
for j in range(i + 1, start_index):
|
||||
following_msg = all_messages[j]
|
||||
following_response_ids = _extract_tool_response_ids_from_message(
|
||||
following_msg
|
||||
)
|
||||
if following_response_ids and following_response_ids & msg_tool_ids:
|
||||
assistant_and_responses.append(following_msg)
|
||||
elif not _is_tool_response_message(following_msg):
|
||||
# Stop at first non-tool-response message
|
||||
break
|
||||
|
||||
# Prepend the assistant and its tool responses (maintain order)
|
||||
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
||||
# Mark these as found
|
||||
orphan_tool_call_ids -= msg_tool_ids
|
||||
# Also add this assistant's tool_call_ids to available set
|
||||
available_tool_call_ids |= msg_tool_ids
|
||||
|
||||
if not orphan_tool_call_ids:
|
||||
# Found all missing assistants
|
||||
break
|
||||
|
||||
if orphan_tool_call_ids:
|
||||
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||
# This shouldn't happen in normal operation but handles edge cases
|
||||
logger.warning(
|
||||
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||
"Removing orphan tool responses."
|
||||
)
|
||||
recent_messages = _remove_orphan_tool_responses(
|
||||
recent_messages, orphan_tool_call_ids
|
||||
)
|
||||
|
||||
if messages_to_prepend:
|
||||
logger.info(
|
||||
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||
f"tool_call/tool_response pairs"
|
||||
)
|
||||
return messages_to_prepend + recent_messages
|
||||
|
||||
return recent_messages
|
||||
|
||||
|
||||
async def _summarize_messages_llm(
|
||||
messages: list[dict],
|
||||
client: AsyncOpenAI,
|
||||
model: str,
|
||||
timeout: float = 30.0,
|
||||
) -> str:
|
||||
"""Summarize messages using an LLM."""
|
||||
conversation = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "")
|
||||
content = msg.get("content", "")
|
||||
if content and role in ("user", "assistant", "tool"):
|
||||
conversation.append(f"{role.upper()}: {content}")
|
||||
|
||||
conversation_text = "\n\n".join(conversation)
|
||||
|
||||
if not conversation_text:
|
||||
return "No conversation history available."
|
||||
|
||||
# Limit to ~100k chars for safety
|
||||
MAX_CHARS = 100_000
|
||||
if len(conversation_text) > MAX_CHARS:
|
||||
conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]"
|
||||
|
||||
response = await client.with_options(timeout=timeout).chat.completions.create(
|
||||
model=model,
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": (
|
||||
"Create a detailed summary of the conversation so far. "
|
||||
"This summary will be used as context when continuing the conversation.\n\n"
|
||||
"Before writing the summary, analyze each message chronologically to identify:\n"
|
||||
"- User requests and their explicit goals\n"
|
||||
"- Your approach and key decisions made\n"
|
||||
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
||||
"- Errors encountered and resolutions applied\n\n"
|
||||
"You MUST include ALL of the following sections:\n\n"
|
||||
"## 1. Primary Request and Intent\n"
|
||||
"The user's explicit goals and what they are trying to accomplish.\n\n"
|
||||
"## 2. Key Technical Concepts\n"
|
||||
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
|
||||
"## 3. Files and Resources Involved\n"
|
||||
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
||||
"## 4. Errors and Fixes\n"
|
||||
"Problems encountered, error messages, and their resolutions. "
|
||||
"Include any user feedback on fixes.\n\n"
|
||||
"## 5. Problem Solving\n"
|
||||
"Issues that have been resolved and how they were addressed.\n\n"
|
||||
"## 6. All User Messages\n"
|
||||
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
|
||||
"## 7. Pending Tasks\n"
|
||||
"Work items the user explicitly requested that have not yet been completed.\n\n"
|
||||
"## 8. Current Work\n"
|
||||
"Precise description of what was being worked on most recently, including relevant context.\n\n"
|
||||
"## 9. Next Steps\n"
|
||||
"What should happen next, aligned with the user's most recent requests. "
|
||||
"Include verbatim quotes of recent instructions if relevant."
|
||||
),
|
||||
},
|
||||
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
||||
],
|
||||
max_tokens=1500,
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
return response.choices[0].message.content or "No summary available."
|
||||
|
||||
|
||||
async def compress_context(
|
||||
messages: list[dict],
|
||||
target_tokens: int = DEFAULT_TOKEN_THRESHOLD,
|
||||
*,
|
||||
model: str = "gpt-4o",
|
||||
client: AsyncOpenAI | None = None,
|
||||
keep_recent: int = DEFAULT_KEEP_RECENT,
|
||||
reserve: int = 2_048,
|
||||
start_cap: int = 8_192,
|
||||
floor_cap: int = 128,
|
||||
) -> CompressResult:
|
||||
"""
|
||||
Unified context compression that combines summarization and truncation strategies.
|
||||
|
||||
Strategy (in order):
|
||||
1. **LLM summarization** – If client provided, summarize old messages into a
|
||||
single context message while keeping recent messages intact. This is the
|
||||
primary strategy for chat service.
|
||||
2. **Content truncation** – Progressively halve a per-message cap and truncate
|
||||
bloated message content (tool outputs, large pastes). Preserves all messages
|
||||
but shortens their content. Primary strategy when client=None (LLM blocks).
|
||||
3. **Middle-out deletion** – Delete whole messages one at a time from the center
|
||||
outward, skipping tool messages and objective messages.
|
||||
4. **First/last trim** – Truncate first and last message content as last resort.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
messages Complete chat history (will be deep-copied).
|
||||
target_tokens Hard ceiling for prompt size.
|
||||
model Model name for tokenization and summarization.
|
||||
client AsyncOpenAI client. If provided, enables LLM summarization
|
||||
as the first strategy. If None, skips to truncation strategies.
|
||||
keep_recent Number of recent messages to preserve during summarization.
|
||||
reserve Tokens to reserve for model response.
|
||||
start_cap Initial per-message truncation ceiling (tokens).
|
||||
floor_cap Lowest cap before moving to deletions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
CompressResult with compressed messages and metadata.
|
||||
"""
|
||||
# Guard clause for empty messages
|
||||
if not messages:
|
||||
return CompressResult(
|
||||
messages=[],
|
||||
token_count=0,
|
||||
was_compacted=False,
|
||||
original_token_count=0,
|
||||
)
|
||||
|
||||
token_model = _normalize_model_for_tokenizer(model)
|
||||
enc = encoding_for_model(token_model)
|
||||
msgs = deepcopy(messages)
|
||||
|
||||
def total_tokens() -> int:
|
||||
return sum(_msg_tokens(m, enc) for m in msgs)
|
||||
|
||||
original_count = total_tokens()
|
||||
|
||||
# Already under limit
|
||||
if original_count + reserve <= target_tokens:
|
||||
return CompressResult(
|
||||
messages=msgs,
|
||||
token_count=original_count,
|
||||
was_compacted=False,
|
||||
original_token_count=original_count,
|
||||
)
|
||||
|
||||
messages_summarized = 0
|
||||
messages_dropped = 0
|
||||
|
||||
# ---- STEP 1: LLM summarization (if client provided) -------------------
|
||||
# This is the primary compression strategy for chat service.
|
||||
# Summarize old messages while keeping recent ones intact.
|
||||
if client is not None:
|
||||
has_system = len(msgs) > 0 and msgs[0].get("role") == "system"
|
||||
system_msg = msgs[0] if has_system else None
|
||||
|
||||
# Calculate old vs recent messages
|
||||
if has_system:
|
||||
if len(msgs) > keep_recent + 1:
|
||||
old_msgs = msgs[1:-keep_recent]
|
||||
recent_msgs = msgs[-keep_recent:]
|
||||
else:
|
||||
old_msgs = []
|
||||
recent_msgs = msgs[1:] if len(msgs) > 1 else []
|
||||
else:
|
||||
if len(msgs) > keep_recent:
|
||||
old_msgs = msgs[:-keep_recent]
|
||||
recent_msgs = msgs[-keep_recent:]
|
||||
else:
|
||||
old_msgs = []
|
||||
recent_msgs = msgs
|
||||
|
||||
# Ensure tool pairs stay intact
|
||||
slice_start = max(0, len(msgs) - keep_recent)
|
||||
recent_msgs = _ensure_tool_pairs_intact(recent_msgs, msgs, slice_start)
|
||||
|
||||
if old_msgs:
|
||||
try:
|
||||
summary_text = await _summarize_messages_llm(old_msgs, client, model)
|
||||
summary_msg = {
|
||||
"role": "assistant",
|
||||
"content": f"[Previous conversation summary — for context only]: {summary_text}",
|
||||
}
|
||||
messages_summarized = len(old_msgs)
|
||||
|
||||
if has_system:
|
||||
msgs = [system_msg, summary_msg] + recent_msgs
|
||||
else:
|
||||
msgs = [summary_msg] + recent_msgs
|
||||
|
||||
logger.info(
|
||||
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
||||
f"summarized {messages_summarized} messages"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
||||
# Fall through to content truncation
|
||||
|
||||
# ---- STEP 2: Normalize content ----------------------------------------
|
||||
# Convert non-string payloads to strings so token counting is coherent.
|
||||
# Always run this before truncation to ensure consistent token counting.
|
||||
for i, m in enumerate(msgs):
|
||||
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
||||
if _is_tool_message(m):
|
||||
continue
|
||||
if i == 0 or i == len(msgs) - 1:
|
||||
continue
|
||||
content_str = json.dumps(m["content"], separators=(",", ":"))
|
||||
if len(content_str) > 20_000:
|
||||
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
||||
m["content"] = content_str
|
||||
|
||||
# ---- STEP 3: Token-aware content truncation ---------------------------
|
||||
# Progressively halve per-message cap and truncate bloated content.
|
||||
# This preserves all messages but shortens their content.
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for m in msgs[1:-1]:
|
||||
if _is_tool_message(m):
|
||||
_truncate_tool_message_content(m, enc, cap)
|
||||
continue
|
||||
if _is_objective_message(m):
|
||||
continue
|
||||
content = m.get("content") or ""
|
||||
if _tok_len(content, enc) > cap:
|
||||
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
||||
cap //= 2
|
||||
|
||||
# ---- STEP 4: Middle-out deletion --------------------------------------
|
||||
# Delete messages one at a time from the center outward.
|
||||
# This is more granular than dropping all old messages at once.
|
||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||
deletable: list[int] = []
|
||||
for i in range(1, len(msgs) - 1):
|
||||
msg = msgs[i]
|
||||
if (
|
||||
msg is not None
|
||||
and not _is_tool_message(msg)
|
||||
and not _is_objective_message(msg)
|
||||
):
|
||||
deletable.append(i)
|
||||
if not deletable:
|
||||
break
|
||||
centre = len(msgs) // 2
|
||||
to_delete = min(deletable, key=lambda i: abs(i - centre))
|
||||
del msgs[to_delete]
|
||||
messages_dropped += 1
|
||||
|
||||
# ---- STEP 5: Final trim on first/last ---------------------------------
|
||||
cap = start_cap
|
||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||
for idx in (0, -1):
|
||||
msg = msgs[idx]
|
||||
if msg is None:
|
||||
continue
|
||||
if _is_tool_message(msg):
|
||||
_truncate_tool_message_content(msg, enc, cap)
|
||||
continue
|
||||
text = msg.get("content") or ""
|
||||
if _tok_len(text, enc) > cap:
|
||||
msg["content"] = _truncate_middle_tokens(text, enc, cap)
|
||||
cap //= 2
|
||||
|
||||
# Filter out any None values that may have been introduced
|
||||
final_msgs: list[dict] = [m for m in msgs if m is not None]
|
||||
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
|
||||
error = None
|
||||
if final_count + reserve > target_tokens:
|
||||
error = f"Could not compress below target ({final_count + reserve} > {target_tokens})"
|
||||
logger.warning(error)
|
||||
|
||||
return CompressResult(
|
||||
messages=final_msgs,
|
||||
token_count=final_count,
|
||||
was_compacted=True,
|
||||
error=error,
|
||||
original_token_count=original_count,
|
||||
messages_summarized=messages_summarized,
|
||||
messages_dropped=messages_dropped,
|
||||
)
|
||||
|
||||
@@ -1,21 +1,10 @@
|
||||
"""Tests for prompt utility functions, especially tool call token counting."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from tiktoken import encoding_for_model
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import (
|
||||
CompressResult,
|
||||
_ensure_tool_pairs_intact,
|
||||
_msg_tokens,
|
||||
_normalize_model_for_tokenizer,
|
||||
_truncate_middle_tokens,
|
||||
_truncate_tool_message_content,
|
||||
compress_context,
|
||||
estimate_token_count,
|
||||
)
|
||||
from backend.util.prompt import _msg_tokens, estimate_token_count
|
||||
|
||||
|
||||
class TestMsgTokens:
|
||||
@@ -287,690 +276,3 @@ class TestEstimateTokenCount:
|
||||
|
||||
assert total_tokens == expected_total
|
||||
assert total_tokens > 20 # Should be substantial
|
||||
|
||||
|
||||
class TestNormalizeModelForTokenizer:
|
||||
"""Test model name normalization for tiktoken."""
|
||||
|
||||
def test_openai_models_unchanged(self):
|
||||
"""Test that OpenAI models are returned as-is."""
|
||||
assert _normalize_model_for_tokenizer("gpt-4o") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("gpt-4") == "gpt-4"
|
||||
assert _normalize_model_for_tokenizer("gpt-3.5-turbo") == "gpt-3.5-turbo"
|
||||
|
||||
def test_claude_models_normalized(self):
|
||||
"""Test that Claude models are normalized to gpt-4o."""
|
||||
assert _normalize_model_for_tokenizer("claude-3-opus") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("claude-3-sonnet") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("anthropic/claude-3-haiku") == "gpt-4o"
|
||||
|
||||
def test_openrouter_paths_extracted(self):
|
||||
"""Test that OpenRouter model paths are handled."""
|
||||
assert _normalize_model_for_tokenizer("openai/gpt-4o") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("anthropic/claude-3-opus") == "gpt-4o"
|
||||
|
||||
def test_unknown_models_default_to_gpt4o(self):
|
||||
"""Test that unknown models default to gpt-4o."""
|
||||
assert _normalize_model_for_tokenizer("some-random-model") == "gpt-4o"
|
||||
assert _normalize_model_for_tokenizer("llama-3-70b") == "gpt-4o"
|
||||
|
||||
|
||||
class TestTruncateToolMessageContent:
|
||||
"""Test tool message content truncation."""
|
||||
|
||||
@pytest.fixture
|
||||
def enc(self):
|
||||
return encoding_for_model("gpt-4o")
|
||||
|
||||
def test_truncate_openai_tool_message(self, enc):
|
||||
"""Test truncation of OpenAI-style tool message with string content."""
|
||||
long_content = "x" * 10000
|
||||
msg = {"role": "tool", "tool_call_id": "call_123", "content": long_content}
|
||||
|
||||
_truncate_tool_message_content(msg, enc, max_tokens=100)
|
||||
|
||||
# Content should be truncated
|
||||
assert len(msg["content"]) < len(long_content)
|
||||
assert "…" in msg["content"] # Has ellipsis marker
|
||||
|
||||
def test_truncate_anthropic_tool_result(self, enc):
|
||||
"""Test truncation of Anthropic-style tool_result."""
|
||||
long_content = "y" * 10000
|
||||
msg = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_123",
|
||||
"content": long_content,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
_truncate_tool_message_content(msg, enc, max_tokens=100)
|
||||
|
||||
# Content should be truncated
|
||||
result_content = msg["content"][0]["content"]
|
||||
assert len(result_content) < len(long_content)
|
||||
assert "…" in result_content
|
||||
|
||||
def test_preserve_tool_use_blocks(self, enc):
|
||||
"""Test that tool_use blocks are not truncated."""
|
||||
msg = {
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "some_function",
|
||||
"input": {"key": "value" * 1000}, # Large input
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
original = json.dumps(msg["content"][0]["input"])
|
||||
_truncate_tool_message_content(msg, enc, max_tokens=10)
|
||||
|
||||
# tool_use should be unchanged
|
||||
assert json.dumps(msg["content"][0]["input"]) == original
|
||||
|
||||
def test_no_truncation_when_under_limit(self, enc):
|
||||
"""Test that short content is not modified."""
|
||||
msg = {"role": "tool", "tool_call_id": "call_123", "content": "Short content"}
|
||||
|
||||
original = msg["content"]
|
||||
_truncate_tool_message_content(msg, enc, max_tokens=1000)
|
||||
|
||||
assert msg["content"] == original
|
||||
|
||||
|
||||
class TestTruncateMiddleTokens:
|
||||
"""Test middle truncation of text."""
|
||||
|
||||
@pytest.fixture
|
||||
def enc(self):
|
||||
return encoding_for_model("gpt-4o")
|
||||
|
||||
def test_truncates_long_text(self, enc):
|
||||
"""Test that long text is truncated with ellipsis in middle."""
|
||||
long_text = "word " * 1000
|
||||
result = _truncate_middle_tokens(long_text, enc, max_tok=50)
|
||||
|
||||
assert len(enc.encode(result)) <= 52 # Allow some slack for ellipsis
|
||||
assert "…" in result
|
||||
assert result.startswith("word") # Head preserved
|
||||
assert result.endswith("word ") # Tail preserved
|
||||
|
||||
def test_preserves_short_text(self, enc):
|
||||
"""Test that short text is not modified."""
|
||||
short_text = "Hello world"
|
||||
result = _truncate_middle_tokens(short_text, enc, max_tok=100)
|
||||
|
||||
assert result == short_text
|
||||
|
||||
|
||||
class TestEnsureToolPairsIntact:
|
||||
"""Test tool call/response pair preservation for both OpenAI and Anthropic formats."""
|
||||
|
||||
# ---- OpenAI Format Tests ----
|
||||
|
||||
def test_openai_adds_missing_tool_call(self):
|
||||
"""Test that orphaned OpenAI tool_response gets its tool_call prepended."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (the tool response)
|
||||
recent = [all_msgs[2], all_msgs[3]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the tool_call message
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert "tool_calls" in result[0]
|
||||
|
||||
def test_openai_keeps_complete_pairs(self):
|
||||
"""Test that complete OpenAI pairs are unchanged."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "System"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
|
||||
]
|
||||
recent = all_msgs[1:] # Include both tool_call and response
|
||||
start_index = 1
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
assert len(result) == 2 # No messages added
|
||||
|
||||
def test_openai_multiple_tool_calls(self):
|
||||
"""Test multiple OpenAI tool calls in one assistant message."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "System"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "f1"}},
|
||||
{"id": "call_2", "type": "function", "function": {"name": "f2"}},
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "result1"},
|
||||
{"role": "tool", "tool_call_id": "call_2", "content": "result2"},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (first tool response)
|
||||
recent = [all_msgs[2], all_msgs[3], all_msgs[4]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the assistant message with both tool_calls
|
||||
assert len(result) == 4
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert len(result[0]["tool_calls"]) == 2
|
||||
|
||||
# ---- Anthropic Format Tests ----
|
||||
|
||||
def test_anthropic_adds_missing_tool_use(self):
|
||||
"""Test that orphaned Anthropic tool_result gets its tool_use prepended."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_123",
|
||||
"name": "get_weather",
|
||||
"input": {"location": "SF"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_123",
|
||||
"content": "22°C and sunny",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (the tool_result)
|
||||
recent = [all_msgs[2], all_msgs[3]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the tool_use message
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert result[0]["content"][0]["type"] == "tool_use"
|
||||
|
||||
def test_anthropic_keeps_complete_pairs(self):
|
||||
"""Test that complete Anthropic pairs are unchanged."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "System"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_456",
|
||||
"name": "calculator",
|
||||
"input": {"expr": "2+2"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_456",
|
||||
"content": "4",
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
recent = all_msgs[1:] # Include both tool_use and result
|
||||
start_index = 1
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
assert len(result) == 2 # No messages added
|
||||
|
||||
def test_anthropic_multiple_tool_uses(self):
|
||||
"""Test multiple Anthropic tool_use blocks in one message."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "System"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me check both..."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_1",
|
||||
"name": "get_weather",
|
||||
"input": {"city": "NYC"},
|
||||
},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_2",
|
||||
"name": "get_weather",
|
||||
"input": {"city": "LA"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_1",
|
||||
"content": "Cold",
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_2",
|
||||
"content": "Warm",
|
||||
},
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (tool_result)
|
||||
recent = [all_msgs[2], all_msgs[3]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the assistant message with both tool_uses
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "assistant"
|
||||
tool_use_count = sum(
|
||||
1 for b in result[0]["content"] if b.get("type") == "tool_use"
|
||||
)
|
||||
assert tool_use_count == 2
|
||||
|
||||
# ---- Mixed/Edge Case Tests ----
|
||||
|
||||
def test_anthropic_with_type_message_field(self):
|
||||
"""Test Anthropic format with 'type': 'message' field (smart_decision_maker style)."""
|
||||
all_msgs = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_abc",
|
||||
"name": "search",
|
||||
"input": {"q": "test"},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"type": "message", # Extra field from smart_decision_maker
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_abc",
|
||||
"content": "Found results",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "user", "content": "Thanks!"},
|
||||
]
|
||||
# Recent messages start at index 2 (the tool_result with 'type': 'message')
|
||||
recent = [all_msgs[2], all_msgs[3]]
|
||||
start_index = 2
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the tool_use message
|
||||
assert len(result) == 3
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert result[0]["content"][0]["type"] == "tool_use"
|
||||
|
||||
def test_handles_no_tool_messages(self):
|
||||
"""Test messages without tool calls."""
|
||||
all_msgs = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
recent = all_msgs
|
||||
start_index = 0
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
assert result == all_msgs
|
||||
|
||||
def test_handles_empty_messages(self):
|
||||
"""Test empty message list."""
|
||||
result = _ensure_tool_pairs_intact([], [], 0)
|
||||
assert result == []
|
||||
|
||||
def test_mixed_text_and_tool_content(self):
|
||||
"""Test Anthropic message with mixed text and tool_use content."""
|
||||
all_msgs = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "I'll help you with that."},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "toolu_mixed",
|
||||
"name": "search",
|
||||
"input": {"q": "test"},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_mixed",
|
||||
"content": "Found results",
|
||||
}
|
||||
],
|
||||
},
|
||||
{"role": "assistant", "content": "Here are the results..."},
|
||||
]
|
||||
# Start from tool_result
|
||||
recent = [all_msgs[1], all_msgs[2]]
|
||||
start_index = 1
|
||||
|
||||
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||
|
||||
# Should prepend the assistant message with tool_use
|
||||
assert len(result) == 3
|
||||
assert result[0]["content"][0]["type"] == "text"
|
||||
assert result[0]["content"][1]["type"] == "tool_use"
|
||||
|
||||
|
||||
class TestCompressContext:
|
||||
"""Test the async compress_context function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_compression_needed(self):
|
||||
"""Test messages under limit return without compression."""
|
||||
messages = [
|
||||
{"role": "system", "content": "You are helpful."},
|
||||
{"role": "user", "content": "Hello!"},
|
||||
]
|
||||
|
||||
result = await compress_context(messages, target_tokens=100000)
|
||||
|
||||
assert isinstance(result, CompressResult)
|
||||
assert result.was_compacted is False
|
||||
assert len(result.messages) == 2
|
||||
assert result.error is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_truncation_without_client(self):
|
||||
"""Test that truncation works without LLM client."""
|
||||
long_content = "x" * 50000
|
||||
messages = [
|
||||
{"role": "system", "content": "System"},
|
||||
{"role": "user", "content": long_content},
|
||||
{"role": "assistant", "content": "Response"},
|
||||
]
|
||||
|
||||
result = await compress_context(
|
||||
messages, target_tokens=1000, client=None, reserve=100
|
||||
)
|
||||
|
||||
assert result.was_compacted is True
|
||||
# Should have truncated without summarization
|
||||
assert result.messages_summarized == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_mocked_llm_client(self):
|
||||
"""Test summarization with mocked LLM client."""
|
||||
# Create many messages to trigger summarization
|
||||
messages = [{"role": "system", "content": "System prompt"}]
|
||||
for i in range(30):
|
||||
messages.append({"role": "user", "content": f"User message {i} " * 100})
|
||||
messages.append(
|
||||
{"role": "assistant", "content": f"Assistant response {i} " * 100}
|
||||
)
|
||||
|
||||
# Mock the AsyncOpenAI client
|
||||
mock_client = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [MagicMock()]
|
||||
mock_response.choices[0].message.content = "Summary of conversation"
|
||||
mock_client.with_options.return_value.chat.completions.create = AsyncMock(
|
||||
return_value=mock_response
|
||||
)
|
||||
|
||||
result = await compress_context(
|
||||
messages,
|
||||
target_tokens=5000,
|
||||
client=mock_client,
|
||||
keep_recent=5,
|
||||
reserve=500,
|
||||
)
|
||||
|
||||
assert result.was_compacted is True
|
||||
# Should have attempted summarization
|
||||
assert mock_client.with_options.called or result.messages_summarized > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preserves_tool_pairs(self):
|
||||
"""Test that tool call/response pairs stay together."""
|
||||
messages = [
|
||||
{"role": "system", "content": "System"},
|
||||
{"role": "user", "content": "Do something"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{"id": "call_1", "type": "function", "function": {"name": "func"}}
|
||||
],
|
||||
},
|
||||
{"role": "tool", "tool_call_id": "call_1", "content": "Result " * 1000},
|
||||
{"role": "assistant", "content": "Done!"},
|
||||
]
|
||||
|
||||
result = await compress_context(
|
||||
messages, target_tokens=500, client=None, reserve=50
|
||||
)
|
||||
|
||||
# Check that if tool response exists, its call exists too
|
||||
tool_call_ids = set()
|
||||
tool_response_ids = set()
|
||||
for msg in result.messages:
|
||||
if "tool_calls" in msg:
|
||||
for tc in msg["tool_calls"]:
|
||||
tool_call_ids.add(tc["id"])
|
||||
if msg.get("role") == "tool":
|
||||
tool_response_ids.add(msg.get("tool_call_id"))
|
||||
|
||||
# All tool responses should have their calls
|
||||
assert tool_response_ids <= tool_call_ids
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_error_when_cannot_compress(self):
|
||||
"""Test that error is returned when compression fails."""
|
||||
# Single huge message that can't be compressed enough
|
||||
messages = [
|
||||
{"role": "user", "content": "x" * 100000},
|
||||
]
|
||||
|
||||
result = await compress_context(
|
||||
messages, target_tokens=100, client=None, reserve=50
|
||||
)
|
||||
|
||||
# Should have an error since we can't get below 100 tokens
|
||||
assert result.error is not None
|
||||
assert result.was_compacted is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_messages(self):
|
||||
"""Test that empty messages list returns early without error."""
|
||||
result = await compress_context([], target_tokens=1000)
|
||||
|
||||
assert result.messages == []
|
||||
assert result.token_count == 0
|
||||
assert result.was_compacted is False
|
||||
assert result.error is None
|
||||
|
||||
|
||||
class TestRemoveOrphanToolResponses:
|
||||
"""Test _remove_orphan_tool_responses helper function."""
|
||||
|
||||
def test_removes_openai_orphan(self):
|
||||
"""Test removal of orphan OpenAI tool response."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "call_orphan", "content": "result"},
|
||||
{"role": "user", "content": "Hello"},
|
||||
]
|
||||
orphan_ids = {"call_orphan"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
|
||||
def test_keeps_valid_openai_tool(self):
|
||||
"""Test that valid OpenAI tool responses are kept."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{"role": "tool", "tool_call_id": "call_valid", "content": "result"},
|
||||
]
|
||||
orphan_ids = {"call_other"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["tool_call_id"] == "call_valid"
|
||||
|
||||
def test_filters_anthropic_mixed_blocks(self):
|
||||
"""Test filtering individual orphan blocks from Anthropic message with mixed valid/orphan."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_valid",
|
||||
"content": "valid result",
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_orphan",
|
||||
"content": "orphan result",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
orphan_ids = {"toolu_orphan"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
assert len(result) == 1
|
||||
# Should only have the valid tool_result, orphan filtered out
|
||||
assert len(result[0]["content"]) == 1
|
||||
assert result[0]["content"][0]["tool_use_id"] == "toolu_valid"
|
||||
|
||||
def test_removes_anthropic_all_orphan(self):
|
||||
"""Test removal of Anthropic message when all tool_results are orphans."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_orphan1",
|
||||
"content": "result1",
|
||||
},
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "toolu_orphan2",
|
||||
"content": "result2",
|
||||
},
|
||||
],
|
||||
},
|
||||
]
|
||||
orphan_ids = {"toolu_orphan1", "toolu_orphan2"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
# Message should be completely removed since no content left
|
||||
assert len(result) == 0
|
||||
|
||||
def test_preserves_non_tool_messages(self):
|
||||
"""Test that non-tool messages are preserved."""
|
||||
from backend.util.prompt import _remove_orphan_tool_responses
|
||||
|
||||
messages = [
|
||||
{"role": "user", "content": "Hello"},
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
orphan_ids = {"some_id"}
|
||||
|
||||
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||
|
||||
assert result == messages
|
||||
|
||||
|
||||
class TestCompressResultDataclass:
|
||||
"""Test CompressResult dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""Test default values are set correctly."""
|
||||
result = CompressResult(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
token_count=10,
|
||||
was_compacted=False,
|
||||
)
|
||||
|
||||
assert result.error is None
|
||||
assert result.original_token_count == 0 # Defaults to 0, not None
|
||||
assert result.messages_summarized == 0
|
||||
assert result.messages_dropped == 0
|
||||
|
||||
def test_all_fields(self):
|
||||
"""Test all fields can be set."""
|
||||
result = CompressResult(
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
token_count=100,
|
||||
was_compacted=True,
|
||||
error="Some error",
|
||||
original_token_count=500,
|
||||
messages_summarized=10,
|
||||
messages_dropped=5,
|
||||
)
|
||||
|
||||
assert result.token_count == 100
|
||||
assert result.was_compacted is True
|
||||
assert result.error == "Some error"
|
||||
assert result.original_token_count == 500
|
||||
assert result.messages_summarized == 10
|
||||
assert result.messages_dropped == 5
|
||||
|
||||
@@ -656,6 +656,7 @@ class Secrets(UpdateTrackingModel["Secrets"], BaseSettings):
|
||||
e2b_api_key: str = Field(default="", description="E2B API key")
|
||||
nvidia_api_key: str = Field(default="", description="Nvidia API key")
|
||||
mem0_api_key: str = Field(default="", description="Mem0 API key")
|
||||
elevenlabs_api_key: str = Field(default="", description="ElevenLabs API key")
|
||||
|
||||
linear_client_id: str = Field(default="", description="Linear client ID")
|
||||
linear_client_secret: str = Field(default="", description="Linear client secret")
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
-- Migrate Claude 3.7 Sonnet to Claude 4.5 Sonnet
|
||||
-- This updates all AgentNode blocks that use the deprecated Claude 3.7 Sonnet model
|
||||
-- Anthropic is retiring claude-3-7-sonnet-20250219 on February 19, 2026
|
||||
|
||||
-- Update AgentNode constant inputs
|
||||
UPDATE "AgentNode"
|
||||
SET "constantInput" = JSONB_SET(
|
||||
"constantInput"::jsonb,
|
||||
'{model}',
|
||||
'"claude-sonnet-4-5-20250929"'::jsonb
|
||||
)
|
||||
WHERE "constantInput"::jsonb->>'model' = 'claude-3-7-sonnet-20250219';
|
||||
|
||||
-- Update AgentPreset input overrides (stored in AgentNodeExecutionInputOutput)
|
||||
UPDATE "AgentNodeExecutionInputOutput"
|
||||
SET "data" = JSONB_SET(
|
||||
"data"::jsonb,
|
||||
'{model}',
|
||||
'"claude-sonnet-4-5-20250929"'::jsonb
|
||||
)
|
||||
WHERE "agentPresetId" IS NOT NULL
|
||||
AND "data"::jsonb->>'model' = 'claude-3-7-sonnet-20250219';
|
||||
47
autogpt_platform/backend/poetry.lock
generated
47
autogpt_platform/backend/poetry.lock
generated
@@ -1169,6 +1169,29 @@ attrs = ">=21.3.0"
|
||||
e2b = ">=1.5.4,<2.0.0"
|
||||
httpx = ">=0.20.0,<1.0.0"
|
||||
|
||||
[[package]]
|
||||
name = "elevenlabs"
|
||||
version = "1.59.0"
|
||||
description = ""
|
||||
optional = false
|
||||
python-versions = "<4.0,>=3.8"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "elevenlabs-1.59.0-py3-none-any.whl", hash = "sha256:468145db81a0bc867708b4a8619699f75583e9481b395ec1339d0b443da771ed"},
|
||||
{file = "elevenlabs-1.59.0.tar.gz", hash = "sha256:16e735bd594e86d415dd445d249c8cc28b09996cfd627fbc10102c0a84698859"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
httpx = ">=0.21.2"
|
||||
pydantic = ">=1.9.2"
|
||||
pydantic-core = ">=2.18.2,<3.0.0"
|
||||
requests = ">=2.20"
|
||||
typing_extensions = ">=4.0.0"
|
||||
websockets = ">=11.0"
|
||||
|
||||
[package.extras]
|
||||
pyaudio = ["pyaudio (>=0.2.14)"]
|
||||
|
||||
[[package]]
|
||||
name = "email-validator"
|
||||
version = "2.2.0"
|
||||
@@ -7361,6 +7384,28 @@ files = [
|
||||
defusedxml = ">=0.7.1,<0.8.0"
|
||||
requests = "*"
|
||||
|
||||
[[package]]
|
||||
name = "yt-dlp"
|
||||
version = "2025.12.8"
|
||||
description = "A feature-rich command-line audio/video downloader"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "yt_dlp-2025.12.8-py3-none-any.whl", hash = "sha256:36e2584342e409cfbfa0b5e61448a1c5189e345cf4564294456ee509e7d3e065"},
|
||||
{file = "yt_dlp-2025.12.8.tar.gz", hash = "sha256:b773c81bb6b71cb2c111cfb859f453c7a71cf2ef44eff234ff155877184c3e4f"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
build = ["build", "hatchling (>=1.27.0)", "pip", "setuptools (>=71.0.2)", "wheel"]
|
||||
curl-cffi = ["curl-cffi (>=0.5.10,<0.6.dev0 || >=0.10.dev0,<0.14) ; implementation_name == \"cpython\""]
|
||||
default = ["brotli ; implementation_name == \"cpython\"", "brotlicffi ; implementation_name != \"cpython\"", "certifi", "mutagen", "pycryptodomex", "requests (>=2.32.2,<3)", "urllib3 (>=2.0.2,<3)", "websockets (>=13.0)", "yt-dlp-ejs (==0.3.2)"]
|
||||
dev = ["autopep8 (>=2.0,<3.0)", "pre-commit", "pytest (>=8.1,<9.0)", "pytest-rerunfailures (>=14.0,<15.0)", "ruff (>=0.14.0,<0.15.0)"]
|
||||
pyinstaller = ["pyinstaller (>=6.17.0)"]
|
||||
secretstorage = ["cffi", "secretstorage"]
|
||||
static-analysis = ["autopep8 (>=2.0,<3.0)", "ruff (>=0.14.0,<0.15.0)"]
|
||||
test = ["pytest (>=8.1,<9.0)", "pytest-rerunfailures (>=14.0,<15.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "zerobouncesdk"
|
||||
version = "1.1.2"
|
||||
@@ -7512,4 +7557,4 @@ cffi = ["cffi (>=1.11)"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = ">=3.10,<3.14"
|
||||
content-hash = "ee5742dc1a9df50dfc06d4b26a1682cbb2b25cab6b79ce5625ec272f93e4f4bf"
|
||||
content-hash = "8239323f9ae6713224dffd1fe8ba8b449fe88b6c3c7a90940294a74f43a0387a"
|
||||
|
||||
@@ -20,6 +20,7 @@ click = "^8.2.0"
|
||||
cryptography = "^45.0"
|
||||
discord-py = "^2.5.2"
|
||||
e2b-code-interpreter = "^1.5.2"
|
||||
elevenlabs = "^1.50.0"
|
||||
fastapi = "^0.116.1"
|
||||
feedparser = "^6.0.11"
|
||||
flake8 = "^7.3.0"
|
||||
@@ -71,6 +72,7 @@ tweepy = "^4.16.0"
|
||||
uvicorn = { extras = ["standard"], version = "^0.35.0" }
|
||||
websockets = "^15.0"
|
||||
youtube-transcript-api = "^1.2.1"
|
||||
yt-dlp = "2025.12.08"
|
||||
zerobouncesdk = "^1.1.2"
|
||||
# NOTE: please insert new dependencies in their alphabetical location
|
||||
pytest-snapshot = "^0.9.0"
|
||||
|
||||
@@ -9,8 +9,7 @@
|
||||
"sub_heading": "Creator agent subheading",
|
||||
"description": "Creator agent description",
|
||||
"runs": 50,
|
||||
"rating": 4.0,
|
||||
"agent_graph_id": "test-graph-2"
|
||||
"rating": 4.0
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -9,8 +9,7 @@
|
||||
"sub_heading": "Category agent subheading",
|
||||
"description": "Category agent description",
|
||||
"runs": 60,
|
||||
"rating": 4.1,
|
||||
"agent_graph_id": "test-graph-category"
|
||||
"rating": 4.1
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -9,8 +9,7 @@
|
||||
"sub_heading": "Agent 0 subheading",
|
||||
"description": "Agent 0 description",
|
||||
"runs": 0,
|
||||
"rating": 4.0,
|
||||
"agent_graph_id": "test-graph-2"
|
||||
"rating": 4.0
|
||||
},
|
||||
{
|
||||
"slug": "agent-1",
|
||||
@@ -21,8 +20,7 @@
|
||||
"sub_heading": "Agent 1 subheading",
|
||||
"description": "Agent 1 description",
|
||||
"runs": 10,
|
||||
"rating": 4.0,
|
||||
"agent_graph_id": "test-graph-2"
|
||||
"rating": 4.0
|
||||
},
|
||||
{
|
||||
"slug": "agent-2",
|
||||
@@ -33,8 +31,7 @@
|
||||
"sub_heading": "Agent 2 subheading",
|
||||
"description": "Agent 2 description",
|
||||
"runs": 20,
|
||||
"rating": 4.0,
|
||||
"agent_graph_id": "test-graph-2"
|
||||
"rating": 4.0
|
||||
},
|
||||
{
|
||||
"slug": "agent-3",
|
||||
@@ -45,8 +42,7 @@
|
||||
"sub_heading": "Agent 3 subheading",
|
||||
"description": "Agent 3 description",
|
||||
"runs": 30,
|
||||
"rating": 4.0,
|
||||
"agent_graph_id": "test-graph-2"
|
||||
"rating": 4.0
|
||||
},
|
||||
{
|
||||
"slug": "agent-4",
|
||||
@@ -57,8 +53,7 @@
|
||||
"sub_heading": "Agent 4 subheading",
|
||||
"description": "Agent 4 description",
|
||||
"runs": 40,
|
||||
"rating": 4.0,
|
||||
"agent_graph_id": "test-graph-2"
|
||||
"rating": 4.0
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -9,8 +9,7 @@
|
||||
"sub_heading": "Search agent subheading",
|
||||
"description": "Specific search term description",
|
||||
"runs": 75,
|
||||
"rating": 4.2,
|
||||
"agent_graph_id": "test-graph-search"
|
||||
"rating": 4.2
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -9,8 +9,7 @@
|
||||
"sub_heading": "Top agent subheading",
|
||||
"description": "Top agent description",
|
||||
"runs": 1000,
|
||||
"rating": 5.0,
|
||||
"agent_graph_id": "test-graph-3"
|
||||
"rating": 5.0
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -9,8 +9,7 @@
|
||||
"sub_heading": "Featured agent subheading",
|
||||
"description": "Featured agent description",
|
||||
"runs": 100,
|
||||
"rating": 4.5,
|
||||
"agent_graph_id": "test-graph-1"
|
||||
"rating": 4.5
|
||||
}
|
||||
],
|
||||
"pagination": {
|
||||
|
||||
@@ -31,10 +31,6 @@
|
||||
"has_sensitive_action": false,
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"execution_count": 0,
|
||||
"success_rate": null,
|
||||
"avg_correctness_score": null,
|
||||
"recent_executions": [],
|
||||
"can_access_graph": true,
|
||||
"is_latest_version": true,
|
||||
"is_favorite": false,
|
||||
@@ -76,10 +72,6 @@
|
||||
"has_sensitive_action": false,
|
||||
"trigger_setup_info": null,
|
||||
"new_output": false,
|
||||
"execution_count": 0,
|
||||
"success_rate": null,
|
||||
"avg_correctness_score": null,
|
||||
"recent_executions": [],
|
||||
"can_access_graph": false,
|
||||
"is_latest_version": true,
|
||||
"is_favorite": false,
|
||||
|
||||
@@ -57,8 +57,7 @@ class TestDecomposeGoal:
|
||||
|
||||
result = await core.decompose_goal("Build a chatbot")
|
||||
|
||||
# library_agents defaults to None
|
||||
mock_external.assert_called_once_with("Build a chatbot", "", None)
|
||||
mock_external.assert_called_once_with("Build a chatbot", "")
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -75,8 +74,7 @@ class TestDecomposeGoal:
|
||||
|
||||
await core.decompose_goal("Build a chatbot", "Use Python")
|
||||
|
||||
# library_agents defaults to None
|
||||
mock_external.assert_called_once_with("Build a chatbot", "Use Python", None)
|
||||
mock_external.assert_called_once_with("Build a chatbot", "Use Python")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_service_failure(self):
|
||||
@@ -111,8 +109,7 @@ class TestGenerateAgent:
|
||||
instructions = {"type": "instructions", "steps": ["Step 1"]}
|
||||
result = await core.generate_agent(instructions)
|
||||
|
||||
# library_agents defaults to None
|
||||
mock_external.assert_called_once_with(instructions, None)
|
||||
mock_external.assert_called_once_with(instructions)
|
||||
# Result should have id, version, is_active added if not present
|
||||
assert result is not None
|
||||
assert result["name"] == "Test Agent"
|
||||
@@ -177,8 +174,7 @@ class TestGenerateAgentPatch:
|
||||
current_agent = {"nodes": [], "links": []}
|
||||
result = await core.generate_agent_patch("Add a node", current_agent)
|
||||
|
||||
# library_agents defaults to None
|
||||
mock_external.assert_called_once_with("Add a node", current_agent, None)
|
||||
mock_external.assert_called_once_with("Add a node", current_agent)
|
||||
assert result == expected_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@@ -1,857 +0,0 @@
|
||||
"""
|
||||
Tests for library agent fetching functionality in agent generator.
|
||||
|
||||
This test suite verifies the search-based library agent fetching,
|
||||
including the combination of library and marketplace agents.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.api.features.chat.tools.agent_generator import core
|
||||
|
||||
|
||||
class TestGetLibraryAgentsForGeneration:
|
||||
"""Test get_library_agents_for_generation function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetches_agents_with_search_term(self):
|
||||
"""Test that search_term is passed to the library db."""
|
||||
# Create a mock agent with proper attribute values
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.graph_id = "agent-123"
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.name = "Email Agent"
|
||||
mock_agent.description = "Sends emails"
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
mock_agent.recent_executions = []
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = [mock_agent]
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="send email",
|
||||
)
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term="send email",
|
||||
page=1,
|
||||
page_size=15,
|
||||
include_executions=True,
|
||||
)
|
||||
|
||||
# Verify result format
|
||||
assert len(result) == 1
|
||||
assert result[0]["graph_id"] == "agent-123"
|
||||
assert result[0]["name"] == "Email Agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_excludes_specified_graph_id(self):
|
||||
"""Test that agents with excluded graph_id are filtered out."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = [
|
||||
MagicMock(
|
||||
graph_id="agent-123",
|
||||
graph_version=1,
|
||||
name="Agent 1",
|
||||
description="First agent",
|
||||
input_schema={},
|
||||
output_schema={},
|
||||
recent_executions=[],
|
||||
),
|
||||
MagicMock(
|
||||
graph_id="agent-456",
|
||||
graph_version=1,
|
||||
name="Agent 2",
|
||||
description="Second agent",
|
||||
input_schema={},
|
||||
output_schema={},
|
||||
recent_executions=[],
|
||||
),
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
):
|
||||
result = await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
exclude_graph_id="agent-123",
|
||||
)
|
||||
|
||||
# Verify the excluded agent is not in results
|
||||
assert len(result) == 1
|
||||
assert result[0]["graph_id"] == "agent-456"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_respects_max_results(self):
|
||||
"""Test that max_results parameter limits the page_size."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_list:
|
||||
await core.get_library_agents_for_generation(
|
||||
user_id="user-123",
|
||||
max_results=5,
|
||||
)
|
||||
|
||||
mock_list.assert_called_once_with(
|
||||
user_id="user-123",
|
||||
search_term=None,
|
||||
page=1,
|
||||
page_size=5,
|
||||
include_executions=True,
|
||||
)
|
||||
|
||||
|
||||
class TestSearchMarketplaceAgentsForGeneration:
|
||||
"""Test search_marketplace_agents_for_generation function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_searches_marketplace_with_query(self):
|
||||
"""Test that marketplace is searched with the query."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = [
|
||||
MagicMock(
|
||||
agent_name="Public Agent",
|
||||
description="A public agent",
|
||||
sub_heading="Does something useful",
|
||||
creator="creator-1",
|
||||
agent_graph_id="graph-123",
|
||||
)
|
||||
]
|
||||
|
||||
mock_graph = MagicMock()
|
||||
mock_graph.id = "graph-123"
|
||||
mock_graph.version = 1
|
||||
mock_graph.input_schema = {"type": "object"}
|
||||
mock_graph.output_schema = {"type": "object"}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.api.features.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
) as mock_search,
|
||||
patch(
|
||||
"backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"graph-123": mock_graph},
|
||||
),
|
||||
):
|
||||
result = await core.search_marketplace_agents_for_generation(
|
||||
search_query="automation",
|
||||
max_results=10,
|
||||
)
|
||||
|
||||
mock_search.assert_called_once_with(
|
||||
search_query="automation",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "Public Agent"
|
||||
assert result[0]["graph_id"] == "graph-123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handles_marketplace_error_gracefully(self):
|
||||
"""Test that marketplace errors don't crash the function."""
|
||||
with patch(
|
||||
"backend.api.features.store.db.get_store_agents",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Marketplace unavailable"),
|
||||
):
|
||||
result = await core.search_marketplace_agents_for_generation(
|
||||
search_query="test"
|
||||
)
|
||||
|
||||
# Should return empty list, not raise exception
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestGetAllRelevantAgentsForGeneration:
|
||||
"""Test get_all_relevant_agents_for_generation function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_combines_library_and_marketplace_agents(self):
|
||||
"""Test that agents from both sources are combined."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "lib-123",
|
||||
"graph_version": 1,
|
||||
"name": "Library Agent",
|
||||
"description": "From library",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
marketplace_agents = [
|
||||
{
|
||||
"graph_id": "market-456",
|
||||
"graph_version": 1,
|
||||
"name": "Market Agent",
|
||||
"description": "From marketplace",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_library_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=library_agents,
|
||||
):
|
||||
with patch.object(
|
||||
core,
|
||||
"search_marketplace_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=marketplace_agents,
|
||||
):
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="test query",
|
||||
include_marketplace=True,
|
||||
)
|
||||
|
||||
# Library agents should come first
|
||||
assert len(result) == 2
|
||||
assert result[0]["name"] == "Library Agent"
|
||||
assert result[1]["name"] == "Market Agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_by_graph_id(self):
|
||||
"""Test that marketplace agents with same graph_id as library are excluded."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "shared-123",
|
||||
"graph_version": 1,
|
||||
"name": "Shared Agent",
|
||||
"description": "From library",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
marketplace_agents = [
|
||||
{
|
||||
"graph_id": "shared-123", # Same graph_id, should be deduplicated
|
||||
"graph_version": 1,
|
||||
"name": "Shared Agent",
|
||||
"description": "From marketplace",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
},
|
||||
{
|
||||
"graph_id": "unique-456",
|
||||
"graph_version": 1,
|
||||
"name": "Unique Agent",
|
||||
"description": "Only in marketplace",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
},
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_library_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=library_agents,
|
||||
):
|
||||
with patch.object(
|
||||
core,
|
||||
"search_marketplace_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=marketplace_agents,
|
||||
):
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="test",
|
||||
include_marketplace=True,
|
||||
)
|
||||
|
||||
# Shared Agent from marketplace should be excluded by graph_id
|
||||
assert len(result) == 2
|
||||
names = [a["name"] for a in result]
|
||||
assert "Shared Agent" in names
|
||||
assert "Unique Agent" in names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_marketplace_when_disabled(self):
|
||||
"""Test that marketplace is not searched when include_marketplace=False."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "lib-123",
|
||||
"graph_version": 1,
|
||||
"name": "Library Agent",
|
||||
"description": "From library",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_library_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=library_agents,
|
||||
):
|
||||
with patch.object(
|
||||
core,
|
||||
"search_marketplace_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_marketplace:
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="test",
|
||||
include_marketplace=False,
|
||||
)
|
||||
|
||||
# Marketplace should not be called
|
||||
mock_marketplace.assert_not_called()
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_marketplace_when_no_search_query(self):
|
||||
"""Test that marketplace is not searched without a search query."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "lib-123",
|
||||
"graph_version": 1,
|
||||
"name": "Library Agent",
|
||||
"description": "From library",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_library_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=library_agents,
|
||||
):
|
||||
with patch.object(
|
||||
core,
|
||||
"search_marketplace_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_marketplace:
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query=None, # No search query
|
||||
include_marketplace=True,
|
||||
)
|
||||
|
||||
# Marketplace should not be called without search query
|
||||
mock_marketplace.assert_not_called()
|
||||
assert len(result) == 1
|
||||
|
||||
|
||||
class TestExtractSearchTermsFromSteps:
|
||||
"""Test extract_search_terms_from_steps function."""
|
||||
|
||||
def test_extracts_terms_from_instructions_type(self):
|
||||
"""Test extraction from valid instructions decomposition result."""
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{
|
||||
"description": "Send an email notification",
|
||||
"block_name": "GmailSendBlock",
|
||||
},
|
||||
{"description": "Fetch weather data", "action": "Get weather API"},
|
||||
],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
assert "Send an email notification" in result
|
||||
assert "GmailSendBlock" in result
|
||||
assert "Fetch weather data" in result
|
||||
assert "Get weather API" in result
|
||||
|
||||
def test_returns_empty_for_non_instructions_type(self):
|
||||
"""Test that non-instructions types return empty list."""
|
||||
decomposition_result = {
|
||||
"type": "clarifying_questions",
|
||||
"questions": [{"question": "What email?"}],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
assert result == []
|
||||
|
||||
def test_deduplicates_terms_case_insensitively(self):
|
||||
"""Test that duplicate terms are removed (case-insensitive)."""
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{"description": "Send Email", "name": "send email"},
|
||||
{"description": "Other task"},
|
||||
],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
# Should only have one "send email" variant
|
||||
email_terms = [t for t in result if "email" in t.lower()]
|
||||
assert len(email_terms) == 1
|
||||
|
||||
def test_filters_short_terms(self):
|
||||
"""Test that terms with 3 or fewer characters are filtered out."""
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{"description": "ab", "action": "xyz"}, # Both too short
|
||||
{"description": "Valid term here"},
|
||||
],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
assert "ab" not in result
|
||||
assert "xyz" not in result
|
||||
assert "Valid term here" in result
|
||||
|
||||
def test_handles_empty_steps(self):
|
||||
"""Test handling of empty steps list."""
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
result = core.extract_search_terms_from_steps(decomposition_result)
|
||||
|
||||
assert result == []
|
||||
|
||||
|
||||
class TestEnrichLibraryAgentsFromSteps:
|
||||
"""Test enrich_library_agents_from_steps function."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_enriches_with_additional_agents(self):
|
||||
"""Test that additional agents are found based on steps."""
|
||||
existing_agents = [
|
||||
{
|
||||
"graph_id": "existing-123",
|
||||
"graph_version": 1,
|
||||
"name": "Existing Agent",
|
||||
"description": "Already fetched",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
additional_agents = [
|
||||
{
|
||||
"graph_id": "new-456",
|
||||
"graph_version": 1,
|
||||
"name": "Email Agent",
|
||||
"description": "For sending emails",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{"description": "Send email notification"},
|
||||
],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=additional_agents,
|
||||
):
|
||||
result = await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should have both existing and new agents
|
||||
assert len(result) == 2
|
||||
names = [a["name"] for a in result]
|
||||
assert "Existing Agent" in names
|
||||
assert "Email Agent" in names
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_by_graph_id(self):
|
||||
"""Test that agents with same graph_id are not duplicated."""
|
||||
existing_agents = [
|
||||
{
|
||||
"graph_id": "agent-123",
|
||||
"graph_version": 1,
|
||||
"name": "Existing Agent",
|
||||
"description": "Already fetched",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
# Additional search returns same agent
|
||||
additional_agents = [
|
||||
{
|
||||
"graph_id": "agent-123", # Same ID
|
||||
"graph_version": 1,
|
||||
"name": "Existing Agent Copy",
|
||||
"description": "Same agent different name",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [{"description": "Some action"}],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=additional_agents,
|
||||
):
|
||||
result = await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should not duplicate
|
||||
assert len(result) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deduplicates_by_name(self):
|
||||
"""Test that agents with same name are not duplicated."""
|
||||
existing_agents = [
|
||||
{
|
||||
"graph_id": "agent-123",
|
||||
"graph_version": 1,
|
||||
"name": "Email Agent",
|
||||
"description": "Already fetched",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
# Additional search returns agent with same name but different ID
|
||||
additional_agents = [
|
||||
{
|
||||
"graph_id": "agent-456", # Different ID
|
||||
"graph_version": 1,
|
||||
"name": "Email Agent", # Same name
|
||||
"description": "Different agent same name",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [{"description": "Send email"}],
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_all_relevant_agents_for_generation",
|
||||
new_callable=AsyncMock,
|
||||
return_value=additional_agents,
|
||||
):
|
||||
result = await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should not duplicate by name
|
||||
assert len(result) == 1
|
||||
assert result[0].get("graph_id") == "agent-123" # Original kept
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_existing_when_no_steps(self):
|
||||
"""Test that existing agents are returned when no search terms extracted."""
|
||||
existing_agents = [
|
||||
{
|
||||
"graph_id": "existing-123",
|
||||
"graph_version": 1,
|
||||
"name": "Existing Agent",
|
||||
"description": "Already fetched",
|
||||
"input_schema": {},
|
||||
"output_schema": {},
|
||||
}
|
||||
]
|
||||
|
||||
decomposition_result = {
|
||||
"type": "clarifying_questions", # Not instructions type
|
||||
"questions": [],
|
||||
}
|
||||
|
||||
result = await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should return existing unchanged
|
||||
assert result == existing_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_limits_search_terms_to_three(self):
|
||||
"""Test that only first 3 search terms are used."""
|
||||
existing_agents = []
|
||||
|
||||
decomposition_result = {
|
||||
"type": "instructions",
|
||||
"steps": [
|
||||
{"description": "First action"},
|
||||
{"description": "Second action"},
|
||||
{"description": "Third action"},
|
||||
{"description": "Fourth action"},
|
||||
{"description": "Fifth action"},
|
||||
],
|
||||
}
|
||||
|
||||
call_count = 0
|
||||
|
||||
async def mock_get_agents(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return []
|
||||
|
||||
with patch.object(
|
||||
core,
|
||||
"get_all_relevant_agents_for_generation",
|
||||
side_effect=mock_get_agents,
|
||||
):
|
||||
await core.enrich_library_agents_from_steps(
|
||||
user_id="user-123",
|
||||
decomposition_result=decomposition_result,
|
||||
existing_agents=existing_agents,
|
||||
)
|
||||
|
||||
# Should only make 3 calls (limited to first 3 terms)
|
||||
assert call_count == 3
|
||||
|
||||
|
||||
class TestExtractUuidsFromText:
|
||||
"""Test extract_uuids_from_text function."""
|
||||
|
||||
def test_extracts_single_uuid(self):
|
||||
"""Test extraction of a single UUID from text."""
|
||||
text = "Use my agent 46631191-e8a8-486f-ad90-84f89738321d for this task"
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert len(result) == 1
|
||||
assert "46631191-e8a8-486f-ad90-84f89738321d" in result
|
||||
|
||||
def test_extracts_multiple_uuids(self):
|
||||
"""Test extraction of multiple UUIDs from text."""
|
||||
text = (
|
||||
"Combine agents 11111111-1111-4111-8111-111111111111 "
|
||||
"and 22222222-2222-4222-9222-222222222222"
|
||||
)
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert len(result) == 2
|
||||
assert "11111111-1111-4111-8111-111111111111" in result
|
||||
assert "22222222-2222-4222-9222-222222222222" in result
|
||||
|
||||
def test_deduplicates_uuids(self):
|
||||
"""Test that duplicate UUIDs are deduplicated."""
|
||||
text = (
|
||||
"Use 46631191-e8a8-486f-ad90-84f89738321d twice: "
|
||||
"46631191-e8a8-486f-ad90-84f89738321d"
|
||||
)
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_normalizes_to_lowercase(self):
|
||||
"""Test that UUIDs are normalized to lowercase."""
|
||||
text = "Use 46631191-E8A8-486F-AD90-84F89738321D"
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert result[0] == "46631191-e8a8-486f-ad90-84f89738321d"
|
||||
|
||||
def test_returns_empty_for_no_uuids(self):
|
||||
"""Test that empty list is returned when no UUIDs found."""
|
||||
text = "Create an email agent that sends notifications"
|
||||
result = core.extract_uuids_from_text(text)
|
||||
assert result == []
|
||||
|
||||
def test_ignores_invalid_uuids(self):
|
||||
"""Test that invalid UUID-like strings are ignored."""
|
||||
text = "Not a valid UUID: 12345678-1234-1234-1234-123456789abc"
|
||||
result = core.extract_uuids_from_text(text)
|
||||
# UUID v4 requires specific patterns (4 in third group, 8/9/a/b in fourth)
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
class TestGetLibraryAgentById:
|
||||
"""Test get_library_agent_by_id function (and its alias get_library_agent_by_graph_id)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_agent_when_found_by_graph_id(self):
|
||||
"""Test that agent is returned when found by graph_id."""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.graph_id = "agent-123"
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.name = "Test Agent"
|
||||
mock_agent.description = "Test description"
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
with patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is not None
|
||||
assert result["graph_id"] == "agent-123"
|
||||
assert result["name"] == "Test Agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_to_library_agent_id(self):
|
||||
"""Test that lookup falls back to library agent ID when graph_id not found."""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.graph_id = "graph-456" # Different from the lookup ID
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.name = "Library Agent"
|
||||
mock_agent.description = "Found by library ID"
|
||||
mock_agent.input_schema = {"properties": {}}
|
||||
mock_agent.output_schema = {"properties": {}}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None, # Not found by graph_id
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent, # Found by library ID
|
||||
),
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "library-id-123")
|
||||
|
||||
assert result is not None
|
||||
assert result["graph_id"] == "graph-456"
|
||||
assert result["name"] == "Library Agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_when_not_found_by_either_method(self):
|
||||
"""Test that None is returned when agent not found by either method."""
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=core.NotFoundError("Not found"),
|
||||
),
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "nonexistent")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_none_on_exception(self):
|
||||
"""Test that None is returned when exception occurs in both lookups."""
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Exception("Database error"),
|
||||
),
|
||||
):
|
||||
result = await core.get_library_agent_by_id("user-123", "agent-123")
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_alias_works(self):
|
||||
"""Test that get_library_agent_by_graph_id is an alias for get_library_agent_by_id."""
|
||||
assert core.get_library_agent_by_graph_id is core.get_library_agent_by_id
|
||||
|
||||
|
||||
class TestGetAllRelevantAgentsWithUuids:
|
||||
"""Test UUID extraction in get_all_relevant_agents_for_generation."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetches_explicitly_mentioned_agents(self):
|
||||
"""Test that agents mentioned by UUID are fetched directly."""
|
||||
mock_agent = MagicMock()
|
||||
mock_agent.graph_id = "46631191-e8a8-486f-ad90-84f89738321d"
|
||||
mock_agent.graph_version = 1
|
||||
mock_agent.name = "Mentioned Agent"
|
||||
mock_agent.description = "Explicitly mentioned"
|
||||
mock_agent.input_schema = {}
|
||||
mock_agent.output_schema = {}
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.agents = []
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"get_library_agent_by_graph_id",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_agent,
|
||||
),
|
||||
patch.object(
|
||||
core.library_db,
|
||||
"list_library_agents",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_response,
|
||||
),
|
||||
):
|
||||
result = await core.get_all_relevant_agents_for_generation(
|
||||
user_id="user-123",
|
||||
search_query="Use agent 46631191-e8a8-486f-ad90-84f89738321d",
|
||||
include_marketplace=False,
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].get("graph_id") == "46631191-e8a8-486f-ad90-84f89738321d"
|
||||
assert result[0].get("name") == "Mentioned Agent"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@@ -102,7 +102,7 @@ class TestDecomposeGoalExternal:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_with_context(self):
|
||||
"""Test decomposition with additional context enriched into description."""
|
||||
"""Test decomposition with additional context."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
@@ -119,12 +119,9 @@ class TestDecomposeGoalExternal:
|
||||
"Build a chatbot", context="Use Python"
|
||||
)
|
||||
|
||||
expected_description = (
|
||||
"Build a chatbot\n\nAdditional context from user:\nUse Python"
|
||||
)
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/decompose-description",
|
||||
json={"description": expected_description},
|
||||
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -436,139 +433,5 @@ class TestGetBlocksExternal:
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestLibraryAgentsPassthrough:
|
||||
"""Test that library_agents are passed correctly in all requests."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset client singleton before each test."""
|
||||
service._settings = None
|
||||
service._client = None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_passes_library_agents(self):
|
||||
"""Test that library_agents are included in decompose goal payload."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "agent-123",
|
||||
"graph_version": 1,
|
||||
"name": "Email Sender",
|
||||
"description": "Sends emails",
|
||||
"input_schema": {"properties": {"to": {"type": "string"}}},
|
||||
"output_schema": {"properties": {"sent": {"type": "boolean"}}},
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.decompose_goal_external(
|
||||
"Send an email",
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_passes_library_agents(self):
|
||||
"""Test that library_agents are included in generate agent payload."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "agent-456",
|
||||
"graph_version": 2,
|
||||
"name": "Data Fetcher",
|
||||
"description": "Fetches data from API",
|
||||
"input_schema": {"properties": {"url": {"type": "string"}}},
|
||||
"output_schema": {"properties": {"data": {"type": "object"}}},
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": {"name": "Test Agent", "nodes": []},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.generate_agent_external(
|
||||
{"steps": ["Step 1"]},
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_agent_patch_passes_library_agents(self):
|
||||
"""Test that library_agents are included in patch generation payload."""
|
||||
library_agents = [
|
||||
{
|
||||
"graph_id": "agent-789",
|
||||
"graph_version": 1,
|
||||
"name": "Slack Notifier",
|
||||
"description": "Sends Slack messages",
|
||||
"input_schema": {"properties": {"message": {"type": "string"}}},
|
||||
"output_schema": {"properties": {"success": {"type": "boolean"}}},
|
||||
},
|
||||
]
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"agent_json": {"name": "Updated Agent", "nodes": []},
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.generate_agent_patch_external(
|
||||
"Add error handling",
|
||||
{"name": "Original Agent", "nodes": []},
|
||||
library_agents=library_agents,
|
||||
)
|
||||
|
||||
# Verify library_agents was passed in the payload
|
||||
call_args = mock_client.post.call_args
|
||||
assert call_args[1]["json"]["library_agents"] == library_agents
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_without_library_agents(self):
|
||||
"""Test that decompose goal works without library_agents."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
"type": "instructions",
|
||||
"steps": ["Step 1"],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.post.return_value = mock_response
|
||||
|
||||
with patch.object(service, "_get_client", return_value=mock_client):
|
||||
await service.decompose_goal_external("Build a workflow")
|
||||
|
||||
# Verify library_agents was NOT passed when not provided
|
||||
call_args = mock_client.post.call_args
|
||||
assert "library_agents" not in call_args[1]["json"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@@ -43,24 +43,19 @@ faker = Faker()
|
||||
# Constants for data generation limits (reduced for E2E tests)
|
||||
NUM_USERS = 15
|
||||
NUM_AGENT_BLOCKS = 30
|
||||
MIN_GRAPHS_PER_USER = 25
|
||||
MAX_GRAPHS_PER_USER = 25
|
||||
MIN_GRAPHS_PER_USER = 15
|
||||
MAX_GRAPHS_PER_USER = 15
|
||||
MIN_NODES_PER_GRAPH = 3
|
||||
MAX_NODES_PER_GRAPH = 6
|
||||
MIN_PRESETS_PER_USER = 2
|
||||
MAX_PRESETS_PER_USER = 3
|
||||
MIN_AGENTS_PER_USER = 25
|
||||
MAX_AGENTS_PER_USER = 25
|
||||
MIN_AGENTS_PER_USER = 15
|
||||
MAX_AGENTS_PER_USER = 15
|
||||
MIN_EXECUTIONS_PER_GRAPH = 2
|
||||
MAX_EXECUTIONS_PER_GRAPH = 8
|
||||
MIN_REVIEWS_PER_VERSION = 2
|
||||
MAX_REVIEWS_PER_VERSION = 5
|
||||
|
||||
# Guaranteed minimums for marketplace tests (deterministic)
|
||||
GUARANTEED_FEATURED_AGENTS = 8
|
||||
GUARANTEED_FEATURED_CREATORS = 5
|
||||
GUARANTEED_TOP_AGENTS = 10
|
||||
|
||||
|
||||
def get_image():
|
||||
"""Generate a consistent image URL using picsum.photos service."""
|
||||
@@ -390,7 +385,7 @@ class TestDataCreator:
|
||||
|
||||
library_agents = []
|
||||
for user in self.users:
|
||||
num_agents = random.randint(MIN_AGENTS_PER_USER, MAX_AGENTS_PER_USER)
|
||||
num_agents = 10 # Create exactly 10 agents per user
|
||||
|
||||
# Get available graphs for this user
|
||||
user_graphs = [
|
||||
@@ -512,17 +507,14 @@ class TestDataCreator:
|
||||
existing_profiles, min(num_creators, len(existing_profiles))
|
||||
)
|
||||
|
||||
# Guarantee at least GUARANTEED_FEATURED_CREATORS featured creators
|
||||
num_featured = max(GUARANTEED_FEATURED_CREATORS, int(num_creators * 0.5))
|
||||
# Mark about 50% of creators as featured (more for testing)
|
||||
num_featured = max(2, int(num_creators * 0.5))
|
||||
num_featured = min(
|
||||
num_featured, len(selected_profiles)
|
||||
) # Don't exceed available profiles
|
||||
featured_profile_ids = set(
|
||||
random.sample([p.id for p in selected_profiles], num_featured)
|
||||
)
|
||||
print(
|
||||
f"🎯 Creating {num_featured} featured creators (min: {GUARANTEED_FEATURED_CREATORS})"
|
||||
)
|
||||
|
||||
for profile in selected_profiles:
|
||||
try:
|
||||
@@ -553,25 +545,21 @@ class TestDataCreator:
|
||||
return profiles
|
||||
|
||||
async def create_test_store_submissions(self) -> List[Dict[str, Any]]:
|
||||
"""Create test store submissions using the API function.
|
||||
|
||||
DETERMINISTIC: Guarantees minimum featured agents for E2E tests.
|
||||
"""
|
||||
"""Create test store submissions using the API function."""
|
||||
print("Creating test store submissions...")
|
||||
|
||||
submissions = []
|
||||
approved_submissions = []
|
||||
featured_count = 0
|
||||
submission_counter = 0
|
||||
|
||||
# Create a special test submission for test123@gmail.com (ALWAYS approved + featured)
|
||||
# Create a special test submission for test123@gmail.com
|
||||
test_user = next(
|
||||
(user for user in self.users if user["email"] == "test123@gmail.com"), None
|
||||
)
|
||||
if test_user and self.agent_graphs:
|
||||
if test_user:
|
||||
# Special test data for consistent testing
|
||||
test_submission_data = {
|
||||
"user_id": test_user["id"],
|
||||
"agent_id": self.agent_graphs[0]["id"],
|
||||
"agent_id": self.agent_graphs[0]["id"], # Use first available graph
|
||||
"agent_version": 1,
|
||||
"slug": "test-agent-submission",
|
||||
"name": "Test Agent Submission",
|
||||
@@ -592,24 +580,37 @@ class TestDataCreator:
|
||||
submissions.append(test_submission.model_dump())
|
||||
print("✅ Created special test store submission for test123@gmail.com")
|
||||
|
||||
# ALWAYS approve and feature the test submission
|
||||
# Randomly approve, reject, or leave pending the test submission
|
||||
if test_submission.store_listing_version_id:
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=test_submission.store_listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Test submission approved",
|
||||
internal_comments="Auto-approved test submission",
|
||||
reviewer_id=test_user["id"],
|
||||
)
|
||||
approved_submissions.append(approved_submission.model_dump())
|
||||
print("✅ Approved test store submission")
|
||||
random_value = random.random()
|
||||
if random_value < 0.4: # 40% chance to approve
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=test_submission.store_listing_version_id,
|
||||
is_approved=True,
|
||||
external_comments="Test submission approved",
|
||||
internal_comments="Auto-approved test submission",
|
||||
reviewer_id=test_user["id"],
|
||||
)
|
||||
approved_submissions.append(approved_submission.model_dump())
|
||||
print("✅ Approved test store submission")
|
||||
|
||||
await prisma.storelistingversion.update(
|
||||
where={"id": test_submission.store_listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
print("🌟 Marked test agent as FEATURED")
|
||||
# Mark approved submission as featured
|
||||
await prisma.storelistingversion.update(
|
||||
where={"id": test_submission.store_listing_version_id},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
print("🌟 Marked test agent as FEATURED")
|
||||
elif random_value < 0.7: # 30% chance to reject (40% to 70%)
|
||||
await review_store_submission(
|
||||
store_listing_version_id=test_submission.store_listing_version_id,
|
||||
is_approved=False,
|
||||
external_comments="Test submission rejected - needs improvements",
|
||||
internal_comments="Auto-rejected test submission for E2E testing",
|
||||
reviewer_id=test_user["id"],
|
||||
)
|
||||
print("❌ Rejected test store submission")
|
||||
else: # 30% chance to leave pending (70% to 100%)
|
||||
print("⏳ Left test submission pending for review")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error creating test store submission: {e}")
|
||||
@@ -619,6 +620,7 @@ class TestDataCreator:
|
||||
|
||||
# Create regular submissions for all users
|
||||
for user in self.users:
|
||||
# Get available graphs for this specific user
|
||||
user_graphs = [
|
||||
g for g in self.agent_graphs if g.get("userId") == user["id"]
|
||||
]
|
||||
@@ -629,17 +631,18 @@ class TestDataCreator:
|
||||
)
|
||||
continue
|
||||
|
||||
# Create exactly 4 store submissions per user
|
||||
for submission_index in range(4):
|
||||
graph = random.choice(user_graphs)
|
||||
submission_counter += 1
|
||||
|
||||
try:
|
||||
print(
|
||||
f"Creating store submission for user {user['id']} with graph {graph['id']}"
|
||||
f"Creating store submission for user {user['id']} with graph {graph['id']} (owner: {graph.get('userId')})"
|
||||
)
|
||||
|
||||
# Use the API function to create store submission with correct parameters
|
||||
submission = await create_store_submission(
|
||||
user_id=user["id"],
|
||||
user_id=user["id"], # Must match graph's userId
|
||||
agent_id=graph["id"],
|
||||
agent_version=graph.get("version", 1),
|
||||
slug=faker.slug(),
|
||||
@@ -648,24 +651,22 @@ class TestDataCreator:
|
||||
video_url=get_video_url() if random.random() < 0.3 else None,
|
||||
image_urls=[get_image() for _ in range(3)],
|
||||
description=faker.text(),
|
||||
categories=[get_category()],
|
||||
categories=[
|
||||
get_category()
|
||||
], # Single category from predefined list
|
||||
changes_summary="Initial E2E test submission",
|
||||
)
|
||||
submissions.append(submission.model_dump())
|
||||
print(f"✅ Created store submission: {submission.name}")
|
||||
|
||||
# Randomly approve, reject, or leave pending the submission
|
||||
if submission.store_listing_version_id:
|
||||
# DETERMINISTIC: First N submissions are always approved
|
||||
# First GUARANTEED_FEATURED_AGENTS of those are always featured
|
||||
should_approve = (
|
||||
submission_counter <= GUARANTEED_TOP_AGENTS
|
||||
or random.random() < 0.4
|
||||
)
|
||||
should_feature = featured_count < GUARANTEED_FEATURED_AGENTS
|
||||
|
||||
if should_approve:
|
||||
random_value = random.random()
|
||||
if random_value < 0.4: # 40% chance to approve
|
||||
try:
|
||||
# Pick a random user as the reviewer (admin)
|
||||
reviewer_id = random.choice(self.users)["id"]
|
||||
|
||||
approved_submission = await review_store_submission(
|
||||
store_listing_version_id=submission.store_listing_version_id,
|
||||
is_approved=True,
|
||||
@@ -680,7 +681,16 @@ class TestDataCreator:
|
||||
f"✅ Approved store submission: {submission.name}"
|
||||
)
|
||||
|
||||
if should_feature:
|
||||
# Mark some agents as featured during creation (30% chance)
|
||||
# More likely for creators and first submissions
|
||||
is_creator = user["id"] in [
|
||||
p.get("userId") for p in self.profiles
|
||||
]
|
||||
feature_chance = (
|
||||
0.5 if is_creator else 0.2
|
||||
) # 50% for creators, 20% for others
|
||||
|
||||
if random.random() < feature_chance:
|
||||
try:
|
||||
await prisma.storelistingversion.update(
|
||||
where={
|
||||
@@ -688,25 +698,8 @@ class TestDataCreator:
|
||||
},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
print(
|
||||
f"🌟 Marked agent as FEATURED ({featured_count}/{GUARANTEED_FEATURED_AGENTS}): {submission.name}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Warning: Could not mark submission as featured: {e}"
|
||||
)
|
||||
elif random.random() < 0.2:
|
||||
try:
|
||||
await prisma.storelistingversion.update(
|
||||
where={
|
||||
"id": submission.store_listing_version_id
|
||||
},
|
||||
data={"isFeatured": True},
|
||||
)
|
||||
featured_count += 1
|
||||
print(
|
||||
f"🌟 Marked agent as FEATURED (bonus): {submission.name}"
|
||||
f"🌟 Marked agent as FEATURED: {submission.name}"
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
@@ -717,9 +710,11 @@ class TestDataCreator:
|
||||
print(
|
||||
f"Warning: Could not approve submission {submission.name}: {e}"
|
||||
)
|
||||
elif random.random() < 0.5:
|
||||
elif random_value < 0.7: # 30% chance to reject (40% to 70%)
|
||||
try:
|
||||
# Pick a random user as the reviewer (admin)
|
||||
reviewer_id = random.choice(self.users)["id"]
|
||||
|
||||
await review_store_submission(
|
||||
store_listing_version_id=submission.store_listing_version_id,
|
||||
is_approved=False,
|
||||
@@ -734,7 +729,7 @@ class TestDataCreator:
|
||||
print(
|
||||
f"Warning: Could not reject submission {submission.name}: {e}"
|
||||
)
|
||||
else:
|
||||
else: # 30% chance to leave pending (70% to 100%)
|
||||
print(
|
||||
f"⏳ Left submission pending for review: {submission.name}"
|
||||
)
|
||||
@@ -748,13 +743,9 @@ class TestDataCreator:
|
||||
traceback.print_exc()
|
||||
continue
|
||||
|
||||
print("\n📊 Store Submissions Summary:")
|
||||
print(f" Created: {len(submissions)}")
|
||||
print(f" Approved: {len(approved_submissions)}")
|
||||
print(
|
||||
f" Featured: {featured_count} (guaranteed min: {GUARANTEED_FEATURED_AGENTS})"
|
||||
f"Created {len(submissions)} store submissions, approved {len(approved_submissions)}"
|
||||
)
|
||||
|
||||
self.store_submissions = submissions
|
||||
return submissions
|
||||
|
||||
@@ -834,15 +825,12 @@ class TestDataCreator:
|
||||
print(f"✅ Agent blocks available: {len(self.agent_blocks)}")
|
||||
print(f"✅ Agent graphs created: {len(self.agent_graphs)}")
|
||||
print(f"✅ Library agents created: {len(self.library_agents)}")
|
||||
print(f"✅ Creator profiles updated: {len(self.profiles)}")
|
||||
print(f"✅ Store submissions created: {len(self.store_submissions)}")
|
||||
print(f"✅ Creator profiles updated: {len(self.profiles)} (some featured)")
|
||||
print(
|
||||
f"✅ Store submissions created: {len(self.store_submissions)} (some marked as featured during creation)"
|
||||
)
|
||||
print(f"✅ API keys created: {len(self.api_keys)}")
|
||||
print(f"✅ Presets created: {len(self.presets)}")
|
||||
print("\n🎯 Deterministic Guarantees:")
|
||||
print(f" • Featured agents: >= {GUARANTEED_FEATURED_AGENTS}")
|
||||
print(f" • Featured creators: >= {GUARANTEED_FEATURED_CREATORS}")
|
||||
print(f" • Top agents (approved): >= {GUARANTEED_TOP_AGENTS}")
|
||||
print(f" • Library agents per user: >= {MIN_AGENTS_PER_USER}")
|
||||
print("\n🚀 Your E2E test database is ready to use!")
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"use client";
|
||||
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
||||
import { getOnboardingStatus, resolveResponse } from "@/app/api/helpers";
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { resolveResponse, getOnboardingStatus } from "@/app/api/helpers";
|
||||
import { getV1OnboardingState } from "@/app/api/__generated__/endpoints/onboarding/onboarding";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
|
||||
export default function OnboardingPage() {
|
||||
const router = useRouter();
|
||||
@@ -12,10 +13,12 @@ export default function OnboardingPage() {
|
||||
async function redirectToStep() {
|
||||
try {
|
||||
// Check if onboarding is enabled (also gets chat flag for redirect)
|
||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
||||
const { shouldShowOnboarding, isChatEnabled } =
|
||||
await getOnboardingStatus();
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
if (!shouldShowOnboarding) {
|
||||
router.replace("/");
|
||||
router.replace(homepageRoute);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -23,7 +26,7 @@ export default function OnboardingPage() {
|
||||
|
||||
// Handle completed onboarding
|
||||
if (onboarding.completedSteps.includes("GET_RESULTS")) {
|
||||
router.replace("/");
|
||||
router.replace(homepageRoute);
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import { getOnboardingStatus } from "@/app/api/helpers";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { NextResponse } from "next/server";
|
||||
import { revalidatePath } from "next/cache";
|
||||
import { getOnboardingStatus } from "@/app/api/helpers";
|
||||
|
||||
// Handle the callback to complete the user session login
|
||||
export async function GET(request: Request) {
|
||||
@@ -26,12 +27,13 @@ export async function GET(request: Request) {
|
||||
await api.createUser();
|
||||
|
||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
||||
const { shouldShowOnboarding, isChatEnabled } =
|
||||
await getOnboardingStatus();
|
||||
if (shouldShowOnboarding) {
|
||||
next = "/onboarding";
|
||||
revalidatePath("/onboarding", "layout");
|
||||
} else {
|
||||
next = "/";
|
||||
next = getHomepageRoute(isChatEnabled);
|
||||
revalidatePath(next, "layout");
|
||||
}
|
||||
} catch (createUserError) {
|
||||
|
||||
@@ -857,7 +857,7 @@ export const CustomNode = React.memo(
|
||||
})();
|
||||
|
||||
const hasAdvancedFields =
|
||||
data.inputSchema?.properties &&
|
||||
data.inputSchema &&
|
||||
Object.entries(data.inputSchema.properties).some(([key, value]) => {
|
||||
return (
|
||||
value.advanced === true && !data.inputSchema.required?.includes(key)
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
"use client";
|
||||
import { FeatureFlagPage } from "@/services/feature-flags/FeatureFlagPage";
|
||||
import { Flag } from "@/services/feature-flags/use-get-flag";
|
||||
import { type ReactNode } from "react";
|
||||
import type { ReactNode } from "react";
|
||||
import { CopilotShell } from "./components/CopilotShell/CopilotShell";
|
||||
|
||||
export default function CopilotLayout({ children }: { children: ReactNode }) {
|
||||
return (
|
||||
<FeatureFlagPage flag={Flag.CHAT} whenDisabled="/library">
|
||||
<CopilotShell>{children}</CopilotShell>
|
||||
</FeatureFlagPage>
|
||||
);
|
||||
return <CopilotShell>{children}</CopilotShell>;
|
||||
}
|
||||
|
||||
@@ -14,8 +14,14 @@ export default function CopilotPage() {
|
||||
const isInterruptModalOpen = useCopilotStore((s) => s.isInterruptModalOpen);
|
||||
const confirmInterrupt = useCopilotStore((s) => s.confirmInterrupt);
|
||||
const cancelInterrupt = useCopilotStore((s) => s.cancelInterrupt);
|
||||
const { greetingName, quickActions, isLoading, hasSession, initialPrompt } =
|
||||
state;
|
||||
const {
|
||||
greetingName,
|
||||
quickActions,
|
||||
isLoading,
|
||||
hasSession,
|
||||
initialPrompt,
|
||||
isReady,
|
||||
} = state;
|
||||
const {
|
||||
handleQuickAction,
|
||||
startChatWithPrompt,
|
||||
@@ -23,6 +29,8 @@ export default function CopilotPage() {
|
||||
handleStreamingChange,
|
||||
} = handlers;
|
||||
|
||||
if (!isReady) return null;
|
||||
|
||||
if (hasSession) {
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
|
||||
@@ -3,11 +3,18 @@ import {
|
||||
postV2CreateSession,
|
||||
} from "@/app/api/__generated__/endpoints/chat/chat";
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useOnboarding } from "@/providers/onboarding/onboarding-provider";
|
||||
import {
|
||||
Flag,
|
||||
type FlagValues,
|
||||
useGetFlag,
|
||||
} from "@/services/feature-flags/use-get-flag";
|
||||
import { SessionKey, sessionStorage } from "@/services/storage/session-storage";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { useCopilotStore } from "./copilot-page-store";
|
||||
@@ -26,6 +33,22 @@ export function useCopilotPage() {
|
||||
const isCreating = useCopilotStore((s) => s.isCreatingSession);
|
||||
const setIsCreating = useCopilotStore((s) => s.setIsCreatingSession);
|
||||
|
||||
// Complete VISIT_COPILOT onboarding step to grant $5 welcome bonus
|
||||
useEffect(() => {
|
||||
if (isLoggedIn) {
|
||||
completeStep("VISIT_COPILOT");
|
||||
}
|
||||
}, [completeStep, isLoggedIn]);
|
||||
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const flags = useFlags<FlagValues>();
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||
const isFlagReady =
|
||||
!isLaunchDarklyConfigured || flags[Flag.CHAT] !== undefined;
|
||||
|
||||
const greetingName = getGreetingName(user);
|
||||
const quickActions = getQuickActions();
|
||||
|
||||
@@ -35,8 +58,11 @@ export function useCopilotPage() {
|
||||
: undefined;
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoggedIn) completeStep("VISIT_COPILOT");
|
||||
}, [completeStep, isLoggedIn]);
|
||||
if (!isFlagReady) return;
|
||||
if (isChatEnabled === false) {
|
||||
router.replace(homepageRoute);
|
||||
}
|
||||
}, [homepageRoute, isChatEnabled, isFlagReady, router]);
|
||||
|
||||
async function startChatWithPrompt(prompt: string) {
|
||||
if (!prompt?.trim()) return;
|
||||
@@ -90,6 +116,7 @@ export function useCopilotPage() {
|
||||
isLoading: isUserLoading,
|
||||
hasSession,
|
||||
initialPrompt,
|
||||
isReady: isFlagReady && isChatEnabled !== false && isLoggedIn,
|
||||
},
|
||||
handlers: {
|
||||
handleQuickAction,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import { ErrorCard } from "@/components/molecules/ErrorCard/ErrorCard";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { Suspense } from "react";
|
||||
import { getErrorDetails } from "./helpers";
|
||||
@@ -9,6 +11,8 @@ function ErrorPageContent() {
|
||||
const searchParams = useSearchParams();
|
||||
const errorMessage = searchParams.get("message");
|
||||
const errorDetails = getErrorDetails(errorMessage);
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
function handleRetry() {
|
||||
// Auth-related errors should redirect to login
|
||||
@@ -26,7 +30,7 @@ function ErrorPageContent() {
|
||||
}, 2000);
|
||||
} else {
|
||||
// For server/network errors, go to home
|
||||
window.location.href = "/";
|
||||
window.location.href = homepageRoute;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"use server";
|
||||
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import BackendAPI from "@/lib/autogpt-server-api";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { loginFormSchema } from "@/types/auth";
|
||||
@@ -37,8 +38,10 @@ export async function login(email: string, password: string) {
|
||||
await api.createUser();
|
||||
|
||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
||||
const next = shouldShowOnboarding ? "/onboarding" : "/";
|
||||
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
||||
const next = shouldShowOnboarding
|
||||
? "/onboarding"
|
||||
: getHomepageRoute(isChatEnabled);
|
||||
|
||||
return {
|
||||
success: true,
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { environment } from "@/services/environment";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { loginFormSchema, LoginProvider } from "@/types/auth";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
@@ -20,15 +22,17 @@ export function useLoginPage() {
|
||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||
const isCloudEnv = environment.isCloud();
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
// Get redirect destination from 'next' query parameter
|
||||
const nextUrl = searchParams.get("next");
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoggedIn && !isLoggingIn) {
|
||||
router.push(nextUrl || "/");
|
||||
router.push(nextUrl || homepageRoute);
|
||||
}
|
||||
}, [isLoggedIn, isLoggingIn, nextUrl, router]);
|
||||
}, [homepageRoute, isLoggedIn, isLoggingIn, nextUrl, router]);
|
||||
|
||||
const form = useForm<z.infer<typeof loginFormSchema>>({
|
||||
resolver: zodResolver(loginFormSchema),
|
||||
@@ -94,7 +98,7 @@ export function useLoginPage() {
|
||||
}
|
||||
|
||||
// Prefer URL's next parameter, then use backend-determined route
|
||||
router.replace(nextUrl || result.next || "/");
|
||||
router.replace(nextUrl || result.next || homepageRoute);
|
||||
} catch (error) {
|
||||
toast({
|
||||
title:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"use server";
|
||||
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { getServerSupabase } from "@/lib/supabase/server/getServerSupabase";
|
||||
import { signupFormSchema } from "@/types/auth";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
@@ -58,8 +59,10 @@ export async function signup(
|
||||
}
|
||||
|
||||
// Get onboarding status from backend (includes chat flag evaluated for this user)
|
||||
const { shouldShowOnboarding } = await getOnboardingStatus();
|
||||
const next = shouldShowOnboarding ? "/onboarding" : "/";
|
||||
const { shouldShowOnboarding, isChatEnabled } = await getOnboardingStatus();
|
||||
const next = shouldShowOnboarding
|
||||
? "/onboarding"
|
||||
: getHomepageRoute(isChatEnabled);
|
||||
|
||||
return { success: true, next };
|
||||
} catch (err) {
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import { useToast } from "@/components/molecules/Toast/use-toast";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { environment } from "@/services/environment";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { LoginProvider, signupFormSchema } from "@/types/auth";
|
||||
import { zodResolver } from "@hookform/resolvers/zod";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
@@ -20,15 +22,17 @@ export function useSignupPage() {
|
||||
const [isGoogleLoading, setIsGoogleLoading] = useState(false);
|
||||
const [showNotAllowedModal, setShowNotAllowedModal] = useState(false);
|
||||
const isCloudEnv = environment.isCloud();
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
// Get redirect destination from 'next' query parameter
|
||||
const nextUrl = searchParams.get("next");
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoggedIn && !isSigningUp) {
|
||||
router.push(nextUrl || "/");
|
||||
router.push(nextUrl || homepageRoute);
|
||||
}
|
||||
}, [isLoggedIn, isSigningUp, nextUrl, router]);
|
||||
}, [homepageRoute, isLoggedIn, isSigningUp, nextUrl, router]);
|
||||
|
||||
const form = useForm<z.infer<typeof signupFormSchema>>({
|
||||
resolver: zodResolver(signupFormSchema),
|
||||
@@ -129,7 +133,7 @@ export function useSignupPage() {
|
||||
}
|
||||
|
||||
// Prefer the URL's next parameter, then result.next (for onboarding), then default
|
||||
const redirectTo = nextUrl || result.next || "/";
|
||||
const redirectTo = nextUrl || result.next || homepageRoute;
|
||||
router.replace(redirectTo);
|
||||
} catch (error) {
|
||||
setIsLoading(false);
|
||||
|
||||
@@ -181,5 +181,6 @@ export async function getOnboardingStatus() {
|
||||
const isCompleted = onboarding.completedSteps.includes("CONGRATS");
|
||||
return {
|
||||
shouldShowOnboarding: status.is_onboarding_enabled && !isCompleted,
|
||||
isChatEnabled: status.is_chat_enabled,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -7981,25 +7981,6 @@
|
||||
]
|
||||
},
|
||||
"new_output": { "type": "boolean", "title": "New Output" },
|
||||
"execution_count": {
|
||||
"type": "integer",
|
||||
"title": "Execution Count",
|
||||
"default": 0
|
||||
},
|
||||
"success_rate": {
|
||||
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||
"title": "Success Rate"
|
||||
},
|
||||
"avg_correctness_score": {
|
||||
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||
"title": "Avg Correctness Score"
|
||||
},
|
||||
"recent_executions": {
|
||||
"items": { "$ref": "#/components/schemas/RecentExecution" },
|
||||
"type": "array",
|
||||
"title": "Recent Executions",
|
||||
"description": "List of recent executions with status, score, and summary"
|
||||
},
|
||||
"can_access_graph": {
|
||||
"type": "boolean",
|
||||
"title": "Can Access Graph"
|
||||
@@ -9393,23 +9374,6 @@
|
||||
"required": ["providers", "pagination"],
|
||||
"title": "ProviderResponse"
|
||||
},
|
||||
"RecentExecution": {
|
||||
"properties": {
|
||||
"status": { "type": "string", "title": "Status" },
|
||||
"correctness_score": {
|
||||
"anyOf": [{ "type": "number" }, { "type": "null" }],
|
||||
"title": "Correctness Score"
|
||||
},
|
||||
"activity_summary": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Activity Summary"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["status"],
|
||||
"title": "RecentExecution",
|
||||
"description": "Summary of a recent execution for quality assessment.\n\nUsed by the LLM to understand the agent's recent performance with specific examples\nrather than just aggregate statistics."
|
||||
},
|
||||
"RefundRequest": {
|
||||
"properties": {
|
||||
"id": { "type": "string", "title": "Id" },
|
||||
@@ -9833,8 +9797,7 @@
|
||||
"sub_heading": { "type": "string", "title": "Sub Heading" },
|
||||
"description": { "type": "string", "title": "Description" },
|
||||
"runs": { "type": "integer", "title": "Runs" },
|
||||
"rating": { "type": "number", "title": "Rating" },
|
||||
"agent_graph_id": { "type": "string", "title": "Agent Graph Id" }
|
||||
"rating": { "type": "number", "title": "Rating" }
|
||||
},
|
||||
"type": "object",
|
||||
"required": [
|
||||
@@ -9846,8 +9809,7 @@
|
||||
"sub_heading",
|
||||
"description",
|
||||
"runs",
|
||||
"rating",
|
||||
"agent_graph_id"
|
||||
"rating"
|
||||
],
|
||||
"title": "StoreAgent"
|
||||
},
|
||||
|
||||
@@ -1,15 +1,27 @@
|
||||
"use client";
|
||||
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
|
||||
export default function Page() {
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const router = useRouter();
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||
const isFlagReady =
|
||||
!isLaunchDarklyConfigured || typeof isChatEnabled === "boolean";
|
||||
|
||||
useEffect(() => {
|
||||
router.replace("/copilot");
|
||||
}, [router]);
|
||||
useEffect(
|
||||
function redirectToHomepage() {
|
||||
if (!isFlagReady) return;
|
||||
router.replace(homepageRoute);
|
||||
},
|
||||
[homepageRoute, isFlagReady, router],
|
||||
);
|
||||
|
||||
return <LoadingSpinner size="large" cover />;
|
||||
return null;
|
||||
}
|
||||
|
||||
@@ -57,7 +57,6 @@ export function ChatInput({
|
||||
isStreaming,
|
||||
value,
|
||||
baseHandleKeyDown,
|
||||
inputId,
|
||||
});
|
||||
|
||||
return (
|
||||
|
||||
@@ -15,7 +15,6 @@ interface Args {
|
||||
isStreaming?: boolean;
|
||||
value: string;
|
||||
baseHandleKeyDown: (event: KeyboardEvent<HTMLTextAreaElement>) => void;
|
||||
inputId?: string;
|
||||
}
|
||||
|
||||
export function useVoiceRecording({
|
||||
@@ -24,7 +23,6 @@ export function useVoiceRecording({
|
||||
isStreaming = false,
|
||||
value,
|
||||
baseHandleKeyDown,
|
||||
inputId,
|
||||
}: Args) {
|
||||
const [isRecording, setIsRecording] = useState(false);
|
||||
const [isTranscribing, setIsTranscribing] = useState(false);
|
||||
@@ -105,7 +103,7 @@ export function useVoiceRecording({
|
||||
setIsTranscribing(false);
|
||||
}
|
||||
},
|
||||
[handleTranscription, inputId],
|
||||
[handleTranscription],
|
||||
);
|
||||
|
||||
const stopRecording = useCallback(() => {
|
||||
@@ -203,15 +201,6 @@ export function useVoiceRecording({
|
||||
}
|
||||
}, [error, toast]);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isTranscribing && inputId) {
|
||||
const inputElement = document.getElementById(inputId);
|
||||
if (inputElement) {
|
||||
inputElement.focus();
|
||||
}
|
||||
}
|
||||
}, [isTranscribing, inputId]);
|
||||
|
||||
const handleKeyDown = useCallback(
|
||||
(event: KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (event.key === " " && !value.trim() && !isTranscribing) {
|
||||
|
||||
@@ -156,19 +156,11 @@ export function ChatMessage({
|
||||
}
|
||||
|
||||
if (isClarificationNeeded && message.type === "clarification_needed") {
|
||||
const hasUserReplyAfter =
|
||||
index >= 0 &&
|
||||
messages
|
||||
.slice(index + 1)
|
||||
.some((m) => m.type === "message" && m.role === "user");
|
||||
|
||||
return (
|
||||
<ClarificationQuestionsWidget
|
||||
questions={message.questions}
|
||||
message={message.message}
|
||||
sessionId={message.sessionId}
|
||||
onSubmitAnswers={handleClarificationAnswers}
|
||||
isAnswered={hasUserReplyAfter}
|
||||
className={className}
|
||||
/>
|
||||
);
|
||||
|
||||
@@ -6,7 +6,7 @@ import { Input } from "@/components/atoms/Input/Input";
|
||||
import { Text } from "@/components/atoms/Text/Text";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { CheckCircleIcon, QuestionIcon } from "@phosphor-icons/react";
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import { useState } from "react";
|
||||
|
||||
export interface ClarifyingQuestion {
|
||||
question: string;
|
||||
@@ -17,96 +17,39 @@ export interface ClarifyingQuestion {
|
||||
interface Props {
|
||||
questions: ClarifyingQuestion[];
|
||||
message: string;
|
||||
sessionId?: string;
|
||||
onSubmitAnswers: (answers: Record<string, string>) => void;
|
||||
onCancel?: () => void;
|
||||
isAnswered?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
function getStorageKey(sessionId?: string): string | null {
|
||||
if (!sessionId) return null;
|
||||
return `clarification_answers_${sessionId}`;
|
||||
}
|
||||
|
||||
export function ClarificationQuestionsWidget({
|
||||
questions,
|
||||
message,
|
||||
sessionId,
|
||||
onSubmitAnswers,
|
||||
onCancel,
|
||||
isAnswered = false,
|
||||
className,
|
||||
}: Props) {
|
||||
const [answers, setAnswers] = useState<Record<string, string>>({});
|
||||
const [isSubmitted, setIsSubmitted] = useState(false);
|
||||
const lastSessionIdRef = useRef<string | undefined>(undefined);
|
||||
|
||||
useEffect(() => {
|
||||
const storageKey = getStorageKey(sessionId);
|
||||
if (!storageKey) {
|
||||
setAnswers({});
|
||||
setIsSubmitted(false);
|
||||
lastSessionIdRef.current = sessionId;
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const saved = localStorage.getItem(storageKey);
|
||||
if (saved) {
|
||||
const parsed = JSON.parse(saved) as Record<string, string>;
|
||||
setAnswers(parsed);
|
||||
} else {
|
||||
setAnswers({});
|
||||
}
|
||||
setIsSubmitted(false);
|
||||
} catch {
|
||||
setAnswers({});
|
||||
setIsSubmitted(false);
|
||||
}
|
||||
lastSessionIdRef.current = sessionId;
|
||||
}, [sessionId]);
|
||||
|
||||
useEffect(() => {
|
||||
if (lastSessionIdRef.current !== sessionId) {
|
||||
return;
|
||||
}
|
||||
const storageKey = getStorageKey(sessionId);
|
||||
if (!storageKey) return;
|
||||
|
||||
const hasAnswers = Object.values(answers).some((v) => v.trim());
|
||||
try {
|
||||
if (hasAnswers) {
|
||||
localStorage.setItem(storageKey, JSON.stringify(answers));
|
||||
} else {
|
||||
localStorage.removeItem(storageKey);
|
||||
}
|
||||
} catch {}
|
||||
}, [answers, sessionId]);
|
||||
|
||||
function handleAnswerChange(keyword: string, value: string) {
|
||||
setAnswers((prev) => ({ ...prev, [keyword]: value }));
|
||||
}
|
||||
|
||||
function handleSubmit() {
|
||||
// Check if all questions are answered
|
||||
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
|
||||
if (!allAnswered) {
|
||||
return;
|
||||
}
|
||||
setIsSubmitted(true);
|
||||
onSubmitAnswers(answers);
|
||||
|
||||
const storageKey = getStorageKey(sessionId);
|
||||
try {
|
||||
if (storageKey) {
|
||||
localStorage.removeItem(storageKey);
|
||||
}
|
||||
} catch {}
|
||||
}
|
||||
|
||||
const allAnswered = questions.every((q) => answers[q.keyword]?.trim());
|
||||
|
||||
if (isAnswered || isSubmitted) {
|
||||
// Show submitted state after answers are submitted
|
||||
if (isSubmitted) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
|
||||
@@ -30,9 +30,9 @@ export function getErrorMessage(result: unknown): string {
|
||||
}
|
||||
if (typeof result === "object" && result !== null) {
|
||||
const response = result as Record<string, unknown>;
|
||||
if (response.error) return stripInternalReasoning(String(response.error));
|
||||
if (response.message)
|
||||
return stripInternalReasoning(String(response.message));
|
||||
if (response.error) return stripInternalReasoning(String(response.error));
|
||||
}
|
||||
return "An error occurred";
|
||||
}
|
||||
@@ -363,8 +363,8 @@ export function formatToolResponse(result: unknown, toolName: string): string {
|
||||
|
||||
case "error":
|
||||
const errorMsg =
|
||||
(response.message as string) || response.error || "An error occurred";
|
||||
return stripInternalReasoning(String(errorMsg));
|
||||
(response.error as string) || response.message || "An error occurred";
|
||||
return `Error: ${errorMsg}`;
|
||||
|
||||
case "no_results":
|
||||
const suggestions = (response.suggestions as string[]) || [];
|
||||
|
||||
@@ -26,6 +26,7 @@ export const providerIcons: Partial<
|
||||
nvidia: fallbackIcon,
|
||||
discord: FaDiscord,
|
||||
d_id: fallbackIcon,
|
||||
elevenlabs: fallbackIcon,
|
||||
google_maps: FaGoogle,
|
||||
jina: fallbackIcon,
|
||||
ideogram: fallbackIcon,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { IconLaptop } from "@/components/__legacy__/ui/icons";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { ListChecksIcon } from "@phosphor-icons/react/dist/ssr";
|
||||
@@ -23,11 +24,11 @@ interface Props {
|
||||
export function NavbarLink({ name, href }: Props) {
|
||||
const pathname = usePathname();
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const expectedHomeRoute = isChatEnabled ? "/copilot" : "/library";
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
const isActive =
|
||||
href === expectedHomeRoute
|
||||
? pathname === "/" || pathname.startsWith(expectedHomeRoute)
|
||||
href === homepageRoute
|
||||
? pathname === "/" || pathname.startsWith(homepageRoute)
|
||||
: pathname.includes(href);
|
||||
|
||||
return (
|
||||
|
||||
@@ -66,7 +66,7 @@ export default function useAgentGraph(
|
||||
>(null);
|
||||
const [xyNodes, setXYNodes] = useState<CustomNode[]>([]);
|
||||
const [xyEdges, setXYEdges] = useState<CustomEdge[]>([]);
|
||||
const betaBlocks = useGetFlag(Flag.BETA_BLOCKS) as string[];
|
||||
const betaBlocks = useGetFlag(Flag.BETA_BLOCKS);
|
||||
|
||||
// Filter blocks based on beta flags
|
||||
const availableBlocks = useMemo(() => {
|
||||
|
||||
@@ -11,3 +11,10 @@ export const API_KEY_HEADER_NAME = "X-API-Key";
|
||||
|
||||
// Layout
|
||||
export const NAVBAR_HEIGHT_PX = 60;
|
||||
|
||||
// Routes
|
||||
export function getHomepageRoute(isChatEnabled?: boolean | null): string {
|
||||
if (isChatEnabled === true) return "/copilot";
|
||||
if (isChatEnabled === false) return "/library";
|
||||
return "/";
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { environment } from "@/services/environment";
|
||||
import { Key, storage } from "@/services/storage/local-storage";
|
||||
import { type CookieOptions } from "@supabase/ssr";
|
||||
@@ -70,7 +71,7 @@ export function getRedirectPath(
|
||||
}
|
||||
|
||||
if (isAdminPage(path) && userRole !== "admin") {
|
||||
return "/";
|
||||
return getHomepageRoute();
|
||||
}
|
||||
|
||||
return null;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { environment } from "@/services/environment";
|
||||
import { createServerClient } from "@supabase/ssr";
|
||||
import { NextResponse, type NextRequest } from "next/server";
|
||||
@@ -66,7 +67,7 @@ export async function updateSession(request: NextRequest) {
|
||||
|
||||
// 2. Check if user is authenticated but lacks admin role when accessing admin pages
|
||||
if (user && userRole !== "admin" && isAdminPage(pathname)) {
|
||||
url.pathname = "/";
|
||||
url.pathname = getHomepageRoute();
|
||||
return NextResponse.redirect(url);
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,9 @@ import {
|
||||
WebSocketNotification,
|
||||
} from "@/lib/autogpt-server-api";
|
||||
import { useBackendAPI } from "@/lib/autogpt-server-api/context";
|
||||
import { getHomepageRoute } from "@/lib/constants";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import Link from "next/link";
|
||||
import { usePathname, useRouter } from "next/navigation";
|
||||
import {
|
||||
@@ -102,6 +104,8 @@ export default function OnboardingProvider({
|
||||
const pathname = usePathname();
|
||||
const router = useRouter();
|
||||
const { isLoggedIn } = useSupabase();
|
||||
const isChatEnabled = useGetFlag(Flag.CHAT);
|
||||
const homepageRoute = getHomepageRoute(isChatEnabled);
|
||||
|
||||
useOnboardingTimezoneDetection();
|
||||
|
||||
@@ -146,7 +150,7 @@ export default function OnboardingProvider({
|
||||
if (isOnOnboardingRoute) {
|
||||
const enabled = await resolveResponse(getV1IsOnboardingEnabled());
|
||||
if (!enabled) {
|
||||
router.push("/");
|
||||
router.push(homepageRoute);
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -158,7 +162,7 @@ export default function OnboardingProvider({
|
||||
isOnOnboardingRoute &&
|
||||
shouldRedirectFromOnboarding(onboarding.completedSteps, pathname)
|
||||
) {
|
||||
router.push("/");
|
||||
router.push(homepageRoute);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Failed to initialize onboarding:", error);
|
||||
@@ -173,7 +177,7 @@ export default function OnboardingProvider({
|
||||
}
|
||||
|
||||
initializeOnboarding();
|
||||
}, [api, isOnOnboardingRoute, router, isLoggedIn, pathname]);
|
||||
}, [api, homepageRoute, isOnOnboardingRoute, router, isLoggedIn, pathname]);
|
||||
|
||||
const handleOnboardingNotification = useCallback(
|
||||
(notification: WebSocketNotification) => {
|
||||
|
||||
@@ -83,10 +83,6 @@ function getPostHogCredentials() {
|
||||
};
|
||||
}
|
||||
|
||||
function getLaunchDarklyClientId() {
|
||||
return process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||
}
|
||||
|
||||
function isProductionBuild() {
|
||||
return process.env.NODE_ENV === "production";
|
||||
}
|
||||
@@ -124,10 +120,7 @@ function isVercelPreview() {
|
||||
}
|
||||
|
||||
function areFeatureFlagsEnabled() {
|
||||
return (
|
||||
process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true" &&
|
||||
Boolean(process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID)
|
||||
);
|
||||
return process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "enabled";
|
||||
}
|
||||
|
||||
function isPostHogEnabled() {
|
||||
@@ -150,7 +143,6 @@ export const environment = {
|
||||
getSupabaseAnonKey,
|
||||
getPreviewStealingDev,
|
||||
getPostHogCredentials,
|
||||
getLaunchDarklyClientId,
|
||||
// Assertions
|
||||
isServerSide,
|
||||
isClientSide,
|
||||
|
||||
@@ -1,59 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { useLDClient } from "launchdarkly-react-client-sdk";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { ReactNode, useEffect, useState } from "react";
|
||||
import { environment } from "../environment";
|
||||
import { Flag, useGetFlag } from "./use-get-flag";
|
||||
|
||||
interface FeatureFlagRedirectProps {
|
||||
flag: Flag;
|
||||
whenDisabled: string;
|
||||
children: ReactNode;
|
||||
}
|
||||
|
||||
export function FeatureFlagPage({
|
||||
flag,
|
||||
whenDisabled,
|
||||
children,
|
||||
}: FeatureFlagRedirectProps) {
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const router = useRouter();
|
||||
const flagValue = useGetFlag(flag);
|
||||
const ldClient = useLDClient();
|
||||
const ldEnabled = environment.areFeatureFlagsEnabled();
|
||||
const ldReady = Boolean(ldClient);
|
||||
const flagEnabled = Boolean(flagValue);
|
||||
|
||||
useEffect(() => {
|
||||
const initialize = async () => {
|
||||
if (!ldEnabled) {
|
||||
router.replace(whenDisabled);
|
||||
setIsLoading(false);
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for LaunchDarkly to initialize when enabled to prevent race conditions
|
||||
if (ldEnabled && !ldReady) return;
|
||||
|
||||
try {
|
||||
await ldClient?.waitForInitialization();
|
||||
if (!flagEnabled) router.replace(whenDisabled);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
router.replace(whenDisabled);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
initialize();
|
||||
}, [ldReady, flagEnabled]);
|
||||
|
||||
return isLoading || !flagEnabled ? (
|
||||
<LoadingSpinner size="large" cover />
|
||||
) : (
|
||||
<>{children}</>
|
||||
);
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { useLDClient } from "launchdarkly-react-client-sdk";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useEffect } from "react";
|
||||
import { environment } from "../environment";
|
||||
import { Flag, useGetFlag } from "./use-get-flag";
|
||||
|
||||
interface FeatureFlagRedirectProps {
|
||||
flag: Flag;
|
||||
whenEnabled: string;
|
||||
whenDisabled: string;
|
||||
}
|
||||
|
||||
export function FeatureFlagRedirect({
|
||||
flag,
|
||||
whenEnabled,
|
||||
whenDisabled,
|
||||
}: FeatureFlagRedirectProps) {
|
||||
const router = useRouter();
|
||||
const flagValue = useGetFlag(flag);
|
||||
const ldEnabled = environment.areFeatureFlagsEnabled();
|
||||
const ldClient = useLDClient();
|
||||
const ldReady = Boolean(ldClient);
|
||||
const flagEnabled = Boolean(flagValue);
|
||||
|
||||
useEffect(() => {
|
||||
const initialize = async () => {
|
||||
if (!ldEnabled) {
|
||||
router.replace(whenDisabled);
|
||||
return;
|
||||
}
|
||||
|
||||
// Wait for LaunchDarkly to initialize when enabled to prevent race conditions
|
||||
if (ldEnabled && !ldReady) return;
|
||||
|
||||
try {
|
||||
await ldClient?.waitForInitialization();
|
||||
router.replace(flagEnabled ? whenEnabled : whenDisabled);
|
||||
} catch (error) {
|
||||
console.error(error);
|
||||
router.replace(whenDisabled);
|
||||
}
|
||||
};
|
||||
|
||||
initialize();
|
||||
}, [ldReady, flagEnabled]);
|
||||
|
||||
return <LoadingSpinner size="large" cover />;
|
||||
}
|
||||
@@ -1,6 +1,5 @@
|
||||
"use client";
|
||||
|
||||
import { LoadingSpinner } from "@/components/atoms/LoadingSpinner/LoadingSpinner";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import * as Sentry from "@sentry/nextjs";
|
||||
import { LDProvider } from "launchdarkly-react-client-sdk";
|
||||
@@ -8,17 +7,17 @@ import type { ReactNode } from "react";
|
||||
import { useMemo } from "react";
|
||||
import { environment } from "../environment";
|
||||
|
||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||
const LAUNCHDARKLY_INIT_TIMEOUT_MS = 5000;
|
||||
|
||||
export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
||||
const { user, isUserLoading } = useSupabase();
|
||||
const envEnabled = environment.areFeatureFlagsEnabled();
|
||||
const clientId = environment.getLaunchDarklyClientId();
|
||||
const isCloud = environment.isCloud();
|
||||
const isLaunchDarklyConfigured = isCloud && envEnabled && clientId;
|
||||
|
||||
const context = useMemo(() => {
|
||||
if (isUserLoading) return;
|
||||
|
||||
if (!user) {
|
||||
if (isUserLoading || !user) {
|
||||
return {
|
||||
kind: "user" as const,
|
||||
key: "anonymous",
|
||||
@@ -37,17 +36,15 @@ export function LaunchDarklyProvider({ children }: { children: ReactNode }) {
|
||||
};
|
||||
}, [user, isUserLoading]);
|
||||
|
||||
if (!envEnabled) {
|
||||
if (!isLaunchDarklyConfigured) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
if (isUserLoading) {
|
||||
return <LoadingSpinner size="large" cover />;
|
||||
}
|
||||
|
||||
return (
|
||||
<LDProvider
|
||||
clientSideID={clientId ?? ""}
|
||||
// Add this key prop. It will be 'anonymous' when logged out,
|
||||
key={context.key}
|
||||
clientSideID={clientId}
|
||||
context={context}
|
||||
timeout={LAUNCHDARKLY_INIT_TIMEOUT_MS}
|
||||
reactOptions={{ useCamelCaseFlagKeys: false }}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { DEFAULT_SEARCH_TERMS } from "@/app/(platform)/marketplace/components/HeroSection/helpers";
|
||||
import { environment } from "@/services/environment";
|
||||
import { useFlags } from "launchdarkly-react-client-sdk";
|
||||
|
||||
export enum Flag {
|
||||
@@ -19,9 +18,24 @@ export enum Flag {
|
||||
CHAT = "chat",
|
||||
}
|
||||
|
||||
export type FlagValues = {
|
||||
[Flag.BETA_BLOCKS]: string[];
|
||||
[Flag.NEW_BLOCK_MENU]: boolean;
|
||||
[Flag.NEW_AGENT_RUNS]: boolean;
|
||||
[Flag.GRAPH_SEARCH]: boolean;
|
||||
[Flag.ENABLE_ENHANCED_OUTPUT_HANDLING]: boolean;
|
||||
[Flag.NEW_FLOW_EDITOR]: boolean;
|
||||
[Flag.BUILDER_VIEW_SWITCH]: boolean;
|
||||
[Flag.SHARE_EXECUTION_RESULTS]: boolean;
|
||||
[Flag.AGENT_FAVORITING]: boolean;
|
||||
[Flag.MARKETPLACE_SEARCH_TERMS]: string[];
|
||||
[Flag.ENABLE_PLATFORM_PAYMENT]: boolean;
|
||||
[Flag.CHAT]: boolean;
|
||||
};
|
||||
|
||||
const isPwMockEnabled = process.env.NEXT_PUBLIC_PW_TEST === "true";
|
||||
|
||||
const defaultFlags = {
|
||||
const mockFlags = {
|
||||
[Flag.BETA_BLOCKS]: [],
|
||||
[Flag.NEW_BLOCK_MENU]: false,
|
||||
[Flag.NEW_AGENT_RUNS]: false,
|
||||
@@ -36,16 +50,17 @@ const defaultFlags = {
|
||||
[Flag.CHAT]: false,
|
||||
};
|
||||
|
||||
type FlagValues = typeof defaultFlags;
|
||||
|
||||
export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] {
|
||||
export function useGetFlag<T extends Flag>(flag: T): FlagValues[T] | null {
|
||||
const currentFlags = useFlags<FlagValues>();
|
||||
const flagValue = currentFlags[flag];
|
||||
const areFlagsEnabled = environment.areFeatureFlagsEnabled();
|
||||
|
||||
if (!areFlagsEnabled || isPwMockEnabled) {
|
||||
return defaultFlags[flag];
|
||||
const envEnabled = process.env.NEXT_PUBLIC_LAUNCHDARKLY_ENABLED === "true";
|
||||
const clientId = process.env.NEXT_PUBLIC_LAUNCHDARKLY_CLIENT_ID;
|
||||
const isLaunchDarklyConfigured = envEnabled && Boolean(clientId);
|
||||
|
||||
if (!isLaunchDarklyConfigured || isPwMockEnabled) {
|
||||
return mockFlags[flag];
|
||||
}
|
||||
|
||||
return flagValue ?? defaultFlags[flag];
|
||||
return flagValue ?? mockFlags[flag];
|
||||
}
|
||||
|
||||
@@ -59,13 +59,12 @@ test.describe("Library", () => {
|
||||
});
|
||||
|
||||
test("pagination works correctly", async ({ page }, testInfo) => {
|
||||
test.setTimeout(testInfo.timeout * 3);
|
||||
test.setTimeout(testInfo.timeout * 3); // Increase timeout for pagination operations
|
||||
await page.goto("/library");
|
||||
|
||||
const PAGE_SIZE = 20;
|
||||
const paginationResult = await libraryPage.testPagination();
|
||||
|
||||
if (paginationResult.initialCount >= PAGE_SIZE) {
|
||||
if (paginationResult.initialCount >= 10) {
|
||||
expect(paginationResult.finalCount).toBeGreaterThanOrEqual(
|
||||
paginationResult.initialCount,
|
||||
);
|
||||
@@ -134,10 +133,7 @@ test.describe("Library", () => {
|
||||
test.expect(clearedSearchValue).toBe("");
|
||||
});
|
||||
|
||||
test("pagination while searching works correctly", async ({
|
||||
page,
|
||||
}, testInfo) => {
|
||||
test.setTimeout(testInfo.timeout * 3);
|
||||
test("pagination while searching works correctly", async ({ page }) => {
|
||||
await page.goto("/library");
|
||||
|
||||
const allAgents = await libraryPage.getAgents();
|
||||
@@ -156,10 +152,9 @@ test.describe("Library", () => {
|
||||
);
|
||||
expect(matchingResults.length).toEqual(initialSearchResults.length);
|
||||
|
||||
const PAGE_SIZE = 20;
|
||||
const searchPaginationResult = await libraryPage.testPagination();
|
||||
|
||||
if (searchPaginationResult.initialCount >= PAGE_SIZE) {
|
||||
if (searchPaginationResult.initialCount >= 10) {
|
||||
expect(searchPaginationResult.finalCount).toBeGreaterThanOrEqual(
|
||||
searchPaginationResult.initialCount,
|
||||
);
|
||||
|
||||
@@ -69,12 +69,9 @@ test.describe("Marketplace Creator Page – Basic Functionality", () => {
|
||||
await marketplacePage.getFirstCreatorProfile(page);
|
||||
await firstCreatorProfile.click();
|
||||
await page.waitForURL("**/marketplace/creator/**");
|
||||
await page.waitForLoadState("networkidle").catch(() => {});
|
||||
|
||||
const firstAgent = page
|
||||
.locator('[data-testid="store-card"]:visible')
|
||||
.first();
|
||||
await firstAgent.waitFor({ state: "visible", timeout: 30000 });
|
||||
|
||||
await firstAgent.click();
|
||||
await page.waitForURL("**/marketplace/agent/**");
|
||||
|
||||
@@ -77,6 +77,7 @@ test.describe("Marketplace – Basic Functionality", () => {
|
||||
|
||||
const firstFeaturedAgent =
|
||||
await marketplacePage.getFirstFeaturedAgent(page);
|
||||
await firstFeaturedAgent.waitFor({ state: "visible" });
|
||||
await firstFeaturedAgent.click();
|
||||
await page.waitForURL("**/marketplace/agent/**");
|
||||
await matchesUrl(page, /\/marketplace\/agent\/.+/);
|
||||
@@ -115,15 +116,7 @@ test.describe("Marketplace – Basic Functionality", () => {
|
||||
const searchTerm = page.getByText("DummyInput").first();
|
||||
await isVisible(searchTerm);
|
||||
|
||||
await page.waitForLoadState("networkidle").catch(() => {});
|
||||
|
||||
await page
|
||||
.waitForFunction(
|
||||
() =>
|
||||
document.querySelectorAll('[data-testid="store-card"]').length > 0,
|
||||
{ timeout: 15000 },
|
||||
)
|
||||
.catch(() => console.log("No search results appeared within timeout"));
|
||||
await page.waitForTimeout(10000);
|
||||
|
||||
const results = await marketplacePage.getSearchResultsCount(page);
|
||||
expect(results).toBeGreaterThan(0);
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user