mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-03 03:14:57 -05:00
Compare commits
9 Commits
dev
...
ntindle/fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b7968a560d | ||
|
|
56788a0226 | ||
|
|
6119900f44 | ||
|
|
9688a507a3 | ||
|
|
3a38bb2338 | ||
|
|
683fff4fab | ||
|
|
ccb537696f | ||
|
|
7838855da9 | ||
|
|
3ec8fd912f |
@@ -3,8 +3,7 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from dataclasses import dataclass
|
from typing import Any
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
import orjson
|
import orjson
|
||||||
@@ -16,14 +15,7 @@ from openai import (
|
|||||||
PermissionDeniedError,
|
PermissionDeniedError,
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
)
|
)
|
||||||
from openai.types.chat import (
|
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||||
ChatCompletionAssistantMessageParam,
|
|
||||||
ChatCompletionChunk,
|
|
||||||
ChatCompletionMessageParam,
|
|
||||||
ChatCompletionStreamOptionsParam,
|
|
||||||
ChatCompletionSystemMessageParam,
|
|
||||||
ChatCompletionToolParam,
|
|
||||||
)
|
|
||||||
|
|
||||||
from backend.data.redis_client import get_redis_async
|
from backend.data.redis_client import get_redis_async
|
||||||
from backend.data.understanding import (
|
from backend.data.understanding import (
|
||||||
@@ -31,7 +23,6 @@ from backend.data.understanding import (
|
|||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
)
|
)
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.prompt import estimate_token_count
|
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from . import db as chat_db
|
from . import db as chat_db
|
||||||
@@ -803,201 +794,6 @@ def _is_region_blocked_error(error: Exception) -> bool:
|
|||||||
return "not available in your region" in str(error).lower()
|
return "not available in your region" in str(error).lower()
|
||||||
|
|
||||||
|
|
||||||
# Context window management constants
|
|
||||||
TOKEN_THRESHOLD = 120_000
|
|
||||||
KEEP_RECENT_MESSAGES = 15
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ContextWindowResult:
|
|
||||||
"""Result of context window management."""
|
|
||||||
|
|
||||||
messages: list[dict[str, Any]]
|
|
||||||
token_count: int
|
|
||||||
was_compacted: bool
|
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def _messages_to_dicts(messages: list) -> list[dict[str, Any]]:
|
|
||||||
"""Convert message objects to dicts, filtering None values.
|
|
||||||
|
|
||||||
Handles both TypedDict (dict-like) and other message formats.
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
for msg in messages:
|
|
||||||
if msg is None:
|
|
||||||
continue
|
|
||||||
if isinstance(msg, dict):
|
|
||||||
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
|
||||||
else:
|
|
||||||
msg_dict = dict(msg)
|
|
||||||
result.append(msg_dict)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def _manage_context_window(
|
|
||||||
messages: list,
|
|
||||||
model: str,
|
|
||||||
api_key: str | None = None,
|
|
||||||
base_url: str | None = None,
|
|
||||||
) -> ContextWindowResult:
|
|
||||||
"""
|
|
||||||
Manage context window by summarizing old messages if token count exceeds threshold.
|
|
||||||
|
|
||||||
This function handles context compaction for LLM calls by:
|
|
||||||
1. Counting tokens in the message list
|
|
||||||
2. If over threshold, summarizing old messages while keeping recent ones
|
|
||||||
3. Ensuring tool_call/tool_response pairs stay intact
|
|
||||||
4. Progressively reducing message count if still over limit
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of messages in OpenAI format (with system prompt if present)
|
|
||||||
model: Model name for token counting
|
|
||||||
api_key: API key for summarization calls
|
|
||||||
base_url: Base URL for summarization calls
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ContextWindowResult with compacted messages and metadata
|
|
||||||
"""
|
|
||||||
if not messages:
|
|
||||||
return ContextWindowResult([], 0, False, "No messages to compact")
|
|
||||||
|
|
||||||
messages_dict = _messages_to_dicts(messages)
|
|
||||||
|
|
||||||
# Normalize model name for token counting (tiktoken only supports OpenAI models)
|
|
||||||
token_count_model = model.split("/")[-1] if "/" in model else model
|
|
||||||
if "claude" in token_count_model.lower() or not any(
|
|
||||||
known in token_count_model.lower()
|
|
||||||
for known in ["gpt", "o1", "chatgpt", "text-"]
|
|
||||||
):
|
|
||||||
token_count_model = "gpt-4o"
|
|
||||||
|
|
||||||
try:
|
|
||||||
token_count = estimate_token_count(messages_dict, model=token_count_model)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Token counting failed: {e}. Using gpt-4o approximation.")
|
|
||||||
token_count_model = "gpt-4o"
|
|
||||||
token_count = estimate_token_count(messages_dict, model=token_count_model)
|
|
||||||
|
|
||||||
if token_count <= TOKEN_THRESHOLD:
|
|
||||||
return ContextWindowResult(messages, token_count, False)
|
|
||||||
|
|
||||||
has_system_prompt = messages[0].get("role") == "system"
|
|
||||||
slice_start = max(0, len(messages_dict) - KEEP_RECENT_MESSAGES)
|
|
||||||
recent_messages = _ensure_tool_pairs_intact(
|
|
||||||
messages_dict[-KEEP_RECENT_MESSAGES:], messages_dict, slice_start
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine old messages to summarize (explicit bounds to avoid slice edge cases)
|
|
||||||
system_msg = messages[0] if has_system_prompt else None
|
|
||||||
if has_system_prompt:
|
|
||||||
old_messages_dict = (
|
|
||||||
messages_dict[1:-KEEP_RECENT_MESSAGES]
|
|
||||||
if len(messages_dict) > KEEP_RECENT_MESSAGES + 1
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
old_messages_dict = (
|
|
||||||
messages_dict[:-KEEP_RECENT_MESSAGES]
|
|
||||||
if len(messages_dict) > KEEP_RECENT_MESSAGES
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to summarize old messages, fall back to truncation on failure
|
|
||||||
summary_msg = None
|
|
||||||
if old_messages_dict:
|
|
||||||
try:
|
|
||||||
summary_text = await _summarize_messages(
|
|
||||||
old_messages_dict, model=model, api_key=api_key, base_url=base_url
|
|
||||||
)
|
|
||||||
summary_msg = ChatCompletionAssistantMessageParam(
|
|
||||||
role="assistant",
|
|
||||||
content=f"[Previous conversation summary — for context only]: {summary_text}",
|
|
||||||
)
|
|
||||||
base = [system_msg, summary_msg] if has_system_prompt else [summary_msg]
|
|
||||||
messages = base + recent_messages
|
|
||||||
logger.info(
|
|
||||||
f"Context summarized: {token_count} tokens, "
|
|
||||||
f"summarized {len(old_messages_dict)} msgs, kept {KEEP_RECENT_MESSAGES}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Summarization failed, falling back to truncation: {e}")
|
|
||||||
messages = (
|
|
||||||
[system_msg] + recent_messages if has_system_prompt else recent_messages
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Token count {token_count} exceeds threshold but no old messages to summarize"
|
|
||||||
)
|
|
||||||
|
|
||||||
new_token_count = estimate_token_count(
|
|
||||||
_messages_to_dicts(messages), model=token_count_model
|
|
||||||
)
|
|
||||||
|
|
||||||
# Progressive truncation if still over limit
|
|
||||||
if new_token_count > TOKEN_THRESHOLD:
|
|
||||||
logger.warning(
|
|
||||||
f"Still over limit: {new_token_count} tokens. Reducing messages."
|
|
||||||
)
|
|
||||||
base_msgs = (
|
|
||||||
recent_messages
|
|
||||||
if old_messages_dict
|
|
||||||
else (messages_dict[1:] if has_system_prompt else messages_dict)
|
|
||||||
)
|
|
||||||
|
|
||||||
def build_messages(recent: list) -> list:
|
|
||||||
"""Build message list with optional system prompt and summary."""
|
|
||||||
prefix = []
|
|
||||||
if has_system_prompt and system_msg:
|
|
||||||
prefix.append(system_msg)
|
|
||||||
if summary_msg:
|
|
||||||
prefix.append(summary_msg)
|
|
||||||
return prefix + recent
|
|
||||||
|
|
||||||
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
|
||||||
if keep_count == 0:
|
|
||||||
messages = build_messages([])
|
|
||||||
if not messages:
|
|
||||||
continue
|
|
||||||
elif len(base_msgs) < keep_count:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
reduced = _ensure_tool_pairs_intact(
|
|
||||||
base_msgs[-keep_count:],
|
|
||||||
base_msgs,
|
|
||||||
max(0, len(base_msgs) - keep_count),
|
|
||||||
)
|
|
||||||
messages = build_messages(reduced)
|
|
||||||
|
|
||||||
new_token_count = estimate_token_count(
|
|
||||||
_messages_to_dicts(messages), model=token_count_model
|
|
||||||
)
|
|
||||||
if new_token_count <= TOKEN_THRESHOLD:
|
|
||||||
logger.info(
|
|
||||||
f"Reduced to {keep_count} messages, {new_token_count} tokens"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Cannot reduce below threshold. Final: {new_token_count} tokens"
|
|
||||||
)
|
|
||||||
if has_system_prompt and len(messages) > 1:
|
|
||||||
messages = messages[1:]
|
|
||||||
logger.critical("Dropped system prompt as last resort")
|
|
||||||
return ContextWindowResult(
|
|
||||||
messages, new_token_count, True, "System prompt dropped"
|
|
||||||
)
|
|
||||||
# No system prompt to drop - return error so callers don't proceed with oversized context
|
|
||||||
return ContextWindowResult(
|
|
||||||
messages,
|
|
||||||
new_token_count,
|
|
||||||
True,
|
|
||||||
"Unable to reduce context below token limit",
|
|
||||||
)
|
|
||||||
|
|
||||||
return ContextWindowResult(messages, new_token_count, True)
|
|
||||||
|
|
||||||
|
|
||||||
async def _summarize_messages(
|
async def _summarize_messages(
|
||||||
messages: list,
|
messages: list,
|
||||||
model: str,
|
model: str,
|
||||||
@@ -1226,8 +1022,11 @@ async def _stream_chat_chunks(
|
|||||||
|
|
||||||
logger.info("Starting pure chat stream")
|
logger.info("Starting pure chat stream")
|
||||||
|
|
||||||
|
# Build messages with system prompt prepended
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||||
|
|
||||||
system_message = ChatCompletionSystemMessageParam(
|
system_message = ChatCompletionSystemMessageParam(
|
||||||
role="system",
|
role="system",
|
||||||
content=system_prompt,
|
content=system_prompt,
|
||||||
@@ -1235,16 +1034,204 @@ async def _stream_chat_chunks(
|
|||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Apply context window management
|
# Apply context window management
|
||||||
context_result = await _manage_context_window(
|
token_count = 0 # Initialize for exception handler
|
||||||
messages=messages,
|
try:
|
||||||
|
from backend.util.prompt import estimate_token_count
|
||||||
|
|
||||||
|
# Convert to dict for token counting
|
||||||
|
# OpenAI message types are TypedDicts, so they're already dict-like
|
||||||
|
messages_dict = []
|
||||||
|
for msg in messages:
|
||||||
|
# TypedDict objects are already dicts, just filter None values
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
||||||
|
else:
|
||||||
|
# Fallback for unexpected types
|
||||||
|
msg_dict = dict(msg)
|
||||||
|
messages_dict.append(msg_dict)
|
||||||
|
|
||||||
|
# Estimate tokens using appropriate tokenizer
|
||||||
|
# Normalize model name for token counting (tiktoken only supports OpenAI models)
|
||||||
|
token_count_model = model
|
||||||
|
if "/" in model:
|
||||||
|
# Strip provider prefix (e.g., "anthropic/claude-opus-4.5" -> "claude-opus-4.5")
|
||||||
|
token_count_model = model.split("/")[-1]
|
||||||
|
|
||||||
|
# For Claude and other non-OpenAI models, approximate with gpt-4o tokenizer
|
||||||
|
# Most modern LLMs have similar tokenization (~1 token per 4 chars)
|
||||||
|
if "claude" in token_count_model.lower() or not any(
|
||||||
|
known in token_count_model.lower()
|
||||||
|
for known in ["gpt", "o1", "chatgpt", "text-"]
|
||||||
|
):
|
||||||
|
token_count_model = "gpt-4o"
|
||||||
|
|
||||||
|
# Attempt token counting with error handling
|
||||||
|
try:
|
||||||
|
token_count = estimate_token_count(messages_dict, model=token_count_model)
|
||||||
|
except Exception as token_error:
|
||||||
|
# If token counting fails, use gpt-4o as fallback approximation
|
||||||
|
logger.warning(
|
||||||
|
f"Token counting failed for model {token_count_model}: {token_error}. "
|
||||||
|
"Using gpt-4o approximation."
|
||||||
|
)
|
||||||
|
token_count = estimate_token_count(messages_dict, model="gpt-4o")
|
||||||
|
|
||||||
|
# If over threshold, summarize old messages
|
||||||
|
if token_count > 120_000:
|
||||||
|
KEEP_RECENT = 15
|
||||||
|
|
||||||
|
# Check if we have a system prompt at the start
|
||||||
|
has_system_prompt = (
|
||||||
|
len(messages) > 0 and messages[0].get("role") == "system"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Always attempt mitigation when over limit, even with few messages
|
||||||
|
if messages:
|
||||||
|
# Split messages based on whether system prompt exists
|
||||||
|
# Calculate start index for the slice
|
||||||
|
slice_start = max(0, len(messages_dict) - KEEP_RECENT)
|
||||||
|
recent_messages = messages_dict[-KEEP_RECENT:]
|
||||||
|
|
||||||
|
# Ensure tool_call/tool_response pairs stay together
|
||||||
|
# This prevents API errors from orphan tool responses
|
||||||
|
recent_messages = _ensure_tool_pairs_intact(
|
||||||
|
recent_messages, messages_dict, slice_start
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_system_prompt:
|
||||||
|
# Keep system prompt separate, summarize everything between system and recent
|
||||||
|
system_msg = messages[0]
|
||||||
|
old_messages_dict = messages_dict[1:-KEEP_RECENT]
|
||||||
|
else:
|
||||||
|
# No system prompt, summarize everything except recent
|
||||||
|
system_msg = None
|
||||||
|
old_messages_dict = messages_dict[:-KEEP_RECENT]
|
||||||
|
|
||||||
|
# Summarize any non-empty old messages (no minimum threshold)
|
||||||
|
# If we're over the token limit, we need to compress whatever we can
|
||||||
|
if old_messages_dict:
|
||||||
|
# Summarize old messages using the same model as chat
|
||||||
|
summary_text = await _summarize_messages(
|
||||||
|
old_messages_dict,
|
||||||
model=model,
|
model=model,
|
||||||
api_key=config.api_key,
|
api_key=config.api_key,
|
||||||
base_url=config.base_url,
|
base_url=config.base_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
if context_result.error:
|
# Build new message list
|
||||||
if "System prompt dropped" in context_result.error:
|
# Use assistant role (not system) to prevent privilege escalation
|
||||||
# Warning only - continue with reduced context
|
# of user-influenced content to instruction-level authority
|
||||||
|
from openai.types.chat import ChatCompletionAssistantMessageParam
|
||||||
|
|
||||||
|
summary_msg = ChatCompletionAssistantMessageParam(
|
||||||
|
role="assistant",
|
||||||
|
content=(
|
||||||
|
"[Previous conversation summary — for context only]: "
|
||||||
|
f"{summary_text}"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rebuild messages based on whether we have a system prompt
|
||||||
|
if has_system_prompt:
|
||||||
|
# system_prompt + summary + recent_messages
|
||||||
|
messages = [system_msg, summary_msg] + recent_messages
|
||||||
|
else:
|
||||||
|
# summary + recent_messages (no original system prompt)
|
||||||
|
messages = [summary_msg] + recent_messages
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Context summarized: {token_count} tokens, "
|
||||||
|
f"summarized {len(old_messages_dict)} old messages, "
|
||||||
|
f"kept last {KEEP_RECENT} messages"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fallback: If still over limit after summarization, progressively drop recent messages
|
||||||
|
# This handles edge cases where recent messages are extremely large
|
||||||
|
new_messages_dict = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
||||||
|
else:
|
||||||
|
msg_dict = dict(msg)
|
||||||
|
new_messages_dict.append(msg_dict)
|
||||||
|
|
||||||
|
new_token_count = estimate_token_count(
|
||||||
|
new_messages_dict, model=token_count_model
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_token_count > 120_000:
|
||||||
|
# Still over limit - progressively reduce KEEP_RECENT
|
||||||
|
logger.warning(
|
||||||
|
f"Still over limit after summarization: {new_token_count} tokens. "
|
||||||
|
"Reducing number of recent messages kept."
|
||||||
|
)
|
||||||
|
|
||||||
|
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
||||||
|
if keep_count == 0:
|
||||||
|
# Try with just system prompt + summary (no recent messages)
|
||||||
|
if has_system_prompt:
|
||||||
|
messages = [system_msg, summary_msg]
|
||||||
|
else:
|
||||||
|
messages = [summary_msg]
|
||||||
|
logger.info(
|
||||||
|
"Trying with 0 recent messages (system + summary only)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Slice from ORIGINAL recent_messages to avoid duplicating summary
|
||||||
|
reduced_recent = (
|
||||||
|
recent_messages[-keep_count:]
|
||||||
|
if len(recent_messages) >= keep_count
|
||||||
|
else recent_messages
|
||||||
|
)
|
||||||
|
# Ensure tool pairs stay intact in the reduced slice
|
||||||
|
reduced_slice_start = max(
|
||||||
|
0, len(recent_messages) - keep_count
|
||||||
|
)
|
||||||
|
reduced_recent = _ensure_tool_pairs_intact(
|
||||||
|
reduced_recent, recent_messages, reduced_slice_start
|
||||||
|
)
|
||||||
|
if has_system_prompt:
|
||||||
|
messages = [
|
||||||
|
system_msg,
|
||||||
|
summary_msg,
|
||||||
|
] + reduced_recent
|
||||||
|
else:
|
||||||
|
messages = [summary_msg] + reduced_recent
|
||||||
|
|
||||||
|
new_messages_dict = []
|
||||||
|
for msg in messages:
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
msg_dict = {
|
||||||
|
k: v for k, v in msg.items() if v is not None
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
msg_dict = dict(msg)
|
||||||
|
new_messages_dict.append(msg_dict)
|
||||||
|
|
||||||
|
new_token_count = estimate_token_count(
|
||||||
|
new_messages_dict, model=token_count_model
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_token_count <= 120_000:
|
||||||
|
logger.info(
|
||||||
|
f"Reduced to {keep_count} recent messages, "
|
||||||
|
f"now {new_token_count} tokens"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"Unable to reduce token count below threshold even with 0 messages. "
|
||||||
|
f"Final count: {new_token_count} tokens"
|
||||||
|
)
|
||||||
|
# ABSOLUTE LAST RESORT: Drop system prompt
|
||||||
|
# This should only happen if summary itself is massive
|
||||||
|
if has_system_prompt and len(messages) > 1:
|
||||||
|
messages = messages[1:] # Drop system prompt
|
||||||
|
logger.critical(
|
||||||
|
"CRITICAL: Dropped system prompt as absolute last resort. "
|
||||||
|
"Behavioral consistency may be affected."
|
||||||
|
)
|
||||||
|
# Yield error to user
|
||||||
yield StreamError(
|
yield StreamError(
|
||||||
errorText=(
|
errorText=(
|
||||||
"Warning: System prompt dropped due to size constraints. "
|
"Warning: System prompt dropped due to size constraints. "
|
||||||
@@ -1252,21 +1239,109 @@ async def _stream_chat_chunks(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Any other error - abort to prevent failed LLM calls
|
# No old messages to summarize - all messages are "recent"
|
||||||
|
# Apply progressive truncation to reduce token count
|
||||||
|
logger.warning(
|
||||||
|
f"Token count {token_count} exceeds threshold but no old messages to summarize. "
|
||||||
|
f"Applying progressive truncation to recent messages."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a base list excluding system prompt to avoid duplication
|
||||||
|
# This is the pool of messages we'll slice from in the loop
|
||||||
|
# Use messages_dict for type consistency with _ensure_tool_pairs_intact
|
||||||
|
base_msgs = (
|
||||||
|
messages_dict[1:] if has_system_prompt else messages_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try progressively smaller keep counts
|
||||||
|
new_token_count = token_count # Initialize with current count
|
||||||
|
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
||||||
|
if keep_count == 0:
|
||||||
|
# Try with just system prompt (no recent messages)
|
||||||
|
if has_system_prompt:
|
||||||
|
messages = [system_msg]
|
||||||
|
logger.info(
|
||||||
|
"Trying with 0 recent messages (system prompt only)"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No system prompt and no recent messages = empty messages list
|
||||||
|
# This is invalid, skip this iteration
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
if len(base_msgs) < keep_count:
|
||||||
|
continue # Skip if we don't have enough messages
|
||||||
|
|
||||||
|
# Slice from base_msgs to get recent messages (without system prompt)
|
||||||
|
recent_messages = base_msgs[-keep_count:]
|
||||||
|
|
||||||
|
# Ensure tool pairs stay intact in the reduced slice
|
||||||
|
reduced_slice_start = max(0, len(base_msgs) - keep_count)
|
||||||
|
recent_messages = _ensure_tool_pairs_intact(
|
||||||
|
recent_messages, base_msgs, reduced_slice_start
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_system_prompt:
|
||||||
|
messages = [system_msg] + recent_messages
|
||||||
|
else:
|
||||||
|
messages = recent_messages
|
||||||
|
|
||||||
|
new_messages_dict = []
|
||||||
|
for msg in messages:
|
||||||
|
if msg is None:
|
||||||
|
continue # Skip None messages (type safety)
|
||||||
|
if isinstance(msg, dict):
|
||||||
|
msg_dict = {
|
||||||
|
k: v for k, v in msg.items() if v is not None
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
msg_dict = dict(msg)
|
||||||
|
new_messages_dict.append(msg_dict)
|
||||||
|
|
||||||
|
new_token_count = estimate_token_count(
|
||||||
|
new_messages_dict, model=token_count_model
|
||||||
|
)
|
||||||
|
|
||||||
|
if new_token_count <= 120_000:
|
||||||
|
logger.info(
|
||||||
|
f"Reduced to {keep_count} recent messages, "
|
||||||
|
f"now {new_token_count} tokens"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# Even with 0 messages still over limit
|
||||||
|
logger.error(
|
||||||
|
f"Unable to reduce token count below threshold even with 0 messages. "
|
||||||
|
f"Final count: {new_token_count} tokens. Messages may be extremely large."
|
||||||
|
)
|
||||||
|
# ABSOLUTE LAST RESORT: Drop system prompt
|
||||||
|
if has_system_prompt and len(messages) > 1:
|
||||||
|
messages = messages[1:] # Drop system prompt
|
||||||
|
logger.critical(
|
||||||
|
"CRITICAL: Dropped system prompt as absolute last resort. "
|
||||||
|
"Behavioral consistency may be affected."
|
||||||
|
)
|
||||||
|
# Yield error to user
|
||||||
yield StreamError(
|
yield StreamError(
|
||||||
errorText=(
|
errorText=(
|
||||||
f"Context window management failed: {context_result.error}. "
|
"Warning: System prompt dropped due to size constraints. "
|
||||||
"Please start a new conversation."
|
"Assistant behavior may be affected."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Context summarization failed: {e}", exc_info=True)
|
||||||
|
# If we were over the token limit, yield error to user
|
||||||
|
# Don't silently continue with oversized messages that will fail
|
||||||
|
if token_count > 120_000:
|
||||||
|
yield StreamError(
|
||||||
|
errorText=(
|
||||||
|
f"Unable to manage context window (token limit exceeded: {token_count} tokens). "
|
||||||
|
"Context summarization failed. Please start a new conversation."
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
yield StreamFinish()
|
yield StreamFinish()
|
||||||
return
|
return
|
||||||
|
# Otherwise, continue with original messages (under limit)
|
||||||
messages = context_result.messages
|
|
||||||
if context_result.was_compacted:
|
|
||||||
logger.info(
|
|
||||||
f"Context compacted for streaming: {context_result.token_count} tokens"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Loop to handle tool calls and continue conversation
|
# Loop to handle tool calls and continue conversation
|
||||||
while True:
|
while True:
|
||||||
@@ -1294,6 +1369,14 @@ async def _stream_chat_chunks(
|
|||||||
:128
|
:128
|
||||||
] # OpenRouter limit
|
] # OpenRouter limit
|
||||||
|
|
||||||
|
# Create the stream with proper types
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
from openai.types.chat import (
|
||||||
|
ChatCompletionMessageParam,
|
||||||
|
ChatCompletionStreamOptionsParam,
|
||||||
|
)
|
||||||
|
|
||||||
stream = await client.chat.completions.create(
|
stream = await client.chat.completions.create(
|
||||||
model=model,
|
model=model,
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
@@ -1817,36 +1900,17 @@ async def _generate_llm_continuation(
|
|||||||
# Build system prompt
|
# Build system prompt
|
||||||
system_prompt, _ = await _build_system_prompt(user_id)
|
system_prompt, _ = await _build_system_prompt(user_id)
|
||||||
|
|
||||||
|
# Build messages in OpenAI format
|
||||||
messages = session.to_openai_messages()
|
messages = session.to_openai_messages()
|
||||||
if system_prompt:
|
if system_prompt:
|
||||||
|
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||||
|
|
||||||
system_message = ChatCompletionSystemMessageParam(
|
system_message = ChatCompletionSystemMessageParam(
|
||||||
role="system",
|
role="system",
|
||||||
content=system_prompt,
|
content=system_prompt,
|
||||||
)
|
)
|
||||||
messages = [system_message] + messages
|
messages = [system_message] + messages
|
||||||
|
|
||||||
# Apply context window management to prevent oversized requests
|
|
||||||
context_result = await _manage_context_window(
|
|
||||||
messages=messages,
|
|
||||||
model=config.model,
|
|
||||||
api_key=config.api_key,
|
|
||||||
base_url=config.base_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
if context_result.error and "System prompt dropped" not in context_result.error:
|
|
||||||
logger.error(
|
|
||||||
f"Context window management failed for session {session_id}: "
|
|
||||||
f"{context_result.error} (tokens={context_result.token_count})"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
messages = context_result.messages
|
|
||||||
if context_result.was_compacted:
|
|
||||||
logger.info(
|
|
||||||
f"Context compacted for LLM continuation: "
|
|
||||||
f"{context_result.token_count} tokens"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Build extra_body for tracing
|
# Build extra_body for tracing
|
||||||
extra_body: dict[str, Any] = {
|
extra_body: dict[str, Any] = {
|
||||||
"posthogProperties": {
|
"posthogProperties": {
|
||||||
@@ -1859,54 +1923,19 @@ async def _generate_llm_continuation(
|
|||||||
if session_id:
|
if session_id:
|
||||||
extra_body["session_id"] = session_id[:128]
|
extra_body["session_id"] = session_id[:128]
|
||||||
|
|
||||||
retry_count = 0
|
# Make non-streaming LLM call (no tools - just text response)
|
||||||
last_error: Exception | None = None
|
from typing import cast
|
||||||
response = None
|
|
||||||
|
|
||||||
while retry_count <= MAX_RETRIES:
|
from openai.types.chat import ChatCompletionMessageParam
|
||||||
try:
|
|
||||||
logger.info(
|
|
||||||
f"Generating LLM continuation for session {session_id}"
|
|
||||||
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# No tools parameter = text-only response (no tool calls)
|
||||||
response = await client.chat.completions.create(
|
response = await client.chat.completions.create(
|
||||||
model=config.model,
|
model=config.model,
|
||||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
last_error = None # Clear any previous error on success
|
|
||||||
break # Success, exit retry loop
|
|
||||||
except Exception as e:
|
|
||||||
last_error = e
|
|
||||||
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
|
|
||||||
retry_count += 1
|
|
||||||
delay = min(
|
|
||||||
BASE_DELAY_SECONDS * (2 ** (retry_count - 1)),
|
|
||||||
MAX_DELAY_SECONDS,
|
|
||||||
)
|
|
||||||
logger.warning(
|
|
||||||
f"Retryable error in LLM continuation: {e!s}. "
|
|
||||||
f"Retrying in {delay:.1f}s (attempt {retry_count}/{MAX_RETRIES})"
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# Non-retryable error - log and exit gracefully
|
|
||||||
logger.error(
|
|
||||||
f"Non-retryable error in LLM continuation: {e!s}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if last_error:
|
if response.choices and response.choices[0].message.content:
|
||||||
logger.error(
|
|
||||||
f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. "
|
|
||||||
f"Last error: {last_error!s}"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if response and response.choices and response.choices[0].message.content:
|
|
||||||
assistant_content = response.choices[0].message.content
|
assistant_content = response.choices[0].message.content
|
||||||
|
|
||||||
# Reload session from DB to avoid race condition with user messages
|
# Reload session from DB to avoid race condition with user messages
|
||||||
|
|||||||
@@ -139,10 +139,11 @@ async def decompose_goal_external(
|
|||||||
"""
|
"""
|
||||||
client = _get_client()
|
client = _get_client()
|
||||||
|
|
||||||
if context:
|
# Build the request payload
|
||||||
description = f"{description}\n\nAdditional context from user:\n{context}"
|
|
||||||
|
|
||||||
payload: dict[str, Any] = {"description": description}
|
payload: dict[str, Any] = {"description": description}
|
||||||
|
if context:
|
||||||
|
# The external service uses user_instruction for additional context
|
||||||
|
payload["user_instruction"] = context
|
||||||
if library_agents:
|
if library_agents:
|
||||||
payload["library_agents"] = library_agents
|
payload["library_agents"] = library_agents
|
||||||
|
|
||||||
|
|||||||
@@ -104,10 +104,52 @@ async def list_library_agents(
|
|||||||
order_by = {"updatedAt": "desc"}
|
order_by = {"updatedAt": "desc"}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# For LAST_EXECUTED sorting, we need to fetch execution data and sort in Python
|
||||||
|
# since Prisma doesn't support sorting by nested relations
|
||||||
|
if sort_by == library_model.LibraryAgentSort.LAST_EXECUTED:
|
||||||
|
# TODO: This fetches all agents into memory for sorting, which may cause
|
||||||
|
# performance issues for users with many agents. Prisma doesn't support
|
||||||
|
# sorting by nested relations, so a dedicated lastExecutedAt column or
|
||||||
|
# raw SQL query would be needed for database-level pagination.
|
||||||
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||||
where=where_clause,
|
where=where_clause,
|
||||||
include=library_agent_include(
|
include=library_agent_include(
|
||||||
user_id, include_nodes=False, include_executions=include_executions
|
user_id,
|
||||||
|
include_nodes=False,
|
||||||
|
include_executions=True,
|
||||||
|
execution_limit=1,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_sort_key(
|
||||||
|
agent: prisma.models.LibraryAgent,
|
||||||
|
) -> tuple[int, float]:
|
||||||
|
"""
|
||||||
|
Returns a tuple for sorting: (has_no_executions, -timestamp).
|
||||||
|
|
||||||
|
Agents WITH executions come first (sorted by most recent execution),
|
||||||
|
agents WITHOUT executions come last (sorted by creation date).
|
||||||
|
"""
|
||||||
|
graph = agent.AgentGraph
|
||||||
|
if graph and graph.Executions and len(graph.Executions) > 0:
|
||||||
|
execution = graph.Executions[0]
|
||||||
|
timestamp = execution.updatedAt or execution.createdAt
|
||||||
|
return (0, -timestamp.timestamp())
|
||||||
|
return (1, -agent.createdAt.timestamp())
|
||||||
|
|
||||||
|
library_agents.sort(key=get_sort_key)
|
||||||
|
|
||||||
|
# Apply pagination after sorting
|
||||||
|
agent_count = len(library_agents)
|
||||||
|
start_idx = (page - 1) * page_size
|
||||||
|
end_idx = start_idx + page_size
|
||||||
|
library_agents = library_agents[start_idx:end_idx]
|
||||||
|
else:
|
||||||
|
# Standard sorting via database
|
||||||
|
library_agents = await prisma.models.LibraryAgent.prisma().find_many(
|
||||||
|
where=where_clause,
|
||||||
|
include=library_agent_include(
|
||||||
|
user_id, include_nodes=False, include_executions=False
|
||||||
),
|
),
|
||||||
order=order_by,
|
order=order_by,
|
||||||
skip=(page - 1) * page_size,
|
skip=(page - 1) * page_size,
|
||||||
@@ -345,6 +387,20 @@ async def get_library_agent_by_graph_id(
|
|||||||
graph_id: str,
|
graph_id: str,
|
||||||
graph_version: Optional[int] = None,
|
graph_version: Optional[int] = None,
|
||||||
) -> library_model.LibraryAgent | None:
|
) -> library_model.LibraryAgent | None:
|
||||||
|
"""
|
||||||
|
Retrieves a library agent by its graph ID for a given user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The ID of the user who owns the library agent.
|
||||||
|
graph_id: The ID of the agent graph to look up.
|
||||||
|
graph_version: Optional specific version of the graph to retrieve.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The LibraryAgent if found, otherwise None.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabaseError: If there's an error during retrieval.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
filter: prisma.types.LibraryAgentWhereInput = {
|
filter: prisma.types.LibraryAgentWhereInput = {
|
||||||
"agentGraphId": graph_id,
|
"agentGraphId": graph_id,
|
||||||
@@ -628,6 +684,17 @@ async def update_library_agent(
|
|||||||
async def delete_library_agent(
|
async def delete_library_agent(
|
||||||
library_agent_id: str, user_id: str, soft_delete: bool = True
|
library_agent_id: str, user_id: str, soft_delete: bool = True
|
||||||
) -> None:
|
) -> None:
|
||||||
|
"""
|
||||||
|
Deletes a library agent and cleans up associated schedules and webhooks.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
library_agent_id: The ID of the library agent to delete.
|
||||||
|
user_id: The ID of the user who owns the library agent.
|
||||||
|
soft_delete: If True, marks the agent as deleted; if False, permanently removes it.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the library agent is not found or doesn't belong to the user.
|
||||||
|
"""
|
||||||
# First get the agent to find the graph_id for cleanup
|
# First get the agent to find the graph_id for cleanup
|
||||||
library_agent = await prisma.models.LibraryAgent.prisma().find_unique(
|
library_agent = await prisma.models.LibraryAgent.prisma().find_unique(
|
||||||
where={"id": library_agent_id}, include={"AgentGraph": True}
|
where={"id": library_agent_id}, include={"AgentGraph": True}
|
||||||
@@ -1121,6 +1188,20 @@ async def update_preset(
|
|||||||
async def set_preset_webhook(
|
async def set_preset_webhook(
|
||||||
user_id: str, preset_id: str, webhook_id: str | None
|
user_id: str, preset_id: str, webhook_id: str | None
|
||||||
) -> library_model.LibraryAgentPreset:
|
) -> library_model.LibraryAgentPreset:
|
||||||
|
"""
|
||||||
|
Sets or removes a webhook connection for a preset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The ID of the user who owns the preset.
|
||||||
|
preset_id: The ID of the preset to update.
|
||||||
|
webhook_id: The ID of the webhook to connect, or None to disconnect.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The updated LibraryAgentPreset.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
NotFoundError: If the preset is not found or doesn't belong to the user.
|
||||||
|
"""
|
||||||
current = await prisma.models.AgentPreset.prisma().find_unique(
|
current = await prisma.models.AgentPreset.prisma().find_unique(
|
||||||
where={"id": preset_id},
|
where={"id": preset_id},
|
||||||
include=AGENT_PRESET_INCLUDE,
|
include=AGENT_PRESET_INCLUDE,
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
import prisma.enums
|
import prisma.enums
|
||||||
import prisma.models
|
import prisma.models
|
||||||
@@ -9,6 +9,7 @@ from backend.data.db import connect
|
|||||||
from backend.data.includes import library_agent_include
|
from backend.data.includes import library_agent_include
|
||||||
|
|
||||||
from . import db
|
from . import db
|
||||||
|
from . import model as library_model
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@@ -225,3 +226,183 @@ async def test_add_agent_to_library_not_found(mocker):
|
|||||||
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
mock_store_listing_version.return_value.find_unique.assert_called_once_with(
|
||||||
where={"id": "version123"}, include={"AgentGraph": True}
|
where={"id": "version123"}, include={"AgentGraph": True}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_library_agents_sort_by_last_executed(mocker):
|
||||||
|
"""
|
||||||
|
Test LAST_EXECUTED sorting behavior:
|
||||||
|
- Agents WITH executions come first, sorted by most recent execution (updatedAt)
|
||||||
|
- Agents WITHOUT executions come last, sorted by creation date
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
# Agent 1: Has execution that finished 1 hour ago
|
||||||
|
agent1_execution = prisma.models.AgentGraphExecution(
|
||||||
|
id="exec1",
|
||||||
|
agentGraphId="agent1",
|
||||||
|
agentGraphVersion=1,
|
||||||
|
userId="test-user",
|
||||||
|
createdAt=now - timedelta(hours=2),
|
||||||
|
updatedAt=now - timedelta(hours=1), # Finished 1 hour ago
|
||||||
|
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||||
|
isDeleted=False,
|
||||||
|
isShared=False,
|
||||||
|
)
|
||||||
|
agent1_graph = prisma.models.AgentGraph(
|
||||||
|
id="agent1",
|
||||||
|
version=1,
|
||||||
|
name="Agent With Recent Execution",
|
||||||
|
description="Has execution finished 1 hour ago",
|
||||||
|
userId="test-user",
|
||||||
|
isActive=True,
|
||||||
|
createdAt=now - timedelta(days=5),
|
||||||
|
Executions=[agent1_execution],
|
||||||
|
)
|
||||||
|
library_agent1 = prisma.models.LibraryAgent(
|
||||||
|
id="lib1",
|
||||||
|
userId="test-user",
|
||||||
|
agentGraphId="agent1",
|
||||||
|
agentGraphVersion=1,
|
||||||
|
settings="{}", # type: ignore
|
||||||
|
isCreatedByUser=True,
|
||||||
|
isDeleted=False,
|
||||||
|
isArchived=False,
|
||||||
|
createdAt=now - timedelta(days=5),
|
||||||
|
updatedAt=now - timedelta(days=5),
|
||||||
|
isFavorite=False,
|
||||||
|
useGraphIsActiveVersion=True,
|
||||||
|
AgentGraph=agent1_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent 2: Has execution that finished 3 hours ago
|
||||||
|
agent2_execution = prisma.models.AgentGraphExecution(
|
||||||
|
id="exec2",
|
||||||
|
agentGraphId="agent2",
|
||||||
|
agentGraphVersion=1,
|
||||||
|
userId="test-user",
|
||||||
|
createdAt=now - timedelta(hours=5),
|
||||||
|
updatedAt=now - timedelta(hours=3), # Finished 3 hours ago
|
||||||
|
executionStatus=prisma.enums.AgentExecutionStatus.COMPLETED,
|
||||||
|
isDeleted=False,
|
||||||
|
isShared=False,
|
||||||
|
)
|
||||||
|
agent2_graph = prisma.models.AgentGraph(
|
||||||
|
id="agent2",
|
||||||
|
version=1,
|
||||||
|
name="Agent With Older Execution",
|
||||||
|
description="Has execution finished 3 hours ago",
|
||||||
|
userId="test-user",
|
||||||
|
isActive=True,
|
||||||
|
createdAt=now - timedelta(days=3),
|
||||||
|
Executions=[agent2_execution],
|
||||||
|
)
|
||||||
|
library_agent2 = prisma.models.LibraryAgent(
|
||||||
|
id="lib2",
|
||||||
|
userId="test-user",
|
||||||
|
agentGraphId="agent2",
|
||||||
|
agentGraphVersion=1,
|
||||||
|
settings="{}", # type: ignore
|
||||||
|
isCreatedByUser=True,
|
||||||
|
isDeleted=False,
|
||||||
|
isArchived=False,
|
||||||
|
createdAt=now - timedelta(days=3),
|
||||||
|
updatedAt=now - timedelta(days=3),
|
||||||
|
isFavorite=False,
|
||||||
|
useGraphIsActiveVersion=True,
|
||||||
|
AgentGraph=agent2_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent 3: No executions, created 1 day ago (should come after agents with executions)
|
||||||
|
agent3_graph = prisma.models.AgentGraph(
|
||||||
|
id="agent3",
|
||||||
|
version=1,
|
||||||
|
name="Agent Without Executions (Newer)",
|
||||||
|
description="No executions, created 1 day ago",
|
||||||
|
userId="test-user",
|
||||||
|
isActive=True,
|
||||||
|
createdAt=now - timedelta(days=1),
|
||||||
|
Executions=[],
|
||||||
|
)
|
||||||
|
library_agent3 = prisma.models.LibraryAgent(
|
||||||
|
id="lib3",
|
||||||
|
userId="test-user",
|
||||||
|
agentGraphId="agent3",
|
||||||
|
agentGraphVersion=1,
|
||||||
|
settings="{}", # type: ignore
|
||||||
|
isCreatedByUser=True,
|
||||||
|
isDeleted=False,
|
||||||
|
isArchived=False,
|
||||||
|
createdAt=now - timedelta(days=1),
|
||||||
|
updatedAt=now - timedelta(days=1),
|
||||||
|
isFavorite=False,
|
||||||
|
useGraphIsActiveVersion=True,
|
||||||
|
AgentGraph=agent3_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Agent 4: No executions, created 2 days ago
|
||||||
|
agent4_graph = prisma.models.AgentGraph(
|
||||||
|
id="agent4",
|
||||||
|
version=1,
|
||||||
|
name="Agent Without Executions (Older)",
|
||||||
|
description="No executions, created 2 days ago",
|
||||||
|
userId="test-user",
|
||||||
|
isActive=True,
|
||||||
|
createdAt=now - timedelta(days=2),
|
||||||
|
Executions=[],
|
||||||
|
)
|
||||||
|
library_agent4 = prisma.models.LibraryAgent(
|
||||||
|
id="lib4",
|
||||||
|
userId="test-user",
|
||||||
|
agentGraphId="agent4",
|
||||||
|
agentGraphVersion=1,
|
||||||
|
settings="{}", # type: ignore
|
||||||
|
isCreatedByUser=True,
|
||||||
|
isDeleted=False,
|
||||||
|
isArchived=False,
|
||||||
|
createdAt=now - timedelta(days=2),
|
||||||
|
updatedAt=now - timedelta(days=2),
|
||||||
|
isFavorite=False,
|
||||||
|
useGraphIsActiveVersion=True,
|
||||||
|
AgentGraph=agent4_graph,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Return agents in random order to verify sorting works
|
||||||
|
mock_library_agents = [
|
||||||
|
library_agent3,
|
||||||
|
library_agent1,
|
||||||
|
library_agent4,
|
||||||
|
library_agent2,
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock prisma calls
|
||||||
|
mock_agent_graph = mocker.patch("prisma.models.AgentGraph.prisma")
|
||||||
|
mock_agent_graph.return_value.find_many = mocker.AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
mock_library_agent = mocker.patch("prisma.models.LibraryAgent.prisma")
|
||||||
|
mock_library_agent.return_value.find_many = mocker.AsyncMock(
|
||||||
|
return_value=mock_library_agents
|
||||||
|
)
|
||||||
|
|
||||||
|
# Call function with LAST_EXECUTED sort
|
||||||
|
result = await db.list_library_agents(
|
||||||
|
"test-user",
|
||||||
|
sort_by=library_model.LibraryAgentSort.LAST_EXECUTED,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify sorting order:
|
||||||
|
# 1. Agent 1 (execution finished 1 hour ago) - most recent execution
|
||||||
|
# 2. Agent 2 (execution finished 3 hours ago) - older execution
|
||||||
|
# 3. Agent 3 (no executions, created 1 day ago) - newer creation
|
||||||
|
# 4. Agent 4 (no executions, created 2 days ago) - older creation
|
||||||
|
assert len(result.agents) == 4
|
||||||
|
assert (
|
||||||
|
result.agents[0].id == "lib1"
|
||||||
|
), "Agent with most recent execution should be first"
|
||||||
|
assert result.agents[1].id == "lib2", "Agent with older execution should be second"
|
||||||
|
assert (
|
||||||
|
result.agents[2].id == "lib3"
|
||||||
|
), "Agent without executions (newer) should be third"
|
||||||
|
assert (
|
||||||
|
result.agents[3].id == "lib4"
|
||||||
|
), "Agent without executions (older) should be last"
|
||||||
|
|||||||
@@ -442,6 +442,7 @@ class LibraryAgentSort(str, Enum):
|
|||||||
|
|
||||||
CREATED_AT = "createdAt"
|
CREATED_AT = "createdAt"
|
||||||
UPDATED_AT = "updatedAt"
|
UPDATED_AT = "updatedAt"
|
||||||
|
LAST_EXECUTED = "lastExecuted"
|
||||||
|
|
||||||
|
|
||||||
class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
class LibraryAgentUpdateRequest(pydantic.BaseModel):
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ async def list_library_agents(
|
|||||||
None, description="Search term to filter agents"
|
None, description="Search term to filter agents"
|
||||||
),
|
),
|
||||||
sort_by: library_model.LibraryAgentSort = Query(
|
sort_by: library_model.LibraryAgentSort = Query(
|
||||||
library_model.LibraryAgentSort.UPDATED_AT,
|
library_model.LibraryAgentSort.LAST_EXECUTED,
|
||||||
description="Criteria to sort results by",
|
description="Criteria to sort results by",
|
||||||
),
|
),
|
||||||
page: int = Query(
|
page: int = Query(
|
||||||
|
|||||||
@@ -112,7 +112,7 @@ async def test_get_library_agents_success(
|
|||||||
mock_db_call.assert_called_once_with(
|
mock_db_call.assert_called_once_with(
|
||||||
user_id=test_user_id,
|
user_id=test_user_id,
|
||||||
search_term="test",
|
search_term="test",
|
||||||
sort_by=library_model.LibraryAgentSort.UPDATED_AT,
|
sort_by=library_model.LibraryAgentSort.LAST_EXECUTED,
|
||||||
page=1,
|
page=1,
|
||||||
page_size=15,
|
page_size=15,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -119,7 +119,9 @@ def library_agent_include(
|
|||||||
if include_executions:
|
if include_executions:
|
||||||
agent_graph_include["Executions"] = {
|
agent_graph_include["Executions"] = {
|
||||||
"where": {"userId": user_id},
|
"where": {"userId": user_id},
|
||||||
"order_by": {"createdAt": "desc"},
|
"order_by": {
|
||||||
|
"updatedAt": "desc"
|
||||||
|
}, # Uses updatedAt because it reflects when the executioncompleted or last progressed
|
||||||
"take": execution_limit,
|
"take": execution_limit,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,39 +0,0 @@
|
|||||||
from urllib.parse import urlparse
|
|
||||||
|
|
||||||
import fastapi
|
|
||||||
from fastapi.routing import APIRoute
|
|
||||||
|
|
||||||
from backend.api.features.integrations.router import router as integrations_router
|
|
||||||
from backend.integrations.providers import ProviderName
|
|
||||||
from backend.integrations.webhooks import utils as webhooks_utils
|
|
||||||
|
|
||||||
|
|
||||||
def test_webhook_ingress_url_matches_route(monkeypatch) -> None:
|
|
||||||
app = fastapi.FastAPI()
|
|
||||||
app.include_router(integrations_router, prefix="/api/integrations")
|
|
||||||
|
|
||||||
provider = ProviderName.GITHUB
|
|
||||||
webhook_id = "webhook_123"
|
|
||||||
base_url = "https://example.com"
|
|
||||||
|
|
||||||
monkeypatch.setattr(webhooks_utils.app_config, "platform_base_url", base_url)
|
|
||||||
|
|
||||||
route = next(
|
|
||||||
route
|
|
||||||
for route in integrations_router.routes
|
|
||||||
if isinstance(route, APIRoute)
|
|
||||||
and route.path == "/{provider}/webhooks/{webhook_id}/ingress"
|
|
||||||
and "POST" in route.methods
|
|
||||||
)
|
|
||||||
expected_path = f"/api/integrations{route.path}".format(
|
|
||||||
provider=provider.value,
|
|
||||||
webhook_id=webhook_id,
|
|
||||||
)
|
|
||||||
actual_url = urlparse(webhooks_utils.webhook_ingress_url(provider, webhook_id))
|
|
||||||
expected_base = urlparse(base_url)
|
|
||||||
|
|
||||||
assert (actual_url.scheme, actual_url.netloc) == (
|
|
||||||
expected_base.scheme,
|
|
||||||
expected_base.netloc,
|
|
||||||
)
|
|
||||||
assert actual_url.path == expected_path
|
|
||||||
@@ -102,7 +102,7 @@ class TestDecomposeGoalExternal:
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_decompose_goal_with_context(self):
|
async def test_decompose_goal_with_context(self):
|
||||||
"""Test decomposition with additional context enriched into description."""
|
"""Test decomposition with additional context."""
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.json.return_value = {
|
mock_response.json.return_value = {
|
||||||
"success": True,
|
"success": True,
|
||||||
@@ -119,12 +119,9 @@ class TestDecomposeGoalExternal:
|
|||||||
"Build a chatbot", context="Use Python"
|
"Build a chatbot", context="Use Python"
|
||||||
)
|
)
|
||||||
|
|
||||||
expected_description = (
|
|
||||||
"Build a chatbot\n\nAdditional context from user:\nUse Python"
|
|
||||||
)
|
|
||||||
mock_client.post.assert_called_once_with(
|
mock_client.post.assert_called_once_with(
|
||||||
"/api/decompose-description",
|
"/api/decompose-description",
|
||||||
json={"description": expected_description},
|
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|||||||
@@ -23,10 +23,13 @@ export function LibrarySortMenu({ setLibrarySort }: Props) {
|
|||||||
<Select onValueChange={handleSortChange}>
|
<Select onValueChange={handleSortChange}>
|
||||||
<SelectTrigger className="ml-1 w-fit space-x-1 border-none px-0 text-base underline underline-offset-4 shadow-none">
|
<SelectTrigger className="ml-1 w-fit space-x-1 border-none px-0 text-base underline underline-offset-4 shadow-none">
|
||||||
<ArrowDownNarrowWideIcon className="h-4 w-4 sm:hidden" />
|
<ArrowDownNarrowWideIcon className="h-4 w-4 sm:hidden" />
|
||||||
<SelectValue placeholder="Last Modified" />
|
<SelectValue placeholder="Last Executed" />
|
||||||
</SelectTrigger>
|
</SelectTrigger>
|
||||||
<SelectContent>
|
<SelectContent>
|
||||||
<SelectGroup>
|
<SelectGroup>
|
||||||
|
<SelectItem value={LibraryAgentSort.lastExecuted}>
|
||||||
|
Last Executed
|
||||||
|
</SelectItem>
|
||||||
<SelectItem value={LibraryAgentSort.createdAt}>
|
<SelectItem value={LibraryAgentSort.createdAt}>
|
||||||
Creation Date
|
Creation Date
|
||||||
</SelectItem>
|
</SelectItem>
|
||||||
|
|||||||
@@ -11,12 +11,14 @@ export function useLibrarySortMenu({ setLibrarySort }: Props) {
|
|||||||
|
|
||||||
const getSortLabel = (sort: LibraryAgentSort) => {
|
const getSortLabel = (sort: LibraryAgentSort) => {
|
||||||
switch (sort) {
|
switch (sort) {
|
||||||
|
case LibraryAgentSort.lastExecuted:
|
||||||
|
return "Last Executed";
|
||||||
case LibraryAgentSort.createdAt:
|
case LibraryAgentSort.createdAt:
|
||||||
return "Creation Date";
|
return "Creation Date";
|
||||||
case LibraryAgentSort.updatedAt:
|
case LibraryAgentSort.updatedAt:
|
||||||
return "Last Modified";
|
return "Last Modified";
|
||||||
default:
|
default:
|
||||||
return "Last Modified";
|
return "Last Executed";
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort";
|
import { LibraryAgentSort } from "@/app/api/__generated__/models/libraryAgentSort";
|
||||||
import { parseAsStringEnum, useQueryState } from "nuqs";
|
import { parseAsStringEnum, useQueryState } from "nuqs";
|
||||||
import { useCallback, useEffect, useMemo, useState } from "react";
|
import { useCallback, useMemo, useState } from "react";
|
||||||
|
|
||||||
const sortParser = parseAsStringEnum(Object.values(LibraryAgentSort));
|
const sortParser = parseAsStringEnum(Object.values(LibraryAgentSort));
|
||||||
|
|
||||||
@@ -11,14 +11,7 @@ export function useLibraryListPage() {
|
|||||||
const [uploadedFile, setUploadedFile] = useState<File | null>(null);
|
const [uploadedFile, setUploadedFile] = useState<File | null>(null);
|
||||||
const [librarySortRaw, setLibrarySortRaw] = useQueryState("sort", sortParser);
|
const [librarySortRaw, setLibrarySortRaw] = useQueryState("sort", sortParser);
|
||||||
|
|
||||||
// Ensure sort param is always present in URL (even if default)
|
const librarySort = librarySortRaw || LibraryAgentSort.lastExecuted;
|
||||||
useEffect(() => {
|
|
||||||
if (!librarySortRaw) {
|
|
||||||
setLibrarySortRaw(LibraryAgentSort.updatedAt, { shallow: false });
|
|
||||||
}
|
|
||||||
}, [librarySortRaw, setLibrarySortRaw]);
|
|
||||||
|
|
||||||
const librarySort = librarySortRaw || LibraryAgentSort.updatedAt;
|
|
||||||
|
|
||||||
const setLibrarySort = useCallback(
|
const setLibrarySort = useCallback(
|
||||||
(value: LibraryAgentSort) => {
|
(value: LibraryAgentSort) => {
|
||||||
|
|||||||
@@ -3361,7 +3361,7 @@
|
|||||||
"schema": {
|
"schema": {
|
||||||
"$ref": "#/components/schemas/LibraryAgentSort",
|
"$ref": "#/components/schemas/LibraryAgentSort",
|
||||||
"description": "Criteria to sort results by",
|
"description": "Criteria to sort results by",
|
||||||
"default": "updatedAt"
|
"default": "lastExecuted"
|
||||||
},
|
},
|
||||||
"description": "Criteria to sort results by"
|
"description": "Criteria to sort results by"
|
||||||
},
|
},
|
||||||
@@ -8239,7 +8239,7 @@
|
|||||||
},
|
},
|
||||||
"LibraryAgentSort": {
|
"LibraryAgentSort": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"enum": ["createdAt", "updatedAt"],
|
"enum": ["createdAt", "updatedAt", "lastExecuted"],
|
||||||
"title": "LibraryAgentSort",
|
"title": "LibraryAgentSort",
|
||||||
"description": "Possible sort options for sorting library agents."
|
"description": "Possible sort options for sorting library agents."
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -612,6 +612,7 @@ export type LibraryAgentPresetUpdatable = Partial<
|
|||||||
export enum LibraryAgentSortEnum {
|
export enum LibraryAgentSortEnum {
|
||||||
CREATED_AT = "createdAt",
|
CREATED_AT = "createdAt",
|
||||||
UPDATED_AT = "updatedAt",
|
UPDATED_AT = "updatedAt",
|
||||||
|
LAST_EXECUTED = "lastExecuted",
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *** CREDENTIALS *** */
|
/* *** CREDENTIALS *** */
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ export class LibraryPage extends BasePage {
|
|||||||
|
|
||||||
async selectSortOption(
|
async selectSortOption(
|
||||||
page: Page,
|
page: Page,
|
||||||
sortOption: "Creation Date" | "Last Modified",
|
sortOption: "Last Executed" | "Creation Date" | "Last Modified",
|
||||||
): Promise<void> {
|
): Promise<void> {
|
||||||
const { getRole } = getSelectors(page);
|
const { getRole } = getSelectors(page);
|
||||||
await getRole("combobox").click();
|
await getRole("combobox").click();
|
||||||
|
|||||||
@@ -182,7 +182,7 @@ test("logged in user is redirected from /login to /library", async ({
|
|||||||
await hasUrl(page, "/marketplace");
|
await hasUrl(page, "/marketplace");
|
||||||
|
|
||||||
await page.goto("/login");
|
await page.goto("/login");
|
||||||
await hasUrl(page, "/library?sort=updatedAt");
|
await hasUrl(page, "/library");
|
||||||
});
|
});
|
||||||
|
|
||||||
test("logged in user is redirected from /signup to /library", async ({
|
test("logged in user is redirected from /signup to /library", async ({
|
||||||
@@ -195,5 +195,5 @@ test("logged in user is redirected from /signup to /library", async ({
|
|||||||
await hasUrl(page, "/marketplace");
|
await hasUrl(page, "/marketplace");
|
||||||
|
|
||||||
await page.goto("/signup");
|
await page.goto("/signup");
|
||||||
await hasUrl(page, "/library?sort=updatedAt");
|
await hasUrl(page, "/library");
|
||||||
});
|
});
|
||||||
|
|||||||
Reference in New Issue
Block a user