fix(copilot): address remaining PR #12623 review items

Blockers:
- Rename `_resolve_use_sdk`/`_resolve_effective_mode` to
  `resolve_use_sdk_for_mode`/`resolve_effective_mode` in processor.py
  so the mode-routing logic is importable; tests now exercise the
  production functions directly instead of a local copy.
- Extract `is_transcript_stale`/`should_upload_transcript` helpers in
  baseline/service.py and cover them with direct unit tests, replacing
  the duplicated boolean expressions in transcript_integration_test.

Should-fix:
- Add `TestTranscriptLifecycle` that drives the download -> validate ->
  build -> upload flow end-to-end with mocked storage.
- Avoid the triple JSONL parse on upload: rely on the transcript
  builder's `last_entry_type == "assistant"` invariant and thread
  `skip_strip=True` through `upload_transcript` for builder-generated
  content.
- Run `_load_prior_transcript` and `_build_system_prompt` concurrently
  via `asyncio.gather` on the request critical path.
- Add a compression round-trip test proving `tool_calls` and
  `tool_call_id` survive `_compress_session_messages`.
- Extract the inline mode-toggle JSX into a dedicated
  `ModeToggleButton` sub-component.

Nice-to-have:
- Introduce `CopilotMode` type alias in `copilot/config.py` and reuse
  it across backend routes, executor utils, processor, and baseline
  service.
- Bound the shielded transcript upload with `asyncio.wait_for(..., 30)`
  so a hung storage backend cannot block response completion.
- Trim the 7 private re-exports from `sdk/transcript.py` shim; tests
  that needed the privates now import them from the canonical
  `backend.copilot.transcript`.
- Upload the transcript and its metadata sidecar concurrently via
  `asyncio.gather` with `return_exceptions=True`.

Nits:
- Rename `isFastModeEnabled` to `showModeToggle`.
- Narrow `except Exception` to `(ValueError, TypeError,
  orjson.JSONDecodeError)` around tool-call argument parsing.
- Replace `role=\"switch\" aria-checked` with `aria-pressed` on the
  toggle button (a11y-correct for a toggle button role).
- Surface a streaming-specific tooltip when the toggle is disabled.
This commit is contained in:
Zamil Majdy
2026-04-05 12:41:14 +02:00
parent 4e0d6bbde5
commit 84c3dd7000
17 changed files with 623 additions and 197 deletions

View File

@@ -4,7 +4,7 @@ import asyncio
import logging
import re
from collections.abc import AsyncGenerator
from typing import Annotated, Literal
from typing import Annotated
from uuid import uuid4
from autogpt_libs import auth
@@ -15,7 +15,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig
from backend.copilot.config import ChatConfig, CopilotMode
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.model import (
ChatMessage,
@@ -111,7 +111,7 @@ class StreamChatRequest(BaseModel):
file_ids: list[str] | None = Field(
default=None, max_length=20
) # Workspace file IDs attached to this message
mode: Literal["fast", "extended_thinking"] | None = Field(
mode: CopilotMode | None = Field(
default=None,
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
"If None, uses the server default (extended_thinking).",

View File

@@ -12,12 +12,13 @@ import uuid
from collections.abc import AsyncGenerator, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Literal, cast
from typing import Any, cast
import orjson
from langfuse import propagate_attributes
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
from backend.copilot.config import CopilotMode
from backend.copilot.context import set_execution_context
from backend.copilot.model import (
ChatMessage,
@@ -55,6 +56,7 @@ from backend.copilot.tracking import track_user_message
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
STOP_REASON_TOOL_USE,
TranscriptDownload,
download_transcript,
upload_transcript,
validate_transcript,
@@ -82,7 +84,7 @@ _background_tasks: set[asyncio.Task[Any]] = set()
_MAX_TOOL_ROUNDS = 30
def _resolve_baseline_model(mode: Literal["fast", "extended_thinking"] | None) -> str:
def _resolve_baseline_model(mode: CopilotMode | None) -> str:
"""Pick the model for the baseline path based on the per-request mode.
Only ``mode='fast'`` downgrades to the cheaper/faster model. Any other
@@ -356,7 +358,7 @@ def _record_turn_to_transcript(
for tc in response.tool_calls:
try:
args = orjson.loads(tc.arguments) if tc.arguments else {}
except Exception as parse_err:
except (ValueError, TypeError, orjson.JSONDecodeError) as parse_err:
logger.debug(
"[Baseline] Failed to parse tool_call arguments "
"(tool=%s, id=%s): %s",
@@ -466,8 +468,7 @@ async def _compress_session_messages(
if result.was_compacted:
logger.info(
"[Baseline] Context compacted: %d -> %d tokens "
"(%d summarized, %d dropped)",
"[Baseline] Context compacted: %d -> %d tokens (%d summarized, %d dropped)",
result.original_token_count,
result.token_count,
result.messages_summarized,
@@ -486,6 +487,39 @@ async def _compress_session_messages(
return messages
def is_transcript_stale(dl: TranscriptDownload | None, session_msg_count: int) -> bool:
"""Return ``True`` when a download doesn't cover the current session.
A transcript is stale when it has a known ``message_count`` and that
count doesn't reach ``session_msg_count - 1`` (i.e. the session has
already advanced beyond what the stored transcript captures).
Loading a stale transcript would silently drop intermediate turns,
so callers should treat stale as "skip load, skip upload".
An unknown ``message_count`` (``0``) is treated as **not stale**
because older transcripts uploaded before msg_count tracking
existed must still be usable.
"""
if dl is None:
return False
if not dl.message_count:
return False
return dl.message_count < session_msg_count - 1
def should_upload_transcript(
user_id: str | None, transcript_covers_prefix: bool
) -> bool:
"""Return ``True`` when the caller should upload the final transcript.
Uploads require a logged-in user (for the storage key) *and* a
transcript that covered the session prefix when loaded — otherwise
we'd be overwriting a more complete version in storage with a
partial one built from just the current turn.
"""
return bool(user_id) and transcript_covers_prefix
async def _load_prior_transcript(
user_id: str,
session_id: str,
@@ -513,10 +547,7 @@ async def _load_prior_transcript(
logger.warning("[Baseline] Downloaded transcript but invalid")
return False
# Reject stale transcripts: if msg_count is known and doesn't cover
# the current session, loading it would silently drop intermediate
# turns from the transcript.
if dl.message_count and dl.message_count < session_msg_count - 1:
if is_transcript_stale(dl, session_msg_count):
logger.warning(
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
dl.message_count,
@@ -539,21 +570,37 @@ async def _upload_final_transcript(
transcript_builder: TranscriptBuilder,
session_msg_count: int,
) -> None:
"""Serialize and upload the transcript for next-turn continuity."""
"""Serialize and upload the transcript for next-turn continuity.
Uses the builder's own invariants to decide whether to upload,
avoiding a JSONL re-parse. A builder that ends with an assistant
entry is structurally complete; a builder that doesn't (empty, or
ends mid-turn) is skipped.
"""
try:
content = transcript_builder.to_jsonl()
if content and validate_transcript(content):
await asyncio.shield(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=content,
message_count=session_msg_count,
log_prefix="[Baseline]",
)
if transcript_builder.last_entry_type != "assistant":
logger.debug(
"[Baseline] No complete assistant turn to upload (last_entry=%s)",
transcript_builder.last_entry_type,
)
else:
logger.debug("[Baseline] No valid transcript to upload")
return
content = transcript_builder.to_jsonl()
if not content:
logger.debug("[Baseline] Empty transcript content, skipping upload")
return
upload_task = asyncio.shield(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=content,
message_count=session_msg_count,
log_prefix="[Baseline]",
skip_strip=True,
)
)
# Bound the shielded upload: a hung storage backend must not
# block the response from finishing.
await asyncio.wait_for(upload_task, timeout=30)
except Exception as upload_err:
logger.error("[Baseline] Transcript upload failed: %s", upload_err)
@@ -564,7 +611,7 @@ async def stream_chat_completion_baseline(
is_user_message: bool = True,
user_id: str | None = None,
session: ChatSession | None = None,
mode: Literal["fast", "extended_thinking"] | None = None,
mode: CopilotMode | None = None,
**_kwargs: Any,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Baseline LLM with tool calling via OpenAI-compatible API.
@@ -601,13 +648,28 @@ async def stream_chat_completion_baseline(
transcript_builder = TranscriptBuilder()
transcript_covers_prefix = True
# Build system prompt only on the first turn to avoid mid-conversation
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
if is_first_turn:
prompt_task = _build_system_prompt(user_id, has_conversation_history=False)
else:
prompt_task = _build_system_prompt(user_id=None, has_conversation_history=True)
# Run download + prompt build concurrently — both are independent I/O
# on the request critical path.
if user_id and len(session.messages) > 1:
transcript_covers_prefix = await _load_prior_transcript(
user_id=user_id,
session_id=session_id,
session_msg_count=len(session.messages),
transcript_builder=transcript_builder,
transcript_covers_prefix, (base_system_prompt, _) = await asyncio.gather(
_load_prior_transcript(
user_id=user_id,
session_id=session_id,
session_msg_count=len(session.messages),
transcript_builder=transcript_builder,
),
prompt_task,
)
else:
base_system_prompt, _ = await prompt_task
# Append user message to transcript.
# Always append when the message is present and is from the user,
@@ -633,18 +695,6 @@ async def stream_chat_completion_baseline(
message_id = str(uuid.uuid4())
# Build system prompt only on the first turn to avoid mid-conversation
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
if is_first_turn:
base_system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=False
)
else:
base_system_prompt, _ = await _build_system_prompt(
user_id=None, has_conversation_history=True
)
# Append tool documentation and technical notes
system_prompt = base_system_prompt + get_baseline_supplement()
@@ -839,7 +889,7 @@ async def stream_chat_completion_baseline(
stop_reason=STOP_REASON_END_TURN,
)
if user_id and transcript_covers_prefix:
if user_id and should_upload_transcript(user_id, transcript_covers_prefix):
await _upload_final_transcript(
user_id=user_id,
session_id=session_id,

View File

@@ -4,11 +4,18 @@ These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState`
without requiring API keys, database connections, or network access.
"""
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.baseline.service import (
_baseline_conversation_updater,
_BaselineStreamState,
_compress_session_messages,
)
from backend.copilot.model import ChatMessage
from backend.copilot.transcript_builder import TranscriptBuilder
from backend.util.prompt import CompressResult
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
@@ -232,3 +239,129 @@ class TestBaselineConversationUpdater:
# Should not raise — invalid JSON falls back to {} in transcript
assert len(messages) == 2
assert messages[0]["tool_calls"][0]["function"]["arguments"] == "not-json"
class TestCompressSessionMessagesPreservesToolCalls:
"""``_compress_session_messages`` must round-trip tool_calls + tool_call_id.
Compression serialises ChatMessage to dict for ``compress_context`` and
reifies the result back to ChatMessage. A regression that drops
``tool_calls`` or ``tool_call_id`` would corrupt the OpenAI message
list and break downstream tool-execution rounds.
"""
@pytest.mark.asyncio
async def test_compressed_output_keeps_tool_calls_and_ids(self):
# Simulate compression that returns a summary + the most recent
# assistant(tool_call) + tool(tool_result) intact.
summary = {"role": "system", "content": "prior turns: user asked X"}
assistant_with_tc = {
"role": "assistant",
"content": "calling tool",
"tool_calls": [
{
"id": "tc_abc",
"type": "function",
"function": {"name": "search", "arguments": '{"q":"y"}'},
}
],
}
tool_result = {
"role": "tool",
"tool_call_id": "tc_abc",
"content": "search result",
}
compress_result = CompressResult(
messages=[summary, assistant_with_tc, tool_result],
token_count=100,
was_compacted=True,
original_token_count=5000,
messages_summarized=10,
messages_dropped=0,
)
# Input: messages that should be compressed.
input_messages = [
ChatMessage(role="user", content="q1"),
ChatMessage(
role="assistant",
content="calling tool",
tool_calls=[
{
"id": "tc_abc",
"type": "function",
"function": {
"name": "search",
"arguments": '{"q":"y"}',
},
}
],
),
ChatMessage(
role="tool",
tool_call_id="tc_abc",
content="search result",
),
]
with patch(
"backend.copilot.baseline.service.compress_context",
new=AsyncMock(return_value=compress_result),
):
compressed = await _compress_session_messages(
input_messages, model="openrouter/anthropic/claude-opus-4"
)
# Summary, assistant(tool_calls), tool(tool_call_id).
assert len(compressed) == 3
# Assistant message must keep its tool_calls intact.
assistant_msg = compressed[1]
assert assistant_msg.role == "assistant"
assert assistant_msg.tool_calls is not None
assert len(assistant_msg.tool_calls) == 1
assert assistant_msg.tool_calls[0]["id"] == "tc_abc"
assert assistant_msg.tool_calls[0]["function"]["name"] == "search"
# Tool-role message must keep tool_call_id for OpenAI linkage.
tool_msg = compressed[2]
assert tool_msg.role == "tool"
assert tool_msg.tool_call_id == "tc_abc"
assert tool_msg.content == "search result"
@pytest.mark.asyncio
async def test_uncompressed_passthrough_keeps_fields(self):
"""When compression is a no-op (was_compacted=False), the original
messages must be returned unchanged — including tool_calls."""
input_messages = [
ChatMessage(
role="assistant",
content="c",
tool_calls=[
{
"id": "t1",
"type": "function",
"function": {"name": "f", "arguments": "{}"},
}
],
),
ChatMessage(role="tool", tool_call_id="t1", content="ok"),
]
noop_result = CompressResult(
messages=[], # ignored when was_compacted=False
token_count=10,
was_compacted=False,
)
with patch(
"backend.copilot.baseline.service.compress_context",
new=AsyncMock(return_value=noop_result),
):
out = await _compress_session_messages(
input_messages, model="openrouter/anthropic/claude-opus-4"
)
assert out is input_messages # same list returned
assert out[0].tool_calls is not None
assert out[0].tool_calls[0]["id"] == "t1"
assert out[1].tool_call_id == "t1"

View File

@@ -16,6 +16,8 @@ from backend.copilot.baseline.service import (
_record_turn_to_transcript,
_resolve_baseline_model,
_upload_final_transcript,
is_transcript_stale,
should_upload_transcript,
)
from backend.copilot.service import config
from backend.copilot.transcript import (
@@ -455,3 +457,211 @@ class TestRoundTrip:
stop_reason=STOP_REASON_END_TURN,
)
assert builder.entry_count == initial_count
class TestIsTranscriptStale:
"""``is_transcript_stale`` gates prior-transcript loading."""
def test_none_download_is_not_stale(self):
assert is_transcript_stale(None, session_msg_count=5) is False
def test_zero_message_count_is_not_stale(self):
"""Legacy transcripts without msg_count tracking must remain usable."""
dl = TranscriptDownload(content="", message_count=0)
assert is_transcript_stale(dl, session_msg_count=20) is False
def test_stale_when_covers_less_than_prefix(self):
dl = TranscriptDownload(content="", message_count=2)
# session has 6 messages; transcript must cover at least 5 (6-1).
assert is_transcript_stale(dl, session_msg_count=6) is True
def test_fresh_when_covers_full_prefix(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_fresh_when_exceeds_prefix(self):
"""Race: transcript ahead of session count is still acceptable."""
dl = TranscriptDownload(content="", message_count=10)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_boundary_equal_to_prefix_minus_one(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
class TestShouldUploadTranscript:
"""``should_upload_transcript`` gates the final upload."""
def test_upload_allowed_for_user_with_coverage(self):
assert should_upload_transcript("user-1", True) is True
def test_upload_skipped_when_no_user(self):
assert should_upload_transcript(None, True) is False
def test_upload_skipped_when_empty_user(self):
assert should_upload_transcript("", True) is False
def test_upload_skipped_without_coverage(self):
"""Partial transcript must never clobber a more complete stored one."""
assert should_upload_transcript("user-1", False) is False
def test_upload_skipped_when_no_user_and_no_coverage(self):
assert should_upload_transcript(None, False) is False
class TestTranscriptLifecycle:
"""End-to-end: download → validate → build → upload.
Simulates the full transcript lifecycle inside
``stream_chat_completion_baseline`` by mocking the storage layer and
driving each step through the real helpers.
"""
@pytest.mark.asyncio
async def test_full_lifecycle_happy_path(self):
"""Fresh download, append a turn, upload covers the session."""
builder = TranscriptBuilder()
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
# --- 1. Download & load prior transcript ---
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
transcript_builder=builder,
)
assert covers is True
# --- 2. Append a new user turn + a new assistant response ---
builder.append_user(content="follow-up question")
_record_turn_to_transcript(
LLMLoopResponse(
response_text="follow-up answer",
tool_calls=[],
raw_response=None,
),
tool_results=None,
transcript_builder=builder,
model="test-model",
)
# --- 3. Gate + upload ---
assert (
should_upload_transcript(
user_id="user-1", transcript_covers_prefix=covers
)
is True
)
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=4,
)
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "follow-up question" in uploaded
assert "follow-up answer" in uploaded
# Original prior-turn content preserved.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_lifecycle_stale_download_suppresses_upload(self):
"""Stale download → covers=False → upload must be skipped."""
builder = TranscriptBuilder()
# session has 10 msgs but stored transcript only covers 2 → stale.
stale = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
message_count=2,
)
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=stale),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=10,
transcript_builder=builder,
)
assert covers is False
# The caller's gate mirrors the production path.
assert (
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
is False
)
upload_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_lifecycle_anonymous_user_skips_upload(self):
"""Anonymous (user_id=None) → upload gate must return False."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "hello"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
assert (
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
is False
)
@pytest.mark.asyncio
async def test_lifecycle_missing_download_still_uploads_new_content(self):
"""No prior transcript → covers defaults to True in the service,
new turn should upload cleanly."""
builder = TranscriptBuilder()
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=None),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=1,
transcript_builder=builder,
)
# No download: covers is False, so the production path would
# skip upload. This protects against overwriting a future
# more-complete transcript with a single-turn snapshot.
assert covers is False
assert (
should_upload_transcript(
user_id="user-1", transcript_covers_prefix=covers
)
is False
)
upload_mock.assert_not_awaited()

View File

@@ -8,6 +8,14 @@ from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
# Per-request routing mode for a single chat turn.
# - 'fast': route to the baseline OpenAI-compatible path with the cheaper model.
# - 'extended_thinking': route to the Claude Agent SDK path with the default
# (opus) model.
# ``None`` means "no override"; the server falls back to the Claude Code
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
CopilotMode = Literal["fast", "extended_thinking"]
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""

View File

@@ -10,11 +10,10 @@ import os
import subprocess
import threading
import time
from typing import Literal
from backend.copilot import stream_registry
from backend.copilot.baseline import stream_chat_completion_baseline
from backend.copilot.config import ChatConfig
from backend.copilot.config import ChatConfig, CopilotMode
from backend.copilot.response_model import StreamError
from backend.copilot.sdk import service as sdk_service
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
@@ -34,10 +33,10 @@ logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]"
# ============ Mode Routing ============ #
async def _resolve_effective_mode(
mode: Literal["fast", "extended_thinking"] | None,
async def resolve_effective_mode(
mode: CopilotMode | None,
user_id: str | None,
) -> Literal["fast", "extended_thinking"] | None:
) -> CopilotMode | None:
"""Strip ``mode`` when the user is not entitled to the toggle.
The UI gates the mode toggle behind ``CHAT_MODE_OPTION``; the
@@ -57,8 +56,8 @@ async def _resolve_effective_mode(
return mode
async def _resolve_use_sdk(
mode: Literal["fast", "extended_thinking"] | None,
async def resolve_use_sdk_for_mode(
mode: CopilotMode | None,
user_id: str | None,
*,
use_claude_code_subscription: bool,
@@ -306,10 +305,8 @@ class CoPilotProcessor:
else:
# Enforce server-side feature-flag gate so unauthorised
# users cannot force a mode by crafting the request.
effective_mode = await _resolve_effective_mode(
entry.mode, entry.user_id
)
use_sdk = await _resolve_use_sdk(
effective_mode = await resolve_effective_mode(entry.mode, entry.user_id)
use_sdk = await resolve_use_sdk_for_mode(
effective_mode,
entry.user_id,
use_claude_code_subscription=config.use_claude_code_subscription,

View File

@@ -14,10 +14,13 @@ from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.executor.processor import _resolve_effective_mode, _resolve_use_sdk
from backend.copilot.executor.processor import (
resolve_effective_mode,
resolve_use_sdk_for_mode,
)
class TestResolveUseSdk:
class TestResolveUseSdkForMode:
"""Tests for the per-request mode routing logic."""
@pytest.mark.asyncio
@@ -28,7 +31,7 @@ class TestResolveUseSdk:
new=AsyncMock(return_value=True),
):
assert (
await _resolve_use_sdk(
await resolve_use_sdk_for_mode(
"fast",
"user-1",
use_claude_code_subscription=True,
@@ -45,7 +48,7 @@ class TestResolveUseSdk:
new=AsyncMock(return_value=False),
):
assert (
await _resolve_use_sdk(
await resolve_use_sdk_for_mode(
"extended_thinking",
"user-1",
use_claude_code_subscription=False,
@@ -62,7 +65,7 @@ class TestResolveUseSdk:
new=AsyncMock(return_value=False),
):
assert (
await _resolve_use_sdk(
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=True,
@@ -79,7 +82,7 @@ class TestResolveUseSdk:
new=AsyncMock(return_value=True),
) as flag_mock:
assert (
await _resolve_use_sdk(
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=False,
@@ -98,7 +101,7 @@ class TestResolveUseSdk:
new=AsyncMock(return_value=True),
):
assert (
await _resolve_use_sdk(
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=False,
@@ -115,7 +118,7 @@ class TestResolveUseSdk:
new=AsyncMock(return_value=False),
):
assert (
await _resolve_use_sdk(
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=False,
@@ -135,7 +138,7 @@ class TestResolveEffectiveMode:
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
) as flag_mock:
assert await _resolve_effective_mode(None, "user-1") is None
assert await resolve_effective_mode(None, "user-1") is None
flag_mock.assert_not_awaited()
@pytest.mark.asyncio
@@ -145,8 +148,8 @@ class TestResolveEffectiveMode:
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert await _resolve_effective_mode("fast", "user-1") is None
assert await _resolve_effective_mode("extended_thinking", "user-1") is None
assert await resolve_effective_mode("fast", "user-1") is None
assert await resolve_effective_mode("extended_thinking", "user-1") is None
@pytest.mark.asyncio
async def test_mode_preserved_when_flag_enabled(self):
@@ -155,9 +158,9 @@ class TestResolveEffectiveMode:
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
):
assert await _resolve_effective_mode("fast", "user-1") == "fast"
assert await resolve_effective_mode("fast", "user-1") == "fast"
assert (
await _resolve_effective_mode("extended_thinking", "user-1")
await resolve_effective_mode("extended_thinking", "user-1")
== "extended_thinking"
)
@@ -168,5 +171,5 @@ class TestResolveEffectiveMode:
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
) as flag_mock:
assert await _resolve_effective_mode("fast", None) is None
assert await resolve_effective_mode("fast", None) is None
flag_mock.assert_awaited_once()

View File

@@ -6,10 +6,10 @@ Defines two exchanges and queues following the graph executor pattern:
"""
import logging
from typing import Literal
from pydantic import BaseModel
from backend.copilot.config import CopilotMode
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
@@ -157,7 +157,7 @@ class CoPilotExecutionEntry(BaseModel):
file_ids: list[str] | None = None
"""Workspace file IDs attached to the user's message"""
mode: Literal["fast", "extended_thinking"] | None = None
mode: CopilotMode | None = None
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
@@ -179,7 +179,7 @@ async def enqueue_copilot_turn(
is_user_message: bool = True,
context: dict[str, str] | None = None,
file_ids: list[str] | None = None,
mode: Literal["fast", "extended_thinking"] | None = None,
mode: CopilotMode | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.

View File

@@ -8,20 +8,19 @@ from uuid import uuid4
import pytest
from backend.util import json
from backend.util.prompt import CompressResult
from .conftest import build_test_transcript as _build_transcript
from .service import _friendly_error_text, _is_prompt_too_long
from .transcript import (
from backend.copilot.transcript import (
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_run_compression,
_transcript_to_messages,
compact_transcript,
validate_transcript,
)
from backend.util import json
from backend.util.prompt import CompressResult
from .conftest import build_test_transcript as _build_transcript
from .service import _friendly_error_text, _is_prompt_too_long
from .transcript import compact_transcript, validate_transcript
# ---------------------------------------------------------------------------
# _flatten_assistant_content

View File

@@ -26,18 +26,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.util import json
from .conftest import build_test_transcript as _build_transcript
from .service import _MAX_STREAM_ATTEMPTS, _reduce_context
from .transcript import (
from backend.copilot.transcript import (
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_transcript_to_messages,
compact_transcript,
validate_transcript,
)
from backend.util import json
from .conftest import build_test_transcript as _build_transcript
from .service import _MAX_STREAM_ATTEMPTS, _reduce_context
from .transcript import compact_transcript, validate_transcript
from .transcript_builder import TranscriptBuilder
# ---------------------------------------------------------------------------
@@ -1405,9 +1404,9 @@ class TestStreamChatCompletionRetryIntegration:
events.append(event)
# Should NOT retry — only 1 attempt for auth errors
assert attempt_count[0] == 1, (
f"Expected 1 attempt (no retry for auth error), " f"got {attempt_count[0]}"
)
assert (
attempt_count[0] == 1
), f"Expected 1 attempt (no retry for auth error), got {attempt_count[0]}"
errors = [e for e in events if isinstance(e, StreamError)]
assert errors, "Expected StreamError"
assert errors[0].code == "sdk_stream_error"

View File

@@ -27,20 +27,19 @@ from backend.copilot.response_model import (
StreamTextDelta,
StreamTextStart,
)
from backend.util import json
from .conftest import build_structured_transcript
from .response_adapter import SDKResponseAdapter
from .service import _format_sdk_content_blocks
from .transcript import (
from backend.copilot.transcript import (
_find_last_assistant_entry,
_flatten_assistant_content,
_messages_to_transcript,
_rechain_tail,
_transcript_to_messages,
compact_transcript,
validate_transcript,
)
from backend.util import json
from .conftest import build_structured_transcript
from .response_adapter import SDKResponseAdapter
from .service import _format_sdk_content_blocks
from .transcript import compact_transcript, validate_transcript
# ---------------------------------------------------------------------------
# Fixtures: realistic thinking block content

View File

@@ -1,28 +1,19 @@
"""Re-export from shared ``backend.copilot.transcript`` for backward compat.
"""Re-export public API from shared ``backend.copilot.transcript``.
The canonical implementation now lives at ``backend.copilot.transcript``
so both the SDK and baseline paths can import without cross-package
dependencies. All symbols are re-exported here so existing ``from
dependencies. Public symbols are re-exported here so existing ``from
.transcript import ...`` statements within the ``sdk`` package continue
to work without modification.
"""
from backend.copilot.transcript import (
_MAX_PROJECT_DIRS_TO_SWEEP,
_STALE_PROJECT_DIR_SECONDS,
COMPACT_MSG_ID_PREFIX,
ENTRY_TYPE_MESSAGE,
STOP_REASON_END_TURN,
STRIPPABLE_TYPES,
TRANSCRIPT_STORAGE_PREFIX,
TranscriptDownload,
_find_last_assistant_entry,
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_rechain_tail,
_run_compression,
_transcript_to_messages,
cleanup_stale_project_dirs,
compact_transcript,
delete_transcript,
@@ -43,15 +34,6 @@ __all__ = [
"STRIPPABLE_TYPES",
"TRANSCRIPT_STORAGE_PREFIX",
"TranscriptDownload",
"_MAX_PROJECT_DIRS_TO_SWEEP",
"_STALE_PROJECT_DIR_SECONDS",
"_find_last_assistant_entry",
"_flatten_assistant_content",
"_flatten_tool_result_content",
"_messages_to_transcript",
"_rechain_tail",
"_run_compression",
"_transcript_to_messages",
"cleanup_stale_project_dirs",
"compact_transcript",
"delete_transcript",

View File

@@ -850,7 +850,7 @@ class TestRunCompression:
@pytest.mark.asyncio
async def test_no_client_uses_truncation(self):
"""Path (a): ``get_openai_client()`` returns None → truncation only."""
from .transcript import _run_compression
from backend.copilot.transcript import _run_compression
truncation_result = self._make_compress_result(
True, [{"role": "user", "content": "truncated"}]
@@ -885,7 +885,7 @@ class TestRunCompression:
@pytest.mark.asyncio
async def test_llm_success_returns_llm_result(self):
"""Path (b): ``get_openai_client()`` returns a client → LLM compresses."""
from .transcript import _run_compression
from backend.copilot.transcript import _run_compression
llm_result = self._make_compress_result(
True, [{"role": "user", "content": "LLM summary"}]
@@ -916,7 +916,7 @@ class TestRunCompression:
@pytest.mark.asyncio
async def test_llm_failure_falls_back_to_truncation(self):
"""Path (c): LLM call raises → truncation fallback used instead."""
from .transcript import _run_compression
from backend.copilot.transcript import _run_compression
truncation_result = self._make_compress_result(
True, [{"role": "user", "content": "truncated fallback"}]
@@ -953,7 +953,7 @@ class TestRunCompression:
@pytest.mark.asyncio
async def test_llm_timeout_falls_back_to_truncation(self):
"""Path (d): LLM call exceeds timeout → truncation fallback used."""
from .transcript import _run_compression
from backend.copilot.transcript import _run_compression
truncation_result = self._make_compress_result(
True, [{"role": "user", "content": "truncated after timeout"}]
@@ -1007,7 +1007,7 @@ class TestCleanupStaleProjectDirs:
def test_removes_old_copilot_dirs(self, tmp_path, monkeypatch):
"""Directories matching copilot pattern older than threshold are removed."""
from backend.copilot.sdk.transcript import (
from backend.copilot.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
@@ -1039,7 +1039,7 @@ class TestCleanupStaleProjectDirs:
def test_ignores_non_copilot_dirs(self, tmp_path, monkeypatch):
"""Directories not matching copilot pattern are left alone."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
from backend.copilot.transcript import cleanup_stale_project_dirs
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
@@ -1062,7 +1062,7 @@ class TestCleanupStaleProjectDirs:
def test_ttl_boundary_not_removed(self, tmp_path, monkeypatch):
"""A directory exactly at the TTL boundary should NOT be removed."""
from backend.copilot.sdk.transcript import (
from backend.copilot.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
@@ -1088,7 +1088,7 @@ class TestCleanupStaleProjectDirs:
def test_skips_non_directory_entries(self, tmp_path, monkeypatch):
"""Regular files matching the copilot pattern are not removed."""
from backend.copilot.sdk.transcript import (
from backend.copilot.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
@@ -1114,7 +1114,7 @@ class TestCleanupStaleProjectDirs:
def test_missing_base_dir_returns_zero(self, tmp_path, monkeypatch):
"""If the projects base directory doesn't exist, return 0 gracefully."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
from backend.copilot.transcript import cleanup_stale_project_dirs
nonexistent = str(tmp_path / "does-not-exist" / "projects")
monkeypatch.setattr(
@@ -1129,7 +1129,7 @@ class TestCleanupStaleProjectDirs:
"""When encoded_cwd is supplied only that directory is swept."""
import time
from backend.copilot.sdk.transcript import (
from backend.copilot.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
@@ -1160,7 +1160,7 @@ class TestCleanupStaleProjectDirs:
def test_scoped_fresh_dir_not_removed(self, tmp_path, monkeypatch):
"""Scoped sweep leaves a fresh directory alone."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
from backend.copilot.transcript import cleanup_stale_project_dirs
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
@@ -1181,7 +1181,7 @@ class TestCleanupStaleProjectDirs:
"""Scoped sweep refuses to remove a non-copilot directory."""
import time
from backend.copilot.sdk.transcript import (
from backend.copilot.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)

View File

@@ -658,6 +658,7 @@ async def upload_transcript(
content: str,
message_count: int = 0,
log_prefix: str = "[Transcript]",
skip_strip: bool = False,
) -> None:
"""Strip progress entries and stale thinking blocks, then upload transcript.
@@ -670,11 +671,18 @@ async def upload_transcript(
Args:
content: Complete JSONL transcript (from TranscriptBuilder).
message_count: ``len(session.messages)`` at upload time.
skip_strip: When ``True``, skip the strip + re-validate pass.
Safe for builder-generated content (baseline path) which
never emits progress entries or stale thinking blocks.
"""
# Strip metadata entries and stale thinking blocks in a single parse pass.
# SDK-built transcripts shouldn't have progress entries, but strip for safety.
stripped = strip_for_upload(content)
if not validate_transcript(stripped):
if skip_strip:
# Caller guarantees the content is already clean and valid.
stripped = content
else:
# Strip metadata entries and stale thinking blocks in a single parse.
# SDK-built transcripts may have progress entries; strip for safety.
stripped = strip_for_upload(content)
if not skip_strip and not validate_transcript(stripped):
# Log entry types for debugging — helps identify why validation failed
entry_types = [
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
@@ -695,27 +703,34 @@ async def upload_transcript(
storage = await get_workspace_storage()
wid, fid, fname = _storage_path_parts(user_id, session_id)
encoded = stripped.encode("utf-8")
meta = {"message_count": message_count, "uploaded_at": time.time()}
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
meta_encoded = json.dumps(meta).encode("utf-8")
await storage.store(
workspace_id=wid,
file_id=fid,
filename=fname,
content=encoded,
)
# Update metadata so message_count stays current. The gap-fill logic
# in _build_query_message relies on it to avoid re-compressing messages.
try:
meta = {"message_count": message_count, "uploaded_at": time.time()}
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
await storage.store(
# Transcript + metadata are independent objects at different keys, so
# write them concurrently. ``return_exceptions`` keeps a metadata
# failure from sinking the transcript write.
transcript_result, metadata_result = await asyncio.gather(
storage.store(
workspace_id=wid,
file_id=fid,
filename=fname,
content=encoded,
),
storage.store(
workspace_id=mwid,
file_id=mfid,
filename=mfname,
content=json.dumps(meta).encode("utf-8"),
)
except Exception as e:
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
content=meta_encoded,
),
return_exceptions=True,
)
if isinstance(transcript_result, BaseException):
raise transcript_result
if isinstance(metadata_result, BaseException):
# Metadata is best-effort — the gap-fill logic in
# _build_query_message tolerates a missing metadata file.
logger.warning("%s Failed to write metadata: %s", log_prefix, metadata_result)
logger.info(
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",

View File

@@ -9,10 +9,10 @@ import { toast } from "@/components/molecules/Toast/use-toast";
import { InputGroup } from "@/components/ui/input-group";
import { cn } from "@/lib/utils";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { Brain, Lightning } from "@phosphor-icons/react";
import { ChangeEvent, useEffect, useState } from "react";
import { AttachmentMenu } from "./components/AttachmentMenu";
import { FileChips } from "./components/FileChips";
import { ModeToggleButton } from "./components/ModeToggleButton";
import { RecordingButton } from "./components/RecordingButton";
import { RecordingIndicator } from "./components/RecordingIndicator";
import { useCopilotUIStore } from "../../store";
@@ -47,7 +47,7 @@ export function ChatInput({
onDroppedFilesConsumed,
}: Props) {
const { copilotMode, setCopilotMode } = useCopilotUIStore();
const isFastModeEnabled = useGetFlag(Flag.CHAT_MODE_OPTION);
const showModeToggle = useGetFlag(Flag.CHAT_MODE_OPTION);
const [files, setFiles] = useState<File[]>([]);
function handleToggleMode() {
@@ -179,43 +179,12 @@ export function ChatInput({
onFilesSelected={handleFilesSelected}
disabled={isBusy}
/>
{isFastModeEnabled && (
<button
type="button"
role="switch"
aria-checked={copilotMode === "fast"}
disabled={isStreaming}
onClick={handleToggleMode}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
copilotMode === "extended_thinking"
? "bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-300"
: "bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-300",
isStreaming && "cursor-not-allowed opacity-50",
)}
aria-label={
copilotMode === "extended_thinking"
? "Switch to Fast mode"
: "Switch to Extended Thinking mode"
}
title={
copilotMode === "extended_thinking"
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
}
>
{copilotMode === "extended_thinking" ? (
<>
<Brain size={14} />
Thinking
</>
) : (
<>
<Lightning size={14} />
Fast
</>
)}
</button>
{showModeToggle && (
<ModeToggleButton
mode={copilotMode}
isStreaming={isStreaming}
onToggle={handleToggleMode}
/>
)}
</PromptInputTools>

View File

@@ -159,20 +159,29 @@ describe("ChatInput mode toggle", () => {
expect(button.hasAttribute("disabled")).toBe(true);
});
it("exposes role='switch' with aria-checked", () => {
it("exposes aria-pressed=true in extended_thinking mode", () => {
mockFlagValue = true;
mockCopilotMode = "extended_thinking";
render(<ChatInput onSend={mockOnSend} />);
const button = screen.getByRole("switch");
expect(button.getAttribute("aria-checked")).toBe("false");
const button = screen.getByLabelText(/switch to fast mode/i);
expect(button.getAttribute("aria-pressed")).toBe("true");
});
it("sets aria-checked=true in fast mode", () => {
it("sets aria-pressed=false in fast mode", () => {
mockFlagValue = true;
mockCopilotMode = "fast";
render(<ChatInput onSend={mockOnSend} />);
const button = screen.getByRole("switch");
expect(button.getAttribute("aria-checked")).toBe("true");
const button = screen.getByLabelText(/switch to extended thinking/i);
expect(button.getAttribute("aria-pressed")).toBe("false");
});
it("uses streaming-specific tooltip when disabled", () => {
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} isStreaming />);
const button = screen.getByLabelText(/switch to fast mode/i);
expect(button.getAttribute("title")).toBe(
"Mode cannot be changed while streaming",
);
});
it("shows a toast when the user toggles mode", async () => {
@@ -180,7 +189,7 @@ describe("ChatInput mode toggle", () => {
mockFlagValue = true;
mockCopilotMode = "extended_thinking";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByRole("switch"));
fireEvent.click(screen.getByLabelText(/switch to fast mode/i));
expect(toast).toHaveBeenCalledWith(
expect.objectContaining({
title: expect.stringMatching(/switched to fast mode/i),

View File

@@ -0,0 +1,53 @@
"use client";
import { cn } from "@/lib/utils";
import { Brain, Lightning } from "@phosphor-icons/react";
type CopilotMode = "extended_thinking" | "fast";
interface Props {
mode: CopilotMode;
isStreaming: boolean;
onToggle: () => void;
}
export function ModeToggleButton({ mode, isStreaming, onToggle }: Props) {
const isExtended = mode === "extended_thinking";
return (
<button
type="button"
aria-pressed={isExtended}
disabled={isStreaming}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isExtended
? "bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-300"
: "bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-300",
isStreaming && "cursor-not-allowed opacity-50",
)}
aria-label={
isExtended ? "Switch to Fast mode" : "Switch to Extended Thinking mode"
}
title={
isStreaming
? "Mode cannot be changed while streaming"
: isExtended
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
}
>
{isExtended ? (
<>
<Brain size={14} />
Thinking
</>
) : (
<>
<Lightning size={14} />
Fast
</>
)}
</button>
);
}