mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
Compare commits
28 Commits
feat/keep-
...
fix/copilo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3fc3e13eb0 | ||
|
|
be0ebea612 | ||
|
|
bd79186f6b | ||
|
|
f068583d7f | ||
|
|
3eee249b96 | ||
|
|
1fc37f0e61 | ||
|
|
7f9c06e758 | ||
|
|
dc5a305805 | ||
|
|
4ee4643b06 | ||
|
|
3f9bc9a365 | ||
|
|
f416c2e472 | ||
|
|
14317531c7 | ||
|
|
9d8b5e3091 | ||
|
|
1c23ae0df5 | ||
|
|
d871d6c12c | ||
|
|
a83416d86b | ||
|
|
253c4bbc63 | ||
|
|
c0a91be65e | ||
|
|
64d82797b5 | ||
|
|
1565564bce | ||
|
|
0614b22a72 | ||
|
|
feeed4645c | ||
|
|
ccd69df357 | ||
|
|
1d5598df3d | ||
|
|
84f3ca9a62 | ||
|
|
94af0b264c | ||
|
|
a31fc008e8 | ||
|
|
2e8b984f8e |
@@ -115,10 +115,22 @@ class ChatConfig(BaseSettings):
|
||||
description="Use --resume for multi-turn conversations instead of "
|
||||
"history compression. Falls back to compression when unavailable.",
|
||||
)
|
||||
use_openrouter: bool = Field(
|
||||
default=True,
|
||||
description="Enable routing API calls through the OpenRouter proxy. "
|
||||
"The actual decision also requires ``api_key`` and ``base_url`` — "
|
||||
"use the ``openrouter_active`` property for the final answer.",
|
||||
)
|
||||
use_claude_code_subscription: bool = Field(
|
||||
default=False,
|
||||
description="For personal/dev use: use Claude Code CLI subscription auth instead of API keys. Requires `claude login` on the host. Only works with SDK mode.",
|
||||
)
|
||||
test_mode: bool = Field(
|
||||
default=False,
|
||||
description="Use dummy service instead of real LLM calls. "
|
||||
"Send __test_transient_error__, __test_fatal_error__, or "
|
||||
"__test_slow_response__ to trigger specific scenarios.",
|
||||
)
|
||||
|
||||
# E2B Sandbox Configuration
|
||||
use_e2b_sandbox: bool = Field(
|
||||
@@ -146,6 +158,21 @@ class ChatConfig(BaseSettings):
|
||||
description="E2B lifecycle action on timeout: 'pause' (default, free) or 'kill'.",
|
||||
)
|
||||
|
||||
@property
|
||||
def openrouter_active(self) -> bool:
|
||||
"""True when OpenRouter is enabled AND credentials are usable.
|
||||
|
||||
Single source of truth for "will the SDK route through OpenRouter?".
|
||||
Checks the flag *and* that ``api_key`` + a valid ``base_url`` are
|
||||
present — mirrors the fallback logic in ``_build_sdk_env``.
|
||||
"""
|
||||
if not self.use_openrouter:
|
||||
return False
|
||||
base = (self.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
return bool(self.api_key and base and base.startswith("http"))
|
||||
|
||||
@property
|
||||
def e2b_active(self) -> bool:
|
||||
"""True when E2B is enabled and the API key is present.
|
||||
@@ -168,15 +195,6 @@ class ChatConfig(BaseSettings):
|
||||
"""
|
||||
return self.e2b_api_key if self.e2b_active else None
|
||||
|
||||
@field_validator("use_e2b_sandbox", mode="before")
|
||||
@classmethod
|
||||
def get_use_e2b_sandbox(cls, v):
|
||||
"""Get use_e2b_sandbox from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_E2B_SANDBOX", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("e2b_api_key", mode="before")
|
||||
@classmethod
|
||||
def get_e2b_api_key(cls, v):
|
||||
@@ -219,26 +237,6 @@ class ChatConfig(BaseSettings):
|
||||
v = OPENROUTER_BASE_URL
|
||||
return v
|
||||
|
||||
@field_validator("use_claude_agent_sdk", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_agent_sdk(cls, v):
|
||||
"""Get use_claude_agent_sdk from environment if not provided."""
|
||||
# Check environment variable - default to True if not set
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_AGENT_SDK", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
# Default to True (SDK enabled by default)
|
||||
return True if v is None else v
|
||||
|
||||
@field_validator("use_claude_code_subscription", mode="before")
|
||||
@classmethod
|
||||
def get_use_claude_code_subscription(cls, v):
|
||||
"""Get use_claude_code_subscription from environment if not provided."""
|
||||
env_val = os.getenv("CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "").lower()
|
||||
if env_val:
|
||||
return env_val in ("true", "1", "yes", "on")
|
||||
return False if v is None else v
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
@@ -248,6 +246,7 @@ class ChatConfig(BaseSettings):
|
||||
class Config:
|
||||
"""Pydantic config."""
|
||||
|
||||
env_prefix = "CHAT_"
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore" # Ignore extra environment variables
|
||||
|
||||
@@ -6,19 +6,70 @@ from .config import ChatConfig
|
||||
|
||||
# Env vars that the ChatConfig validators read — must be cleared so they don't
|
||||
# override the explicit constructor values we pass in each test.
|
||||
_E2B_ENV_VARS = (
|
||||
_ENV_VARS_TO_CLEAR = (
|
||||
"CHAT_USE_E2B_SANDBOX",
|
||||
"CHAT_E2B_API_KEY",
|
||||
"E2B_API_KEY",
|
||||
"CHAT_USE_OPENROUTER",
|
||||
"CHAT_API_KEY",
|
||||
"OPEN_ROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"CHAT_BASE_URL",
|
||||
"OPENROUTER_BASE_URL",
|
||||
"OPENAI_BASE_URL",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_e2b_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _E2B_ENV_VARS:
|
||||
def _clean_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _ENV_VARS_TO_CLEAR:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestOpenrouterActive:
|
||||
"""Tests for the openrouter_active property."""
|
||||
|
||||
def test_enabled_with_credentials_returns_true(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is True
|
||||
|
||||
def test_enabled_but_missing_api_key_returns_false(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key=None,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
def test_disabled_returns_false_despite_credentials(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
def test_strips_v1_suffix_and_still_valid(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
)
|
||||
assert cfg.openrouter_active is True
|
||||
|
||||
def test_invalid_base_url_returns_false(self):
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="not-a-url",
|
||||
)
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
|
||||
class TestE2BActive:
|
||||
"""Tests for the e2b_active property — single source of truth for E2B usage."""
|
||||
|
||||
|
||||
@@ -4,6 +4,9 @@
|
||||
# The hex suffix makes accidental LLM generation of these strings virtually
|
||||
# impossible, avoiding false-positive marker detection in normal conversation.
|
||||
COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]" # Renders as ErrorCard
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX = (
|
||||
"[__COPILOT_RETRYABLE_ERROR_a9c2__]" # ErrorCard + retry
|
||||
)
|
||||
COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]" # Renders as system info message
|
||||
|
||||
# Prefix for all synthetic IDs generated by CoPilot block execution.
|
||||
@@ -35,3 +38,24 @@ def parse_node_id_from_exec_id(node_exec_id: str) -> str:
|
||||
Format: "{node_id}:{random_hex}" → returns "{node_id}".
|
||||
"""
|
||||
return node_exec_id.rsplit(COPILOT_NODE_EXEC_ID_SEPARATOR, 1)[0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transient Anthropic API error detection
|
||||
# ---------------------------------------------------------------------------
|
||||
# Patterns in error text that indicate a transient Anthropic API error
|
||||
# (ECONNRESET / dropped TCP connection) which is retryable.
|
||||
_TRANSIENT_ERROR_PATTERNS = (
|
||||
"socket connection was closed unexpectedly",
|
||||
"ECONNRESET",
|
||||
"connection was forcibly closed",
|
||||
"network socket disconnected",
|
||||
)
|
||||
|
||||
FRIENDLY_TRANSIENT_MSG = "Anthropic connection interrupted — please retry"
|
||||
|
||||
|
||||
def is_transient_api_error(error_text: str) -> bool:
|
||||
"""Return True if *error_text* matches a known transient Anthropic API error."""
|
||||
lower = error_text.lower()
|
||||
return any(pat.lower() in lower for pat in _TRANSIENT_ERROR_PATTERNS)
|
||||
|
||||
@@ -16,6 +16,7 @@ from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.response_model import StreamFinish
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
from backend.executor.cluster_lock import ClusterLock
|
||||
from backend.util.decorator import error_logged
|
||||
from backend.util.feature_flag import Flag, is_feature_enabled
|
||||
@@ -246,17 +247,25 @@ class CoPilotProcessor:
|
||||
# Choose service based on LaunchDarkly flag.
|
||||
# Claude Code subscription forces SDK mode (CLI subprocess auth).
|
||||
config = ChatConfig()
|
||||
use_sdk = config.use_claude_code_subscription or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
|
||||
if config.test_mode:
|
||||
stream_fn = stream_chat_completion_dummy
|
||||
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
|
||||
else:
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
else stream_chat_completion_baseline
|
||||
)
|
||||
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
async for chunk in stream_fn(
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
"""Dummy SDK service for testing copilot streaming.
|
||||
|
||||
Returns mock streaming responses without calling Claude Agent SDK.
|
||||
Enable via COPILOT_TEST_MODE=true environment variable.
|
||||
Enable via CHAT_TEST_MODE=true in .env (ChatConfig.test_mode).
|
||||
|
||||
WARNING: This is for testing only. Do not use in production.
|
||||
|
||||
Magic keywords (case-insensitive, anywhere in message):
|
||||
__test_transient_error__ — Simulate a transient Anthropic API error
|
||||
(ECONNRESET). Streams partial text, then
|
||||
yields StreamError with retryable prefix.
|
||||
__test_fatal_error__ — Simulate a non-retryable SDK error.
|
||||
__test_slow_response__ — Simulate a slow response (2s per word).
|
||||
(no keyword) — Normal dummy response.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -12,12 +20,39 @@ import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from ..model import ChatSession
|
||||
from ..response_model import StreamBaseResponse, StreamStart, StreamTextDelta
|
||||
from ..constants import (
|
||||
COPILOT_ERROR_PREFIX,
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
FRIENDLY_TRANSIENT_MSG,
|
||||
)
|
||||
from ..model import ChatMessage, ChatSession, get_chat_session, upsert_chat_session
|
||||
from ..response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _safe_upsert(session: ChatSession) -> None:
|
||||
"""Best-effort session persist — skip silently if DB is unavailable."""
|
||||
try:
|
||||
await upsert_chat_session(session)
|
||||
except Exception:
|
||||
logger.debug("[TEST MODE] Could not persist session (DB unavailable)")
|
||||
|
||||
|
||||
def _has_keyword(message: str | None, keyword: str) -> bool:
|
||||
return keyword in (message or "").lower()
|
||||
|
||||
|
||||
async def stream_chat_completion_dummy(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
@@ -36,24 +71,89 @@ async def stream_chat_completion_dummy(
|
||||
- No timeout occurs
|
||||
- Text arrives in chunks
|
||||
- StreamFinish is sent by mark_session_completed
|
||||
|
||||
See module docstring for magic keywords that trigger error scenarios.
|
||||
"""
|
||||
logger.warning(
|
||||
f"[TEST MODE] Using dummy copilot streaming for session {session_id}"
|
||||
)
|
||||
|
||||
# Load session from DB (matches SDK service behaviour) so error markers
|
||||
# and the assistant reply are persisted and survive page refresh.
|
||||
# Best-effort: skip if DB is unavailable (e.g. unit tests).
|
||||
if session is None:
|
||||
try:
|
||||
session = await get_chat_session(session_id, user_id)
|
||||
except Exception:
|
||||
logger.debug("[TEST MODE] Could not load session (DB unavailable)")
|
||||
session = None
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
text_block_id = str(uuid.uuid4())
|
||||
|
||||
# Start the stream
|
||||
# Start the stream (matches baseline: StreamStart → StreamStartStep)
|
||||
yield StreamStart(messageId=message_id, sessionId=session_id)
|
||||
yield StreamStartStep()
|
||||
|
||||
# Simulate streaming text response with delays
|
||||
# --- Magic keyword: transient error (retryable) -------------------------
|
||||
if _has_keyword(message, "__test_transient_error__"):
|
||||
# Stream some partial text first (simulates mid-stream failure)
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
for word in ["Working", "on", "it..."]:
|
||||
yield StreamTextDelta(id=text_block_id, delta=f"{word} ")
|
||||
await asyncio.sleep(0.1)
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
yield StreamFinishStep()
|
||||
# Persist retryable marker so "Try Again" button shows after refresh
|
||||
if session:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=f"{COPILOT_RETRYABLE_ERROR_PREFIX} {FRIENDLY_TRANSIENT_MSG}",
|
||||
)
|
||||
)
|
||||
await _safe_upsert(session)
|
||||
yield StreamError(
|
||||
errorText=FRIENDLY_TRANSIENT_MSG,
|
||||
code="transient_api_error",
|
||||
)
|
||||
return
|
||||
|
||||
# --- Magic keyword: fatal error (non-retryable) -------------------------
|
||||
if _has_keyword(message, "__test_fatal_error__"):
|
||||
yield StreamFinishStep()
|
||||
error_msg = "Internal SDK error: model refused to respond"
|
||||
# Persist non-retryable error marker
|
||||
if session:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content=f"{COPILOT_ERROR_PREFIX} {error_msg}",
|
||||
)
|
||||
)
|
||||
await _safe_upsert(session)
|
||||
yield StreamError(errorText=error_msg, code="sdk_error")
|
||||
return
|
||||
|
||||
# --- Magic keyword: slow response ---------------------------------------
|
||||
delay = 2.0 if _has_keyword(message, "__test_slow_response__") else 0.1
|
||||
|
||||
# --- Normal dummy response ----------------------------------------------
|
||||
dummy_response = "I counted: 1... 2... 3. All done!"
|
||||
words = dummy_response.split()
|
||||
|
||||
yield StreamTextStart(id=text_block_id)
|
||||
for i, word in enumerate(words):
|
||||
# Add space except for last word
|
||||
text = word if i == len(words) - 1 else f"{word} "
|
||||
yield StreamTextDelta(id=text_block_id, delta=text)
|
||||
# Small delay to simulate real streaming
|
||||
await asyncio.sleep(0.1)
|
||||
await asyncio.sleep(delay)
|
||||
yield StreamTextEnd(id=text_block_id)
|
||||
|
||||
# Persist the assistant reply so it survives page refresh
|
||||
if session:
|
||||
session.messages.append(ChatMessage(role="assistant", content=dummy_response))
|
||||
await _safe_upsert(session)
|
||||
|
||||
yield StreamFinishStep()
|
||||
yield StreamFinish()
|
||||
|
||||
@@ -20,6 +20,7 @@ from claude_agent_sdk import (
|
||||
UserMessage,
|
||||
)
|
||||
|
||||
from backend.copilot.constants import FRIENDLY_TRANSIENT_MSG, is_transient_api_error
|
||||
from backend.copilot.response_model import (
|
||||
StreamBaseResponse,
|
||||
StreamError,
|
||||
@@ -214,10 +215,12 @@ class SDKResponseAdapter:
|
||||
if sdk_message.subtype == "success":
|
||||
responses.append(StreamFinish())
|
||||
elif sdk_message.subtype in ("error", "error_during_execution"):
|
||||
error_msg = sdk_message.result or "Unknown error"
|
||||
responses.append(
|
||||
StreamError(errorText=str(error_msg), code="sdk_error")
|
||||
)
|
||||
raw_error = str(sdk_message.result or "Unknown error")
|
||||
if is_transient_api_error(raw_error):
|
||||
error_text, code = FRIENDLY_TRANSIENT_MSG, "transient_api_error"
|
||||
else:
|
||||
error_text, code = raw_error, "sdk_error"
|
||||
responses.append(StreamError(errorText=error_text, code=code))
|
||||
responses.append(StreamFinish())
|
||||
else:
|
||||
logger.warning(
|
||||
|
||||
@@ -37,7 +37,13 @@ from backend.util.prompt import compress_context
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from ..config import ChatConfig
|
||||
from ..constants import COPILOT_ERROR_PREFIX, COPILOT_SYSTEM_PREFIX
|
||||
from ..constants import (
|
||||
COPILOT_ERROR_PREFIX,
|
||||
COPILOT_RETRYABLE_ERROR_PREFIX,
|
||||
COPILOT_SYSTEM_PREFIX,
|
||||
FRIENDLY_TRANSIENT_MSG,
|
||||
is_transient_api_error,
|
||||
)
|
||||
from ..model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -90,6 +96,28 @@ logger = logging.getLogger(__name__)
|
||||
config = ChatConfig()
|
||||
|
||||
|
||||
def _append_error_marker(
|
||||
session: ChatSession | None,
|
||||
display_msg: str,
|
||||
*,
|
||||
retryable: bool = False,
|
||||
) -> None:
|
||||
"""Append a copilot error marker to *session* so it persists across refresh.
|
||||
|
||||
Args:
|
||||
session: The chat session to append to (no-op if ``None``).
|
||||
display_msg: User-visible error text.
|
||||
retryable: If ``True``, use the retryable prefix so the frontend
|
||||
shows a "Try Again" button.
|
||||
"""
|
||||
if session is None:
|
||||
return
|
||||
prefix = COPILOT_RETRYABLE_ERROR_PREFIX if retryable else COPILOT_ERROR_PREFIX
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content=f"{prefix} {display_msg}")
|
||||
)
|
||||
|
||||
|
||||
def _setup_langfuse_otel() -> None:
|
||||
"""Configure OTEL tracing for the Claude Agent SDK → Langfuse.
|
||||
|
||||
@@ -158,7 +186,12 @@ def _resolve_sdk_model() -> str | None:
|
||||
|
||||
Uses ``config.claude_agent_model`` if set, otherwise derives from
|
||||
``config.model`` by stripping the OpenRouter provider prefix (e.g.,
|
||||
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4.6"``).
|
||||
``"anthropic/claude-opus-4.6"`` → ``"claude-opus-4-6"``).
|
||||
|
||||
OpenRouter uses dot-separated versions (``claude-opus-4.6``) while the
|
||||
direct Anthropic API uses hyphen-separated versions (``claude-opus-4-6``).
|
||||
Normalisation is only applied when the SDK will actually talk to
|
||||
Anthropic directly (not through OpenRouter).
|
||||
|
||||
When ``use_claude_code_subscription`` is enabled and no explicit
|
||||
``claude_agent_model`` is set, returns ``None`` so the CLI uses the
|
||||
@@ -170,7 +203,12 @@ def _resolve_sdk_model() -> str | None:
|
||||
return None
|
||||
model = config.model
|
||||
if "/" in model:
|
||||
return model.split("/", 1)[1]
|
||||
model = model.split("/", 1)[1]
|
||||
# OpenRouter uses dots in versions (claude-opus-4.6) but the direct
|
||||
# Anthropic API requires hyphens (claude-opus-4-6). Only normalise
|
||||
# when NOT routing through OpenRouter.
|
||||
if not config.openrouter_active:
|
||||
model = model.replace(".", "-")
|
||||
return model
|
||||
|
||||
|
||||
@@ -209,61 +247,57 @@ def _build_sdk_env(
|
||||
session_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""Build env vars for the SDK CLI process.
|
||||
"""Build env vars for the SDK CLI subprocess.
|
||||
|
||||
Routes API calls through OpenRouter (or a custom base_url) using
|
||||
the same ``config.api_key`` / ``config.base_url`` as the non-SDK path.
|
||||
This gives per-call token and cost tracking on the OpenRouter dashboard.
|
||||
|
||||
When *session_id* is provided, an ``x-session-id`` custom header is
|
||||
injected via ``ANTHROPIC_CUSTOM_HEADERS`` so that OpenRouter Broadcast
|
||||
forwards traces (including cost/usage) to Langfuse for the
|
||||
``/api/v1/messages`` endpoint.
|
||||
|
||||
Only overrides ``ANTHROPIC_API_KEY`` when a valid proxy URL and auth
|
||||
token are both present — otherwise returns an empty dict so the SDK
|
||||
falls back to its default credentials.
|
||||
Three modes (checked in order):
|
||||
1. **Subscription** — clears all keys; CLI uses ``claude login`` auth.
|
||||
2. **Direct Anthropic** — returns ``{}``; subprocess inherits
|
||||
``ANTHROPIC_API_KEY`` from the parent environment.
|
||||
3. **OpenRouter** (default) — overrides base URL and auth token to
|
||||
route through the proxy, with Langfuse trace headers.
|
||||
"""
|
||||
env: dict[str, str] = {}
|
||||
|
||||
# --- Mode 1: Claude Code subscription auth ---
|
||||
if config.use_claude_code_subscription:
|
||||
# Claude Code subscription: let the CLI use its own logged-in auth.
|
||||
# Explicitly clear API key env vars so the subprocess doesn't pick
|
||||
# them up from the parent process and bypass subscription auth.
|
||||
_validate_claude_code_subscription()
|
||||
env["ANTHROPIC_API_KEY"] = ""
|
||||
env["ANTHROPIC_AUTH_TOKEN"] = ""
|
||||
env["ANTHROPIC_BASE_URL"] = ""
|
||||
elif config.api_key and config.base_url:
|
||||
# Strip /v1 suffix — SDK expects the base URL without a version path
|
||||
base = config.base_url.rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
if not base or not base.startswith("http"):
|
||||
# Invalid base_url — don't override SDK defaults
|
||||
return env
|
||||
env["ANTHROPIC_BASE_URL"] = base
|
||||
env["ANTHROPIC_AUTH_TOKEN"] = config.api_key
|
||||
# Must be explicitly empty so the CLI uses AUTH_TOKEN instead
|
||||
env["ANTHROPIC_API_KEY"] = ""
|
||||
return {
|
||||
"ANTHROPIC_API_KEY": "",
|
||||
"ANTHROPIC_AUTH_TOKEN": "",
|
||||
"ANTHROPIC_BASE_URL": "",
|
||||
}
|
||||
|
||||
# --- Mode 2: Direct Anthropic (no proxy hop) ---
|
||||
# ``openrouter_active`` checks the flag *and* credential presence.
|
||||
if not config.openrouter_active:
|
||||
if config.use_openrouter:
|
||||
logger.warning(
|
||||
"[SDK] OpenRouter enabled but api_key/base_url missing or "
|
||||
"invalid; falling back to direct Anthropic mode"
|
||||
)
|
||||
return {}
|
||||
|
||||
# --- Mode 3: OpenRouter proxy ---
|
||||
# Strip /v1 suffix — SDK expects the base URL without a version path.
|
||||
base = (config.base_url or "").rstrip("/")
|
||||
if base.endswith("/v1"):
|
||||
base = base[:-3]
|
||||
env: dict[str, str] = {
|
||||
"ANTHROPIC_BASE_URL": base,
|
||||
"ANTHROPIC_AUTH_TOKEN": config.api_key or "",
|
||||
"ANTHROPIC_API_KEY": "", # force CLI to use AUTH_TOKEN
|
||||
}
|
||||
|
||||
# Inject broadcast headers so OpenRouter forwards traces to Langfuse.
|
||||
# The ``x-session-id`` header is *required* for the Anthropic-native
|
||||
# ``/messages`` endpoint — without it broadcast silently drops the
|
||||
# trace even when org-level Langfuse integration is configured.
|
||||
def _safe(value: str) -> str:
|
||||
"""Strip CR/LF to prevent header injection, then truncate."""
|
||||
return value.replace("\r", "").replace("\n", "").strip()[:128]
|
||||
def _safe(v: str) -> str:
|
||||
"""Sanitise a header value: strip newlines/whitespace and cap length."""
|
||||
return v.replace("\r", "").replace("\n", "").strip()[:128]
|
||||
|
||||
headers: list[str] = []
|
||||
parts = []
|
||||
if session_id:
|
||||
headers.append(f"x-session-id: {_safe(session_id)}")
|
||||
parts.append(f"x-session-id: {_safe(session_id)}")
|
||||
if user_id:
|
||||
headers.append(f"x-user-id: {_safe(user_id)}")
|
||||
# Only inject headers when routing through OpenRouter/proxy — they're
|
||||
# meaningless (and leak internal IDs) when using subscription mode.
|
||||
if headers and env.get("ANTHROPIC_BASE_URL"):
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(headers)
|
||||
parts.append(f"x-user-id: {_safe(user_id)}")
|
||||
if parts:
|
||||
env["ANTHROPIC_CUSTOM_HEADERS"] = "\n".join(parts)
|
||||
|
||||
return env
|
||||
|
||||
@@ -655,13 +689,17 @@ async def stream_chat_completion_sdk(
|
||||
# Type narrowing: session is guaranteed ChatSession after the check above
|
||||
session = cast(ChatSession, session)
|
||||
|
||||
# Clean up stale error markers from previous turn before starting new turn
|
||||
# If the last message contains an error marker, remove it (user is retrying)
|
||||
if (
|
||||
# Clean up ALL trailing error markers from previous turn before starting
|
||||
# a new turn. Multiple markers can accumulate when a mid-stream error is
|
||||
# followed by a cleanup error in __aexit__ (both append a marker).
|
||||
while (
|
||||
len(session.messages) > 0
|
||||
and session.messages[-1].role == "assistant"
|
||||
and session.messages[-1].content
|
||||
and COPILOT_ERROR_PREFIX in session.messages[-1].content
|
||||
and (
|
||||
COPILOT_ERROR_PREFIX in session.messages[-1].content
|
||||
or COPILOT_RETRYABLE_ERROR_PREFIX in session.messages[-1].content
|
||||
)
|
||||
):
|
||||
logger.info(
|
||||
"[SDK] [%s] Removing stale error marker from previous turn",
|
||||
@@ -806,7 +844,7 @@ async def stream_chat_completion_sdk(
|
||||
)
|
||||
except Exception as transcript_err:
|
||||
logger.warning(
|
||||
"%s Transcript download failed, continuing without " "--resume: %s",
|
||||
"%s Transcript download failed, continuing without --resume: %s",
|
||||
log_prefix,
|
||||
transcript_err,
|
||||
)
|
||||
@@ -829,7 +867,7 @@ async def stream_chat_completion_sdk(
|
||||
is_valid = validate_transcript(dl.content)
|
||||
dl_lines = dl.content.strip().split("\n") if dl.content else []
|
||||
logger.info(
|
||||
"%s Downloaded transcript: %dB, %d lines, " "msg_count=%d, valid=%s",
|
||||
"%s Downloaded transcript: %dB, %d lines, msg_count=%d, valid=%s",
|
||||
log_prefix,
|
||||
len(dl.content),
|
||||
len(dl_lines),
|
||||
@@ -1048,23 +1086,36 @@ async def stream_chat_completion_sdk(
|
||||
# Exception in receive_response() — capture it
|
||||
# so the session can still be saved and the
|
||||
# frontend gets a clean finish.
|
||||
logger.error(
|
||||
if is_transient_api_error(str(stream_err)):
|
||||
log, display, code = (
|
||||
logger.warning,
|
||||
FRIENDLY_TRANSIENT_MSG,
|
||||
"transient_api_error",
|
||||
)
|
||||
else:
|
||||
log, display, code = (
|
||||
logger.error,
|
||||
f"SDK stream error: {stream_err}",
|
||||
"sdk_stream_error",
|
||||
)
|
||||
|
||||
log(
|
||||
"%s Stream error from SDK: %s",
|
||||
log_prefix,
|
||||
stream_err,
|
||||
exc_info=True,
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
|
||||
yield StreamError(
|
||||
errorText=f"SDK stream error: {stream_err}",
|
||||
code="sdk_stream_error",
|
||||
_append_error_marker(
|
||||
session,
|
||||
display,
|
||||
retryable=(code == "transient_api_error"),
|
||||
)
|
||||
yield StreamError(errorText=display, code=code)
|
||||
break
|
||||
|
||||
logger.info(
|
||||
"%s Received: %s %s "
|
||||
"(unresolved=%d, current=%d, resolved=%d)",
|
||||
"%s Received: %s %s (unresolved=%d, current=%d, resolved=%d)",
|
||||
log_prefix,
|
||||
type(sdk_msg).__name__,
|
||||
getattr(sdk_msg, "subtype", ""),
|
||||
@@ -1078,15 +1129,42 @@ async def stream_chat_completion_sdk(
|
||||
# so we can debug Anthropic API 400s surfaced by the CLI.
|
||||
sdk_error = getattr(sdk_msg, "error", None)
|
||||
if isinstance(sdk_msg, AssistantMessage) and sdk_error:
|
||||
error_text = str(sdk_error)
|
||||
error_preview = str(sdk_msg.content)[:500]
|
||||
logger.error(
|
||||
"[SDK] [%s] AssistantMessage has error=%s, "
|
||||
"content_blocks=%d, content_preview=%s",
|
||||
session_id[:12],
|
||||
sdk_error,
|
||||
len(sdk_msg.content),
|
||||
str(sdk_msg.content)[:500],
|
||||
error_preview,
|
||||
)
|
||||
|
||||
# Intercept transient API errors (socket closed,
|
||||
# ECONNRESET) — replace the raw message with a
|
||||
# user-friendly error text and use the retryable
|
||||
# error prefix so the frontend shows a retry button.
|
||||
# Check both the error field and content for patterns.
|
||||
if is_transient_api_error(error_text) or is_transient_api_error(
|
||||
error_preview
|
||||
):
|
||||
logger.warning(
|
||||
"%s Transient Anthropic API error detected, "
|
||||
"suppressing raw error text",
|
||||
log_prefix,
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
_append_error_marker(
|
||||
session,
|
||||
FRIENDLY_TRANSIENT_MSG,
|
||||
retryable=True,
|
||||
)
|
||||
yield StreamError(
|
||||
errorText=FRIENDLY_TRANSIENT_MSG,
|
||||
code="transient_api_error",
|
||||
)
|
||||
break
|
||||
|
||||
# Race-condition fix: SDK hooks (PostToolUse) are
|
||||
# executed asynchronously via start_soon() — the next
|
||||
# message can arrive before the hook stashes output.
|
||||
@@ -1212,7 +1290,7 @@ async def stream_chat_completion_sdk(
|
||||
extra,
|
||||
)
|
||||
|
||||
# Log errors being sent to frontend
|
||||
# Persist error markers so they survive page refresh
|
||||
if isinstance(response, StreamError):
|
||||
logger.error(
|
||||
"%s Sending error to frontend: %s (code=%s)",
|
||||
@@ -1220,6 +1298,12 @@ async def stream_chat_completion_sdk(
|
||||
response.errorText,
|
||||
response.code,
|
||||
)
|
||||
_append_error_marker(
|
||||
session,
|
||||
response.errorText,
|
||||
retryable=(response.code == "transient_api_error"),
|
||||
)
|
||||
ended_with_stream_error = True
|
||||
|
||||
yield response
|
||||
|
||||
@@ -1434,14 +1518,18 @@ async def stream_chat_completion_sdk(
|
||||
else:
|
||||
logger.error("%s Error: %s", log_prefix, error_msg, exc_info=True)
|
||||
|
||||
# Append error marker to session (non-invasive text parsing approach)
|
||||
# The finally block will persist the session with this error marker
|
||||
if session:
|
||||
session.messages.append(
|
||||
ChatMessage(
|
||||
role="assistant", content=f"{COPILOT_ERROR_PREFIX} {error_msg}"
|
||||
)
|
||||
)
|
||||
is_transient = is_transient_api_error(error_msg)
|
||||
if is_transient:
|
||||
display_msg, code = FRIENDLY_TRANSIENT_MSG, "transient_api_error"
|
||||
else:
|
||||
display_msg, code = error_msg, "sdk_error"
|
||||
|
||||
# Append error marker to session (non-invasive text parsing approach).
|
||||
# The finally block will persist the session with this error marker.
|
||||
# Skip if a marker was already appended inside the stream loop
|
||||
# (ended_with_stream_error) to avoid duplicate stale markers.
|
||||
if not ended_with_stream_error:
|
||||
_append_error_marker(session, display_msg, retryable=is_transient)
|
||||
logger.debug(
|
||||
"%s Appended error marker, will be persisted in finally",
|
||||
log_prefix,
|
||||
@@ -1453,10 +1541,7 @@ async def stream_chat_completion_sdk(
|
||||
isinstance(e, RuntimeError) and "cancel scope" in str(e)
|
||||
)
|
||||
if not is_cancellation:
|
||||
yield StreamError(
|
||||
errorText=error_msg,
|
||||
code="sdk_error",
|
||||
)
|
||||
yield StreamError(errorText=display_msg, code=code)
|
||||
|
||||
raise
|
||||
finally:
|
||||
|
||||
@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from .service import _prepare_file_attachments
|
||||
from .service import _prepare_file_attachments, _resolve_sdk_model
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -288,3 +288,127 @@ class TestPromptSupplement:
|
||||
# Count how many times this tool appears as a bullet point
|
||||
count = docs.count(f"- **`{tool_name}`**")
|
||||
assert count == 1, f"Tool '{tool_name}' appears {count} times (should be 1)"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Env vars that ChatConfig validators read — must be cleared so explicit
|
||||
# constructor values are used.
|
||||
# ---------------------------------------------------------------------------
|
||||
_CONFIG_ENV_VARS = (
|
||||
"CHAT_USE_OPENROUTER",
|
||||
"CHAT_API_KEY",
|
||||
"OPEN_ROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
"CHAT_BASE_URL",
|
||||
"OPENROUTER_BASE_URL",
|
||||
"OPENAI_BASE_URL",
|
||||
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
|
||||
"CHAT_USE_CLAUDE_AGENT_SDK",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _clean_config_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for var in _CONFIG_ENV_VARS:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
class TestResolveSdkModel:
|
||||
"""Tests for _resolve_sdk_model — model ID resolution for the SDK CLI."""
|
||||
|
||||
def test_openrouter_active_keeps_dots(self, monkeypatch, _clean_config_env):
|
||||
"""When OpenRouter is fully active, model keeps dot-separated version."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4.6"
|
||||
|
||||
def test_openrouter_disabled_normalizes_to_hyphens(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""When OpenRouter is disabled, dots are replaced with hyphens."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
def test_openrouter_enabled_but_missing_key_normalizes(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""When OpenRouter is enabled but api_key is missing, falls back to
|
||||
direct Anthropic and normalizes dots to hyphens."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key=None,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
def test_explicit_claude_agent_model_takes_precedence(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""When claude_agent_model is explicitly set, it is returned as-is."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model="claude-sonnet-4-5-20250514",
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-sonnet-4-5-20250514"
|
||||
|
||||
def test_subscription_mode_returns_none(self, monkeypatch, _clean_config_env):
|
||||
"""When using Claude Code subscription, returns None (CLI picks model)."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="anthropic/claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() is None
|
||||
|
||||
def test_model_without_provider_prefix(self, monkeypatch, _clean_config_env):
|
||||
"""When model has no provider prefix, it still normalizes correctly."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
model="claude-opus-4.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
@@ -4,11 +4,12 @@ These tests verify the complete copilot flow using dummy implementations
|
||||
for agent generator and SDK service, allowing automated testing without
|
||||
external LLM calls.
|
||||
|
||||
Enable test mode with COPILOT_TEST_MODE=true environment variable.
|
||||
Enable test mode with CHAT_TEST_MODE=true environment variable (or in .env).
|
||||
|
||||
Note: StreamFinish is NOT emitted by the dummy service — it is published
|
||||
by mark_session_completed in the processor layer. These tests only cover
|
||||
the service-level streaming output (StreamStart + StreamTextDelta).
|
||||
The dummy service emits the full AI SDK protocol event sequence:
|
||||
StreamStart → StreamStartStep → StreamTextStart → StreamTextDelta(s) →
|
||||
StreamTextEnd → StreamFinishStep → StreamFinish.
|
||||
The processor skips StreamFinish and publishes its own via mark_session_completed.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -20,9 +21,14 @@ import pytest
|
||||
from backend.copilot.model import ChatMessage, ChatSession, upsert_chat_session
|
||||
from backend.copilot.response_model import (
|
||||
StreamError,
|
||||
StreamFinish,
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
StreamTextEnd,
|
||||
StreamTextStart,
|
||||
)
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
|
||||
@@ -30,9 +36,9 @@ from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_test_mode():
|
||||
"""Enable test mode for all tests in this module."""
|
||||
os.environ["COPILOT_TEST_MODE"] = "true"
|
||||
os.environ["CHAT_TEST_MODE"] = "true"
|
||||
yield
|
||||
os.environ.pop("COPILOT_TEST_MODE", None)
|
||||
os.environ.pop("CHAT_TEST_MODE", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -110,9 +116,14 @@ async def test_streaming_event_types():
|
||||
):
|
||||
event_types.add(type(event).__name__)
|
||||
|
||||
# Required event types (StreamFinish is published by processor, not service)
|
||||
# Required event types for full AI SDK protocol
|
||||
assert "StreamStart" in event_types, "Missing StreamStart"
|
||||
assert "StreamStartStep" in event_types, "Missing StreamStartStep"
|
||||
assert "StreamTextStart" in event_types, "Missing StreamTextStart"
|
||||
assert "StreamTextDelta" in event_types, "Missing StreamTextDelta"
|
||||
assert "StreamTextEnd" in event_types, "Missing StreamTextEnd"
|
||||
assert "StreamFinishStep" in event_types, "Missing StreamFinishStep"
|
||||
assert "StreamFinish" in event_types, "Missing StreamFinish"
|
||||
|
||||
print(f"✅ Event types: {sorted(event_types)}")
|
||||
|
||||
@@ -175,16 +186,17 @@ async def test_streaming_heartbeat_timing():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_handling():
|
||||
"""Test that errors are properly formatted and sent."""
|
||||
# This would require a dummy that can trigger errors
|
||||
# For now, just verify error event structure
|
||||
|
||||
"""Test that error events have correct SSE structure."""
|
||||
error = StreamError(errorText="Test error", code="test_error")
|
||||
assert error.errorText == "Test error"
|
||||
assert error.code == "test_error"
|
||||
assert str(error.type.value) in ["error", "error"]
|
||||
|
||||
print("✅ Error structure verified")
|
||||
# Verify to_sse() strips code (AI SDK protocol compliance)
|
||||
sse = error.to_sse()
|
||||
assert '"errorText"' in sse
|
||||
assert '"code"' not in sse, "to_sse() must strip code field for AI SDK"
|
||||
|
||||
print("✅ Error structure verified (code stripped in SSE)")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -326,20 +338,85 @@ async def test_stream_completeness():
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Check for required events (StreamFinish is published by processor)
|
||||
has_start = any(isinstance(e, StreamStart) for e in events)
|
||||
has_text = any(isinstance(e, StreamTextDelta) for e in events)
|
||||
|
||||
assert has_start, "Stream must include StreamStart"
|
||||
assert has_text, "Stream must include text deltas"
|
||||
# Check for all required event types
|
||||
assert any(isinstance(e, StreamStart) for e in events), "Missing StreamStart"
|
||||
assert any(
|
||||
isinstance(e, StreamStartStep) for e in events
|
||||
), "Missing StreamStartStep"
|
||||
assert any(
|
||||
isinstance(e, StreamTextStart) for e in events
|
||||
), "Missing StreamTextStart"
|
||||
assert any(
|
||||
isinstance(e, StreamTextDelta) for e in events
|
||||
), "Missing StreamTextDelta"
|
||||
assert any(isinstance(e, StreamTextEnd) for e in events), "Missing StreamTextEnd"
|
||||
assert any(
|
||||
isinstance(e, StreamFinishStep) for e in events
|
||||
), "Missing StreamFinishStep"
|
||||
assert any(isinstance(e, StreamFinish) for e in events), "Missing StreamFinish"
|
||||
|
||||
# Verify exactly one start
|
||||
start_count = sum(1 for e in events if isinstance(e, StreamStart))
|
||||
assert start_count == 1, f"Should have exactly 1 StreamStart, got {start_count}"
|
||||
|
||||
print(
|
||||
f"✅ Completeness: 1 start, {sum(1 for e in events if isinstance(e, StreamTextDelta))} text deltas"
|
||||
)
|
||||
print(f"✅ Completeness: {len(events)} events, full protocol sequence")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transient_error_shows_retryable():
|
||||
"""Test __test_transient_error__ yields partial text then retryable StreamError."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-transient",
|
||||
message="please fail __test_transient_error__",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# Should start with StreamStart
|
||||
assert isinstance(events[0], StreamStart)
|
||||
|
||||
# Should have some partial text before the error
|
||||
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
|
||||
assert len(text_events) > 0, "Should stream partial text before error"
|
||||
|
||||
# Should end with StreamError
|
||||
error_events = [e for e in events if isinstance(e, StreamError)]
|
||||
assert len(error_events) == 1, "Should have exactly one StreamError"
|
||||
assert error_events[0].code == "transient_api_error"
|
||||
assert "connection interrupted" in error_events[0].errorText.lower()
|
||||
|
||||
print(f"✅ Transient error: {len(text_events)} partial deltas + retryable error")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fatal_error_not_retryable():
|
||||
"""Test __test_fatal_error__ yields StreamError without retryable code."""
|
||||
events = []
|
||||
|
||||
async for event in stream_chat_completion_dummy(
|
||||
session_id="test-fatal",
|
||||
message="__test_fatal_error__",
|
||||
is_user_message=True,
|
||||
user_id="test-user",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert isinstance(events[0], StreamStart)
|
||||
|
||||
# Should have StreamError with sdk_error code (not transient)
|
||||
error_events = [e for e in events if isinstance(e, StreamError)]
|
||||
assert len(error_events) == 1
|
||||
assert error_events[0].code == "sdk_error"
|
||||
assert "transient" not in error_events[0].code
|
||||
|
||||
# Should NOT have any text deltas (fatal errors fail immediately)
|
||||
text_events = [e for e in events if isinstance(e, StreamTextDelta)]
|
||||
assert len(text_events) == 0, "Fatal error should not stream any text"
|
||||
|
||||
print("✅ Fatal error: immediate error, no partial text")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -395,6 +472,8 @@ if __name__ == "__main__":
|
||||
asyncio.run(test_message_deduplication())
|
||||
asyncio.run(test_event_ordering())
|
||||
asyncio.run(test_stream_completeness())
|
||||
asyncio.run(test_transient_error_shows_retryable())
|
||||
asyncio.run(test_fatal_error_not_retryable())
|
||||
asyncio.run(test_text_delta_consistency())
|
||||
|
||||
print("=" * 60)
|
||||
|
||||
@@ -49,6 +49,20 @@ export const ChatContainer = ({
|
||||
!!isSessionError;
|
||||
const inputLayoutId = "copilot-2-chat-input";
|
||||
|
||||
// Retry: re-send the last user message (used by ErrorCard on transient errors)
|
||||
function handleRetry() {
|
||||
const lastUserMsg = [...messages].reverse().find((m) => m.role === "user");
|
||||
const lastText = lastUserMsg?.parts
|
||||
.filter(
|
||||
(p): p is Extract<typeof p, { type: "text" }> => p.type === "text",
|
||||
)
|
||||
.map((p) => p.text)
|
||||
.join("");
|
||||
if (lastText) {
|
||||
onSend(lastText);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<CopilotChatActionsProvider onSend={onSend}>
|
||||
<LayoutGroup id="copilot-2-chat-layout">
|
||||
@@ -61,6 +75,7 @@ export const ChatContainer = ({
|
||||
error={error}
|
||||
isLoading={isLoadingSession}
|
||||
sessionID={sessionId}
|
||||
onRetry={handleRetry}
|
||||
/>
|
||||
<motion.div
|
||||
initial={{ opacity: 0 }}
|
||||
|
||||
@@ -31,11 +31,13 @@ interface Props {
|
||||
error: Error | undefined;
|
||||
isLoading: boolean;
|
||||
sessionID?: string | null;
|
||||
onRetry?: () => void;
|
||||
}
|
||||
|
||||
function renderSegments(
|
||||
segments: RenderSegment[],
|
||||
messageID: string,
|
||||
onRetry?: () => void,
|
||||
): React.ReactNode[] {
|
||||
return segments.map((seg, segIdx) => {
|
||||
if (seg.kind === "collapsed-group") {
|
||||
@@ -47,6 +49,7 @@ function renderSegments(
|
||||
part={seg.part}
|
||||
messageID={messageID}
|
||||
partIndex={seg.index}
|
||||
onRetry={onRetry}
|
||||
/>
|
||||
);
|
||||
});
|
||||
@@ -102,6 +105,7 @@ export function ChatMessagesContainer({
|
||||
error,
|
||||
isLoading,
|
||||
sessionID,
|
||||
onRetry,
|
||||
}: Props) {
|
||||
const lastMessage = messages[messages.length - 1];
|
||||
const graphExecId = useMemo(() => extractGraphExecId(messages), [messages]);
|
||||
@@ -161,9 +165,12 @@ export function ChatMessagesContainer({
|
||||
(p): p is Extract<typeof p, { type: "text" }> => p.type === "text",
|
||||
);
|
||||
const lastTextPart = textParts[textParts.length - 1];
|
||||
const markerType =
|
||||
lastTextPart !== undefined
|
||||
? parseSpecialMarkers(lastTextPart.text).markerType
|
||||
: null;
|
||||
const hasErrorMarker =
|
||||
lastTextPart !== undefined &&
|
||||
parseSpecialMarkers(lastTextPart.text).markerType === "error";
|
||||
markerType === "error" || markerType === "retryable_error";
|
||||
const showActions =
|
||||
isLastInTurn &&
|
||||
!isCurrentlyStreaming &&
|
||||
@@ -209,13 +216,18 @@ export function ChatMessagesContainer({
|
||||
</ReasoningCollapse>
|
||||
)}
|
||||
{responseSegments
|
||||
? renderSegments(responseSegments, message.id)
|
||||
? renderSegments(
|
||||
responseSegments,
|
||||
message.id,
|
||||
isLastAssistant ? onRetry : undefined,
|
||||
)
|
||||
: message.parts.map((part, i) => (
|
||||
<MessagePartRenderer
|
||||
key={`${message.id}-${i}`}
|
||||
part={part}
|
||||
messageID={message.id}
|
||||
partIndex={i}
|
||||
onRetry={isLastAssistant ? onRetry : undefined}
|
||||
/>
|
||||
))}
|
||||
{isLastInTurn && !isCurrentlyStreaming && (
|
||||
|
||||
@@ -69,9 +69,15 @@ interface Props {
|
||||
part: UIMessage<unknown, UIDataTypes, UITools>["parts"][number];
|
||||
messageID: string;
|
||||
partIndex: number;
|
||||
onRetry?: () => void;
|
||||
}
|
||||
|
||||
export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
|
||||
export function MessagePartRenderer({
|
||||
part,
|
||||
messageID,
|
||||
partIndex,
|
||||
onRetry,
|
||||
}: Props) {
|
||||
const key = `${messageID}-${partIndex}`;
|
||||
|
||||
switch (part.type) {
|
||||
@@ -80,7 +86,7 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
|
||||
part.text,
|
||||
);
|
||||
|
||||
if (markerType === "error") {
|
||||
if (markerType === "error" || markerType === "retryable_error") {
|
||||
const lowerMarker = markerText.toLowerCase();
|
||||
const isCancellation =
|
||||
lowerMarker === "operation cancelled" ||
|
||||
@@ -100,6 +106,7 @@ export function MessagePartRenderer({ part, messageID, partIndex }: Props) {
|
||||
key={key}
|
||||
responseError={{ message: markerText }}
|
||||
context="execution"
|
||||
onRetry={markerType === "retryable_error" ? onRetry : undefined}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -172,16 +172,22 @@ export function getTurnMessages(
|
||||
// The hex suffix makes it virtually impossible for an LLM to accidentally
|
||||
// produce these strings in normal conversation.
|
||||
const COPILOT_ERROR_PREFIX = "[__COPILOT_ERROR_f7a1__]";
|
||||
const COPILOT_RETRYABLE_ERROR_PREFIX = "[__COPILOT_RETRYABLE_ERROR_a9c2__]";
|
||||
const COPILOT_SYSTEM_PREFIX = "[__COPILOT_SYSTEM_e3b0__]";
|
||||
|
||||
export type MarkerType = "error" | "system" | null;
|
||||
export type MarkerType = "error" | "retryable_error" | "system" | null;
|
||||
|
||||
/** Escape all regex special characters in a string. */
|
||||
function escapeRegExp(s: string): string {
|
||||
return s.replace(/[.*+?^${}()|[\]\\]/g, "\\$&");
|
||||
}
|
||||
|
||||
// Pre-compiled marker regexes (avoids re-creating on every call / render)
|
||||
// Pre-compiled marker regexes (avoids re-creating on every call / render).
|
||||
// Retryable check must come first since it's more specific.
|
||||
const RETRYABLE_ERROR_MARKER_RE = new RegExp(
|
||||
`${escapeRegExp(COPILOT_RETRYABLE_ERROR_PREFIX)}\\s*(.+?)$`,
|
||||
"s",
|
||||
);
|
||||
const ERROR_MARKER_RE = new RegExp(
|
||||
`${escapeRegExp(COPILOT_ERROR_PREFIX)}\\s*(.+?)$`,
|
||||
"s",
|
||||
@@ -196,6 +202,15 @@ export function parseSpecialMarkers(text: string): {
|
||||
markerText: string;
|
||||
cleanText: string;
|
||||
} {
|
||||
const retryableMatch = text.match(RETRYABLE_ERROR_MARKER_RE);
|
||||
if (retryableMatch) {
|
||||
return {
|
||||
markerType: "retryable_error",
|
||||
markerText: retryableMatch[1].trim(),
|
||||
cleanText: text.replace(retryableMatch[0], "").trim(),
|
||||
};
|
||||
}
|
||||
|
||||
const errorMatch = text.match(ERROR_MARKER_RE);
|
||||
if (errorMatch) {
|
||||
return {
|
||||
|
||||
@@ -222,12 +222,17 @@ export function useCopilotStream({
|
||||
return;
|
||||
}
|
||||
|
||||
// Only reconnect on network errors (not HTTP errors), and never
|
||||
// reconnect when the user explicitly stopped the stream.
|
||||
// Reconnect on network errors or transient API errors so the
|
||||
// persisted retryable-error marker is loaded and the "Try Again"
|
||||
// button appears. Without this, transient errors only show in the
|
||||
// onError callback (where StreamError strips the retryable prefix).
|
||||
if (isUserStoppingRef.current) return;
|
||||
const isNetworkError =
|
||||
error.name === "TypeError" || error.name === "AbortError";
|
||||
if (isNetworkError) {
|
||||
const isTransientApiError = errorDetail.includes(
|
||||
"connection interrupted",
|
||||
);
|
||||
if (isNetworkError || isTransientApiError) {
|
||||
handleReconnect(sessionId);
|
||||
}
|
||||
},
|
||||
|
||||
Reference in New Issue
Block a user