diff --git a/autogpt_platform/backend/backend/api/features/chat/service.py b/autogpt_platform/backend/backend/api/features/chat/service.py index bcd6856503..f1f3156713 100644 --- a/autogpt_platform/backend/backend/api/features/chat/service.py +++ b/autogpt_platform/backend/backend/api/features/chat/service.py @@ -3,7 +3,8 @@ import logging import time from asyncio import CancelledError from collections.abc import AsyncGenerator -from typing import Any +from dataclasses import dataclass +from typing import Any, cast import openai import orjson @@ -15,7 +16,14 @@ from openai import ( PermissionDeniedError, RateLimitError, ) -from openai.types.chat import ChatCompletionChunk, ChatCompletionToolParam +from openai.types.chat import ( + ChatCompletionAssistantMessageParam, + ChatCompletionChunk, + ChatCompletionMessageParam, + ChatCompletionStreamOptionsParam, + ChatCompletionSystemMessageParam, + ChatCompletionToolParam, +) from backend.data.redis_client import get_redis_async from backend.data.understanding import ( @@ -23,6 +31,7 @@ from backend.data.understanding import ( get_business_understanding, ) from backend.util.exceptions import NotFoundError +from backend.util.prompt import estimate_token_count from backend.util.settings import Settings from . import db as chat_db @@ -794,6 +803,201 @@ def _is_region_blocked_error(error: Exception) -> bool: return "not available in your region" in str(error).lower() +# Context window management constants +TOKEN_THRESHOLD = 120_000 +KEEP_RECENT_MESSAGES = 15 + + +@dataclass +class ContextWindowResult: + """Result of context window management.""" + + messages: list[dict[str, Any]] + token_count: int + was_compacted: bool + error: str | None = None + + +def _messages_to_dicts(messages: list) -> list[dict[str, Any]]: + """Convert message objects to dicts, filtering None values. + + Handles both TypedDict (dict-like) and other message formats. + """ + result = [] + for msg in messages: + if msg is None: + continue + if isinstance(msg, dict): + msg_dict = {k: v for k, v in msg.items() if v is not None} + else: + msg_dict = dict(msg) + result.append(msg_dict) + return result + + +async def _manage_context_window( + messages: list, + model: str, + api_key: str | None = None, + base_url: str | None = None, +) -> ContextWindowResult: + """ + Manage context window by summarizing old messages if token count exceeds threshold. + + This function handles context compaction for LLM calls by: + 1. Counting tokens in the message list + 2. If over threshold, summarizing old messages while keeping recent ones + 3. Ensuring tool_call/tool_response pairs stay intact + 4. Progressively reducing message count if still over limit + + Args: + messages: List of messages in OpenAI format (with system prompt if present) + model: Model name for token counting + api_key: API key for summarization calls + base_url: Base URL for summarization calls + + Returns: + ContextWindowResult with compacted messages and metadata + """ + if not messages: + return ContextWindowResult([], 0, False, "No messages to compact") + + messages_dict = _messages_to_dicts(messages) + + # Normalize model name for token counting (tiktoken only supports OpenAI models) + token_count_model = model.split("/")[-1] if "/" in model else model + if "claude" in token_count_model.lower() or not any( + known in token_count_model.lower() + for known in ["gpt", "o1", "chatgpt", "text-"] + ): + token_count_model = "gpt-4o" + + try: + token_count = estimate_token_count(messages_dict, model=token_count_model) + except Exception as e: + logger.warning(f"Token counting failed: {e}. Using gpt-4o approximation.") + token_count_model = "gpt-4o" + token_count = estimate_token_count(messages_dict, model=token_count_model) + + if token_count <= TOKEN_THRESHOLD: + return ContextWindowResult(messages, token_count, False) + + has_system_prompt = messages[0].get("role") == "system" + slice_start = max(0, len(messages_dict) - KEEP_RECENT_MESSAGES) + recent_messages = _ensure_tool_pairs_intact( + messages_dict[-KEEP_RECENT_MESSAGES:], messages_dict, slice_start + ) + + # Determine old messages to summarize (explicit bounds to avoid slice edge cases) + system_msg = messages[0] if has_system_prompt else None + if has_system_prompt: + old_messages_dict = ( + messages_dict[1:-KEEP_RECENT_MESSAGES] + if len(messages_dict) > KEEP_RECENT_MESSAGES + 1 + else [] + ) + else: + old_messages_dict = ( + messages_dict[:-KEEP_RECENT_MESSAGES] + if len(messages_dict) > KEEP_RECENT_MESSAGES + else [] + ) + + # Try to summarize old messages, fall back to truncation on failure + summary_msg = None + if old_messages_dict: + try: + summary_text = await _summarize_messages( + old_messages_dict, model=model, api_key=api_key, base_url=base_url + ) + summary_msg = ChatCompletionAssistantMessageParam( + role="assistant", + content=f"[Previous conversation summary — for context only]: {summary_text}", + ) + base = [system_msg, summary_msg] if has_system_prompt else [summary_msg] + messages = base + recent_messages + logger.info( + f"Context summarized: {token_count} tokens, " + f"summarized {len(old_messages_dict)} msgs, kept {KEEP_RECENT_MESSAGES}" + ) + except Exception as e: + logger.warning(f"Summarization failed, falling back to truncation: {e}") + messages = ( + [system_msg] + recent_messages if has_system_prompt else recent_messages + ) + else: + logger.warning( + f"Token count {token_count} exceeds threshold but no old messages to summarize" + ) + + new_token_count = estimate_token_count( + _messages_to_dicts(messages), model=token_count_model + ) + + # Progressive truncation if still over limit + if new_token_count > TOKEN_THRESHOLD: + logger.warning( + f"Still over limit: {new_token_count} tokens. Reducing messages." + ) + base_msgs = ( + recent_messages + if old_messages_dict + else (messages_dict[1:] if has_system_prompt else messages_dict) + ) + + def build_messages(recent: list) -> list: + """Build message list with optional system prompt and summary.""" + prefix = [] + if has_system_prompt and system_msg: + prefix.append(system_msg) + if summary_msg: + prefix.append(summary_msg) + return prefix + recent + + for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]: + if keep_count == 0: + messages = build_messages([]) + if not messages: + continue + elif len(base_msgs) < keep_count: + continue + else: + reduced = _ensure_tool_pairs_intact( + base_msgs[-keep_count:], + base_msgs, + max(0, len(base_msgs) - keep_count), + ) + messages = build_messages(reduced) + + new_token_count = estimate_token_count( + _messages_to_dicts(messages), model=token_count_model + ) + if new_token_count <= TOKEN_THRESHOLD: + logger.info( + f"Reduced to {keep_count} messages, {new_token_count} tokens" + ) + break + else: + logger.error( + f"Cannot reduce below threshold. Final: {new_token_count} tokens" + ) + if has_system_prompt and len(messages) > 1: + messages = messages[1:] + logger.critical("Dropped system prompt as last resort") + return ContextWindowResult( + messages, new_token_count, True, "System prompt dropped" + ) + # No system prompt to drop - return error so callers don't proceed with oversized context + return ContextWindowResult( + messages, + new_token_count, + True, + "Unable to reduce context below token limit", + ) + + return ContextWindowResult(messages, new_token_count, True) + + async def _summarize_messages( messages: list, model: str, @@ -1022,11 +1226,8 @@ async def _stream_chat_chunks( logger.info("Starting pure chat stream") - # Build messages with system prompt prepended messages = session.to_openai_messages() if system_prompt: - from openai.types.chat import ChatCompletionSystemMessageParam - system_message = ChatCompletionSystemMessageParam( role="system", content=system_prompt, @@ -1034,314 +1235,38 @@ async def _stream_chat_chunks( messages = [system_message] + messages # Apply context window management - token_count = 0 # Initialize for exception handler - try: - from backend.util.prompt import estimate_token_count + context_result = await _manage_context_window( + messages=messages, + model=model, + api_key=config.api_key, + base_url=config.base_url, + ) - # Convert to dict for token counting - # OpenAI message types are TypedDicts, so they're already dict-like - messages_dict = [] - for msg in messages: - # TypedDict objects are already dicts, just filter None values - if isinstance(msg, dict): - msg_dict = {k: v for k, v in msg.items() if v is not None} - else: - # Fallback for unexpected types - msg_dict = dict(msg) - messages_dict.append(msg_dict) - - # Estimate tokens using appropriate tokenizer - # Normalize model name for token counting (tiktoken only supports OpenAI models) - token_count_model = model - if "/" in model: - # Strip provider prefix (e.g., "anthropic/claude-opus-4.5" -> "claude-opus-4.5") - token_count_model = model.split("/")[-1] - - # For Claude and other non-OpenAI models, approximate with gpt-4o tokenizer - # Most modern LLMs have similar tokenization (~1 token per 4 chars) - if "claude" in token_count_model.lower() or not any( - known in token_count_model.lower() - for known in ["gpt", "o1", "chatgpt", "text-"] - ): - token_count_model = "gpt-4o" - - # Attempt token counting with error handling - try: - token_count = estimate_token_count(messages_dict, model=token_count_model) - except Exception as token_error: - # If token counting fails, use gpt-4o as fallback approximation - logger.warning( - f"Token counting failed for model {token_count_model}: {token_error}. " - "Using gpt-4o approximation." - ) - token_count = estimate_token_count(messages_dict, model="gpt-4o") - - # If over threshold, summarize old messages - if token_count > 120_000: - KEEP_RECENT = 15 - - # Check if we have a system prompt at the start - has_system_prompt = ( - len(messages) > 0 and messages[0].get("role") == "system" - ) - - # Always attempt mitigation when over limit, even with few messages - if messages: - # Split messages based on whether system prompt exists - # Calculate start index for the slice - slice_start = max(0, len(messages_dict) - KEEP_RECENT) - recent_messages = messages_dict[-KEEP_RECENT:] - - # Ensure tool_call/tool_response pairs stay together - # This prevents API errors from orphan tool responses - recent_messages = _ensure_tool_pairs_intact( - recent_messages, messages_dict, slice_start - ) - - if has_system_prompt: - # Keep system prompt separate, summarize everything between system and recent - system_msg = messages[0] - old_messages_dict = messages_dict[1:-KEEP_RECENT] - else: - # No system prompt, summarize everything except recent - system_msg = None - old_messages_dict = messages_dict[:-KEEP_RECENT] - - # Summarize any non-empty old messages (no minimum threshold) - # If we're over the token limit, we need to compress whatever we can - if old_messages_dict: - # Summarize old messages using the same model as chat - summary_text = await _summarize_messages( - old_messages_dict, - model=model, - api_key=config.api_key, - base_url=config.base_url, - ) - - # Build new message list - # Use assistant role (not system) to prevent privilege escalation - # of user-influenced content to instruction-level authority - from openai.types.chat import ChatCompletionAssistantMessageParam - - summary_msg = ChatCompletionAssistantMessageParam( - role="assistant", - content=( - "[Previous conversation summary — for context only]: " - f"{summary_text}" - ), - ) - - # Rebuild messages based on whether we have a system prompt - if has_system_prompt: - # system_prompt + summary + recent_messages - messages = [system_msg, summary_msg] + recent_messages - else: - # summary + recent_messages (no original system prompt) - messages = [summary_msg] + recent_messages - - logger.info( - f"Context summarized: {token_count} tokens, " - f"summarized {len(old_messages_dict)} old messages, " - f"kept last {KEEP_RECENT} messages" - ) - - # Fallback: If still over limit after summarization, progressively drop recent messages - # This handles edge cases where recent messages are extremely large - new_messages_dict = [] - for msg in messages: - if isinstance(msg, dict): - msg_dict = {k: v for k, v in msg.items() if v is not None} - else: - msg_dict = dict(msg) - new_messages_dict.append(msg_dict) - - new_token_count = estimate_token_count( - new_messages_dict, model=token_count_model - ) - - if new_token_count > 120_000: - # Still over limit - progressively reduce KEEP_RECENT - logger.warning( - f"Still over limit after summarization: {new_token_count} tokens. " - "Reducing number of recent messages kept." - ) - - for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]: - if keep_count == 0: - # Try with just system prompt + summary (no recent messages) - if has_system_prompt: - messages = [system_msg, summary_msg] - else: - messages = [summary_msg] - logger.info( - "Trying with 0 recent messages (system + summary only)" - ) - else: - # Slice from ORIGINAL recent_messages to avoid duplicating summary - reduced_recent = ( - recent_messages[-keep_count:] - if len(recent_messages) >= keep_count - else recent_messages - ) - # Ensure tool pairs stay intact in the reduced slice - reduced_slice_start = max( - 0, len(recent_messages) - keep_count - ) - reduced_recent = _ensure_tool_pairs_intact( - reduced_recent, recent_messages, reduced_slice_start - ) - if has_system_prompt: - messages = [ - system_msg, - summary_msg, - ] + reduced_recent - else: - messages = [summary_msg] + reduced_recent - - new_messages_dict = [] - for msg in messages: - if isinstance(msg, dict): - msg_dict = { - k: v for k, v in msg.items() if v is not None - } - else: - msg_dict = dict(msg) - new_messages_dict.append(msg_dict) - - new_token_count = estimate_token_count( - new_messages_dict, model=token_count_model - ) - - if new_token_count <= 120_000: - logger.info( - f"Reduced to {keep_count} recent messages, " - f"now {new_token_count} tokens" - ) - break - else: - logger.error( - f"Unable to reduce token count below threshold even with 0 messages. " - f"Final count: {new_token_count} tokens" - ) - # ABSOLUTE LAST RESORT: Drop system prompt - # This should only happen if summary itself is massive - if has_system_prompt and len(messages) > 1: - messages = messages[1:] # Drop system prompt - logger.critical( - "CRITICAL: Dropped system prompt as absolute last resort. " - "Behavioral consistency may be affected." - ) - # Yield error to user - yield StreamError( - errorText=( - "Warning: System prompt dropped due to size constraints. " - "Assistant behavior may be affected." - ) - ) - else: - # No old messages to summarize - all messages are "recent" - # Apply progressive truncation to reduce token count - logger.warning( - f"Token count {token_count} exceeds threshold but no old messages to summarize. " - f"Applying progressive truncation to recent messages." - ) - - # Create a base list excluding system prompt to avoid duplication - # This is the pool of messages we'll slice from in the loop - # Use messages_dict for type consistency with _ensure_tool_pairs_intact - base_msgs = ( - messages_dict[1:] if has_system_prompt else messages_dict - ) - - # Try progressively smaller keep counts - new_token_count = token_count # Initialize with current count - for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]: - if keep_count == 0: - # Try with just system prompt (no recent messages) - if has_system_prompt: - messages = [system_msg] - logger.info( - "Trying with 0 recent messages (system prompt only)" - ) - else: - # No system prompt and no recent messages = empty messages list - # This is invalid, skip this iteration - continue - else: - if len(base_msgs) < keep_count: - continue # Skip if we don't have enough messages - - # Slice from base_msgs to get recent messages (without system prompt) - recent_messages = base_msgs[-keep_count:] - - # Ensure tool pairs stay intact in the reduced slice - reduced_slice_start = max(0, len(base_msgs) - keep_count) - recent_messages = _ensure_tool_pairs_intact( - recent_messages, base_msgs, reduced_slice_start - ) - - if has_system_prompt: - messages = [system_msg] + recent_messages - else: - messages = recent_messages - - new_messages_dict = [] - for msg in messages: - if msg is None: - continue # Skip None messages (type safety) - if isinstance(msg, dict): - msg_dict = { - k: v for k, v in msg.items() if v is not None - } - else: - msg_dict = dict(msg) - new_messages_dict.append(msg_dict) - - new_token_count = estimate_token_count( - new_messages_dict, model=token_count_model - ) - - if new_token_count <= 120_000: - logger.info( - f"Reduced to {keep_count} recent messages, " - f"now {new_token_count} tokens" - ) - break - else: - # Even with 0 messages still over limit - logger.error( - f"Unable to reduce token count below threshold even with 0 messages. " - f"Final count: {new_token_count} tokens. Messages may be extremely large." - ) - # ABSOLUTE LAST RESORT: Drop system prompt - if has_system_prompt and len(messages) > 1: - messages = messages[1:] # Drop system prompt - logger.critical( - "CRITICAL: Dropped system prompt as absolute last resort. " - "Behavioral consistency may be affected." - ) - # Yield error to user - yield StreamError( - errorText=( - "Warning: System prompt dropped due to size constraints. " - "Assistant behavior may be affected." - ) - ) - - except Exception as e: - logger.error(f"Context summarization failed: {e}", exc_info=True) - # If we were over the token limit, yield error to user - # Don't silently continue with oversized messages that will fail - if token_count > 120_000: + if context_result.error: + if "System prompt dropped" in context_result.error: + # Warning only - continue with reduced context yield StreamError( errorText=( - f"Unable to manage context window (token limit exceeded: {token_count} tokens). " - "Context summarization failed. Please start a new conversation." + "Warning: System prompt dropped due to size constraints. " + "Assistant behavior may be affected." + ) + ) + else: + # Any other error - abort to prevent failed LLM calls + yield StreamError( + errorText=( + f"Context window management failed: {context_result.error}. " + "Please start a new conversation." ) ) yield StreamFinish() return - # Otherwise, continue with original messages (under limit) + + messages = context_result.messages + if context_result.was_compacted: + logger.info( + f"Context compacted for streaming: {context_result.token_count} tokens" + ) # Loop to handle tool calls and continue conversation while True: @@ -1369,14 +1294,6 @@ async def _stream_chat_chunks( :128 ] # OpenRouter limit - # Create the stream with proper types - from typing import cast - - from openai.types.chat import ( - ChatCompletionMessageParam, - ChatCompletionStreamOptionsParam, - ) - stream = await client.chat.completions.create( model=model, messages=cast(list[ChatCompletionMessageParam], messages), @@ -1900,17 +1817,36 @@ async def _generate_llm_continuation( # Build system prompt system_prompt, _ = await _build_system_prompt(user_id) - # Build messages in OpenAI format messages = session.to_openai_messages() if system_prompt: - from openai.types.chat import ChatCompletionSystemMessageParam - system_message = ChatCompletionSystemMessageParam( role="system", content=system_prompt, ) messages = [system_message] + messages + # Apply context window management to prevent oversized requests + context_result = await _manage_context_window( + messages=messages, + model=config.model, + api_key=config.api_key, + base_url=config.base_url, + ) + + if context_result.error and "System prompt dropped" not in context_result.error: + logger.error( + f"Context window management failed for session {session_id}: " + f"{context_result.error} (tokens={context_result.token_count})" + ) + return + + messages = context_result.messages + if context_result.was_compacted: + logger.info( + f"Context compacted for LLM continuation: " + f"{context_result.token_count} tokens" + ) + # Build extra_body for tracing extra_body: dict[str, Any] = { "posthogProperties": { @@ -1923,19 +1859,54 @@ async def _generate_llm_continuation( if session_id: extra_body["session_id"] = session_id[:128] - # Make non-streaming LLM call (no tools - just text response) - from typing import cast + retry_count = 0 + last_error: Exception | None = None + response = None - from openai.types.chat import ChatCompletionMessageParam + while retry_count <= MAX_RETRIES: + try: + logger.info( + f"Generating LLM continuation for session {session_id}" + f"{f' (retry {retry_count}/{MAX_RETRIES})' if retry_count > 0 else ''}" + ) - # No tools parameter = text-only response (no tool calls) - response = await client.chat.completions.create( - model=config.model, - messages=cast(list[ChatCompletionMessageParam], messages), - extra_body=extra_body, - ) + response = await client.chat.completions.create( + model=config.model, + messages=cast(list[ChatCompletionMessageParam], messages), + extra_body=extra_body, + ) + last_error = None # Clear any previous error on success + break # Success, exit retry loop + except Exception as e: + last_error = e + if _is_retryable_error(e) and retry_count < MAX_RETRIES: + retry_count += 1 + delay = min( + BASE_DELAY_SECONDS * (2 ** (retry_count - 1)), + MAX_DELAY_SECONDS, + ) + logger.warning( + f"Retryable error in LLM continuation: {e!s}. " + f"Retrying in {delay:.1f}s (attempt {retry_count}/{MAX_RETRIES})" + ) + await asyncio.sleep(delay) + continue + else: + # Non-retryable error - log and exit gracefully + logger.error( + f"Non-retryable error in LLM continuation: {e!s}", + exc_info=True, + ) + return - if response.choices and response.choices[0].message.content: + if last_error: + logger.error( + f"Max retries ({MAX_RETRIES}) exceeded for LLM continuation. " + f"Last error: {last_error!s}" + ) + return + + if response and response.choices and response.choices[0].message.content: assistant_content = response.choices[0].message.content # Reload session from DB to avoid race condition with user messages