mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-08 03:00:28 -04:00
fix(copilot): address remaining PR #12623 review items
Blockers: - Rename `_resolve_use_sdk`/`_resolve_effective_mode` to `resolve_use_sdk_for_mode`/`resolve_effective_mode` in processor.py so the mode-routing logic is importable; tests now exercise the production functions directly instead of a local copy. - Extract `is_transcript_stale`/`should_upload_transcript` helpers in baseline/service.py and cover them with direct unit tests, replacing the duplicated boolean expressions in transcript_integration_test. Should-fix: - Add `TestTranscriptLifecycle` that drives the download -> validate -> build -> upload flow end-to-end with mocked storage. - Avoid the triple JSONL parse on upload: rely on the transcript builder's `last_entry_type == "assistant"` invariant and thread `skip_strip=True` through `upload_transcript` for builder-generated content. - Run `_load_prior_transcript` and `_build_system_prompt` concurrently via `asyncio.gather` on the request critical path. - Add a compression round-trip test proving `tool_calls` and `tool_call_id` survive `_compress_session_messages`. - Extract the inline mode-toggle JSX into a dedicated `ModeToggleButton` sub-component. Nice-to-have: - Introduce `CopilotMode` type alias in `copilot/config.py` and reuse it across backend routes, executor utils, processor, and baseline service. - Bound the shielded transcript upload with `asyncio.wait_for(..., 30)` so a hung storage backend cannot block response completion. - Trim the 7 private re-exports from `sdk/transcript.py` shim; tests that needed the privates now import them from the canonical `backend.copilot.transcript`. - Upload the transcript and its metadata sidecar concurrently via `asyncio.gather` with `return_exceptions=True`. Nits: - Rename `isFastModeEnabled` to `showModeToggle`. - Narrow `except Exception` to `(ValueError, TypeError, orjson.JSONDecodeError)` around tool-call argument parsing. - Replace `role=\"switch\" aria-checked` with `aria-pressed` on the toggle button (a11y-correct for a toggle button role). - Surface a streaming-specific tooltip when the toggle is disabled.
This commit is contained in:
@@ -4,7 +4,7 @@ import asyncio
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated, Literal
|
||||
from typing import Annotated
|
||||
from uuid import uuid4
|
||||
|
||||
from autogpt_libs import auth
|
||||
@@ -15,7 +15,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from backend.copilot import service as chat_service
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.config import ChatConfig, CopilotMode
|
||||
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
@@ -111,7 +111,7 @@ class StreamChatRequest(BaseModel):
|
||||
file_ids: list[str] | None = Field(
|
||||
default=None, max_length=20
|
||||
) # Workspace file IDs attached to this message
|
||||
mode: Literal["fast", "extended_thinking"] | None = Field(
|
||||
mode: CopilotMode | None = Field(
|
||||
default=None,
|
||||
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
|
||||
"If None, uses the server default (extended_thinking).",
|
||||
|
||||
@@ -12,12 +12,13 @@ import uuid
|
||||
from collections.abc import AsyncGenerator, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, cast
|
||||
|
||||
import orjson
|
||||
from langfuse import propagate_attributes
|
||||
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
|
||||
|
||||
from backend.copilot.config import CopilotMode
|
||||
from backend.copilot.context import set_execution_context
|
||||
from backend.copilot.model import (
|
||||
ChatMessage,
|
||||
@@ -55,6 +56,7 @@ from backend.copilot.tracking import track_user_message
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
STOP_REASON_TOOL_USE,
|
||||
TranscriptDownload,
|
||||
download_transcript,
|
||||
upload_transcript,
|
||||
validate_transcript,
|
||||
@@ -82,7 +84,7 @@ _background_tasks: set[asyncio.Task[Any]] = set()
|
||||
_MAX_TOOL_ROUNDS = 30
|
||||
|
||||
|
||||
def _resolve_baseline_model(mode: Literal["fast", "extended_thinking"] | None) -> str:
|
||||
def _resolve_baseline_model(mode: CopilotMode | None) -> str:
|
||||
"""Pick the model for the baseline path based on the per-request mode.
|
||||
|
||||
Only ``mode='fast'`` downgrades to the cheaper/faster model. Any other
|
||||
@@ -356,7 +358,7 @@ def _record_turn_to_transcript(
|
||||
for tc in response.tool_calls:
|
||||
try:
|
||||
args = orjson.loads(tc.arguments) if tc.arguments else {}
|
||||
except Exception as parse_err:
|
||||
except (ValueError, TypeError, orjson.JSONDecodeError) as parse_err:
|
||||
logger.debug(
|
||||
"[Baseline] Failed to parse tool_call arguments "
|
||||
"(tool=%s, id=%s): %s",
|
||||
@@ -466,8 +468,7 @@ async def _compress_session_messages(
|
||||
|
||||
if result.was_compacted:
|
||||
logger.info(
|
||||
"[Baseline] Context compacted: %d -> %d tokens "
|
||||
"(%d summarized, %d dropped)",
|
||||
"[Baseline] Context compacted: %d -> %d tokens (%d summarized, %d dropped)",
|
||||
result.original_token_count,
|
||||
result.token_count,
|
||||
result.messages_summarized,
|
||||
@@ -486,6 +487,39 @@ async def _compress_session_messages(
|
||||
return messages
|
||||
|
||||
|
||||
def is_transcript_stale(dl: TranscriptDownload | None, session_msg_count: int) -> bool:
|
||||
"""Return ``True`` when a download doesn't cover the current session.
|
||||
|
||||
A transcript is stale when it has a known ``message_count`` and that
|
||||
count doesn't reach ``session_msg_count - 1`` (i.e. the session has
|
||||
already advanced beyond what the stored transcript captures).
|
||||
Loading a stale transcript would silently drop intermediate turns,
|
||||
so callers should treat stale as "skip load, skip upload".
|
||||
|
||||
An unknown ``message_count`` (``0``) is treated as **not stale**
|
||||
because older transcripts uploaded before msg_count tracking
|
||||
existed must still be usable.
|
||||
"""
|
||||
if dl is None:
|
||||
return False
|
||||
if not dl.message_count:
|
||||
return False
|
||||
return dl.message_count < session_msg_count - 1
|
||||
|
||||
|
||||
def should_upload_transcript(
|
||||
user_id: str | None, transcript_covers_prefix: bool
|
||||
) -> bool:
|
||||
"""Return ``True`` when the caller should upload the final transcript.
|
||||
|
||||
Uploads require a logged-in user (for the storage key) *and* a
|
||||
transcript that covered the session prefix when loaded — otherwise
|
||||
we'd be overwriting a more complete version in storage with a
|
||||
partial one built from just the current turn.
|
||||
"""
|
||||
return bool(user_id) and transcript_covers_prefix
|
||||
|
||||
|
||||
async def _load_prior_transcript(
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
@@ -513,10 +547,7 @@ async def _load_prior_transcript(
|
||||
logger.warning("[Baseline] Downloaded transcript but invalid")
|
||||
return False
|
||||
|
||||
# Reject stale transcripts: if msg_count is known and doesn't cover
|
||||
# the current session, loading it would silently drop intermediate
|
||||
# turns from the transcript.
|
||||
if dl.message_count and dl.message_count < session_msg_count - 1:
|
||||
if is_transcript_stale(dl, session_msg_count):
|
||||
logger.warning(
|
||||
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
|
||||
dl.message_count,
|
||||
@@ -539,21 +570,37 @@ async def _upload_final_transcript(
|
||||
transcript_builder: TranscriptBuilder,
|
||||
session_msg_count: int,
|
||||
) -> None:
|
||||
"""Serialize and upload the transcript for next-turn continuity."""
|
||||
"""Serialize and upload the transcript for next-turn continuity.
|
||||
|
||||
Uses the builder's own invariants to decide whether to upload,
|
||||
avoiding a JSONL re-parse. A builder that ends with an assistant
|
||||
entry is structurally complete; a builder that doesn't (empty, or
|
||||
ends mid-turn) is skipped.
|
||||
"""
|
||||
try:
|
||||
content = transcript_builder.to_jsonl()
|
||||
if content and validate_transcript(content):
|
||||
await asyncio.shield(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=content,
|
||||
message_count=session_msg_count,
|
||||
log_prefix="[Baseline]",
|
||||
)
|
||||
if transcript_builder.last_entry_type != "assistant":
|
||||
logger.debug(
|
||||
"[Baseline] No complete assistant turn to upload (last_entry=%s)",
|
||||
transcript_builder.last_entry_type,
|
||||
)
|
||||
else:
|
||||
logger.debug("[Baseline] No valid transcript to upload")
|
||||
return
|
||||
content = transcript_builder.to_jsonl()
|
||||
if not content:
|
||||
logger.debug("[Baseline] Empty transcript content, skipping upload")
|
||||
return
|
||||
upload_task = asyncio.shield(
|
||||
upload_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
content=content,
|
||||
message_count=session_msg_count,
|
||||
log_prefix="[Baseline]",
|
||||
skip_strip=True,
|
||||
)
|
||||
)
|
||||
# Bound the shielded upload: a hung storage backend must not
|
||||
# block the response from finishing.
|
||||
await asyncio.wait_for(upload_task, timeout=30)
|
||||
except Exception as upload_err:
|
||||
logger.error("[Baseline] Transcript upload failed: %s", upload_err)
|
||||
|
||||
@@ -564,7 +611,7 @@ async def stream_chat_completion_baseline(
|
||||
is_user_message: bool = True,
|
||||
user_id: str | None = None,
|
||||
session: ChatSession | None = None,
|
||||
mode: Literal["fast", "extended_thinking"] | None = None,
|
||||
mode: CopilotMode | None = None,
|
||||
**_kwargs: Any,
|
||||
) -> AsyncGenerator[StreamBaseResponse, None]:
|
||||
"""Baseline LLM with tool calling via OpenAI-compatible API.
|
||||
@@ -601,13 +648,28 @@ async def stream_chat_completion_baseline(
|
||||
transcript_builder = TranscriptBuilder()
|
||||
transcript_covers_prefix = True
|
||||
|
||||
# Build system prompt only on the first turn to avoid mid-conversation
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
if is_first_turn:
|
||||
prompt_task = _build_system_prompt(user_id, has_conversation_history=False)
|
||||
else:
|
||||
prompt_task = _build_system_prompt(user_id=None, has_conversation_history=True)
|
||||
|
||||
# Run download + prompt build concurrently — both are independent I/O
|
||||
# on the request critical path.
|
||||
if user_id and len(session.messages) > 1:
|
||||
transcript_covers_prefix = await _load_prior_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
session_msg_count=len(session.messages),
|
||||
transcript_builder=transcript_builder,
|
||||
transcript_covers_prefix, (base_system_prompt, _) = await asyncio.gather(
|
||||
_load_prior_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
session_msg_count=len(session.messages),
|
||||
transcript_builder=transcript_builder,
|
||||
),
|
||||
prompt_task,
|
||||
)
|
||||
else:
|
||||
base_system_prompt, _ = await prompt_task
|
||||
|
||||
# Append user message to transcript.
|
||||
# Always append when the message is present and is from the user,
|
||||
@@ -633,18 +695,6 @@ async def stream_chat_completion_baseline(
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
# Build system prompt only on the first turn to avoid mid-conversation
|
||||
# changes from concurrent chats updating business understanding.
|
||||
is_first_turn = len(session.messages) <= 1
|
||||
if is_first_turn:
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id, has_conversation_history=False
|
||||
)
|
||||
else:
|
||||
base_system_prompt, _ = await _build_system_prompt(
|
||||
user_id=None, has_conversation_history=True
|
||||
)
|
||||
|
||||
# Append tool documentation and technical notes
|
||||
system_prompt = base_system_prompt + get_baseline_supplement()
|
||||
|
||||
@@ -839,7 +889,7 @@ async def stream_chat_completion_baseline(
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
if user_id and transcript_covers_prefix:
|
||||
if user_id and should_upload_transcript(user_id, transcript_covers_prefix):
|
||||
await _upload_final_transcript(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
|
||||
@@ -4,11 +4,18 @@ These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState`
|
||||
without requiring API keys, database connections, or network access.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.baseline.service import (
|
||||
_baseline_conversation_updater,
|
||||
_BaselineStreamState,
|
||||
_compress_session_messages,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.transcript_builder import TranscriptBuilder
|
||||
from backend.util.prompt import CompressResult
|
||||
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
|
||||
|
||||
|
||||
@@ -232,3 +239,129 @@ class TestBaselineConversationUpdater:
|
||||
# Should not raise — invalid JSON falls back to {} in transcript
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["tool_calls"][0]["function"]["arguments"] == "not-json"
|
||||
|
||||
|
||||
class TestCompressSessionMessagesPreservesToolCalls:
|
||||
"""``_compress_session_messages`` must round-trip tool_calls + tool_call_id.
|
||||
|
||||
Compression serialises ChatMessage to dict for ``compress_context`` and
|
||||
reifies the result back to ChatMessage. A regression that drops
|
||||
``tool_calls`` or ``tool_call_id`` would corrupt the OpenAI message
|
||||
list and break downstream tool-execution rounds.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compressed_output_keeps_tool_calls_and_ids(self):
|
||||
# Simulate compression that returns a summary + the most recent
|
||||
# assistant(tool_call) + tool(tool_result) intact.
|
||||
summary = {"role": "system", "content": "prior turns: user asked X"}
|
||||
assistant_with_tc = {
|
||||
"role": "assistant",
|
||||
"content": "calling tool",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "tc_abc",
|
||||
"type": "function",
|
||||
"function": {"name": "search", "arguments": '{"q":"y"}'},
|
||||
}
|
||||
],
|
||||
}
|
||||
tool_result = {
|
||||
"role": "tool",
|
||||
"tool_call_id": "tc_abc",
|
||||
"content": "search result",
|
||||
}
|
||||
|
||||
compress_result = CompressResult(
|
||||
messages=[summary, assistant_with_tc, tool_result],
|
||||
token_count=100,
|
||||
was_compacted=True,
|
||||
original_token_count=5000,
|
||||
messages_summarized=10,
|
||||
messages_dropped=0,
|
||||
)
|
||||
|
||||
# Input: messages that should be compressed.
|
||||
input_messages = [
|
||||
ChatMessage(role="user", content="q1"),
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="calling tool",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "tc_abc",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search",
|
||||
"arguments": '{"q":"y"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(
|
||||
role="tool",
|
||||
tool_call_id="tc_abc",
|
||||
content="search result",
|
||||
),
|
||||
]
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.compress_context",
|
||||
new=AsyncMock(return_value=compress_result),
|
||||
):
|
||||
compressed = await _compress_session_messages(
|
||||
input_messages, model="openrouter/anthropic/claude-opus-4"
|
||||
)
|
||||
|
||||
# Summary, assistant(tool_calls), tool(tool_call_id).
|
||||
assert len(compressed) == 3
|
||||
# Assistant message must keep its tool_calls intact.
|
||||
assistant_msg = compressed[1]
|
||||
assert assistant_msg.role == "assistant"
|
||||
assert assistant_msg.tool_calls is not None
|
||||
assert len(assistant_msg.tool_calls) == 1
|
||||
assert assistant_msg.tool_calls[0]["id"] == "tc_abc"
|
||||
assert assistant_msg.tool_calls[0]["function"]["name"] == "search"
|
||||
# Tool-role message must keep tool_call_id for OpenAI linkage.
|
||||
tool_msg = compressed[2]
|
||||
assert tool_msg.role == "tool"
|
||||
assert tool_msg.tool_call_id == "tc_abc"
|
||||
assert tool_msg.content == "search result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_uncompressed_passthrough_keeps_fields(self):
|
||||
"""When compression is a no-op (was_compacted=False), the original
|
||||
messages must be returned unchanged — including tool_calls."""
|
||||
input_messages = [
|
||||
ChatMessage(
|
||||
role="assistant",
|
||||
content="c",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "t1",
|
||||
"type": "function",
|
||||
"function": {"name": "f", "arguments": "{}"},
|
||||
}
|
||||
],
|
||||
),
|
||||
ChatMessage(role="tool", tool_call_id="t1", content="ok"),
|
||||
]
|
||||
|
||||
noop_result = CompressResult(
|
||||
messages=[], # ignored when was_compacted=False
|
||||
token_count=10,
|
||||
was_compacted=False,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.baseline.service.compress_context",
|
||||
new=AsyncMock(return_value=noop_result),
|
||||
):
|
||||
out = await _compress_session_messages(
|
||||
input_messages, model="openrouter/anthropic/claude-opus-4"
|
||||
)
|
||||
|
||||
assert out is input_messages # same list returned
|
||||
assert out[0].tool_calls is not None
|
||||
assert out[0].tool_calls[0]["id"] == "t1"
|
||||
assert out[1].tool_call_id == "t1"
|
||||
|
||||
@@ -16,6 +16,8 @@ from backend.copilot.baseline.service import (
|
||||
_record_turn_to_transcript,
|
||||
_resolve_baseline_model,
|
||||
_upload_final_transcript,
|
||||
is_transcript_stale,
|
||||
should_upload_transcript,
|
||||
)
|
||||
from backend.copilot.service import config
|
||||
from backend.copilot.transcript import (
|
||||
@@ -455,3 +457,211 @@ class TestRoundTrip:
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
assert builder.entry_count == initial_count
|
||||
|
||||
|
||||
class TestIsTranscriptStale:
|
||||
"""``is_transcript_stale`` gates prior-transcript loading."""
|
||||
|
||||
def test_none_download_is_not_stale(self):
|
||||
assert is_transcript_stale(None, session_msg_count=5) is False
|
||||
|
||||
def test_zero_message_count_is_not_stale(self):
|
||||
"""Legacy transcripts without msg_count tracking must remain usable."""
|
||||
dl = TranscriptDownload(content="", message_count=0)
|
||||
assert is_transcript_stale(dl, session_msg_count=20) is False
|
||||
|
||||
def test_stale_when_covers_less_than_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=2)
|
||||
# session has 6 messages; transcript must cover at least 5 (6-1).
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is True
|
||||
|
||||
def test_fresh_when_covers_full_prefix(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_fresh_when_exceeds_prefix(self):
|
||||
"""Race: transcript ahead of session count is still acceptable."""
|
||||
dl = TranscriptDownload(content="", message_count=10)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
def test_boundary_equal_to_prefix_minus_one(self):
|
||||
dl = TranscriptDownload(content="", message_count=5)
|
||||
assert is_transcript_stale(dl, session_msg_count=6) is False
|
||||
|
||||
|
||||
class TestShouldUploadTranscript:
|
||||
"""``should_upload_transcript`` gates the final upload."""
|
||||
|
||||
def test_upload_allowed_for_user_with_coverage(self):
|
||||
assert should_upload_transcript("user-1", True) is True
|
||||
|
||||
def test_upload_skipped_when_no_user(self):
|
||||
assert should_upload_transcript(None, True) is False
|
||||
|
||||
def test_upload_skipped_when_empty_user(self):
|
||||
assert should_upload_transcript("", True) is False
|
||||
|
||||
def test_upload_skipped_without_coverage(self):
|
||||
"""Partial transcript must never clobber a more complete stored one."""
|
||||
assert should_upload_transcript("user-1", False) is False
|
||||
|
||||
def test_upload_skipped_when_no_user_and_no_coverage(self):
|
||||
assert should_upload_transcript(None, False) is False
|
||||
|
||||
|
||||
class TestTranscriptLifecycle:
|
||||
"""End-to-end: download → validate → build → upload.
|
||||
|
||||
Simulates the full transcript lifecycle inside
|
||||
``stream_chat_completion_baseline`` by mocking the storage layer and
|
||||
driving each step through the real helpers.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_full_lifecycle_happy_path(self):
|
||||
"""Fresh download, append a turn, upload covers the session."""
|
||||
builder = TranscriptBuilder()
|
||||
prior = _make_transcript_content("user", "assistant")
|
||||
download = TranscriptDownload(content=prior, message_count=2)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=download),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
# --- 1. Download & load prior transcript ---
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=3,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
assert covers is True
|
||||
|
||||
# --- 2. Append a new user turn + a new assistant response ---
|
||||
builder.append_user(content="follow-up question")
|
||||
_record_turn_to_transcript(
|
||||
LLMLoopResponse(
|
||||
response_text="follow-up answer",
|
||||
tool_calls=[],
|
||||
raw_response=None,
|
||||
),
|
||||
tool_results=None,
|
||||
transcript_builder=builder,
|
||||
model="test-model",
|
||||
)
|
||||
|
||||
# --- 3. Gate + upload ---
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is True
|
||||
)
|
||||
await _upload_final_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
transcript_builder=builder,
|
||||
session_msg_count=4,
|
||||
)
|
||||
|
||||
upload_mock.assert_awaited_once()
|
||||
assert upload_mock.await_args is not None
|
||||
uploaded = upload_mock.await_args.kwargs["content"]
|
||||
assert "follow-up question" in uploaded
|
||||
assert "follow-up answer" in uploaded
|
||||
# Original prior-turn content preserved.
|
||||
assert "user message 0" in uploaded
|
||||
assert "assistant message 1" in uploaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_stale_download_suppresses_upload(self):
|
||||
"""Stale download → covers=False → upload must be skipped."""
|
||||
builder = TranscriptBuilder()
|
||||
# session has 10 msgs but stored transcript only covers 2 → stale.
|
||||
stale = TranscriptDownload(
|
||||
content=_make_transcript_content("user", "assistant"),
|
||||
message_count=2,
|
||||
)
|
||||
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=stale),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=10,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
|
||||
assert covers is False
|
||||
# The caller's gate mirrors the production path.
|
||||
assert (
|
||||
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_anonymous_user_skips_upload(self):
|
||||
"""Anonymous (user_id=None) → upload gate must return False."""
|
||||
builder = TranscriptBuilder()
|
||||
builder.append_user(content="hi")
|
||||
builder.append_assistant(
|
||||
content_blocks=[{"type": "text", "text": "hello"}],
|
||||
model="test-model",
|
||||
stop_reason=STOP_REASON_END_TURN,
|
||||
)
|
||||
|
||||
assert (
|
||||
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
|
||||
is False
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lifecycle_missing_download_still_uploads_new_content(self):
|
||||
"""No prior transcript → covers defaults to True in the service,
|
||||
new turn should upload cleanly."""
|
||||
builder = TranscriptBuilder()
|
||||
upload_mock = AsyncMock(return_value=None)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.baseline.service.download_transcript",
|
||||
new=AsyncMock(return_value=None),
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.baseline.service.upload_transcript",
|
||||
new=upload_mock,
|
||||
),
|
||||
):
|
||||
covers = await _load_prior_transcript(
|
||||
user_id="user-1",
|
||||
session_id="session-1",
|
||||
session_msg_count=1,
|
||||
transcript_builder=builder,
|
||||
)
|
||||
# No download: covers is False, so the production path would
|
||||
# skip upload. This protects against overwriting a future
|
||||
# more-complete transcript with a single-turn snapshot.
|
||||
assert covers is False
|
||||
assert (
|
||||
should_upload_transcript(
|
||||
user_id="user-1", transcript_covers_prefix=covers
|
||||
)
|
||||
is False
|
||||
)
|
||||
upload_mock.assert_not_awaited()
|
||||
|
||||
@@ -8,6 +8,14 @@ from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
|
||||
# Per-request routing mode for a single chat turn.
|
||||
# - 'fast': route to the baseline OpenAI-compatible path with the cheaper model.
|
||||
# - 'extended_thinking': route to the Claude Agent SDK path with the default
|
||||
# (opus) model.
|
||||
# ``None`` means "no override"; the server falls back to the Claude Code
|
||||
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
|
||||
CopilotMode = Literal["fast", "extended_thinking"]
|
||||
|
||||
|
||||
class ChatConfig(BaseSettings):
|
||||
"""Configuration for the chat system."""
|
||||
|
||||
@@ -10,11 +10,10 @@ import os
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
from typing import Literal
|
||||
|
||||
from backend.copilot import stream_registry
|
||||
from backend.copilot.baseline import stream_chat_completion_baseline
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.config import ChatConfig, CopilotMode
|
||||
from backend.copilot.response_model import StreamError
|
||||
from backend.copilot.sdk import service as sdk_service
|
||||
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
|
||||
@@ -34,10 +33,10 @@ logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]"
|
||||
# ============ Mode Routing ============ #
|
||||
|
||||
|
||||
async def _resolve_effective_mode(
|
||||
mode: Literal["fast", "extended_thinking"] | None,
|
||||
async def resolve_effective_mode(
|
||||
mode: CopilotMode | None,
|
||||
user_id: str | None,
|
||||
) -> Literal["fast", "extended_thinking"] | None:
|
||||
) -> CopilotMode | None:
|
||||
"""Strip ``mode`` when the user is not entitled to the toggle.
|
||||
|
||||
The UI gates the mode toggle behind ``CHAT_MODE_OPTION``; the
|
||||
@@ -57,8 +56,8 @@ async def _resolve_effective_mode(
|
||||
return mode
|
||||
|
||||
|
||||
async def _resolve_use_sdk(
|
||||
mode: Literal["fast", "extended_thinking"] | None,
|
||||
async def resolve_use_sdk_for_mode(
|
||||
mode: CopilotMode | None,
|
||||
user_id: str | None,
|
||||
*,
|
||||
use_claude_code_subscription: bool,
|
||||
@@ -306,10 +305,8 @@ class CoPilotProcessor:
|
||||
else:
|
||||
# Enforce server-side feature-flag gate so unauthorised
|
||||
# users cannot force a mode by crafting the request.
|
||||
effective_mode = await _resolve_effective_mode(
|
||||
entry.mode, entry.user_id
|
||||
)
|
||||
use_sdk = await _resolve_use_sdk(
|
||||
effective_mode = await resolve_effective_mode(entry.mode, entry.user_id)
|
||||
use_sdk = await resolve_use_sdk_for_mode(
|
||||
effective_mode,
|
||||
entry.user_id,
|
||||
use_claude_code_subscription=config.use_claude_code_subscription,
|
||||
|
||||
@@ -14,10 +14,13 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.executor.processor import _resolve_effective_mode, _resolve_use_sdk
|
||||
from backend.copilot.executor.processor import (
|
||||
resolve_effective_mode,
|
||||
resolve_use_sdk_for_mode,
|
||||
)
|
||||
|
||||
|
||||
class TestResolveUseSdk:
|
||||
class TestResolveUseSdkForMode:
|
||||
"""Tests for the per-request mode routing logic."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -28,7 +31,7 @@ class TestResolveUseSdk:
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert (
|
||||
await _resolve_use_sdk(
|
||||
await resolve_use_sdk_for_mode(
|
||||
"fast",
|
||||
"user-1",
|
||||
use_claude_code_subscription=True,
|
||||
@@ -45,7 +48,7 @@ class TestResolveUseSdk:
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await _resolve_use_sdk(
|
||||
await resolve_use_sdk_for_mode(
|
||||
"extended_thinking",
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
@@ -62,7 +65,7 @@ class TestResolveUseSdk:
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await _resolve_use_sdk(
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=True,
|
||||
@@ -79,7 +82,7 @@ class TestResolveUseSdk:
|
||||
new=AsyncMock(return_value=True),
|
||||
) as flag_mock:
|
||||
assert (
|
||||
await _resolve_use_sdk(
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
@@ -98,7 +101,7 @@ class TestResolveUseSdk:
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert (
|
||||
await _resolve_use_sdk(
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
@@ -115,7 +118,7 @@ class TestResolveUseSdk:
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert (
|
||||
await _resolve_use_sdk(
|
||||
await resolve_use_sdk_for_mode(
|
||||
None,
|
||||
"user-1",
|
||||
use_claude_code_subscription=False,
|
||||
@@ -135,7 +138,7 @@ class TestResolveEffectiveMode:
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
) as flag_mock:
|
||||
assert await _resolve_effective_mode(None, "user-1") is None
|
||||
assert await resolve_effective_mode(None, "user-1") is None
|
||||
flag_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -145,8 +148,8 @@ class TestResolveEffectiveMode:
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
assert await _resolve_effective_mode("fast", "user-1") is None
|
||||
assert await _resolve_effective_mode("extended_thinking", "user-1") is None
|
||||
assert await resolve_effective_mode("fast", "user-1") is None
|
||||
assert await resolve_effective_mode("extended_thinking", "user-1") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mode_preserved_when_flag_enabled(self):
|
||||
@@ -155,9 +158,9 @@ class TestResolveEffectiveMode:
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
assert await _resolve_effective_mode("fast", "user-1") == "fast"
|
||||
assert await resolve_effective_mode("fast", "user-1") == "fast"
|
||||
assert (
|
||||
await _resolve_effective_mode("extended_thinking", "user-1")
|
||||
await resolve_effective_mode("extended_thinking", "user-1")
|
||||
== "extended_thinking"
|
||||
)
|
||||
|
||||
@@ -168,5 +171,5 @@ class TestResolveEffectiveMode:
|
||||
"backend.copilot.executor.processor.is_feature_enabled",
|
||||
new=AsyncMock(return_value=False),
|
||||
) as flag_mock:
|
||||
assert await _resolve_effective_mode("fast", None) is None
|
||||
assert await resolve_effective_mode("fast", None) is None
|
||||
flag_mock.assert_awaited_once()
|
||||
|
||||
@@ -6,10 +6,10 @@ Defines two exchanges and queues following the graph executor pattern:
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from backend.copilot.config import CopilotMode
|
||||
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
|
||||
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
|
||||
|
||||
@@ -157,7 +157,7 @@ class CoPilotExecutionEntry(BaseModel):
|
||||
file_ids: list[str] | None = None
|
||||
"""Workspace file IDs attached to the user's message"""
|
||||
|
||||
mode: Literal["fast", "extended_thinking"] | None = None
|
||||
mode: CopilotMode | None = None
|
||||
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
|
||||
|
||||
|
||||
@@ -179,7 +179,7 @@ async def enqueue_copilot_turn(
|
||||
is_user_message: bool = True,
|
||||
context: dict[str, str] | None = None,
|
||||
file_ids: list[str] | None = None,
|
||||
mode: Literal["fast", "extended_thinking"] | None = None,
|
||||
mode: CopilotMode | None = None,
|
||||
) -> None:
|
||||
"""Enqueue a CoPilot task for processing by the executor service.
|
||||
|
||||
|
||||
@@ -8,20 +8,19 @@ from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _friendly_error_text, _is_prompt_too_long
|
||||
from .transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_run_compression,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.util import json
|
||||
from backend.util.prompt import CompressResult
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _friendly_error_text, _is_prompt_too_long
|
||||
from .transcript import compact_transcript, validate_transcript
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _flatten_assistant_content
|
||||
|
||||
@@ -26,18 +26,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _MAX_STREAM_ATTEMPTS, _reduce_context
|
||||
from .transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_test_transcript as _build_transcript
|
||||
from .service import _MAX_STREAM_ATTEMPTS, _reduce_context
|
||||
from .transcript import compact_transcript, validate_transcript
|
||||
from .transcript_builder import TranscriptBuilder
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1405,9 +1404,9 @@ class TestStreamChatCompletionRetryIntegration:
|
||||
events.append(event)
|
||||
|
||||
# Should NOT retry — only 1 attempt for auth errors
|
||||
assert attempt_count[0] == 1, (
|
||||
f"Expected 1 attempt (no retry for auth error), " f"got {attempt_count[0]}"
|
||||
)
|
||||
assert (
|
||||
attempt_count[0] == 1
|
||||
), f"Expected 1 attempt (no retry for auth error), got {attempt_count[0]}"
|
||||
errors = [e for e in events if isinstance(e, StreamError)]
|
||||
assert errors, "Expected StreamError"
|
||||
assert errors[0].code == "sdk_stream_error"
|
||||
|
||||
@@ -27,20 +27,19 @@ from backend.copilot.response_model import (
|
||||
StreamTextDelta,
|
||||
StreamTextStart,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_structured_transcript
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .service import _format_sdk_content_blocks
|
||||
from .transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_find_last_assistant_entry,
|
||||
_flatten_assistant_content,
|
||||
_messages_to_transcript,
|
||||
_rechain_tail,
|
||||
_transcript_to_messages,
|
||||
compact_transcript,
|
||||
validate_transcript,
|
||||
)
|
||||
from backend.util import json
|
||||
|
||||
from .conftest import build_structured_transcript
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .service import _format_sdk_content_blocks
|
||||
from .transcript import compact_transcript, validate_transcript
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures: realistic thinking block content
|
||||
|
||||
@@ -1,28 +1,19 @@
|
||||
"""Re-export from shared ``backend.copilot.transcript`` for backward compat.
|
||||
"""Re-export public API from shared ``backend.copilot.transcript``.
|
||||
|
||||
The canonical implementation now lives at ``backend.copilot.transcript``
|
||||
so both the SDK and baseline paths can import without cross-package
|
||||
dependencies. All symbols are re-exported here so existing ``from
|
||||
dependencies. Public symbols are re-exported here so existing ``from
|
||||
.transcript import ...`` statements within the ``sdk`` package continue
|
||||
to work without modification.
|
||||
"""
|
||||
|
||||
from backend.copilot.transcript import (
|
||||
_MAX_PROJECT_DIRS_TO_SWEEP,
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
COMPACT_MSG_ID_PREFIX,
|
||||
ENTRY_TYPE_MESSAGE,
|
||||
STOP_REASON_END_TURN,
|
||||
STRIPPABLE_TYPES,
|
||||
TRANSCRIPT_STORAGE_PREFIX,
|
||||
TranscriptDownload,
|
||||
_find_last_assistant_entry,
|
||||
_flatten_assistant_content,
|
||||
_flatten_tool_result_content,
|
||||
_messages_to_transcript,
|
||||
_rechain_tail,
|
||||
_run_compression,
|
||||
_transcript_to_messages,
|
||||
cleanup_stale_project_dirs,
|
||||
compact_transcript,
|
||||
delete_transcript,
|
||||
@@ -43,15 +34,6 @@ __all__ = [
|
||||
"STRIPPABLE_TYPES",
|
||||
"TRANSCRIPT_STORAGE_PREFIX",
|
||||
"TranscriptDownload",
|
||||
"_MAX_PROJECT_DIRS_TO_SWEEP",
|
||||
"_STALE_PROJECT_DIR_SECONDS",
|
||||
"_find_last_assistant_entry",
|
||||
"_flatten_assistant_content",
|
||||
"_flatten_tool_result_content",
|
||||
"_messages_to_transcript",
|
||||
"_rechain_tail",
|
||||
"_run_compression",
|
||||
"_transcript_to_messages",
|
||||
"cleanup_stale_project_dirs",
|
||||
"compact_transcript",
|
||||
"delete_transcript",
|
||||
|
||||
@@ -850,7 +850,7 @@ class TestRunCompression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_client_uses_truncation(self):
|
||||
"""Path (a): ``get_openai_client()`` returns None → truncation only."""
|
||||
from .transcript import _run_compression
|
||||
from backend.copilot.transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated"}]
|
||||
@@ -885,7 +885,7 @@ class TestRunCompression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_success_returns_llm_result(self):
|
||||
"""Path (b): ``get_openai_client()`` returns a client → LLM compresses."""
|
||||
from .transcript import _run_compression
|
||||
from backend.copilot.transcript import _run_compression
|
||||
|
||||
llm_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "LLM summary"}]
|
||||
@@ -916,7 +916,7 @@ class TestRunCompression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_failure_falls_back_to_truncation(self):
|
||||
"""Path (c): LLM call raises → truncation fallback used instead."""
|
||||
from .transcript import _run_compression
|
||||
from backend.copilot.transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated fallback"}]
|
||||
@@ -953,7 +953,7 @@ class TestRunCompression:
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_timeout_falls_back_to_truncation(self):
|
||||
"""Path (d): LLM call exceeds timeout → truncation fallback used."""
|
||||
from .transcript import _run_compression
|
||||
from backend.copilot.transcript import _run_compression
|
||||
|
||||
truncation_result = self._make_compress_result(
|
||||
True, [{"role": "user", "content": "truncated after timeout"}]
|
||||
@@ -1007,7 +1007,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_removes_old_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories matching copilot pattern older than threshold are removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1039,7 +1039,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_ignores_non_copilot_dirs(self, tmp_path, monkeypatch):
|
||||
"""Directories not matching copilot pattern are left alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
from backend.copilot.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
@@ -1062,7 +1062,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_ttl_boundary_not_removed(self, tmp_path, monkeypatch):
|
||||
"""A directory exactly at the TTL boundary should NOT be removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1088,7 +1088,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_skips_non_directory_entries(self, tmp_path, monkeypatch):
|
||||
"""Regular files matching the copilot pattern are not removed."""
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1114,7 +1114,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_missing_base_dir_returns_zero(self, tmp_path, monkeypatch):
|
||||
"""If the projects base directory doesn't exist, return 0 gracefully."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
from backend.copilot.transcript import cleanup_stale_project_dirs
|
||||
|
||||
nonexistent = str(tmp_path / "does-not-exist" / "projects")
|
||||
monkeypatch.setattr(
|
||||
@@ -1129,7 +1129,7 @@ class TestCleanupStaleProjectDirs:
|
||||
"""When encoded_cwd is supplied only that directory is swept."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
@@ -1160,7 +1160,7 @@ class TestCleanupStaleProjectDirs:
|
||||
|
||||
def test_scoped_fresh_dir_not_removed(self, tmp_path, monkeypatch):
|
||||
"""Scoped sweep leaves a fresh directory alone."""
|
||||
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
|
||||
from backend.copilot.transcript import cleanup_stale_project_dirs
|
||||
|
||||
projects_dir = tmp_path / "projects"
|
||||
projects_dir.mkdir()
|
||||
@@ -1181,7 +1181,7 @@ class TestCleanupStaleProjectDirs:
|
||||
"""Scoped sweep refuses to remove a non-copilot directory."""
|
||||
import time
|
||||
|
||||
from backend.copilot.sdk.transcript import (
|
||||
from backend.copilot.transcript import (
|
||||
_STALE_PROJECT_DIR_SECONDS,
|
||||
cleanup_stale_project_dirs,
|
||||
)
|
||||
|
||||
@@ -658,6 +658,7 @@ async def upload_transcript(
|
||||
content: str,
|
||||
message_count: int = 0,
|
||||
log_prefix: str = "[Transcript]",
|
||||
skip_strip: bool = False,
|
||||
) -> None:
|
||||
"""Strip progress entries and stale thinking blocks, then upload transcript.
|
||||
|
||||
@@ -670,11 +671,18 @@ async def upload_transcript(
|
||||
Args:
|
||||
content: Complete JSONL transcript (from TranscriptBuilder).
|
||||
message_count: ``len(session.messages)`` at upload time.
|
||||
skip_strip: When ``True``, skip the strip + re-validate pass.
|
||||
Safe for builder-generated content (baseline path) which
|
||||
never emits progress entries or stale thinking blocks.
|
||||
"""
|
||||
# Strip metadata entries and stale thinking blocks in a single parse pass.
|
||||
# SDK-built transcripts shouldn't have progress entries, but strip for safety.
|
||||
stripped = strip_for_upload(content)
|
||||
if not validate_transcript(stripped):
|
||||
if skip_strip:
|
||||
# Caller guarantees the content is already clean and valid.
|
||||
stripped = content
|
||||
else:
|
||||
# Strip metadata entries and stale thinking blocks in a single parse.
|
||||
# SDK-built transcripts may have progress entries; strip for safety.
|
||||
stripped = strip_for_upload(content)
|
||||
if not skip_strip and not validate_transcript(stripped):
|
||||
# Log entry types for debugging — helps identify why validation failed
|
||||
entry_types = [
|
||||
json.loads(line, fallback={"type": "INVALID_JSON"}).get("type", "?")
|
||||
@@ -695,27 +703,34 @@ async def upload_transcript(
|
||||
storage = await get_workspace_storage()
|
||||
wid, fid, fname = _storage_path_parts(user_id, session_id)
|
||||
encoded = stripped.encode("utf-8")
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
meta_encoded = json.dumps(meta).encode("utf-8")
|
||||
|
||||
await storage.store(
|
||||
workspace_id=wid,
|
||||
file_id=fid,
|
||||
filename=fname,
|
||||
content=encoded,
|
||||
)
|
||||
|
||||
# Update metadata so message_count stays current. The gap-fill logic
|
||||
# in _build_query_message relies on it to avoid re-compressing messages.
|
||||
try:
|
||||
meta = {"message_count": message_count, "uploaded_at": time.time()}
|
||||
mwid, mfid, mfname = _meta_storage_path_parts(user_id, session_id)
|
||||
await storage.store(
|
||||
# Transcript + metadata are independent objects at different keys, so
|
||||
# write them concurrently. ``return_exceptions`` keeps a metadata
|
||||
# failure from sinking the transcript write.
|
||||
transcript_result, metadata_result = await asyncio.gather(
|
||||
storage.store(
|
||||
workspace_id=wid,
|
||||
file_id=fid,
|
||||
filename=fname,
|
||||
content=encoded,
|
||||
),
|
||||
storage.store(
|
||||
workspace_id=mwid,
|
||||
file_id=mfid,
|
||||
filename=mfname,
|
||||
content=json.dumps(meta).encode("utf-8"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, e)
|
||||
content=meta_encoded,
|
||||
),
|
||||
return_exceptions=True,
|
||||
)
|
||||
if isinstance(transcript_result, BaseException):
|
||||
raise transcript_result
|
||||
if isinstance(metadata_result, BaseException):
|
||||
# Metadata is best-effort — the gap-fill logic in
|
||||
# _build_query_message tolerates a missing metadata file.
|
||||
logger.warning("%s Failed to write metadata: %s", log_prefix, metadata_result)
|
||||
|
||||
logger.info(
|
||||
"%s Uploaded %dB (stripped from %dB, msg_count=%d)",
|
||||
|
||||
@@ -9,10 +9,10 @@ import { toast } from "@/components/molecules/Toast/use-toast";
|
||||
import { InputGroup } from "@/components/ui/input-group";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
|
||||
import { Brain, Lightning } from "@phosphor-icons/react";
|
||||
import { ChangeEvent, useEffect, useState } from "react";
|
||||
import { AttachmentMenu } from "./components/AttachmentMenu";
|
||||
import { FileChips } from "./components/FileChips";
|
||||
import { ModeToggleButton } from "./components/ModeToggleButton";
|
||||
import { RecordingButton } from "./components/RecordingButton";
|
||||
import { RecordingIndicator } from "./components/RecordingIndicator";
|
||||
import { useCopilotUIStore } from "../../store";
|
||||
@@ -47,7 +47,7 @@ export function ChatInput({
|
||||
onDroppedFilesConsumed,
|
||||
}: Props) {
|
||||
const { copilotMode, setCopilotMode } = useCopilotUIStore();
|
||||
const isFastModeEnabled = useGetFlag(Flag.CHAT_MODE_OPTION);
|
||||
const showModeToggle = useGetFlag(Flag.CHAT_MODE_OPTION);
|
||||
const [files, setFiles] = useState<File[]>([]);
|
||||
|
||||
function handleToggleMode() {
|
||||
@@ -179,43 +179,12 @@ export function ChatInput({
|
||||
onFilesSelected={handleFilesSelected}
|
||||
disabled={isBusy}
|
||||
/>
|
||||
{isFastModeEnabled && (
|
||||
<button
|
||||
type="button"
|
||||
role="switch"
|
||||
aria-checked={copilotMode === "fast"}
|
||||
disabled={isStreaming}
|
||||
onClick={handleToggleMode}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
copilotMode === "extended_thinking"
|
||||
? "bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-300"
|
||||
: "bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-300",
|
||||
isStreaming && "cursor-not-allowed opacity-50",
|
||||
)}
|
||||
aria-label={
|
||||
copilotMode === "extended_thinking"
|
||||
? "Switch to Fast mode"
|
||||
: "Switch to Extended Thinking mode"
|
||||
}
|
||||
title={
|
||||
copilotMode === "extended_thinking"
|
||||
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
|
||||
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
|
||||
}
|
||||
>
|
||||
{copilotMode === "extended_thinking" ? (
|
||||
<>
|
||||
<Brain size={14} />
|
||||
Thinking
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Lightning size={14} />
|
||||
Fast
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
{showModeToggle && (
|
||||
<ModeToggleButton
|
||||
mode={copilotMode}
|
||||
isStreaming={isStreaming}
|
||||
onToggle={handleToggleMode}
|
||||
/>
|
||||
)}
|
||||
</PromptInputTools>
|
||||
|
||||
|
||||
@@ -159,20 +159,29 @@ describe("ChatInput mode toggle", () => {
|
||||
expect(button.hasAttribute("disabled")).toBe(true);
|
||||
});
|
||||
|
||||
it("exposes role='switch' with aria-checked", () => {
|
||||
it("exposes aria-pressed=true in extended_thinking mode", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "extended_thinking";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
const button = screen.getByRole("switch");
|
||||
expect(button.getAttribute("aria-checked")).toBe("false");
|
||||
const button = screen.getByLabelText(/switch to fast mode/i);
|
||||
expect(button.getAttribute("aria-pressed")).toBe("true");
|
||||
});
|
||||
|
||||
it("sets aria-checked=true in fast mode", () => {
|
||||
it("sets aria-pressed=false in fast mode", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "fast";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
const button = screen.getByRole("switch");
|
||||
expect(button.getAttribute("aria-checked")).toBe("true");
|
||||
const button = screen.getByLabelText(/switch to extended thinking/i);
|
||||
expect(button.getAttribute("aria-pressed")).toBe("false");
|
||||
});
|
||||
|
||||
it("uses streaming-specific tooltip when disabled", () => {
|
||||
mockFlagValue = true;
|
||||
render(<ChatInput onSend={mockOnSend} isStreaming />);
|
||||
const button = screen.getByLabelText(/switch to fast mode/i);
|
||||
expect(button.getAttribute("title")).toBe(
|
||||
"Mode cannot be changed while streaming",
|
||||
);
|
||||
});
|
||||
|
||||
it("shows a toast when the user toggles mode", async () => {
|
||||
@@ -180,7 +189,7 @@ describe("ChatInput mode toggle", () => {
|
||||
mockFlagValue = true;
|
||||
mockCopilotMode = "extended_thinking";
|
||||
render(<ChatInput onSend={mockOnSend} />);
|
||||
fireEvent.click(screen.getByRole("switch"));
|
||||
fireEvent.click(screen.getByLabelText(/switch to fast mode/i));
|
||||
expect(toast).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
title: expect.stringMatching(/switched to fast mode/i),
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Brain, Lightning } from "@phosphor-icons/react";
|
||||
|
||||
type CopilotMode = "extended_thinking" | "fast";
|
||||
|
||||
interface Props {
|
||||
mode: CopilotMode;
|
||||
isStreaming: boolean;
|
||||
onToggle: () => void;
|
||||
}
|
||||
|
||||
export function ModeToggleButton({ mode, isStreaming, onToggle }: Props) {
|
||||
const isExtended = mode === "extended_thinking";
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
aria-pressed={isExtended}
|
||||
disabled={isStreaming}
|
||||
onClick={onToggle}
|
||||
className={cn(
|
||||
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
|
||||
isExtended
|
||||
? "bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-300"
|
||||
: "bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-300",
|
||||
isStreaming && "cursor-not-allowed opacity-50",
|
||||
)}
|
||||
aria-label={
|
||||
isExtended ? "Switch to Fast mode" : "Switch to Extended Thinking mode"
|
||||
}
|
||||
title={
|
||||
isStreaming
|
||||
? "Mode cannot be changed while streaming"
|
||||
: isExtended
|
||||
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
|
||||
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
|
||||
}
|
||||
>
|
||||
{isExtended ? (
|
||||
<>
|
||||
<Brain size={14} />
|
||||
Thinking
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Lightning size={14} />
|
||||
Fast
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user