mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(platform): harden retry logic — walk exception chains, handle write failures, add 20 edge-case tests
- _is_prompt_too_long walks __cause__/__context__ to detect wrapped errors - write_transcript_to_tempfile failure now sets transcript_caused_error - _run_compression catches all exceptions for truncation fallback - _flatten_tool_result_content handles None content - Remove circular import workaround (ChatConfig as top-level import) - Add 20 new tests: exception chains, malformed transcripts, unicode, tool use/result pairs, session rollback, write failure propagation
This commit is contained in:
@@ -68,6 +68,38 @@ class TestIsPromptTooLong:
|
||||
err = RuntimeError("prompt is too long")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
def test_walks_cause_chain(self):
|
||||
"""_is_prompt_too_long walks __cause__ to find wrapped errors."""
|
||||
inner = Exception("prompt is too long: 250000 > 200000")
|
||||
outer = RuntimeError("SDK process failed")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_walks_context_chain(self):
|
||||
"""_is_prompt_too_long walks __context__ for implicit chaining."""
|
||||
inner = Exception("context_length_exceeded")
|
||||
outer = RuntimeError("during handling")
|
||||
outer.__context__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_no_infinite_loop_on_circular_chain(self):
|
||||
"""Circular exception chains terminate without hanging."""
|
||||
a = Exception("error a")
|
||||
b = Exception("error b")
|
||||
a.__cause__ = b
|
||||
b.__cause__ = a
|
||||
assert _is_prompt_too_long(a) is False
|
||||
|
||||
def test_deep_chain(self):
|
||||
"""Deeply nested exception chain is walked."""
|
||||
bottom = Exception("request too large")
|
||||
current = bottom
|
||||
for i in range(10):
|
||||
wrapper = RuntimeError(f"layer {i}")
|
||||
wrapper.__cause__ = current
|
||||
current = wrapper
|
||||
assert _is_prompt_too_long(current) is True
|
||||
|
||||
def test_patterns_constant_is_tuple(self):
|
||||
"""Verify the patterns constant exists and is iterable."""
|
||||
assert len(_PROMPT_TOO_LONG_PATTERNS) >= 2
|
||||
@@ -145,6 +177,16 @@ class TestFlattenToolResultContent:
|
||||
def test_raw_string(self):
|
||||
assert _flatten_tool_result_content(["raw"]) == "raw"
|
||||
|
||||
def test_tool_result_with_none_content(self):
|
||||
"""tool_result with content=None should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": None}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_tool_result_with_empty_list_content(self):
|
||||
"""tool_result with content=[] should produce empty string."""
|
||||
blocks = [{"type": "tool_result", "tool_use_id": "x", "content": []}]
|
||||
assert _flatten_tool_result_content(blocks) == ""
|
||||
|
||||
def test_empty(self):
|
||||
assert _flatten_tool_result_content([]) == ""
|
||||
|
||||
@@ -238,6 +280,93 @@ class TestTranscriptToMessages:
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "tool output"
|
||||
|
||||
def test_malformed_json_lines_skipped(self):
|
||||
"""Malformed JSON lines in transcript are silently skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"this is not valid json",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_empty_lines_skipped(self):
|
||||
"""Empty lines and whitespace-only lines are skipped."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello"),
|
||||
"",
|
||||
" ",
|
||||
_make_entry("assistant", "assistant", [{"type": "text", "text": "Hi"}]),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
|
||||
def test_unicode_content_preserved(self):
|
||||
"""Unicode characters survive transcript roundtrip."""
|
||||
lines = [
|
||||
_make_entry("user", "user", "Hello 你好 🌍"),
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[{"type": "text", "text": "Bonjour 日本語 émojis 🎉"}],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert messages[0]["content"] == "Hello 你好 🌍"
|
||||
assert messages[1]["content"] == "Bonjour 日本語 émojis 🎉"
|
||||
|
||||
def test_entry_without_role_skipped(self):
|
||||
"""Entries with missing role in message are skipped."""
|
||||
entry_no_role = json.dumps(
|
||||
{
|
||||
"type": "user",
|
||||
"uuid": str(uuid4()),
|
||||
"parentUuid": None,
|
||||
"message": {"content": "no role here"},
|
||||
}
|
||||
)
|
||||
lines = [
|
||||
entry_no_role,
|
||||
_make_entry("user", "user", "Hello"),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == "Hello"
|
||||
|
||||
def test_tool_use_and_result_pairs(self):
|
||||
"""Tool use + tool result pairs are properly flattened."""
|
||||
lines = [
|
||||
_make_entry(
|
||||
"assistant",
|
||||
"assistant",
|
||||
[
|
||||
{"type": "text", "text": "Let me check."},
|
||||
{"type": "tool_use", "name": "read_file", "input": {"path": "/x"}},
|
||||
],
|
||||
),
|
||||
_make_entry(
|
||||
"user",
|
||||
"user",
|
||||
[
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "abc",
|
||||
"content": [{"type": "text", "text": "file contents"}],
|
||||
}
|
||||
],
|
||||
),
|
||||
]
|
||||
content = "\n".join(lines) + "\n"
|
||||
messages = _transcript_to_messages(content)
|
||||
assert len(messages) == 2
|
||||
assert "Let me check." in messages[0]["content"]
|
||||
assert "[tool_use: read_file]" in messages[0]["content"]
|
||||
assert messages[1]["content"] == "file contents"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _messages_to_transcript
|
||||
|
||||
@@ -27,6 +27,8 @@ from backend.util import json
|
||||
|
||||
from .service import _is_prompt_too_long
|
||||
from .transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
@@ -633,15 +635,29 @@ class TestRetryEdgeCases:
|
||||
"""Edge cases for the retry logic components."""
|
||||
|
||||
def test_is_prompt_too_long_with_nested_exception(self):
|
||||
"""Chained exception with prompt-too-long in cause."""
|
||||
"""Chained exception with prompt-too-long in __cause__ is detected."""
|
||||
inner = Exception("prompt is too long: 250000 > 200000")
|
||||
# The function checks str(err), not __cause__
|
||||
outer = RuntimeError("SDK error")
|
||||
outer.__cause__ = inner
|
||||
# Only checks the outer exception message
|
||||
assert _is_prompt_too_long(outer) is False
|
||||
# Walks the exception chain to find prompt-too-long in __cause__
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
assert _is_prompt_too_long(inner) is True
|
||||
|
||||
def test_is_prompt_too_long_with_context_exception(self):
|
||||
"""Chained exception with prompt-too-long in __context__ is detected."""
|
||||
inner = Exception("context_length_exceeded")
|
||||
outer = RuntimeError("wrapper")
|
||||
outer.__context__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
|
||||
def test_is_prompt_too_long_no_infinite_loop(self):
|
||||
"""Circular exception chain doesn't cause infinite loop."""
|
||||
a = Exception("error a")
|
||||
b = Exception("error b")
|
||||
a.__cause__ = b
|
||||
b.__cause__ = a # circular
|
||||
assert _is_prompt_too_long(a) is False
|
||||
|
||||
def test_is_prompt_too_long_case_insensitive(self):
|
||||
"""Pattern matching must be case-insensitive."""
|
||||
assert _is_prompt_too_long(Exception("PROMPT IS TOO LONG")) is True
|
||||
@@ -743,3 +759,139 @@ class TestRetryEdgeCases:
|
||||
entries = [json.loads(line) for line in output.strip().split("\n")]
|
||||
for i in range(1, len(entries)):
|
||||
assert entries[i]["parentUuid"] == entries[i - 1]["uuid"]
|
||||
|
||||
|
||||
class TestRetryStateReset:
|
||||
"""Verify state is properly reset between retry attempts."""
|
||||
|
||||
def test_session_messages_rollback_on_retry(self):
|
||||
"""Simulate session.messages rollback as done in service.py."""
|
||||
session_messages = ["msg1", "msg2"] # pre-existing
|
||||
pre_attempt_count = len(session_messages)
|
||||
|
||||
# Simulate streaming adding partial messages
|
||||
session_messages.append("partial_assistant")
|
||||
session_messages.append("tool_result")
|
||||
assert len(session_messages) == 4
|
||||
|
||||
# Rollback (as done at line 1410 in service.py)
|
||||
session_messages = session_messages[:pre_attempt_count]
|
||||
assert len(session_messages) == 2
|
||||
assert session_messages == ["msg1", "msg2"]
|
||||
|
||||
def test_write_transcript_failure_sets_error_flag(self):
|
||||
"""When write_transcript_to_tempfile fails, transcript_caused_error
|
||||
must be set True to prevent uploading stale data."""
|
||||
# Simulate the logic from service.py lines 1012-1020
|
||||
transcript_caused_error = False
|
||||
use_resume = True
|
||||
resume_file = None # write_transcript_to_tempfile returned None
|
||||
|
||||
if not resume_file:
|
||||
use_resume = False
|
||||
transcript_caused_error = True
|
||||
|
||||
assert transcript_caused_error is True
|
||||
assert use_resume is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_returns_none_preserves_error_flag(self):
|
||||
"""When compaction returns None, transcript_caused_error is set."""
|
||||
transcript = _build_transcript([("user", "A"), ("assistant", "B")])
|
||||
transcript_caused_error = False
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.config.ChatConfig",
|
||||
return_value=type(
|
||||
"Cfg", (), {"model": "m", "api_key": "k", "base_url": "u"}
|
||||
)(),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.sdk.transcript._run_compression",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("boom"),
|
||||
),
|
||||
):
|
||||
compacted = await compact_transcript(transcript)
|
||||
|
||||
# compact_transcript returns None on failure
|
||||
assert compacted is None
|
||||
# Caller sets transcript_caused_error
|
||||
if not compacted:
|
||||
transcript_caused_error = True
|
||||
assert transcript_caused_error is True
|
||||
|
||||
|
||||
class TestTranscriptEdgeCases:
|
||||
"""Edge cases in transcript parsing and generation."""
|
||||
|
||||
def test_transcript_with_very_long_content(self):
|
||||
"""Large content doesn't corrupt the transcript format."""
|
||||
big_content = "x" * 100_000
|
||||
pairs = [("user", big_content), ("assistant", "ok")]
|
||||
transcript = _build_transcript(pairs)
|
||||
msgs = _transcript_to_messages(transcript)
|
||||
assert len(msgs) == 2
|
||||
assert msgs[0]["content"] == big_content
|
||||
|
||||
def test_transcript_with_special_json_chars(self):
|
||||
"""Content with JSON special characters is handled."""
|
||||
pairs = [
|
||||
("user", 'Hello "world" with \\backslash and \nnewline'),
|
||||
("assistant", "Tab\there and null\x00byte"),
|
||||
]
|
||||
transcript = _build_transcript(pairs)
|
||||
msgs = _transcript_to_messages(transcript)
|
||||
assert len(msgs) == 2
|
||||
assert '"world"' in msgs[0]["content"]
|
||||
|
||||
def test_messages_to_transcript_empty_content(self):
|
||||
"""Messages with empty content produce valid transcript."""
|
||||
messages = [
|
||||
{"role": "user", "content": ""},
|
||||
{"role": "assistant", "content": ""},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
assert validate_transcript(result)
|
||||
restored = _transcript_to_messages(result)
|
||||
assert len(restored) == 2
|
||||
|
||||
def test_consecutive_same_role_messages(self):
|
||||
"""Multiple consecutive user or assistant messages are preserved."""
|
||||
messages = [
|
||||
{"role": "user", "content": "First"},
|
||||
{"role": "user", "content": "Second"},
|
||||
{"role": "assistant", "content": "Reply"},
|
||||
]
|
||||
result = _messages_to_transcript(messages)
|
||||
restored = _transcript_to_messages(result)
|
||||
assert len(restored) == 3
|
||||
assert restored[0]["content"] == "First"
|
||||
assert restored[1]["content"] == "Second"
|
||||
|
||||
def test_flatten_assistant_with_only_tool_use(self):
|
||||
"""Assistant message with only tool_use blocks (no text)."""
|
||||
blocks = [
|
||||
{"type": "tool_use", "name": "bash", "input": {"cmd": "ls"}},
|
||||
{"type": "tool_use", "name": "read", "input": {"path": "/f"}},
|
||||
]
|
||||
result = _flatten_assistant_content(blocks)
|
||||
assert "[tool_use: bash]" in result
|
||||
assert "[tool_use: read]" in result
|
||||
|
||||
def test_flatten_tool_result_nested_image(self):
|
||||
"""Tool result containing image blocks uses placeholder."""
|
||||
blocks = [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "x",
|
||||
"content": [
|
||||
{"type": "image", "source": {"type": "base64", "data": "abc"}},
|
||||
{"type": "text", "text": "screenshot above"},
|
||||
],
|
||||
}
|
||||
]
|
||||
result = _flatten_tool_result_content(blocks)
|
||||
# json.dumps fallback for image, text for text
|
||||
assert "screenshot above" in result
|
||||
|
||||
@@ -97,9 +97,20 @@ _PROMPT_TOO_LONG_PATTERNS = (
|
||||
|
||||
|
||||
def _is_prompt_too_long(err: BaseException) -> bool:
|
||||
"""Return True if *err* indicates the prompt exceeds the model's limit."""
|
||||
msg = str(err).lower()
|
||||
return any(p in msg for p in _PROMPT_TOO_LONG_PATTERNS)
|
||||
"""Return True if *err* indicates the prompt exceeds the model's limit.
|
||||
|
||||
Walks the exception chain (``__cause__`` / ``__context__``) so that
|
||||
wrapped errors (e.g. ``RuntimeError`` wrapping an API error) are
|
||||
detected too.
|
||||
"""
|
||||
seen: set[int] = set()
|
||||
current: BaseException | None = err
|
||||
while current is not None and id(current) not in seen:
|
||||
seen.add(id(current))
|
||||
if any(p in str(current).lower() for p in _PROMPT_TOO_LONG_PATTERNS):
|
||||
return True
|
||||
current = current.__cause__ or current.__context__
|
||||
return False
|
||||
|
||||
|
||||
def _setup_langfuse_otel() -> None:
|
||||
@@ -998,7 +1009,17 @@ async def stream_chat_completion_sdk(
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
compacted, session_id, sdk_cwd
|
||||
)
|
||||
use_resume = bool(resume_file)
|
||||
if not resume_file:
|
||||
logger.warning(
|
||||
"%s Failed to write compacted transcript, "
|
||||
"dropping transcript",
|
||||
log_prefix,
|
||||
)
|
||||
transcript_builder = TranscriptBuilder()
|
||||
use_resume = False
|
||||
transcript_caused_error = True
|
||||
else:
|
||||
use_resume = True
|
||||
transcript_msg_count = 0
|
||||
else:
|
||||
logger.warning(
|
||||
|
||||
@@ -19,17 +19,14 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import uuid4
|
||||
|
||||
import openai
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult, compress_context
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# UUIDs are hex + hyphens; strip everything else to prevent path injection.
|
||||
@@ -621,7 +618,7 @@ def _flatten_tool_result_content(blocks: list) -> str:
|
||||
str_parts: list[str] = []
|
||||
for block in blocks:
|
||||
if isinstance(block, dict) and block.get("type") == "tool_result":
|
||||
inner = block.get("content", "")
|
||||
inner = block.get("content") or ""
|
||||
if isinstance(inner, list):
|
||||
for sub in inner:
|
||||
if isinstance(sub, dict):
|
||||
@@ -722,7 +719,7 @@ async def _run_compression(
|
||||
api_key=cfg.api_key, base_url=cfg.base_url, timeout=30.0
|
||||
) as client:
|
||||
return await compress_context(messages=messages, model=model, client=client)
|
||||
except (openai.APIError, openai.APITimeoutError, OSError) as e:
|
||||
except Exception as e:
|
||||
logger.warning("%s LLM compaction failed, using truncation: %s", log_prefix, e)
|
||||
return await compress_context(messages=messages, model=model, client=None)
|
||||
|
||||
@@ -738,11 +735,6 @@ async def compact_transcript(
|
||||
|
||||
Returns the compacted JSONL string, or ``None`` on failure.
|
||||
"""
|
||||
# Local import: ChatConfig is in TYPE_CHECKING for annotations but
|
||||
# needs a runtime import here. Top-level would create a circular
|
||||
# dependency (config → … → transcript).
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
cfg = ChatConfig()
|
||||
messages = _transcript_to_messages(content)
|
||||
if len(messages) < 2:
|
||||
|
||||
Reference in New Issue
Block a user