fix(copilot): address all review items on PR #12623

Blockers:
- B1: use config.model (opus) for baseline unless mode='fast' explicitly
  downgrades — prevents silent quality drop for all baseline users
- B2: extract _resolve_use_sdk / _resolve_effective_mode to production
  code in processor.py; tests now exercise real routing logic
- B3: add real unit tests covering transcript download/validate/load/
  append/backfill/upload round-trip via new service helpers

Should-fix:
- S1: server-side CHAT_MODE_OPTION flag gate in processor blocks
  unauthorised users from bypassing the UI toggle via crafted requests
- S2: frontend gates copilotMode behind the CHAT_MODE_OPTION flag in
  useCopilotPage so stale localStorage values aren't sent when off
- S3: Pydantic validation tests for StreamChatRequest.mode literal
- S4: extract _load_prior_transcript / _upload_final_transcript helpers
  and split _baseline_conversation_updater into message-mutation and
  transcript-recording concerns

Nice-to-have + Nits:
- N1: parallelize GCS transcript content + metadata downloads via gather
- N3: add role='switch' and aria-checked to mode toggle button
- N5: toast notification when user toggles mode
- Nit1: drop export on ChatInput Props
- Nit2: inline dismissRateLimit, drop unused useCallback import
- Nit3: replace 'end_turn'/'tool_use' magic strings with constants
- Nit4: log malformed tool_call JSON parse errors at debug level
This commit is contained in:
Zamil Majdy
2026-04-05 11:34:55 +02:00
parent d9c59e3616
commit 4e0d6bbde5
12 changed files with 923 additions and 346 deletions

View File

@@ -541,3 +541,41 @@ def test_create_session_rejects_nested_metadata(
)
assert response.status_code == 422
class TestStreamChatRequestModeValidation:
"""Pydantic-level validation of the ``mode`` field on StreamChatRequest."""
def test_rejects_invalid_mode_value(self) -> None:
"""Any string outside the Literal set must raise ValidationError."""
from pydantic import ValidationError
from backend.api.features.chat.routes import StreamChatRequest
with pytest.raises(ValidationError):
StreamChatRequest(message="hi", mode="turbo") # type: ignore[arg-type]
def test_accepts_fast_mode(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode="fast")
assert req.mode == "fast"
def test_accepts_extended_thinking_mode(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode="extended_thinking")
assert req.mode == "extended_thinking"
def test_accepts_none_mode(self) -> None:
"""``mode=None`` is valid (server decides via feature flags)."""
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode=None)
assert req.mode is None
def test_mode_defaults_to_none_when_omitted(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi")
assert req.mode is None

View File

@@ -12,7 +12,7 @@ import uuid
from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import Any, cast
from typing import Any, Literal, cast
import orjson
from langfuse import propagate_attributes
@@ -53,6 +53,8 @@ from backend.copilot.token_tracking import persist_and_record_usage
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
STOP_REASON_TOOL_USE,
download_transcript,
upload_transcript,
validate_transcript,
@@ -80,6 +82,19 @@ _background_tasks: set[asyncio.Task[Any]] = set()
_MAX_TOOL_ROUNDS = 30
def _resolve_baseline_model(mode: Literal["fast", "extended_thinking"] | None) -> str:
"""Pick the model for the baseline path based on the per-request mode.
Only ``mode='fast'`` downgrades to the cheaper/faster model. Any other
value (including ``None`` and ``'extended_thinking'``) preserves the
default model so that users who never select a mode don't get
silently moved to the cheaper tier.
"""
if mode == "fast":
return config.fast_model
return config.model
@dataclass
class _BaselineStreamState:
"""Mutable state shared between the tool-call loop callbacks.
@@ -88,6 +103,7 @@ class _BaselineStreamState:
can be module-level functions instead of deeply nested closures.
"""
model: str = ""
pending_events: list[StreamBaseResponse] = field(default_factory=list)
assistant_text: str = ""
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
@@ -115,7 +131,7 @@ async def _baseline_llm_caller(
if tools:
typed_tools = cast(list[ChatCompletionToolParam], tools)
response = await client.chat.completions.create(
model=config.fast_model,
model=state.model,
messages=typed_messages,
tools=typed_tools,
stream=True,
@@ -123,7 +139,7 @@ async def _baseline_llm_caller(
)
else:
response = await client.chat.completions.create(
model=config.fast_model,
model=state.model,
messages=typed_messages,
stream=True,
stream_options={"include_usage": True},
@@ -285,20 +301,17 @@ async def _baseline_tool_executor(
)
def _baseline_conversation_updater(
def _mutate_openai_messages(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
*,
transcript_builder: TranscriptBuilder,
model: str = "",
tool_results: list[ToolCallResult] | None,
) -> None:
"""Update OpenAI message list with assistant response + tool results.
"""Append assistant / tool-result entries to the OpenAI message list.
Extracted from ``stream_chat_completion_baseline`` for readability.
This is the side-effect boundary for the next LLM call — no transcript
mutation happens here.
"""
if tool_results:
# Build assistant message with tool_calls
assistant_msg: dict[str, Any] = {"role": "assistant"}
if response.response_text:
assistant_msg["content"] = response.response_text
@@ -311,14 +324,46 @@ def _baseline_conversation_updater(
for tc in response.tool_calls
]
messages.append(assistant_msg)
# Record assistant message (with tool_calls) to transcript
for tr in tool_results:
messages.append(
{
"role": "tool",
"tool_call_id": tr.tool_call_id,
"content": tr.content,
}
)
elif response.response_text:
messages.append({"role": "assistant", "content": response.response_text})
def _record_turn_to_transcript(
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None,
*,
transcript_builder: TranscriptBuilder,
model: str,
) -> None:
"""Append assistant + tool-result entries to the transcript builder.
Kept separate from :func:`_mutate_openai_messages` so the two
concerns (next-LLM-call payload vs. durable conversation log) can
evolve independently.
"""
if tool_results:
content_blocks: list[dict[str, Any]] = []
if response.response_text:
content_blocks.append({"type": "text", "text": response.response_text})
for tc in response.tool_calls:
try:
args = orjson.loads(tc.arguments) if tc.arguments else {}
except Exception:
except Exception as parse_err:
logger.debug(
"[Baseline] Failed to parse tool_call arguments "
"(tool=%s, id=%s): %s",
tc.name,
tc.id,
parse_err,
)
args = {}
content_blocks.append(
{
@@ -332,16 +377,9 @@ def _baseline_conversation_updater(
transcript_builder.append_assistant(
content_blocks=content_blocks,
model=model,
stop_reason="tool_use",
stop_reason=STOP_REASON_TOOL_USE,
)
for tr in tool_results:
messages.append(
{
"role": "tool",
"tool_call_id": tr.tool_call_id,
"content": tr.content,
}
)
# Record tool result to transcript AFTER the assistant tool_use
# block to maintain correct Anthropic API ordering:
# assistant(tool_use) → user(tool_result)
@@ -349,15 +387,34 @@ def _baseline_conversation_updater(
tool_use_id=tr.tool_call_id,
content=tr.content,
)
else:
if response.response_text:
messages.append({"role": "assistant", "content": response.response_text})
# Record final text to transcript
transcript_builder.append_assistant(
content_blocks=[{"type": "text", "text": response.response_text}],
model=model,
stop_reason="end_turn",
)
elif response.response_text:
transcript_builder.append_assistant(
content_blocks=[{"type": "text", "text": response.response_text}],
model=model,
stop_reason=STOP_REASON_END_TURN,
)
def _baseline_conversation_updater(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
*,
transcript_builder: TranscriptBuilder,
model: str = "",
) -> None:
"""Update OpenAI message list with assistant response + tool results.
Thin composition of :func:`_mutate_openai_messages` and
:func:`_record_turn_to_transcript`.
"""
_mutate_openai_messages(messages, response, tool_results)
_record_turn_to_transcript(
response,
tool_results,
transcript_builder=transcript_builder,
model=model,
)
async def _update_title_async(
@@ -374,6 +431,7 @@ async def _update_title_async(
async def _compress_session_messages(
messages: list[ChatMessage],
model: str,
) -> list[ChatMessage]:
"""Compress session messages if they exceed the model's token limit.
@@ -395,14 +453,14 @@ async def _compress_session_messages(
try:
result = await compress_context(
messages=messages_dict,
model=config.fast_model,
model=model,
client=_get_openai_client(),
)
except Exception as e:
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
result = await compress_context(
messages=messages_dict,
model=config.fast_model,
model=model,
client=None,
)
@@ -428,12 +486,85 @@ async def _compress_session_messages(
return messages
async def _load_prior_transcript(
user_id: str,
session_id: str,
session_msg_count: int,
transcript_builder: TranscriptBuilder,
) -> bool:
"""Download and load the prior transcript into ``transcript_builder``.
Returns ``True`` when the loaded transcript fully covers the session
prefix; ``False`` otherwise (stale, missing, invalid, or download
error). Callers should suppress uploads when this returns ``False``
to avoid overwriting a more complete version in storage.
"""
try:
dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]")
except Exception as e:
logger.warning("[Baseline] Transcript download failed: %s", e)
return False
if dl is None:
logger.debug("[Baseline] No transcript available")
return False
if not validate_transcript(dl.content):
logger.warning("[Baseline] Downloaded transcript but invalid")
return False
# Reject stale transcripts: if msg_count is known and doesn't cover
# the current session, loading it would silently drop intermediate
# turns from the transcript.
if dl.message_count and dl.message_count < session_msg_count - 1:
logger.warning(
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
dl.message_count,
session_msg_count,
)
return False
transcript_builder.load_previous(dl.content, log_prefix="[Baseline]")
logger.info(
"[Baseline] Loaded transcript: %dB, msg_count=%d",
len(dl.content),
dl.message_count,
)
return True
async def _upload_final_transcript(
user_id: str,
session_id: str,
transcript_builder: TranscriptBuilder,
session_msg_count: int,
) -> None:
"""Serialize and upload the transcript for next-turn continuity."""
try:
content = transcript_builder.to_jsonl()
if content and validate_transcript(content):
await asyncio.shield(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=content,
message_count=session_msg_count,
log_prefix="[Baseline]",
)
)
else:
logger.debug("[Baseline] No valid transcript to upload")
except Exception as upload_err:
logger.error("[Baseline] Transcript upload failed: %s", upload_err)
async def stream_chat_completion_baseline(
session_id: str,
message: str | None = None,
is_user_message: bool = True,
user_id: str | None = None,
session: ChatSession | None = None,
mode: Literal["fast", "extended_thinking"] | None = None,
**_kwargs: Any,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Baseline LLM with tool calling via OpenAI-compatible API.
@@ -462,43 +593,21 @@ async def stream_chat_completion_baseline(
session = await upsert_chat_session(session)
# Select model based on the per-request mode. 'fast' downgrades to
# the cheaper/faster model; everything else keeps the default.
active_model = _resolve_baseline_model(mode)
# --- Transcript support (feature parity with SDK path) ---
transcript_builder = TranscriptBuilder()
transcript_covers_prefix = True
if user_id and len(session.messages) > 1:
try:
dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]")
if dl and validate_transcript(dl.content):
# Reject stale transcripts: if msg_count is known and
# doesn't cover the current session, loading it would
# silently drop intermediate turns from the transcript.
session_msg_count = len(session.messages)
if dl.message_count and dl.message_count < session_msg_count - 1:
logger.warning(
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
dl.message_count,
session_msg_count,
)
transcript_covers_prefix = False
else:
transcript_builder.load_previous(
dl.content, log_prefix="[Baseline]"
)
logger.info(
"[Baseline] Loaded transcript: %dB, msg_count=%d",
len(dl.content),
dl.message_count,
)
elif dl:
logger.warning("[Baseline] Downloaded transcript but invalid")
transcript_covers_prefix = False
else:
logger.debug("[Baseline] No transcript available")
transcript_covers_prefix = False
except Exception as e:
logger.warning("[Baseline] Transcript download failed: %s", e)
transcript_covers_prefix = False
transcript_covers_prefix = await _load_prior_transcript(
user_id=user_id,
session_id=session_id,
session_msg_count=len(session.messages),
transcript_builder=transcript_builder,
)
# Append user message to transcript.
# Always append when the message is present and is from the user,
@@ -540,7 +649,9 @@ async def stream_chat_completion_baseline(
system_prompt = base_system_prompt + get_baseline_supplement()
# Compress context if approaching the model's token limit
messages_for_context = await _compress_session_messages(session.messages)
messages_for_context = await _compress_session_messages(
session.messages, model=active_model
)
# Build OpenAI message list from session history.
# Include tool_calls on assistant messages and tool-role results so the
@@ -590,7 +701,7 @@ async def stream_chat_completion_baseline(
logger.warning("[Baseline] Langfuse trace context setup failed")
_stream_error = False # Track whether an error occurred during streaming
state = _BaselineStreamState()
state = _BaselineStreamState(model=active_model)
# Bind extracted module-level callbacks to this request's state/session
# using functools.partial so they satisfy the Protocol signatures.
@@ -602,7 +713,7 @@ async def stream_chat_completion_baseline(
_bound_conversation_updater = partial(
_baseline_conversation_updater,
transcript_builder=transcript_builder,
model=config.fast_model,
model=active_model,
)
try:
@@ -681,10 +792,10 @@ async def stream_chat_completion_baseline(
and not (_stream_error and not state.assistant_text)
):
state.turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.fast_model), 1
estimate_token_count(openai_messages, model=active_model), 1
)
state.turn_completion_tokens = estimate_token_count_str(
state.assistant_text, model=config.fast_model
state.assistant_text, model=active_model
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
@@ -724,27 +835,17 @@ async def stream_chat_completion_baseline(
if transcript_builder.last_entry_type != "assistant":
transcript_builder.append_assistant(
content_blocks=[{"type": "text", "text": state.assistant_text}],
model=config.fast_model,
stop_reason="end_turn",
model=active_model,
stop_reason=STOP_REASON_END_TURN,
)
if user_id and transcript_covers_prefix:
try:
_transcript_content = transcript_builder.to_jsonl()
if _transcript_content and validate_transcript(_transcript_content):
await asyncio.shield(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=_transcript_content,
message_count=len(session.messages),
log_prefix="[Baseline]",
)
)
else:
logger.debug("[Baseline] No valid transcript to upload")
except Exception as upload_err:
logger.error("[Baseline] Transcript upload failed: %s", upload_err)
await _upload_final_transcript(
user_id=user_id,
session_id=session_id,
transcript_builder=transcript_builder,
session_msg_count=len(session.messages),
)
# Yield usage and finish AFTER try/finally (not inside finally).
# PEP 525 prohibits yielding from finally in async generators during

View File

@@ -1,19 +1,34 @@
"""Unit tests for baseline transcript integration logic.
"""Integration tests for baseline transcript flow.
Tests cover stale transcript detection, transcript_covers_prefix gating,
and partial backfill on stream error — without API keys or network access.
Exercises the real helpers in ``baseline/service.py`` that download,
validate, load, append to, backfill, and upload the transcript.
Storage is mocked via ``download_transcript`` / ``upload_transcript``
patches; no network access is required.
"""
import json as stdlib_json
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.transcript import TranscriptDownload, validate_transcript
from backend.copilot.baseline.service import (
_load_prior_transcript,
_record_turn_to_transcript,
_resolve_baseline_model,
_upload_final_transcript,
)
from backend.copilot.service import config
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
STOP_REASON_TOOL_USE,
TranscriptDownload,
)
from backend.copilot.transcript_builder import TranscriptBuilder
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
def _make_transcript_content(*roles: str) -> str:
"""Build minimal valid JSONL transcript from role names."""
"""Build a minimal valid JSONL transcript from role names."""
lines = []
parent = ""
for i, role in enumerate(roles):
@@ -30,169 +45,413 @@ def _make_transcript_content(*roles: str) -> str:
if role == "assistant":
entry["message"]["id"] = f"msg_{i}"
entry["message"]["model"] = "test-model"
entry["message"]["type"] = "message"
entry["message"]["stop_reason"] = STOP_REASON_END_TURN
lines.append(stdlib_json.dumps(entry))
parent = uid
return "\n".join(lines) + "\n"
class TestStaleTranscriptDetection:
"""Tests for the stale-transcript detection logic in the baseline service.
class TestResolveBaselineModel:
"""Model selection honours the per-request mode."""
When the downloaded transcript's message_count is behind the session's
message count, the transcript is considered stale and skipped.
"""
def test_fast_mode_selects_fast_model(self):
assert _resolve_baseline_model("fast") == config.fast_model
def test_stale_transcript_detected(self):
"""Transcript with fewer messages than the session is flagged as stale."""
dl = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
message_count=2,
def test_extended_thinking_selects_default_model(self):
assert _resolve_baseline_model("extended_thinking") == config.model
def test_none_mode_selects_default_model(self):
"""Critical: baseline users without a mode MUST keep the default (opus)."""
assert _resolve_baseline_model(None) == config.model
def test_default_and_fast_models_differ(self):
"""Sanity: the two tiers are actually distinct in production config."""
assert config.model != config.fast_model
class TestLoadPriorTranscript:
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
@pytest.mark.asyncio
async def test_loads_fresh_transcript(self):
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=content, message_count=2)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
transcript_builder=builder,
)
assert covers is True
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
@pytest.mark.asyncio
async def test_rejects_stale_transcript(self):
"""msg_count strictly less than session-1 is treated as stale."""
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
# session has 6 messages, transcript only covers 2 → stale.
download = TranscriptDownload(content=content, message_count=2)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=6,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_missing_transcript_returns_false(self):
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=None),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_invalid_transcript_returns_false(self):
builder = TranscriptBuilder()
download = TranscriptDownload(
content='{"type":"progress","uuid":"a"}\n',
message_count=1,
)
session_msg_count = 6 # 4 new messages since transcript was uploaded
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
transcript_builder=builder,
)
is_stale = dl.message_count and dl.message_count < session_msg_count - 1
assert is_stale
assert covers is False
assert builder.is_empty
def test_fresh_transcript_accepted(self):
"""Transcript covering the session prefix is not flagged as stale."""
dl = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
message_count=4,
)
session_msg_count = 5 # Only 1 new message (the user turn just sent)
@pytest.mark.asyncio
async def test_download_exception_returns_false(self):
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(side_effect=RuntimeError("boom")),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
transcript_builder=builder,
)
is_stale = dl.message_count and dl.message_count < session_msg_count - 1
assert not is_stale
assert covers is False
assert builder.is_empty
def test_zero_message_count_not_stale(self):
"""When message_count is 0 (unknown), staleness check is skipped."""
dl = TranscriptDownload(
@pytest.mark.asyncio
async def test_zero_message_count_not_stale(self):
"""When msg_count is 0 (unknown), staleness check is skipped."""
builder = TranscriptBuilder()
download = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
message_count=0,
)
session_msg_count = 10
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=20,
transcript_builder=builder,
)
is_stale = dl.message_count and dl.message_count < session_msg_count - 1
assert not is_stale # 0 is falsy, so check is skipped
assert covers is True
assert builder.entry_count == 2
class TestTranscriptCoversPrefix:
"""Tests for transcript_covers_prefix gating in the baseline upload path.
class TestUploadFinalTranscript:
"""``_upload_final_transcript`` serialises and calls storage."""
When transcript_covers_prefix is False, the transcript is NOT uploaded
to avoid overwriting a more complete version in storage.
"""
def test_no_download_sets_covers_false(self):
"""When no transcript is available, covers_prefix should be False."""
dl = None
transcript_covers_prefix = dl is not None
assert not transcript_covers_prefix
def test_invalid_transcript_sets_covers_false(self):
"""When downloaded transcript fails validation, covers_prefix is False."""
content = '{"type":"progress","uuid":"a"}\n'
assert not validate_transcript(content)
def test_valid_transcript_sets_covers_true(self):
"""When downloaded transcript is valid and fresh, covers_prefix is True."""
content = _make_transcript_content("user", "assistant")
assert validate_transcript(content)
class TestPartialBackfill:
"""Tests for partial backfill of assistant text on stream error.
When the stream aborts mid-round, the conversation updater may not have
recorded the partial assistant text. The finally block backfills it
so mode-switching after a failed turn doesn't lose the partial response.
"""
def test_backfill_appends_when_last_entry_not_assistant(self):
"""When the last transcript entry is not an assistant, backfill appends."""
@pytest.mark.asyncio
async def test_uploads_valid_transcript(self):
builder = TranscriptBuilder()
builder.append_user("user question")
assistant_text = "partial response before error"
assert builder.last_entry_type != "assistant"
builder.append_user(content="hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": assistant_text}],
content_blocks=[{"type": "text", "text": "hello"}],
model="test-model",
stop_reason="end_turn",
stop_reason=STOP_REASON_END_TURN,
)
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=2,
)
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
call_kwargs = upload_mock.await_args.kwargs
assert call_kwargs["user_id"] == "user-1"
assert call_kwargs["session_id"] == "session-1"
assert call_kwargs["message_count"] == 2
assert "hello" in call_kwargs["content"]
@pytest.mark.asyncio
async def test_skips_upload_when_builder_empty(self):
builder = TranscriptBuilder()
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=0,
)
upload_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_swallows_upload_exceptions(self):
"""Upload failures should not propagate (flow continues for the user)."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "hello"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=AsyncMock(side_effect=RuntimeError("storage unavailable")),
):
# Should not raise.
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=2,
)
class TestRecordTurnToTranscript:
"""``_record_turn_to_transcript`` translates LLMLoopResponse → transcript."""
def test_records_final_assistant_text(self):
builder = TranscriptBuilder()
builder.append_user(content="hi")
response = LLMLoopResponse(
response_text="hello there",
tool_calls=[],
raw_response=None,
)
_record_turn_to_transcript(
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
jsonl = builder.to_jsonl()
assert "partial response before error" in jsonl
assert "hello there" in jsonl
assert STOP_REASON_END_TURN in jsonl
def test_no_backfill_when_last_entry_is_assistant(self):
"""When the conversation updater already recorded the assistant turn,
backfill should be skipped (checked via last_entry_type)."""
def test_records_tool_use_then_tool_result(self):
"""Anthropic ordering: assistant(tool_use) → user(tool_result)."""
builder = TranscriptBuilder()
builder.append_user("user question")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "already recorded"}],
builder.append_user(content="use a tool")
response = LLMLoopResponse(
response_text=None,
tool_calls=[
LLMToolCall(id="call-1", name="echo", arguments='{"text":"hi"}')
],
raw_response=None,
)
tool_results = [
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="hi")
]
_record_turn_to_transcript(
response,
tool_results,
transcript_builder=builder,
model="test-model",
stop_reason="end_turn",
)
assert builder.last_entry_type == "assistant"
initial_count = builder.entry_count
# user, assistant(tool_use), user(tool_result) = 3 entries
assert builder.entry_count == 3
jsonl = builder.to_jsonl()
assert STOP_REASON_TOOL_USE in jsonl
assert "tool_use" in jsonl
assert "tool_result" in jsonl
assert "call-1" in jsonl
# Simulating the backfill guard: don't append if already assistant
def test_records_nothing_on_empty_response(self):
builder = TranscriptBuilder()
builder.append_user(content="hi")
response = LLMLoopResponse(
response_text=None,
tool_calls=[],
raw_response=None,
)
_record_turn_to_transcript(
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 1
def test_malformed_tool_args_dont_crash(self):
"""Bad JSON in tool arguments falls back to {} without raising."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
response = LLMLoopResponse(
response_text=None,
tool_calls=[LLMToolCall(id="call-1", name="echo", arguments="{not-json")],
raw_response=None,
)
tool_results = [
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="ok")
]
_record_turn_to_transcript(
response,
tool_results,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 3
jsonl = builder.to_jsonl()
assert '"input":{}' in jsonl
class TestRoundTrip:
"""End-to-end: load prior → append new turn → upload."""
@pytest.mark.asyncio
async def test_full_round_trip(self):
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
transcript_builder=builder,
)
assert covers is True
assert builder.entry_count == 2
# New user turn.
builder.append_user(content="new question")
assert builder.entry_count == 3
# New assistant turn.
response = LLMLoopResponse(
response_text="new answer",
tool_calls=[],
raw_response=None,
)
_record_turn_to_transcript(
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 4
# Upload.
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=4,
)
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "new question" in uploaded
assert "new answer" in uploaded
# Original content preserved in the round trip.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_backfill_append_guard(self):
"""Backfill only runs when the last entry is not already assistant."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
# Simulate the backfill guard from stream_chat_completion_baseline.
assistant_text = "partial text before error"
if builder.last_entry_type != "assistant":
builder.append_assistant(
content_blocks=[{"type": "text", "text": assistant_text}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
assert builder.last_entry_type == "assistant"
assert "partial text before error" in builder.to_jsonl()
# Second invocation: the guard must prevent double-append.
initial_count = builder.entry_count
if builder.last_entry_type != "assistant":
builder.append_assistant(
content_blocks=[{"type": "text", "text": "duplicate"}],
model="test-model",
stop_reason="end_turn",
stop_reason=STOP_REASON_END_TURN,
)
assert builder.entry_count == initial_count
def test_no_backfill_when_no_assistant_text(self):
"""When stream_error is True but no assistant text was produced,
no backfill should occur."""
builder = TranscriptBuilder()
builder.append_user("user question")
assistant_text = ""
_stream_error = True
# Simulating the guard from service.py:
# if _stream_error and state.assistant_text:
should_backfill = _stream_error and assistant_text
assert not should_backfill
class TestTranscriptUploadGating:
"""Tests for the upload gating logic.
Upload only happens when user_id is set AND transcript_covers_prefix is True.
"""
@pytest.mark.asyncio
async def test_upload_skipped_without_user_id(self):
"""No upload when user_id is None."""
user_id = None
transcript_covers_prefix = True
should_upload = user_id and transcript_covers_prefix
assert not should_upload
@pytest.mark.asyncio
async def test_upload_skipped_when_prefix_not_covered(self):
"""No upload when transcript doesn't cover the session prefix."""
user_id = "user-1"
transcript_covers_prefix = False
should_upload = user_id and transcript_covers_prefix
assert not should_upload
@pytest.mark.asyncio
async def test_upload_proceeds_when_conditions_met(self):
"""Upload proceeds when user_id is set and prefix is covered."""
user_id = "user-1"
transcript_covers_prefix = True
should_upload = user_id and transcript_covers_prefix
assert should_upload

View File

@@ -10,6 +10,7 @@ import os
import subprocess
import threading
import time
from typing import Literal
from backend.copilot import stream_registry
from backend.copilot.baseline import stream_chat_completion_baseline
@@ -30,6 +31,57 @@ from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
# ============ Mode Routing ============ #
async def _resolve_effective_mode(
mode: Literal["fast", "extended_thinking"] | None,
user_id: str | None,
) -> Literal["fast", "extended_thinking"] | None:
"""Strip ``mode`` when the user is not entitled to the toggle.
The UI gates the mode toggle behind ``CHAT_MODE_OPTION``; the
processor enforces the same gate server-side so an authenticated
user cannot bypass the flag by crafting a request directly.
"""
if mode is None:
return None
allowed = await is_feature_enabled(
Flag.CHAT_MODE_OPTION,
user_id or "anonymous",
default=False,
)
if not allowed:
logger.info(f"Ignoring mode={mode} — CHAT_MODE_OPTION is disabled for user")
return None
return mode
async def _resolve_use_sdk(
mode: Literal["fast", "extended_thinking"] | None,
user_id: str | None,
*,
use_claude_code_subscription: bool,
config_default: bool,
) -> bool:
"""Pick the SDK vs baseline path for a single turn.
Per-request ``mode`` wins whenever it is set (after the
``CHAT_MODE_OPTION`` gate has been applied upstream). Otherwise
falls back to the Claude Code subscription override, then the
``COPILOT_SDK`` LaunchDarkly flag, then the config default.
"""
if mode == "fast":
return False
if mode == "extended_thinking":
return True
return use_claude_code_subscription or await is_feature_enabled(
Flag.COPILOT_SDK,
user_id or "anonymous",
default=config_default,
)
# ============ Module Entry Points ============ #
# Thread-local storage for processor instances
@@ -250,23 +302,19 @@ class CoPilotProcessor:
if config.test_mode:
stream_fn = stream_chat_completion_dummy
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
effective_mode = None
else:
# Per-request mode override from the frontend takes priority.
# 'fast' → baseline (OpenAI-compatible), 'extended_thinking' → SDK.
if entry.mode == "fast":
use_sdk = False
elif entry.mode == "extended_thinking":
use_sdk = True
else:
# No mode specified — fall back to feature flag / config.
use_sdk = (
config.use_claude_code_subscription
or await is_feature_enabled(
Flag.COPILOT_SDK,
entry.user_id or "anonymous",
default=config.use_claude_agent_sdk,
)
)
# Enforce server-side feature-flag gate so unauthorised
# users cannot force a mode by crafting the request.
effective_mode = await _resolve_effective_mode(
entry.mode, entry.user_id
)
use_sdk = await _resolve_use_sdk(
effective_mode,
entry.user_id,
use_claude_code_subscription=config.use_claude_code_subscription,
config_default=config.use_claude_agent_sdk,
)
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
@@ -274,7 +322,7 @@ class CoPilotProcessor:
)
log.info(
f"Using {'SDK' if use_sdk else 'baseline'} service "
f"(mode={entry.mode or 'default'})"
f"(mode={effective_mode or 'default'})"
)
# Stream chat completion and publish chunks to Redis.
@@ -287,6 +335,7 @@ class CoPilotProcessor:
user_id=entry.user_id,
context=entry.context,
file_ids=entry.file_ids,
mode=effective_mode,
)
async for chunk in stream_registry.stream_and_publish(
session_id=entry.session_id,

View File

@@ -4,104 +4,169 @@ Tests cover the mode→service mapping:
- 'fast' → baseline service
- 'extended_thinking' → SDK service
- None → feature flag / config fallback
as well as the ``CHAT_MODE_OPTION`` server-side gate. The tests import
the real production helpers from ``processor.py`` so the routing logic
has meaningful coverage.
"""
from typing import Literal
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.executor.processor import _resolve_effective_mode, _resolve_use_sdk
def _resolve_use_sdk(
mode: Literal["fast", "extended_thinking"] | None,
use_claude_code_subscription: bool = False,
feature_flag_value: bool = False,
config_default: bool = True,
) -> bool:
"""Replicate the mode-routing logic from CoPilotProcessor._execute_async.
Extracted here so we can test it in isolation without instantiating the
full processor or its event loop.
"""
if mode == "fast":
return False
elif mode == "extended_thinking":
return True
else:
return use_claude_code_subscription or feature_flag_value or config_default
class TestModeRouting:
class TestResolveUseSdk:
"""Tests for the per-request mode routing logic."""
def test_fast_mode_uses_baseline(self):
@pytest.mark.asyncio
async def test_fast_mode_uses_baseline(self):
"""mode='fast' always routes to baseline, regardless of flags."""
assert _resolve_use_sdk("fast") is False
assert _resolve_use_sdk("fast", use_claude_code_subscription=True) is False
assert _resolve_use_sdk("fast", feature_flag_value=True) is False
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
):
assert (
await _resolve_use_sdk(
"fast",
"user-1",
use_claude_code_subscription=True,
config_default=True,
)
is False
)
def test_extended_thinking_uses_sdk(self):
@pytest.mark.asyncio
async def test_extended_thinking_uses_sdk(self):
"""mode='extended_thinking' always routes to SDK, regardless of flags."""
assert _resolve_use_sdk("extended_thinking") is True
assert (
_resolve_use_sdk("extended_thinking", use_claude_code_subscription=False)
is True
)
assert _resolve_use_sdk("extended_thinking", config_default=False) is True
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert (
await _resolve_use_sdk(
"extended_thinking",
"user-1",
use_claude_code_subscription=False,
config_default=False,
)
is True
)
def test_none_mode_uses_subscription_override(self):
@pytest.mark.asyncio
async def test_none_mode_uses_subscription_override(self):
"""mode=None with claude_code_subscription=True routes to SDK."""
assert (
_resolve_use_sdk(
None,
use_claude_code_subscription=True,
feature_flag_value=False,
config_default=False,
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert (
await _resolve_use_sdk(
None,
"user-1",
use_claude_code_subscription=True,
config_default=False,
)
is True
)
is True
)
def test_none_mode_uses_feature_flag(self):
@pytest.mark.asyncio
async def test_none_mode_uses_feature_flag(self):
"""mode=None with feature flag enabled routes to SDK."""
assert (
_resolve_use_sdk(
None,
use_claude_code_subscription=False,
feature_flag_value=True,
config_default=False,
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
) as flag_mock:
assert (
await _resolve_use_sdk(
None,
"user-1",
use_claude_code_subscription=False,
config_default=False,
)
is True
)
is True
)
flag_mock.assert_awaited_once()
def test_none_mode_uses_config_default(self):
@pytest.mark.asyncio
async def test_none_mode_uses_config_default(self):
"""mode=None falls back to config.use_claude_agent_sdk."""
assert (
_resolve_use_sdk(
None,
use_claude_code_subscription=False,
feature_flag_value=False,
config_default=True,
# When LaunchDarkly returns the default (True), we expect SDK routing.
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
):
assert (
await _resolve_use_sdk(
None,
"user-1",
use_claude_code_subscription=False,
config_default=True,
)
is True
)
is True
)
def test_none_mode_all_disabled(self):
@pytest.mark.asyncio
async def test_none_mode_all_disabled(self):
"""mode=None with all flags off routes to baseline."""
assert (
_resolve_use_sdk(
None,
use_claude_code_subscription=False,
feature_flag_value=False,
config_default=False,
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert (
await _resolve_use_sdk(
None,
"user-1",
use_claude_code_subscription=False,
config_default=False,
)
is False
)
is False
)
def test_none_mode_precedence_subscription_over_flag(self):
"""Claude Code subscription takes priority over feature flag."""
assert (
_resolve_use_sdk(
None,
use_claude_code_subscription=True,
feature_flag_value=False,
config_default=False,
class TestResolveEffectiveMode:
"""Tests for the CHAT_MODE_OPTION server-side gate."""
@pytest.mark.asyncio
async def test_none_mode_passes_through(self):
"""mode=None is returned as-is without a flag check."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
) as flag_mock:
assert await _resolve_effective_mode(None, "user-1") is None
flag_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_mode_stripped_when_flag_disabled(self):
"""When CHAT_MODE_OPTION is off, mode is dropped to None."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert await _resolve_effective_mode("fast", "user-1") is None
assert await _resolve_effective_mode("extended_thinking", "user-1") is None
@pytest.mark.asyncio
async def test_mode_preserved_when_flag_enabled(self):
"""When CHAT_MODE_OPTION is on, the user-selected mode is preserved."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
):
assert await _resolve_effective_mode("fast", "user-1") == "fast"
assert (
await _resolve_effective_mode("extended_thinking", "user-1")
== "extended_thinking"
)
is True
)
@pytest.mark.asyncio
async def test_anonymous_user_with_mode(self):
"""Anonymous users (user_id=None) still pass through the gate."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
) as flag_mock:
assert await _resolve_effective_mode("fast", None) is None
flag_mock.assert_awaited_once()

View File

@@ -735,33 +735,44 @@ async def download_transcript(
Returns a ``TranscriptDownload`` with the JSONL content and the
``message_count`` watermark from the upload, or ``None`` if not found.
The content and metadata fetches run concurrently since they are
independent objects in the bucket.
"""
storage = await get_workspace_storage()
path = _build_storage_path(user_id, session_id, storage)
meta_path = _build_meta_storage_path(user_id, session_id, storage)
try:
data = await storage.retrieve(path)
content = data.decode("utf-8")
except FileNotFoundError:
content_task = asyncio.create_task(storage.retrieve(path))
meta_task = asyncio.create_task(storage.retrieve(meta_path))
content_result, meta_result = await asyncio.gather(
content_task, meta_task, return_exceptions=True
)
if isinstance(content_result, FileNotFoundError):
logger.debug("%s No transcript in storage", log_prefix)
return None
except Exception as e:
logger.warning("%s Failed to download transcript: %s", log_prefix, e)
if isinstance(content_result, BaseException):
logger.warning(
"%s Failed to download transcript: %s", log_prefix, content_result
)
return None
# Try to load metadata (best-effort — old transcripts won't have it)
content = content_result.decode("utf-8")
# Metadata is best-effort — old transcripts won't have it.
message_count = 0
uploaded_at = 0.0
try:
meta_path = _build_meta_storage_path(user_id, session_id, storage)
meta_data = await storage.retrieve(meta_path)
meta = json.loads(meta_data.decode("utf-8"), fallback={})
if isinstance(meta_result, FileNotFoundError):
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
elif isinstance(meta_result, BaseException):
logger.debug(
"%s Failed to load transcript metadata: %s", log_prefix, meta_result
)
else:
meta = json.loads(meta_result.decode("utf-8"), fallback={})
message_count = meta.get("message_count", 0)
uploaded_at = meta.get("uploaded_at", 0.0)
except FileNotFoundError:
pass # No metadata — treat as unknown (msg_count=0 → always fill gap)
except Exception as e:
logger.debug("%s Failed to load transcript metadata: %s", log_prefix, e)
logger.info(
"%s Downloaded %dB (msg_count=%d)", log_prefix, len(content), message_count
@@ -803,6 +814,7 @@ async def delete_transcript(user_id: str, session_id: str) -> None:
# JSONL protocol values used in transcript serialization.
STOP_REASON_END_TURN = "end_turn"
STOP_REASON_TOOL_USE = "tool_use"
COMPACT_MSG_ID_PREFIX = "msg_compact_"
ENTRY_TYPE_MESSAGE = "message"

View File

@@ -38,6 +38,7 @@ class Flag(str, Enum):
AGENT_ACTIVITY = "agent-activity"
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
CHAT = "chat"
CHAT_MODE_OPTION = "chat-mode-option"
COPILOT_SDK = "copilot-sdk"
COPILOT_DAILY_TOKEN_LIMIT = "copilot-daily-token-limit"
COPILOT_WEEKLY_TOKEN_LIMIT = "copilot-weekly-token-limit"

View File

@@ -5,6 +5,7 @@ import {
PromptInputTextarea,
PromptInputTools,
} from "@/components/ai-elements/prompt-input";
import { toast } from "@/components/molecules/Toast/use-toast";
import { InputGroup } from "@/components/ui/input-group";
import { cn } from "@/lib/utils";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
@@ -18,7 +19,7 @@ import { useCopilotUIStore } from "../../store";
import { useChatInput } from "./useChatInput";
import { useVoiceRecording } from "./useVoiceRecording";
export interface Props {
interface Props {
onSend: (message: string, files?: File[]) => void | Promise<void>;
disabled?: boolean;
isStreaming?: boolean;
@@ -49,6 +50,22 @@ export function ChatInput({
const isFastModeEnabled = useGetFlag(Flag.CHAT_MODE_OPTION);
const [files, setFiles] = useState<File[]>([]);
function handleToggleMode() {
const next =
copilotMode === "extended_thinking" ? "fast" : "extended_thinking";
setCopilotMode(next);
toast({
title:
next === "fast"
? "Switched to Fast mode"
: "Switched to Extended Thinking mode",
description:
next === "fast"
? "Response quality may differ."
: "Responses may take longer.",
});
}
// Merge files dropped onto the chat window into internal state.
useEffect(() => {
if (droppedFiles && droppedFiles.length > 0) {
@@ -165,14 +182,10 @@ export function ChatInput({
{isFastModeEnabled && (
<button
type="button"
role="switch"
aria-checked={copilotMode === "fast"}
disabled={isStreaming}
onClick={() =>
setCopilotMode(
copilotMode === "extended_thinking"
? "fast"
: "extended_thinking",
)
}
onClick={handleToggleMode}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
copilotMode === "extended_thinking"

View File

@@ -27,6 +27,11 @@ vi.mock("@/services/feature-flags/use-get-flag", () => ({
useGetFlag: () => mockFlagValue,
}));
vi.mock("@/components/molecules/Toast/use-toast", () => ({
toast: vi.fn(),
useToast: () => ({ toast: vi.fn(), dismiss: vi.fn() }),
}));
vi.mock("../useVoiceRecording", () => ({
useVoiceRecording: () => ({
isRecording: false,
@@ -153,4 +158,33 @@ describe("ChatInput mode toggle", () => {
const button = screen.getByLabelText(/switch to fast mode/i);
expect(button.hasAttribute("disabled")).toBe(true);
});
it("exposes role='switch' with aria-checked", () => {
mockFlagValue = true;
mockCopilotMode = "extended_thinking";
render(<ChatInput onSend={mockOnSend} />);
const button = screen.getByRole("switch");
expect(button.getAttribute("aria-checked")).toBe("false");
});
it("sets aria-checked=true in fast mode", () => {
mockFlagValue = true;
mockCopilotMode = "fast";
render(<ChatInput onSend={mockOnSend} />);
const button = screen.getByRole("switch");
expect(button.getAttribute("aria-checked")).toBe("true");
});
it("shows a toast when the user toggles mode", async () => {
const { toast } = await import("@/components/molecules/Toast/use-toast");
mockFlagValue = true;
mockCopilotMode = "extended_thinking";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByRole("switch"));
expect(toast).toHaveBeenCalledWith(
expect.objectContaining({
title: expect.stringMatching(/switched to fast mode/i),
}),
);
});
});

View File

@@ -10,6 +10,7 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useQueryClient } from "@tanstack/react-query";
import type { FileUIPart } from "ai";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useEffect, useRef, useState } from "react";
import { useCopilotUIStore } from "./store";
import { useChatSession } from "./useChatSession";
@@ -32,6 +33,8 @@ export function useCopilotPage() {
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
const queryClient = useQueryClient();
const isModeToggleEnabled = useGetFlag(Flag.CHAT_MODE_OPTION);
const {
sessionToDelete,
setSessionToDelete,
@@ -69,7 +72,7 @@ export function useCopilotPage() {
hydratedMessages,
hasActiveStream,
refetchSession,
copilotMode,
copilotMode: isModeToggleEnabled ? copilotMode : undefined,
});
useCopilotNotifications(sessionId);

View File

@@ -10,7 +10,7 @@ import { useChat } from "@ai-sdk/react";
import { useQueryClient } from "@tanstack/react-query";
import { DefaultChatTransport } from "ai";
import type { FileUIPart, UIMessage } from "ai";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useEffect, useMemo, useRef, useState } from "react";
import {
deduplicateMessages,
extractSendMessageText,
@@ -40,8 +40,8 @@ interface UseCopilotStreamArgs {
hydratedMessages: UIMessage[] | undefined;
hasActiveStream: boolean;
refetchSession: () => Promise<{ data?: unknown }>;
/** Autopilot mode to use for requests. */
copilotMode: "extended_thinking" | "fast";
/** Autopilot mode to use for requests. `undefined` = let backend decide via feature flags. */
copilotMode: "extended_thinking" | "fast" | undefined;
}
export function useCopilotStream({
@@ -53,7 +53,9 @@ export function useCopilotStream({
}: UseCopilotStreamArgs) {
const queryClient = useQueryClient();
const [rateLimitMessage, setRateLimitMessage] = useState<string | null>(null);
const dismissRateLimit = useCallback(() => setRateLimitMessage(null), []);
function dismissRateLimit() {
setRateLimitMessage(null);
}
// Use a ref for copilotMode so the transport closure always reads the
// latest value without recreating the DefaultChatTransport (which would
// reset useChat's internal Chat instance and break mid-session streaming).
@@ -89,7 +91,7 @@ export function useCopilotStream({
is_user_message: last.role === "user",
context: null,
file_ids: fileIds && fileIds.length > 0 ? fileIds : null,
mode: copilotModeRef.current,
mode: copilotModeRef.current ?? null,
},
headers: await getAuthHeaders(),
};

View File

@@ -27,7 +27,7 @@ const defaultFlags = {
[Flag.AGENT_FAVORITING]: false,
[Flag.MARKETPLACE_SEARCH_TERMS]: DEFAULT_SEARCH_TERMS,
[Flag.ENABLE_PLATFORM_PAYMENT]: false,
[Flag.CHAT_MODE_OPTION]: false,
[Flag.CHAT_MODE_OPTION]: true,
};
type FlagValues = typeof defaultFlags;