mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-03 11:24:57 -05:00
Compare commits
5 Commits
fix/copilo
...
otto/copil
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9e604528ea | ||
|
|
c3ec7c2880 | ||
|
|
7d9380a793 | ||
|
|
678ddde751 | ||
|
|
aef6f57cfd |
@@ -17,6 +17,14 @@ from .model import ChatSession, create_chat_session, get_chat_session, get_user_
|
|||||||
|
|
||||||
config = ChatConfig()
|
config = ChatConfig()
|
||||||
|
|
||||||
|
# SSE response headers for streaming
|
||||||
|
SSE_RESPONSE_HEADERS = {
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no",
|
||||||
|
"x-vercel-ai-ui-message-stream": "v1",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -32,6 +40,60 @@ async def _validate_and_get_session(
|
|||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_stream_generator(
|
||||||
|
session_id: str,
|
||||||
|
message: str,
|
||||||
|
user_id: str | None,
|
||||||
|
session: ChatSession,
|
||||||
|
is_user_message: bool = True,
|
||||||
|
context: dict[str, str] | None = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Create SSE event generator for chat streaming.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_id: Chat session ID
|
||||||
|
message: User message to process
|
||||||
|
user_id: Optional authenticated user ID
|
||||||
|
session: Pre-fetched chat session
|
||||||
|
is_user_message: Whether the message is from a user
|
||||||
|
context: Optional context dict with url and content
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
SSE-formatted chunks from the chat completion stream
|
||||||
|
"""
|
||||||
|
chunk_count = 0
|
||||||
|
first_chunk_type: str | None = None
|
||||||
|
async for chunk in chat_service.stream_chat_completion(
|
||||||
|
session_id,
|
||||||
|
message,
|
||||||
|
is_user_message=is_user_message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
context=context,
|
||||||
|
):
|
||||||
|
if chunk_count < 3:
|
||||||
|
logger.info(
|
||||||
|
"Chat stream chunk",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"chunk_type": str(chunk.type),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
if not first_chunk_type:
|
||||||
|
first_chunk_type = str(chunk.type)
|
||||||
|
chunk_count += 1
|
||||||
|
yield chunk.to_sse()
|
||||||
|
logger.info(
|
||||||
|
"Chat stream completed",
|
||||||
|
extra={
|
||||||
|
"session_id": session_id,
|
||||||
|
"chunk_count": chunk_count,
|
||||||
|
"first_chunk_type": first_chunk_type,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(
|
router = APIRouter(
|
||||||
tags=["chat"],
|
tags=["chat"],
|
||||||
)
|
)
|
||||||
@@ -221,49 +283,17 @@ async def stream_chat_post(
|
|||||||
"""
|
"""
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
|
||||||
chunk_count = 0
|
|
||||||
first_chunk_type: str | None = None
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
|
||||||
session_id,
|
|
||||||
request.message,
|
|
||||||
is_user_message=request.is_user_message,
|
|
||||||
user_id=user_id,
|
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
|
||||||
context=request.context,
|
|
||||||
):
|
|
||||||
if chunk_count < 3:
|
|
||||||
logger.info(
|
|
||||||
"Chat stream chunk",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_type": str(chunk.type),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if not first_chunk_type:
|
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
|
||||||
logger.info(
|
|
||||||
"Chat stream completed",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"first_chunk_type": first_chunk_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
_create_stream_generator(
|
||||||
|
session_id=session_id,
|
||||||
|
message=request.message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
is_user_message=request.is_user_message,
|
||||||
|
context=request.context,
|
||||||
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers=SSE_RESPONSE_HEADERS,
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -295,48 +325,16 @@ async def stream_chat_get(
|
|||||||
"""
|
"""
|
||||||
session = await _validate_and_get_session(session_id, user_id)
|
session = await _validate_and_get_session(session_id, user_id)
|
||||||
|
|
||||||
async def event_generator() -> AsyncGenerator[str, None]:
|
|
||||||
chunk_count = 0
|
|
||||||
first_chunk_type: str | None = None
|
|
||||||
async for chunk in chat_service.stream_chat_completion(
|
|
||||||
session_id,
|
|
||||||
message,
|
|
||||||
is_user_message=is_user_message,
|
|
||||||
user_id=user_id,
|
|
||||||
session=session, # Pass pre-fetched session to avoid double-fetch
|
|
||||||
):
|
|
||||||
if chunk_count < 3:
|
|
||||||
logger.info(
|
|
||||||
"Chat stream chunk",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_type": str(chunk.type),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if not first_chunk_type:
|
|
||||||
first_chunk_type = str(chunk.type)
|
|
||||||
chunk_count += 1
|
|
||||||
yield chunk.to_sse()
|
|
||||||
logger.info(
|
|
||||||
"Chat stream completed",
|
|
||||||
extra={
|
|
||||||
"session_id": session_id,
|
|
||||||
"chunk_count": chunk_count,
|
|
||||||
"first_chunk_type": first_chunk_type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
# AI SDK protocol termination
|
|
||||||
yield "data: [DONE]\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
event_generator(),
|
_create_stream_generator(
|
||||||
|
session_id=session_id,
|
||||||
|
message=message,
|
||||||
|
user_id=user_id,
|
||||||
|
session=session,
|
||||||
|
is_user_message=is_user_message,
|
||||||
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
headers={
|
headers=SSE_RESPONSE_HEADERS,
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
"X-Accel-Buffering": "no", # Disable nginx buffering
|
|
||||||
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,13 @@ import logging
|
|||||||
import time
|
import time
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from dataclasses import dataclass
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from backend.util.prompt import CompressResult
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from langfuse import get_client
|
from langfuse import get_client
|
||||||
from openai import (
|
from openai import (
|
||||||
@@ -17,7 +20,6 @@ from openai import (
|
|||||||
RateLimitError,
|
RateLimitError,
|
||||||
)
|
)
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
ChatCompletionAssistantMessageParam,
|
|
||||||
ChatCompletionChunk,
|
ChatCompletionChunk,
|
||||||
ChatCompletionMessageParam,
|
ChatCompletionMessageParam,
|
||||||
ChatCompletionStreamOptionsParam,
|
ChatCompletionStreamOptionsParam,
|
||||||
@@ -31,7 +33,6 @@ from backend.data.understanding import (
|
|||||||
get_business_understanding,
|
get_business_understanding,
|
||||||
)
|
)
|
||||||
from backend.util.exceptions import NotFoundError
|
from backend.util.exceptions import NotFoundError
|
||||||
from backend.util.prompt import estimate_token_count
|
|
||||||
from backend.util.settings import Settings
|
from backend.util.settings import Settings
|
||||||
|
|
||||||
from . import db as chat_db
|
from . import db as chat_db
|
||||||
@@ -803,402 +804,58 @@ def _is_region_blocked_error(error: Exception) -> bool:
|
|||||||
return "not available in your region" in str(error).lower()
|
return "not available in your region" in str(error).lower()
|
||||||
|
|
||||||
|
|
||||||
# Context window management constants
|
|
||||||
TOKEN_THRESHOLD = 120_000
|
|
||||||
KEEP_RECENT_MESSAGES = 15
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ContextWindowResult:
|
|
||||||
"""Result of context window management."""
|
|
||||||
|
|
||||||
messages: list[dict[str, Any]]
|
|
||||||
token_count: int
|
|
||||||
was_compacted: bool
|
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def _messages_to_dicts(messages: list) -> list[dict[str, Any]]:
|
|
||||||
"""Convert message objects to dicts, filtering None values.
|
|
||||||
|
|
||||||
Handles both TypedDict (dict-like) and other message formats.
|
|
||||||
"""
|
|
||||||
result = []
|
|
||||||
for msg in messages:
|
|
||||||
if msg is None:
|
|
||||||
continue
|
|
||||||
if isinstance(msg, dict):
|
|
||||||
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
|
||||||
else:
|
|
||||||
msg_dict = dict(msg)
|
|
||||||
result.append(msg_dict)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def _manage_context_window(
|
async def _manage_context_window(
|
||||||
messages: list,
|
messages: list,
|
||||||
model: str,
|
model: str,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
base_url: str | None = None,
|
base_url: str | None = None,
|
||||||
) -> ContextWindowResult:
|
) -> "CompressResult":
|
||||||
"""
|
"""
|
||||||
Manage context window by summarizing old messages if token count exceeds threshold.
|
Manage context window using the unified compress_context function.
|
||||||
|
|
||||||
This function handles context compaction for LLM calls by:
|
This is a thin wrapper that creates an OpenAI client for summarization
|
||||||
1. Counting tokens in the message list
|
and delegates to the shared compression logic in prompt.py.
|
||||||
2. If over threshold, summarizing old messages while keeping recent ones
|
|
||||||
3. Ensuring tool_call/tool_response pairs stay intact
|
|
||||||
4. Progressively reducing message count if still over limit
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: List of messages in OpenAI format (with system prompt if present)
|
messages: List of messages in OpenAI format
|
||||||
model: Model name for token counting
|
model: Model name for token counting and summarization
|
||||||
api_key: API key for summarization calls
|
api_key: API key for summarization calls
|
||||||
base_url: Base URL for summarization calls
|
base_url: Base URL for summarization calls
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ContextWindowResult with compacted messages and metadata
|
CompressResult with compacted messages and metadata
|
||||||
"""
|
"""
|
||||||
if not messages:
|
|
||||||
return ContextWindowResult([], 0, False, "No messages to compact")
|
|
||||||
|
|
||||||
messages_dict = _messages_to_dicts(messages)
|
|
||||||
|
|
||||||
# Normalize model name for token counting (tiktoken only supports OpenAI models)
|
|
||||||
token_count_model = model.split("/")[-1] if "/" in model else model
|
|
||||||
if "claude" in token_count_model.lower() or not any(
|
|
||||||
known in token_count_model.lower()
|
|
||||||
for known in ["gpt", "o1", "chatgpt", "text-"]
|
|
||||||
):
|
|
||||||
token_count_model = "gpt-4o"
|
|
||||||
|
|
||||||
try:
|
|
||||||
token_count = estimate_token_count(messages_dict, model=token_count_model)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Token counting failed: {e}. Using gpt-4o approximation.")
|
|
||||||
token_count_model = "gpt-4o"
|
|
||||||
token_count = estimate_token_count(messages_dict, model=token_count_model)
|
|
||||||
|
|
||||||
if token_count <= TOKEN_THRESHOLD:
|
|
||||||
return ContextWindowResult(messages, token_count, False)
|
|
||||||
|
|
||||||
has_system_prompt = messages[0].get("role") == "system"
|
|
||||||
slice_start = max(0, len(messages_dict) - KEEP_RECENT_MESSAGES)
|
|
||||||
recent_messages = _ensure_tool_pairs_intact(
|
|
||||||
messages_dict[-KEEP_RECENT_MESSAGES:], messages_dict, slice_start
|
|
||||||
)
|
|
||||||
|
|
||||||
# Determine old messages to summarize (explicit bounds to avoid slice edge cases)
|
|
||||||
system_msg = messages[0] if has_system_prompt else None
|
|
||||||
if has_system_prompt:
|
|
||||||
old_messages_dict = (
|
|
||||||
messages_dict[1:-KEEP_RECENT_MESSAGES]
|
|
||||||
if len(messages_dict) > KEEP_RECENT_MESSAGES + 1
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
old_messages_dict = (
|
|
||||||
messages_dict[:-KEEP_RECENT_MESSAGES]
|
|
||||||
if len(messages_dict) > KEEP_RECENT_MESSAGES
|
|
||||||
else []
|
|
||||||
)
|
|
||||||
|
|
||||||
# Try to summarize old messages, fall back to truncation on failure
|
|
||||||
summary_msg = None
|
|
||||||
if old_messages_dict:
|
|
||||||
try:
|
|
||||||
summary_text = await _summarize_messages(
|
|
||||||
old_messages_dict, model=model, api_key=api_key, base_url=base_url
|
|
||||||
)
|
|
||||||
summary_msg = ChatCompletionAssistantMessageParam(
|
|
||||||
role="assistant",
|
|
||||||
content=f"[Previous conversation summary — for context only]: {summary_text}",
|
|
||||||
)
|
|
||||||
base = [system_msg, summary_msg] if has_system_prompt else [summary_msg]
|
|
||||||
messages = base + recent_messages
|
|
||||||
logger.info(
|
|
||||||
f"Context summarized: {token_count} tokens, "
|
|
||||||
f"summarized {len(old_messages_dict)} msgs, kept {KEEP_RECENT_MESSAGES}"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Summarization failed, falling back to truncation: {e}")
|
|
||||||
messages = (
|
|
||||||
[system_msg] + recent_messages if has_system_prompt else recent_messages
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
f"Token count {token_count} exceeds threshold but no old messages to summarize"
|
|
||||||
)
|
|
||||||
|
|
||||||
new_token_count = estimate_token_count(
|
|
||||||
_messages_to_dicts(messages), model=token_count_model
|
|
||||||
)
|
|
||||||
|
|
||||||
# Progressive truncation if still over limit
|
|
||||||
if new_token_count > TOKEN_THRESHOLD:
|
|
||||||
logger.warning(
|
|
||||||
f"Still over limit: {new_token_count} tokens. Reducing messages."
|
|
||||||
)
|
|
||||||
base_msgs = (
|
|
||||||
recent_messages
|
|
||||||
if old_messages_dict
|
|
||||||
else (messages_dict[1:] if has_system_prompt else messages_dict)
|
|
||||||
)
|
|
||||||
|
|
||||||
def build_messages(recent: list) -> list:
|
|
||||||
"""Build message list with optional system prompt and summary."""
|
|
||||||
prefix = []
|
|
||||||
if has_system_prompt and system_msg:
|
|
||||||
prefix.append(system_msg)
|
|
||||||
if summary_msg:
|
|
||||||
prefix.append(summary_msg)
|
|
||||||
return prefix + recent
|
|
||||||
|
|
||||||
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
|
||||||
if keep_count == 0:
|
|
||||||
messages = build_messages([])
|
|
||||||
if not messages:
|
|
||||||
continue
|
|
||||||
elif len(base_msgs) < keep_count:
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
reduced = _ensure_tool_pairs_intact(
|
|
||||||
base_msgs[-keep_count:],
|
|
||||||
base_msgs,
|
|
||||||
max(0, len(base_msgs) - keep_count),
|
|
||||||
)
|
|
||||||
messages = build_messages(reduced)
|
|
||||||
|
|
||||||
new_token_count = estimate_token_count(
|
|
||||||
_messages_to_dicts(messages), model=token_count_model
|
|
||||||
)
|
|
||||||
if new_token_count <= TOKEN_THRESHOLD:
|
|
||||||
logger.info(
|
|
||||||
f"Reduced to {keep_count} messages, {new_token_count} tokens"
|
|
||||||
)
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
logger.error(
|
|
||||||
f"Cannot reduce below threshold. Final: {new_token_count} tokens"
|
|
||||||
)
|
|
||||||
if has_system_prompt and len(messages) > 1:
|
|
||||||
messages = messages[1:]
|
|
||||||
logger.critical("Dropped system prompt as last resort")
|
|
||||||
return ContextWindowResult(
|
|
||||||
messages, new_token_count, True, "System prompt dropped"
|
|
||||||
)
|
|
||||||
# No system prompt to drop - return error so callers don't proceed with oversized context
|
|
||||||
return ContextWindowResult(
|
|
||||||
messages,
|
|
||||||
new_token_count,
|
|
||||||
True,
|
|
||||||
"Unable to reduce context below token limit",
|
|
||||||
)
|
|
||||||
|
|
||||||
return ContextWindowResult(messages, new_token_count, True)
|
|
||||||
|
|
||||||
|
|
||||||
async def _summarize_messages(
|
|
||||||
messages: list,
|
|
||||||
model: str,
|
|
||||||
api_key: str | None = None,
|
|
||||||
base_url: str | None = None,
|
|
||||||
timeout: float = 30.0,
|
|
||||||
) -> str:
|
|
||||||
"""Summarize a list of messages into concise context.
|
|
||||||
|
|
||||||
Uses the same model as the chat for higher quality summaries.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of message dicts to summarize
|
|
||||||
model: Model to use for summarization (same as chat model)
|
|
||||||
api_key: API key for OpenAI client
|
|
||||||
base_url: Base URL for OpenAI client
|
|
||||||
timeout: Request timeout in seconds (default: 30.0)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Summarized text
|
|
||||||
"""
|
|
||||||
# Format messages for summarization
|
|
||||||
conversation = []
|
|
||||||
for msg in messages:
|
|
||||||
role = msg.get("role", "")
|
|
||||||
content = msg.get("content", "")
|
|
||||||
# Include user, assistant, and tool messages (tool outputs are important context)
|
|
||||||
if content and role in ("user", "assistant", "tool"):
|
|
||||||
conversation.append(f"{role.upper()}: {content}")
|
|
||||||
|
|
||||||
conversation_text = "\n\n".join(conversation)
|
|
||||||
|
|
||||||
# Handle empty conversation
|
|
||||||
if not conversation_text:
|
|
||||||
return "No conversation history available."
|
|
||||||
|
|
||||||
# Truncate conversation to fit within summarization model's context
|
|
||||||
# gpt-4o-mini has 128k context, but we limit to ~25k tokens (~100k chars) for safety
|
|
||||||
MAX_CHARS = 100_000
|
|
||||||
if len(conversation_text) > MAX_CHARS:
|
|
||||||
conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]"
|
|
||||||
|
|
||||||
# Call LLM to summarize
|
|
||||||
import openai
|
import openai
|
||||||
|
|
||||||
summarization_client = openai.AsyncOpenAI(
|
from backend.util.prompt import compress_context
|
||||||
api_key=api_key, base_url=base_url, timeout=timeout
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await summarization_client.chat.completions.create(
|
# Convert messages to dict format
|
||||||
model=model,
|
messages_dict = []
|
||||||
messages=[
|
for msg in messages:
|
||||||
{
|
if isinstance(msg, dict):
|
||||||
"role": "system",
|
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
||||||
"content": (
|
|
||||||
"Create a detailed summary of the conversation so far. "
|
|
||||||
"This summary will be used as context when continuing the conversation.\n\n"
|
|
||||||
"Before writing the summary, analyze each message chronologically to identify:\n"
|
|
||||||
"- User requests and their explicit goals\n"
|
|
||||||
"- Your approach and key decisions made\n"
|
|
||||||
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
|
||||||
"- Errors encountered and resolutions applied\n\n"
|
|
||||||
"You MUST include ALL of the following sections:\n\n"
|
|
||||||
"## 1. Primary Request and Intent\n"
|
|
||||||
"The user's explicit goals and what they are trying to accomplish.\n\n"
|
|
||||||
"## 2. Key Technical Concepts\n"
|
|
||||||
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
|
|
||||||
"## 3. Files and Resources Involved\n"
|
|
||||||
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
|
||||||
"## 4. Errors and Fixes\n"
|
|
||||||
"Problems encountered, error messages, and their resolutions. "
|
|
||||||
"Include any user feedback on fixes.\n\n"
|
|
||||||
"## 5. Problem Solving\n"
|
|
||||||
"Issues that have been resolved and how they were addressed.\n\n"
|
|
||||||
"## 6. All User Messages\n"
|
|
||||||
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
|
|
||||||
"## 7. Pending Tasks\n"
|
|
||||||
"Work items the user explicitly requested that have not yet been completed.\n\n"
|
|
||||||
"## 8. Current Work\n"
|
|
||||||
"Precise description of what was being worked on most recently, including relevant context.\n\n"
|
|
||||||
"## 9. Next Steps\n"
|
|
||||||
"What should happen next, aligned with the user's most recent requests. "
|
|
||||||
"Include verbatim quotes of recent instructions if relevant."
|
|
||||||
),
|
|
||||||
},
|
|
||||||
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
|
||||||
],
|
|
||||||
max_tokens=1500,
|
|
||||||
temperature=0.3,
|
|
||||||
)
|
|
||||||
|
|
||||||
summary = response.choices[0].message.content
|
|
||||||
return summary or "No summary available."
|
|
||||||
|
|
||||||
|
|
||||||
def _ensure_tool_pairs_intact(
|
|
||||||
recent_messages: list[dict],
|
|
||||||
all_messages: list[dict],
|
|
||||||
start_index: int,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Ensure tool_call/tool_response pairs stay together after slicing.
|
|
||||||
|
|
||||||
When slicing messages for context compaction, a naive slice can separate
|
|
||||||
an assistant message containing tool_calls from its corresponding tool
|
|
||||||
response messages. This causes API validation errors (e.g., Anthropic's
|
|
||||||
"unexpected tool_use_id found in tool_result blocks").
|
|
||||||
|
|
||||||
This function checks for orphan tool responses in the slice and extends
|
|
||||||
backwards to include their corresponding assistant messages.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
recent_messages: The sliced messages to validate
|
|
||||||
all_messages: The complete message list (for looking up missing assistants)
|
|
||||||
start_index: The index in all_messages where recent_messages begins
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A potentially extended list of messages with tool pairs intact
|
|
||||||
"""
|
|
||||||
if not recent_messages:
|
|
||||||
return recent_messages
|
|
||||||
|
|
||||||
# Collect all tool_call_ids from assistant messages in the slice
|
|
||||||
available_tool_call_ids: set[str] = set()
|
|
||||||
for msg in recent_messages:
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
for tc in msg["tool_calls"]:
|
|
||||||
tc_id = tc.get("id")
|
|
||||||
if tc_id:
|
|
||||||
available_tool_call_ids.add(tc_id)
|
|
||||||
|
|
||||||
# Find orphan tool responses (tool messages whose tool_call_id is missing)
|
|
||||||
orphan_tool_call_ids: set[str] = set()
|
|
||||||
for msg in recent_messages:
|
|
||||||
if msg.get("role") == "tool":
|
|
||||||
tc_id = msg.get("tool_call_id")
|
|
||||||
if tc_id and tc_id not in available_tool_call_ids:
|
|
||||||
orphan_tool_call_ids.add(tc_id)
|
|
||||||
|
|
||||||
if not orphan_tool_call_ids:
|
|
||||||
# No orphans, slice is valid
|
|
||||||
return recent_messages
|
|
||||||
|
|
||||||
# Find the assistant messages that contain the orphan tool_call_ids
|
|
||||||
# Search backwards from start_index in all_messages
|
|
||||||
messages_to_prepend: list[dict] = []
|
|
||||||
for i in range(start_index - 1, -1, -1):
|
|
||||||
msg = all_messages[i]
|
|
||||||
if msg.get("role") == "assistant" and msg.get("tool_calls"):
|
|
||||||
msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")}
|
|
||||||
if msg_tool_ids & orphan_tool_call_ids:
|
|
||||||
# This assistant message has tool_calls we need
|
|
||||||
# Also collect its contiguous tool responses that follow it
|
|
||||||
assistant_and_responses: list[dict] = [msg]
|
|
||||||
|
|
||||||
# Scan forward from this assistant to collect tool responses
|
|
||||||
for j in range(i + 1, start_index):
|
|
||||||
following_msg = all_messages[j]
|
|
||||||
if following_msg.get("role") == "tool":
|
|
||||||
tool_id = following_msg.get("tool_call_id")
|
|
||||||
if tool_id and tool_id in msg_tool_ids:
|
|
||||||
assistant_and_responses.append(following_msg)
|
|
||||||
else:
|
else:
|
||||||
# Stop at first non-tool message
|
msg_dict = dict(msg)
|
||||||
break
|
messages_dict.append(msg_dict)
|
||||||
|
|
||||||
# Prepend the assistant and its tool responses (maintain order)
|
# Only create client if api_key is provided (enables summarization)
|
||||||
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
# Use context manager to avoid socket leaks
|
||||||
# Mark these as found
|
if api_key:
|
||||||
orphan_tool_call_ids -= msg_tool_ids
|
async with openai.AsyncOpenAI(
|
||||||
# Also add this assistant's tool_call_ids to available set
|
api_key=api_key, base_url=base_url, timeout=30.0
|
||||||
available_tool_call_ids |= msg_tool_ids
|
) as client:
|
||||||
|
return await compress_context(
|
||||||
if not orphan_tool_call_ids:
|
messages=messages_dict,
|
||||||
# Found all missing assistants
|
model=model,
|
||||||
break
|
client=client,
|
||||||
|
|
||||||
if orphan_tool_call_ids:
|
|
||||||
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
|
||||||
# This shouldn't happen in normal operation but handles edge cases
|
|
||||||
logger.warning(
|
|
||||||
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
|
||||||
"Removing orphan tool responses."
|
|
||||||
)
|
)
|
||||||
recent_messages = [
|
else:
|
||||||
msg
|
# No API key - use truncation-only mode
|
||||||
for msg in recent_messages
|
return await compress_context(
|
||||||
if not (
|
messages=messages_dict,
|
||||||
msg.get("role") == "tool"
|
model=model,
|
||||||
and msg.get("tool_call_id") in orphan_tool_call_ids
|
client=None,
|
||||||
)
|
)
|
||||||
]
|
|
||||||
|
|
||||||
if messages_to_prepend:
|
|
||||||
logger.info(
|
|
||||||
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
|
||||||
f"tool_call/tool_response pairs"
|
|
||||||
)
|
|
||||||
return messages_to_prepend + recent_messages
|
|
||||||
|
|
||||||
return recent_messages
|
|
||||||
|
|
||||||
|
|
||||||
async def _stream_chat_chunks(
|
async def _stream_chat_chunks(
|
||||||
|
|||||||
@@ -0,0 +1,77 @@
|
|||||||
|
"""Shared helpers for chat tools."""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from .models import ErrorResponse
|
||||||
|
|
||||||
|
|
||||||
|
def error_response(
|
||||||
|
message: str, session_id: str | None, **kwargs: Any
|
||||||
|
) -> ErrorResponse:
|
||||||
|
"""Create standardized error response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message to display
|
||||||
|
session_id: Current session ID
|
||||||
|
**kwargs: Additional fields to pass to ErrorResponse
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ErrorResponse with the given message and session_id
|
||||||
|
"""
|
||||||
|
return ErrorResponse(message=message, session_id=session_id, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_inputs_from_schema(
|
||||||
|
input_schema: dict[str, Any],
|
||||||
|
exclude_fields: set[str] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Extract input field info from JSON schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_schema: JSON schema dict with 'properties' and 'required'
|
||||||
|
exclude_fields: Set of field names to exclude (e.g., credential fields)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of dicts with field info (name, title, type, description, required, default)
|
||||||
|
"""
|
||||||
|
exclude = exclude_fields or set()
|
||||||
|
properties = input_schema.get("properties", {})
|
||||||
|
required = set(input_schema.get("required", []))
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": name,
|
||||||
|
"title": schema.get("title", name),
|
||||||
|
"type": schema.get("type", "string"),
|
||||||
|
"description": schema.get("description", ""),
|
||||||
|
"required": name in required,
|
||||||
|
"default": schema.get("default"),
|
||||||
|
}
|
||||||
|
for name, schema in properties.items()
|
||||||
|
if name not in exclude
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def format_inputs_as_markdown(inputs: list[dict[str, Any]]) -> str:
|
||||||
|
"""Format input fields as a readable markdown list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: List of input dicts from get_inputs_from_schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Markdown-formatted string listing the inputs
|
||||||
|
"""
|
||||||
|
if not inputs:
|
||||||
|
return "No inputs required."
|
||||||
|
|
||||||
|
lines = []
|
||||||
|
for inp in inputs:
|
||||||
|
required_marker = " (required)" if inp.get("required") else ""
|
||||||
|
default = inp.get("default")
|
||||||
|
default_info = f" [default: {default}]" if default is not None else ""
|
||||||
|
description = inp.get("description", "")
|
||||||
|
desc_info = f" - {description}" if description else ""
|
||||||
|
|
||||||
|
lines.append(f"- **{inp['name']}**{required_marker}{default_info}{desc_info}")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
@@ -24,6 +24,7 @@ from backend.util.timezone_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
AgentDetails,
|
AgentDetails,
|
||||||
AgentDetailsResponse,
|
AgentDetailsResponse,
|
||||||
@@ -354,19 +355,7 @@ class RunAgentTool(BaseTool):
|
|||||||
|
|
||||||
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
"""Extract inputs list from schema."""
|
"""Extract inputs list from schema."""
|
||||||
inputs_list = []
|
return get_inputs_from_schema(input_schema)
|
||||||
if isinstance(input_schema, dict) and "properties" in input_schema:
|
|
||||||
for field_name, field_schema in input_schema["properties"].items():
|
|
||||||
inputs_list.append(
|
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in input_schema.get("required", []),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return inputs_list
|
|
||||||
|
|
||||||
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
def _get_execution_modes(self, graph: GraphModel) -> list[str]:
|
||||||
"""Get available execution modes for the graph."""
|
"""Get available execution modes for the graph."""
|
||||||
|
|||||||
@@ -8,12 +8,13 @@ from typing import Any
|
|||||||
from backend.api.features.chat.model import ChatSession
|
from backend.api.features.chat.model import ChatSession
|
||||||
from backend.data.block import get_block
|
from backend.data.block import get_block
|
||||||
from backend.data.execution import ExecutionContext
|
from backend.data.execution import ExecutionContext
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput
|
||||||
from backend.data.workspace import get_or_create_workspace
|
from backend.data.workspace import get_or_create_workspace
|
||||||
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
from backend.integrations.creds_manager import IntegrationCredentialsManager
|
||||||
from backend.util.exceptions import BlockError
|
from backend.util.exceptions import BlockError
|
||||||
|
|
||||||
from .base import BaseTool
|
from .base import BaseTool
|
||||||
|
from .helpers import get_inputs_from_schema
|
||||||
from .models import (
|
from .models import (
|
||||||
BlockOutputResponse,
|
BlockOutputResponse,
|
||||||
ErrorResponse,
|
ErrorResponse,
|
||||||
@@ -22,7 +23,10 @@ from .models import (
|
|||||||
ToolResponseBase,
|
ToolResponseBase,
|
||||||
UserReadiness,
|
UserReadiness,
|
||||||
)
|
)
|
||||||
from .utils import build_missing_credentials_from_field_info
|
from .utils import (
|
||||||
|
build_missing_credentials_from_field_info,
|
||||||
|
match_credentials_to_requirements,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -71,6 +75,22 @@ class RunBlockTool(BaseTool):
|
|||||||
def requires_auth(self) -> bool:
|
def requires_auth(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def _get_credentials_requirements(
|
||||||
|
self,
|
||||||
|
block: Any,
|
||||||
|
) -> dict[str, CredentialsFieldInfo]:
|
||||||
|
"""
|
||||||
|
Get credential requirements from block's input schema.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
block: Block to get credentials for
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict mapping field names to CredentialsFieldInfo
|
||||||
|
"""
|
||||||
|
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
||||||
|
return credentials_fields_info if credentials_fields_info else {}
|
||||||
|
|
||||||
async def _check_block_credentials(
|
async def _check_block_credentials(
|
||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
@@ -82,53 +102,12 @@ class RunBlockTool(BaseTool):
|
|||||||
Returns:
|
Returns:
|
||||||
tuple[matched_credentials, missing_credentials]
|
tuple[matched_credentials, missing_credentials]
|
||||||
"""
|
"""
|
||||||
matched_credentials: dict[str, CredentialsMetaInput] = {}
|
requirements = self._get_credentials_requirements(block)
|
||||||
missing_credentials: list[CredentialsMetaInput] = []
|
|
||||||
|
|
||||||
# Get credential field info from block's input schema
|
if not requirements:
|
||||||
credentials_fields_info = block.input_schema.get_credentials_fields_info()
|
return {}, []
|
||||||
|
|
||||||
if not credentials_fields_info:
|
return await match_credentials_to_requirements(user_id, requirements)
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
# Get user's available credentials
|
|
||||||
creds_manager = IntegrationCredentialsManager()
|
|
||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
|
||||||
|
|
||||||
for field_name, field_info in credentials_fields_info.items():
|
|
||||||
# field_info.provider is a frozenset of acceptable providers
|
|
||||||
# field_info.supported_types is a frozenset of acceptable types
|
|
||||||
matching_cred = next(
|
|
||||||
(
|
|
||||||
cred
|
|
||||||
for cred in available_creds
|
|
||||||
if cred.provider in field_info.provider
|
|
||||||
and cred.type in field_info.supported_types
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
|
|
||||||
if matching_cred:
|
|
||||||
matched_credentials[field_name] = CredentialsMetaInput(
|
|
||||||
id=matching_cred.id,
|
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Create a placeholder for the missing credential
|
|
||||||
provider = next(iter(field_info.provider), "unknown")
|
|
||||||
cred_type = next(iter(field_info.supported_types), "api_key")
|
|
||||||
missing_credentials.append(
|
|
||||||
CredentialsMetaInput(
|
|
||||||
id=field_name,
|
|
||||||
provider=provider, # type: ignore
|
|
||||||
type=cred_type, # type: ignore
|
|
||||||
title=field_name.replace("_", " ").title(),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return matched_credentials, missing_credentials
|
|
||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
@@ -320,27 +299,7 @@ class RunBlockTool(BaseTool):
|
|||||||
|
|
||||||
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
|
||||||
"""Extract non-credential inputs from block schema."""
|
"""Extract non-credential inputs from block schema."""
|
||||||
inputs_list = []
|
|
||||||
schema = block.input_schema.jsonschema()
|
schema = block.input_schema.jsonschema()
|
||||||
properties = schema.get("properties", {})
|
|
||||||
required_fields = set(schema.get("required", []))
|
|
||||||
|
|
||||||
# Get credential field names to exclude
|
# Get credential field names to exclude
|
||||||
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
credentials_fields = set(block.input_schema.get_credentials_fields().keys())
|
||||||
|
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
|
||||||
for field_name, field_schema in properties.items():
|
|
||||||
# Skip credential fields
|
|
||||||
if field_name in credentials_fields:
|
|
||||||
continue
|
|
||||||
|
|
||||||
inputs_list.append(
|
|
||||||
{
|
|
||||||
"name": field_name,
|
|
||||||
"title": field_schema.get("title", field_name),
|
|
||||||
"type": field_schema.get("type", "string"),
|
|
||||||
"description": field_schema.get("description", ""),
|
|
||||||
"required": field_name in required_fields,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return inputs_list
|
|
||||||
|
|||||||
@@ -225,6 +225,127 @@ async def get_or_create_library_agent(
|
|||||||
return library_agents[0]
|
return library_agents[0]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_user_credentials(user_id: str) -> list:
|
||||||
|
"""
|
||||||
|
Get all available credentials for a user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of user's credentials
|
||||||
|
"""
|
||||||
|
creds_manager = IntegrationCredentialsManager()
|
||||||
|
return await creds_manager.store.get_all_creds(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
def find_matching_credential(
|
||||||
|
available_creds: list,
|
||||||
|
field_info: CredentialsFieldInfo,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Find a credential that matches the required provider, type, and scopes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
available_creds: List of user's available credentials
|
||||||
|
field_info: CredentialsFieldInfo with provider, type, and scope requirements
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Matching credential or None
|
||||||
|
"""
|
||||||
|
for cred in available_creds:
|
||||||
|
if cred.provider not in field_info.provider:
|
||||||
|
continue
|
||||||
|
if cred.type not in field_info.supported_types:
|
||||||
|
continue
|
||||||
|
if not _credential_has_required_scopes(cred, field_info):
|
||||||
|
continue
|
||||||
|
return cred
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def create_credential_meta_from_match(
|
||||||
|
matching_cred,
|
||||||
|
) -> CredentialsMetaInput:
|
||||||
|
"""
|
||||||
|
Create a CredentialsMetaInput from a matched credential.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
matching_cred: The matched credential object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CredentialsMetaInput instance
|
||||||
|
"""
|
||||||
|
return CredentialsMetaInput(
|
||||||
|
id=matching_cred.id,
|
||||||
|
provider=matching_cred.provider, # type: ignore
|
||||||
|
type=matching_cred.type,
|
||||||
|
title=matching_cred.title,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def match_credentials_to_requirements(
|
||||||
|
user_id: str,
|
||||||
|
requirements: dict[str, CredentialsFieldInfo],
|
||||||
|
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
|
||||||
|
"""
|
||||||
|
Match user's credentials against a dictionary of credential requirements.
|
||||||
|
|
||||||
|
This is the core matching logic shared by both graph and block credential matching.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: The user's ID
|
||||||
|
requirements: Dict mapping field names to CredentialsFieldInfo
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[matched_credentials dict, missing_credentials list]
|
||||||
|
"""
|
||||||
|
matched: dict[str, CredentialsMetaInput] = {}
|
||||||
|
missing: list[CredentialsMetaInput] = []
|
||||||
|
|
||||||
|
if not requirements:
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
available_creds = await get_user_credentials(user_id)
|
||||||
|
|
||||||
|
for field_name, field_info in requirements.items():
|
||||||
|
matching_cred = find_matching_credential(available_creds, field_info)
|
||||||
|
|
||||||
|
if matching_cred:
|
||||||
|
try:
|
||||||
|
matched[field_name] = create_credential_meta_from_match(matching_cred)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"Failed to create CredentialsMetaInput for field '{field_name}': "
|
||||||
|
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
||||||
|
f"credential_id={matching_cred.id}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=f"{field_name} (validation failed: {e})",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
provider = next(iter(field_info.provider), "unknown")
|
||||||
|
cred_type = next(iter(field_info.supported_types), "api_key")
|
||||||
|
missing.append(
|
||||||
|
CredentialsMetaInput(
|
||||||
|
id=field_name,
|
||||||
|
provider=provider, # type: ignore
|
||||||
|
type=cred_type, # type: ignore
|
||||||
|
title=field_name.replace("_", " ").title(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return matched, missing
|
||||||
|
|
||||||
|
|
||||||
async def match_user_credentials_to_graph(
|
async def match_user_credentials_to_graph(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
graph: GraphModel,
|
graph: GraphModel,
|
||||||
@@ -242,9 +363,6 @@ async def match_user_credentials_to_graph(
|
|||||||
Returns:
|
Returns:
|
||||||
tuple[matched_credentials dict, missing_credential_descriptions list]
|
tuple[matched_credentials dict, missing_credential_descriptions list]
|
||||||
"""
|
"""
|
||||||
graph_credentials_inputs: dict[str, CredentialsMetaInput] = {}
|
|
||||||
missing_creds: list[str] = []
|
|
||||||
|
|
||||||
# Get aggregated credentials requirements from the graph
|
# Get aggregated credentials requirements from the graph
|
||||||
aggregated_creds = graph.aggregate_credentials_inputs()
|
aggregated_creds = graph.aggregate_credentials_inputs()
|
||||||
logger.debug(
|
logger.debug(
|
||||||
@@ -252,69 +370,30 @@ async def match_user_credentials_to_graph(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if not aggregated_creds:
|
if not aggregated_creds:
|
||||||
return graph_credentials_inputs, missing_creds
|
return {}, []
|
||||||
|
|
||||||
# Get all available credentials for the user
|
# Convert aggregated format to simple requirements dict
|
||||||
creds_manager = IntegrationCredentialsManager()
|
requirements = {
|
||||||
available_creds = await creds_manager.store.get_all_creds(user_id)
|
field_name: field_info
|
||||||
|
for field_name, (field_info, _node_fields) in aggregated_creds.items()
|
||||||
|
}
|
||||||
|
|
||||||
# For each required credential field, find a matching user credential
|
# Use shared matching logic
|
||||||
# field_info.provider is a frozenset because aggregate_credentials_inputs()
|
matched, missing_list = await match_credentials_to_requirements(
|
||||||
# combines requirements from multiple nodes. A credential matches if its
|
user_id, requirements
|
||||||
# provider is in the set of acceptable providers.
|
|
||||||
for credential_field_name, (
|
|
||||||
credential_requirements,
|
|
||||||
_node_fields,
|
|
||||||
) in aggregated_creds.items():
|
|
||||||
# Find first matching credential by provider, type, and scopes
|
|
||||||
matching_cred = next(
|
|
||||||
(
|
|
||||||
cred
|
|
||||||
for cred in available_creds
|
|
||||||
if cred.provider in credential_requirements.provider
|
|
||||||
and cred.type in credential_requirements.supported_types
|
|
||||||
and _credential_has_required_scopes(cred, credential_requirements)
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if matching_cred:
|
# Convert missing list to string descriptions for backward compatibility
|
||||||
try:
|
missing_descriptions = [
|
||||||
graph_credentials_inputs[credential_field_name] = CredentialsMetaInput(
|
f"{cred.id} (requires provider={cred.provider}, type={cred.type})"
|
||||||
id=matching_cred.id,
|
for cred in missing_list
|
||||||
provider=matching_cred.provider, # type: ignore
|
|
||||||
type=matching_cred.type,
|
|
||||||
title=matching_cred.title,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Failed to create CredentialsMetaInput for field '{credential_field_name}': "
|
|
||||||
f"provider={matching_cred.provider}, type={matching_cred.type}, "
|
|
||||||
f"credential_id={matching_cred.id}",
|
|
||||||
exc_info=True,
|
|
||||||
)
|
|
||||||
missing_creds.append(
|
|
||||||
f"{credential_field_name} (validation failed: {e})"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Build a helpful error message including scope requirements
|
|
||||||
error_parts = [
|
|
||||||
f"provider in {list(credential_requirements.provider)}",
|
|
||||||
f"type in {list(credential_requirements.supported_types)}",
|
|
||||||
]
|
]
|
||||||
if credential_requirements.required_scopes:
|
|
||||||
error_parts.append(
|
|
||||||
f"scopes including {list(credential_requirements.required_scopes)}"
|
|
||||||
)
|
|
||||||
missing_creds.append(
|
|
||||||
f"{credential_field_name} (requires {', '.join(error_parts)})"
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Credential matching complete: {len(graph_credentials_inputs)}/{len(aggregated_creds)} matched"
|
f"Credential matching complete: {len(matched)}/{len(aggregated_creds)} matched"
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph_credentials_inputs, missing_creds
|
return matched, missing_descriptions
|
||||||
|
|
||||||
|
|
||||||
def _credential_has_required_scopes(
|
def _credential_has_required_scopes(
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from backend.data.model import (
|
|||||||
from backend.integrations.providers import ProviderName
|
from backend.integrations.providers import ProviderName
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.logging import TruncatedLogger
|
from backend.util.logging import TruncatedLogger
|
||||||
from backend.util.prompt import compress_prompt, estimate_token_count
|
from backend.util.prompt import compress_context, estimate_token_count
|
||||||
from backend.util.text import TextFormatter
|
from backend.util.text import TextFormatter
|
||||||
|
|
||||||
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
|
||||||
@@ -634,11 +634,18 @@ async def llm_call(
|
|||||||
context_window = llm_model.context_window
|
context_window = llm_model.context_window
|
||||||
|
|
||||||
if compress_prompt_to_fit:
|
if compress_prompt_to_fit:
|
||||||
prompt = compress_prompt(
|
result = await compress_context(
|
||||||
messages=prompt,
|
messages=prompt,
|
||||||
target_tokens=llm_model.context_window // 2,
|
target_tokens=llm_model.context_window // 2,
|
||||||
lossy_ok=True,
|
client=None, # Truncation-only, no LLM summarization
|
||||||
|
reserve=0, # Caller handles response token budget separately
|
||||||
)
|
)
|
||||||
|
if result.error:
|
||||||
|
logger.warning(
|
||||||
|
f"Prompt compression did not meet target: {result.error}. "
|
||||||
|
f"Proceeding with {result.token_count} tokens."
|
||||||
|
)
|
||||||
|
prompt = result.messages
|
||||||
|
|
||||||
# Calculate available tokens based on context window and input length
|
# Calculate available tokens based on context window and input length
|
||||||
estimated_input_tokens = estimate_token_count(prompt)
|
estimated_input_tokens = estimate_token_count(prompt)
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ from backend.data.analytics import (
|
|||||||
get_accuracy_trends_and_alerts,
|
get_accuracy_trends_and_alerts,
|
||||||
get_marketplace_graphs_for_monitoring,
|
get_marketplace_graphs_for_monitoring,
|
||||||
)
|
)
|
||||||
|
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
||||||
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
|
||||||
from backend.data.execution import (
|
from backend.data.execution import (
|
||||||
create_graph_execution,
|
create_graph_execution,
|
||||||
@@ -219,6 +220,9 @@ class DatabaseManager(AppService):
|
|||||||
# Onboarding
|
# Onboarding
|
||||||
increment_onboarding_runs = _(increment_onboarding_runs)
|
increment_onboarding_runs = _(increment_onboarding_runs)
|
||||||
|
|
||||||
|
# OAuth
|
||||||
|
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
|
||||||
|
|
||||||
# Store
|
# Store
|
||||||
get_store_agents = _(get_store_agents)
|
get_store_agents = _(get_store_agents)
|
||||||
get_store_agent_details = _(get_store_agent_details)
|
get_store_agent_details = _(get_store_agent_details)
|
||||||
@@ -349,6 +353,9 @@ class DatabaseManagerAsyncClient(AppServiceClient):
|
|||||||
# Onboarding
|
# Onboarding
|
||||||
increment_onboarding_runs = d.increment_onboarding_runs
|
increment_onboarding_runs = d.increment_onboarding_runs
|
||||||
|
|
||||||
|
# OAuth
|
||||||
|
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
|
||||||
|
|
||||||
# Store
|
# Store
|
||||||
get_store_agents = d.get_store_agents
|
get_store_agents = d.get_store_agents
|
||||||
get_store_agent_details = d.get_store_agent_details
|
get_store_agent_details = d.get_store_agent_details
|
||||||
|
|||||||
@@ -24,11 +24,9 @@ from dotenv import load_dotenv
|
|||||||
from pydantic import BaseModel, Field, ValidationError
|
from pydantic import BaseModel, Field, ValidationError
|
||||||
from sqlalchemy import MetaData, create_engine
|
from sqlalchemy import MetaData, create_engine
|
||||||
|
|
||||||
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
|
|
||||||
from backend.data.block import BlockInput
|
from backend.data.block import BlockInput
|
||||||
from backend.data.execution import GraphExecutionWithNodes
|
from backend.data.execution import GraphExecutionWithNodes
|
||||||
from backend.data.model import CredentialsMetaInput
|
from backend.data.model import CredentialsMetaInput
|
||||||
from backend.data.onboarding import increment_onboarding_runs
|
|
||||||
from backend.executor import utils as execution_utils
|
from backend.executor import utils as execution_utils
|
||||||
from backend.monitoring import (
|
from backend.monitoring import (
|
||||||
NotificationJobArgs,
|
NotificationJobArgs,
|
||||||
@@ -38,7 +36,11 @@ from backend.monitoring import (
|
|||||||
report_execution_accuracy_alerts,
|
report_execution_accuracy_alerts,
|
||||||
report_late_executions,
|
report_late_executions,
|
||||||
)
|
)
|
||||||
from backend.util.clients import get_database_manager_client, get_scheduler_client
|
from backend.util.clients import (
|
||||||
|
get_database_manager_async_client,
|
||||||
|
get_database_manager_client,
|
||||||
|
get_scheduler_client,
|
||||||
|
)
|
||||||
from backend.util.cloud_storage import cleanup_expired_files_async
|
from backend.util.cloud_storage import cleanup_expired_files_async
|
||||||
from backend.util.exceptions import (
|
from backend.util.exceptions import (
|
||||||
GraphNotFoundError,
|
GraphNotFoundError,
|
||||||
@@ -148,6 +150,7 @@ def execute_graph(**kwargs):
|
|||||||
async def _execute_graph(**kwargs):
|
async def _execute_graph(**kwargs):
|
||||||
args = GraphExecutionJobArgs(**kwargs)
|
args = GraphExecutionJobArgs(**kwargs)
|
||||||
start_time = asyncio.get_event_loop().time()
|
start_time = asyncio.get_event_loop().time()
|
||||||
|
db = get_database_manager_async_client()
|
||||||
try:
|
try:
|
||||||
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
logger.info(f"Executing recurring job for graph #{args.graph_id}")
|
||||||
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
|
||||||
@@ -157,7 +160,7 @@ async def _execute_graph(**kwargs):
|
|||||||
inputs=args.input_data,
|
inputs=args.input_data,
|
||||||
graph_credentials_inputs=args.input_credentials,
|
graph_credentials_inputs=args.input_credentials,
|
||||||
)
|
)
|
||||||
await increment_onboarding_runs(args.user_id)
|
await db.increment_onboarding_runs(args.user_id)
|
||||||
elapsed = asyncio.get_event_loop().time() - start_time
|
elapsed = asyncio.get_event_loop().time() - start_time
|
||||||
logger.info(
|
logger.info(
|
||||||
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
|
||||||
@@ -246,8 +249,13 @@ def cleanup_expired_files():
|
|||||||
|
|
||||||
def cleanup_oauth_tokens():
|
def cleanup_oauth_tokens():
|
||||||
"""Clean up expired OAuth tokens from the database."""
|
"""Clean up expired OAuth tokens from the database."""
|
||||||
|
|
||||||
# Wait for completion
|
# Wait for completion
|
||||||
run_async(cleanup_expired_oauth_tokens())
|
async def _cleanup():
|
||||||
|
db = get_database_manager_async_client()
|
||||||
|
return await db.cleanup_expired_oauth_tokens()
|
||||||
|
|
||||||
|
run_async(_cleanup())
|
||||||
|
|
||||||
|
|
||||||
def execution_accuracy_alerts():
|
def execution_accuracy_alerts():
|
||||||
|
|||||||
@@ -1,10 +1,19 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from tiktoken import encoding_for_model
|
from tiktoken import encoding_for_model
|
||||||
|
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------#
|
# ---------------------------------------------------------------------------#
|
||||||
# CONSTANTS #
|
# CONSTANTS #
|
||||||
# ---------------------------------------------------------------------------#
|
# ---------------------------------------------------------------------------#
|
||||||
@@ -100,9 +109,17 @@ def _is_objective_message(msg: dict) -> bool:
|
|||||||
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
|
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
|
||||||
"""
|
"""
|
||||||
Carefully truncate tool message content while preserving tool structure.
|
Carefully truncate tool message content while preserving tool structure.
|
||||||
Only truncates tool_result content, leaves tool_use intact.
|
Handles both Anthropic-style (list content) and OpenAI-style (string content) tool messages.
|
||||||
"""
|
"""
|
||||||
content = msg.get("content")
|
content = msg.get("content")
|
||||||
|
|
||||||
|
# OpenAI-style tool message: role="tool" with string content
|
||||||
|
if msg.get("role") == "tool" and isinstance(content, str):
|
||||||
|
if _tok_len(content, enc) > max_tokens:
|
||||||
|
msg["content"] = _truncate_middle_tokens(content, enc, max_tokens)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Anthropic-style: list content with tool_result items
|
||||||
if not isinstance(content, list):
|
if not isinstance(content, list):
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -140,141 +157,6 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
|
|||||||
# ---------------------------------------------------------------------------#
|
# ---------------------------------------------------------------------------#
|
||||||
|
|
||||||
|
|
||||||
def compress_prompt(
|
|
||||||
messages: list[dict],
|
|
||||||
target_tokens: int,
|
|
||||||
*,
|
|
||||||
model: str = "gpt-4o",
|
|
||||||
reserve: int = 2_048,
|
|
||||||
start_cap: int = 8_192,
|
|
||||||
floor_cap: int = 128,
|
|
||||||
lossy_ok: bool = True,
|
|
||||||
) -> list[dict]:
|
|
||||||
"""
|
|
||||||
Shrink *messages* so that::
|
|
||||||
|
|
||||||
token_count(prompt) + reserve ≤ target_tokens
|
|
||||||
|
|
||||||
Strategy
|
|
||||||
--------
|
|
||||||
1. **Token-aware truncation** – progressively halve a per-message cap
|
|
||||||
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
|
|
||||||
*content* of every message except the first and last. Tool shells
|
|
||||||
are included: we keep the envelope but shorten huge payloads.
|
|
||||||
2. **Middle-out deletion** – if still over the limit, delete whole
|
|
||||||
messages working outward from the centre, **skipping** any message
|
|
||||||
that contains ``tool_calls`` or has ``role == "tool"``.
|
|
||||||
3. **Last-chance trim** – if still too big, truncate the *first* and
|
|
||||||
*last* message bodies down to `floor_cap` tokens.
|
|
||||||
4. If the prompt is *still* too large:
|
|
||||||
• raise ``ValueError`` when ``lossy_ok == False`` (default)
|
|
||||||
• return the partially-trimmed prompt when ``lossy_ok == True``
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
messages Complete chat history (will be deep-copied).
|
|
||||||
model Model name; passed to tiktoken to pick the right
|
|
||||||
tokenizer (gpt-4o → 'o200k_base', others fallback).
|
|
||||||
target_tokens Hard ceiling for prompt size **excluding** the model's
|
|
||||||
forthcoming answer.
|
|
||||||
reserve How many tokens you want to leave available for that
|
|
||||||
answer (`max_tokens` in your subsequent completion call).
|
|
||||||
start_cap Initial per-message truncation ceiling (tokens).
|
|
||||||
floor_cap Lowest cap we'll accept before moving to deletions.
|
|
||||||
lossy_ok If *True* return best-effort prompt instead of raising
|
|
||||||
after all trim passes have been exhausted.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
list[dict] – A *new* messages list that abides by the rules above.
|
|
||||||
"""
|
|
||||||
enc = encoding_for_model(model) # best-match tokenizer
|
|
||||||
msgs = deepcopy(messages) # never mutate caller
|
|
||||||
|
|
||||||
def total_tokens() -> int:
|
|
||||||
"""Current size of *msgs* in tokens."""
|
|
||||||
return sum(_msg_tokens(m, enc) for m in msgs)
|
|
||||||
|
|
||||||
original_token_count = total_tokens()
|
|
||||||
|
|
||||||
if original_token_count + reserve <= target_tokens:
|
|
||||||
return msgs
|
|
||||||
|
|
||||||
# ---- STEP 0 : normalise content --------------------------------------
|
|
||||||
# Convert non-string payloads to strings so token counting is coherent.
|
|
||||||
for i, m in enumerate(msgs):
|
|
||||||
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
|
||||||
if _is_tool_message(m):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Keep first and last messages intact (unless they're tool messages)
|
|
||||||
if i == 0 or i == len(msgs) - 1:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Reasonable 20k-char ceiling prevents pathological blobs
|
|
||||||
content_str = json.dumps(m["content"], separators=(",", ":"))
|
|
||||||
if len(content_str) > 20_000:
|
|
||||||
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
|
||||||
m["content"] = content_str
|
|
||||||
|
|
||||||
# ---- STEP 1 : token-aware truncation ---------------------------------
|
|
||||||
cap = start_cap
|
|
||||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
|
||||||
for m in msgs[1:-1]: # keep first & last intact
|
|
||||||
if _is_tool_message(m):
|
|
||||||
# For tool messages, only truncate tool result content, preserve structure
|
|
||||||
_truncate_tool_message_content(m, enc, cap)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if _is_objective_message(m):
|
|
||||||
# Never truncate objective messages - they contain the core task
|
|
||||||
continue
|
|
||||||
|
|
||||||
content = m.get("content") or ""
|
|
||||||
if _tok_len(content, enc) > cap:
|
|
||||||
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
|
||||||
cap //= 2 # tighten the screw
|
|
||||||
|
|
||||||
# ---- STEP 2 : middle-out deletion -----------------------------------
|
|
||||||
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
|
||||||
# Identify all deletable messages (not first/last, not tool messages, not objective messages)
|
|
||||||
deletable_indices = []
|
|
||||||
for i in range(1, len(msgs) - 1): # Skip first and last
|
|
||||||
if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]):
|
|
||||||
deletable_indices.append(i)
|
|
||||||
|
|
||||||
if not deletable_indices:
|
|
||||||
break # nothing more we can drop
|
|
||||||
|
|
||||||
# Delete from center outward - find the index closest to center
|
|
||||||
centre = len(msgs) // 2
|
|
||||||
to_delete = min(deletable_indices, key=lambda i: abs(i - centre))
|
|
||||||
del msgs[to_delete]
|
|
||||||
|
|
||||||
# ---- STEP 3 : final safety-net trim on first & last ------------------
|
|
||||||
cap = start_cap
|
|
||||||
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
|
||||||
for idx in (0, -1): # first and last
|
|
||||||
if _is_tool_message(msgs[idx]):
|
|
||||||
# For tool messages at first/last position, truncate tool result content only
|
|
||||||
_truncate_tool_message_content(msgs[idx], enc, cap)
|
|
||||||
continue
|
|
||||||
|
|
||||||
text = msgs[idx].get("content") or ""
|
|
||||||
if _tok_len(text, enc) > cap:
|
|
||||||
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
|
|
||||||
cap //= 2 # tighten the screw
|
|
||||||
|
|
||||||
# ---- STEP 4 : success or fail-gracefully -----------------------------
|
|
||||||
if total_tokens() + reserve > target_tokens and not lossy_ok:
|
|
||||||
raise ValueError(
|
|
||||||
"compress_prompt: prompt still exceeds budget "
|
|
||||||
f"({total_tokens() + reserve} > {target_tokens})."
|
|
||||||
)
|
|
||||||
|
|
||||||
return msgs
|
|
||||||
|
|
||||||
|
|
||||||
def estimate_token_count(
|
def estimate_token_count(
|
||||||
messages: list[dict],
|
messages: list[dict],
|
||||||
*,
|
*,
|
||||||
@@ -293,7 +175,8 @@ def estimate_token_count(
|
|||||||
-------
|
-------
|
||||||
int – Token count.
|
int – Token count.
|
||||||
"""
|
"""
|
||||||
enc = encoding_for_model(model) # best-match tokenizer
|
token_model = _normalize_model_for_tokenizer(model)
|
||||||
|
enc = encoding_for_model(token_model)
|
||||||
return sum(_msg_tokens(m, enc) for m in messages)
|
return sum(_msg_tokens(m, enc) for m in messages)
|
||||||
|
|
||||||
|
|
||||||
@@ -315,6 +198,543 @@ def estimate_token_count_str(
|
|||||||
-------
|
-------
|
||||||
int – Token count.
|
int – Token count.
|
||||||
"""
|
"""
|
||||||
enc = encoding_for_model(model) # best-match tokenizer
|
token_model = _normalize_model_for_tokenizer(model)
|
||||||
|
enc = encoding_for_model(token_model)
|
||||||
text = json.dumps(text) if not isinstance(text, str) else text
|
text = json.dumps(text) if not isinstance(text, str) else text
|
||||||
return _tok_len(text, enc)
|
return _tok_len(text, enc)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------#
|
||||||
|
# UNIFIED CONTEXT COMPRESSION #
|
||||||
|
# ---------------------------------------------------------------------------#
|
||||||
|
|
||||||
|
# Default thresholds
|
||||||
|
DEFAULT_TOKEN_THRESHOLD = 120_000
|
||||||
|
DEFAULT_KEEP_RECENT = 15
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CompressResult:
|
||||||
|
"""Result of context compression."""
|
||||||
|
|
||||||
|
messages: list[dict]
|
||||||
|
token_count: int
|
||||||
|
was_compacted: bool
|
||||||
|
error: str | None = None
|
||||||
|
original_token_count: int = 0
|
||||||
|
messages_summarized: int = 0
|
||||||
|
messages_dropped: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_model_for_tokenizer(model: str) -> str:
|
||||||
|
"""Normalize model name for tiktoken tokenizer selection."""
|
||||||
|
if "/" in model:
|
||||||
|
model = model.split("/")[-1]
|
||||||
|
if "claude" in model.lower() or not any(
|
||||||
|
known in model.lower() for known in ["gpt", "o1", "chatgpt", "text-"]
|
||||||
|
):
|
||||||
|
return "gpt-4o"
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_call_ids_from_message(msg: dict) -> set[str]:
|
||||||
|
"""
|
||||||
|
Extract tool_call IDs from an assistant message.
|
||||||
|
|
||||||
|
Supports both formats:
|
||||||
|
- OpenAI: {"role": "assistant", "tool_calls": [{"id": "..."}]}
|
||||||
|
- Anthropic: {"role": "assistant", "content": [{"type": "tool_use", "id": "..."}]}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of tool_call IDs found in the message.
|
||||||
|
"""
|
||||||
|
ids: set[str] = set()
|
||||||
|
if msg.get("role") != "assistant":
|
||||||
|
return ids
|
||||||
|
|
||||||
|
# OpenAI format: tool_calls array
|
||||||
|
if msg.get("tool_calls"):
|
||||||
|
for tc in msg["tool_calls"]:
|
||||||
|
tc_id = tc.get("id")
|
||||||
|
if tc_id:
|
||||||
|
ids.add(tc_id)
|
||||||
|
|
||||||
|
# Anthropic format: content list with tool_use blocks
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "tool_use":
|
||||||
|
tc_id = block.get("id")
|
||||||
|
if tc_id:
|
||||||
|
ids.add(tc_id)
|
||||||
|
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_response_ids_from_message(msg: dict) -> set[str]:
|
||||||
|
"""
|
||||||
|
Extract tool_call IDs that this message is responding to.
|
||||||
|
|
||||||
|
Supports both formats:
|
||||||
|
- OpenAI: {"role": "tool", "tool_call_id": "..."}
|
||||||
|
- Anthropic: {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "..."}]}
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of tool_call IDs this message responds to.
|
||||||
|
"""
|
||||||
|
ids: set[str] = set()
|
||||||
|
|
||||||
|
# OpenAI format: role=tool with tool_call_id
|
||||||
|
if msg.get("role") == "tool":
|
||||||
|
tc_id = msg.get("tool_call_id")
|
||||||
|
if tc_id:
|
||||||
|
ids.add(tc_id)
|
||||||
|
|
||||||
|
# Anthropic format: content list with tool_result blocks
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||||
|
tc_id = block.get("tool_use_id")
|
||||||
|
if tc_id:
|
||||||
|
ids.add(tc_id)
|
||||||
|
|
||||||
|
return ids
|
||||||
|
|
||||||
|
|
||||||
|
def _is_tool_response_message(msg: dict) -> bool:
|
||||||
|
"""Check if message is a tool response (OpenAI or Anthropic format)."""
|
||||||
|
# OpenAI format
|
||||||
|
if msg.get("role") == "tool":
|
||||||
|
return True
|
||||||
|
# Anthropic format
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _remove_orphan_tool_responses(
|
||||||
|
messages: list[dict], orphan_ids: set[str]
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Remove tool response messages/blocks that reference orphan tool_call IDs.
|
||||||
|
|
||||||
|
Supports both OpenAI and Anthropic formats.
|
||||||
|
For Anthropic messages with mixed valid/orphan tool_result blocks,
|
||||||
|
filters out only the orphan blocks instead of dropping the entire message.
|
||||||
|
"""
|
||||||
|
result = []
|
||||||
|
for msg in messages:
|
||||||
|
# OpenAI format: role=tool - drop entire message if orphan
|
||||||
|
if msg.get("role") == "tool":
|
||||||
|
tc_id = msg.get("tool_call_id")
|
||||||
|
if tc_id and tc_id in orphan_ids:
|
||||||
|
continue
|
||||||
|
result.append(msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Anthropic format: content list may have mixed tool_result blocks
|
||||||
|
content = msg.get("content")
|
||||||
|
if isinstance(content, list):
|
||||||
|
has_tool_results = any(
|
||||||
|
isinstance(b, dict) and b.get("type") == "tool_result" for b in content
|
||||||
|
)
|
||||||
|
if has_tool_results:
|
||||||
|
# Filter out orphan tool_result blocks, keep valid ones
|
||||||
|
filtered_content = [
|
||||||
|
block
|
||||||
|
for block in content
|
||||||
|
if not (
|
||||||
|
isinstance(block, dict)
|
||||||
|
and block.get("type") == "tool_result"
|
||||||
|
and block.get("tool_use_id") in orphan_ids
|
||||||
|
)
|
||||||
|
]
|
||||||
|
# Only keep message if it has remaining content
|
||||||
|
if filtered_content:
|
||||||
|
msg = msg.copy()
|
||||||
|
msg["content"] = filtered_content
|
||||||
|
result.append(msg)
|
||||||
|
continue
|
||||||
|
|
||||||
|
result.append(msg)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_tool_pairs_intact(
|
||||||
|
recent_messages: list[dict],
|
||||||
|
all_messages: list[dict],
|
||||||
|
start_index: int,
|
||||||
|
) -> list[dict]:
|
||||||
|
"""
|
||||||
|
Ensure tool_call/tool_response pairs stay together after slicing.
|
||||||
|
|
||||||
|
When slicing messages for context compaction, a naive slice can separate
|
||||||
|
an assistant message containing tool_calls from its corresponding tool
|
||||||
|
response messages. This causes API validation errors (e.g., Anthropic's
|
||||||
|
"unexpected tool_use_id found in tool_result blocks").
|
||||||
|
|
||||||
|
This function checks for orphan tool responses in the slice and extends
|
||||||
|
backwards to include their corresponding assistant messages.
|
||||||
|
|
||||||
|
Supports both formats:
|
||||||
|
- OpenAI: tool_calls array + role="tool" responses
|
||||||
|
- Anthropic: tool_use blocks + tool_result blocks
|
||||||
|
|
||||||
|
Args:
|
||||||
|
recent_messages: The sliced messages to validate
|
||||||
|
all_messages: The complete message list (for looking up missing assistants)
|
||||||
|
start_index: The index in all_messages where recent_messages begins
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A potentially extended list of messages with tool pairs intact
|
||||||
|
"""
|
||||||
|
if not recent_messages:
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
# Collect all tool_call_ids from assistant messages in the slice
|
||||||
|
available_tool_call_ids: set[str] = set()
|
||||||
|
for msg in recent_messages:
|
||||||
|
available_tool_call_ids |= _extract_tool_call_ids_from_message(msg)
|
||||||
|
|
||||||
|
# Find orphan tool responses (responses whose tool_call_id is missing)
|
||||||
|
orphan_tool_call_ids: set[str] = set()
|
||||||
|
for msg in recent_messages:
|
||||||
|
response_ids = _extract_tool_response_ids_from_message(msg)
|
||||||
|
for tc_id in response_ids:
|
||||||
|
if tc_id not in available_tool_call_ids:
|
||||||
|
orphan_tool_call_ids.add(tc_id)
|
||||||
|
|
||||||
|
if not orphan_tool_call_ids:
|
||||||
|
# No orphans, slice is valid
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
# Find the assistant messages that contain the orphan tool_call_ids
|
||||||
|
# Search backwards from start_index in all_messages
|
||||||
|
messages_to_prepend: list[dict] = []
|
||||||
|
for i in range(start_index - 1, -1, -1):
|
||||||
|
msg = all_messages[i]
|
||||||
|
msg_tool_ids = _extract_tool_call_ids_from_message(msg)
|
||||||
|
if msg_tool_ids & orphan_tool_call_ids:
|
||||||
|
# This assistant message has tool_calls we need
|
||||||
|
# Also collect its contiguous tool responses that follow it
|
||||||
|
assistant_and_responses: list[dict] = [msg]
|
||||||
|
|
||||||
|
# Scan forward from this assistant to collect tool responses
|
||||||
|
for j in range(i + 1, start_index):
|
||||||
|
following_msg = all_messages[j]
|
||||||
|
following_response_ids = _extract_tool_response_ids_from_message(
|
||||||
|
following_msg
|
||||||
|
)
|
||||||
|
if following_response_ids and following_response_ids & msg_tool_ids:
|
||||||
|
assistant_and_responses.append(following_msg)
|
||||||
|
elif not _is_tool_response_message(following_msg):
|
||||||
|
# Stop at first non-tool-response message
|
||||||
|
break
|
||||||
|
|
||||||
|
# Prepend the assistant and its tool responses (maintain order)
|
||||||
|
messages_to_prepend = assistant_and_responses + messages_to_prepend
|
||||||
|
# Mark these as found
|
||||||
|
orphan_tool_call_ids -= msg_tool_ids
|
||||||
|
# Also add this assistant's tool_call_ids to available set
|
||||||
|
available_tool_call_ids |= msg_tool_ids
|
||||||
|
|
||||||
|
if not orphan_tool_call_ids:
|
||||||
|
# Found all missing assistants
|
||||||
|
break
|
||||||
|
|
||||||
|
if orphan_tool_call_ids:
|
||||||
|
# Some tool_call_ids couldn't be resolved - remove those tool responses
|
||||||
|
# This shouldn't happen in normal operation but handles edge cases
|
||||||
|
logger.warning(
|
||||||
|
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
|
||||||
|
"Removing orphan tool responses."
|
||||||
|
)
|
||||||
|
recent_messages = _remove_orphan_tool_responses(
|
||||||
|
recent_messages, orphan_tool_call_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
if messages_to_prepend:
|
||||||
|
logger.info(
|
||||||
|
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
|
||||||
|
f"tool_call/tool_response pairs"
|
||||||
|
)
|
||||||
|
return messages_to_prepend + recent_messages
|
||||||
|
|
||||||
|
return recent_messages
|
||||||
|
|
||||||
|
|
||||||
|
async def _summarize_messages_llm(
|
||||||
|
messages: list[dict],
|
||||||
|
client: AsyncOpenAI,
|
||||||
|
model: str,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
) -> str:
|
||||||
|
"""Summarize messages using an LLM."""
|
||||||
|
conversation = []
|
||||||
|
for msg in messages:
|
||||||
|
role = msg.get("role", "")
|
||||||
|
content = msg.get("content", "")
|
||||||
|
if content and role in ("user", "assistant", "tool"):
|
||||||
|
conversation.append(f"{role.upper()}: {content}")
|
||||||
|
|
||||||
|
conversation_text = "\n\n".join(conversation)
|
||||||
|
|
||||||
|
if not conversation_text:
|
||||||
|
return "No conversation history available."
|
||||||
|
|
||||||
|
# Limit to ~100k chars for safety
|
||||||
|
MAX_CHARS = 100_000
|
||||||
|
if len(conversation_text) > MAX_CHARS:
|
||||||
|
conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]"
|
||||||
|
|
||||||
|
response = await client.with_options(timeout=timeout).chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": (
|
||||||
|
"Create a detailed summary of the conversation so far. "
|
||||||
|
"This summary will be used as context when continuing the conversation.\n\n"
|
||||||
|
"Before writing the summary, analyze each message chronologically to identify:\n"
|
||||||
|
"- User requests and their explicit goals\n"
|
||||||
|
"- Your approach and key decisions made\n"
|
||||||
|
"- Technical specifics (file names, tool outputs, function signatures)\n"
|
||||||
|
"- Errors encountered and resolutions applied\n\n"
|
||||||
|
"You MUST include ALL of the following sections:\n\n"
|
||||||
|
"## 1. Primary Request and Intent\n"
|
||||||
|
"The user's explicit goals and what they are trying to accomplish.\n\n"
|
||||||
|
"## 2. Key Technical Concepts\n"
|
||||||
|
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
|
||||||
|
"## 3. Files and Resources Involved\n"
|
||||||
|
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
|
||||||
|
"## 4. Errors and Fixes\n"
|
||||||
|
"Problems encountered, error messages, and their resolutions. "
|
||||||
|
"Include any user feedback on fixes.\n\n"
|
||||||
|
"## 5. Problem Solving\n"
|
||||||
|
"Issues that have been resolved and how they were addressed.\n\n"
|
||||||
|
"## 6. All User Messages\n"
|
||||||
|
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
|
||||||
|
"## 7. Pending Tasks\n"
|
||||||
|
"Work items the user explicitly requested that have not yet been completed.\n\n"
|
||||||
|
"## 8. Current Work\n"
|
||||||
|
"Precise description of what was being worked on most recently, including relevant context.\n\n"
|
||||||
|
"## 9. Next Steps\n"
|
||||||
|
"What should happen next, aligned with the user's most recent requests. "
|
||||||
|
"Include verbatim quotes of recent instructions if relevant."
|
||||||
|
),
|
||||||
|
},
|
||||||
|
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
|
||||||
|
],
|
||||||
|
max_tokens=1500,
|
||||||
|
temperature=0.3,
|
||||||
|
)
|
||||||
|
|
||||||
|
return response.choices[0].message.content or "No summary available."
|
||||||
|
|
||||||
|
|
||||||
|
async def compress_context(
|
||||||
|
messages: list[dict],
|
||||||
|
target_tokens: int = DEFAULT_TOKEN_THRESHOLD,
|
||||||
|
*,
|
||||||
|
model: str = "gpt-4o",
|
||||||
|
client: AsyncOpenAI | None = None,
|
||||||
|
keep_recent: int = DEFAULT_KEEP_RECENT,
|
||||||
|
reserve: int = 2_048,
|
||||||
|
start_cap: int = 8_192,
|
||||||
|
floor_cap: int = 128,
|
||||||
|
) -> CompressResult:
|
||||||
|
"""
|
||||||
|
Unified context compression that combines summarization and truncation strategies.
|
||||||
|
|
||||||
|
Strategy (in order):
|
||||||
|
1. **LLM summarization** – If client provided, summarize old messages into a
|
||||||
|
single context message while keeping recent messages intact. This is the
|
||||||
|
primary strategy for chat service.
|
||||||
|
2. **Content truncation** – Progressively halve a per-message cap and truncate
|
||||||
|
bloated message content (tool outputs, large pastes). Preserves all messages
|
||||||
|
but shortens their content. Primary strategy when client=None (LLM blocks).
|
||||||
|
3. **Middle-out deletion** – Delete whole messages one at a time from the center
|
||||||
|
outward, skipping tool messages and objective messages.
|
||||||
|
4. **First/last trim** – Truncate first and last message content as last resort.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
messages Complete chat history (will be deep-copied).
|
||||||
|
target_tokens Hard ceiling for prompt size.
|
||||||
|
model Model name for tokenization and summarization.
|
||||||
|
client AsyncOpenAI client. If provided, enables LLM summarization
|
||||||
|
as the first strategy. If None, skips to truncation strategies.
|
||||||
|
keep_recent Number of recent messages to preserve during summarization.
|
||||||
|
reserve Tokens to reserve for model response.
|
||||||
|
start_cap Initial per-message truncation ceiling (tokens).
|
||||||
|
floor_cap Lowest cap before moving to deletions.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
CompressResult with compressed messages and metadata.
|
||||||
|
"""
|
||||||
|
# Guard clause for empty messages
|
||||||
|
if not messages:
|
||||||
|
return CompressResult(
|
||||||
|
messages=[],
|
||||||
|
token_count=0,
|
||||||
|
was_compacted=False,
|
||||||
|
original_token_count=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
token_model = _normalize_model_for_tokenizer(model)
|
||||||
|
enc = encoding_for_model(token_model)
|
||||||
|
msgs = deepcopy(messages)
|
||||||
|
|
||||||
|
def total_tokens() -> int:
|
||||||
|
return sum(_msg_tokens(m, enc) for m in msgs)
|
||||||
|
|
||||||
|
original_count = total_tokens()
|
||||||
|
|
||||||
|
# Already under limit
|
||||||
|
if original_count + reserve <= target_tokens:
|
||||||
|
return CompressResult(
|
||||||
|
messages=msgs,
|
||||||
|
token_count=original_count,
|
||||||
|
was_compacted=False,
|
||||||
|
original_token_count=original_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages_summarized = 0
|
||||||
|
messages_dropped = 0
|
||||||
|
|
||||||
|
# ---- STEP 1: LLM summarization (if client provided) -------------------
|
||||||
|
# This is the primary compression strategy for chat service.
|
||||||
|
# Summarize old messages while keeping recent ones intact.
|
||||||
|
if client is not None:
|
||||||
|
has_system = len(msgs) > 0 and msgs[0].get("role") == "system"
|
||||||
|
system_msg = msgs[0] if has_system else None
|
||||||
|
|
||||||
|
# Calculate old vs recent messages
|
||||||
|
if has_system:
|
||||||
|
if len(msgs) > keep_recent + 1:
|
||||||
|
old_msgs = msgs[1:-keep_recent]
|
||||||
|
recent_msgs = msgs[-keep_recent:]
|
||||||
|
else:
|
||||||
|
old_msgs = []
|
||||||
|
recent_msgs = msgs[1:] if len(msgs) > 1 else []
|
||||||
|
else:
|
||||||
|
if len(msgs) > keep_recent:
|
||||||
|
old_msgs = msgs[:-keep_recent]
|
||||||
|
recent_msgs = msgs[-keep_recent:]
|
||||||
|
else:
|
||||||
|
old_msgs = []
|
||||||
|
recent_msgs = msgs
|
||||||
|
|
||||||
|
# Ensure tool pairs stay intact
|
||||||
|
slice_start = max(0, len(msgs) - keep_recent)
|
||||||
|
recent_msgs = _ensure_tool_pairs_intact(recent_msgs, msgs, slice_start)
|
||||||
|
|
||||||
|
if old_msgs:
|
||||||
|
try:
|
||||||
|
summary_text = await _summarize_messages_llm(old_msgs, client, model)
|
||||||
|
summary_msg = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": f"[Previous conversation summary — for context only]: {summary_text}",
|
||||||
|
}
|
||||||
|
messages_summarized = len(old_msgs)
|
||||||
|
|
||||||
|
if has_system:
|
||||||
|
msgs = [system_msg, summary_msg] + recent_msgs
|
||||||
|
else:
|
||||||
|
msgs = [summary_msg] + recent_msgs
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
|
||||||
|
f"summarized {messages_summarized} messages"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Summarization failed, continuing with truncation: {e}")
|
||||||
|
# Fall through to content truncation
|
||||||
|
|
||||||
|
# ---- STEP 2: Normalize content ----------------------------------------
|
||||||
|
# Convert non-string payloads to strings so token counting is coherent.
|
||||||
|
# Always run this before truncation to ensure consistent token counting.
|
||||||
|
for i, m in enumerate(msgs):
|
||||||
|
if not isinstance(m.get("content"), str) and m.get("content") is not None:
|
||||||
|
if _is_tool_message(m):
|
||||||
|
continue
|
||||||
|
if i == 0 or i == len(msgs) - 1:
|
||||||
|
continue
|
||||||
|
content_str = json.dumps(m["content"], separators=(",", ":"))
|
||||||
|
if len(content_str) > 20_000:
|
||||||
|
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
|
||||||
|
m["content"] = content_str
|
||||||
|
|
||||||
|
# ---- STEP 3: Token-aware content truncation ---------------------------
|
||||||
|
# Progressively halve per-message cap and truncate bloated content.
|
||||||
|
# This preserves all messages but shortens their content.
|
||||||
|
cap = start_cap
|
||||||
|
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||||
|
for m in msgs[1:-1]:
|
||||||
|
if _is_tool_message(m):
|
||||||
|
_truncate_tool_message_content(m, enc, cap)
|
||||||
|
continue
|
||||||
|
if _is_objective_message(m):
|
||||||
|
continue
|
||||||
|
content = m.get("content") or ""
|
||||||
|
if _tok_len(content, enc) > cap:
|
||||||
|
m["content"] = _truncate_middle_tokens(content, enc, cap)
|
||||||
|
cap //= 2
|
||||||
|
|
||||||
|
# ---- STEP 4: Middle-out deletion --------------------------------------
|
||||||
|
# Delete messages one at a time from the center outward.
|
||||||
|
# This is more granular than dropping all old messages at once.
|
||||||
|
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
|
||||||
|
deletable: list[int] = []
|
||||||
|
for i in range(1, len(msgs) - 1):
|
||||||
|
msg = msgs[i]
|
||||||
|
if (
|
||||||
|
msg is not None
|
||||||
|
and not _is_tool_message(msg)
|
||||||
|
and not _is_objective_message(msg)
|
||||||
|
):
|
||||||
|
deletable.append(i)
|
||||||
|
if not deletable:
|
||||||
|
break
|
||||||
|
centre = len(msgs) // 2
|
||||||
|
to_delete = min(deletable, key=lambda i: abs(i - centre))
|
||||||
|
del msgs[to_delete]
|
||||||
|
messages_dropped += 1
|
||||||
|
|
||||||
|
# ---- STEP 5: Final trim on first/last ---------------------------------
|
||||||
|
cap = start_cap
|
||||||
|
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
|
||||||
|
for idx in (0, -1):
|
||||||
|
msg = msgs[idx]
|
||||||
|
if msg is None:
|
||||||
|
continue
|
||||||
|
if _is_tool_message(msg):
|
||||||
|
_truncate_tool_message_content(msg, enc, cap)
|
||||||
|
continue
|
||||||
|
text = msg.get("content") or ""
|
||||||
|
if _tok_len(text, enc) > cap:
|
||||||
|
msg["content"] = _truncate_middle_tokens(text, enc, cap)
|
||||||
|
cap //= 2
|
||||||
|
|
||||||
|
# Filter out any None values that may have been introduced
|
||||||
|
final_msgs: list[dict] = [m for m in msgs if m is not None]
|
||||||
|
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
|
||||||
|
error = None
|
||||||
|
if final_count + reserve > target_tokens:
|
||||||
|
error = f"Could not compress below target ({final_count + reserve} > {target_tokens})"
|
||||||
|
logger.warning(error)
|
||||||
|
|
||||||
|
return CompressResult(
|
||||||
|
messages=final_msgs,
|
||||||
|
token_count=final_count,
|
||||||
|
was_compacted=True,
|
||||||
|
error=error,
|
||||||
|
original_token_count=original_count,
|
||||||
|
messages_summarized=messages_summarized,
|
||||||
|
messages_dropped=messages_dropped,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,10 +1,21 @@
|
|||||||
"""Tests for prompt utility functions, especially tool call token counting."""
|
"""Tests for prompt utility functions, especially tool call token counting."""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from tiktoken import encoding_for_model
|
from tiktoken import encoding_for_model
|
||||||
|
|
||||||
from backend.util import json
|
from backend.util import json
|
||||||
from backend.util.prompt import _msg_tokens, estimate_token_count
|
from backend.util.prompt import (
|
||||||
|
CompressResult,
|
||||||
|
_ensure_tool_pairs_intact,
|
||||||
|
_msg_tokens,
|
||||||
|
_normalize_model_for_tokenizer,
|
||||||
|
_truncate_middle_tokens,
|
||||||
|
_truncate_tool_message_content,
|
||||||
|
compress_context,
|
||||||
|
estimate_token_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestMsgTokens:
|
class TestMsgTokens:
|
||||||
@@ -276,3 +287,690 @@ class TestEstimateTokenCount:
|
|||||||
|
|
||||||
assert total_tokens == expected_total
|
assert total_tokens == expected_total
|
||||||
assert total_tokens > 20 # Should be substantial
|
assert total_tokens > 20 # Should be substantial
|
||||||
|
|
||||||
|
|
||||||
|
class TestNormalizeModelForTokenizer:
|
||||||
|
"""Test model name normalization for tiktoken."""
|
||||||
|
|
||||||
|
def test_openai_models_unchanged(self):
|
||||||
|
"""Test that OpenAI models are returned as-is."""
|
||||||
|
assert _normalize_model_for_tokenizer("gpt-4o") == "gpt-4o"
|
||||||
|
assert _normalize_model_for_tokenizer("gpt-4") == "gpt-4"
|
||||||
|
assert _normalize_model_for_tokenizer("gpt-3.5-turbo") == "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
def test_claude_models_normalized(self):
|
||||||
|
"""Test that Claude models are normalized to gpt-4o."""
|
||||||
|
assert _normalize_model_for_tokenizer("claude-3-opus") == "gpt-4o"
|
||||||
|
assert _normalize_model_for_tokenizer("claude-3-sonnet") == "gpt-4o"
|
||||||
|
assert _normalize_model_for_tokenizer("anthropic/claude-3-haiku") == "gpt-4o"
|
||||||
|
|
||||||
|
def test_openrouter_paths_extracted(self):
|
||||||
|
"""Test that OpenRouter model paths are handled."""
|
||||||
|
assert _normalize_model_for_tokenizer("openai/gpt-4o") == "gpt-4o"
|
||||||
|
assert _normalize_model_for_tokenizer("anthropic/claude-3-opus") == "gpt-4o"
|
||||||
|
|
||||||
|
def test_unknown_models_default_to_gpt4o(self):
|
||||||
|
"""Test that unknown models default to gpt-4o."""
|
||||||
|
assert _normalize_model_for_tokenizer("some-random-model") == "gpt-4o"
|
||||||
|
assert _normalize_model_for_tokenizer("llama-3-70b") == "gpt-4o"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncateToolMessageContent:
|
||||||
|
"""Test tool message content truncation."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def enc(self):
|
||||||
|
return encoding_for_model("gpt-4o")
|
||||||
|
|
||||||
|
def test_truncate_openai_tool_message(self, enc):
|
||||||
|
"""Test truncation of OpenAI-style tool message with string content."""
|
||||||
|
long_content = "x" * 10000
|
||||||
|
msg = {"role": "tool", "tool_call_id": "call_123", "content": long_content}
|
||||||
|
|
||||||
|
_truncate_tool_message_content(msg, enc, max_tokens=100)
|
||||||
|
|
||||||
|
# Content should be truncated
|
||||||
|
assert len(msg["content"]) < len(long_content)
|
||||||
|
assert "…" in msg["content"] # Has ellipsis marker
|
||||||
|
|
||||||
|
def test_truncate_anthropic_tool_result(self, enc):
|
||||||
|
"""Test truncation of Anthropic-style tool_result."""
|
||||||
|
long_content = "y" * 10000
|
||||||
|
msg = {
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_123",
|
||||||
|
"content": long_content,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
_truncate_tool_message_content(msg, enc, max_tokens=100)
|
||||||
|
|
||||||
|
# Content should be truncated
|
||||||
|
result_content = msg["content"][0]["content"]
|
||||||
|
assert len(result_content) < len(long_content)
|
||||||
|
assert "…" in result_content
|
||||||
|
|
||||||
|
def test_preserve_tool_use_blocks(self, enc):
|
||||||
|
"""Test that tool_use blocks are not truncated."""
|
||||||
|
msg = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_123",
|
||||||
|
"name": "some_function",
|
||||||
|
"input": {"key": "value" * 1000}, # Large input
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
original = json.dumps(msg["content"][0]["input"])
|
||||||
|
_truncate_tool_message_content(msg, enc, max_tokens=10)
|
||||||
|
|
||||||
|
# tool_use should be unchanged
|
||||||
|
assert json.dumps(msg["content"][0]["input"]) == original
|
||||||
|
|
||||||
|
def test_no_truncation_when_under_limit(self, enc):
|
||||||
|
"""Test that short content is not modified."""
|
||||||
|
msg = {"role": "tool", "tool_call_id": "call_123", "content": "Short content"}
|
||||||
|
|
||||||
|
original = msg["content"]
|
||||||
|
_truncate_tool_message_content(msg, enc, max_tokens=1000)
|
||||||
|
|
||||||
|
assert msg["content"] == original
|
||||||
|
|
||||||
|
|
||||||
|
class TestTruncateMiddleTokens:
|
||||||
|
"""Test middle truncation of text."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def enc(self):
|
||||||
|
return encoding_for_model("gpt-4o")
|
||||||
|
|
||||||
|
def test_truncates_long_text(self, enc):
|
||||||
|
"""Test that long text is truncated with ellipsis in middle."""
|
||||||
|
long_text = "word " * 1000
|
||||||
|
result = _truncate_middle_tokens(long_text, enc, max_tok=50)
|
||||||
|
|
||||||
|
assert len(enc.encode(result)) <= 52 # Allow some slack for ellipsis
|
||||||
|
assert "…" in result
|
||||||
|
assert result.startswith("word") # Head preserved
|
||||||
|
assert result.endswith("word ") # Tail preserved
|
||||||
|
|
||||||
|
def test_preserves_short_text(self, enc):
|
||||||
|
"""Test that short text is not modified."""
|
||||||
|
short_text = "Hello world"
|
||||||
|
result = _truncate_middle_tokens(short_text, enc, max_tok=100)
|
||||||
|
|
||||||
|
assert result == short_text
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnsureToolPairsIntact:
|
||||||
|
"""Test tool call/response pair preservation for both OpenAI and Anthropic formats."""
|
||||||
|
|
||||||
|
# ---- OpenAI Format Tests ----
|
||||||
|
|
||||||
|
def test_openai_adds_missing_tool_call(self):
|
||||||
|
"""Test that orphaned OpenAI tool_response gets its tool_call prepended."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
|
||||||
|
{"role": "user", "content": "Thanks!"},
|
||||||
|
]
|
||||||
|
# Recent messages start at index 2 (the tool response)
|
||||||
|
recent = [all_msgs[2], all_msgs[3]]
|
||||||
|
start_index = 2
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
# Should prepend the tool_call message
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["role"] == "assistant"
|
||||||
|
assert "tool_calls" in result[0]
|
||||||
|
|
||||||
|
def test_openai_keeps_complete_pairs(self):
|
||||||
|
"""Test that complete OpenAI pairs are unchanged."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "System"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
|
||||||
|
]
|
||||||
|
recent = all_msgs[1:] # Include both tool_call and response
|
||||||
|
start_index = 1
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
assert len(result) == 2 # No messages added
|
||||||
|
|
||||||
|
def test_openai_multiple_tool_calls(self):
|
||||||
|
"""Test multiple OpenAI tool calls in one assistant message."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "System"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_1", "type": "function", "function": {"name": "f1"}},
|
||||||
|
{"id": "call_2", "type": "function", "function": {"name": "f2"}},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_1", "content": "result1"},
|
||||||
|
{"role": "tool", "tool_call_id": "call_2", "content": "result2"},
|
||||||
|
{"role": "user", "content": "Thanks!"},
|
||||||
|
]
|
||||||
|
# Recent messages start at index 2 (first tool response)
|
||||||
|
recent = [all_msgs[2], all_msgs[3], all_msgs[4]]
|
||||||
|
start_index = 2
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
# Should prepend the assistant message with both tool_calls
|
||||||
|
assert len(result) == 4
|
||||||
|
assert result[0]["role"] == "assistant"
|
||||||
|
assert len(result[0]["tool_calls"]) == 2
|
||||||
|
|
||||||
|
# ---- Anthropic Format Tests ----
|
||||||
|
|
||||||
|
def test_anthropic_adds_missing_tool_use(self):
|
||||||
|
"""Test that orphaned Anthropic tool_result gets its tool_use prepended."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_123",
|
||||||
|
"name": "get_weather",
|
||||||
|
"input": {"location": "SF"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_123",
|
||||||
|
"content": "22°C and sunny",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Thanks!"},
|
||||||
|
]
|
||||||
|
# Recent messages start at index 2 (the tool_result)
|
||||||
|
recent = [all_msgs[2], all_msgs[3]]
|
||||||
|
start_index = 2
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
# Should prepend the tool_use message
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["role"] == "assistant"
|
||||||
|
assert result[0]["content"][0]["type"] == "tool_use"
|
||||||
|
|
||||||
|
def test_anthropic_keeps_complete_pairs(self):
|
||||||
|
"""Test that complete Anthropic pairs are unchanged."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "System"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_456",
|
||||||
|
"name": "calculator",
|
||||||
|
"input": {"expr": "2+2"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_456",
|
||||||
|
"content": "4",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
recent = all_msgs[1:] # Include both tool_use and result
|
||||||
|
start_index = 1
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
assert len(result) == 2 # No messages added
|
||||||
|
|
||||||
|
def test_anthropic_multiple_tool_uses(self):
|
||||||
|
"""Test multiple Anthropic tool_use blocks in one message."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "System"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "Let me check both..."},
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_1",
|
||||||
|
"name": "get_weather",
|
||||||
|
"input": {"city": "NYC"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_2",
|
||||||
|
"name": "get_weather",
|
||||||
|
"input": {"city": "LA"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_1",
|
||||||
|
"content": "Cold",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_2",
|
||||||
|
"content": "Warm",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Thanks!"},
|
||||||
|
]
|
||||||
|
# Recent messages start at index 2 (tool_result)
|
||||||
|
recent = [all_msgs[2], all_msgs[3]]
|
||||||
|
start_index = 2
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
# Should prepend the assistant message with both tool_uses
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["role"] == "assistant"
|
||||||
|
tool_use_count = sum(
|
||||||
|
1 for b in result[0]["content"] if b.get("type") == "tool_use"
|
||||||
|
)
|
||||||
|
assert tool_use_count == 2
|
||||||
|
|
||||||
|
# ---- Mixed/Edge Case Tests ----
|
||||||
|
|
||||||
|
def test_anthropic_with_type_message_field(self):
|
||||||
|
"""Test Anthropic format with 'type': 'message' field (smart_decision_maker style)."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_abc",
|
||||||
|
"name": "search",
|
||||||
|
"input": {"q": "test"},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"type": "message", # Extra field from smart_decision_maker
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_abc",
|
||||||
|
"content": "Found results",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "user", "content": "Thanks!"},
|
||||||
|
]
|
||||||
|
# Recent messages start at index 2 (the tool_result with 'type': 'message')
|
||||||
|
recent = [all_msgs[2], all_msgs[3]]
|
||||||
|
start_index = 2
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
# Should prepend the tool_use message
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["role"] == "assistant"
|
||||||
|
assert result[0]["content"][0]["type"] == "tool_use"
|
||||||
|
|
||||||
|
def test_handles_no_tool_messages(self):
|
||||||
|
"""Test messages without tool calls."""
|
||||||
|
all_msgs = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
recent = all_msgs
|
||||||
|
start_index = 0
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
assert result == all_msgs
|
||||||
|
|
||||||
|
def test_handles_empty_messages(self):
|
||||||
|
"""Test empty message list."""
|
||||||
|
result = _ensure_tool_pairs_intact([], [], 0)
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_mixed_text_and_tool_content(self):
|
||||||
|
"""Test Anthropic message with mixed text and tool_use content."""
|
||||||
|
all_msgs = [
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [
|
||||||
|
{"type": "text", "text": "I'll help you with that."},
|
||||||
|
{
|
||||||
|
"type": "tool_use",
|
||||||
|
"id": "toolu_mixed",
|
||||||
|
"name": "search",
|
||||||
|
"input": {"q": "test"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_mixed",
|
||||||
|
"content": "Found results",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "Here are the results..."},
|
||||||
|
]
|
||||||
|
# Start from tool_result
|
||||||
|
recent = [all_msgs[1], all_msgs[2]]
|
||||||
|
start_index = 1
|
||||||
|
|
||||||
|
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
|
||||||
|
|
||||||
|
# Should prepend the assistant message with tool_use
|
||||||
|
assert len(result) == 3
|
||||||
|
assert result[0]["content"][0]["type"] == "text"
|
||||||
|
assert result[0]["content"][1]["type"] == "tool_use"
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompressContext:
|
||||||
|
"""Test the async compress_context function."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_compression_needed(self):
|
||||||
|
"""Test messages under limit return without compression."""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are helpful."},
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await compress_context(messages, target_tokens=100000)
|
||||||
|
|
||||||
|
assert isinstance(result, CompressResult)
|
||||||
|
assert result.was_compacted is False
|
||||||
|
assert len(result.messages) == 2
|
||||||
|
assert result.error is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_truncation_without_client(self):
|
||||||
|
"""Test that truncation works without LLM client."""
|
||||||
|
long_content = "x" * 50000
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "System"},
|
||||||
|
{"role": "user", "content": long_content},
|
||||||
|
{"role": "assistant", "content": "Response"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await compress_context(
|
||||||
|
messages, target_tokens=1000, client=None, reserve=100
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.was_compacted is True
|
||||||
|
# Should have truncated without summarization
|
||||||
|
assert result.messages_summarized == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_with_mocked_llm_client(self):
|
||||||
|
"""Test summarization with mocked LLM client."""
|
||||||
|
# Create many messages to trigger summarization
|
||||||
|
messages = [{"role": "system", "content": "System prompt"}]
|
||||||
|
for i in range(30):
|
||||||
|
messages.append({"role": "user", "content": f"User message {i} " * 100})
|
||||||
|
messages.append(
|
||||||
|
{"role": "assistant", "content": f"Assistant response {i} " * 100}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the AsyncOpenAI client
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.choices = [MagicMock()]
|
||||||
|
mock_response.choices[0].message.content = "Summary of conversation"
|
||||||
|
mock_client.with_options.return_value.chat.completions.create = AsyncMock(
|
||||||
|
return_value=mock_response
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await compress_context(
|
||||||
|
messages,
|
||||||
|
target_tokens=5000,
|
||||||
|
client=mock_client,
|
||||||
|
keep_recent=5,
|
||||||
|
reserve=500,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.was_compacted is True
|
||||||
|
# Should have attempted summarization
|
||||||
|
assert mock_client.with_options.called or result.messages_summarized > 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_preserves_tool_pairs(self):
|
||||||
|
"""Test that tool call/response pairs stay together."""
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "System"},
|
||||||
|
{"role": "user", "content": "Do something"},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "call_1", "type": "function", "function": {"name": "func"}}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{"role": "tool", "tool_call_id": "call_1", "content": "Result " * 1000},
|
||||||
|
{"role": "assistant", "content": "Done!"},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await compress_context(
|
||||||
|
messages, target_tokens=500, client=None, reserve=50
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that if tool response exists, its call exists too
|
||||||
|
tool_call_ids = set()
|
||||||
|
tool_response_ids = set()
|
||||||
|
for msg in result.messages:
|
||||||
|
if "tool_calls" in msg:
|
||||||
|
for tc in msg["tool_calls"]:
|
||||||
|
tool_call_ids.add(tc["id"])
|
||||||
|
if msg.get("role") == "tool":
|
||||||
|
tool_response_ids.add(msg.get("tool_call_id"))
|
||||||
|
|
||||||
|
# All tool responses should have their calls
|
||||||
|
assert tool_response_ids <= tool_call_ids
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_error_when_cannot_compress(self):
|
||||||
|
"""Test that error is returned when compression fails."""
|
||||||
|
# Single huge message that can't be compressed enough
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "x" * 100000},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await compress_context(
|
||||||
|
messages, target_tokens=100, client=None, reserve=50
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should have an error since we can't get below 100 tokens
|
||||||
|
assert result.error is not None
|
||||||
|
assert result.was_compacted is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_messages(self):
|
||||||
|
"""Test that empty messages list returns early without error."""
|
||||||
|
result = await compress_context([], target_tokens=1000)
|
||||||
|
|
||||||
|
assert result.messages == []
|
||||||
|
assert result.token_count == 0
|
||||||
|
assert result.was_compacted is False
|
||||||
|
assert result.error is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestRemoveOrphanToolResponses:
|
||||||
|
"""Test _remove_orphan_tool_responses helper function."""
|
||||||
|
|
||||||
|
def test_removes_openai_orphan(self):
|
||||||
|
"""Test removal of orphan OpenAI tool response."""
|
||||||
|
from backend.util.prompt import _remove_orphan_tool_responses
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "tool", "tool_call_id": "call_orphan", "content": "result"},
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
]
|
||||||
|
orphan_ids = {"call_orphan"}
|
||||||
|
|
||||||
|
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["role"] == "user"
|
||||||
|
|
||||||
|
def test_keeps_valid_openai_tool(self):
|
||||||
|
"""Test that valid OpenAI tool responses are kept."""
|
||||||
|
from backend.util.prompt import _remove_orphan_tool_responses
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "tool", "tool_call_id": "call_valid", "content": "result"},
|
||||||
|
]
|
||||||
|
orphan_ids = {"call_other"}
|
||||||
|
|
||||||
|
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0]["tool_call_id"] == "call_valid"
|
||||||
|
|
||||||
|
def test_filters_anthropic_mixed_blocks(self):
|
||||||
|
"""Test filtering individual orphan blocks from Anthropic message with mixed valid/orphan."""
|
||||||
|
from backend.util.prompt import _remove_orphan_tool_responses
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_valid",
|
||||||
|
"content": "valid result",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_orphan",
|
||||||
|
"content": "orphan result",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
orphan_ids = {"toolu_orphan"}
|
||||||
|
|
||||||
|
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||||
|
|
||||||
|
assert len(result) == 1
|
||||||
|
# Should only have the valid tool_result, orphan filtered out
|
||||||
|
assert len(result[0]["content"]) == 1
|
||||||
|
assert result[0]["content"][0]["tool_use_id"] == "toolu_valid"
|
||||||
|
|
||||||
|
def test_removes_anthropic_all_orphan(self):
|
||||||
|
"""Test removal of Anthropic message when all tool_results are orphans."""
|
||||||
|
from backend.util.prompt import _remove_orphan_tool_responses
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_orphan1",
|
||||||
|
"content": "result1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"tool_use_id": "toolu_orphan2",
|
||||||
|
"content": "result2",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
orphan_ids = {"toolu_orphan1", "toolu_orphan2"}
|
||||||
|
|
||||||
|
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||||
|
|
||||||
|
# Message should be completely removed since no content left
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
def test_preserves_non_tool_messages(self):
|
||||||
|
"""Test that non-tool messages are preserved."""
|
||||||
|
from backend.util.prompt import _remove_orphan_tool_responses
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "Hello"},
|
||||||
|
{"role": "assistant", "content": "Hi there!"},
|
||||||
|
]
|
||||||
|
orphan_ids = {"some_id"}
|
||||||
|
|
||||||
|
result = _remove_orphan_tool_responses(messages, orphan_ids)
|
||||||
|
|
||||||
|
assert result == messages
|
||||||
|
|
||||||
|
|
||||||
|
class TestCompressResultDataclass:
|
||||||
|
"""Test CompressResult dataclass."""
|
||||||
|
|
||||||
|
def test_default_values(self):
|
||||||
|
"""Test default values are set correctly."""
|
||||||
|
result = CompressResult(
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
token_count=10,
|
||||||
|
was_compacted=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.error is None
|
||||||
|
assert result.original_token_count == 0 # Defaults to 0, not None
|
||||||
|
assert result.messages_summarized == 0
|
||||||
|
assert result.messages_dropped == 0
|
||||||
|
|
||||||
|
def test_all_fields(self):
|
||||||
|
"""Test all fields can be set."""
|
||||||
|
result = CompressResult(
|
||||||
|
messages=[{"role": "user", "content": "test"}],
|
||||||
|
token_count=100,
|
||||||
|
was_compacted=True,
|
||||||
|
error="Some error",
|
||||||
|
original_token_count=500,
|
||||||
|
messages_summarized=10,
|
||||||
|
messages_dropped=5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.token_count == 100
|
||||||
|
assert result.was_compacted is True
|
||||||
|
assert result.error == "Some error"
|
||||||
|
assert result.original_token_count == 500
|
||||||
|
assert result.messages_summarized == 10
|
||||||
|
assert result.messages_dropped == 5
|
||||||
|
|||||||
32
autogpt_platform/backend/backend/util/validation.py
Normal file
32
autogpt_platform/backend/backend/util/validation.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Validation utilities."""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
_UUID_V4_PATTERN = re.compile(
|
||||||
|
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_uuid_v4(text: str) -> bool:
|
||||||
|
"""Check if text is a valid UUID v4.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: String to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the text is a valid UUID v4, False otherwise
|
||||||
|
"""
|
||||||
|
return bool(_UUID_V4_PATTERN.fullmatch(text.strip()))
|
||||||
|
|
||||||
|
|
||||||
|
def extract_uuids(text: str) -> list[str]:
|
||||||
|
"""Extract all UUID v4 strings from text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: String to search for UUIDs
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of unique UUIDs found (lowercase)
|
||||||
|
"""
|
||||||
|
return list({m.lower() for m in _UUID_V4_PATTERN.findall(text)})
|
||||||
Reference in New Issue
Block a user