mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(copilot): address all review items on PR #12623
Blockers: - B1: use config.model (opus) for baseline unless mode='fast' explicitly downgrades — prevents silent quality drop for all baseline users - B2: extract _resolve_use_sdk / _resolve_effective_mode to production code in processor.py; tests now exercise real routing logic - B3: add real unit tests covering transcript download/validate/load/ append/backfill/upload round-trip via new service helpers Should-fix: - S1: server-side CHAT_MODE_OPTION flag gate in processor blocks unauthorised users from bypassing the UI toggle via crafted requests - S2: frontend gates copilotMode behind the CHAT_MODE_OPTION flag in useCopilotPage so stale localStorage values aren't sent when off - S3: Pydantic validation tests for StreamChatRequest.mode literal - S4: extract _load_prior_transcript / _upload_final_transcript helpers and split _baseline_conversation_updater into message-mutation and transcript-recording concerns Nice-to-have + Nits: - N1: parallelize GCS transcript content + metadata downloads via gather - N3: add role='switch' and aria-checked to mode toggle button - N5: toast notification when user toggles mode - Nit1: drop export on ChatInput Props - Nit2: inline dismissRateLimit, drop unused useCallback import - Nit3: replace 'end_turn'/'tool_use' magic strings with constants - Nit4: log malformed tool_call JSON parse errors at debug level
This commit is contained in:
@@ -541,3 +541,41 @@ def test_create_session_rejects_nested_metadata(
|
||||
)
|
||||
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
class TestStreamChatRequestModeValidation:
|
||||
"""Pydantic-level validation of the ``mode`` field on StreamChatRequest."""
|
||||
|
||||
def test_rejects_invalid_mode_value(self) -> None:
|
||||
"""Any string outside the Literal set must raise ValidationError."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
StreamChatRequest(message="hi", mode="turbo") # type: ignore[arg-type]
|
||||
|
||||
def test_accepts_fast_mode(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode="fast")
|
||||
assert req.mode == "fast"
|
||||
|
||||
def test_accepts_extended_thinking_mode(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode="extended_thinking")
|
||||
assert req.mode == "extended_thinking"
|
||||
|
||||
def test_accepts_none_mode(self) -> None:
|
||||
"""``mode=None`` is valid (server decides via feature flags)."""
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi", mode=None)
|
||||
assert req.mode is None
|
||||
|
||||
def test_mode_defaults_to_none_when_omitted(self) -> None:
|
||||
from backend.api.features.chat.routes import StreamChatRequest
|
||||
|
||||
req = StreamChatRequest(message="hi")
|
||||
assert req.mode is None
|
||||
|
||||
@@ -12,7 +12,7 @@ import uuid
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
@@ -53,6 +53,8 @@ from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
STOP_REASON_TOOL_USE,
|
||||
download_transcript,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
@@ -80,6 +82,19 @@ _background_tasks: set[asyncio.Task[Any]] = set()
|
||||
_MAX_TOOL_ROUNDS = 30
|
||||
|
||||
|
||||
def _resolve_baseline_model(mode: Literal["fast", "extended_thinking"] | None) -> str:
|
||||
"""Pick the model for the baseline path based on the per-request mode.
|
||||
|
||||
Only ``mode='fast'`` downgrades to the cheaper/faster model. Any other
|
||||
value (including ``None`` and ``'extended_thinking'``) preserves the
|
||||
default model so that users who never select a mode don't get
|
||||
silently moved to the cheaper tier.
|
||||
"""
|
||||
if mode == "fast":
|
||||
return config.fast_model
|
||||
return config.model
|
||||
|
||||
|
||||
@dataclass
|
||||
class _BaselineStreamState:
|
||||
"""Mutable state shared between the tool-call loop callbacks.
|
||||
@@ -88,6 +103,7 @@ class _BaselineStreamState:
|
||||
can be module-level functions instead of deeply nested closures.
|
||||
"""
|
||||
|
||||
model: str = ""
|
||||
pending_events: list[StreamBaseResponse] = field(default_factory=list)
|
||||
assistant_text: str = ""
|
||||
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
@@ -115,7 +131,7 @@ async def _baseline_llm_caller(
|
||||
if tools:
|
||||
typed_tools = cast(list[ChatCompletionToolParam], tools)
|
||||
response = await client.chat.completions.create(
|
||||
model=config.fast_model,
|
||||
model=state.model,
|
||||
messages=typed_messages,
|
||||
tools=typed_tools,
|
||||
stream=True,
|
||||
@@ -123,7 +139,7 @@ async def _baseline_llm_caller(
|
||||
)
|
||||
else:
|
||||
response = await client.chat.completions.create(
|
||||
model=config.fast_model,
|
||||
model=state.model,
|
||||
messages=typed_messages,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
@@ -285,20 +301,17 @@ async def _baseline_tool_executor(
|
||||
)
|
||||
|
||||
|
||||
def _baseline_conversation_updater(
|
||||
def _mutate_openai_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
*,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
model: str = "",
|
||||
tool_results: list[ToolCallResult] | None,
|
||||
) -> None:
|
||||
"""Update OpenAI message list with assistant response + tool results.
|
||||
"""Append assistant / tool-result entries to the OpenAI message list.
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` for readability.
|
||||
This is the side-effect boundary for the next LLM call — no transcript
|
||||
mutation happens here.
|
||||
"""
|
||||
if tool_results:
|
||||
# Build assistant message with tool_calls
|
||||
assistant_msg: dict[str, Any] = {"role": "assistant"}
|
||||
if response.response_text:
|
||||
assistant_msg["content"] = response.response_text
|
||||
@@ -311,14 +324,46 @@ def _baseline_conversation_updater(
|
||||
for tc in response.tool_calls
|
||||
]
|
||||
messages.append(assistant_msg)
|
||||
# Record assistant message (with tool_calls) to transcript
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_call_id,
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
elif response.response_text:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
|
||||
|
||||
def _record_turn_to_transcript(
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None,
|
||||
*,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
model: str,
|
||||
) -> None:
|
||||
"""Append assistant + tool-result entries to the transcript builder.
|
||||
|
||||
Kept separate from :func:`_mutate_openai_messages` so the two
|
||||
concerns (next-LLM-call payload vs. durable conversation log) can
|
||||
evolve independently.
|
||||
"""
|
||||
if tool_results:
|
||||
content_blocks: list[dict[str, Any]] = []
|
||||
if response.response_text:
|
||||
content_blocks.append({"type": "text", "text": response.response_text})
|
||||
for tc in response.tool_calls:
|
||||
try:
|
||||
args = orjson.loads(tc.arguments) if tc.arguments else {}
|
||||
except Exception:
|
||||
except Exception as parse_err:
|
||||
logger.debug(
|
||||
"[Baseline] Failed to parse tool_call arguments "
|
||||
"(tool=%s, id=%s): %s",
|
||||
tc.name,
|
||||
tc.id,
|
||||
parse_err,
|
||||
)
|
||||
args = {}
|
||||
content_blocks.append(
|
||||
{
|
||||
@@ -332,16 +377,9 @@ def _baseline_conversation_updater(
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=content_blocks,
|
||||
model=model,
|
||||
stop_reason="tool_use",
|
||||
stop_reason=STOP_REASON_TOOL_USE,
|
||||
)
|
||||
for tr in tool_results:
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tr.tool_call_id,
|
||||
"content": tr.content,
|
||||
}
|
||||
)
|
||||
# Record tool result to transcript AFTER the assistant tool_use
|
||||
# block to maintain correct Anthropic API ordering:
|
||||
# assistant(tool_use) → user(tool_result)
|
||||
@@ -349,15 +387,34 @@ def _baseline_conversation_updater(
|
||||
tool_use_id=tr.tool_call_id,
|
||||
content=tr.content,
|
||||
)
|
||||
else:
|
||||
if response.response_text:
|
||||
messages.append({"role": "assistant", "content": response.response_text})
|
||||
# Record final text to transcript
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": response.response_text}],
|
||||
model=model,
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
elif response.response_text:
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": response.response_text}],
|
||||
model=model,
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
|
||||
def _baseline_conversation_updater(
|
||||
messages: list[dict[str, Any]],
|
||||
response: LLMLoopResponse,
|
||||
tool_results: list[ToolCallResult] | None = None,
|
||||
*,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
model: str = "",
|
||||
) -> None:
|
||||
"""Update OpenAI message list with assistant response + tool results.
|
||||
|
||||
Thin composition of :func:`_mutate_openai_messages` and
|
||||
:func:`_record_turn_to_transcript`.
|
||||
"""
|
||||
_mutate_openai_messages(messages, response, tool_results)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results,
|
||||
transcript_builder=transcript_builder,
|
||||
model=model,
|
||||
)
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
@@ -374,6 +431,7 @@ async def _update_title_async(
|
||||
|
||||
async def _compress_session_messages(
|
||||
messages: list[ChatMessage],
|
||||
model: str,
|
||||
) -> list[ChatMessage]:
|
||||
"""Compress session messages if they exceed the model's token limit.
|
||||
|
||||
@@ -395,14 +453,14 @@ async def _compress_session_messages(
|
||||
try:
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.fast_model,
|
||||
model=model,
|
||||
client=_get_openai_client(),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
|
||||
result = await compress_context(
|
||||
messages=messages_dict,
|
||||
model=config.fast_model,
|
||||
model=model,
|
||||
client=None,
|
||||
)
|
||||
|
||||
@@ -428,12 +486,85 @@ async def _compress_session_messages(
|
||||
return messages
|
||||
|
||||
|
||||
async def _load_prior_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
session_msg_count: int,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
) -> bool:
|
||||
"""Download and load the prior transcript into ``transcript_builder``.
|
||||
|
||||
Returns ``True`` when the loaded transcript fully covers the session
|
||||
prefix; ``False`` otherwise (stale, missing, invalid, or download
|
||||
error). Callers should suppress uploads when this returns ``False``
|
||||
to avoid overwriting a more complete version in storage.
|
||||
"""
|
||||
try:
|
||||
dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]")
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Transcript download failed: %s", e)
|
||||
return False
|
||||
|
||||
if dl is None:
|
||||
logger.debug("[Baseline] No transcript available")
|
||||
return False
|
||||
|
||||
if not validate_transcript(dl.content):
|
||||
logger.warning("[Baseline] Downloaded transcript but invalid")
|
||||
return False
|
||||
|
||||
# Reject stale transcripts: if msg_count is known and doesn't cover
|
||||
# the current session, loading it would silently drop intermediate
|
||||
# turns from the transcript.
|
||||
if dl.message_count and dl.message_count < session_msg_count - 1:
|
||||
logger.warning(
|
||||
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
|
||||
dl.message_count,
|
||||
session_msg_count,
|
||||
)
|
||||
return False
|
||||
|
||||
transcript_builder.load_previous(dl.content, log_prefix="[Baseline]")
|
||||
logger.info(
|
||||
"[Baseline] Loaded transcript: %dB, msg_count=%d",
|
||||
len(dl.content),
|
||||
dl.message_count,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
async def _upload_final_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
transcript_builder: TranscriptBuilder,
|
||||
session_msg_count: int,
|
||||
) -> None:
|
||||
"""Serialize and upload the transcript for next-turn continuity."""
|
||||
try:
|
||||
content = transcript_builder.to_jsonl()
|
||||
if content and validate_transcript(content):
|
||||
await asyncio.shield(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=content,
|
||||
message_count=session_msg_count,
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.debug("[Baseline] No valid transcript to upload")
|
||||
except Exception as upload_err:
|
||||
logger.error("[Baseline] Transcript upload failed: %s", upload_err)
|
||||
|
||||
|
||||
async def stream_chat_completion_baseline(
|
||||
session_id: str,
|
||||
message: str | None = None,
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
session: ChatSession | None = None,
|
||||
mode: Literal["fast", "extended_thinking"] | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Baseline LLM with tool calling via OpenAI-compatible API.
|
||||
@@ -462,43 +593,21 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
session = await upsert_chat_session(session)
|
||||
|
||||
# Select model based on the per-request mode. 'fast' downgrades to
|
||||
# the cheaper/faster model; everything else keeps the default.
|
||||
active_model = _resolve_baseline_model(mode)
|
||||
|
||||
# --- Transcript support (feature parity with SDK path) ---
|
||||
transcript_builder = TranscriptBuilder()
|
||||
transcript_covers_prefix = True
|
||||
|
||||
if user_id and len(session.messages) > 1:
|
||||
try:
|
||||
dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]")
|
||||
if dl and validate_transcript(dl.content):
|
||||
# Reject stale transcripts: if msg_count is known and
|
||||
# doesn't cover the current session, loading it would
|
||||
# silently drop intermediate turns from the transcript.
|
||||
session_msg_count = len(session.messages)
|
||||
if dl.message_count and dl.message_count < session_msg_count - 1:
|
||||
logger.warning(
|
||||
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
|
||||
dl.message_count,
|
||||
session_msg_count,
|
||||
)
|
||||
transcript_covers_prefix = False
|
||||
else:
|
||||
transcript_builder.load_previous(
|
||||
dl.content, log_prefix="[Baseline]"
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] Loaded transcript: %dB, msg_count=%d",
|
||||
len(dl.content),
|
||||
dl.message_count,
|
||||
)
|
||||
elif dl:
|
||||
logger.warning("[Baseline] Downloaded transcript but invalid")
|
||||
transcript_covers_prefix = False
|
||||
else:
|
||||
logger.debug("[Baseline] No transcript available")
|
||||
transcript_covers_prefix = False
|
||||
except Exception as e:
|
||||
logger.warning("[Baseline] Transcript download failed: %s", e)
|
||||
transcript_covers_prefix = False
|
||||
transcript_covers_prefix = await _load_prior_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
session_msg_count=len(session.messages),
|
||||
transcript_builder=transcript_builder,
|
||||
)
|
||||
|
||||
# Append user message to transcript.
|
||||
# Always append when the message is present and is from the user,
|
||||
@@ -540,7 +649,9 @@ async def stream_chat_completion_baseline(
|
||||
system_prompt = base_system_prompt + get_baseline_supplement()
|
||||
|
||||
# Compress context if approaching the model's token limit
|
||||
messages_for_context = await _compress_session_messages(session.messages)
|
||||
messages_for_context = await _compress_session_messages(
|
||||
session.messages, model=active_model
|
||||
)
|
||||
|
||||
# Build OpenAI message list from session history.
|
||||
# Include tool_calls on assistant messages and tool-role results so the
|
||||
@@ -590,7 +701,7 @@ async def stream_chat_completion_baseline(
|
||||
logger.warning("[Baseline] Langfuse trace context setup failed")
|
||||
|
||||
_stream_error = False # Track whether an error occurred during streaming
|
||||
state = _BaselineStreamState()
|
||||
state = _BaselineStreamState(model=active_model)
|
||||
|
||||
# Bind extracted module-level callbacks to this request's state/session
|
||||
# using functools.partial so they satisfy the Protocol signatures.
|
||||
@@ -602,7 +713,7 @@ async def stream_chat_completion_baseline(
|
||||
_bound_conversation_updater = partial(
|
||||
_baseline_conversation_updater,
|
||||
transcript_builder=transcript_builder,
|
||||
model=config.fast_model,
|
||||
model=active_model,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -681,10 +792,10 @@ async def stream_chat_completion_baseline(
|
||||
and not (_stream_error and not state.assistant_text)
|
||||
):
|
||||
state.turn_prompt_tokens = max(
|
||||
estimate_token_count(openai_messages, model=config.fast_model), 1
|
||||
estimate_token_count(openai_messages, model=active_model), 1
|
||||
)
|
||||
state.turn_completion_tokens = estimate_token_count_str(
|
||||
state.assistant_text, model=config.fast_model
|
||||
state.assistant_text, model=active_model
|
||||
)
|
||||
logger.info(
|
||||
"[Baseline] No streaming usage reported; estimated tokens: "
|
||||
@@ -724,27 +835,17 @@ async def stream_chat_completion_baseline(
|
||||
if transcript_builder.last_entry_type != "assistant":
|
||||
transcript_builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": state.assistant_text}],
|
||||
model=config.fast_model,
|
||||
stop_reason="end_turn",
|
||||
model=active_model,
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
if user_id and transcript_covers_prefix:
|
||||
try:
|
||||
_transcript_content = transcript_builder.to_jsonl()
|
||||
if _transcript_content and validate_transcript(_transcript_content):
|
||||
await asyncio.shield(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=_transcript_content,
|
||||
message_count=len(session.messages),
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.debug("[Baseline] No valid transcript to upload")
|
||||
except Exception as upload_err:
|
||||
logger.error("[Baseline] Transcript upload failed: %s", upload_err)
|
||||
await _upload_final_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
transcript_builder=transcript_builder,
|
||||
session_msg_count=len(session.messages),
|
||||
)
|
||||
|
||||
# Yield usage and finish AFTER try/finally (not inside finally).
|
||||
# PEP 525 prohibits yielding from finally in async generators during
|
||||
|
||||
@@ -1,19 +1,34 @@
|
||||
"""Unit tests for baseline transcript integration logic.
|
||||
"""Integration tests for baseline transcript flow.
|
||||
|
||||
Tests cover stale transcript detection, transcript_covers_prefix gating,
|
||||
and partial backfill on stream error — without API keys or network access.
|
||||
Exercises the real helpers in ``baseline/service.py`` that download,
|
||||
validate, load, append to, backfill, and upload the transcript.
|
||||
Storage is mocked via ``download_transcript`` / ``upload_transcript``
|
||||
patches; no network access is required.
|
||||
"""
|
||||
|
||||
import json as stdlib_json
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.transcript import TranscriptDownload, validate_transcript
|
||||
from backend.copilot.baseline.service import (
|
||||
_load_prior_transcript,
|
||||
_record_turn_to_transcript,
|
||||
_resolve_baseline_model,
|
||||
_upload_final_transcript,
|
||||
)
|
||||
from backend.copilot.service import config
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
STOP_REASON_TOOL_USE,
|
||||
TranscriptDownload,
|
||||
)
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
|
||||
|
||||
|
||||
def _make_transcript_content(*roles: str) -> str:
|
||||
"""Build minimal valid JSONL transcript from role names."""
|
||||
"""Build a minimal valid JSONL transcript from role names."""
|
||||
lines = []
|
||||
parent = ""
|
||||
for i, role in enumerate(roles):
|
||||
@@ -30,169 +45,413 @@ def _make_transcript_content(*roles: str) -> str:
|
||||
if role == "assistant":
|
||||
entry["message"]["id"] = f"msg_{i}"
|
||||
entry["message"]["model"] = "test-model"
|
||||
entry["message"]["type"] = "message"
|
||||
entry["message"]["stop_reason"] = STOP_REASON_END_TURN
|
||||
lines.append(stdlib_json.dumps(entry))
|
||||
parent = uid
|
||||
return "\n".join(lines) + "\n"
|
||||
|
||||
|
||||
class TestStaleTranscriptDetection:
|
||||
"""Tests for the stale-transcript detection logic in the baseline service.
|
||||
class TestResolveBaselineModel:
|
||||
"""Model selection honours the per-request mode."""
|
||||
|
||||
When the downloaded transcript's message_count is behind the session's
|
||||
message count, the transcript is considered stale and skipped.
|
||||
"""
|
||||
def test_fast_mode_selects_fast_model(self):
|
||||
assert _resolve_baseline_model("fast") == config.fast_model
|
||||
|
||||
def test_stale_transcript_detected(self):
|
||||
"""Transcript with fewer messages than the session is flagged as stale."""
|
||||
dl = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
message_count=2,
|
||||
def test_extended_thinking_selects_default_model(self):
|
||||
assert _resolve_baseline_model("extended_thinking") == config.model
|
||||
|
||||
def test_none_mode_selects_default_model(self):
|
||||
"""Critical: baseline users without a mode MUST keep the default (opus)."""
|
||||
assert _resolve_baseline_model(None) == config.model
|
||||
|
||||
def test_default_and_fast_models_differ(self):
|
||||
"""Sanity: the two tiers are actually distinct in production config."""
|
||||
assert config.model != config.fast_model
|
||||
|
||||
|
||||
class TestLoadPriorTranscript:
|
||||
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loads_fresh_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rejects_stale_transcript(self):
|
||||
"""msg_count strictly less than session-1 is treated as stale."""
|
||||
builder = TranscriptBuilder()
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
# session has 6 messages, transcript only covers 2 → stale.
|
||||
download = TranscriptDownload(content=content, message_count=2)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=6,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_transcript_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_transcript_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content='{"type":"progress","uuid":"a"}\n',
|
||||
message_count=1,
|
||||
)
|
||||
session_msg_count = 6 # 4 new messages since transcript was uploaded
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
is_stale = dl.message_count and dl.message_count < session_msg_count - 1
|
||||
assert is_stale
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
def test_fresh_transcript_accepted(self):
|
||||
"""Transcript covering the session prefix is not flagged as stale."""
|
||||
dl = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
message_count=4,
|
||||
)
|
||||
session_msg_count = 5 # Only 1 new message (the user turn just sent)
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_exception_returns_false(self):
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("boom")),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=2,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
is_stale = dl.message_count and dl.message_count < session_msg_count - 1
|
||||
assert not is_stale
|
||||
assert covers is False
|
||||
assert builder.is_empty
|
||||
|
||||
def test_zero_message_count_not_stale(self):
|
||||
"""When message_count is 0 (unknown), staleness check is skipped."""
|
||||
dl = TranscriptDownload(
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_message_count_not_stale(self):
|
||||
"""When msg_count is 0 (unknown), staleness check is skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
download = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
message_count=0,
|
||||
)
|
||||
session_msg_count = 10
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=20,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
is_stale = dl.message_count and dl.message_count < session_msg_count - 1
|
||||
assert not is_stale # 0 is falsy, so check is skipped
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
|
||||
|
||||
class TestTranscriptCoversPrefix:
|
||||
"""Tests for transcript_covers_prefix gating in the baseline upload path.
|
||||
class TestUploadFinalTranscript:
|
||||
"""``_upload_final_transcript`` serialises and calls storage."""
|
||||
|
||||
When transcript_covers_prefix is False, the transcript is NOT uploaded
|
||||
to avoid overwriting a more complete version in storage.
|
||||
"""
|
||||
|
||||
def test_no_download_sets_covers_false(self):
|
||||
"""When no transcript is available, covers_prefix should be False."""
|
||||
dl = None
|
||||
transcript_covers_prefix = dl is not None
|
||||
assert not transcript_covers_prefix
|
||||
|
||||
def test_invalid_transcript_sets_covers_false(self):
|
||||
"""When downloaded transcript fails validation, covers_prefix is False."""
|
||||
content = '{"type":"progress","uuid":"a"}\n'
|
||||
assert not validate_transcript(content)
|
||||
|
||||
def test_valid_transcript_sets_covers_true(self):
|
||||
"""When downloaded transcript is valid and fresh, covers_prefix is True."""
|
||||
content = _make_transcript_content("user", "assistant")
|
||||
assert validate_transcript(content)
|
||||
|
||||
|
||||
class TestPartialBackfill:
|
||||
"""Tests for partial backfill of assistant text on stream error.
|
||||
|
||||
When the stream aborts mid-round, the conversation updater may not have
|
||||
recorded the partial assistant text. The finally block backfills it
|
||||
so mode-switching after a failed turn doesn't lose the partial response.
|
||||
"""
|
||||
|
||||
def test_backfill_appends_when_last_entry_not_assistant(self):
|
||||
"""When the last transcript entry is not an assistant, backfill appends."""
|
||||
@pytest.mark.asyncio
|
||||
async def test_uploads_valid_transcript(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("user question")
|
||||
|
||||
assistant_text = "partial response before error"
|
||||
assert builder.last_entry_type != "assistant"
|
||||
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": assistant_text}],
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason="end_turn",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=2,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
call_kwargs = upload_mock.await_args.kwargs
|
||||
assert call_kwargs["user_id"] == "user-1"
|
||||
assert call_kwargs["session_id"] == "session-1"
|
||||
assert call_kwargs["message_count"] == 2
|
||||
assert "hello" in call_kwargs["content"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_upload_when_builder_empty(self):
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=0,
|
||||
)
|
||||
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_swallows_upload_exceptions(self):
|
||||
"""Upload failures should not propagate (flow continues for the user)."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=AsyncMock(side_effect=RuntimeError("storage unavailable")),
|
||||
):
|
||||
# Should not raise.
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=2,
|
||||
)
|
||||
|
||||
|
||||
class TestRecordTurnToTranscript:
|
||||
"""``_record_turn_to_transcript`` translates LLMLoopResponse → transcript."""
|
||||
|
||||
def test_records_final_assistant_text(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text="hello there",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 2
|
||||
assert builder.last_entry_type == "assistant"
|
||||
jsonl = builder.to_jsonl()
|
||||
assert "partial response before error" in jsonl
|
||||
assert "hello there" in jsonl
|
||||
assert STOP_REASON_END_TURN in jsonl
|
||||
|
||||
def test_no_backfill_when_last_entry_is_assistant(self):
|
||||
"""When the conversation updater already recorded the assistant turn,
|
||||
backfill should be skipped (checked via last_entry_type)."""
|
||||
def test_records_tool_use_then_tool_result(self):
|
||||
"""Anthropic ordering: assistant(tool_use) → user(tool_result)."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("user question")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "already recorded"}],
|
||||
builder.append_user(content="use a tool")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[
|
||||
LLMToolCall(id="call-1", name="echo", arguments='{"text":"hi"}')
|
||||
],
|
||||
raw_response=None,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="hi")
|
||||
]
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
stop_reason="end_turn",
|
||||
)
|
||||
|
||||
assert builder.last_entry_type == "assistant"
|
||||
initial_count = builder.entry_count
|
||||
# user, assistant(tool_use), user(tool_result) = 3 entries
|
||||
assert builder.entry_count == 3
|
||||
jsonl = builder.to_jsonl()
|
||||
assert STOP_REASON_TOOL_USE in jsonl
|
||||
assert "tool_use" in jsonl
|
||||
assert "tool_result" in jsonl
|
||||
assert "call-1" in jsonl
|
||||
|
||||
# Simulating the backfill guard: don't append if already assistant
|
||||
def test_records_nothing_on_empty_response(self):
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 1
|
||||
|
||||
def test_malformed_tool_args_dont_crash(self):
|
||||
"""Bad JSON in tool arguments falls back to {} without raising."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
response = LLMLoopResponse(
|
||||
response_text=None,
|
||||
tool_calls=[LLMToolCall(id="call-1", name="echo", arguments="{not-json")],
|
||||
raw_response=None,
|
||||
)
|
||||
tool_results = [
|
||||
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="ok")
|
||||
]
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
assert builder.entry_count == 3
|
||||
jsonl = builder.to_jsonl()
|
||||
assert '"input":{}' in jsonl
|
||||
|
||||
|
||||
class TestRoundTrip:
|
||||
"""End-to-end: load prior → append new turn → upload."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_round_trip(self):
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
|
||||
builder = TranscriptBuilder()
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
assert builder.entry_count == 2
|
||||
|
||||
# New user turn.
|
||||
builder.append_user(content="new question")
|
||||
assert builder.entry_count == 3
|
||||
|
||||
# New assistant turn.
|
||||
response = LLMLoopResponse(
|
||||
response_text="new answer",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
)
|
||||
_record_turn_to_transcript(
|
||||
response,
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
assert builder.entry_count == 4
|
||||
|
||||
# Upload.
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
):
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=4,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "new question" in uploaded
|
||||
assert "new answer" in uploaded
|
||||
# Original content preserved in the round trip.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_backfill_append_guard(self):
|
||||
"""Backfill only runs when the last entry is not already assistant."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
|
||||
# Simulate the backfill guard from stream_chat_completion_baseline.
|
||||
assistant_text = "partial text before error"
|
||||
if builder.last_entry_type != "assistant":
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": assistant_text}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert builder.last_entry_type == "assistant"
|
||||
assert "partial text before error" in builder.to_jsonl()
|
||||
|
||||
# Second invocation: the guard must prevent double-append.
|
||||
initial_count = builder.entry_count
|
||||
if builder.last_entry_type != "assistant":
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "duplicate"}],
|
||||
model="test-model",
|
||||
stop_reason="end_turn",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert builder.entry_count == initial_count
|
||||
|
||||
def test_no_backfill_when_no_assistant_text(self):
|
||||
"""When stream_error is True but no assistant text was produced,
|
||||
no backfill should occur."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user("user question")
|
||||
|
||||
assistant_text = ""
|
||||
_stream_error = True
|
||||
|
||||
# Simulating the guard from service.py:
|
||||
# if _stream_error and state.assistant_text:
|
||||
should_backfill = _stream_error and assistant_text
|
||||
assert not should_backfill
|
||||
|
||||
|
||||
class TestTranscriptUploadGating:
|
||||
"""Tests for the upload gating logic.
|
||||
|
||||
Upload only happens when user_id is set AND transcript_covers_prefix is True.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_skipped_without_user_id(self):
|
||||
"""No upload when user_id is None."""
|
||||
user_id = None
|
||||
transcript_covers_prefix = True
|
||||
|
||||
should_upload = user_id and transcript_covers_prefix
|
||||
assert not should_upload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_skipped_when_prefix_not_covered(self):
|
||||
"""No upload when transcript doesn't cover the session prefix."""
|
||||
user_id = "user-1"
|
||||
transcript_covers_prefix = False
|
||||
|
||||
should_upload = user_id and transcript_covers_prefix
|
||||
assert not should_upload
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_proceeds_when_conditions_met(self):
|
||||
"""Upload proceeds when user_id is set and prefix is covered."""
|
||||
user_id = "user-1"
|
||||
transcript_covers_prefix = True
|
||||
|
||||
should_upload = user_id and transcript_covers_prefix
|
||||
assert should_upload
|
||||
|
||||
@@ -10,6 +10,7 @@ import os
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
@@ -30,6 +31,57 @@ from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
|
||||
|
||||
# ============ Mode Routing ============ #
|
||||
|
||||
|
||||
async def _resolve_effective_mode(
|
||||
mode: Literal["fast", "extended_thinking"] | None,
|
||||
user_id: str | None,
|
||||
) -> Literal["fast", "extended_thinking"] | None:
|
||||
"""Strip ``mode`` when the user is not entitled to the toggle.
|
||||
|
||||
The UI gates the mode toggle behind ``CHAT_MODE_OPTION``; the
|
||||
processor enforces the same gate server-side so an authenticated
|
||||
user cannot bypass the flag by crafting a request directly.
|
||||
"""
|
||||
if mode is None:
|
||||
return None
|
||||
allowed = await is_feature_enabled(
|
||||
Flag.CHAT_MODE_OPTION,
|
||||
user_id or "anonymous",
|
||||
default=False,
|
||||
)
|
||||
if not allowed:
|
||||
logger.info(f"Ignoring mode={mode} — CHAT_MODE_OPTION is disabled for user")
|
||||
return None
|
||||
return mode
|
||||
|
||||
|
||||
async def _resolve_use_sdk(
|
||||
mode: Literal["fast", "extended_thinking"] | None,
|
||||
user_id: str | None,
|
||||
*,
|
||||
use_claude_code_subscription: bool,
|
||||
config_default: bool,
|
||||
) -> bool:
|
||||
"""Pick the SDK vs baseline path for a single turn.
|
||||
|
||||
Per-request ``mode`` wins whenever it is set (after the
|
||||
``CHAT_MODE_OPTION`` gate has been applied upstream). Otherwise
|
||||
falls back to the Claude Code subscription override, then the
|
||||
``COPILOT_SDK`` LaunchDarkly flag, then the config default.
|
||||
"""
|
||||
if mode == "fast":
|
||||
return False
|
||||
if mode == "extended_thinking":
|
||||
return True
|
||||
return use_claude_code_subscription or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
user_id or "anonymous",
|
||||
default=config_default,
|
||||
)
|
||||
|
||||
|
||||
# ============ Module Entry Points ============ #
|
||||
|
||||
# Thread-local storage for processor instances
|
||||
@@ -250,23 +302,19 @@ class CoPilotProcessor:
|
||||
if config.test_mode:
|
||||
stream_fn = stream_chat_completion_dummy
|
||||
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
|
||||
effective_mode = None
|
||||
else:
|
||||
# Per-request mode override from the frontend takes priority.
|
||||
# 'fast' → baseline (OpenAI-compatible), 'extended_thinking' → SDK.
|
||||
if entry.mode == "fast":
|
||||
use_sdk = False
|
||||
elif entry.mode == "extended_thinking":
|
||||
use_sdk = True
|
||||
else:
|
||||
# No mode specified — fall back to feature flag / config.
|
||||
use_sdk = (
|
||||
config.use_claude_code_subscription
|
||||
or await is_feature_enabled(
|
||||
Flag.COPILOT_SDK,
|
||||
entry.user_id or "anonymous",
|
||||
default=config.use_claude_agent_sdk,
|
||||
)
|
||||
)
|
||||
# Enforce server-side feature-flag gate so unauthorised
|
||||
# users cannot force a mode by crafting the request.
|
||||
effective_mode = await _resolve_effective_mode(
|
||||
entry.mode, entry.user_id
|
||||
)
|
||||
use_sdk = await _resolve_use_sdk(
|
||||
effective_mode,
|
||||
entry.user_id,
|
||||
use_claude_code_subscription=config.use_claude_code_subscription,
|
||||
config_default=config.use_claude_agent_sdk,
|
||||
)
|
||||
stream_fn = (
|
||||
sdk_service.stream_chat_completion_sdk
|
||||
if use_sdk
|
||||
@@ -274,7 +322,7 @@ class CoPilotProcessor:
|
||||
)
|
||||
log.info(
|
||||
f"Using {'SDK' if use_sdk else 'baseline'} service "
|
||||
f"(mode={entry.mode or 'default'})"
|
||||
f"(mode={effective_mode or 'default'})"
|
||||
)
|
||||
|
||||
# Stream chat completion and publish chunks to Redis.
|
||||
@@ -287,6 +335,7 @@ class CoPilotProcessor:
|
||||
user_id=entry.user_id,
|
||||
context=entry.context,
|
||||
file_ids=entry.file_ids,
|
||||
mode=effective_mode,
|
||||
)
|
||||
async for chunk in stream_registry.stream_and_publish(
|
||||
session_id=entry.session_id,
|
||||
|
||||
@@ -4,104 +4,169 @@ Tests cover the mode→service mapping:
|
||||
- 'fast' → baseline service
|
||||
- 'extended_thinking' → SDK service
|
||||
- None → feature flag / config fallback
|
||||
|
||||
as well as the ``CHAT_MODE_OPTION`` server-side gate. The tests import
|
||||
the real production helpers from ``processor.py`` so the routing logic
|
||||
has meaningful coverage.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.executor.processor import _resolve_effective_mode, _resolve_use_sdk
|
||||
|
||||
|
||||
def _resolve_use_sdk(
|
||||
mode: Literal["fast", "extended_thinking"] | None,
|
||||
use_claude_code_subscription: bool = False,
|
||||
feature_flag_value: bool = False,
|
||||
config_default: bool = True,
|
||||
) -> bool:
|
||||
"""Replicate the mode-routing logic from CoPilotProcessor._execute_async.
|
||||
|
||||
Extracted here so we can test it in isolation without instantiating the
|
||||
full processor or its event loop.
|
||||
"""
|
||||
if mode == "fast":
|
||||
return False
|
||||
elif mode == "extended_thinking":
|
||||
return True
|
||||
else:
|
||||
return use_claude_code_subscription or feature_flag_value or config_default
|
||||
|
||||
|
||||
class TestModeRouting:
|
||||
class TestResolveUseSdk:
|
||||
"""Tests for the per-request mode routing logic."""
|
||||
|
||||
def test_fast_mode_uses_baseline(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_mode_uses_baseline(self):
|
||||
"""mode='fast' always routes to baseline, regardless of flags."""
|
||||
assert _resolve_use_sdk("fast") is False
|
||||
assert _resolve_use_sdk("fast", use_claude_code_subscription=True) is False
|
||||
assert _resolve_use_sdk("fast", feature_flag_value=True) is False
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert (
|
||||
await _resolve_use_sdk(
|
||||
"fast",
|
||||
"user-1",
|
||||
use_claude_code_subscription=True,
|
||||
config_default=True,
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_extended_thinking_uses_sdk(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_extended_thinking_uses_sdk(self):
|
||||
"""mode='extended_thinking' always routes to SDK, regardless of flags."""
|
||||
assert _resolve_use_sdk("extended_thinking") is True
|
||||
assert (
|
||||
_resolve_use_sdk("extended_thinking", use_claude_code_subscription=False)
|
||||
is True
|
||||
)
|
||||
assert _resolve_use_sdk("extended_thinking", config_default=False) is True
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await _resolve_use_sdk(
|
||||
"extended_thinking",
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=False,
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_none_mode_uses_subscription_override(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_uses_subscription_override(self):
|
||||
"""mode=None with claude_code_subscription=True routes to SDK."""
|
||||
assert (
|
||||
_resolve_use_sdk(
|
||||
None,
|
||||
use_claude_code_subscription=True,
|
||||
feature_flag_value=False,
|
||||
config_default=False,
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await _resolve_use_sdk(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=True,
|
||||
config_default=False,
|
||||
)
|
||||
is True
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_none_mode_uses_feature_flag(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_uses_feature_flag(self):
|
||||
"""mode=None with feature flag enabled routes to SDK."""
|
||||
assert (
|
||||
_resolve_use_sdk(
|
||||
None,
|
||||
use_claude_code_subscription=False,
|
||||
feature_flag_value=True,
|
||||
config_default=False,
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
) as flag_mock:
|
||||
assert (
|
||||
await _resolve_use_sdk(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=False,
|
||||
)
|
||||
is True
|
||||
)
|
||||
is True
|
||||
)
|
||||
flag_mock.assert_awaited_once()
|
||||
|
||||
def test_none_mode_uses_config_default(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_uses_config_default(self):
|
||||
"""mode=None falls back to config.use_claude_agent_sdk."""
|
||||
assert (
|
||||
_resolve_use_sdk(
|
||||
None,
|
||||
use_claude_code_subscription=False,
|
||||
feature_flag_value=False,
|
||||
config_default=True,
|
||||
# When LaunchDarkly returns the default (True), we expect SDK routing.
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert (
|
||||
await _resolve_use_sdk(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=True,
|
||||
)
|
||||
is True
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_none_mode_all_disabled(self):
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_all_disabled(self):
|
||||
"""mode=None with all flags off routes to baseline."""
|
||||
assert (
|
||||
_resolve_use_sdk(
|
||||
None,
|
||||
use_claude_code_subscription=False,
|
||||
feature_flag_value=False,
|
||||
config_default=False,
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await _resolve_use_sdk(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
config_default=False,
|
||||
)
|
||||
is False
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
def test_none_mode_precedence_subscription_over_flag(self):
|
||||
"""Claude Code subscription takes priority over feature flag."""
|
||||
assert (
|
||||
_resolve_use_sdk(
|
||||
None,
|
||||
use_claude_code_subscription=True,
|
||||
feature_flag_value=False,
|
||||
config_default=False,
|
||||
|
||||
class TestResolveEffectiveMode:
|
||||
"""Tests for the CHAT_MODE_OPTION server-side gate."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_mode_passes_through(self):
|
||||
"""mode=None is returned as-is without a flag check."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
) as flag_mock:
|
||||
assert await _resolve_effective_mode(None, "user-1") is None
|
||||
flag_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mode_stripped_when_flag_disabled(self):
|
||||
"""When CHAT_MODE_OPTION is off, mode is dropped to None."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert await _resolve_effective_mode("fast", "user-1") is None
|
||||
assert await _resolve_effective_mode("extended_thinking", "user-1") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mode_preserved_when_flag_enabled(self):
|
||||
"""When CHAT_MODE_OPTION is on, the user-selected mode is preserved."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert await _resolve_effective_mode("fast", "user-1") == "fast"
|
||||
assert (
|
||||
await _resolve_effective_mode("extended_thinking", "user-1")
|
||||
== "extended_thinking"
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_anonymous_user_with_mode(self):
|
||||
"""Anonymous users (user_id=None) still pass through the gate."""
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
) as flag_mock:
|
||||
assert await _resolve_effective_mode("fast", None) is None
|
||||
flag_mock.assert_awaited_once()
|
||||
|
||||
@@ -735,33 +735,44 @@ async def download_transcript(
|
||||
|
||||
Returns a ``TranscriptDownload`` with the JSONL content and the
|
||||
``message_count`` watermark from the upload, or ``None`` if not found.
|
||||
|
||||
The content and metadata fetches run concurrently since they are
|
||||
independent objects in the bucket.
|
||||
"""
|
||||
storage = await get_workspace_storage()
|
||||
path = _build_storage_path(user_id, session_id, storage)
|
||||
meta_path = _build_meta_storage_path(user_id, session_id, storage)
|
||||
|
||||
try:
|
||||
data = await storage.retrieve(path)
|
||||
content = data.decode("utf-8")
|
||||
except FileNotFoundError:
|
||||
content_task = asyncio.create_task(storage.retrieve(path))
|
||||
meta_task = asyncio.create_task(storage.retrieve(meta_path))
|
||||
content_result, meta_result = await asyncio.gather(
|
||||
content_task, meta_task, return_exceptions=True
|
||||
)
|
||||
|
||||
if isinstance(content_result, FileNotFoundError):
|
||||
logger.debug("%s No transcript in storage", log_prefix)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
|
||||
if isinstance(content_result, BaseException):
|
||||
logger.warning(
|
||||
"%s Failed to download transcript: %s", log_prefix, content_result
|
||||
)
|
||||
return None
|
||||
|
||||
# Try to load metadata (best-effort — old transcripts won't have it)
|
||||
content = content_result.decode("utf-8")
|
||||
|
||||
# Metadata is best-effort — old transcripts won't have it.
|
||||
message_count = 0
|
||||
uploaded_at = 0.0
|
||||
try:
|
||||
meta_path = _build_meta_storage_path(user_id, session_id, storage)
|
||||
meta_data = await storage.retrieve(meta_path)
|
||||
meta = json.loads(meta_data.decode("utf-8"), fallback={})
|
||||
if isinstance(meta_result, FileNotFoundError):
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
elif isinstance(meta_result, BaseException):
|
||||
logger.debug(
|
||||
"%s Failed to load transcript metadata: %s", log_prefix, meta_result
|
||||
)
|
||||
else:
|
||||
meta = json.loads(meta_result.decode("utf-8"), fallback={})
|
||||
message_count = meta.get("message_count", 0)
|
||||
uploaded_at = meta.get("uploaded_at", 0.0)
|
||||
except FileNotFoundError:
|
||||
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
|
||||
except Exception as e:
|
||||
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
|
||||
|
||||
logger.info(
|
||||
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
|
||||
@@ -803,6 +814,7 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
|
||||
|
||||
# JSONL protocol values used in transcript serialization.
|
||||
STOP_REASON_END_TURN = "end_turn"
|
||||
STOP_REASON_TOOL_USE = "tool_use"
|
||||
COMPACT_MSG_ID_PREFIX = "msg_compact_"
|
||||
ENTRY_TYPE_MESSAGE = "message"
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ class Flag(str, Enum):
|
||||
AGENT_ACTIVITY = "agent-activity"
|
||||
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
|
||||
CHAT = "chat"
|
||||
CHAT_MODE_OPTION = "chat-mode-option"
|
||||
COPILOT_SDK = "copilot-sdk"
|
||||
COPILOT_DAILY_TOKEN_LIMIT = "copilot-daily-token-limit"
|
||||
COPILOT_WEEKLY_TOKEN_LIMIT = "copilot-weekly-token-limit"
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
PromptInputTextarea,
|
||||
PromptInputTools,
|
||||
} from "@/components/ai-elements/prompt-input";
|
||||
import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { InputGroup } from "@/components/ui/input-group";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
@@ -18,7 +19,7 @@ import { useCopilotUIStore } from "../../store";
|
||||
import { useChatInput } from "./useChatInput";
|
||||
import { useVoiceRecording } from "./useVoiceRecording";
|
||||
|
||||
export interface Props {
|
||||
interface Props {
|
||||
onSend: (message: string, files?: File[]) => void | Promise<void>;
|
||||
disabled?: boolean;
|
||||
isStreaming?: boolean;
|
||||
@@ -49,6 +50,22 @@ export function ChatInput({
|
||||
const isFastModeEnabled = useGetFlag(Flag.CHAT_MODE_OPTION);
|
||||
const [files, setFiles] = useState<File[]>([]);
|
||||
|
||||
function handleToggleMode() {
|
||||
const next =
|
||||
copilotMode === "extended_thinking" ? "fast" : "extended_thinking";
|
||||
setCopilotMode(next);
|
||||
toast({
|
||||
title:
|
||||
next === "fast"
|
||||
? "Switched to Fast mode"
|
||||
: "Switched to Extended Thinking mode",
|
||||
description:
|
||||
next === "fast"
|
||||
? "Response quality may differ."
|
||||
: "Responses may take longer.",
|
||||
});
|
||||
}
|
||||
|
||||
// Merge files dropped onto the chat window into internal state.
|
||||
useEffect(() => {
|
||||
if (droppedFiles && droppedFiles.length > 0) {
|
||||
@@ -165,14 +182,10 @@ export function ChatInput({
|
||||
{isFastModeEnabled && (
|
||||
<button
|
||||
type="button"
|
||||
role="switch"
|
||||
aria-checked={copilotMode === "fast"}
|
||||
disabled={isStreaming}
|
||||
onClick={() =>
|
||||
setCopilotMode(
|
||||
copilotMode === "extended_thinking"
|
||||
? "fast"
|
||||
: "extended_thinking",
|
||||
)
|
||||
}
|
||||
onClick={handleToggleMode}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
copilotMode === "extended_thinking"
|
||||
|
||||
@@ -27,6 +27,11 @@ vi.mock("@/services/feature-flags/use-get-flag", () => ({
|
||||
useGetFlag: () => mockFlagValue,
|
||||
}));
|
||||
|
||||
vi.mock("@/components/molecules/Toast/use-toast", () => ({
|
||||
toast: vi.fn(),
|
||||
useToast: () => ({ toast: vi.fn(), dismiss: vi.fn() }),
|
||||
}));
|
||||
|
||||
vi.mock("../useVoiceRecording", () => ({
|
||||
useVoiceRecording: () => ({
|
||||
isRecording: false,
|
||||
@@ -153,4 +158,33 @@ describe("ChatInput mode toggle", () => {
|
||||
const button = screen.getByLabelText(/switch to fast mode/i);
|
||||
expect(button.hasAttribute("disabled")).toBe(true);
|
||||
});
|
||||
|
||||
it("exposes role='switch' with aria-checked", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "extended_thinking";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
const button = screen.getByRole("switch");
|
||||
expect(button.getAttribute("aria-checked")).toBe("false");
|
||||
});
|
||||
|
||||
it("sets aria-checked=true in fast mode", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "fast";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
const button = screen.getByRole("switch");
|
||||
expect(button.getAttribute("aria-checked")).toBe("true");
|
||||
});
|
||||
|
||||
it("shows a toast when the user toggles mode", async () => {
|
||||
const { toast } = await import("@/components/molecules/Toast/use-toast");
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "extended_thinking";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByRole("switch"));
|
||||
expect(toast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: expect.stringMatching(/switched to fast mode/i),
|
||||
}),
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -10,6 +10,7 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
|
||||
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import type { FileUIPart } from "ai";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useCopilotUIStore } from "./store";
|
||||
import { useChatSession } from "./useChatSession";
|
||||
@@ -32,6 +33,8 @@ export function useCopilotPage() {
|
||||
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
|
||||
const queryClient = useQueryClient();
|
||||
|
||||
const isModeToggleEnabled = useGetFlag(Flag.CHAT_MODE_OPTION);
|
||||
|
||||
const {
|
||||
sessionToDelete,
|
||||
setSessionToDelete,
|
||||
@@ -69,7 +72,7 @@ export function useCopilotPage() {
|
||||
hydratedMessages,
|
||||
hasActiveStream,
|
||||
refetchSession,
|
||||
copilotMode,
|
||||
copilotMode: isModeToggleEnabled ? copilotMode : undefined,
|
||||
});
|
||||
|
||||
useCopilotNotifications(sessionId);
|
||||
|
||||
@@ -10,7 +10,7 @@ import { useChat } from "@ai-sdk/react";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { DefaultChatTransport } from "ai";
|
||||
import type { FileUIPart, UIMessage } from "ai";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import {
|
||||
deduplicateMessages,
|
||||
extractSendMessageText,
|
||||
@@ -40,8 +40,8 @@ interface UseCopilotStreamArgs {
|
||||
hydratedMessages: UIMessage[] | undefined;
|
||||
hasActiveStream: boolean;
|
||||
refetchSession: () => Promise<{ data?: unknown }>;
|
||||
/** Autopilot mode to use for requests. */
|
||||
copilotMode: "extended_thinking" | "fast";
|
||||
/** Autopilot mode to use for requests. `undefined` = let backend decide via feature flags. */
|
||||
copilotMode: "extended_thinking" | "fast" | undefined;
|
||||
}
|
||||
|
||||
export function useCopilotStream({
|
||||
@@ -53,7 +53,9 @@ export function useCopilotStream({
|
||||
}: UseCopilotStreamArgs) {
|
||||
const queryClient = useQueryClient();
|
||||
const [rateLimitMessage, setRateLimitMessage] = useState<string | null>(null);
|
||||
const dismissRateLimit = useCallback(() => setRateLimitMessage(null), []);
|
||||
function dismissRateLimit() {
|
||||
setRateLimitMessage(null);
|
||||
}
|
||||
// Use a ref for copilotMode so the transport closure always reads the
|
||||
// latest value without recreating the DefaultChatTransport (which would
|
||||
// reset useChat's internal Chat instance and break mid-session streaming).
|
||||
@@ -89,7 +91,7 @@ export function useCopilotStream({
|
||||
is_user_message: last.role === "user",
|
||||
context: null,
|
||||
file_ids: fileIds && fileIds.length > 0 ? fileIds : null,
|
||||
mode: copilotModeRef.current,
|
||||
mode: copilotModeRef.current ?? null,
|
||||
},
|
||||
headers: await getAuthHeaders(),
|
||||
};
|
||||
|
||||
@@ -27,7 +27,7 @@ const defaultFlags = {
|
||||
[Flag.AGENT_FAVORITING]: false,
|
||||
[Flag.MARKETPLACE_SEARCH_TERMS]: DEFAULT_SEARCH_TERMS,
|
||||
[Flag.ENABLE_PLATFORM_PAYMENT]: false,
|
||||
[Flag.CHAT_MODE_OPTION]: false,
|
||||
[Flag.CHAT_MODE_OPTION]: true,
|
||||
};
|
||||
|
||||
type FlagValues = typeof defaultFlags;
|
||||
|
||||
Reference in New Issue
Block a user