mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-02-03 11:24:57 -05:00
Compare commits
4 Commits
fix/schedu
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
efd1e96235 | ||
|
|
14cee1670a | ||
|
|
d81d1ce024 | ||
|
|
2dd341c369 |
@@ -3,7 +3,8 @@ import logging
|
||||
import time
|
||||
from asyncio import CancelledError
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast
|
||||
|
||||
import openai
|
||||
import orjson
|
||||
@@ -15,7 +16,14 @@ from openai import (
|
||||
PermissionDeniedError,
|
||||
RateLimitError,
|
||||
)
|
||||
from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam
|
||||
from openai.types.chat import (
|
||||
ChatCompletionAssistantMessageParam,
|
||||
ChatCompletionChunk,
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionStreamOptionsParam,
|
||||
ChatCompletionSystemMessageParam,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
|
||||
from backend.data.redis_client import get_redis_async
|
||||
from backend.data.understanding import (
|
||||
@@ -23,6 +31,7 @@ from backend.data.understanding import (
|
||||
get_business_understanding,
|
||||
)
|
||||
from backend.util.exceptions import NotFoundError
|
||||
from backend.util.prompt import estimate_token_count
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from . import db as chat_db
|
||||
@@ -794,6 +803,201 @@ def _is_region_blocked_error(error: Exception) -> bool:
|
||||
return "not available in your region" in str(error).lower()
|
||||
|
||||
|
||||
# Context window management constants
|
||||
TOKEN_THRESHOLD = 120_000
|
||||
KEEP_RECENT_MESSAGES = 15
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContextWindowResult:
|
||||
"""Result of context window management."""
|
||||
|
||||
messages: list[dict[str, Any]]
|
||||
token_count: int
|
||||
was_compacted: bool
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def _messages_to_dicts(messages: list) -> list[dict[str, Any]]:
|
||||
"""Convert message objects to dicts, filtering None values.
|
||||
|
||||
Handles both TypedDict (dict-like) and other message formats.
|
||||
"""
|
||||
result = []
|
||||
for msg in messages:
|
||||
if msg is None:
|
||||
continue
|
||||
if isinstance(msg, dict):
|
||||
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
||||
else:
|
||||
msg_dict = dict(msg)
|
||||
result.append(msg_dict)
|
||||
return result
|
||||
|
||||
|
||||
async def _manage_context_window(
|
||||
messages: list,
|
||||
model: str,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
) -> ContextWindowResult:
|
||||
"""
|
||||
Manage context window by summarizing old messages if token count exceeds threshold.
|
||||
|
||||
This function handles context compaction for LLM calls by:
|
||||
1. Counting tokens in the message list
|
||||
2. If over threshold, summarizing old messages while keeping recent ones
|
||||
3. Ensuring tool_call/tool_response pairs stay intact
|
||||
4. Progressively reducing message count if still over limit
|
||||
|
||||
Args:
|
||||
messages: List of messages in OpenAI format (with system prompt if present)
|
||||
model: Model name for token counting
|
||||
api_key: API key for summarization calls
|
||||
base_url: Base URL for summarization calls
|
||||
|
||||
Returns:
|
||||
ContextWindowResult with compacted messages and metadata
|
||||
"""
|
||||
if not messages:
|
||||
return ContextWindowResult([], 0, False, "No messages to compact")
|
||||
|
||||
messages_dict = _messages_to_dicts(messages)
|
||||
|
||||
# Normalize model name for token counting (tiktoken only supports OpenAI models)
|
||||
token_count_model = model.split("/")[-1] if "/" in model else model
|
||||
if "claude" in token_count_model.lower() or not any(
|
||||
known in token_count_model.lower()
|
||||
for known in ["gpt", "o1", "chatgpt", "text-"]
|
||||
):
|
||||
token_count_model = "gpt-4o"
|
||||
|
||||
try:
|
||||
token_count = estimate_token_count(messages_dict, model=token_count_model)
|
||||
except Exception as e:
|
||||
logger.warning(f"Token counting failed: {e}. Using gpt-4o approximation.")
|
||||
token_count_model = "gpt-4o"
|
||||
token_count = estimate_token_count(messages_dict, model=token_count_model)
|
||||
|
||||
if token_count <= TOKEN_THRESHOLD:
|
||||
return ContextWindowResult(messages, token_count, False)
|
||||
|
||||
has_system_prompt = messages[0].get("role") == "system"
|
||||
slice_start = max(0, len(messages_dict) - KEEP_RECENT_MESSAGES)
|
||||
recent_messages = _ensure_tool_pairs_intact(
|
||||
messages_dict[-KEEP_RECENT_MESSAGES:], messages_dict, slice_start
|
||||
)
|
||||
|
||||
# Determine old messages to summarize (explicit bounds to avoid slice edge cases)
|
||||
system_msg = messages[0] if has_system_prompt else None
|
||||
if has_system_prompt:
|
||||
old_messages_dict = (
|
||||
messages_dict[1:-KEEP_RECENT_MESSAGES]
|
||||
if len(messages_dict) > KEEP_RECENT_MESSAGES + 1
|
||||
else []
|
||||
)
|
||||
else:
|
||||
old_messages_dict = (
|
||||
messages_dict[:-KEEP_RECENT_MESSAGES]
|
||||
if len(messages_dict) > KEEP_RECENT_MESSAGES
|
||||
else []
|
||||
)
|
||||
|
||||
# Try to summarize old messages, fall back to truncation on failure
|
||||
summary_msg = None
|
||||
if old_messages_dict:
|
||||
try:
|
||||
summary_text = await _summarize_messages(
|
||||
old_messages_dict, model=model, api_key=api_key, base_url=base_url
|
||||
)
|
||||
summary_msg = ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=f"[Previous conversation summary — for context only]: {summary_text}",
|
||||
)
|
||||
base = [system_msg, summary_msg] if has_system_prompt else [summary_msg]
|
||||
messages = base + recent_messages
|
||||
logger.info(
|
||||
f"Context summarized: {token_count} tokens, "
|
||||
f"summarized {len(old_messages_dict)} msgs, kept {KEEP_RECENT_MESSAGES}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Summarization failed, falling back to truncation: {e}")
|
||||
messages = (
|
||||
[system_msg] + recent_messages if has_system_prompt else recent_messages
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Token count {token_count} exceeds threshold but no old messages to summarize"
|
||||
)
|
||||
|
||||
new_token_count = estimate_token_count(
|
||||
_messages_to_dicts(messages), model=token_count_model
|
||||
)
|
||||
|
||||
# Progressive truncation if still over limit
|
||||
if new_token_count > TOKEN_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Still over limit: {new_token_count} tokens. Reducing messages."
|
||||
)
|
||||
base_msgs = (
|
||||
recent_messages
|
||||
if old_messages_dict
|
||||
else (messages_dict[1:] if has_system_prompt else messages_dict)
|
||||
)
|
||||
|
||||
def build_messages(recent: list) -> list:
|
||||
"""Build message list with optional system prompt and summary."""
|
||||
prefix = []
|
||||
if has_system_prompt and system_msg:
|
||||
prefix.append(system_msg)
|
||||
if summary_msg:
|
||||
prefix.append(summary_msg)
|
||||
return prefix + recent
|
||||
|
||||
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
||||
if keep_count == 0:
|
||||
messages = build_messages([])
|
||||
if not messages:
|
||||
continue
|
||||
elif len(base_msgs) < keep_count:
|
||||
continue
|
||||
else:
|
||||
reduced = _ensure_tool_pairs_intact(
|
||||
base_msgs[-keep_count:],
|
||||
base_msgs,
|
||||
max(0, len(base_msgs) - keep_count),
|
||||
)
|
||||
messages = build_messages(reduced)
|
||||
|
||||
new_token_count = estimate_token_count(
|
||||
_messages_to_dicts(messages), model=token_count_model
|
||||
)
|
||||
if new_token_count <= TOKEN_THRESHOLD:
|
||||
logger.info(
|
||||
f"Reduced to {keep_count} messages, {new_token_count} tokens"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.error(
|
||||
f"Cannot reduce below threshold. Final: {new_token_count} tokens"
|
||||
)
|
||||
if has_system_prompt and len(messages) > 1:
|
||||
messages = messages[1:]
|
||||
logger.critical("Dropped system prompt as last resort")
|
||||
return ContextWindowResult(
|
||||
messages, new_token_count, True, "System prompt dropped"
|
||||
)
|
||||
# No system prompt to drop - return error so callers don't proceed with oversized context
|
||||
return ContextWindowResult(
|
||||
messages,
|
||||
new_token_count,
|
||||
True,
|
||||
"Unable to reduce context below token limit",
|
||||
)
|
||||
|
||||
return ContextWindowResult(messages, new_token_count, True)
|
||||
|
||||
|
||||
async def _summarize_messages(
|
||||
messages: list,
|
||||
model: str,
|
||||
@@ -1022,11 +1226,8 @@ async def _stream_chat_chunks(
|
||||
|
||||
logger.info("Starting pure chat stream")
|
||||
|
||||
# Build messages with system prompt prepended
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||
|
||||
system_message = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
@@ -1034,314 +1235,38 @@ async def _stream_chat_chunks(
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Apply context window management
|
||||
token_count = 0 # Initialize for exception handler
|
||||
try:
|
||||
from backend.util.prompt import estimate_token_count
|
||||
context_result = await _manage_context_window(
|
||||
messages=messages,
|
||||
model=model,
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
)
|
||||
|
||||
# Convert to dict for token counting
|
||||
# OpenAI message types are TypedDicts, so they're already dict-like
|
||||
messages_dict = []
|
||||
for msg in messages:
|
||||
# TypedDict objects are already dicts, just filter None values
|
||||
if isinstance(msg, dict):
|
||||
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
||||
else:
|
||||
# Fallback for unexpected types
|
||||
msg_dict = dict(msg)
|
||||
messages_dict.append(msg_dict)
|
||||
|
||||
# Estimate tokens using appropriate tokenizer
|
||||
# Normalize model name for token counting (tiktoken only supports OpenAI models)
|
||||
token_count_model = model
|
||||
if "/" in model:
|
||||
# Strip provider prefix (e.g., "anthropic/claude-opus-4.5" -> "claude-opus-4.5")
|
||||
token_count_model = model.split("/")[-1]
|
||||
|
||||
# For Claude and other non-OpenAI models, approximate with gpt-4o tokenizer
|
||||
# Most modern LLMs have similar tokenization (~1 token per 4 chars)
|
||||
if "claude" in token_count_model.lower() or not any(
|
||||
known in token_count_model.lower()
|
||||
for known in ["gpt", "o1", "chatgpt", "text-"]
|
||||
):
|
||||
token_count_model = "gpt-4o"
|
||||
|
||||
# Attempt token counting with error handling
|
||||
try:
|
||||
token_count = estimate_token_count(messages_dict, model=token_count_model)
|
||||
except Exception as token_error:
|
||||
# If token counting fails, use gpt-4o as fallback approximation
|
||||
logger.warning(
|
||||
f"Token counting failed for model {token_count_model}: {token_error}. "
|
||||
"Using gpt-4o approximation."
|
||||
)
|
||||
token_count = estimate_token_count(messages_dict, model="gpt-4o")
|
||||
|
||||
# If over threshold, summarize old messages
|
||||
if token_count > 120_000:
|
||||
KEEP_RECENT = 15
|
||||
|
||||
# Check if we have a system prompt at the start
|
||||
has_system_prompt = (
|
||||
len(messages) > 0 and messages[0].get("role") == "system"
|
||||
)
|
||||
|
||||
# Always attempt mitigation when over limit, even with few messages
|
||||
if messages:
|
||||
# Split messages based on whether system prompt exists
|
||||
# Calculate start index for the slice
|
||||
slice_start = max(0, len(messages_dict) - KEEP_RECENT)
|
||||
recent_messages = messages_dict[-KEEP_RECENT:]
|
||||
|
||||
# Ensure tool_call/tool_response pairs stay together
|
||||
# This prevents API errors from orphan tool responses
|
||||
recent_messages = _ensure_tool_pairs_intact(
|
||||
recent_messages, messages_dict, slice_start
|
||||
)
|
||||
|
||||
if has_system_prompt:
|
||||
# Keep system prompt separate, summarize everything between system and recent
|
||||
system_msg = messages[0]
|
||||
old_messages_dict = messages_dict[1:-KEEP_RECENT]
|
||||
else:
|
||||
# No system prompt, summarize everything except recent
|
||||
system_msg = None
|
||||
old_messages_dict = messages_dict[:-KEEP_RECENT]
|
||||
|
||||
# Summarize any non-empty old messages (no minimum threshold)
|
||||
# If we're over the token limit, we need to compress whatever we can
|
||||
if old_messages_dict:
|
||||
# Summarize old messages using the same model as chat
|
||||
summary_text = await _summarize_messages(
|
||||
old_messages_dict,
|
||||
model=model,
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
)
|
||||
|
||||
# Build new message list
|
||||
# Use assistant role (not system) to prevent privilege escalation
|
||||
# of user-influenced content to instruction-level authority
|
||||
from openai.types.chat import ChatCompletionAssistantMessageParam
|
||||
|
||||
summary_msg = ChatCompletionAssistantMessageParam(
|
||||
role="assistant",
|
||||
content=(
|
||||
"[Previous conversation summary — for context only]: "
|
||||
f"{summary_text}"
|
||||
),
|
||||
)
|
||||
|
||||
# Rebuild messages based on whether we have a system prompt
|
||||
if has_system_prompt:
|
||||
# system_prompt + summary + recent_messages
|
||||
messages = [system_msg, summary_msg] + recent_messages
|
||||
else:
|
||||
# summary + recent_messages (no original system prompt)
|
||||
messages = [summary_msg] + recent_messages
|
||||
|
||||
logger.info(
|
||||
f"Context summarized: {token_count} tokens, "
|
||||
f"summarized {len(old_messages_dict)} old messages, "
|
||||
f"kept last {KEEP_RECENT} messages"
|
||||
)
|
||||
|
||||
# Fallback: If still over limit after summarization, progressively drop recent messages
|
||||
# This handles edge cases where recent messages are extremely large
|
||||
new_messages_dict = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
msg_dict = {k: v for k, v in msg.items() if v is not None}
|
||||
else:
|
||||
msg_dict = dict(msg)
|
||||
new_messages_dict.append(msg_dict)
|
||||
|
||||
new_token_count = estimate_token_count(
|
||||
new_messages_dict, model=token_count_model
|
||||
)
|
||||
|
||||
if new_token_count > 120_000:
|
||||
# Still over limit - progressively reduce KEEP_RECENT
|
||||
logger.warning(
|
||||
f"Still over limit after summarization: {new_token_count} tokens. "
|
||||
"Reducing number of recent messages kept."
|
||||
)
|
||||
|
||||
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
||||
if keep_count == 0:
|
||||
# Try with just system prompt + summary (no recent messages)
|
||||
if has_system_prompt:
|
||||
messages = [system_msg, summary_msg]
|
||||
else:
|
||||
messages = [summary_msg]
|
||||
logger.info(
|
||||
"Trying with 0 recent messages (system + summary only)"
|
||||
)
|
||||
else:
|
||||
# Slice from ORIGINAL recent_messages to avoid duplicating summary
|
||||
reduced_recent = (
|
||||
recent_messages[-keep_count:]
|
||||
if len(recent_messages) >= keep_count
|
||||
else recent_messages
|
||||
)
|
||||
# Ensure tool pairs stay intact in the reduced slice
|
||||
reduced_slice_start = max(
|
||||
0, len(recent_messages) - keep_count
|
||||
)
|
||||
reduced_recent = _ensure_tool_pairs_intact(
|
||||
reduced_recent, recent_messages, reduced_slice_start
|
||||
)
|
||||
if has_system_prompt:
|
||||
messages = [
|
||||
system_msg,
|
||||
summary_msg,
|
||||
] + reduced_recent
|
||||
else:
|
||||
messages = [summary_msg] + reduced_recent
|
||||
|
||||
new_messages_dict = []
|
||||
for msg in messages:
|
||||
if isinstance(msg, dict):
|
||||
msg_dict = {
|
||||
k: v for k, v in msg.items() if v is not None
|
||||
}
|
||||
else:
|
||||
msg_dict = dict(msg)
|
||||
new_messages_dict.append(msg_dict)
|
||||
|
||||
new_token_count = estimate_token_count(
|
||||
new_messages_dict, model=token_count_model
|
||||
)
|
||||
|
||||
if new_token_count <= 120_000:
|
||||
logger.info(
|
||||
f"Reduced to {keep_count} recent messages, "
|
||||
f"now {new_token_count} tokens"
|
||||
)
|
||||
break
|
||||
else:
|
||||
logger.error(
|
||||
f"Unable to reduce token count below threshold even with 0 messages. "
|
||||
f"Final count: {new_token_count} tokens"
|
||||
)
|
||||
# ABSOLUTE LAST RESORT: Drop system prompt
|
||||
# This should only happen if summary itself is massive
|
||||
if has_system_prompt and len(messages) > 1:
|
||||
messages = messages[1:] # Drop system prompt
|
||||
logger.critical(
|
||||
"CRITICAL: Dropped system prompt as absolute last resort. "
|
||||
"Behavioral consistency may be affected."
|
||||
)
|
||||
# Yield error to user
|
||||
yield StreamError(
|
||||
errorText=(
|
||||
"Warning: System prompt dropped due to size constraints. "
|
||||
"Assistant behavior may be affected."
|
||||
)
|
||||
)
|
||||
else:
|
||||
# No old messages to summarize - all messages are "recent"
|
||||
# Apply progressive truncation to reduce token count
|
||||
logger.warning(
|
||||
f"Token count {token_count} exceeds threshold but no old messages to summarize. "
|
||||
f"Applying progressive truncation to recent messages."
|
||||
)
|
||||
|
||||
# Create a base list excluding system prompt to avoid duplication
|
||||
# This is the pool of messages we'll slice from in the loop
|
||||
# Use messages_dict for type consistency with _ensure_tool_pairs_intact
|
||||
base_msgs = (
|
||||
messages_dict[1:] if has_system_prompt else messages_dict
|
||||
)
|
||||
|
||||
# Try progressively smaller keep counts
|
||||
new_token_count = token_count # Initialize with current count
|
||||
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
|
||||
if keep_count == 0:
|
||||
# Try with just system prompt (no recent messages)
|
||||
if has_system_prompt:
|
||||
messages = [system_msg]
|
||||
logger.info(
|
||||
"Trying with 0 recent messages (system prompt only)"
|
||||
)
|
||||
else:
|
||||
# No system prompt and no recent messages = empty messages list
|
||||
# This is invalid, skip this iteration
|
||||
continue
|
||||
else:
|
||||
if len(base_msgs) < keep_count:
|
||||
continue # Skip if we don't have enough messages
|
||||
|
||||
# Slice from base_msgs to get recent messages (without system prompt)
|
||||
recent_messages = base_msgs[-keep_count:]
|
||||
|
||||
# Ensure tool pairs stay intact in the reduced slice
|
||||
reduced_slice_start = max(0, len(base_msgs) - keep_count)
|
||||
recent_messages = _ensure_tool_pairs_intact(
|
||||
recent_messages, base_msgs, reduced_slice_start
|
||||
)
|
||||
|
||||
if has_system_prompt:
|
||||
messages = [system_msg] + recent_messages
|
||||
else:
|
||||
messages = recent_messages
|
||||
|
||||
new_messages_dict = []
|
||||
for msg in messages:
|
||||
if msg is None:
|
||||
continue # Skip None messages (type safety)
|
||||
if isinstance(msg, dict):
|
||||
msg_dict = {
|
||||
k: v for k, v in msg.items() if v is not None
|
||||
}
|
||||
else:
|
||||
msg_dict = dict(msg)
|
||||
new_messages_dict.append(msg_dict)
|
||||
|
||||
new_token_count = estimate_token_count(
|
||||
new_messages_dict, model=token_count_model
|
||||
)
|
||||
|
||||
if new_token_count <= 120_000:
|
||||
logger.info(
|
||||
f"Reduced to {keep_count} recent messages, "
|
||||
f"now {new_token_count} tokens"
|
||||
)
|
||||
break
|
||||
else:
|
||||
# Even with 0 messages still over limit
|
||||
logger.error(
|
||||
f"Unable to reduce token count below threshold even with 0 messages. "
|
||||
f"Final count: {new_token_count} tokens. Messages may be extremely large."
|
||||
)
|
||||
# ABSOLUTE LAST RESORT: Drop system prompt
|
||||
if has_system_prompt and len(messages) > 1:
|
||||
messages = messages[1:] # Drop system prompt
|
||||
logger.critical(
|
||||
"CRITICAL: Dropped system prompt as absolute last resort. "
|
||||
"Behavioral consistency may be affected."
|
||||
)
|
||||
# Yield error to user
|
||||
yield StreamError(
|
||||
errorText=(
|
||||
"Warning: System prompt dropped due to size constraints. "
|
||||
"Assistant behavior may be affected."
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Context summarization failed: {e}", exc_info=True)
|
||||
# If we were over the token limit, yield error to user
|
||||
# Don't silently continue with oversized messages that will fail
|
||||
if token_count > 120_000:
|
||||
if context_result.error:
|
||||
if "System prompt dropped" in context_result.error:
|
||||
# Warning only - continue with reduced context
|
||||
yield StreamError(
|
||||
errorText=(
|
||||
f"Unable to manage context window (token limit exceeded: {token_count} tokens). "
|
||||
"Context summarization failed. Please start a new conversation."
|
||||
"Warning: System prompt dropped due to size constraints. "
|
||||
"Assistant behavior may be affected."
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Any other error - abort to prevent failed LLM calls
|
||||
yield StreamError(
|
||||
errorText=(
|
||||
f"Context window management failed: {context_result.error}. "
|
||||
"Please start a new conversation."
|
||||
)
|
||||
)
|
||||
yield StreamFinish()
|
||||
return
|
||||
# Otherwise, continue with original messages (under limit)
|
||||
|
||||
messages = context_result.messages
|
||||
if context_result.was_compacted:
|
||||
logger.info(
|
||||
f"Context compacted for streaming: {context_result.token_count} tokens"
|
||||
)
|
||||
|
||||
# Loop to handle tool calls and continue conversation
|
||||
while True:
|
||||
@@ -1369,14 +1294,6 @@ async def _stream_chat_chunks(
|
||||
:128
|
||||
] # OpenRouter limit
|
||||
|
||||
# Create the stream with proper types
|
||||
from typing import cast
|
||||
|
||||
from openai.types.chat import (
|
||||
ChatCompletionMessageParam,
|
||||
ChatCompletionStreamOptionsParam,
|
||||
)
|
||||
|
||||
stream = await client.chat.completions.create(
|
||||
model=model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
@@ -1502,6 +1419,7 @@ async def _stream_chat_chunks(
|
||||
return
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
|
||||
retry_count += 1
|
||||
# Calculate delay with exponential backoff
|
||||
@@ -1517,12 +1435,24 @@ async def _stream_chat_chunks(
|
||||
continue # Retry the stream
|
||||
else:
|
||||
# Non-retryable error or max retries exceeded
|
||||
logger.error(
|
||||
f"Error in stream (not retrying): {e!s}",
|
||||
exc_info=True,
|
||||
_log_api_error(
|
||||
error=e,
|
||||
session_id=session.session_id if session else None,
|
||||
message_count=len(messages) if messages else None,
|
||||
model=model,
|
||||
retry_count=retry_count,
|
||||
)
|
||||
error_code = None
|
||||
error_text = str(e)
|
||||
|
||||
error_details = _extract_api_error_details(e)
|
||||
if error_details.get("response_body"):
|
||||
body = error_details["response_body"]
|
||||
if isinstance(body, dict) and body.get("error", {}).get(
|
||||
"message"
|
||||
):
|
||||
error_text = body["error"]["message"]
|
||||
|
||||
if _is_region_blocked_error(e):
|
||||
error_code = "MODEL_NOT_AVAILABLE_REGION"
|
||||
error_text = (
|
||||
@@ -1539,9 +1469,12 @@ async def _stream_chat_chunks(
|
||||
|
||||
# If we exit the retry loop without returning, it means we exhausted retries
|
||||
if last_error:
|
||||
logger.error(
|
||||
f"Max retries ({MAX_RETRIES}) exceeded. Last error: {last_error!s}",
|
||||
exc_info=True,
|
||||
_log_api_error(
|
||||
error=last_error,
|
||||
session_id=session.session_id if session else None,
|
||||
message_count=len(messages) if messages else None,
|
||||
model=model,
|
||||
retry_count=MAX_RETRIES,
|
||||
)
|
||||
yield StreamError(errorText=f"Max retries exceeded: {last_error!s}")
|
||||
yield StreamFinish()
|
||||
@@ -1900,17 +1833,36 @@ async def _generate_llm_continuation(
|
||||
# Build system prompt
|
||||
system_prompt, _ = await _build_system_prompt(user_id)
|
||||
|
||||
# Build messages in OpenAI format
|
||||
messages = session.to_openai_messages()
|
||||
if system_prompt:
|
||||
from openai.types.chat import ChatCompletionSystemMessageParam
|
||||
|
||||
system_message = ChatCompletionSystemMessageParam(
|
||||
role="system",
|
||||
content=system_prompt,
|
||||
)
|
||||
messages = [system_message] + messages
|
||||
|
||||
# Apply context window management to prevent oversized requests
|
||||
context_result = await _manage_context_window(
|
||||
messages=messages,
|
||||
model=config.model,
|
||||
api_key=config.api_key,
|
||||
base_url=config.base_url,
|
||||
)
|
||||
|
||||
if context_result.error and "System prompt dropped" not in context_result.error:
|
||||
logger.error(
|
||||
f"Context window management failed for session {session_id}: "
|
||||
f"{context_result.error} (tokens={context_result.token_count})"
|
||||
)
|
||||
return
|
||||
|
||||
messages = context_result.messages
|
||||
if context_result.was_compacted:
|
||||
logger.info(
|
||||
f"Context compacted for LLM continuation: "
|
||||
f"{context_result.token_count} tokens"
|
||||
)
|
||||
|
||||
# Build extra_body for tracing
|
||||
extra_body: dict[str, Any] = {
|
||||
"posthogProperties": {
|
||||
@@ -1923,19 +1875,61 @@ async def _generate_llm_continuation(
|
||||
if session_id:
|
||||
extra_body["session_id"] = session_id[:128]
|
||||
|
||||
# Make non-streaming LLM call (no tools - just text response)
|
||||
from typing import cast
|
||||
retry_count = 0
|
||||
last_error: Exception | None = None
|
||||
response = None
|
||||
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
while retry_count <= MAX_RETRIES:
|
||||
try:
|
||||
logger.info(
|
||||
f"Generating LLM continuation for session {session_id}"
|
||||
f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}"
|
||||
)
|
||||
|
||||
# No tools parameter = text-only response (no tool calls)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
extra_body=extra_body,
|
||||
)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.model,
|
||||
messages=cast(list[ChatCompletionMessageParam], messages),
|
||||
extra_body=extra_body,
|
||||
)
|
||||
last_error = None # Clear any previous error on success
|
||||
break # Success, exit retry loop
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
|
||||
if response.choices and response.choices[0].message.content:
|
||||
if _is_retryable_error(e) and retry_count < MAX_RETRIES:
|
||||
retry_count += 1
|
||||
delay = min(
|
||||
BASE_DELAY_SECONDS * (2 ** (retry_count - 1)),
|
||||
MAX_DELAY_SECONDS,
|
||||
)
|
||||
logger.warning(
|
||||
f"Retryable error in LLM continuation: {e!s}. "
|
||||
f"Retrying in {delay:.1f}s (attempt {retry_count}/{MAX_RETRIES})"
|
||||
)
|
||||
await asyncio.sleep(delay)
|
||||
continue
|
||||
else:
|
||||
# Non-retryable error - log details and exit gracefully
|
||||
_log_api_error(
|
||||
error=e,
|
||||
session_id=session_id,
|
||||
message_count=len(messages) if messages else None,
|
||||
model=config.model,
|
||||
retry_count=retry_count,
|
||||
)
|
||||
return
|
||||
|
||||
if last_error:
|
||||
_log_api_error(
|
||||
error=last_error,
|
||||
session_id=session_id,
|
||||
message_count=len(messages) if messages else None,
|
||||
model=config.model,
|
||||
retry_count=MAX_RETRIES,
|
||||
)
|
||||
return
|
||||
|
||||
if response and response.choices and response.choices[0].message.content:
|
||||
assistant_content = response.choices[0].message.content
|
||||
|
||||
# Reload session from DB to avoid race condition with user messages
|
||||
@@ -1969,3 +1963,78 @@ async def _generate_llm_continuation(
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to generate LLM continuation: {e}", exc_info=True)
|
||||
|
||||
|
||||
def _log_api_error(
|
||||
error: Exception,
|
||||
session_id: str | None = None,
|
||||
message_count: int | None = None,
|
||||
model: str | None = None,
|
||||
retry_count: int = 0,
|
||||
) -> None:
|
||||
"""Log detailed API error information for debugging."""
|
||||
details = _extract_api_error_details(error)
|
||||
details["session_id"] = session_id
|
||||
details["message_count"] = message_count
|
||||
details["model"] = model
|
||||
details["retry_count"] = retry_count
|
||||
|
||||
if isinstance(error, RateLimitError):
|
||||
logger.warning(f"Rate limit error: {details}")
|
||||
elif isinstance(error, APIConnectionError):
|
||||
logger.warning(f"API connection error: {details}")
|
||||
elif isinstance(error, APIStatusError) and error.status_code >= 500:
|
||||
logger.error(f"API server error (5xx): {details}")
|
||||
else:
|
||||
logger.error(f"API error: {details}")
|
||||
|
||||
|
||||
def _extract_api_error_details(error: Exception) -> dict[str, Any]:
|
||||
"""Extract detailed information from OpenAI/OpenRouter API errors."""
|
||||
error_msg = str(error)
|
||||
details: dict[str, Any] = {
|
||||
"error_type": type(error).__name__,
|
||||
"error_message": error_msg[:500] + "..." if len(error_msg) > 500 else error_msg,
|
||||
}
|
||||
|
||||
if hasattr(error, "code"):
|
||||
details["code"] = error.code
|
||||
if hasattr(error, "param"):
|
||||
details["param"] = error.param
|
||||
|
||||
if isinstance(error, APIStatusError):
|
||||
details["status_code"] = error.status_code
|
||||
details["request_id"] = getattr(error, "request_id", None)
|
||||
|
||||
if hasattr(error, "body") and error.body:
|
||||
details["response_body"] = _sanitize_error_body(error.body)
|
||||
|
||||
if hasattr(error, "response") and error.response:
|
||||
headers = error.response.headers
|
||||
details["openrouter_provider"] = headers.get("x-openrouter-provider")
|
||||
details["openrouter_model"] = headers.get("x-openrouter-model")
|
||||
details["retry_after"] = headers.get("retry-after")
|
||||
details["rate_limit_remaining"] = headers.get("x-ratelimit-remaining")
|
||||
|
||||
return details
|
||||
|
||||
|
||||
def _sanitize_error_body(body: Any, max_length: int = 2000) -> dict[str, Any] | None:
|
||||
"""Extract only safe fields from error response body to avoid logging sensitive data."""
|
||||
if not isinstance(body, dict):
|
||||
return None
|
||||
|
||||
safe_fields = ("message", "type", "code", "param", "error")
|
||||
sanitized: dict[str, Any] = {}
|
||||
|
||||
for field in safe_fields:
|
||||
if field in body:
|
||||
value = body[field]
|
||||
if field == "error" and isinstance(value, dict):
|
||||
sanitized[field] = _sanitize_error_body(value, max_length)
|
||||
elif isinstance(value, str) and len(value) > max_length:
|
||||
sanitized[field] = value[:max_length] + "...[truncated]"
|
||||
else:
|
||||
sanitized[field] = value
|
||||
|
||||
return sanitized if sanitized else None
|
||||
|
||||
@@ -139,11 +139,10 @@ async def decompose_goal_external(
|
||||
"""
|
||||
client = _get_client()
|
||||
|
||||
# Build the request payload
|
||||
payload: dict[str, Any] = {"description": description}
|
||||
if context:
|
||||
# The external service uses user_instruction for additional context
|
||||
payload["user_instruction"] = context
|
||||
description = f"{description}\n\nAdditional context from user:\n{context}"
|
||||
|
||||
payload: dict[str, Any] = {"description": description}
|
||||
if library_agents:
|
||||
payload["library_agents"] = library_agents
|
||||
|
||||
|
||||
@@ -66,18 +66,24 @@ async def event_broadcaster(manager: ConnectionManager):
|
||||
execution_bus = AsyncRedisExecutionEventBus()
|
||||
notification_bus = AsyncRedisNotificationEventBus()
|
||||
|
||||
async def execution_worker():
|
||||
async for event in execution_bus.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
try:
|
||||
|
||||
async def notification_worker():
|
||||
async for notification in notification_bus.listen("*"):
|
||||
await manager.send_notification(
|
||||
user_id=notification.user_id,
|
||||
payload=notification.payload,
|
||||
)
|
||||
async def execution_worker():
|
||||
async for event in execution_bus.listen("*"):
|
||||
await manager.send_execution_update(event)
|
||||
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
async def notification_worker():
|
||||
async for notification in notification_bus.listen("*"):
|
||||
await manager.send_notification(
|
||||
user_id=notification.user_id,
|
||||
payload=notification.payload,
|
||||
)
|
||||
|
||||
await asyncio.gather(execution_worker(), notification_worker())
|
||||
finally:
|
||||
# Ensure PubSub connections are closed on any exit to prevent leaks
|
||||
await execution_bus.close()
|
||||
await notification_bus.close()
|
||||
|
||||
|
||||
async def authenticate_websocket(websocket: WebSocket) -> str:
|
||||
|
||||
@@ -133,10 +133,23 @@ class RedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
|
||||
|
||||
class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
def __init__(self):
|
||||
self._pubsub: AsyncPubSub | None = None
|
||||
|
||||
@property
|
||||
async def connection(self) -> redis.AsyncRedis:
|
||||
return await redis.get_redis_async()
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the PubSub connection if it exists."""
|
||||
if self._pubsub is not None:
|
||||
try:
|
||||
await self._pubsub.close()
|
||||
except Exception:
|
||||
logger.warning("Failed to close PubSub connection", exc_info=True)
|
||||
finally:
|
||||
self._pubsub = None
|
||||
|
||||
async def publish_event(self, event: M, channel_key: str):
|
||||
"""
|
||||
Publish an event to Redis. Gracefully handles connection failures
|
||||
@@ -157,6 +170,7 @@ class AsyncRedisEventBus(BaseRedisEventBus[M], ABC):
|
||||
await self.connection, channel_key
|
||||
)
|
||||
assert isinstance(pubsub, AsyncPubSub)
|
||||
self._pubsub = pubsub
|
||||
|
||||
if "*" in channel_key:
|
||||
await pubsub.psubscribe(full_channel_name)
|
||||
|
||||
@@ -193,11 +193,9 @@ async def _handle_graph_validation_error(args: "GraphExecutionJobArgs") -> None:
|
||||
user_id=args.user_id,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Old scheduled job for graph {args.graph_id} (user {args.user_id}) "
|
||||
f"has no schedule_id, attempting targeted cleanup"
|
||||
logger.error(
|
||||
f"Unable to unschedule graph: {args.graph_id} as this is an old job with no associated schedule_id please remove manually"
|
||||
)
|
||||
await _cleanup_old_schedules_without_id(args.graph_id, args.user_id)
|
||||
|
||||
|
||||
async def _handle_graph_not_available(
|
||||
@@ -240,35 +238,6 @@ async def _cleanup_orphaned_schedules_for_graph(graph_id: str, user_id: str) ->
|
||||
)
|
||||
|
||||
|
||||
async def _cleanup_old_schedules_without_id(graph_id: str, user_id: str) -> None:
|
||||
"""Remove only schedules that have no schedule_id in their job args.
|
||||
|
||||
Unlike _cleanup_orphaned_schedules_for_graph (which removes ALL schedules
|
||||
for a graph), this only targets legacy jobs created before schedule_id was
|
||||
added to GraphExecutionJobArgs, preserving any valid newer schedules.
|
||||
"""
|
||||
scheduler_client = get_scheduler_client()
|
||||
schedules = await scheduler_client.get_execution_schedules(
|
||||
graph_id=graph_id, user_id=user_id
|
||||
)
|
||||
|
||||
for schedule in schedules:
|
||||
if schedule.schedule_id is not None:
|
||||
continue
|
||||
try:
|
||||
await scheduler_client.delete_schedule(
|
||||
schedule_id=schedule.id, user_id=user_id
|
||||
)
|
||||
logger.info(
|
||||
f"Cleaned up old schedule {schedule.id} (no schedule_id) "
|
||||
f"for graph {graph_id}"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to delete old schedule {schedule.id} for graph {graph_id}"
|
||||
)
|
||||
|
||||
|
||||
def cleanup_expired_files():
|
||||
"""Clean up expired files from cloud storage."""
|
||||
# Wait for completion
|
||||
|
||||
@@ -102,7 +102,7 @@ class TestDecomposeGoalExternal:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decompose_goal_with_context(self):
|
||||
"""Test decomposition with additional context."""
|
||||
"""Test decomposition with additional context enriched into description."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"success": True,
|
||||
@@ -119,9 +119,12 @@ class TestDecomposeGoalExternal:
|
||||
"Build a chatbot", context="Use Python"
|
||||
)
|
||||
|
||||
expected_description = (
|
||||
"Build a chatbot\n\nAdditional context from user:\nUse Python"
|
||||
)
|
||||
mock_client.post.assert_called_once_with(
|
||||
"/api/decompose-description",
|
||||
json={"description": "Build a chatbot", "user_instruction": "Use Python"},
|
||||
json={"description": expected_description},
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
Reference in New Issue
Block a user