mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
refactor(backend/copilot): retry on any streaming error, not just prompt-too-long
Replace the prompt-too-long-only retry with a general error retry strategy: - RetryStrategy.COMPACT_THEN_FALLBACK: try Plan B (compact transcript), then Plan C (DB fallback) — used for all errors by default - RetryStrategy.FALLBACK_ONLY: skip compaction, go straight to Plan C — used for transcript/JSON parse errors where compaction can't help Extract _apply_plan_b() and _apply_plan_c() helpers to deduplicate the inline retry setup logic. Replace _is_prompt_too_long() with _classify_error() that determines the retry strategy.
This commit is contained in:
@@ -9,7 +9,7 @@ import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .service import _PROMPT_TOO_LONG_PATTERNS, _is_prompt_too_long
|
||||
from .service import RetryStrategy, _classify_error
|
||||
from .transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
@@ -20,12 +20,12 @@ from .transcript import (
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_prompt_too_long
|
||||
# _classify_error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsPromptTooLong:
|
||||
"""Tests for _is_prompt_too_long error detector."""
|
||||
class TestClassifyError:
|
||||
"""Tests for _classify_error — maps errors to RetryStrategy values."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_msg",
|
||||
@@ -33,54 +33,49 @@ class TestIsPromptTooLong:
|
||||
"prompt is too long: 250000 tokens > 200000 maximum",
|
||||
"Error: prompt is too long",
|
||||
"context_length_exceeded",
|
||||
"prompt_too_long",
|
||||
"The prompt is too long for this model",
|
||||
"PROMPT IS TOO LONG", # case-insensitive
|
||||
"Error: CONTEXT_LENGTH_EXCEEDED",
|
||||
"request too large", # HTTP 413 from Anthropic API
|
||||
"Request too large for model",
|
||||
],
|
||||
)
|
||||
def test_detects_prompt_too_long_errors(self, error_msg: str):
|
||||
err = Exception(error_msg)
|
||||
assert _is_prompt_too_long(err) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_msg",
|
||||
[
|
||||
"request too large",
|
||||
"Connection timeout",
|
||||
"Authentication failed",
|
||||
"Rate limit exceeded",
|
||||
"Internal server error",
|
||||
"Invalid API key",
|
||||
"Network unreachable",
|
||||
"SDK process exited with code 1",
|
||||
"",
|
||||
"context_length is 4096", # partial match should NOT trigger
|
||||
],
|
||||
)
|
||||
def test_rejects_non_prompt_errors(self, error_msg: str):
|
||||
err = Exception(error_msg)
|
||||
assert _is_prompt_too_long(err) is False
|
||||
def test_general_errors_return_compact_then_fallback(self, error_msg: str):
|
||||
"""Most errors (including prompt-too-long) → COMPACT_THEN_FALLBACK."""
|
||||
assert (
|
||||
_classify_error(Exception(error_msg)) == RetryStrategy.COMPACT_THEN_FALLBACK
|
||||
)
|
||||
|
||||
def test_handles_non_exception_types(self):
|
||||
"""_is_prompt_too_long should work with any BaseException."""
|
||||
err = RuntimeError("prompt is too long")
|
||||
assert _is_prompt_too_long(err) is True
|
||||
@pytest.mark.parametrize(
|
||||
"error_msg",
|
||||
[
|
||||
"invalid json in transcript",
|
||||
"json decode error at position 42",
|
||||
"JSONDecodeError: Expecting value",
|
||||
"failed to read resume file",
|
||||
"session file not found",
|
||||
"malformed jsonl entry",
|
||||
],
|
||||
)
|
||||
def test_transcript_errors_return_fallback_only(self, error_msg: str):
|
||||
"""Transcript/JSON parse errors → FALLBACK_ONLY."""
|
||||
assert _classify_error(Exception(error_msg)) == RetryStrategy.FALLBACK_ONLY
|
||||
|
||||
def test_walks_cause_chain(self):
|
||||
"""_is_prompt_too_long walks __cause__ to find wrapped errors."""
|
||||
inner = Exception("prompt is too long: 250000 > 200000")
|
||||
"""Walks __cause__ to find transcript errors in wrapped exceptions."""
|
||||
inner = Exception("invalid json in transcript")
|
||||
outer = RuntimeError("SDK process failed")
|
||||
outer.__cause__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
assert _classify_error(outer) == RetryStrategy.FALLBACK_ONLY
|
||||
|
||||
def test_walks_context_chain(self):
|
||||
"""_is_prompt_too_long walks __context__ for implicit chaining."""
|
||||
inner = Exception("context_length_exceeded")
|
||||
"""Walks __context__ for implicit exception chaining."""
|
||||
inner = Exception("json decode error")
|
||||
outer = RuntimeError("during handling")
|
||||
outer.__context__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
assert _classify_error(outer) == RetryStrategy.FALLBACK_ONLY
|
||||
|
||||
def test_no_infinite_loop_on_circular_chain(self):
|
||||
"""Circular exception chains terminate without hanging."""
|
||||
@@ -88,24 +83,25 @@ class TestIsPromptTooLong:
|
||||
b = Exception("error b")
|
||||
a.__cause__ = b
|
||||
b.__cause__ = a
|
||||
assert _is_prompt_too_long(a) is False
|
||||
assert _classify_error(a) == RetryStrategy.COMPACT_THEN_FALLBACK
|
||||
|
||||
def test_deep_chain(self):
|
||||
"""Deeply nested exception chain is walked."""
|
||||
bottom = Exception("request too large")
|
||||
bottom = Exception("malformed jsonl")
|
||||
current = bottom
|
||||
for i in range(10):
|
||||
wrapper = RuntimeError(f"layer {i}")
|
||||
wrapper.__cause__ = current
|
||||
current = wrapper
|
||||
assert _is_prompt_too_long(current) is True
|
||||
assert _classify_error(current) == RetryStrategy.FALLBACK_ONLY
|
||||
|
||||
def test_patterns_constant_is_tuple(self):
|
||||
"""Verify the patterns constant exists and is iterable."""
|
||||
assert len(_PROMPT_TOO_LONG_PATTERNS) >= 2
|
||||
for p in _PROMPT_TOO_LONG_PATTERNS:
|
||||
assert isinstance(p, str)
|
||||
assert p == p.lower(), f"Pattern {p!r} should be lowercase"
|
||||
def test_case_insensitive(self):
|
||||
"""Pattern matching is case-insensitive."""
|
||||
assert _classify_error(Exception("INVALID JSON")) == RetryStrategy.FALLBACK_ONLY
|
||||
assert (
|
||||
_classify_error(Exception("Resume File not found"))
|
||||
== RetryStrategy.FALLBACK_ONLY
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,17 +1,17 @@
|
||||
"""Integration tests for the try-compact-retry loop scenarios.
|
||||
"""Integration tests for the retry/fallback loop scenarios.
|
||||
|
||||
These tests exercise the retry decision logic end-to-end by simulating
|
||||
the state transitions that happen in ``stream_chat_completion_sdk`` when
|
||||
the SDK raises prompt-too-long errors.
|
||||
the SDK raises streaming errors.
|
||||
|
||||
Scenario matrix (from the design doc):
|
||||
Scenario matrix:
|
||||
1. Normal flow — no error, no retry
|
||||
2. Prompt-too-long → compact succeeds → retry succeeds
|
||||
3. Prompt-too-long → compact fails → DB fallback succeeds
|
||||
4. Prompt-too-long → no transcript → DB fallback succeeds
|
||||
5. Prompt-too-long → compact succeeds → retry fails → DB fallback succeeds
|
||||
6. All 3 attempts exhausted → StreamError(prompt_too_long)
|
||||
7. Non-prompt-too-long error → no retry, StreamError(sdk_stream_error)
|
||||
2. Error → compact succeeds → retry succeeds
|
||||
3. Error → compact fails → DB fallback succeeds
|
||||
4. Error → no transcript → DB fallback succeeds
|
||||
5. Error × 2 → attempt 3 DB fallback succeeds
|
||||
6. All 3 attempts exhausted → StreamError(all_attempts_exhausted)
|
||||
7. Transcript parse error → skip compact, DB fallback directly
|
||||
8. Compaction returns identical content → treated as compact failure → DB fallback
|
||||
9. transcript_caused_error → finally skips upload
|
||||
"""
|
||||
@@ -25,7 +25,7 @@ import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .service import _is_prompt_too_long
|
||||
from .service import RetryStrategy, _classify_error
|
||||
from .transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
@@ -98,19 +98,22 @@ def _mock_compress_result(
|
||||
|
||||
|
||||
class TestScenarioNormalFlow:
|
||||
"""When no prompt-too-long error occurs, no retry logic fires."""
|
||||
"""When no error occurs, no retry logic fires."""
|
||||
|
||||
def test_is_prompt_too_long_returns_false_for_normal_errors(self):
|
||||
"""Normal SDK errors should not trigger retry."""
|
||||
normal_errors = [
|
||||
def test_general_errors_get_compact_then_fallback(self):
|
||||
"""General SDK errors should get COMPACT_THEN_FALLBACK strategy."""
|
||||
general_errors = [
|
||||
"Connection refused",
|
||||
"SDK process exited with code 1",
|
||||
"Authentication failed",
|
||||
"Rate limit exceeded",
|
||||
"Internal server error",
|
||||
"prompt is too long",
|
||||
]
|
||||
for msg in normal_errors:
|
||||
assert _is_prompt_too_long(Exception(msg)) is False, msg
|
||||
for msg in general_errors:
|
||||
assert (
|
||||
_classify_error(Exception(msg)) == RetryStrategy.COMPACT_THEN_FALLBACK
|
||||
), msg
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -319,46 +322,42 @@ class TestScenarioDoubleFailDBFallback:
|
||||
|
||||
|
||||
class TestScenarioAllAttemptsExhausted:
|
||||
"""All 3 attempts hit prompt-too-long — final StreamError is emitted."""
|
||||
"""All 3 attempts fail — final StreamError is emitted."""
|
||||
|
||||
def test_exhaustion_state_variables(self):
|
||||
"""Verify the state after exhausting all retry attempts."""
|
||||
# Simulate the retry loop state
|
||||
_MAX_QUERY_ATTEMPTS = 3
|
||||
_prompt_too_long = False
|
||||
_retry_strategy: str | None = None
|
||||
transcript_caused_error = False
|
||||
|
||||
for _query_attempt in range(_MAX_QUERY_ATTEMPTS):
|
||||
# Every attempt hits prompt-too-long
|
||||
_prompt_too_long = True
|
||||
# The `continue` in real code skips post-processing
|
||||
_retry_strategy = RetryStrategy.COMPACT_THEN_FALLBACK
|
||||
|
||||
# After loop: check exhaustion
|
||||
assert _prompt_too_long is True
|
||||
# In the real code, this sets transcript_caused_error = True
|
||||
assert _retry_strategy is not None
|
||||
transcript_caused_error = True
|
||||
assert transcript_caused_error is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 7: Non-prompt-too-long error — no retry
|
||||
# Scenario 7: Transcript parse error → skip compact, DB fallback directly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestScenarioNonPromptError:
|
||||
"""A non-prompt-too-long SDK error yields StreamError immediately,
|
||||
no retry."""
|
||||
class TestScenarioTranscriptParseError:
|
||||
"""Transcript/JSON parse errors skip compaction and go straight to
|
||||
DB fallback (FALLBACK_ONLY strategy)."""
|
||||
|
||||
def test_generic_errors_not_retried(self):
|
||||
"""Verify _is_prompt_too_long rejects generic errors."""
|
||||
generic_errors = [
|
||||
Exception("SDK process exited with code 1"),
|
||||
RuntimeError("Connection reset"),
|
||||
ValueError("Invalid argument"),
|
||||
Exception("context_length is 4096"), # partial match
|
||||
def test_transcript_errors_get_fallback_only(self):
|
||||
"""Verify transcript parse errors return FALLBACK_ONLY."""
|
||||
transcript_errors = [
|
||||
Exception("invalid json in line 5"),
|
||||
RuntimeError("json decode error"),
|
||||
ValueError("malformed jsonl entry"),
|
||||
Exception("failed to read resume file"),
|
||||
]
|
||||
for err in generic_errors:
|
||||
assert _is_prompt_too_long(err) is False, str(err)
|
||||
for err in transcript_errors:
|
||||
assert _classify_error(err) == RetryStrategy.FALLBACK_ONLY, str(err)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -477,55 +476,66 @@ class TestRetryStateMachine:
|
||||
attempt_results: list[str],
|
||||
transcript_content: str = "some_content",
|
||||
compact_result: str | None = "compacted_content",
|
||||
error_strategy: str = RetryStrategy.COMPACT_THEN_FALLBACK,
|
||||
) -> dict:
|
||||
"""Simulate the retry loop and return final state.
|
||||
|
||||
Args:
|
||||
attempt_results: List of outcomes per attempt.
|
||||
"success" = stream completes normally
|
||||
"prompt_too_long" = prompt-too-long error
|
||||
"error" = streaming error
|
||||
transcript_content: Initial transcript content ("" = none)
|
||||
compact_result: Result of compact_transcript (None = failure)
|
||||
error_strategy: RetryStrategy for errors
|
||||
"""
|
||||
_MAX_QUERY_ATTEMPTS = 3
|
||||
_prompt_too_long = False
|
||||
_retry_strategy: str | None = None
|
||||
transcript_caused_error = False
|
||||
use_resume = bool(transcript_content)
|
||||
stream_completed = False
|
||||
attempts_made = 0
|
||||
_plan_c_applied = False
|
||||
|
||||
for _query_attempt in range(min(_MAX_QUERY_ATTEMPTS, len(attempt_results))):
|
||||
if _query_attempt > 0:
|
||||
_prompt_too_long = False
|
||||
_last_strategy = _retry_strategy
|
||||
_retry_strategy = None
|
||||
stream_completed = False
|
||||
|
||||
if _query_attempt == 1 and transcript_content:
|
||||
# Plan B or Plan C?
|
||||
if (
|
||||
_last_strategy == RetryStrategy.COMPACT_THEN_FALLBACK
|
||||
and transcript_content
|
||||
and not _plan_c_applied
|
||||
):
|
||||
if compact_result and compact_result != transcript_content:
|
||||
use_resume = True
|
||||
else:
|
||||
use_resume = False
|
||||
transcript_caused_error = True
|
||||
_plan_c_applied = True
|
||||
else:
|
||||
use_resume = False
|
||||
transcript_caused_error = True
|
||||
_plan_c_applied = True
|
||||
|
||||
attempts_made += 1
|
||||
result = attempt_results[_query_attempt]
|
||||
|
||||
if result == "prompt_too_long":
|
||||
_prompt_too_long = True
|
||||
if result == "error":
|
||||
_retry_strategy = error_strategy
|
||||
continue # skip post-stream
|
||||
|
||||
# Stream succeeded
|
||||
stream_completed = True
|
||||
break
|
||||
|
||||
if _prompt_too_long:
|
||||
if _retry_strategy is not None:
|
||||
transcript_caused_error = True
|
||||
|
||||
return {
|
||||
"attempts_made": attempts_made,
|
||||
"prompt_too_long": _prompt_too_long,
|
||||
"retry_strategy": _retry_strategy,
|
||||
"transcript_caused_error": transcript_caused_error,
|
||||
"stream_completed": stream_completed,
|
||||
"use_resume": use_resume,
|
||||
@@ -535,7 +545,7 @@ class TestRetryStateMachine:
|
||||
"""Scenario 1: Success on first attempt."""
|
||||
state = self._simulate_retry_loop(["success"])
|
||||
assert state["attempts_made"] == 1
|
||||
assert state["prompt_too_long"] is False
|
||||
assert state["retry_strategy"] is None
|
||||
assert state["transcript_caused_error"] is False
|
||||
assert state["stream_completed"] is True
|
||||
assert state["use_resume"] is True
|
||||
@@ -543,12 +553,12 @@ class TestRetryStateMachine:
|
||||
def test_compact_and_retry_succeeds(self):
|
||||
"""Scenario 2: Fail, compact, succeed on attempt 2."""
|
||||
state = self._simulate_retry_loop(
|
||||
["prompt_too_long", "success"],
|
||||
["error", "success"],
|
||||
transcript_content="original",
|
||||
compact_result="compacted",
|
||||
)
|
||||
assert state["attempts_made"] == 2
|
||||
assert state["prompt_too_long"] is False
|
||||
assert state["retry_strategy"] is None
|
||||
assert state["transcript_caused_error"] is False
|
||||
assert state["stream_completed"] is True
|
||||
assert state["use_resume"] is True # compacted transcript used
|
||||
@@ -556,12 +566,12 @@ class TestRetryStateMachine:
|
||||
def test_compact_fails_db_fallback_succeeds(self):
|
||||
"""Scenario 3: Fail, compact fails, DB fallback succeeds."""
|
||||
state = self._simulate_retry_loop(
|
||||
["prompt_too_long", "success"],
|
||||
["error", "success"],
|
||||
transcript_content="original",
|
||||
compact_result=None, # compact fails
|
||||
)
|
||||
assert state["attempts_made"] == 2
|
||||
assert state["prompt_too_long"] is False
|
||||
assert state["retry_strategy"] is None
|
||||
assert state["transcript_caused_error"] is True # DB fallback
|
||||
assert state["stream_completed"] is True
|
||||
assert state["use_resume"] is False
|
||||
@@ -569,11 +579,11 @@ class TestRetryStateMachine:
|
||||
def test_no_transcript_db_fallback_succeeds(self):
|
||||
"""Scenario 4: No transcript, DB fallback on attempt 2."""
|
||||
state = self._simulate_retry_loop(
|
||||
["prompt_too_long", "success"],
|
||||
["error", "success"],
|
||||
transcript_content="", # no transcript
|
||||
)
|
||||
assert state["attempts_made"] == 2
|
||||
assert state["prompt_too_long"] is False
|
||||
assert state["retry_strategy"] is None
|
||||
assert state["transcript_caused_error"] is True
|
||||
assert state["stream_completed"] is True
|
||||
assert state["use_resume"] is False
|
||||
@@ -581,12 +591,12 @@ class TestRetryStateMachine:
|
||||
def test_double_fail_db_fallback_succeeds(self):
|
||||
"""Scenario 5: Fail, compact succeeds but retry fails, DB fallback."""
|
||||
state = self._simulate_retry_loop(
|
||||
["prompt_too_long", "prompt_too_long", "success"],
|
||||
["error", "error", "success"],
|
||||
transcript_content="original",
|
||||
compact_result="compacted",
|
||||
)
|
||||
assert state["attempts_made"] == 3
|
||||
assert state["prompt_too_long"] is False
|
||||
assert state["retry_strategy"] is None
|
||||
assert state["transcript_caused_error"] is True
|
||||
assert state["stream_completed"] is True
|
||||
assert state["use_resume"] is False # dropped for attempt 3
|
||||
@@ -594,19 +604,19 @@ class TestRetryStateMachine:
|
||||
def test_all_attempts_exhausted(self):
|
||||
"""Scenario 6: All 3 attempts fail."""
|
||||
state = self._simulate_retry_loop(
|
||||
["prompt_too_long", "prompt_too_long", "prompt_too_long"],
|
||||
["error", "error", "error"],
|
||||
transcript_content="original",
|
||||
compact_result="compacted",
|
||||
)
|
||||
assert state["attempts_made"] == 3
|
||||
assert state["prompt_too_long"] is True
|
||||
assert state["retry_strategy"] is not None
|
||||
assert state["transcript_caused_error"] is True
|
||||
assert state["stream_completed"] is False
|
||||
|
||||
def test_compact_identical_triggers_db_fallback(self):
|
||||
"""Scenario 8: Compaction returns identical content."""
|
||||
state = self._simulate_retry_loop(
|
||||
["prompt_too_long", "success"],
|
||||
["error", "success"],
|
||||
transcript_content="original",
|
||||
compact_result="original", # Same as input!
|
||||
)
|
||||
@@ -617,14 +627,26 @@ class TestRetryStateMachine:
|
||||
def test_no_transcript_all_exhausted(self):
|
||||
"""No transcript + all attempts fail."""
|
||||
state = self._simulate_retry_loop(
|
||||
["prompt_too_long", "prompt_too_long", "prompt_too_long"],
|
||||
["error", "error", "error"],
|
||||
transcript_content="",
|
||||
)
|
||||
assert state["attempts_made"] == 3
|
||||
assert state["prompt_too_long"] is True
|
||||
assert state["retry_strategy"] is not None
|
||||
assert state["transcript_caused_error"] is True
|
||||
assert state["stream_completed"] is False
|
||||
|
||||
def test_fallback_only_skips_compact(self):
|
||||
"""FALLBACK_ONLY strategy skips Plan B, goes straight to Plan C."""
|
||||
state = self._simulate_retry_loop(
|
||||
["error", "success"],
|
||||
transcript_content="original",
|
||||
compact_result="compacted", # would succeed, but should be skipped
|
||||
error_strategy=RetryStrategy.FALLBACK_ONLY,
|
||||
)
|
||||
assert state["attempts_made"] == 2
|
||||
assert state["transcript_caused_error"] is True
|
||||
assert state["use_resume"] is False # Plan C, not Plan B
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
@@ -634,35 +656,36 @@ class TestRetryStateMachine:
|
||||
class TestRetryEdgeCases:
|
||||
"""Edge cases for the retry logic components."""
|
||||
|
||||
def test_is_prompt_too_long_with_nested_exception(self):
|
||||
"""Chained exception with prompt-too-long in __cause__ is detected."""
|
||||
inner = Exception("prompt is too long: 250000 > 200000")
|
||||
def test_classify_error_with_nested_exception(self):
|
||||
"""Chained exception with transcript error in __cause__ is detected."""
|
||||
inner = Exception("invalid json in transcript")
|
||||
outer = RuntimeError("SDK error")
|
||||
outer.__cause__ = inner
|
||||
# Walks the exception chain to find prompt-too-long in __cause__
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
assert _is_prompt_too_long(inner) is True
|
||||
assert _classify_error(outer) == RetryStrategy.FALLBACK_ONLY
|
||||
assert _classify_error(inner) == RetryStrategy.FALLBACK_ONLY
|
||||
|
||||
def test_is_prompt_too_long_with_context_exception(self):
|
||||
"""Chained exception with prompt-too-long in __context__ is detected."""
|
||||
inner = Exception("context_length_exceeded")
|
||||
def test_classify_error_with_context_exception(self):
|
||||
"""Chained exception via __context__ is detected."""
|
||||
inner = Exception("json decode error")
|
||||
outer = RuntimeError("wrapper")
|
||||
outer.__context__ = inner
|
||||
assert _is_prompt_too_long(outer) is True
|
||||
assert _classify_error(outer) == RetryStrategy.FALLBACK_ONLY
|
||||
|
||||
def test_is_prompt_too_long_no_infinite_loop(self):
|
||||
def test_classify_error_no_infinite_loop(self):
|
||||
"""Circular exception chain doesn't cause infinite loop."""
|
||||
a = Exception("error a")
|
||||
b = Exception("error b")
|
||||
a.__cause__ = b
|
||||
b.__cause__ = a # circular
|
||||
assert _is_prompt_too_long(a) is False
|
||||
assert _classify_error(a) == RetryStrategy.COMPACT_THEN_FALLBACK
|
||||
|
||||
def test_is_prompt_too_long_case_insensitive(self):
|
||||
def test_classify_error_case_insensitive(self):
|
||||
"""Pattern matching must be case-insensitive."""
|
||||
assert _is_prompt_too_long(Exception("PROMPT IS TOO LONG")) is True
|
||||
assert _is_prompt_too_long(Exception("Prompt_Too_Long")) is True
|
||||
assert _is_prompt_too_long(Exception("CONTEXT_LENGTH_EXCEEDED")) is True
|
||||
assert _classify_error(Exception("INVALID JSON")) == RetryStrategy.FALLBACK_ONLY
|
||||
assert (
|
||||
_classify_error(Exception("RESUME FILE missing"))
|
||||
== RetryStrategy.FALLBACK_ONLY
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compact_transcript_with_single_message(self):
|
||||
|
||||
@@ -88,29 +88,89 @@ from .transcript_builder import TranscriptBuilder
|
||||
logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
_PROMPT_TOO_LONG_PATTERNS = (
|
||||
"prompt is too long",
|
||||
"prompt_too_long",
|
||||
"context_length_exceeded",
|
||||
"request too large",
|
||||
|
||||
class RetryStrategy:
|
||||
"""Determines what to try next after a streaming error.
|
||||
|
||||
COMPACT_THEN_FALLBACK — try Plan B (compact transcript), then Plan C (DB).
|
||||
FALLBACK_ONLY — skip compaction, go straight to Plan C (DB).
|
||||
"""
|
||||
|
||||
COMPACT_THEN_FALLBACK = "compact_then_fallback"
|
||||
FALLBACK_ONLY = "fallback_only"
|
||||
|
||||
|
||||
# Patterns checked against the full exception chain (lowercased).
|
||||
_TRANSCRIPT_ERROR_PATTERNS = (
|
||||
"invalid json",
|
||||
"json decode",
|
||||
"jsondecodeerror",
|
||||
"resume file",
|
||||
"session file",
|
||||
"malformed jsonl",
|
||||
)
|
||||
|
||||
|
||||
def _is_prompt_too_long(err: BaseException) -> bool:
|
||||
"""Return True if *err* indicates the prompt exceeds the model's limit.
|
||||
def _classify_error(err: BaseException) -> str:
|
||||
"""Classify a streaming error into a :class:`RetryStrategy`.
|
||||
|
||||
Walks the exception chain (``__cause__`` / ``__context__``) so that
|
||||
wrapped errors (e.g. ``RuntimeError`` wrapping an API error) are
|
||||
detected too.
|
||||
wrapped errors are detected too.
|
||||
|
||||
* Transcript/JSON parse errors → ``FALLBACK_ONLY`` (compaction won't
|
||||
help if the file itself is broken).
|
||||
* Everything else (prompt-too-long, transient 500s, timeouts, …) →
|
||||
``COMPACT_THEN_FALLBACK``.
|
||||
"""
|
||||
seen: set[int] = set()
|
||||
current: BaseException | None = err
|
||||
err_parts: list[str] = []
|
||||
|
||||
while current is not None and id(current) not in seen:
|
||||
seen.add(id(current))
|
||||
if any(p in str(current).lower() for p in _PROMPT_TOO_LONG_PATTERNS):
|
||||
return True
|
||||
err_parts.append(str(current).lower())
|
||||
current = current.__cause__ or current.__context__
|
||||
return False
|
||||
|
||||
combined = " ".join(err_parts)
|
||||
|
||||
if any(p in combined for p in _TRANSCRIPT_ERROR_PATTERNS):
|
||||
return RetryStrategy.FALLBACK_ONLY
|
||||
|
||||
return RetryStrategy.COMPACT_THEN_FALLBACK
|
||||
|
||||
|
||||
async def _apply_plan_b(
|
||||
transcript_content: str,
|
||||
session_id: str,
|
||||
sdk_cwd: str,
|
||||
log_prefix: str,
|
||||
) -> tuple[TranscriptBuilder, bool, str | None, bool]:
|
||||
"""Plan B: compact the transcript via LLM summarization.
|
||||
|
||||
Returns ``(transcript_builder, use_resume, resume_file, success)``.
|
||||
On failure, returns a fresh builder with ``use_resume=False``.
|
||||
"""
|
||||
compacted = await compact_transcript(transcript_content, log_prefix=log_prefix)
|
||||
if compacted and compacted != transcript_content and validate_transcript(compacted):
|
||||
logger.info("%s Plan B — using compacted transcript", log_prefix)
|
||||
tb = TranscriptBuilder()
|
||||
tb.load_previous(compacted, log_prefix=log_prefix)
|
||||
resume_file = write_transcript_to_tempfile(compacted, session_id, sdk_cwd)
|
||||
if resume_file:
|
||||
return tb, True, resume_file, True
|
||||
logger.warning("%s Plan B — failed to write compacted transcript", log_prefix)
|
||||
|
||||
logger.warning("%s Plan B — compaction failed, will try Plan C", log_prefix)
|
||||
return TranscriptBuilder(), False, None, False
|
||||
|
||||
|
||||
def _apply_plan_c(log_prefix: str) -> tuple[TranscriptBuilder, bool, None]:
|
||||
"""Plan C: drop transcript entirely, rely on DB message history.
|
||||
|
||||
Returns ``(transcript_builder, use_resume, resume_file)``.
|
||||
"""
|
||||
logger.warning("%s Plan C — dropping transcript, using DB fallback", log_prefix)
|
||||
return TranscriptBuilder(), False, None
|
||||
|
||||
|
||||
def _setup_langfuse_otel() -> None:
|
||||
@@ -977,70 +1037,46 @@ async def stream_chat_completion_sdk(
|
||||
query_message = f"{query_message}\n\n{attachments.hint}"
|
||||
|
||||
_MAX_QUERY_ATTEMPTS = 3
|
||||
_prompt_too_long = False
|
||||
_retry_strategy: str | None = None # set on error, None = no error
|
||||
_last_stream_error: Exception | None = None
|
||||
_plan_c_applied = False
|
||||
|
||||
for _query_attempt in range(_MAX_QUERY_ATTEMPTS):
|
||||
if _query_attempt > 0:
|
||||
_prompt_too_long = False
|
||||
_last_retry_strategy = _retry_strategy
|
||||
_retry_strategy = None
|
||||
_last_stream_error = None
|
||||
stream_completed = False
|
||||
|
||||
logger.info(
|
||||
"%s Prompt-too-long retry attempt %d/%d",
|
||||
"%s Retry attempt %d/%d (strategy=%s)",
|
||||
log_prefix,
|
||||
_query_attempt + 1,
|
||||
_MAX_QUERY_ATTEMPTS,
|
||||
_last_retry_strategy,
|
||||
)
|
||||
if _query_attempt == 1 and transcript_content:
|
||||
compacted = await compact_transcript(
|
||||
transcript_content, log_prefix=log_prefix
|
||||
|
||||
# Decide: Plan B (compact) or Plan C (DB fallback)?
|
||||
if (
|
||||
_last_retry_strategy == RetryStrategy.COMPACT_THEN_FALLBACK
|
||||
and transcript_content
|
||||
and not _plan_c_applied
|
||||
):
|
||||
tb, use_resume, resume_file, success = await _apply_plan_b(
|
||||
transcript_content, session_id, sdk_cwd, log_prefix
|
||||
)
|
||||
if (
|
||||
compacted
|
||||
and compacted != transcript_content
|
||||
and validate_transcript(compacted)
|
||||
):
|
||||
logger.info(
|
||||
"%s Using compacted transcript for retry",
|
||||
log_prefix,
|
||||
)
|
||||
transcript_builder = TranscriptBuilder()
|
||||
transcript_builder.load_previous(
|
||||
compacted, log_prefix=log_prefix
|
||||
)
|
||||
resume_file = write_transcript_to_tempfile(
|
||||
compacted, session_id, sdk_cwd
|
||||
)
|
||||
if not resume_file:
|
||||
logger.warning(
|
||||
"%s Failed to write compacted transcript, "
|
||||
"dropping transcript",
|
||||
log_prefix,
|
||||
)
|
||||
transcript_builder = TranscriptBuilder()
|
||||
use_resume = False
|
||||
transcript_caused_error = True
|
||||
else:
|
||||
use_resume = True
|
||||
transcript_msg_count = 0
|
||||
else:
|
||||
logger.warning(
|
||||
"%s Compaction failed, dropping transcript",
|
||||
log_prefix,
|
||||
)
|
||||
transcript_builder = TranscriptBuilder()
|
||||
use_resume = False
|
||||
resume_file = None
|
||||
transcript_msg_count = 0
|
||||
transcript_builder = tb
|
||||
transcript_msg_count = 0
|
||||
if not success:
|
||||
transcript_caused_error = True
|
||||
_plan_c_applied = True
|
||||
else:
|
||||
logger.warning(
|
||||
"%s Dropping transcript, using DB fallback",
|
||||
log_prefix,
|
||||
transcript_builder, use_resume, resume_file = _apply_plan_c(
|
||||
log_prefix
|
||||
)
|
||||
transcript_builder = TranscriptBuilder()
|
||||
use_resume = False
|
||||
resume_file = None
|
||||
transcript_msg_count = 0
|
||||
transcript_caused_error = True
|
||||
_plan_c_applied = True
|
||||
|
||||
# Rebuild SDK options with updated resume state
|
||||
sdk_options_kwargs_retry = dict(sdk_options_kwargs)
|
||||
@@ -1162,28 +1198,17 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
break
|
||||
except Exception as stream_err:
|
||||
if _is_prompt_too_long(stream_err):
|
||||
logger.warning(
|
||||
"%s Prompt too long (attempt %d/%d): %s",
|
||||
log_prefix,
|
||||
_query_attempt + 1,
|
||||
_MAX_QUERY_ATTEMPTS,
|
||||
stream_err,
|
||||
)
|
||||
_prompt_too_long = True
|
||||
else:
|
||||
logger.error(
|
||||
"%s Stream error from SDK: %s",
|
||||
log_prefix,
|
||||
stream_err,
|
||||
exc_info=True,
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
|
||||
yield StreamError(
|
||||
errorText=f"SDK stream error: {stream_err}",
|
||||
code="sdk_stream_error",
|
||||
)
|
||||
_retry_strategy = _classify_error(stream_err)
|
||||
_last_stream_error = stream_err
|
||||
logger.warning(
|
||||
"%s Stream error (attempt %d/%d, " "strategy=%s): %s",
|
||||
log_prefix,
|
||||
_query_attempt + 1,
|
||||
_MAX_QUERY_ATTEMPTS,
|
||||
_retry_strategy,
|
||||
stream_err,
|
||||
exc_info=True,
|
||||
)
|
||||
break
|
||||
|
||||
logger.info(
|
||||
@@ -1423,13 +1448,12 @@ async def stream_chat_completion_sdk(
|
||||
log_prefix,
|
||||
)
|
||||
|
||||
# On prompt-too-long, skip post-stream processing — the retry
|
||||
# loop will either compact and retry, or exhaust attempts.
|
||||
# Roll back any partial messages appended during the failed
|
||||
# attempt to prevent duplicates on retry.
|
||||
if _prompt_too_long:
|
||||
# On error, skip post-stream processing — the retry loop
|
||||
# will compact / fallback / exhaust attempts. Roll back any
|
||||
# partial messages appended during the failed attempt.
|
||||
if _retry_strategy is not None:
|
||||
session.messages = session.messages[:_pre_attempt_msg_count]
|
||||
continue # goes to next iteration of for _query_attempt
|
||||
continue # next retry attempt
|
||||
|
||||
# Safety net: if tools are still unresolved after the
|
||||
# streaming loop (e.g. StopAsyncIteration before ResultMessage,
|
||||
@@ -1496,22 +1520,29 @@ async def stream_chat_completion_sdk(
|
||||
) and not has_appended_assistant:
|
||||
session.messages.append(assistant_response)
|
||||
|
||||
# If we reach here, the stream completed normally (no
|
||||
# prompt-too-long). The _prompt_too_long case is handled
|
||||
# by the `continue` after the inner finally block.
|
||||
# Stream completed normally — exit the retry loop.
|
||||
# Errors are handled by the `continue` above.
|
||||
break
|
||||
|
||||
if _prompt_too_long:
|
||||
# All retry attempts exhausted — surface error to the user.
|
||||
if _retry_strategy is not None:
|
||||
transcript_caused_error = True
|
||||
ended_with_stream_error = True
|
||||
logger.error(
|
||||
"%s All %d query attempts exhausted — prompt too long",
|
||||
"%s All %d query attempts exhausted (strategy=%s): %s",
|
||||
log_prefix,
|
||||
_MAX_QUERY_ATTEMPTS,
|
||||
_retry_strategy,
|
||||
_last_stream_error,
|
||||
)
|
||||
yield StreamError(
|
||||
errorText="The conversation is too long for the model. "
|
||||
"Please start a new session.",
|
||||
code="prompt_too_long",
|
||||
errorText=(
|
||||
"The conversation is too long for the model. "
|
||||
"Please start a new session."
|
||||
if _retry_strategy == RetryStrategy.COMPACT_THEN_FALLBACK
|
||||
else f"SDK stream error: {_last_stream_error}"
|
||||
),
|
||||
code="all_attempts_exhausted",
|
||||
)
|
||||
|
||||
# Transcript upload is handled exclusively in the finally block
|
||||
|
||||
Reference in New Issue
Block a user