feat(block): Introduce context-window aware prompt compaction for LLM & SmartDecision blocks (#10252)

Calling LLM using the current block sometimes can break due to the high
context window.
A prompt compaction algorithm is applied (enabled by default) to make
sure the sent prompt is within a context window limit.


### Changes 🏗️

````
Heuristics
--------
* Prefer shrinking the content rather than truncating the conversation.
* If the conversation content is compacted and it's still not enough, then reduce the conversation list.
* The rest of the implementation is adjusted to minimize the LLM call breaking.

Strategy
--------
1. **Token-aware truncation** – progressively halve a per-message cap
   (`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
   *content* of every message except the first and last.  Tool shells
   are included: we keep the envelope but shorten huge payloads.
2. **Middle-out deletion** – if still over the limit, delete the whole
   messages working outward from the centre, **skipping** any message
   that contains ``tool_calls`` or has ``role == "tool"``.
3. **Last-chance trim** – if still too big, truncate the *first* and
   *last* message bodies down to `floor_cap` tokens.
4. If the prompt is *still* too large:
     • raise ``ValueError``      when ``lossy_ok == False`` (default)
     • return the partially-trimmed prompt when ``lossy_ok == True``
````

### Checklist 📋

#### For code changes:
- [x] I have clearly listed my changes in the PR description
- [x] I have made a test plan
- [x] I have tested my changes according to the test plan:
  <!-- Put your test plan here: -->
- [x] Run an SDM block in a loop until it hits 200000 tokens using the
open-ai O3 model.
This commit is contained in:
Zamil Majdy
2025-06-27 08:07:50 -07:00
committed by GitHub
parent c01beaf003
commit c4056cbae9
6 changed files with 236 additions and 29 deletions

View File

@@ -23,6 +23,7 @@ from backend.data.model import (
from backend.integrations.providers import ProviderName
from backend.util import json
from backend.util.logging import TruncatedLogger
from backend.util.prompt import compress_prompt, estimate_token_count
from backend.util.text import TextFormatter
logger = TruncatedLogger(logging.getLogger(__name__), "[LLM-Block]")
@@ -306,13 +307,6 @@ def convert_openai_tool_fmt_to_anthropic(
return anthropic_tools
def estimate_token_count(prompt_messages: list[dict]) -> int:
char_count = sum(len(str(msg.get("content", ""))) for msg in prompt_messages)
message_overhead = len(prompt_messages) * 4
estimated_tokens = (char_count // 4) + message_overhead
return int(estimated_tokens * 1.2)
async def llm_call(
credentials: APIKeyCredentials,
llm_model: LlmModel,
@@ -321,7 +315,8 @@ async def llm_call(
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
parallel_tool_calls: bool | None = None,
parallel_tool_calls=None,
compress_prompt_to_fit: bool = True,
) -> LLMResponse:
"""
Make a call to a language model.
@@ -344,10 +339,17 @@ async def llm_call(
- completion_tokens: The number of tokens used in the completion.
"""
provider = llm_model.metadata.provider
context_window = llm_model.context_window
if compress_prompt_to_fit:
prompt = compress_prompt(
messages=prompt,
target_tokens=llm_model.context_window // 2,
lossy_ok=True,
)
# Calculate available tokens based on context window and input length
estimated_input_tokens = estimate_token_count(prompt)
context_window = llm_model.context_window
model_max_output = llm_model.max_output_tokens or int(2**15)
user_max = max_tokens or model_max_output
available_tokens = max(context_window - estimated_input_tokens, 0)
@@ -358,14 +360,10 @@ async def llm_call(
oai_client = openai.AsyncOpenAI(api_key=credentials.api_key.get_secret_value())
response_format = None
if llm_model in [LlmModel.O1_MINI, LlmModel.O1_PREVIEW]:
sys_messages = [p["content"] for p in prompt if p["role"] == "system"]
usr_messages = [p["content"] for p in prompt if p["role"] != "system"]
prompt = [
{"role": "user", "content": "\n".join(sys_messages)},
{"role": "user", "content": "\n".join(usr_messages)},
]
elif json_format:
if llm_model.startswith("o") or parallel_tool_calls is None:
parallel_tool_calls = openai.NOT_GIVEN
if json_format:
response_format = {"type": "json_object"}
response = await oai_client.chat.completions.create(
@@ -374,9 +372,7 @@ async def llm_call(
response_format=response_format, # type: ignore
max_completion_tokens=max_tokens,
tools=tools_param, # type: ignore
parallel_tool_calls=(
openai.NOT_GIVEN if parallel_tool_calls is None else parallel_tool_calls
),
parallel_tool_calls=parallel_tool_calls,
)
if response.choices[0].message.tool_calls:
@@ -699,7 +695,11 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
default=None,
description="The maximum number of tokens to generate in the chat completion.",
)
compress_prompt_to_fit: bool = SchemaField(
advanced=True,
default=True,
description="Whether to compress the prompt to fit within the model's context window.",
)
ollama_host: str = SchemaField(
advanced=True,
default="localhost:11434",
@@ -757,6 +757,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
llm_model: LlmModel,
prompt: list[dict],
json_format: bool,
compress_prompt_to_fit: bool,
max_tokens: int | None,
tools: list[dict] | None = None,
ollama_host: str = "localhost:11434",
@@ -774,6 +775,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
max_tokens=max_tokens,
tools=tools,
ollama_host=ollama_host,
compress_prompt_to_fit=compress_prompt_to_fit,
)
async def run(
@@ -832,7 +834,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
except JSONDecodeError as e:
return f"JSON decode error: {e}"
logger.info(f"LLM request: {prompt}")
logger.debug(f"LLM request: {prompt}")
retry_prompt = ""
llm_model = input_data.model
@@ -842,6 +844,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
credentials=credentials,
llm_model=llm_model,
prompt=prompt,
compress_prompt_to_fit=input_data.compress_prompt_to_fit,
json_format=bool(input_data.expected_format),
ollama_host=input_data.ollama_host,
max_tokens=input_data.max_tokens,
@@ -853,7 +856,7 @@ class AIStructuredResponseGeneratorBlock(AIBlockBase):
output_token_count=llm_response.completion_tokens,
)
)
logger.info(f"LLM attempt-{retry_count} response: {response_text}")
logger.debug(f"LLM attempt-{retry_count} response: {response_text}")
if input_data.expected_format:

View File

@@ -85,7 +85,7 @@ def _get_tool_responses(entry: dict[str, Any]) -> list[str]:
return tool_call_ids
def _create_tool_response(call_id: str, output: dict[str, Any]) -> dict[str, Any]:
def _create_tool_response(call_id: str, output: Any) -> dict[str, Any]:
"""
Create a tool response message for either OpenAI or Anthropics,
based on the tool_id format.
@@ -212,6 +212,15 @@ class SmartDecisionMakerBlock(Block):
"link like the output of `StoreValue` or `AgentInput` block"
)
# Check that both conversation_history and last_tool_output are connected together
if any(link.sink_name == "conversation_history" for link in links) != any(
link.sink_name == "last_tool_output" for link in links
):
raise ValueError(
"Last Tool Output is needed when Conversation History is used, "
"and vice versa. Please connect both inputs together."
)
return missing_links
@classmethod
@@ -222,8 +231,15 @@ class SmartDecisionMakerBlock(Block):
conversation_history = data.get("conversation_history", [])
pending_tool_calls = get_pending_tool_calls(conversation_history)
last_tool_output = data.get("last_tool_output")
if not last_tool_output and pending_tool_calls:
# Tool call is pending, wait for the tool output to be provided.
if last_tool_output is None and pending_tool_calls:
return {"last_tool_output"}
# No tool call is pending, wait for the conversation history to be updated.
if last_tool_output is not None and not pending_tool_calls:
return {"conversation_history"}
return set()
class Output(BlockSchema):
@@ -433,7 +449,7 @@ class SmartDecisionMakerBlock(Block):
prompt = [json.to_dict(p) for p in input_data.conversation_history if p]
pending_tool_calls = get_pending_tool_calls(input_data.conversation_history)
if pending_tool_calls and not input_data.last_tool_output:
if pending_tool_calls and input_data.last_tool_output is None:
raise ValueError(f"Tool call requires an output for {pending_tool_calls}")
# Prefill all missing tool calls with the last tool output/
@@ -497,7 +513,7 @@ class SmartDecisionMakerBlock(Block):
max_tokens=input_data.max_tokens,
tools=tool_functions,
ollama_host=input_data.ollama_host,
parallel_tool_calls=True if input_data.multiple_tool_calls else None,
parallel_tool_calls=input_data.multiple_tool_calls,
)
if not response.tool_calls:

View File

@@ -0,0 +1,181 @@
import json
from copy import deepcopy
from tiktoken import encoding_for_model
# ---------------------------------------------------------------------------#
# INTERNAL UTILITIES #
# ---------------------------------------------------------------------------#
def _tok_len(text: str, enc) -> int:
"""True token length of *text* in tokenizer *enc* (no wrapper cost)."""
return len(enc.encode(text))
def _msg_tokens(msg: dict, enc) -> int:
"""
OpenAI counts ≈3 wrapper tokens per chat message, plus 1 if "name"
is present, plus the tokenised content length.
"""
WRAPPER = 3 + (1 if "name" in msg else 0)
return WRAPPER + _tok_len(msg.get("content") or "", enc)
def _truncate_middle_tokens(text: str, enc, max_tok: int) -> str:
"""
Return *text* shortened to ≈max_tok tokens by keeping the head & tail
and inserting an ellipsis token in the middle.
"""
ids = enc.encode(text)
if len(ids) <= max_tok:
return text # nothing to do
# Split the allowance between the two ends:
head = max_tok // 2 - 1 # -1 for the ellipsis
tail = max_tok - head - 1
mid = enc.encode("")
return enc.decode(ids[:head] + mid + ids[-tail:])
# ---------------------------------------------------------------------------#
# PUBLIC API #
# ---------------------------------------------------------------------------#
def compress_prompt(
messages: list[dict],
target_tokens: int,
*,
model: str = "gpt-4o",
reserve: int = 2_048,
start_cap: int = 8_192,
floor_cap: int = 128,
lossy_ok: bool = True,
) -> list[dict]:
"""
Shrink *messages* so that::
token_count(prompt) + reserve ≤ target_tokens
Strategy
--------
1. **Token-aware truncation** progressively halve a per-message cap
(`start_cap`, `start_cap/2`, … `floor_cap`) and apply it to the
*content* of every message except the first and last. Tool shells
are included: we keep the envelope but shorten huge payloads.
2. **Middle-out deletion** if still over the limit, delete whole
messages working outward from the centre, **skipping** any message
that contains ``tool_calls`` or has ``role == "tool"``.
3. **Last-chance trim** if still too big, truncate the *first* and
*last* message bodies down to `floor_cap` tokens.
4. If the prompt is *still* too large:
• raise ``ValueError`` when ``lossy_ok == False`` (default)
• return the partially-trimmed prompt when ``lossy_ok == True``
Parameters
----------
messages Complete chat history (will be deep-copied).
model Model name; passed to tiktoken to pick the right
tokenizer (gpt-4o → 'o200k_base', others fallback).
target_tokens Hard ceiling for prompt size **excluding** the model's
forthcoming answer.
reserve How many tokens you want to leave available for that
answer (`max_tokens` in your subsequent completion call).
start_cap Initial per-message truncation ceiling (tokens).
floor_cap Lowest cap we'll accept before moving to deletions.
lossy_ok If *True* return best-effort prompt instead of raising
after all trim passes have been exhausted.
Returns
-------
list[dict] A *new* messages list that abides by the rules above.
"""
enc = encoding_for_model(model) # best-match tokenizer
msgs = deepcopy(messages) # never mutate caller
def total_tokens() -> int:
"""Current size of *msgs* in tokens."""
return sum(_msg_tokens(m, enc) for m in msgs)
original_token_count = total_tokens()
if original_token_count + reserve <= target_tokens:
return msgs
# ---- STEP 0 : normalise content --------------------------------------
# Convert non-string payloads to strings so token counting is coherent.
for m in msgs[1:-1]: # keep the first & last intact
if not isinstance(m.get("content"), str) and m.get("content") is not None:
# Reasonable 20k-char ceiling prevents pathological blobs
content_str = json.dumps(m["content"], separators=(",", ":"))
if len(content_str) > 20_000:
content_str = _truncate_middle_tokens(content_str, enc, 20_000)
m["content"] = content_str
# ---- STEP 1 : token-aware truncation ---------------------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for m in msgs[1:-1]: # keep first & last intact
if _tok_len(m.get("content") or "", enc) > cap:
m["content"] = _truncate_middle_tokens(m["content"], enc, cap)
cap //= 2 # tighten the screw
# ---- STEP 2 : middle-out deletion -----------------------------------
while total_tokens() + reserve > target_tokens and len(msgs) > 2:
centre = len(msgs) // 2
# Build a symmetrical centre-out index walk: centre, centre+1, centre-1, ...
order = [centre] + [
i
for pair in zip(range(centre + 1, len(msgs) - 1), range(centre - 1, 0, -1))
for i in pair
]
removed = False
for i in order:
msg = msgs[i]
if "tool_calls" in msg or msg.get("role") == "tool":
continue # protect tool shells
del msgs[i]
removed = True
break
if not removed: # nothing more we can drop
break
# ---- STEP 3 : final safety-net trim on first & last ------------------
cap = start_cap
while total_tokens() + reserve > target_tokens and cap >= floor_cap:
for idx in (0, -1): # first and last
text = msgs[idx].get("content") or ""
if _tok_len(text, enc) > cap:
msgs[idx]["content"] = _truncate_middle_tokens(text, enc, cap)
cap //= 2 # tighten the screw
# ---- STEP 4 : success or fail-gracefully -----------------------------
if total_tokens() + reserve > target_tokens and not lossy_ok:
raise ValueError(
"compress_prompt: prompt still exceeds budget "
f"({total_tokens() + reserve} > {target_tokens})."
)
return msgs
def estimate_token_count(
messages: list[dict],
*,
model: str = "gpt-4o",
) -> int:
"""
Return the true token count of *messages* when encoded for *model*.
Parameters
----------
messages Complete chat history.
model Model name; passed to tiktoken to pick the right
tokenizer (gpt-4o → 'o200k_base', others fallback).
Returns
-------
int Token count.
"""
enc = encoding_for_model(model) # best-match tokenizer
return sum(_msg_tokens(m, enc) for m in messages)

View File

@@ -430,7 +430,13 @@ class Requests:
) as response:
if self.raise_for_status:
response.raise_for_status()
try:
response.raise_for_status()
except ClientResponseError as e:
body = await response.read()
raise Exception(
f"HTTP {response.status} Error: {response.reason}, Body: {body.decode(errors='replace')}"
) from e
# If allowed and a redirect is received, follow the redirect manually
if allow_redirects and response.status in (301, 302, 303, 307, 308):

View File

@@ -6380,4 +6380,4 @@ cffi = ["cffi (>=1.11)"]
[metadata]
lock-version = "2.1"
python-versions = ">=3.10,<3.13"
content-hash = "35f6516ea0e72a0b4381842f4a6ad6d01ed263e01baabb09e554f9a63ca8b175"
content-hash = "bd117a21d817a2a735ed923c383713dd08469938ef5f7d07c4222da1acca2b5c"

View File

@@ -69,6 +69,7 @@ zerobouncesdk = "^1.1.1"
pytest-snapshot = "^0.9.0"
aiofiles = "^24.1.0"
pyclamd = "^0.4.0"
tiktoken = "^0.9.0"
[tool.poetry.group.dev.dependencies]
aiohappyeyeballs = "^2.6.1"