Compare commits

...

58 Commits

Author SHA1 Message Date
Test Runner
bada09ff7e test: add UI verification screenshots for PR #12623 mode toggle 2026-04-05 14:41:48 +02:00
Zamil Majdy
2e7bbd879d test: add E2E test screenshots for PR #12623 2026-04-05 14:15:35 +02:00
Zamil Majdy
cdb2699477 fix(frontend): guard copilotMode storage.get behind isClient
Missed this initializer when porting the SSR guards — copilotMode was
calling storage.get() unconditionally while its neighbours
(completedSessionIDs, isSoundEnabled, isNotificationsEnabled) already
check isClient first. Flagged by Sentry (MEDIUM): storage.get() is
designed to throw on the server, so every SSR pass was generating
noise in error tracking.
2026-04-05 13:19:13 +02:00
Zamil Majdy
7d061a0de0 fix(frontend): surface a toast when send is suppressed during reconnect
shouldSuppressDuplicateSend returned a boolean for two very different
cases (reconnect-in-flight vs duplicate text echo), so the caller had
to silently drop both. Flagged by Sentry (MEDIUM): the React state
driving the disabled-input UI lags the synchronous ref by one render,
so a user can click send against a still-enabled input and have the
message vanish with no feedback.

Split into getSendSuppressionReason() returning "reconnecting" |
"duplicate" | null. useCopilotStream shows a toast for reconnect
("Wait for the connection to resume before sending") and keeps the
silent drop for duplicate-text echoes from the session. Boolean wrapper
retained as deprecated for back-compat.
2026-04-05 13:12:14 +02:00
Zamil Majdy
fe660d9aaf fix(copilot): track transcript upload as background task to prevent orphaned coroutine on timeout
asyncio.wait_for(asyncio.shield(coroutine), ...) leaks the inner
coroutine when the timeout fires because shield wraps a raw coroutine
that has no task to register against _background_tasks. Flagged by
Sentry (MEDIUM): under sustained GCS latency the orphaned tasks
accumulate and hold storage connections.

Use asyncio.create_task() to get a Task we can track, add it to
_background_tasks, and shield *that* task inside wait_for. On timeout,
the task keeps running and is cleaned up via the done_callback.
2026-04-05 13:04:43 +02:00
Zamil Majdy
84c3dd7000 fix(copilot): address remaining PR #12623 review items
Blockers:
- Rename `_resolve_use_sdk`/`_resolve_effective_mode` to
  `resolve_use_sdk_for_mode`/`resolve_effective_mode` in processor.py
  so the mode-routing logic is importable; tests now exercise the
  production functions directly instead of a local copy.
- Extract `is_transcript_stale`/`should_upload_transcript` helpers in
  baseline/service.py and cover them with direct unit tests, replacing
  the duplicated boolean expressions in transcript_integration_test.

Should-fix:
- Add `TestTranscriptLifecycle` that drives the download -> validate ->
  build -> upload flow end-to-end with mocked storage.
- Avoid the triple JSONL parse on upload: rely on the transcript
  builder's `last_entry_type == "assistant"` invariant and thread
  `skip_strip=True` through `upload_transcript` for builder-generated
  content.
- Run `_load_prior_transcript` and `_build_system_prompt` concurrently
  via `asyncio.gather` on the request critical path.
- Add a compression round-trip test proving `tool_calls` and
  `tool_call_id` survive `_compress_session_messages`.
- Extract the inline mode-toggle JSX into a dedicated
  `ModeToggleButton` sub-component.

Nice-to-have:
- Introduce `CopilotMode` type alias in `copilot/config.py` and reuse
  it across backend routes, executor utils, processor, and baseline
  service.
- Bound the shielded transcript upload with `asyncio.wait_for(..., 30)`
  so a hung storage backend cannot block response completion.
- Trim the 7 private re-exports from `sdk/transcript.py` shim; tests
  that needed the privates now import them from the canonical
  `backend.copilot.transcript`.
- Upload the transcript and its metadata sidecar concurrently via
  `asyncio.gather` with `return_exceptions=True`.

Nits:
- Rename `isFastModeEnabled` to `showModeToggle`.
- Narrow `except Exception` to `(ValueError, TypeError,
  orjson.JSONDecodeError)` around tool-call argument parsing.
- Replace `role=\"switch\" aria-checked` with `aria-pressed` on the
  toggle button (a11y-correct for a toggle button role).
- Surface a streaming-specific tooltip when the toggle is disabled.
2026-04-05 12:41:14 +02:00
Zamil Majdy
4e0d6bbde5 fix(copilot): address all review items on PR #12623
Blockers:
- B1: use config.model (opus) for baseline unless mode='fast' explicitly
  downgrades — prevents silent quality drop for all baseline users
- B2: extract _resolve_use_sdk / _resolve_effective_mode to production
  code in processor.py; tests now exercise real routing logic
- B3: add real unit tests covering transcript download/validate/load/
  append/backfill/upload round-trip via new service helpers

Should-fix:
- S1: server-side CHAT_MODE_OPTION flag gate in processor blocks
  unauthorised users from bypassing the UI toggle via crafted requests
- S2: frontend gates copilotMode behind the CHAT_MODE_OPTION flag in
  useCopilotPage so stale localStorage values aren't sent when off
- S3: Pydantic validation tests for StreamChatRequest.mode literal
- S4: extract _load_prior_transcript / _upload_final_transcript helpers
  and split _baseline_conversation_updater into message-mutation and
  transcript-recording concerns

Nice-to-have + Nits:
- N1: parallelize GCS transcript content + metadata downloads via gather
- N3: add role='switch' and aria-checked to mode toggle button
- N5: toast notification when user toggles mode
- Nit1: drop export on ChatInput Props
- Nit2: inline dismissRateLimit, drop unused useCallback import
- Nit3: replace 'end_turn'/'tool_use' magic strings with constants
- Nit4: log malformed tool_call JSON parse errors at debug level
2026-04-05 11:34:55 +02:00
Zamil Majdy
d9c59e3616 test(backend): add unit tests for transcript, rate_limit, and executor utils
Cover pure helper functions in the copilot modules to bring patch
coverage above the 80% codecov threshold:

- transcript_test.py: _sanitize_id, _flatten_assistant_content,
  _flatten_tool_result_content, _transcript_to_messages,
  _messages_to_transcript, _find_last_assistant_entry, _rechain_tail,
  strip_for_upload, validate_transcript, storage path helpers
- rate_limit_test.py: _daily_key, _weekly_key, _daily_reset_time,
  _weekly_reset_time, acquire_reset_lock, release_reset_lock,
  get_daily_reset_count, increment_daily_reset_count, reset_user_usage
- executor/utils_test.py: CoPilotExecutionEntry, CancelCoPilotEvent,
  create_copilot_queue_config, CoPilotLogMetadata
2026-04-03 23:45:44 +02:00
Zamil Majdy
457c2f4dca fix(copilot): address 5 should-fix review items from PR #12623
1. Combine strip_progress_entries + strip_stale_thinking_blocks into a
   single-parse strip_for_upload() to eliminate redundant JSONL parsing
   on transcript upload (3 parse passes -> 1).

2. Add baseline transcript integration tests covering stale detection,
   transcript_covers_prefix gating, partial backfill, and upload guards.

3. Add mode routing unit tests verifying fast->baseline,
   extended_thinking->SDK, and None->feature-flag/config fallback.

4. Extract hardcoded model/mode strings in fixer.py to named constants
   (ORCHESTRATOR_DEFAULT_MODEL, ORCHESTRATOR_DEFAULT_EXECUTION_MODE).

5. Add min-h-11 min-w-11 (44px) to copilot mode toggle button for
   WCAG-compliant touch target sizing.
2026-04-03 22:48:29 +02:00
Zamil Majdy
86c64ddb8e fix: resolve merge conflict with dev in baseline/service.py 2026-04-03 20:11:42 +02:00
Zamil Majdy
3a46fde82b fix(blocks): apply OrchestratorBlock defaults before AI model fixer
Move fix_orchestrator_blocks() before fix_ai_model_parameter() in
apply_all_fixes so the orchestrator-specific model (claude-opus-4-6) is
set first. Previously, fix_ai_model_parameter ran first and set the
generic default (gpt-4o) on OrchestratorBlock nodes; the later
fix_orchestrator_blocks then skipped the model field because it was
already present.
2026-04-03 16:40:43 +02:00
Zamil Majdy
25a216885f fix(backend): sort imports in sdk/service.py (isort) 2026-04-03 16:30:23 +02:00
Zamil Majdy
e621ecb824 style: remove unnecessary blank lines around COPILOT_MODE entries 2026-04-03 16:28:07 +02:00
Zamil Majdy
c82c3c34ef test(frontend): add coverage for copilot mode toggle and dedup helpers 2026-04-03 16:16:47 +02:00
Zamil Majdy
97268bb22c merge: resolve import conflict (keep both transcript + rate_limit imports) 2026-04-03 16:01:12 +02:00
Zamil Majdy
571a4c9540 style(frontend): format copilot store.ts with prettier 2026-04-03 15:55:38 +02:00
Zamil Majdy
6c7cdea55c fix(backend): format service_unit_test.py (isort + black) 2026-04-03 15:54:39 +02:00
Zamil Majdy
374694f64a ci: retrigger codecov evaluation 2026-04-03 15:08:47 +02:00
Zamil Majdy
7062ea7244 test(frontend): add coverage for completedSessions persistence in copilot store
Exercise the persistCompletedSessions code paths (add, clear single,
clear all) and verify COPILOT_COMPLETED_SESSIONS is cleaned in
clearCopilotLocalData, improving codecov/patch coverage for #12623.
2026-04-03 14:45:29 +02:00
Zamil Majdy
248293a1de style: format store.test.ts with prettier 2026-04-03 14:18:33 +02:00
Zamil Majdy
7b15c4d350 merge: resolve conflicts (keep both COPILOT_MODE and COPILOT_COMPLETED_SESSIONS) 2026-04-03 14:12:26 +02:00
Zamil Majdy
816736826e merge: resolve conflict in feature flags 2026-04-03 13:17:42 +02:00
Zamil Majdy
b753cb7d0b fix(copilot): emit empty-string parentUuid and backfill partial transcript on error
- TranscriptBuilder now uses parentUuid="" for root entries (matching
  _messages_to_transcript canonical format) instead of None which was
  dropped by exclude_none serialization
- Backfill partial assistant text into transcript before upload when
  stream aborts mid-round, preventing transcript divergence on
  mode-switch after failed turns
2026-04-03 13:13:05 +02:00
Zamil Majdy
60b6101f25 fix(frontend): prevent duplicate message submission via ref guard and reconnect dedup
Bug 1: Add synchronous isSubmittingRef guard in useChatInput to prevent
double-submit when user double-taps Enter or network is slow. The existing
isSending state is asynchronous and leaves a gap between calls.

Bug 2: Wrap sendMessage in useCopilotStream to block POSTs during active
reconnect cycles and skip re-sending messages the session already contains.
The reconnect path exclusively uses resumeStream (GET), so any sendMessage
call during reconnect would be a duplicate.
2026-04-03 13:03:35 +02:00
Zamil Majdy
45fe984d66 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/copilot-mode-toggle 2026-04-03 12:57:04 +02:00
Zamil Majdy
da10cf6f47 refactor: rename flag to CHAT_MODE_OPTION 2026-04-03 12:30:17 +02:00
Zamil Majdy
5feeaaf39a refactor: rename flag to COPILOT_FAST_MODE_OPTION 2026-04-03 12:29:43 +02:00
Zamil Majdy
09706ca8d2 feat(frontend): gate mode toggle behind COPILOT_FAST_MODE feature flag 2026-04-03 12:29:02 +02:00
Zamil Majdy
18dd829a89 feat(frontend): hide mode toggle button until baseline is fully tested
Backend bug fixes (duplicate message, tool_calls compression, stale
thinking blocks, stop_reason) are ready to ship. The toggle UI is
hidden with the HTML hidden attribute — remove it to re-enable.
2026-04-03 12:26:20 +02:00
Zamil Majdy
538e8619da Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/copilot-mode-toggle 2026-04-02 15:55:54 +02:00
Zamil Majdy
ad77e881c9 fix(backend/copilot): strip stale thinking blocks in upload_transcript
Add strip_stale_thinking_blocks() call to upload_transcript() alongside
the existing strip_progress_entries(). When a user switches from SDK
(extended_thinking) to baseline (fast) mode and back, the re-downloaded
transcript may contain stale thinking blocks from the SDK session.
Without stripping, these blocks consume significant tokens and trigger
unnecessary compaction cycles.
2026-04-02 14:50:50 +02:00
Zamil Majdy
49c7ab4011 fix(backend/copilot): set correct stop_reason in baseline transcript entries
Set stop_reason="tool_use" for assistant messages with tool calls and
stop_reason="end_turn" for final text responses. This ensures the
transcript format is compatible with the SDK's --resume flag when a
user switches from fast to extended_thinking mode mid-conversation.
2026-04-02 14:39:47 +02:00
Zamil Majdy
927c6e7db0 fix(frontend): add aria-label and disabled state to mode toggle button
- Add aria-label for screen reader accessibility
- Disable button during streaming to prevent confusing mode switches mid-turn
- Add opacity/cursor styling when disabled
2026-04-02 14:38:00 +02:00
Zamil Majdy
114f91ff53 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/copilot-mode-toggle 2026-04-02 14:32:47 +02:00
Zamil Majdy
697b15ce81 fix(backend/copilot): always append user message to transcript on retries
When a duplicate user message was suppressed (e.g. network retry), the
user turn was not added to the transcript builder while the assistant
reply still was, creating a malformed assistant-after-assistant structure
that broke conversation resumption. Now the user message is always
appended to the transcript when present and is_user_message, regardless
of whether the session-level dedup suppressed it.
2026-04-02 06:18:26 +02:00
Zamil Majdy
7f986bc565 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/copilot-mode-toggle 2026-04-01 20:51:50 +02:00
Zamil Majdy
6f679a0e32 fix(backend/copilot): preserve tool_calls and tool_call_id through context compression 2026-04-01 20:27:33 +02:00
Zamil Majdy
05495d8478 chore: remove accidentally committed test screenshots 2026-04-01 19:16:18 +02:00
Zamil Majdy
1a645e1e37 fix(backend/copilot): align _flatten_assistant_content with master (drop tool_use blocks)
The merge conflict resolution copied the pre-#12625 version of
_flatten_assistant_content which converts tool_use blocks to
[tool_use: name] placeholders. Master's #12625 changed this to
drop tool_use blocks entirely to prevent the model from mimicking
them as plain text. Align the canonical transcript.py with master.
2026-04-01 18:14:59 +02:00
Zamil Majdy
fd1d706315 fix(frontend): replace lucide-react icons with Phosphor equivalents in mode toggle
Use Brain and Lightning from @phosphor-icons/react instead of Brain and
Zap from lucide-react to comply with the project icon guidelines.
2026-04-01 18:04:44 +02:00
Zamil Majdy
89264091ad fix(backend/copilot): add missing strip_stale_thinking_blocks to canonical transcript module
The merge conflict resolution moved transcript.py to a re-export wrapper
but failed to copy strip_stale_thinking_blocks into the canonical
backend.copilot.transcript module. This caused an ImportError in
transcript_test.py which imports from the sdk wrapper.
2026-04-01 18:00:41 +02:00
Zamil Majdy
14ad37b0c7 fix: resolve merge conflict in transcript.py re-export module 2026-04-01 17:53:57 +02:00
Zamil Majdy
389cd28879 test: add round 3 E2E screenshots for PR #12623 2026-04-01 17:01:10 +02:00
Zamil Majdy
f0a3afda7d Add test screenshots for PR #12623 2026-04-01 08:49:33 +02:00
Zamil Majdy
9ffecbac02 fix(backend/copilot): add missing mode param to enqueue_copilot_turn docstring 2026-04-01 08:03:35 +02:00
Zamil Majdy
c2709fbc28 Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/copilot-mode-toggle 2026-04-01 06:14:49 +02:00
Zamil Majdy
3adbaacc0e Merge branch 'dev' of github.com:Significant-Gravitas/AutoGPT into feat/copilot-mode-toggle 2026-03-31 19:07:34 +02:00
Zamil Majdy
56e0b568a4 fix(backend): update tests for transcript module move and new fixer defaults
- Update patch targets in transcript tests from
  backend.copilot.sdk.transcript to backend.copilot.transcript since
  the re-export shim only re-exports public symbols; private names
  like _projects_base and get_openai_client live in the canonical module.
- Update orchestrator fixer test assertions to account for 2 new
  _SDM_DEFAULTS (execution_mode, model) and add execution_mode to the
  E2E test's mock block inputSchema.
2026-03-31 18:26:18 +02:00
Zamil Majdy
0b0777ac87 fix(copilot): update fix_orchestrator_blocks docstring to list all 6 defaults
The docstring only listed 4 defaults but _SDM_DEFAULTS has 6 entries
including execution_mode and model. Updated to reflect the actual behavior.
2026-03-31 17:49:54 +02:00
Zamil Majdy
698b1599cb fix(copilot): reject stale transcripts in baseline service 2026-03-31 17:41:06 +02:00
Zamil Majdy
a2f94f08d9 fix(copilot): address review comments round 3 2026-03-31 17:35:11 +02:00
Zamil Majdy
0c6f20f728 feat(copilot): set extended_thinking + Opus as OrchestratorBlock defaults
Update the agent generator fixer defaults so generated agents inherit
the copilot's default reasoning mode (extended_thinking with Opus).
User-set values are preserved — the fixer only fills in missing fields.
2026-03-31 17:23:06 +02:00
Zamil Majdy
d100b2515b fix(copilot): include tool messages in baseline conversation context
The baseline was only including user/assistant text messages when
building the OpenAI message list, dropping all tool_calls and tool
results. This meant the model had no memory of previous tool
invocations or their outputs in multi-turn conversations.

Now includes assistant messages with tool_calls and tool-role messages
with tool_call_id, giving the model full conversation context.
2026-03-31 17:12:37 +02:00
Zamil Majdy
14113f96a9 feat(copilot): use Sonnet for fast mode, Opus for extended thinking
Add `fast_model` config field (default: anthropic/claude-sonnet-4) so
fast mode uses a faster/cheaper model while extended thinking keeps
using Opus. The baseline service now uses config.fast_model for all
LLM calls.
2026-03-31 17:07:04 +02:00
Zamil Majdy
ee40a4b9a8 refactor(copilot): move transcript modules to shared location 2026-03-31 16:29:48 +02:00
Zamil Majdy
0008cafc3b fix(copilot): fix transcript ordering and mode toggle mid-session
- Fix transcript ordering: move append_tool_result from tool executor
  to conversation updater so entries follow correct API order
  (assistant tool_use → user tool_result)
- Fix mode toggle mid-session: use useRef for copilotMode so transport
  closure reads latest value without recreating DefaultChatTransport
- Use Literal type for mode in CoPilotExecutionEntry for type safety
2026-03-31 16:02:36 +02:00
Zamil Majdy
f55bc84fe7 fix(copilot): address PR review comments
- Use Literal["fast", "extended_thinking"] for mode validation (blocker)
- Wrap transcript upload in asyncio.shield() (should fix)
- Restore top-level estimate_token_count imports (nice to have)
- Guard localStorage copilotMode read against invalid values (should fix)
- Replace inline SVGs with lucide-react Brain/Zap icons (nice to have)
2026-03-31 15:52:06 +02:00
Zamil Majdy
3cfee4c4b5 feat(copilot): add mode toggle and baseline transcript support
- Add transcript support to baseline autopilot (download/upload/build)
  for feature parity with SDK path, enabling seamless mode switching
- Thread `mode` field through full stack: StreamChatRequest → queue →
  executor → service selection (fast=baseline, extended_thinking=SDK)
- Add mode toggle button in ChatInput UI with brain/lightning icons
- Persist mode preference in localStorage via Zustand store
2026-03-31 15:46:23 +02:00
54 changed files with 5778 additions and 1478 deletions

View File

@@ -15,7 +15,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
from backend.copilot import service as chat_service
from backend.copilot import stream_registry
from backend.copilot.config import ChatConfig
from backend.copilot.config import ChatConfig, CopilotMode
from backend.copilot.executor.utils import enqueue_cancel_task, enqueue_copilot_turn
from backend.copilot.model import (
ChatMessage,
@@ -111,6 +111,11 @@ class StreamChatRequest(BaseModel):
file_ids: list[str] | None = Field(
default=None, max_length=20
) # Workspace file IDs attached to this message
mode: CopilotMode | None = Field(
default=None,
description="Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. "
"If None, uses the server default (extended_thinking).",
)
class CreateSessionRequest(BaseModel):
@@ -840,6 +845,7 @@ async def stream_chat_post(
is_user_message=request.is_user_message,
context=request.context,
file_ids=sanitized_file_ids,
mode=request.mode,
)
setup_time = (time.perf_counter() - stream_start_time) * 1000

View File

@@ -541,3 +541,41 @@ def test_create_session_rejects_nested_metadata(
)
assert response.status_code == 422
class TestStreamChatRequestModeValidation:
"""Pydantic-level validation of the ``mode`` field on StreamChatRequest."""
def test_rejects_invalid_mode_value(self) -> None:
"""Any string outside the Literal set must raise ValidationError."""
from pydantic import ValidationError
from backend.api.features.chat.routes import StreamChatRequest
with pytest.raises(ValidationError):
StreamChatRequest(message="hi", mode="turbo") # type: ignore[arg-type]
def test_accepts_fast_mode(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode="fast")
assert req.mode == "fast"
def test_accepts_extended_thinking_mode(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode="extended_thinking")
assert req.mode == "extended_thinking"
def test_accepts_none_mode(self) -> None:
"""``mode=None`` is valid (server decides via feature flags)."""
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi", mode=None)
assert req.mode is None
def test_mode_defaults_to_none_when_omitted(self) -> None:
from backend.api.features.chat.routes import StreamChatRequest
req = StreamChatRequest(message="hi")
assert req.mode is None

View File

@@ -18,6 +18,7 @@ import orjson
from langfuse import propagate_attributes
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionToolParam
from backend.copilot.config import CopilotMode
from backend.copilot.context import set_execution_context
from backend.copilot.model import (
ChatMessage,
@@ -52,6 +53,15 @@ from backend.copilot.service import (
from backend.copilot.token_tracking import persist_and_record_usage
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
STOP_REASON_TOOL_USE,
TranscriptDownload,
download_transcript,
upload_transcript,
validate_transcript,
)
from backend.copilot.transcript_builder import TranscriptBuilder
from backend.util.exceptions import NotFoundError
from backend.util.prompt import (
compress_context,
@@ -74,6 +84,19 @@ _background_tasks: set[asyncio.Task[Any]] = set()
_MAX_TOOL_ROUNDS = 30
def _resolve_baseline_model(mode: CopilotMode | None) -> str:
"""Pick the model for the baseline path based on the per-request mode.
Only ``mode='fast'`` downgrades to the cheaper/faster model. Any other
value (including ``None`` and ``'extended_thinking'``) preserves the
default model so that users who never select a mode don't get
silently moved to the cheaper tier.
"""
if mode == "fast":
return config.fast_model
return config.model
@dataclass
class _BaselineStreamState:
"""Mutable state shared between the tool-call loop callbacks.
@@ -82,6 +105,7 @@ class _BaselineStreamState:
can be module-level functions instead of deeply nested closures.
"""
model: str = ""
pending_events: list[StreamBaseResponse] = field(default_factory=list)
assistant_text: str = ""
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
@@ -109,7 +133,7 @@ async def _baseline_llm_caller(
if tools:
typed_tools = cast(list[ChatCompletionToolParam], tools)
response = await client.chat.completions.create(
model=config.model,
model=state.model,
messages=typed_messages,
tools=typed_tools,
stream=True,
@@ -117,7 +141,7 @@ async def _baseline_llm_caller(
)
else:
response = await client.chat.completions.create(
model=config.model,
model=state.model,
messages=typed_messages,
stream=True,
stream_options={"include_usage": True},
@@ -279,17 +303,17 @@ async def _baseline_tool_executor(
)
def _baseline_conversation_updater(
def _mutate_openai_messages(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
tool_results: list[ToolCallResult] | None,
) -> None:
"""Update OpenAI message list with assistant response + tool results.
"""Append assistant / tool-result entries to the OpenAI message list.
Extracted from ``stream_chat_completion_baseline`` for readability.
This is the side-effect boundary for the next LLM call — no transcript
mutation happens here.
"""
if tool_results:
# Build assistant message with tool_calls
assistant_msg: dict[str, Any] = {"role": "assistant"}
if response.response_text:
assistant_msg["content"] = response.response_text
@@ -310,9 +334,89 @@ def _baseline_conversation_updater(
"content": tr.content,
}
)
else:
elif response.response_text:
messages.append({"role": "assistant", "content": response.response_text})
def _record_turn_to_transcript(
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None,
*,
transcript_builder: TranscriptBuilder,
model: str,
) -> None:
"""Append assistant + tool-result entries to the transcript builder.
Kept separate from :func:`_mutate_openai_messages` so the two
concerns (next-LLM-call payload vs. durable conversation log) can
evolve independently.
"""
if tool_results:
content_blocks: list[dict[str, Any]] = []
if response.response_text:
messages.append({"role": "assistant", "content": response.response_text})
content_blocks.append({"type": "text", "text": response.response_text})
for tc in response.tool_calls:
try:
args = orjson.loads(tc.arguments) if tc.arguments else {}
except (ValueError, TypeError, orjson.JSONDecodeError) as parse_err:
logger.debug(
"[Baseline] Failed to parse tool_call arguments "
"(tool=%s, id=%s): %s",
tc.name,
tc.id,
parse_err,
)
args = {}
content_blocks.append(
{
"type": "tool_use",
"id": tc.id,
"name": tc.name,
"input": args,
}
)
if content_blocks:
transcript_builder.append_assistant(
content_blocks=content_blocks,
model=model,
stop_reason=STOP_REASON_TOOL_USE,
)
for tr in tool_results:
# Record tool result to transcript AFTER the assistant tool_use
# block to maintain correct Anthropic API ordering:
# assistant(tool_use) → user(tool_result)
transcript_builder.append_tool_result(
tool_use_id=tr.tool_call_id,
content=tr.content,
)
elif response.response_text:
transcript_builder.append_assistant(
content_blocks=[{"type": "text", "text": response.response_text}],
model=model,
stop_reason=STOP_REASON_END_TURN,
)
def _baseline_conversation_updater(
messages: list[dict[str, Any]],
response: LLMLoopResponse,
tool_results: list[ToolCallResult] | None = None,
*,
transcript_builder: TranscriptBuilder,
model: str = "",
) -> None:
"""Update OpenAI message list with assistant response + tool results.
Thin composition of :func:`_mutate_openai_messages` and
:func:`_record_turn_to_transcript`.
"""
_mutate_openai_messages(messages, response, tool_results)
_record_turn_to_transcript(
response,
tool_results,
transcript_builder=transcript_builder,
model=model,
)
async def _update_title_async(
@@ -329,6 +433,7 @@ async def _update_title_async(
async def _compress_session_messages(
messages: list[ChatMessage],
model: str,
) -> list[ChatMessage]:
"""Compress session messages if they exceed the model's token limit.
@@ -341,45 +446,178 @@ async def _compress_session_messages(
msg_dict: dict[str, Any] = {"role": msg.role}
if msg.content:
msg_dict["content"] = msg.content
if msg.tool_calls:
msg_dict["tool_calls"] = msg.tool_calls
if msg.tool_call_id:
msg_dict["tool_call_id"] = msg.tool_call_id
messages_dict.append(msg_dict)
try:
result = await compress_context(
messages=messages_dict,
model=config.model,
model=model,
client=_get_openai_client(),
)
except Exception as e:
logger.warning("[Baseline] Context compression with LLM failed: %s", e)
result = await compress_context(
messages=messages_dict,
model=config.model,
model=model,
client=None,
)
if result.was_compacted:
logger.info(
"[Baseline] Context compacted: %d -> %d tokens "
"(%d summarized, %d dropped)",
"[Baseline] Context compacted: %d -> %d tokens (%d summarized, %d dropped)",
result.original_token_count,
result.token_count,
result.messages_summarized,
result.messages_dropped,
)
return [
ChatMessage(role=m["role"], content=m.get("content"))
ChatMessage(
role=m["role"],
content=m.get("content"),
tool_calls=m.get("tool_calls"),
tool_call_id=m.get("tool_call_id"),
)
for m in result.messages
]
return messages
def is_transcript_stale(dl: TranscriptDownload | None, session_msg_count: int) -> bool:
"""Return ``True`` when a download doesn't cover the current session.
A transcript is stale when it has a known ``message_count`` and that
count doesn't reach ``session_msg_count - 1`` (i.e. the session has
already advanced beyond what the stored transcript captures).
Loading a stale transcript would silently drop intermediate turns,
so callers should treat stale as "skip load, skip upload".
An unknown ``message_count`` (``0``) is treated as **not stale**
because older transcripts uploaded before msg_count tracking
existed must still be usable.
"""
if dl is None:
return False
if not dl.message_count:
return False
return dl.message_count < session_msg_count - 1
def should_upload_transcript(
user_id: str | None, transcript_covers_prefix: bool
) -> bool:
"""Return ``True`` when the caller should upload the final transcript.
Uploads require a logged-in user (for the storage key) *and* a
transcript that covered the session prefix when loaded — otherwise
we'd be overwriting a more complete version in storage with a
partial one built from just the current turn.
"""
return bool(user_id) and transcript_covers_prefix
async def _load_prior_transcript(
user_id: str,
session_id: str,
session_msg_count: int,
transcript_builder: TranscriptBuilder,
) -> bool:
"""Download and load the prior transcript into ``transcript_builder``.
Returns ``True`` when the loaded transcript fully covers the session
prefix; ``False`` otherwise (stale, missing, invalid, or download
error). Callers should suppress uploads when this returns ``False``
to avoid overwriting a more complete version in storage.
"""
try:
dl = await download_transcript(user_id, session_id, log_prefix="[Baseline]")
except Exception as e:
logger.warning("[Baseline] Transcript download failed: %s", e)
return False
if dl is None:
logger.debug("[Baseline] No transcript available")
return False
if not validate_transcript(dl.content):
logger.warning("[Baseline] Downloaded transcript but invalid")
return False
if is_transcript_stale(dl, session_msg_count):
logger.warning(
"[Baseline] Transcript stale: covers %d of %d messages, skipping",
dl.message_count,
session_msg_count,
)
return False
transcript_builder.load_previous(dl.content, log_prefix="[Baseline]")
logger.info(
"[Baseline] Loaded transcript: %dB, msg_count=%d",
len(dl.content),
dl.message_count,
)
return True
async def _upload_final_transcript(
user_id: str,
session_id: str,
transcript_builder: TranscriptBuilder,
session_msg_count: int,
) -> None:
"""Serialize and upload the transcript for next-turn continuity.
Uses the builder's own invariants to decide whether to upload,
avoiding a JSONL re-parse. A builder that ends with an assistant
entry is structurally complete; a builder that doesn't (empty, or
ends mid-turn) is skipped.
"""
try:
if transcript_builder.last_entry_type != "assistant":
logger.debug(
"[Baseline] No complete assistant turn to upload (last_entry=%s)",
transcript_builder.last_entry_type,
)
return
content = transcript_builder.to_jsonl()
if not content:
logger.debug("[Baseline] Empty transcript content, skipping upload")
return
# Track the upload as a background task so a timeout doesn't leak an
# orphaned coroutine; shield it so cancellation of this caller doesn't
# abort the in-flight GCS write.
upload_task = asyncio.create_task(
upload_transcript(
user_id=user_id,
session_id=session_id,
content=content,
message_count=session_msg_count,
log_prefix="[Baseline]",
skip_strip=True,
)
)
_background_tasks.add(upload_task)
upload_task.add_done_callback(_background_tasks.discard)
# Bound the wait: a hung storage backend must not block the response
# from finishing. The task keeps running in _background_tasks on
# timeout and will be cleaned up when it resolves.
await asyncio.wait_for(asyncio.shield(upload_task), timeout=30)
except Exception as upload_err:
logger.error("[Baseline] Transcript upload failed: %s", upload_err)
async def stream_chat_completion_baseline(
session_id: str,
message: str | None = None,
is_user_message: bool = True,
user_id: str | None = None,
session: ChatSession | None = None,
mode: CopilotMode | None = None,
**_kwargs: Any,
) -> AsyncGenerator[StreamBaseResponse, None]:
"""Baseline LLM with tool calling via OpenAI-compatible API.
@@ -408,6 +646,47 @@ async def stream_chat_completion_baseline(
session = await upsert_chat_session(session)
# Select model based on the per-request mode. 'fast' downgrades to
# the cheaper/faster model; everything else keeps the default.
active_model = _resolve_baseline_model(mode)
# --- Transcript support (feature parity with SDK path) ---
transcript_builder = TranscriptBuilder()
transcript_covers_prefix = True
# Build system prompt only on the first turn to avoid mid-conversation
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
if is_first_turn:
prompt_task = _build_system_prompt(user_id, has_conversation_history=False)
else:
prompt_task = _build_system_prompt(user_id=None, has_conversation_history=True)
# Run download + prompt build concurrently — both are independent I/O
# on the request critical path.
if user_id and len(session.messages) > 1:
transcript_covers_prefix, (base_system_prompt, _) = await asyncio.gather(
_load_prior_transcript(
user_id=user_id,
session_id=session_id,
session_msg_count=len(session.messages),
transcript_builder=transcript_builder,
),
prompt_task,
)
else:
base_system_prompt, _ = await prompt_task
# Append user message to transcript.
# Always append when the message is present and is from the user,
# even on duplicate-suppressed retries (is_new_message=False).
# The loaded transcript may be stale (uploaded before the previous
# attempt stored this message), so skipping it would leave the
# transcript without the user turn, creating a malformed
# assistant-after-assistant structure when the LLM reply is added.
if message and is_user_message:
transcript_builder.append_user(content=message)
# Generate title for new sessions
if is_user_message and not session.title:
user_messages = [m for m in session.messages if m.role == "user"]
@@ -422,30 +701,38 @@ async def stream_chat_completion_baseline(
message_id = str(uuid.uuid4())
# Build system prompt only on the first turn to avoid mid-conversation
# changes from concurrent chats updating business understanding.
is_first_turn = len(session.messages) <= 1
if is_first_turn:
base_system_prompt, _ = await _build_system_prompt(
user_id, has_conversation_history=False
)
else:
base_system_prompt, _ = await _build_system_prompt(
user_id=None, has_conversation_history=True
)
# Append tool documentation and technical notes
system_prompt = base_system_prompt + get_baseline_supplement()
# Compress context if approaching the model's token limit
messages_for_context = await _compress_session_messages(session.messages)
messages_for_context = await _compress_session_messages(
session.messages, model=active_model
)
# Build OpenAI message list from session history
# Build OpenAI message list from session history.
# Include tool_calls on assistant messages and tool-role results so the
# model retains full context of what tools were invoked and their outcomes.
openai_messages: list[dict[str, Any]] = [
{"role": "system", "content": system_prompt}
]
for msg in messages_for_context:
if msg.role in ("user", "assistant") and msg.content:
if msg.role == "assistant":
entry: dict[str, Any] = {"role": "assistant"}
if msg.content:
entry["content"] = msg.content
if msg.tool_calls:
entry["tool_calls"] = msg.tool_calls
if msg.content or msg.tool_calls:
openai_messages.append(entry)
elif msg.role == "tool" and msg.tool_call_id:
openai_messages.append(
{
"role": "tool",
"tool_call_id": msg.tool_call_id,
"content": msg.content or "",
}
)
elif msg.role == "user" and msg.content:
openai_messages.append({"role": msg.role, "content": msg.content})
tools = get_available_tools()
@@ -470,7 +757,7 @@ async def stream_chat_completion_baseline(
logger.warning("[Baseline] Langfuse trace context setup failed")
_stream_error = False # Track whether an error occurred during streaming
state = _BaselineStreamState()
state = _BaselineStreamState(model=active_model)
# Bind extracted module-level callbacks to this request's state/session
# using functools.partial so they satisfy the Protocol signatures.
@@ -479,6 +766,12 @@ async def stream_chat_completion_baseline(
_baseline_tool_executor, state=state, user_id=user_id, session=session
)
_bound_conversation_updater = partial(
_baseline_conversation_updater,
transcript_builder=transcript_builder,
model=active_model,
)
try:
loop_result = None
async for loop_result in tool_call_loop(
@@ -486,7 +779,7 @@ async def stream_chat_completion_baseline(
tools=tools,
llm_call=_bound_llm_caller,
execute_tool=_bound_tool_executor,
update_conversation=_baseline_conversation_updater,
update_conversation=_bound_conversation_updater,
max_iterations=_MAX_TOOL_ROUNDS,
):
# Drain buffered events after each iteration (real-time streaming)
@@ -555,10 +848,10 @@ async def stream_chat_completion_baseline(
and not (_stream_error and not state.assistant_text)
):
state.turn_prompt_tokens = max(
estimate_token_count(openai_messages, model=config.model), 1
estimate_token_count(openai_messages, model=active_model), 1
)
state.turn_completion_tokens = estimate_token_count_str(
state.assistant_text, model=config.model
state.assistant_text, model=active_model
)
logger.info(
"[Baseline] No streaming usage reported; estimated tokens: "
@@ -589,6 +882,27 @@ async def stream_chat_completion_baseline(
except Exception as persist_err:
logger.error("[Baseline] Failed to persist session: %s", persist_err)
# --- Upload transcript for next-turn continuity ---
# Backfill partial assistant text that wasn't recorded by the
# conversation updater (e.g. when the stream aborted mid-round).
# Without this, mode-switching after a failed turn would lose
# the partial assistant response from the transcript.
if _stream_error and state.assistant_text:
if transcript_builder.last_entry_type != "assistant":
transcript_builder.append_assistant(
content_blocks=[{"type": "text", "text": state.assistant_text}],
model=active_model,
stop_reason=STOP_REASON_END_TURN,
)
if user_id and should_upload_transcript(user_id, transcript_covers_prefix):
await _upload_final_transcript(
user_id=user_id,
session_id=session_id,
transcript_builder=transcript_builder,
session_msg_count=len(session.messages),
)
# Yield usage and finish AFTER try/finally (not inside finally).
# PEP 525 prohibits yielding from finally in async generators during
# aclose() — doing so raises RuntimeError on client disconnect.

View File

@@ -0,0 +1,367 @@
"""Unit tests for baseline service pure-logic helpers.
These tests cover ``_baseline_conversation_updater`` and ``_BaselineStreamState``
without requiring API keys, database connections, or network access.
"""
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.baseline.service import (
_baseline_conversation_updater,
_BaselineStreamState,
_compress_session_messages,
)
from backend.copilot.model import ChatMessage
from backend.copilot.transcript_builder import TranscriptBuilder
from backend.util.prompt import CompressResult
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
class TestBaselineStreamState:
def test_defaults(self):
state = _BaselineStreamState()
assert state.pending_events == []
assert state.assistant_text == ""
assert state.text_started is False
assert state.turn_prompt_tokens == 0
assert state.turn_completion_tokens == 0
assert state.text_block_id # Should be a UUID string
def test_mutable_fields(self):
state = _BaselineStreamState()
state.assistant_text = "hello"
state.turn_prompt_tokens = 100
state.turn_completion_tokens = 50
assert state.assistant_text == "hello"
assert state.turn_prompt_tokens == 100
assert state.turn_completion_tokens == 50
class TestBaselineConversationUpdater:
"""Tests for _baseline_conversation_updater which updates the OpenAI
message list and transcript builder after each LLM call."""
def _make_transcript_builder(self) -> TranscriptBuilder:
builder = TranscriptBuilder()
builder.append_user("test question")
return builder
def test_text_only_response(self):
"""When the LLM returns text without tool calls, the updater appends
a single assistant message and records it in the transcript."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text="Hello, world!",
tool_calls=[],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
_baseline_conversation_updater(
messages,
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert len(messages) == 1
assert messages[0]["role"] == "assistant"
assert messages[0]["content"] == "Hello, world!"
# Transcript should have user + assistant
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
def test_tool_calls_response(self):
"""When the LLM returns tool calls, the updater appends the assistant
message with tool_calls and tool result messages."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text="Let me search...",
tool_calls=[
LLMToolCall(
id="tc_1",
name="search",
arguments='{"query": "test"}',
),
],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results = [
ToolCallResult(
tool_call_id="tc_1",
tool_name="search",
content="Found result",
),
]
_baseline_conversation_updater(
messages,
response,
tool_results=tool_results,
transcript_builder=builder,
model="test-model",
)
# Messages: assistant (with tool_calls) + tool result
assert len(messages) == 2
assert messages[0]["role"] == "assistant"
assert messages[0]["content"] == "Let me search..."
assert len(messages[0]["tool_calls"]) == 1
assert messages[0]["tool_calls"][0]["id"] == "tc_1"
assert messages[1]["role"] == "tool"
assert messages[1]["tool_call_id"] == "tc_1"
assert messages[1]["content"] == "Found result"
# Transcript: user + assistant(tool_use) + user(tool_result)
assert builder.entry_count == 3
def test_tool_calls_without_text(self):
"""Tool calls without accompanying text should still work."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text=None,
tool_calls=[
LLMToolCall(id="tc_1", name="run", arguments="{}"),
],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results = [
ToolCallResult(tool_call_id="tc_1", tool_name="run", content="done"),
]
_baseline_conversation_updater(
messages,
response,
tool_results=tool_results,
transcript_builder=builder,
model="test-model",
)
assert len(messages) == 2
assert "content" not in messages[0] # No text content
assert messages[0]["tool_calls"][0]["function"]["name"] == "run"
def test_no_text_no_tools(self):
"""When the response has no text and no tool calls, nothing is appended."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text=None,
tool_calls=[],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
_baseline_conversation_updater(
messages,
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert len(messages) == 0
# Only the user entry from setup
assert builder.entry_count == 1
def test_multiple_tool_calls(self):
"""Multiple tool calls in a single response are all recorded."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text=None,
tool_calls=[
LLMToolCall(id="tc_1", name="tool_a", arguments="{}"),
LLMToolCall(id="tc_2", name="tool_b", arguments='{"x": 1}'),
],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results = [
ToolCallResult(tool_call_id="tc_1", tool_name="tool_a", content="result_a"),
ToolCallResult(tool_call_id="tc_2", tool_name="tool_b", content="result_b"),
]
_baseline_conversation_updater(
messages,
response,
tool_results=tool_results,
transcript_builder=builder,
model="test-model",
)
# 1 assistant + 2 tool results
assert len(messages) == 3
assert len(messages[0]["tool_calls"]) == 2
assert messages[1]["tool_call_id"] == "tc_1"
assert messages[2]["tool_call_id"] == "tc_2"
def test_invalid_tool_arguments_handled(self):
"""Tool call with invalid JSON arguments: the arguments field is
stored as-is in the message, and orjson failure falls back to {}
in the transcript content_blocks."""
messages: list = []
builder = self._make_transcript_builder()
response = LLMLoopResponse(
response_text=None,
tool_calls=[
LLMToolCall(id="tc_1", name="tool_x", arguments="not-json"),
],
raw_response=None,
prompt_tokens=0,
completion_tokens=0,
)
tool_results = [
ToolCallResult(tool_call_id="tc_1", tool_name="tool_x", content="ok"),
]
_baseline_conversation_updater(
messages,
response,
tool_results=tool_results,
transcript_builder=builder,
model="test-model",
)
# Should not raise — invalid JSON falls back to {} in transcript
assert len(messages) == 2
assert messages[0]["tool_calls"][0]["function"]["arguments"] == "not-json"
class TestCompressSessionMessagesPreservesToolCalls:
"""``_compress_session_messages`` must round-trip tool_calls + tool_call_id.
Compression serialises ChatMessage to dict for ``compress_context`` and
reifies the result back to ChatMessage. A regression that drops
``tool_calls`` or ``tool_call_id`` would corrupt the OpenAI message
list and break downstream tool-execution rounds.
"""
@pytest.mark.asyncio
async def test_compressed_output_keeps_tool_calls_and_ids(self):
# Simulate compression that returns a summary + the most recent
# assistant(tool_call) + tool(tool_result) intact.
summary = {"role": "system", "content": "prior turns: user asked X"}
assistant_with_tc = {
"role": "assistant",
"content": "calling tool",
"tool_calls": [
{
"id": "tc_abc",
"type": "function",
"function": {"name": "search", "arguments": '{"q":"y"}'},
}
],
}
tool_result = {
"role": "tool",
"tool_call_id": "tc_abc",
"content": "search result",
}
compress_result = CompressResult(
messages=[summary, assistant_with_tc, tool_result],
token_count=100,
was_compacted=True,
original_token_count=5000,
messages_summarized=10,
messages_dropped=0,
)
# Input: messages that should be compressed.
input_messages = [
ChatMessage(role="user", content="q1"),
ChatMessage(
role="assistant",
content="calling tool",
tool_calls=[
{
"id": "tc_abc",
"type": "function",
"function": {
"name": "search",
"arguments": '{"q":"y"}',
},
}
],
),
ChatMessage(
role="tool",
tool_call_id="tc_abc",
content="search result",
),
]
with patch(
"backend.copilot.baseline.service.compress_context",
new=AsyncMock(return_value=compress_result),
):
compressed = await _compress_session_messages(
input_messages, model="openrouter/anthropic/claude-opus-4"
)
# Summary, assistant(tool_calls), tool(tool_call_id).
assert len(compressed) == 3
# Assistant message must keep its tool_calls intact.
assistant_msg = compressed[1]
assert assistant_msg.role == "assistant"
assert assistant_msg.tool_calls is not None
assert len(assistant_msg.tool_calls) == 1
assert assistant_msg.tool_calls[0]["id"] == "tc_abc"
assert assistant_msg.tool_calls[0]["function"]["name"] == "search"
# Tool-role message must keep tool_call_id for OpenAI linkage.
tool_msg = compressed[2]
assert tool_msg.role == "tool"
assert tool_msg.tool_call_id == "tc_abc"
assert tool_msg.content == "search result"
@pytest.mark.asyncio
async def test_uncompressed_passthrough_keeps_fields(self):
"""When compression is a no-op (was_compacted=False), the original
messages must be returned unchanged — including tool_calls."""
input_messages = [
ChatMessage(
role="assistant",
content="c",
tool_calls=[
{
"id": "t1",
"type": "function",
"function": {"name": "f", "arguments": "{}"},
}
],
),
ChatMessage(role="tool", tool_call_id="t1", content="ok"),
]
noop_result = CompressResult(
messages=[], # ignored when was_compacted=False
token_count=10,
was_compacted=False,
)
with patch(
"backend.copilot.baseline.service.compress_context",
new=AsyncMock(return_value=noop_result),
):
out = await _compress_session_messages(
input_messages, model="openrouter/anthropic/claude-opus-4"
)
assert out is input_messages # same list returned
assert out[0].tool_calls is not None
assert out[0].tool_calls[0]["id"] == "t1"
assert out[1].tool_call_id == "t1"

View File

@@ -0,0 +1,667 @@
"""Integration tests for baseline transcript flow.
Exercises the real helpers in ``baseline/service.py`` that download,
validate, load, append to, backfill, and upload the transcript.
Storage is mocked via ``download_transcript`` / ``upload_transcript``
patches; no network access is required.
"""
import json as stdlib_json
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.baseline.service import (
_load_prior_transcript,
_record_turn_to_transcript,
_resolve_baseline_model,
_upload_final_transcript,
is_transcript_stale,
should_upload_transcript,
)
from backend.copilot.service import config
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
STOP_REASON_TOOL_USE,
TranscriptDownload,
)
from backend.copilot.transcript_builder import TranscriptBuilder
from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallResult
def _make_transcript_content(*roles: str) -> str:
"""Build a minimal valid JSONL transcript from role names."""
lines = []
parent = ""
for i, role in enumerate(roles):
uid = f"uuid-{i}"
entry: dict = {
"type": role,
"uuid": uid,
"parentUuid": parent,
"message": {
"role": role,
"content": [{"type": "text", "text": f"{role} message {i}"}],
},
}
if role == "assistant":
entry["message"]["id"] = f"msg_{i}"
entry["message"]["model"] = "test-model"
entry["message"]["type"] = "message"
entry["message"]["stop_reason"] = STOP_REASON_END_TURN
lines.append(stdlib_json.dumps(entry))
parent = uid
return "\n".join(lines) + "\n"
class TestResolveBaselineModel:
"""Model selection honours the per-request mode."""
def test_fast_mode_selects_fast_model(self):
assert _resolve_baseline_model("fast") == config.fast_model
def test_extended_thinking_selects_default_model(self):
assert _resolve_baseline_model("extended_thinking") == config.model
def test_none_mode_selects_default_model(self):
"""Critical: baseline users without a mode MUST keep the default (opus)."""
assert _resolve_baseline_model(None) == config.model
def test_default_and_fast_models_differ(self):
"""Sanity: the two tiers are actually distinct in production config."""
assert config.model != config.fast_model
class TestLoadPriorTranscript:
"""``_load_prior_transcript`` wraps the download + validate + load flow."""
@pytest.mark.asyncio
async def test_loads_fresh_transcript(self):
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=content, message_count=2)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
transcript_builder=builder,
)
assert covers is True
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
@pytest.mark.asyncio
async def test_rejects_stale_transcript(self):
"""msg_count strictly less than session-1 is treated as stale."""
builder = TranscriptBuilder()
content = _make_transcript_content("user", "assistant")
# session has 6 messages, transcript only covers 2 → stale.
download = TranscriptDownload(content=content, message_count=2)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=6,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_missing_transcript_returns_false(self):
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=None),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_invalid_transcript_returns_false(self):
builder = TranscriptBuilder()
download = TranscriptDownload(
content='{"type":"progress","uuid":"a"}\n',
message_count=1,
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_download_exception_returns_false(self):
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(side_effect=RuntimeError("boom")),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=2,
transcript_builder=builder,
)
assert covers is False
assert builder.is_empty
@pytest.mark.asyncio
async def test_zero_message_count_not_stale(self):
"""When msg_count is 0 (unknown), staleness check is skipped."""
builder = TranscriptBuilder()
download = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
message_count=0,
)
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=20,
transcript_builder=builder,
)
assert covers is True
assert builder.entry_count == 2
class TestUploadFinalTranscript:
"""``_upload_final_transcript`` serialises and calls storage."""
@pytest.mark.asyncio
async def test_uploads_valid_transcript(self):
builder = TranscriptBuilder()
builder.append_user(content="hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "hello"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=2,
)
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
call_kwargs = upload_mock.await_args.kwargs
assert call_kwargs["user_id"] == "user-1"
assert call_kwargs["session_id"] == "session-1"
assert call_kwargs["message_count"] == 2
assert "hello" in call_kwargs["content"]
@pytest.mark.asyncio
async def test_skips_upload_when_builder_empty(self):
builder = TranscriptBuilder()
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=0,
)
upload_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_swallows_upload_exceptions(self):
"""Upload failures should not propagate (flow continues for the user)."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "hello"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=AsyncMock(side_effect=RuntimeError("storage unavailable")),
):
# Should not raise.
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=2,
)
class TestRecordTurnToTranscript:
"""``_record_turn_to_transcript`` translates LLMLoopResponse → transcript."""
def test_records_final_assistant_text(self):
builder = TranscriptBuilder()
builder.append_user(content="hi")
response = LLMLoopResponse(
response_text="hello there",
tool_calls=[],
raw_response=None,
)
_record_turn_to_transcript(
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
jsonl = builder.to_jsonl()
assert "hello there" in jsonl
assert STOP_REASON_END_TURN in jsonl
def test_records_tool_use_then_tool_result(self):
"""Anthropic ordering: assistant(tool_use) → user(tool_result)."""
builder = TranscriptBuilder()
builder.append_user(content="use a tool")
response = LLMLoopResponse(
response_text=None,
tool_calls=[
LLMToolCall(id="call-1", name="echo", arguments='{"text":"hi"}')
],
raw_response=None,
)
tool_results = [
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="hi")
]
_record_turn_to_transcript(
response,
tool_results,
transcript_builder=builder,
model="test-model",
)
# user, assistant(tool_use), user(tool_result) = 3 entries
assert builder.entry_count == 3
jsonl = builder.to_jsonl()
assert STOP_REASON_TOOL_USE in jsonl
assert "tool_use" in jsonl
assert "tool_result" in jsonl
assert "call-1" in jsonl
def test_records_nothing_on_empty_response(self):
builder = TranscriptBuilder()
builder.append_user(content="hi")
response = LLMLoopResponse(
response_text=None,
tool_calls=[],
raw_response=None,
)
_record_turn_to_transcript(
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 1
def test_malformed_tool_args_dont_crash(self):
"""Bad JSON in tool arguments falls back to {} without raising."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
response = LLMLoopResponse(
response_text=None,
tool_calls=[LLMToolCall(id="call-1", name="echo", arguments="{not-json")],
raw_response=None,
)
tool_results = [
ToolCallResult(tool_call_id="call-1", tool_name="echo", content="ok")
]
_record_turn_to_transcript(
response,
tool_results,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 3
jsonl = builder.to_jsonl()
assert '"input":{}' in jsonl
class TestRoundTrip:
"""End-to-end: load prior → append new turn → upload."""
@pytest.mark.asyncio
async def test_full_round_trip(self):
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
builder = TranscriptBuilder()
with patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
transcript_builder=builder,
)
assert covers is True
assert builder.entry_count == 2
# New user turn.
builder.append_user(content="new question")
assert builder.entry_count == 3
# New assistant turn.
response = LLMLoopResponse(
response_text="new answer",
tool_calls=[],
raw_response=None,
)
_record_turn_to_transcript(
response,
tool_results=None,
transcript_builder=builder,
model="test-model",
)
assert builder.entry_count == 4
# Upload.
upload_mock = AsyncMock(return_value=None)
with patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
):
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=4,
)
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "new question" in uploaded
assert "new answer" in uploaded
# Original content preserved in the round trip.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_backfill_append_guard(self):
"""Backfill only runs when the last entry is not already assistant."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
# Simulate the backfill guard from stream_chat_completion_baseline.
assistant_text = "partial text before error"
if builder.last_entry_type != "assistant":
builder.append_assistant(
content_blocks=[{"type": "text", "text": assistant_text}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
assert builder.last_entry_type == "assistant"
assert "partial text before error" in builder.to_jsonl()
# Second invocation: the guard must prevent double-append.
initial_count = builder.entry_count
if builder.last_entry_type != "assistant":
builder.append_assistant(
content_blocks=[{"type": "text", "text": "duplicate"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
assert builder.entry_count == initial_count
class TestIsTranscriptStale:
"""``is_transcript_stale`` gates prior-transcript loading."""
def test_none_download_is_not_stale(self):
assert is_transcript_stale(None, session_msg_count=5) is False
def test_zero_message_count_is_not_stale(self):
"""Legacy transcripts without msg_count tracking must remain usable."""
dl = TranscriptDownload(content="", message_count=0)
assert is_transcript_stale(dl, session_msg_count=20) is False
def test_stale_when_covers_less_than_prefix(self):
dl = TranscriptDownload(content="", message_count=2)
# session has 6 messages; transcript must cover at least 5 (6-1).
assert is_transcript_stale(dl, session_msg_count=6) is True
def test_fresh_when_covers_full_prefix(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_fresh_when_exceeds_prefix(self):
"""Race: transcript ahead of session count is still acceptable."""
dl = TranscriptDownload(content="", message_count=10)
assert is_transcript_stale(dl, session_msg_count=6) is False
def test_boundary_equal_to_prefix_minus_one(self):
dl = TranscriptDownload(content="", message_count=5)
assert is_transcript_stale(dl, session_msg_count=6) is False
class TestShouldUploadTranscript:
"""``should_upload_transcript`` gates the final upload."""
def test_upload_allowed_for_user_with_coverage(self):
assert should_upload_transcript("user-1", True) is True
def test_upload_skipped_when_no_user(self):
assert should_upload_transcript(None, True) is False
def test_upload_skipped_when_empty_user(self):
assert should_upload_transcript("", True) is False
def test_upload_skipped_without_coverage(self):
"""Partial transcript must never clobber a more complete stored one."""
assert should_upload_transcript("user-1", False) is False
def test_upload_skipped_when_no_user_and_no_coverage(self):
assert should_upload_transcript(None, False) is False
class TestTranscriptLifecycle:
"""End-to-end: download → validate → build → upload.
Simulates the full transcript lifecycle inside
``stream_chat_completion_baseline`` by mocking the storage layer and
driving each step through the real helpers.
"""
@pytest.mark.asyncio
async def test_full_lifecycle_happy_path(self):
"""Fresh download, append a turn, upload covers the session."""
builder = TranscriptBuilder()
prior = _make_transcript_content("user", "assistant")
download = TranscriptDownload(content=prior, message_count=2)
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=download),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
# --- 1. Download & load prior transcript ---
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=3,
transcript_builder=builder,
)
assert covers is True
# --- 2. Append a new user turn + a new assistant response ---
builder.append_user(content="follow-up question")
_record_turn_to_transcript(
LLMLoopResponse(
response_text="follow-up answer",
tool_calls=[],
raw_response=None,
),
tool_results=None,
transcript_builder=builder,
model="test-model",
)
# --- 3. Gate + upload ---
assert (
should_upload_transcript(
user_id="user-1", transcript_covers_prefix=covers
)
is True
)
await _upload_final_transcript(
user_id="user-1",
session_id="session-1",
transcript_builder=builder,
session_msg_count=4,
)
upload_mock.assert_awaited_once()
assert upload_mock.await_args is not None
uploaded = upload_mock.await_args.kwargs["content"]
assert "follow-up question" in uploaded
assert "follow-up answer" in uploaded
# Original prior-turn content preserved.
assert "user message 0" in uploaded
assert "assistant message 1" in uploaded
@pytest.mark.asyncio
async def test_lifecycle_stale_download_suppresses_upload(self):
"""Stale download → covers=False → upload must be skipped."""
builder = TranscriptBuilder()
# session has 10 msgs but stored transcript only covers 2 → stale.
stale = TranscriptDownload(
content=_make_transcript_content("user", "assistant"),
message_count=2,
)
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=stale),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=10,
transcript_builder=builder,
)
assert covers is False
# The caller's gate mirrors the production path.
assert (
should_upload_transcript(user_id="user-1", transcript_covers_prefix=covers)
is False
)
upload_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_lifecycle_anonymous_user_skips_upload(self):
"""Anonymous (user_id=None) → upload gate must return False."""
builder = TranscriptBuilder()
builder.append_user(content="hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "hello"}],
model="test-model",
stop_reason=STOP_REASON_END_TURN,
)
assert (
should_upload_transcript(user_id=None, transcript_covers_prefix=True)
is False
)
@pytest.mark.asyncio
async def test_lifecycle_missing_download_still_uploads_new_content(self):
"""No prior transcript → covers defaults to True in the service,
new turn should upload cleanly."""
builder = TranscriptBuilder()
upload_mock = AsyncMock(return_value=None)
with (
patch(
"backend.copilot.baseline.service.download_transcript",
new=AsyncMock(return_value=None),
),
patch(
"backend.copilot.baseline.service.upload_transcript",
new=upload_mock,
),
):
covers = await _load_prior_transcript(
user_id="user-1",
session_id="session-1",
session_msg_count=1,
transcript_builder=builder,
)
# No download: covers is False, so the production path would
# skip upload. This protects against overwriting a future
# more-complete transcript with a single-turn snapshot.
assert covers is False
assert (
should_upload_transcript(
user_id="user-1", transcript_covers_prefix=covers
)
is False
)
upload_mock.assert_not_awaited()

View File

@@ -8,13 +8,26 @@ from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
# Per-request routing mode for a single chat turn.
# - 'fast': route to the baseline OpenAI-compatible path with the cheaper model.
# - 'extended_thinking': route to the Claude Agent SDK path with the default
# (opus) model.
# ``None`` means "no override"; the server falls back to the Claude Code
# subscription flag → LaunchDarkly COPILOT_SDK → config.use_claude_agent_sdk.
CopilotMode = Literal["fast", "extended_thinking"]
class ChatConfig(BaseSettings):
"""Configuration for the chat system."""
# OpenAI API Configuration
model: str = Field(
default="anthropic/claude-opus-4.6", description="Default model to use"
default="anthropic/claude-opus-4.6",
description="Default model for extended thinking mode",
)
fast_model: str = Field(
default="anthropic/claude-sonnet-4",
description="Model for fast mode (baseline path). Should be faster/cheaper than the default model.",
)
title_model: str = Field(
default="openai/gpt-4o-mini",

View File

@@ -13,7 +13,7 @@ import time
from backend.copilot import stream_registry
from backend.copilot.baseline import stream_chat_completion_baseline
from backend.copilot.config import ChatConfig
from backend.copilot.config import ChatConfig, CopilotMode
from backend.copilot.response_model import StreamError
from backend.copilot.sdk import service as sdk_service
from backend.copilot.sdk.dummy import stream_chat_completion_dummy
@@ -30,6 +30,57 @@ from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
# ============ Mode Routing ============ #
async def resolve_effective_mode(
mode: CopilotMode | None,
user_id: str | None,
) -> CopilotMode | None:
"""Strip ``mode`` when the user is not entitled to the toggle.
The UI gates the mode toggle behind ``CHAT_MODE_OPTION``; the
processor enforces the same gate server-side so an authenticated
user cannot bypass the flag by crafting a request directly.
"""
if mode is None:
return None
allowed = await is_feature_enabled(
Flag.CHAT_MODE_OPTION,
user_id or "anonymous",
default=False,
)
if not allowed:
logger.info(f"Ignoring mode={mode} — CHAT_MODE_OPTION is disabled for user")
return None
return mode
async def resolve_use_sdk_for_mode(
mode: CopilotMode | None,
user_id: str | None,
*,
use_claude_code_subscription: bool,
config_default: bool,
) -> bool:
"""Pick the SDK vs baseline path for a single turn.
Per-request ``mode`` wins whenever it is set (after the
``CHAT_MODE_OPTION`` gate has been applied upstream). Otherwise
falls back to the Claude Code subscription override, then the
``COPILOT_SDK`` LaunchDarkly flag, then the config default.
"""
if mode == "fast":
return False
if mode == "extended_thinking":
return True
return use_claude_code_subscription or await is_feature_enabled(
Flag.COPILOT_SDK,
user_id or "anonymous",
default=config_default,
)
# ============ Module Entry Points ============ #
# Thread-local storage for processor instances
@@ -250,21 +301,26 @@ class CoPilotProcessor:
if config.test_mode:
stream_fn = stream_chat_completion_dummy
log.warning("Using DUMMY service (CHAT_TEST_MODE=true)")
effective_mode = None
else:
use_sdk = (
config.use_claude_code_subscription
or await is_feature_enabled(
Flag.COPILOT_SDK,
entry.user_id or "anonymous",
default=config.use_claude_agent_sdk,
)
# Enforce server-side feature-flag gate so unauthorised
# users cannot force a mode by crafting the request.
effective_mode = await resolve_effective_mode(entry.mode, entry.user_id)
use_sdk = await resolve_use_sdk_for_mode(
effective_mode,
entry.user_id,
use_claude_code_subscription=config.use_claude_code_subscription,
config_default=config.use_claude_agent_sdk,
)
stream_fn = (
sdk_service.stream_chat_completion_sdk
if use_sdk
else stream_chat_completion_baseline
)
log.info(f"Using {'SDK' if use_sdk else 'baseline'} service")
log.info(
f"Using {'SDK' if use_sdk else 'baseline'} service "
f"(mode={effective_mode or 'default'})"
)
# Stream chat completion and publish chunks to Redis.
# stream_and_publish wraps the raw stream with registry
@@ -276,6 +332,7 @@ class CoPilotProcessor:
user_id=entry.user_id,
context=entry.context,
file_ids=entry.file_ids,
mode=effective_mode,
)
async for chunk in stream_registry.stream_and_publish(
session_id=entry.session_id,

View File

@@ -0,0 +1,175 @@
"""Unit tests for CoPilot mode routing logic in the processor.
Tests cover the mode→service mapping:
- 'fast' → baseline service
- 'extended_thinking' → SDK service
- None → feature flag / config fallback
as well as the ``CHAT_MODE_OPTION`` server-side gate. The tests import
the real production helpers from ``processor.py`` so the routing logic
has meaningful coverage.
"""
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.executor.processor import (
resolve_effective_mode,
resolve_use_sdk_for_mode,
)
class TestResolveUseSdkForMode:
"""Tests for the per-request mode routing logic."""
@pytest.mark.asyncio
async def test_fast_mode_uses_baseline(self):
"""mode='fast' always routes to baseline, regardless of flags."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
):
assert (
await resolve_use_sdk_for_mode(
"fast",
"user-1",
use_claude_code_subscription=True,
config_default=True,
)
is False
)
@pytest.mark.asyncio
async def test_extended_thinking_uses_sdk(self):
"""mode='extended_thinking' always routes to SDK, regardless of flags."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert (
await resolve_use_sdk_for_mode(
"extended_thinking",
"user-1",
use_claude_code_subscription=False,
config_default=False,
)
is True
)
@pytest.mark.asyncio
async def test_none_mode_uses_subscription_override(self):
"""mode=None with claude_code_subscription=True routes to SDK."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert (
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=True,
config_default=False,
)
is True
)
@pytest.mark.asyncio
async def test_none_mode_uses_feature_flag(self):
"""mode=None with feature flag enabled routes to SDK."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
) as flag_mock:
assert (
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=False,
config_default=False,
)
is True
)
flag_mock.assert_awaited_once()
@pytest.mark.asyncio
async def test_none_mode_uses_config_default(self):
"""mode=None falls back to config.use_claude_agent_sdk."""
# When LaunchDarkly returns the default (True), we expect SDK routing.
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
):
assert (
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=False,
config_default=True,
)
is True
)
@pytest.mark.asyncio
async def test_none_mode_all_disabled(self):
"""mode=None with all flags off routes to baseline."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert (
await resolve_use_sdk_for_mode(
None,
"user-1",
use_claude_code_subscription=False,
config_default=False,
)
is False
)
class TestResolveEffectiveMode:
"""Tests for the CHAT_MODE_OPTION server-side gate."""
@pytest.mark.asyncio
async def test_none_mode_passes_through(self):
"""mode=None is returned as-is without a flag check."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
) as flag_mock:
assert await resolve_effective_mode(None, "user-1") is None
flag_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_mode_stripped_when_flag_disabled(self):
"""When CHAT_MODE_OPTION is off, mode is dropped to None."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
):
assert await resolve_effective_mode("fast", "user-1") is None
assert await resolve_effective_mode("extended_thinking", "user-1") is None
@pytest.mark.asyncio
async def test_mode_preserved_when_flag_enabled(self):
"""When CHAT_MODE_OPTION is on, the user-selected mode is preserved."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=True),
):
assert await resolve_effective_mode("fast", "user-1") == "fast"
assert (
await resolve_effective_mode("extended_thinking", "user-1")
== "extended_thinking"
)
@pytest.mark.asyncio
async def test_anonymous_user_with_mode(self):
"""Anonymous users (user_id=None) still pass through the gate."""
with patch(
"backend.copilot.executor.processor.is_feature_enabled",
new=AsyncMock(return_value=False),
) as flag_mock:
assert await resolve_effective_mode("fast", None) is None
flag_mock.assert_awaited_once()

View File

@@ -9,6 +9,7 @@ import logging
from pydantic import BaseModel
from backend.copilot.config import CopilotMode
from backend.data.rabbitmq import Exchange, ExchangeType, Queue, RabbitMQConfig
from backend.util.logging import TruncatedLogger, is_structured_logging_enabled
@@ -156,6 +157,9 @@ class CoPilotExecutionEntry(BaseModel):
file_ids: list[str] | None = None
"""Workspace file IDs attached to the user's message"""
mode: CopilotMode | None = None
"""Autopilot mode override: 'fast' or 'extended_thinking'. None = server default."""
class CancelCoPilotEvent(BaseModel):
"""Event to cancel a CoPilot operation."""
@@ -175,6 +179,7 @@ async def enqueue_copilot_turn(
is_user_message: bool = True,
context: dict[str, str] | None = None,
file_ids: list[str] | None = None,
mode: CopilotMode | None = None,
) -> None:
"""Enqueue a CoPilot task for processing by the executor service.
@@ -186,6 +191,7 @@ async def enqueue_copilot_turn(
is_user_message: Whether the message is from the user (vs system/assistant)
context: Optional context for the message (e.g., {url: str, content: str})
file_ids: Optional workspace file IDs attached to the user's message
mode: Autopilot mode override ('fast' or 'extended_thinking'). None = server default.
"""
from backend.util.clients import get_async_copilot_queue
@@ -197,6 +203,7 @@ async def enqueue_copilot_turn(
is_user_message=is_user_message,
context=context,
file_ids=file_ids,
mode=mode,
)
queue_client = await get_async_copilot_queue()

View File

@@ -0,0 +1,123 @@
"""Tests for CoPilot executor utils (queue config, message models, logging)."""
from backend.copilot.executor.utils import (
COPILOT_EXECUTION_EXCHANGE,
COPILOT_EXECUTION_QUEUE_NAME,
COPILOT_EXECUTION_ROUTING_KEY,
CancelCoPilotEvent,
CoPilotExecutionEntry,
CoPilotLogMetadata,
create_copilot_queue_config,
)
class TestCoPilotExecutionEntry:
def test_basic_fields(self):
entry = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="hello",
)
assert entry.session_id == "s1"
assert entry.user_id == "u1"
assert entry.message == "hello"
assert entry.is_user_message is True
assert entry.mode is None
assert entry.context is None
assert entry.file_ids is None
def test_mode_field(self):
entry = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="test",
mode="fast",
)
assert entry.mode == "fast"
entry2 = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="test",
mode="extended_thinking",
)
assert entry2.mode == "extended_thinking"
def test_optional_fields(self):
entry = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="test",
turn_id="t1",
context={"url": "https://example.com"},
file_ids=["f1", "f2"],
is_user_message=False,
)
assert entry.turn_id == "t1"
assert entry.context == {"url": "https://example.com"}
assert entry.file_ids == ["f1", "f2"]
assert entry.is_user_message is False
def test_serialization_roundtrip(self):
entry = CoPilotExecutionEntry(
session_id="s1",
user_id="u1",
message="hello",
mode="fast",
)
json_str = entry.model_dump_json()
restored = CoPilotExecutionEntry.model_validate_json(json_str)
assert restored == entry
class TestCancelCoPilotEvent:
def test_basic(self):
event = CancelCoPilotEvent(session_id="s1")
assert event.session_id == "s1"
def test_serialization(self):
event = CancelCoPilotEvent(session_id="s1")
restored = CancelCoPilotEvent.model_validate_json(event.model_dump_json())
assert restored.session_id == "s1"
class TestCreateCopilotQueueConfig:
def test_returns_valid_config(self):
config = create_copilot_queue_config()
assert len(config.exchanges) == 2
assert len(config.queues) == 2
def test_execution_queue_properties(self):
config = create_copilot_queue_config()
exec_queue = next(
q for q in config.queues if q.name == COPILOT_EXECUTION_QUEUE_NAME
)
assert exec_queue.durable is True
assert exec_queue.exchange == COPILOT_EXECUTION_EXCHANGE
assert exec_queue.routing_key == COPILOT_EXECUTION_ROUTING_KEY
def test_cancel_queue_uses_fanout(self):
config = create_copilot_queue_config()
cancel_queue = next(
q for q in config.queues if q.name != COPILOT_EXECUTION_QUEUE_NAME
)
assert cancel_queue.exchange is not None
assert cancel_queue.exchange.type.value == "fanout"
class TestCoPilotLogMetadata:
def test_creates_logger_with_metadata(self):
import logging
base_logger = logging.getLogger("test")
log = CoPilotLogMetadata(base_logger, session_id="s1", user_id="u1")
assert log is not None
def test_filters_none_values(self):
import logging
base_logger = logging.getLogger("test")
log = CoPilotLogMetadata(
base_logger, session_id="s1", user_id=None, turn_id="t1"
)
assert log is not None

View File

@@ -13,12 +13,21 @@ from .rate_limit import (
RateLimitExceeded,
SubscriptionTier,
UsageWindow,
_daily_key,
_daily_reset_time,
_weekly_key,
_weekly_reset_time,
acquire_reset_lock,
check_rate_limit,
get_daily_reset_count,
get_global_rate_limits,
get_usage_status,
get_user_tier,
increment_daily_reset_count,
record_token_usage,
release_reset_lock,
reset_daily_usage,
reset_user_usage,
set_user_tier,
)
@@ -1210,3 +1219,205 @@ class TestTierLimitsEnforced:
assert daily == biz_daily # 20x
# Should NOT raise — usage is within the BUSINESS tier allowance
await check_rate_limit(_USER, daily, weekly)
# ---------------------------------------------------------------------------
# Private key/reset helpers
# ---------------------------------------------------------------------------
class TestKeyHelpers:
def test_daily_key_format(self):
now = datetime(2026, 4, 3, 12, 0, 0, tzinfo=UTC)
key = _daily_key("user-1", now=now)
assert "daily" in key
assert "user-1" in key
assert "2026-04-03" in key
def test_daily_key_defaults_to_now(self):
key = _daily_key("user-1")
assert "daily" in key
assert "user-1" in key
def test_weekly_key_format(self):
now = datetime(2026, 4, 3, 12, 0, 0, tzinfo=UTC)
key = _weekly_key("user-1", now=now)
assert "weekly" in key
assert "user-1" in key
assert "2026-W" in key
def test_weekly_key_defaults_to_now(self):
key = _weekly_key("user-1")
assert "weekly" in key
def test_daily_reset_time_is_next_midnight(self):
now = datetime(2026, 4, 3, 15, 30, 0, tzinfo=UTC)
reset = _daily_reset_time(now=now)
assert reset == datetime(2026, 4, 4, 0, 0, 0, tzinfo=UTC)
def test_daily_reset_time_defaults_to_now(self):
reset = _daily_reset_time()
assert reset.hour == 0
assert reset.minute == 0
def test_weekly_reset_time_is_next_monday(self):
# 2026-04-03 is a Friday
now = datetime(2026, 4, 3, 15, 30, 0, tzinfo=UTC)
reset = _weekly_reset_time(now=now)
assert reset.weekday() == 0 # Monday
assert reset == datetime(2026, 4, 6, 0, 0, 0, tzinfo=UTC)
def test_weekly_reset_time_defaults_to_now(self):
reset = _weekly_reset_time()
assert reset.weekday() == 0 # Monday
# ---------------------------------------------------------------------------
# acquire_reset_lock / release_reset_lock
# ---------------------------------------------------------------------------
class TestResetLock:
@pytest.mark.asyncio
async def test_acquire_lock_success(self):
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=True)
with patch(
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
):
result = await acquire_reset_lock("user-1")
assert result is True
@pytest.mark.asyncio
async def test_acquire_lock_already_held(self):
mock_redis = AsyncMock()
mock_redis.set = AsyncMock(return_value=False)
with patch(
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
):
result = await acquire_reset_lock("user-1")
assert result is False
@pytest.mark.asyncio
async def test_acquire_lock_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=RedisError("down"),
):
result = await acquire_reset_lock("user-1")
assert result is False
@pytest.mark.asyncio
async def test_release_lock_success(self):
mock_redis = AsyncMock()
with patch(
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
):
await release_reset_lock("user-1")
mock_redis.delete.assert_called_once()
@pytest.mark.asyncio
async def test_release_lock_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=RedisError("down"),
):
# Should not raise
await release_reset_lock("user-1")
# ---------------------------------------------------------------------------
# get_daily_reset_count / increment_daily_reset_count
# ---------------------------------------------------------------------------
class TestDailyResetCount:
@pytest.mark.asyncio
async def test_get_count_returns_value(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value="3")
with patch(
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
):
count = await get_daily_reset_count("user-1")
assert count == 3
@pytest.mark.asyncio
async def test_get_count_returns_zero_when_no_key(self):
mock_redis = AsyncMock()
mock_redis.get = AsyncMock(return_value=None)
with patch(
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
):
count = await get_daily_reset_count("user-1")
assert count == 0
@pytest.mark.asyncio
async def test_get_count_returns_none_when_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=RedisError("down"),
):
count = await get_daily_reset_count("user-1")
assert count is None
@pytest.mark.asyncio
async def test_increment_count(self):
mock_pipe = MagicMock()
mock_pipe.incr = MagicMock()
mock_pipe.expire = MagicMock()
mock_pipe.execute = AsyncMock()
mock_redis = AsyncMock()
mock_redis.pipeline = MagicMock(return_value=mock_pipe)
with patch(
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
):
await increment_daily_reset_count("user-1")
mock_pipe.incr.assert_called_once()
mock_pipe.expire.assert_called_once()
@pytest.mark.asyncio
async def test_increment_count_redis_unavailable(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=RedisError("down"),
):
# Should not raise
await increment_daily_reset_count("user-1")
# ---------------------------------------------------------------------------
# reset_user_usage
# ---------------------------------------------------------------------------
class TestResetUserUsage:
@pytest.mark.asyncio
async def test_resets_daily_key(self):
mock_redis = AsyncMock()
with patch(
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
):
await reset_user_usage("user-1")
mock_redis.delete.assert_called_once()
@pytest.mark.asyncio
async def test_resets_daily_and_weekly(self):
mock_redis = AsyncMock()
with patch(
"backend.copilot.rate_limit.get_redis_async", return_value=mock_redis
):
await reset_user_usage("user-1", reset_weekly=True)
args = mock_redis.delete.call_args[0]
assert len(args) == 2 # both daily and weekly keys
@pytest.mark.asyncio
async def test_raises_on_redis_failure(self):
with patch(
"backend.copilot.rate_limit.get_redis_async",
side_effect=RedisError("down"),
):
with pytest.raises(RedisError):
await reset_user_usage("user-1")

View File

@@ -8,20 +8,19 @@ from uuid import uuid4
import pytest
from backend.util import json
from backend.util.prompt import CompressResult
from .conftest import build_test_transcript as _build_transcript
from .service import _friendly_error_text, _is_prompt_too_long
from .transcript import (
from backend.copilot.transcript import (
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_run_compression,
_transcript_to_messages,
compact_transcript,
validate_transcript,
)
from backend.util import json
from backend.util.prompt import CompressResult
from .conftest import build_test_transcript as _build_transcript
from .service import _friendly_error_text, _is_prompt_too_long
from .transcript import compact_transcript, validate_transcript
# ---------------------------------------------------------------------------
# _flatten_assistant_content
@@ -403,7 +402,7 @@ class TestCompactTranscript:
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
@@ -438,7 +437,7 @@ class TestCompactTranscript:
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
@@ -462,7 +461,7 @@ class TestCompactTranscript:
]
)
with patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
side_effect=RuntimeError("LLM unavailable"),
):
@@ -568,11 +567,11 @@ class TestRunCompressionTimeout:
with (
patch(
"backend.copilot.sdk.transcript.get_openai_client",
"backend.copilot.transcript.get_openai_client",
return_value="fake-client",
),
patch(
"backend.copilot.sdk.transcript.compress_context",
"backend.copilot.transcript.compress_context",
side_effect=_mock_compress,
),
):
@@ -602,11 +601,11 @@ class TestRunCompressionTimeout:
with (
patch(
"backend.copilot.sdk.transcript.get_openai_client",
"backend.copilot.transcript.get_openai_client",
return_value=None,
),
patch(
"backend.copilot.sdk.transcript.compress_context",
"backend.copilot.transcript.compress_context",
new_callable=AsyncMock,
return_value=truncation_result,
) as mock_compress,

View File

@@ -26,18 +26,17 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from backend.util import json
from .conftest import build_test_transcript as _build_transcript
from .service import _MAX_STREAM_ATTEMPTS, _reduce_context
from .transcript import (
from backend.copilot.transcript import (
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_transcript_to_messages,
compact_transcript,
validate_transcript,
)
from backend.util import json
from .conftest import build_test_transcript as _build_transcript
from .service import _MAX_STREAM_ATTEMPTS, _reduce_context
from .transcript import compact_transcript, validate_transcript
from .transcript_builder import TranscriptBuilder
# ---------------------------------------------------------------------------
@@ -113,7 +112,7 @@ class TestScenarioCompactAndRetry:
)(),
),
patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
),
@@ -170,7 +169,7 @@ class TestScenarioCompactFailsFallback:
)(),
),
patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
side_effect=RuntimeError("LLM unavailable"),
),
@@ -261,7 +260,7 @@ class TestScenarioDoubleFailDBFallback:
)(),
),
patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
),
@@ -337,7 +336,7 @@ class TestScenarioCompactionIdentical:
)(),
),
patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
),
@@ -730,7 +729,7 @@ class TestRetryEdgeCases:
)(),
),
patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
),
@@ -841,7 +840,7 @@ class TestRetryStateReset:
)(),
),
patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
side_effect=RuntimeError("boom"),
),
@@ -1405,9 +1404,9 @@ class TestStreamChatCompletionRetryIntegration:
events.append(event)
# Should NOT retry — only 1 attempt for auth errors
assert attempt_count[0] == 1, (
f"Expected 1 attempt (no retry for auth error), " f"got {attempt_count[0]}"
)
assert (
attempt_count[0] == 1
), f"Expected 1 attempt (no retry for auth error), got {attempt_count[0]}"
errors = [e for e in events if isinstance(e, StreamError)]
assert errors, "Expected StreamError"
assert errors[0].code == "sdk_stream_error"

View File

@@ -34,6 +34,17 @@ from pydantic import BaseModel
from backend.copilot.context import get_workspace_manager
from backend.copilot.permissions import apply_tool_permissions
from backend.copilot.rate_limit import get_user_tier
from backend.copilot.transcript import (
_run_compression,
cleanup_stale_project_dirs,
compact_transcript,
download_transcript,
read_compacted_entries,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
)
from backend.copilot.transcript_builder import TranscriptBuilder
from backend.data.redis_client import get_redis_async
from backend.executor.cluster_lock import AsyncClusterLock
from backend.util.exceptions import NotFoundError
@@ -94,17 +105,6 @@ from .tool_adapter import (
set_execution_context,
wait_for_stash,
)
from .transcript import (
_run_compression,
cleanup_stale_project_dirs,
compact_transcript,
download_transcript,
read_compacted_entries,
upload_transcript,
validate_transcript,
write_transcript_to_tempfile,
)
from .transcript_builder import TranscriptBuilder
logger = logging.getLogger(__name__)
config = ChatConfig()

View File

@@ -27,20 +27,19 @@ from backend.copilot.response_model import (
StreamTextDelta,
StreamTextStart,
)
from backend.util import json
from .conftest import build_structured_transcript
from .response_adapter import SDKResponseAdapter
from .service import _format_sdk_content_blocks
from .transcript import (
from backend.copilot.transcript import (
_find_last_assistant_entry,
_flatten_assistant_content,
_messages_to_transcript,
_rechain_tail,
_transcript_to_messages,
compact_transcript,
validate_transcript,
)
from backend.util import json
from .conftest import build_structured_transcript
from .response_adapter import SDKResponseAdapter
from .service import _format_sdk_content_blocks
from .transcript import compact_transcript, validate_transcript
# ---------------------------------------------------------------------------
# Fixtures: realistic thinking block content
@@ -439,7 +438,7 @@ class TestCompactTranscriptThinkingBlocks:
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
@@ -498,7 +497,7 @@ class TestCompactTranscriptThinkingBlocks:
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
side_effect=mock_compression,
):
await compact_transcript(transcript, model="test-model")
@@ -551,7 +550,7 @@ class TestCompactTranscriptThinkingBlocks:
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
@@ -601,7 +600,7 @@ class TestCompactTranscriptThinkingBlocks:
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
@@ -638,7 +637,7 @@ class TestCompactTranscriptThinkingBlocks:
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):
@@ -699,7 +698,7 @@ class TestCompactTranscriptThinkingBlocks:
},
)()
with patch(
"backend.copilot.sdk.transcript._run_compression",
"backend.copilot.transcript._run_compression",
new_callable=AsyncMock,
return_value=mock_result,
):

File diff suppressed because it is too large Load Diff

View File

@@ -1,235 +1,10 @@
"""Build complete JSONL transcript from SDK messages.
"""Re-export from shared ``backend.copilot.transcript_builder`` for backward compat.
The transcript represents the FULL active context at any point in time.
Each upload REPLACES the previous transcript atomically.
Flow:
Turn 1: Upload [msg1, msg2]
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
The transcript is never incremental - always the complete atomic state.
The canonical implementation now lives at ``backend.copilot.transcript_builder``
so both the SDK and baseline paths can import without cross-package
dependencies.
"""
import logging
from typing import Any
from uuid import uuid4
from backend.copilot.transcript_builder import TranscriptBuilder, TranscriptEntry
from pydantic import BaseModel
from backend.util import json
from .transcript import STRIPPABLE_TYPES
logger = logging.getLogger(__name__)
class TranscriptEntry(BaseModel):
"""Single transcript entry (user or assistant turn)."""
type: str
uuid: str
parentUuid: str | None
isCompactSummary: bool | None = None
message: dict[str, Any]
class TranscriptBuilder:
"""Build complete JSONL transcript from SDK messages.
This builder maintains the FULL conversation state, not incremental changes.
The output is always the complete active context.
"""
def __init__(self) -> None:
self._entries: list[TranscriptEntry] = []
self._last_uuid: str | None = None
def _last_is_assistant(self) -> bool:
return bool(self._entries) and self._entries[-1].type == "assistant"
def _last_message_id(self) -> str:
"""Return the message.id of the last entry, or '' if none."""
if self._entries:
return self._entries[-1].message.get("id", "")
return ""
@staticmethod
def _parse_entry(data: dict) -> TranscriptEntry | None:
"""Parse a single transcript entry, filtering strippable types.
Returns ``None`` for entries that should be skipped (strippable types
that are not compaction summaries).
"""
entry_type = data.get("type", "")
if entry_type in STRIPPABLE_TYPES and not data.get("isCompactSummary"):
return None
return TranscriptEntry(
type=entry_type,
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid"),
isCompactSummary=data.get("isCompactSummary"),
message=data.get("message", {}),
)
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
"""Load complete previous transcript.
This loads the FULL previous context. As new messages come in,
we append to this state. The final output is the complete context
(previous + new), not just the delta.
"""
if not content or not content.strip():
return
lines = content.strip().split("\n")
for line_num, line in enumerate(lines, 1):
if not line.strip():
continue
data = json.loads(line, fallback=None)
if data is None:
logger.warning(
"%s Failed to parse transcript line %d/%d",
log_prefix,
line_num,
len(lines),
)
continue
entry = self._parse_entry(data)
if entry is None:
continue
self._entries.append(entry)
self._last_uuid = entry.uuid
logger.info(
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
log_prefix,
len(self._entries),
self._last_uuid[:12] if self._last_uuid else None,
)
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
"""Append a user entry."""
msg_uuid = uuid or str(uuid4())
self._entries.append(
TranscriptEntry(
type="user",
uuid=msg_uuid,
parentUuid=self._last_uuid,
message={"role": "user", "content": content},
)
)
self._last_uuid = msg_uuid
def append_tool_result(self, tool_use_id: str, content: str) -> None:
"""Append a tool result as a user entry (one per tool call)."""
self.append_user(
content=[
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
]
)
def append_assistant(
self,
content_blocks: list[dict],
model: str = "",
stop_reason: str | None = None,
) -> None:
"""Append an assistant entry.
Consecutive assistant entries automatically share the same message ID
so the CLI can merge them (thinking → text → tool_use) into a single
API message on ``--resume``. A new ID is assigned whenever an
assistant entry follows a non-assistant entry (user message or tool
result), because that marks the start of a new API response.
"""
message_id = (
self._last_message_id()
if self._last_is_assistant()
else f"msg_sdk_{uuid4().hex[:24]}"
)
msg_uuid = str(uuid4())
self._entries.append(
TranscriptEntry(
type="assistant",
uuid=msg_uuid,
parentUuid=self._last_uuid,
message={
"role": "assistant",
"model": model,
"id": message_id,
"type": "message",
"content": content_blocks,
"stop_reason": stop_reason,
"stop_sequence": None,
},
)
)
self._last_uuid = msg_uuid
def replace_entries(
self, compacted_entries: list[dict], log_prefix: str = "[Transcript]"
) -> None:
"""Replace all entries with compacted entries from the CLI session file.
Called after mid-stream compaction so TranscriptBuilder mirrors the
CLI's active context (compaction summary + post-compaction entries).
Builds the new list first and validates it's non-empty before swapping,
so corrupt input cannot wipe the conversation history.
"""
new_entries: list[TranscriptEntry] = []
for data in compacted_entries:
entry = self._parse_entry(data)
if entry is not None:
new_entries.append(entry)
if not new_entries:
logger.warning(
"%s replace_entries produced 0 entries from %d inputs, keeping old (%d entries)",
log_prefix,
len(compacted_entries),
len(self._entries),
)
return
old_count = len(self._entries)
self._entries = new_entries
self._last_uuid = new_entries[-1].uuid
logger.info(
"%s TranscriptBuilder compacted: %d entries -> %d entries",
log_prefix,
old_count,
len(self._entries),
)
def to_jsonl(self) -> str:
"""Export complete context as JSONL.
Consecutive assistant entries are kept separate to match the
native CLI format — the SDK merges them internally on resume.
Returns the FULL conversation state (all entries), not incremental.
This output REPLACES any previous transcript.
"""
if not self._entries:
return ""
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
return "\n".join(lines) + "\n"
@property
def entry_count(self) -> int:
"""Total number of entries in the complete context."""
return len(self._entries)
@property
def is_empty(self) -> bool:
"""Whether this builder has any entries."""
return len(self._entries) == 0
__all__ = ["TranscriptBuilder", "TranscriptEntry"]

View File

@@ -303,7 +303,7 @@ class TestDeleteTranscript:
mock_storage.delete = AsyncMock()
with patch(
"backend.copilot.sdk.transcript.get_workspace_storage",
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
@@ -323,7 +323,7 @@ class TestDeleteTranscript:
)
with patch(
"backend.copilot.sdk.transcript.get_workspace_storage",
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
@@ -341,7 +341,7 @@ class TestDeleteTranscript:
)
with patch(
"backend.copilot.sdk.transcript.get_workspace_storage",
"backend.copilot.transcript.get_workspace_storage",
new_callable=AsyncMock,
return_value=mock_storage,
):
@@ -850,7 +850,7 @@ class TestRunCompression:
@pytest.mark.asyncio
async def test_no_client_uses_truncation(self):
"""Path (a): ``get_openai_client()`` returns None → truncation only."""
from .transcript import _run_compression
from backend.copilot.transcript import _run_compression
truncation_result = self._make_compress_result(
True, [{"role": "user", "content": "truncated"}]
@@ -858,11 +858,11 @@ class TestRunCompression:
with (
patch(
"backend.copilot.sdk.transcript.get_openai_client",
"backend.copilot.transcript.get_openai_client",
return_value=None,
),
patch(
"backend.copilot.sdk.transcript.compress_context",
"backend.copilot.transcript.compress_context",
new_callable=AsyncMock,
return_value=truncation_result,
) as mock_compress,
@@ -885,7 +885,7 @@ class TestRunCompression:
@pytest.mark.asyncio
async def test_llm_success_returns_llm_result(self):
"""Path (b): ``get_openai_client()`` returns a client → LLM compresses."""
from .transcript import _run_compression
from backend.copilot.transcript import _run_compression
llm_result = self._make_compress_result(
True, [{"role": "user", "content": "LLM summary"}]
@@ -894,11 +894,11 @@ class TestRunCompression:
with (
patch(
"backend.copilot.sdk.transcript.get_openai_client",
"backend.copilot.transcript.get_openai_client",
return_value=mock_client,
),
patch(
"backend.copilot.sdk.transcript.compress_context",
"backend.copilot.transcript.compress_context",
new_callable=AsyncMock,
return_value=llm_result,
) as mock_compress,
@@ -916,7 +916,7 @@ class TestRunCompression:
@pytest.mark.asyncio
async def test_llm_failure_falls_back_to_truncation(self):
"""Path (c): LLM call raises → truncation fallback used instead."""
from .transcript import _run_compression
from backend.copilot.transcript import _run_compression
truncation_result = self._make_compress_result(
True, [{"role": "user", "content": "truncated fallback"}]
@@ -932,11 +932,11 @@ class TestRunCompression:
with (
patch(
"backend.copilot.sdk.transcript.get_openai_client",
"backend.copilot.transcript.get_openai_client",
return_value=mock_client,
),
patch(
"backend.copilot.sdk.transcript.compress_context",
"backend.copilot.transcript.compress_context",
side_effect=_compress_side_effect,
),
):
@@ -953,7 +953,7 @@ class TestRunCompression:
@pytest.mark.asyncio
async def test_llm_timeout_falls_back_to_truncation(self):
"""Path (d): LLM call exceeds timeout → truncation fallback used."""
from .transcript import _run_compression
from backend.copilot.transcript import _run_compression
truncation_result = self._make_compress_result(
True, [{"role": "user", "content": "truncated after timeout"}]
@@ -970,19 +970,19 @@ class TestRunCompression:
fake_client = MagicMock()
with (
patch(
"backend.copilot.sdk.transcript.get_openai_client",
"backend.copilot.transcript.get_openai_client",
return_value=fake_client,
),
patch(
"backend.copilot.sdk.transcript.compress_context",
"backend.copilot.transcript.compress_context",
side_effect=_compress_side_effect,
),
patch(
"backend.copilot.sdk.transcript._COMPACTION_TIMEOUT_SECONDS",
"backend.copilot.transcript._COMPACTION_TIMEOUT_SECONDS",
0.05,
),
patch(
"backend.copilot.sdk.transcript._TRUNCATION_TIMEOUT_SECONDS",
"backend.copilot.transcript._TRUNCATION_TIMEOUT_SECONDS",
5,
),
):
@@ -1007,7 +1007,7 @@ class TestCleanupStaleProjectDirs:
def test_removes_old_copilot_dirs(self, tmp_path, monkeypatch):
"""Directories matching copilot pattern older than threshold are removed."""
from backend.copilot.sdk.transcript import (
from backend.copilot.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
@@ -1015,7 +1015,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
"backend.copilot.transcript._projects_base",
lambda: str(projects_dir),
)
@@ -1039,12 +1039,12 @@ class TestCleanupStaleProjectDirs:
def test_ignores_non_copilot_dirs(self, tmp_path, monkeypatch):
"""Directories not matching copilot pattern are left alone."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
from backend.copilot.transcript import cleanup_stale_project_dirs
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
"backend.copilot.transcript._projects_base",
lambda: str(projects_dir),
)
@@ -1062,7 +1062,7 @@ class TestCleanupStaleProjectDirs:
def test_ttl_boundary_not_removed(self, tmp_path, monkeypatch):
"""A directory exactly at the TTL boundary should NOT be removed."""
from backend.copilot.sdk.transcript import (
from backend.copilot.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
@@ -1070,7 +1070,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
"backend.copilot.transcript._projects_base",
lambda: str(projects_dir),
)
@@ -1088,7 +1088,7 @@ class TestCleanupStaleProjectDirs:
def test_skips_non_directory_entries(self, tmp_path, monkeypatch):
"""Regular files matching the copilot pattern are not removed."""
from backend.copilot.sdk.transcript import (
from backend.copilot.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
@@ -1096,7 +1096,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
"backend.copilot.transcript._projects_base",
lambda: str(projects_dir),
)
@@ -1114,11 +1114,11 @@ class TestCleanupStaleProjectDirs:
def test_missing_base_dir_returns_zero(self, tmp_path, monkeypatch):
"""If the projects base directory doesn't exist, return 0 gracefully."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
from backend.copilot.transcript import cleanup_stale_project_dirs
nonexistent = str(tmp_path / "does-not-exist" / "projects")
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
"backend.copilot.transcript._projects_base",
lambda: nonexistent,
)
@@ -1129,7 +1129,7 @@ class TestCleanupStaleProjectDirs:
"""When encoded_cwd is supplied only that directory is swept."""
import time
from backend.copilot.sdk.transcript import (
from backend.copilot.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
@@ -1137,7 +1137,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
"backend.copilot.transcript._projects_base",
lambda: str(projects_dir),
)
@@ -1160,12 +1160,12 @@ class TestCleanupStaleProjectDirs:
def test_scoped_fresh_dir_not_removed(self, tmp_path, monkeypatch):
"""Scoped sweep leaves a fresh directory alone."""
from backend.copilot.sdk.transcript import cleanup_stale_project_dirs
from backend.copilot.transcript import cleanup_stale_project_dirs
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
"backend.copilot.transcript._projects_base",
lambda: str(projects_dir),
)
@@ -1181,7 +1181,7 @@ class TestCleanupStaleProjectDirs:
"""Scoped sweep refuses to remove a non-copilot directory."""
import time
from backend.copilot.sdk.transcript import (
from backend.copilot.transcript import (
_STALE_PROJECT_DIR_SECONDS,
cleanup_stale_project_dirs,
)
@@ -1189,7 +1189,7 @@ class TestCleanupStaleProjectDirs:
projects_dir = tmp_path / "projects"
projects_dir.mkdir()
monkeypatch.setattr(
"backend.copilot.sdk.transcript._projects_base",
"backend.copilot.transcript._projects_base",
lambda: str(projects_dir),
)

View File

@@ -7,7 +7,7 @@ import pytest
from .model import create_chat_session, get_chat_session, upsert_chat_session
from .response_model import StreamError, StreamTextDelta
from .sdk import service as sdk_service
from .sdk.transcript import download_transcript
from .transcript import download_transcript
logger = logging.getLogger(__name__)

View File

@@ -33,12 +33,23 @@ _GET_CURRENT_DATE_BLOCK_ID = "b29c1b50-5d0e-4d9f-8f9d-1b0e6fcbf0b1"
_GMAIL_SEND_BLOCK_ID = "6c27abc2-e51d-499e-a85f-5a0041ba94f0"
_TEXT_REPLACE_BLOCK_ID = "7e7c87ab-3469-4bcc-9abe-67705091b713"
# Default OrchestratorBlock model/mode — kept in sync with ChatConfig.model.
# ChatConfig uses the OpenRouter format ("anthropic/claude-opus-4.6");
# OrchestratorBlock uses the native Anthropic model name.
ORCHESTRATOR_DEFAULT_MODEL = "claude-opus-4-6"
ORCHESTRATOR_DEFAULT_EXECUTION_MODE = "extended_thinking"
# Defaults applied to OrchestratorBlock nodes by the fixer.
_SDM_DEFAULTS: dict[str, int | bool] = {
# execution_mode and model match the copilot's default (extended thinking
# with Opus) so generated agents inherit the same reasoning capabilities.
# If the user explicitly sets these fields, the fixer won't override them.
_SDM_DEFAULTS: dict[str, int | bool | str] = {
"agent_mode_max_iterations": 10,
"conversation_compaction": True,
"retry": 3,
"multiple_tool_calls": False,
"execution_mode": ORCHESTRATOR_DEFAULT_EXECUTION_MODE,
"model": ORCHESTRATOR_DEFAULT_MODEL,
}
@@ -1649,6 +1660,8 @@ class AgentFixer:
2. ``conversation_compaction`` defaults to ``True``
3. ``retry`` defaults to ``3``
4. ``multiple_tool_calls`` defaults to ``False``
5. ``execution_mode`` defaults to ``"extended_thinking"``
6. ``model`` defaults to ``"claude-opus-4-6"``
Args:
agent: The agent dictionary to fix
@@ -1748,6 +1761,12 @@ class AgentFixer:
agent = self.fix_node_x_coordinates(agent, node_lookup=node_lookup)
agent = self.fix_getcurrentdate_offset(agent)
# Apply OrchestratorBlock defaults BEFORE fix_ai_model_parameter so that
# the orchestrator-specific model (claude-opus-4-6) is set first and
# fix_ai_model_parameter sees it as a valid allowed model instead of
# overwriting it with the generic default (gpt-4o).
agent = self.fix_orchestrator_blocks(agent)
# Apply fixes that require blocks information
if blocks:
agent = self.fix_invalid_nested_sink_links(
@@ -1765,9 +1784,6 @@ class AgentFixer:
# Apply fixes for MCPToolBlock nodes
agent = self.fix_mcp_tool_blocks(agent)
# Apply fixes for OrchestratorBlock nodes (agent-mode defaults)
agent = self.fix_orchestrator_blocks(agent)
# Apply fixes for AgentExecutorBlock nodes (sub-agents)
if library_agents:
agent = self.fix_agent_executor_blocks(agent, library_agents)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,240 @@
"""Build complete JSONL transcript from SDK messages.
The transcript represents the FULL active context at any point in time.
Each upload REPLACES the previous transcript atomically.
Flow:
Turn 1: Upload [msg1, msg2]
Turn 2: Download [msg1, msg2] → Upload [msg1, msg2, msg3, msg4] (REPLACE)
Turn 3: Download [msg1, msg2, msg3, msg4] → Upload [all messages] (REPLACE)
The transcript is never incremental - always the complete atomic state.
"""
import logging
from typing import Any
from uuid import uuid4
from pydantic import BaseModel
from backend.util import json
from .transcript import STRIPPABLE_TYPES
logger = logging.getLogger(__name__)
class TranscriptEntry(BaseModel):
"""Single transcript entry (user or assistant turn)."""
type: str
uuid: str
parentUuid: str = ""
isCompactSummary: bool | None = None
message: dict[str, Any]
class TranscriptBuilder:
"""Build complete JSONL transcript from SDK messages.
This builder maintains the FULL conversation state, not incremental changes.
The output is always the complete active context.
"""
def __init__(self) -> None:
self._entries: list[TranscriptEntry] = []
self._last_uuid: str | None = None
def _last_is_assistant(self) -> bool:
return bool(self._entries) and self._entries[-1].type == "assistant"
def _last_message_id(self) -> str:
"""Return the message.id of the last entry, or '' if none."""
if self._entries:
return self._entries[-1].message.get("id", "")
return ""
@staticmethod
def _parse_entry(data: dict) -> TranscriptEntry | None:
"""Parse a single transcript entry, filtering strippable types.
Returns ``None`` for entries that should be skipped (strippable types
that are not compaction summaries).
"""
entry_type = data.get("type", "")
if entry_type in STRIPPABLE_TYPES and not data.get("isCompactSummary"):
return None
return TranscriptEntry(
type=entry_type,
uuid=data.get("uuid") or str(uuid4()),
parentUuid=data.get("parentUuid") or "",
isCompactSummary=data.get("isCompactSummary"),
message=data.get("message", {}),
)
def load_previous(self, content: str, log_prefix: str = "[Transcript]") -> None:
"""Load complete previous transcript.
This loads the FULL previous context. As new messages come in,
we append to this state. The final output is the complete context
(previous + new), not just the delta.
"""
if not content or not content.strip():
return
lines = content.strip().split("\n")
for line_num, line in enumerate(lines, 1):
if not line.strip():
continue
data = json.loads(line, fallback=None)
if data is None:
logger.warning(
"%s Failed to parse transcript line %d/%d",
log_prefix,
line_num,
len(lines),
)
continue
entry = self._parse_entry(data)
if entry is None:
continue
self._entries.append(entry)
self._last_uuid = entry.uuid
logger.info(
"%s Loaded %d entries from previous transcript (last_uuid=%s)",
log_prefix,
len(self._entries),
self._last_uuid[:12] if self._last_uuid else None,
)
def append_user(self, content: str | list[dict], uuid: str | None = None) -> None:
"""Append a user entry."""
msg_uuid = uuid or str(uuid4())
self._entries.append(
TranscriptEntry(
type="user",
uuid=msg_uuid,
parentUuid=self._last_uuid or "",
message={"role": "user", "content": content},
)
)
self._last_uuid = msg_uuid
def append_tool_result(self, tool_use_id: str, content: str) -> None:
"""Append a tool result as a user entry (one per tool call)."""
self.append_user(
content=[
{"type": "tool_result", "tool_use_id": tool_use_id, "content": content}
]
)
def append_assistant(
self,
content_blocks: list[dict],
model: str = "",
stop_reason: str | None = None,
) -> None:
"""Append an assistant entry.
Consecutive assistant entries automatically share the same message ID
so the CLI can merge them (thinking → text → tool_use) into a single
API message on ``--resume``. A new ID is assigned whenever an
assistant entry follows a non-assistant entry (user message or tool
result), because that marks the start of a new API response.
"""
message_id = (
self._last_message_id()
if self._last_is_assistant()
else f"msg_sdk_{uuid4().hex[:24]}"
)
msg_uuid = str(uuid4())
self._entries.append(
TranscriptEntry(
type="assistant",
uuid=msg_uuid,
parentUuid=self._last_uuid or "",
message={
"role": "assistant",
"model": model,
"id": message_id,
"type": "message",
"content": content_blocks,
"stop_reason": stop_reason,
"stop_sequence": None,
},
)
)
self._last_uuid = msg_uuid
def replace_entries(
self, compacted_entries: list[dict], log_prefix: str = "[Transcript]"
) -> None:
"""Replace all entries with compacted entries from the CLI session file.
Called after mid-stream compaction so TranscriptBuilder mirrors the
CLI's active context (compaction summary + post-compaction entries).
Builds the new list first and validates it's non-empty before swapping,
so corrupt input cannot wipe the conversation history.
"""
new_entries: list[TranscriptEntry] = []
for data in compacted_entries:
entry = self._parse_entry(data)
if entry is not None:
new_entries.append(entry)
if not new_entries:
logger.warning(
"%s replace_entries produced 0 entries from %d inputs, keeping old (%d entries)",
log_prefix,
len(compacted_entries),
len(self._entries),
)
return
old_count = len(self._entries)
self._entries = new_entries
self._last_uuid = new_entries[-1].uuid
logger.info(
"%s TranscriptBuilder compacted: %d entries -> %d entries",
log_prefix,
old_count,
len(self._entries),
)
def to_jsonl(self) -> str:
"""Export complete context as JSONL.
Consecutive assistant entries are kept separate to match the
native CLI format — the SDK merges them internally on resume.
Returns the FULL conversation state (all entries), not incremental.
This output REPLACES any previous transcript.
"""
if not self._entries:
return ""
lines = [entry.model_dump_json(exclude_none=True) for entry in self._entries]
return "\n".join(lines) + "\n"
@property
def entry_count(self) -> int:
"""Total number of entries in the complete context."""
return len(self._entries)
@property
def is_empty(self) -> bool:
"""Whether this builder has any entries."""
return len(self._entries) == 0
@property
def last_entry_type(self) -> str | None:
"""Type of the last entry, or None if empty."""
return self._entries[-1].type if self._entries else None

View File

@@ -0,0 +1,260 @@
"""Tests for canonical TranscriptBuilder (backend.copilot.transcript_builder).
These tests directly import from the canonical module to ensure codecov
patch coverage for the new file.
"""
from backend.copilot.transcript_builder import TranscriptBuilder, TranscriptEntry
from backend.util import json
def _make_jsonl(*entries: dict) -> str:
return "\n".join(json.dumps(e) for e in entries) + "\n"
USER_MSG = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hello"},
}
ASST_MSG = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_1",
"type": "message",
"content": [{"type": "text", "text": "hi"}],
"stop_reason": "end_turn",
"stop_sequence": None,
},
}
class TestTranscriptEntry:
def test_basic_construction(self):
entry = TranscriptEntry(
type="user", uuid="u1", message={"role": "user", "content": "hi"}
)
assert entry.type == "user"
assert entry.uuid == "u1"
assert entry.parentUuid == ""
assert entry.isCompactSummary is None
def test_optional_fields(self):
entry = TranscriptEntry(
type="summary",
uuid="s1",
parentUuid="p1",
isCompactSummary=True,
message={"role": "user", "content": "summary"},
)
assert entry.isCompactSummary is True
assert entry.parentUuid == "p1"
class TestTranscriptBuilderInit:
def test_starts_empty(self):
builder = TranscriptBuilder()
assert builder.is_empty
assert builder.entry_count == 0
assert builder.last_entry_type is None
assert builder.to_jsonl() == ""
class TestAppendUser:
def test_appends_user_entry(self):
builder = TranscriptBuilder()
builder.append_user("hello")
assert builder.entry_count == 1
assert builder.last_entry_type == "user"
def test_chains_parent_uuid(self):
builder = TranscriptBuilder()
builder.append_user("first", uuid="u1")
builder.append_user("second", uuid="u2")
output = builder.to_jsonl()
entries = [json.loads(line) for line in output.strip().split("\n")]
assert entries[0]["parentUuid"] == ""
assert entries[1]["parentUuid"] == "u1"
def test_custom_uuid(self):
builder = TranscriptBuilder()
builder.append_user("hello", uuid="custom-id")
output = builder.to_jsonl()
entry = json.loads(output.strip())
assert entry["uuid"] == "custom-id"
class TestAppendToolResult:
def test_appends_as_user_entry(self):
builder = TranscriptBuilder()
builder.append_tool_result(tool_use_id="tc_1", content="result text")
assert builder.entry_count == 1
assert builder.last_entry_type == "user"
output = builder.to_jsonl()
entry = json.loads(output.strip())
content = entry["message"]["content"]
assert len(content) == 1
assert content[0]["type"] == "tool_result"
assert content[0]["tool_use_id"] == "tc_1"
assert content[0]["content"] == "result text"
class TestAppendAssistant:
def test_appends_assistant_entry(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "hello"}],
model="test-model",
stop_reason="end_turn",
)
assert builder.entry_count == 2
assert builder.last_entry_type == "assistant"
def test_consecutive_assistants_share_message_id(self):
builder = TranscriptBuilder()
builder.append_user("hi")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "part 1"}],
model="m",
)
builder.append_assistant(
content_blocks=[{"type": "text", "text": "part 2"}],
model="m",
)
output = builder.to_jsonl()
entries = [json.loads(line) for line in output.strip().split("\n")]
# The two assistant entries share the same message ID
assert entries[1]["message"]["id"] == entries[2]["message"]["id"]
def test_non_consecutive_assistants_get_different_ids(self):
builder = TranscriptBuilder()
builder.append_user("q1")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "a1"}],
model="m",
)
builder.append_user("q2")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "a2"}],
model="m",
)
output = builder.to_jsonl()
entries = [json.loads(line) for line in output.strip().split("\n")]
assert entries[1]["message"]["id"] != entries[3]["message"]["id"]
class TestLoadPrevious:
def test_loads_valid_entries(self):
content = _make_jsonl(USER_MSG, ASST_MSG)
builder = TranscriptBuilder()
builder.load_previous(content)
assert builder.entry_count == 2
def test_skips_empty_content(self):
builder = TranscriptBuilder()
builder.load_previous("")
assert builder.is_empty
builder.load_previous(" ")
assert builder.is_empty
def test_skips_strippable_types(self):
progress = {"type": "progress", "uuid": "p1", "message": {}}
content = _make_jsonl(USER_MSG, progress, ASST_MSG)
builder = TranscriptBuilder()
builder.load_previous(content)
assert builder.entry_count == 2 # progress was skipped
def test_preserves_compact_summary(self):
compact = {
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "user", "content": "summary"},
}
content = _make_jsonl(compact, ASST_MSG)
builder = TranscriptBuilder()
builder.load_previous(content)
assert builder.entry_count == 2
def test_skips_invalid_json_lines(self):
content = '{"type":"user","uuid":"u1","message":{}}\nnot-valid-json\n'
builder = TranscriptBuilder()
builder.load_previous(content)
assert builder.entry_count == 1
class TestToJsonl:
def test_roundtrip(self):
builder = TranscriptBuilder()
builder.append_user("hello", uuid="u1")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "world"}],
model="m",
)
output = builder.to_jsonl()
assert output.endswith("\n")
lines = output.strip().split("\n")
assert len(lines) == 2
for line in lines:
parsed = json.loads(line)
assert "type" in parsed
assert "uuid" in parsed
assert "message" in parsed
class TestReplaceEntries:
def test_replaces_all_entries(self):
builder = TranscriptBuilder()
builder.append_user("old")
builder.append_assistant(
content_blocks=[{"type": "text", "text": "old answer"}], model="m"
)
assert builder.entry_count == 2
compacted = [
{
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "user", "content": "compacted"},
}
]
builder.replace_entries(compacted)
assert builder.entry_count == 1
def test_empty_replacement_keeps_existing(self):
builder = TranscriptBuilder()
builder.append_user("keep me")
builder.replace_entries([])
assert builder.entry_count == 1
class TestParseEntry:
def test_filters_strippable_non_compact(self):
result = TranscriptBuilder._parse_entry(
{"type": "progress", "uuid": "p1", "message": {}}
)
assert result is None
def test_keeps_compact_summary(self):
result = TranscriptBuilder._parse_entry(
{
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {},
}
)
assert result is not None
assert result.isCompactSummary is True
def test_generates_uuid_if_missing(self):
result = TranscriptBuilder._parse_entry(
{"type": "user", "message": {"role": "user", "content": "hi"}}
)
assert result is not None
assert result.uuid # Should be a generated UUID

View File

@@ -0,0 +1,726 @@
"""Tests for canonical transcript module (backend.copilot.transcript).
Covers pure helper functions that are not exercised by the SDK re-export tests.
"""
from __future__ import annotations
from unittest.mock import MagicMock
from backend.util import json
from .transcript import (
TranscriptDownload,
_build_path_from_parts,
_find_last_assistant_entry,
_flatten_assistant_content,
_flatten_tool_result_content,
_messages_to_transcript,
_meta_storage_path_parts,
_rechain_tail,
_sanitize_id,
_storage_path_parts,
_transcript_to_messages,
strip_for_upload,
validate_transcript,
)
def _make_jsonl(*entries: dict) -> str:
return "\n".join(json.dumps(e) for e in entries) + "\n"
# ---------------------------------------------------------------------------
# _sanitize_id
# ---------------------------------------------------------------------------
class TestSanitizeId:
def test_uuid_passes_through(self):
assert _sanitize_id("abcdef12-3456-7890-abcd-ef1234567890") == (
"abcdef12-3456-7890-abcd-ef1234567890"
)
def test_strips_non_hex_characters(self):
# Only hex chars (0-9, a-f, A-F) and hyphens are kept
result = _sanitize_id("abc/../../etc/passwd")
assert "/" not in result
assert "." not in result
# 'p', 's', 'w' are not hex chars, so they are stripped
assert all(c in "0123456789abcdefABCDEF-" for c in result)
def test_truncates_to_max_len(self):
long_id = "a" * 100
result = _sanitize_id(long_id, max_len=10)
assert len(result) == 10
def test_empty_returns_unknown(self):
assert _sanitize_id("") == "unknown"
def test_none_returns_unknown(self):
assert _sanitize_id(None) == "unknown" # type: ignore[arg-type]
def test_special_chars_only_returns_unknown(self):
assert _sanitize_id("!@#$%^&*()") == "unknown"
# ---------------------------------------------------------------------------
# _storage_path_parts / _meta_storage_path_parts
# ---------------------------------------------------------------------------
class TestStoragePathParts:
def test_returns_triple(self):
prefix, uid, fname = _storage_path_parts("user-1", "sess-2")
assert prefix == "chat-transcripts"
assert "e" in uid # hex chars from "user-1" sanitized
assert fname.endswith(".jsonl")
def test_meta_returns_meta_json(self):
prefix, uid, fname = _meta_storage_path_parts("user-1", "sess-2")
assert prefix == "chat-transcripts"
assert fname.endswith(".meta.json")
# ---------------------------------------------------------------------------
# _build_path_from_parts
# ---------------------------------------------------------------------------
class TestBuildPathFromParts:
def test_gcs_backend(self):
from backend.util.workspace_storage import GCSWorkspaceStorage
mock_gcs = MagicMock(spec=GCSWorkspaceStorage)
mock_gcs.bucket_name = "my-bucket"
path = _build_path_from_parts(("wid", "fid", "file.jsonl"), mock_gcs)
assert path == "gcs://my-bucket/workspaces/wid/fid/file.jsonl"
def test_local_backend(self):
# Use a plain object (not MagicMock) so isinstance(GCSWorkspaceStorage) is False
local_backend = type("LocalBackend", (), {})()
path = _build_path_from_parts(("wid", "fid", "file.jsonl"), local_backend)
assert path == "local://wid/fid/file.jsonl"
# ---------------------------------------------------------------------------
# TranscriptDownload dataclass
# ---------------------------------------------------------------------------
class TestTranscriptDownload:
def test_defaults(self):
td = TranscriptDownload(content="hello")
assert td.content == "hello"
assert td.message_count == 0
assert td.uploaded_at == 0.0
def test_custom_values(self):
td = TranscriptDownload(content="data", message_count=5, uploaded_at=123.45)
assert td.message_count == 5
assert td.uploaded_at == 123.45
# ---------------------------------------------------------------------------
# _flatten_assistant_content
# ---------------------------------------------------------------------------
class TestFlattenAssistantContent:
def test_text_blocks(self):
blocks = [
{"type": "text", "text": "Hello"},
{"type": "text", "text": "World"},
]
assert _flatten_assistant_content(blocks) == "Hello\nWorld"
def test_thinking_blocks_stripped(self):
blocks = [
{"type": "thinking", "thinking": "hmm..."},
{"type": "text", "text": "answer"},
{"type": "redacted_thinking", "data": "secret"},
]
assert _flatten_assistant_content(blocks) == "answer"
def test_tool_use_blocks_stripped(self):
blocks = [
{"type": "text", "text": "I'll run a tool"},
{"type": "tool_use", "name": "bash", "id": "tc1", "input": {}},
]
assert _flatten_assistant_content(blocks) == "I'll run a tool"
def test_string_blocks(self):
blocks = ["hello", "world"]
assert _flatten_assistant_content(blocks) == "hello\nworld"
def test_empty_blocks(self):
assert _flatten_assistant_content([]) == ""
def test_unknown_dict_blocks_skipped(self):
blocks = [{"type": "image", "data": "base64..."}]
assert _flatten_assistant_content(blocks) == ""
# ---------------------------------------------------------------------------
# _flatten_tool_result_content
# ---------------------------------------------------------------------------
class TestFlattenToolResultContent:
def test_tool_result_with_text_content(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "tc1",
"content": [{"type": "text", "text": "output data"}],
}
]
assert _flatten_tool_result_content(blocks) == "output data"
def test_tool_result_with_string_content(self):
blocks = [
{"type": "tool_result", "tool_use_id": "tc1", "content": "simple string"}
]
assert _flatten_tool_result_content(blocks) == "simple string"
def test_tool_result_with_image_placeholder(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "tc1",
"content": [{"type": "image", "data": "base64..."}],
}
]
assert _flatten_tool_result_content(blocks) == "[__image__]"
def test_tool_result_with_document_placeholder(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "tc1",
"content": [{"type": "document", "data": "base64..."}],
}
]
assert _flatten_tool_result_content(blocks) == "[__document__]"
def test_tool_result_with_none_content(self):
blocks = [{"type": "tool_result", "tool_use_id": "tc1", "content": None}]
assert _flatten_tool_result_content(blocks) == ""
def test_text_block_outside_tool_result(self):
blocks = [{"type": "text", "text": "standalone"}]
assert _flatten_tool_result_content(blocks) == "standalone"
def test_unknown_dict_block_placeholder(self):
blocks = [{"type": "custom_widget", "data": "x"}]
assert _flatten_tool_result_content(blocks) == "[__custom_widget__]"
def test_string_blocks(self):
blocks = ["raw text"]
assert _flatten_tool_result_content(blocks) == "raw text"
def test_empty_blocks(self):
assert _flatten_tool_result_content([]) == ""
def test_mixed_content_in_tool_result(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "tc1",
"content": [
{"type": "text", "text": "line1"},
{"type": "image", "data": "..."},
"raw string",
],
}
]
result = _flatten_tool_result_content(blocks)
assert "line1" in result
assert "[__image__]" in result
assert "raw string" in result
def test_tool_result_with_dict_without_text_key(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "tc1",
"content": [{"count": 42}],
}
]
result = _flatten_tool_result_content(blocks)
assert "42" in result
def test_tool_result_content_list_with_list_content(self):
blocks = [
{
"type": "tool_result",
"tool_use_id": "tc1",
"content": [{"type": "text", "text": None}],
}
]
result = _flatten_tool_result_content(blocks)
assert result == "None"
# ---------------------------------------------------------------------------
# _transcript_to_messages
# ---------------------------------------------------------------------------
USER_ENTRY = {
"type": "user",
"uuid": "u1",
"parentUuid": "",
"message": {"role": "user", "content": "hello"},
}
ASST_ENTRY = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_1",
"content": [{"type": "text", "text": "hi there"}],
},
}
PROGRESS_ENTRY = {
"type": "progress",
"uuid": "p1",
"parentUuid": "u1",
"data": {},
}
class TestTranscriptToMessages:
def test_basic_conversion(self):
content = _make_jsonl(USER_ENTRY, ASST_ENTRY)
messages = _transcript_to_messages(content)
assert len(messages) == 2
assert messages[0] == {"role": "user", "content": "hello"}
assert messages[1]["role"] == "assistant"
assert messages[1]["content"] == "hi there"
def test_skips_strippable_types(self):
content = _make_jsonl(USER_ENTRY, PROGRESS_ENTRY, ASST_ENTRY)
messages = _transcript_to_messages(content)
assert len(messages) == 2
def test_skips_entries_without_role(self):
no_role = {"type": "user", "uuid": "x", "message": {"content": "no role"}}
content = _make_jsonl(no_role)
messages = _transcript_to_messages(content)
assert len(messages) == 0
def test_handles_string_content(self):
entry = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "plain string"},
}
content = _make_jsonl(entry)
messages = _transcript_to_messages(content)
assert messages[0]["content"] == "plain string"
def test_handles_tool_result_content(self):
entry = {
"type": "user",
"uuid": "u1",
"message": {
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "tc1", "content": "output"}
],
},
}
content = _make_jsonl(entry)
messages = _transcript_to_messages(content)
assert messages[0]["content"] == "output"
def test_handles_none_content(self):
entry = {
"type": "assistant",
"uuid": "a1",
"message": {"role": "assistant", "content": None},
}
content = _make_jsonl(entry)
messages = _transcript_to_messages(content)
assert messages[0]["content"] == ""
def test_skips_invalid_json(self):
content = "not valid json\n"
messages = _transcript_to_messages(content)
assert len(messages) == 0
def test_preserves_compact_summary(self):
compact = {
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "user", "content": "summary of conversation"},
}
content = _make_jsonl(compact)
messages = _transcript_to_messages(content)
assert len(messages) == 1
def test_strips_summary_without_compact_flag(self):
summary = {
"type": "summary",
"uuid": "s1",
"message": {"role": "user", "content": "summary"},
}
content = _make_jsonl(summary)
messages = _transcript_to_messages(content)
assert len(messages) == 0
# ---------------------------------------------------------------------------
# _messages_to_transcript
# ---------------------------------------------------------------------------
class TestMessagesToTranscript:
def test_basic_roundtrip(self):
messages = [
{"role": "user", "content": "hello"},
{"role": "assistant", "content": "world"},
]
result = _messages_to_transcript(messages)
assert result.endswith("\n")
lines = result.strip().split("\n")
assert len(lines) == 2
user_entry = json.loads(lines[0])
assert user_entry["type"] == "user"
assert user_entry["message"]["role"] == "user"
assert user_entry["message"]["content"] == "hello"
assert user_entry["parentUuid"] == ""
asst_entry = json.loads(lines[1])
assert asst_entry["type"] == "assistant"
assert asst_entry["message"]["role"] == "assistant"
assert asst_entry["message"]["content"] == [{"type": "text", "text": "world"}]
assert asst_entry["parentUuid"] == user_entry["uuid"]
def test_empty_messages(self):
assert _messages_to_transcript([]) == ""
def test_assistant_has_message_envelope(self):
messages = [{"role": "assistant", "content": "test"}]
result = _messages_to_transcript(messages)
entry = json.loads(result.strip())
msg = entry["message"]
assert "id" in msg
assert msg["id"].startswith("msg_compact_")
assert msg["type"] == "message"
assert msg["stop_reason"] == "end_turn"
assert msg["stop_sequence"] is None
def test_uuid_chain(self):
messages = [
{"role": "user", "content": "a"},
{"role": "assistant", "content": "b"},
{"role": "user", "content": "c"},
]
result = _messages_to_transcript(messages)
lines = result.strip().split("\n")
entries = [json.loads(line) for line in lines]
assert entries[0]["parentUuid"] == ""
assert entries[1]["parentUuid"] == entries[0]["uuid"]
assert entries[2]["parentUuid"] == entries[1]["uuid"]
def test_assistant_with_empty_content(self):
messages = [{"role": "assistant", "content": ""}]
result = _messages_to_transcript(messages)
entry = json.loads(result.strip())
assert entry["message"]["content"] == []
# ---------------------------------------------------------------------------
# _find_last_assistant_entry
# ---------------------------------------------------------------------------
class TestFindLastAssistantEntry:
def test_splits_at_last_assistant(self):
user = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
}
asst = {
"type": "assistant",
"uuid": "a1",
"message": {"role": "assistant", "id": "msg1", "content": "answer"},
}
content = _make_jsonl(user, asst)
prefix, tail = _find_last_assistant_entry(content)
assert len(prefix) == 1
assert len(tail) == 1
def test_no_assistant_returns_all_in_prefix(self):
user1 = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
}
user2 = {
"type": "user",
"uuid": "u2",
"message": {"role": "user", "content": "hey"},
}
content = _make_jsonl(user1, user2)
prefix, tail = _find_last_assistant_entry(content)
assert len(prefix) == 2
assert len(tail) == 0
def test_multi_entry_turn_preserved(self):
user = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "q"},
}
asst1 = {
"type": "assistant",
"uuid": "a1",
"message": {
"role": "assistant",
"id": "msg_turn",
"content": [{"type": "thinking", "thinking": "hmm"}],
},
}
asst2 = {
"type": "assistant",
"uuid": "a2",
"message": {
"role": "assistant",
"id": "msg_turn",
"content": [{"type": "text", "text": "answer"}],
},
}
content = _make_jsonl(user, asst1, asst2)
prefix, tail = _find_last_assistant_entry(content)
assert len(prefix) == 1 # just the user
assert len(tail) == 2 # both assistant entries
def test_assistant_without_id(self):
user = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "q"},
}
asst = {
"type": "assistant",
"uuid": "a1",
"message": {"role": "assistant", "content": "no id"},
}
content = _make_jsonl(user, asst)
prefix, tail = _find_last_assistant_entry(content)
assert len(prefix) == 1
assert len(tail) == 1
def test_trailing_user_after_assistant(self):
user1 = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "q"},
}
asst = {
"type": "assistant",
"uuid": "a1",
"message": {"role": "assistant", "id": "msg1", "content": "a"},
}
user2 = {
"type": "user",
"uuid": "u2",
"message": {"role": "user", "content": "follow"},
}
content = _make_jsonl(user1, asst, user2)
prefix, tail = _find_last_assistant_entry(content)
assert len(prefix) == 1 # user1
assert len(tail) == 2 # asst + user2
# ---------------------------------------------------------------------------
# _rechain_tail
# ---------------------------------------------------------------------------
class TestRechainTail:
def test_empty_tail(self):
assert _rechain_tail("some prefix\n", []) == ""
def test_patches_first_entry_parent(self):
prefix_entry = {"uuid": "last-prefix-uuid", "type": "user", "message": {}}
prefix = json.dumps(prefix_entry) + "\n"
tail_entry = {
"uuid": "t1",
"parentUuid": "old-parent",
"type": "assistant",
"message": {},
}
tail_lines = [json.dumps(tail_entry)]
result = _rechain_tail(prefix, tail_lines)
parsed = json.loads(result.strip())
assert parsed["parentUuid"] == "last-prefix-uuid"
def test_chains_consecutive_tail_entries(self):
prefix_entry = {"uuid": "p1", "type": "user", "message": {}}
prefix = json.dumps(prefix_entry) + "\n"
t1 = {"uuid": "t1", "parentUuid": "old1", "type": "assistant", "message": {}}
t2 = {"uuid": "t2", "parentUuid": "old2", "type": "user", "message": {}}
tail_lines = [json.dumps(t1), json.dumps(t2)]
result = _rechain_tail(prefix, tail_lines)
entries = [json.loads(line) for line in result.strip().split("\n")]
assert entries[0]["parentUuid"] == "p1"
assert entries[1]["parentUuid"] == "t1"
def test_non_dict_lines_passed_through(self):
prefix_entry = {"uuid": "p1", "type": "user", "message": {}}
prefix = json.dumps(prefix_entry) + "\n"
tail_lines = ["not-a-json-dict"]
result = _rechain_tail(prefix, tail_lines)
assert "not-a-json-dict" in result
# ---------------------------------------------------------------------------
# strip_for_upload (combined single-parse)
# ---------------------------------------------------------------------------
class TestStripForUpload:
def test_strips_progress_and_thinking(self):
user = {
"type": "user",
"uuid": "u1",
"parentUuid": "",
"message": {"role": "user", "content": "hi"},
}
progress = {"type": "progress", "uuid": "p1", "parentUuid": "u1", "data": {}}
asst_old = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "p1",
"message": {
"role": "assistant",
"id": "msg_old",
"content": [
{"type": "thinking", "thinking": "stale thinking"},
{"type": "text", "text": "old answer"},
],
},
}
user2 = {
"type": "user",
"uuid": "u2",
"parentUuid": "a1",
"message": {"role": "user", "content": "next"},
}
asst_new = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "u2",
"message": {
"role": "assistant",
"id": "msg_new",
"content": [
{"type": "thinking", "thinking": "fresh thinking"},
{"type": "text", "text": "new answer"},
],
},
}
content = _make_jsonl(user, progress, asst_old, user2, asst_new)
result = strip_for_upload(content)
lines = result.strip().split("\n")
# Progress should be stripped -> 4 entries remain
assert len(lines) == 4
# First entry (user) should be reparented since its child (progress) was stripped
entries = [json.loads(line) for line in lines]
types = [e.get("type") for e in entries]
assert "progress" not in types
# Old assistant thinking stripped, new assistant thinking preserved
old_asst = next(
e for e in entries if e.get("message", {}).get("id") == "msg_old"
)
old_content = old_asst["message"]["content"]
old_types = [b["type"] for b in old_content if isinstance(b, dict)]
assert "thinking" not in old_types
assert "text" in old_types
new_asst = next(
e for e in entries if e.get("message", {}).get("id") == "msg_new"
)
new_content = new_asst["message"]["content"]
new_types = [b["type"] for b in new_content if isinstance(b, dict)]
assert "thinking" in new_types # last assistant preserved
def test_empty_content(self):
result = strip_for_upload("")
# Empty string produces a single empty line after split, resulting in "\n"
assert result.strip() == ""
def test_preserves_compact_summary(self):
compact = {
"type": "summary",
"uuid": "cs1",
"isCompactSummary": True,
"message": {"role": "user", "content": "summary"},
}
asst = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "cs1",
"message": {"role": "assistant", "id": "msg1", "content": "answer"},
}
content = _make_jsonl(compact, asst)
result = strip_for_upload(content)
lines = result.strip().split("\n")
assert len(lines) == 2
def test_no_assistant_entries(self):
user = {
"type": "user",
"uuid": "u1",
"message": {"role": "user", "content": "hi"},
}
content = _make_jsonl(user)
result = strip_for_upload(content)
lines = result.strip().split("\n")
assert len(lines) == 1
# ---------------------------------------------------------------------------
# validate_transcript (additional edge cases)
# ---------------------------------------------------------------------------
class TestValidateTranscript:
def test_valid_with_assistant(self):
content = _make_jsonl(
USER_ENTRY,
ASST_ENTRY,
)
assert validate_transcript(content) is True
def test_none_returns_false(self):
assert validate_transcript(None) is False
def test_whitespace_only_returns_false(self):
assert validate_transcript(" \n ") is False
def test_no_assistant_returns_false(self):
content = _make_jsonl(USER_ENTRY)
assert validate_transcript(content) is False
def test_invalid_json_returns_false(self):
assert validate_transcript("not json\n") is False
def test_assistant_only_is_valid(self):
content = _make_jsonl(ASST_ENTRY)
assert validate_transcript(content) is True

View File

@@ -38,6 +38,7 @@ class Flag(str, Enum):
AGENT_ACTIVITY = "agent-activity"
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment"
CHAT = "chat"
CHAT_MODE_OPTION = "chat-mode-option"
COPILOT_SDK = "copilot-sdk"
COPILOT_DAILY_TOKEN_LIMIT = "copilot-daily-token-limit"
COPILOT_WEEKLY_TOKEN_LIMIT = "copilot-weekly-token-limit"

View File

@@ -140,7 +140,9 @@ class TestFixOrchestratorBlocks:
assert defaults["conversation_compaction"] is True
assert defaults["retry"] == 3
assert defaults["multiple_tool_calls"] is False
assert len(fixer.fixes_applied) == 4
assert defaults["execution_mode"] == "extended_thinking"
assert defaults["model"] == "claude-opus-4-6"
assert len(fixer.fixes_applied) == 6
def test_preserves_existing_values(self):
"""Existing user-set values are never overwritten."""
@@ -153,6 +155,8 @@ class TestFixOrchestratorBlocks:
"conversation_compaction": False,
"retry": 1,
"multiple_tool_calls": True,
"execution_mode": "built_in",
"model": "gpt-4o",
}
)
],
@@ -166,6 +170,8 @@ class TestFixOrchestratorBlocks:
assert defaults["conversation_compaction"] is False
assert defaults["retry"] == 1
assert defaults["multiple_tool_calls"] is True
assert defaults["execution_mode"] == "built_in"
assert defaults["model"] == "gpt-4o"
assert len(fixer.fixes_applied) == 0
def test_partial_defaults(self):
@@ -189,7 +195,9 @@ class TestFixOrchestratorBlocks:
assert defaults["conversation_compaction"] is True # filled
assert defaults["retry"] == 3 # filled
assert defaults["multiple_tool_calls"] is False # filled
assert len(fixer.fixes_applied) == 3
assert defaults["execution_mode"] == "extended_thinking" # filled
assert defaults["model"] == "claude-opus-4-6" # filled
assert len(fixer.fixes_applied) == 5
def test_skips_non_sdm_nodes(self):
"""Non-Orchestrator nodes are untouched."""
@@ -258,11 +266,13 @@ class TestFixOrchestratorBlocks:
result = fixer.fix_orchestrator_blocks(agent)
defaults = result["nodes"][0]["input_default"]
assert defaults["agent_mode_max_iterations"] == 10 # None default
assert defaults["conversation_compaction"] is True # None default
assert defaults["agent_mode_max_iterations"] == 10 # None -> default
assert defaults["conversation_compaction"] is True # None -> default
assert defaults["retry"] == 3 # kept
assert defaults["multiple_tool_calls"] is False # kept
assert len(fixer.fixes_applied) == 2
assert defaults["execution_mode"] == "extended_thinking" # filled
assert defaults["model"] == "claude-opus-4-6" # filled
assert len(fixer.fixes_applied) == 4
def test_multiple_sdm_nodes(self):
"""Multiple SDM nodes are all fixed independently."""
@@ -277,11 +287,11 @@ class TestFixOrchestratorBlocks:
result = fixer.fix_orchestrator_blocks(agent)
# First node: 3 defaults filled (agent_mode was already set)
# First node: 5 defaults filled (agent_mode was already set)
assert result["nodes"][0]["input_default"]["agent_mode_max_iterations"] == 3
# Second node: all 4 defaults filled
# Second node: all 6 defaults filled
assert result["nodes"][1]["input_default"]["agent_mode_max_iterations"] == 10
assert len(fixer.fixes_applied) == 7 # 3 + 4
assert len(fixer.fixes_applied) == 11 # 5 + 6
def test_registered_in_apply_all_fixes(self):
"""fix_orchestrator_blocks runs as part of apply_all_fixes."""
@@ -655,6 +665,7 @@ class TestOrchestratorE2EPipeline:
"conversation_compaction": {"type": "boolean"},
"retry": {"type": "integer"},
"multiple_tool_calls": {"type": "boolean"},
"execution_mode": {"type": "string"},
},
"required": ["prompt"],
},

View File

@@ -0,0 +1,133 @@
import { describe, it, expect, beforeEach } from "vitest";
import { useOnboardingWizardStore } from "../store";
beforeEach(() => {
useOnboardingWizardStore.getState().reset();
});
describe("useOnboardingWizardStore", () => {
describe("initial state", () => {
it("starts at step 1 with empty fields", () => {
const state = useOnboardingWizardStore.getState();
expect(state.currentStep).toBe(1);
expect(state.name).toBe("");
expect(state.role).toBe("");
expect(state.otherRole).toBe("");
expect(state.painPoints).toEqual([]);
expect(state.otherPainPoint).toBe("");
});
});
describe("setName", () => {
it("updates the name", () => {
useOnboardingWizardStore.getState().setName("Alice");
expect(useOnboardingWizardStore.getState().name).toBe("Alice");
});
});
describe("setRole", () => {
it("updates the role", () => {
useOnboardingWizardStore.getState().setRole("Engineer");
expect(useOnboardingWizardStore.getState().role).toBe("Engineer");
});
});
describe("setOtherRole", () => {
it("updates the other role text", () => {
useOnboardingWizardStore.getState().setOtherRole("Designer");
expect(useOnboardingWizardStore.getState().otherRole).toBe("Designer");
});
});
describe("togglePainPoint", () => {
it("adds a pain point", () => {
useOnboardingWizardStore.getState().togglePainPoint("slow builds");
expect(useOnboardingWizardStore.getState().painPoints).toEqual([
"slow builds",
]);
});
it("removes a pain point when toggled again", () => {
useOnboardingWizardStore.getState().togglePainPoint("slow builds");
useOnboardingWizardStore.getState().togglePainPoint("slow builds");
expect(useOnboardingWizardStore.getState().painPoints).toEqual([]);
});
it("handles multiple pain points", () => {
useOnboardingWizardStore.getState().togglePainPoint("slow builds");
useOnboardingWizardStore.getState().togglePainPoint("no tests");
expect(useOnboardingWizardStore.getState().painPoints).toEqual([
"slow builds",
"no tests",
]);
useOnboardingWizardStore.getState().togglePainPoint("slow builds");
expect(useOnboardingWizardStore.getState().painPoints).toEqual([
"no tests",
]);
});
});
describe("setOtherPainPoint", () => {
it("updates the other pain point text", () => {
useOnboardingWizardStore.getState().setOtherPainPoint("flaky CI");
expect(useOnboardingWizardStore.getState().otherPainPoint).toBe(
"flaky CI",
);
});
});
describe("nextStep", () => {
it("increments the step", () => {
useOnboardingWizardStore.getState().nextStep();
expect(useOnboardingWizardStore.getState().currentStep).toBe(2);
});
it("clamps at step 4", () => {
useOnboardingWizardStore.getState().goToStep(4);
useOnboardingWizardStore.getState().nextStep();
expect(useOnboardingWizardStore.getState().currentStep).toBe(4);
});
});
describe("prevStep", () => {
it("decrements the step", () => {
useOnboardingWizardStore.getState().goToStep(3);
useOnboardingWizardStore.getState().prevStep();
expect(useOnboardingWizardStore.getState().currentStep).toBe(2);
});
it("clamps at step 1", () => {
useOnboardingWizardStore.getState().prevStep();
expect(useOnboardingWizardStore.getState().currentStep).toBe(1);
});
});
describe("goToStep", () => {
it("jumps to an arbitrary step", () => {
useOnboardingWizardStore.getState().goToStep(3);
expect(useOnboardingWizardStore.getState().currentStep).toBe(3);
});
});
describe("reset", () => {
it("resets all fields to defaults", () => {
useOnboardingWizardStore.getState().setName("Alice");
useOnboardingWizardStore.getState().setRole("Engineer");
useOnboardingWizardStore.getState().setOtherRole("Other");
useOnboardingWizardStore.getState().togglePainPoint("slow builds");
useOnboardingWizardStore.getState().setOtherPainPoint("flaky CI");
useOnboardingWizardStore.getState().goToStep(3);
useOnboardingWizardStore.getState().reset();
const state = useOnboardingWizardStore.getState();
expect(state.currentStep).toBe(1);
expect(state.name).toBe("");
expect(state.role).toBe("");
expect(state.otherRole).toBe("");
expect(state.painPoints).toEqual([]);
expect(state.otherPainPoint).toBe("");
});
});
});

View File

@@ -0,0 +1,221 @@
import { describe, expect, it, beforeEach, vi } from "vitest";
import { useCopilotUIStore } from "../store";
vi.mock("@sentry/nextjs", () => ({
captureException: vi.fn(),
}));
vi.mock("@/services/environment", () => ({
environment: {
isServerSide: vi.fn(() => false),
},
}));
describe("useCopilotUIStore", () => {
beforeEach(() => {
window.localStorage.clear();
useCopilotUIStore.setState({
initialPrompt: null,
sessionToDelete: null,
isDrawerOpen: false,
completedSessionIDs: new Set<string>(),
isNotificationsEnabled: false,
isSoundEnabled: true,
showNotificationDialog: false,
copilotMode: "extended_thinking",
});
});
describe("initialPrompt", () => {
it("starts as null", () => {
expect(useCopilotUIStore.getState().initialPrompt).toBeNull();
});
it("sets and clears prompt", () => {
useCopilotUIStore.getState().setInitialPrompt("Hello");
expect(useCopilotUIStore.getState().initialPrompt).toBe("Hello");
useCopilotUIStore.getState().setInitialPrompt(null);
expect(useCopilotUIStore.getState().initialPrompt).toBeNull();
});
});
describe("sessionToDelete", () => {
it("starts as null", () => {
expect(useCopilotUIStore.getState().sessionToDelete).toBeNull();
});
it("sets and clears a delete target", () => {
useCopilotUIStore
.getState()
.setSessionToDelete({ id: "abc", title: "Test" });
expect(useCopilotUIStore.getState().sessionToDelete).toEqual({
id: "abc",
title: "Test",
});
useCopilotUIStore.getState().setSessionToDelete(null);
expect(useCopilotUIStore.getState().sessionToDelete).toBeNull();
});
});
describe("drawer", () => {
it("starts closed", () => {
expect(useCopilotUIStore.getState().isDrawerOpen).toBe(false);
});
it("opens and closes", () => {
useCopilotUIStore.getState().setDrawerOpen(true);
expect(useCopilotUIStore.getState().isDrawerOpen).toBe(true);
useCopilotUIStore.getState().setDrawerOpen(false);
expect(useCopilotUIStore.getState().isDrawerOpen).toBe(false);
});
});
describe("completedSessionIDs", () => {
it("starts empty", () => {
expect(useCopilotUIStore.getState().completedSessionIDs.size).toBe(0);
});
it("adds a completed session", () => {
useCopilotUIStore.getState().addCompletedSession("s1");
expect(useCopilotUIStore.getState().completedSessionIDs.has("s1")).toBe(
true,
);
});
it("persists added sessions to localStorage", () => {
useCopilotUIStore.getState().addCompletedSession("s1");
useCopilotUIStore.getState().addCompletedSession("s2");
const raw = window.localStorage.getItem("copilot-completed-sessions");
expect(raw).not.toBeNull();
const parsed = JSON.parse(raw!) as string[];
expect(parsed).toContain("s1");
expect(parsed).toContain("s2");
});
it("clears a single completed session", () => {
useCopilotUIStore.getState().addCompletedSession("s1");
useCopilotUIStore.getState().addCompletedSession("s2");
useCopilotUIStore.getState().clearCompletedSession("s1");
expect(useCopilotUIStore.getState().completedSessionIDs.has("s1")).toBe(
false,
);
expect(useCopilotUIStore.getState().completedSessionIDs.has("s2")).toBe(
true,
);
});
it("updates localStorage when a session is cleared", () => {
useCopilotUIStore.getState().addCompletedSession("s1");
useCopilotUIStore.getState().addCompletedSession("s2");
useCopilotUIStore.getState().clearCompletedSession("s1");
const raw = window.localStorage.getItem("copilot-completed-sessions");
const parsed = JSON.parse(raw!) as string[];
expect(parsed).not.toContain("s1");
expect(parsed).toContain("s2");
});
it("clears all completed sessions", () => {
useCopilotUIStore.getState().addCompletedSession("s1");
useCopilotUIStore.getState().addCompletedSession("s2");
useCopilotUIStore.getState().clearAllCompletedSessions();
expect(useCopilotUIStore.getState().completedSessionIDs.size).toBe(0);
});
it("removes localStorage key when all sessions are cleared", () => {
useCopilotUIStore.getState().addCompletedSession("s1");
useCopilotUIStore.getState().clearAllCompletedSessions();
expect(
window.localStorage.getItem("copilot-completed-sessions"),
).toBeNull();
});
});
describe("sound toggle", () => {
it("starts enabled", () => {
expect(useCopilotUIStore.getState().isSoundEnabled).toBe(true);
});
it("toggles sound off and on", () => {
useCopilotUIStore.getState().toggleSound();
expect(useCopilotUIStore.getState().isSoundEnabled).toBe(false);
useCopilotUIStore.getState().toggleSound();
expect(useCopilotUIStore.getState().isSoundEnabled).toBe(true);
});
it("persists to localStorage", () => {
useCopilotUIStore.getState().toggleSound();
expect(window.localStorage.getItem("copilot-sound-enabled")).toBe(
"false",
);
});
});
describe("copilotMode", () => {
it("defaults to extended_thinking", () => {
expect(useCopilotUIStore.getState().copilotMode).toBe(
"extended_thinking",
);
});
it("sets mode to fast", () => {
useCopilotUIStore.getState().setCopilotMode("fast");
expect(useCopilotUIStore.getState().copilotMode).toBe("fast");
expect(window.localStorage.getItem("copilot-mode")).toBe("fast");
});
it("sets mode back to extended_thinking", () => {
useCopilotUIStore.getState().setCopilotMode("fast");
useCopilotUIStore.getState().setCopilotMode("extended_thinking");
expect(useCopilotUIStore.getState().copilotMode).toBe(
"extended_thinking",
);
});
});
describe("clearCopilotLocalData", () => {
it("resets state and clears localStorage keys", () => {
useCopilotUIStore.getState().setCopilotMode("fast");
useCopilotUIStore.getState().setNotificationsEnabled(true);
useCopilotUIStore.getState().toggleSound();
useCopilotUIStore.getState().addCompletedSession("s1");
useCopilotUIStore.getState().clearCopilotLocalData();
const state = useCopilotUIStore.getState();
expect(state.copilotMode).toBe("extended_thinking");
expect(state.isNotificationsEnabled).toBe(false);
expect(state.isSoundEnabled).toBe(true);
expect(state.completedSessionIDs.size).toBe(0);
expect(window.localStorage.getItem("copilot-mode")).toBeNull();
expect(
window.localStorage.getItem("copilot-notifications-enabled"),
).toBeNull();
expect(window.localStorage.getItem("copilot-sound-enabled")).toBeNull();
expect(
window.localStorage.getItem("copilot-completed-sessions"),
).toBeNull();
});
});
describe("notifications", () => {
it("sets notification preference", () => {
useCopilotUIStore.getState().setNotificationsEnabled(true);
expect(useCopilotUIStore.getState().isNotificationsEnabled).toBe(true);
expect(window.localStorage.getItem("copilot-notifications-enabled")).toBe(
"true",
);
});
it("shows and hides notification dialog", () => {
useCopilotUIStore.getState().setShowNotificationDialog(true);
expect(useCopilotUIStore.getState().showNotificationDialog).toBe(true);
useCopilotUIStore.getState().setShowNotificationDialog(false);
expect(useCopilotUIStore.getState().showNotificationDialog).toBe(false);
});
});
});

View File

@@ -5,17 +5,21 @@ import {
PromptInputTextarea,
PromptInputTools,
} from "@/components/ai-elements/prompt-input";
import { toast } from "@/components/molecules/Toast/use-toast";
import { InputGroup } from "@/components/ui/input-group";
import { cn } from "@/lib/utils";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { ChangeEvent, useEffect, useState } from "react";
import { AttachmentMenu } from "./components/AttachmentMenu";
import { FileChips } from "./components/FileChips";
import { ModeToggleButton } from "./components/ModeToggleButton";
import { RecordingButton } from "./components/RecordingButton";
import { RecordingIndicator } from "./components/RecordingIndicator";
import { useCopilotUIStore } from "../../store";
import { useChatInput } from "./useChatInput";
import { useVoiceRecording } from "./useVoiceRecording";
export interface Props {
interface Props {
onSend: (message: string, files?: File[]) => void | Promise<void>;
disabled?: boolean;
isStreaming?: boolean;
@@ -42,8 +46,26 @@ export function ChatInput({
droppedFiles,
onDroppedFilesConsumed,
}: Props) {
const { copilotMode, setCopilotMode } = useCopilotUIStore();
const showModeToggle = useGetFlag(Flag.CHAT_MODE_OPTION);
const [files, setFiles] = useState<File[]>([]);
function handleToggleMode() {
const next =
copilotMode === "extended_thinking" ? "fast" : "extended_thinking";
setCopilotMode(next);
toast({
title:
next === "fast"
? "Switched to Fast mode"
: "Switched to Extended Thinking mode",
description:
next === "fast"
? "Response quality may differ."
: "Responses may take longer.",
});
}
// Merge files dropped onto the chat window into internal state.
useEffect(() => {
if (droppedFiles && droppedFiles.length > 0) {
@@ -157,6 +179,13 @@ export function ChatInput({
onFilesSelected={handleFilesSelected}
disabled={isBusy}
/>
{showModeToggle && (
<ModeToggleButton
mode={copilotMode}
isStreaming={isStreaming}
onToggle={handleToggleMode}
/>
)}
</PromptInputTools>
<div className="flex items-center gap-4">

View File

@@ -0,0 +1,199 @@
import {
render,
screen,
fireEvent,
cleanup,
} from "@/tests/integrations/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import { ChatInput } from "../ChatInput";
let mockCopilotMode = "extended_thinking";
const mockSetCopilotMode = vi.fn((mode: string) => {
mockCopilotMode = mode;
});
vi.mock("@/app/(platform)/copilot/store", () => ({
useCopilotUIStore: () => ({
copilotMode: mockCopilotMode,
setCopilotMode: mockSetCopilotMode,
initialPrompt: null,
setInitialPrompt: vi.fn(),
}),
}));
let mockFlagValue = false;
vi.mock("@/services/feature-flags/use-get-flag", () => ({
Flag: { CHAT_MODE_OPTION: "CHAT_MODE_OPTION" },
useGetFlag: () => mockFlagValue,
}));
vi.mock("@/components/molecules/Toast/use-toast", () => ({
toast: vi.fn(),
useToast: () => ({ toast: vi.fn(), dismiss: vi.fn() }),
}));
vi.mock("../useVoiceRecording", () => ({
useVoiceRecording: () => ({
isRecording: false,
isTranscribing: false,
elapsedTime: 0,
toggleRecording: vi.fn(),
handleKeyDown: vi.fn(),
showMicButton: false,
isInputDisabled: false,
audioStream: null,
}),
}));
vi.mock("@/components/ai-elements/prompt-input", () => ({
PromptInputBody: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PromptInputFooter: ({ children }: { children: React.ReactNode }) => (
<div>{children}</div>
),
PromptInputSubmit: ({ disabled }: { disabled?: boolean }) => (
<button disabled={disabled} data-testid="submit">
Send
</button>
),
PromptInputTextarea: (props: {
id?: string;
value?: string;
onChange?: React.ChangeEventHandler<HTMLTextAreaElement>;
disabled?: boolean;
placeholder?: string;
}) => (
<textarea
id={props.id}
value={props.value}
onChange={props.onChange}
disabled={props.disabled}
placeholder={props.placeholder}
data-testid="textarea"
/>
),
PromptInputTools: ({ children }: { children: React.ReactNode }) => (
<div data-testid="tools">{children}</div>
),
}));
vi.mock("@/components/ui/input-group", () => ({
InputGroup: ({
children,
className,
}: {
children: React.ReactNode;
className?: string;
}) => <div className={className}>{children}</div>,
}));
vi.mock("../components/AttachmentMenu", () => ({
AttachmentMenu: () => <div data-testid="attachment-menu" />,
}));
vi.mock("../components/FileChips", () => ({
FileChips: () => null,
}));
vi.mock("../components/RecordingButton", () => ({
RecordingButton: () => null,
}));
vi.mock("../components/RecordingIndicator", () => ({
RecordingIndicator: () => null,
}));
const mockOnSend = vi.fn();
afterEach(() => {
cleanup();
vi.clearAllMocks();
mockCopilotMode = "extended_thinking";
});
describe("ChatInput mode toggle", () => {
it("does not render mode toggle when flag is disabled", () => {
mockFlagValue = false;
render(<ChatInput onSend={mockOnSend} />);
expect(screen.queryByLabelText(/switch to/i)).toBeNull();
});
it("renders mode toggle when flag is enabled", () => {
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} />);
expect(screen.getByLabelText(/switch to fast mode/i)).toBeDefined();
});
it("shows Thinking label in extended_thinking mode", () => {
mockFlagValue = true;
mockCopilotMode = "extended_thinking";
render(<ChatInput onSend={mockOnSend} />);
expect(screen.getByText("Thinking")).toBeDefined();
});
it("shows Fast label in fast mode", () => {
mockFlagValue = true;
mockCopilotMode = "fast";
render(<ChatInput onSend={mockOnSend} />);
expect(screen.getByText("Fast")).toBeDefined();
});
it("toggles from extended_thinking to fast on click", () => {
mockFlagValue = true;
mockCopilotMode = "extended_thinking";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to fast mode/i));
expect(mockSetCopilotMode).toHaveBeenCalledWith("fast");
});
it("toggles from fast to extended_thinking on click", () => {
mockFlagValue = true;
mockCopilotMode = "fast";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to extended thinking/i));
expect(mockSetCopilotMode).toHaveBeenCalledWith("extended_thinking");
});
it("disables toggle button when streaming", () => {
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} isStreaming />);
const button = screen.getByLabelText(/switch to fast mode/i);
expect(button.hasAttribute("disabled")).toBe(true);
});
it("exposes aria-pressed=true in extended_thinking mode", () => {
mockFlagValue = true;
mockCopilotMode = "extended_thinking";
render(<ChatInput onSend={mockOnSend} />);
const button = screen.getByLabelText(/switch to fast mode/i);
expect(button.getAttribute("aria-pressed")).toBe("true");
});
it("sets aria-pressed=false in fast mode", () => {
mockFlagValue = true;
mockCopilotMode = "fast";
render(<ChatInput onSend={mockOnSend} />);
const button = screen.getByLabelText(/switch to extended thinking/i);
expect(button.getAttribute("aria-pressed")).toBe("false");
});
it("uses streaming-specific tooltip when disabled", () => {
mockFlagValue = true;
render(<ChatInput onSend={mockOnSend} isStreaming />);
const button = screen.getByLabelText(/switch to fast mode/i);
expect(button.getAttribute("title")).toBe(
"Mode cannot be changed while streaming",
);
});
it("shows a toast when the user toggles mode", async () => {
const { toast } = await import("@/components/molecules/Toast/use-toast");
mockFlagValue = true;
mockCopilotMode = "extended_thinking";
render(<ChatInput onSend={mockOnSend} />);
fireEvent.click(screen.getByLabelText(/switch to fast mode/i));
expect(toast).toHaveBeenCalledWith(
expect.objectContaining({
title: expect.stringMatching(/switched to fast mode/i),
}),
);
});
});

View File

@@ -0,0 +1,122 @@
import { renderHook, act } from "@testing-library/react";
import { describe, expect, it, vi, beforeEach } from "vitest";
import { useChatInput } from "../useChatInput";
vi.mock("@/app/(platform)/copilot/store", () => ({
useCopilotUIStore: () => ({
initialPrompt: null,
setInitialPrompt: vi.fn(),
}),
}));
describe("useChatInput", () => {
const mockOnSend = vi.fn();
beforeEach(() => {
vi.clearAllMocks();
mockOnSend.mockResolvedValue(undefined);
});
it("does not send when value is empty", async () => {
const { result } = renderHook(() => useChatInput({ onSend: mockOnSend }));
await act(async () => {
await result.current.handleSend();
});
expect(mockOnSend).not.toHaveBeenCalled();
});
it("sends trimmed value and clears input", async () => {
const { result } = renderHook(() => useChatInput({ onSend: mockOnSend }));
act(() => {
result.current.setValue(" hello ");
});
await act(async () => {
await result.current.handleSend();
});
expect(mockOnSend).toHaveBeenCalledWith("hello");
expect(result.current.value).toBe("");
});
it("does not send when disabled", async () => {
const { result } = renderHook(() =>
useChatInput({ onSend: mockOnSend, disabled: true }),
);
act(() => {
result.current.setValue("hello");
});
await act(async () => {
await result.current.handleSend();
});
expect(mockOnSend).not.toHaveBeenCalled();
});
it("prevents double-submit via ref guard", async () => {
let resolveFirst: () => void;
const slowSend = vi.fn(
() =>
new Promise<void>((resolve) => {
resolveFirst = resolve;
}),
);
const { result } = renderHook(() => useChatInput({ onSend: slowSend }));
act(() => {
result.current.setValue("hello");
});
act(() => {
void result.current.handleSend();
});
await act(async () => {
await result.current.handleSend();
});
expect(slowSend).toHaveBeenCalledTimes(1);
await act(async () => {
resolveFirst!();
});
});
it("allows sending empty when canSendEmpty is true", async () => {
const { result } = renderHook(() =>
useChatInput({ onSend: mockOnSend, canSendEmpty: true }),
);
await act(async () => {
await result.current.handleSend();
});
expect(mockOnSend).toHaveBeenCalledWith("");
});
it("resets isSending after onSend throws", async () => {
mockOnSend.mockRejectedValue(new Error("fail"));
const { result } = renderHook(() => useChatInput({ onSend: mockOnSend }));
act(() => {
result.current.setValue("hello");
});
await act(async () => {
try {
await result.current.handleSend();
} catch {
// expected
}
});
expect(result.current.isSending).toBe(false);
});
});

View File

@@ -0,0 +1,53 @@
"use client";
import { cn } from "@/lib/utils";
import { Brain, Lightning } from "@phosphor-icons/react";
type CopilotMode = "extended_thinking" | "fast";
interface Props {
mode: CopilotMode;
isStreaming: boolean;
onToggle: () => void;
}
export function ModeToggleButton({ mode, isStreaming, onToggle }: Props) {
const isExtended = mode === "extended_thinking";
return (
<button
type="button"
aria-pressed={isExtended}
disabled={isStreaming}
onClick={onToggle}
className={cn(
"inline-flex min-h-11 min-w-11 items-center justify-center gap-1 rounded-md px-2 py-1 text-xs font-medium transition-colors",
isExtended
? "bg-purple-100 text-purple-700 hover:bg-purple-200 dark:bg-purple-900/30 dark:text-purple-300"
: "bg-amber-100 text-amber-700 hover:bg-amber-200 dark:bg-amber-900/30 dark:text-amber-300",
isStreaming && "cursor-not-allowed opacity-50",
)}
aria-label={
isExtended ? "Switch to Fast mode" : "Switch to Extended Thinking mode"
}
title={
isStreaming
? "Mode cannot be changed while streaming"
: isExtended
? "Extended Thinking mode — deeper reasoning (click to switch to Fast mode)"
: "Fast mode — quicker responses (click to switch to Extended Thinking)"
}
>
{isExtended ? (
<>
<Brain size={14} />
Thinking
</>
) : (
<>
<Lightning size={14} />
Fast
</>
)}
</button>
);
}

View File

@@ -1,5 +1,5 @@
import { useCopilotUIStore } from "@/app/(platform)/copilot/store";
import { ChangeEvent, FormEvent, useEffect, useState } from "react";
import { ChangeEvent, FormEvent, useEffect, useRef, useState } from "react";
interface Args {
onSend: (message: string) => void;
@@ -17,6 +17,9 @@ export function useChatInput({
}: Args) {
const [value, setValue] = useState("");
const [isSending, setIsSending] = useState(false);
// Synchronous guard against double-submit — refs update immediately,
// unlike state which batches and can leave a gap for a second call.
const isSubmittingRef = useRef(false);
const { initialPrompt, setInitialPrompt } = useCopilotUIStore();
useEffect(
@@ -47,12 +50,15 @@ export function useChatInput({
async function handleSend() {
if (disabled || isSending || (!value.trim() && !canSendEmpty)) return;
if (isSubmittingRef.current) return;
isSubmittingRef.current = true;
setIsSending(true);
try {
await onSend(value.trim());
setValue("");
} finally {
isSubmittingRef.current = false;
setIsSending(false);
}
}

View File

@@ -1,8 +1,11 @@
import type { UIMessage } from "ai";
import { describe, expect, it } from "vitest";
import {
ORIGINAL_TITLE,
extractSendMessageText,
formatNotificationTitle,
parseSessionIDs,
shouldSuppressDuplicateSend,
} from "./helpers";
describe("formatNotificationTitle", () => {
@@ -74,3 +77,118 @@ describe("parseSessionIDs", () => {
expect(parseSessionIDs('["a","a","b"]')).toEqual(new Set(["a", "b"]));
});
});
describe("extractSendMessageText", () => {
it("extracts text from a string argument", () => {
expect(extractSendMessageText("hello")).toBe("hello");
});
it("extracts text from an object with text property", () => {
expect(extractSendMessageText({ text: "world" })).toBe("world");
});
it("returns empty string for null", () => {
expect(extractSendMessageText(null)).toBe("");
});
it("returns empty string for undefined", () => {
expect(extractSendMessageText(undefined)).toBe("");
});
it("converts numbers to string", () => {
expect(extractSendMessageText(42)).toBe("42");
});
});
function makeMsg(role: "user" | "assistant", text: string): UIMessage {
return {
id: `msg-${Math.random()}`,
role,
parts: [{ type: "text", text }],
};
}
describe("shouldSuppressDuplicateSend", () => {
it("suppresses when reconnect is scheduled", () => {
expect(
shouldSuppressDuplicateSend({
text: "hello",
isReconnectScheduled: true,
lastSubmittedText: null,
messages: [],
}),
).toBe(true);
});
it("allows send when not reconnecting and no prior submission", () => {
expect(
shouldSuppressDuplicateSend({
text: "hello",
isReconnectScheduled: false,
lastSubmittedText: null,
messages: [],
}),
).toBe(false);
});
it("suppresses when text matches last submitted AND last user message", () => {
const messages = [makeMsg("user", "hello"), makeMsg("assistant", "hi")];
expect(
shouldSuppressDuplicateSend({
text: "hello",
isReconnectScheduled: false,
lastSubmittedText: "hello",
messages,
}),
).toBe(true);
});
it("allows send when text matches last submitted but differs from last user message", () => {
const messages = [
makeMsg("user", "different"),
makeMsg("assistant", "reply"),
];
expect(
shouldSuppressDuplicateSend({
text: "hello",
isReconnectScheduled: false,
lastSubmittedText: "hello",
messages,
}),
).toBe(false);
});
it("allows send when text differs from last submitted", () => {
const messages = [makeMsg("user", "hello")];
expect(
shouldSuppressDuplicateSend({
text: "new message",
isReconnectScheduled: false,
lastSubmittedText: "hello",
messages,
}),
).toBe(false);
});
it("allows send when text is empty", () => {
expect(
shouldSuppressDuplicateSend({
text: "",
isReconnectScheduled: false,
lastSubmittedText: "",
messages: [],
}),
).toBe(false);
});
it("allows send with empty messages array even if text matches lastSubmitted", () => {
expect(
shouldSuppressDuplicateSend({
text: "hello",
isReconnectScheduled: false,
lastSubmittedText: "hello",
messages: [],
}),
).toBe(false);
});
});

View File

@@ -65,6 +65,72 @@ export function resolveInProgressTools(
}));
}
/**
* Extract the user-visible text from the arguments passed to `sendMessage`.
* Handles both `sendMessage("hello")` and `sendMessage({ text: "hello" })`.
*/
export function extractSendMessageText(firstArg: unknown): string {
if (firstArg && typeof firstArg === "object" && "text" in firstArg)
return (firstArg as { text: string }).text;
return String(firstArg ?? "");
}
interface SuppressDuplicateArgs {
text: string;
isReconnectScheduled: boolean;
lastSubmittedText: string | null;
messages: UIMessage[];
}
/**
* Reason a sendMessage was suppressed, or ``null`` to pass through.
*
* - ``"reconnecting"``: the stream is reconnecting; the caller should
* notify the user (the UI may not yet reflect the disabled state).
* - ``"duplicate"``: the same text was just submitted and echoed back
* by the session — safe to silently drop (user double-clicked).
*/
export type SuppressReason = "reconnecting" | "duplicate" | null;
/**
* Determine whether a sendMessage call should be suppressed to prevent
* duplicate POSTs during reconnect cycles or re-submits of the same text.
*
* Returns the reason so callers can surface user-visible feedback when
* the suppression isn't just a silent duplicate.
*/
export function getSendSuppressionReason({
text,
isReconnectScheduled,
lastSubmittedText,
messages,
}: SuppressDuplicateArgs): SuppressReason {
if (isReconnectScheduled) return "reconnecting";
if (text && lastSubmittedText === text) {
const lastUserMsg = messages.filter((m) => m.role === "user").pop();
const lastUserText = lastUserMsg?.parts
?.map((p) => ("text" in p ? p.text : ""))
.join("")
.trim();
if (lastUserText === text) return "duplicate";
}
return null;
}
/**
* Backwards-compatible boolean wrapper for ``getSendSuppressionReason``.
*
* @deprecated Call ``getSendSuppressionReason`` directly so callers can
* distinguish between reconnect and duplicate suppression.
*/
export function shouldSuppressDuplicateSend(
args: SuppressDuplicateArgs,
): boolean {
return getSendSuppressionReason(args) !== null;
}
/**
* Deduplicate messages by ID and by consecutive content fingerprint.
*

View File

@@ -47,6 +47,10 @@ interface CopilotUIState {
showNotificationDialog: boolean;
setShowNotificationDialog: (show: boolean) => void;
/** Autopilot mode: 'extended_thinking' (default) or 'fast'. */
copilotMode: "extended_thinking" | "fast";
setCopilotMode: (mode: "extended_thinking" | "fast") => void;
clearCopilotLocalData: () => void;
}
@@ -104,16 +108,27 @@ export const useCopilotUIStore = create<CopilotUIState>((set) => ({
showNotificationDialog: false,
setShowNotificationDialog: (show) => set({ showNotificationDialog: show }),
copilotMode:
isClient && storage.get(Key.COPILOT_MODE) === "fast"
? "fast"
: "extended_thinking",
setCopilotMode: (mode) => {
storage.set(Key.COPILOT_MODE, mode);
set({ copilotMode: mode });
},
clearCopilotLocalData: () => {
storage.clean(Key.COPILOT_NOTIFICATIONS_ENABLED);
storage.clean(Key.COPILOT_SOUND_ENABLED);
storage.clean(Key.COPILOT_NOTIFICATION_BANNER_DISMISSED);
storage.clean(Key.COPILOT_NOTIFICATION_DIALOG_DISMISSED);
storage.clean(Key.COPILOT_MODE);
storage.clean(Key.COPILOT_COMPLETED_SESSIONS);
set({
completedSessionIDs: new Set<string>(),
isNotificationsEnabled: false,
isSoundEnabled: true,
copilotMode: "extended_thinking",
});
if (isClient) {
document.title = ORIGINAL_TITLE;

View File

@@ -10,6 +10,7 @@ import { useBreakpoint } from "@/lib/hooks/useBreakpoint";
import { useSupabase } from "@/lib/supabase/hooks/useSupabase";
import { useQueryClient } from "@tanstack/react-query";
import type { FileUIPart } from "ai";
import { Flag, useGetFlag } from "@/services/feature-flags/use-get-flag";
import { useEffect, useRef, useState } from "react";
import { useCopilotUIStore } from "./store";
import { useChatSession } from "./useChatSession";
@@ -32,8 +33,15 @@ export function useCopilotPage() {
const [pendingMessage, setPendingMessage] = useState<string | null>(null);
const queryClient = useQueryClient();
const { sessionToDelete, setSessionToDelete, isDrawerOpen, setDrawerOpen } =
useCopilotUIStore();
const isModeToggleEnabled = useGetFlag(Flag.CHAT_MODE_OPTION);
const {
sessionToDelete,
setSessionToDelete,
isDrawerOpen,
setDrawerOpen,
copilotMode,
} = useCopilotUIStore();
const {
sessionId,
@@ -64,6 +72,7 @@ export function useCopilotPage() {
hydratedMessages,
hasActiveStream,
refetchSession,
copilotMode: isModeToggleEnabled ? copilotMode : undefined,
});
useCopilotNotifications(sessionId);

View File

@@ -10,11 +10,13 @@ import { useChat } from "@ai-sdk/react";
import { useQueryClient } from "@tanstack/react-query";
import { DefaultChatTransport } from "ai";
import type { FileUIPart, UIMessage } from "ai";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useEffect, useMemo, useRef, useState } from "react";
import {
deduplicateMessages,
extractSendMessageText,
hasActiveBackendStream,
resolveInProgressTools,
getSendSuppressionReason,
} from "./helpers";
const RECONNECT_BASE_DELAY_MS = 1_000;
@@ -38,6 +40,8 @@ interface UseCopilotStreamArgs {
hydratedMessages: UIMessage[] | undefined;
hasActiveStream: boolean;
refetchSession: () => Promise<{ data?: unknown }>;
/** Autopilot mode to use for requests. `undefined` = let backend decide via feature flags. */
copilotMode: "extended_thinking" | "fast" | undefined;
}
export function useCopilotStream({
@@ -45,10 +49,18 @@ export function useCopilotStream({
hydratedMessages,
hasActiveStream,
refetchSession,
copilotMode,
}: UseCopilotStreamArgs) {
const queryClient = useQueryClient();
const [rateLimitMessage, setRateLimitMessage] = useState<string | null>(null);
const dismissRateLimit = useCallback(() => setRateLimitMessage(null), []);
function dismissRateLimit() {
setRateLimitMessage(null);
}
// Use a ref for copilotMode so the transport closure always reads the
// latest value without recreating the DefaultChatTransport (which would
// reset useChat's internal Chat instance and break mid-session streaming).
const copilotModeRef = useRef(copilotMode);
copilotModeRef.current = copilotMode;
// Connect directly to the Python backend for SSE, bypassing the Next.js
// serverless proxy. This eliminates the Vercel 800s function timeout that
@@ -79,6 +91,7 @@ export function useCopilotStream({
is_user_message: last.role === "user",
context: null,
file_ids: fileIds && fileIds.length > 0 ? fileIds : null,
mode: copilotModeRef.current ?? null,
},
headers: await getAuthHeaders(),
};
@@ -147,9 +160,14 @@ export function useCopilotStream({
}, delay);
}
// Tracks the ID of the last user message that was submitted via sendMessage.
// During a reconnect cycle, if the session already contains this message, we
// must not POST it again — only GET-resume is safe.
const lastSubmittedMsgRef = useRef<string | null>(null);
const {
messages: rawMessages,
sendMessage,
sendMessage: sdkSendMessage,
stop: sdkStop,
status,
error,
@@ -236,6 +254,36 @@ export function useCopilotStream({
},
});
// Wrap sdkSendMessage to guard against re-sending the user message during a
// reconnect cycle. If the session already has the message (i.e. we are in a
// reconnect/resume flow), only GET-resume is safe — never re-POST.
const sendMessage: typeof sdkSendMessage = async (...args) => {
const text = extractSendMessageText(args[0]);
const suppressReason = getSendSuppressionReason({
text,
isReconnectScheduled: isReconnectScheduledRef.current,
lastSubmittedText: lastSubmittedMsgRef.current,
messages: rawMessages,
});
if (suppressReason === "reconnecting") {
// The ref flips to ``true`` synchronously while the React state that
// drives the UI's disabled state only updates on the next render, so
// the user may have clicked send against a still-enabled input. Tell
// them their message wasn't dropped silently.
toast({
title: "Reconnecting",
description: "Wait for the connection to resume before sending.",
});
return;
}
if (suppressReason === "duplicate") return;
lastSubmittedMsgRef.current = text;
return sdkSendMessage(...args);
};
// Deduplicate messages continuously to prevent duplicates when resuming streams
const messages = useMemo(
() => deduplicateMessages(rawMessages),
@@ -381,6 +429,7 @@ export function useCopilotStream({
setRateLimitMessage(null);
hasShownDisconnectToast.current = false;
isUserStoppingRef.current = false;
lastSubmittedMsgRef.current = null;
setReconnectExhausted(false);
setIsSyncing(false);
hasResumedRef.current.clear();
@@ -409,6 +458,7 @@ export function useCopilotStream({
if (status === "ready") {
reconnectAttemptsRef.current = 0;
hasShownDisconnectToast.current = false;
lastSubmittedMsgRef.current = null;
setReconnectExhausted(false);
}
}

View File

@@ -13175,6 +13175,14 @@
{ "type": "null" }
],
"title": "File Ids"
},
"mode": {
"anyOf": [
{ "type": "string", "enum": ["fast", "extended_thinking"] },
{ "type": "null" }
],
"title": "Mode",
"description": "Autopilot mode: 'fast' for baseline LLM, 'extended_thinking' for Claude Agent SDK. If None, uses the server default (extended_thinking)."
}
},
"type": "object",

View File

@@ -13,6 +13,7 @@ export enum Flag {
AGENT_FAVORITING = "agent-favoriting",
MARKETPLACE_SEARCH_TERMS = "marketplace-search-terms",
ENABLE_PLATFORM_PAYMENT = "enable-platform-payment",
CHAT_MODE_OPTION = "chat-mode-option",
}
const isPwMockEnabled = process.env.NEXT_PUBLIC_PW_TEST === "true";
@@ -26,6 +27,7 @@ const defaultFlags = {
[Flag.AGENT_FAVORITING]: false,
[Flag.MARKETPLACE_SEARCH_TERMS]: DEFAULT_SEARCH_TERMS,
[Flag.ENABLE_PLATFORM_PAYMENT]: false,
[Flag.CHAT_MODE_OPTION]: true,
};
type FlagValues = typeof defaultFlags;

View File

@@ -0,0 +1,68 @@
import { describe, expect, it, beforeEach, vi } from "vitest";
vi.mock("@sentry/nextjs", () => ({
captureException: vi.fn(),
}));
vi.mock("@/services/environment", () => ({
environment: {
isServerSide: vi.fn(() => false),
},
}));
import { Key, storage } from "../local-storage";
import { environment } from "@/services/environment";
describe("storage", () => {
beforeEach(() => {
window.localStorage.clear();
vi.mocked(environment.isServerSide).mockReturnValue(false);
});
describe("set and get", () => {
it("stores and retrieves a value", () => {
storage.set(Key.COPILOT_MODE, "fast");
expect(storage.get(Key.COPILOT_MODE)).toBe("fast");
});
it("returns null for unset keys", () => {
expect(storage.get(Key.COPILOT_MODE)).toBeNull();
});
});
describe("clean", () => {
it("removes a stored value", () => {
storage.set(Key.COPILOT_SOUND_ENABLED, "true");
storage.clean(Key.COPILOT_SOUND_ENABLED);
expect(storage.get(Key.COPILOT_SOUND_ENABLED)).toBeNull();
});
});
describe("server-side guard", () => {
it("returns undefined for get when on server side", () => {
vi.mocked(environment.isServerSide).mockReturnValue(true);
expect(storage.get(Key.COPILOT_MODE)).toBeUndefined();
});
it("returns undefined for set when on server side", () => {
vi.mocked(environment.isServerSide).mockReturnValue(true);
expect(storage.set(Key.COPILOT_MODE, "fast")).toBeUndefined();
});
it("returns undefined for clean when on server side", () => {
vi.mocked(environment.isServerSide).mockReturnValue(true);
expect(storage.clean(Key.COPILOT_MODE)).toBeUndefined();
});
});
});
describe("Key enum", () => {
it("has expected keys", () => {
expect(Key.COPILOT_MODE).toBe("copilot-mode");
expect(Key.COPILOT_SOUND_ENABLED).toBe("copilot-sound-enabled");
expect(Key.COPILOT_NOTIFICATIONS_ENABLED).toBe(
"copilot-notifications-enabled",
);
expect(Key.CHAT_SESSION_ID).toBe("chat_session_id");
});
});

View File

@@ -15,6 +15,7 @@ export enum Key {
COPILOT_NOTIFICATIONS_ENABLED = "copilot-notifications-enabled",
COPILOT_NOTIFICATION_BANNER_DISMISSED = "copilot-notification-banner-dismissed",
COPILOT_NOTIFICATION_DIALOG_DISMISSED = "copilot-notification-dialog-dismissed",
COPILOT_MODE = "copilot-mode",
COPILOT_COMPLETED_SESSIONS = "copilot-completed-sessions",
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 33 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 60 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 101 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 77 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 108 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB