mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
38 Commits
dev
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fff9faf13c | ||
|
|
f95772f0af | ||
|
|
79b8ad80fe | ||
|
|
8a4bc0b1e4 | ||
|
|
644d39d6be | ||
|
|
e2add1ba5b | ||
|
|
1a52b0d02c | ||
|
|
b101069eaf | ||
|
|
de094eee36 | ||
|
|
bddc633a11 | ||
|
|
2411cc386d | ||
|
|
49bef40ef0 | ||
|
|
eeb2f08d6d | ||
|
|
eda02f9ce6 | ||
|
|
2a969e5018 | ||
|
|
a68f48e6b7 | ||
|
|
2bf5a37646 | ||
|
|
289a19d402 | ||
|
|
e57e48272a | ||
|
|
c2f421cb42 | ||
|
|
e3d589b180 | ||
|
|
8de935c84b | ||
|
|
a55653f8c1 | ||
|
|
3e6faf2de7 | ||
|
|
22e8c5c353 | ||
|
|
b3d9e9e856 | ||
|
|
32bfe1b209 | ||
|
|
b220fe4347 | ||
|
|
61513b9dad | ||
|
|
e753aee7a0 | ||
|
|
3f24a003ad | ||
|
|
a369fbe169 | ||
|
|
d3173605eb | ||
|
|
98c27653f2 | ||
|
|
dced534df3 | ||
|
|
4ebe294707 | ||
|
|
2e8e115cd1 | ||
|
|
5ca49a8ec9 |
@@ -146,6 +146,32 @@ class ChatConfig(BaseSettings):
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
claude_agent_fallback_model: str = Field(
|
||||
default="claude-sonnet-4-20250514",
|
||||
description="Fallback model when the primary model is unavailable (e.g. 529 "
|
||||
"overloaded). The SDK automatically retries with this cheaper model.",
|
||||
)
|
||||
claude_agent_max_turns: int = Field(
|
||||
default=1000,
|
||||
ge=1,
|
||||
le=10000,
|
||||
description="Maximum number of agentic turns (tool-use loops) per query. "
|
||||
"Prevents runaway tool loops from burning budget.",
|
||||
)
|
||||
claude_agent_max_budget_usd: float = Field(
|
||||
default=100.0,
|
||||
ge=0.01,
|
||||
le=1000.0,
|
||||
description="Maximum spend in USD per SDK query. The CLI aborts the "
|
||||
"request if this budget is exceeded.",
|
||||
)
|
||||
claude_agent_max_transient_retries: int = Field(
|
||||
default=3,
|
||||
ge=0,
|
||||
le=10,
|
||||
description="Maximum number of retries for transient API errors "
|
||||
"(429, 5xx, ECONNRESET) before surfacing the error to the user.",
|
||||
)
|
||||
use_openrouter: bool = Field(
|
||||
default=True,
|
||||
description="Enable routing API calls through the OpenRouter proxy. "
|
||||
|
||||
@@ -44,12 +44,31 @@ def parse_node_id_from_exec_id(node_exec_id: str) -> str:
|
||||
# Transient Anthropic API error detection
|
||||
# ---------------------------------------------------------------------------
|
||||
# Patterns in error text that indicate a transient Anthropic API error
|
||||
# (ECONNRESET / dropped TCP connection) which is retryable.
|
||||
# which is retryable. Covers:
|
||||
# - Connection-level: ECONNRESET, dropped TCP connections
|
||||
# - HTTP 429: rate-limit / too-many-requests
|
||||
# - HTTP 5xx: server errors
|
||||
#
|
||||
# Prefer specific status-code patterns over natural-language phrases
|
||||
# (e.g. "overloaded", "bad gateway") — those phrases can appear in
|
||||
# application-level SDK messages and would trigger spurious retries.
|
||||
_TRANSIENT_ERROR_PATTERNS = (
|
||||
# Connection-level
|
||||
"socket connection was closed unexpectedly",
|
||||
"ECONNRESET",
|
||||
"connection was forcibly closed",
|
||||
"network socket disconnected",
|
||||
# 429 rate-limit patterns
|
||||
"rate limit",
|
||||
"rate_limit",
|
||||
"too many requests",
|
||||
"status code 429",
|
||||
# 5xx server error patterns (status-code-specific to avoid false positives)
|
||||
"status code 529",
|
||||
"status code 500",
|
||||
"status code 502",
|
||||
"status code 503",
|
||||
"status code 504",
|
||||
)
|
||||
|
||||
FRIENDLY_TRANSIENT_MSG = "Anthropic connection interrupted — please retry"
|
||||
|
||||
@@ -8,6 +8,8 @@ circular import through ``executor`` → ``credit`` → ``block_cost_config``).
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.sdk.subscription import validate_subscription
|
||||
|
||||
@@ -26,14 +28,14 @@ def build_sdk_env(
|
||||
|
||||
Three modes (checked in order):
|
||||
1. **Subscription** — clears all keys; CLI uses ``claude login`` auth.
|
||||
2. **Direct Anthropic** — returns ``{}``; subprocess inherits
|
||||
``ANTHROPIC_API_KEY`` from the parent environment.
|
||||
2. **Direct Anthropic** — subprocess inherits ``ANTHROPIC_API_KEY``
|
||||
from the parent environment (no overrides needed).
|
||||
3. **OpenRouter** (default) — overrides base URL and auth token to
|
||||
route through the proxy, with Langfuse trace headers.
|
||||
|
||||
When *sdk_cwd* is provided, ``CLAUDE_CODE_TMPDIR`` is set so that
|
||||
the CLI writes temp/sub-agent output inside the per-session workspace
|
||||
directory rather than an inaccessible system temp path.
|
||||
All modes receive workspace isolation (``CLAUDE_CODE_TMPDIR``) and
|
||||
security hardening env vars to prevent .claude.md loading, prompt
|
||||
history persistence, auto-memory writes, and non-essential traffic.
|
||||
"""
|
||||
# --- Mode 1: Claude Code subscription auth ---
|
||||
if config.use_claude_code_subscription:
|
||||
@@ -43,40 +45,51 @@ def build_sdk_env(
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
return env
|
||||
|
||||
# --- Mode 2: Direct Anthropic (no proxy hop) ---
|
||||
if not config.openrouter_active:
|
||||
elif not config.openrouter_active:
|
||||
env = {}
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
return env
|
||||
|
||||
# --- Mode 3: OpenRouter proxy ---
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
}
|
||||
else:
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
}
|
||||
|
||||
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
|
||||
def _safe(v: str) -> str:
|
||||
return v.replace("\r", "").replace("\n", "").strip()[:128]
|
||||
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
|
||||
def _safe(v: str) -> str:
|
||||
# Keep only printable ASCII (0x20–0x7e); strip control chars,
|
||||
# null bytes, and non-ASCII to produce a valid HTTP header value
|
||||
# (RFC 7230 §3.2.6).
|
||||
return re.sub(r"[^\x20-\x7e]", "", v).strip()[:128]
|
||||
|
||||
parts = []
|
||||
if session_id:
|
||||
parts.append(f"x-session-id: {_safe(session_id)}")
|
||||
if user_id:
|
||||
parts.append(f"x-user-id: {_safe(user_id)}")
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
parts = []
|
||||
if session_id:
|
||||
parts.append(f"x-session-id: {_safe(session_id)}")
|
||||
if user_id:
|
||||
parts.append(f"x-user-id: {_safe(user_id)}")
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
|
||||
# --- Common: workspace isolation + security hardening (all modes) ---
|
||||
# Route subagent temp files into the per-session workspace so output
|
||||
# files are accessible (fixes /tmp/claude-0/ permission errors in E2B).
|
||||
if sdk_cwd:
|
||||
env["CLAUDE_CODE_TMPDIR"] = sdk_cwd
|
||||
|
||||
# Harden multi-tenant deployment: prevent loading untrusted workspace
|
||||
# .claude.md files, persisting prompt history, writing auto-memory,
|
||||
# and sending non-essential telemetry traffic.
|
||||
# These are undocumented CLI internals validated against
|
||||
# claude-agent-sdk 0.1.45 — re-verify when upgrading the SDK.
|
||||
env["CLAUDE_CODE_DISABLE_CLAUDE_MDS"] = "1"
|
||||
env["CLAUDE_CODE_SKIP_PROMPT_HISTORY"] = "1"
|
||||
env["CLAUDE_CODE_DISABLE_AUTO_MEMORY"] = "1"
|
||||
env["CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC"] = "1"
|
||||
|
||||
return env
|
||||
|
||||
@@ -41,11 +41,9 @@ class TestBuildSdkEnvSubscription:
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
assert result["ANTHROPIC_API_KEY"] == ""
|
||||
assert result["ANTHROPIC_AUTH_TOKEN"] == ""
|
||||
assert result["ANTHROPIC_BASE_URL"] == ""
|
||||
mock_validate.assert_called_once()
|
||||
|
||||
@patch(
|
||||
@@ -68,18 +66,20 @@ class TestBuildSdkEnvSubscription:
|
||||
|
||||
|
||||
class TestBuildSdkEnvDirectAnthropic:
|
||||
"""When OpenRouter is inactive, return empty dict (inherit parent env)."""
|
||||
"""When OpenRouter is inactive, no ANTHROPIC_* overrides (inherit parent env)."""
|
||||
|
||||
def test_returns_empty_dict_when_openrouter_inactive(self):
|
||||
def test_no_anthropic_key_overrides_when_openrouter_inactive(self):
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch("backend.copilot.sdk.env.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {}
|
||||
assert "ANTHROPIC_API_KEY" not in result
|
||||
assert "ANTHROPIC_AUTH_TOKEN" not in result
|
||||
assert "ANTHROPIC_BASE_URL" not in result
|
||||
|
||||
def test_returns_empty_dict_when_openrouter_flag_true_but_no_key(self):
|
||||
def test_no_anthropic_key_overrides_when_openrouter_flag_true_but_no_key(self):
|
||||
"""OpenRouter flag is True but no api_key => openrouter_active is False."""
|
||||
cfg = _make_config(use_openrouter=True, base_url="https://openrouter.ai/api/v1")
|
||||
# Force api_key to None after construction (field_validator may pick up env vars)
|
||||
@@ -90,7 +90,9 @@ class TestBuildSdkEnvDirectAnthropic:
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
assert result == {}
|
||||
assert "ANTHROPIC_API_KEY" not in result
|
||||
assert "ANTHROPIC_AUTH_TOKEN" not in result
|
||||
assert "ANTHROPIC_BASE_URL" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -234,12 +236,12 @@ class TestBuildSdkEnvModePriority:
|
||||
|
||||
result = build_sdk_env()
|
||||
|
||||
# Should get subscription result, not OpenRouter
|
||||
assert result == {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
# Should get subscription result (blanked keys), not OpenRouter proxy
|
||||
assert result["ANTHROPIC_API_KEY"] == ""
|
||||
assert result["ANTHROPIC_AUTH_TOKEN"] == ""
|
||||
assert result["ANTHROPIC_BASE_URL"] == ""
|
||||
# OpenRouter-specific key must NOT be present
|
||||
assert "ANTHROPIC_CUSTOM_HEADERS" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,535 @@
|
||||
"""Tests for P0 guardrails: _resolve_fallback_model, security env vars, TMPDIR."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.constants import is_transient_api_error
|
||||
|
||||
|
||||
def _make_config(**overrides) -> ChatConfig:
|
||||
"""Create a ChatConfig with safe defaults, applying *overrides*."""
|
||||
defaults = {
|
||||
"use_claude_code_subscription": False,
|
||||
"use_openrouter": False,
|
||||
"api_key": None,
|
||||
"base_url": None,
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ChatConfig(**defaults)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _resolve_fallback_model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SVC = "backend.copilot.sdk.service"
|
||||
_ENV = "backend.copilot.sdk.env"
|
||||
|
||||
|
||||
class TestResolveFallbackModel:
|
||||
"""Provider-aware fallback model resolution."""
|
||||
|
||||
def test_returns_none_when_empty(self):
|
||||
cfg = _make_config(claude_agent_fallback_model="")
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
assert _resolve_fallback_model() is None
|
||||
|
||||
def test_strips_provider_prefix(self):
|
||||
"""OpenRouter-style 'anthropic/claude-sonnet-4-...' is stripped."""
|
||||
cfg = _make_config(
|
||||
claude_agent_fallback_model="anthropic/claude-sonnet-4-20250514",
|
||||
use_openrouter=True,
|
||||
api_key="sk-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
result = _resolve_fallback_model()
|
||||
|
||||
assert result == "claude-sonnet-4-20250514"
|
||||
assert "/" not in result
|
||||
|
||||
def test_dots_replaced_for_direct_anthropic(self):
|
||||
"""Direct Anthropic requires hyphen-separated versions."""
|
||||
cfg = _make_config(
|
||||
claude_agent_fallback_model="claude-sonnet-4.5-20250514",
|
||||
use_openrouter=False,
|
||||
)
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
result = _resolve_fallback_model()
|
||||
|
||||
assert result is not None
|
||||
assert "." not in result
|
||||
assert result == "claude-sonnet-4-5-20250514"
|
||||
|
||||
def test_dots_preserved_for_openrouter(self):
|
||||
"""OpenRouter uses dot-separated versions — don't normalise."""
|
||||
cfg = _make_config(
|
||||
claude_agent_fallback_model="claude-sonnet-4.5-20250514",
|
||||
use_openrouter=True,
|
||||
api_key="sk-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
result = _resolve_fallback_model()
|
||||
|
||||
assert result == "claude-sonnet-4.5-20250514"
|
||||
|
||||
def test_default_value(self):
|
||||
"""Default fallback model resolves to a valid string."""
|
||||
cfg = _make_config()
|
||||
with patch(f"{_SVC}.config", cfg):
|
||||
from backend.copilot.sdk.service import _resolve_fallback_model
|
||||
|
||||
result = _resolve_fallback_model()
|
||||
|
||||
assert result is not None
|
||||
assert "sonnet" in result.lower() or "claude" in result.lower()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Security & isolation env vars
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
_SECURITY_VARS = (
|
||||
"CLAUDE_CODE_DISABLE_CLAUDE_MDS",
|
||||
"CLAUDE_CODE_SKIP_PROMPT_HISTORY",
|
||||
"CLAUDE_CODE_DISABLE_AUTO_MEMORY",
|
||||
"CLAUDE_CODE_DISABLE_NONESSENTIAL_TRAFFIC",
|
||||
)
|
||||
|
||||
|
||||
class TestSecurityEnvVars:
|
||||
"""Verify security env vars are set in the returned dict for every auth mode.
|
||||
|
||||
Tests call ``build_sdk_env()`` directly and assert the vars are present
|
||||
in the returned dict — not just present somewhere in the source file.
|
||||
"""
|
||||
|
||||
def test_security_vars_set_in_openrouter_mode(self):
|
||||
"""Mode 3 (OpenRouter): security vars must be in the returned env."""
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=False,
|
||||
use_openrouter=True,
|
||||
api_key="sk-or-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env(session_id="s1", user_id="u1")
|
||||
|
||||
for var in _SECURITY_VARS:
|
||||
assert env.get(var) == "1", f"{var} not set in OpenRouter mode"
|
||||
|
||||
def test_security_vars_set_in_direct_anthropic_mode(self):
|
||||
"""Mode 2 (direct Anthropic): security vars must be in the returned env."""
|
||||
cfg = _make_config(use_claude_code_subscription=False, use_openrouter=False)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env()
|
||||
|
||||
for var in _SECURITY_VARS:
|
||||
assert env.get(var) == "1", f"{var} not set in direct Anthropic mode"
|
||||
|
||||
def test_security_vars_set_in_subscription_mode(self):
|
||||
"""Mode 1 (subscription): security vars must be in the returned env."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with (
|
||||
patch(f"{_ENV}.config", cfg),
|
||||
patch(f"{_ENV}.validate_subscription"),
|
||||
):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env(session_id="s1", user_id="u1")
|
||||
|
||||
for var in _SECURITY_VARS:
|
||||
assert env.get(var) == "1", f"{var} not set in subscription mode"
|
||||
|
||||
def test_tmpdir_set_when_sdk_cwd_provided(self):
|
||||
"""CLAUDE_CODE_TMPDIR must be set when sdk_cwd is provided."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env(sdk_cwd="/workspace/session-1")
|
||||
|
||||
assert env.get("CLAUDE_CODE_TMPDIR") == "/workspace/session-1"
|
||||
|
||||
def test_tmpdir_absent_when_sdk_cwd_not_provided(self):
|
||||
"""CLAUDE_CODE_TMPDIR must NOT be set when sdk_cwd is None."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env()
|
||||
|
||||
assert "CLAUDE_CODE_TMPDIR" not in env
|
||||
|
||||
def test_home_not_overridden(self):
|
||||
"""HOME must NOT be overridden — would break git/ssh/npm in subprocesses."""
|
||||
cfg = _make_config(use_openrouter=False)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env()
|
||||
|
||||
assert "HOME" not in env
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config defaults
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigDefaults:
|
||||
"""Verify ChatConfig P0 fields have correct defaults."""
|
||||
|
||||
def test_fallback_model_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_fallback_model
|
||||
assert "sonnet" in cfg.claude_agent_fallback_model.lower()
|
||||
|
||||
def test_max_turns_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_max_turns == 1000
|
||||
|
||||
def test_max_budget_usd_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_max_budget_usd == 100.0
|
||||
|
||||
def test_max_transient_retries_default(self):
|
||||
cfg = _make_config()
|
||||
assert cfg.claude_agent_max_transient_retries == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_sdk_env — all 3 auth modes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBuildSdkEnv:
|
||||
"""Verify build_sdk_env returns correct dicts for each auth mode."""
|
||||
|
||||
def test_subscription_mode_clears_keys(self):
|
||||
"""Mode 1: subscription clears API key / auth token / base URL."""
|
||||
cfg = _make_config(use_claude_code_subscription=True)
|
||||
with (
|
||||
patch(f"{_ENV}.config", cfg),
|
||||
patch(f"{_ENV}.validate_subscription"),
|
||||
):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env(session_id="s1", user_id="u1")
|
||||
|
||||
assert env["ANTHROPIC_API_KEY"] == ""
|
||||
assert env["ANTHROPIC_AUTH_TOKEN"] == ""
|
||||
assert env["ANTHROPIC_BASE_URL"] == ""
|
||||
|
||||
def test_direct_anthropic_inherits_api_key(self):
|
||||
"""Mode 2: direct Anthropic doesn't set ANTHROPIC_* keys (inherits from parent)."""
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=False,
|
||||
use_openrouter=False,
|
||||
)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env()
|
||||
|
||||
assert "ANTHROPIC_API_KEY" not in env
|
||||
assert "ANTHROPIC_AUTH_TOKEN" not in env
|
||||
assert "ANTHROPIC_BASE_URL" not in env
|
||||
|
||||
def test_openrouter_sets_base_url_and_auth(self):
|
||||
"""Mode 3: OpenRouter sets base URL, auth token, and clears API key."""
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=False,
|
||||
use_openrouter=True,
|
||||
api_key="sk-or-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env(session_id="sess-1", user_id="user-1")
|
||||
|
||||
assert env["ANTHROPIC_BASE_URL"] == "https://openrouter.ai/api"
|
||||
assert env["ANTHROPIC_AUTH_TOKEN"] == "sk-or-test"
|
||||
assert env["ANTHROPIC_API_KEY"] == ""
|
||||
assert "x-session-id: sess-1" in env["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
assert "x-user-id: user-1" in env["ANTHROPIC_CUSTOM_HEADERS"]
|
||||
|
||||
def test_openrouter_no_headers_when_ids_empty(self):
|
||||
"""Mode 3: No custom headers when session_id/user_id are not given."""
|
||||
cfg = _make_config(
|
||||
use_claude_code_subscription=False,
|
||||
use_openrouter=True,
|
||||
api_key="sk-or-test",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
with patch(f"{_ENV}.config", cfg):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env()
|
||||
|
||||
assert "ANTHROPIC_CUSTOM_HEADERS" not in env
|
||||
|
||||
def test_all_modes_return_mutable_dict(self):
|
||||
"""build_sdk_env must return a mutable dict (not None) in every mode."""
|
||||
for cfg in (
|
||||
_make_config(use_claude_code_subscription=True),
|
||||
_make_config(use_openrouter=False),
|
||||
_make_config(
|
||||
use_openrouter=True,
|
||||
api_key="k",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
),
|
||||
):
|
||||
with (
|
||||
patch(f"{_ENV}.config", cfg),
|
||||
patch(f"{_ENV}.validate_subscription"),
|
||||
):
|
||||
from backend.copilot.sdk.env import build_sdk_env
|
||||
|
||||
env = build_sdk_env()
|
||||
|
||||
assert isinstance(env, dict)
|
||||
env["CLAUDE_CODE_TMPDIR"] = "/tmp/test"
|
||||
assert env["CLAUDE_CODE_TMPDIR"] == "/tmp/test"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# is_transient_api_error
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsTransientApiError:
|
||||
"""Verify that is_transient_api_error detects all transient patterns."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_text",
|
||||
[
|
||||
"socket connection was closed unexpectedly",
|
||||
"ECONNRESET",
|
||||
"connection was forcibly closed",
|
||||
"network socket disconnected",
|
||||
],
|
||||
)
|
||||
def test_connection_level_errors(self, error_text: str):
|
||||
assert is_transient_api_error(error_text)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_text",
|
||||
[
|
||||
"rate limit exceeded",
|
||||
"rate_limit_error",
|
||||
"Too Many Requests",
|
||||
"status code 429",
|
||||
],
|
||||
)
|
||||
def test_429_rate_limit_errors(self, error_text: str):
|
||||
assert is_transient_api_error(error_text)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_text",
|
||||
[
|
||||
# Status-code-specific patterns (preferred — no false-positive risk)
|
||||
"status code 529",
|
||||
"status code 500",
|
||||
"status code 502",
|
||||
"status code 503",
|
||||
"status code 504",
|
||||
],
|
||||
)
|
||||
def test_5xx_server_errors(self, error_text: str):
|
||||
assert is_transient_api_error(error_text)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_text",
|
||||
[
|
||||
"invalid_api_key",
|
||||
"Authentication failed",
|
||||
"prompt is too long",
|
||||
"model not found",
|
||||
"",
|
||||
# Natural-language phrases intentionally NOT matched — they are too
|
||||
# broad and could appear in application-level SDK messages unrelated
|
||||
# to Anthropic API transient conditions.
|
||||
"API is overloaded",
|
||||
"Internal Server Error",
|
||||
"Bad Gateway",
|
||||
"Service Unavailable",
|
||||
"Gateway Timeout",
|
||||
],
|
||||
)
|
||||
def test_non_transient_errors(self, error_text: str):
|
||||
assert not is_transient_api_error(error_text)
|
||||
|
||||
def test_case_insensitive(self):
|
||||
assert is_transient_api_error("SOCKET CONNECTION WAS CLOSED UNEXPECTEDLY")
|
||||
assert is_transient_api_error("econnreset")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _HandledStreamError.already_yielded contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHandledStreamErrorAlreadyYielded:
|
||||
"""Verify the already_yielded semantics on _HandledStreamError."""
|
||||
|
||||
def test_default_already_yielded_is_true(self):
|
||||
"""Non-transient callers (circuit-breaker, idle timeout) don't pass the flag —
|
||||
the default True means the outer loop won't yield a duplicate StreamError."""
|
||||
from backend.copilot.sdk.service import _HandledStreamError
|
||||
|
||||
exc = _HandledStreamError("some error", code="circuit_breaker_empty_tool_calls")
|
||||
assert exc.already_yielded is True
|
||||
|
||||
def test_transient_error_sets_already_yielded_false(self):
|
||||
"""Transient errors pass already_yielded=False so the outer loop
|
||||
yields StreamError only once (when retries are exhausted)."""
|
||||
from backend.copilot.sdk.service import _HandledStreamError
|
||||
|
||||
exc = _HandledStreamError(
|
||||
"transient",
|
||||
code="transient_api_error",
|
||||
already_yielded=False,
|
||||
)
|
||||
assert exc.already_yielded is False
|
||||
|
||||
def test_backoff_capped_at_30s(self):
|
||||
"""Exponential backoff must be capped at 30 seconds.
|
||||
|
||||
With max_transient_retries=10, uncapped 2^9=512s would stall users
|
||||
for 8+ minutes. min(30, 2**(n-1)) keeps the ceiling at 30s.
|
||||
"""
|
||||
# Check that 2^(10-1)=512 would exceed 30 but min() caps it.
|
||||
assert min(30, 2 ** (10 - 1)) == 30
|
||||
# Verify the formula is monotonically non-decreasing and capped.
|
||||
backoffs = [min(30, 2 ** (n - 1)) for n in range(1, 11)]
|
||||
assert all(b <= 30 for b in backoffs)
|
||||
assert backoffs[-1] == 30 # last retry is capped
|
||||
assert backoffs[0] == 1 # first retry starts at 1s
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Config validators for max_turns / max_budget_usd
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConfigValidators:
|
||||
"""Verify ge/le bounds on max_turns and max_budget_usd."""
|
||||
|
||||
def test_max_turns_rejects_zero(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_turns=0)
|
||||
|
||||
def test_max_turns_rejects_negative(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_turns=-1)
|
||||
|
||||
def test_max_turns_rejects_above_10000(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_turns=10001)
|
||||
|
||||
def test_max_turns_accepts_boundary_values(self):
|
||||
cfg_low = _make_config(claude_agent_max_turns=1)
|
||||
assert cfg_low.claude_agent_max_turns == 1
|
||||
cfg_high = _make_config(claude_agent_max_turns=10000)
|
||||
assert cfg_high.claude_agent_max_turns == 10000
|
||||
|
||||
def test_max_budget_rejects_zero(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_budget_usd=0.0)
|
||||
|
||||
def test_max_budget_rejects_negative(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_budget_usd=-1.0)
|
||||
|
||||
def test_max_budget_rejects_above_1000(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_budget_usd=1000.01)
|
||||
|
||||
def test_max_budget_accepts_boundary_values(self):
|
||||
cfg_low = _make_config(claude_agent_max_budget_usd=0.01)
|
||||
assert cfg_low.claude_agent_max_budget_usd == 0.01
|
||||
cfg_high = _make_config(claude_agent_max_budget_usd=1000.0)
|
||||
assert cfg_high.claude_agent_max_budget_usd == 1000.0
|
||||
|
||||
def test_max_transient_retries_rejects_negative(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_transient_retries=-1)
|
||||
|
||||
def test_max_transient_retries_rejects_above_10(self):
|
||||
with pytest.raises(ValidationError):
|
||||
_make_config(claude_agent_max_transient_retries=11)
|
||||
|
||||
def test_max_transient_retries_accepts_boundary_values(self):
|
||||
cfg_low = _make_config(claude_agent_max_transient_retries=0)
|
||||
assert cfg_low.claude_agent_max_transient_retries == 0
|
||||
cfg_high = _make_config(claude_agent_max_transient_retries=10)
|
||||
assert cfg_high.claude_agent_max_transient_retries == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# transient_exhausted SSE code contract
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTransientExhaustedErrorCode:
|
||||
"""Verify transient-exhausted path emits the correct SSE error code."""
|
||||
|
||||
def test_transient_exhausted_uses_transient_api_error_code(self):
|
||||
"""When except-Exception transient retries are exhausted, the SSE
|
||||
StreamError must use code='transient_api_error', not 'sdk_stream_error'.
|
||||
|
||||
This ensures the frontend shows the same 'Try again' affordance as
|
||||
the _HandledStreamError path.
|
||||
"""
|
||||
from backend.copilot.constants import FRIENDLY_TRANSIENT_MSG
|
||||
|
||||
# Simulate the post-loop branching logic extracted from service.py
|
||||
attempts_exhausted = False
|
||||
transient_exhausted = True
|
||||
stream_err: Exception | None = ConnectionResetError("ECONNRESET")
|
||||
|
||||
if attempts_exhausted:
|
||||
error_code = "all_attempts_exhausted"
|
||||
error_text = "conversation too long"
|
||||
elif transient_exhausted:
|
||||
error_code = "transient_api_error"
|
||||
error_text = FRIENDLY_TRANSIENT_MSG
|
||||
else:
|
||||
error_code = "sdk_stream_error"
|
||||
error_text = f"SDK stream error: {stream_err}"
|
||||
|
||||
assert error_code == "transient_api_error"
|
||||
assert error_text == FRIENDLY_TRANSIENT_MSG
|
||||
|
||||
def test_non_transient_exhausted_uses_sdk_stream_error_code(self):
|
||||
"""Non-transient fatal errors (auth, network) keep 'sdk_stream_error'."""
|
||||
attempts_exhausted = False
|
||||
transient_exhausted = False
|
||||
|
||||
if attempts_exhausted:
|
||||
error_code = "all_attempts_exhausted"
|
||||
elif transient_exhausted:
|
||||
error_code = "transient_api_error"
|
||||
else:
|
||||
error_code = "sdk_stream_error"
|
||||
|
||||
assert error_code == "sdk_stream_error"
|
||||
@@ -260,13 +260,13 @@ def test_result_error_emits_error_and_finish():
|
||||
is_error=True,
|
||||
num_turns=0,
|
||||
session_id="s1",
|
||||
result="API rate limited",
|
||||
result="Invalid API key provided",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
# No step was open, so no FinishStep — just Error + Finish
|
||||
assert len(results) == 2
|
||||
assert isinstance(results[0], StreamError)
|
||||
assert "API rate limited" in results[0].errorText
|
||||
assert "Invalid API key provided" in results[0].errorText
|
||||
assert isinstance(results[1], StreamFinish)
|
||||
|
||||
|
||||
|
||||
@@ -105,6 +105,10 @@ def test_agent_options_accepts_all_our_fields():
|
||||
"env",
|
||||
"resume",
|
||||
"max_buffer_size",
|
||||
"stderr",
|
||||
"fallback_model",
|
||||
"max_turns",
|
||||
"max_budget_usd",
|
||||
]
|
||||
sig = inspect.signature(ClaudeAgentOptions)
|
||||
for field in fields_we_use:
|
||||
|
||||
@@ -545,17 +545,34 @@ async def _iter_sdk_messages(
|
||||
pass
|
||||
|
||||
|
||||
def _normalize_model_name(raw_model: str) -> str:
|
||||
"""Normalize a model name for the current routing configuration.
|
||||
|
||||
Applies two transformations shared by both the primary and fallback
|
||||
model resolution paths:
|
||||
|
||||
1. **Strip provider prefix** — OpenRouter-style names like
|
||||
``"anthropic/claude-opus-4.6"`` are reduced to ``"claude-opus-4.6"``.
|
||||
2. **Dot-to-hyphen conversion** — when *not* routing through OpenRouter
|
||||
the direct Anthropic API requires hyphen-separated versions
|
||||
(``"claude-opus-4-6"``), so dots are replaced with hyphens.
|
||||
"""
|
||||
model = raw_model
|
||||
if "/" in model:
|
||||
model = model.split("/", 1)[1]
|
||||
# OpenRouter uses dots in versions (claude-opus-4.6) but the direct
|
||||
# Anthropic API requires hyphens (claude-opus-4-6). Only normalise
|
||||
# when NOT routing through OpenRouter.
|
||||
if not config.openrouter_active:
|
||||
model = model.replace(".", "-")
|
||||
return model
|
||||
|
||||
|
||||
def _resolve_sdk_model() -> str | None:
|
||||
"""Resolve the model name for the Claude Agent SDK CLI.
|
||||
|
||||
Uses `config.claude_agent_model` if set, otherwise derives from
|
||||
`config.model` by stripping the OpenRouter provider prefix (e.g.,
|
||||
`"anthropic/claude-opus-4.6"` → `"claude-opus-4-6"`).
|
||||
|
||||
OpenRouter uses dot-separated versions (`claude-opus-4.6`) while the
|
||||
direct Anthropic API uses hyphen-separated versions (`claude-opus-4-6`).
|
||||
Normalisation is only applied when the SDK will actually talk to
|
||||
Anthropic directly (not through OpenRouter).
|
||||
`config.model` via :func:`_normalize_model_name`.
|
||||
|
||||
When `use_claude_code_subscription` is enabled and no explicit
|
||||
`claude_agent_model` is set, returns `None` so the CLI uses the
|
||||
@@ -565,15 +582,18 @@ def _resolve_sdk_model() -> str | None:
|
||||
return config.claude_agent_model
|
||||
if config.use_claude_code_subscription:
|
||||
return None
|
||||
model = config.model
|
||||
if "/" in model:
|
||||
model = model.split("/", 1)[1]
|
||||
# OpenRouter uses dots in versions (claude-opus-4.6) but the direct
|
||||
# Anthropic API requires hyphens (claude-opus-4-6). Only normalise
|
||||
# when NOT routing through OpenRouter.
|
||||
if not config.openrouter_active:
|
||||
model = model.replace(".", "-")
|
||||
return model
|
||||
return _normalize_model_name(config.model)
|
||||
|
||||
|
||||
def _resolve_fallback_model() -> str | None:
|
||||
"""Resolve the fallback model name via :func:`_normalize_model_name`.
|
||||
|
||||
Returns ``None`` when no fallback is configured (empty string).
|
||||
"""
|
||||
raw = config.claude_agent_fallback_model
|
||||
if not raw:
|
||||
return None
|
||||
return _normalize_model_name(raw)
|
||||
|
||||
|
||||
def _make_sdk_cwd(session_id: str) -> str:
|
||||
@@ -1063,17 +1083,25 @@ def _dispatch_response(
|
||||
|
||||
|
||||
class _HandledStreamError(Exception):
|
||||
"""Raised by `_run_stream_attempt` after it has already yielded a
|
||||
`StreamError` to the client (e.g. transient API error, circuit breaker).
|
||||
"""Raised by `_run_stream_attempt` when an attempt fails and the outer
|
||||
retry loop must roll back session state.
|
||||
|
||||
This signals the outer retry loop that the attempt failed so it can
|
||||
perform session-message rollback and set the `ended_with_stream_error`
|
||||
flag, **without** yielding a duplicate `StreamError` to the client.
|
||||
Two sub-cases:
|
||||
|
||||
* ``already_yielded=True`` (default) — a ``StreamError`` was already sent
|
||||
to the client inside ``_run_stream_attempt`` (circuit-breaker, idle
|
||||
timeout, etc.). The outer loop must **not** yield another one.
|
||||
* ``already_yielded=False`` — the error is transient and the outer loop
|
||||
will decide whether to retry or surface the error. If retrying it
|
||||
yields a ``StreamStatus("retrying…")``; if exhausted it yields the
|
||||
``StreamError`` itself so the client sees it only once.
|
||||
|
||||
Attributes:
|
||||
error_msg: The user-facing error message to persist.
|
||||
code: Machine-readable error code (e.g. ``circuit_breaker_empty_tool_calls``).
|
||||
retryable: Whether the frontend should offer a retry button.
|
||||
already_yielded: ``True`` when ``StreamError`` was already sent to the
|
||||
client before this exception was raised.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -1082,11 +1110,13 @@ class _HandledStreamError(Exception):
|
||||
error_msg: str | None = None,
|
||||
code: str | None = None,
|
||||
retryable: bool = True,
|
||||
already_yielded: bool = True,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.error_msg = error_msg
|
||||
self.code = code
|
||||
self.retryable = retryable
|
||||
self.already_yielded = already_yielded
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -1377,15 +1407,12 @@ async def _run_stream_attempt(
|
||||
)
|
||||
stream_error_msg = FRIENDLY_TRANSIENT_MSG
|
||||
stream_error_code = "transient_api_error"
|
||||
_append_error_marker(
|
||||
ctx.session,
|
||||
stream_error_msg,
|
||||
retryable=True,
|
||||
)
|
||||
yield StreamError(
|
||||
errorText=stream_error_msg,
|
||||
code=stream_error_code,
|
||||
)
|
||||
# Do NOT yield StreamError or append error marker here.
|
||||
# The outer retry loop decides: if a retry is available it
|
||||
# yields StreamStatus("retrying…"); if retries are exhausted
|
||||
# it appends the marker and yields StreamError exactly once.
|
||||
# Yielding StreamError before the retry decision causes the
|
||||
# client to display an error that is immediately superseded.
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
|
||||
@@ -1658,14 +1685,16 @@ async def _run_stream_attempt(
|
||||
) and not acc.has_appended_assistant:
|
||||
ctx.session.messages.append(acc.assistant_response)
|
||||
|
||||
# If the attempt ended with a transient error that was already surfaced
|
||||
# to the client (StreamError yielded above), raise so the outer retry
|
||||
# loop can rollback session messages and set its error flags properly.
|
||||
# Raise so the outer retry loop can rollback session messages.
|
||||
# already_yielded=False for transient_api_error: StreamError was NOT
|
||||
# sent to the client yet (the outer loop does it when retries are
|
||||
# exhausted, avoiding a premature error flash before the retry).
|
||||
if ended_with_stream_error:
|
||||
raise _HandledStreamError(
|
||||
"Stream error handled — StreamError already yielded",
|
||||
"Stream error handled",
|
||||
error_msg=stream_error_msg,
|
||||
code=stream_error_code,
|
||||
already_yielded=(stream_error_code != "transient_api_error"),
|
||||
)
|
||||
|
||||
|
||||
@@ -1960,10 +1989,29 @@ async def stream_chat_completion_sdk(
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
|
||||
# Flag set by _on_stderr when the SDK logs that it switched to the
|
||||
# fallback model (e.g. on a 529 overloaded error). Checked once per
|
||||
# heartbeat cycle and emitted as a StreamStatus notification.
|
||||
fallback_model_activated = False
|
||||
|
||||
def _on_stderr(line: str) -> None:
|
||||
"""Log a stderr line emitted by the Claude CLI subprocess."""
|
||||
nonlocal fallback_model_activated
|
||||
sid = session_id[:12] if session_id else "?"
|
||||
logger.info("[SDK] [%s] CLI stderr: %s", sid, line.rstrip())
|
||||
# Detect SDK fallback-model activation. The CLI logs a
|
||||
# message containing "fallback model" when it switches models
|
||||
# after a 529/overloaded error. Match "fallback model" rather
|
||||
# than just "fallback" to avoid false positives from unrelated
|
||||
# stderr lines (e.g. tool-level retries, cached result fallbacks).
|
||||
lower = line.lower()
|
||||
if not fallback_model_activated and "fallback model" in lower:
|
||||
fallback_model_activated = True
|
||||
logger.warning(
|
||||
"[SDK] [%s] Fallback model activated — primary model "
|
||||
"overloaded, switching to fallback",
|
||||
sid,
|
||||
)
|
||||
|
||||
sdk_options_kwargs: dict[str, Any] = {
|
||||
"system_prompt": system_prompt,
|
||||
@@ -1974,6 +2022,15 @@ async def stream_chat_completion_sdk(
|
||||
"cwd": sdk_cwd,
|
||||
"max_buffer_size": config.claude_agent_max_buffer_size,
|
||||
"stderr": _on_stderr,
|
||||
# --- P0 guardrails ---
|
||||
# fallback_model: SDK auto-retries with this cheaper model on
|
||||
# 529 (overloaded) errors, avoiding user-visible failures.
|
||||
"fallback_model": _resolve_fallback_model(),
|
||||
# max_turns: hard cap on agentic tool-use loops per query to
|
||||
# prevent runaway execution from burning budget.
|
||||
"max_turns": config.claude_agent_max_turns,
|
||||
# max_budget_usd: per-query spend ceiling enforced by the CLI.
|
||||
"max_budget_usd": config.claude_agent_max_budget_usd,
|
||||
}
|
||||
if sdk_model:
|
||||
sdk_options_kwargs["model"] = sdk_model
|
||||
@@ -2058,8 +2115,29 @@ async def stream_chat_completion_sdk(
|
||||
# ---------------------------------------------------------------
|
||||
ended_with_stream_error = False
|
||||
attempts_exhausted = False
|
||||
transient_exhausted = False
|
||||
stream_err: Exception | None = None
|
||||
|
||||
# Transient retry helper — deduplicates the logic shared between
|
||||
# _HandledStreamError and the generic except-Exception handler.
|
||||
transient_retries = 0
|
||||
max_transient_retries = config.claude_agent_max_transient_retries
|
||||
|
||||
def _next_transient_backoff() -> int | None:
|
||||
"""Return the next backoff delay in seconds, or ``None`` to surface the error.
|
||||
|
||||
Returns the backoff seconds if a retry should be attempted,
|
||||
or ``None`` if retries are exhausted or events were already
|
||||
yielded. Mutates outer ``transient_retries`` via nonlocal.
|
||||
"""
|
||||
nonlocal transient_retries
|
||||
if events_yielded > 0:
|
||||
return None
|
||||
transient_retries += 1
|
||||
if transient_retries > max_transient_retries:
|
||||
return None
|
||||
return min(30, 2 ** (transient_retries - 1)) # 1s, 2s, 4s, …, cap 30s
|
||||
|
||||
state = _RetryState(
|
||||
options=options,
|
||||
query_message=query_message,
|
||||
@@ -2072,7 +2150,19 @@ async def stream_chat_completion_sdk(
|
||||
usage=_TokenUsage(),
|
||||
)
|
||||
|
||||
for attempt in range(_MAX_STREAM_ATTEMPTS):
|
||||
attempt = 0
|
||||
_last_reset_attempt = -1
|
||||
while attempt < _MAX_STREAM_ATTEMPTS:
|
||||
# Reset transient retry counter per context-level attempt so
|
||||
# each attempt (original, compacted, no-transcript) gets the
|
||||
# full retry budget for transient errors.
|
||||
# Only reset when the attempt number actually changes —
|
||||
# transient retries `continue` back to the loop top without
|
||||
# incrementing `attempt`, so resetting unconditionally would
|
||||
# create an infinite retry loop.
|
||||
if attempt != _last_reset_attempt:
|
||||
transient_retries = 0
|
||||
_last_reset_attempt = attempt
|
||||
# Clear any stale stash signal from the previous attempt so
|
||||
# wait_for_stash() doesn't fire prematurely on a leftover event.
|
||||
reset_stash_event()
|
||||
@@ -2127,7 +2217,15 @@ async def stream_chat_completion_sdk(
|
||||
state.usage.reset()
|
||||
|
||||
pre_attempt_msg_count = len(session.messages)
|
||||
# Snapshot transcript builder state — it maintains an
|
||||
# independent _entries list from session.messages, so rolling
|
||||
# back session.messages alone would leave duplicate entries
|
||||
# from the failed attempt in the uploaded transcript.
|
||||
pre_transcript_entries = list(state.transcript_builder._entries)
|
||||
pre_transcript_uuid = state.transcript_builder._last_uuid
|
||||
events_yielded = 0
|
||||
fallback_model_activated = False
|
||||
fallback_notified = False
|
||||
|
||||
try:
|
||||
async for event in _run_stream_attempt(stream_ctx, state):
|
||||
@@ -2143,9 +2241,24 @@ async def stream_chat_completion_sdk(
|
||||
StreamToolInputStart,
|
||||
StreamToolInputAvailable,
|
||||
StreamToolOutputAvailable,
|
||||
# Transient StreamError and StreamStatus are
|
||||
# ephemeral notifications, not content. Counting
|
||||
# them would prevent the backoff retry from firing
|
||||
# because _next_transient_backoff() returns None
|
||||
# when events_yielded > 0.
|
||||
StreamError,
|
||||
StreamStatus,
|
||||
),
|
||||
):
|
||||
events_yielded += 1
|
||||
# Emit a one-time StreamStatus when the SDK switches
|
||||
# to the fallback model (detected via stderr).
|
||||
if fallback_model_activated and not fallback_notified:
|
||||
fallback_notified = True
|
||||
yield StreamStatus(
|
||||
message="Primary model overloaded — "
|
||||
"using fallback model for this request"
|
||||
)
|
||||
yield event
|
||||
break # Stream completed — exit retry loop
|
||||
except asyncio.CancelledError:
|
||||
@@ -2162,6 +2275,31 @@ async def stream_chat_completion_sdk(
|
||||
# session messages and set the error flag — do NOT set
|
||||
# stream_err so the post-loop code won't emit a
|
||||
# duplicate StreamError.
|
||||
session.messages = session.messages[:pre_attempt_msg_count]
|
||||
state.transcript_builder._entries = pre_transcript_entries
|
||||
state.transcript_builder._last_uuid = pre_transcript_uuid
|
||||
# Check if this is a transient error we can retry with backoff.
|
||||
# exc.code is the only reliable signal — str(exc) is always the
|
||||
# static "Stream error handled — StreamError already yielded" message.
|
||||
if exc.code == "transient_api_error":
|
||||
backoff = _next_transient_backoff()
|
||||
if backoff is not None:
|
||||
logger.warning(
|
||||
"%s Transient error — retrying in %ds (%d/%d)",
|
||||
log_prefix,
|
||||
backoff,
|
||||
transient_retries,
|
||||
max_transient_retries,
|
||||
)
|
||||
yield StreamStatus(
|
||||
message=f"Connection interrupted, retrying in {backoff}s…"
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
state.adapter = SDKResponseAdapter(
|
||||
message_id=message_id, session_id=session_id
|
||||
)
|
||||
state.usage.reset()
|
||||
continue # retry the same context-level attempt
|
||||
logger.warning(
|
||||
"%s Stream error handled in attempt "
|
||||
"(attempt %d/%d, code=%s, events_yielded=%d)",
|
||||
@@ -2171,7 +2309,6 @@ async def stream_chat_completion_sdk(
|
||||
exc.code or "transient",
|
||||
events_yielded,
|
||||
)
|
||||
session.messages = session.messages[:pre_attempt_msg_count]
|
||||
# transcript_builder still contains entries from the aborted
|
||||
# attempt that no longer match session.messages. Skip upload
|
||||
# so a future --resume doesn't replay rolled-back content.
|
||||
@@ -2186,22 +2323,37 @@ async def stream_chat_completion_sdk(
|
||||
retryable=True,
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
# For transient errors the StreamError was deliberately NOT
|
||||
# yielded inside _run_stream_attempt (already_yielded=False)
|
||||
# so the client didn't see a premature error flash. Yield it
|
||||
# now that we know retries are exhausted.
|
||||
# For non-transient errors (circuit breaker, idle timeout)
|
||||
# already_yielded=True — do NOT yield again.
|
||||
if not exc.already_yielded:
|
||||
yield StreamError(
|
||||
errorText=exc.error_msg or FRIENDLY_TRANSIENT_MSG,
|
||||
code=exc.code or "transient_api_error",
|
||||
)
|
||||
break
|
||||
except Exception as e:
|
||||
stream_err = e
|
||||
is_context_error = _is_prompt_too_long(e)
|
||||
is_transient = is_transient_api_error(str(e))
|
||||
logger.warning(
|
||||
"%s Stream error (attempt %d/%d, context_error=%s, "
|
||||
"events_yielded=%d): %s",
|
||||
"transient=%s, events_yielded=%d): %s",
|
||||
log_prefix,
|
||||
attempt + 1,
|
||||
_MAX_STREAM_ATTEMPTS,
|
||||
is_context_error,
|
||||
is_transient,
|
||||
events_yielded,
|
||||
stream_err,
|
||||
exc_info=True,
|
||||
)
|
||||
session.messages = session.messages[:pre_attempt_msg_count]
|
||||
state.transcript_builder._entries = pre_transcript_entries
|
||||
state.transcript_builder._last_uuid = pre_transcript_uuid
|
||||
if events_yielded > 0:
|
||||
# Events were already sent to the frontend and cannot be
|
||||
# unsent. Retrying would produce duplicate/inconsistent
|
||||
@@ -2214,16 +2366,50 @@ async def stream_chat_completion_sdk(
|
||||
skip_transcript_upload = True
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
# Transient API errors (ECONNRESET, 429, 5xx) — retry
|
||||
# with exponential backoff via the shared helper.
|
||||
if is_transient:
|
||||
backoff = _next_transient_backoff()
|
||||
if backoff is not None:
|
||||
logger.warning(
|
||||
"%s Transient exception — retrying in %ds (%d/%d)",
|
||||
log_prefix,
|
||||
backoff,
|
||||
transient_retries,
|
||||
max_transient_retries,
|
||||
)
|
||||
yield StreamStatus(
|
||||
message=f"Connection interrupted, retrying in {backoff}s…"
|
||||
)
|
||||
await asyncio.sleep(backoff)
|
||||
state.adapter = SDKResponseAdapter(
|
||||
message_id=message_id, session_id=session_id
|
||||
)
|
||||
state.usage.reset()
|
||||
continue # retry same context-level attempt
|
||||
# Retries exhausted — persist retryable marker so the
|
||||
# frontend shows "Try again" after refresh.
|
||||
# Mirrors the _HandledStreamError exhausted-retry path
|
||||
# at line ~2310.
|
||||
transient_exhausted = True
|
||||
skip_transcript_upload = True
|
||||
_append_error_marker(
|
||||
session, FRIENDLY_TRANSIENT_MSG, retryable=True
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
|
||||
if not is_context_error:
|
||||
# Non-context errors (network, auth, rate-limit) should
|
||||
# not trigger compaction — surface the error immediately.
|
||||
# Non-context, non-transient errors (auth, fatal)
|
||||
# should not trigger compaction — surface immediately.
|
||||
skip_transcript_upload = True
|
||||
ended_with_stream_error = True
|
||||
break
|
||||
attempt += 1 # advance to next context-level attempt
|
||||
continue
|
||||
else:
|
||||
# All retry attempts exhausted (loop ended without break)
|
||||
# skip_transcript_upload is already set by _reduce_context
|
||||
# while condition became False — all attempts exhausted without
|
||||
# break. skip_transcript_upload is already set by _reduce_context
|
||||
# when the transcript was dropped (transcript_lost=True).
|
||||
ended_with_stream_error = True
|
||||
attempts_exhausted = True
|
||||
@@ -2252,25 +2438,24 @@ async def stream_chat_completion_sdk(
|
||||
yield response
|
||||
|
||||
if ended_with_stream_error and stream_err is not None:
|
||||
# Use distinct error codes: "all_attempts_exhausted" when all
|
||||
# retries were consumed vs "sdk_stream_error" for non-context
|
||||
# errors that broke the loop immediately (network, auth, etc.).
|
||||
# Use distinct error codes depending on how the loop ended:
|
||||
# • "all_attempts_exhausted" — context compaction ran out of room
|
||||
# • "transient_api_error" — 429/5xx/ECONNRESET retries exhausted
|
||||
# • "sdk_stream_error" — non-context, non-transient fatal error
|
||||
safe_err = str(stream_err).replace("\n", " ").replace("\r", "")[:500]
|
||||
if attempts_exhausted:
|
||||
error_text = (
|
||||
"Your conversation is too long. "
|
||||
"Please start a new chat or clear some history."
|
||||
)
|
||||
error_code = "all_attempts_exhausted"
|
||||
elif transient_exhausted:
|
||||
error_text = FRIENDLY_TRANSIENT_MSG
|
||||
error_code = "transient_api_error"
|
||||
else:
|
||||
error_text = _friendly_error_text(safe_err)
|
||||
yield StreamError(
|
||||
errorText=error_text,
|
||||
code=(
|
||||
"all_attempts_exhausted"
|
||||
if attempts_exhausted
|
||||
else "sdk_stream_error"
|
||||
),
|
||||
)
|
||||
error_code = "sdk_stream_error"
|
||||
yield StreamError(errorText=error_text, code=error_code)
|
||||
|
||||
# Copy token usage from retry state to outer-scope accumulators
|
||||
# so the finally block can persist them.
|
||||
|
||||
@@ -10,6 +10,7 @@ import pytest
|
||||
|
||||
from .service import (
|
||||
_is_sdk_disconnect_error,
|
||||
_normalize_model_name,
|
||||
_prepare_file_attachments,
|
||||
_resolve_sdk_model,
|
||||
_safe_close_sdk_client,
|
||||
@@ -405,6 +406,49 @@ def _clean_config_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestNormalizeModelName:
|
||||
"""Tests for _normalize_model_name — shared provider-aware normalization."""
|
||||
|
||||
def test_strips_provider_prefix(self, monkeypatch, _clean_config_env):
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4-6"
|
||||
|
||||
def test_dots_preserved_for_openrouter(self, monkeypatch, _clean_config_env):
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4.6"
|
||||
|
||||
def test_no_prefix_no_dots(self, monkeypatch, _clean_config_env):
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert (
|
||||
_normalize_model_name("claude-sonnet-4-20250514")
|
||||
== "claude-sonnet-4-20250514"
|
||||
)
|
||||
|
||||
|
||||
class TestResolveSdkModel:
|
||||
"""Tests for _resolve_sdk_model — model ID resolution for the SDK CLI."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user