Merge remote-tracking branch 'origin/dev' into feat/task-decomposition-copilot

# Conflicts:
#	autogpt_platform/backend/backend/api/features/chat/routes.py
#	autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md
#	autogpt_platform/backend/backend/copilot/tools/tool_schema_test.py
This commit is contained in:
anvyle
2026-04-23 06:25:11 +02:00
63 changed files with 7603 additions and 906 deletions

View File

@@ -0,0 +1,245 @@
---
name: pr-polish
description: Alternate /pr-review and /pr-address on a PR until the PR is truly mergeable — no new review findings, zero unresolved inline threads, zero unaddressed top-level reviews or issue comments, all CI checks green, and two consecutive quiet polls after CI settles. Use when the user wants a PR polished to merge-ready without setting a fixed number of rounds.
user-invocable: true
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
metadata:
author: autogpt-team
version: "1.0.0"
---
# PR Polish
**Goal.** Drive a PR to merge-ready by alternating `/pr-review` and `/pr-address` until **all** of the following hold:
1. The most recent `/pr-review` produces **zero new findings** (no new inline comments, no new top-level reviews with a non-empty body).
2. Every inline review thread reachable via GraphQL reports `isResolved: true`.
3. Every non-bot, non-author top-level review has been acknowledged (replied-to) OR resolved via a thread it spawned.
4. Every non-bot, non-author issue comment has been acknowledged (replied-to).
5. Every CI check is `conclusion: "success"` or `"skipped"` / `"neutral"` — none `"failure"` or still pending.
6. **Two consecutive post-CI polls** (≥60s apart) stay clean — no new threads, no new non-empty reviews, no new issue comments. Bots (coderabbitai, sentry, autogpt-reviewer) frequently post late after CI settles; a single green snapshot is not sufficient.
**Do not stop at a fixed number of rounds.** If round N introduces new comments, round N+1 is required. Cap at `_MAX_ROUNDS = 10` as a safety valve, but expect 25 in practice.
## TodoWrite
Before starting, write two todos so the user can see the loop progression:
- `Round {current}: /pr-review + /pr-address on PR #{N}` — current iteration.
- `Final polish polling: 2 consecutive clean polls, CI green, 0 unresolved` — runs after the last non-empty review round.
Update the `current` round counter at the start of each iteration; mark `completed` only when the round's address step finishes (all new threads addressed + resolved).
## Find the PR
```bash
ARG_PR="${ARG:-}"
# Normalize URL → numeric ID if the skill arg is a pull-request URL.
if [[ "$ARG_PR" =~ ^https?://github\.com/[^/]+/[^/]+/pull/([0-9]+) ]]; then
ARG_PR="${BASH_REMATCH[1]}"
fi
PR="${ARG_PR:-$(gh pr list --head "$(git branch --show-current)" --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')}"
if [ -z "$PR" ] || [ "$PR" = "null" ]; then
echo "No PR found for current branch. Provide a PR number or URL as the skill arg."
exit 1
fi
echo "Polishing PR #$PR"
```
## The outer loop
```text
round = 0
while round < _MAX_ROUNDS:
round += 1
baseline = snapshot_state(PR) # see "Snapshotting state" below
invoke_skill("pr-review", PR) # posts findings as inline comments / top-level review
findings = diff_state(PR, baseline)
if findings.total == 0:
break # no new findings → go to polish polling
invoke_skill("pr-address", PR) # resolves every unresolved thread + CI failure
# Post-loop: polish polling (see below).
polish_polling(PR)
```
### Snapshotting state
Before each `/pr-review`, capture a baseline so the diff after the review reflects **only** what the review just added (not pre-existing threads):
```bash
# Inline threads — total count + latest databaseId per thread
gh api graphql -f query="
{
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
pullRequest(number: ${PR}) {
reviewThreads(first: 100) {
totalCount
nodes {
id
isResolved
comments(last: 1) { nodes { databaseId } }
}
}
}
}
}" > /tmp/baseline_threads.json
# Top-level reviews — count + latest id per non-empty review
gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}/reviews" --paginate \
--jq '[.[] | select((.body // "") != "") | {id, user: .user.login, state, submitted_at}]' \
> /tmp/baseline_reviews.json
# Issue comments — count + latest id per non-bot, non-author comment.
# Bots are filtered by User.type == "Bot" (GitHub sets this for app/bot
# accounts like coderabbitai, github-actions, sentry-io). The author is
# filtered by comparing login to the PR author — export it so jq can see it.
AUTHOR=$(gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}" --jq '.user.login')
gh api "repos/Significant-Gravitas/AutoGPT/issues/${PR}/comments" --paginate \
--jq --arg author "$AUTHOR" \
'[.[] | select(.user.type != "Bot" and .user.login != $author)
| {id, user: .user.login, created_at}]' \
> /tmp/baseline_issue_comments.json
```
### Diffing after a review
After `/pr-review` runs, any of these counting as "new findings" means another address round is needed:
- New inline thread `id` not in the baseline.
- An existing thread whose latest comment `databaseId` is higher than the baseline's (new reply on an old thread).
- A new top-level review `id` with a non-empty body.
- A new issue comment `id` from a non-bot, non-author user.
If any of the four buckets is non-empty → not done; invoke `/pr-address` and loop.
## Polish polling
Once `/pr-review` produces zero new findings, do **not** exit yet. Bots (coderabbitai, sentry, autogpt-reviewer) commonly post late reviews after CI settles — 3090 seconds after the final push. Poll at 60-second intervals:
```text
NON_SUCCESS_TERMINAL = {"failure", "cancelled", "timed_out", "action_required", "startup_failure"}
clean_polls = 0
required_clean = 2
while clean_polls < required_clean:
# 1. CI gate — any terminal non-success conclusion (not just "failure")
# must trigger /pr-address. "success", "skipped", "neutral" are clean;
# anything else (including cancelled, timed_out, action_required) is a
# blocker that won't self-resolve.
ci = fetch_check_runs(PR)
if any ci.conclusion in NON_SUCCESS_TERMINAL:
invoke_skill("pr-address", PR) # address failures + any new comments
baseline = snapshot_state(PR) # reset — push during address invalidates old baseline
clean_polls = 0
continue
if any ci.conclusion is None (still in_progress):
sleep 60; continue # wait without counting this as clean
# 2. Comment / thread gate
threads = fetch_unresolved_threads(PR)
new_issue_comments = diff_against_baseline(issue_comments)
new_reviews = diff_against_baseline(reviews)
if threads or new_issue_comments or new_reviews:
invoke_skill("pr-address", PR)
baseline = snapshot_state(PR) # reset — the address loop just dealt with these,
# otherwise they stay "new" relative to the old baseline forever
clean_polls = 0
continue
# 3. Mergeability gate
mergeable = gh api repos/.../pulls/${PR} --jq '.mergeable'
if mergeable == false (CONFLICTING):
resolve_conflicts(PR) # see pr-address skill
clean_polls = 0
continue
if mergeable is null (UNKNOWN):
sleep 60; continue
clean_polls += 1
sleep 60
```
Only after `clean_polls == 2` do you report `ORCHESTRATOR:DONE`.
### Why 2 clean polls, not 1
A single green snapshot can be misleading — the final CI check often completes ~30s before a bot posts its delayed review. One quiet cycle does not prove the PR is stable; two consecutive cycles with no new threads, reviews, or issue comments arriving gives high confidence nothing else is incoming.
### Why checking every source each poll
`/pr-address` polling inside a single round already re-checks its own comments, but `/pr-polish` sits a level above and must also catch:
- New top-level reviews (autogpt-reviewer sometimes posts structured feedback only after several CI green cycles).
- Issue comments from human reviewers (not caught by inline thread polling).
- Sentry bug predictions that land on new line numbers post-push.
- Merge conflicts introduced by a race between your push and a merge to `dev`.
## Invocation pattern
Delegate to existing skills with the `Skill` tool; do not re-implement the review or address logic inline. This keeps the polish loop focused on orchestration and lets the child skills evolve independently.
```python
Skill(skill="pr-review", args=pr_url)
Skill(skill="pr-address", args=pr_url)
```
After each child invocation, re-query GitHub state directly — never trust a summary for the stop condition. The orchestrator's `ORCHESTRATOR:DONE` is verified against actual GraphQL / REST responses per the rules in `pr-address`'s "Verify actual count before outputting ORCHESTRATOR:DONE" section.
### **Auto-continue: do NOT end your response between child skills**
`/pr-polish` is a single orchestration task — one invocation drives the PR all the way to merge-ready. When a child `Skill()` call returns control to you:
- Do NOT summarize and stop.
- Do NOT wait for user confirmation to continue.
- Immediately, in the same response, perform the next loop step: state diff → decide next action → next `Skill()` call or polling sleep.
The child skill returning is a **loop iteration boundary**, not a conversation turn boundary. You are expected to keep going until one of the exit conditions in the opening section is met (2 consecutive clean polls, `_MAX_ROUNDS` hit, or an unrecoverable error).
If the user needs to approve a risky action mid-loop (e.g., a force-push or a destructive git operation), pause there — but not at the routine "round N finished, round N+1 needed" boundary. Those are silent transitions.
## GitHub rate limits
This skill issues many GraphQL calls (one review-thread query per outer iteration plus per-poll queries inside polish polling). Expect the GraphQL budget to be tight on large PRs. When `gh api rate_limit --jq .resources.graphql.remaining` drops below ~200, back off:
- Fall back to REST for reads (flat `/pulls/{N}/comments`, `/pulls/{N}/reviews`, `/issues/{N}/comments`) per the `pr-address` skill's GraphQL-fallback section.
- Queue thread resolutions (GraphQL-only) until the budget resets; keep making progress on fixes + REST replies meanwhile.
- `sleep 5` between any batch of ≥20 writes to avoid secondary rate limits.
## Safety valves
- `_MAX_ROUNDS = 10` — if review+address rounds exceed this, stop and escalate to the user with a summary of what's still unresolved. A PR that cannot converge in 10 rounds has systemic issues that need human judgment.
- After each commit, run `poetry run format` / `pnpm format && pnpm lint && pnpm types` per the target codebase's conventions. A failing format check is CI `failure` that will never self-resolve.
- Every `/pr-review` round checks for **duplicate** concerns first (via `pr-review`'s own "Fetch existing review comments" step) so the loop does not re-post the same finding that a prior round already resolved.
## Reporting
When the skill finishes (either via two clean polls or hitting `_MAX_ROUNDS`), produce a compact summary:
```
PR #{N} polish complete ({rounds_completed} rounds):
- {X} inline threads opened and resolved
- {Y} CI failures fixed
- {Z} new commits pushed
Final state: CI green, {total} threads all resolved, mergeable.
```
If exiting via `_MAX_ROUNDS`, flag explicitly:
```
PR #{N} polish stopped at {_MAX_ROUNDS} rounds — NOT merge-ready:
- {N} threads still unresolved: {titles}
- CI status: {summary}
Needs human review.
```
## When to use this skill
Use when the user says any of:
- "polish this PR"
- "keep reviewing and addressing until it's mergeable"
- "loop /pr-review + /pr-address until done"
- "make sure the PR is actually merge-ready"
Do **not** use when:
- User wants just one review pass (→ `/pr-review`).
- User wants to address already-posted comments without further self-review (→ `/pr-address`).
- A fixed round count is explicitly requested (e.g., "do 3 rounds") — honour the count instead of converging.

View File

@@ -260,6 +260,32 @@ Use a `trap` so release runs even on `exit 1`:
trap 'kill "$HEARTBEAT_PID" 2>/dev/null; rm -f "$LOCK"' EXIT INT TERM
```
### **Release the lock AS SOON AS the test run is done**
The lock guards **test execution**, not **app lifecycle**. Once Step 5 (record results) and Step 6 (post PR comment) are complete, release the lock IMMEDIATELY — even if:
- The native `poetry run app` / `pnpm dev` processes are still running so the user can keep poking at the app manually.
- You're leaving docker containers up.
- You're tailing logs for a minute or two.
Keeping the lock held past the test run is the single most common way `/pr-test` stalls other agents. **The app staying up is orthogonal to the lock; don't conflate them.** Sibling worktrees running their own `/pr-test` will kill the stray processes and free the ports themselves (Step 3c/3e-native handle that) — they just need the lock file gone.
Concretely, the sequence at the end of every `/pr-test` run (success or failure) is:
```bash
# 1. Write the final report + post PR comment — done above in Step 5/6.
# 2. Release the lock right now, even if the app is still up.
kill "$HEARTBEAT_PID" 2>/dev/null
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock (app may still be running)" \
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
# 3. Optionally leave the app running and note it so the user knows:
echo "Native stack still running on :3000 / :8006 for manual poking. Kill with:"
echo " pkill -9 -f 'poetry run app'; pkill -9 -f 'next-server|next dev'"
```
If a sibling agent's `/pr-test` needs to take over, it'll do the kill+rebuild dance from Step 3c/3e-native on its own — your only job is to not hold the lock file past the end of your test.
### Shared status log
`/Users/majdyz/Code/AutoGPT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes:
@@ -755,6 +781,19 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `![name](raw_url)` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
**CRITICAL — NEVER paste absolute local paths into the PR comment.** Strings like `/Users/…`, `/home/…`, `C:\…` are useless to every reviewer except you. Before posting, grep the final body for `/Users/`, `/home/`, `/tmp/`, `/private/`, `C:\`, `~/` and either drop those lines entirely or rewrite them as repo-relative paths (`autogpt_platform/backend/…`). The PR comment is an artifact reviewers on GitHub read — it must be self-contained on github.com. Keep local paths in `$RESULTS_DIR/test-report.md` for yourself; only copy the *content* they reference (excerpts, test names, log lines) into the PR comment, not the path.
**Pre-post sanity check** (paste after building the comment body, before `gh api ... comments`):
```bash
# Reject any local-looking absolute path or home-dir shortcut in the body
if grep -nE '(^|[^A-Za-z])(/Users/|/home/|/tmp/|/private/|C:\\|~/)[A-Za-z0-9]' "$COMMENT_FILE" ; then
echo "ABORT: local filesystem paths detected in PR comment body."
echo "Remove or rewrite as repo-relative (autogpt_platform/...) before posting."
exit 1
fi
```
```bash
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
REPO="Significant-Gravitas/AutoGPT"

View File

@@ -76,6 +76,7 @@ from backend.copilot.tools.models import (
SetupRequirementsResponse,
SuggestedGoalResponse,
TaskDecompositionResponse,
TodoWriteResponse,
UnderstandingUpdatedResponse,
)
from backend.copilot.tracking import track_user_message
@@ -1443,6 +1444,7 @@ ToolResponseUnion = (
| MemorySearchResponse
| MemoryForgetCandidatesResponse
| MemoryForgetConfirmResponse
| TodoWriteResponse
)

View File

@@ -50,13 +50,13 @@ _VISIBLE_REASONING_TYPES = frozenset({"reasoning.text", "reasoning.summary"})
# (~4,700 deltas per turn in one observed session, vs ~28 for Sonnet); without
# coalescing, every chunk is one Redis ``xadd`` + one SSE frame + one React
# re-render of the non-virtualised chat list, which paint-storms the browser
# main thread and freezes the UI. Batching into ~32-char / ~40 ms windows
# cuts the event rate ~100x while staying snappy enough that the Reasoning
# main thread and freezes the UI. Batching into ~64-char / ~50 ms windows
# cuts the event rate ~150x while staying snappy enough that the Reasoning
# collapse still feels live (well under the ~100 ms perceptual threshold).
# Per-delta persistence to ``session.messages`` stays granular — we only
# coalesce the *wire* emission.
_COALESCE_MIN_CHARS = 32
_COALESCE_MAX_INTERVAL_MS = 40.0
_COALESCE_MIN_CHARS = 64
_COALESCE_MAX_INTERVAL_MS = 50.0
class ReasoningDetail(BaseModel):
@@ -242,6 +242,12 @@ class BaselineReasoningEmitter:
fresh ``ChatMessage(role="reasoning")`` is appended and mutated
in-place as further deltas arrive; :meth:`close` drops the reference
but leaves the appended row intact.
``render_in_ui=False`` suppresses only the live wire events
(``StreamReasoning*``); the ``role='reasoning'`` persistence row is
still appended so ``convertChatSessionToUiMessages.ts`` can hydrate
the reasoning bubble on reload. The state machine advances
identically either way.
"""
def __init__(
@@ -250,21 +256,19 @@ class BaselineReasoningEmitter:
*,
coalesce_min_chars: int = _COALESCE_MIN_CHARS,
coalesce_max_interval_ms: float = _COALESCE_MAX_INTERVAL_MS,
render_in_ui: bool = True,
) -> None:
self._block_id: str = str(uuid.uuid4())
self._open: bool = False
self._session_messages = session_messages
self._current_row: ChatMessage | None = None
# Coalescing state — ``_pending_delta`` accumulates reasoning text
# between wire flushes. Providers like Kimi K2.6 emit very fine-
# grained chunks; batching them reduces Redis ``xadd`` + SSE + React
# re-render load by ~100x for equivalent text output. Tuning knobs
# are kwargs so tests can disable coalescing (``=0``) for
# deterministic event assertions.
# Coalescing state — tests can disable (``=0``) for deterministic
# event assertions.
self._coalesce_min_chars = coalesce_min_chars
self._coalesce_max_interval_ms = coalesce_max_interval_ms
self._pending_delta: str = ""
self._last_flush_monotonic: float = 0.0
self._render_in_ui = render_in_ui
@property
def is_open(self) -> bool:
@@ -296,8 +300,9 @@ class BaselineReasoningEmitter:
# syscalls off the hot path without changing semantics.
now = time.monotonic()
if not self._open:
events.append(StreamReasoningStart(id=self._block_id))
events.append(StreamReasoningDelta(id=self._block_id, delta=text))
if self._render_in_ui:
events.append(StreamReasoningStart(id=self._block_id))
events.append(StreamReasoningDelta(id=self._block_id, delta=text))
self._open = True
self._last_flush_monotonic = now
if self._session_messages is not None:
@@ -305,17 +310,15 @@ class BaselineReasoningEmitter:
self._session_messages.append(self._current_row)
return events
# Persist per-delta (no coalescing here — the session snapshot stays
# consistent at every chunk boundary, independent of the wire
# coalesce window).
if self._current_row is not None:
self._current_row.content = (self._current_row.content or "") + text
self._pending_delta += text
if self._should_flush_pending(now):
events.append(
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
)
if self._render_in_ui:
events.append(
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
)
self._pending_delta = ""
self._last_flush_monotonic = now
return events
@@ -348,12 +351,13 @@ class BaselineReasoningEmitter:
if not self._open:
return []
events: list[StreamBaseResponse] = []
if self._pending_delta:
events.append(
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
)
self._pending_delta = ""
events.append(StreamReasoningEnd(id=self._block_id))
if self._render_in_ui:
if self._pending_delta:
events.append(
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
)
events.append(StreamReasoningEnd(id=self._block_id))
self._pending_delta = ""
self._open = False
self._block_id = str(uuid.uuid4())
self._current_row = None

View File

@@ -452,3 +452,63 @@ class TestReasoningPersistence:
events = emitter.on_delta(_delta(reasoning="pure wire"))
assert len(events) == 2 # start + delta, no crash
# Nothing else to assert — just proves None session is supported.
class TestBaselineReasoningEmitterRenderFlag:
"""``render_in_ui=False`` must silence ``StreamReasoning*`` wire events
AND drop persistence of ``role="reasoning"`` rows — the operator hides
the collapse on both the live wire and on reload. Persistence is tied
to the wire events because the frontend's hydration path unconditionally
re-renders persisted reasoning rows; keeping them would make the flag a
no-op post-reload. These tests pin the contract in both directions so
future refactors can't flip only one half."""
def test_render_off_suppresses_start_and_delta(self):
emitter = BaselineReasoningEmitter(render_in_ui=False)
events = emitter.on_delta(_delta(reasoning="hidden"))
# No wire events, but state advanced (is_open == True) so close()
# below has something to rotate.
assert events == []
assert emitter.is_open is True
def test_render_off_suppresses_close_end(self):
emitter = BaselineReasoningEmitter(render_in_ui=False)
emitter.on_delta(_delta(reasoning="hidden"))
events = emitter.close()
assert events == []
assert emitter.is_open is False
def test_render_off_still_persists(self):
"""Persistence is decoupled from the render flag — session
transcript always keeps the ``role="reasoning"`` row so audit
and ``--resume``-equivalent replay never lose thinking text.
The frontend gates rendering separately."""
session: list[ChatMessage] = []
emitter = BaselineReasoningEmitter(session, render_in_ui=False)
emitter.on_delta(_delta(reasoning="part one "))
emitter.on_delta(_delta(reasoning="part two"))
emitter.close()
assert len(session) == 1
assert session[0].role == "reasoning"
assert session[0].content == "part one part two"
def test_render_off_rotates_block_id_between_sessions(self):
"""Even with wire events silenced the block id must rotate on close,
otherwise a hypothetical mid-session flip would reuse a stale id."""
emitter = BaselineReasoningEmitter(render_in_ui=False)
emitter.on_delta(_delta(reasoning="first"))
first_block_id = emitter._block_id
emitter.close()
emitter.on_delta(_delta(reasoning="second"))
assert emitter._block_id != first_block_id
def test_render_on_is_default(self):
"""Defaulting to True preserves backward compat — existing callers
that don't pass the kwarg keep emitting wire events as before."""
emitter = BaselineReasoningEmitter()
events = emitter.on_delta(_delta(reasoning="hello"))
assert len(events) == 2
assert isinstance(events[0], StreamReasoningStart)
assert isinstance(events[1], StreamReasoningDelta)

View File

@@ -15,7 +15,7 @@ import re
import shutil
import tempfile
import uuid
from collections.abc import AsyncGenerator, Mapping, Sequence
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
from dataclasses import dataclass, field
from functools import partial
from typing import TYPE_CHECKING, Any, cast
@@ -45,6 +45,8 @@ from backend.copilot.model import (
maybe_append_user_message,
upsert_chat_session,
)
from backend.copilot.model_router import resolve_model
from backend.copilot.moonshot import is_moonshot_model
from backend.copilot.pending_message_helpers import (
combine_pending_with_current,
drain_pending_safe,
@@ -82,7 +84,7 @@ from backend.copilot.service import (
from backend.copilot.session_cleanup import prune_orphan_tool_calls
from backend.copilot.thinking_stripper import ThinkingStripper as _ThinkingStripper
from backend.copilot.token_tracking import persist_and_record_usage
from backend.copilot.tools import execute_tool, get_available_tools
from backend.copilot.tools import ToolGroup, execute_tool, get_available_tools
from backend.copilot.tracking import track_user_message
from backend.copilot.transcript import (
STOP_REASON_END_TURN,
@@ -318,20 +320,17 @@ def _filter_tools_by_permissions(
]
def _resolve_baseline_model(tier: CopilotLlmModel | None) -> str:
async def _resolve_baseline_model(
tier: CopilotLlmModel | None, user_id: str | None
) -> str:
"""Pick the model for the baseline path based on the per-request tier.
Baseline resolves independently of SDK via the ``fast_*_model`` cells
of the (path, tier) matrix. ``'standard'`` / ``None`` picks Kimi
K2.6 by default (cheap + OpenRouter ``reasoning`` support);
``'advanced'`` picks Opus by default so the advanced tier is a clean
A/B against the SDK advanced tier — same model, different path —
isolating reasoning-wire + cache differences from model capability.
Both defaults are overridable per ``CHAT_FAST_*_MODEL`` env vars.
Delegates to :func:`copilot.model_router.resolve_model` so the
``(fast, tier)`` cell is LD-overridable per user. ``None`` tier
maps to ``"standard"``.
"""
if tier == "advanced":
return config.fast_advanced_model
return config.fast_standard_model
tier_name = "advanced" if tier == "advanced" else "standard"
return await resolve_model("fast", tier_name, user_id, config=config)
@dataclass
@@ -343,7 +342,21 @@ class _BaselineStreamState:
"""
model: str = ""
pending_events: list[StreamBaseResponse] = field(default_factory=list)
# Live delivery channel drained concurrently by ``stream_chat_completion_baseline``
# so reasoning / text / tool events reach the SSE wire **during** the upstream
# LLM stream, not after ``_baseline_llm_caller`` returns. Before this was a
# ``list`` drained per ``tool_call_loop`` iteration, so any model with
# extended thinking (Anthropic via OpenRouter, Moonshot, future reasoning
# routes) froze the UI for the entire duration of each LLM round before
# flushing the backlog in one burst. The queue is single-producer (the
# streaming loop) / single-consumer (the outer async-gen yield loop);
# ``None`` is the close sentinel.
pending_events: asyncio.Queue[StreamBaseResponse | None] = field(
default_factory=asyncio.Queue
)
# Mirror of every event put on ``pending_events`` — kept for unit tests that
# inspect post-hoc what was emitted. Not consumed by production code.
emitted_events: list[StreamBaseResponse] = field(default_factory=list)
assistant_text: str = ""
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
text_started: bool = False
@@ -382,31 +395,71 @@ class _BaselineStreamState:
# frontend's ``convertChatSessionToUiMessages`` relies on these
# rows to render the Reasoning collapse after the AI SDK's
# stream-end hydrate swaps in the DB-backed message list.
self.reasoning_emitter = BaselineReasoningEmitter(self.session_messages)
# ``render_in_ui`` is sourced from ``config.render_reasoning_in_ui``
# so the operator can silence the reasoning collapse globally
# without dropping the persisted audit trail.
self.reasoning_emitter = BaselineReasoningEmitter(
self.session_messages,
render_in_ui=config.render_reasoning_in_ui,
)
def _emit(state: "_BaselineStreamState", event: StreamBaseResponse) -> None:
"""Queue *event* for the live SSE wire AND mirror into ``emitted_events``.
Single helper so every streaming producer (LLM stream loop, tool executor,
conversation updater) posts to the same single-consumer queue. The mirror
list is read-only from production code — it exists so unit tests can assert
on the full sequence emitted during one call.
"""
state.pending_events.put_nowait(event)
state.emitted_events.append(event)
def _emit_all(
state: "_BaselineStreamState", events: Iterable[StreamBaseResponse]
) -> None:
"""Queue *events* in order — convenience for emitter batches."""
for event in events:
_emit(state, event)
def _is_anthropic_model(model: str) -> bool:
"""Return True if *model* routes to Anthropic (native or via OpenRouter).
Cache-control markers on message content + the ``anthropic-beta`` header
are Anthropic-specific. OpenAI rejects the unknown ``cache_control``
field with a 400 ("Extra inputs are not permitted") and Grok / other
providers behave similarly. OpenRouter strips unknown headers but
passes through ``cache_control`` on the body regardless of provider —
which would also fail when OpenRouter routes to a non-Anthropic model.
Examples that return True:
- ``anthropic/claude-sonnet-4-6`` (OpenRouter route)
- ``claude-3-5-sonnet-20241022`` (direct Anthropic API)
- ``anthropic.claude-3-5-sonnet`` (Bedrock-style)
False for ``openai/gpt-4o``, ``google/gemini-2.5-pro``, ``xai/grok-4``
etc.
etc. Moonshot is False here too even though Moonshot's
Anthropic-compat endpoint honours ``cache_control`` — use
:func:`_supports_prompt_cache_markers` for the cache-gating decision,
which also allows Moonshot routes. This function stays scoped to
"genuinely Anthropic" so callers that need the stricter check (e.g.
``anthropic-beta`` header emission) keep their existing semantics.
"""
lowered = model.lower()
return "claude" in lowered or lowered.startswith("anthropic")
def _supports_prompt_cache_markers(model: str) -> bool:
"""Return True when *model* accepts Anthropic-style ``cache_control``.
Superset of :func:`_is_anthropic_model` — also allows Moonshot
(``moonshotai/*``), whose OpenRouter Anthropic-compat endpoint
honours the marker and empirically lifts cache hit rate on
continuation turns from near-zero (Moonshot's own automatic prefix
cache, which drifts readily) to the 60-95% Anthropic ballpark.
OpenAI / Grok / Gemini still 400 on ``cache_control``, so this
function returns False for those providers — add new vendors here
only after verifying their endpoint accepts the field.
"""
return _is_anthropic_model(model) or is_moonshot_model(model)
def _fresh_ephemeral_cache_control() -> dict[str, str]:
"""Return a FRESH ephemeral ``cache_control`` dict each call.
@@ -519,7 +572,7 @@ async def _baseline_llm_caller(
Extracted from ``stream_chat_completion_baseline`` for readability.
"""
state.pending_events.append(StreamStartStep())
_emit(state, StreamStartStep())
# Fresh thinking-strip state per round so a malformed unclosed
# block in one LLM call cannot silently drop content in the next.
state.thinking_stripper = _ThinkingStripper()
@@ -527,19 +580,24 @@ async def _baseline_llm_caller(
round_text = ""
try:
client = _get_openai_client()
# Cache markers are Anthropic-specific. For OpenAI/Grok/other
# providers, leaving them on would trigger a 400 ("Extra inputs
# are not permitted" on cache_control). Tools were precomputed
# in stream_chat_completion_baseline via _mark_tools_with_cache_control
# (only when the model was Anthropic), so on non-Anthropic routes
# tools ship without cache_control on the last entry too.
# Cache markers are accepted by Anthropic AND Moonshot (via OR's
# Anthropic-compat endpoint). OpenAI/Grok/Gemini 400 on the
# unknown ``cache_control`` field — tools were precomputed in
# stream_chat_completion_baseline via _mark_tools_with_cache_control
# with the same gate, so on unsupported routes tools ship
# unmarked too.
#
# `extra_body` `usage.include=true` asks OpenRouter to embed the real
# generation cost into the final usage chunk — required by the
# cost-based rate limiter in routes.py. Separate from the Anthropic
# The ``anthropic-beta`` header is only emitted for genuinely
# Anthropic routes (see :func:`_is_anthropic_model`) — Moonshot
# doesn't need the beta header; sending it is a no-op but we
# keep the check strict for clarity.
#
# `extra_body` `usage.include=true` asks OpenRouter to embed the
# real generation cost into the final usage chunk — required by
# the cost-based rate limiter in routes.py. Separate from the
# caching headers, always sent.
is_anthropic = _is_anthropic_model(state.model)
if is_anthropic:
supports_cache = _supports_prompt_cache_markers(state.model)
if supports_cache:
# Build the cached system dict once per session and splice it in
# on each round. The full ``messages`` list grows with every
# tool call, so copying the entire list just to mutate index 0
@@ -555,7 +613,11 @@ async def _baseline_llm_caller(
final_messages = [state.cached_system_message, *messages[1:]]
else:
final_messages = messages
extra_headers = _fresh_anthropic_caching_headers()
extra_headers = (
_fresh_anthropic_caching_headers()
if _is_anthropic_model(state.model)
else None
)
else:
final_messages = messages
extra_headers = None
@@ -621,31 +683,30 @@ async def _baseline_llm_caller(
if not delta:
continue
state.pending_events.extend(state.reasoning_emitter.on_delta(delta))
_emit_all(state, state.reasoning_emitter.on_delta(delta))
if delta.content:
# Text and reasoning must not interleave on the wire — the
# AI SDK maps distinct start/end pairs to distinct UI
# parts. Close any open reasoning block before emitting
# the first text delta of this run.
state.pending_events.extend(state.reasoning_emitter.close())
_emit_all(state, state.reasoning_emitter.close())
emit = state.thinking_stripper.process(delta.content)
if emit:
if not state.text_started:
state.pending_events.append(
StreamTextStart(id=state.text_block_id)
)
_emit(state, StreamTextStart(id=state.text_block_id))
state.text_started = True
round_text += emit
state.pending_events.append(
StreamTextDelta(id=state.text_block_id, delta=emit)
_emit(
state,
StreamTextDelta(id=state.text_block_id, delta=emit),
)
if delta.tool_calls:
# Same rule as the text branch: close any open reasoning
# block before a tool_use starts so the AI SDK treats
# reasoning and tool-use as distinct parts.
state.pending_events.extend(state.reasoning_emitter.close())
_emit_all(state, state.reasoning_emitter.close())
for tc in delta.tool_calls:
idx = tc.index
if idx not in tool_calls_by_index:
@@ -676,19 +737,17 @@ async def _baseline_llm_caller(
# ``async for chunk in response`` would otherwise leave reasoning
# and/or text unterminated and only ``StreamFinishStep`` emitted —
# the Reasoning / Text collapses would never finalise.
state.pending_events.extend(state.reasoning_emitter.close())
_emit_all(state, state.reasoning_emitter.close())
# Flush any buffered text held back by the thinking stripper.
tail = state.thinking_stripper.flush()
if tail:
if not state.text_started:
state.pending_events.append(StreamTextStart(id=state.text_block_id))
_emit(state, StreamTextStart(id=state.text_block_id))
state.text_started = True
round_text += tail
state.pending_events.append(
StreamTextDelta(id=state.text_block_id, delta=tail)
)
_emit(state, StreamTextDelta(id=state.text_block_id, delta=tail))
if state.text_started:
state.pending_events.append(StreamTextEnd(id=state.text_block_id))
_emit(state, StreamTextEnd(id=state.text_block_id))
state.text_started = False
state.text_block_id = str(uuid.uuid4())
# Always persist partial text so the session history stays consistent,
@@ -696,7 +755,7 @@ async def _baseline_llm_caller(
state.assistant_text += round_text
# Always emit StreamFinishStep to match the StreamStartStep,
# even if an exception occurred during streaming.
state.pending_events.append(StreamFinishStep())
_emit(state, StreamFinishStep())
# Convert to shared format
llm_tool_calls = [
@@ -738,13 +797,14 @@ async def _baseline_tool_executor(
except orjson.JSONDecodeError as parse_err:
parse_error = f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
logger.warning("[Baseline] %s", parse_error)
state.pending_events.append(
_emit(
state,
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=parse_error,
success=False,
)
),
)
return ToolCallResult(
tool_call_id=tool_call_id,
@@ -753,15 +813,17 @@ async def _baseline_tool_executor(
is_error=True,
)
state.pending_events.append(
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
_emit(
state,
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name),
)
state.pending_events.append(
_emit(
state,
StreamToolInputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
input=tool_args,
)
),
)
# Announce the tool call to the session so in-turn guards like
@@ -785,7 +847,7 @@ async def _baseline_tool_executor(
session=session,
tool_call_id=tool_call_id,
)
state.pending_events.append(result)
_emit(state, result)
tool_output = (
result.output if isinstance(result.output, str) else str(result.output)
)
@@ -802,13 +864,14 @@ async def _baseline_tool_executor(
error_output,
exc_info=True,
)
state.pending_events.append(
_emit(
state,
StreamToolOutputAvailable(
toolCallId=tool_call_id,
toolName=tool_name,
output=error_output,
success=False,
)
),
)
return ToolCallResult(
tool_call_id=tool_call_id,
@@ -1317,7 +1380,7 @@ async def stream_chat_completion_baseline(
# Select model based on the per-request tier toggle (standard / advanced).
# The path (fast vs extended_thinking) is already decided — we're in the
# baseline (fast) path; ``mode`` is accepted for logging parity only.
active_model = _resolve_baseline_model(model)
active_model = await _resolve_baseline_model(model, user_id)
# --- E2B sandbox setup (feature parity with SDK path) ---
e2b_sandbox = None
@@ -1583,7 +1646,10 @@ async def stream_chat_completion_baseline(
openai_messages[i]["content"] = text
break
tools = get_available_tools()
disabled_tool_groups: list[ToolGroup] = []
if not graphiti_enabled:
disabled_tool_groups.append("graphiti")
tools = get_available_tools(disabled_groups=disabled_tool_groups)
# --- Permission filtering ---
if permissions is not None:
@@ -1594,9 +1660,10 @@ async def stream_chat_completion_baseline(
# _baseline_llm_caller) avoids re-copying ~43 tool dicts on every LLM
# round of the tool-call loop.
#
# Only apply to Anthropic routes — OpenAI/Grok/other providers would
# 400 on the unknown ``cache_control`` field inside tool definitions.
if _is_anthropic_model(active_model):
# Applies to Anthropic AND Moonshot routes — OpenAI/Grok/Gemini 400
# on the unknown ``cache_control`` field inside tool definitions, so
# the gate stays narrow (see :func:`_supports_prompt_cache_markers`).
if _supports_prompt_cache_markers(active_model):
tools = cast(
list[ChatCompletionToolParam], _mark_tools_with_cache_control(tools)
)
@@ -1660,139 +1727,172 @@ async def stream_chat_completion_baseline(
state=state,
)
try:
loop_result = None
async for loop_result in tool_call_loop(
messages=openai_messages,
tools=tools,
llm_call=_bound_llm_caller,
execute_tool=_bound_tool_executor,
update_conversation=_bound_conversation_updater,
max_iterations=_MAX_TOOL_ROUNDS,
):
# Drain buffered events after each iteration (real-time streaming)
for evt in state.pending_events:
yield evt
state.pending_events.clear()
# Run the tool-call loop concurrently with the event consumer so
# ``StreamReasoning*`` / ``StreamText*`` deltas emitted inside
# ``_baseline_llm_caller`` reach the SSE wire DURING the upstream LLM
# stream instead of only at iteration boundaries. Any reasoning route
# that streams for several minutes per round (extended thinking on
# Anthropic / Moonshot / future providers) would otherwise freeze the
# UI for the whole window before flushing the backlog in one burst.
loop_result_holder: list[Any] = [None]
loop_task: asyncio.Task[None] | None = None
# Inject any messages the user queued while the turn was
# running. ``tool_call_loop`` mutates ``openai_messages``
# in-place, so appending here means the model sees the new
# messages on its next LLM call.
#
# IMPORTANT: skip when the loop has already finished (no
# more LLM calls are coming). ``tool_call_loop`` yields
# a final ``ToolCallLoopResult`` on both paths:
# - natural finish: ``finished_naturally=True``
# - hit max_iterations: ``finished_naturally=False``
# and ``iterations >= max_iterations``
# In either case the loop is about to return on the next
# ``async for`` step, so draining here would silently
# lose the message (the user sees 202 but the model never
# reads the text). Those messages stay in the buffer and
# get picked up at the start of the next turn.
is_final_yield = (
loop_result.finished_naturally
or loop_result.iterations >= _MAX_TOOL_ROUNDS
)
if is_final_yield:
continue
try:
pending = await drain_pending_messages(session_id)
except Exception:
logger.warning(
"[Baseline] mid-loop drain_pending_messages failed for session %s",
session_id,
exc_info=True,
)
pending = []
if pending:
# Flush any buffered assistant/tool messages from completed
# rounds into session.messages BEFORE appending the pending
# user message. ``_baseline_conversation_updater`` only
# records assistant+tool rounds into ``state.session_messages``
# — they are normally batch-flushed in the finally block.
# Without this in-order flush, the mid-loop pending user
# message lands before the preceding round's assistant/tool
# entries, producing chronologically-wrong session.messages
# on persist (user interposed between an assistant tool_call
# and its tool-result), which breaks OpenAI tool-call ordering
# invariants on the next turn's replay.
async def _run_tool_call_loop() -> None:
# Read/write the current session via ``_session_holder`` so this
# closure doesn't need to ``nonlocal session`` — pyright can't narrow
# the outer ``session: ChatSession | None`` through a nested scope,
# but the holder is typed non-optional after the preflight guard
# above.
try:
async for loop_result in tool_call_loop(
messages=openai_messages,
tools=tools,
llm_call=_bound_llm_caller,
execute_tool=_bound_tool_executor,
update_conversation=_bound_conversation_updater,
max_iterations=_MAX_TOOL_ROUNDS,
):
loop_result_holder[0] = loop_result
# Inject any messages the user queued while the turn was
# running. ``tool_call_loop`` mutates ``openai_messages``
# in-place, so appending here means the model sees the new
# messages on its next LLM call.
#
# Also persist any assistant text from text-only rounds (rounds
# with no tool calls, which ``_baseline_conversation_updater``
# does NOT record in session_messages). If we only update
# ``_flushed_assistant_text_len`` without persisting the text,
# that text is silently lost: the finally block only appends
# assistant_text[_flushed_assistant_text_len:], so text generated
# before this drain never reaches session.messages.
recorded_text = "".join(
m.content or ""
for m in state.session_messages
if m.role == "assistant"
# IMPORTANT: skip when the loop has already finished (no
# more LLM calls are coming). ``tool_call_loop`` yields
# a final ``ToolCallLoopResult`` on both paths:
# - natural finish: ``finished_naturally=True``
# - hit max_iterations: ``finished_naturally=False``
# and ``iterations >= max_iterations``
# In either case the loop is about to return on the next
# ``async for`` step, so draining here would silently
# lose the message (the user sees 202 but the model never
# reads the text). Those messages stay in the buffer and
# get picked up at the start of the next turn.
is_final_yield = (
loop_result.finished_naturally
or loop_result.iterations >= _MAX_TOOL_ROUNDS
)
unflushed_text = state.assistant_text[
state._flushed_assistant_text_len :
]
text_only_text = (
unflushed_text[len(recorded_text) :]
if unflushed_text.startswith(recorded_text)
else unflushed_text
)
if text_only_text.strip():
session.messages.append(
ChatMessage(role="assistant", content=text_only_text)
if is_final_yield:
continue
try:
pending = await drain_pending_messages(session_id)
except Exception:
logger.warning(
"[Baseline] mid-loop drain_pending_messages failed for "
"session %s",
session_id,
exc_info=True,
)
for _buffered in state.session_messages:
session.messages.append(_buffered)
state.session_messages.clear()
# Record how much assistant_text has been covered by the
# structured entries just flushed, so the finally block's
# final-text dedup doesn't re-append rounds already persisted.
state._flushed_assistant_text_len = len(state.assistant_text)
pending = []
if pending:
# Flush any buffered assistant/tool messages from completed
# rounds into session.messages BEFORE appending the pending
# user message. ``_baseline_conversation_updater`` only
# records assistant+tool rounds into ``state.session_messages``
# — they are normally batch-flushed in the finally block.
# Without this in-order flush, the mid-loop pending user
# message lands before the preceding round's assistant/tool
# entries, producing chronologically-wrong session.messages
# on persist (user interposed between an assistant tool_call
# and its tool-result), which breaks OpenAI tool-call ordering
# invariants on the next turn's replay.
#
# Also persist any assistant text from text-only rounds (rounds
# with no tool calls, which ``_baseline_conversation_updater``
# does NOT record in session_messages). If we only update
# ``_flushed_assistant_text_len`` without persisting the text,
# that text is silently lost: the finally block only appends
# assistant_text[_flushed_assistant_text_len:], so text generated
# before this drain never reaches session.messages.
recorded_text = "".join(
m.content or ""
for m in state.session_messages
if m.role == "assistant"
)
unflushed_text = state.assistant_text[
state._flushed_assistant_text_len :
]
text_only_text = (
unflushed_text[len(recorded_text) :]
if unflushed_text.startswith(recorded_text)
else unflushed_text
)
current_session = _session_holder[0]
if text_only_text.strip():
current_session.messages.append(
ChatMessage(role="assistant", content=text_only_text)
)
for _buffered in state.session_messages:
current_session.messages.append(_buffered)
state.session_messages.clear()
# Record how much assistant_text has been covered by the
# structured entries just flushed, so the finally block's
# final-text dedup doesn't re-append rounds already persisted.
state._flushed_assistant_text_len = len(state.assistant_text)
# Persist the assistant/tool flush BEFORE the pending append
# so a later pending-persist failure can roll back the
# pending rows without also discarding LLM output.
session = await persist_session_safe(session, "[Baseline]")
# ``upsert_chat_session`` may return a *new* ``ChatSession``
# instance (e.g. when a concurrent title update has written a
# newer title to Redis, it returns ``session.model_copy``).
# Keep ``_session_holder`` in sync so subsequent tool rounds
# executed via ``_bound_tool_executor`` see the fresh session
# — any tool-side mutations on the stale object would be
# discarded when the new one is persisted in the ``finally``.
_session_holder[0] = session
# Persist the assistant/tool flush BEFORE the pending append
# so a later pending-persist failure can roll back the
# pending rows without also discarding LLM output.
current_session = await persist_session_safe(
current_session, "[Baseline]"
)
# ``upsert_chat_session`` may return a *new* ``ChatSession``
# instance (e.g. when a concurrent title update has written a
# newer title to Redis, it returns ``session.model_copy``).
# Keep ``_session_holder`` in sync so subsequent tool rounds
# executed via ``_bound_tool_executor`` see the fresh session
# — any tool-side mutations on the stale object would be
# discarded when the new one is persisted in the ``finally``.
_session_holder[0] = current_session
# ``format_pending_as_user_message`` embeds file attachments
# and context URL/page content into the content string so
# the in-session transcript is a faithful copy of what the
# model actually saw. We also mirror each push into
# ``openai_messages`` so the model's next LLM round sees it.
#
# Pre-compute the formatted dicts once so both the openai
# messages append and the content_of lookup inside the
# shared helper use the same string — and so ``on_rollback``
# can trim ``openai_messages`` to the recorded anchor.
formatted_by_pm = {
id(pm): format_pending_as_user_message(pm) for pm in pending
}
_openai_anchor = len(openai_messages)
for pm in pending:
openai_messages.append(formatted_by_pm[id(pm)])
# ``format_pending_as_user_message`` embeds file attachments
# and context URL/page content into the content string so
# the in-session transcript is a faithful copy of what the
# model actually saw. We also mirror each push into
# ``openai_messages`` so the model's next LLM round sees it.
#
# Pre-compute the formatted dicts once so both the openai
# messages append and the content_of lookup inside the
# shared helper use the same string — and so ``on_rollback``
# can trim ``openai_messages`` to the recorded anchor.
formatted_by_pm = {
id(pm): format_pending_as_user_message(pm) for pm in pending
}
_openai_anchor = len(openai_messages)
for pm in pending:
openai_messages.append(formatted_by_pm[id(pm)])
def _trim_openai_on_rollback(_session_anchor: int) -> None:
del openai_messages[_openai_anchor:]
def _trim_openai_on_rollback(_session_anchor: int) -> None:
del openai_messages[_openai_anchor:]
await persist_pending_as_user_rows(
session,
transcript_builder,
pending,
log_prefix="[Baseline]",
content_of=lambda pm: formatted_by_pm[id(pm)]["content"],
on_rollback=_trim_openai_on_rollback,
)
await persist_pending_as_user_rows(
current_session,
transcript_builder,
pending,
log_prefix="[Baseline]",
content_of=lambda pm: formatted_by_pm[id(pm)]["content"],
on_rollback=_trim_openai_on_rollback,
)
finally:
# Always post the sentinel so the outer consumer exits — even if
# ``tool_call_loop`` raised. ``_baseline_llm_caller``'s own
# finally block has already pushed ``StreamReasoningEnd`` /
# ``StreamTextEnd`` / ``StreamFinishStep`` at this point, so the
# sentinel only terminates the consumer; it does not suppress
# any still-unflushed events.
state.pending_events.put_nowait(None)
loop_task = asyncio.create_task(_run_tool_call_loop())
try:
while True:
evt = await state.pending_events.get()
if evt is None:
break
yield evt
# Sentinel received — surface any exception the inner task hit.
await loop_task
loop_result = loop_result_holder[0]
if loop_result and not loop_result.finished_naturally:
limit_msg = (
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
@@ -1803,25 +1903,34 @@ async def stream_chat_completion_baseline(
errorText=limit_msg,
code="baseline_tool_round_limit",
)
except Exception as e:
_stream_error = True
error_msg = str(e) or type(e).__name__
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
# ``_baseline_llm_caller``'s finally block closes any open
# reasoning / text blocks and appends ``StreamFinishStep`` on
# both normal and exception paths, so pending_events already has
# the correct protocol ordering:
# StreamStartStep -> StreamReasoningStart -> ...deltas... ->
# StreamReasoningEnd -> StreamTextStart -> ...deltas... ->
# StreamTextEnd -> StreamFinishStep
# Just drain what's buffered, then yield the error.
for evt in state.pending_events:
yield evt
state.pending_events.clear()
# Drain any queued tail events (reasoning/text close + finish step)
# that ``_baseline_llm_caller``'s finally block pushed before the
# sentinel arrived — without this the frontend would be missing the
# matching end / finish parts for the partial round.
while not state.pending_events.empty():
evt = state.pending_events.get_nowait()
if evt is not None:
yield evt
yield StreamError(errorText=error_msg, code="baseline_error")
# Still persist whatever we got
finally:
# Cancel the inner task if we're unwinding early (client disconnect,
# unexpected error in the consumer) so it doesn't keep streaming
# tokens into a dead queue.
if loop_task is not None and not loop_task.done():
loop_task.cancel()
try:
await loop_task
except (asyncio.CancelledError, Exception):
pass
# Re-sync the outer ``session`` binding in case the inner task
# reassigned it via a mid-loop ``persist_session_safe`` call.
session = _session_holder[0]
# In-flight tool-call announcements are only meaningful for the
# current turn; clear at the top of the outer finally so the next
# turn starts with a clean scratch buffer even if one of the

View File

@@ -21,6 +21,7 @@ from backend.copilot.baseline.service import (
_is_anthropic_model,
_mark_system_message_with_cache_control,
_mark_tools_with_cache_control,
_supports_prompt_cache_markers,
)
from backend.copilot.model import ChatMessage
from backend.copilot.response_model import (
@@ -39,7 +40,10 @@ from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallRe
class TestBaselineStreamState:
def test_defaults(self):
state = _BaselineStreamState()
assert state.pending_events == []
# ``pending_events`` is an asyncio.Queue now (live SSE channel).
# The durable inspection view is ``emitted_events``.
assert state.pending_events.empty()
assert state.emitted_events == []
assert state.assistant_text == ""
assert state.text_started is False
assert state.turn_prompt_tokens == 0
@@ -1687,7 +1691,7 @@ class TestBaselineReasoningStreaming:
state=state,
)
types = [type(e).__name__ for e in state.pending_events]
types = [type(e).__name__ for e in state.emitted_events]
assert "StreamReasoningStart" in types
assert "StreamReasoningDelta" in types
assert "StreamReasoningEnd" in types
@@ -1702,14 +1706,14 @@ class TestBaselineReasoningStreaming:
# a fresh id after the reasoning-end rotation.
reasoning_ids = {
e.id
for e in state.pending_events
for e in state.emitted_events
if isinstance(
e, (StreamReasoningStart, StreamReasoningDelta, StreamReasoningEnd)
)
}
text_ids = {
e.id
for e in state.pending_events
for e in state.emitted_events
if isinstance(e, (StreamTextStart, StreamTextDelta, StreamTextEnd))
}
assert len(reasoning_ids) == 1
@@ -1717,7 +1721,7 @@ class TestBaselineReasoningStreaming:
assert reasoning_ids.isdisjoint(text_ids)
combined = "".join(
e.delta for e in state.pending_events if isinstance(e, StreamReasoningDelta)
e.delta for e in state.emitted_events if isinstance(e, StreamReasoningDelta)
)
assert combined == "thinking... more"
@@ -1759,7 +1763,7 @@ class TestBaselineReasoningStreaming:
# A reasoning-end must have been emitted — this is the tool_calls
# branch's responsibility, not the stream-end cleanup.
types = [type(e).__name__ for e in state.pending_events]
types = [type(e).__name__ for e in state.emitted_events]
assert "StreamReasoningStart" in types
assert "StreamReasoningEnd" in types
@@ -1802,7 +1806,7 @@ class TestBaselineReasoningStreaming:
state=state,
)
types = [type(e).__name__ for e in state.pending_events]
types = [type(e).__name__ for e in state.emitted_events]
# The reasoning block was opened, the exception fired, and the
# finally block must have closed it before emitting the finish
# step.
@@ -1861,12 +1865,14 @@ class TestBaselineReasoningStreaming:
assert "reasoning" not in extra_body
@pytest.mark.asyncio
async def test_kimi_route_sends_reasoning_but_no_cache_control(self):
"""Kimi K2.6 is the default fast_model and sends ``reasoning`` via
OpenRouter's unified extension. It must NOT receive ``cache_control``
markers or the ``anthropic-beta`` header — Moonshot uses its own
auto-caching and those Anthropic-only fields would either get
silently dropped or (worst case) 400 on a future provider change."""
async def test_kimi_route_sends_reasoning_and_cache_control(self):
"""Kimi K2.6 (Moonshot via OpenRouter's Anthropic-compat endpoint)
accepts ``cache_control: {type: ephemeral}`` on the system block
and the last tool — the endpoint honours the marker and lifts
cache hit rate on continuation turns from near-zero (Moonshot's
auto-caching drifts) to the Anthropic ~60-95% ballpark. The
``anthropic-beta`` header stays off because Moonshot doesn't need
it; OpenRouter would strip the unknown header anyway."""
state = _BaselineStreamState(model="moonshotai/kimi-k2.6")
mock_client = MagicMock()
@@ -1898,15 +1904,29 @@ class TestBaselineReasoningStreaming:
# cheap-but-still-reasoning-capable path.
assert "reasoning" in extra_body
assert extra_body["reasoning"]["max_tokens"] > 0
# Anthropic-only fields stay off.
assert "extra_headers" not in call_kwargs
# No ``anthropic-beta`` header — that beta is specifically for
# native Anthropic endpoints; Moonshot's shim accepts
# ``cache_control`` without it, and sending it would be wasted
# bytes (OR strips it before forwarding to Moonshot).
assert "extra_headers" not in call_kwargs or not call_kwargs.get(
"extra_headers"
)
# System block MUST carry ``cache_control`` so Moonshot's cache
# breakpoint is honoured. The cached system-message builder
# emits list-shape content with the marker on the first (and
# only) block — assert on that shape.
sys_msg = call_kwargs["messages"][0]
sys_content = sys_msg.get("content")
if isinstance(sys_content, list):
assert all("cache_control" not in block for block in sys_content)
tools = call_kwargs.get("tools", [])
for t in tools:
assert "cache_control" not in t
assert isinstance(
sys_content, list
), "Cached system message should be a list-shape content block"
assert any(
"cache_control" in block for block in sys_content if isinstance(block, dict)
), "Kimi system message should now carry cache_control markers"
# Tool-level cache marking is applied by ``stream_chat_completion_baseline``
# (see ``_mark_tools_with_cache_control``) before tools reach
# ``_baseline_llm_caller``, so this unit test doesn't exercise
# that path — covered by the outer integration test.
@pytest.mark.asyncio
async def test_reasoning_only_stream_still_closes_block(self):
@@ -1935,7 +1955,7 @@ class TestBaselineReasoningStreaming:
state=state,
)
types = [type(e).__name__ for e in state.pending_events]
types = [type(e).__name__ for e in state.emitted_events]
assert "StreamReasoningStart" in types
assert "StreamReasoningEnd" in types
# No text was produced — no text events should be emitted.
@@ -2006,3 +2026,55 @@ class TestBaselineReasoningStreaming:
reasoning_rows = [m for m in state.session_messages if m.role == "reasoning"]
assert len(reasoning_rows) == 1
assert reasoning_rows[0].content == "first thought"
class TestSupportsPromptCacheMarkers:
"""``_supports_prompt_cache_markers`` is the widened gate for
emitting ``cache_control`` markers on message content. It's a
superset of ``_is_anthropic_model`` that ALSO admits Moonshot
(whose Anthropic-compat endpoint honours the marker) while keeping
the False answer for OpenAI / Grok / Gemini (which 400 on the
unknown field)."""
@pytest.mark.parametrize(
"model",
[
"anthropic/claude-sonnet-4-6",
"claude-3-5-sonnet-20241022",
"anthropic.claude-3-5-sonnet",
"ANTHROPIC/Claude-Opus",
],
)
def test_anthropic_routes_are_supported(self, model):
assert _supports_prompt_cache_markers(model) is True
@pytest.mark.parametrize(
"model",
[
"moonshotai/kimi-k2.6",
"moonshotai/kimi-k2-thinking",
"moonshotai/kimi-k2.5",
"moonshotai/kimi-k3.0", # future SKU
],
)
def test_moonshot_routes_are_supported(self, model):
"""The whole reason this predicate exists — Moonshot must be
True even though ``_is_anthropic_model`` is False for it."""
assert _supports_prompt_cache_markers(model) is True
# Verify this is strictly wider than the anthropic-only check.
assert _is_anthropic_model(model) is False
@pytest.mark.parametrize(
"model",
[
"openai/gpt-4o",
"google/gemini-2.5-pro",
"xai/grok-4",
"meta-llama/llama-3.3-70b-instruct",
"deepseek/deepseek-v3",
],
)
def test_other_providers_still_rejected(self, model):
"""Regression guard: OpenAI/Grok/Gemini still 400 on
``cache_control``, so the widened gate must keep them out."""
assert _supports_prompt_cache_markers(model) is False

View File

@@ -67,34 +67,40 @@ class TestResolveBaselineModel:
Baseline reads the ``fast_*_model`` cells of the (path, tier) matrix
and never falls through to the SDK-side ``thinking_*_model`` cells.
Default routing:
- ``standard`` / ``None`` → ``config.fast_standard_model`` (Kimi K2.6)
- ``advanced`` → ``config.fast_advanced_model`` (Opus — same as SDK's
advanced tier, so the advanced A/B isolates path differences)
Without a user_id (so no LD context) the resolver returns the
``ChatConfig`` static default; per-user overrides are exercised in
``copilot/model_router_test.py``.
"""
def test_advanced_tier_selects_fast_advanced_model(self):
assert _resolve_baseline_model("advanced") == config.fast_advanced_model
@pytest.mark.asyncio
async def test_advanced_tier_selects_fast_advanced_model(self):
assert (
await _resolve_baseline_model("advanced", None)
== config.fast_advanced_model
)
def test_standard_tier_selects_fast_standard_model(self):
assert _resolve_baseline_model("standard") == config.fast_standard_model
@pytest.mark.asyncio
async def test_standard_tier_selects_fast_standard_model(self):
assert (
await _resolve_baseline_model("standard", None)
== config.fast_standard_model
)
def test_none_tier_selects_fast_standard_model(self):
"""Baseline users without a tier get the cheap fast-standard default."""
assert _resolve_baseline_model(None) == config.fast_standard_model
@pytest.mark.asyncio
async def test_none_tier_selects_fast_standard_model(self):
"""Baseline users without a tier get the fast-standard default."""
assert await _resolve_baseline_model(None, None) == config.fast_standard_model
def test_fast_standard_default_is_kimi(self):
"""Shipped default: Kimi K2.6 on the baseline standard cell.
Asserts the declared ``Field`` default — env-independent — so a
deploy-time ``CHAT_FAST_STANDARD_MODEL`` rollback override
doesn't fail CI while still pinning the shipped default.
"""
def test_fast_standard_default_is_sonnet(self):
"""Shipped default: Sonnet on the baseline standard cell — the
non-Anthropic routes ship via the LD flag instead of a config
change. Asserts the declared ``Field`` default so a deploy-time
``CHAT_FAST_STANDARD_MODEL`` override doesn't flake CI."""
from backend.copilot.config import ChatConfig
assert (
ChatConfig.model_fields["fast_standard_model"].default
== "moonshotai/kimi-k2.6"
== "anthropic/claude-sonnet-4-6"
)
def test_fast_advanced_default_is_opus(self):
@@ -108,18 +114,6 @@ class TestResolveBaselineModel:
== "anthropic/claude-opus-4.7"
)
def test_standard_cells_diverge_across_paths(self):
"""The whole point of the split: baseline cheap (Kimi) vs SDK
Anthropic-only (Sonnet). If the shipped standard defaults ever
collapse to the same value someone lost the cost savings.
Checked against ``Field`` defaults, not the env-backed singleton."""
from backend.copilot.config import ChatConfig
assert (
ChatConfig.model_fields["thinking_standard_model"].default
!= ChatConfig.model_fields["fast_standard_model"].default
)
def test_standard_and_advanced_cells_differ_on_fast(self):
"""Advanced tier defaults to a different model than standard on
the baseline path. Checked against declared ``Field`` defaults

View File

@@ -3,7 +3,7 @@
import os
from typing import Literal
from pydantic import AliasChoices, Field, field_validator
from pydantic import AliasChoices, Field, field_validator, model_validator
from pydantic_settings import BaseSettings
from backend.util.clients import OPENROUTER_BASE_URL
@@ -43,26 +43,20 @@ class ChatConfig(BaseSettings):
# ``CHAT_FAST_MODEL``) are preserved via ``validation_alias`` so
# existing deployments continue to override the same effective cell.
fast_standard_model: str = Field(
default="moonshotai/kimi-k2.6",
default="anthropic/claude-sonnet-4-6",
validation_alias=AliasChoices(
"CHAT_FAST_STANDARD_MODEL",
"CHAT_FAST_MODEL",
),
description="Baseline path, 'standard' / ``None`` tier. Kimi K2.6 "
"by default: ~5x cheaper input and ~5.4x cheaper output than Sonnet, "
"SWE-Bench Verified parity with Opus, and OpenRouter advertises the "
"``reasoning`` + ``include_reasoning`` extension params on the "
"Moonshot endpoints — so the baseline reasoning plumbing lights up "
"without provider-specific code. Roll back to the Anthropic route "
"via ``CHAT_FAST_STANDARD_MODEL=anthropic/claude-sonnet-4-6`` (then "
"``cache_control`` breakpoints reactivate via "
"``_is_anthropic_model``).",
description="Baseline path, 'standard' / ``None`` tier. Per-user "
"overrides flow through the ``copilot-fast-standard-model`` LD flag "
"(see ``copilot/model_router.py``); this value is the fallback.",
)
fast_advanced_model: str = Field(
default="anthropic/claude-opus-4.7",
validation_alias=AliasChoices("CHAT_FAST_ADVANCED_MODEL"),
description="Baseline path, 'advanced' tier. Opus by default. "
"Override via ``CHAT_FAST_ADVANCED_MODEL``.",
description="Baseline path, 'advanced' tier. LD override: "
"``copilot-fast-advanced-model``.",
)
thinking_standard_model: str = Field(
default="anthropic/claude-sonnet-4-6",
@@ -71,11 +65,7 @@ class ChatConfig(BaseSettings):
"CHAT_MODEL",
),
description="SDK (extended-thinking) path, 'standard' / ``None`` "
"tier. Sonnet by default: the Claude Agent SDK CLI only speaks to "
"Anthropic endpoints, so the standard SDK tier has to stay on an "
"Anthropic model regardless of what the baseline path runs. "
"Override via ``CHAT_THINKING_STANDARD_MODEL`` (legacy "
"``CHAT_MODEL`` still honored).",
"tier. LD override: ``copilot-thinking-standard-model``.",
)
thinking_advanced_model: str = Field(
default="anthropic/claude-opus-4.7",
@@ -83,17 +73,18 @@ class ChatConfig(BaseSettings):
"CHAT_THINKING_ADVANCED_MODEL",
"CHAT_ADVANCED_MODEL",
),
description="SDK (extended-thinking) path, 'advanced' tier. Opus "
"by default. Override via ``CHAT_THINKING_ADVANCED_MODEL`` "
"(legacy ``CHAT_ADVANCED_MODEL`` still honored).",
description="SDK (extended-thinking) path, 'advanced' tier. LD "
"override: ``copilot-thinking-advanced-model``.",
)
title_model: str = Field(
default="openai/gpt-4o-mini",
description="Model to use for generating session titles (should be fast/cheap)",
)
simulation_model: str = Field(
default="google/gemini-2.5-flash",
description="Model for dry-run block simulation (should be fast/cheap with good JSON output)",
default="google/gemini-2.5-flash-lite",
description="Model for dry-run block simulation (should be fast/cheap with good JSON output). "
"Gemini 2.5 Flash-Lite is ~3x cheaper than Flash ($0.10/$0.40 vs $0.30/$1.20 per MTok) "
"with JSON-mode reliability adequate for shape-matching block outputs.",
)
api_key: str | None = Field(default=None, description="OpenAI API key")
base_url: str | None = Field(
@@ -249,13 +240,28 @@ class ChatConfig(BaseSettings):
"``max_thinking_tokens`` kwarg so the CLI falls back to model default "
"(which, without the flag, leaves extended thinking off).",
)
render_reasoning_in_ui: bool = Field(
default=True,
description="Render reasoning as live UI parts "
"(``StreamReasoning*`` wire events). False suppresses the live "
"wire events only; ``role='reasoning'`` rows are always persisted "
"so the reasoning bubble hydrates on reload. Tokens are billed "
"upstream regardless.",
)
stream_replay_count: int = Field(
default=200,
ge=1,
le=10000,
description="Max Redis stream entries replayed on SSE reconnect.",
)
claude_agent_thinking_effort: Literal["low", "medium", "high", "max"] | None = (
Field(
default=None,
description="Thinking effort level: 'low', 'medium', 'high', 'max', or None. "
"Only applies to models with extended thinking (Opus). "
"Sonnet doesn't have extended thinking — setting effort on Sonnet "
"can cause <internal_reasoning> tag leaks. "
"Applies to models that emit a reasoning channel — Opus (extended "
"thinking) and Kimi K2.6 (OpenRouter ``reasoning`` extension lit "
"up by #12871). Sonnet does not have extended thinking — setting "
"effort on Sonnet can cause <internal_reasoning> tag leaks. "
"None = let the model decide. Override via CHAT_CLAUDE_AGENT_THINKING_EFFORT.",
)
)
@@ -287,6 +293,31 @@ class ChatConfig(BaseSettings):
"(24h, permanent) TTL option — see "
"https://platform.claude.com/docs/en/build-with-claude/prompt-caching.",
)
sdk_include_partial_messages: bool = Field(
default=True,
description="Stream SDK responses token-by-token instead of in "
"one lump at the end. Set to False if the SDK path starts "
"double-writing text or dropping the tail of long messages.",
)
sdk_reconcile_openrouter_cost: bool = Field(
default=True,
description="Query OpenRouter's ``/api/v1/generation?id=`` after each "
"SDK turn and record the authoritative ``total_cost`` instead of the "
"Claude Agent SDK CLI's estimate. Covers every OpenRouter-routed "
"SDK turn regardless of vendor — the CLI's static Anthropic pricing "
"table is accurate for Anthropic models (Sonnet/Opus via OpenRouter "
"bill at Anthropic's own rates, penny-for-penny), but the reconcile "
"catches any future rate change the CLI hasn't picked up and makes "
"non-Anthropic cost (Kimi et al) correct — real billed amount, "
"matching the baseline path's ``usage.cost`` read since #12864. "
"Kill-switch for emergencies: set ``CHAT_SDK_RECONCILE_OPENROUTER_COST"
"=false`` to fall back to the CLI's ``total_cost_usd`` reported "
"synchronously (accurate-for-Anthropic / over-billed-for-Kimi). "
"Tradeoff: 0.5-2s window between turn end and cost write; rate-limit "
"counter briefly unaware, back-to-back turns in that window see "
"stale state. The alternative (writing an estimate sync then a "
"correction delta) would double-count the rate limit.",
)
claude_agent_cli_path: str | None = Field(
default=None,
description="Optional explicit path to a Claude Code CLI binary. "
@@ -457,6 +488,59 @@ class ChatConfig(BaseSettings):
)
return v
@model_validator(mode="after")
def _validate_sdk_model_vendor_compatibility(self) -> "ChatConfig":
"""Fail at config load when an SDK model slug is incompatible with
explicit direct-Anthropic mode.
The SDK path's ``_normalize_model_name`` raises ``ValueError`` when
a non-Anthropic vendor slug (e.g. ``moonshotai/kimi-k2.6``) is paired
with direct-Anthropic mode — but that fires inside the request loop,
so a misconfigured deployment would surface a 500 to every user
instead of failing visibly at boot.
Only the **explicit** opt-out (``use_openrouter=False``) is checked
here, not the credential-missing path. Build environments and
OpenAPI-schema export jobs construct ``ChatConfig()`` without any
OpenRouter credentials in the env — that's not a misconfiguration,
it's "config loads ok, but no SDK turn will succeed until creds are
wired". The runtime guard in ``_normalize_model_name`` still
catches the credential-missing path on the first SDK turn.
Covers all three SDK fields that flow through
``_normalize_model_name``: primary tier
(``thinking_standard_model``), advanced tier
(``thinking_advanced_model``), and fallback model
(``claude_agent_fallback_model`` via ``_resolve_fallback_model``).
Skipped when ``use_claude_code_subscription=True`` because the
subscription path resolves the model to ``None`` (CLI default)
and never calls ``_normalize_model_name``. Empty fallback strings
are also skipped (no fallback configured).
"""
if self.use_claude_code_subscription:
return self
if self.use_openrouter:
return self
for field_name in (
"thinking_standard_model",
"thinking_advanced_model",
"claude_agent_fallback_model",
):
value: str = getattr(self, field_name)
if not value or "/" not in value:
continue
if value.split("/", 1)[0] != "anthropic":
raise ValueError(
f"Direct-Anthropic mode (use_openrouter=False) "
f"requires an Anthropic model for {field_name}, got "
f"{value!r}. Set CHAT_THINKING_STANDARD_MODEL / "
f"CHAT_THINKING_ADVANCED_MODEL / "
f"CHAT_CLAUDE_AGENT_FALLBACK_MODEL to an anthropic/* "
f"slug, or set CHAT_USE_OPENROUTER=true."
)
return self
# Prompt paths for different contexts
PROMPT_PATHS: dict[str, str] = {
"default": "prompts/chat_system.md",

View File

@@ -5,12 +5,17 @@ import pytest
from .config import ChatConfig
# Env vars that the ChatConfig validators read — must be cleared so they don't
# override the explicit constructor values we pass in each test.
# override the explicit constructor values we pass in each test. Includes the
# SDK/baseline model aliases so a leftover ``CHAT_MODEL=...`` in the developer
# or CI environment can't change whether
# ``_validate_sdk_model_vendor_compatibility`` raises.
_ENV_VARS_TO_CLEAR = (
"CHAT_USE_E2B_SANDBOX",
"CHAT_E2B_API_KEY",
"E2B_API_KEY",
"CHAT_USE_OPENROUTER",
"CHAT_USE_CLAUDE_AGENT_SDK",
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
"CHAT_API_KEY",
"OPEN_ROUTER_API_KEY",
"OPENAI_API_KEY",
@@ -19,6 +24,16 @@ _ENV_VARS_TO_CLEAR = (
"OPENAI_BASE_URL",
"CHAT_CLAUDE_AGENT_CLI_PATH",
"CLAUDE_AGENT_CLI_PATH",
"CHAT_FAST_STANDARD_MODEL",
"CHAT_FAST_MODEL",
"CHAT_FAST_ADVANCED_MODEL",
"CHAT_THINKING_STANDARD_MODEL",
"CHAT_THINKING_ADVANCED_MODEL",
"CHAT_MODEL",
"CHAT_ADVANCED_MODEL",
"CHAT_CLAUDE_AGENT_FALLBACK_MODEL",
"CHAT_RENDER_REASONING_IN_UI",
"CHAT_STREAM_REPLAY_COUNT",
)
@@ -28,6 +43,22 @@ def _clean_env(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv(var, raising=False)
def _make_direct_safe_config(**kwargs) -> ChatConfig:
"""Build a ``ChatConfig`` for tests that pass ``use_openrouter=False``
but aren't exercising the SDK vendor-compatibility validator.
Pins ``thinking_standard_model``/``thinking_advanced_model`` to anthropic/*
so the construction passes ``_validate_sdk_model_vendor_compatibility``
without each test having to repeat the override.
"""
defaults: dict = {
"thinking_standard_model": "anthropic/claude-sonnet-4-6",
"thinking_advanced_model": "anthropic/claude-opus-4-7",
}
defaults.update(kwargs)
return ChatConfig(**defaults)
class TestOpenrouterActive:
"""Tests for the openrouter_active property."""
@@ -48,7 +79,7 @@ class TestOpenrouterActive:
assert cfg.openrouter_active is False
def test_disabled_returns_false_despite_credentials(self):
cfg = ChatConfig(
cfg = _make_direct_safe_config(
use_openrouter=False,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
@@ -164,3 +195,133 @@ class TestClaudeAgentCliPathEnvFallback:
monkeypatch.setenv("CLAUDE_AGENT_CLI_PATH", str(tmp_path))
with pytest.raises(Exception, match="not a regular file"):
ChatConfig()
class TestSdkModelVendorCompatibility:
"""``model_validator`` that fails fast on SDK model vs routing-mode
mismatch — see PR #12878 iteration-2 review. Mirrors the runtime
guard in ``_normalize_model_name`` so misconfig surfaces at boot
instead of as a 500 on the first SDK turn."""
def test_direct_anthropic_with_kimi_override_raises(self):
"""A non-Anthropic SDK model must fail at config load when the
deployment has no OpenRouter credentials."""
with pytest.raises(Exception, match="requires an Anthropic model"):
ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
thinking_standard_model="moonshotai/kimi-k2.6",
)
def test_direct_anthropic_with_anthropic_default_succeeds(self):
"""Direct-Anthropic mode is fine when both SDK slugs are anthropic/*
— which is the default after the LD-routed model rollout."""
cfg = ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
)
assert cfg.thinking_standard_model == "anthropic/claude-sonnet-4-6"
def test_openrouter_with_kimi_override_succeeds(self):
"""Kimi slug round-trips cleanly when OpenRouter is on — exercised
via the LD-flag override path in production."""
cfg = ChatConfig(
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
use_claude_code_subscription=False,
thinking_standard_model="moonshotai/kimi-k2.6",
)
assert cfg.thinking_standard_model == "moonshotai/kimi-k2.6"
def test_subscription_mode_skips_check(self):
"""Subscription path resolves the model to None and bypasses
``_normalize_model_name``, so the slug check is skipped."""
cfg = ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=True,
)
assert cfg.use_claude_code_subscription is True
def test_advanced_tier_also_validated(self):
"""Both standard and advanced SDK slugs are checked."""
with pytest.raises(Exception, match="thinking_advanced_model"):
ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="moonshotai/kimi-k2.6",
)
def test_fallback_model_also_validated(self):
"""``claude_agent_fallback_model`` flows through
``_normalize_model_name`` via ``_resolve_fallback_model`` so the
same direct-Anthropic guard applies."""
with pytest.raises(Exception, match="claude_agent_fallback_model"):
ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4-7",
claude_agent_fallback_model="moonshotai/kimi-k2.6",
)
def test_empty_fallback_skipped(self):
"""Empty ``claude_agent_fallback_model`` (no fallback configured)
must not trip the validator — the fallback-disabled state is
intentional and shouldn't require a placeholder anthropic/* slug."""
cfg = ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4-7",
claude_agent_fallback_model="",
)
assert cfg.claude_agent_fallback_model == ""
class TestRenderReasoningInUi:
"""``render_reasoning_in_ui`` gates reasoning wire events globally."""
def test_defaults_to_true(self):
"""Default must stay True — flipping it silences the reasoning
collapse for every user, which is an opt-in operator decision."""
cfg = ChatConfig()
assert cfg.render_reasoning_in_ui is True
def test_env_override_false(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CHAT_RENDER_REASONING_IN_UI", "false")
cfg = ChatConfig()
assert cfg.render_reasoning_in_ui is False
class TestStreamReplayCount:
"""``stream_replay_count`` caps the SSE reconnect replay batch size."""
def test_default_is_200(self):
"""200 covers a full Kimi turn after coalescing (~150 events) while
bounding the replay storm from 1000+ chunks."""
cfg = ChatConfig()
assert cfg.stream_replay_count == 200
def test_env_override(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("CHAT_STREAM_REPLAY_COUNT", "500")
cfg = ChatConfig()
assert cfg.stream_replay_count == 500
def test_zero_rejected(self):
"""count=0 would make XREAD replay nothing — rejected via ge=1."""
with pytest.raises(Exception):
ChatConfig(stream_replay_count=0)

View File

@@ -105,25 +105,46 @@ class CoPilotExecutor(AppProcess):
time.sleep(1e5)
def cleanup(self):
"""Graceful shutdown with active execution waiting."""
pid = os.getpid()
logger.info(f"[cleanup {pid}] Starting graceful shutdown...")
"""Graceful shutdown — mirrors ``backend.executor.manager`` pattern.
# Signal the consumer thread to stop
1. Stop consumer immediately (both the Python flag that gates
``_handle_run_message`` and ``channel.stop_consuming()`` at
the broker), so no new work enters.
2. Passively wait for ``active_tasks`` to drain — each turn's
own ``finally`` publishes its terminal state via
``mark_session_completed``. When a turn exits, ``on_run_done``
removes it from ``active_tasks`` and releases its cluster lock.
3. Shut down the thread-pool executor (cancels pending, leaves
running threads alone — process exit handles them).
4. Release any cluster locks still held (defensive — on_run_done's
finally should have already released them).
5. Stop message consumer threads + disconnect pika clients.
The zombie-session bug this PR targets is handled inside each
turn's own lifecycle by :func:`sync_fail_close_session`, NOT by
cleanup — so cleanup can stay as a simple "wait, then teardown"
and matches agent-executor's proven pattern.
"""
pid = os.getpid()
prefix = f"[cleanup {pid}]"
logger.info(f"{prefix} Starting graceful shutdown...")
# 1. Stop consumer — flag AND broker-side
try:
self.stop_consuming.set()
run_channel = self.run_client.get_channel()
run_channel.connection.add_callback_threadsafe(
lambda: run_channel.stop_consuming()
)
logger.info(f"[cleanup {pid}] Consumer has been signaled to stop")
logger.info(f"{prefix} Consumer has been signaled to stop")
except Exception as e:
logger.error(f"[cleanup {pid}] Error stopping consumer: {e}")
logger.error(f"{prefix} Error stopping consumer: {e}")
# Wait for active executions to complete
# 2. Wait for in-flight turns to finish naturally
if self.active_tasks:
logger.info(
f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
f"{prefix} Waiting for {len(self.active_tasks)} active tasks "
f"to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
)
start_time = time.monotonic()
@@ -138,38 +159,42 @@ class CoPilotExecutor(AppProcess):
if not self.active_tasks:
break
# Refresh cluster locks periodically
current_time = time.monotonic()
if current_time - last_refresh >= lock_refresh_interval:
now = time.monotonic()
if now - last_refresh >= lock_refresh_interval:
for lock in list(self._task_locks.values()):
try:
lock.refresh()
except Exception as e:
logger.warning(
f"[cleanup {pid}] Failed to refresh lock: {e}"
)
last_refresh = current_time
logger.warning(f"{prefix} Failed to refresh lock: {e}")
last_refresh = now
logger.info(
f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..."
f"{prefix} {len(self.active_tasks)} tasks still active, waiting..."
)
time.sleep(10.0)
# Stop message consumers
if self.active_tasks:
logger.warning(
f"{prefix} {len(self.active_tasks)} tasks still running after "
f"{GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s — process exit will "
f"abandon them; RabbitMQ redelivery handles the message."
)
# 3. Stop message consumer threads
if self._run_thread:
self._stop_message_consumers(
self._run_thread, self.run_client, "[cleanup][run]"
self._run_thread, self.run_client, f"{prefix} [run]"
)
if self._cancel_thread:
self._stop_message_consumers(
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
self._cancel_thread, self.cancel_client, f"{prefix} [cancel]"
)
# Clean up worker threads (closes per-loop workspace storage sessions)
# 4. Worker cleanup + executor shutdown
if self._executor:
from .processor import cleanup_worker
logger.info(f"[cleanup {pid}] Cleaning up workers...")
logger.info(f"{prefix} Cleaning up workers...")
futures = []
for _ in range(self._executor._max_workers):
futures.append(self._executor.submit(cleanup_worker))
@@ -177,22 +202,20 @@ class CoPilotExecutor(AppProcess):
try:
f.result(timeout=10)
except Exception as e:
logger.warning(f"[cleanup {pid}] Worker cleanup error: {e}")
logger.warning(f"{prefix} Worker cleanup error: {e}")
logger.info(f"[cleanup {pid}] Shutting down executor...")
logger.info(f"{prefix} Shutting down executor...")
self._executor.shutdown(wait=False)
# Release any remaining locks
# 5. Release any cluster locks still held
for session_id, lock in list(self._task_locks.items()):
try:
lock.release()
logger.info(f"[cleanup {pid}] Released lock for {session_id}")
logger.info(f"{prefix} Released lock for {session_id}")
except Exception as e:
logger.error(
f"[cleanup {pid}] Failed to release lock for {session_id}: {e}"
)
logger.error(f"{prefix} Failed to release lock for {session_id}: {e}")
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
logger.info(f"{prefix} Graceful shutdown completed")
# ============ RabbitMQ Consumer Methods ============ #
@@ -387,13 +410,12 @@ class CoPilotExecutor(AppProcess):
# Execute the task
try:
self._task_locks[session_id] = cluster_lock
logger.info(
f"Acquired cluster lock for {session_id}, "
f"executor_id={self.executor_id}"
)
self._task_locks[session_id] = cluster_lock
cancel_event = threading.Event()
future = self.executor.submit(
execute_copilot_turn, entry, cancel_event, cluster_lock
@@ -425,7 +447,6 @@ class CoPilotExecutor(AppProcess):
error_msg = str(e) or type(e).__name__
logger.exception(f"Error in run completion callback: {error_msg}")
finally:
# Release the cluster lock
if session_id in self._task_locks:
logger.info(f"Releasing cluster lock for {session_id}")
self._task_locks[session_id].release()

View File

@@ -5,6 +5,7 @@ in a thread-local context, following the graph executor pattern.
"""
import asyncio
import concurrent.futures
import logging
import os
import subprocess
@@ -30,6 +31,87 @@ from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
SHUTDOWN_ERROR_MESSAGE = (
"Copilot executor shut down before this turn finished. Please retry."
)
# Max time execute() blocks after calling future.cancel() / when draining a
# soon-to-be-cancelled future. Gives _execute_async's own finally a chance to
# publish the accurate terminal state over the Redis CAS; long enough to let
# an in-flight Redis call settle, short enough that shutdown doesn't stall.
_CANCEL_GRACE_SECONDS = 5.0
# Max time the sync safety net itself spends on a single Redis CAS. Without
# this bound the whole point of ``sync_fail_close_session`` is defeated —
# ``mark_session_completed`` would hang on the same broken Redis that caused
# the original failure. On timeout we give up silently; worst case the
# session stays ``running`` until the stale-session watchdog reaps it, but
# at least the pool worker thread isn't blocked forever.
_FAIL_CLOSE_REDIS_TIMEOUT = 10.0
# Module-level symbol preserved for backward-compat with callers that import
# ``sync_fail_close_session``; the real implementation now lives on
# ``CoPilotProcessor`` so it can reuse ``self.execution_loop`` (same
# pattern as ``backend.executor.manager``'s ``node_execution_loop`` bridge
# at :meth:`ExecutionProcessor.on_graph_execution`).
def sync_fail_close_session(
session_id: str,
log: "CoPilotLogMetadata | TruncatedLogger",
execution_loop: asyncio.AbstractEventLoop,
) -> None:
"""Synchronously mark *session_id* as failed from the pool worker thread.
Submits the CAS coroutine to the long-lived *execution_loop* via
``run_coroutine_threadsafe`` — the same shape agent-executor uses at
:meth:`backend.executor.manager.ExecutionProcessor.on_graph_execution`
to reach its ``node_execution_loop`` from the pool worker. Reusing the
persistent loop means:
* no fresh TCP connection per turn (the ``@thread_cached``
``AsyncRedis`` on the execution thread stays bound to the same loop
and is reused across every turn);
* no loop-teardown overhead;
* no ``clear_cache()`` gymnastics to dodge the "loop is closed" pitfall.
``mark_session_completed`` is an atomic CAS on ``status == "running"``,
so when the async path already wrote a terminal state the sync call is
a cheap no-op. The inner ``asyncio.wait_for`` bounds the Redis call so
a wedged Redis can't hang the safety net for the full redis-py default
TCP timeout; the outer ``.result(timeout=...)`` is a belt-and-braces
upper bound for the cross-thread wait.
"""
async def _bounded() -> None:
await asyncio.wait_for(
stream_registry.mark_session_completed(
session_id, error_message=SHUTDOWN_ERROR_MESSAGE
),
timeout=_FAIL_CLOSE_REDIS_TIMEOUT,
)
try:
future = asyncio.run_coroutine_threadsafe(_bounded(), execution_loop)
except RuntimeError as e:
# execution_loop is closed — happens if cleanup() already ran the
# per-worker teardown. Nothing we can do; let the stale-session
# watchdog reap it.
log.warning(f"sync fail-close skipped (execution_loop closed): {e}")
return
try:
future.result(timeout=_FAIL_CLOSE_REDIS_TIMEOUT + 2)
except concurrent.futures.TimeoutError:
log.warning(
f"sync fail-close timed out after {_FAIL_CLOSE_REDIS_TIMEOUT}s "
f"(session={session_id})"
)
future.cancel()
except Exception as e:
log.warning(f"sync fail-close mark_session_completed failed: {e}")
# ============ Mode Routing ============ #
@@ -252,12 +334,13 @@ class CoPilotProcessor:
):
"""Execute a CoPilot turn.
Runs the async logic in the worker's event loop and handles errors.
Args:
entry: The turn payload containing session and message info
cancel: Threading event to signal cancellation
cluster_lock: Distributed lock to prevent duplicate execution
Thin wrapper around :meth:`_execute`. The ``try/finally`` here
guarantees :func:`sync_fail_close_session` runs on every exit
path — normal completion, exception, or a wedged event loop
that escapes via :data:`_CANCEL_GRACE_SECONDS` timeout.
``mark_session_completed`` is an atomic CAS on
``status == "running"``, so when the async path already wrote a
terminal state the sync call is a cheap no-op.
"""
log = CoPilotLogMetadata(
logging.getLogger(__name__),
@@ -265,10 +348,28 @@ class CoPilotProcessor:
user_id=entry.user_id,
)
log.info("Starting execution")
start_time = time.monotonic()
try:
self._execute(entry, cancel, cluster_lock, log)
finally:
sync_fail_close_session(entry.session_id, log, self.execution_loop)
elapsed = time.monotonic() - start_time
log.info(f"Execution completed in {elapsed:.2f}s")
# Run the async execution in our event loop
def _execute(
self,
entry: CoPilotExecutionEntry,
cancel: threading.Event,
cluster_lock: ClusterLock,
log: CoPilotLogMetadata,
):
"""Submit the async turn to ``self.execution_loop`` and drive it.
Handles the sync/async boundary (cancel-event checks, cluster-lock
refresh, bounded waits) without any Redis-state cleanup logic —
that lives in :func:`sync_fail_close_session` which the outer
:meth:`execute` always invokes on exit.
"""
future = asyncio.run_coroutine_threadsafe(
self._execute_async(entry, cancel, cluster_lock, log),
self.execution_loop,
@@ -282,16 +383,27 @@ class CoPilotProcessor:
if cancel.is_set():
log.info("Cancellation requested")
future.cancel()
break
# Refresh cluster lock to maintain ownership
# Give _execute_async's own finally a short window to
# publish its accurate terminal state before the outer
# sync safety net fires.
try:
future.result(timeout=_CANCEL_GRACE_SECONDS)
except BaseException:
pass
return
cluster_lock.refresh()
if not future.cancelled():
# Get result to propagate any exceptions
future.result()
elapsed = time.monotonic() - start_time
log.info(f"Execution completed in {elapsed:.2f}s")
# Bounded timeout so a wedged event loop can't trap us here —
# on timeout we escape to execute()'s finally and the sync
# safety net fires.
try:
future.result(timeout=_CANCEL_GRACE_SECONDS)
except concurrent.futures.TimeoutError:
log.warning(
"Future did not complete within grace window; "
"falling through to sync fail-close"
)
async def _execute_async(
self,

View File

@@ -10,6 +10,8 @@ the real production helpers from ``processor.py`` so the routing logic
has meaningful coverage.
"""
import asyncio
import concurrent.futures
import logging
import threading
from unittest.mock import AsyncMock, MagicMock, patch
@@ -20,6 +22,7 @@ from backend.copilot.executor.processor import (
CoPilotProcessor,
resolve_effective_mode,
resolve_use_sdk_for_mode,
sync_fail_close_session,
)
from backend.copilot.executor.utils import CoPilotExecutionEntry, CoPilotLogMetadata
@@ -275,3 +278,221 @@ class TestExecuteAsyncAclose:
await proc._execute_async(_make_entry(), cancel, cluster_lock, _make_log())
assert published.aclose_called is True
@pytest.fixture
def exec_loop():
"""Long-lived asyncio loop on a daemon thread — mirrors the layout
``CoPilotProcessor`` sets up (``execution_loop`` + ``execution_thread``)
so ``sync_fail_close_session`` has a real cross-thread loop to submit
into via ``run_coroutine_threadsafe``."""
loop = asyncio.new_event_loop()
thread = threading.Thread(target=loop.run_forever, daemon=True)
thread.start()
try:
yield loop
finally:
loop.call_soon_threadsafe(loop.stop)
thread.join(timeout=5)
loop.close()
class TestSyncFailCloseSession:
"""``sync_fail_close_session`` is the last-line-of-defense invoked from
``CoPilotProcessor.execute``'s ``finally``. It must call
``mark_session_completed`` via the processor's long-lived
``execution_loop`` (cross-thread submit) and must swallow Redis
failures so a transient outage doesn't propagate out of the finally."""
def test_invokes_mark_session_completed_with_shutdown_message(
self, exec_loop
) -> None:
mock_mark = AsyncMock()
with patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=mock_mark,
):
sync_fail_close_session("sess-1", _make_log(), exec_loop)
mock_mark.assert_awaited_once()
assert mock_mark.await_args is not None
assert mock_mark.await_args.args[0] == "sess-1"
assert "shut down" in mock_mark.await_args.kwargs["error_message"].lower()
def test_swallows_redis_error(self, exec_loop) -> None:
# Raising from the mock ensures the helper catches the exception
# instead of propagating it back into execute()'s finally block.
mock_mark = AsyncMock(side_effect=RuntimeError("redis down"))
with patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=mock_mark,
):
sync_fail_close_session("sess-2", _make_log(), exec_loop) # must not raise
mock_mark.assert_awaited_once()
def test_closed_execution_loop_skipped_cleanly(self) -> None:
"""If cleanup_worker has already stopped the execution_loop by the
time the safety net fires, ``run_coroutine_threadsafe`` raises
RuntimeError. Expected behavior: log + return without propagating."""
dead_loop = asyncio.new_event_loop()
dead_loop.close()
mock_mark = AsyncMock()
with patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=mock_mark,
):
# Must not raise even though the loop is closed
sync_fail_close_session("sess-closed-loop", _make_log(), dead_loop)
# mark_session_completed was never scheduled because the loop was dead
mock_mark.assert_not_awaited()
def test_bounded_timeout_when_redis_hangs(self, exec_loop) -> None:
"""Scenario D: Redis unreachable — the inner ``asyncio.wait_for``
must fire and the helper must return without blocking the worker.
Simulates a wedged Redis by sleeping past the 10s fail-close budget.
The helper must return within the configured grace (+ a small
scheduler margin) and must not re-raise.
"""
import time as _time
from backend.copilot.executor.processor import _FAIL_CLOSE_REDIS_TIMEOUT
async def _hang(*_args, **_kwargs):
await asyncio.sleep(_FAIL_CLOSE_REDIS_TIMEOUT + 5)
with patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=_hang,
):
start = _time.monotonic()
sync_fail_close_session(
"sess-hang", _make_log(), exec_loop
) # must not raise
elapsed = _time.monotonic() - start
# wait_for fires at _FAIL_CLOSE_REDIS_TIMEOUT; outer future.result
# has +2s slack. If the timeout is missing/broken the helper would
# block the full sleep duration (~15s).
assert elapsed < _FAIL_CLOSE_REDIS_TIMEOUT + 4.0, (
f"sync_fail_close_session hung for {elapsed:.1f}s — bounded "
f"timeout did not fire"
)
# ---------------------------------------------------------------------------
# End-to-end execute() safety-net coverage — the PR's core invariant
# ---------------------------------------------------------------------------
class TestExecuteSafetyNet:
"""``CoPilotProcessor.execute`` must always invoke
``sync_fail_close_session`` in its ``finally`` so a session never stays
``status=running`` in Redis.
Validates the four deploy-time scenarios the PR targets:
* A — SIGTERM mid-turn: ``cancel`` event fires, ``_execute`` returns,
safety net still runs.
* B — happy path: normal completion, safety net runs (cheap CAS no-op).
* C — zombie Redis state: the async ``mark_session_completed`` in
``_execute_async`` blows up, but the outer safety net marks the
session failed anyway.
* D — covered by ``TestSyncFailCloseSession::test_bounded_timeout…``.
"""
def _attach_exec_loop(self, proc: CoPilotProcessor, loop) -> None:
"""``execute`` dispatches the safety net onto ``self.execution_loop``.
Tests don't call ``on_executor_start`` (which spawns the real
per-worker loop), so wire the shared fixture loop in directly."""
proc.execution_loop = loop
def _run_execute_in_thread(self, proc: CoPilotProcessor, cancel: threading.Event):
"""``CoPilotProcessor.execute`` expects to be called from a pool
worker thread that has *no* running event loop, so we always run
it off the main thread to preserve that invariant. Returns the
future so callers can inspect both result and exception paths."""
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
try:
fut = pool.submit(proc.execute, _make_entry(), cancel, MagicMock())
# Block until execute() returns (or raises) so the safety net
# has run by the time we inspect mocks.
try:
fut.result(timeout=30)
except BaseException:
pass
return fut
finally:
pool.shutdown(wait=True)
def test_happy_path_invokes_safety_net(self, exec_loop) -> None:
"""Scenario B: normal completion still runs the sync safety net.
Proves the ``finally`` always fires, even when nothing went wrong —
``mark_session_completed``'s atomic CAS makes this a cheap no-op
in production."""
mock_mark = AsyncMock()
proc = CoPilotProcessor()
self._attach_exec_loop(proc, exec_loop)
with patch.object(proc, "_execute"), patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=mock_mark,
):
self._run_execute_in_thread(proc, threading.Event())
mock_mark.assert_awaited_once()
assert mock_mark.await_args is not None
assert mock_mark.await_args.args[0] == "sess-1"
def test_sigterm_mid_turn_invokes_safety_net(self, exec_loop) -> None:
"""Scenario A: worker raises (simulating future.cancel + grace
timeout escaping ``_execute``); ``execute`` must still reach the
safety net in its ``finally`` and mark the session failed."""
mock_mark = AsyncMock()
proc = CoPilotProcessor()
self._attach_exec_loop(proc, exec_loop)
with patch.object(
proc,
"_execute",
side_effect=concurrent.futures.TimeoutError("grace expired"),
), patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=mock_mark,
):
self._run_execute_in_thread(proc, threading.Event())
mock_mark.assert_awaited_once()
def test_zombie_redis_async_path_still_marks_session_failed(
self, exec_loop
) -> None:
"""Scenario C: ``_execute_async``'s own ``mark_session_completed``
call is broken (simulating the exact async-Redis hiccup that caused
the original zombie sessions). The outer ``sync_fail_close_session``
runs on the processor's long-lived ``execution_loop`` and succeeds
where the async path failed."""
call_log: list[str] = []
async def _ok(*args, **kwargs):
call_log.append("sync-ok")
def _broken_execute(entry, cancel, cluster_lock, log):
# Simulate the async path raising because its Redis client is
# wedged (the pre-fix zombie-session scenario).
raise RuntimeError("async Redis client broken")
proc = CoPilotProcessor()
self._attach_exec_loop(proc, exec_loop)
with patch.object(proc, "_execute", side_effect=_broken_execute), patch(
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
new=_ok,
):
self._run_execute_in_thread(proc, threading.Event())
# The sync safety net must have fired despite the async path
# blowing up — this is the core guarantee of the PR.
assert call_log == [
"sync-ok"
], f"expected sync_fail_close_session to run once, got {call_log!r}"

View File

@@ -89,11 +89,16 @@ def get_session_lock_key(session_id: str) -> str:
# CoPilot operations can include extended thinking and agent generation
# which may take 30+ minutes to complete
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
# which may take several hours to complete. Matches the pod's
# terminationGracePeriodSeconds in the helm chart so a rolling deploy can let
# the longest legitimate turn finish. Also bounds the stale-session auto-
# complete watchdog in stream_registry (consumer_timeout + 5min buffer).
COPILOT_CONSUMER_TIMEOUT_SECONDS = 6 * 60 * 60 # 6 hours
# Graceful shutdown timeout - allow in-flight operations to complete
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
# Graceful shutdown timeout - must match COPILOT_CONSUMER_TIMEOUT_SECONDS so
# cleanup can let the longest legitimate turn complete before the pod is
# SIGKILL'd by kubelet.
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = COPILOT_CONSUMER_TIMEOUT_SECONDS
def create_copilot_queue_config() -> RabbitMQConfig:
@@ -113,9 +118,27 @@ def create_copilot_queue_config() -> RabbitMQConfig:
durable=True,
auto_delete=False,
arguments={
# Extended consumer timeout for long-running LLM operations
# Default 30-minute timeout is insufficient for extended thinking
# and agent generation which can take 30+ minutes
# Consumer timeout matches the pod graceful-shutdown window so a
# rolling deploy never forces redelivery of a turn that the pod
# is still legitimately finishing.
#
# Deploy note: RabbitMQ (verified on 4.1.4) does NOT strictly
# compare ``x-consumer-timeout`` on queue redeclaration, so this
# value can change between deploys without triggering
# PRECONDITION_FAILED. To update the *effective* timeout on an
# already-running queue before the new code deploys (so pods
# mid-shutdown don't have their consumer cancelled at the old
# limit), apply a policy:
#
# rabbitmqctl set_policy copilot-consumer-timeout \
# "^copilot_execution_queue$" \
# '{"consumer-timeout": 21600000}' \
# --apply-to queues
#
# The policy takes effect immediately. Once the policy is set
# to match the code's value the policy is redundant for new
# pods and can be removed after a stable deploy if desired —
# but it's harmless to leave in place.
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
* 1000,
},

View File

@@ -0,0 +1,104 @@
"""LaunchDarkly-aware model selection for the copilot.
Each cell of the ``(mode, tier)`` matrix has a static default baked into
``ChatConfig`` (see ``copilot/config.py``) and a matching LaunchDarkly
string-valued feature flag that can override it per-user. This module
centralises the lookup so both the baseline and SDK paths agree on the
selection rule and so A/B experiments can target a single cell without
shipping a config change.
Matrix:
+----------+-------------------------------------+-------------------------------------+
| | standard | advanced |
+----------+-------------------------------------+-------------------------------------+
| fast | copilot-fast-standard-model | copilot-fast-advanced-model |
| thinking | copilot-thinking-standard-model | copilot-thinking-advanced-model |
+----------+-------------------------------------+-------------------------------------+
LD flag values are arbitrary strings (model identifiers, e.g.
``"anthropic/claude-sonnet-4-6"`` or ``"moonshotai/kimi-k2.6"``). Empty
or non-string values fall back to the config default.
"""
from __future__ import annotations
import logging
from typing import Literal
from backend.copilot.config import ChatConfig
from backend.util.feature_flag import Flag, get_feature_flag_value
logger = logging.getLogger(__name__)
ModelMode = Literal["fast", "thinking"]
ModelTier = Literal["standard", "advanced"]
_FLAG_BY_CELL: dict[tuple[ModelMode, ModelTier], Flag] = {
("fast", "standard"): Flag.COPILOT_FAST_STANDARD_MODEL,
("fast", "advanced"): Flag.COPILOT_FAST_ADVANCED_MODEL,
("thinking", "standard"): Flag.COPILOT_THINKING_STANDARD_MODEL,
("thinking", "advanced"): Flag.COPILOT_THINKING_ADVANCED_MODEL,
}
def _config_default(config: ChatConfig, mode: ModelMode, tier: ModelTier) -> str:
if mode == "fast":
return (
config.fast_advanced_model
if tier == "advanced"
else config.fast_standard_model
)
return (
config.thinking_advanced_model
if tier == "advanced"
else config.thinking_standard_model
)
async def resolve_model(
mode: ModelMode,
tier: ModelTier,
user_id: str | None,
*,
config: ChatConfig,
) -> str:
"""Return the model identifier for a ``(mode, tier)`` cell.
Consults the matching LaunchDarkly flag for *user_id* first and
falls back to the ``ChatConfig`` default on missing user, missing
flag, or non-string flag value. Passing *config* explicitly keeps
the resolver cheap to unit-test.
"""
fallback = _config_default(config, mode, tier).strip()
if not user_id:
return fallback
flag = _FLAG_BY_CELL[(mode, tier)]
try:
value = await get_feature_flag_value(flag.value, user_id, default=fallback)
except Exception:
logger.warning(
"[model_router] LD lookup failed for %s — using config default %s",
flag.value,
fallback,
exc_info=True,
)
return fallback
if isinstance(value, str) and value.strip():
return value.strip()
if value != fallback:
reason = (
"empty string"
if isinstance(value, str)
else f"non-string ({type(value).__name__})"
)
logger.warning(
"[model_router] LD flag %s returned %s — using config default %s",
flag.value,
reason,
fallback,
)
return fallback

View File

@@ -0,0 +1,166 @@
"""Tests for the LD-aware model resolver."""
from unittest.mock import AsyncMock, patch
import pytest
from backend.copilot.config import ChatConfig
from backend.copilot.model_router import _FLAG_BY_CELL, _config_default, resolve_model
def _make_config() -> ChatConfig:
"""Build a config with the canonical defaults so tests read naturally."""
return ChatConfig(
fast_standard_model="anthropic/claude-sonnet-4-6",
fast_advanced_model="anthropic/claude-opus-4.7",
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4.7",
)
class TestConfigDefault:
def test_fast_standard(self):
cfg = _make_config()
assert _config_default(cfg, "fast", "standard") == cfg.fast_standard_model
def test_fast_advanced(self):
cfg = _make_config()
assert _config_default(cfg, "fast", "advanced") == cfg.fast_advanced_model
def test_thinking_standard(self):
cfg = _make_config()
assert (
_config_default(cfg, "thinking", "standard") == cfg.thinking_standard_model
)
def test_thinking_advanced(self):
cfg = _make_config()
assert (
_config_default(cfg, "thinking", "advanced") == cfg.thinking_advanced_model
)
class TestResolveModel:
@pytest.mark.asyncio
async def test_missing_user_returns_fallback(self):
"""Without user_id there's no LD context — skip the lookup entirely."""
cfg = _make_config()
with patch("backend.copilot.model_router.get_feature_flag_value") as mock_flag:
result = await resolve_model("fast", "standard", None, config=cfg)
assert result == cfg.fast_standard_model
mock_flag.assert_not_called()
@pytest.mark.asyncio
async def test_missing_user_strips_whitespace_from_fallback(self):
"""Sentry MEDIUM: the anonymous-user branch returned an unstripped
config value. If ``CHAT_*_MODEL`` env carries trailing whitespace
the downstream ``resolved == tier_default`` check in
``_resolve_sdk_model_for_request`` would diverge from the
whitespace-stripped LD side, bypassing subscription mode for
every anonymous request. Strip at the source."""
cfg = ChatConfig(
fast_standard_model="anthropic/claude-sonnet-4-6 ", # trailing ws
fast_advanced_model="anthropic/claude-opus-4.7",
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4.7",
)
result = await resolve_model("fast", "standard", None, config=cfg)
assert result == "anthropic/claude-sonnet-4-6"
@pytest.mark.asyncio
async def test_ld_string_override_wins(self):
"""LD-returned model string replaces the config default."""
cfg = _make_config()
with patch(
"backend.copilot.model_router.get_feature_flag_value",
new=AsyncMock(return_value="moonshotai/kimi-k2.6"),
):
result = await resolve_model("fast", "standard", "user-1", config=cfg)
assert result == "moonshotai/kimi-k2.6"
@pytest.mark.asyncio
async def test_whitespace_is_stripped(self):
cfg = _make_config()
with patch(
"backend.copilot.model_router.get_feature_flag_value",
new=AsyncMock(return_value=" xai/grok-4 "),
):
result = await resolve_model("thinking", "advanced", "user-1", config=cfg)
assert result == "xai/grok-4"
@pytest.mark.asyncio
async def test_non_string_value_falls_back_with_type_in_warning(self, caplog):
"""LD misconfigured as a boolean flag — don't try to use ``True`` as a
model name; return the config default. Warning must say
'non-string' (not 'empty string') so the LD operator knows the
flag type is wrong, not just missing a value."""
import logging
cfg = _make_config()
with caplog.at_level(logging.WARNING, logger="backend.copilot.model_router"):
with patch(
"backend.copilot.model_router.get_feature_flag_value",
new=AsyncMock(return_value=True),
):
result = await resolve_model("fast", "advanced", "user-1", config=cfg)
assert result == cfg.fast_advanced_model
assert any("non-string" in r.message for r in caplog.records)
@pytest.mark.asyncio
async def test_empty_string_falls_back_with_empty_in_warning(self, caplog):
"""When LD returns ``""`` the warning must say 'empty string'
not 'non-string' — so the operator doesn't chase a type bug
when the flag is simply unset to an empty value."""
import logging
cfg = _make_config()
with caplog.at_level(logging.WARNING, logger="backend.copilot.model_router"):
with patch(
"backend.copilot.model_router.get_feature_flag_value",
new=AsyncMock(return_value=""),
):
result = await resolve_model("fast", "standard", "user-1", config=cfg)
assert result == cfg.fast_standard_model
messages = [r.message for r in caplog.records]
assert any("empty string" in m for m in messages)
assert not any("non-string" in m for m in messages)
@pytest.mark.asyncio
async def test_ld_exception_falls_back(self):
"""LD client throws (network blip, SDK init race) — serve the default
instead of failing the whole request."""
cfg = _make_config()
with patch(
"backend.copilot.model_router.get_feature_flag_value",
new=AsyncMock(side_effect=RuntimeError("LD down")),
):
result = await resolve_model("fast", "standard", "user-1", config=cfg)
assert result == cfg.fast_standard_model
@pytest.mark.asyncio
async def test_all_four_cells_hit_distinct_flags(self):
"""Each (mode, tier) cell must route to its own flag — regression
guard against copy-paste bugs in the _FLAG_BY_CELL map."""
cfg = _make_config()
calls: list[str] = []
async def _capture(flag_key, user_id, default):
calls.append(flag_key)
return default
with patch(
"backend.copilot.model_router.get_feature_flag_value",
new=AsyncMock(side_effect=_capture),
):
await resolve_model("fast", "standard", "u", config=cfg)
await resolve_model("fast", "advanced", "u", config=cfg)
await resolve_model("thinking", "standard", "u", config=cfg)
await resolve_model("thinking", "advanced", "u", config=cfg)
assert calls == [
_FLAG_BY_CELL[("fast", "standard")].value,
_FLAG_BY_CELL[("fast", "advanced")].value,
_FLAG_BY_CELL[("thinking", "standard")].value,
_FLAG_BY_CELL[("thinking", "advanced")].value,
]
assert len(set(calls)) == 4

View File

@@ -0,0 +1,147 @@
"""Moonshot-specific pricing and cache-control helpers.
Moonshot's Kimi K2.x family is routed through OpenRouter's Anthropic-compat
shim — it speaks Anthropic's API shape but its pricing and cache behaviour
diverge from Anthropic in ways the Claude Agent SDK CLI and our baseline
cache-control gating don't handle on their own:
* **Rate card** — NOT the canonical cost source. The authoritative number
for every OpenRouter-routed turn is the reconcile task
(:mod:`openrouter_cost`), which reads ``total_cost`` directly from
``/api/v1/generation`` post-turn. This module exists purely so the
CLI's in-turn ``ResultMessage.total_cost_usd`` (which silently bills
Moonshot at Sonnet rates, ~5x the real Moonshot price because the CLI
pricing table only knows Anthropic) isn't left wildly wrong before the
reconcile fires AND so the reconcile's lookup-fail fallback records a
plausible Moonshot estimate rather than a Sonnet-rate overcharge.
Signal authority: reconcile >> this module's rate card >> CLI.
* **Cache-control** — Anthropic and Moonshot both accept the
``cache_control: {type: ephemeral}`` breakpoint on message blocks, but
our baseline path currently gates cache markers on an
``anthropic/`` / ``claude`` name match because non-Anthropic providers
(OpenAI, Grok, Gemini) 400 on the unknown field. Moonshot's
Anthropic-compat endpoint silently accepts and honours the marker —
empirically boosts cache hit rate on continuation turns — but was
caught in the non-Anthropic branch of the original gate.
:func:`moonshot_supports_cache_control` lets callers widen the gate
to include Moonshot without weakening the ``false`` answer for
OpenAI et al. (The predicate is intentionally narrow — Moonshot-only
— so callers combine it with an explicit Anthropic check at the call
site; see ``baseline/service.py::_supports_prompt_cache_markers``.)
Detection is prefix-based (``moonshotai/``). Moonshot routes every Kimi
SKU through the same Anthropic-compat surface and currently prices them
identically, so a new ``moonshotai/kimi-k3.0`` slug transparently
inherits both the rate card and the cache-control gate without editing
this file. Per-slug overrides are in :data:`_RATE_OVERRIDES_USD_PER_MTOK`
for when Moonshot eventually splits prices.
"""
from __future__ import annotations
# All Moonshot slugs share these rates as of April 2026 — Moonshot prices
# every Kimi K2.x SKU at $0.60/$2.80 per million (input/output) via
# OpenRouter. Cache-read / cache-write discounts are NOT applied here:
# OpenRouter currently exposes only a single input price per Moonshot
# endpoint; the real billed amount (with cache savings) lands via the
# reconcile path. Keep in sync with https://platform.moonshot.ai/docs/pricing.
_DEFAULT_MOONSHOT_RATE_USD_PER_MTOK: tuple[float, float] = (0.60, 2.80)
# Per-slug overrides for when Moonshot splits pricing across SKUs. Empty
# today — every slug matching ``moonshotai/`` falls back to
# :data:`_DEFAULT_MOONSHOT_RATE_USD_PER_MTOK`.
_RATE_OVERRIDES_USD_PER_MTOK: dict[str, tuple[float, float]] = {}
# Vendor prefix — matches any OpenRouter slug Moonshot ships. Keep as a
# module constant so the prefix check stays in exactly one place.
_MOONSHOT_PREFIX = "moonshotai/"
def is_moonshot_model(model: str | None) -> bool:
"""True when *model* is a Moonshot OpenRouter slug.
Prefix match against ``moonshotai/`` covers every Kimi SKU Moonshot
ships today (``kimi-k2``, ``kimi-k2.5``, ``kimi-k2.6``,
``kimi-k2-thinking``) plus any future SKU Moonshot publishes under
the same namespace. Used by both pricing and cache-control gating.
"""
return isinstance(model, str) and model.startswith(_MOONSHOT_PREFIX)
def rate_card_usd(model: str | None) -> tuple[float, float] | None:
"""Return (input, output) $/Mtok for *model* or None if non-Moonshot.
Looks up a per-slug override first, falling back to the shared
default for anything under ``moonshotai/``. Returns None for
non-Moonshot slugs (including ``None``) so callers can skip the
override without a preflight guard.
"""
if not is_moonshot_model(model):
return None
# ``is_moonshot_model`` narrowed ``model`` to str; dict.get is
# type-safe here despite the wider param annotation above.
assert model is not None
return _RATE_OVERRIDES_USD_PER_MTOK.get(model, _DEFAULT_MOONSHOT_RATE_USD_PER_MTOK)
def override_cost_usd(
*,
model: str | None,
sdk_reported_usd: float,
prompt_tokens: int,
completion_tokens: int,
cache_read_tokens: int,
cache_creation_tokens: int,
) -> float:
"""Recompute SDK turn cost from the Moonshot rate card.
Not the canonical cost source — the OpenRouter ``/generation``
reconcile (:mod:`openrouter_cost`) lands the authoritative billed
amount post-turn. This helper exists only to improve the CLI's
in-turn ``ResultMessage.total_cost_usd``:
1. So the ``cost_usd`` the client sees before the reconcile completes
isn't wildly wrong (the CLI would otherwise ship a Sonnet-rate
estimate, ~5x the real Moonshot bill).
2. So the reconcile's own lookup-fail fallback records a plausible
Moonshot estimate rather than the CLI's Sonnet number.
For Moonshot slugs we compute cost from the reported token counts;
for anything else (including Anthropic) we return the SDK number
unchanged — Anthropic slugs are priced accurately by the CLI.
Cache read / creation tokens are folded into ``prompt_tokens`` at
the full input rate because Moonshot's rate card doesn't distinguish
them at the OpenRouter surface; the reconcile has the authoritative
discount accounting for turns where Moonshot's cache engaged.
"""
if model is None:
return sdk_reported_usd
rates = rate_card_usd(model)
if rates is None:
return sdk_reported_usd
input_rate, output_rate = rates
total_prompt = prompt_tokens + cache_read_tokens + cache_creation_tokens
return (total_prompt * input_rate + completion_tokens * output_rate) / 1_000_000
def moonshot_supports_cache_control(model: str | None) -> bool:
"""True when a Moonshot *model* accepts Anthropic-style ``cache_control``.
Narrow, Moonshot-specific predicate — callers that need the full
"does this route accept cache markers" answer combine this with an
Anthropic check (see ``baseline/service.py::_supports_prompt_cache_markers``).
Named ``moonshot_*`` deliberately so the call site can't mistake it
for a universal predicate that answers correctly for Anthropic
(which also supports cache_control — this function would return
False for Anthropic slugs).
Moonshot's Anthropic-compat endpoint honours the marker. Without
it Moonshot falls back to its own automatic prefix caching, which
drifts more readily between turns (internal testing saw 0/4 cache
hits across two continuation sessions). With explicit
``cache_control`` the upstream cache hit rate rises to the same
ballpark as Anthropic's ~60-95% on continuations.
"""
return is_moonshot_model(model)

View File

@@ -0,0 +1,173 @@
"""Unit tests for Moonshot pricing and cache-control helpers."""
from __future__ import annotations
import pytest
from backend.copilot.moonshot import (
is_moonshot_model,
moonshot_supports_cache_control,
override_cost_usd,
rate_card_usd,
)
class TestIsMoonshotModel:
"""Prefix detection covers every Moonshot SKU without a slug list."""
@pytest.mark.parametrize(
"model",
[
"moonshotai/kimi-k2.6",
"moonshotai/kimi-k2-thinking",
"moonshotai/kimi-k2.5",
"moonshotai/kimi-k2",
"moonshotai/kimi-k3.0", # Future SKU must match transparently.
],
)
def test_moonshot_slugs_match(self, model: str) -> None:
assert is_moonshot_model(model) is True
@pytest.mark.parametrize(
"model",
[
"anthropic/claude-sonnet-4.6",
"anthropic/claude-opus-4.7",
"openai/gpt-4o",
"google/gemini-2.5-flash",
"xai/grok-4",
"deepseek/deepseek-v3",
"", # Empty string — not Moonshot.
],
)
def test_non_moonshot_slugs_do_not_match(self, model: str) -> None:
assert is_moonshot_model(model) is False
@pytest.mark.parametrize("model", [None, 123, ["moonshotai/kimi-k2.6"]])
def test_non_string_returns_false(self, model) -> None:
# Type-robust: never raise on unexpected types; callers pass None.
assert is_moonshot_model(model) is False
class TestRateCardUsd:
"""Rate card defaults to the shared Moonshot price for every SKU."""
def test_moonshot_default_rate(self) -> None:
assert rate_card_usd("moonshotai/kimi-k2.6") == (0.60, 2.80)
def test_future_moonshot_sku_inherits_default(self) -> None:
# Verifies the prefix-based fallback — new SKUs don't need a code
# edit to get a reasonable rate card.
assert rate_card_usd("moonshotai/kimi-k3.0") == (0.60, 2.80)
def test_non_moonshot_returns_none(self) -> None:
assert rate_card_usd("anthropic/claude-sonnet-4.6") is None
assert rate_card_usd("openai/gpt-4o") is None
class TestOverrideCostUsd:
"""Rate-card override replaces the CLI's Sonnet-rate estimate for
Moonshot turns; Anthropic and unknown slugs pass through unchanged."""
def test_moonshot_recomputes_from_rate_card(self) -> None:
"""A 29.5K-prompt Kimi turn should land at ~$0.018 on the
Moonshot rate card, not the CLI's $0.09 Sonnet-rate estimate."""
recomputed = override_cost_usd(
model="moonshotai/kimi-k2.6",
sdk_reported_usd=0.089862, # What the CLI reported (Sonnet price).
prompt_tokens=29564,
completion_tokens=78,
cache_read_tokens=0,
cache_creation_tokens=0,
)
expected = (29564 * 0.60 + 78 * 2.80) / 1_000_000
assert recomputed == pytest.approx(expected, rel=1e-9)
assert 0.017 < recomputed < 0.019 # Sanity against Moonshot's rate card.
def test_anthropic_passes_through(self) -> None:
"""Anthropic slugs are priced accurately by the CLI already —
the override returns the SDK number unchanged."""
assert (
override_cost_usd(
model="anthropic/claude-sonnet-4.6",
sdk_reported_usd=0.089862,
prompt_tokens=29564,
completion_tokens=78,
cache_read_tokens=0,
cache_creation_tokens=0,
)
== 0.089862
)
def test_unknown_non_moonshot_passes_through(self) -> None:
"""A non-Moonshot, non-Anthropic slug falls back to the SDK value
— best-effort rather than leaking a zero or a wrong rate card."""
assert (
override_cost_usd(
model="deepseek/deepseek-v3",
sdk_reported_usd=0.05,
prompt_tokens=10_000,
completion_tokens=500,
cache_read_tokens=0,
cache_creation_tokens=0,
)
== 0.05
)
def test_none_model_passes_through(self) -> None:
"""Subscription mode sets model=None — return the SDK value."""
assert (
override_cost_usd(
model=None,
sdk_reported_usd=0.07,
prompt_tokens=100,
completion_tokens=10,
cache_read_tokens=0,
cache_creation_tokens=0,
)
== 0.07
)
def test_cache_tokens_priced_at_input_rate(self) -> None:
"""OpenRouter's Moonshot endpoints don't expose a discounted
cached-input price — cache_read / cache_creation tokens are
priced at the full input rate. The reconcile path has the
authoritative discount for turns where Moonshot's cache engaged."""
recomputed = override_cost_usd(
model="moonshotai/kimi-k2.6",
sdk_reported_usd=0.5,
prompt_tokens=1000,
completion_tokens=0,
cache_read_tokens=5000,
cache_creation_tokens=2000,
)
expected = (1000 + 5000 + 2000) * 0.60 / 1_000_000
assert recomputed == pytest.approx(expected, rel=1e-9)
class TestSupportsCacheControl:
"""Gate for emitting ``cache_control: {type: ephemeral}`` on message
blocks. True for Moonshot (Anthropic-compat endpoint accepts it)
and False for everything else this module knows about — Anthropic
callers use their own ``_is_anthropic_model`` check which is
combined with this one into a wider gate."""
def test_moonshot_supports_cache_control(self) -> None:
assert moonshot_supports_cache_control("moonshotai/kimi-k2.6") is True
def test_future_moonshot_sku_supports_cache_control(self) -> None:
assert moonshot_supports_cache_control("moonshotai/kimi-k3.0") is True
@pytest.mark.parametrize(
"model",
[
"openai/gpt-4o",
"google/gemini-2.5-flash",
"xai/grok-4",
"deepseek/deepseek-v3",
"",
None,
],
)
def test_non_moonshot_does_not_support_cache_control(self, model) -> None:
assert moonshot_supports_cache_control(model) is False

View File

@@ -240,16 +240,15 @@ async def peek_pending_messages(session_id: str) -> list[PendingMessage]:
return messages
async def _clear_pending_messages_unsafe(session_id: str) -> None:
async def clear_pending_messages_unsafe(session_id: str) -> None:
"""Drop the session's pending buffer — **not** the normal turn cleanup.
Named ``_unsafe`` because reaching for this at turn end drops queued
follow-ups on the floor instead of running them (the bug fixed by
commit b64be73). The atomic ``LPOP`` drain at turn start is the
primary consumer; anything pushed after the drain window belongs to
the next turn by definition. Retained only as an operator/debug
escape hatch for manually clearing a stuck session and as a fixture
in the unit tests.
The ``_unsafe`` suffix warns: reaching for this at turn end drops queued
follow-ups on the floor instead of running them (the bug fixed by commit
b64be73). The atomic ``LPOP`` drain at turn start is the primary consumer;
anything pushed after the drain window belongs to the next turn by
definition. Retained only as an operator/debug escape hatch for manually
clearing a stuck session and as a fixture in the unit tests.
"""
redis = await get_redis_async()
await redis.delete(_buffer_key(session_id))

View File

@@ -16,7 +16,7 @@ from backend.copilot.pending_messages import (
MAX_PENDING_MESSAGES,
PendingMessage,
PendingMessageContext,
_clear_pending_messages_unsafe,
clear_pending_messages_unsafe,
drain_and_format_for_injection,
drain_pending_for_persist,
drain_pending_messages,
@@ -208,15 +208,15 @@ async def test_cap_drops_oldest_when_exceeded(fake_redis: _FakeRedis) -> None:
async def test_clear_removes_buffer(fake_redis: _FakeRedis) -> None:
await push_pending_message("sess4", PendingMessage(content="x"))
await push_pending_message("sess4", PendingMessage(content="y"))
await _clear_pending_messages_unsafe("sess4")
await clear_pending_messages_unsafe("sess4")
assert await peek_pending_count("sess4") == 0
@pytest.mark.asyncio
async def test_clear_is_idempotent(fake_redis: _FakeRedis) -> None:
# Clearing an already-empty buffer should not raise
await _clear_pending_messages_unsafe("sess_empty")
await _clear_pending_messages_unsafe("sess_empty")
await clear_pending_messages_unsafe("sess_empty")
await clear_pending_messages_unsafe("sess_empty")
# ── Publish hook ────────────────────────────────────────────────────

View File

@@ -52,10 +52,15 @@ is at most as permissive as the parent:
from __future__ import annotations
import re
from typing import Literal, get_args
from typing import TYPE_CHECKING, Literal, get_args
from pydantic import BaseModel, PrivateAttr
if TYPE_CHECKING:
from collections.abc import Iterable
from backend.copilot.tools import ToolGroup
# ---------------------------------------------------------------------------
# Constants — single source of truth for all accepted tool names
# ---------------------------------------------------------------------------
@@ -66,7 +71,6 @@ from pydantic import BaseModel, PrivateAttr
ToolName = Literal[
# Platform tools (must match keys in TOOL_REGISTRY)
"add_understanding",
"ask_question",
"bash_exec",
"browser_act",
"browser_navigate",
@@ -125,9 +129,16 @@ ToolName = Literal[
# Frozen set of all valid tool names — derived from the Literal.
ALL_TOOL_NAMES: frozenset[str] = frozenset(get_args(ToolName))
# SDK built-in tool names — uppercase-initial names are SDK built-ins.
# SDK built-in tool names — tools provided by the Claude Code CLI that our
# code does not implement directly. ``TodoWrite`` is DELIBERATELY excluded:
# baseline mode ships an MCP-wrapped platform version
# (``tools/todo_write.py``), while SDK mode still uses the CLI-native
# original via ``_SDK_BUILTIN_ALWAYS`` in ``sdk/tool_adapter.py`` — the
# MCP copy is filtered out there. ``Task`` remains an SDK-only built-in
# (for queue-backed context-isolation on baseline, use ``run_sub_session``
# instead).
SDK_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
n for n in ALL_TOOL_NAMES if n[0].isupper()
{"Agent", "Edit", "Glob", "Grep", "Read", "Task", "WebSearch", "Write"}
)
# Platform tool names — everything that isn't an SDK built-in.
@@ -364,13 +375,17 @@ def apply_tool_permissions(
permissions: CopilotPermissions,
*,
use_e2b: bool = False,
disabled_groups: Iterable[ToolGroup] = (),
) -> tuple[list[str], list[str]]:
"""Compute (allowed_tools, extra_disallowed) for :class:`ClaudeAgentOptions`.
Takes the base allowed/disallowed lists from
:func:`~backend.copilot.sdk.tool_adapter.get_copilot_tool_names` /
:func:`~backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools` and
applies *permissions* on top.
applies *permissions* on top. Tools belonging to any *disabled_groups*
are hidden from the base allowed list — use this to gate capability
groups (e.g. ``"graphiti"`` when the memory backend is off for the
current user).
Returns:
``(allowed_tools, extra_disallowed)`` where *allowed_tools* is the
@@ -380,13 +395,16 @@ def apply_tool_permissions(
"""
from backend.copilot.sdk.tool_adapter import (
_READ_TOOL_NAME,
BASELINE_ONLY_MCP_TOOLS,
MCP_TOOL_PREFIX,
get_copilot_tool_names,
get_sdk_disallowed_tools,
)
from backend.copilot.tools import TOOL_REGISTRY
base_allowed = get_copilot_tool_names(use_e2b=use_e2b)
base_allowed = get_copilot_tool_names(
use_e2b=use_e2b, disabled_groups=disabled_groups
)
base_disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
if permissions.is_empty():
@@ -420,7 +438,14 @@ def apply_tool_permissions(
# keeping only those present in the original base_allowed list.
def to_sdk_names(short: str) -> list[str]:
names: list[str] = []
if short in TOOL_REGISTRY:
if short in BASELINE_ONLY_MCP_TOOLS:
# Baseline ships MCP versions of these (Task/TodoWrite) for
# model-flexibility parity, but SDK mode uses the CLI-native
# originals. Permissions target the CLI built-in here so
# ``base_allowed`` (which excludes the MCP wrappers) still
# matches.
names.append(short)
elif short in TOOL_REGISTRY:
names.append(f"{MCP_TOOL_PREFIX}{short}")
elif short in _SDK_TO_MCP:
# Map SDK built-in file tool to its MCP equivalent.

View File

@@ -582,6 +582,11 @@ class TestApplyToolPermissions:
class TestSdkBuiltinToolNames:
def test_expected_builtins_present(self):
# ``TodoWrite`` is DELIBERATELY absent: baseline ships an MCP-wrapped
# platform version for model-flexibility parity, so it appears in
# PLATFORM_TOOL_NAMES / TOOL_REGISTRY instead. ``Task`` remains
# SDK-only — baseline uses ``run_sub_session`` for the equivalent
# context-isolation role.
expected = {
"Agent",
"Read",
@@ -591,9 +596,9 @@ class TestSdkBuiltinToolNames:
"Grep",
"Task",
"WebSearch",
"TodoWrite",
}
assert expected.issubset(SDK_BUILTIN_TOOL_NAMES)
assert "TodoWrite" not in SDK_BUILTIN_TOOL_NAMES
def test_platform_names_match_tool_registry(self):
"""PLATFORM_TOOL_NAMES (derived from ToolName Literal) must match TOOL_REGISTRY keys."""

View File

@@ -145,31 +145,12 @@ When the user asks to interact with a service or API, follow this order:
**Never skip step 1.** Built-in blocks are more reliable, tested, and user-friendly than MCP or raw API calls.
### Sub-agent tasks
- When using the Task tool, NEVER set `run_in_background` to true.
All tasks must run in the foreground.
### Delegating to another autopilot (sub-autopilot pattern)
Use the **`run_sub_session`** tool to delegate a task to a fresh
sub-AutoPilot. The sub has its own full tool set and can perform
multi-step work autonomously.
- `prompt` (required): the task description.
- `system_context` (optional): extra context prepended to the prompt.
- `sub_autopilot_session_id` (optional): continue an existing
sub-AutoPilot — pass the `sub_autopilot_session_id` returned by a
previous completed run.
- `wait_for_result` (default 60, max 300): seconds to wait inline. If
the sub isn't done by then you get `status="running"` + a
`sub_session_id` — call **`get_sub_session_result`** with that id
(wait up to 300s more per call) until it returns `completed` or
`error`. Works across turns — safe to reconnect in a later message.
Use this when a task is complex enough to benefit from a separate
autopilot context, e.g. "research X and write a report" while the
parent autopilot handles orchestration. Do NOT invoke `AutoPilotBlock`
via `run_block` — it's hidden from `run_block` by design because the
dedicated tool handles the async lifecycle correctly.
### Complex multi-step work
- Use `TodoWrite` to track the plan once the job has 3+ distinct steps.
- Delegate self-contained subtasks to `run_sub_session` to keep their
intermediate tool calls out of the parent context.
- Do NOT invoke `AutoPilotBlock` via `run_block`; use `run_sub_session`
instead.
"""
@@ -182,14 +163,18 @@ sandbox so `bash_exec` can access it for further processing.
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
### GitHub CLI (`gh`) and git
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it.
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before running `connect_integration(provider="github")` which will ask the user to connect their GitHub regardless if it's already connected.
- If the user has connected their GitHub account, both `gh` and `git` are
pre-authenticated — use them directly without any manual login step.
`git` HTTPS operations (clone, push, pull) work automatically.
- If the token changes mid-session (e.g. user reconnects with a new token),
run `gh auth setup-git` to re-register the credential helper.
- If `gh` or `git` fails with an authentication error (e.g. "authentication
required", "could not read Username", or exit code 128), call
- **MANDATORY:** You MUST run `gh auth status` before EVER calling
`connect_integration(provider="github")`. If it shows `Logged in`,
proceed directly — no integration connection needed. Never skip this check.
- If `gh auth status` shows NOT logged in, or `gh`/`git` fails with an
authentication error (e.g. "authentication required", "could not read
Username", or exit code 128), THEN call
`connect_integration(provider="github")` to surface the GitHub credentials
setup card so the user can connect their account. Once connected, retry
the operation.

View File

@@ -1,7 +1,6 @@
"""Tests for agent generation guide — verifies clarification section."""
"""Tests for prompting helpers."""
import importlib
from pathlib import Path
from backend.copilot import prompting
@@ -31,28 +30,3 @@ class TestGetSdkSupplementStaticPlaceholder:
def test_e2b_mode_has_no_session_placeholder(self):
result = prompting.get_sdk_supplement(use_e2b=True)
assert "<session-id>" not in result
class TestAgentGenerationGuideContainsClarifySection:
"""The agent generation guide must include the clarification section."""
def test_guide_includes_clarify_section(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
assert "Before or During Building" in content
def test_guide_mentions_find_block_for_clarification(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
clarify_section = content.split("Before or During Building")[1].split(
"### Workflow"
)[0]
assert "find_block" in clarify_section
def test_guide_mentions_ask_question_tool(self):
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
content = guide_path.read_text(encoding="utf-8")
clarify_section = content.split("Before or During Building")[1].split(
"### Workflow"
)[0]
assert "ask_question" in clarify_section

View File

@@ -39,6 +39,7 @@ Before running the workflow below, ALWAYS decompose the goal first:
For simple goals (1-2 blocks), keep steps brief (2-3 steps).
For complex goals, use as many steps as needed.
### Workflow for Creating/Editing Agents
1. **If editing**: First narrow to the specific agent by UUID, then fetch its

View File

@@ -13,12 +13,19 @@ from backend.copilot.config import ChatConfig
def _make_config(**overrides) -> ChatConfig:
"""Create a ChatConfig with safe defaults, applying *overrides*."""
"""Create a ChatConfig with safe defaults, applying *overrides*.
SDK model fields are pinned to anthropic/* so the
``_validate_sdk_model_vendor_compatibility`` model_validator allows
construction with ``use_openrouter=False`` (the default here).
"""
defaults = {
"use_claude_code_subscription": False,
"use_openrouter": False,
"api_key": None,
"base_url": None,
"thinking_standard_model": "anthropic/claude-sonnet-4-6",
"thinking_advanced_model": "anthropic/claude-opus-4-7",
}
defaults.update(overrides)
return ChatConfig(**defaults)

View File

@@ -0,0 +1,399 @@
"""Authoritative per-turn cost for OpenRouter-routed SDK generations.
The Claude Agent SDK CLI's ``ResultMessage.total_cost_usd`` is computed
from a static Anthropic pricing table baked into the binary. For
non-Anthropic models routed through OpenRouter (e.g. Kimi K2.6) the CLI
silently falls back to Sonnet rates — empirically ~5x too high. Even
after a rate-card override the estimate is still ~37% off in practice
because OpenRouter's own tokenizer counts, reasoning-token rollup, and
dated-snapshot pricing tiers can't be reconstructed from what the SDK
exposes locally.
This module provides :func:`record_turn_cost_from_openrouter` — an
``asyncio.create_task``-able coroutine that:
1. Queries ``https://openrouter.ai/api/v1/generation?id=<gen-id>`` for
each generation ID captured during the turn.
2. Sums the authoritative ``total_cost`` across all rounds.
3. Calls :func:`persist_and_record_usage` **once** with the real number,
updating both the cost-analytics row and the rate-limit counter.
If every lookup fails (404 / timeout / parse error), the caller's
``fallback_cost_usd`` is recorded instead — keeps the rate-limit counter
populated with the best available estimate rather than leaving the turn
uncharged.
"""
from __future__ import annotations
import asyncio
import logging
import os
from pathlib import Path
from typing import TYPE_CHECKING
import httpx
from backend.copilot.token_tracking import persist_and_record_usage
from backend.util import json
if TYPE_CHECKING:
from backend.copilot.model import ChatSession
logger = logging.getLogger(__name__)
# OpenRouter docs:
# https://openrouter.ai/docs/api-reference/get-a-generation
_GENERATION_URL = "https://openrouter.ai/api/v1/generation"
# OpenRouter's generation endpoint indexes the billing row a few seconds
# after the SSE stream closes — observed ~8-12s in practice. Retry with
# progressive backoff for up to ~30s total before giving up, so the typical
# indexing window (~10s) fits inside the retry envelope. Backoff values
# in seconds summed: 0.5 + 1 + 2 + 4 + 8 + 15 = 30.5.
_MAX_RETRIES = 7
_BACKOFF_SECONDS = (0.5, 1.0, 2.0, 4.0, 8.0, 15.0)
_REQUEST_TIMEOUT = 10.0
async def _fetch_generation_cost(
client: httpx.AsyncClient,
gen_id: str,
api_key: str,
log_prefix: str,
) -> float | None:
"""Fetch the ``total_cost`` for one generation, with retries.
Retries only on transient conditions:
* HTTP 404 — row not yet indexed server-side (typical ~5-10s lag
after the SSE stream closes)
* HTTP 408 / 429 — timeout / rate limit
* HTTP 5xx — transient OpenRouter outage
* Network / ``httpx`` exceptions — transport-level retryable
Fails fast on permanent client errors (401 Unauthorized,
403 Forbidden, 400 Bad Request, etc.) since they can't recover
within the retry window and would just burn API quota.
Returns ``None`` when the endpoint reports no data, on a permanent
failure, or when every retry attempt hits a transient error.
"""
headers = {"Authorization": f"Bearer {api_key}"}
params = {"id": gen_id}
last_error: Exception | None = None
for attempt in range(_MAX_RETRIES):
if attempt > 0:
await asyncio.sleep(_BACKOFF_SECONDS[attempt - 1])
try:
resp = await client.get(
_GENERATION_URL,
params=params,
headers=headers,
timeout=_REQUEST_TIMEOUT,
)
status = resp.status_code
# Fast-fail on permanent client errors — retrying 401/403/400
# just burns API quota and delays the fallback.
if status in (400, 401, 403):
logger.warning(
"%s OpenRouter /generation permanent error %d for %s"
"not retrying (check API key / request shape)",
log_prefix,
status,
gen_id,
)
return None
# Transient retryable: 404 (indexing lag), 408 (timeout),
# 429 (rate limit), 5xx (server error).
if status == 404 or status == 408 or status == 429 or status >= 500:
last_error = RuntimeError(f"HTTP {status} on attempt {attempt + 1}")
continue
# Any other 4xx — treat as permanent.
if status >= 400:
logger.warning(
"%s OpenRouter /generation unexpected status %d for %s"
"not retrying",
log_prefix,
status,
gen_id,
)
return None
payload = resp.json().get("data")
if not isinstance(payload, dict):
logger.warning(
"%s OpenRouter /generation returned no data for %s",
log_prefix,
gen_id,
)
return None
cost = payload.get("total_cost")
if cost is None:
logger.warning(
"%s OpenRouter /generation response missing total_cost "
"for %s (keys=%s)",
log_prefix,
gen_id,
sorted(payload.keys())[:10],
)
return None
return float(cost)
except Exception as exc: # noqa: BLE001
# Network / transport errors are retryable.
last_error = exc
continue
logger.warning(
"%s OpenRouter /generation lookup failed for %s after %d attempts: %s",
log_prefix,
gen_id,
_MAX_RETRIES,
last_error,
)
return None
def _gen_ids_from_jsonl(path: Path) -> set[str]:
"""Extract ``gen-`` message IDs from every assistant entry in a
Claude CLI JSONL file.
Tolerant of malformed lines: single bad JSON object doesn't block
the whole file. Also reads ``redacted_thinking`` / ``thinking``
entries that share an ID with their parent (via ``jq -u`` in the
CLI) and dedups by caller.
"""
ids: set[str] = set()
try:
with path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
entry = json.loads(line, fallback=None)
if not isinstance(entry, dict):
continue
if entry.get("type") != "assistant":
continue
message = entry.get("message")
if not isinstance(message, dict):
continue
msg_id = message.get("id")
if isinstance(msg_id, str) and msg_id.startswith("gen-"):
ids.add(msg_id)
except (OSError, UnicodeDecodeError) as exc:
logger.debug(
"Failed to scan JSONL for gen-IDs: path=%s err=%s",
path,
exc,
)
return ids
def _discover_turn_subagent_gen_ids(
project_dir: Path,
session_id: str,
turn_start_ts: float,
known: list[str],
) -> list[str]:
"""Gen-IDs from this session's subagents created during this turn.
Main-turn LLM rounds (incl. fallback retries) arrive on the live
stream as ``AssistantMessage`` and land on ``known`` via
``message_id``. What's NOT on ``known`` is the CLI's subagent LLM
calls — chiefly auto-compaction, which spawns a fresh JSONL under
``<project_dir>/<session_id>/subagents/agent-acompact-*.jsonl``
whose gen-IDs never touch our main adapter. OpenRouter bills them
anyway, so without this sweep compaction turns under-report cost.
Scoping: ONLY the current session's subagent dir
(``<project_dir>/<session_id>/subagents/agent-*.jsonl``) and ONLY
files whose ``mtime >= turn_start_ts``. Without both guards we'd
merge prior turns' gen-IDs (main JSONL accumulates forever) and
foreign sessions' gen-IDs (the project dir contains every session
for this cwd), double-billing the user.
Also covers non-compaction subagents (Task tool etc.) when the CLI
spawns them — their live-stream visibility depends on SDK version,
so the sweep is a safety net. The dedup against ``known`` means
anything already captured live doesn't double count.
Preserves ``known`` ordering so main-turn IDs stay first; only
appends truly new IDs from the sweep.
"""
merged: list[str] = list(known)
seen = set(merged)
subagents_dir = project_dir / session_id / "subagents"
if not subagents_dir.exists():
return merged
try:
for jsonl in subagents_dir.glob("agent-*.jsonl"):
try:
if jsonl.stat().st_mtime < turn_start_ts:
continue
except OSError:
continue
for gen_id in _gen_ids_from_jsonl(jsonl):
if gen_id not in seen:
seen.add(gen_id)
merged.append(gen_id)
except OSError as exc:
logger.debug("Failed to walk subagents dir=%s: %s", subagents_dir, exc)
return merged
async def record_turn_cost_from_openrouter(
*,
session: "ChatSession",
user_id: str | None,
model: str | None,
prompt_tokens: int,
completion_tokens: int,
cache_read_tokens: int,
cache_creation_tokens: int,
generation_ids: list[str],
cli_project_dir: str | None,
cli_session_id: str | None,
turn_start_ts: float | None,
fallback_cost_usd: float | None,
api_key: str | None,
log_prefix: str,
) -> None:
"""Persist turn cost from OpenRouter's authoritative ``/generation``.
Writes a single cost-analytics row via :func:`persist_and_record_usage`
— same method used for the Anthropic-direct sync path — so the
cost-log append and rate-limit counter stay consistent. No double
counting: the caller skips its own sync persist for non-Anthropic
OpenRouter turns and defers entirely to this task.
Launched via ``asyncio.create_task`` from the stream ``finally`` block
so the ~500-2000ms ``/generation`` indexing delay doesn't add latency
to the turn. During that window the rate-limit counter is briefly
unaware of the turn's cost; back-to-back turns in that sub-second
gap see a stale counter. Acceptable tradeoff — the alternative
(writing a possibly-wrong estimate synchronously) creates a
double-count when the reconcile delta arrives.
Fallback semantics: if every generation lookup fails, records
``fallback_cost_usd`` instead so the rate-limit counter isn't left
completely empty. Keeps behaviour at-worst equivalent to the
rate-card estimate that came before this task existed.
"""
if not api_key:
logger.debug(
"%s OpenRouter cost record skipped: no API key available",
log_prefix,
)
return
# Merge in any gen-IDs from CLI subagent JSONLs the live stream
# didn't surface — chiefly SDK-internal compaction, which spawns a
# summarisation LLM call under
# ``<project_dir>/<cli_session_id>/subagents/...`` that OpenRouter
# bills but doesn't emit via our main adapter. Safe no-op when no
# compaction happened (no subagent files created this turn) or the
# CLI wrote nothing there.
#
# The sweep is SESSION-scoped (``<cli_session_id>/subagents/``, not
# the whole project dir) and TURN-scoped (mtime >= turn_start_ts).
# Both guards are load-bearing: the project dir contains every
# session for this cwd, and subagent files persist across turns,
# so an unscoped sweep would re-bill prior turns and foreign
# sessions' gen-IDs.
if cli_project_dir and cli_session_id and turn_start_ts is not None:
merged_ids = _discover_turn_subagent_gen_ids(
Path(os.path.expanduser(cli_project_dir)),
cli_session_id,
turn_start_ts,
generation_ids,
)
if len(merged_ids) != len(generation_ids):
logger.info(
"%s[cost-record] discovered %d additional gen-IDs in "
"session subagents (compaction / Task) — reconcile "
"covers all",
log_prefix,
len(merged_ids) - len(generation_ids),
)
generation_ids = merged_ids
if not generation_ids:
return
try:
async with httpx.AsyncClient() as client:
tasks = [
_fetch_generation_cost(client, gen_id, api_key, log_prefix)
for gen_id in generation_ids
]
results = await asyncio.gather(*tasks, return_exceptions=False)
except Exception as exc: # noqa: BLE001
logger.warning(
"%s OpenRouter cost record failed to fetch any generation "
"(falling back to rate-card estimate): %s",
log_prefix,
exc,
)
results = []
fetched = [r for r in results if isinstance(r, (int, float))]
if fetched and len(fetched) == len(generation_ids):
real_cost: float | None = sum(fetched)
# Log real (OpenRouter billed) vs CLI rate-card estimate so an
# operator can spot divergence without querying OpenRouter by
# hand. Under-count typically means a gen-ID source we don't
# capture live (e.g. title model, background LLM calls running
# outside the main stream); over-count means the CLI's rate
# table is stale vs. OpenRouter's current pricing.
delta_pct: float | None = None
if fallback_cost_usd and fallback_cost_usd > 0:
delta_pct = (real_cost - fallback_cost_usd) / fallback_cost_usd * 100
logger.info(
"%s[cost-record] OpenRouter real=$%.6f cli_estimate=$%s "
"delta=%s (gen_ids=%d)",
log_prefix,
real_cost,
f"{fallback_cost_usd:.6f}" if fallback_cost_usd is not None else "?",
f"{delta_pct:+.1f}%" if delta_pct is not None else "n/a",
len(generation_ids),
)
else:
real_cost = fallback_cost_usd
if fetched:
# Partial success: some lookups returned a cost, others didn't.
# Trusting the partial sum would under-report; fall back to the
# estimate so rate-limit enforcement stays conservative.
logger.warning(
"%s[cost-record] OpenRouter partial lookup (%d/%d) — "
"using fallback estimate=$%s",
log_prefix,
len(fetched),
len(generation_ids),
real_cost,
)
else:
logger.warning(
"%s[cost-record] OpenRouter lookup failed for all gens — "
"using fallback estimate=$%s",
log_prefix,
real_cost,
)
try:
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_tokens=cache_read_tokens,
cache_creation_tokens=cache_creation_tokens,
log_prefix=f"{log_prefix}[cost-record]",
cost_usd=real_cost,
model=model,
provider="open_router",
)
except Exception as exc: # noqa: BLE001
logger.warning(
"%s[cost-record] failed to persist: %s",
log_prefix,
exc,
)

View File

@@ -0,0 +1,520 @@
"""Unit tests for SDK-path OpenRouter cost recording."""
from __future__ import annotations
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
import httpx
import pytest
from backend.copilot.model import ChatSession
from backend.copilot.sdk.openrouter_cost import record_turn_cost_from_openrouter
def _session() -> ChatSession:
now = datetime.now(UTC)
return ChatSession(
session_id="sess-1",
user_id="user-1",
usage=[],
started_at=now,
updated_at=now,
messages=[],
)
def _mock_generation_response(cost: float) -> dict:
return {
"data": {
"total_cost": cost,
"native_tokens_prompt": 1000,
"native_tokens_completion": 50,
"tokens_prompt": 1100,
"tokens_completion": 60,
}
}
class TestRecordTurnCostFromOpenRouter:
"""Single-write semantics: the cost + rate-limit counter is updated
exactly once per turn via this background task. The sync path at
the call site is already skipped for non-Anthropic OpenRouter turns,
so there's no double-counting path even on partial failure."""
@pytest.mark.asyncio
async def test_empty_generation_ids_no_op(self):
"""Direct-Anthropic turn produces no gen-IDs — task is a no-op."""
with (
patch(
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
new_callable=AsyncMock,
) as mock_persist,
patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get,
):
await record_turn_cost_from_openrouter(
session=_session(),
user_id="u1",
model="anthropic/claude-sonnet-4.6",
prompt_tokens=10,
completion_tokens=5,
cache_read_tokens=0,
cache_creation_tokens=0,
generation_ids=[],
cli_project_dir=None,
cli_session_id=None,
turn_start_ts=None,
fallback_cost_usd=0.05,
api_key="sk-or-test",
log_prefix="[test]",
)
mock_persist.assert_not_called()
mock_get.assert_not_called()
@pytest.mark.asyncio
async def test_missing_api_key_no_op(self):
"""Without an OpenRouter API key we can't query the endpoint — skip."""
with patch(
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
new_callable=AsyncMock,
) as mock_persist:
await record_turn_cost_from_openrouter(
session=_session(),
user_id="u1",
model="moonshotai/kimi-k2.6",
prompt_tokens=10,
completion_tokens=5,
cache_read_tokens=0,
cache_creation_tokens=0,
generation_ids=["gen-1"],
cli_project_dir=None,
cli_session_id=None,
turn_start_ts=None,
fallback_cost_usd=0.02,
api_key=None,
log_prefix="[test]",
)
mock_persist.assert_not_called()
@pytest.mark.asyncio
async def test_single_generation_records_real_cost(self):
"""Authoritative cost from OpenRouter is the value recorded — no
reliance on the fallback estimate."""
real_cost = 0.02900595
async def _get(self, url, **kwargs): # noqa: ARG001
return httpx.Response(200, json=_mock_generation_response(real_cost))
with (
patch(
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
new_callable=AsyncMock,
) as mock_persist,
patch("httpx.AsyncClient.get", new=_get),
):
await record_turn_cost_from_openrouter(
session=_session(),
user_id="u1",
model="moonshotai/kimi-k2.6",
prompt_tokens=29669,
completion_tokens=280,
cache_read_tokens=0,
cache_creation_tokens=0,
generation_ids=["gen-1776842410"],
cli_project_dir=None,
cli_session_id=None,
turn_start_ts=None,
fallback_cost_usd=0.01858, # rate-card estimate, deliberately wrong
api_key="sk-or-test",
log_prefix="[test]",
)
mock_persist.assert_called_once()
kwargs = mock_persist.call_args.kwargs
assert kwargs["cost_usd"] == pytest.approx(real_cost, rel=1e-9)
assert kwargs["prompt_tokens"] == 29669
assert kwargs["completion_tokens"] == 280
assert kwargs["provider"] == "open_router"
assert kwargs["model"] == "moonshotai/kimi-k2.6"
@pytest.mark.asyncio
async def test_multi_round_turn_sums_costs(self):
"""Tool-use turn has N generation IDs; the real cost is the sum
of ``total_cost`` across all rounds — recorded in a single row."""
costs_by_id = {"gen-a": 0.029, "gen-b": 0.030}
async def _get(self, url, **kwargs): # noqa: ARG001
gen_id = kwargs.get("params", {}).get("id")
return httpx.Response(
200, json=_mock_generation_response(costs_by_id[gen_id])
)
with (
patch(
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
new_callable=AsyncMock,
) as mock_persist,
patch("httpx.AsyncClient.get", new=_get),
):
await record_turn_cost_from_openrouter(
session=_session(),
user_id="u1",
model="moonshotai/kimi-k2.6",
prompt_tokens=60000,
completion_tokens=600,
cache_read_tokens=0,
cache_creation_tokens=0,
generation_ids=["gen-a", "gen-b"],
cli_project_dir=None,
cli_session_id=None,
turn_start_ts=None,
fallback_cost_usd=0.037,
api_key="sk-or-test",
log_prefix="[test]",
)
mock_persist.assert_called_once()
cost = mock_persist.call_args.kwargs["cost_usd"]
assert cost == pytest.approx(sum(costs_by_id.values()), rel=1e-9)
@pytest.mark.asyncio
async def test_partial_lookup_falls_back_to_estimate(self):
"""If only some gen-IDs resolve, summing them would under-report.
Fall back to the caller's estimate and log — the rate-limit
counter stays populated with the best available number."""
fallback = 0.05
seq = iter(
[
httpx.Response(200, json=_mock_generation_response(0.03)),
httpx.Response(404, text="not found"),
httpx.Response(404, text="not found"),
httpx.Response(404, text="not found"),
httpx.Response(404, text="not found"),
]
)
async def _get(self, *args, **kwargs): # noqa: ARG001
return next(seq)
with (
patch(
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
new_callable=AsyncMock,
) as mock_persist,
patch("httpx.AsyncClient.get", new=_get),
):
await record_turn_cost_from_openrouter(
session=_session(),
user_id="u1",
model="moonshotai/kimi-k2.6",
prompt_tokens=1000,
completion_tokens=10,
cache_read_tokens=0,
cache_creation_tokens=0,
generation_ids=["gen-a", "gen-b"],
cli_project_dir=None,
cli_session_id=None,
turn_start_ts=None,
fallback_cost_usd=fallback,
api_key="sk-or-test",
log_prefix="[test]",
)
mock_persist.assert_called_once()
assert mock_persist.call_args.kwargs["cost_usd"] == fallback
@pytest.mark.asyncio
async def test_fast_fail_on_401_no_retries(self):
"""Permanent client errors (401/403/400) must not retry — burning
the 30s retry window on an unauthenticated request wastes API
quota and delays the fallback."""
call_count = {"n": 0}
async def _get(self, *args, **kwargs): # noqa: ARG001
call_count["n"] += 1
return httpx.Response(401, text="unauthorized")
with (
patch(
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
new_callable=AsyncMock,
) as mock_persist,
patch("httpx.AsyncClient.get", new=_get),
):
await record_turn_cost_from_openrouter(
session=_session(),
user_id="u1",
model="moonshotai/kimi-k2.6",
prompt_tokens=1000,
completion_tokens=10,
cache_read_tokens=0,
cache_creation_tokens=0,
generation_ids=["gen-a"],
cli_project_dir=None,
cli_session_id=None,
turn_start_ts=None,
fallback_cost_usd=0.02,
api_key="sk-bad",
log_prefix="[test]",
)
# Only one call — no retries.
assert call_count["n"] == 1
# Fallback was recorded (lookup failed → keep rate-limit counter live).
mock_persist.assert_called_once()
assert mock_persist.call_args.kwargs["cost_usd"] == 0.02
@pytest.mark.asyncio
async def test_retries_on_404_then_succeeds(self):
"""Indexing lag: endpoint returns 404 initially, then 200 once the
billing row is indexed. Retry budget should exhaust transient
states rather than giving up on first 404."""
seq = iter(
[
httpx.Response(404, text="not found"),
httpx.Response(200, json=_mock_generation_response(0.025)),
]
)
async def _get(self, *args, **kwargs): # noqa: ARG001
return next(seq)
with (
patch(
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
new_callable=AsyncMock,
) as mock_persist,
patch("httpx.AsyncClient.get", new=_get),
):
await record_turn_cost_from_openrouter(
session=_session(),
user_id="u1",
model="moonshotai/kimi-k2.6",
prompt_tokens=1000,
completion_tokens=10,
cache_read_tokens=0,
cache_creation_tokens=0,
generation_ids=["gen-a"],
cli_project_dir=None,
cli_session_id=None,
turn_start_ts=None,
fallback_cost_usd=0.05,
api_key="sk-or-test",
log_prefix="[test]",
)
mock_persist.assert_called_once()
assert mock_persist.call_args.kwargs["cost_usd"] == pytest.approx(0.025)
@pytest.mark.asyncio
async def test_complete_lookup_failure_falls_back_to_estimate(self):
"""Every lookup fails → record the estimate so the rate-limit
counter isn't left empty. At-worst parity with the pre-task
behaviour."""
fallback = 0.02
async def _get(self, *args, **kwargs): # noqa: ARG001
raise httpx.ConnectError("no network")
with (
patch(
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
new_callable=AsyncMock,
) as mock_persist,
patch("httpx.AsyncClient.get", new=_get),
):
await record_turn_cost_from_openrouter(
session=_session(),
user_id="u1",
model="moonshotai/kimi-k2.6",
prompt_tokens=1000,
completion_tokens=10,
cache_read_tokens=0,
cache_creation_tokens=0,
generation_ids=["gen-a"],
cli_project_dir=None,
cli_session_id=None,
turn_start_ts=None,
fallback_cost_usd=fallback,
api_key="sk-or-test",
log_prefix="[test]",
)
mock_persist.assert_called_once()
assert mock_persist.call_args.kwargs["cost_usd"] == fallback
@pytest.mark.asyncio
async def test_compaction_subagent_gen_ids_are_swept(self, tmp_path):
"""CLI-internal compaction spawns a subagent JSONL under
``<project_dir>/<session_id>/subagents/agent-acompact-*.jsonl``
whose gen-IDs the live adapter never surfaces. When
``cli_project_dir`` + ``cli_session_id`` + ``turn_start_ts``
are supplied the reconcile walks only THIS session's subagents
and discovers the compaction IDs."""
session_id = "sess-abc"
sub_dir = tmp_path / session_id / "subagents"
sub_dir.mkdir(parents=True)
(sub_dir / "agent-acompact-xyz.jsonl").write_text(
'{"type":"assistant","message":{"id":"gen-compact-1","content":[]}}\n'
'{"type":"assistant","message":{"id":"gen-compact-2","content":[]}}\n'
)
costs_by_id = {
"gen-main-1": 0.020,
"gen-compact-1": 0.005,
"gen-compact-2": 0.003,
}
async def _get(self, *args, **kwargs): # noqa: ARG001
gen_id = kwargs.get("params", {}).get("id")
return httpx.Response(
200, json=_mock_generation_response(costs_by_id[gen_id])
)
with (
patch(
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
new_callable=AsyncMock,
) as mock_persist,
patch("httpx.AsyncClient.get", new=_get),
):
await record_turn_cost_from_openrouter(
session=_session(),
user_id="u1",
model="anthropic/claude-opus-4.7",
prompt_tokens=1000,
completion_tokens=10,
cache_read_tokens=0,
cache_creation_tokens=0,
generation_ids=["gen-main-1"],
cli_project_dir=str(tmp_path),
cli_session_id=session_id,
turn_start_ts=0.0,
fallback_cost_usd=0.05,
api_key="sk-or-test",
log_prefix="[test]",
)
mock_persist.assert_called_once()
assert mock_persist.call_args.kwargs["cost_usd"] == pytest.approx(
sum(costs_by_id.values()), rel=1e-9
)
@pytest.mark.asyncio
async def test_compaction_sweep_no_subagents_is_noop(self, tmp_path):
"""No compaction happened → reconcile uses only the caller's
gen-IDs, same as when cli_project_dir is None."""
session_id = "sess-none"
(tmp_path / session_id).mkdir()
async def _get(self, *args, **kwargs): # noqa: ARG001
return httpx.Response(200, json=_mock_generation_response(0.02))
with (
patch(
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
new_callable=AsyncMock,
) as mock_persist,
patch("httpx.AsyncClient.get", new=_get),
):
await record_turn_cost_from_openrouter(
session=_session(),
user_id="u1",
model="moonshotai/kimi-k2.6",
prompt_tokens=1000,
completion_tokens=10,
cache_read_tokens=0,
cache_creation_tokens=0,
generation_ids=["gen-main-1"],
cli_project_dir=str(tmp_path),
cli_session_id=session_id,
turn_start_ts=0.0,
fallback_cost_usd=0.05,
api_key="sk-or-test",
log_prefix="[test]",
)
mock_persist.assert_called_once()
assert mock_persist.call_args.kwargs["cost_usd"] == pytest.approx(0.02)
@pytest.mark.asyncio
async def test_compaction_sweep_ignores_prior_turn_and_foreign_sessions(
self, tmp_path
):
"""Scoping guards the sweep from double-billing: a stale subagent
file from a prior turn (mtime before ``turn_start_ts``) and any
subagent from a foreign session (different session_id folder)
must BOTH be skipped. Without either guard, a long-running
session with past compactions would re-bill every prior turn,
and a second chat session in the same cwd would inherit the
first session's compaction cost."""
import os
import time
this_session = "sess-current"
other_session = "sess-other"
this_subagents = tmp_path / this_session / "subagents"
this_subagents.mkdir(parents=True)
other_subagents = tmp_path / other_session / "subagents"
other_subagents.mkdir(parents=True)
# Prior-turn compaction file — same session, stale mtime.
stale_file = this_subagents / "agent-acompact-stale.jsonl"
stale_file.write_text(
'{"type":"assistant","message":{"id":"gen-stale-1","content":[]}}\n'
)
# Foreign session's compaction file.
foreign_file = other_subagents / "agent-acompact-foreign.jsonl"
foreign_file.write_text(
'{"type":"assistant","message":{"id":"gen-foreign-1","content":[]}}\n'
)
# Current-turn compaction file — fresh.
fresh_file = this_subagents / "agent-acompact-fresh.jsonl"
fresh_file.write_text(
'{"type":"assistant","message":{"id":"gen-fresh-1","content":[]}}\n'
)
# turn_start_ts lies between the stale and fresh mtimes.
past = time.time() - 3600
os.utime(stale_file, (past, past))
os.utime(foreign_file, (past, past))
turn_start_ts = time.time() - 60 # 1 min ago
fresh_now = time.time()
os.utime(fresh_file, (fresh_now, fresh_now))
costs_by_id = {
"gen-main-1": 0.010,
"gen-fresh-1": 0.004,
}
async def _get(self, *args, **kwargs): # noqa: ARG001
gen_id = kwargs.get("params", {}).get("id")
# If the sweep leaks a stale/foreign ID, the test fails here
# with a KeyError rather than silently over-billing.
assert gen_id in costs_by_id, f"sweep leaked out-of-scope gen_id {gen_id}"
return httpx.Response(
200, json=_mock_generation_response(costs_by_id[gen_id])
)
with (
patch(
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
new_callable=AsyncMock,
) as mock_persist,
patch("httpx.AsyncClient.get", new=_get),
):
await record_turn_cost_from_openrouter(
session=_session(),
user_id="u1",
model="moonshotai/kimi-k2.6",
prompt_tokens=1000,
completion_tokens=10,
cache_read_tokens=0,
cache_creation_tokens=0,
generation_ids=["gen-main-1"],
cli_project_dir=str(tmp_path),
cli_session_id=this_session,
turn_start_ts=turn_start_ts,
fallback_cost_usd=0.05,
api_key="sk-or-test",
log_prefix="[test]",
)
mock_persist.assert_called_once()
# Exactly the current-turn main + fresh compaction — no stale,
# no foreign.
assert mock_persist.call_args.kwargs["cost_usd"] == pytest.approx(
sum(costs_by_id.values()), rel=1e-9
)

View File

@@ -10,12 +10,19 @@ from backend.copilot.constants import is_transient_api_error
def _make_config(**overrides) -> ChatConfig:
"""Create a ChatConfig with safe defaults, applying *overrides*."""
"""Create a ChatConfig with safe defaults, applying *overrides*.
SDK model fields are pinned to anthropic/* so the
``_validate_sdk_model_vendor_compatibility`` model_validator allows
construction with ``use_openrouter=False`` (the default here).
"""
defaults = {
"use_claude_code_subscription": False,
"use_openrouter": False,
"api_key": None,
"base_url": None,
"thinking_standard_model": "anthropic/claude-sonnet-4-6",
"thinking_advanced_model": "anthropic/claude-opus-4-7",
}
defaults.update(overrides)
return ChatConfig(**defaults)
@@ -39,8 +46,11 @@ class TestResolveFallbackModel:
assert _resolve_fallback_model() is None
def test_strips_provider_prefix(self):
"""OpenRouter-style 'anthropic/claude-sonnet-4-...' is stripped."""
def test_keeps_full_slug_when_openrouter_active(self):
"""OpenRouter routes by ``vendor/model`` slug — _normalize_model_name
now preserves the prefix when openrouter_active is True so non-
Anthropic vendors stay routable. Anthropic slugs are passed
through unchanged in this mode (PR #12878)."""
cfg = _make_config(
claude_agent_fallback_model="anthropic/claude-sonnet-4-20250514",
use_openrouter=True,
@@ -52,8 +62,7 @@ class TestResolveFallbackModel:
result = _resolve_fallback_model()
assert result == "claude-sonnet-4-20250514"
assert "/" not in result
assert result == "anthropic/claude-sonnet-4-20250514"
def test_dots_replaced_for_direct_anthropic(self):
"""Direct Anthropic requires hyphen-separated versions."""
@@ -714,11 +723,16 @@ class TestDoTransientBackoff:
mock_sleep.assert_called_once_with(7)
async def test_replaces_adapter_with_new_instance(self):
"""state.adapter is replaced with a new SDKResponseAdapter after yield."""
"""state.adapter is replaced with a new SDKResponseAdapter after yield,
and ``render_reasoning_in_ui`` is threaded from the SDK service config
(not hardcoded) so ``CHAT_RENDER_REASONING_IN_UI=false`` at runtime
flips the reconstruction consistently with the rest of the path."""
from unittest.mock import AsyncMock, MagicMock, patch
from backend.copilot.sdk.service import _do_transient_backoff
cfg = _make_config(render_reasoning_in_ui=False)
original_adapter = MagicMock()
state = MagicMock()
state.adapter = original_adapter
@@ -726,6 +740,7 @@ class TestDoTransientBackoff:
with (
patch("asyncio.sleep", new=AsyncMock()),
patch(f"{_SVC}.config", cfg),
patch("backend.copilot.sdk.service.SDKResponseAdapter") as mock_cls,
):
new_adapter = MagicMock()
@@ -733,7 +748,11 @@ class TestDoTransientBackoff:
async for _ in _do_transient_backoff(3, state, "msg-1", "sess-1"):
pass
mock_cls.assert_called_once_with(message_id="msg-1", session_id="sess-1")
mock_cls.assert_called_once_with(
message_id="msg-1",
session_id="sess-1",
render_reasoning_in_ui=False,
)
assert state.adapter is new_adapter
async def test_resets_usage_after_yield(self):

View File

@@ -7,12 +7,15 @@ the frontend expects.
import json
import logging
import time
import uuid
from typing import Any
from claude_agent_sdk import (
AssistantMessage,
Message,
ResultMessage,
StreamEvent,
SystemMessage,
TextBlock,
ThinkingBlock,
@@ -46,6 +49,16 @@ from .tool_adapter import MCP_TOOL_PREFIX, pop_pending_tool_output
logger = logging.getLogger(__name__)
# Coalescing thresholds for ``thinking_delta`` events on the SDK partial
# stream — matches the baseline window (see
# ``baseline/reasoning.py::_COALESCE_MIN_CHARS``). Anthropic's extended-
# thinking channel emits ~1 event per token (~4,700 per Kimi K2.6 turn);
# a 64-char / 50 ms window halves the event rate vs 32/40 while staying
# well under the ~100 ms perceptual threshold.
_THINKING_COALESCE_MIN_CHARS = 64
_THINKING_COALESCE_MAX_INTERVAL_MS = 50.0
class SDKResponseAdapter:
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
@@ -53,7 +66,13 @@ class SDKResponseAdapter:
text blocks, tool calls, and message lifecycle.
"""
def __init__(self, message_id: str | None = None, session_id: str | None = None):
def __init__(
self,
message_id: str | None = None,
session_id: str | None = None,
*,
render_reasoning_in_ui: bool = True,
):
self.message_id = message_id or str(uuid.uuid4())
self.session_id = session_id
self.text_block_id = str(uuid.uuid4())
@@ -62,6 +81,7 @@ class SDKResponseAdapter:
self.reasoning_block_id = str(uuid.uuid4())
self.has_started_reasoning = False
self.has_ended_reasoning = True
self.render_reasoning_in_ui = render_reasoning_in_ui
self.current_tool_calls: dict[str, dict[str, str]] = {}
self.resolved_tool_calls: set[str] = set()
self.step_open = False
@@ -74,6 +94,36 @@ class SDKResponseAdapter:
# case so the turn renders as cleanly complete.
self._text_since_last_tool_result = False
self._any_tool_results_seen = False
# --- Partial-message streaming state (CHAT_SDK_INCLUDE_PARTIAL_MESSAGES)
# When ``include_partial_messages=True`` is set on
# ``ClaudeAgentOptions``, the CLI emits raw Anthropic streaming
# events (``content_block_start`` / ``content_block_delta`` /
# ``content_block_stop``) as ``StreamEvent`` messages ahead of
# each summary ``AssistantMessage``. We consume those for
# per-token wire emission and reconcile against the summary to
# catch any tail content the partial stream missed (short blocks
# the CLI emits summary-only, OpenRouter proxy quirks, encrypted
# thinking).
#
self._block_types_by_index: dict[int, str] = {}
# Running partial-stream buffers. Summary AssistantMessages can
# arrive *before* the corresponding ``content_block_stop`` event
# (the CLI flushes the summary as soon as the block is complete
# on the provider side, with the stop event following as a
# separate frame). Reconcile-by-index therefore can't rely on
# completed-block queues — instead we maintain running buffers
# of all partial output of each type, and each summary block of
# that type consumes its prefix. This also trivially handles
# Kimi K2.6's pattern of emitting each content block as its own
# summary AssistantMessage: Python list indices don't align
# with Anthropic content_block indices, but per-type order does.
self._partial_text_buffer: str = ""
self._partial_thinking_buffer: str = ""
# Coalescing buffer for ``thinking_delta`` — text_delta is
# naturally coarser so we let it through unbuffered.
self._pending_thinking_delta: str = ""
self._pending_thinking_index: int | None = None
self._last_thinking_flush_monotonic: float = 0.0
@property
def has_unresolved_tool_calls(self) -> bool:
@@ -99,6 +149,15 @@ class SDKResponseAdapter:
# produced (task_progress events were previously silent).
responses.append(StreamHeartbeat())
elif isinstance(sdk_message, StreamEvent):
# Raw Anthropic streaming events — only delivered when
# ``include_partial_messages=True`` is set on
# ``ClaudeAgentOptions`` (gated by
# ``config.sdk_include_partial_messages``). Drives per-token
# emission of text + thinking; tool_use and other structural
# events stay on the ``AssistantMessage`` path.
self._handle_stream_event(sdk_message, responses)
elif isinstance(sdk_message, AssistantMessage):
# Flush any SDK built-in tool calls that didn't get a UserMessage
# result (e.g. WebSearch, Read handled internally by the CLI).
@@ -115,18 +174,43 @@ class SDKResponseAdapter:
responses.append(StreamStartStep())
self.step_open = True
for block in sdk_message.content:
# Hoist ThinkingBlocks to the front of the iteration so the UI
# sees reasoning *before* the answer it produced — that's the
# natural reading order and the way Anthropic models emit them.
# OpenRouter passthrough providers (Moonshot/Kimi, DeepSeek)
# often place ``reasoning`` after the visible text in the
# response, which would make ``ReasoningCollapse`` render under
# the assistant message instead of above it. ToolUse and other
# block types stay in their original relative order so tool
# call sequences remain coherent.
#
# Note: when ``include_partial_messages=True`` is active the
# per-token stream already emitted reasoning + text in their
# natural on-the-wire order via ``_handle_stream_event``. The
# summary walk below falls through to ``_emit_text_tail`` /
# ``_emit_thinking_tail`` which emit only the diff, preserving
# that ordering without duplicating content.
blocks_with_idx = sorted(
enumerate(sdk_message.content),
key=lambda pair: 0 if isinstance(pair[1], ThinkingBlock) else 1,
)
for block_index, block in blocks_with_idx:
if isinstance(block, TextBlock):
if block.text:
# Reasoning and text are distinct UI parts; close
# any open reasoning block before opening text so
# the AI SDK transport doesn't merge them.
# Reasoning and text are distinct UI parts; close any
# open reasoning block before opening text so the AI
# SDK transport doesn't merge them.
tail = self._text_tail_for_summary_block(block.text)
if tail:
self._end_reasoning_if_open(responses)
self._ensure_text_started(responses)
responses.append(
StreamTextDelta(id=self.text_block_id, delta=block.text)
StreamTextDelta(id=self.text_block_id, delta=tail)
)
self._text_since_last_tool_result = True
elif block.text:
# Partial stream already emitted the full text.
self._text_since_last_tool_result = True
elif isinstance(block, ThinkingBlock):
# Stream extended_thinking content as a reasoning
@@ -142,13 +226,38 @@ class SDKResponseAdapter:
# it live, extended_thinking turns that end
# thinking-only left the UI stuck on "Thought for Xs"
# with nothing rendered until a page refresh.
if block.thinking:
#
# When ``render_reasoning_in_ui=False`` the three
# reasoning helpers below (and the append) no-op, so
# the frontend sees a text-only stream AND no
# ``ChatMessage(role='reasoning')`` row is persisted
# (the row is only created by ``_dispatch_response``
# when ``StreamReasoningStart`` arrives, which is
# suppressed here). Persistence of the thinking text
# into the SDK transcript via
# ``_format_sdk_content_blocks`` is unaffected — that
# feeds ``--resume`` continuity, not the UI.
#
# Flush any pending coalesce buffer to the wire BEFORE
# computing the tail — otherwise a summary that
# arrives between the last partial delta and the
# ``content_block_stop`` event (race: summary is
# flushed by the CLI as soon as the block is complete
# provider-side, with stop lagging as a separate
# frame) would see ``_partial_thinking_buffer``
# missing the pending prefix, and
# ``_thinking_tail_for_summary_block`` would emit the
# full block — duplicating the tail that
# ``_end_reasoning_if_open`` still drains on stop.
self._flush_pending_thinking(responses)
tail = self._thinking_tail_for_summary_block(block.thinking)
if tail:
self._end_text_if_open(responses)
self._ensure_reasoning_started(responses)
responses.append(
StreamReasoningDelta(
id=self.reasoning_block_id,
delta=block.thinking,
delta=tail,
)
)
@@ -158,7 +267,7 @@ class SDKResponseAdapter:
# Strip MCP prefix so frontend sees "find_block"
# instead of "mcp__copilot__find_block".
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
tool_name = block.name.strip().removeprefix(MCP_TOOL_PREFIX)
responses.append(
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
@@ -347,8 +456,12 @@ class SDKResponseAdapter:
"""Start (or restart) a reasoning block if needed.
Each ``ThinkingBlock`` the SDK emits gets its own streaming block
on the wire so the frontend can render a new ``Reasoning`` part
per LLM turn (rather than concatenating across the whole session).
so the frontend can render a new ``Reasoning`` part per LLM turn
(rather than concatenating across the whole session). Events
are emitted unconditionally — the caller filters them out of the
SSE wire when ``render_reasoning_in_ui=False`` but still feeds
them through ``_dispatch_response`` so the session transcript
keeps a ``role='reasoning'`` row.
"""
if not self.has_started_reasoning or self.has_ended_reasoning:
if self.has_ended_reasoning:
@@ -358,11 +471,238 @@ class SDKResponseAdapter:
self.has_started_reasoning = True
def _end_reasoning_if_open(self, responses: list[StreamBaseResponse]) -> None:
"""End the current reasoning block if one is open."""
"""End the current reasoning block if one is open.
Drains any buffered thinking_delta text so the tail isn't lost
when the block closes before the coalesce window elapses.
"""
if self.has_started_reasoning and not self.has_ended_reasoning:
if self._pending_thinking_delta:
responses.append(
StreamReasoningDelta(
id=self.reasoning_block_id,
delta=self._pending_thinking_delta,
)
)
self._partial_thinking_buffer += self._pending_thinking_delta
self._pending_thinking_delta = ""
self._pending_thinking_index = None
responses.append(StreamReasoningEnd(id=self.reasoning_block_id))
self.has_ended_reasoning = True
# ------------------------------------------------------------------
# Partial-message streaming (CHAT_SDK_INCLUDE_PARTIAL_MESSAGES)
# ------------------------------------------------------------------
def _reset_partial_stream_state(self) -> None:
"""Clear per-message partial-stream state.
Anthropic's ``content_block`` indices are scoped to a single
message — when a fresh ``message_start`` event arrives (new
``AssistantMessage`` turn) the maps must reset so indices from
the previous message don't suppress genuine content in the new
one.
Also clears ``_partial_*_buffer``: multi-round turns (tool use)
emit a ``message_start`` per LLM round, and leftover prefix
content from round N would cause the summary walk in round N+1
to either match the wrong prefix (silently dropping new content)
or diverge and fall back to re-emitting the whole block.
"""
self._block_types_by_index = {}
self._partial_text_buffer = ""
self._partial_thinking_buffer = ""
self._pending_thinking_delta = ""
self._pending_thinking_index = None
def _text_tail_for_summary_block(self, full_text: str) -> str:
"""Reconcile the next summary ``TextBlock`` against the running
partial-stream buffer.
The CLI can emit the summary ``AssistantMessage`` before the
matching ``content_block_stop`` event, so we can't rely on a
queue of completed blocks. Instead we maintain
``_partial_text_buffer`` — the concatenation of every
``text_delta`` chunk that hasn't been claimed by a summary
block yet — and consume ``full_text`` as a prefix from it.
Summary blocks that have no partial backing (buffer empty)
emit their full text; blocks that partial covered wholly are
silent; blocks with a partial prefix + a summary tail emit
only the tail. Kimi K2.6's pattern of emitting each content
block as its own summary ``AssistantMessage`` is handled
automatically because block order is preserved across both
streams.
"""
if not full_text:
return ""
if not self._partial_text_buffer:
return full_text
if full_text.startswith(self._partial_text_buffer):
tail = full_text[len(self._partial_text_buffer) :]
self._partial_text_buffer = ""
return tail
if self._partial_text_buffer.startswith(full_text):
# Partial already emitted this whole block plus more — the
# "more" belongs to a later summary block. Consume only the
# prefix matching this block and leave the rest buffered.
self._partial_text_buffer = self._partial_text_buffer[len(full_text) :]
return ""
logger.warning(
"SDK partial/summary text diverged "
"(partial_buf=%d chars, summary=%d chars) — emitting summary, "
"clearing partial buffer to recover",
len(self._partial_text_buffer),
len(full_text),
)
self._partial_text_buffer = ""
return full_text
def _thinking_tail_for_summary_block(self, full_thinking: str) -> str:
"""Same as :meth:`_text_tail_for_summary_block` for reasoning."""
if not full_thinking:
return ""
if not self._partial_thinking_buffer:
return full_thinking
if full_thinking.startswith(self._partial_thinking_buffer):
tail = full_thinking[len(self._partial_thinking_buffer) :]
self._partial_thinking_buffer = ""
return tail
if self._partial_thinking_buffer.startswith(full_thinking):
self._partial_thinking_buffer = self._partial_thinking_buffer[
len(full_thinking) :
]
return ""
logger.warning(
"SDK partial/summary thinking diverged "
"(partial_buf=%d chars, summary=%d chars) — emitting summary, "
"clearing partial buffer to recover",
len(self._partial_thinking_buffer),
len(full_thinking),
)
self._partial_thinking_buffer = ""
return full_thinking
def _handle_stream_event(
self, evt: StreamEvent, responses: list[StreamBaseResponse]
) -> None:
"""Translate raw Anthropic streaming events into wire events.
Handles four event types; everything else (``message_delta``
stop reasons, ``signature_delta``, ``input_json_delta``,
``ping``, ...) is ignored because the summary ``AssistantMessage``
carries their effects.
* ``message_start`` — new message boundary, reset per-index maps
* ``content_block_start`` — open text / reasoning block on the
wire and remember the block type at that index
* ``content_block_delta`` — forward ``text_delta`` immediately
and coalesce ``thinking_delta`` (64-char / 50 ms window)
* ``content_block_stop`` — drain any buffered thinking and close
the corresponding wire block
"""
raw: dict[str, Any] = evt.event or {}
event_type = raw.get("type")
if event_type == "message_start":
self._reset_partial_stream_state()
return
if event_type == "content_block_start":
block = raw.get("content_block") or {}
index = raw.get("index")
block_type = block.get("type")
if not isinstance(index, int) or not isinstance(block_type, str):
return
self._block_types_by_index[index] = block_type
if block_type == "text":
self._end_reasoning_if_open(responses)
self._ensure_text_started(responses)
# Seed any preamble the block_start carries.
seed = block.get("text") or ""
if seed:
responses.append(StreamTextDelta(id=self.text_block_id, delta=seed))
self._partial_text_buffer += seed
self._text_since_last_tool_result = True
elif block_type == "thinking":
self._end_text_if_open(responses)
self._ensure_reasoning_started(responses)
self._last_thinking_flush_monotonic = time.monotonic()
# tool_use / server_tool_use / redacted_thinking blocks stay
# on the ``AssistantMessage`` path — the frontend widgets
# need the final ``input`` payload which only arrives in the
# summary.
return
if event_type == "content_block_delta":
index = raw.get("index")
if not isinstance(index, int):
return
delta = raw.get("delta") or {}
delta_type = delta.get("type")
if delta_type == "text_delta":
chunk = delta.get("text") or ""
if not chunk:
return
self._ensure_text_started(responses)
responses.append(StreamTextDelta(id=self.text_block_id, delta=chunk))
self._partial_text_buffer += chunk
self._text_since_last_tool_result = True
elif delta_type == "thinking_delta":
chunk = delta.get("thinking") or ""
if not chunk:
return
self._ensure_reasoning_started(responses)
# Flush the coalesce buffer if the index changed — shouldn't
# happen in practice but guard against interleaved indices.
if (
self._pending_thinking_index is not None
and self._pending_thinking_index != index
):
self._flush_pending_thinking(responses)
self._pending_thinking_delta += chunk
self._pending_thinking_index = index
now = time.monotonic()
elapsed_ms = (now - self._last_thinking_flush_monotonic) * 1000.0
if (
len(self._pending_thinking_delta) >= _THINKING_COALESCE_MIN_CHARS
or elapsed_ms >= _THINKING_COALESCE_MAX_INTERVAL_MS
):
self._flush_pending_thinking(responses)
self._last_thinking_flush_monotonic = now
# Other delta types (``signature_delta``, ``input_json_delta``)
# are CLI / tool-dispatch plumbing — not surfaced on the wire.
return
if event_type == "content_block_stop":
index = raw.get("index")
if not isinstance(index, int):
return
block_type = self._block_types_by_index.pop(index, None)
if block_type == "text":
self._end_text_if_open(responses)
elif block_type == "thinking":
self._end_reasoning_if_open(responses)
return
def _flush_pending_thinking(self, responses: list[StreamBaseResponse]) -> None:
"""Drain the coalesce buffer into a ``StreamReasoningDelta``.
Separate from ``_end_reasoning_if_open`` because the coalesce
window can flush mid-block (threshold hit) without closing the
reasoning block.
"""
if not self._pending_thinking_delta:
return
responses.append(
StreamReasoningDelta(
id=self.reasoning_block_id,
delta=self._pending_thinking_delta,
)
)
self._partial_thinking_buffer += self._pending_thinking_delta
self._pending_thinking_delta = ""
self._pending_thinking_index = None
def _flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
"""Emit outputs for tool calls that didn't receive a UserMessage result.

View File

@@ -6,6 +6,7 @@ import pytest
from claude_agent_sdk import (
AssistantMessage,
ResultMessage,
StreamEvent,
SystemMessage,
TextBlock,
ThinkingBlock,
@@ -21,6 +22,7 @@ from backend.copilot.response_model import (
StreamFinishStep,
StreamHeartbeat,
StreamReasoningDelta,
StreamReasoningEnd,
StreamStart,
StreamStartStep,
StreamTextDelta,
@@ -136,6 +138,29 @@ def test_tool_use_emits_input_start_and_available():
assert results[2].input == {"q": "x"}
def test_tool_use_strips_whitespace_in_tool_name():
adapter = _adapter()
msg = AssistantMessage(
content=[
ToolUseBlock(
id="tool-1",
name=f" {MCP_TOOL_PREFIX}find_block",
input={},
)
],
model="test",
)
results = adapter.convert_message(msg)
tool_events = [
r
for r in results
if isinstance(r, (StreamToolInputStart, StreamToolInputAvailable))
]
assert tool_events, "expected tool input events"
for event in tool_events:
assert event.toolName == "find_block"
def test_text_then_tool_ends_text_block():
adapter = _adapter()
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
@@ -298,6 +323,36 @@ def test_text_after_thinking_closes_reasoning_and_opens_text():
assert re_idx < ts_idx
def test_thinking_after_text_in_same_message_renders_reasoning_first():
"""Kimi K2.6 (and other non-Anthropic OpenRouter providers) place
``reasoning`` AFTER the visible text in the response, so the SDK
builds an ``AssistantMessage`` with content = [TextBlock, ThinkingBlock].
Without reordering, the UI would show the answer first and the
reasoning panel below it — the opposite of the natural reading
order Anthropic models produce. response_adapter must hoist
ThinkingBlocks to the front so ``reasoning-start/delta/end`` events
hit the SSE stream BEFORE the ``text-*`` events."""
adapter = _adapter()
msg = AssistantMessage(
content=[
TextBlock(text="63"),
ThinkingBlock(thinking="7 times 9 is 63", signature=""),
],
model="test",
)
results = adapter.convert_message(msg)
types = [type(r).__name__ for r in results]
# ReasoningStart must land before TextStart in the emitted stream
assert "StreamReasoningStart" in types
assert "StreamTextStart" in types
assert types.index("StreamReasoningStart") < types.index("StreamTextStart")
# ReasoningDelta payload is intact
assert any(
isinstance(r, StreamReasoningDelta) and r.delta == "7 times 9 is 63"
for r in results
)
def test_tool_use_after_thinking_closes_reasoning():
"""Opening a tool also closes an open reasoning block."""
adapter = _adapter()
@@ -331,6 +386,69 @@ def test_empty_thinking_block_is_ignored():
assert [type(r).__name__ for r in results] == ["StreamStartStep"]
def test_render_reasoning_in_ui_false_still_emits_adapter_events():
"""With the persist/render decoupling the adapter is flag-agnostic:
it always emits ``StreamReasoning*`` so the session transcript keeps a
durable reasoning record. Wire-level suppression when
``render_reasoning_in_ui=False`` happens at the SDK service yield
boundary, not here — see
``backend/copilot/sdk/service.py::_filter_reasoning_events``.
"""
adapter = SDKResponseAdapter(
message_id="m",
session_id="s",
render_reasoning_in_ui=False,
)
msg = AssistantMessage(
content=[ThinkingBlock(thinking="plan", signature="sig")],
model="test",
)
results = adapter.convert_message(msg)
types = [type(r).__name__ for r in results]
assert "StreamReasoningStart" in types
assert "StreamReasoningDelta" in types
def test_render_reasoning_off_text_after_thinking_still_closes_reasoning():
"""Adapter still emits a ``StreamReasoningEnd`` when text follows a
thinking block — decoupled from the render flag. The service layer
drops the reasoning events at yield time; the adapter's structural
open/close pairing must not depend on the flag or downstream filters
would see orphan reasoning starts on the persisted transcript.
"""
adapter = SDKResponseAdapter(
message_id="m",
session_id="s",
render_reasoning_in_ui=False,
)
adapter.convert_message(
AssistantMessage(
content=[ThinkingBlock(thinking="warming up", signature="sig")],
model="test",
)
)
results = adapter.convert_message(
AssistantMessage(content=[TextBlock(text="hello")], model="test")
)
types = [type(r).__name__ for r in results]
assert "StreamReasoningEnd" in types
assert "StreamTextStart" in types
assert "StreamTextDelta" in types
def test_render_reasoning_on_is_default():
"""Default is True — existing callers keep emitting reasoning events."""
adapter = SDKResponseAdapter(message_id="m", session_id="s")
msg = AssistantMessage(
content=[ThinkingBlock(thinking="plan", signature="sig")],
model="test",
)
results = adapter.convert_message(msg)
types = [type(r).__name__ for r in results]
assert "StreamReasoningStart" in types
assert "StreamReasoningDelta" in types
def test_result_success_synthesizes_fallback_text_when_final_turn_is_thinking_only():
"""If the model's last LLM call after a tool_result produced only a
ThinkingBlock (no TextBlock), the UI would hang on the tool output
@@ -992,3 +1110,318 @@ def test_end_text_if_open_no_op_after_text_already_ended():
second: list[StreamBaseResponse] = []
adapter._end_text_if_open(second)
assert second == []
# ---------------------------------------------------------------------------
# Partial-message streaming (CHAT_SDK_INCLUDE_PARTIAL_MESSAGES)
# Covers the 10 scenarios in docs/sdk-per-token-streaming-followup.md
# ---------------------------------------------------------------------------
def _stream_event(payload: dict) -> StreamEvent:
"""Convenience constructor for a raw Anthropic StreamEvent payload."""
return StreamEvent(
uuid="stream-evt",
session_id="session-1",
parent_tool_use_id=None,
event=payload,
)
def _message_start() -> StreamEvent:
return _stream_event({"type": "message_start"})
def _text_block_start(index: int) -> StreamEvent:
return _stream_event(
{
"type": "content_block_start",
"index": index,
"content_block": {"type": "text", "text": ""},
}
)
def _text_delta(index: int, text: str) -> StreamEvent:
return _stream_event(
{
"type": "content_block_delta",
"index": index,
"delta": {"type": "text_delta", "text": text},
}
)
def _thinking_block_start(index: int) -> StreamEvent:
return _stream_event(
{
"type": "content_block_start",
"index": index,
"content_block": {"type": "thinking", "thinking": ""},
}
)
def _thinking_delta(index: int, text: str) -> StreamEvent:
return _stream_event(
{
"type": "content_block_delta",
"index": index,
"delta": {"type": "thinking_delta", "thinking": text},
}
)
def _block_stop(index: int) -> StreamEvent:
return _stream_event({"type": "content_block_stop", "index": index})
def _collect_text_deltas(responses):
return "".join(r.delta for r in responses if isinstance(r, StreamTextDelta))
def _collect_reasoning_deltas(responses):
return "".join(r.delta for r in responses if isinstance(r, StreamReasoningDelta))
class TestPartialMessageStreaming:
"""Scenarios 1-10 from sdk-per-token-streaming-followup.md.
The adapter runs unconditionally in partial-aware mode — when the
flag ``CHAT_SDK_INCLUDE_PARTIAL_MESSAGES`` is off the CLI simply
never emits ``StreamEvent`` messages and the diff maps stay empty
(so the tail logic degrades to "emit the full summary content"
which is the pre-partial behaviour).
"""
def test_partial_and_summary_agree_no_duplicate(self):
"""Scenario 1: partial streams full text, summary matches exactly.
No duplicate emission, no truncation — full content reaches the
wire once."""
adapter = _adapter()
full = "Hello world"
responses: list[StreamBaseResponse] = []
adapter._handle_stream_event(_text_block_start(0), responses)
for chunk in ("Hello", " ", "world"):
adapter._handle_stream_event(_text_delta(0, chunk), responses)
adapter._handle_stream_event(_block_stop(0), responses)
# Summary arrives with the same full text
summary = adapter.convert_message(
AssistantMessage(content=[TextBlock(text=full)], model="test")
)
combined = responses + summary
assert _collect_text_deltas(combined) == full
def test_partial_short_summary_long_tail_emitted(self):
"""Scenario 2 (the truncation bug we saw): partial emitted a
prefix of the real answer; summary has the full text. The
adapter must emit only the tail so no content is lost."""
adapter = _adapter()
responses: list[StreamBaseResponse] = []
adapter._handle_stream_event(_text_block_start(0), responses)
for chunk in ("The user ", "seems confused. They sent"):
adapter._handle_stream_event(_text_delta(0, chunk), responses)
# Summary has the full, un-truncated content
full = (
"The user seems confused. They sent a short greeting. "
"Let me offer them concrete options."
)
summary = adapter.convert_message(
AssistantMessage(content=[TextBlock(text=full)], model="test")
)
combined = responses + summary
assert _collect_text_deltas(combined) == full
def test_partial_empty_summary_only(self):
"""Scenario 3: no partial deltas (CLI emitted the block entirely
in the summary — short blocks, proxy buffering, encrypted
content). Summary carries the full text."""
adapter = _adapter()
summary = adapter.convert_message(
AssistantMessage(content=[TextBlock(text="short answer")], model="test")
)
assert _collect_text_deltas(summary) == "short answer"
def test_partial_long_summary_matches_no_double_emit(self):
"""Scenario 4 (most common): partial streams everything, summary
repeats the same content. No duplication on the wire."""
adapter = _adapter()
responses: list[StreamBaseResponse] = []
full = "Here is a long paragraph with several words in it."
adapter._handle_stream_event(_text_block_start(0), responses)
# Partition into chunks that *exactly* reconstruct ``full`` — a
# word-split with trailing spaces would emit more content than
# the summary carries and the reconcile would correctly flag
# divergence.
chunks = [full[:13], full[13:25], full[25:]]
assert "".join(chunks) == full
for chunk in chunks:
adapter._handle_stream_event(_text_delta(0, chunk), responses)
adapter._handle_stream_event(_block_stop(0), responses)
assert _collect_text_deltas(responses) == full
summary = adapter.convert_message(
AssistantMessage(content=[TextBlock(text=full)], model="test")
)
# Summary must not add any TextDelta since partial already covered it
assert _collect_text_deltas(summary) == ""
def test_partial_diverges_summary_wins(self):
"""Scenario 5: partial content isn't a prefix of the summary.
Defensive path emits the full summary content — content must
not silently disappear."""
adapter = _adapter()
responses: list[StreamBaseResponse] = []
adapter._handle_stream_event(_text_block_start(0), responses)
adapter._handle_stream_event(_text_delta(0, "first draft"), responses)
# Summary has totally different content (proxy rewrote it)
summary = adapter.convert_message(
AssistantMessage(
content=[TextBlock(text="final polished answer")],
model="test",
)
)
# The summary's text must reach the wire even though partial
# already emitted "first draft" (which was the proxy's draft).
assert "final polished answer" in _collect_text_deltas(responses + summary)
def test_thinking_only_partial_coalesced(self):
"""Scenario 6a (thinking-only permutation): a run of
``thinking_delta`` events below the coalesce threshold flushes
at ``content_block_stop`` so the reasoning tail isn't lost."""
adapter = _adapter()
responses: list[StreamBaseResponse] = []
adapter._handle_stream_event(_thinking_block_start(0), responses)
# Each chunk is well under the 64-char threshold
for chunk in ("Let ", "me ", "think"):
adapter._handle_stream_event(_thinking_delta(0, chunk), responses)
# At stop, the pending buffer drains
adapter._handle_stream_event(_block_stop(0), responses)
assert _collect_reasoning_deltas(responses) == "Let me think"
# Block closed
assert any(isinstance(r, StreamReasoningEnd) for r in responses)
def test_text_only_via_partial_and_summary(self):
"""Scenario 6b (text-only permutation): partial fills a block,
summary matches — see scenario 4 for no-double-emit assertion."""
adapter = _adapter()
responses: list[StreamBaseResponse] = []
adapter._handle_stream_event(_text_block_start(0), responses)
adapter._handle_stream_event(_text_delta(0, "hi"), responses)
adapter._handle_stream_event(_block_stop(0), responses)
assert _collect_text_deltas(responses) == "hi"
def test_mixed_text_then_thinking_partial_preserves_order(self):
"""Scenario 6c (mixed, Anthropic order — reasoning then text).
When partial emits blocks in natural order and summary matches,
the wire order is identical to emission order."""
adapter = _adapter()
responses: list[StreamBaseResponse] = []
# Anthropic-shape: thinking index 0, text index 1
adapter._handle_stream_event(_thinking_block_start(0), responses)
adapter._handle_stream_event(
_thinking_delta(0, "X" * 80), responses
) # over threshold
adapter._handle_stream_event(_block_stop(0), responses)
adapter._handle_stream_event(_text_block_start(1), responses)
adapter._handle_stream_event(_text_delta(1, "answer"), responses)
adapter._handle_stream_event(_block_stop(1), responses)
types = [type(r).__name__ for r in responses]
# ReasoningStart must come before TextStart — partial streams in
# the CLI's natural order, which is also the UI's desired order.
assert types.index("StreamReasoningStart") < types.index("StreamTextStart")
def test_multi_message_turn_resets_per_index_maps(self):
"""Scenario 7: tool-use loop creates multiple AssistantMessages
per turn. Anthropic content-block indices are scoped to a single
message — ``message_start`` must reset the diff maps so the next
message's index-0 text isn't silently suppressed."""
adapter = _adapter()
responses: list[StreamBaseResponse] = []
# First message at index 0 = "first"
adapter._handle_stream_event(_message_start(), responses)
adapter._handle_stream_event(_text_block_start(0), responses)
adapter._handle_stream_event(_text_delta(0, "first"), responses)
adapter._handle_stream_event(_block_stop(0), responses)
# New message starts — index 0 now refers to a fresh block
adapter._handle_stream_event(_message_start(), responses)
adapter._handle_stream_event(_text_block_start(0), responses)
adapter._handle_stream_event(_text_delta(0, "second"), responses)
adapter._handle_stream_event(_block_stop(0), responses)
# Both texts must land on the wire
assert _collect_text_deltas(responses) == "firstsecond"
def test_empty_thinking_with_signature_emits_nothing(self):
"""Scenario 8: encrypted / empty thinking block. Partial emits
nothing, summary carries ``block.thinking == ""`` with a
signature — the adapter must not open a reasoning block."""
adapter = _adapter()
summary = adapter.convert_message(
AssistantMessage(
content=[ThinkingBlock(thinking="", signature="sig")],
model="test",
)
)
# No reasoning events should be emitted for empty thinking
reasoning_events = [
r
for r in summary
if isinstance(r, StreamReasoningDelta)
or type(r).__name__ in ("StreamReasoningStart", "StreamReasoningEnd")
]
assert reasoning_events == []
def test_thinking_tail_drains_on_block_stop(self):
"""Scenario 10: a thinking_delta chunk smaller than the 64-char
threshold arrives, then ``content_block_stop``. The tail text
must emit in a final ``StreamReasoningDelta`` BEFORE
``StreamReasoningEnd``."""
adapter = _adapter()
responses: list[StreamBaseResponse] = []
adapter._handle_stream_event(_thinking_block_start(0), responses)
# One small chunk well under 64 chars
adapter._handle_stream_event(_thinking_delta(0, "tiny chunk"), responses)
# Block stop must flush the pending buffer
adapter._handle_stream_event(_block_stop(0), responses)
types = [type(r).__name__ for r in responses]
# The final ReasoningDelta must precede ReasoningEnd
rd_idx = types.index("StreamReasoningDelta")
re_idx = types.index("StreamReasoningEnd")
assert rd_idx < re_idx
assert _collect_reasoning_deltas(responses) == "tiny chunk"
def test_thinking_coalesces_on_char_threshold(self):
"""Extra: thinking_delta accumulating past 64 chars flushes
mid-block without waiting for block_stop (coalesce threshold)."""
adapter = _adapter()
responses: list[StreamBaseResponse] = []
adapter._handle_stream_event(_thinking_block_start(0), responses)
# One 80-char chunk trips the threshold on a single event
adapter._handle_stream_event(_thinking_delta(0, "x" * 80), responses)
# A ReasoningDelta must already have been emitted (not buffered
# until block_stop).
assert any(isinstance(r, StreamReasoningDelta) for r in responses)
# ---------------------------------------------------------------------------
# Partial/summary reconcile — summary walk must not duplicate partial content
# ---------------------------------------------------------------------------
def test_summary_walk_skips_fully_streamed_text():
"""If the partial stream delivered the entire TextBlock, the summary
walk must not emit a second ``StreamTextDelta`` for the same block."""
adapter = _adapter()
responses: list[StreamBaseResponse] = []
adapter._handle_stream_event(_text_block_start(0), responses)
adapter._handle_stream_event(_text_delta(0, "complete answer"), responses)
adapter._handle_stream_event(_block_stop(0), responses)
# Summary arrives with matching content
summary = adapter.convert_message(
AssistantMessage(content=[TextBlock(text="complete answer")], model="test")
)
# Partial path emitted exactly one StreamTextDelta
partial_deltas = [r for r in responses if isinstance(r, StreamTextDelta)]
summary_deltas = [r for r in summary if isinstance(r, StreamTextDelta)]
assert len(partial_deltas) == 1
assert summary_deltas == []

View File

@@ -27,6 +27,7 @@ from claude_agent_sdk import (
ClaudeAgentOptions,
ClaudeSDKClient,
ResultMessage,
StreamEvent,
TextBlock,
ThinkingBlock,
ToolResultBlock,
@@ -56,6 +57,11 @@ from ..constants import (
from ..session_cleanup import prune_orphan_tool_calls
from ..context import encode_cwd_for_cli, get_workspace_manager
from ..graphiti.config import is_enabled_for_user
from ..model_router import resolve_model
from ..moonshot import (
is_moonshot_model as _is_moonshot_model,
override_cost_usd as _override_cost_for_moonshot,
)
from ..model import (
ChatMessage,
ChatSession,
@@ -109,6 +115,7 @@ from ..service import (
)
from ..thinking_stripper import ThinkingStripper
from ..token_tracking import persist_and_record_usage
from ..tools import ToolGroup
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
from ..tracking import track_user_message
@@ -129,6 +136,7 @@ from ..transcript import (
from ..transcript_builder import TranscriptBuilder
from .compaction import CompactionTracker, filter_compaction_messages
from .env import build_sdk_env # noqa: F401 — re-export for backward compat
from .openrouter_cost import record_turn_cost_from_openrouter
from .response_adapter import SDKResponseAdapter
from .security_hooks import create_security_hooks
from .tool_adapter import (
@@ -364,6 +372,20 @@ class _RetryState:
# ``detect_gap`` picks them up as gap-fill entries instead of assuming the
# JSONL already covers them.
midturn_user_rows: int = 0
# OpenRouter generation IDs collected across all attempts of this turn.
# Populated from ``AssistantMessage.message_id`` when routed via
# OpenRouter (``gen-...`` prefix). Consumed by the finally block to
# fire ``record_turn_cost_from_openrouter`` for non-Anthropic models —
# the CLI's static-Anthropic-priced estimate is replaced with the
# authoritative ``/generation`` total_cost. Lives on ``_RetryState``
# (not per-attempt ``_StreamAccumulator``) so it survives retries.
generation_ids: list[str] = dataclass_field(default_factory=list)
# The *actually executed* model observed on ``AssistantMessage.model`` —
# differs from ``state.options.model`` (the requested primary) when
# ``_resolve_fallback_model`` swaps to a fallback mid-attempt. The
# Moonshot cost override gates on this so a Moonshot-→-Anthropic
# fallback doesn't get mis-billed at Moonshot rates, and vice versa.
observed_model: str | None = None
@dataclass
@@ -678,35 +700,51 @@ async def _iter_sdk_messages(
def _normalize_model_name(raw_model: str) -> str:
"""Normalize a model name for the current routing configuration.
Applies two transformations shared by both the primary and fallback
model resolution paths:
Two routing modes:
1. **Strip provider prefix** — OpenRouter-style names like
``"anthropic/claude-opus-4.6"`` are reduced to ``"claude-opus-4.6"``.
2. **Dot-to-hyphen conversion** — when *not* routing through OpenRouter
the direct Anthropic API requires hyphen-separated versions
(``"claude-opus-4-6"``), so dots are replaced with hyphens.
1. **OpenRouter active** — the canonical OpenRouter slug is
``"<vendor>/<model>"`` (e.g. ``"anthropic/claude-opus-4.6"``,
``"moonshotai/kimi-k2.6"``). Pass the prefixed name through
unchanged so OpenRouter can route to the correct provider. Anthropic
names happen to also resolve when stripped, but non-Anthropic vendors
(Moonshot, Google, etc.) do not — keeping the prefix is the only form
that works for every model in the catalog.
2. **Direct Anthropic** — strip the OpenRouter ``anthropic/`` prefix
and convert dots to hyphens (``"claude-opus-4.6"`` →
``"claude-opus-4-6"``) since the Anthropic Messages API rejects
both the prefix and dot-separated versions. Raises ``ValueError``
when a non-Anthropic vendor slug is paired with direct-Anthropic
mode — silently stripping ``moonshotai/`` would send ``kimi-k2.6``
to the Anthropic API and produce an opaque ``model_not_found``
error far from the misconfiguration source.
"""
if config.openrouter_active:
return raw_model
model = raw_model
if "/" in model:
model = model.split("/", 1)[1]
# OpenRouter uses dots in versions (claude-opus-4.6) but the direct
# Anthropic API requires hyphens (claude-opus-4-6). Only normalise
# when NOT routing through OpenRouter.
if not config.openrouter_active:
model = model.replace(".", "-")
return model
vendor, model = model.split("/", 1)
if vendor != "anthropic":
raise ValueError(
f"Direct-Anthropic mode (use_openrouter=False or missing "
f"OpenRouter credentials) requires an Anthropic model, got "
f"vendor={vendor!r} from model={raw_model!r}. Set "
f"CHAT_THINKING_STANDARD_MODEL/CHAT_THINKING_ADVANCED_MODEL "
f"to an anthropic/* slug, or enable OpenRouter."
)
return model.replace(".", "-")
def _resolve_sdk_model() -> str | None:
"""Resolve the model name for the Claude Agent SDK CLI.
"""Resolve the SDK-CLI model name from static config (no LD lookup).
Uses `config.claude_agent_model` if set, otherwise derives from
`config.thinking_standard_model` via :func:`_normalize_model_name`.
``config.claude_agent_model`` is an explicit override that wins
unconditionally. When the Claude Code subscription is enabled and no
override is set, returns ``None`` so the CLI picks the model for the
user's subscription plan. Otherwise derives from
``config.thinking_standard_model``.
When `use_claude_code_subscription` is enabled and no explicit
`claude_agent_model` is set, returns `None` so the CLI uses the
default model for the user's subscription plan.
For per-user routing (LaunchDarkly overrides), see
:func:`_resolve_sdk_model_for_request`.
"""
if config.claude_agent_model:
return config.claude_agent_model
@@ -715,6 +753,18 @@ def _resolve_sdk_model() -> str | None:
return _normalize_model_name(config.thinking_standard_model)
async def _resolve_thinking_model_for_user(
tier: "CopilotLlmModel",
user_id: str | None,
) -> str:
"""LD-aware thinking-tier model pick for a specific user.
Consults ``copilot-thinking-{tier}-model`` and falls back to the
``ChatConfig`` default on missing user / missing flag.
"""
return await resolve_model("thinking", tier, user_id, config=config)
def _resolve_fallback_model() -> str | None:
"""Resolve the fallback model name via :func:`_normalize_model_name`.
@@ -729,37 +779,94 @@ def _resolve_fallback_model() -> str | None:
async def _resolve_sdk_model_for_request(
model: "CopilotLlmModel | None",
session_id: str,
user_id: str | None = None,
) -> str | None:
"""Resolve the SDK model string for a turn.
Priority (highest first):
1. Explicit per-request ``model`` tier from the frontend toggle.
2. Global config default (``_resolve_sdk_model()``).
Returns ``None`` when the Claude Code subscription default applies.
Rate-limit accounting no longer applies a multiplier — the real turn
cost (reported by the SDK) already reflects model-pricing differences.
1. ``config.claude_agent_model`` — unconditional override, bypasses LD.
2. LaunchDarkly ``copilot-thinking-{tier}-model`` if it serves a value
different from the config default for *user_id*. An LD-served
override wins over subscription mode so admins can route specific
users to a specific model without flipping subscription on/off.
3. ``config.use_claude_code_subscription`` on the standard tier —
returns ``None`` so the CLI picks the subscription default (this
branch fires when LD has no opinion, i.e. the value equals the
config default).
4. ``ChatConfig`` static default for the tier.
"""
if model == "advanced":
sdk_model = _normalize_model_name(config.thinking_advanced_model)
if config.claude_agent_model:
return config.claude_agent_model
tier_name: "CopilotLlmModel" = "advanced" if model == "advanced" else "standard"
# Strip at read time so a stray trailing space in ``CHAT_*_MODEL`` (a
# common ``.env`` pitfall) doesn't make the ``resolved == tier_default``
# comparison below spuriously diverge — ``resolve_model`` already strips
# the LD side, so both halves must end up whitespace-normalised to stay
# equal when they're semantically equal. Downstream ``_normalize_model_name``
# also benefits from the strip.
tier_default = (
config.thinking_advanced_model
if tier_name == "advanced"
else config.thinking_standard_model
).strip()
resolved = await _resolve_thinking_model_for_user(tier_name, user_id)
# Subscription mode on standard tier only wins when LD has no opinion
# (value == config default ⇒ admin hasn't explicitly pointed this
# user somewhere). Any LD override — even to the same value with
# stripped whitespace normalised — is an explicit admin choice that
# must be honoured. Without this, a subscription-mode deployment
# silently ignores the ``copilot-thinking-standard-model`` flag
# entirely, which defeats the point of cohort-based routing.
ld_overrides_default = resolved != tier_default
if (
not ld_overrides_default
and tier_name == "standard"
and config.use_claude_code_subscription
):
logger.info(
"[SDK] [%s] Per-request model override: advanced (%s)",
"[SDK] [%s] Subscription default (tier=standard, LD unset)",
session_id[:12] if session_id else "?",
)
return None
try:
sdk_model = _normalize_model_name(resolved)
except ValueError as exc:
# The per-user LD value didn't pass ``_normalize_model_name``'s
# vendor check (most commonly: a ``moonshotai/kimi-*`` slug on a
# direct-Anthropic deployment that has no OpenRouter route). Fail
# soft to the TIER-SPECIFIC config default — using the generic
# ``_resolve_sdk_model()`` here would pin advanced-tier requests to
# ``thinking_standard_model`` (Sonnet) whenever LD misconfigures
# the advanced cell, silently downgrading the user's chosen tier.
try:
sdk_model = _normalize_model_name(tier_default)
except ValueError:
# Config default is *also* invalid for the active routing
# mode — this is a deployment-level misconfig that the
# ``model_validator`` should catch at startup. Re-raise the
# original LD error so the issue surfaces loudly rather than
# returning something misleading.
raise exc
logger.warning(
"[SDK] [%s] LD model %r rejected for tier=%s (%s); falling "
"back to tier default %s",
session_id[:12] if session_id else "?",
resolved,
tier_name,
exc,
sdk_model,
)
return sdk_model
if model == "standard":
# Reset to config default — respects subscription mode (None = CLI default).
sdk_model = _resolve_sdk_model()
logger.info(
"[SDK] [%s] Per-request model override: standard (%s)",
session_id[:12] if session_id else "?",
sdk_model or "subscription-default",
)
return sdk_model
return _resolve_sdk_model()
logger.info(
"[SDK] [%s] Resolved model for tier=%s: %s",
session_id[:12] if session_id else "?",
tier_name,
sdk_model,
)
return sdk_model
_MAX_TRANSIENT_BACKOFF_SECONDS = 30
@@ -823,7 +930,11 @@ async def _do_transient_backoff(
"""
yield StreamStatus(message=f"Connection interrupted, retrying in {backoff}s…")
await asyncio.sleep(backoff)
state.adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
state.adapter = SDKResponseAdapter(
message_id=message_id,
session_id=session_id,
render_reasoning_in_ui=config.render_reasoning_in_ui,
)
state.usage.reset()
@@ -2084,6 +2195,33 @@ async def _run_stream_attempt(
len(state.adapter.resolved_tool_calls),
)
# Capture OpenRouter generation IDs from each
# ``AssistantMessage.message_id`` — when routed via OpenRouter
# these are ``gen-...`` slugs we can use post-turn to query
# ``/api/v1/generation?id=`` for the authoritative per-turn
# cost and token counts (the CLI's ``total_cost_usd`` is
# computed from a static Anthropic pricing table that
# silently over-bills non-Anthropic routes). Direct-Anthropic
# turns produce ``msg_...`` IDs which the generation endpoint
# doesn't know about — harmlessly ignored at reconcile time.
if isinstance(sdk_msg, AssistantMessage):
msg_id = sdk_msg.message_id
if (
msg_id is not None
and msg_id.startswith("gen-")
and msg_id not in state.generation_ids
):
state.generation_ids.append(msg_id)
# Track the model the SDK actually used — when a fallback
# activates, this differs from ``state.options.model``.
# Consumed by the Moonshot cost-override decision so we
# don't mis-bill a fallback-Anthropic response at
# Moonshot rates (or a fallback-Moonshot at Anthropic
# rates).
observed = getattr(sdk_msg, "model", None)
if isinstance(observed, str) and observed:
state.observed_model = observed
# Log AssistantMessage API errors (e.g. invalid_request)
# so we can debug Anthropic API 400s surfaced by the CLI.
sdk_error = getattr(sdk_msg, "error", None)
@@ -2252,7 +2390,37 @@ async def _run_stream_attempt(
state.usage.completion_tokens,
)
if sdk_msg.total_cost_usd is not None:
state.usage.cost_usd = sdk_msg.total_cost_usd
# Default: trust the CLI-reported value. Accurate for
# Anthropic models (the CLI's bundled pricing table is
# Anthropic-authored), and becomes the sync-path cost
# when the reconcile is disabled or fails.
# Prefer the ACTUALLY executed model
# (``state.observed_model`` from ``AssistantMessage.model``)
# over the requested primary (``state.options.model``)
# so a fallback activation doesn't mis-route pricing.
active_model = state.observed_model or getattr(
state.options, "model", None
)
if _is_moonshot_model(active_model):
# Moonshot slug — the CLI doesn't know Moonshot's
# rate card and silently bills at Sonnet rates
# (~5x over-charge). Replace with the rate-card
# estimate so the in-stream ``cost_usd`` and the
# reconcile's lookup-fail fallback reflect
# reality. Reconcile
# (``record_turn_cost_from_openrouter``) still
# overrides this value when every gen-ID lookup
# succeeds.
state.usage.cost_usd = _override_cost_for_moonshot(
model=active_model,
sdk_reported_usd=sdk_msg.total_cost_usd,
prompt_tokens=state.usage.prompt_tokens,
completion_tokens=state.usage.completion_tokens,
cache_read_tokens=state.usage.cache_read_tokens,
cache_creation_tokens=state.usage.cache_creation_tokens,
)
else:
state.usage.cost_usd = sdk_msg.total_cost_usd
# Emit compaction end if SDK finished compacting.
# Sync TranscriptBuilder with the CLI's active context.
@@ -2374,6 +2542,18 @@ async def _run_stream_attempt(
skip_strip=response is tail_delta,
)
if dispatched is not None:
# Persistence (via _dispatch_response) always runs so the
# session transcript keeps role='reasoning' rows; the
# wire is gated so UI can suppress rendering.
if not state.adapter.render_reasoning_in_ui and isinstance(
dispatched,
(
StreamReasoningStart,
StreamReasoningDelta,
StreamReasoningEnd,
),
):
continue
yield dispatched
# Mid-turn follow-up persistence: the MCP tool wrapper drains
@@ -2434,16 +2614,36 @@ async def _run_stream_attempt(
# flush the assistant message before tool_calls are set on it
# (text and tool_use arrive as separate SDK events), the
# tool_calls update is lost — the next flush starts past it.
_msgs_since_flush += 1
#
# With ``include_partial_messages=True`` the CLI delivers
# hundreds of ``StreamEvent`` messages per turn — incrementing
# ``_msgs_since_flush`` on each one trips the threshold long
# before the assistant text is complete, saving a truncated
# prefix that subsequent deltas can never extend (append-only).
# Count only messages that produce a persisted row boundary
# (AssistantMessage, UserMessage, ResultMessage) and skip
# raw StreamEvents. Also skip when text or reasoning is
# still in-flight on the adapter: the row is live and a flush
# would lock it at its current length.
if not isinstance(sdk_msg, StreamEvent):
_msgs_since_flush += 1
now = time.monotonic()
has_pending_tools = (
acc.has_appended_assistant
and acc.accumulated_tool_calls
and not acc.has_tool_results
)
if not has_pending_tools and (
_msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD
or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS
adapter = state.adapter
has_open_block = (
adapter.has_started_text and not adapter.has_ended_text
) or (adapter.has_started_reasoning and not adapter.has_ended_reasoning)
if (
not has_pending_tools
and not has_open_block
and (
_msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD
or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS
)
):
try:
await asyncio.shield(upsert_chat_session(ctx.session))
@@ -2920,6 +3120,12 @@ async def stream_chat_completion_sdk(
# Defaults ensure the finally block can always reference these safely even when
# an early return (e.g. sdk_cwd error) skips their normal assignment below.
sdk_model: str | None = None
# Wall-clock timestamp captured before the CLI runs so the
# OpenRouter reconcile can filter subagent JSONLs by mtime — only
# files created during THIS turn contribute gen-IDs. Without this
# the sweep would pick up prior turns' compaction files that persist
# under ``<session_id>/subagents/``, double-billing the user.
turn_start_ts = time.time()
# Make sure there is no more code between the lock acquisition and try-block.
try:
@@ -3043,8 +3249,8 @@ async def stream_chat_completion_sdk(
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
# Resolve model (request tier → config default).
sdk_model = await _resolve_sdk_model_for_request(model, session_id)
# Resolve model (request tier → LD per-user override → config default).
sdk_model = await _resolve_sdk_model_for_request(model, session_id, user_id)
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
compaction = CompactionTracker()
@@ -3056,10 +3262,18 @@ async def stream_chat_completion_sdk(
on_compact=compaction.on_compact,
)
disabled_tool_groups: list[ToolGroup] = []
if not graphiti_enabled:
disabled_tool_groups.append("graphiti")
if permissions is not None:
allowed, disallowed = apply_tool_permissions(permissions, use_e2b=use_e2b)
allowed, disallowed = apply_tool_permissions(
permissions, use_e2b=use_e2b, disabled_groups=disabled_tool_groups
)
else:
allowed = get_copilot_tool_names(use_e2b=use_e2b)
allowed = get_copilot_tool_names(
use_e2b=use_e2b, disabled_groups=disabled_tool_groups
)
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
def _on_stderr(line: str) -> None:
@@ -3129,6 +3343,17 @@ async def stream_chat_completion_sdk(
sdk_options_kwargs["effort"] = config.claude_agent_thinking_effort
if sdk_model:
sdk_options_kwargs["model"] = sdk_model
if config.sdk_include_partial_messages:
# Opt into per-token streaming — the CLI emits raw Anthropic
# ``content_block_delta`` events as ``StreamEvent`` messages
# ahead of each summary ``AssistantMessage`` so reasoning and
# text land on the wire token-by-token (matching the baseline
# path's UX shipped in #12873). ``SDKResponseAdapter`` consumes
# the partial stream via ``_handle_stream_event`` and emits
# only the tail diff from the subsequent summary, so content
# never double-emits and a summary-only short block still
# reaches the UI.
sdk_options_kwargs["include_partial_messages"] = True
if sdk_env:
sdk_options_kwargs["env"] = sdk_env
@@ -3160,7 +3385,11 @@ async def stream_chat_completion_sdk(
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type] # dynamic kwargs
adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
adapter = SDKResponseAdapter(
message_id=message_id,
session_id=session_id,
render_reasoning_in_ui=config.render_reasoning_in_ui,
)
# Propagate user_id/session_id as OTEL context attributes so the
# langsmith tracing integration attaches them to every span. This
@@ -3494,7 +3723,9 @@ async def stream_chat_completion_sdk(
session, user_id, is_user_message, state.query_message
)
state.adapter = SDKResponseAdapter(
message_id=message_id, session_id=session_id
message_id=message_id,
session_id=session_id,
render_reasoning_in_ui=config.render_reasoning_in_ui,
)
# Reset token accumulators so a failed attempt's partial
# usage is not double-counted in the successful attempt.
@@ -3854,18 +4085,103 @@ async def stream_chat_completion_sdk(
# --- Persist token usage to session + rate-limit counters ---
# Both must live in finally so they stay consistent even when an
# exception interrupts the try block after StreamUsage was yielded.
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
log_prefix=log_prefix,
cost_usd=turn_cost_usd,
model=sdk_model or config.thinking_standard_model,
provider="anthropic",
effective_model = sdk_model or config.thinking_standard_model
# ``state`` is populated lazily inside the retry loop; when the
# turn exits before the first attempt runs (e.g. very early
# validation error) it's still None, so ``generation_ids`` is
# empty by definition.
collected_gen_ids: list[str] = (
list(state.generation_ids) if state is not None else []
)
_use_openrouter_reconcile = bool(
config.openrouter_active
and config.sdk_reconcile_openrouter_cost
and collected_gen_ids
)
# CLI project dir — used by the reconcile task to sweep for
# compaction subagents' gen-IDs. ``sdk_cwd`` is the per-session
# CLI working directory; the CLI encodes it into the project-dir
# name the same way ``encode_cwd_for_cli`` does, and writes
# the main transcript + any ``subagents/`` alongside it under
# ``~/.claude/projects/<encoded>/``. Empty when sdk_cwd isn't
# set (shouldn't happen in practice for SDK turns).
cli_project_dir: str | None = None
if sdk_cwd:
cli_project_dir = os.path.join(
os.path.expanduser("~/.claude/projects"),
encode_cwd_for_cli(sdk_cwd),
)
if _use_openrouter_reconcile:
# Defer the single cost-and-rate-limit write to a background
# task that queries OpenRouter's authoritative
# ``/generation?id=`` for every round in this turn. Covers
# all vendors:
#
# * Non-Anthropic (Kimi et al): the CLI's ``total_cost_usd``
# is computed from a static Anthropic rate table that
# doesn't know the model — silently over-bills by ~5x.
# The reconcile replaces it with OpenRouter's real bill.
# * Anthropic via OpenRouter: the CLI's number matches
# Anthropic's own rates penny-for-penny in the common
# case, but the reconcile catches any rate change the
# CLI binary hasn't picked up and any OpenRouter-side
# divergence (cache-discount accounting, promo pricing).
#
# The task calls ``persist_and_record_usage`` exactly once
# per turn — same method as the sync path, so append-only
# cost-log + rate-limit counter update together. The sync
# path below is skipped entirely when the reconcile fires,
# so no double-counting. Kill-switch:
# ``CHAT_SDK_RECONCILE_OPENROUTER_COST=false``.
#
# Brief window (~0.5-2s) where the rate-limit counter is
# unaware of this turn — back-to-back turns in that window
# see a stale counter.
asyncio.create_task(
record_turn_cost_from_openrouter(
session=session,
user_id=user_id,
model=effective_model,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
generation_ids=collected_gen_ids,
cli_project_dir=cli_project_dir,
cli_session_id=session_id,
turn_start_ts=turn_start_ts,
fallback_cost_usd=turn_cost_usd,
api_key=config.api_key,
log_prefix=log_prefix,
)
)
else:
# Reconcile disabled, OpenRouter inactive, or subscription
# path (no gen-IDs). Record the SDK CLI's
# ``total_cost_usd`` synchronously: accurate for Anthropic
# (same rate card as billing); for non-Anthropic it's the
# rate-card estimate that ``_override_cost_for_non_anthropic``
# caps (still 1.5-2x off vs real OpenRouter bill, but much
# closer than the ~5x Sonnet-rate fallback).
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=turn_prompt_tokens,
completion_tokens=turn_completion_tokens,
cache_read_tokens=turn_cache_read_tokens,
cache_creation_tokens=turn_cache_creation_tokens,
log_prefix=log_prefix,
cost_usd=turn_cost_usd,
model=effective_model,
# ``provider`` labels the cost-analytics row; the cost
# value still comes from the SDK-reported number.
# Tracks the actual upstream so the row matches reality:
# OpenRouter when ``openrouter_active``, Anthropic
# otherwise.
provider=("open_router" if config.openrouter_active else "anthropic"),
)
# --- Persist session messages ---
# This MUST run in finally to persist messages even when the generator

View File

@@ -366,38 +366,84 @@ class TestNormalizeModelName:
The per-request model toggle calls _normalize_model_name with either
``config.thinking_advanced_model`` (for 'advanced') or
``config.thinking_standard_model`` (for 'standard'). These tests verify
the OpenRouter/provider-prefix stripping that keeps the value compatible
with the Claude CLI.
the OpenRouter/direct-Anthropic split: OpenRouter routes by full
``vendor/model`` slug, while direct-Anthropic strips the prefix and
converts dots to hyphens.
"""
def test_strips_anthropic_prefix(self):
@pytest.fixture
def _direct_anthropic_config(self, monkeypatch: pytest.MonkeyPatch):
"""Force ``config.openrouter_active = False`` for prefix-strip tests.
Pins the SDK model fields to anthropic/* so the new
``_validate_sdk_model_vendor_compatibility`` model_validator
permits ChatConfig construction.
"""
from backend.copilot import config as cfg_mod
cfg = cfg_mod.ChatConfig(
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4-7",
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
@pytest.fixture
def _openrouter_config(self, monkeypatch: pytest.MonkeyPatch):
"""Force ``config.openrouter_active = True`` for slug-preservation tests."""
from backend.copilot import config as cfg_mod
cfg = cfg_mod.ChatConfig(
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
use_claude_code_subscription=False,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
def test_strips_anthropic_prefix(self, _direct_anthropic_config):
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
def test_strips_openai_prefix(self):
assert _normalize_model_name("openai/gpt-4o") == "gpt-4o"
def test_rejects_non_anthropic_vendor_in_direct_mode(
self, _direct_anthropic_config
):
"""Direct-Anthropic mode must fail loudly on non-Anthropic vendor
slugs — silent strip would send e.g. ``gpt-4o`` to the Anthropic
API and produce an opaque model_not_found error."""
with pytest.raises(ValueError, match="requires an Anthropic model"):
_normalize_model_name("openai/gpt-4o")
with pytest.raises(ValueError, match="requires an Anthropic model"):
_normalize_model_name("moonshotai/kimi-k2.6")
with pytest.raises(ValueError, match="requires an Anthropic model"):
_normalize_model_name("google/gemini-2.5-flash")
def test_strips_google_prefix(self):
assert _normalize_model_name("google/gemini-2.5-flash") == "gemini-2.5-flash"
def test_already_normalized_unchanged(self):
def test_already_normalized_unchanged(self, _direct_anthropic_config):
assert (
_normalize_model_name("claude-sonnet-4-20250514")
== "claude-sonnet-4-20250514"
)
def test_empty_string_unchanged(self):
def test_empty_string_unchanged(self, _direct_anthropic_config):
assert _normalize_model_name("") == ""
def test_opus_model_roundtrip(self):
"""The exact string used for the 'opus' toggle strips correctly."""
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
def test_opus_model_dot_to_hyphen(self, _direct_anthropic_config):
"""Direct-Anthropic mode: dots in versions become hyphens."""
assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4-6"
def test_sonnet_openrouter_model(self):
"""Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly."""
def test_openrouter_keeps_anthropic_slug(self, _openrouter_config):
"""OpenRouter routes by full slug — keep prefix and dots intact."""
assert (
_normalize_model_name("anthropic/claude-sonnet-4-6") == "claude-sonnet-4-6"
_normalize_model_name("anthropic/claude-sonnet-4.6")
== "anthropic/claude-sonnet-4.6"
)
def test_openrouter_keeps_kimi_slug(self, _openrouter_config):
"""Non-Anthropic vendors (Moonshot) require the prefix to route."""
assert _normalize_model_name("moonshotai/kimi-k2.6") == "moonshotai/kimi-k2.6"
# ---------------------------------------------------------------------------
# _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug)

View File

@@ -4,6 +4,7 @@ import asyncio
import base64
import os
from dataclasses import dataclass
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -17,6 +18,7 @@ from .service import (
_normalize_model_name,
_prepare_file_attachments,
_resolve_sdk_model,
_resolve_sdk_model_for_request,
_safe_close_sdk_client,
)
@@ -355,11 +357,16 @@ class TestNormalizeModelName:
api_key=None,
base_url=None,
use_claude_code_subscription=False,
# Pin SDK slugs to anthropic/* so the new
# _validate_sdk_model_vendor_compatibility allows construction.
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4-7",
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4-6"
def test_dots_preserved_for_openrouter(self, monkeypatch, _clean_config_env):
def test_openrouter_keeps_full_slug(self, monkeypatch, _clean_config_env):
"""OpenRouter routes by ``vendor/model`` slug — keep prefix and dots."""
from backend.copilot import config as cfg_mod
cfg = cfg_mod.ChatConfig(
@@ -369,7 +376,11 @@ class TestNormalizeModelName:
use_claude_code_subscription=False,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4.6"
assert (
_normalize_model_name("anthropic/claude-opus-4.6")
== "anthropic/claude-opus-4.6"
)
assert _normalize_model_name("moonshotai/kimi-k2.6") == "moonshotai/kimi-k2.6"
def test_no_prefix_no_dots(self, monkeypatch, _clean_config_env):
from backend.copilot import config as cfg_mod
@@ -379,6 +390,8 @@ class TestNormalizeModelName:
api_key=None,
base_url=None,
use_claude_code_subscription=False,
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4-7",
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
assert (
@@ -390,8 +403,9 @@ class TestNormalizeModelName:
class TestResolveSdkModel:
"""Tests for _resolve_sdk_model — model ID resolution for the SDK CLI."""
def test_openrouter_active_keeps_dots(self, monkeypatch, _clean_config_env):
"""When OpenRouter is fully active, model keeps dot-separated version."""
def test_openrouter_active_keeps_full_slug(self, monkeypatch, _clean_config_env):
"""When OpenRouter is fully active, the canonical vendor/model slug
is preserved so OpenRouter can route to the correct provider."""
from backend.copilot import config as cfg_mod
cfg = cfg_mod.ChatConfig(
@@ -403,7 +417,23 @@ class TestResolveSdkModel:
use_claude_code_subscription=False,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
assert _resolve_sdk_model() == "claude-opus-4.6"
assert _resolve_sdk_model() == "anthropic/claude-opus-4.6"
def test_openrouter_active_kimi_slug(self, monkeypatch, _clean_config_env):
"""Non-Anthropic models (Kimi via Moonshot) require the prefix to
survive OpenRouter routing — strip would leave an unroutable slug."""
from backend.copilot import config as cfg_mod
cfg = cfg_mod.ChatConfig(
thinking_standard_model="moonshotai/kimi-k2.6",
claude_agent_model=None,
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
use_claude_code_subscription=False,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
assert _resolve_sdk_model() == "moonshotai/kimi-k2.6"
def test_openrouter_disabled_normalizes_to_hyphens(
self, monkeypatch, _clean_config_env
@@ -488,6 +518,213 @@ class TestResolveSdkModel:
assert _resolve_sdk_model() == "claude-opus-4-6"
class TestResolveSdkModelForRequestLdFallback:
"""``_resolve_sdk_model_for_request`` must fail soft when the LD value
can't be normalised for the active routing mode — flagged as MAJOR by
CodeRabbit + HIGH by Sentry when it was a hard ValueError."""
@pytest.mark.asyncio
async def test_direct_anthropic_mode_rejects_kimi_ld_value_and_falls_back(
self, monkeypatch, _clean_config_env
):
"""LD serves ``moonshotai/kimi-k2.6`` but we're on direct-Anthropic
(no OpenRouter key). ``_normalize_model_name`` raises; the
resolver must log + return the config-default path instead of
500-ing the turn."""
cfg = cfg_mod.ChatConfig(
thinking_standard_model="anthropic/claude-sonnet-4-6",
claude_agent_model=None,
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
with patch(
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
new=AsyncMock(return_value="moonshotai/kimi-k2.6"),
):
resolved = await _resolve_sdk_model_for_request(
model="standard", session_id="sess-abc", user_id="user-1"
)
# Fallback == tier-specific config default (thinking_standard_model
# normalised to hyphen-form for direct-Anthropic mode).
assert resolved == "claude-sonnet-4-6"
@pytest.mark.asyncio
async def test_openrouter_mode_accepts_ld_kimi_value(
self, monkeypatch, _clean_config_env
):
"""On OpenRouter the Kimi slug is legitimate — no fallback,
value returned as-is."""
cfg = cfg_mod.ChatConfig(
thinking_standard_model="anthropic/claude-sonnet-4-6",
claude_agent_model=None,
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
use_claude_code_subscription=False,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
with patch(
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
new=AsyncMock(return_value="moonshotai/kimi-k2.6"),
):
resolved = await _resolve_sdk_model_for_request(
model="standard", session_id="sess-abc", user_id="user-1"
)
assert resolved == "moonshotai/kimi-k2.6"
@pytest.mark.asyncio
async def test_advanced_tier_fallback_uses_advanced_default_not_standard(
self, monkeypatch, _clean_config_env
):
"""An LD-rejected ADVANCED slug must fall back to the advanced
config default (Opus) — not the standard default (Sonnet).
Using ``_resolve_sdk_model()`` as the fallback silently
downgraded the user's chosen tier. Flagged MAJOR by CodeRabbit
+ HIGH by Sentry on the first fail-soft commit."""
cfg = cfg_mod.ChatConfig(
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4.7",
claude_agent_model=None,
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=False,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
with patch(
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
new=AsyncMock(return_value="moonshotai/kimi-k2.6"),
):
resolved = await _resolve_sdk_model_for_request(
model="advanced", session_id="sess-adv", user_id="user-1"
)
# Direct-Anthropic normalises anthropic/claude-opus-4.7 → claude-opus-4-7
assert resolved == "claude-opus-4-7"
@pytest.mark.asyncio
async def test_standard_ld_override_wins_over_subscription(
self, monkeypatch, _clean_config_env
):
"""Bug reported in local test: subscription mode + LD serving Kimi
on ``copilot-thinking-standard-model`` returned ``None`` (CLI
picked subscription default Opus), silently ignoring the LD
override. An LD value different from the config default is an
explicit admin decision and must win."""
cfg = cfg_mod.ChatConfig(
thinking_standard_model="anthropic/claude-sonnet-4-6",
claude_agent_model=None,
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
use_claude_code_subscription=True,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
with patch(
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
new=AsyncMock(return_value="moonshotai/kimi-k2.6"),
):
resolved = await _resolve_sdk_model_for_request(
model="standard", session_id="sess-std-sub", user_id="user-1"
)
# Expect LD-served Kimi, NOT None (the old subscription-default bypass)
assert resolved == "moonshotai/kimi-k2.6"
@pytest.mark.asyncio
async def test_standard_subscription_survives_trailing_whitespace_in_env(
self, monkeypatch, _clean_config_env
):
"""``_resolve_thinking_model_for_user`` strips whitespace from the LD
side; the config tier default must be stripped too, otherwise a
stray trailing space in ``CHAT_THINKING_STANDARD_MODEL`` makes
``resolved == tier_default`` spuriously False and bypasses
subscription-default mode. Sentry HIGH on L856."""
cfg = cfg_mod.ChatConfig(
thinking_standard_model="anthropic/claude-sonnet-4-6 ", # trailing spaces
claude_agent_model=None,
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=True,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
with patch(
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
new=AsyncMock(return_value="anthropic/claude-sonnet-4-6"),
):
resolved = await _resolve_sdk_model_for_request(
model="standard", session_id="sess-ws", user_id="user-1"
)
assert resolved is None, (
"LD value semantically matches the whitespace-padded config "
"default — subscription mode must still win and return None"
)
@pytest.mark.asyncio
async def test_standard_subscription_default_honoured_when_ld_matches_config(
self, monkeypatch, _clean_config_env
):
"""When LD serves the SAME value as the config default (i.e. the
flag is effectively unset / no override), subscription mode still
wins and we return ``None`` so the CLI uses the subscription
default model."""
cfg = cfg_mod.ChatConfig(
thinking_standard_model="anthropic/claude-sonnet-4-6",
claude_agent_model=None,
use_openrouter=False,
api_key=None,
base_url=None,
use_claude_code_subscription=True,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
with patch(
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
new=AsyncMock(return_value="anthropic/claude-sonnet-4-6"),
):
resolved = await _resolve_sdk_model_for_request(
model="standard", session_id="sess-std-nop", user_id="user-1"
)
assert resolved is None
@pytest.mark.asyncio
async def test_advanced_tier_consults_ld_under_subscription(
self, monkeypatch, _clean_config_env
):
"""Subscription mode bypasses LD only on the standard tier —
the advanced tier always consults LD because the user explicitly
asked for the premium path. A subscription + advanced request
with LD-served Opus must return Opus (not ``None``)."""
cfg = cfg_mod.ChatConfig(
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4.7",
claude_agent_model=None,
use_openrouter=True,
api_key="or-key",
base_url="https://openrouter.ai/api/v1",
use_claude_code_subscription=True,
)
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
with patch(
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
new=AsyncMock(return_value="anthropic/claude-opus-4.7"),
):
resolved = await _resolve_sdk_model_for_request(
model="advanced", session_id="sess-adv-sub", user_id="user-1"
)
assert resolved == "anthropic/claude-opus-4.7"
# ---------------------------------------------------------------------------
# _is_sdk_disconnect_error — classify client disconnect cleanup errors
# ---------------------------------------------------------------------------
@@ -651,6 +888,8 @@ class TestSystemPromptPreset:
api_key=None,
base_url=None,
use_claude_code_subscription=False,
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4-7",
)
assert cfg.claude_agent_cross_user_prompt_cache is True
@@ -662,6 +901,8 @@ class TestSystemPromptPreset:
api_key=None,
base_url=None,
use_claude_code_subscription=False,
thinking_standard_model="anthropic/claude-sonnet-4-6",
thinking_advanced_model="anthropic/claude-opus-4-7",
)
assert cfg.claude_agent_cross_user_prompt_cache is False
@@ -674,3 +915,225 @@ class TestIdleTimeoutConstant:
def test_idle_timeout_is_10_min(self):
assert _IDLE_TIMEOUT_SECONDS == 10 * 60
# ---------------------------------------------------------------------------
# _RetryState.observed_model — Moonshot cost-override input
# ---------------------------------------------------------------------------
class TestRetryStateObservedModel:
"""Regression guards for the ``observed_model`` field added to
``_RetryState``. The Moonshot cost override reads this — when a
fallback model activates mid-attempt, the requested primary
(``state.options.model``) no longer matches what actually ran."""
def _make_state(self, *, options_model: str | None = "primary/model"):
"""Build a minimally-valid ``_RetryState``. All the heavy
collaborators are ``MagicMock()`` — the field we care about is
a plain Optional[str], so the surrounding scaffolding just needs
to let the dataclass instantiate."""
from .service import _RetryState, _TokenUsage
options = MagicMock()
options.model = options_model
return _RetryState(
options=options,
query_message="",
was_compacted=False,
use_resume=False,
resume_file=None,
transcript_msg_count=0,
adapter=MagicMock(),
transcript_builder=MagicMock(),
usage=_TokenUsage(),
)
def test_default_is_none(self):
state = self._make_state()
assert state.observed_model is None
def test_assigned_from_assistant_message_model(self):
"""Simulates the population path in ``_run_stream_attempt``:
``observed`` is pulled off the ``AssistantMessage.model`` attr
and assigned onto ``state.observed_model`` when it's a non-empty
string."""
state = self._make_state()
# Simulates the inline assignment the generator does on each
# AssistantMessage — a non-empty string lands on state.
assistant_like = SimpleNamespace(model="anthropic/claude-sonnet-4-6")
observed = getattr(assistant_like, "model", None)
if isinstance(observed, str) and observed:
state.observed_model = observed
assert state.observed_model == "anthropic/claude-sonnet-4-6"
def test_empty_string_model_is_not_assigned(self):
"""Guard against overwriting a real observed value with an
empty-string model (the generator's ``and observed`` check)."""
state = self._make_state()
state.observed_model = "moonshotai/kimi-k2.6" # seeded from a prior msg
assistant_like = SimpleNamespace(model="")
observed = getattr(assistant_like, "model", None)
if isinstance(observed, str) and observed:
state.observed_model = observed
assert state.observed_model == "moonshotai/kimi-k2.6"
def test_missing_model_attr_leaves_observed_untouched(self):
state = self._make_state()
state.observed_model = "moonshotai/kimi-k2.6"
# AssistantMessage may not carry ``.model`` on older SDK rels.
assistant_like = SimpleNamespace() # no ``.model`` attr
observed = getattr(assistant_like, "model", None)
if isinstance(observed, str) and observed:
state.observed_model = observed
assert state.observed_model == "moonshotai/kimi-k2.6"
# ---------------------------------------------------------------------------
# Moonshot cost-override gate — decision logic at the call site
# ---------------------------------------------------------------------------
class TestMoonshotCostOverrideGate:
"""Regression guards for the decision logic in
``_run_stream_attempt`` that picks between the CLI-reported cost
and the Moonshot rate-card override. The code:
active_model = state.observed_model or getattr(state.options, "model", None)
if _is_moonshot_model(active_model):
state.usage.cost_usd = _override_cost_for_moonshot(...)
else:
state.usage.cost_usd = sdk_msg.total_cost_usd
is critical-path billing logic — make sure observed_model wins over
the requested primary, and Anthropic turns pass through untouched."""
def _decide_cost(
self,
*,
observed_model: str | None,
options_model: str | None,
sdk_reported_usd: float,
prompt_tokens: int = 0,
completion_tokens: int = 0,
) -> float:
"""Mirror of the real decision block — lets us assert the gate
without constructing the whole 1000-line generator."""
from .service import _is_moonshot_model, _override_cost_for_moonshot
active_model = observed_model or options_model
if _is_moonshot_model(active_model):
return _override_cost_for_moonshot(
model=active_model,
sdk_reported_usd=sdk_reported_usd,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cache_read_tokens=0,
cache_creation_tokens=0,
)
return sdk_reported_usd
def test_anthropic_turn_passes_sdk_cost_through(self):
"""Anthropic — the CLI's pricing table is authoritative, so
``state.usage.cost_usd`` is set to ``sdk_msg.total_cost_usd``
unchanged."""
cost = self._decide_cost(
observed_model="anthropic/claude-sonnet-4-6",
options_model="anthropic/claude-sonnet-4-6",
sdk_reported_usd=0.123,
)
assert cost == 0.123
def test_moonshot_turn_uses_rate_card_override(self):
"""Moonshot — the CLI would silently bill at Sonnet rates, so
the override recomputes from the Moonshot rate card."""
cost = self._decide_cost(
observed_model="moonshotai/kimi-k2.6",
options_model="moonshotai/kimi-k2.6",
sdk_reported_usd=0.089862, # CLI's Sonnet-priced estimate.
prompt_tokens=29564,
completion_tokens=78,
)
expected = (29564 * 0.60 + 78 * 2.80) / 1_000_000
assert cost == pytest.approx(expected, rel=1e-9)
# Sanity: ~5x cheaper than the CLI's Sonnet-priced number.
assert cost < 0.089862 / 4
def test_observed_model_wins_over_options_primary(self):
"""The whole point of ``observed_model``: a Moonshot-primary
request that fell back to Anthropic must NOT get Moonshot
pricing applied. The gate follows the observed model, not the
requested primary."""
cost = self._decide_cost(
observed_model="anthropic/claude-sonnet-4-6",
options_model="moonshotai/kimi-k2.6", # what we ASKED for
sdk_reported_usd=0.123,
prompt_tokens=1000,
completion_tokens=100,
)
# Observed == Anthropic → CLI-reported cost passes through unchanged.
assert cost == 0.123
def test_anthropic_to_moonshot_fallback_uses_override(self):
"""The inverse: an Anthropic-primary request that fell back to
Moonshot must get the Moonshot override applied — the CLI is
still billing at Sonnet rates for the fallback response."""
cost = self._decide_cost(
observed_model="moonshotai/kimi-k2.6",
options_model="anthropic/claude-sonnet-4-6",
sdk_reported_usd=0.089862,
prompt_tokens=29564,
completion_tokens=78,
)
expected = (29564 * 0.60 + 78 * 2.80) / 1_000_000
assert cost == pytest.approx(expected, rel=1e-9)
def test_no_observed_falls_back_to_options_model(self):
"""First AssistantMessage hasn't arrived yet (or the SDK didn't
emit ``.model``) — the gate falls back to the requested primary."""
cost = self._decide_cost(
observed_model=None,
options_model="moonshotai/kimi-k2.6",
sdk_reported_usd=0.089862,
prompt_tokens=100,
completion_tokens=10,
)
expected = (100 * 0.60 + 10 * 2.80) / 1_000_000
assert cost == pytest.approx(expected, rel=1e-9)
def test_both_none_passes_sdk_cost_through(self):
"""Subscription mode — ``options.model`` may be None and no
AssistantMessage has arrived yet. ``None`` is not a Moonshot
slug so the SDK number lands unchanged."""
cost = self._decide_cost(
observed_model=None,
options_model=None,
sdk_reported_usd=0.05,
)
assert cost == 0.05
# ---------------------------------------------------------------------------
# Moonshot helper re-exports — keep imports stable for call-site code
# ---------------------------------------------------------------------------
class TestMoonshotHelperReexports:
"""``sdk/service.py`` imports the Moonshot helpers under local
aliases (``_is_moonshot_model``, ``_override_cost_for_moonshot``).
Regression guard so a refactor doesn't silently break the import
path the hot-loop code relies on."""
def test_is_moonshot_model_aliased(self):
from backend.copilot.moonshot import is_moonshot_model as canonical
from .service import _is_moonshot_model
assert _is_moonshot_model is canonical
def test_override_cost_for_moonshot_aliased(self):
from backend.copilot.moonshot import override_cost_usd as canonical
from .service import _override_cost_for_moonshot
assert _override_cost_for_moonshot is canonical

View File

@@ -10,6 +10,7 @@ import json
import logging
import os
import uuid
from collections.abc import Iterable
from contextvars import ContextVar
from typing import TYPE_CHECKING, Any
@@ -33,7 +34,7 @@ from backend.copilot.sdk.file_ref import (
expand_file_refs_in_args,
read_file_bytes,
)
from backend.copilot.tools import TOOL_REGISTRY
from backend.copilot.tools import TOOL_REGISTRY, ToolGroup, tool_names_in_groups
from backend.copilot.tools.base import BaseTool
from backend.util.truncate import truncate
@@ -853,9 +854,29 @@ DANGEROUS_PATTERNS = [
r"subprocess",
]
# Platform-tool names whose MCP wrappers must NOT be exposed to SDK mode.
# Baseline ships an MCP ``TodoWrite`` for model-flexibility parity; SDK mode
# keeps using the CLI-native built-in listed in ``_SDK_BUILTIN_ALWAYS`` so
# there is no double exposure. Public (no leading underscore) so a future
# refactor renaming it is visible at both call sites —
# ``permissions.apply_tool_permissions`` maps short tool names back to the
# CLI built-in form for SDK mode.
BASELINE_ONLY_MCP_TOOLS: frozenset[str] = frozenset({"TodoWrite"})
def _registry_mcp_tools(*, hidden: frozenset[str] = frozenset()) -> list[str]:
return [
f"{MCP_TOOL_PREFIX}{name}"
for name in TOOL_REGISTRY.keys()
if name not in BASELINE_ONLY_MCP_TOOLS and name not in hidden
]
# Static tool name list for the non-E2B case (backward compatibility).
# Includes all capability-gated tools; per-user filtering happens in
# ``get_copilot_tool_names`` when the caller passes ``disabled_groups``.
COPILOT_TOOL_NAMES = [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
*_registry_mcp_tools(),
f"{MCP_TOOL_PREFIX}{WRITE_TOOL_NAME}",
f"{MCP_TOOL_PREFIX}{READ_TOOL_NAME}",
f"{MCP_TOOL_PREFIX}{EDIT_TOOL_NAME}",
@@ -864,20 +885,31 @@ COPILOT_TOOL_NAMES = [
]
def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]:
def get_copilot_tool_names(
*,
use_e2b: bool = False,
disabled_groups: Iterable[ToolGroup] = (),
) -> list[str]:
"""Build the ``allowed_tools`` list for :class:`ClaudeAgentOptions`.
When *use_e2b* is True the SDK built-in file tools are replaced by MCP
equivalents that route to the E2B sandbox.
equivalents that route to the E2B sandbox. Tools belonging to any of
*disabled_groups* are filtered out — see ``ToolGroup`` / ``TOOL_GROUPS``
in ``backend.copilot.tools`` for the full list.
"""
hidden_short_names = tool_names_in_groups(disabled_groups)
hidden_mcp_names = {f"{MCP_TOOL_PREFIX}{n}" for n in hidden_short_names}
if not use_e2b:
return list(COPILOT_TOOL_NAMES)
if not hidden_mcp_names:
return list(COPILOT_TOOL_NAMES)
return [n for n in COPILOT_TOOL_NAMES if n not in hidden_mcp_names]
# In E2B mode, Write/Edit are NOT registered (E2B uses write_file/edit_file
# from E2B_FILE_TOOLS instead), so don't include them here.
# _READ_TOOL_NAME is still needed for SDK tool-result reads.
return [
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
*_registry_mcp_tools(hidden=hidden_short_names),
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
*[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES],
*_SDK_BUILTIN_ALWAYS,

View File

@@ -1249,7 +1249,14 @@ class TestStripStaleThinkingBlocks:
new_asst = self._asst_entry(
"msg_new",
[
{"type": "thinking", "thinking": "latest thoughts"},
# Anthropic-shape thinking block (has signature) — preserved
# on the last turn. Signature-less variant covered by
# ``test_strips_signatureless_last_turn_thinking``.
{
"type": "thinking",
"thinking": "latest thoughts",
"signature": "anthropic-sig",
},
{"type": "text", "text": "world"},
],
uuid="a2",
@@ -1271,11 +1278,16 @@ class TestStripStaleThinkingBlocks:
assert new_content[1]["type"] == "text"
def test_preserves_last_assistant_thinking(self) -> None:
"""The last assistant entry's thinking blocks must be preserved."""
"""The last assistant entry's thinking blocks must be preserved
when they carry an Anthropic ``signature``. Signature-less
blocks (e.g. from Kimi K2.6 via OpenRouter) are stripped to
prevent ``Invalid `signature` in `thinking` block`` errors on
a subsequent Anthropic-model turn — covered by
``test_strips_signatureless_last_turn_thinking``."""
entry = self._asst_entry(
"msg_only",
[
{"type": "thinking", "thinking": "must keep"},
{"type": "thinking", "thinking": "must keep", "signature": "sig"},
{"type": "text", "text": "response"},
],
)
@@ -1284,6 +1296,47 @@ class TestStripStaleThinkingBlocks:
lines = [json.loads(ln) for ln in result.strip().split("\n")]
assert len(lines[0]["message"]["content"]) == 2
def test_strips_signatureless_last_turn_thinking(self) -> None:
"""Cross-model fix (PR #12878): a signature-less thinking block
on the last assistant turn must be stripped before a subsequent
Anthropic-model dispatch tries to replay it."""
entry = self._asst_entry(
"msg_kimi",
[
# No signature → non-Anthropic provider
{"type": "thinking", "thinking": "kimi reasoning"},
{"type": "text", "text": "answer"},
],
)
content = _make_jsonl(entry)
result = strip_stale_thinking_blocks(content)
lines = [json.loads(ln) for ln in result.strip().split("\n")]
types = [b["type"] for b in lines[0]["message"]["content"]]
assert "thinking" not in types
assert "text" in types
def test_preserves_redacted_thinking_on_last_turn(self) -> None:
"""``redacted_thinking`` blocks are signature-less by design
(encrypted ``data`` field instead). Stripping them on the last
turn would violate Anthropic's value-identity requirement for
multi-turn replay. The signature rule only applies to plain
``thinking`` blocks."""
entry = self._asst_entry(
"msg_anthropic",
[
# Anthropic-emitted redacted_thinking: has ``data``,
# never has ``signature``.
{"type": "redacted_thinking", "data": "encrypted_blob"},
{"type": "text", "text": "response"},
],
)
content = _make_jsonl(entry)
result = strip_stale_thinking_blocks(content)
lines = [json.loads(ln) for ln in result.strip().split("\n")]
types = [b["type"] for b in lines[0]["message"]["content"]]
assert "redacted_thinking" in types
assert "text" in types
def test_no_assistant_entries_returns_unchanged(self) -> None:
"""Transcripts with only user entries should pass through unchanged."""
user = self._user_entry("hello")
@@ -1294,12 +1347,13 @@ class TestStripStaleThinkingBlocks:
assert strip_stale_thinking_blocks("") == ""
def test_multiple_turns_strips_all_but_last(self) -> None:
"""With 3 assistant turns, only the last keeps thinking blocks."""
"""With 3 assistant turns, only the last keeps thinking blocks
(and only when those blocks carry an Anthropic ``signature``)."""
entries = [
self._asst_entry(
"msg_1",
[
{"type": "thinking", "thinking": "t1"},
{"type": "thinking", "thinking": "t1", "signature": "s1"},
{"type": "text", "text": "a1"},
],
uuid="a1",
@@ -1308,7 +1362,7 @@ class TestStripStaleThinkingBlocks:
self._asst_entry(
"msg_2",
[
{"type": "thinking", "thinking": "t2"},
{"type": "thinking", "thinking": "t2", "signature": "s2"},
{"type": "text", "text": "a2"},
],
uuid="a2",
@@ -1318,7 +1372,7 @@ class TestStripStaleThinkingBlocks:
self._asst_entry(
"msg_3",
[
{"type": "thinking", "thinking": "t3"},
{"type": "thinking", "thinking": "t3", "signature": "s3"},
{"type": "text", "text": "a3"},
],
uuid="a3",
@@ -1339,16 +1393,18 @@ class TestStripStaleThinkingBlocks:
assert lines[4]["message"]["content"][0]["type"] == "thinking"
def test_same_msg_id_multi_entry_turn(self) -> None:
"""Multiple entries sharing the same message.id (same turn) are preserved."""
"""Multiple entries sharing the same message.id (same turn) are
preserved when their thinking blocks carry an Anthropic
``signature``."""
entries = [
self._asst_entry(
"msg_old",
[{"type": "thinking", "thinking": "old"}],
[{"type": "thinking", "thinking": "old", "signature": "old_sig"}],
uuid="a1",
),
self._asst_entry(
"msg_last",
[{"type": "thinking", "thinking": "t_part1"}],
[{"type": "thinking", "thinking": "t_part1", "signature": "p1_sig"}],
uuid="a2",
parent="a1",
),

View File

@@ -17,6 +17,7 @@ from langfuse import get_client
from langfuse.openai import (
AsyncOpenAI as LangfuseAsyncOpenAI, # pyright: ignore[reportPrivateImportUsage]
)
from openai.types.chat import ChatCompletion
from backend.data.db_accessors import chat_db, understanding_db
from backend.data.understanding import (
@@ -34,6 +35,7 @@ from .model import (
update_session_title,
upsert_chat_session,
)
from .token_tracking import persist_and_record_usage
logger = logging.getLogger(__name__)
@@ -495,20 +497,31 @@ async def _generate_session_title(
message: str,
user_id: str | None = None,
session_id: str | None = None,
) -> str | None:
) -> tuple[str | None, ChatCompletion | None]:
"""Generate a concise title for a chat session based on the first message.
Returns ``(title, response)``. The caller is responsible for
persisting the title AND recording the title call's cost — keeping
them as separate concerns in the caller lets a cost-tracking hiccup
not lose the title, and lets a title-persist failure still record
the cost (we paid for the LLM call either way).
Args:
message: The first user message in the session
user_id: User ID for OpenRouter tracing (optional)
session_id: Session ID for OpenRouter tracing (optional)
Returns:
A short title (3-6 words) or None if generation fails
``(title, response)`` on success; ``(None, None)`` if the LLM
call raised. ``response`` is returned even when ``title`` is
empty so the caller can still record the (paid-for) cost.
"""
try:
# Build extra_body for OpenRouter tracing and PostHog analytics
extra_body: dict[str, Any] = {}
# Build extra_body for OpenRouter tracing and PostHog analytics.
# ``usage: {"include": True}`` asks OR to embed the real billed
# cost into the final usage chunk — matches the baseline path's
# ``_OPENROUTER_INCLUDE_USAGE_COST`` pattern, same read path.
extra_body: dict[str, Any] = {"usage": {"include": True}}
if user_id:
extra_body["user"] = user_id[:128] # OpenRouter limit
extra_body["posthogDistinctId"] = user_id
@@ -534,18 +547,113 @@ async def _generate_session_title(
max_tokens=20,
extra_body=extra_body,
)
title = response.choices[0].message.content
if title:
# Clean up the title
title = title.strip().strip("\"'")
# Limit length
if len(title) > 50:
title = title[:47] + "..."
return title
return None
except Exception as e:
logger.warning(f"Failed to generate session title: {e}")
return None
return None, None
# Robust against an empty ``choices`` list OR a choice whose
# ``message`` is missing ``content`` (shouldn't happen on the OpenAI
# SDK typing, but belt-and-suspenders — the background task would
# otherwise die on ``IndexError`` and lose the (paid-for) cost
# recording we're about to do below).
title: str | None = None
if response.choices:
msg = response.choices[0].message
title = msg.content if msg is not None else None
if title:
title = title.strip().strip("\"'")
if len(title) > 50:
title = title[:47] + "..."
return title, response
def _title_usage_from_response(
response: ChatCompletion,
) -> tuple[int, int, float | None]:
"""Extract ``(prompt_tokens, completion_tokens, cost_usd)`` from a
title-generation chat-completion response.
Returns zeros / ``None`` for missing fields — the OpenAI SDK's
``CompletionUsage`` doesn't declare OpenRouter's ``cost`` extension,
so we read it off ``model_extra`` (pydantic v2 extras container).
Absent for non-OR routes; returned as ``None`` in that case.
"""
usage = response.usage
if usage is None:
return 0, 0, None
prompt_tokens = usage.prompt_tokens or 0
completion_tokens = usage.completion_tokens or 0
extras = usage.model_extra or {}
cost_raw = extras.get("cost") if isinstance(extras, dict) else None
if isinstance(cost_raw, (int, float)):
cost_usd: float | None = float(cost_raw)
else:
cost_usd = None
return prompt_tokens, completion_tokens, cost_usd
async def _record_title_generation_cost(
*,
response: ChatCompletion,
user_id: str | None,
session_id: str | None,
) -> None:
"""Persist the title LLM call's cost to ``PlatformCostLog``.
Title generation runs in a background task per-session — low cost
(~$0.0001 per title) but 100% of sessions pay it. Without this the
admin dashboard under-reports total provider spend by the aggregate
of those calls. Separate ``block_name="copilot:title"`` so the row
is clearly distinguishable from the turn's main ``copilot:SDK`` /
``copilot:baseline`` attributions.
Invariants enforced by the caller:
* ``response`` is a completed ``ChatCompletion`` (the create call
didn't raise) — so ``response.usage`` shape is SDK-contractual.
* Exceptions are NOT suppressed — the caller runs this AFTER
title persistence so a persist failure here doesn't lose the
title, and a real DB / Prisma outage surfaces in the caller's
single background-task warning handler.
"""
prompt_tokens, completion_tokens, cost_usd = _title_usage_from_response(response)
# Nothing meaningful to record — skip the DB roundtrip entirely
# rather than writing a zero-valued row. Covers the non-OR route
# (no ``usage.cost`` field) and the degenerate zero-tokens case.
if cost_usd is None and prompt_tokens == 0 and completion_tokens == 0:
return
# Provider label is derived from the configured ``base_url`` (title
# LLM uses the shared copilot OpenAI client whose base URL mirrors
# ``ChatConfig.base_url``). This lets a deployment that points
# title generation at a non-OR endpoint still get the correct
# ``provider`` on the cost-log row.
provider = (
"open_router"
if (config.base_url and "openrouter.ai" in config.base_url)
else "openai"
)
# Intentionally pass ``session=None``. ``persist_and_record_usage``
# would otherwise append a ``Usage`` entry to the live session
# object, but this background task holds no reference to the
# request-scoped session — we'd have to ``get_chat_session`` +
# ``upsert_chat_session`` round-trip the mutation back, and the
# turn's main ``persist_and_record_usage`` already owns the session
# usage-list mirror for the originating turn. Title cost is
# recorded into ``PlatformCostLog`` (admin dashboard) and the
# microdollar rate-limit counter — those are the two places that
# actually matter for this call.
await persist_and_record_usage(
session=None,
user_id=user_id,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
log_prefix="[title]",
cost_usd=cost_usd,
model=config.title_model,
provider=provider,
)
async def _update_title_async(
@@ -553,15 +661,29 @@ async def _update_title_async(
) -> None:
"""Generate and persist a session title in the background.
Shared by both the SDK and baseline execution paths.
Shared by both the SDK and baseline execution paths. Title
persistence and cost recording are run as independent best-effort
steps — a failure in one does not cancel the other, so a flaky
Prisma call on cost recording never costs us the generated title.
"""
try:
title = await _generate_session_title(message, user_id, session_id)
if title and user_id:
title, response = await _generate_session_title(message, user_id, session_id)
if title and user_id:
try:
await update_session_title(session_id, user_id, title, only_if_empty=True)
logger.debug("Generated title for session %s", session_id)
except Exception as e:
logger.warning("Failed to update session title for %s: %s", session_id, e)
except Exception as e:
logger.warning("Failed to persist session title for %s: %s", session_id, e)
if response is not None:
try:
await _record_title_generation_cost(
response=response, user_id=user_id, session_id=session_id
)
except Exception as e:
logger.warning(
"Failed to record title generation cost for %s: %s", session_id, e
)
async def assign_user_to_session(

View File

@@ -0,0 +1,541 @@
"""Unit tests for title-generation cost tracking helpers.
Covers the new code added in PR #12882:
* ``_title_usage_from_response`` — shape-robust OR ``usage.cost`` extraction
* ``_record_title_generation_cost`` — provider-label + zero-tokens gate
* ``_update_title_async`` — independent title / cost persistence try blocks
* ``_generate_session_title`` — tuple return + robustness against empty choices
Mocks ``persist_and_record_usage`` / ``update_session_title`` at the boundary
where the code under test imports them (``backend.copilot.service.*``).
"""
from __future__ import annotations
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.completion_usage import CompletionUsage
from backend.copilot.service import (
_generate_session_title,
_record_title_generation_cost,
_title_usage_from_response,
_update_title_async,
)
def _build_completion(
*,
content: str | None = "Hello Title",
usage: CompletionUsage | None = None,
choices: list[Choice] | None = None,
) -> ChatCompletion:
if choices is None:
msg = ChatCompletionMessage(role="assistant", content=content)
choices = [Choice(index=0, message=msg, finish_reason="stop")]
return ChatCompletion(
id="cmpl-1",
choices=choices,
created=0,
model="anthropic/claude-haiku",
object="chat.completion",
usage=usage,
)
def _usage_with_cost(cost: object | None) -> CompletionUsage:
"""Return a CompletionUsage whose ``model_extra`` carries ``cost``.
Uses ``model_validate`` so OpenRouter's ``cost`` extension lands in
the pydantic ``model_extra`` dict the helper reads from.
"""
payload: dict[str, object] = {
"prompt_tokens": 12,
"completion_tokens": 3,
"total_tokens": 15,
}
if cost is not None:
payload["cost"] = cost
return CompletionUsage.model_validate(payload)
class TestTitleUsageFromResponse:
"""``_title_usage_from_response`` returns sensible zeros/Nones when
optional fields are absent or of unexpected shape."""
def test_usage_none_returns_all_zero(self):
resp = _build_completion(usage=None)
prompt, completion, cost = _title_usage_from_response(resp)
assert prompt == 0
assert completion == 0
assert cost is None
def test_missing_cost_field_returns_none_cost(self):
resp = _build_completion(usage=_usage_with_cost(None))
prompt, completion, cost = _title_usage_from_response(resp)
assert prompt == 12
assert completion == 3
assert cost is None
def test_cost_as_int_is_coerced_to_float(self):
resp = _build_completion(usage=_usage_with_cost(2))
_, _, cost = _title_usage_from_response(resp)
assert isinstance(cost, float)
assert cost == 2.0
def test_cost_as_float_is_returned_as_is(self):
resp = _build_completion(usage=_usage_with_cost(0.000123))
_, _, cost = _title_usage_from_response(resp)
assert cost == pytest.approx(0.000123)
def test_cost_as_non_numeric_string_returns_none(self):
resp = _build_completion(usage=_usage_with_cost("free"))
_, _, cost = _title_usage_from_response(resp)
assert cost is None
def test_empty_model_extra_returns_none_cost(self):
# ``model_extra`` is empty for non-OR routes where pydantic didn't
# receive any extras — prompt/completion still flow through.
usage = CompletionUsage(prompt_tokens=5, completion_tokens=2, total_tokens=7)
resp = _build_completion(usage=usage)
prompt, completion, cost = _title_usage_from_response(resp)
assert (prompt, completion, cost) == (5, 2, None)
def test_zero_prompt_and_completion_tokens(self):
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
resp = _build_completion(usage=usage)
prompt, completion, cost = _title_usage_from_response(resp)
assert (prompt, completion, cost) == (0, 0, None)
class TestRecordTitleGenerationCost:
"""``_record_title_generation_cost`` persists cost + picks the right
provider label and skips the DB roundtrip when nothing's meaningful
to record."""
@pytest.mark.asyncio
async def test_openrouter_base_url_uses_open_router_provider(self):
resp = _build_completion(usage=_usage_with_cost(0.0002))
persist = AsyncMock(return_value=0)
with (
patch(
"backend.copilot.service.persist_and_record_usage",
new=persist,
),
patch(
"backend.copilot.service.config",
MagicMock(
base_url="https://openrouter.ai/api/v1",
title_model="anthropic/claude-haiku",
),
),
):
await _record_title_generation_cost(
response=resp, user_id="u", session_id="s"
)
persist.assert_awaited_once()
kwargs = persist.await_args.kwargs
assert kwargs["provider"] == "open_router"
assert kwargs["model"] == "anthropic/claude-haiku"
assert kwargs["prompt_tokens"] == 12
assert kwargs["completion_tokens"] == 3
assert kwargs["cost_usd"] == pytest.approx(0.0002)
assert kwargs["log_prefix"] == "[title]"
assert kwargs["session"] is None
@pytest.mark.asyncio
async def test_non_openrouter_base_url_uses_openai_provider(self):
resp = _build_completion(usage=_usage_with_cost(0.0002))
persist = AsyncMock(return_value=0)
with (
patch(
"backend.copilot.service.persist_and_record_usage",
new=persist,
),
patch(
"backend.copilot.service.config",
MagicMock(
base_url="https://api.openai.com/v1",
title_model="gpt-4o-mini",
),
),
):
await _record_title_generation_cost(
response=resp, user_id="u", session_id="s"
)
persist.assert_awaited_once()
assert persist.await_args.kwargs["provider"] == "openai"
@pytest.mark.asyncio
async def test_empty_base_url_uses_openai_provider(self):
resp = _build_completion(usage=_usage_with_cost(0.0001))
persist = AsyncMock(return_value=0)
with (
patch(
"backend.copilot.service.persist_and_record_usage",
new=persist,
),
patch(
"backend.copilot.service.config",
MagicMock(base_url=None, title_model="gpt-4o-mini"),
),
):
await _record_title_generation_cost(
response=resp, user_id=None, session_id=None
)
persist.assert_awaited_once()
assert persist.await_args.kwargs["provider"] == "openai"
@pytest.mark.asyncio
async def test_zero_tokens_zero_cost_skips_persist(self):
"""No cost, no tokens — the early return avoids a worthless
``PlatformCostLog`` row."""
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
resp = _build_completion(usage=usage)
persist = AsyncMock(return_value=0)
with (
patch(
"backend.copilot.service.persist_and_record_usage",
new=persist,
),
patch(
"backend.copilot.service.config",
MagicMock(
base_url="https://openrouter.ai/api/v1",
title_model="x",
),
),
):
await _record_title_generation_cost(
response=resp, user_id="u", session_id="s"
)
persist.assert_not_awaited()
@pytest.mark.asyncio
async def test_usage_none_skips_persist(self):
"""``usage`` absent on the response == provider didn't report —
still short-circuits to avoid writing a zero-valued row."""
resp = _build_completion(usage=None)
persist = AsyncMock(return_value=0)
with (
patch(
"backend.copilot.service.persist_and_record_usage",
new=persist,
),
patch(
"backend.copilot.service.config",
MagicMock(
base_url="https://openrouter.ai/api/v1",
title_model="x",
),
),
):
await _record_title_generation_cost(
response=resp, user_id="u", session_id="s"
)
persist.assert_not_awaited()
@pytest.mark.asyncio
async def test_tokens_without_cost_still_records(self):
"""Tokens present but ``cost`` missing (non-OR route) still
records a row so token counts are captured — ``cost_usd=None``
is accepted by ``persist_and_record_usage``."""
usage = CompletionUsage(prompt_tokens=8, completion_tokens=2, total_tokens=10)
resp = _build_completion(usage=usage)
persist = AsyncMock(return_value=0)
with (
patch(
"backend.copilot.service.persist_and_record_usage",
new=persist,
),
patch(
"backend.copilot.service.config",
MagicMock(base_url=None, title_model="m"),
),
):
await _record_title_generation_cost(
response=resp, user_id="u", session_id="s"
)
persist.assert_awaited_once()
assert persist.await_args.kwargs["cost_usd"] is None
assert persist.await_args.kwargs["prompt_tokens"] == 8
class TestUpdateTitleAsync:
"""``_update_title_async`` runs title persistence and cost recording
as independent best-effort steps — a failure in one does NOT
cancel the other."""
@pytest.mark.asyncio
async def test_title_success_cost_success(self):
resp = _build_completion(usage=_usage_with_cost(0.0001))
gen = AsyncMock(return_value=("My Title", resp))
update = AsyncMock(return_value=True)
record = AsyncMock()
with (
patch(
"backend.copilot.service._generate_session_title",
new=gen,
),
patch(
"backend.copilot.service.update_session_title",
new=update,
),
patch(
"backend.copilot.service._record_title_generation_cost",
new=record,
),
):
await _update_title_async("sess-1", "hello", user_id="u1")
update.assert_awaited_once_with("sess-1", "u1", "My Title", only_if_empty=True)
record.assert_awaited_once()
assert record.await_args.kwargs["response"] is resp
assert record.await_args.kwargs["user_id"] == "u1"
assert record.await_args.kwargs["session_id"] == "sess-1"
@pytest.mark.asyncio
async def test_title_persist_fails_but_cost_still_recorded(self):
resp = _build_completion(usage=_usage_with_cost(0.0001))
gen = AsyncMock(return_value=("Title", resp))
update = AsyncMock(side_effect=RuntimeError("prisma boom"))
record = AsyncMock()
with (
patch(
"backend.copilot.service._generate_session_title",
new=gen,
),
patch(
"backend.copilot.service.update_session_title",
new=update,
),
patch(
"backend.copilot.service._record_title_generation_cost",
new=record,
),
):
# Must NOT raise — persist failure is swallowed.
await _update_title_async("sess-2", "msg", user_id="u")
update.assert_awaited_once()
record.assert_awaited_once()
@pytest.mark.asyncio
async def test_cost_record_fails_but_title_was_persisted(self):
resp = _build_completion(usage=_usage_with_cost(0.0001))
gen = AsyncMock(return_value=("Title", resp))
update = AsyncMock(return_value=True)
record = AsyncMock(side_effect=RuntimeError("cost record boom"))
with (
patch(
"backend.copilot.service._generate_session_title",
new=gen,
),
patch(
"backend.copilot.service.update_session_title",
new=update,
),
patch(
"backend.copilot.service._record_title_generation_cost",
new=record,
),
):
# Must NOT raise — cost-recording failure is swallowed.
await _update_title_async("sess-3", "msg", user_id="u")
update.assert_awaited_once()
record.assert_awaited_once()
@pytest.mark.asyncio
async def test_no_user_id_skips_title_persist_but_records_cost(self):
"""Anonymous sessions skip the user-scoped title write, but we
still paid for the LLM call — cost recording runs regardless."""
resp = _build_completion(usage=_usage_with_cost(0.0001))
gen = AsyncMock(return_value=("Title", resp))
update = AsyncMock()
record = AsyncMock()
with (
patch(
"backend.copilot.service._generate_session_title",
new=gen,
),
patch(
"backend.copilot.service.update_session_title",
new=update,
),
patch(
"backend.copilot.service._record_title_generation_cost",
new=record,
),
):
await _update_title_async("sess-4", "msg", user_id=None)
update.assert_not_awaited()
record.assert_awaited_once()
@pytest.mark.asyncio
async def test_generation_returns_none_response_skips_cost(self):
"""``_generate_session_title`` swallows exceptions and returns
``(None, None)`` — no response means no cost to record."""
gen = AsyncMock(return_value=(None, None))
update = AsyncMock()
record = AsyncMock()
with (
patch(
"backend.copilot.service._generate_session_title",
new=gen,
),
patch(
"backend.copilot.service.update_session_title",
new=update,
),
patch(
"backend.copilot.service._record_title_generation_cost",
new=record,
),
):
await _update_title_async("sess-5", "msg", user_id="u")
update.assert_not_awaited()
record.assert_not_awaited()
@pytest.mark.asyncio
async def test_empty_title_with_response_still_records_cost(self):
"""Title came back empty but we still paid for the LLM call —
cost recording runs even though the title write is skipped."""
resp = _build_completion(usage=_usage_with_cost(0.0001))
gen = AsyncMock(return_value=(None, resp))
update = AsyncMock()
record = AsyncMock()
with (
patch(
"backend.copilot.service._generate_session_title",
new=gen,
),
patch(
"backend.copilot.service.update_session_title",
new=update,
),
patch(
"backend.copilot.service._record_title_generation_cost",
new=record,
),
):
await _update_title_async("sess-6", "msg", user_id="u")
update.assert_not_awaited()
record.assert_awaited_once()
class TestGenerateSessionTitle:
"""``_generate_session_title`` returns ``(title, response)`` — the
caller owns both the persist and the cost-record decisions."""
@pytest.mark.asyncio
async def test_valid_response_returns_cleaned_title_and_response(self):
# Code strips whitespace, then strips ``"'`` — whitespace inside
# the quotes survives on purpose (titles like ``My Agent`` read
# better than ``MyAgent``). Test keeps the outer quotes + inner
# whitespace distinct so the ordering is pinned.
resp = _build_completion(content='"Clean Me" ')
client = MagicMock()
client.chat.completions.create = AsyncMock(return_value=resp)
with patch(
"backend.copilot.service._get_openai_client",
return_value=client,
):
title, response = await _generate_session_title(
"first message", user_id="u", session_id="s"
)
assert title == "Clean Me"
assert response is resp
@pytest.mark.asyncio
async def test_long_title_truncated_with_ellipsis(self):
"""Titles >50 chars get truncated to 47 + '...'."""
long_title = "A" * 80
resp = _build_completion(content=long_title)
client = MagicMock()
client.chat.completions.create = AsyncMock(return_value=resp)
with patch(
"backend.copilot.service._get_openai_client",
return_value=client,
):
title, _ = await _generate_session_title("x", user_id=None)
assert title is not None
assert len(title) == 50
assert title.endswith("...")
@pytest.mark.asyncio
async def test_empty_choices_returns_none_title_with_response(self):
"""No ``choices`` on the response (shouldn't happen per SDK
typing) must not raise IndexError — response is preserved so the
caller can still record the paid-for cost."""
resp = _build_completion(choices=[])
client = MagicMock()
client.chat.completions.create = AsyncMock(return_value=resp)
with patch(
"backend.copilot.service._get_openai_client",
return_value=client,
):
title, response = await _generate_session_title("x")
assert title is None
assert response is resp
@pytest.mark.asyncio
async def test_missing_message_returns_none_title(self):
"""A choice whose ``.message`` is absent produces a None title
but the response still lands on the caller."""
fake_choice = SimpleNamespace(message=None)
fake_response = SimpleNamespace(choices=[fake_choice])
client = MagicMock()
client.chat.completions.create = AsyncMock(return_value=fake_response)
with patch(
"backend.copilot.service._get_openai_client",
return_value=client,
):
title, response = await _generate_session_title("x")
assert title is None
assert response is fake_response
@pytest.mark.asyncio
async def test_llm_call_raises_returns_none_none(self):
"""Network / API errors on the create call are swallowed;
``(None, None)`` ensures the caller skips both title and cost
without crashing the background task."""
client = MagicMock()
client.chat.completions.create = AsyncMock(
side_effect=RuntimeError("connection reset")
)
with patch(
"backend.copilot.service._get_openai_client",
return_value=client,
):
title, response = await _generate_session_title("x")
assert title is None
assert response is None
@pytest.mark.asyncio
async def test_create_receives_usage_include_extra_body(self):
"""PR adds ``usage: {'include': True}`` so OpenRouter embeds the
real billed cost into the final usage chunk."""
resp = _build_completion(content="Title")
client = MagicMock()
client.chat.completions.create = AsyncMock(return_value=resp)
with patch(
"backend.copilot.service._get_openai_client",
return_value=client,
):
await _generate_session_title(
"hello world", user_id="user-abc", session_id="sess-abc"
)
client.chat.completions.create.assert_awaited_once()
extra_body = client.chat.completions.create.await_args.kwargs["extra_body"]
assert extra_body["usage"] == {"include": True}
assert extra_body["user"] == "user-abc"
assert extra_body["session_id"] == "sess-abc"

View File

@@ -485,9 +485,11 @@ async def subscribe_to_session(
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
stream_key = _get_turn_stream_key(session.turn_id)
# Step 1: Replay messages from Redis Stream
# Replay batch capped by ``stream_replay_count``.
xread_start = time.perf_counter()
messages = await redis.xread({stream_key: last_message_id}, block=None, count=1000)
messages = await redis.xread(
{stream_key: last_message_id}, block=None, count=config.stream_replay_count
)
xread_time = (time.perf_counter() - xread_start) * 1000
logger.info(
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={session_status}",
@@ -1024,8 +1026,8 @@ async def get_active_session(
# Check if session is stale (running beyond tool timeout + buffer).
# Auto-complete it to prevent infinite polling loops.
# Synchronous tools can run up to COPILOT_CONSUMER_TIMEOUT_SECONDS (1 hour),
# so we add a 5-minute buffer to avoid false positives during legitimate operations.
# A turn can legitimately run up to COPILOT_CONSUMER_TIMEOUT_SECONDS, so we
# add a 5-minute buffer to avoid false positives during legitimate operations.
created_at_str = meta.get("created_at")
if created_at_str:
try:

View File

@@ -1,7 +1,8 @@
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Literal
from openai.types.chat import ChatCompletionToolParam
@@ -10,7 +11,6 @@ from backend.copilot.tracking import track_tool_called
from .add_understanding import AddUnderstandingTool
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
from .agent_output import AgentOutputTool
from .ask_question import AskQuestionTool
from .base import BaseTool
from .bash_exec import BashExecTool
from .connect_integration import ConnectIntegrationTool
@@ -44,6 +44,7 @@ from .run_block import RunBlockTool
from .run_mcp_tool import RunMCPToolTool
from .run_sub_session import RunSubSessionTool
from .search_docs import SearchDocsTool
from .todo_write import TodoWriteTool
from .validate_agent import ValidateAgentGraphTool
from .web_fetch import WebFetchTool
from .web_search import WebSearchTool
@@ -63,7 +64,6 @@ logger = logging.getLogger(__name__)
# Single source of truth for all tools
TOOL_REGISTRY: dict[str, BaseTool] = {
"add_understanding": AddUnderstandingTool(),
"ask_question": AskQuestionTool(),
"create_agent": CreateAgentTool(),
"customize_agent": CustomizeAgentTool(),
"decompose_goal": DecomposeGoalTool(),
@@ -88,6 +88,7 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
"continue_run_block": ContinueRunBlockTool(),
"run_sub_session": RunSubSessionTool(),
"get_sub_session_result": GetSubSessionResultTool(),
"TodoWrite": TodoWriteTool(),
"run_mcp_tool": RunMCPToolTool(),
"get_mcp_guide": GetMCPGuideTool(),
"view_agent_output": AgentOutputTool(),
@@ -123,15 +124,45 @@ find_agent_tool = TOOL_REGISTRY["find_agent"]
run_agent_tool = TOOL_REGISTRY["run_agent"]
def get_available_tools() -> list[ChatCompletionToolParam]:
# Capability groups a tool may belong to. The service layer can hide all
# tools in a group when the backing capability isn't available to this user
# (e.g. Graphiti memory behind a feature flag), so the model doesn't reach
# for tools whose backend is off and then hit opaque runtime errors. Add
# a new group by extending ``ToolGroup`` and registering its members in
# ``TOOL_GROUPS`` below.
ToolGroup = Literal["graphiti"]
TOOL_GROUPS: dict[str, ToolGroup] = {
"memory_store": "graphiti",
"memory_search": "graphiti",
"memory_forget_search": "graphiti",
"memory_forget_confirm": "graphiti",
}
def tool_names_in_groups(groups: Iterable[ToolGroup]) -> frozenset[str]:
"""Return the set of tool short-names belonging to any of *groups*."""
group_set = frozenset(groups)
return frozenset(name for name, g in TOOL_GROUPS.items() if g in group_set)
def get_available_tools(
*,
disabled_groups: Iterable[ToolGroup] = (),
) -> list[ChatCompletionToolParam]:
"""Return OpenAI tool schemas for tools available in the current environment.
Called per-request so that env-var or binary availability is evaluated
fresh each time (e.g. browser_* tools are excluded when agent-browser
CLI is not installed).
CLI is not installed). Tools belonging to any *disabled_groups* are
also filtered out — use this to hide capability-gated tools (e.g.
``graphiti`` when the memory backend is off for the current user).
"""
hidden = tool_names_in_groups(disabled_groups)
return [
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
tool.as_openai_tool()
for name, tool in TOOL_REGISTRY.items()
if tool.is_available and name not in hidden
]

View File

@@ -181,7 +181,9 @@ async def execute_block(
# (e.g., "42" → 42, string booleans → bool, enum defaults applied).
coerce_inputs_to_schema(input_data, block.input_schema)
outputs: dict[str, list[Any]] = defaultdict(list)
async for output_name, output_data in simulate_block(block, input_data):
async for output_name, output_data in simulate_block(
block, input_data, user_id=user_id
):
outputs[output_name].append(output_data)
# simulator signals internal failure via ("error", "[SIMULATOR ERROR …]")
sim_error = outputs.get("error", [])

View File

@@ -91,6 +91,9 @@ class ResponseType(str, Enum):
MEMORY_FORGET_CANDIDATES = "memory_forget_candidates"
MEMORY_FORGET_CONFIRM = "memory_forget_confirm"
# Planning
TODO_WRITE = "todo_write"
# Base response model
class ToolResponseBase(BaseModel):
@@ -605,11 +608,13 @@ class WebSearchResponse(ToolResponseBase):
type: ResponseType = ResponseType.WEB_SEARCH
query: str
# Web-grounded synthesised answer the search provider wrote from
# fresh page content. The LLM caller should read this directly
# instead of re-fetching each citation URL — many sites are
# bot-protected and ``web_fetch`` won't get through. Empty string
# when the provider returned only citations.
answer: str = ""
results: list[WebSearchResult] = Field(default_factory=list)
# Backend-reported usage for this call (copied from Anthropic's
# ``usage.server_tool_use``). Surfaces as metadata for frontend
# debug panels but is also what drives rate-limit / cost tracking
# via ``persist_and_record_usage(provider="anthropic")``.
search_requests: int = 0
@@ -896,3 +901,36 @@ class MemoryForgetConfirmResponse(ToolResponseBase):
type: ResponseType = ResponseType.MEMORY_FORGET_CONFIRM
deleted_uuids: list[str] = Field(default_factory=list)
failed_uuids: list[str] = Field(default_factory=list)
# --- Planning ---
class TodoItem(BaseModel):
"""One entry in a ``TodoWrite`` checklist.
Mirrors the schema used by Claude Code's built-in ``TodoWrite`` tool so
the frontend's ``GenericTool`` accordion renders baseline-emitted todos
identically to SDK-emitted ones.
"""
content: str = Field(description="Imperative description of the task.")
activeForm: str = Field(
description="Present-continuous form shown while the task is running.",
)
status: Literal["pending", "in_progress", "completed"] = Field(
default="pending",
)
class TodoWriteResponse(ToolResponseBase):
"""Ack returned by ``TodoWrite``.
The tool is effectively stateless — the authoritative task list lives in
the assistant's latest tool-call arguments, which are replayed from the
transcript on each turn. The tool output only needs to confirm that the
update was accepted so the model can proceed.
"""
type: ResponseType = ResponseType.TODO_WRITE
todos: list[TodoItem] = Field(default_factory=list)

View File

@@ -237,7 +237,7 @@ async def test_execute_block_dry_run_skips_real_execution():
mock_block = make_mock_block()
mock_block.execute = AsyncMock() # should NOT be called
async def fake_simulate(block, input_data):
async def fake_simulate(block, input_data, **_kwargs):
yield "result", "simulated"
# Patching at helpers.simulate_block works because helpers.py imports
@@ -267,7 +267,7 @@ async def test_execute_block_dry_run_response_format():
"""Dry-run response should look like a normal success (no dry-run signal to LLM)."""
mock_block = make_mock_block()
async def fake_simulate(block, input_data):
async def fake_simulate(block, input_data, **_kwargs):
yield "result", "simulated"
with patch(
@@ -331,7 +331,7 @@ async def test_execute_block_real_execution_unchanged():
# Just verify simulate_block is NOT called.
simulate_called = False
async def fake_simulate(block, input_data):
async def fake_simulate(block, input_data, **_kwargs):
nonlocal simulate_called
simulate_called = True
yield "result", "should not happen"
@@ -455,7 +455,7 @@ async def test_execute_block_dry_run_no_empty_error_from_simulator():
"""
mock_block = make_mock_block()
async def fake_simulate(block, input_data):
async def fake_simulate(block, input_data, **_kwargs):
# Simulator now omits empty error pins at source
yield "result", "simulated output"
@@ -485,7 +485,7 @@ async def test_execute_block_dry_run_keeps_nonempty_error_pin():
"""Dry-run should keep the 'error' pin when it contains a real error message."""
mock_block = make_mock_block()
async def fake_simulate(block, input_data):
async def fake_simulate(block, input_data, **_kwargs):
yield "result", ""
yield "error", "API rate limit exceeded"
@@ -515,7 +515,7 @@ async def test_execute_block_dry_run_message_includes_completed_status():
"""Dry-run message should clearly indicate COMPLETED status."""
mock_block = make_mock_block()
async def fake_simulate(block, input_data):
async def fake_simulate(block, input_data, **_kwargs):
yield "result", "simulated"
with patch(
@@ -541,7 +541,7 @@ async def test_execute_block_dry_run_simulator_error_returns_error_response():
"""When simulate_block yields a SIMULATOR ERROR tuple, execute_block returns ErrorResponse."""
mock_block = make_mock_block()
async def fake_simulate_error(block, input_data):
async def fake_simulate_error(block, input_data, **_kwargs):
yield (
"error",
"[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available (missing OpenAI/OpenRouter API key).",

View File

@@ -0,0 +1,120 @@
"""Task-list tool for baseline copilot mode.
Mirrors the schema and UX of Claude Code's built-in ``TodoWrite`` tool so
the frontend's generic tool renderer draws baseline-emitted checklists the
same way it draws SDK-emitted ones. The tool is stateless: the model's
latest ``todos`` argument IS the canonical list, replayed from transcript
on subsequent turns.
Baseline needs this as a platform tool because OpenAI-compatible providers
(Kimi, GPT, Grok, Gemini) do not ship a built-in equivalent. The SDK path
continues to use the CLI's native ``TodoWrite`` — the MCP-wrapped version
of this tool is filtered out of SDK's allowed_tools list (see
``sdk/tool_adapter.py``) to avoid name shadowing.
"""
from __future__ import annotations
import logging
from typing import Any
from backend.copilot.model import ChatSession
from .base import BaseTool
from .models import ErrorResponse, TodoItem, TodoWriteResponse, ToolResponseBase
logger = logging.getLogger(__name__)
class TodoWriteTool(BaseTool):
"""Maintain a step-by-step task checklist visible to the user."""
@property
def name(self) -> str:
# Capitalised to match the frontend's switch on ``"TodoWrite"``
# (see ``copilot/tools/GenericTool/helpers.ts``).
return "TodoWrite"
@property
def description(self) -> str:
return (
"Plan and track multi-step work as a visible checklist. Send "
"the full list every call; exactly one item in_progress at a time."
)
@property
def parameters(self) -> dict[str, Any]:
return {
"type": "object",
"properties": {
"todos": {
"type": "array",
"description": "Full updated task list (not a delta).",
"items": {
"type": "object",
"properties": {
"content": {
"type": "string",
"description": "Imperative (e.g. 'Run tests').",
},
"activeForm": {
"type": "string",
"description": (
"Present-continuous (e.g. 'Running tests')."
),
},
"status": {
"type": "string",
"enum": ["pending", "in_progress", "completed"],
"default": "pending",
},
},
"required": ["content", "activeForm"],
},
},
},
"required": ["todos"],
}
async def _execute(
self,
user_id: str | None,
session: ChatSession,
**kwargs: Any,
) -> ToolResponseBase:
del user_id
raw_todos = kwargs.get("todos")
if raw_todos is None:
return ErrorResponse(
message="`todos` is required.",
session_id=session.session_id,
)
if not isinstance(raw_todos, list):
return ErrorResponse(
message="`todos` must be an array.",
session_id=session.session_id,
)
try:
parsed = [TodoItem.model_validate(item) for item in raw_todos]
except Exception as exc:
return ErrorResponse(
message=f"Invalid todo entry: {exc}",
session_id=session.session_id,
)
in_progress = sum(1 for t in parsed if t.status == "in_progress")
if in_progress > 1:
return ErrorResponse(
message=(
"Only one todo may be 'in_progress' at a time "
f"(found {in_progress})."
),
session_id=session.session_id,
)
return TodoWriteResponse(
message="Task list updated.",
session_id=session.session_id,
todos=parsed,
)

View File

@@ -0,0 +1,125 @@
"""Tests for TodoWriteTool."""
import pytest
from backend.copilot.model import ChatSession
from backend.copilot.tools.models import ErrorResponse, TodoItem, TodoWriteResponse
from backend.copilot.tools.todo_write import TodoWriteTool
@pytest.fixture()
def tool() -> TodoWriteTool:
return TodoWriteTool()
@pytest.fixture()
def session() -> ChatSession:
return ChatSession.new(user_id="test-user", dry_run=False)
@pytest.mark.asyncio
async def test_valid_todo_list(tool: TodoWriteTool, session: ChatSession):
result = await tool._execute(
user_id=None,
session=session,
todos=[
{
"content": "Write tests",
"activeForm": "Writing tests",
"status": "pending",
},
{
"content": "Ship PR",
"activeForm": "Shipping PR",
"status": "in_progress",
},
],
)
assert isinstance(result, TodoWriteResponse)
assert result.session_id == session.session_id
assert len(result.todos) == 2
assert result.todos[0] == TodoItem(
content="Write tests",
activeForm="Writing tests",
status="pending",
)
assert result.todos[1].status == "in_progress"
@pytest.mark.asyncio
async def test_default_status_is_pending(tool: TodoWriteTool, session: ChatSession):
result = await tool._execute(
user_id=None,
session=session,
todos=[{"content": "Write tests", "activeForm": "Writing tests"}],
)
assert isinstance(result, TodoWriteResponse)
assert result.todos[0].status == "pending"
@pytest.mark.asyncio
async def test_missing_todos_returns_error(tool: TodoWriteTool, session: ChatSession):
result = await tool._execute(user_id=None, session=session)
assert isinstance(result, ErrorResponse)
assert "todos" in result.message.lower()
@pytest.mark.asyncio
async def test_non_list_todos_returns_error(tool: TodoWriteTool, session: ChatSession):
result = await tool._execute(user_id=None, session=session, todos="not a list")
assert isinstance(result, ErrorResponse)
@pytest.mark.asyncio
async def test_invalid_item_returns_error(tool: TodoWriteTool, session: ChatSession):
# Missing required `activeForm` field.
result = await tool._execute(
user_id=None,
session=session,
todos=[{"content": "Missing active form"}],
)
assert isinstance(result, ErrorResponse)
@pytest.mark.asyncio
async def test_multiple_in_progress_rejected(tool: TodoWriteTool, session: ChatSession):
"""Exactly one item should be in_progress at a time — SDK parity rule."""
result = await tool._execute(
user_id=None,
session=session,
todos=[
{
"content": "A",
"activeForm": "Doing A",
"status": "in_progress",
},
{
"content": "B",
"activeForm": "Doing B",
"status": "in_progress",
},
],
)
assert isinstance(result, ErrorResponse)
assert "in_progress" in result.message
def test_openai_schema_shape(tool: TodoWriteTool):
schema = tool.as_openai_tool()
assert schema["type"] == "function"
assert schema["function"]["name"] == "TodoWrite"
params = schema["function"]["parameters"]
assert params["required"] == ["todos"]
items = params["properties"]["todos"]["items"]
assert items["required"] == ["content", "activeForm"]
assert items["properties"]["status"]["enum"] == [
"pending",
"in_progress",
"completed",
]

View File

@@ -21,7 +21,9 @@ from backend.copilot.tools import TOOL_REGISTRY
# response shape carries) and the dry_run description. Keeps the
# regression gate effective while accepting a deliberate ~120-token
# spend on LLM-decision-critical copy.
_CHAR_BUDGET = 34_000
# Bumped to 34000 to accommodate decompose_goal tool + web_search +
# TodoWrite tool descriptions.
_CHAR_BUDGET = 35_000
@pytest.fixture(scope="module")
@@ -111,9 +113,10 @@ def test_total_schema_char_budget() -> None:
This locks in the 34% token reduction from #12398 and prevents future
description bloat from eroding the gains. Uses character count with a
~4 chars/token heuristic (budget of 32000 chars ≈ 8000 tokens).
Character count is tokenizer-agnostic — no dependency on GPT or Claude
tokenizers — while still providing a stable regression gate.
~4 chars/token heuristic; see ``_CHAR_BUDGET`` above for the current
value and its change history. Character count is tokenizer-agnostic
— no dependency on GPT or Claude tokenizers — while still providing a
stable regression gate.
"""
schemas = [tool.as_openai_tool() for tool in TOOL_REGISTRY.values()]
serialized = json.dumps(schemas)
@@ -123,3 +126,60 @@ def test_total_schema_char_budget() -> None:
f"exceeding budget of {_CHAR_BUDGET} chars (~{_CHAR_BUDGET // 4} tokens). "
f"Description bloat detected — trim descriptions or raise the budget intentionally."
)
# ── Capability-group filtering (ToolGroup / disabled_groups) ───────────
def test_get_available_tools_hides_graphiti_when_disabled() -> None:
"""When the ``graphiti`` group is disabled, the memory_* tools must
not appear in the OpenAI schema list — they'd just confuse the model
and produce opaque runtime errors."""
from backend.copilot.tools import get_available_tools
memory_tool_names = {
"memory_store",
"memory_search",
"memory_forget_search",
"memory_forget_confirm",
}
default = {t["function"]["name"] for t in get_available_tools()}
assert memory_tool_names.issubset(
default
), "sanity: memory_* tools should be present when no groups disabled"
filtered = {
t["function"]["name"] for t in get_available_tools(disabled_groups=["graphiti"])
}
assert not (
memory_tool_names & filtered
), f"graphiti disabled but memory_* still present: {memory_tool_names & filtered}"
# Non-graphiti tools stay visible.
assert "find_block" in filtered
assert "TodoWrite" in filtered
def test_get_copilot_tool_names_hides_graphiti_when_disabled() -> None:
"""Same invariant for the SDK tool-name list."""
from backend.copilot.sdk.tool_adapter import MCP_TOOL_PREFIX, get_copilot_tool_names
memory_mcp_names = {
f"{MCP_TOOL_PREFIX}memory_store",
f"{MCP_TOOL_PREFIX}memory_search",
f"{MCP_TOOL_PREFIX}memory_forget_search",
f"{MCP_TOOL_PREFIX}memory_forget_confirm",
}
default = set(get_copilot_tool_names())
assert memory_mcp_names.issubset(default)
filtered = set(get_copilot_tool_names(disabled_groups=["graphiti"]))
assert not (
memory_mcp_names & filtered
), f"graphiti disabled but memory MCP names still present: {memory_mcp_names & filtered}"
# E2B path stays consistent.
filtered_e2b = set(
get_copilot_tool_names(use_e2b=True, disabled_groups=["graphiti"])
)
assert not (memory_mcp_names & filtered_e2b)

View File

@@ -1,29 +1,66 @@
"""Web search tool — wraps Anthropic's server-side ``web_search`` beta.
"""Web search tool — Perplexity Sonar via OpenRouter.
Single entry point for web search on both SDK and baseline paths. The
``web_search_20250305`` tool is server-side on Anthropic, so we call
the Messages API directly regardless of which LLM invoked the copilot
tool — OpenRouter can't proxy server-side tool execution.
One provider, two tiers, one billing path:
* ``deep=False`` (default) — ``perplexity/sonar``. Searches the web
natively and returns citation annotations in a single inference pass.
* ``deep=True`` — ``perplexity/sonar-deep-research``. Multi-step
agentic research; slower and costlier.
Why Sonar and not the ``openrouter:web_search`` server tool + dispatch
model? The server tool feeds all search-result page content back into
the dispatch model for a second inference pass — one observed call was
74K input tokens at Gemini Flash rates, billing $0.072. Sonar
searches natively in one pass, returns annotations typed as
``ChatCompletionMessage.annotations`` in ``openai.types``, and at
$1 / MTok base pricing lands ~$0.01 / call at our default shape.
``resp.usage.cost`` carries the real billed value via OpenRouter's
``include: true`` extension; the value flows through
``persist_and_record_usage(provider='open_router')`` into the daily /
weekly microdollar rate-limit counter on the same rails as every other
OpenRouter turn — no separate provider ledger line, no estimation
drift. ``_extract_cost_usd`` mirrors the baseline service's
``_extract_usage_cost`` logic; keep the two in sync if one changes.
"""
import logging
import math
from typing import Any
from anthropic import AsyncAnthropic
from openai import AsyncOpenAI
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion
from backend.copilot.config import ChatConfig
from backend.copilot.model import ChatSession
from backend.copilot.token_tracking import persist_and_record_usage
from backend.util.settings import Settings
from .base import BaseTool
from .models import ErrorResponse, ToolResponseBase, WebSearchResponse, WebSearchResult
logger = logging.getLogger(__name__)
_WEB_SEARCH_DISPATCH_MODEL = "claude-haiku-4-5"
_MAX_DISPATCH_TOKENS = 512
_chat_config = ChatConfig()
_QUICK_MODEL = "perplexity/sonar"
# Sonar base can emit up to ~4K output; cap at the provider ceiling so the
# model stops when the answer is complete rather than when our budget trips.
_QUICK_MAX_TOKENS = 4096
_DEEP_MODEL = "perplexity/sonar-deep-research"
# Deep runs can produce long structured writeups — ~4x the quick ceiling
# is enough headroom for multi-source comparisons without uncapping.
_DEEP_MAX_TOKENS = _QUICK_MAX_TOKENS * 4
_DEFAULT_MAX_RESULTS = 5
_HARD_MAX_RESULTS = 20
_SNIPPET_MAX_CHARS = 500
# OpenRouter-specific extra_body flag that embeds the real generation
# cost into the response usage object. Same dict shape the baseline
# service uses — keep the two aligned.
_OPENROUTER_INCLUDE_USAGE_COST: dict[str, Any] = {"usage": {"include": True}}
class WebSearchTool(BaseTool):
@@ -36,9 +73,13 @@ class WebSearchTool(BaseTool):
@property
def description(self) -> str:
return (
"Search the web for live info (news, recent docs). Returns "
"{title, url, snippet}; use web_fetch to deep-dive a URL. "
"Prefer one targeted query over many reformulations."
"Search the web for live info (news, recent docs). Returns a "
"synthesised answer grounded in fresh page content plus "
"{title, url, snippet} citations — read the answer first "
"before reaching for web_fetch. Set deep=true when the user "
"asks for research / comparison / in-depth analysis; leave "
"deep=false for quick fact lookups. Prefer one targeted "
"query over many reformulations."
)
@property
@@ -58,6 +99,18 @@ class WebSearchTool(BaseTool):
),
"default": _DEFAULT_MAX_RESULTS,
},
"deep": {
"type": "boolean",
"description": (
"Only set true when the user EXPLICITLY asks for "
"research, comparison, or in-depth investigation "
"across many sources — it is ~100x more expensive "
"and much slower than a normal search. Default "
"false; do not flip it for ordinary fact lookups "
"or fresh-news questions."
),
"default": False,
},
},
"required": ["query"],
}
@@ -68,7 +121,7 @@ class WebSearchTool(BaseTool):
@property
def is_available(self) -> bool:
return bool(Settings().secrets.anthropic_api_key)
return bool(_chat_config.api_key and _chat_config.base_url)
async def _execute(
self,
@@ -76,6 +129,7 @@ class WebSearchTool(BaseTool):
session: ChatSession,
query: str = "",
max_results: int = _DEFAULT_MAX_RESULTS,
deep: bool = False,
**kwargs: Any,
) -> ToolResponseBase:
query = (query or "").strip()
@@ -93,44 +147,35 @@ class WebSearchTool(BaseTool):
max_results = _DEFAULT_MAX_RESULTS
max_results = max(1, min(max_results, _HARD_MAX_RESULTS))
api_key = Settings().secrets.anthropic_api_key
if not api_key:
if not _chat_config.api_key or not _chat_config.base_url:
return ErrorResponse(
message=(
"Web search is unavailable — the deployment has no "
"Anthropic API key configured."
"OpenRouter credentials configured."
),
error="web_search_not_configured",
session_id=session_id,
)
client = AsyncAnthropic(api_key=api_key)
client = AsyncOpenAI(
api_key=_chat_config.api_key, base_url=_chat_config.base_url
)
model_used = _DEEP_MODEL if deep else _QUICK_MODEL
max_tokens = _DEEP_MAX_TOKENS if deep else _QUICK_MAX_TOKENS
try:
resp = await client.messages.create(
model=_WEB_SEARCH_DISPATCH_MODEL,
max_tokens=_MAX_DISPATCH_TOKENS,
tools=[
{
"type": "web_search_20250305",
"name": "web_search",
"max_uses": 1,
}
],
messages=[
{
"role": "user",
"content": (
f"Use the web_search tool exactly once with the "
f"query {query!r} and then stop. Do not "
f"summarise — the caller parses the raw "
f"tool_result."
),
}
],
resp = await client.chat.completions.create(
model=model_used,
max_tokens=max_tokens,
messages=[{"role": "user", "content": query}],
extra_body=_OPENROUTER_INCLUDE_USAGE_COST,
)
except Exception as exc:
logger.warning(
"[web_search] Anthropic call failed for query=%r: %s", query, exc
"[web_search] OpenRouter call failed (deep=%s) for query=%r: %s",
deep,
query,
exc,
)
return ErrorResponse(
message=f"Web search failed: {exc}",
@@ -138,20 +183,20 @@ class WebSearchTool(BaseTool):
session_id=session_id,
)
results, search_requests = _extract_results(resp, limit=max_results)
answer = _extract_answer(resp)
results = _extract_results(resp, limit=max_results)
cost_usd = _extract_cost_usd(resp.usage)
cost_usd = _estimate_cost_usd(resp, search_requests=search_requests)
try:
usage = getattr(resp, "usage", None)
await persist_and_record_usage(
session=session,
user_id=user_id,
prompt_tokens=getattr(usage, "input_tokens", 0) or 0,
completion_tokens=getattr(usage, "output_tokens", 0) or 0,
prompt_tokens=resp.usage.prompt_tokens if resp.usage else 0,
completion_tokens=resp.usage.completion_tokens if resp.usage else 0,
log_prefix="[web_search]",
cost_usd=cost_usd,
model=_WEB_SEARCH_DISPATCH_MODEL,
provider="anthropic",
model=model_used,
provider="open_router",
)
except Exception as exc:
logger.warning("[web_search] usage tracking failed: %s", exc)
@@ -159,66 +204,92 @@ class WebSearchTool(BaseTool):
return WebSearchResponse(
message=f"Found {len(results)} result(s) for {query!r}.",
query=query,
answer=answer,
results=results,
search_requests=search_requests,
search_requests=1 if results else 0,
session_id=session_id,
)
def _extract_results(resp: Any, *, limit: int) -> tuple[list[WebSearchResult], int]:
"""Pull results + server-side request count from an Anthropic response."""
results: list[WebSearchResult] = []
search_requests = 0
def _extract_answer(resp: ChatCompletion) -> str:
"""Return the synthesised answer text from Sonar's response.
for block in getattr(resp, "content", []) or []:
btype = getattr(block, "type", None)
if btype == "web_search_tool_result":
content = getattr(block, "content", []) or []
for item in content:
if getattr(item, "type", None) != "web_search_result":
continue
if len(results) >= limit:
break
# Anthropic's ``web_search_result`` exposes only
# ``title``/``url``/``page_age`` plus an opaque
# ``encrypted_content`` blob that is meant for citation
# round-tripping, not for display — it is base64-ish
# binary and would show as gibberish if surfaced to the
# model or the frontend. There is no plain-text snippet
# field in the current beta; callers get the readable
# text via the model's ``text`` blocks with citations,
# not via this list. Leave ``snippet`` empty.
results.append(
WebSearchResult(
title=getattr(item, "title", "") or "",
url=getattr(item, "url", "") or "",
snippet="",
page_age=getattr(item, "page_age", None),
)
)
usage = getattr(resp, "usage", None)
server_tool_use = getattr(usage, "server_tool_use", None) if usage else None
if server_tool_use is not None:
search_requests = getattr(server_tool_use, "web_search_requests", 0) or 0
return results, search_requests
Sonar reads every page it cites and writes a web-grounded synthesis
into ``choices[0].message.content`` on the same call we pay for.
Surfacing it saves the agent from re-fetching citation URLs — many
are bot-protected and ``web_fetch`` can't reach them.
"""
if not resp.choices:
return ""
content = resp.choices[0].message.content
return content or ""
# Update when Anthropic revises pricing.
_COST_PER_SEARCH_USD = 0.010 # $10 per 1,000 web_search requests
_HAIKU_INPUT_USD_PER_MTOK = 1.0
_HAIKU_OUTPUT_USD_PER_MTOK = 5.0
def _extract_results(resp: ChatCompletion, *, limit: int) -> list[WebSearchResult]:
"""Pull ``url_citation`` annotations from the response.
Shared across both tiers — OpenRouter normalises the annotation
schema across Perplexity's sonar models into
``Annotation.url_citation`` (typed in ``openai.types.chat``). The
``content`` snippet is an OpenRouter extension on the otherwise-
typed ``AnnotationURLCitation``; pydantic stashes unknown fields in
``model_extra``, which we read there rather than via ``getattr``.
"""
if not resp.choices:
return []
annotations = resp.choices[0].message.annotations or []
out: list[WebSearchResult] = []
for ann in annotations:
if len(out) >= limit:
break
if ann.type != "url_citation":
continue
citation = ann.url_citation
extras = citation.model_extra or {}
snippet_raw = extras.get("content")
snippet = (snippet_raw or "")[:_SNIPPET_MAX_CHARS] if snippet_raw else ""
out.append(
WebSearchResult(
title=citation.title,
url=citation.url,
snippet=snippet,
page_age=None,
)
)
return out
def _estimate_cost_usd(resp: Any, *, search_requests: int) -> float:
"""Per-search fee × count + Haiku dispatch tokens."""
usage = getattr(resp, "usage", None)
input_tokens = getattr(usage, "input_tokens", 0) if usage else 0
output_tokens = getattr(usage, "output_tokens", 0) if usage else 0
def _extract_cost_usd(usage: CompletionUsage | None) -> float | None:
"""Return the provider-reported USD cost off the response usage.
search_cost = search_requests * _COST_PER_SEARCH_USD
inference_cost = (input_tokens / 1_000_000) * _HAIKU_INPUT_USD_PER_MTOK + (
output_tokens / 1_000_000
) * _HAIKU_OUTPUT_USD_PER_MTOK
return round(search_cost + inference_cost, 6)
OpenRouter piggybacks a ``cost`` field on the OpenAI-compatible
usage object when the request body includes
``usage: {"include": True}``. The OpenAI SDK's typed
``CompletionUsage`` does not declare it, so we read it off
``model_extra`` (the pydantic v2 container for extras) to keep
access fully typed — no ``getattr``. Mirrors the baseline service
``_extract_usage_cost``; keep the two in sync.
Returns ``None`` when the field is absent, null, non-numeric,
non-finite, or negative. Invalid values log at error level because
they indicate a provider bug worth chasing; plain absences are
silent so the caller can dedupe the "missing cost" warning.
"""
if usage is None:
return None
extras = usage.model_extra or {}
if "cost" not in extras:
return None
raw = extras["cost"]
if raw is None:
logger.error("[web_search] usage.cost is present but null")
return None
try:
val = float(raw)
except (TypeError, ValueError):
logger.error("[web_search] usage.cost is not numeric: %r", raw)
return None
if not math.isfinite(val) or val < 0:
logger.error("[web_search] usage.cost is non-finite or negative: %r", val)
return None
return val

View File

@@ -1,212 +1,289 @@
"""Tests for the ``web_search`` copilot tool.
Covers the result extractor + cost estimator as pure units (fed with
synthetic Anthropic response objects), plus light integration tests that
mock ``AsyncAnthropic.messages.create`` and confirm the handler plumbs
through to ``persist_and_record_usage`` with the right provider tag.
Covers the annotation extractor + cost extractor as pure units (fed
with real ``openai`` SDK types — no duck-typed ``SimpleNamespace``
stand-ins), plus integration tests exercising both the quick
(``perplexity/sonar``) and deep (``perplexity/sonar-deep-research``)
paths — mocking ``AsyncOpenAI.chat.completions.create`` and confirming
the handler plumbs through to ``persist_and_record_usage`` with
``provider='open_router'`` and the real ``usage.cost`` value.
"""
from types import SimpleNamespace
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message import (
Annotation,
AnnotationURLCitation,
ChatCompletionMessage,
)
from backend.copilot.model import ChatSession
from .models import ErrorResponse, WebSearchResponse, WebSearchResult
from .models import ErrorResponse, WebSearchResponse
from .web_search import (
_COST_PER_SEARCH_USD,
WebSearchTool,
_estimate_cost_usd,
_extract_answer,
_extract_cost_usd,
_extract_results,
)
def _fake_anthropic_response(
def _usage(
*,
results: list[dict] | None = None,
search_requests: int = 1,
input_tokens: int = 120,
output_tokens: int = 40,
) -> SimpleNamespace:
"""Build a synthetic Anthropic Messages response.
prompt_tokens: int = 120,
completion_tokens: int = 40,
cost: object = 0.01,
) -> CompletionUsage:
"""Typed ``CompletionUsage`` with OpenRouter's ``cost`` extension
parked in ``model_extra`` — the same channel the production code
reads it from. ``model_construct`` preserves unknown fields;
``model_validate`` would drop them because ``CompletionUsage``
treats the schema as strict."""
payload: dict[str, Any] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
if cost is not None:
payload["cost"] = cost
return CompletionUsage.model_construct(None, **payload)
Matches the shape produced by ``client.messages.create`` when the
response includes a ``web_search_tool_result`` content block and
``usage.server_tool_use.web_search_requests`` on the turn meter.
"""
content = []
if results is not None:
content.append(
SimpleNamespace(
type="web_search_tool_result",
content=[
SimpleNamespace(
type="web_search_result",
title=r.get("title", "untitled"),
url=r.get("url", ""),
encrypted_content=r.get("snippet", ""),
page_age=r.get("page_age"),
)
for r in results
],
)
def _citation(*, url: str, title: str, content: str | None = None) -> Annotation:
"""Typed ``Annotation`` for a URL citation. ``content`` is an
OpenRouter extension on the otherwise-typed schema — goes into
``url_citation.model_extra`` when model_construct preserves it."""
payload: dict[str, Any] = {
"url": url,
"title": title,
"start_index": 0,
"end_index": len(title),
}
if content is not None:
payload["content"] = content
url_citation = AnnotationURLCitation.model_construct(None, **payload)
return Annotation(type="url_citation", url_citation=url_citation)
def _fake_response(
*,
citations: list[dict] | None = None,
answer: str = "ok",
prompt_tokens: int = 120,
completion_tokens: int = 40,
cost: object = 0.01,
) -> ChatCompletion:
"""Build a typed ``ChatCompletion`` shaped like an OpenRouter
response — typed end-to-end so the production code's attribute
access runs under the real SDK types in tests."""
annotations = [
_citation(
url=c.get("url", ""),
title=c.get("title", "untitled"),
content=c.get("content"),
)
usage = SimpleNamespace(
input_tokens=input_tokens,
output_tokens=output_tokens,
server_tool_use=SimpleNamespace(web_search_requests=search_requests),
for c in citations or []
]
message = ChatCompletionMessage.model_construct(
None,
role="assistant",
content=answer,
annotations=annotations,
)
choice = Choice.model_construct(
None,
index=0,
finish_reason="stop",
message=message,
)
return ChatCompletion.model_construct(
None,
id="cmpl-test",
object="chat.completion",
created=0,
model="perplexity/sonar",
choices=[choice],
usage=_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
cost=cost,
),
)
return SimpleNamespace(content=content, usage=usage)
class TestExtractResults:
"""The extractor is the only Anthropic-response-shape contact point;
pin its behaviour so an API shape change surfaces here first."""
"""Pin the annotation shape — a schema bump in the OpenAI SDK or
OpenRouter surfaces here first. Same extractor serves both tiers
because OpenRouter normalises annotations across models."""
def test_extracts_title_url_page_age_and_drops_encrypted_snippet(self):
# Anthropic's ``web_search_result`` ships an opaque
# ``encrypted_content`` blob that is not safe to surface —
# the extractor must drop it (snippet=="") regardless of
# whether the blob is non-empty.
resp = _fake_anthropic_response(
results=[
def test_extracts_title_url_and_content_snippet(self):
resp = _fake_response(
citations=[
{
"title": "Kimi K2.6 launch",
"url": "https://example.com/kimi",
"snippet": "EiJjbGF1ZGUtZW5jcnlwdGVkLWJsb2I=",
"page_age": "1 day",
"content": "Moonshot released K2.6 on 2026-04-20.",
},
{
"title": "OpenRouter pricing",
"url": "https://openrouter.ai/moonshotai/kimi-k2.6",
"snippet": "",
},
]
)
out, requests = _extract_results(resp, limit=10)
assert requests == 1
out = _extract_results(resp, limit=10)
assert len(out) == 2
assert out[0].title == "Kimi K2.6 launch"
assert out[0].url == "https://example.com/kimi"
assert out[0].snippet == ""
assert out[0].page_age == "1 day"
assert out[0].snippet.startswith("Moonshot released")
# Missing ``content`` extension → empty snippet rather than crash.
assert out[1].snippet == ""
def test_limit_caps_returned_results(self):
resp = _fake_anthropic_response(
results=[{"title": f"r{i}", "url": f"https://e/{i}"} for i in range(10)]
resp = _fake_response(
citations=[{"title": f"r{i}", "url": f"https://e/{i}"} for i in range(10)]
)
out, _ = _extract_results(resp, limit=3)
out = _extract_results(resp, limit=3)
assert len(out) == 3
assert [r.title for r in out] == ["r0", "r1", "r2"]
def test_missing_content_returns_empty(self):
resp = SimpleNamespace(content=[], usage=None)
out, requests = _extract_results(resp, limit=10)
assert out == []
assert requests == 0
def test_non_search_blocks_are_ignored(self):
resp = SimpleNamespace(
content=[
SimpleNamespace(type="text", text="Here's what I found..."),
SimpleNamespace(
type="web_search_tool_result",
content=[
SimpleNamespace(
type="web_search_result",
title="real",
url="https://real.example",
encrypted_content="body",
page_age=None,
)
],
),
],
usage=None,
def test_missing_choices_returns_empty(self):
resp = ChatCompletion.model_construct(
None,
id="cmpl-test",
object="chat.completion",
created=0,
model="perplexity/sonar",
choices=[],
usage=_usage(),
)
out, _ = _extract_results(resp, limit=10)
assert len(out) == 1 and out[0].title == "real"
assert _extract_results(resp, limit=10) == []
class TestEstimateCostUsd:
"""Pin the per-search fee + Haiku inference math — the pricing
constants in ``web_search.py`` are hard-coded (no live lookup) so a
drift between Anthropic's schedule and our constants must surface
in this test for the next reader to notice."""
def test_zero_searches_still_charges_inference(self):
resp = _fake_anthropic_response(results=[], search_requests=0)
cost = _estimate_cost_usd(resp, search_requests=0)
# Haiku at 1000 input / 5000 output tokens = tiny but non-zero.
assert 0 < cost < 0.001
def test_single_search_fee_dominates(self):
resp = _fake_anthropic_response(
results=[{"title": "x", "url": "https://e"}],
search_requests=1,
input_tokens=100,
output_tokens=20,
def test_extract_answer_returns_message_content(self):
resp = _fake_response(
answer="Sonar's synthesised, web-grounded answer text.",
citations=[{"title": "t", "url": "https://e"}],
)
cost = _estimate_cost_usd(resp, search_requests=1)
# ~$0.010 search + trivial inference — total still ~1 cent.
assert cost >= _COST_PER_SEARCH_USD
assert cost < _COST_PER_SEARCH_USD + 0.001
assert _extract_answer(resp) == "Sonar's synthesised, web-grounded answer text."
def test_three_searches_linear_in_count(self):
resp = _fake_anthropic_response(
results=[], search_requests=3, input_tokens=0, output_tokens=0
def test_extract_answer_returns_empty_when_no_choices(self):
resp = ChatCompletion.model_construct(
None,
id="cmpl-test",
object="chat.completion",
created=0,
model="perplexity/sonar",
choices=[],
usage=_usage(),
)
cost = _estimate_cost_usd(resp, search_requests=3)
assert cost == pytest.approx(3 * _COST_PER_SEARCH_USD)
assert _extract_answer(resp) == ""
def test_snippet_clamped_to_max_chars(self):
long_body = "x" * 5000
resp = _fake_response(
citations=[{"title": "t", "url": "https://e", "content": long_body}]
)
out = _extract_results(resp, limit=1)
assert len(out) == 1
assert len(out[0].snippet) == 500
class TestExtractCostUsd:
"""Read real ``usage.cost`` via typed ``model_extra`` — no
hard-coded rates, so a future provider price change is reflected
automatically. Error handling mirrors the baseline service's
``_extract_usage_cost``."""
def test_returns_cost_value(self):
assert _extract_cost_usd(_usage(cost=0.023456)) == pytest.approx(0.023456)
def test_returns_none_when_usage_missing(self):
assert _extract_cost_usd(None) is None
def test_returns_none_when_cost_field_missing(self):
assert _extract_cost_usd(_usage(cost=None)) is None
def test_returns_none_when_cost_is_explicit_null(self):
usage = CompletionUsage.model_construct(
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=None
)
assert _extract_cost_usd(usage) is None
def test_returns_none_when_cost_is_negative(self):
usage = CompletionUsage.model_construct(
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=-1.0
)
assert _extract_cost_usd(usage) is None
def test_accepts_numeric_string(self):
usage = CompletionUsage.model_construct(
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost="0.017"
)
assert _extract_cost_usd(usage) == pytest.approx(0.017)
class TestWebSearchToolDispatch:
"""Lightweight integration test: mock the Anthropic client, confirm
the handler returns a ``WebSearchResponse`` and the usage tracker is
called with ``provider='anthropic'`` (not 'open_router', even on the
baseline path — server-side web_search bills Anthropic regardless of
the calling LLM's route)."""
"""Integration test: mock the OpenAI client, confirm both paths
dispatch the right Sonar model + track cost."""
def _session(self) -> ChatSession:
s = ChatSession.new("test-user", dry_run=False)
s.session_id = "sess-1"
return s
@pytest.mark.asyncio
async def test_returns_response_with_results_and_tracks_cost(self, monkeypatch):
fake_resp = _fake_anthropic_response(
results=[
{
"title": "hello",
"url": "https://example.com",
"snippet": "greeting",
}
],
search_requests=1,
)
mock_client = type(
def _mock_client(self, fake_resp: ChatCompletion) -> Any:
return type(
"MC",
(),
{
"messages": type(
"M", (), {"create": AsyncMock(return_value=fake_resp)}
"chat": type(
"C",
(),
{
"completions": type(
"CC",
(),
{"create": AsyncMock(return_value=fake_resp)},
)()
},
)()
},
)()
# Stub the Anthropic API key so ``is_available`` is True.
@pytest.mark.asyncio
async def test_quick_path_uses_sonar_base(self, monkeypatch):
fake_resp = _fake_response(
citations=[
{
"title": "hello",
"url": "https://example.com",
"content": "greeting",
}
],
answer="Kimi K2.6 launched 2026-04-20 [1].",
cost=0.01,
)
mock_client = self._mock_client(fake_resp)
monkeypatch.setattr(
"backend.copilot.tools.web_search.Settings",
lambda: SimpleNamespace(
secrets=SimpleNamespace(anthropic_api_key="sk-test")
),
"backend.copilot.tools.web_search._chat_config",
type(
"C",
(),
{
"api_key": "sk-test",
"base_url": "https://openrouter.ai/api/v1",
},
)(),
)
with (
patch(
"backend.copilot.tools.web_search.AsyncAnthropic",
"backend.copilot.tools.web_search.AsyncOpenAI",
return_value=mock_client,
),
patch(
@@ -220,35 +297,88 @@ class TestWebSearchToolDispatch:
session=self._session(),
query="kimi k2.6 launch",
max_results=5,
deep=False,
)
assert isinstance(result, WebSearchResponse)
assert result.query == "kimi k2.6 launch"
assert result.answer == "Kimi K2.6 launched 2026-04-20 [1]."
assert len(result.results) == 1
assert isinstance(result.results[0], WebSearchResult)
assert result.search_requests == 1
assert result.results[0].snippet == "greeting"
create_call = mock_client.chat.completions.create.call_args
assert create_call.kwargs["model"] == "perplexity/sonar"
# Sonar searches natively — no server-tool extras.
assert create_call.kwargs["extra_body"] == {"usage": {"include": True}}
# Cost tracker must have been called with provider="anthropic".
assert mock_track.await_count == 1
kwargs = mock_track.await_args.kwargs
assert kwargs["provider"] == "anthropic"
assert kwargs["model"] == "claude-haiku-4-5"
assert kwargs["user_id"] == "u1"
assert kwargs["cost_usd"] >= _COST_PER_SEARCH_USD
assert kwargs["provider"] == "open_router"
assert kwargs["model"] == "perplexity/sonar"
assert kwargs["cost_usd"] == pytest.approx(0.01)
@pytest.mark.asyncio
async def test_missing_api_key_returns_error_without_calling_anthropic(
self, monkeypatch
):
monkeypatch.setattr(
"backend.copilot.tools.web_search.Settings",
lambda: SimpleNamespace(secrets=SimpleNamespace(anthropic_api_key="")),
async def test_deep_path_uses_sonar_deep_research(self, monkeypatch):
fake_resp = _fake_response(
citations=[
{
"title": "deep find",
"url": "https://example.com/deep",
"content": "research body",
}
],
cost=0.087,
)
anthropic_stub = AsyncMock()
mock_client = self._mock_client(fake_resp)
monkeypatch.setattr(
"backend.copilot.tools.web_search._chat_config",
type(
"C",
(),
{
"api_key": "sk-test",
"base_url": "https://openrouter.ai/api/v1",
},
)(),
)
with (
patch(
"backend.copilot.tools.web_search.AsyncAnthropic",
return_value=anthropic_stub,
"backend.copilot.tools.web_search.AsyncOpenAI",
return_value=mock_client,
),
patch(
"backend.copilot.tools.web_search.persist_and_record_usage",
new=AsyncMock(return_value=160),
) as mock_track,
):
tool = WebSearchTool()
result = await tool._execute(
user_id="u1",
session=self._session(),
query="research question",
deep=True,
)
assert isinstance(result, WebSearchResponse)
create_call = mock_client.chat.completions.create.call_args
assert create_call.kwargs["model"] == "perplexity/sonar-deep-research"
kwargs = mock_track.await_args.kwargs
assert kwargs["provider"] == "open_router"
assert kwargs["model"] == "perplexity/sonar-deep-research"
assert kwargs["cost_usd"] == pytest.approx(0.087)
@pytest.mark.asyncio
async def test_missing_credentials_returns_error(self, monkeypatch):
monkeypatch.setattr(
"backend.copilot.tools.web_search._chat_config",
type("C", (), {"api_key": "", "base_url": ""})(),
)
openai_stub = AsyncMock()
with (
patch(
"backend.copilot.tools.web_search.AsyncOpenAI",
return_value=openai_stub,
),
patch(
"backend.copilot.tools.web_search.persist_and_record_usage",
@@ -264,21 +394,26 @@ class TestWebSearchToolDispatch:
)
assert isinstance(result, ErrorResponse)
assert result.error == "web_search_not_configured"
anthropic_stub.messages.create.assert_not_called()
openai_stub.chat.completions.create.assert_not_called()
mock_track.assert_not_called()
@pytest.mark.asyncio
async def test_empty_query_rejected_without_api_call(self, monkeypatch):
monkeypatch.setattr(
"backend.copilot.tools.web_search.Settings",
lambda: SimpleNamespace(
secrets=SimpleNamespace(anthropic_api_key="sk-test")
),
"backend.copilot.tools.web_search._chat_config",
type(
"C",
(),
{
"api_key": "sk-test",
"base_url": "https://openrouter.ai/api/v1",
},
)(),
)
anthropic_stub = AsyncMock()
openai_stub = AsyncMock()
with patch(
"backend.copilot.tools.web_search.AsyncAnthropic",
return_value=anthropic_stub,
"backend.copilot.tools.web_search.AsyncOpenAI",
return_value=openai_stub,
):
tool = WebSearchTool()
result = await tool._execute(
@@ -286,13 +421,13 @@ class TestWebSearchToolDispatch:
)
assert isinstance(result, ErrorResponse)
assert result.error == "missing_query"
anthropic_stub.messages.create.assert_not_called()
openai_stub.chat.completions.create.assert_not_called()
class TestToolRegistryIntegration:
"""The tool must be registered under the ``web_search`` name so the
MCP layer exposes it as ``mcp__copilot__web_search`` — which is
what the SDK path now dispatches to (see
what the SDK path dispatches to (see
``sdk/tool_adapter.py::SDK_DISALLOWED_TOOLS`` which blocks the CLI's
native ``WebSearch`` in favour of the MCP route)."""

View File

@@ -195,16 +195,17 @@ def strip_stale_thinking_blocks(content: str) -> str:
is_last_turn = (
last_asst_msg_id is not None and msg.get("id") == last_asst_msg_id
) or (last_asst_msg_id is None and i == last_asst_idx)
if (
msg.get("role") == "assistant"
and not is_last_turn
and isinstance(msg.get("content"), list)
):
if msg.get("role") == "assistant" and isinstance(msg.get("content"), list):
content_blocks = msg["content"]
producing_model = msg.get("model") if isinstance(msg, dict) else None
filtered = [
b
for b in content_blocks
if not (isinstance(b, dict) and b.get("type") in _THINKING_BLOCK_TYPES)
if not _should_strip_thinking_block(
b,
is_last_turn=is_last_turn,
producing_model=producing_model,
)
]
if len(filtered) < len(content_blocks):
stripped_count += len(content_blocks) - len(filtered)
@@ -310,23 +311,30 @@ def strip_for_upload(content: str) -> str:
if uid in reparented:
needs_reserialize = True
# Strip stale thinking blocks from non-last assistant entries
# Strip stale thinking blocks from non-last assistant entries.
# Also strip *signature-less* thinking blocks from the last entry —
# those come from non-Anthropic providers (e.g. Kimi K2.6 via
# OpenRouter) and are rejected with ``Invalid `signature` in
# `thinking` block`` if a subsequent turn is dispatched to an
# Anthropic model that re-validates them. Anthropic-emitted
# thinking blocks always carry a non-empty ``signature`` field, so
# this filter is a no-op on Sonnet/Opus turns and only kicks in
# when the prior turn ran on a non-Anthropic vendor.
if last_asst_idx is not None:
msg = entry.get("message", {})
is_last_turn = (
last_asst_msg_id is not None and msg.get("id") == last_asst_msg_id
) or (last_asst_msg_id is None and i == last_asst_idx)
if (
msg.get("role") == "assistant"
and not is_last_turn
and isinstance(msg.get("content"), list)
):
if msg.get("role") == "assistant" and isinstance(msg.get("content"), list):
content_blocks = msg["content"]
producing_model = msg.get("model") if isinstance(msg, dict) else None
filtered = [
b
for b in content_blocks
if not (
isinstance(b, dict) and b.get("type") in _THINKING_BLOCK_TYPES
if not _should_strip_thinking_block(
b,
is_last_turn=is_last_turn,
producing_model=producing_model,
)
]
if len(filtered) < len(content_blocks):
@@ -951,6 +959,92 @@ ENTRY_TYPE_MESSAGE = "message"
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
def _is_anthropic_model(model: str | None) -> bool:
"""True when *model* is an Anthropic-issued slug.
Used to decide whether a thinking block's signature is
cryptographically valid for Anthropic replay. Non-Anthropic vendors
routed through OpenRouter's Anthropic-compat shim (Kimi K2.6,
DeepSeek, GPT-OSS) sometimes emit thinking blocks with a
placeholder signature — it passes a non-empty string check but
fails Anthropic's cryptographic validation, producing the opaque
``Invalid signature in thinking block`` 400 on the next turn
whenever the model toggle switches to Sonnet/Opus.
"""
return isinstance(model, str) and model.startswith("anthropic/")
def _should_strip_thinking_block(
block: object,
*,
is_last_turn: bool,
producing_model: str | None = None,
) -> bool:
"""Return True when *block* is a thinking block that should be removed
from a transcript entry before upload.
Strip only when the block CAN'T be replayed safely. Never strip a
valid Anthropic-issued thinking block — it carries real reasoning
state that preserves context continuity on ``--resume``.
Strip rules (first match wins):
1. **Non-Anthropic producer (any position)** — thinking blocks from
Kimi / DeepSeek / GPT-OSS via OpenRouter's Anthropic-compat shim
carry either no signature or a placeholder string that passes a
non-empty check but fails Anthropic's cryptographic validation.
Strip unconditionally; they also add low-value tokens to the
replay context.
2. **Malformed ``thinking`` (any position, Anthropic producer,
empty signature)** — shouldn't happen in practice, but if the
signature is missing / empty the block can't be validated.
Safer to drop than to 400 the next turn.
3. **Stale non-last entry with unknown producer** — when the
caller doesn't wire ``producing_model`` through (legacy paths /
older tests) we can't tell if the block is safe to keep; fall
back to the old behaviour of dropping non-last thinking blocks
to avoid replaying an unverifiable block to Anthropic.
Preserved:
* Anthropic ``thinking`` with non-empty signature — at any
position, last OR non-last. Keeping prior-turn reasoning
chains helps continuity on multi-round SDK resumes without any
risk of signature rejection.
* Anthropic ``redacted_thinking`` — carries an encrypted ``data``
payload instead of a ``signature``; by design signature-less,
but Anthropic-issued and safely replayable.
"""
if not isinstance(block, dict):
return False
btype = block.get("type")
if btype not in _THINKING_BLOCK_TYPES:
return False
# Legacy call sites pass producing_model=None — preserve the old
# "strip-all-non-last-thinking" heuristic for those so we don't
# regress callers that haven't been updated.
if producing_model is None:
if not is_last_turn:
return True
if btype != "thinking":
return False
signature = block.get("signature")
return not (isinstance(signature, str) and signature)
# Non-Anthropic producer — strip at any position. These blocks
# CAN'T be cryptographically validated by Anthropic on replay.
if not _is_anthropic_model(producing_model):
return True
# Anthropic producer, redacted_thinking: always preserve — the
# ``data`` field is the signature analog.
if btype == "redacted_thinking":
return False
# Anthropic producer, ``thinking``: keep iff it has a real
# (non-empty) signature. Empty-signature Anthropic thinking
# shouldn't happen but guard against it anyway.
signature = block.get("signature")
return not (isinstance(signature, str) and signature)
def _flatten_assistant_content(blocks: list) -> str:
"""Flatten assistant content blocks into a single plain-text string.

View File

@@ -591,7 +591,16 @@ class TestStripForUpload:
"role": "assistant",
"id": "msg_new",
"content": [
{"type": "thinking", "thinking": "fresh thinking"},
# Anthropic-style thinking block — has a signature so
# ``_should_strip_thinking_block`` preserves it on the
# last turn. Without the signature (e.g. emitted by
# Kimi K2.6 via OpenRouter) it would be stripped — see
# ``test_strips_signatureless_thinking_from_last_turn``.
{
"type": "thinking",
"thinking": "fresh thinking",
"signature": "anthropic-signed-blob",
},
{"type": "text", "text": "new answer"},
],
},
@@ -624,6 +633,224 @@ class TestStripForUpload:
new_types = [b["type"] for b in new_content if isinstance(b, dict)]
assert "thinking" in new_types # last assistant preserved
def test_strips_signatureless_thinking_from_last_turn(self):
"""Kimi K2.6 (and other non-Anthropic OpenRouter providers) emit
thinking blocks without the Anthropic ``signature`` field. When
a subsequent advanced-tier toggle replays the transcript to Opus,
Anthropic's API rejects the signature-less block with ``Invalid
`signature` in `thinking` block`` — so strip_for_upload must drop
them from the LAST assistant entry too, not just stale ones."""
user = {
"type": "user",
"uuid": "u1",
"parentUuid": "",
"message": {"role": "user", "content": "hi"},
}
# Last (and only) assistant entry with a Kimi-shape thinking block
asst = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_kimi",
"content": [
# No ``signature`` field → non-Anthropic provider
{"type": "thinking", "thinking": "kimi reasoning"},
{"type": "text", "text": "answer"},
],
},
}
content = _make_jsonl(user, asst)
result = strip_for_upload(content)
entries = [json.loads(line) for line in result.strip().split("\n")]
asst_entry = next(
e for e in entries if e.get("message", {}).get("id") == "msg_kimi"
)
types = [
b["type"] for b in asst_entry["message"]["content"] if isinstance(b, dict)
]
assert "thinking" not in types, (
"Signature-less thinking block on last turn must be stripped "
"to prevent Anthropic API rejection on model-switch replay"
)
assert "text" in types, "Text content must survive stripping"
def test_strips_non_anthropic_thinking_with_placeholder_signature(self):
"""OpenRouter's Anthropic-compat shim can emit thinking blocks
from non-Anthropic producers (Kimi K2.6, DeepSeek) with a
PLACEHOLDER signature string that passes the "non-empty string"
check but fails Anthropic's cryptographic validation on replay.
Observed in session 864a55ba after model-toggle from standard
(Kimi) to advanced (Opus): the CLI session upload included a
thinking block with ``signature="ANTHROPIC_SHIM_PLACEHOLDER"``
(or similar), Opus 4.7 rejected with 400 ``Invalid `signature`
in `thinking` block``. Fix: strip thinking blocks from the
LAST assistant turn whenever the producing ``model`` isn't an
``anthropic/*`` slug, regardless of signature presence."""
user = {
"type": "user",
"uuid": "u1",
"parentUuid": "",
"message": {"role": "user", "content": "hi"},
}
asst = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_kimi_shim",
"model": "moonshotai/kimi-k2.6-20260420",
"content": [
# Placeholder signature — non-empty but cryptographically
# invalid for Anthropic. Legacy strip (signature-only)
# would KEEP this block.
{
"type": "thinking",
"thinking": "shimmed reasoning",
"signature": "PLACEHOLDER_SHIM_SIG_abc123",
},
{"type": "text", "text": "answer"},
],
},
}
content = _make_jsonl(user, asst)
result = strip_for_upload(content)
entries = [json.loads(line) for line in result.strip().split("\n")]
asst_entry = next(
e for e in entries if e.get("message", {}).get("id") == "msg_kimi_shim"
)
types = [
b["type"] for b in asst_entry["message"]["content"] if isinstance(b, dict)
]
assert "thinking" not in types, (
"Non-Anthropic thinking block must be stripped even when it "
"carries a placeholder signature — replay-to-Opus otherwise "
"400s with Invalid signature"
)
assert "text" in types
def test_preserves_anthropic_thinking_on_non_last_turn(self):
"""Anthropic ``thinking`` blocks on NON-last turns carry real
reasoning state that helps context continuity on ``--resume``.
Keep them when the producing model is known-Anthropic with a
valid signature; strip only when we can't validate safely
(legacy callers with no model info — falls through to the
old stale-strip rule).
"""
user = {
"type": "user",
"uuid": "u1",
"parentUuid": "",
"message": {"role": "user", "content": "first"},
}
asst1 = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_opus_prev",
"model": "anthropic/claude-4.7-opus-20260416",
"content": [
{
"type": "thinking",
"thinking": "first-turn reasoning",
"signature": "ANTHROPIC_SIG_1",
},
{"type": "text", "text": "first answer"},
],
},
}
user2 = {
"type": "user",
"uuid": "u2",
"parentUuid": "a1",
"message": {"role": "user", "content": "second"},
}
asst2 = {
"type": "assistant",
"uuid": "a2",
"parentUuid": "u2",
"message": {
"role": "assistant",
"id": "msg_opus_last",
"model": "anthropic/claude-4.7-opus-20260416",
"content": [
{
"type": "thinking",
"thinking": "last-turn reasoning",
"signature": "ANTHROPIC_SIG_2",
},
{"type": "text", "text": "last answer"},
],
},
}
content = _make_jsonl(user, asst1, user2, asst2)
result = strip_for_upload(content)
entries = [json.loads(line) for line in result.strip().split("\n")]
# Prior Opus turn's thinking must survive — valid Anthropic
# block with signature.
prev = next(
e for e in entries if e.get("message", {}).get("id") == "msg_opus_prev"
)
prev_types = [b["type"] for b in prev["message"]["content"]]
assert "thinking" in prev_types, (
"Anthropic thinking block on a non-last turn must be "
"preserved — it carries real reasoning state"
)
# Last turn's thinking also preserved.
last = next(
e for e in entries if e.get("message", {}).get("id") == "msg_opus_last"
)
last_types = [b["type"] for b in last["message"]["content"]]
assert "thinking" in last_types
def test_preserves_anthropic_thinking_with_valid_signature(self):
"""Sanity: an Anthropic-issued thinking block with a real
signature on the last turn must NOT be stripped — Anthropic
requires value-identity on replay."""
user = {
"type": "user",
"uuid": "u1",
"parentUuid": "",
"message": {"role": "user", "content": "hi"},
}
asst = {
"type": "assistant",
"uuid": "a1",
"parentUuid": "u1",
"message": {
"role": "assistant",
"id": "msg_opus",
"model": "anthropic/claude-4.7-opus-20260416",
"content": [
{
"type": "thinking",
"thinking": "reasoning",
"signature": "REAL_ANTHROPIC_SIG_blob",
},
{"type": "text", "text": "answer"},
],
},
}
content = _make_jsonl(user, asst)
result = strip_for_upload(content)
entries = [json.loads(line) for line in result.strip().split("\n")]
asst_entry = next(
e for e in entries if e.get("message", {}).get("id") == "msg_opus"
)
types = [
b["type"] for b in asst_entry["message"]["content"] if isinstance(b, dict)
]
assert (
"thinking" in types
), "Anthropic-signed thinking on last turn must survive strip"
assert "text" in types
def test_empty_content(self):
result = strip_for_upload("")
# Empty string produces a single empty line after split, resulting in "\n"

View File

@@ -14,6 +14,21 @@ HOST = os.getenv("REDIS_HOST", "localhost")
PORT = int(os.getenv("REDIS_PORT", "6379"))
PASSWORD = os.getenv("REDIS_PASSWORD", None)
# Default socket timeouts so a wedged Redis endpoint can't hang callers
# indefinitely — long-running code paths (cluster_lock refresh in particular)
# rely on these to fail-fast instead of blocking on no-response TCP. Override
# via env if a specific deployment needs a different budget.
#
# 30s matches the convention in ``backend.data.rabbitmq`` and leaves ~6x
# headroom over the largest ``xread(block=5000)`` wait in stream_registry.
# The connect timeout is shorter (5s) because initial connects should be
# fast; a slow connect usually means the endpoint is genuinely unreachable.
SOCKET_TIMEOUT = float(os.getenv("REDIS_SOCKET_TIMEOUT", "30"))
SOCKET_CONNECT_TIMEOUT = float(os.getenv("REDIS_SOCKET_CONNECT_TIMEOUT", "5"))
# How often redis-py sends a PING on idle connections to detect half-open
# sockets; cheap and avoids waiting for the OS TCP keepalive (~2h default).
HEALTH_CHECK_INTERVAL = int(os.getenv("REDIS_HEALTH_CHECK_INTERVAL", "30"))
logger = logging.getLogger(__name__)
@@ -24,6 +39,10 @@ def connect() -> Redis:
port=PORT,
password=PASSWORD,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
socket_keepalive=True,
health_check_interval=HEALTH_CHECK_INTERVAL,
)
c.ping()
return c
@@ -46,6 +65,10 @@ async def connect_async() -> AsyncRedis:
port=PORT,
password=PASSWORD,
decode_responses=True,
socket_timeout=SOCKET_TIMEOUT,
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
socket_keepalive=True,
health_check_interval=HEALTH_CHECK_INTERVAL,
)
await c.ping()
return c

View File

@@ -366,7 +366,7 @@ async def execute_node(
try:
if execution_context.dry_run and _dry_run_input is None:
block_iter = simulate_block(node_block, input_data)
block_iter = simulate_block(node_block, input_data, user_id=user_id)
else:
block_iter = node_block.execute(input_data, **extra_exec_kwargs)

View File

@@ -31,21 +31,31 @@ Inspired by https://github.com/Significant-Gravitas/agent-simulator
import inspect
import json
import logging
import math
from collections.abc import AsyncGenerator
from typing import Any
from openai.types import CompletionUsage
from backend.blocks.agent import AgentExecutorBlock
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
from backend.blocks.orchestrator import OrchestratorBlock
from backend.copilot.token_tracking import persist_and_record_usage
from backend.util.clients import get_openai_client
logger = logging.getLogger(__name__)
# Default simulator model — Gemini 2.5 Flash via OpenRouter (fast, cheap, good at
# JSON generation). Configurable via ChatConfig.simulation_model
# (CHAT_SIMULATION_MODEL env var).
_DEFAULT_SIMULATOR_MODEL = "google/gemini-2.5-flash"
# Default simulator model — Gemini 2.5 Flash-Lite via OpenRouter. Same provider
# as Flash ($0.10 / $0.40 per MTok vs $0.30 / $1.20 — ~3× cheaper) with JSON-mode
# reliability that's more than enough for dry-run shape-matching. Configurable
# via ChatConfig.simulation_model (CHAT_SIMULATION_MODEL env var).
_DEFAULT_SIMULATOR_MODEL = "google/gemini-2.5-flash-lite"
# OpenRouter-specific extra_body flag that embeds the real generation cost on
# the response usage object. Same shape used by the baseline copilot service
# and web_search tool — keep the three aligned.
_OPENROUTER_INCLUDE_USAGE_COST: dict[str, Any] = {"usage": {"include": True}}
def _simulator_model() -> str:
@@ -105,10 +115,15 @@ async def _call_llm_for_simulation(
user_prompt: str,
*,
label: str = "simulate",
user_id: str | None = None,
) -> dict[str, Any]:
"""Send a simulation prompt to the LLM and return the parsed JSON dict.
Handles client acquisition, retries on invalid JSON, and logging.
Handles client acquisition, retries on invalid JSON, logging, and platform
cost tracking. The dry-run simulator calls OpenRouter on the platform's
key rather than a user's own API credentials, so every successful call is
recorded against the triggering ``user_id``'s rate-limit counter via
``persist_and_record_usage`` (same rails as every copilot turn).
Raises:
RuntimeError: If no LLM client is available.
@@ -133,6 +148,7 @@ async def _call_llm_for_simulation(
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
extra_body=_OPENROUTER_INCLUDE_USAGE_COST,
)
if not response.choices:
raise ValueError("LLM returned empty choices array")
@@ -141,13 +157,21 @@ async def _call_llm_for_simulation(
if not isinstance(parsed, dict):
raise ValueError(f"LLM returned non-object JSON: {raw[:200]}")
logger.debug(
"simulate(%s): attempt=%d tokens=%s/%s",
label,
attempt + 1,
getattr(getattr(response, "usage", None), "prompt_tokens", "?"),
getattr(getattr(response, "usage", None), "completion_tokens", "?"),
)
usage = response.usage
if usage is not None:
logger.debug(
"simulate(%s): attempt=%d tokens=%d/%d",
label,
attempt + 1,
usage.prompt_tokens,
usage.completion_tokens,
)
else:
logger.debug(
"simulate(%s): attempt=%d usage unavailable", label, attempt + 1
)
await _track_simulator_cost(usage=usage, user_id=user_id, model=model)
return parsed
except (json.JSONDecodeError, ValueError) as e:
@@ -174,6 +198,69 @@ async def _call_llm_for_simulation(
raise ValueError(msg)
def _extract_cost_usd(usage: CompletionUsage | None) -> float | None:
"""Return the provider-reported USD cost on the response usage object.
OpenRouter attaches a ``cost`` field to the OpenAI-compatible usage object
when the request body includes ``usage: {"include": True}``. The typed
``CompletionUsage`` does not declare it, so we read it off ``model_extra``
(pydantic v2's container for extras) to keep access fully typed — no
``getattr``. Mirrors ``backend.copilot.tools.web_search._extract_cost_usd``
and ``backend.copilot.baseline.service._extract_usage_cost``; keep the
three in sync.
"""
if usage is None:
return None
extras = usage.model_extra or {}
if "cost" not in extras:
return None
raw = extras["cost"]
if raw is None:
logger.error("[simulator] usage.cost is present but null")
return None
try:
val = float(raw)
except (TypeError, ValueError):
logger.error("[simulator] usage.cost is not numeric: %r", raw)
return None
if not math.isfinite(val) or val < 0:
logger.error("[simulator] usage.cost is non-finite or negative: %r", val)
return None
return val
async def _track_simulator_cost(
*,
usage: CompletionUsage | None,
user_id: str | None,
model: str,
) -> None:
"""Record platform cost for a single simulator LLM call.
The simulator runs outside a copilot ``ChatSession`` — pass ``session=None``
so ``persist_and_record_usage`` skips the session append but still charges
the user's rate-limit counter and writes a ``PlatformCostLog`` entry. No
user_id means no tracking (e.g. in-process tests that don't plumb one
through); rate-limit accounting silently no-ops in that case.
"""
if usage is None:
return
cost_usd = _extract_cost_usd(usage)
try:
await persist_and_record_usage(
session=None,
user_id=user_id,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
log_prefix="[simulator]",
cost_usd=cost_usd,
model=model,
provider="open_router",
)
except Exception as exc:
logger.warning("[simulator] usage tracking failed: %s", exc)
# ---------------------------------------------------------------------------
# Prompt builders
# ---------------------------------------------------------------------------
@@ -393,12 +480,18 @@ def _default_for_input_result(result_schema: dict[str, Any], name: str | None) -
async def simulate_block(
block: Any,
input_data: dict[str, Any],
*,
user_id: str | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
"""Simulate block execution using an LLM.
All block types (including MCPToolBlock) use the same generic LLM prompt
which includes the block's run() source code for accurate simulation.
``user_id`` is threaded through to platform cost tracking — every dry-run
LLM call hits the platform's OpenRouter key and is charged against the
triggering user's rate-limit counter, same rails as copilot turns.
Note: callers should check ``prepare_dry_run(block, input_data)`` first.
OrchestratorBlock and AgentExecutorBlock execute for real in dry-run mode
(see manager.py).
@@ -462,7 +555,9 @@ async def simulate_block(
label = getattr(block, "name", "?")
try:
parsed = await _call_llm_for_simulation(system_prompt, user_prompt, label=label)
parsed = await _call_llm_for_simulation(
system_prompt, user_prompt, label=label, user_id=user_id
)
# Track which pins were yielded so we can fill in missing required
# ones afterwards — downstream nodes connected to unyielded pins

View File

@@ -5,6 +5,7 @@ Covers:
- Input/output block passthrough
- prepare_dry_run routing
- simulate_block output-pin filling
- Default simulator model + OpenRouter cost tracking
"""
from __future__ import annotations
@@ -13,8 +14,14 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openai.types import CompletionUsage
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from backend.executor.simulator import (
_DEFAULT_SIMULATOR_MODEL,
_extract_cost_usd,
_truncate_input_values,
_truncate_value,
build_simulation_prompt,
@@ -511,3 +518,217 @@ class TestSimulateBlockPassthrough:
assert len(outputs) == 1
assert outputs[0][0] == "error"
assert "No client" in outputs[0][1]
# ---------------------------------------------------------------------------
# Default model + OpenRouter cost tracking
# ---------------------------------------------------------------------------
def _sim_usage(
*,
prompt_tokens: int = 1200,
completion_tokens: int = 300,
cost: object = 0.000157,
) -> CompletionUsage:
"""Typed ``CompletionUsage`` carrying OpenRouter's ``cost`` extension
via ``model_extra`` — same pattern as
``copilot/tools/web_search_test.py::_usage``. ``model_construct``
preserves unknown fields; ``model_validate`` would drop them."""
payload: dict[str, Any] = {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": prompt_tokens + completion_tokens,
}
if cost is not None:
payload["cost"] = cost
return CompletionUsage.model_construct(None, **payload)
def _sim_completion(*, content: str, usage: CompletionUsage) -> ChatCompletion:
"""Typed ``ChatCompletion`` shaped like an OpenRouter simulator
response so the production code runs under real SDK types."""
message = ChatCompletionMessage.model_construct(
None, role="assistant", content=content
)
choice = Choice.model_construct(
None, index=0, finish_reason="stop", message=message
)
return ChatCompletion.model_construct(
None,
id="cmpl-sim",
object="chat.completion",
created=0,
model=_DEFAULT_SIMULATOR_MODEL,
choices=[choice],
usage=usage,
)
class TestDefaultSimulatorModel:
"""Pin the default model — anyone flipping this without a cost review
trips the test."""
def test_default_is_flash_lite(self) -> None:
assert _DEFAULT_SIMULATOR_MODEL == "google/gemini-2.5-flash-lite"
class TestExtractCostUsd:
"""Provider-reported USD cost via typed ``model_extra`` — mirrors
``copilot.tools.web_search._extract_cost_usd`` and
``copilot.baseline.service._extract_usage_cost``."""
def test_returns_cost_value(self) -> None:
assert _extract_cost_usd(_sim_usage(cost=0.000157)) == pytest.approx(0.000157)
def test_returns_none_when_usage_missing(self) -> None:
assert _extract_cost_usd(None) is None
def test_returns_none_when_cost_field_missing(self) -> None:
assert _extract_cost_usd(_sim_usage(cost=None)) is None
def test_returns_none_when_cost_is_explicit_null(self) -> None:
usage = CompletionUsage.model_construct(
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=None
)
assert _extract_cost_usd(usage) is None
def test_returns_none_when_cost_is_negative(self) -> None:
usage = CompletionUsage.model_construct(
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=-0.5
)
assert _extract_cost_usd(usage) is None
def test_accepts_numeric_string(self) -> None:
usage = CompletionUsage.model_construct(
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost="0.017"
)
assert _extract_cost_usd(usage) == pytest.approx(0.017)
class TestSimulatorCostTracking:
"""Integration: mock the OpenAI client, confirm the simulator sends
the flash-lite default + extra_body, then plumbs through to
``persist_and_record_usage`` with ``provider='open_router'`` and the
real ``usage.cost`` pulled off ``model_extra``."""
def _mock_client(self, fake_resp: ChatCompletion) -> tuple[Any, AsyncMock]:
"""Build a fake ``AsyncOpenAI`` client. Same nested-type pattern as
``copilot/tools/web_search_test.py::_mock_client`` — avoids
MagicMock's auto-child-attr behaviour so the exact ``create`` call
surface is what gets invoked."""
create_mock = AsyncMock(return_value=fake_resp)
client = type(
"MC",
(),
{
"chat": type(
"C",
(),
{"completions": type("CC", (), {"create": create_mock})()},
)()
},
)()
return client, create_mock
@pytest.mark.asyncio
async def test_passes_default_model_and_tracks_cost(self) -> None:
block = _make_block()
fake_resp = _sim_completion(
content='{"result": "simulated"}',
usage=_sim_usage(prompt_tokens=1100, completion_tokens=220, cost=0.000189),
)
client, create_mock = self._mock_client(fake_resp)
with (
patch(
"backend.executor.simulator.get_openai_client",
return_value=client,
),
patch(
"backend.executor.simulator.persist_and_record_usage",
new=AsyncMock(return_value=1320),
) as mock_track,
):
outputs = []
async for name, data in simulate_block(
block, {"query": "hello"}, user_id="user-42"
):
outputs.append((name, data))
assert ("result", "simulated") in outputs
create_kwargs = create_mock.await_args.kwargs
assert create_kwargs["model"] == _DEFAULT_SIMULATOR_MODEL
assert create_kwargs["extra_body"] == {"usage": {"include": True}}
track_kwargs = mock_track.await_args.kwargs
assert track_kwargs["provider"] == "open_router"
assert track_kwargs["model"] == _DEFAULT_SIMULATOR_MODEL
assert track_kwargs["user_id"] == "user-42"
assert track_kwargs["prompt_tokens"] == 1100
assert track_kwargs["completion_tokens"] == 220
assert track_kwargs["cost_usd"] == pytest.approx(0.000189)
assert track_kwargs["session"] is None
assert track_kwargs["log_prefix"] == "[simulator]"
@pytest.mark.asyncio
async def test_tracks_even_when_cost_absent(self) -> None:
"""Provider may omit ``cost`` (e.g. non-OpenRouter proxies). We
still record token counts — ``persist_and_record_usage`` logs the
turn and skips the rate-limit ledger when cost is ``None``."""
block = _make_block()
fake_resp = _sim_completion(
content='{"result": "ok"}',
usage=_sim_usage(prompt_tokens=100, completion_tokens=20, cost=None),
)
client, _ = self._mock_client(fake_resp)
with (
patch(
"backend.executor.simulator.get_openai_client",
return_value=client,
),
patch(
"backend.executor.simulator.persist_and_record_usage",
new=AsyncMock(return_value=120),
) as mock_track,
):
async for _name, _data in simulate_block(
block, {"query": "x"}, user_id="user-7"
):
pass
track_kwargs = mock_track.await_args.kwargs
assert track_kwargs["cost_usd"] is None
assert track_kwargs["user_id"] == "user-7"
assert track_kwargs["provider"] == "open_router"
@pytest.mark.asyncio
async def test_tracking_failure_does_not_break_simulation(self) -> None:
"""Cost-tracking failures are warnings, not simulation failures —
the block output must still flow to the caller."""
block = _make_block()
fake_resp = _sim_completion(
content='{"result": "simulated"}',
usage=_sim_usage(),
)
client, _ = self._mock_client(fake_resp)
with (
patch(
"backend.executor.simulator.get_openai_client",
return_value=client,
),
patch(
"backend.executor.simulator.persist_and_record_usage",
new=AsyncMock(side_effect=RuntimeError("redis down")),
),
):
outputs = []
async for name, data in simulate_block(
block, {"query": "hello"}, user_id="user-42"
):
outputs.append((name, data))
assert ("result", "simulated") in outputs

View File

@@ -48,6 +48,16 @@ class Flag(str, Enum):
STRIPE_PRICE_BUSINESS = "stripe-price-id-business"
GRAPHITI_MEMORY = "graphiti-memory"
# Copilot model routing — string-valued, returns the model identifier
# (e.g. ``"anthropic/claude-sonnet-4-6"`` or ``"moonshotai/kimi-k2.6"``)
# to use for each cell of the (mode, tier) matrix. Falls back to the
# ``CHAT_*_MODEL`` env/config default when the flag is unset or LD is
# unavailable. Evaluated per user_id so cohorts can be targeted.
COPILOT_FAST_STANDARD_MODEL = "copilot-fast-standard-model"
COPILOT_FAST_ADVANCED_MODEL = "copilot-fast-advanced-model"
COPILOT_THINKING_STANDARD_MODEL = "copilot-thinking-standard-model"
COPILOT_THINKING_ADVANCED_MODEL = "copilot-thinking-advanced-model"
def is_configured() -> bool:
"""Check if LaunchDarkly is configured with an SDK key."""

View File

@@ -0,0 +1,177 @@
import { act, renderHook } from "@testing-library/react";
import type { UIMessage } from "ai";
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
import { useCopilotStream } from "../useCopilotStream";
// Capture the args passed to ``useChat`` so tests can invoke onFinish/onError
// directly — that's the only way to drive handleReconnect without a real SSE.
let lastUseChatArgs: {
onFinish?: (args: { isDisconnect?: boolean; isAbort?: boolean }) => void;
onError?: (err: Error) => void;
} | null = null;
const resumeStreamMock = vi.fn();
const sdkStopMock = vi.fn();
const sdkSendMessageMock = vi.fn();
const setMessagesMock = vi.fn();
function resetSdkMocks() {
lastUseChatArgs = null;
resumeStreamMock.mockReset();
sdkStopMock.mockReset();
sdkSendMessageMock.mockReset();
setMessagesMock.mockReset();
}
vi.mock("@ai-sdk/react", () => ({
useChat: (args: unknown) => {
lastUseChatArgs = args as typeof lastUseChatArgs;
return {
messages: [] as UIMessage[],
sendMessage: sdkSendMessageMock,
stop: sdkStopMock,
status: "ready" as const,
error: undefined,
setMessages: setMessagesMock,
resumeStream: resumeStreamMock,
};
},
}));
vi.mock("ai", async () => {
const actual = await vi.importActual<typeof import("ai")>("ai");
return {
...actual,
DefaultChatTransport: class {
constructor(public opts: unknown) {}
},
};
});
vi.mock("@tanstack/react-query", () => ({
useQueryClient: () => ({ invalidateQueries: vi.fn() }),
}));
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
getGetV2GetCopilotUsageQueryKey: () => ["copilot-usage"],
getGetV2GetSessionQueryKey: (id: string) => ["session", id],
postV2CancelSessionTask: vi.fn(),
deleteV2DisconnectSessionStream: vi.fn().mockResolvedValue(undefined),
}));
vi.mock("@/components/molecules/Toast/use-toast", () => ({
toast: vi.fn(),
}));
vi.mock("@/services/environment", () => ({
environment: {
getAGPTServerBaseUrl: () => "http://localhost",
},
}));
vi.mock("../helpers", async () => {
const actual =
await vi.importActual<typeof import("../helpers")>("../helpers");
return {
...actual,
getCopilotAuthHeaders: vi.fn().mockResolvedValue({}),
disconnectSessionStream: vi.fn(),
};
});
vi.mock("../useHydrateOnStreamEnd", () => ({
useHydrateOnStreamEnd: () => undefined,
}));
function renderStream() {
return renderHook(() =>
useCopilotStream({
sessionId: "sess-1",
hydratedMessages: [],
hasActiveStream: false,
refetchSession: vi.fn().mockResolvedValue({ data: undefined }),
copilotMode: undefined,
copilotModel: undefined,
}),
);
}
describe("useCopilotStream — reconnect debounce", () => {
beforeEach(() => {
resetSdkMocks();
vi.useFakeTimers();
// Pin Date.now so sinceLastResume math is deterministic. The hook reads
// Date.now() both when stashing lastReconnectResumeAtRef and when
// deciding whether to debounce.
vi.setSystemTime(new Date(2025, 0, 1, 12, 0, 0));
});
afterEach(() => {
vi.useRealTimers();
});
it("coalesces a burst of disconnect events into one resumeStream call", async () => {
renderStream();
// First disconnect — schedules a reconnect at the exponential backoff
// delay (1000ms for attempt #1).
await act(async () => {
await lastUseChatArgs!.onFinish!({ isDisconnect: true });
});
// Fire the scheduled timer → resumeStream runs once and stamps
// lastReconnectResumeAtRef.current = Date.now().
await act(async () => {
await vi.advanceTimersByTimeAsync(1_000);
});
expect(resumeStreamMock).toHaveBeenCalledTimes(1);
// A second disconnect arrives immediately after (still inside the
// 1500ms debounce window) — the debounce path must fire and queue a
// coalesced timer, NOT a fresh resume.
await act(async () => {
await lastUseChatArgs!.onFinish!({ isDisconnect: true });
});
expect(resumeStreamMock).toHaveBeenCalledTimes(1);
// The coalesced timer fires at the window boundary and reschedules a
// real reconnect. Advance past the window AND past the second
// reconnect's backoff (attempt #2 = 2000ms) so resume runs.
await act(async () => {
await vi.advanceTimersByTimeAsync(1_500);
});
await act(async () => {
await vi.advanceTimersByTimeAsync(2_000);
});
expect(resumeStreamMock).toHaveBeenCalledTimes(2);
});
it("does not debounce a reconnect that arrives after the window closes", async () => {
renderStream();
// First reconnect cycle.
await act(async () => {
await lastUseChatArgs!.onFinish!({ isDisconnect: true });
});
await act(async () => {
await vi.advanceTimersByTimeAsync(1_000);
});
expect(resumeStreamMock).toHaveBeenCalledTimes(1);
// Wait past the debounce window before the next disconnect.
await act(async () => {
await vi.advanceTimersByTimeAsync(2_000);
});
// Now a fresh disconnect should go through the normal path (NOT the
// debounce branch) and schedule a backoff of 2000ms (attempt #2).
await act(async () => {
await lastUseChatArgs!.onFinish!({ isDisconnect: true });
});
await act(async () => {
await vi.advanceTimersByTimeAsync(2_000);
});
expect(resumeStreamMock).toHaveBeenCalledTimes(2);
});
});

View File

@@ -7,6 +7,7 @@ import {
formatNotificationTitle,
getSendSuppressionReason,
parseSessionIDs,
shouldDebounceReconnect,
shouldSuppressDuplicateSend,
} from "./helpers";
@@ -466,3 +467,88 @@ describe("deduplicateMessages", () => {
expect(result).toHaveLength(2); // duplicate step-start messages are deduped
});
});
describe("shouldDebounceReconnect", () => {
const WINDOW_MS = 1_500;
it("returns null for the first reconnect (lastResumeAt === 0)", () => {
expect(shouldDebounceReconnect(0, 10_000, WINDOW_MS)).toBeNull();
});
it("returns null for a negative lastResumeAt sentinel", () => {
// Defensive: a negative value is still treated as "no reconnect yet".
expect(shouldDebounceReconnect(-1, 10_000, WINDOW_MS)).toBeNull();
});
it("returns the remaining delay when now is inside the window", () => {
// 500ms since the last resume — the caller must wait another 1000ms
// before the storm cap reopens.
const remaining = shouldDebounceReconnect(1_000, 1_500, WINDOW_MS);
expect(remaining).toBe(1_000);
});
it("coalesces a reconnect that arrives immediately after the previous resume", () => {
// now === lastResumeAt → sinceLastResume === 0, so the full window remains.
const remaining = shouldDebounceReconnect(5_000, 5_000, WINDOW_MS);
expect(remaining).toBe(WINDOW_MS);
});
it("returns null when exactly on the window boundary", () => {
// sinceLastResume === windowMs is NOT inside the window — the next
// reconnect should fire immediately.
expect(shouldDebounceReconnect(1_000, 2_500, WINDOW_MS)).toBeNull();
});
it("returns null when the window has elapsed", () => {
expect(shouldDebounceReconnect(1_000, 5_000, WINDOW_MS)).toBeNull();
});
it("returns a small remaining delay at the far edge of the window", () => {
// 1ms before the window closes → 1ms left.
const remaining = shouldDebounceReconnect(1_000, 2_499, WINDOW_MS);
expect(remaining).toBe(1);
});
it("collapses a burst of reconnects into one debounced scheduling", () => {
// Simulates the browser tab-throttle storm: three reconnect calls fire
// within a single second after the last resume. Only the first slot
// would actually run; subsequent calls must always be coalesced.
const lastResumeAt = 10_000;
const firstCallRemaining = shouldDebounceReconnect(
lastResumeAt,
10_100,
WINDOW_MS,
);
const secondCallRemaining = shouldDebounceReconnect(
lastResumeAt,
10_200,
WINDOW_MS,
);
const thirdCallRemaining = shouldDebounceReconnect(
lastResumeAt,
10_300,
WINDOW_MS,
);
expect(firstCallRemaining).toBe(1_400);
expect(secondCallRemaining).toBe(1_300);
expect(thirdCallRemaining).toBe(1_200);
});
it("allows a reconnect to fire immediately once the window has passed", () => {
// After the window expires, a retry that came in earlier can now fire
// rather than stalling the loop. Guards against the regression that
// motivated the coalesce-instead-of-drop fix.
const lastResumeAt = 10_000;
expect(
shouldDebounceReconnect(lastResumeAt, 10_500, WINDOW_MS),
).not.toBeNull();
expect(shouldDebounceReconnect(lastResumeAt, 11_500, WINDOW_MS)).toBeNull();
});
it("honours a custom windowMs value", () => {
// Shouldn't hard-code 1500 anywhere: the helper is generic over the
// window.
expect(shouldDebounceReconnect(1_000, 1_500, 2_000)).toBe(1_500);
expect(shouldDebounceReconnect(1_000, 3_500, 2_000)).toBeNull();
});
});

View File

@@ -184,6 +184,28 @@ export function disconnectSessionStream(sessionId: string): void {
deleteV2DisconnectSessionStream(sessionId).catch(() => {});
}
/**
* Decide whether a reconnect request must be coalesced onto the debounce
* window boundary, rather than firing immediately.
*
* Returns the remaining milliseconds until the window closes (so the caller
* can schedule a `setTimeout` for that delay) when the previous resume
* happened inside the window, or `null` to let the reconnect proceed now.
*
* `lastResumeAt === 0` signals "no reconnect has fired yet in this session"
* — the first reconnect always passes through regardless of `now`.
*/
export function shouldDebounceReconnect(
lastResumeAt: number,
now: number,
windowMs: number,
): number | null {
if (lastResumeAt <= 0) return null;
const sinceLastResume = now - lastResumeAt;
if (sinceLastResume >= windowMs) return null;
return windowMs - sinceLastResume;
}
/**
* Deduplicate messages by ID and by consecutive content fingerprint.
*

View File

@@ -313,11 +313,19 @@ function getWebAccordionData(
: null;
if (results) {
const deep = inp.deep === true;
const noun = deep ? "research source" : "search result";
const answer = getStringField(output, "answer");
return {
title: `${results.length} search result${results.length === 1 ? "" : "s"}`,
title: `${results.length} ${noun}${results.length === 1 ? "" : "s"}`,
description: query ? truncate(query, 80) : undefined,
content: (
<div className="space-y-3">
{answer && (
<div className="whitespace-pre-wrap rounded-md bg-slate-50 p-3 text-sm text-slate-800">
{answer}
</div>
)}
{results.map((r, i) => {
const title = getStringField(r, "title") ?? "(untitled)";
const href = getStringField(r, "url") ?? "";

View File

@@ -141,6 +141,7 @@ describe("GenericTool", () => {
function makeWebSearchPart(
results: Array<Record<string, unknown>>,
query = "kimi k2.6",
answer = "",
): ToolUIPart {
return {
type: "tool-web_search",
@@ -149,6 +150,7 @@ describe("GenericTool", () => {
input: { query },
output: {
type: "web_search_response",
answer,
results,
query,
search_requests: 1,
@@ -254,6 +256,25 @@ describe("GenericTool", () => {
expect(normalized).toContain('Searched "kimi k2.6"');
});
it("renders the synthesised answer above the citations when present", () => {
render(
<GenericTool
part={makeWebSearchPart(
[
{ title: "Citation 1", url: "https://example.com/one" },
{ title: "Citation 2", url: "https://example.com/two" },
],
"kimi k2.6 launch",
"Kimi K2.6 launched on 2026-04-20 with SWE-Bench parity to Opus.",
)}
/>,
);
fireEvent.click(screen.getByRole("button", { expanded: false }));
expect(
screen.getByText(/Kimi K2\.6 launched on 2026-04-20/),
).not.toBeNull();
});
it("uses '(untitled)' when a search result has no title", () => {
render(
<GenericTool

View File

@@ -205,6 +205,14 @@ export function humanizeFileName(filePath: string): string {
/* Animation text */
/* ------------------------------------------------------------------ */
// web_search accepts a ``deep`` arg that dispatches to a multi-step
// research model; render a distinct verb ("Researching"/"Researched"/
// "Research failed") so users know the call takes longer.
function _isDeepWebSearch(part: ToolUIPart): boolean {
const input = part.input as Record<string, unknown> | undefined;
return input?.deep === true;
}
export function getAnimationText(
part: ToolUIPart,
category: ToolCategory,
@@ -223,9 +231,11 @@ export function getAnimationText(
: "Running command\u2026";
case "web":
if (toolName === "WebSearch" || toolName === "web_search") {
const deep = _isDeepWebSearch(part);
const verb = deep ? "Researching" : "Searching";
return shortSummary
? `Searching "${shortSummary}"`
: "Searching the web\u2026";
? `${verb} "${shortSummary}"`
: `${verb} the web\u2026`;
}
return shortSummary
? `Fetching ${shortSummary}`
@@ -285,9 +295,12 @@ export function getAnimationText(
return shortSummary ? `Ran: ${shortSummary}` : "Command completed";
case "web":
if (toolName === "WebSearch" || toolName === "web_search") {
return shortSummary
? `Searched "${shortSummary}"`
const deep = _isDeepWebSearch(part);
const verb = deep ? "Researched" : "Searched";
const completed = deep
? "Web research completed"
: "Web search completed";
return shortSummary ? `${verb} "${shortSummary}"` : completed;
}
return shortSummary
? `Fetched ${shortSummary}`
@@ -354,9 +367,10 @@ export function getAnimationText(
case "bash":
return "Command failed";
case "web":
return toolName === "WebSearch" || toolName === "web_search"
? "Search failed"
: "Fetch failed";
if (toolName === "WebSearch" || toolName === "web_search") {
return _isDeepWebSearch(part) ? "Research failed" : "Search failed";
}
return "Fetch failed";
case "browser":
return "Browser action failed";
default:

View File

@@ -18,6 +18,7 @@ import {
resolveInProgressTools,
getSendSuppressionReason,
disconnectSessionStream,
shouldDebounceReconnect,
} from "./helpers";
import type { CopilotLlmModel, CopilotMode } from "./store";
import { useHydrateOnStreamEnd } from "./useHydrateOnStreamEnd";
@@ -25,6 +26,19 @@ import { useHydrateOnStreamEnd } from "./useHydrateOnStreamEnd";
const RECONNECT_BASE_DELAY_MS = 1_000;
const RECONNECT_MAX_ATTEMPTS = 3;
/**
* Minimum spacing between successive reconnect attempts.
* `isReconnectScheduledRef` already prevents OVERLAPPING reconnects, but
* tab-throttle / visibility wake bursts can fire `onFinish(isDisconnect)`
* several times inside a single second — each one would schedule a fresh
* reconnect the moment the previous timer cleared the ref. Requests that
* arrive inside this window since the last reconnect's resume are COALESCED:
* scheduled to run at the window boundary rather than dropped, so a
* fast-failing resume (e.g. a 502 on GET /stream that trips `onError` inside
* 500 ms) still retries instead of stalling the retry loop.
*/
const RECONNECT_DEBOUNCE_MS = 1_500;
/** Minimum time the page must have been hidden to trigger a wake re-sync. */
const WAKE_RESYNC_THRESHOLD_MS = 30_000;
@@ -110,6 +124,11 @@ export function useCopilotStream({
const isReconnectScheduledRef = useRef(false);
const [isReconnectScheduled, setIsReconnectScheduled] = useState(false);
const reconnectTimerRef = useRef<ReturnType<typeof setTimeout>>();
// Timestamp of the last reconnect's actual resume call — used together
// with RECONNECT_DEBOUNCE_MS to drop rapid duplicate reconnect requests
// (e.g. visibility throttle firing onFinish(isDisconnect) several times
// in the same second). 0 = no reconnect has fired yet in this session.
const lastReconnectResumeAtRef = useRef(0);
const hasShownDisconnectToast = useRef(false);
// Set when the user explicitly clicks stop — prevents onError from
// triggering a reconnect cycle for the resulting AbortError.
@@ -127,6 +146,32 @@ export function useCopilotStream({
function handleReconnect(sid: string) {
if (isReconnectScheduledRef.current || !sid) return;
// Debounce: if the previous reconnect resumed within the last
// RECONNECT_DEBOUNCE_MS, COALESCE this request onto the window boundary
// rather than dropping it. Browser tab-throttle bursts can fire
// onFinish(isDisconnect) 23 times in a second; without the debounce,
// each fires its own GET /stream, each one replays the Redis stream,
// and the flicker storm is back. Dropping the request silently (the
// previous behaviour) stalled the retry loop when a resume failed
// quickly — e.g. a 502 on GET /stream that trips onError inside 500 ms
// while the 1500 ms window is still open. Scheduling the retry for
// the remaining window preserves both the storm cap and the retry.
const remainingDelay = shouldDebounceReconnect(
lastReconnectResumeAtRef.current,
Date.now(),
RECONNECT_DEBOUNCE_MS,
);
if (remainingDelay !== null) {
isReconnectScheduledRef.current = true;
setIsReconnectScheduled(true);
reconnectTimerRef.current = setTimeout(() => {
isReconnectScheduledRef.current = false;
setIsReconnectScheduled(false);
handleReconnect(sid);
}, remainingDelay);
return;
}
const nextAttempt = reconnectAttemptsRef.current + 1;
if (nextAttempt > RECONNECT_MAX_ATTEMPTS) {
setReconnectExhausted(true);
@@ -163,6 +208,7 @@ export function useCopilotStream({
}
return prev;
});
lastReconnectResumeAtRef.current = Date.now();
resumeStreamRef.current();
}, delay);
}
@@ -469,6 +515,7 @@ export function useCopilotStream({
setRateLimitMessage(null);
hasShownDisconnectToast.current = false;
lastSubmittedMsgRef.current = null;
lastReconnectResumeAtRef.current = 0;
setReconnectExhausted(false);
setIsSyncing(false);
hasResumedRef.current.clear();

View File

@@ -2119,7 +2119,8 @@
},
{
"$ref": "#/components/schemas/MemoryForgetConfirmResponse"
}
},
{ "$ref": "#/components/schemas/TodoWriteResponse" }
],
"title": "Response Getv2[Dummy] Tool Response Type Export For Codegen"
}
@@ -14648,7 +14649,8 @@
"memory_store",
"memory_search",
"memory_forget_candidates",
"memory_forget_confirm"
"memory_forget_confirm",
"todo_write"
],
"title": "ResponseType",
"description": "Types of tool responses."
@@ -16810,6 +16812,52 @@
"required": ["timezone"],
"title": "TimezoneResponse"
},
"TodoItem": {
"properties": {
"content": {
"type": "string",
"title": "Content",
"description": "Imperative description of the task."
},
"activeForm": {
"type": "string",
"title": "Activeform",
"description": "Present-continuous form shown while the task is running."
},
"status": {
"type": "string",
"enum": ["pending", "in_progress", "completed"],
"title": "Status",
"default": "pending"
}
},
"type": "object",
"required": ["content", "activeForm"],
"title": "TodoItem",
"description": "One entry in a ``TodoWrite`` checklist.\n\nMirrors the schema used by Claude Code's built-in ``TodoWrite`` tool so\nthe frontend's ``GenericTool`` accordion renders baseline-emitted todos\nidentically to SDK-emitted ones."
},
"TodoWriteResponse": {
"properties": {
"type": {
"$ref": "#/components/schemas/ResponseType",
"default": "todo_write"
},
"message": { "type": "string", "title": "Message" },
"session_id": {
"anyOf": [{ "type": "string" }, { "type": "null" }],
"title": "Session Id"
},
"todos": {
"items": { "$ref": "#/components/schemas/TodoItem" },
"type": "array",
"title": "Todos"
}
},
"type": "object",
"required": ["message"],
"title": "TodoWriteResponse",
"description": "Ack returned by ``TodoWrite``.\n\nThe tool is effectively stateless — the authoritative task list lives in\nthe assistant's latest tool-call arguments, which are replayed from the\ntranscript on each turn. The tool output only needs to confirm that the\nupdate was accepted so the model can proceed."
},
"TokenIntrospectionResult": {
"properties": {
"active": { "type": "boolean", "title": "Active" },