Compare commits

..

10 Commits

Author SHA1 Message Date
Bently
2eaa13b141 Merge branch 'dev' into feat/text-encode-block 2026-02-03 10:07:00 +00:00
Nick Tindle
283b77e7e0 docs: sync block docs for TextEncoderBlock 2026-02-02 14:04:21 -06:00
Nicholas Tindle
ac5aa0a5f4 Merge branch 'dev' into feat/text-encode-block 2026-02-02 12:16:53 -06:00
Otto
0f300d7884 feat(blocks): add error handling and tests to TextEncoderBlock
- Add error output field and try/except handling for encoding failures
- Add comprehensive unit tests covering basic encoding, multiple escapes,
  unicode handling, empty strings, and error scenarios

Co-authored-by: lif <19658300+majiayu000@users.noreply.github.com>
Co-authored-by: Aryan Kaul <134673289+aryancodes1@users.noreply.github.com>
2026-02-02 17:52:53 +00:00
root
7dfc816280 style: fix Black formatting in encoder_block.py
- Add blank line after class docstring before nested class
- Reformat test_input dict for proper line length
2026-01-29 13:02:15 +00:00
root
378126e60f docs: fix block docs sync for Text Encoder
- Move Text Encoder docs from standalone encoder_block.md into text.md
  (matches CATEGORY_FILE_MAP for TEXT category blocks)
- Add Text Encoder to overview table in README.md
- Remove orphaned encoder_block.md
- Follows exact format expected by generate_block_docs.py
2026-01-29 12:53:09 +00:00
root
21a1d993b8 docs: address CodeRabbit review feedback for encoder_block.md
- Expand 'How it works' with technical details (unicode_escape encoding, validation, edge cases)
- Add blank lines around tables (MD058)
- Add language tags to fenced code blocks (MD040)
- Replace 'Possible use case' with 3 structured use cases per docs guidelines
- Separate example into its own section
2026-01-29 12:43:45 +00:00
Bently
e0862e8086 fix: add docs and fix error handling in TextEncoderBlock
- Add encoder_block.md documentation
- Remove error handling that yields undeclared output field
- Match pattern used by TextDecoderBlock
2026-01-28 11:05:03 +00:00
Bently
b1259e0bdd docs(blocks): Add docstrings and error handling to TextEncoderBlock
- Add module, class, and method docstrings for 80%+ coverage
- Add try/except error handling per CodeRabbit review
- Use inherited error field from BlockSchemaOutput
2026-01-27 15:58:19 +00:00
Bently
5244bd94fc feat(blocks): Implement Text Encode block (fixes #11111) 2026-01-27 15:48:10 +00:00
16 changed files with 939 additions and 1679 deletions

View File

@@ -17,14 +17,6 @@ from .model import ChatSession, create_chat_session, get_chat_session, get_user_
config = ChatConfig() config = ChatConfig()
# SSE response headers for streaming
SSE_RESPONSE_HEADERS = {
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
"x-vercel-ai-ui-message-stream": "v1",
}
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -40,60 +32,6 @@ async def _validate_and_get_session(
return session return session
async def _create_stream_generator(
session_id: str,
message: str,
user_id: str | None,
session: ChatSession,
is_user_message: bool = True,
context: dict[str, str] | None = None,
) -> AsyncGenerator[str, None]:
"""Create SSE event generator for chat streaming.
Args:
session_id: Chat session ID
message: User message to process
user_id: Optional authenticated user ID
session: Pre-fetched chat session
is_user_message: Whether the message is from a user
context: Optional context dict with url and content
Yields:
SSE-formatted chunks from the chat completion stream
"""
chunk_count = 0
first_chunk_type: str | None = None
async for chunk in chat_service.stream_chat_completion(
session_id,
message,
is_user_message=is_user_message,
user_id=user_id,
session=session,
context=context,
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
yield "data: [DONE]\n\n"
router = APIRouter( router = APIRouter(
tags=["chat"], tags=["chat"],
) )
@@ -283,17 +221,49 @@ async def stream_chat_post(
""" """
session = await _validate_and_get_session(session_id, user_id) session = await _validate_and_get_session(session_id, user_id)
return StreamingResponse( async def event_generator() -> AsyncGenerator[str, None]:
_create_stream_generator( chunk_count = 0
session_id=session_id, first_chunk_type: str | None = None
message=request.message, async for chunk in chat_service.stream_chat_completion(
user_id=user_id, session_id,
session=session, request.message,
is_user_message=request.is_user_message, is_user_message=request.is_user_message,
user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
context=request.context, context=request.context,
), ):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream", media_type="text/event-stream",
headers=SSE_RESPONSE_HEADERS, headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
) )
@@ -325,16 +295,48 @@ async def stream_chat_get(
""" """
session = await _validate_and_get_session(session_id, user_id) session = await _validate_and_get_session(session_id, user_id)
return StreamingResponse( async def event_generator() -> AsyncGenerator[str, None]:
_create_stream_generator( chunk_count = 0
session_id=session_id, first_chunk_type: str | None = None
message=message, async for chunk in chat_service.stream_chat_completion(
user_id=user_id, session_id,
session=session, message,
is_user_message=is_user_message, is_user_message=is_user_message,
), user_id=user_id,
session=session, # Pass pre-fetched session to avoid double-fetch
):
if chunk_count < 3:
logger.info(
"Chat stream chunk",
extra={
"session_id": session_id,
"chunk_type": str(chunk.type),
},
)
if not first_chunk_type:
first_chunk_type = str(chunk.type)
chunk_count += 1
yield chunk.to_sse()
logger.info(
"Chat stream completed",
extra={
"session_id": session_id,
"chunk_count": chunk_count,
"first_chunk_type": first_chunk_type,
},
)
# AI SDK protocol termination
yield "data: [DONE]\n\n"
return StreamingResponse(
event_generator(),
media_type="text/event-stream", media_type="text/event-stream",
headers=SSE_RESPONSE_HEADERS, headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no", # Disable nginx buffering
"x-vercel-ai-ui-message-stream": "v1", # AI SDK protocol header
},
) )

View File

@@ -3,13 +3,10 @@ import logging
import time import time
from asyncio import CancelledError from asyncio import CancelledError
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, cast from dataclasses import dataclass
from typing import Any, cast
import openai import openai
if TYPE_CHECKING:
from backend.util.prompt import CompressResult
import orjson import orjson
from langfuse import get_client from langfuse import get_client
from openai import ( from openai import (
@@ -20,6 +17,7 @@ from openai import (
RateLimitError, RateLimitError,
) )
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionChunk, ChatCompletionChunk,
ChatCompletionMessageParam, ChatCompletionMessageParam,
ChatCompletionStreamOptionsParam, ChatCompletionStreamOptionsParam,
@@ -33,6 +31,7 @@ from backend.data.understanding import (
get_business_understanding, get_business_understanding,
) )
from backend.util.exceptions import NotFoundError from backend.util.exceptions import NotFoundError
from backend.util.prompt import estimate_token_count
from backend.util.settings import Settings from backend.util.settings import Settings
from . import db as chat_db from . import db as chat_db
@@ -804,59 +803,403 @@ def _is_region_blocked_error(error: Exception) -> bool:
return "not available in your region" in str(error).lower() return "not available in your region" in str(error).lower()
# Context window management constants
TOKEN_THRESHOLD = 120_000
KEEP_RECENT_MESSAGES = 15
@dataclass
class ContextWindowResult:
"""Result of context window management."""
messages: list[dict[str, Any]]
token_count: int
was_compacted: bool
error: str | None = None
def _messages_to_dicts(messages: list) -> list[dict[str, Any]]:
"""Convert message objects to dicts, filtering None values.
Handles both TypedDict (dict-like) and other message formats.
"""
result = []
for msg in messages:
if msg is None:
continue
if isinstance(msg, dict):
msg_dict = {k: v for k, v in msg.items() if v is not None}
else:
msg_dict = dict(msg)
result.append(msg_dict)
return result
async def _manage_context_window( async def _manage_context_window(
messages: list, messages: list,
model: str, model: str,
api_key: str | None = None, api_key: str | None = None,
base_url: str | None = None, base_url: str | None = None,
) -> "CompressResult": ) -> ContextWindowResult:
""" """
Manage context window using the unified compress_context function. Manage context window by summarizing old messages if token count exceeds threshold.
This is a thin wrapper that creates an OpenAI client for summarization This function handles context compaction for LLM calls by:
and delegates to the shared compression logic in prompt.py. 1. Counting tokens in the message list
2. If over threshold, summarizing old messages while keeping recent ones
3. Ensuring tool_call/tool_response pairs stay intact
4. Progressively reducing message count if still over limit
Args: Args:
messages: List of messages in OpenAI format messages: List of messages in OpenAI format (with system prompt if present)
model: Model name for token counting and summarization model: Model name for token counting
api_key: API key for summarization calls api_key: API key for summarization calls
base_url: Base URL for summarization calls base_url: Base URL for summarization calls
Returns: Returns:
CompressResult with compacted messages and metadata ContextWindowResult with compacted messages and metadata
""" """
if not messages:
return ContextWindowResult([], 0, False, "No messages to compact")
messages_dict = _messages_to_dicts(messages)
# Normalize model name for token counting (tiktoken only supports OpenAI models)
token_count_model = model.split("/")[-1] if "/" in model else model
if "claude" in token_count_model.lower() or not any(
known in token_count_model.lower()
for known in ["gpt", "o1", "chatgpt", "text-"]
):
token_count_model = "gpt-4o"
try:
token_count = estimate_token_count(messages_dict, model=token_count_model)
except Exception as e:
logger.warning(f"Token counting failed: {e}. Using gpt-4o approximation.")
token_count_model = "gpt-4o"
token_count = estimate_token_count(messages_dict, model=token_count_model)
if token_count <= TOKEN_THRESHOLD:
return ContextWindowResult(messages, token_count, False)
has_system_prompt = messages[0].get("role") == "system"
slice_start = max(0, len(messages_dict) - KEEP_RECENT_MESSAGES)
recent_messages = _ensure_tool_pairs_intact(
messages_dict[-KEEP_RECENT_MESSAGES:], messages_dict, slice_start
)
# Determine old messages to summarize (explicit bounds to avoid slice edge cases)
system_msg = messages[0] if has_system_prompt else None
if has_system_prompt:
old_messages_dict = (
messages_dict[1:-KEEP_RECENT_MESSAGES]
if len(messages_dict) > KEEP_RECENT_MESSAGES + 1
else []
)
else:
old_messages_dict = (
messages_dict[:-KEEP_RECENT_MESSAGES]
if len(messages_dict) > KEEP_RECENT_MESSAGES
else []
)
# Try to summarize old messages, fall back to truncation on failure
summary_msg = None
if old_messages_dict:
try:
summary_text = await _summarize_messages(
old_messages_dict, model=model, api_key=api_key, base_url=base_url
)
summary_msg = ChatCompletionAssistantMessageParam(
role="assistant",
content=f"[Previous conversation summary — for context only]: {summary_text}",
)
base = [system_msg, summary_msg] if has_system_prompt else [summary_msg]
messages = base + recent_messages
logger.info(
f"Context summarized: {token_count} tokens, "
f"summarized {len(old_messages_dict)} msgs, kept {KEEP_RECENT_MESSAGES}"
)
except Exception as e:
logger.warning(f"Summarization failed, falling back to truncation: {e}")
messages = (
[system_msg] + recent_messages if has_system_prompt else recent_messages
)
else:
logger.warning(
f"Token count {token_count} exceeds threshold but no old messages to summarize"
)
new_token_count = estimate_token_count(
_messages_to_dicts(messages), model=token_count_model
)
# Progressive truncation if still over limit
if new_token_count > TOKEN_THRESHOLD:
logger.warning(
f"Still over limit: {new_token_count} tokens. Reducing messages."
)
base_msgs = (
recent_messages
if old_messages_dict
else (messages_dict[1:] if has_system_prompt else messages_dict)
)
def build_messages(recent: list) -> list:
"""Build message list with optional system prompt and summary."""
prefix = []
if has_system_prompt and system_msg:
prefix.append(system_msg)
if summary_msg:
prefix.append(summary_msg)
return prefix + recent
for keep_count in [12, 10, 8, 5, 3, 2, 1, 0]:
if keep_count == 0:
messages = build_messages([])
if not messages:
continue
elif len(base_msgs) < keep_count:
continue
else:
reduced = _ensure_tool_pairs_intact(
base_msgs[-keep_count:],
base_msgs,
max(0, len(base_msgs) - keep_count),
)
messages = build_messages(reduced)
new_token_count = estimate_token_count(
_messages_to_dicts(messages), model=token_count_model
)
if new_token_count <= TOKEN_THRESHOLD:
logger.info(
f"Reduced to {keep_count} messages, {new_token_count} tokens"
)
break
else:
logger.error(
f"Cannot reduce below threshold. Final: {new_token_count} tokens"
)
if has_system_prompt and len(messages) > 1:
messages = messages[1:]
logger.critical("Dropped system prompt as last resort")
return ContextWindowResult(
messages, new_token_count, True, "System prompt dropped"
)
# No system prompt to drop - return error so callers don't proceed with oversized context
return ContextWindowResult(
messages,
new_token_count,
True,
"Unable to reduce context below token limit",
)
return ContextWindowResult(messages, new_token_count, True)
async def _summarize_messages(
messages: list,
model: str,
api_key: str | None = None,
base_url: str | None = None,
timeout: float = 30.0,
) -> str:
"""Summarize a list of messages into concise context.
Uses the same model as the chat for higher quality summaries.
Args:
messages: List of message dicts to summarize
model: Model to use for summarization (same as chat model)
api_key: API key for OpenAI client
base_url: Base URL for OpenAI client
timeout: Request timeout in seconds (default: 30.0)
Returns:
Summarized text
"""
# Format messages for summarization
conversation = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
# Include user, assistant, and tool messages (tool outputs are important context)
if content and role in ("user", "assistant", "tool"):
conversation.append(f"{role.upper()}: {content}")
conversation_text = "\n\n".join(conversation)
# Handle empty conversation
if not conversation_text:
return "No conversation history available."
# Truncate conversation to fit within summarization model's context
# gpt-4o-mini has 128k context, but we limit to ~25k tokens (~100k chars) for safety
MAX_CHARS = 100_000
if len(conversation_text) > MAX_CHARS:
conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]"
# Call LLM to summarize
import openai import openai
from backend.util.prompt import compress_context summarization_client = openai.AsyncOpenAI(
api_key=api_key, base_url=base_url, timeout=timeout
# Convert messages to dict format
messages_dict = []
for msg in messages:
if isinstance(msg, dict):
msg_dict = {k: v for k, v in msg.items() if v is not None}
else:
msg_dict = dict(msg)
messages_dict.append(msg_dict)
# Only create client if api_key is provided (enables summarization)
# Use context manager to avoid socket leaks
if api_key:
async with openai.AsyncOpenAI(
api_key=api_key, base_url=base_url, timeout=30.0
) as client:
return await compress_context(
messages=messages_dict,
model=model,
client=client,
) )
else:
# No API key - use truncation-only mode response = await summarization_client.chat.completions.create(
return await compress_context(
messages=messages_dict,
model=model, model=model,
client=None, messages=[
{
"role": "system",
"content": (
"Create a detailed summary of the conversation so far. "
"This summary will be used as context when continuing the conversation.\n\n"
"Before writing the summary, analyze each message chronologically to identify:\n"
"- User requests and their explicit goals\n"
"- Your approach and key decisions made\n"
"- Technical specifics (file names, tool outputs, function signatures)\n"
"- Errors encountered and resolutions applied\n\n"
"You MUST include ALL of the following sections:\n\n"
"## 1. Primary Request and Intent\n"
"The user's explicit goals and what they are trying to accomplish.\n\n"
"## 2. Key Technical Concepts\n"
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
"## 3. Files and Resources Involved\n"
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
"## 4. Errors and Fixes\n"
"Problems encountered, error messages, and their resolutions. "
"Include any user feedback on fixes.\n\n"
"## 5. Problem Solving\n"
"Issues that have been resolved and how they were addressed.\n\n"
"## 6. All User Messages\n"
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
"## 7. Pending Tasks\n"
"Work items the user explicitly requested that have not yet been completed.\n\n"
"## 8. Current Work\n"
"Precise description of what was being worked on most recently, including relevant context.\n\n"
"## 9. Next Steps\n"
"What should happen next, aligned with the user's most recent requests. "
"Include verbatim quotes of recent instructions if relevant."
),
},
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
],
max_tokens=1500,
temperature=0.3,
) )
summary = response.choices[0].message.content
return summary or "No summary available."
def _ensure_tool_pairs_intact(
recent_messages: list[dict],
all_messages: list[dict],
start_index: int,
) -> list[dict]:
"""
Ensure tool_call/tool_response pairs stay together after slicing.
When slicing messages for context compaction, a naive slice can separate
an assistant message containing tool_calls from its corresponding tool
response messages. This causes API validation errors (e.g., Anthropic's
"unexpected tool_use_id found in tool_result blocks").
This function checks for orphan tool responses in the slice and extends
backwards to include their corresponding assistant messages.
Args:
recent_messages: The sliced messages to validate
all_messages: The complete message list (for looking up missing assistants)
start_index: The index in all_messages where recent_messages begins
Returns:
A potentially extended list of messages with tool pairs intact
"""
if not recent_messages:
return recent_messages
# Collect all tool_call_ids from assistant messages in the slice
available_tool_call_ids: set[str] = set()
for msg in recent_messages:
if msg.get("role") == "assistant" and msg.get("tool_calls"):
for tc in msg["tool_calls"]:
tc_id = tc.get("id")
if tc_id:
available_tool_call_ids.add(tc_id)
# Find orphan tool responses (tool messages whose tool_call_id is missing)
orphan_tool_call_ids: set[str] = set()
for msg in recent_messages:
if msg.get("role") == "tool":
tc_id = msg.get("tool_call_id")
if tc_id and tc_id not in available_tool_call_ids:
orphan_tool_call_ids.add(tc_id)
if not orphan_tool_call_ids:
# No orphans, slice is valid
return recent_messages
# Find the assistant messages that contain the orphan tool_call_ids
# Search backwards from start_index in all_messages
messages_to_prepend: list[dict] = []
for i in range(start_index - 1, -1, -1):
msg = all_messages[i]
if msg.get("role") == "assistant" and msg.get("tool_calls"):
msg_tool_ids = {tc.get("id") for tc in msg["tool_calls"] if tc.get("id")}
if msg_tool_ids & orphan_tool_call_ids:
# This assistant message has tool_calls we need
# Also collect its contiguous tool responses that follow it
assistant_and_responses: list[dict] = [msg]
# Scan forward from this assistant to collect tool responses
for j in range(i + 1, start_index):
following_msg = all_messages[j]
if following_msg.get("role") == "tool":
tool_id = following_msg.get("tool_call_id")
if tool_id and tool_id in msg_tool_ids:
assistant_and_responses.append(following_msg)
else:
# Stop at first non-tool message
break
# Prepend the assistant and its tool responses (maintain order)
messages_to_prepend = assistant_and_responses + messages_to_prepend
# Mark these as found
orphan_tool_call_ids -= msg_tool_ids
# Also add this assistant's tool_call_ids to available set
available_tool_call_ids |= msg_tool_ids
if not orphan_tool_call_ids:
# Found all missing assistants
break
if orphan_tool_call_ids:
# Some tool_call_ids couldn't be resolved - remove those tool responses
# This shouldn't happen in normal operation but handles edge cases
logger.warning(
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
"Removing orphan tool responses."
)
recent_messages = [
msg
for msg in recent_messages
if not (
msg.get("role") == "tool"
and msg.get("tool_call_id") in orphan_tool_call_ids
)
]
if messages_to_prepend:
logger.info(
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
f"tool_call/tool_response pairs"
)
return messages_to_prepend + recent_messages
return recent_messages
async def _stream_chat_chunks( async def _stream_chat_chunks(
session: ChatSession, session: ChatSession,

View File

@@ -1,77 +0,0 @@
"""Shared helpers for chat tools."""
from typing import Any
from .models import ErrorResponse
def error_response(
message: str, session_id: str | None, **kwargs: Any
) -> ErrorResponse:
"""Create standardized error response.
Args:
message: Error message to display
session_id: Current session ID
**kwargs: Additional fields to pass to ErrorResponse
Returns:
ErrorResponse with the given message and session_id
"""
return ErrorResponse(message=message, session_id=session_id, **kwargs)
def get_inputs_from_schema(
input_schema: dict[str, Any],
exclude_fields: set[str] | None = None,
) -> list[dict[str, Any]]:
"""Extract input field info from JSON schema.
Args:
input_schema: JSON schema dict with 'properties' and 'required'
exclude_fields: Set of field names to exclude (e.g., credential fields)
Returns:
List of dicts with field info (name, title, type, description, required, default)
"""
exclude = exclude_fields or set()
properties = input_schema.get("properties", {})
required = set(input_schema.get("required", []))
return [
{
"name": name,
"title": schema.get("title", name),
"type": schema.get("type", "string"),
"description": schema.get("description", ""),
"required": name in required,
"default": schema.get("default"),
}
for name, schema in properties.items()
if name not in exclude
]
def format_inputs_as_markdown(inputs: list[dict[str, Any]]) -> str:
"""Format input fields as a readable markdown list.
Args:
inputs: List of input dicts from get_inputs_from_schema
Returns:
Markdown-formatted string listing the inputs
"""
if not inputs:
return "No inputs required."
lines = []
for inp in inputs:
required_marker = " (required)" if inp.get("required") else ""
default = inp.get("default")
default_info = f" [default: {default}]" if default is not None else ""
description = inp.get("description", "")
desc_info = f" - {description}" if description else ""
lines.append(f"- **{inp['name']}**{required_marker}{default_info}{desc_info}")
return "\n".join(lines)

View File

@@ -24,7 +24,6 @@ from backend.util.timezone_utils import (
) )
from .base import BaseTool from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import ( from .models import (
AgentDetails, AgentDetails,
AgentDetailsResponse, AgentDetailsResponse,
@@ -355,7 +354,19 @@ class RunAgentTool(BaseTool):
def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]: def _get_inputs_list(self, input_schema: dict[str, Any]) -> list[dict[str, Any]]:
"""Extract inputs list from schema.""" """Extract inputs list from schema."""
return get_inputs_from_schema(input_schema) inputs_list = []
if isinstance(input_schema, dict) and "properties" in input_schema:
for field_name, field_schema in input_schema["properties"].items():
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in input_schema.get("required", []),
}
)
return inputs_list
def _get_execution_modes(self, graph: GraphModel) -> list[str]: def _get_execution_modes(self, graph: GraphModel) -> list[str]:
"""Get available execution modes for the graph.""" """Get available execution modes for the graph."""

View File

@@ -8,13 +8,12 @@ from typing import Any
from backend.api.features.chat.model import ChatSession from backend.api.features.chat.model import ChatSession
from backend.data.block import get_block from backend.data.block import get_block
from backend.data.execution import ExecutionContext from backend.data.execution import ExecutionContext
from backend.data.model import CredentialsFieldInfo, CredentialsMetaInput from backend.data.model import CredentialsMetaInput
from backend.data.workspace import get_or_create_workspace from backend.data.workspace import get_or_create_workspace
from backend.integrations.creds_manager import IntegrationCredentialsManager from backend.integrations.creds_manager import IntegrationCredentialsManager
from backend.util.exceptions import BlockError from backend.util.exceptions import BlockError
from .base import BaseTool from .base import BaseTool
from .helpers import get_inputs_from_schema
from .models import ( from .models import (
BlockOutputResponse, BlockOutputResponse,
ErrorResponse, ErrorResponse,
@@ -23,10 +22,7 @@ from .models import (
ToolResponseBase, ToolResponseBase,
UserReadiness, UserReadiness,
) )
from .utils import ( from .utils import build_missing_credentials_from_field_info
build_missing_credentials_from_field_info,
match_credentials_to_requirements,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -75,22 +71,6 @@ class RunBlockTool(BaseTool):
def requires_auth(self) -> bool: def requires_auth(self) -> bool:
return True return True
def _get_credentials_requirements(
self,
block: Any,
) -> dict[str, CredentialsFieldInfo]:
"""
Get credential requirements from block's input schema.
Args:
block: Block to get credentials for
Returns:
Dict mapping field names to CredentialsFieldInfo
"""
credentials_fields_info = block.input_schema.get_credentials_fields_info()
return credentials_fields_info if credentials_fields_info else {}
async def _check_block_credentials( async def _check_block_credentials(
self, self,
user_id: str, user_id: str,
@@ -102,12 +82,53 @@ class RunBlockTool(BaseTool):
Returns: Returns:
tuple[matched_credentials, missing_credentials] tuple[matched_credentials, missing_credentials]
""" """
requirements = self._get_credentials_requirements(block) matched_credentials: dict[str, CredentialsMetaInput] = {}
missing_credentials: list[CredentialsMetaInput] = []
if not requirements: # Get credential field info from block's input schema
return {}, [] credentials_fields_info = block.input_schema.get_credentials_fields_info()
return await match_credentials_to_requirements(user_id, requirements) if not credentials_fields_info:
return matched_credentials, missing_credentials
# Get user's available credentials
creds_manager = IntegrationCredentialsManager()
available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, field_info in credentials_fields_info.items():
# field_info.provider is a frozenset of acceptable providers
# field_info.supported_types is a frozenset of acceptable types
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in field_info.provider
and cred.type in field_info.supported_types
),
None,
)
if matching_cred:
matched_credentials[field_name] = CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
else:
# Create a placeholder for the missing credential
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing_credentials.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched_credentials, missing_credentials
async def _execute( async def _execute(
self, self,
@@ -299,7 +320,27 @@ class RunBlockTool(BaseTool):
def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]: def _get_inputs_list(self, block: Any) -> list[dict[str, Any]]:
"""Extract non-credential inputs from block schema.""" """Extract non-credential inputs from block schema."""
inputs_list = []
schema = block.input_schema.jsonschema() schema = block.input_schema.jsonschema()
properties = schema.get("properties", {})
required_fields = set(schema.get("required", []))
# Get credential field names to exclude # Get credential field names to exclude
credentials_fields = set(block.input_schema.get_credentials_fields().keys()) credentials_fields = set(block.input_schema.get_credentials_fields().keys())
return get_inputs_from_schema(schema, exclude_fields=credentials_fields)
for field_name, field_schema in properties.items():
# Skip credential fields
if field_name in credentials_fields:
continue
inputs_list.append(
{
"name": field_name,
"title": field_schema.get("title", field_name),
"type": field_schema.get("type", "string"),
"description": field_schema.get("description", ""),
"required": field_name in required_fields,
}
)
return inputs_list

View File

@@ -225,127 +225,6 @@ async def get_or_create_library_agent(
return library_agents[0] return library_agents[0]
async def get_user_credentials(user_id: str) -> list:
"""
Get all available credentials for a user.
Args:
user_id: The user's ID
Returns:
List of user's credentials
"""
creds_manager = IntegrationCredentialsManager()
return await creds_manager.store.get_all_creds(user_id)
def find_matching_credential(
available_creds: list,
field_info: CredentialsFieldInfo,
):
"""
Find a credential that matches the required provider, type, and scopes.
Args:
available_creds: List of user's available credentials
field_info: CredentialsFieldInfo with provider, type, and scope requirements
Returns:
Matching credential or None
"""
for cred in available_creds:
if cred.provider not in field_info.provider:
continue
if cred.type not in field_info.supported_types:
continue
if not _credential_has_required_scopes(cred, field_info):
continue
return cred
return None
def create_credential_meta_from_match(
matching_cred,
) -> CredentialsMetaInput:
"""
Create a CredentialsMetaInput from a matched credential.
Args:
matching_cred: The matched credential object
Returns:
CredentialsMetaInput instance
"""
return CredentialsMetaInput(
id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
async def match_credentials_to_requirements(
user_id: str,
requirements: dict[str, CredentialsFieldInfo],
) -> tuple[dict[str, CredentialsMetaInput], list[CredentialsMetaInput]]:
"""
Match user's credentials against a dictionary of credential requirements.
This is the core matching logic shared by both graph and block credential matching.
Args:
user_id: The user's ID
requirements: Dict mapping field names to CredentialsFieldInfo
Returns:
tuple[matched_credentials dict, missing_credentials list]
"""
matched: dict[str, CredentialsMetaInput] = {}
missing: list[CredentialsMetaInput] = []
if not requirements:
return matched, missing
available_creds = await get_user_credentials(user_id)
for field_name, field_info in requirements.items():
matching_cred = find_matching_credential(available_creds, field_info)
if matching_cred:
try:
matched[field_name] = create_credential_meta_from_match(matching_cred)
except Exception as e:
logger.error(
f"Failed to create CredentialsMetaInput for field '{field_name}': "
f"provider={matching_cred.provider}, type={matching_cred.type}, "
f"credential_id={matching_cred.id}",
exc_info=True,
)
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=f"{field_name} (validation failed: {e})",
)
)
else:
provider = next(iter(field_info.provider), "unknown")
cred_type = next(iter(field_info.supported_types), "api_key")
missing.append(
CredentialsMetaInput(
id=field_name,
provider=provider, # type: ignore
type=cred_type, # type: ignore
title=field_name.replace("_", " ").title(),
)
)
return matched, missing
async def match_user_credentials_to_graph( async def match_user_credentials_to_graph(
user_id: str, user_id: str,
graph: GraphModel, graph: GraphModel,
@@ -363,6 +242,9 @@ async def match_user_credentials_to_graph(
Returns: Returns:
tuple[matched_credentials dict, missing_credential_descriptions list] tuple[matched_credentials dict, missing_credential_descriptions list]
""" """
graph_credentials_inputs: dict[str, CredentialsMetaInput] = {}
missing_creds: list[str] = []
# Get aggregated credentials requirements from the graph # Get aggregated credentials requirements from the graph
aggregated_creds = graph.aggregate_credentials_inputs() aggregated_creds = graph.aggregate_credentials_inputs()
logger.debug( logger.debug(
@@ -370,30 +252,69 @@ async def match_user_credentials_to_graph(
) )
if not aggregated_creds: if not aggregated_creds:
return {}, [] return graph_credentials_inputs, missing_creds
# Convert aggregated format to simple requirements dict # Get all available credentials for the user
requirements = { creds_manager = IntegrationCredentialsManager()
field_name: field_info available_creds = await creds_manager.store.get_all_creds(user_id)
for field_name, (field_info, _node_fields) in aggregated_creds.items()
}
# Use shared matching logic # For each required credential field, find a matching user credential
matched, missing_list = await match_credentials_to_requirements( # field_info.provider is a frozenset because aggregate_credentials_inputs()
user_id, requirements # combines requirements from multiple nodes. A credential matches if its
# provider is in the set of acceptable providers.
for credential_field_name, (
credential_requirements,
_node_fields,
) in aggregated_creds.items():
# Find first matching credential by provider, type, and scopes
matching_cred = next(
(
cred
for cred in available_creds
if cred.provider in credential_requirements.provider
and cred.type in credential_requirements.supported_types
and _credential_has_required_scopes(cred, credential_requirements)
),
None,
) )
# Convert missing list to string descriptions for backward compatibility if matching_cred:
missing_descriptions = [ try:
f"{cred.id} (requires provider={cred.provider}, type={cred.type})" graph_credentials_inputs[credential_field_name] = CredentialsMetaInput(
for cred in missing_list id=matching_cred.id,
provider=matching_cred.provider, # type: ignore
type=matching_cred.type,
title=matching_cred.title,
)
except Exception as e:
logger.error(
f"Failed to create CredentialsMetaInput for field '{credential_field_name}': "
f"provider={matching_cred.provider}, type={matching_cred.type}, "
f"credential_id={matching_cred.id}",
exc_info=True,
)
missing_creds.append(
f"{credential_field_name} (validation failed: {e})"
)
else:
# Build a helpful error message including scope requirements
error_parts = [
f"provider in {list(credential_requirements.provider)}",
f"type in {list(credential_requirements.supported_types)}",
] ]
if credential_requirements.required_scopes:
error_parts.append(
f"scopes including {list(credential_requirements.required_scopes)}"
)
missing_creds.append(
f"{credential_field_name} (requires {', '.join(error_parts)})"
)
logger.info( logger.info(
f"Credential matching complete: {len(matched)}/{len(aggregated_creds)} matched" f"Credential matching complete: {len(graph_credentials_inputs)}/{len(aggregated_creds)} matched"
) )
return matched, missing_descriptions return graph_credentials_inputs, missing_creds
def _credential_has_required_scopes( def _credential_has_required_scopes(

View File

@@ -0,0 +1,77 @@
"""Text encoding block for converting special characters to escape sequences."""
import codecs
from backend.data.block import (
Block,
BlockCategory,
BlockOutput,
BlockSchemaInput,
BlockSchemaOutput,
)
from backend.data.model import SchemaField
class TextEncoderBlock(Block):
"""
Encodes a string by converting special characters into escape sequences.
This block is the inverse of TextDecoderBlock. It takes text containing
special characters (like newlines, tabs, etc.) and converts them into
their escape sequence representations (e.g., newline becomes \\n).
"""
class Input(BlockSchemaInput):
"""Input schema for TextEncoderBlock."""
text: str = SchemaField(
description="A string containing special characters to be encoded",
placeholder="Your text with newlines and quotes to encode",
)
class Output(BlockSchemaOutput):
"""Output schema for TextEncoderBlock."""
encoded_text: str = SchemaField(
description="The encoded text with special characters converted to escape sequences"
)
error: str = SchemaField(description="Error message if encoding fails")
def __init__(self):
super().__init__(
id="5185f32e-4b65-4ecf-8fbb-873f003f09d6",
description="Encodes a string by converting special characters into escape sequences",
categories={BlockCategory.TEXT},
input_schema=TextEncoderBlock.Input,
output_schema=TextEncoderBlock.Output,
test_input={
"text": """Hello
World!
This is a "quoted" string."""
},
test_output=[
(
"encoded_text",
"""Hello\\nWorld!\\nThis is a "quoted" string.""",
)
],
)
async def run(self, input_data: Input, **kwargs) -> BlockOutput:
"""
Encode the input text by converting special characters to escape sequences.
Args:
input_data: The input containing the text to encode.
**kwargs: Additional keyword arguments (unused).
Yields:
The encoded text with escape sequences, or an error message if encoding fails.
"""
try:
encoded_text = codecs.encode(input_data.text, "unicode_escape").decode(
"utf-8"
)
yield "encoded_text", encoded_text
except Exception as e:
yield "error", f"Encoding error: {str(e)}"

View File

@@ -32,7 +32,7 @@ from backend.data.model import (
from backend.integrations.providers import ProviderName from backend.integrations.providers import ProviderName
from backend.util import json from backend.util import json
from backend.util.logging import TruncatedLogger from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_context, estimate_token_count from backend.util.prompt import compress_prompt, estimate_token_count
from backend.util.text import TextFormatter from backend.util.text import TextFormatter
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]") logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
@@ -634,18 +634,11 @@ async def llm_call(
context_window = llm_model.context_window context_window = llm_model.context_window
if compress_prompt_to_fit: if compress_prompt_to_fit:
result = await compress_context( prompt = compress_prompt(
messages=prompt, messages=prompt,
target_tokens=llm_model.context_window // 2, target_tokens=llm_model.context_window // 2,
client=None, # Truncation-only, no LLM summarization lossy_ok=True,
reserve=0, # Caller handles response token budget separately
) )
if result.error:
logger.warning(
f"Prompt compression did not meet target: {result.error}. "
f"Proceeding with {result.token_count} tokens."
)
prompt = result.messages
# Calculate available tokens based on context window and input length # Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt) estimated_input_tokens = estimate_token_count(prompt)

View File

@@ -0,0 +1,77 @@
import pytest
from backend.blocks.encoder_block import TextEncoderBlock
@pytest.mark.asyncio
async def test_text_encoder_basic():
"""Test basic encoding of newlines and special characters."""
block = TextEncoderBlock()
result = []
async for output in block.run(TextEncoderBlock.Input(text="Hello\nWorld")):
result.append(output)
assert len(result) == 1
assert result[0][0] == "encoded_text"
assert result[0][1] == "Hello\\nWorld"
@pytest.mark.asyncio
async def test_text_encoder_multiple_escapes():
"""Test encoding of multiple escape sequences."""
block = TextEncoderBlock()
result = []
async for output in block.run(
TextEncoderBlock.Input(text="Line1\nLine2\tTabbed\rCarriage")
):
result.append(output)
assert len(result) == 1
assert result[0][0] == "encoded_text"
assert "\\n" in result[0][1]
assert "\\t" in result[0][1]
assert "\\r" in result[0][1]
@pytest.mark.asyncio
async def test_text_encoder_unicode():
"""Test that unicode characters are handled correctly."""
block = TextEncoderBlock()
result = []
async for output in block.run(TextEncoderBlock.Input(text="Hello 世界\n")):
result.append(output)
assert len(result) == 1
assert result[0][0] == "encoded_text"
# Unicode characters should be escaped as \uXXXX sequences
assert "\\n" in result[0][1]
@pytest.mark.asyncio
async def test_text_encoder_empty_string():
"""Test encoding of an empty string."""
block = TextEncoderBlock()
result = []
async for output in block.run(TextEncoderBlock.Input(text="")):
result.append(output)
assert len(result) == 1
assert result[0][0] == "encoded_text"
assert result[0][1] == ""
@pytest.mark.asyncio
async def test_text_encoder_error_handling():
"""Test that encoding errors are handled gracefully."""
from unittest.mock import patch
block = TextEncoderBlock()
result = []
with patch("codecs.encode", side_effect=Exception("Mocked encoding error")):
async for output in block.run(TextEncoderBlock.Input(text="test")):
result.append(output)
assert len(result) == 1
assert result[0][0] == "error"
assert "Mocked encoding error" in result[0][1]

View File

@@ -17,7 +17,6 @@ from backend.data.analytics import (
get_accuracy_trends_and_alerts, get_accuracy_trends_and_alerts,
get_marketplace_graphs_for_monitoring, get_marketplace_graphs_for_monitoring,
) )
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.credit import UsageTransactionMetadata, get_user_credit_model from backend.data.credit import UsageTransactionMetadata, get_user_credit_model
from backend.data.execution import ( from backend.data.execution import (
create_graph_execution, create_graph_execution,
@@ -220,9 +219,6 @@ class DatabaseManager(AppService):
# Onboarding # Onboarding
increment_onboarding_runs = _(increment_onboarding_runs) increment_onboarding_runs = _(increment_onboarding_runs)
# OAuth
cleanup_expired_oauth_tokens = _(cleanup_expired_oauth_tokens)
# Store # Store
get_store_agents = _(get_store_agents) get_store_agents = _(get_store_agents)
get_store_agent_details = _(get_store_agent_details) get_store_agent_details = _(get_store_agent_details)
@@ -353,9 +349,6 @@ class DatabaseManagerAsyncClient(AppServiceClient):
# Onboarding # Onboarding
increment_onboarding_runs = d.increment_onboarding_runs increment_onboarding_runs = d.increment_onboarding_runs
# OAuth
cleanup_expired_oauth_tokens = d.cleanup_expired_oauth_tokens
# Store # Store
get_store_agents = d.get_store_agents get_store_agents = d.get_store_agents
get_store_agent_details = d.get_store_agent_details get_store_agent_details = d.get_store_agent_details

View File

@@ -24,9 +24,11 @@ from dotenv import load_dotenv
from pydantic import BaseModel, Field, ValidationError from pydantic import BaseModel, Field, ValidationError
from sqlalchemy import MetaData, create_engine from sqlalchemy import MetaData, create_engine
from backend.data.auth.oauth import cleanup_expired_oauth_tokens
from backend.data.block import BlockInput from backend.data.block import BlockInput
from backend.data.execution import GraphExecutionWithNodes from backend.data.execution import GraphExecutionWithNodes
from backend.data.model import CredentialsMetaInput from backend.data.model import CredentialsMetaInput
from backend.data.onboarding import increment_onboarding_runs
from backend.executor import utils as execution_utils from backend.executor import utils as execution_utils
from backend.monitoring import ( from backend.monitoring import (
NotificationJobArgs, NotificationJobArgs,
@@ -36,11 +38,7 @@ from backend.monitoring import (
report_execution_accuracy_alerts, report_execution_accuracy_alerts,
report_late_executions, report_late_executions,
) )
from backend.util.clients import ( from backend.util.clients import get_database_manager_client, get_scheduler_client
get_database_manager_async_client,
get_database_manager_client,
get_scheduler_client,
)
from backend.util.cloud_storage import cleanup_expired_files_async from backend.util.cloud_storage import cleanup_expired_files_async
from backend.util.exceptions import ( from backend.util.exceptions import (
GraphNotFoundError, GraphNotFoundError,
@@ -150,7 +148,6 @@ def execute_graph(**kwargs):
async def _execute_graph(**kwargs): async def _execute_graph(**kwargs):
args = GraphExecutionJobArgs(**kwargs) args = GraphExecutionJobArgs(**kwargs)
start_time = asyncio.get_event_loop().time() start_time = asyncio.get_event_loop().time()
db = get_database_manager_async_client()
try: try:
logger.info(f"Executing recurring job for graph #{args.graph_id}") logger.info(f"Executing recurring job for graph #{args.graph_id}")
graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution( graph_exec: GraphExecutionWithNodes = await execution_utils.add_graph_execution(
@@ -160,7 +157,7 @@ async def _execute_graph(**kwargs):
inputs=args.input_data, inputs=args.input_data,
graph_credentials_inputs=args.input_credentials, graph_credentials_inputs=args.input_credentials,
) )
await db.increment_onboarding_runs(args.user_id) await increment_onboarding_runs(args.user_id)
elapsed = asyncio.get_event_loop().time() - start_time elapsed = asyncio.get_event_loop().time() - start_time
logger.info( logger.info(
f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} " f"Graph execution started with ID {graph_exec.id} for graph {args.graph_id} "
@@ -249,13 +246,8 @@ def cleanup_expired_files():
def cleanup_oauth_tokens(): def cleanup_oauth_tokens():
"""Clean up expired OAuth tokens from the database.""" """Clean up expired OAuth tokens from the database."""
# Wait for completion # Wait for completion
async def _cleanup(): run_async(cleanup_expired_oauth_tokens())
db = get_database_manager_async_client()
return await db.cleanup_expired_oauth_tokens()
run_async(_cleanup())
def execution_accuracy_alerts(): def execution_accuracy_alerts():

View File

@@ -1,19 +1,10 @@
from __future__ import annotations
import logging
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from typing import Any
from typing import TYPE_CHECKING, Any
from tiktoken import encoding_for_model from tiktoken import encoding_for_model
from backend.util import json from backend.util import json
if TYPE_CHECKING:
from openai import AsyncOpenAI
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------# # ---------------------------------------------------------------------------#
# CONSTANTS # # CONSTANTS #
# ---------------------------------------------------------------------------# # ---------------------------------------------------------------------------#
@@ -109,17 +100,9 @@ def _is_objective_message(msg: dict) -> bool:
def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None: def _truncate_tool_message_content(msg: dict, enc, max_tokens: int) -> None:
""" """
Carefully truncate tool message content while preserving tool structure. Carefully truncate tool message content while preserving tool structure.
Handles both Anthropic-style (list content) and OpenAI-style (string content) tool messages. Only truncates tool_result content, leaves tool_use intact.
""" """
content = msg.get("content") content = msg.get("content")
# OpenAI-style tool message: role="tool" with string content
if msg.get("role") == "tool" and isinstance(content, str):
if _tok_len(content, enc) > max_tokens:
msg["content"] = _truncate_middle_tokens(content, enc, max_tokens)
return
# Anthropic-style: list content with tool_result items
if not isinstance(content, list): if not isinstance(content, list):
return return
@@ -157,6 +140,141 @@ def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
# ---------------------------------------------------------------------------# # ---------------------------------------------------------------------------#
def compress_prompt(
messages: list[dict],
target_tokens: int,
*,
model: str = "gpt-4o",
reserve: int = 2_048,
start_cap: int = 8_192,
floor_cap: int = 128,
lossy_ok: bool = True,
) -> list[dict]:
"""
Shrink *messages* so that::
token_count(prompt) + reserve ≤ target_tokens
Strategy
--------
1. **Token-aware truncation** progressively halve a per-message cap
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
*content* of every message except the first and last. Tool shells
are included: we keep the envelope but shorten huge payloads.
2. **Middle-out deletion** if still over the limit, delete whole
messages working outward from the centre, **skipping** any message
that contains ``tool_calls`` or has ``role == "tool"``.
3. **Last-chance trim** if still too big, truncate the *first* and
*last* message bodies down to `floor_cap` tokens.
4. If the prompt is *still* too large:
• raise ``ValueError`` when ``lossy_ok == False`` (default)
• return the partially-trimmed prompt when ``lossy_ok == True``
Parameters
----------
messages Complete chat history (will be deep-copied).
model Model name; passed to tiktoken to pick the right
tokenizer (gpt-4o → 'o200k_base', others fallback).
target_tokens Hard ceiling for prompt size **excluding** the model's
forthcoming answer.
reserve How many tokens you want to leave available for that
answer (`max_tokens` in your subsequent completion call).
start_cap Initial per-message truncation ceiling (tokens).
floor_cap Lowest cap we'll accept before moving to deletions.
lossy_ok If *True* return best-effort prompt instead of raising
after all trim passes have been exhausted.
Returns
-------
list[dict] A *new* messages list that abides by the rules above.
"""
enc = encoding_for_model(model) # best-match tokenizer
msgs = deepcopy(messages) # never mutate caller
def total_tokens() -> int:
"""Current size of *msgs* in tokens."""
return sum(_msg_tokens(m, enc) for m in msgs)
original_token_count = total_tokens()
if original_token_count + reserve <= target_tokens:
return msgs
# ---- STEP 0 : normalise content --------------------------------------
# Convert non-string payloads to strings so token counting is coherent.
for i, m in enumerate(msgs):
if not isinstance(m.get("content"), str) and m.get("content") is not None:
if _is_tool_message(m):
continue
# Keep first and last messages intact (unless they're tool messages)
if i == 0 or i == len(msgs) - 1:
continue
# Reasonable 20k-char ceiling prevents pathological blobs
content_str = json.dumps(m["content"], separators=(",", ":"))
if len(content_str) > 20_000:
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
m["content"] = content_str
# ---- STEP 1 : token-aware truncation ---------------------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for m in msgs[1:-1]: # keep first & last intact
if _is_tool_message(m):
# For tool messages, only truncate tool result content, preserve structure
_truncate_tool_message_content(m, enc, cap)
continue
if _is_objective_message(m):
# Never truncate objective messages - they contain the core task
continue
content = m.get("content") or ""
if _tok_len(content, enc) > cap:
m["content"] = _truncate_middle_tokens(content, enc, cap)
cap //= 2 # tighten the screw
# ---- STEP 2 : middle-out deletion -----------------------------------
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
# Identify all deletable messages (not first/last, not tool messages, not objective messages)
deletable_indices = []
for i in range(1, len(msgs) - 1): # Skip first and last
if not _is_tool_message(msgs[i]) and not _is_objective_message(msgs[i]):
deletable_indices.append(i)
if not deletable_indices:
break # nothing more we can drop
# Delete from center outward - find the index closest to center
centre = len(msgs) // 2
to_delete = min(deletable_indices, key=lambda i: abs(i - centre))
del msgs[to_delete]
# ---- STEP 3 : final safety-net trim on first & last ------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for idx in (0, -1): # first and last
if _is_tool_message(msgs[idx]):
# For tool messages at first/last position, truncate tool result content only
_truncate_tool_message_content(msgs[idx], enc, cap)
continue
text = msgs[idx].get("content") or ""
if _tok_len(text, enc) > cap:
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
cap //= 2 # tighten the screw
# ---- STEP 4 : success or fail-gracefully -----------------------------
if total_tokens() + reserve > target_tokens and not lossy_ok:
raise ValueError(
"compress_prompt: prompt still exceeds budget "
f"({total_tokens() + reserve} > {target_tokens})."
)
return msgs
def estimate_token_count( def estimate_token_count(
messages: list[dict], messages: list[dict],
*, *,
@@ -175,8 +293,7 @@ def estimate_token_count(
------- -------
int Token count. int Token count.
""" """
token_model = _normalize_model_for_tokenizer(model) enc = encoding_for_model(model) # best-match tokenizer
enc = encoding_for_model(token_model)
return sum(_msg_tokens(m, enc) for m in messages) return sum(_msg_tokens(m, enc) for m in messages)
@@ -198,543 +315,6 @@ def estimate_token_count_str(
------- -------
int Token count. int Token count.
""" """
token_model = _normalize_model_for_tokenizer(model) enc = encoding_for_model(model) # best-match tokenizer
enc = encoding_for_model(token_model)
text = json.dumps(text) if not isinstance(text, str) else text text = json.dumps(text) if not isinstance(text, str) else text
return _tok_len(text, enc) return _tok_len(text, enc)
# ---------------------------------------------------------------------------#
# UNIFIED CONTEXT COMPRESSION #
# ---------------------------------------------------------------------------#
# Default thresholds
DEFAULT_TOKEN_THRESHOLD = 120_000
DEFAULT_KEEP_RECENT = 15
@dataclass
class CompressResult:
"""Result of context compression."""
messages: list[dict]
token_count: int
was_compacted: bool
error: str | None = None
original_token_count: int = 0
messages_summarized: int = 0
messages_dropped: int = 0
def _normalize_model_for_tokenizer(model: str) -> str:
"""Normalize model name for tiktoken tokenizer selection."""
if "/" in model:
model = model.split("/")[-1]
if "claude" in model.lower() or not any(
known in model.lower() for known in ["gpt", "o1", "chatgpt", "text-"]
):
return "gpt-4o"
return model
def _extract_tool_call_ids_from_message(msg: dict) -> set[str]:
"""
Extract tool_call IDs from an assistant message.
Supports both formats:
- OpenAI: {"role": "assistant", "tool_calls": [{"id": "..."}]}
- Anthropic: {"role": "assistant", "content": [{"type": "tool_use", "id": "..."}]}
Returns:
Set of tool_call IDs found in the message.
"""
ids: set[str] = set()
if msg.get("role") != "assistant":
return ids
# OpenAI format: tool_calls array
if msg.get("tool_calls"):
for tc in msg["tool_calls"]:
tc_id = tc.get("id")
if tc_id:
ids.add(tc_id)
# Anthropic format: content list with tool_use blocks
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_use":
tc_id = block.get("id")
if tc_id:
ids.add(tc_id)
return ids
def _extract_tool_response_ids_from_message(msg: dict) -> set[str]:
"""
Extract tool_call IDs that this message is responding to.
Supports both formats:
- OpenAI: {"role": "tool", "tool_call_id": "..."}
- Anthropic: {"role": "user", "content": [{"type": "tool_result", "tool_use_id": "..."}]}
Returns:
Set of tool_call IDs this message responds to.
"""
ids: set[str] = set()
# OpenAI format: role=tool with tool_call_id
if msg.get("role") == "tool":
tc_id = msg.get("tool_call_id")
if tc_id:
ids.add(tc_id)
# Anthropic format: content list with tool_result blocks
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_result":
tc_id = block.get("tool_use_id")
if tc_id:
ids.add(tc_id)
return ids
def _is_tool_response_message(msg: dict) -> bool:
"""Check if message is a tool response (OpenAI or Anthropic format)."""
# OpenAI format
if msg.get("role") == "tool":
return True
# Anthropic format
content = msg.get("content")
if isinstance(content, list):
for block in content:
if isinstance(block, dict) and block.get("type") == "tool_result":
return True
return False
def _remove_orphan_tool_responses(
messages: list[dict], orphan_ids: set[str]
) -> list[dict]:
"""
Remove tool response messages/blocks that reference orphan tool_call IDs.
Supports both OpenAI and Anthropic formats.
For Anthropic messages with mixed valid/orphan tool_result blocks,
filters out only the orphan blocks instead of dropping the entire message.
"""
result = []
for msg in messages:
# OpenAI format: role=tool - drop entire message if orphan
if msg.get("role") == "tool":
tc_id = msg.get("tool_call_id")
if tc_id and tc_id in orphan_ids:
continue
result.append(msg)
continue
# Anthropic format: content list may have mixed tool_result blocks
content = msg.get("content")
if isinstance(content, list):
has_tool_results = any(
isinstance(b, dict) and b.get("type") == "tool_result" for b in content
)
if has_tool_results:
# Filter out orphan tool_result blocks, keep valid ones
filtered_content = [
block
for block in content
if not (
isinstance(block, dict)
and block.get("type") == "tool_result"
and block.get("tool_use_id") in orphan_ids
)
]
# Only keep message if it has remaining content
if filtered_content:
msg = msg.copy()
msg["content"] = filtered_content
result.append(msg)
continue
result.append(msg)
return result
def _ensure_tool_pairs_intact(
recent_messages: list[dict],
all_messages: list[dict],
start_index: int,
) -> list[dict]:
"""
Ensure tool_call/tool_response pairs stay together after slicing.
When slicing messages for context compaction, a naive slice can separate
an assistant message containing tool_calls from its corresponding tool
response messages. This causes API validation errors (e.g., Anthropic's
"unexpected tool_use_id found in tool_result blocks").
This function checks for orphan tool responses in the slice and extends
backwards to include their corresponding assistant messages.
Supports both formats:
- OpenAI: tool_calls array + role="tool" responses
- Anthropic: tool_use blocks + tool_result blocks
Args:
recent_messages: The sliced messages to validate
all_messages: The complete message list (for looking up missing assistants)
start_index: The index in all_messages where recent_messages begins
Returns:
A potentially extended list of messages with tool pairs intact
"""
if not recent_messages:
return recent_messages
# Collect all tool_call_ids from assistant messages in the slice
available_tool_call_ids: set[str] = set()
for msg in recent_messages:
available_tool_call_ids |= _extract_tool_call_ids_from_message(msg)
# Find orphan tool responses (responses whose tool_call_id is missing)
orphan_tool_call_ids: set[str] = set()
for msg in recent_messages:
response_ids = _extract_tool_response_ids_from_message(msg)
for tc_id in response_ids:
if tc_id not in available_tool_call_ids:
orphan_tool_call_ids.add(tc_id)
if not orphan_tool_call_ids:
# No orphans, slice is valid
return recent_messages
# Find the assistant messages that contain the orphan tool_call_ids
# Search backwards from start_index in all_messages
messages_to_prepend: list[dict] = []
for i in range(start_index - 1, -1, -1):
msg = all_messages[i]
msg_tool_ids = _extract_tool_call_ids_from_message(msg)
if msg_tool_ids & orphan_tool_call_ids:
# This assistant message has tool_calls we need
# Also collect its contiguous tool responses that follow it
assistant_and_responses: list[dict] = [msg]
# Scan forward from this assistant to collect tool responses
for j in range(i + 1, start_index):
following_msg = all_messages[j]
following_response_ids = _extract_tool_response_ids_from_message(
following_msg
)
if following_response_ids and following_response_ids & msg_tool_ids:
assistant_and_responses.append(following_msg)
elif not _is_tool_response_message(following_msg):
# Stop at first non-tool-response message
break
# Prepend the assistant and its tool responses (maintain order)
messages_to_prepend = assistant_and_responses + messages_to_prepend
# Mark these as found
orphan_tool_call_ids -= msg_tool_ids
# Also add this assistant's tool_call_ids to available set
available_tool_call_ids |= msg_tool_ids
if not orphan_tool_call_ids:
# Found all missing assistants
break
if orphan_tool_call_ids:
# Some tool_call_ids couldn't be resolved - remove those tool responses
# This shouldn't happen in normal operation but handles edge cases
logger.warning(
f"Could not find assistant messages for tool_call_ids: {orphan_tool_call_ids}. "
"Removing orphan tool responses."
)
recent_messages = _remove_orphan_tool_responses(
recent_messages, orphan_tool_call_ids
)
if messages_to_prepend:
logger.info(
f"Extended recent messages by {len(messages_to_prepend)} to preserve "
f"tool_call/tool_response pairs"
)
return messages_to_prepend + recent_messages
return recent_messages
async def _summarize_messages_llm(
messages: list[dict],
client: AsyncOpenAI,
model: str,
timeout: float = 30.0,
) -> str:
"""Summarize messages using an LLM."""
conversation = []
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if content and role in ("user", "assistant", "tool"):
conversation.append(f"{role.upper()}: {content}")
conversation_text = "\n\n".join(conversation)
if not conversation_text:
return "No conversation history available."
# Limit to ~100k chars for safety
MAX_CHARS = 100_000
if len(conversation_text) > MAX_CHARS:
conversation_text = conversation_text[:MAX_CHARS] + "\n\n[truncated]"
response = await client.with_options(timeout=timeout).chat.completions.create(
model=model,
messages=[
{
"role": "system",
"content": (
"Create a detailed summary of the conversation so far. "
"This summary will be used as context when continuing the conversation.\n\n"
"Before writing the summary, analyze each message chronologically to identify:\n"
"- User requests and their explicit goals\n"
"- Your approach and key decisions made\n"
"- Technical specifics (file names, tool outputs, function signatures)\n"
"- Errors encountered and resolutions applied\n\n"
"You MUST include ALL of the following sections:\n\n"
"## 1. Primary Request and Intent\n"
"The user's explicit goals and what they are trying to accomplish.\n\n"
"## 2. Key Technical Concepts\n"
"Technologies, frameworks, tools, and patterns being used or discussed.\n\n"
"## 3. Files and Resources Involved\n"
"Specific files examined or modified, with relevant snippets and identifiers.\n\n"
"## 4. Errors and Fixes\n"
"Problems encountered, error messages, and their resolutions. "
"Include any user feedback on fixes.\n\n"
"## 5. Problem Solving\n"
"Issues that have been resolved and how they were addressed.\n\n"
"## 6. All User Messages\n"
"A complete list of all user inputs (excluding tool outputs) to preserve their exact requests.\n\n"
"## 7. Pending Tasks\n"
"Work items the user explicitly requested that have not yet been completed.\n\n"
"## 8. Current Work\n"
"Precise description of what was being worked on most recently, including relevant context.\n\n"
"## 9. Next Steps\n"
"What should happen next, aligned with the user's most recent requests. "
"Include verbatim quotes of recent instructions if relevant."
),
},
{"role": "user", "content": f"Summarize:\n\n{conversation_text}"},
],
max_tokens=1500,
temperature=0.3,
)
return response.choices[0].message.content or "No summary available."
async def compress_context(
messages: list[dict],
target_tokens: int = DEFAULT_TOKEN_THRESHOLD,
*,
model: str = "gpt-4o",
client: AsyncOpenAI | None = None,
keep_recent: int = DEFAULT_KEEP_RECENT,
reserve: int = 2_048,
start_cap: int = 8_192,
floor_cap: int = 128,
) -> CompressResult:
"""
Unified context compression that combines summarization and truncation strategies.
Strategy (in order):
1. **LLM summarization** If client provided, summarize old messages into a
single context message while keeping recent messages intact. This is the
primary strategy for chat service.
2. **Content truncation** Progressively halve a per-message cap and truncate
bloated message content (tool outputs, large pastes). Preserves all messages
but shortens their content. Primary strategy when client=None (LLM blocks).
3. **Middle-out deletion** Delete whole messages one at a time from the center
outward, skipping tool messages and objective messages.
4. **First/last trim** Truncate first and last message content as last resort.
Parameters
----------
messages Complete chat history (will be deep-copied).
target_tokens Hard ceiling for prompt size.
model Model name for tokenization and summarization.
client AsyncOpenAI client. If provided, enables LLM summarization
as the first strategy. If None, skips to truncation strategies.
keep_recent Number of recent messages to preserve during summarization.
reserve Tokens to reserve for model response.
start_cap Initial per-message truncation ceiling (tokens).
floor_cap Lowest cap before moving to deletions.
Returns
-------
CompressResult with compressed messages and metadata.
"""
# Guard clause for empty messages
if not messages:
return CompressResult(
messages=[],
token_count=0,
was_compacted=False,
original_token_count=0,
)
token_model = _normalize_model_for_tokenizer(model)
enc = encoding_for_model(token_model)
msgs = deepcopy(messages)
def total_tokens() -> int:
return sum(_msg_tokens(m, enc) for m in msgs)
original_count = total_tokens()
# Already under limit
if original_count + reserve <= target_tokens:
return CompressResult(
messages=msgs,
token_count=original_count,
was_compacted=False,
original_token_count=original_count,
)
messages_summarized = 0
messages_dropped = 0
# ---- STEP 1: LLM summarization (if client provided) -------------------
# This is the primary compression strategy for chat service.
# Summarize old messages while keeping recent ones intact.
if client is not None:
has_system = len(msgs) > 0 and msgs[0].get("role") == "system"
system_msg = msgs[0] if has_system else None
# Calculate old vs recent messages
if has_system:
if len(msgs) > keep_recent + 1:
old_msgs = msgs[1:-keep_recent]
recent_msgs = msgs[-keep_recent:]
else:
old_msgs = []
recent_msgs = msgs[1:] if len(msgs) > 1 else []
else:
if len(msgs) > keep_recent:
old_msgs = msgs[:-keep_recent]
recent_msgs = msgs[-keep_recent:]
else:
old_msgs = []
recent_msgs = msgs
# Ensure tool pairs stay intact
slice_start = max(0, len(msgs) - keep_recent)
recent_msgs = _ensure_tool_pairs_intact(recent_msgs, msgs, slice_start)
if old_msgs:
try:
summary_text = await _summarize_messages_llm(old_msgs, client, model)
summary_msg = {
"role": "assistant",
"content": f"[Previous conversation summary — for context only]: {summary_text}",
}
messages_summarized = len(old_msgs)
if has_system:
msgs = [system_msg, summary_msg] + recent_msgs
else:
msgs = [summary_msg] + recent_msgs
logger.info(
f"Context summarized: {original_count} -> {total_tokens()} tokens, "
f"summarized {messages_summarized} messages"
)
except Exception as e:
logger.warning(f"Summarization failed, continuing with truncation: {e}")
# Fall through to content truncation
# ---- STEP 2: Normalize content ----------------------------------------
# Convert non-string payloads to strings so token counting is coherent.
# Always run this before truncation to ensure consistent token counting.
for i, m in enumerate(msgs):
if not isinstance(m.get("content"), str) and m.get("content") is not None:
if _is_tool_message(m):
continue
if i == 0 or i == len(msgs) - 1:
continue
content_str = json.dumps(m["content"], separators=(",", ":"))
if len(content_str) > 20_000:
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
m["content"] = content_str
# ---- STEP 3: Token-aware content truncation ---------------------------
# Progressively halve per-message cap and truncate bloated content.
# This preserves all messages but shortens their content.
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for m in msgs[1:-1]:
if _is_tool_message(m):
_truncate_tool_message_content(m, enc, cap)
continue
if _is_objective_message(m):
continue
content = m.get("content") or ""
if _tok_len(content, enc) > cap:
m["content"] = _truncate_middle_tokens(content, enc, cap)
cap //= 2
# ---- STEP 4: Middle-out deletion --------------------------------------
# Delete messages one at a time from the center outward.
# This is more granular than dropping all old messages at once.
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
deletable: list[int] = []
for i in range(1, len(msgs) - 1):
msg = msgs[i]
if (
msg is not None
and not _is_tool_message(msg)
and not _is_objective_message(msg)
):
deletable.append(i)
if not deletable:
break
centre = len(msgs) // 2
to_delete = min(deletable, key=lambda i: abs(i - centre))
del msgs[to_delete]
messages_dropped += 1
# ---- STEP 5: Final trim on first/last ---------------------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for idx in (0, -1):
msg = msgs[idx]
if msg is None:
continue
if _is_tool_message(msg):
_truncate_tool_message_content(msg, enc, cap)
continue
text = msg.get("content") or ""
if _tok_len(text, enc) > cap:
msg["content"] = _truncate_middle_tokens(text, enc, cap)
cap //= 2
# Filter out any None values that may have been introduced
final_msgs: list[dict] = [m for m in msgs if m is not None]
final_count = sum(_msg_tokens(m, enc) for m in final_msgs)
error = None
if final_count + reserve > target_tokens:
error = f"Could not compress below target ({final_count + reserve} > {target_tokens})"
logger.warning(error)
return CompressResult(
messages=final_msgs,
token_count=final_count,
was_compacted=True,
error=error,
original_token_count=original_count,
messages_summarized=messages_summarized,
messages_dropped=messages_dropped,
)

View File

@@ -1,21 +1,10 @@
"""Tests for prompt utility functions, especially tool call token counting.""" """Tests for prompt utility functions, especially tool call token counting."""
from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from tiktoken import encoding_for_model from tiktoken import encoding_for_model
from backend.util import json from backend.util import json
from backend.util.prompt import ( from backend.util.prompt import _msg_tokens, estimate_token_count
CompressResult,
_ensure_tool_pairs_intact,
_msg_tokens,
_normalize_model_for_tokenizer,
_truncate_middle_tokens,
_truncate_tool_message_content,
compress_context,
estimate_token_count,
)
class TestMsgTokens: class TestMsgTokens:
@@ -287,690 +276,3 @@ class TestEstimateTokenCount:
assert total_tokens == expected_total assert total_tokens == expected_total
assert total_tokens > 20 # Should be substantial assert total_tokens > 20 # Should be substantial
class TestNormalizeModelForTokenizer:
"""Test model name normalization for tiktoken."""
def test_openai_models_unchanged(self):
"""Test that OpenAI models are returned as-is."""
assert _normalize_model_for_tokenizer("gpt-4o") == "gpt-4o"
assert _normalize_model_for_tokenizer("gpt-4") == "gpt-4"
assert _normalize_model_for_tokenizer("gpt-3.5-turbo") == "gpt-3.5-turbo"
def test_claude_models_normalized(self):
"""Test that Claude models are normalized to gpt-4o."""
assert _normalize_model_for_tokenizer("claude-3-opus") == "gpt-4o"
assert _normalize_model_for_tokenizer("claude-3-sonnet") == "gpt-4o"
assert _normalize_model_for_tokenizer("anthropic/claude-3-haiku") == "gpt-4o"
def test_openrouter_paths_extracted(self):
"""Test that OpenRouter model paths are handled."""
assert _normalize_model_for_tokenizer("openai/gpt-4o") == "gpt-4o"
assert _normalize_model_for_tokenizer("anthropic/claude-3-opus") == "gpt-4o"
def test_unknown_models_default_to_gpt4o(self):
"""Test that unknown models default to gpt-4o."""
assert _normalize_model_for_tokenizer("some-random-model") == "gpt-4o"
assert _normalize_model_for_tokenizer("llama-3-70b") == "gpt-4o"
class TestTruncateToolMessageContent:
"""Test tool message content truncation."""
@pytest.fixture
def enc(self):
return encoding_for_model("gpt-4o")
def test_truncate_openai_tool_message(self, enc):
"""Test truncation of OpenAI-style tool message with string content."""
long_content = "x" * 10000
msg = {"role": "tool", "tool_call_id": "call_123", "content": long_content}
_truncate_tool_message_content(msg, enc, max_tokens=100)
# Content should be truncated
assert len(msg["content"]) < len(long_content)
assert "" in msg["content"] # Has ellipsis marker
def test_truncate_anthropic_tool_result(self, enc):
"""Test truncation of Anthropic-style tool_result."""
long_content = "y" * 10000
msg = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_123",
"content": long_content,
}
],
}
_truncate_tool_message_content(msg, enc, max_tokens=100)
# Content should be truncated
result_content = msg["content"][0]["content"]
assert len(result_content) < len(long_content)
assert "" in result_content
def test_preserve_tool_use_blocks(self, enc):
"""Test that tool_use blocks are not truncated."""
msg = {
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_123",
"name": "some_function",
"input": {"key": "value" * 1000}, # Large input
}
],
}
original = json.dumps(msg["content"][0]["input"])
_truncate_tool_message_content(msg, enc, max_tokens=10)
# tool_use should be unchanged
assert json.dumps(msg["content"][0]["input"]) == original
def test_no_truncation_when_under_limit(self, enc):
"""Test that short content is not modified."""
msg = {"role": "tool", "tool_call_id": "call_123", "content": "Short content"}
original = msg["content"]
_truncate_tool_message_content(msg, enc, max_tokens=1000)
assert msg["content"] == original
class TestTruncateMiddleTokens:
"""Test middle truncation of text."""
@pytest.fixture
def enc(self):
return encoding_for_model("gpt-4o")
def test_truncates_long_text(self, enc):
"""Test that long text is truncated with ellipsis in middle."""
long_text = "word " * 1000
result = _truncate_middle_tokens(long_text, enc, max_tok=50)
assert len(enc.encode(result)) <= 52 # Allow some slack for ellipsis
assert "" in result
assert result.startswith("word") # Head preserved
assert result.endswith("word ") # Tail preserved
def test_preserves_short_text(self, enc):
"""Test that short text is not modified."""
short_text = "Hello world"
result = _truncate_middle_tokens(short_text, enc, max_tok=100)
assert result == short_text
class TestEnsureToolPairsIntact:
"""Test tool call/response pair preservation for both OpenAI and Anthropic formats."""
# ---- OpenAI Format Tests ----
def test_openai_adds_missing_tool_call(self):
"""Test that orphaned OpenAI tool_response gets its tool_call prepended."""
all_msgs = [
{"role": "system", "content": "You are helpful."},
{
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
],
},
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (the tool response)
recent = [all_msgs[2], all_msgs[3]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the tool_call message
assert len(result) == 3
assert result[0]["role"] == "assistant"
assert "tool_calls" in result[0]
def test_openai_keeps_complete_pairs(self):
"""Test that complete OpenAI pairs are unchanged."""
all_msgs = [
{"role": "system", "content": "System"},
{
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "f1"}}
],
},
{"role": "tool", "tool_call_id": "call_1", "content": "result"},
]
recent = all_msgs[1:] # Include both tool_call and response
start_index = 1
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
assert len(result) == 2 # No messages added
def test_openai_multiple_tool_calls(self):
"""Test multiple OpenAI tool calls in one assistant message."""
all_msgs = [
{"role": "system", "content": "System"},
{
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "f1"}},
{"id": "call_2", "type": "function", "function": {"name": "f2"}},
],
},
{"role": "tool", "tool_call_id": "call_1", "content": "result1"},
{"role": "tool", "tool_call_id": "call_2", "content": "result2"},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (first tool response)
recent = [all_msgs[2], all_msgs[3], all_msgs[4]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the assistant message with both tool_calls
assert len(result) == 4
assert result[0]["role"] == "assistant"
assert len(result[0]["tool_calls"]) == 2
# ---- Anthropic Format Tests ----
def test_anthropic_adds_missing_tool_use(self):
"""Test that orphaned Anthropic tool_result gets its tool_use prepended."""
all_msgs = [
{"role": "system", "content": "You are helpful."},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_123",
"name": "get_weather",
"input": {"location": "SF"},
}
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_123",
"content": "22°C and sunny",
}
],
},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (the tool_result)
recent = [all_msgs[2], all_msgs[3]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the tool_use message
assert len(result) == 3
assert result[0]["role"] == "assistant"
assert result[0]["content"][0]["type"] == "tool_use"
def test_anthropic_keeps_complete_pairs(self):
"""Test that complete Anthropic pairs are unchanged."""
all_msgs = [
{"role": "system", "content": "System"},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_456",
"name": "calculator",
"input": {"expr": "2+2"},
}
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_456",
"content": "4",
}
],
},
]
recent = all_msgs[1:] # Include both tool_use and result
start_index = 1
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
assert len(result) == 2 # No messages added
def test_anthropic_multiple_tool_uses(self):
"""Test multiple Anthropic tool_use blocks in one message."""
all_msgs = [
{"role": "system", "content": "System"},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me check both..."},
{
"type": "tool_use",
"id": "toolu_1",
"name": "get_weather",
"input": {"city": "NYC"},
},
{
"type": "tool_use",
"id": "toolu_2",
"name": "get_weather",
"input": {"city": "LA"},
},
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_1",
"content": "Cold",
},
{
"type": "tool_result",
"tool_use_id": "toolu_2",
"content": "Warm",
},
],
},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (tool_result)
recent = [all_msgs[2], all_msgs[3]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the assistant message with both tool_uses
assert len(result) == 3
assert result[0]["role"] == "assistant"
tool_use_count = sum(
1 for b in result[0]["content"] if b.get("type") == "tool_use"
)
assert tool_use_count == 2
# ---- Mixed/Edge Case Tests ----
def test_anthropic_with_type_message_field(self):
"""Test Anthropic format with 'type': 'message' field (smart_decision_maker style)."""
all_msgs = [
{"role": "system", "content": "You are helpful."},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_abc",
"name": "search",
"input": {"q": "test"},
}
],
},
{
"role": "user",
"type": "message", # Extra field from smart_decision_maker
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_abc",
"content": "Found results",
}
],
},
{"role": "user", "content": "Thanks!"},
]
# Recent messages start at index 2 (the tool_result with 'type': 'message')
recent = [all_msgs[2], all_msgs[3]]
start_index = 2
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the tool_use message
assert len(result) == 3
assert result[0]["role"] == "assistant"
assert result[0]["content"][0]["type"] == "tool_use"
def test_handles_no_tool_messages(self):
"""Test messages without tool calls."""
all_msgs = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
recent = all_msgs
start_index = 0
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
assert result == all_msgs
def test_handles_empty_messages(self):
"""Test empty message list."""
result = _ensure_tool_pairs_intact([], [], 0)
assert result == []
def test_mixed_text_and_tool_content(self):
"""Test Anthropic message with mixed text and tool_use content."""
all_msgs = [
{
"role": "assistant",
"content": [
{"type": "text", "text": "I'll help you with that."},
{
"type": "tool_use",
"id": "toolu_mixed",
"name": "search",
"input": {"q": "test"},
},
],
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_mixed",
"content": "Found results",
}
],
},
{"role": "assistant", "content": "Here are the results..."},
]
# Start from tool_result
recent = [all_msgs[1], all_msgs[2]]
start_index = 1
result = _ensure_tool_pairs_intact(recent, all_msgs, start_index)
# Should prepend the assistant message with tool_use
assert len(result) == 3
assert result[0]["content"][0]["type"] == "text"
assert result[0]["content"][1]["type"] == "tool_use"
class TestCompressContext:
"""Test the async compress_context function."""
@pytest.mark.asyncio
async def test_no_compression_needed(self):
"""Test messages under limit return without compression."""
messages = [
{"role": "system", "content": "You are helpful."},
{"role": "user", "content": "Hello!"},
]
result = await compress_context(messages, target_tokens=100000)
assert isinstance(result, CompressResult)
assert result.was_compacted is False
assert len(result.messages) == 2
assert result.error is None
@pytest.mark.asyncio
async def test_truncation_without_client(self):
"""Test that truncation works without LLM client."""
long_content = "x" * 50000
messages = [
{"role": "system", "content": "System"},
{"role": "user", "content": long_content},
{"role": "assistant", "content": "Response"},
]
result = await compress_context(
messages, target_tokens=1000, client=None, reserve=100
)
assert result.was_compacted is True
# Should have truncated without summarization
assert result.messages_summarized == 0
@pytest.mark.asyncio
async def test_with_mocked_llm_client(self):
"""Test summarization with mocked LLM client."""
# Create many messages to trigger summarization
messages = [{"role": "system", "content": "System prompt"}]
for i in range(30):
messages.append({"role": "user", "content": f"User message {i} " * 100})
messages.append(
{"role": "assistant", "content": f"Assistant response {i} " * 100}
)
# Mock the AsyncOpenAI client
mock_client = AsyncMock()
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Summary of conversation"
mock_client.with_options.return_value.chat.completions.create = AsyncMock(
return_value=mock_response
)
result = await compress_context(
messages,
target_tokens=5000,
client=mock_client,
keep_recent=5,
reserve=500,
)
assert result.was_compacted is True
# Should have attempted summarization
assert mock_client.with_options.called or result.messages_summarized > 0
@pytest.mark.asyncio
async def test_preserves_tool_pairs(self):
"""Test that tool call/response pairs stay together."""
messages = [
{"role": "system", "content": "System"},
{"role": "user", "content": "Do something"},
{
"role": "assistant",
"tool_calls": [
{"id": "call_1", "type": "function", "function": {"name": "func"}}
],
},
{"role": "tool", "tool_call_id": "call_1", "content": "Result " * 1000},
{"role": "assistant", "content": "Done!"},
]
result = await compress_context(
messages, target_tokens=500, client=None, reserve=50
)
# Check that if tool response exists, its call exists too
tool_call_ids = set()
tool_response_ids = set()
for msg in result.messages:
if "tool_calls" in msg:
for tc in msg["tool_calls"]:
tool_call_ids.add(tc["id"])
if msg.get("role") == "tool":
tool_response_ids.add(msg.get("tool_call_id"))
# All tool responses should have their calls
assert tool_response_ids <= tool_call_ids
@pytest.mark.asyncio
async def test_returns_error_when_cannot_compress(self):
"""Test that error is returned when compression fails."""
# Single huge message that can't be compressed enough
messages = [
{"role": "user", "content": "x" * 100000},
]
result = await compress_context(
messages, target_tokens=100, client=None, reserve=50
)
# Should have an error since we can't get below 100 tokens
assert result.error is not None
assert result.was_compacted is True
@pytest.mark.asyncio
async def test_empty_messages(self):
"""Test that empty messages list returns early without error."""
result = await compress_context([], target_tokens=1000)
assert result.messages == []
assert result.token_count == 0
assert result.was_compacted is False
assert result.error is None
class TestRemoveOrphanToolResponses:
"""Test _remove_orphan_tool_responses helper function."""
def test_removes_openai_orphan(self):
"""Test removal of orphan OpenAI tool response."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{"role": "tool", "tool_call_id": "call_orphan", "content": "result"},
{"role": "user", "content": "Hello"},
]
orphan_ids = {"call_orphan"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
assert len(result) == 1
assert result[0]["role"] == "user"
def test_keeps_valid_openai_tool(self):
"""Test that valid OpenAI tool responses are kept."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{"role": "tool", "tool_call_id": "call_valid", "content": "result"},
]
orphan_ids = {"call_other"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
assert len(result) == 1
assert result[0]["tool_call_id"] == "call_valid"
def test_filters_anthropic_mixed_blocks(self):
"""Test filtering individual orphan blocks from Anthropic message with mixed valid/orphan."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_valid",
"content": "valid result",
},
{
"type": "tool_result",
"tool_use_id": "toolu_orphan",
"content": "orphan result",
},
],
},
]
orphan_ids = {"toolu_orphan"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
assert len(result) == 1
# Should only have the valid tool_result, orphan filtered out
assert len(result[0]["content"]) == 1
assert result[0]["content"][0]["tool_use_id"] == "toolu_valid"
def test_removes_anthropic_all_orphan(self):
"""Test removal of Anthropic message when all tool_results are orphans."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_orphan1",
"content": "result1",
},
{
"type": "tool_result",
"tool_use_id": "toolu_orphan2",
"content": "result2",
},
],
},
]
orphan_ids = {"toolu_orphan1", "toolu_orphan2"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
# Message should be completely removed since no content left
assert len(result) == 0
def test_preserves_non_tool_messages(self):
"""Test that non-tool messages are preserved."""
from backend.util.prompt import _remove_orphan_tool_responses
messages = [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
]
orphan_ids = {"some_id"}
result = _remove_orphan_tool_responses(messages, orphan_ids)
assert result == messages
class TestCompressResultDataclass:
"""Test CompressResult dataclass."""
def test_default_values(self):
"""Test default values are set correctly."""
result = CompressResult(
messages=[{"role": "user", "content": "test"}],
token_count=10,
was_compacted=False,
)
assert result.error is None
assert result.original_token_count == 0 # Defaults to 0, not None
assert result.messages_summarized == 0
assert result.messages_dropped == 0
def test_all_fields(self):
"""Test all fields can be set."""
result = CompressResult(
messages=[{"role": "user", "content": "test"}],
token_count=100,
was_compacted=True,
error="Some error",
original_token_count=500,
messages_summarized=10,
messages_dropped=5,
)
assert result.token_count == 100
assert result.was_compacted is True
assert result.error == "Some error"
assert result.original_token_count == 500
assert result.messages_summarized == 10
assert result.messages_dropped == 5

View File

@@ -1,32 +0,0 @@
"""Validation utilities."""
import re
_UUID_V4_PATTERN = re.compile(
r"[a-f0-9]{8}-[a-f0-9]{4}-4[a-f0-9]{3}-[89ab][a-f0-9]{3}-[a-f0-9]{12}",
re.IGNORECASE,
)
def is_uuid_v4(text: str) -> bool:
"""Check if text is a valid UUID v4.
Args:
text: String to validate
Returns:
True if the text is a valid UUID v4, False otherwise
"""
return bool(_UUID_V4_PATTERN.fullmatch(text.strip()))
def extract_uuids(text: str) -> list[str]:
"""Extract all UUID v4 strings from text.
Args:
text: String to search for UUIDs
Returns:
List of unique UUIDs found (lowercase)
"""
return list({m.lower() for m in _UUID_V4_PATTERN.findall(text)})

View File

@@ -193,6 +193,7 @@ Below is a comprehensive list of all available blocks, categorized by their prim
| [Get Current Time](block-integrations/text.md#get-current-time) | This block outputs the current time | | [Get Current Time](block-integrations/text.md#get-current-time) | This block outputs the current time |
| [Match Text Pattern](block-integrations/text.md#match-text-pattern) | Matches text against a regex pattern and forwards data to positive or negative output based on the match | | [Match Text Pattern](block-integrations/text.md#match-text-pattern) | Matches text against a regex pattern and forwards data to positive or negative output based on the match |
| [Text Decoder](block-integrations/text.md#text-decoder) | Decodes a string containing escape sequences into actual text | | [Text Decoder](block-integrations/text.md#text-decoder) | Decodes a string containing escape sequences into actual text |
| [Text Encoder](block-integrations/text.md#text-encoder) | Encodes a string by converting special characters into escape sequences |
| [Text Replace](block-integrations/text.md#text-replace) | This block is used to replace a text with a new text | | [Text Replace](block-integrations/text.md#text-replace) | This block is used to replace a text with a new text |
| [Text Split](block-integrations/text.md#text-split) | This block is used to split a text into a list of strings | | [Text Split](block-integrations/text.md#text-split) | This block is used to split a text into a list of strings |
| [Word Character Count](block-integrations/text.md#word-character-count) | Counts the number of words and characters in a given text | | [Word Character Count](block-integrations/text.md#word-character-count) | Counts the number of words and characters in a given text |

View File

@@ -380,6 +380,42 @@ This is useful when working with data from APIs or files where escape sequences
--- ---
## Text Encoder
### What it is
Encodes a string by converting special characters into escape sequences
### How it works
<!-- MANUAL: how_it_works -->
The Text Encoder takes the input string and applies Python's `unicode_escape` encoding (equivalent to `codecs.encode(text, "unicode_escape").decode("utf-8")`) to transform special characters like newlines, tabs, and backslashes into their escaped forms.
The block relies on the input schema to ensure the value is a string; non-string inputs are rejected by validation, and any encoding failures surface as block errors. Non-ASCII characters are emitted as `\uXXXX` sequences, which is useful for ASCII-only payloads.
<!-- END MANUAL -->
### Inputs
| Input | Description | Type | Required |
|-------|-------------|------|----------|
| text | A string containing special characters to be encoded | str | Yes |
### Outputs
| Output | Description | Type |
|--------|-------------|------|
| error | Error message if encoding fails | str |
| encoded_text | The encoded text with special characters converted to escape sequences | str |
### Possible use case
<!-- MANUAL: use_case -->
**JSON Payload Preparation**: Encode multiline or quoted text before embedding it in JSON string fields to ensure proper escaping.
**Config/ENV Generation**: Convert template text into escaped strings for `.env` or YAML values that require special character handling.
**Snapshot Fixtures**: Produce stable escaped strings for golden files or API tests where consistent text representation is needed.
<!-- END MANUAL -->
---
## Text Replace ## Text Replace
### What it is ### What it is