mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-04 11:55:11 -05:00
Compare commits
3 Commits
fix/copilo
...
docker/opt
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ed748a356 | ||
|
|
9c28639c32 | ||
|
|
4f37a12743 |
@@ -37,13 +37,15 @@ ENV POETRY_VIRTUALENVS_CREATE=true
|
|||||||
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
|
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
|
||||||
ENV PATH=/opt/poetry/bin:$PATH
|
ENV PATH=/opt/poetry/bin:$PATH
|
||||||
|
|
||||||
RUN pip3 install poetry --break-system-packages
|
RUN pip3 install --no-cache-dir poetry --break-system-packages
|
||||||
|
|
||||||
# Copy and install dependencies
|
# Copy and install dependencies
|
||||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
||||||
WORKDIR /app/autogpt_platform/backend
|
WORKDIR /app/autogpt_platform/backend
|
||||||
RUN poetry install --no-ansi --no-root
|
# Production image only needs runtime deps; dev deps (pytest, black, ruff, etc.)
|
||||||
|
# are installed locally via `poetry install --with dev` per the development docs
|
||||||
|
RUN poetry install --no-ansi --no-root --only main
|
||||||
|
|
||||||
# Generate Prisma client
|
# Generate Prisma client
|
||||||
COPY autogpt_platform/backend/schema.prisma ./
|
COPY autogpt_platform/backend/schema.prisma ./
|
||||||
@@ -51,6 +53,15 @@ COPY autogpt_platform/backend/backend/data/partial_types.py ./backend/data/parti
|
|||||||
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
COPY autogpt_platform/backend/gen_prisma_types_stub.py ./
|
||||||
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
RUN poetry run prisma generate && poetry run gen-prisma-stub
|
||||||
|
|
||||||
|
# Clean up build artifacts and caches to reduce layer size
|
||||||
|
# Note: setuptools is kept as it's a direct dependency (used by aioclamd via pkg_resources)
|
||||||
|
RUN find /app -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true; \
|
||||||
|
find /app -type d -name tests -exec rm -rf {} + 2>/dev/null || true; \
|
||||||
|
find /app -type d -name test -exec rm -rf {} + 2>/dev/null || true; \
|
||||||
|
rm -rf /app/autogpt_platform/backend/.venv/lib/python*/site-packages/pip* \
|
||||||
|
/root/.cache/pip \
|
||||||
|
/root/.cache/pypoetry
|
||||||
|
|
||||||
FROM debian:13-slim AS server_dependencies
|
FROM debian:13-slim AS server_dependencies
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
@@ -68,7 +79,7 @@ RUN apt-get update && apt-get install -y \
|
|||||||
python3-pip \
|
python3-pip \
|
||||||
&& rm -rf /var/lib/apt/lists/*
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
# Copy only necessary files from builder
|
# Copy built artifacts from builder (cleaned of caches, __pycache__, and test dirs)
|
||||||
COPY --from=builder /app /app
|
COPY --from=builder /app /app
|
||||||
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
COPY --from=builder /usr/local/lib/python3* /usr/local/lib/python3*
|
||||||
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
COPY --from=builder /usr/local/bin/poetry /usr/local/bin/poetry
|
||||||
@@ -81,9 +92,7 @@ COPY --from=builder /root/.cache/prisma-python/binaries /root/.cache/prisma-pyth
|
|||||||
|
|
||||||
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
ENV PATH="/app/autogpt_platform/backend/.venv/bin:$PATH"
|
||||||
|
|
||||||
RUN mkdir -p /app/autogpt_platform/autogpt_libs
|
# Copy fresh source from context (overwrites builder's copy with latest source)
|
||||||
RUN mkdir -p /app/autogpt_platform/backend
|
|
||||||
|
|
||||||
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
COPY autogpt_platform/autogpt_libs /app/autogpt_platform/autogpt_libs
|
||||||
|
|
||||||
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
COPY autogpt_platform/backend/poetry.lock autogpt_platform/backend/pyproject.toml /app/autogpt_platform/backend/
|
||||||
|
|||||||
@@ -3,8 +3,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from dataclasses import dataclass
|
from typing import Any
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import orjson
|
import orjson
|
||||||
@@ -16,14 +15,7 @@ from openai import (
|
|||||||
PermissionDeniedError,
|
PermissionDeniedError,
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
)
|
)
|
||||||
from openai.types.chat import (
|
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||||
ChatCompletionAssistantMessageParam,
|
|
||||||
ChatCompletionChunk,
|
|
||||||
ChatCompletionMessageParam,
|
|
||||||
ChatCompletionStreamOptionsParam,
|
|
||||||
ChatCompletionSystemMessageParam,
|
|
||||||
ChatCompletionToolParam,
|
|
||||||
)
|
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.data.understanding import (
|
from backend.data.understanding import (
|
||||||
@@ -31,7 +23,6 @@ from backend.data.understanding import (
|
|||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
)
|
)
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.prompt import estimate_token_count
|
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from . import db as chat_db
|
from . import db as chat_db
|
||||||
@@ -803,201 +794,6 @@ def _is_region_blocked_error(error: Exception) -> bool:
|
|||||||
return "not available in your region" in str(error).lower()
|
return "not available in your region" in str(error).lower()
|
||||||
|
|
||||||
|
|
||||||
# Context window management constants
|
|
||||||
TOKEN_THRESHOLD = 120_000
|
|
||||||
KEEP_RECENT_MESSAGES = 15
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ContextWindowResult:
|
|
||||||
"""Result of context window management."""
|
|
||||||
|
|
||||||
messages: list[dict[str, Any]]
|
|
||||||
token_count: int
|
|
||||||
was_compacted: bool
|
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def _messages_to_dicts(messages: list) -> list[dict[str, Any]]:
|
|
||||||
"""Convert message objects to dicts, filtering None values.
|
|
||||||
|
|
||||||
Handles both TypedDict (dict-like) and other message formats.
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
for msg in messages:
|
|
||||||
if msg is None:
|
|
||||||
continue
|
|
||||||
if isinstance(msg, dict):
|
|
||||||
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
|
||||||
else:
|
|
||||||
msg_dict = dict(msg)
|
|
||||||
result.append(msg_dict)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def _manage_context_window(
|
|
||||||
messages: list,
|
|
||||||
model: str,
|
|
||||||
api_key: str | None = None,
|
|
||||||
base_url: str | None = None,
|
|
||||||
) -> ContextWindowResult:
|
|
||||||
"""
|
|
||||||
Manage context window by summarizing old messages if token count exceeds threshold.
|
|
||||||
|
|
||||||
This function handles context compaction for LLM calls by:
|
|
||||||
1. Counting tokens in the message list
|
|
||||||
2. If over threshold, summarizing old messages while keeping recent ones
|
|
||||||
3. Ensuring tool_call/tool_response pairs stay intact
|
|
||||||
4. Progressively reducing message count if still over limit
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of messages in OpenAI format (with system prompt if present)
|
|
||||||
model: Model name for token counting
|
|
||||||
api_key: API key for summarization calls
|
|
||||||
base_url: Base URL for summarization calls
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ContextWindowResult with compacted messages and metadata
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return ContextWindowResult([], 0, False, "No messages to compact")
|
|
||||||
|
|
||||||
messages_dict = _messages_to_dicts(messages)
|
|
||||||
|
|
||||||
# Normalize model name for token counting (tiktoken only supports OpenAI models)
|
|
||||||
token_count_model = model.split("/")[-1] if "/" in model else model
|
|
||||||
if "claude" in token_count_model.lower() or not any(
|
|
||||||
known in token_count_model.lower()
|
|
||||||
for known in ["gpt", "o1", "chatgpt", "text-"]
|
|
||||||
):
|
|
||||||
token_count_model = "gpt-4o"
|
|
||||||
|
|
||||||
try:
|
|
||||||
token_count = estimate_token_count(messages_dict, model=token_count_model)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Token counting failed: {e}. Using gpt-4o approximation.")
|
|
||||||
token_count_model = "gpt-4o"
|
|
||||||
token_count = estimate_token_count(messages_dict, model=token_count_model)
|
|
||||||
|
|
||||||
if token_count <= TOKEN_THRESHOLD:
|
|
||||||
return ContextWindowResult(messages, token_count, False)
|
|
||||||
|
|
||||||
has_system_prompt = messages[0].get("role") == "system"
|
|
||||||
slice_start = max(0, len(messages_dict) - KEEP_RECENT_MESSAGES)
|
|
||||||
recent_messages = _ensure_tool_pairs_intact(
|
|
||||||
messages_dict[-KEEP_RECENT_MESSAGES:], messages_dict, slice_start
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine old messages to summarize (explicit bounds to avoid slice edge cases)
|
|
||||||
system_msg = messages[0] if has_system_prompt else None
|
|
||||||
if has_system_prompt:
|
|
||||||
old_messages_dict = (
|
|
||||||
messages_dict[1:-KEEP_RECENT_MESSAGES]
|
|
||||||
if len(messages_dict) > KEEP_RECENT_MESSAGES + 1
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
old_messages_dict = (
|
|
||||||
messages_dict[:-KEEP_RECENT_MESSAGES]
|
|
||||||
if len(messages_dict) > KEEP_RECENT_MESSAGES
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to summarize old messages, fall back to truncation on failure
|
|
||||||
summary_msg = None
|
|
||||||
if old_messages_dict:
|
|
||||||
try:
|
|
||||||
summary_text = await _summarize_messages(
|
|
||||||
old_messages_dict, model=model, api_key=api_key, base_url=base_url
|
|
||||||
)
|
|
||||||
summary_msg = ChatCompletionAssistantMessageParam(
|
|
||||||
role="assistant",
|
|
||||||
content=f"[Previous conversation summary — for context only]: {summary_text}",
|
|
||||||
)
|
|
||||||
base = [system_msg, summary_msg] if has_system_prompt else [summary_msg]
|
|
||||||
messages = base + recent_messages
|
|
||||||
logger.info(
|
|
||||||
f"Context summarized: {token_count} tokens, "
|
|
||||||
f"summarized {len(old_messages_dict)} msgs, kept {KEEP_RECENT_MESSAGES}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Summarization failed, falling back to truncation: {e}")
|
|
||||||
messages = (
|
|
||||||
[system_msg] + recent_messages if has_system_prompt else recent_messages
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Token count {token_count} exceeds threshold but no old messages to summarize"
|
|
||||||
)
|
|
||||||
|
|
||||||
new_token_count = estimate_token_count(
|
|
||||||
_messages_to_dicts(messages), model=token_count_model
|
|
||||||
)
|
|
||||||
|
|
||||||
# Progressive truncation if still over limit
|
|
||||||
if new_token_count > TOKEN_THRESHOLD:
|
|
||||||
logger.warning(
|
|
||||||
f"Still over limit: {new_token_count} tokens. Reducing messages."
|
|
||||||
)
|
|
||||||
base_msgs = (
|
|
||||||
recent_messages
|
|
||||||
if old_messages_dict
|
|
||||||
else (messages_dict[1:] if has_system_prompt else messages_dict)
|
|
||||||
)
|
|
||||||
|
|
||||||
def build_messages(recent: list) -> list:
|
|
||||||
"""Build message list with optional system prompt and summary."""
|
|
||||||
prefix = []
|
|
||||||
if has_system_prompt and system_msg:
|
|
||||||
prefix.append(system_msg)
|
|
||||||
if summary_msg:
|
|
||||||
prefix.append(summary_msg)
|
|
||||||
return prefix + recent
|
|
||||||
|
|
||||||
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
|
||||||
if keep_count == 0:
|
|
||||||
messages = build_messages([])
|
|
||||||
if not messages:
|
|
||||||
continue
|
|
||||||
elif len(base_msgs) < keep_count:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
reduced = _ensure_tool_pairs_intact(
|
|
||||||
base_msgs[-keep_count:],
|
|
||||||
base_msgs,
|
|
||||||
max(0, len(base_msgs) - keep_count),
|
|
||||||
)
|
|
||||||
messages = build_messages(reduced)
|
|
||||||
|
|
||||||
new_token_count = estimate_token_count(
|
|
||||||
_messages_to_dicts(messages), model=token_count_model
|
|
||||||
)
|
|
||||||
if new_token_count <= TOKEN_THRESHOLD:
|
|
||||||
logger.info(
|
|
||||||
f"Reduced to {keep_count} messages, {new_token_count} tokens"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Cannot reduce below threshold. Final: {new_token_count} tokens"
|
|
||||||
)
|
|
||||||
if has_system_prompt and len(messages) > 1:
|
|
||||||
messages = messages[1:]
|
|
||||||
logger.critical("Dropped system prompt as last resort")
|
|
||||||
return ContextWindowResult(
|
|
||||||
messages, new_token_count, True, "System prompt dropped"
|
|
||||||
)
|
|
||||||
# No system prompt to drop - return error so callers don't proceed with oversized context
|
|
||||||
return ContextWindowResult(
|
|
||||||
messages,
|
|
||||||
new_token_count,
|
|
||||||
True,
|
|
||||||
"Unable to reduce context below token limit",
|
|
||||||
)
|
|
||||||
|
|
||||||
return ContextWindowResult(messages, new_token_count, True)
|
|
||||||
|
|
||||||
|
|
||||||
async def _summarize_messages(
|
async def _summarize_messages(
|
||||||
messages: list,
|
messages: list,
|
||||||
model: str,
|
model: str,
|
||||||
@@ -1226,8 +1022,11 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
logger.info("Starting pure chat stream")
|
logger.info("Starting pure chat stream")
|
||||||
|
|
||||||
|
# Build messages with system prompt prepended
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||||
|
|
||||||
system_message = ChatCompletionSystemMessageParam(
|
system_message = ChatCompletionSystemMessageParam(
|
||||||
role="system",
|
role="system",
|
||||||
content=system_prompt,
|
content=system_prompt,
|
||||||
@@ -1235,38 +1034,314 @@ async def _stream_chat_chunks(
|
|||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Apply context window management
|
# Apply context window management
|
||||||
context_result = await _manage_context_window(
|
token_count = 0 # Initialize for exception handler
|
||||||
messages=messages,
|
try:
|
||||||
model=model,
|
from backend.util.prompt import estimate_token_count
|
||||||
api_key=config.api_key,
|
|
||||||
base_url=config.base_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
if context_result.error:
|
# Convert to dict for token counting
|
||||||
if "System prompt dropped" in context_result.error:
|
# OpenAI message types are TypedDicts, so they're already dict-like
|
||||||
# Warning only - continue with reduced context
|
messages_dict = []
|
||||||
yield StreamError(
|
for msg in messages:
|
||||||
errorText=(
|
# TypedDict objects are already dicts, just filter None values
|
||||||
"Warning: System prompt dropped due to size constraints. "
|
if isinstance(msg, dict):
|
||||||
"Assistant behavior may be affected."
|
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
||||||
)
|
else:
|
||||||
|
# Fallback for unexpected types
|
||||||
|
msg_dict = dict(msg)
|
||||||
|
messages_dict.append(msg_dict)
|
||||||
|
|
||||||
|
# Estimate tokens using appropriate tokenizer
|
||||||
|
# Normalize model name for token counting (tiktoken only supports OpenAI models)
|
||||||
|
token_count_model = model
|
||||||
|
if "/" in model:
|
||||||
|
# Strip provider prefix (e.g., "anthropic/claude-opus-4.5" -> "claude-opus-4.5")
|
||||||
|
token_count_model = model.split("/")[-1]
|
||||||
|
|
||||||
|
# For Claude and other non-OpenAI models, approximate with gpt-4o tokenizer
|
||||||
|
# Most modern LLMs have similar tokenization (~1 token per 4 chars)
|
||||||
|
if "claude" in token_count_model.lower() or not any(
|
||||||
|
known in token_count_model.lower()
|
||||||
|
for known in ["gpt", "o1", "chatgpt", "text-"]
|
||||||
|
):
|
||||||
|
token_count_model = "gpt-4o"
|
||||||
|
|
||||||
|
# Attempt token counting with error handling
|
||||||
|
try:
|
||||||
|
token_count = estimate_token_count(messages_dict, model=token_count_model)
|
||||||
|
except Exception as token_error:
|
||||||
|
# If token counting fails, use gpt-4o as fallback approximation
|
||||||
|
logger.warning(
|
||||||
|
f"Token counting failed for model {token_count_model}: {token_error}. "
|
||||||
|
"Using gpt-4o approximation."
|
||||||
)
|
)
|
||||||
else:
|
token_count = estimate_token_count(messages_dict, model="gpt-4o")
|
||||||
# Any other error - abort to prevent failed LLM calls
|
|
||||||
|
# If over threshold, summarize old messages
|
||||||
|
if token_count > 120_000:
|
||||||
|
KEEP_RECENT = 15
|
||||||
|
|
||||||
|
# Check if we have a system prompt at the start
|
||||||
|
has_system_prompt = (
|
||||||
|
len(messages) > 0 and messages[0].get("role") == "system"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Always attempt mitigation when over limit, even with few messages
|
||||||
|
if messages:
|
||||||
|
# Split messages based on whether system prompt exists
|
||||||
|
# Calculate start index for the slice
|
||||||
|
slice_start = max(0, len(messages_dict) - KEEP_RECENT)
|
||||||
|
recent_messages = messages_dict[-KEEP_RECENT:]
|
||||||
|
|
||||||
|
# Ensure tool_call/tool_response pairs stay together
|
||||||
|
# This prevents API errors from orphan tool responses
|
||||||
|
recent_messages = _ensure_tool_pairs_intact(
|
||||||
|
recent_messages, messages_dict, slice_start
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_system_prompt:
|
||||||
|
# Keep system prompt separate, summarize everything between system and recent
|
||||||
|
system_msg = messages[0]
|
||||||
|
old_messages_dict = messages_dict[1:-KEEP_RECENT]
|
||||||
|
else:
|
||||||
|
# No system prompt, summarize everything except recent
|
||||||
|
system_msg = None
|
||||||
|
old_messages_dict = messages_dict[:-KEEP_RECENT]
|
||||||
|
|
||||||
|
# Summarize any non-empty old messages (no minimum threshold)
|
||||||
|
# If we're over the token limit, we need to compress whatever we can
|
||||||
|
if old_messages_dict:
|
||||||
|
# Summarize old messages using the same model as chat
|
||||||
|
summary_text = await _summarize_messages(
|
||||||
|
old_messages_dict,
|
||||||
|
model=model,
|
||||||
|
api_key=config.api_key,
|
||||||
|
base_url=config.base_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build new message list
|
||||||
|
# Use assistant role (not system) to prevent privilege escalation
|
||||||
|
# of user-influenced content to instruction-level authority
|
||||||
|
from openai.types.chat import ChatCompletionAssistantMessageParam
|
||||||
|
|
||||||
|
summary_msg = ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant",
|
||||||
|
content=(
|
||||||
|
"[Previous conversation summary — for context only]: "
|
||||||
|
f"{summary_text}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rebuild messages based on whether we have a system prompt
|
||||||
|
if has_system_prompt:
|
||||||
|
# system_prompt + summary + recent_messages
|
||||||
|
messages = [system_msg, summary_msg] + recent_messages
|
||||||
|
else:
|
||||||
|
# summary + recent_messages (no original system prompt)
|
||||||
|
messages = [summary_msg] + recent_messages
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Context summarized: {token_count} tokens, "
|
||||||
|
f"summarized {len(old_messages_dict)} old messages, "
|
||||||
|
f"kept last {KEEP_RECENT} messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback: If still over limit after summarization, progressively drop recent messages
|
||||||
|
# This handles edge cases where recent messages are extremely large
|
||||||
|
new_messages_dict = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
||||||
|
else:
|
||||||
|
msg_dict = dict(msg)
|
||||||
|
new_messages_dict.append(msg_dict)
|
||||||
|
|
||||||
|
new_token_count = estimate_token_count(
|
||||||
|
new_messages_dict, model=token_count_model
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_token_count > 120_000:
|
||||||
|
# Still over limit - progressively reduce KEEP_RECENT
|
||||||
|
logger.warning(
|
||||||
|
f"Still over limit after summarization: {new_token_count} tokens. "
|
||||||
|
"Reducing number of recent messages kept."
|
||||||
|
)
|
||||||
|
|
||||||
|
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
||||||
|
if keep_count == 0:
|
||||||
|
# Try with just system prompt + summary (no recent messages)
|
||||||
|
if has_system_prompt:
|
||||||
|
messages = [system_msg, summary_msg]
|
||||||
|
else:
|
||||||
|
messages = [summary_msg]
|
||||||
|
logger.info(
|
||||||
|
"Trying with 0 recent messages (system + summary only)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Slice from ORIGINAL recent_messages to avoid duplicating summary
|
||||||
|
reduced_recent = (
|
||||||
|
recent_messages[-keep_count:]
|
||||||
|
if len(recent_messages) >= keep_count
|
||||||
|
else recent_messages
|
||||||
|
)
|
||||||
|
# Ensure tool pairs stay intact in the reduced slice
|
||||||
|
reduced_slice_start = max(
|
||||||
|
0, len(recent_messages) - keep_count
|
||||||
|
)
|
||||||
|
reduced_recent = _ensure_tool_pairs_intact(
|
||||||
|
reduced_recent, recent_messages, reduced_slice_start
|
||||||
|
)
|
||||||
|
if has_system_prompt:
|
||||||
|
messages = [
|
||||||
|
system_msg,
|
||||||
|
summary_msg,
|
||||||
|
] + reduced_recent
|
||||||
|
else:
|
||||||
|
messages = [summary_msg] + reduced_recent
|
||||||
|
|
||||||
|
new_messages_dict = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
msg_dict = {
|
||||||
|
k: v for k, v in msg.items() if v is not None
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
msg_dict = dict(msg)
|
||||||
|
new_messages_dict.append(msg_dict)
|
||||||
|
|
||||||
|
new_token_count = estimate_token_count(
|
||||||
|
new_messages_dict, model=token_count_model
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_token_count <= 120_000:
|
||||||
|
logger.info(
|
||||||
|
f"Reduced to {keep_count} recent messages, "
|
||||||
|
f"now {new_token_count} tokens"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Unable to reduce token count below threshold even with 0 messages. "
|
||||||
|
f"Final count: {new_token_count} tokens"
|
||||||
|
)
|
||||||
|
# ABSOLUTE LAST RESORT: Drop system prompt
|
||||||
|
# This should only happen if summary itself is massive
|
||||||
|
if has_system_prompt and len(messages) > 1:
|
||||||
|
messages = messages[1:] # Drop system prompt
|
||||||
|
logger.critical(
|
||||||
|
"CRITICAL: Dropped system prompt as absolute last resort. "
|
||||||
|
"Behavioral consistency may be affected."
|
||||||
|
)
|
||||||
|
# Yield error to user
|
||||||
|
yield StreamError(
|
||||||
|
errorText=(
|
||||||
|
"Warning: System prompt dropped due to size constraints. "
|
||||||
|
"Assistant behavior may be affected."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No old messages to summarize - all messages are "recent"
|
||||||
|
# Apply progressive truncation to reduce token count
|
||||||
|
logger.warning(
|
||||||
|
f"Token count {token_count} exceeds threshold but no old messages to summarize. "
|
||||||
|
f"Applying progressive truncation to recent messages."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a base list excluding system prompt to avoid duplication
|
||||||
|
# This is the pool of messages we'll slice from in the loop
|
||||||
|
# Use messages_dict for type consistency with _ensure_tool_pairs_intact
|
||||||
|
base_msgs = (
|
||||||
|
messages_dict[1:] if has_system_prompt else messages_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try progressively smaller keep counts
|
||||||
|
new_token_count = token_count # Initialize with current count
|
||||||
|
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
||||||
|
if keep_count == 0:
|
||||||
|
# Try with just system prompt (no recent messages)
|
||||||
|
if has_system_prompt:
|
||||||
|
messages = [system_msg]
|
||||||
|
logger.info(
|
||||||
|
"Trying with 0 recent messages (system prompt only)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No system prompt and no recent messages = empty messages list
|
||||||
|
# This is invalid, skip this iteration
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if len(base_msgs) < keep_count:
|
||||||
|
continue # Skip if we don't have enough messages
|
||||||
|
|
||||||
|
# Slice from base_msgs to get recent messages (without system prompt)
|
||||||
|
recent_messages = base_msgs[-keep_count:]
|
||||||
|
|
||||||
|
# Ensure tool pairs stay intact in the reduced slice
|
||||||
|
reduced_slice_start = max(0, len(base_msgs) - keep_count)
|
||||||
|
recent_messages = _ensure_tool_pairs_intact(
|
||||||
|
recent_messages, base_msgs, reduced_slice_start
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_system_prompt:
|
||||||
|
messages = [system_msg] + recent_messages
|
||||||
|
else:
|
||||||
|
messages = recent_messages
|
||||||
|
|
||||||
|
new_messages_dict = []
|
||||||
|
for msg in messages:
|
||||||
|
if msg is None:
|
||||||
|
continue # Skip None messages (type safety)
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
msg_dict = {
|
||||||
|
k: v for k, v in msg.items() if v is not None
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
msg_dict = dict(msg)
|
||||||
|
new_messages_dict.append(msg_dict)
|
||||||
|
|
||||||
|
new_token_count = estimate_token_count(
|
||||||
|
new_messages_dict, model=token_count_model
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_token_count <= 120_000:
|
||||||
|
logger.info(
|
||||||
|
f"Reduced to {keep_count} recent messages, "
|
||||||
|
f"now {new_token_count} tokens"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Even with 0 messages still over limit
|
||||||
|
logger.error(
|
||||||
|
f"Unable to reduce token count below threshold even with 0 messages. "
|
||||||
|
f"Final count: {new_token_count} tokens. Messages may be extremely large."
|
||||||
|
)
|
||||||
|
# ABSOLUTE LAST RESORT: Drop system prompt
|
||||||
|
if has_system_prompt and len(messages) > 1:
|
||||||
|
messages = messages[1:] # Drop system prompt
|
||||||
|
logger.critical(
|
||||||
|
"CRITICAL: Dropped system prompt as absolute last resort. "
|
||||||
|
"Behavioral consistency may be affected."
|
||||||
|
)
|
||||||
|
# Yield error to user
|
||||||
|
yield StreamError(
|
||||||
|
errorText=(
|
||||||
|
"Warning: System prompt dropped due to size constraints. "
|
||||||
|
"Assistant behavior may be affected."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Context summarization failed: {e}", exc_info=True)
|
||||||
|
# If we were over the token limit, yield error to user
|
||||||
|
# Don't silently continue with oversized messages that will fail
|
||||||
|
if token_count > 120_000:
|
||||||
yield StreamError(
|
yield StreamError(
|
||||||
errorText=(
|
errorText=(
|
||||||
f"Context window management failed: {context_result.error}. "
|
f"Unable to manage context window (token limit exceeded: {token_count} tokens). "
|
||||||
"Please start a new conversation."
|
"Context summarization failed. Please start a new conversation."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
|
# Otherwise, continue with original messages (under limit)
|
||||||
messages = context_result.messages
|
|
||||||
if context_result.was_compacted:
|
|
||||||
logger.info(
|
|
||||||
f"Context compacted for streaming: {context_result.token_count} tokens"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Loop to handle tool calls and continue conversation
|
# Loop to handle tool calls and continue conversation
|
||||||
while True:
|
while True:
|
||||||
@@ -1294,6 +1369,14 @@ async def _stream_chat_chunks(
|
|||||||
:128
|
:128
|
||||||
] # OpenRouter limit
|
] # OpenRouter limit
|
||||||
|
|
||||||
|
# Create the stream with proper types
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionStreamOptionsParam,
|
||||||
|
)
|
||||||
|
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
@@ -1419,7 +1502,6 @@ async def _stream_chat_chunks(
|
|||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
last_error = e
|
last_error = e
|
||||||
|
|
||||||
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
|
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
|
||||||
retry_count += 1
|
retry_count += 1
|
||||||
# Calculate delay with exponential backoff
|
# Calculate delay with exponential backoff
|
||||||
@@ -1435,24 +1517,12 @@ async def _stream_chat_chunks(
|
|||||||
continue # Retry the stream
|
continue # Retry the stream
|
||||||
else:
|
else:
|
||||||
# Non-retryable error or max retries exceeded
|
# Non-retryable error or max retries exceeded
|
||||||
_log_api_error(
|
logger.error(
|
||||||
error=e,
|
f"Error in stream (not retrying): {e!s}",
|
||||||
session_id=session.session_id if session else None,
|
exc_info=True,
|
||||||
message_count=len(messages) if messages else None,
|
|
||||||
model=model,
|
|
||||||
retry_count=retry_count,
|
|
||||||
)
|
)
|
||||||
error_code = None
|
error_code = None
|
||||||
error_text = str(e)
|
error_text = str(e)
|
||||||
|
|
||||||
error_details = _extract_api_error_details(e)
|
|
||||||
if error_details.get("response_body"):
|
|
||||||
body = error_details["response_body"]
|
|
||||||
if isinstance(body, dict) and body.get("error", {}).get(
|
|
||||||
"message"
|
|
||||||
):
|
|
||||||
error_text = body["error"]["message"]
|
|
||||||
|
|
||||||
if _is_region_blocked_error(e):
|
if _is_region_blocked_error(e):
|
||||||
error_code = "MODEL_NOT_AVAILABLE_REGION"
|
error_code = "MODEL_NOT_AVAILABLE_REGION"
|
||||||
error_text = (
|
error_text = (
|
||||||
@@ -1469,12 +1539,9 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
# If we exit the retry loop without returning, it means we exhausted retries
|
# If we exit the retry loop without returning, it means we exhausted retries
|
||||||
if last_error:
|
if last_error:
|
||||||
_log_api_error(
|
logger.error(
|
||||||
error=last_error,
|
f"Max retries ({MAX_RETRIES}) exceeded. Last error: {last_error!s}",
|
||||||
session_id=session.session_id if session else None,
|
exc_info=True,
|
||||||
message_count=len(messages) if messages else None,
|
|
||||||
model=model,
|
|
||||||
retry_count=MAX_RETRIES,
|
|
||||||
)
|
)
|
||||||
yield StreamError(errorText=f"Max retries exceeded: {last_error!s}")
|
yield StreamError(errorText=f"Max retries exceeded: {last_error!s}")
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
@@ -1833,36 +1900,17 @@ async def _generate_llm_continuation(
|
|||||||
# Build system prompt
|
# Build system prompt
|
||||||
system_prompt, _ = await _build_system_prompt(user_id)
|
system_prompt, _ = await _build_system_prompt(user_id)
|
||||||
|
|
||||||
|
# Build messages in OpenAI format
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||||
|
|
||||||
system_message = ChatCompletionSystemMessageParam(
|
system_message = ChatCompletionSystemMessageParam(
|
||||||
role="system",
|
role="system",
|
||||||
content=system_prompt,
|
content=system_prompt,
|
||||||
)
|
)
|
||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Apply context window management to prevent oversized requests
|
|
||||||
context_result = await _manage_context_window(
|
|
||||||
messages=messages,
|
|
||||||
model=config.model,
|
|
||||||
api_key=config.api_key,
|
|
||||||
base_url=config.base_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
if context_result.error and "System prompt dropped" not in context_result.error:
|
|
||||||
logger.error(
|
|
||||||
f"Context window management failed for session {session_id}: "
|
|
||||||
f"{context_result.error} (tokens={context_result.token_count})"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
messages = context_result.messages
|
|
||||||
if context_result.was_compacted:
|
|
||||||
logger.info(
|
|
||||||
f"Context compacted for LLM continuation: "
|
|
||||||
f"{context_result.token_count} tokens"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build extra_body for tracing
|
# Build extra_body for tracing
|
||||||
extra_body: dict[str, Any] = {
|
extra_body: dict[str, Any] = {
|
||||||
"posthogProperties": {
|
"posthogProperties": {
|
||||||
@@ -1875,61 +1923,19 @@ async def _generate_llm_continuation(
|
|||||||
if session_id:
|
if session_id:
|
||||||
extra_body["session_id"] = session_id[:128]
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
retry_count = 0
|
# Make non-streaming LLM call (no tools - just text response)
|
||||||
last_error: Exception | None = None
|
from typing import cast
|
||||||
response = None
|
|
||||||
|
|
||||||
while retry_count <= MAX_RETRIES:
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
f"Generating LLM continuation for session {session_id}"
|
|
||||||
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await client.chat.completions.create(
|
# No tools parameter = text-only response (no tool calls)
|
||||||
model=config.model,
|
response = await client.chat.completions.create(
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
model=config.model,
|
||||||
extra_body=extra_body,
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
)
|
extra_body=extra_body,
|
||||||
last_error = None # Clear any previous error on success
|
)
|
||||||
break # Success, exit retry loop
|
|
||||||
except Exception as e:
|
|
||||||
last_error = e
|
|
||||||
|
|
||||||
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
|
if response.choices and response.choices[0].message.content:
|
||||||
retry_count += 1
|
|
||||||
delay = min(
|
|
||||||
BASE_DELAY_SECONDS * (2 ** (retry_count - 1)),
|
|
||||||
MAX_DELAY_SECONDS,
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
f"Retryable error in LLM continuation: {e!s}. "
|
|
||||||
f"Retrying in {delay:.1f}s (attempt {retry_count}/{MAX_RETRIES})"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# Non-retryable error - log details and exit gracefully
|
|
||||||
_log_api_error(
|
|
||||||
error=e,
|
|
||||||
session_id=session_id,
|
|
||||||
message_count=len(messages) if messages else None,
|
|
||||||
model=config.model,
|
|
||||||
retry_count=retry_count,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if last_error:
|
|
||||||
_log_api_error(
|
|
||||||
error=last_error,
|
|
||||||
session_id=session_id,
|
|
||||||
message_count=len(messages) if messages else None,
|
|
||||||
model=config.model,
|
|
||||||
retry_count=MAX_RETRIES,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if response and response.choices and response.choices[0].message.content:
|
|
||||||
assistant_content = response.choices[0].message.content
|
assistant_content = response.choices[0].message.content
|
||||||
|
|
||||||
# Reload session from DB to avoid race condition with user messages
|
# Reload session from DB to avoid race condition with user messages
|
||||||
@@ -1963,78 +1969,3 @@ async def _generate_llm_continuation(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def _log_api_error(
|
|
||||||
error: Exception,
|
|
||||||
session_id: str | None = None,
|
|
||||||
message_count: int | None = None,
|
|
||||||
model: str | None = None,
|
|
||||||
retry_count: int = 0,
|
|
||||||
) -> None:
|
|
||||||
"""Log detailed API error information for debugging."""
|
|
||||||
details = _extract_api_error_details(error)
|
|
||||||
details["session_id"] = session_id
|
|
||||||
details["message_count"] = message_count
|
|
||||||
details["model"] = model
|
|
||||||
details["retry_count"] = retry_count
|
|
||||||
|
|
||||||
if isinstance(error, RateLimitError):
|
|
||||||
logger.warning(f"Rate limit error: {details}")
|
|
||||||
elif isinstance(error, APIConnectionError):
|
|
||||||
logger.warning(f"API connection error: {details}")
|
|
||||||
elif isinstance(error, APIStatusError) and error.status_code >= 500:
|
|
||||||
logger.error(f"API server error (5xx): {details}")
|
|
||||||
else:
|
|
||||||
logger.error(f"API error: {details}")
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_api_error_details(error: Exception) -> dict[str, Any]:
|
|
||||||
"""Extract detailed information from OpenAI/OpenRouter API errors."""
|
|
||||||
error_msg = str(error)
|
|
||||||
details: dict[str, Any] = {
|
|
||||||
"error_type": type(error).__name__,
|
|
||||||
"error_message": error_msg[:500] + "..." if len(error_msg) > 500 else error_msg,
|
|
||||||
}
|
|
||||||
|
|
||||||
if hasattr(error, "code"):
|
|
||||||
details["code"] = error.code
|
|
||||||
if hasattr(error, "param"):
|
|
||||||
details["param"] = error.param
|
|
||||||
|
|
||||||
if isinstance(error, APIStatusError):
|
|
||||||
details["status_code"] = error.status_code
|
|
||||||
details["request_id"] = getattr(error, "request_id", None)
|
|
||||||
|
|
||||||
if hasattr(error, "body") and error.body:
|
|
||||||
details["response_body"] = _sanitize_error_body(error.body)
|
|
||||||
|
|
||||||
if hasattr(error, "response") and error.response:
|
|
||||||
headers = error.response.headers
|
|
||||||
details["openrouter_provider"] = headers.get("x-openrouter-provider")
|
|
||||||
details["openrouter_model"] = headers.get("x-openrouter-model")
|
|
||||||
details["retry_after"] = headers.get("retry-after")
|
|
||||||
details["rate_limit_remaining"] = headers.get("x-ratelimit-remaining")
|
|
||||||
|
|
||||||
return details
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_error_body(body: Any, max_length: int = 2000) -> dict[str, Any] | None:
|
|
||||||
"""Extract only safe fields from error response body to avoid logging sensitive data."""
|
|
||||||
if not isinstance(body, dict):
|
|
||||||
return None
|
|
||||||
|
|
||||||
safe_fields = ("message", "type", "code", "param", "error")
|
|
||||||
sanitized: dict[str, Any] = {}
|
|
||||||
|
|
||||||
for field in safe_fields:
|
|
||||||
if field in body:
|
|
||||||
value = body[field]
|
|
||||||
if field == "error" and isinstance(value, dict):
|
|
||||||
sanitized[field] = _sanitize_error_body(value, max_length)
|
|
||||||
elif isinstance(value, str) and len(value) > max_length:
|
|
||||||
sanitized[field] = value[:max_length] + "...[truncated]"
|
|
||||||
else:
|
|
||||||
sanitized[field] = value
|
|
||||||
|
|
||||||
return sanitized if sanitized else None
|
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from backend.data.graph import (
|
|||||||
create_graph,
|
create_graph,
|
||||||
get_graph,
|
get_graph,
|
||||||
get_graph_all_versions,
|
get_graph_all_versions,
|
||||||
get_store_listed_graphs,
|
|
||||||
)
|
)
|
||||||
from backend.util.exceptions import DatabaseError, NotFoundError
|
from backend.util.exceptions import DatabaseError, NotFoundError
|
||||||
|
|
||||||
@@ -267,18 +266,18 @@ async def get_library_agents_for_generation(
|
|||||||
async def search_marketplace_agents_for_generation(
|
async def search_marketplace_agents_for_generation(
|
||||||
search_query: str,
|
search_query: str,
|
||||||
max_results: int = 10,
|
max_results: int = 10,
|
||||||
) -> list[LibraryAgentSummary]:
|
) -> list[MarketplaceAgentSummary]:
|
||||||
"""Search marketplace agents formatted for Agent Generator.
|
"""Search marketplace agents formatted for Agent Generator.
|
||||||
|
|
||||||
Fetches marketplace agents and their full schemas so they can be used
|
Note: This returns basic agent info. Full input/output schemas would require
|
||||||
as sub-agents in generated workflows.
|
additional graph fetches and is a potential future enhancement.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
search_query: Search term to find relevant public agents
|
search_query: Search term to find relevant public agents
|
||||||
max_results: Maximum number of agents to return (default 10)
|
max_results: Maximum number of agents to return (default 10)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of LibraryAgentSummary with full input/output schemas
|
List of MarketplaceAgentSummary (without detailed schemas for now)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
response = await store_db.get_store_agents(
|
response = await store_db.get_store_agents(
|
||||||
@@ -287,31 +286,17 @@ async def search_marketplace_agents_for_generation(
|
|||||||
page_size=max_results,
|
page_size=max_results,
|
||||||
)
|
)
|
||||||
|
|
||||||
agents_with_graphs = [
|
results: list[MarketplaceAgentSummary] = []
|
||||||
agent for agent in response.agents if agent.agent_graph_id
|
for agent in response.agents:
|
||||||
]
|
results.append(
|
||||||
|
MarketplaceAgentSummary(
|
||||||
if not agents_with_graphs:
|
name=agent.agent_name,
|
||||||
return []
|
description=agent.description,
|
||||||
|
sub_heading=agent.sub_heading,
|
||||||
graph_ids = [agent.agent_graph_id for agent in agents_with_graphs]
|
creator=agent.creator,
|
||||||
graphs = await get_store_listed_graphs(*graph_ids)
|
is_marketplace_agent=True,
|
||||||
|
|
||||||
results: list[LibraryAgentSummary] = []
|
|
||||||
for agent in agents_with_graphs:
|
|
||||||
graph_id = agent.agent_graph_id
|
|
||||||
if graph_id and graph_id in graphs:
|
|
||||||
graph = graphs[graph_id]
|
|
||||||
results.append(
|
|
||||||
LibraryAgentSummary(
|
|
||||||
graph_id=graph.id,
|
|
||||||
graph_version=graph.version,
|
|
||||||
name=agent.agent_name,
|
|
||||||
description=agent.description,
|
|
||||||
input_schema=graph.input_schema,
|
|
||||||
output_schema=graph.output_schema,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
return results
|
return results
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to search marketplace agents: {e}")
|
logger.warning(f"Failed to search marketplace agents: {e}")
|
||||||
@@ -342,7 +327,8 @@ async def get_all_relevant_agents_for_generation(
|
|||||||
max_marketplace_results: Max marketplace agents to return (default 10)
|
max_marketplace_results: Max marketplace agents to return (default 10)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List of AgentSummary with full schemas (both library and marketplace agents)
|
List of AgentSummary, library agents first (with full schemas),
|
||||||
|
then marketplace agents (basic info only)
|
||||||
"""
|
"""
|
||||||
agents: list[AgentSummary] = []
|
agents: list[AgentSummary] = []
|
||||||
seen_graph_ids: set[str] = set()
|
seen_graph_ids: set[str] = set()
|
||||||
@@ -379,11 +365,16 @@ async def get_all_relevant_agents_for_generation(
|
|||||||
search_query=search_query,
|
search_query=search_query,
|
||||||
max_results=max_marketplace_results,
|
max_results=max_marketplace_results,
|
||||||
)
|
)
|
||||||
|
library_names: set[str] = set()
|
||||||
|
for a in agents:
|
||||||
|
name = a.get("name")
|
||||||
|
if name and isinstance(name, str):
|
||||||
|
library_names.add(name.lower())
|
||||||
for agent in marketplace_agents:
|
for agent in marketplace_agents:
|
||||||
graph_id = agent.get("graph_id")
|
agent_name = agent.get("name")
|
||||||
if graph_id and graph_id not in seen_graph_ids:
|
if agent_name and isinstance(agent_name, str):
|
||||||
agents.append(agent)
|
if agent_name.lower() not in library_names:
|
||||||
seen_graph_ids.add(graph_id)
|
agents.append(agent)
|
||||||
|
|
||||||
return agents
|
return agents
|
||||||
|
|
||||||
|
|||||||
@@ -139,10 +139,11 @@ async def decompose_goal_external(
|
|||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
if context:
|
# Build the request payload
|
||||||
description = f"{description}\n\nAdditional context from user:\n{context}"
|
|
||||||
|
|
||||||
payload: dict[str, Any] = {"description": description}
|
payload: dict[str, Any] = {"description": description}
|
||||||
|
if context:
|
||||||
|
# The external service uses user_instruction for additional context
|
||||||
|
payload["user_instruction"] = context
|
||||||
if library_agents:
|
if library_agents:
|
||||||
payload["library_agents"] = library_agents
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
|
|||||||
@@ -112,7 +112,6 @@ async def get_store_agents(
|
|||||||
description=agent["description"],
|
description=agent["description"],
|
||||||
runs=agent["runs"],
|
runs=agent["runs"],
|
||||||
rating=agent["rating"],
|
rating=agent["rating"],
|
||||||
agent_graph_id=agent.get("agentGraphId", ""),
|
|
||||||
)
|
)
|
||||||
store_agents.append(store_agent)
|
store_agents.append(store_agent)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -171,7 +170,6 @@ async def get_store_agents(
|
|||||||
description=agent.description,
|
description=agent.description,
|
||||||
runs=agent.runs,
|
runs=agent.runs,
|
||||||
rating=agent.rating,
|
rating=agent.rating,
|
||||||
agent_graph_id=agent.agentGraphId,
|
|
||||||
)
|
)
|
||||||
# Add to the list only if creation was successful
|
# Add to the list only if creation was successful
|
||||||
store_agents.append(store_agent)
|
store_agents.append(store_agent)
|
||||||
|
|||||||
@@ -600,7 +600,6 @@ async def hybrid_search(
|
|||||||
sa.featured,
|
sa.featured,
|
||||||
sa.is_available,
|
sa.is_available,
|
||||||
sa.updated_at,
|
sa.updated_at,
|
||||||
sa."agentGraphId",
|
|
||||||
-- Searchable text for BM25 reranking
|
-- Searchable text for BM25 reranking
|
||||||
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
COALESCE(sa.agent_name, '') || ' ' || COALESCE(sa.sub_heading, '') || ' ' || COALESCE(sa.description, '') as searchable_text,
|
||||||
-- Semantic score
|
-- Semantic score
|
||||||
@@ -660,7 +659,6 @@ async def hybrid_search(
|
|||||||
featured,
|
featured,
|
||||||
is_available,
|
is_available,
|
||||||
updated_at,
|
updated_at,
|
||||||
"agentGraphId",
|
|
||||||
searchable_text,
|
searchable_text,
|
||||||
semantic_score,
|
semantic_score,
|
||||||
lexical_score,
|
lexical_score,
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ class StoreAgent(pydantic.BaseModel):
|
|||||||
description: str
|
description: str
|
||||||
runs: int
|
runs: int
|
||||||
rating: float
|
rating: float
|
||||||
agent_graph_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class StoreAgentsResponse(pydantic.BaseModel):
|
class StoreAgentsResponse(pydantic.BaseModel):
|
||||||
|
|||||||
@@ -26,13 +26,11 @@ def test_store_agent():
|
|||||||
description="Test description",
|
description="Test description",
|
||||||
runs=50,
|
runs=50,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
agent_graph_id="test-graph-id",
|
|
||||||
)
|
)
|
||||||
assert agent.slug == "test-agent"
|
assert agent.slug == "test-agent"
|
||||||
assert agent.agent_name == "Test Agent"
|
assert agent.agent_name == "Test Agent"
|
||||||
assert agent.runs == 50
|
assert agent.runs == 50
|
||||||
assert agent.rating == 4.5
|
assert agent.rating == 4.5
|
||||||
assert agent.agent_graph_id == "test-graph-id"
|
|
||||||
|
|
||||||
|
|
||||||
def test_store_agents_response():
|
def test_store_agents_response():
|
||||||
@@ -48,7 +46,6 @@ def test_store_agents_response():
|
|||||||
description="Test description",
|
description="Test description",
|
||||||
runs=50,
|
runs=50,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
agent_graph_id="test-graph-id",
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
|
|||||||
@@ -82,7 +82,6 @@ def test_get_agents_featured(
|
|||||||
description="Featured agent description",
|
description="Featured agent description",
|
||||||
runs=100,
|
runs=100,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
agent_graph_id="test-graph-1",
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -128,7 +127,6 @@ def test_get_agents_by_creator(
|
|||||||
description="Creator agent description",
|
description="Creator agent description",
|
||||||
runs=50,
|
runs=50,
|
||||||
rating=4.0,
|
rating=4.0,
|
||||||
agent_graph_id="test-graph-2",
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -174,7 +172,6 @@ def test_get_agents_sorted(
|
|||||||
description="Top agent description",
|
description="Top agent description",
|
||||||
runs=1000,
|
runs=1000,
|
||||||
rating=5.0,
|
rating=5.0,
|
||||||
agent_graph_id="test-graph-3",
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -220,7 +217,6 @@ def test_get_agents_search(
|
|||||||
description="Specific search term description",
|
description="Specific search term description",
|
||||||
runs=75,
|
runs=75,
|
||||||
rating=4.2,
|
rating=4.2,
|
||||||
agent_graph_id="test-graph-search",
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -266,7 +262,6 @@ def test_get_agents_category(
|
|||||||
description="Category agent description",
|
description="Category agent description",
|
||||||
runs=60,
|
runs=60,
|
||||||
rating=4.1,
|
rating=4.1,
|
||||||
agent_graph_id="test-graph-category",
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=store_model.Pagination(
|
pagination=store_model.Pagination(
|
||||||
@@ -311,7 +306,6 @@ def test_get_agents_pagination(
|
|||||||
description=f"Agent {i} description",
|
description=f"Agent {i} description",
|
||||||
runs=i * 10,
|
runs=i * 10,
|
||||||
rating=4.0,
|
rating=4.0,
|
||||||
agent_graph_id="test-graph-2",
|
|
||||||
)
|
)
|
||||||
for i in range(5)
|
for i in range(5)
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ class TestCacheDeletion:
|
|||||||
description="Test description",
|
description="Test description",
|
||||||
runs=100,
|
runs=100,
|
||||||
rating=4.5,
|
rating=4.5,
|
||||||
agent_graph_id="test-graph-id",
|
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
pagination=Pagination(
|
pagination=Pagination(
|
||||||
|
|||||||
@@ -66,24 +66,18 @@ async def event_broadcaster(manager: ConnectionManager):
|
|||||||
execution_bus = AsyncRedisExecutionEventBus()
|
execution_bus = AsyncRedisExecutionEventBus()
|
||||||
notification_bus = AsyncRedisNotificationEventBus()
|
notification_bus = AsyncRedisNotificationEventBus()
|
||||||
|
|
||||||
try:
|
async def execution_worker():
|
||||||
|
async for event in execution_bus.listen("*"):
|
||||||
|
await manager.send_execution_update(event)
|
||||||
|
|
||||||
async def execution_worker():
|
async def notification_worker():
|
||||||
async for event in execution_bus.listen("*"):
|
async for notification in notification_bus.listen("*"):
|
||||||
await manager.send_execution_update(event)
|
await manager.send_notification(
|
||||||
|
user_id=notification.user_id,
|
||||||
|
payload=notification.payload,
|
||||||
|
)
|
||||||
|
|
||||||
async def notification_worker():
|
await asyncio.gather(execution_worker(), notification_worker())
|
||||||
async for notification in notification_bus.listen("*"):
|
|
||||||
await manager.send_notification(
|
|
||||||
user_id=notification.user_id,
|
|
||||||
payload=notification.payload,
|
|
||||||
)
|
|
||||||
|
|
||||||
await asyncio.gather(execution_worker(), notification_worker())
|
|
||||||
finally:
|
|
||||||
# Ensure PubSub connections are closed on any exit to prevent leaks
|
|
||||||
await execution_bus.close()
|
|
||||||
await notification_bus.close()
|
|
||||||
|
|
||||||
|
|
||||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||||
|
|||||||
@@ -133,23 +133,10 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
|
|||||||
|
|
||||||
|
|
||||||
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||||
def __init__(self):
|
|
||||||
self._pubsub: AsyncPubSub | None = None
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
async def connection(self) -> redis.AsyncRedis:
|
async def connection(self) -> redis.AsyncRedis:
|
||||||
return await redis.get_redis_async()
|
return await redis.get_redis_async()
|
||||||
|
|
||||||
async def close(self) -> None:
|
|
||||||
"""Close the PubSub connection if it exists."""
|
|
||||||
if self._pubsub is not None:
|
|
||||||
try:
|
|
||||||
await self._pubsub.close()
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to close PubSub connection", exc_info=True)
|
|
||||||
finally:
|
|
||||||
self._pubsub = None
|
|
||||||
|
|
||||||
async def publish_event(self, event: M, channel_key: str):
|
async def publish_event(self, event: M, channel_key: str):
|
||||||
"""
|
"""
|
||||||
Publish an event to Redis. Gracefully handles connection failures
|
Publish an event to Redis. Gracefully handles connection failures
|
||||||
@@ -170,7 +157,6 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
|||||||
await self.connection, channel_key
|
await self.connection, channel_key
|
||||||
)
|
)
|
||||||
assert isinstance(pubsub, AsyncPubSub)
|
assert isinstance(pubsub, AsyncPubSub)
|
||||||
self._pubsub = pubsub
|
|
||||||
|
|
||||||
if "*" in channel_key:
|
if "*" in channel_key:
|
||||||
await pubsub.psubscribe(full_channel_name)
|
await pubsub.psubscribe(full_channel_name)
|
||||||
|
|||||||
@@ -1028,39 +1028,6 @@ async def get_graph(
|
|||||||
return GraphModel.from_db(graph, for_export)
|
return GraphModel.from_db(graph, for_export)
|
||||||
|
|
||||||
|
|
||||||
async def get_store_listed_graphs(*graph_ids: str) -> dict[str, GraphModel]:
|
|
||||||
"""Batch-fetch multiple store-listed graphs by their IDs.
|
|
||||||
|
|
||||||
Only returns graphs that have approved store listings (publicly available).
|
|
||||||
Does not require permission checks since store-listed graphs are public.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
*graph_ids: Variable number of graph IDs to fetch
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict mapping graph_id to GraphModel for graphs with approved store listings
|
|
||||||
"""
|
|
||||||
if not graph_ids:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
store_listings = await StoreListingVersion.prisma().find_many(
|
|
||||||
where={
|
|
||||||
"agentGraphId": {"in": list(graph_ids)},
|
|
||||||
"submissionStatus": SubmissionStatus.APPROVED,
|
|
||||||
"isDeleted": False,
|
|
||||||
},
|
|
||||||
include={"AgentGraph": {"include": AGENT_GRAPH_INCLUDE}},
|
|
||||||
distinct=["agentGraphId"],
|
|
||||||
order={"agentGraphVersion": "desc"},
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
listing.agentGraphId: GraphModel.from_db(listing.AgentGraph)
|
|
||||||
for listing in store_listings
|
|
||||||
if listing.AgentGraph
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def get_graph_as_admin(
|
async def get_graph_as_admin(
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
version: int | None = None,
|
version: int | None = None,
|
||||||
|
|||||||
@@ -1,39 +0,0 @@
|
|||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
from fastapi.routing import APIRoute
|
|
||||||
|
|
||||||
from backend.api.features.integrations.router import router as integrations_router
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.integrations.webhooks import utils as webhooks_utils
|
|
||||||
|
|
||||||
|
|
||||||
def test_webhook_ingress_url_matches_route(monkeypatch) -> None:
|
|
||||||
app = fastapi.FastAPI()
|
|
||||||
app.include_router(integrations_router, prefix="/api/integrations")
|
|
||||||
|
|
||||||
provider = ProviderName.GITHUB
|
|
||||||
webhook_id = "webhook_123"
|
|
||||||
base_url = "https://example.com"
|
|
||||||
|
|
||||||
monkeypatch.setattr(webhooks_utils.app_config, "platform_base_url", base_url)
|
|
||||||
|
|
||||||
route = next(
|
|
||||||
route
|
|
||||||
for route in integrations_router.routes
|
|
||||||
if isinstance(route, APIRoute)
|
|
||||||
and route.path == "/{provider}/webhooks/{webhook_id}/ingress"
|
|
||||||
and "POST" in route.methods
|
|
||||||
)
|
|
||||||
expected_path = f"/api/integrations{route.path}".format(
|
|
||||||
provider=provider.value,
|
|
||||||
webhook_id=webhook_id,
|
|
||||||
)
|
|
||||||
actual_url = urlparse(webhooks_utils.webhook_ingress_url(provider, webhook_id))
|
|
||||||
expected_base = urlparse(base_url)
|
|
||||||
|
|
||||||
assert (actual_url.scheme, actual_url.netloc) == (
|
|
||||||
expected_base.scheme,
|
|
||||||
expected_base.netloc,
|
|
||||||
)
|
|
||||||
assert actual_url.path == expected_path
|
|
||||||
@@ -9,8 +9,7 @@
|
|||||||
"sub_heading": "Creator agent subheading",
|
"sub_heading": "Creator agent subheading",
|
||||||
"description": "Creator agent description",
|
"description": "Creator agent description",
|
||||||
"runs": 50,
|
"runs": 50,
|
||||||
"rating": 4.0,
|
"rating": 4.0
|
||||||
"agent_graph_id": "test-graph-2"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pagination": {
|
"pagination": {
|
||||||
|
|||||||
@@ -9,8 +9,7 @@
|
|||||||
"sub_heading": "Category agent subheading",
|
"sub_heading": "Category agent subheading",
|
||||||
"description": "Category agent description",
|
"description": "Category agent description",
|
||||||
"runs": 60,
|
"runs": 60,
|
||||||
"rating": 4.1,
|
"rating": 4.1
|
||||||
"agent_graph_id": "test-graph-category"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pagination": {
|
"pagination": {
|
||||||
|
|||||||
@@ -9,8 +9,7 @@
|
|||||||
"sub_heading": "Agent 0 subheading",
|
"sub_heading": "Agent 0 subheading",
|
||||||
"description": "Agent 0 description",
|
"description": "Agent 0 description",
|
||||||
"runs": 0,
|
"runs": 0,
|
||||||
"rating": 4.0,
|
"rating": 4.0
|
||||||
"agent_graph_id": "test-graph-2"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "agent-1",
|
"slug": "agent-1",
|
||||||
@@ -21,8 +20,7 @@
|
|||||||
"sub_heading": "Agent 1 subheading",
|
"sub_heading": "Agent 1 subheading",
|
||||||
"description": "Agent 1 description",
|
"description": "Agent 1 description",
|
||||||
"runs": 10,
|
"runs": 10,
|
||||||
"rating": 4.0,
|
"rating": 4.0
|
||||||
"agent_graph_id": "test-graph-2"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "agent-2",
|
"slug": "agent-2",
|
||||||
@@ -33,8 +31,7 @@
|
|||||||
"sub_heading": "Agent 2 subheading",
|
"sub_heading": "Agent 2 subheading",
|
||||||
"description": "Agent 2 description",
|
"description": "Agent 2 description",
|
||||||
"runs": 20,
|
"runs": 20,
|
||||||
"rating": 4.0,
|
"rating": 4.0
|
||||||
"agent_graph_id": "test-graph-2"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "agent-3",
|
"slug": "agent-3",
|
||||||
@@ -45,8 +42,7 @@
|
|||||||
"sub_heading": "Agent 3 subheading",
|
"sub_heading": "Agent 3 subheading",
|
||||||
"description": "Agent 3 description",
|
"description": "Agent 3 description",
|
||||||
"runs": 30,
|
"runs": 30,
|
||||||
"rating": 4.0,
|
"rating": 4.0
|
||||||
"agent_graph_id": "test-graph-2"
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"slug": "agent-4",
|
"slug": "agent-4",
|
||||||
@@ -57,8 +53,7 @@
|
|||||||
"sub_heading": "Agent 4 subheading",
|
"sub_heading": "Agent 4 subheading",
|
||||||
"description": "Agent 4 description",
|
"description": "Agent 4 description",
|
||||||
"runs": 40,
|
"runs": 40,
|
||||||
"rating": 4.0,
|
"rating": 4.0
|
||||||
"agent_graph_id": "test-graph-2"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pagination": {
|
"pagination": {
|
||||||
|
|||||||
@@ -9,8 +9,7 @@
|
|||||||
"sub_heading": "Search agent subheading",
|
"sub_heading": "Search agent subheading",
|
||||||
"description": "Specific search term description",
|
"description": "Specific search term description",
|
||||||
"runs": 75,
|
"runs": 75,
|
||||||
"rating": 4.2,
|
"rating": 4.2
|
||||||
"agent_graph_id": "test-graph-search"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pagination": {
|
"pagination": {
|
||||||
|
|||||||
@@ -9,8 +9,7 @@
|
|||||||
"sub_heading": "Top agent subheading",
|
"sub_heading": "Top agent subheading",
|
||||||
"description": "Top agent description",
|
"description": "Top agent description",
|
||||||
"runs": 1000,
|
"runs": 1000,
|
||||||
"rating": 5.0,
|
"rating": 5.0
|
||||||
"agent_graph_id": "test-graph-3"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pagination": {
|
"pagination": {
|
||||||
|
|||||||
@@ -9,8 +9,7 @@
|
|||||||
"sub_heading": "Featured agent subheading",
|
"sub_heading": "Featured agent subheading",
|
||||||
"description": "Featured agent description",
|
"description": "Featured agent description",
|
||||||
"runs": 100,
|
"runs": 100,
|
||||||
"rating": 4.5,
|
"rating": 4.5
|
||||||
"agent_graph_id": "test-graph-1"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"pagination": {
|
"pagination": {
|
||||||
|
|||||||
@@ -134,28 +134,15 @@ class TestSearchMarketplaceAgentsForGeneration:
|
|||||||
description="A public agent",
|
description="A public agent",
|
||||||
sub_heading="Does something useful",
|
sub_heading="Does something useful",
|
||||||
creator="creator-1",
|
creator="creator-1",
|
||||||
agent_graph_id="graph-123",
|
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
mock_graph = MagicMock()
|
# The store_db is dynamically imported, so patch the import path
|
||||||
mock_graph.id = "graph-123"
|
with patch(
|
||||||
mock_graph.version = 1
|
"backend.api.features.store.db.get_store_agents",
|
||||||
mock_graph.input_schema = {"type": "object"}
|
new_callable=AsyncMock,
|
||||||
mock_graph.output_schema = {"type": "object"}
|
return_value=mock_response,
|
||||||
|
) as mock_search:
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"backend.api.features.store.db.get_store_agents",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value=mock_response,
|
|
||||||
) as mock_search,
|
|
||||||
patch(
|
|
||||||
"backend.api.features.chat.tools.agent_generator.core.get_store_listed_graphs",
|
|
||||||
new_callable=AsyncMock,
|
|
||||||
return_value={"graph-123": mock_graph},
|
|
||||||
),
|
|
||||||
):
|
|
||||||
result = await core.search_marketplace_agents_for_generation(
|
result = await core.search_marketplace_agents_for_generation(
|
||||||
search_query="automation",
|
search_query="automation",
|
||||||
max_results=10,
|
max_results=10,
|
||||||
@@ -169,7 +156,7 @@ class TestSearchMarketplaceAgentsForGeneration:
|
|||||||
|
|
||||||
assert len(result) == 1
|
assert len(result) == 1
|
||||||
assert result[0]["name"] == "Public Agent"
|
assert result[0]["name"] == "Public Agent"
|
||||||
assert result[0]["graph_id"] == "graph-123"
|
assert result[0]["is_marketplace_agent"] is True
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_handles_marketplace_error_gracefully(self):
|
async def test_handles_marketplace_error_gracefully(self):
|
||||||
@@ -206,12 +193,11 @@ class TestGetAllRelevantAgentsForGeneration:
|
|||||||
|
|
||||||
marketplace_agents = [
|
marketplace_agents = [
|
||||||
{
|
{
|
||||||
"graph_id": "market-456",
|
|
||||||
"graph_version": 1,
|
|
||||||
"name": "Market Agent",
|
"name": "Market Agent",
|
||||||
"description": "From marketplace",
|
"description": "From marketplace",
|
||||||
"input_schema": {},
|
"sub_heading": "Sub heading",
|
||||||
"output_schema": {},
|
"creator": "creator-1",
|
||||||
|
"is_marketplace_agent": True,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -239,11 +225,11 @@ class TestGetAllRelevantAgentsForGeneration:
|
|||||||
assert result[1]["name"] == "Market Agent"
|
assert result[1]["name"] == "Market Agent"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_deduplicates_by_graph_id(self):
|
async def test_deduplicates_by_name(self):
|
||||||
"""Test that marketplace agents with same graph_id as library are excluded."""
|
"""Test that marketplace agents with same name as library are excluded."""
|
||||||
library_agents = [
|
library_agents = [
|
||||||
{
|
{
|
||||||
"graph_id": "shared-123",
|
"graph_id": "lib-123",
|
||||||
"graph_version": 1,
|
"graph_version": 1,
|
||||||
"name": "Shared Agent",
|
"name": "Shared Agent",
|
||||||
"description": "From library",
|
"description": "From library",
|
||||||
@@ -254,20 +240,18 @@ class TestGetAllRelevantAgentsForGeneration:
|
|||||||
|
|
||||||
marketplace_agents = [
|
marketplace_agents = [
|
||||||
{
|
{
|
||||||
"graph_id": "shared-123", # Same graph_id, should be deduplicated
|
"name": "Shared Agent", # Same name, should be deduplicated
|
||||||
"graph_version": 1,
|
|
||||||
"name": "Shared Agent",
|
|
||||||
"description": "From marketplace",
|
"description": "From marketplace",
|
||||||
"input_schema": {},
|
"sub_heading": "Sub heading",
|
||||||
"output_schema": {},
|
"creator": "creator-1",
|
||||||
|
"is_marketplace_agent": True,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"graph_id": "unique-456",
|
|
||||||
"graph_version": 1,
|
|
||||||
"name": "Unique Agent",
|
"name": "Unique Agent",
|
||||||
"description": "Only in marketplace",
|
"description": "Only in marketplace",
|
||||||
"input_schema": {},
|
"sub_heading": "Sub heading",
|
||||||
"output_schema": {},
|
"creator": "creator-2",
|
||||||
|
"is_marketplace_agent": True,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -289,7 +273,7 @@ class TestGetAllRelevantAgentsForGeneration:
|
|||||||
include_marketplace=True,
|
include_marketplace=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Shared Agent from marketplace should be excluded by graph_id
|
# Shared Agent from marketplace should be excluded
|
||||||
assert len(result) == 2
|
assert len(result) == 2
|
||||||
names = [a["name"] for a in result]
|
names = [a["name"] for a in result]
|
||||||
assert "Shared Agent" in names
|
assert "Shared Agent" in names
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ class TestDecomposeGoalExternal:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_decompose_goal_with_context(self):
|
async def test_decompose_goal_with_context(self):
|
||||||
"""Test decomposition with additional context enriched into description."""
|
"""Test decomposition with additional context."""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -119,12 +119,9 @@ class TestDecomposeGoalExternal:
|
|||||||
"Build a chatbot", context="Use Python"
|
"Build a chatbot", context="Use Python"
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_description = (
|
|
||||||
"Build a chatbot\n\nAdditional context from user:\nUse Python"
|
|
||||||
)
|
|
||||||
mock_client.post.assert_called_once_with(
|
mock_client.post.assert_called_once_with(
|
||||||
"/api/decompose-description",
|
"/api/decompose-description",
|
||||||
json={"description": expected_description},
|
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -9833,8 +9833,7 @@
|
|||||||
"sub_heading": { "type": "string", "title": "Sub Heading" },
|
"sub_heading": { "type": "string", "title": "Sub Heading" },
|
||||||
"description": { "type": "string", "title": "Description" },
|
"description": { "type": "string", "title": "Description" },
|
||||||
"runs": { "type": "integer", "title": "Runs" },
|
"runs": { "type": "integer", "title": "Runs" },
|
||||||
"rating": { "type": "number", "title": "Rating" },
|
"rating": { "type": "number", "title": "Rating" }
|
||||||
"agent_graph_id": { "type": "string", "title": "Agent Graph Id" }
|
|
||||||
},
|
},
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": [
|
"required": [
|
||||||
@@ -9846,8 +9845,7 @@
|
|||||||
"sub_heading",
|
"sub_heading",
|
||||||
"description",
|
"description",
|
||||||
"runs",
|
"runs",
|
||||||
"rating",
|
"rating"
|
||||||
"agent_graph_id"
|
|
||||||
],
|
],
|
||||||
"title": "StoreAgent"
|
"title": "StoreAgent"
|
||||||
},
|
},
|
||||||
|
|||||||
Reference in New Issue
Block a user