fix(platform): harden retry logic — walk exception chains, handle write failures, add 20 edge-case tests

- _is_prompt_too_long walks __cause__/__context__ to detect wrapped errors
- write_transcript_to_tempfile failure now sets transcript_caused_error
- _run_compression catches all exceptions for truncation fallback
- _flatten_tool_result_content handles None content
- Remove circular import workaround (ChatConfig as top-level import)
- Add 20 new tests: exception chains, malformed transcripts, unicode,
  tool use/result pairs, session rollback, write failure propagation
This commit is contained in:
Zamil Majdy
2026-03-14 10:34:21 +07:00
parent 8c8e596302
commit 5dc8d6c848
4 changed files with 313 additions and 19 deletions

View File

@@ -68,6 +68,38 @@ class TestIsPromptTooLong:
err = RuntimeError("prompt is too long")
assert _is_prompt_too_long(err) is True
def test_walks_cause_chain(self):
"""_is_prompt_too_long walks __cause__ to find wrapped errors."""
inner = Exception("prompt is too long: 250000 > 200000")
outer = RuntimeError("SDK process failed")
outer.__cause__ = inner
assert _is_prompt_too_long(outer) is True
def test_walks_context_chain(self):
"""_is_prompt_too_long walks __context__ for implicit chaining."""
inner = Exception("context_length_exceeded")
outer = RuntimeError("during handling")
outer.__context__ = inner
assert _is_prompt_too_long(outer) is True
def test_no_infinite_loop_on_circular_chain(self):
"""Circular exception chains terminate without hanging."""
a = Exception("error a")
b = Exception("error b")
a.__cause__ = b
b.__cause__ = a
assert _is_prompt_too_long(a) is False
def test_deep_chain(self):
"""Deeply nested exception chain is walked."""
bottom = Exception("request too large")
current = bottom
for i in range(10):
wrapper = RuntimeError(f"layer {i}")
wrapper.__cause__ = current
current = wrapper
assert _is_prompt_too_long(current) is True
def test_patterns_constant_is_tuple(self):
"""Verify the patterns constant exists and is iterable."""
assert len(_PROMPT_TOO_LONG_PATTERNS) >= 2
@@ -145,6 +177,16 @@ class TestFlattenToolResultContent:
def test_raw_string(self):
assert _flatten_tool_result_content(["raw"]) == "raw"
def test_tool_result_with_none_content(self):
"""tool_result with content=None should produce empty string."""
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": None}]
assert _flatten_tool_result_content(blocks) == ""
def test_tool_result_with_empty_list_content(self):
"""tool_result with content=[] should produce empty string."""
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": []}]
assert _flatten_tool_result_content(blocks) == ""
def test_empty(self):
assert _flatten_tool_result_content([]) == ""
@@ -238,6 +280,93 @@ class TestTranscriptToMessages:
assert len(messages) == 1
assert messages[0]["content"] == "tool output"
def test_malformed_json_lines_skipped(self):
"""Malformed JSON lines in transcript are silently skipped."""
lines = [
_make_entry("user", "user", "Hello"),
"this is not valid json",
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 2
def test_empty_lines_skipped(self):
"""Empty lines and whitespace-only lines are skipped."""
lines = [
_make_entry("user", "user", "Hello"),
"",
" ",
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 2
def test_unicode_content_preserved(self):
"""Unicode characters survive transcript roundtrip."""
lines = [
_make_entry("user", "user", "Hello 你好 🌍"),
_make_entry(
"assistant",
"assistant",
[{"type": "text", "text": "Bonjour 日本語 émojis 🎉"}],
),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert messages[0]["content"] == "Hello 你好 🌍"
assert messages[1]["content"] == "Bonjour 日本語 émojis 🎉"
def test_entry_without_role_skipped(self):
"""Entries with missing role in message are skipped."""
entry_no_role = json.dumps(
{
"type": "user",
"uuid": str(uuid4()),
"parentUuid": None,
"message": {"content": "no role here"},
}
)
lines = [
entry_no_role,
_make_entry("user", "user", "Hello"),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 1
assert messages[0]["content"] == "Hello"
def test_tool_use_and_result_pairs(self):
"""Tool use + tool result pairs are properly flattened."""
lines = [
_make_entry(
"assistant",
"assistant",
[
{"type": "text", "text": "Let me check."},
{"type": "tool_use", "name": "read_file", "input": {"path": "/x"}},
],
),
_make_entry(
"user",
"user",
[
{
"type": "tool_result",
"tool_use_id": "abc",
"content": [{"type": "text", "text": "file contents"}],
}
],
),
]
content = "\n".join(lines) + "\n"
messages = _transcript_to_messages(content)
assert len(messages) == 2
assert "Let me check." in messages[0]["content"]
assert "[tool_use: read_file]" in messages[0]["content"]
assert messages[1]["content"] == "file contents"
# ---------------------------------------------------------------------------
# _messages_to_transcript

View File

@@ -27,6 +27,8 @@ from backend.util import json
from .service import _is_prompt_too_long
from .transcript import (
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_transcript_to_messages,
compact_transcript,
@@ -633,15 +635,29 @@ class TestRetryEdgeCases:
"""Edge cases for the retry logic components."""
def test_is_prompt_too_long_with_nested_exception(self):
"""Chained exception with prompt-too-long in cause."""
"""Chained exception with prompt-too-long in __cause__ is detected."""
inner = Exception("prompt is too long: 250000 > 200000")
# The function checks str(err), not __cause__
outer = RuntimeError("SDK error")
outer.__cause__ = inner
# Only checks the outer exception message
assert _is_prompt_too_long(outer) is False
# Walks the exception chain to find prompt-too-long in __cause__
assert _is_prompt_too_long(outer) is True
assert _is_prompt_too_long(inner) is True
def test_is_prompt_too_long_with_context_exception(self):
"""Chained exception with prompt-too-long in __context__ is detected."""
inner = Exception("context_length_exceeded")
outer = RuntimeError("wrapper")
outer.__context__ = inner
assert _is_prompt_too_long(outer) is True
def test_is_prompt_too_long_no_infinite_loop(self):
"""Circular exception chain doesn't cause infinite loop."""
a = Exception("error a")
b = Exception("error b")
a.__cause__ = b
b.__cause__ = a # circular
assert _is_prompt_too_long(a) is False
def test_is_prompt_too_long_case_insensitive(self):
"""Pattern matching must be case-insensitive."""
assert _is_prompt_too_long(Exception("PROMPT IS TOO LONG")) is True
@@ -743,3 +759,139 @@ class TestRetryEdgeCases:
entries = [json.loads(line) for line in output.strip().split("\n")]
for i in range(1, len(entries)):
assert entries[i]["parentUuid"] == entries[i - 1]["uuid"]
class TestRetryStateReset:
"""Verify state is properly reset between retry attempts."""
def test_session_messages_rollback_on_retry(self):
"""Simulate session.messages rollback as done in service.py."""
session_messages = ["msg1", "msg2"] # pre-existing
pre_attempt_count = len(session_messages)
# Simulate streaming adding partial messages
session_messages.append("partial_assistant")
session_messages.append("tool_result")
assert len(session_messages) == 4
# Rollback (as done at line 1410 in service.py)
session_messages = session_messages[:pre_attempt_count]
assert len(session_messages) == 2
assert session_messages == ["msg1", "msg2"]
def test_write_transcript_failure_sets_error_flag(self):
"""When write_transcript_to_tempfile fails, transcript_caused_error
must be set True to prevent uploading stale data."""
# Simulate the logic from service.py lines 1012-1020
transcript_caused_error = False
use_resume = True
resume_file = None # write_transcript_to_tempfile returned None
if not resume_file:
use_resume = False
transcript_caused_error = True
assert transcript_caused_error is True
assert use_resume is False
@pytest.mark.asyncio
async def test_compact_returns_none_preserves_error_flag(self):
"""When compaction returns None, transcript_caused_error is set."""
transcript = _build_transcript([("user", "A"), ("assistant", "B")])
transcript_caused_error = False
with (
patch(
"backend.copilot.config.ChatConfig",
return_value=type(
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
)(),
),
patch(
"backend.copilot.sdk.transcript._run_compression",
new_callable=AsyncMock,
side_effect=RuntimeError("boom"),
),
):
compacted = await compact_transcript(transcript)
# compact_transcript returns None on failure
assert compacted is None
# Caller sets transcript_caused_error
if not compacted:
transcript_caused_error = True
assert transcript_caused_error is True
class TestTranscriptEdgeCases:
"""Edge cases in transcript parsing and generation."""
def test_transcript_with_very_long_content(self):
"""Large content doesn't corrupt the transcript format."""
big_content = "x" * 100_000
pairs = [("user", big_content), ("assistant", "ok")]
transcript = _build_transcript(pairs)
msgs = _transcript_to_messages(transcript)
assert len(msgs) == 2
assert msgs[0]["content"] == big_content
def test_transcript_with_special_json_chars(self):
"""Content with JSON special characters is handled."""
pairs = [
("user", 'Hello "world" with \\backslash and \nnewline'),
("assistant", "Tab\there and null\x00byte"),
]
transcript = _build_transcript(pairs)
msgs = _transcript_to_messages(transcript)
assert len(msgs) == 2
assert '"world"' in msgs[0]["content"]
def test_messages_to_transcript_empty_content(self):
"""Messages with empty content produce valid transcript."""
messages = [
{"role": "user", "content": ""},
{"role": "assistant", "content": ""},
]
result = _messages_to_transcript(messages)
assert validate_transcript(result)
restored = _transcript_to_messages(result)
assert len(restored) == 2
def test_consecutive_same_role_messages(self):
"""Multiple consecutive user or assistant messages are preserved."""
messages = [
{"role": "user", "content": "First"},
{"role": "user", "content": "Second"},
{"role": "assistant", "content": "Reply"},
]
result = _messages_to_transcript(messages)
restored = _transcript_to_messages(result)
assert len(restored) == 3
assert restored[0]["content"] == "First"
assert restored[1]["content"] == "Second"
def test_flatten_assistant_with_only_tool_use(self):
"""Assistant message with only tool_use blocks (no text)."""
blocks = [
{"type": "tool_use", "name": "bash", "input": {"cmd": "ls"}},
{"type": "tool_use", "name": "read", "input": {"path": "/f"}},
]
result = _flatten_assistant_content(blocks)
assert "[tool_use: bash]" in result
assert "[tool_use: read]" in result
def test_flatten_tool_result_nested_image(self):
"""Tool result containing image blocks uses placeholder."""
blocks = [
{
"type": "tool_result",
"tool_use_id": "x",
"content": [
{"type": "image", "source": {"type": "base64", "data": "abc"}},
{"type": "text", "text": "screenshot above"},
],
}
]
result = _flatten_tool_result_content(blocks)
# json.dumps fallback for image, text for text
assert "screenshot above" in result

View File

@@ -97,9 +97,20 @@ _PROMPT_TOO_LONG_PATTERNS = (
def _is_prompt_too_long(err: BaseException) -> bool:
"""Return True if *err* indicates the prompt exceeds the model's limit."""
msg = str(err).lower()
return any(p in msg for p in _PROMPT_TOO_LONG_PATTERNS)
"""Return True if *err* indicates the prompt exceeds the model's limit.
Walks the exception chain (``__cause__`` / ``__context__``) so that
wrapped errors (e.g. ``RuntimeError`` wrapping an API error) are
detected too.
"""
seen: set[int] = set()
current: BaseException | None = err
while current is not None and id(current) not in seen:
seen.add(id(current))
if any(p in str(current).lower() for p in _PROMPT_TOO_LONG_PATTERNS):
return True
current = current.__cause__ or current.__context__
return False
def _setup_langfuse_otel() -> None:
@@ -998,7 +1009,17 @@ async def stream_chat_completion_sdk(
resume_file = write_transcript_to_tempfile(
compacted, session_id, sdk_cwd
)
use_resume = bool(resume_file)
if not resume_file:
logger.warning(
"%s Failed to write compacted transcript, "
"dropping transcript",
log_prefix,
)
transcript_builder = TranscriptBuilder()
use_resume = False
transcript_caused_error = True
else:
use_resume = True
transcript_msg_count = 0
else:
logger.warning(

View File

@@ -19,17 +19,14 @@ import shutil
import time
from dataclasses import dataclass
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import uuid4
import openai
from backend.copilot.config import ChatConfig
from backend.util import json
from backend.util.prompt import CompressResult, compress_context
if TYPE_CHECKING:
from backend.copilot.config import ChatConfig
logger = logging.getLogger(__name__)
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
@@ -621,7 +618,7 @@ def _flatten_tool_result_content(blocks: list) -> str:
str_parts: list[str] = []
for block in blocks:
if isinstance(block, dict) and block.get("type") == "tool_result":
inner = block.get("content", "")
inner = block.get("content") or ""
if isinstance(inner, list):
for sub in inner:
if isinstance(sub, dict):
@@ -722,7 +719,7 @@ async def _run_compression(
api_key=cfg.api_key, base_url=cfg.base_url, timeout=30.0
) as client:
return await compress_context(messages=messages, model=model, client=client)
except (openai.APIError, openai.APITimeoutError, OSError) as e:
except Exception as e:
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
return await compress_context(messages=messages, model=model, client=None)
@@ -738,11 +735,6 @@ async def compact_transcript(
Returns the compacted JSONL string, or ``None`` on failure.
"""
# Local import: ChatConfig is in TYPE_CHECKING for annotations but
# needs a runtime import here. Top-level would create a circular
# dependency (config → … → transcript).
from backend.copilot.config import ChatConfig
cfg = ChatConfig()
messages = _transcript_to_messages(content)
if len(messages) < 2: