mirror of
https://github.com/Significant-Gravitas/AutoGPT.git
synced 2026-04-30 03:00:41 -04:00
Merge remote-tracking branch 'origin/dev' into feat/task-decomposition-copilot
# Conflicts: # autogpt_platform/backend/backend/api/features/chat/routes.py # autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md # autogpt_platform/backend/backend/copilot/tools/tool_schema_test.py
This commit is contained in:
245
.claude/skills/pr-polish/SKILL.md
Normal file
245
.claude/skills/pr-polish/SKILL.md
Normal file
@@ -0,0 +1,245 @@
|
||||
---
|
||||
name: pr-polish
|
||||
description: Alternate /pr-review and /pr-address on a PR until the PR is truly mergeable — no new review findings, zero unresolved inline threads, zero unaddressed top-level reviews or issue comments, all CI checks green, and two consecutive quiet polls after CI settles. Use when the user wants a PR polished to merge-ready without setting a fixed number of rounds.
|
||||
user-invocable: true
|
||||
argument-hint: "[PR number or URL] — if omitted, finds PR for current branch."
|
||||
metadata:
|
||||
author: autogpt-team
|
||||
version: "1.0.0"
|
||||
---
|
||||
|
||||
# PR Polish
|
||||
|
||||
**Goal.** Drive a PR to merge-ready by alternating `/pr-review` and `/pr-address` until **all** of the following hold:
|
||||
|
||||
1. The most recent `/pr-review` produces **zero new findings** (no new inline comments, no new top-level reviews with a non-empty body).
|
||||
2. Every inline review thread reachable via GraphQL reports `isResolved: true`.
|
||||
3. Every non-bot, non-author top-level review has been acknowledged (replied-to) OR resolved via a thread it spawned.
|
||||
4. Every non-bot, non-author issue comment has been acknowledged (replied-to).
|
||||
5. Every CI check is `conclusion: "success"` or `"skipped"` / `"neutral"` — none `"failure"` or still pending.
|
||||
6. **Two consecutive post-CI polls** (≥60s apart) stay clean — no new threads, no new non-empty reviews, no new issue comments. Bots (coderabbitai, sentry, autogpt-reviewer) frequently post late after CI settles; a single green snapshot is not sufficient.
|
||||
|
||||
**Do not stop at a fixed number of rounds.** If round N introduces new comments, round N+1 is required. Cap at `_MAX_ROUNDS = 10` as a safety valve, but expect 2–5 in practice.
|
||||
|
||||
## TodoWrite
|
||||
|
||||
Before starting, write two todos so the user can see the loop progression:
|
||||
|
||||
- `Round {current}: /pr-review + /pr-address on PR #{N}` — current iteration.
|
||||
- `Final polish polling: 2 consecutive clean polls, CI green, 0 unresolved` — runs after the last non-empty review round.
|
||||
|
||||
Update the `current` round counter at the start of each iteration; mark `completed` only when the round's address step finishes (all new threads addressed + resolved).
|
||||
|
||||
## Find the PR
|
||||
|
||||
```bash
|
||||
ARG_PR="${ARG:-}"
|
||||
# Normalize URL → numeric ID if the skill arg is a pull-request URL.
|
||||
if [[ "$ARG_PR" =~ ^https?://github\.com/[^/]+/[^/]+/pull/([0-9]+) ]]; then
|
||||
ARG_PR="${BASH_REMATCH[1]}"
|
||||
fi
|
||||
PR="${ARG_PR:-$(gh pr list --head "$(git branch --show-current)" --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')}"
|
||||
if [ -z "$PR" ] || [ "$PR" = "null" ]; then
|
||||
echo "No PR found for current branch. Provide a PR number or URL as the skill arg."
|
||||
exit 1
|
||||
fi
|
||||
echo "Polishing PR #$PR"
|
||||
```
|
||||
|
||||
## The outer loop
|
||||
|
||||
```text
|
||||
round = 0
|
||||
while round < _MAX_ROUNDS:
|
||||
round += 1
|
||||
baseline = snapshot_state(PR) # see "Snapshotting state" below
|
||||
invoke_skill("pr-review", PR) # posts findings as inline comments / top-level review
|
||||
findings = diff_state(PR, baseline)
|
||||
if findings.total == 0:
|
||||
break # no new findings → go to polish polling
|
||||
invoke_skill("pr-address", PR) # resolves every unresolved thread + CI failure
|
||||
# Post-loop: polish polling (see below).
|
||||
polish_polling(PR)
|
||||
```
|
||||
|
||||
### Snapshotting state
|
||||
|
||||
Before each `/pr-review`, capture a baseline so the diff after the review reflects **only** what the review just added (not pre-existing threads):
|
||||
|
||||
```bash
|
||||
# Inline threads — total count + latest databaseId per thread
|
||||
gh api graphql -f query="
|
||||
{
|
||||
repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") {
|
||||
pullRequest(number: ${PR}) {
|
||||
reviewThreads(first: 100) {
|
||||
totalCount
|
||||
nodes {
|
||||
id
|
||||
isResolved
|
||||
comments(last: 1) { nodes { databaseId } }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}" > /tmp/baseline_threads.json
|
||||
|
||||
# Top-level reviews — count + latest id per non-empty review
|
||||
gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}/reviews" --paginate \
|
||||
--jq '[.[] | select((.body // "") != "") | {id, user: .user.login, state, submitted_at}]' \
|
||||
> /tmp/baseline_reviews.json
|
||||
|
||||
# Issue comments — count + latest id per non-bot, non-author comment.
|
||||
# Bots are filtered by User.type == "Bot" (GitHub sets this for app/bot
|
||||
# accounts like coderabbitai, github-actions, sentry-io). The author is
|
||||
# filtered by comparing login to the PR author — export it so jq can see it.
|
||||
AUTHOR=$(gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}" --jq '.user.login')
|
||||
gh api "repos/Significant-Gravitas/AutoGPT/issues/${PR}/comments" --paginate \
|
||||
--jq --arg author "$AUTHOR" \
|
||||
'[.[] | select(.user.type != "Bot" and .user.login != $author)
|
||||
| {id, user: .user.login, created_at}]' \
|
||||
> /tmp/baseline_issue_comments.json
|
||||
```
|
||||
|
||||
### Diffing after a review
|
||||
|
||||
After `/pr-review` runs, any of these counting as "new findings" means another address round is needed:
|
||||
|
||||
- New inline thread `id` not in the baseline.
|
||||
- An existing thread whose latest comment `databaseId` is higher than the baseline's (new reply on an old thread).
|
||||
- A new top-level review `id` with a non-empty body.
|
||||
- A new issue comment `id` from a non-bot, non-author user.
|
||||
|
||||
If any of the four buckets is non-empty → not done; invoke `/pr-address` and loop.
|
||||
|
||||
## Polish polling
|
||||
|
||||
Once `/pr-review` produces zero new findings, do **not** exit yet. Bots (coderabbitai, sentry, autogpt-reviewer) commonly post late reviews after CI settles — 30–90 seconds after the final push. Poll at 60-second intervals:
|
||||
|
||||
```text
|
||||
NON_SUCCESS_TERMINAL = {"failure", "cancelled", "timed_out", "action_required", "startup_failure"}
|
||||
clean_polls = 0
|
||||
required_clean = 2
|
||||
while clean_polls < required_clean:
|
||||
# 1. CI gate — any terminal non-success conclusion (not just "failure")
|
||||
# must trigger /pr-address. "success", "skipped", "neutral" are clean;
|
||||
# anything else (including cancelled, timed_out, action_required) is a
|
||||
# blocker that won't self-resolve.
|
||||
ci = fetch_check_runs(PR)
|
||||
if any ci.conclusion in NON_SUCCESS_TERMINAL:
|
||||
invoke_skill("pr-address", PR) # address failures + any new comments
|
||||
baseline = snapshot_state(PR) # reset — push during address invalidates old baseline
|
||||
clean_polls = 0
|
||||
continue
|
||||
if any ci.conclusion is None (still in_progress):
|
||||
sleep 60; continue # wait without counting this as clean
|
||||
|
||||
# 2. Comment / thread gate
|
||||
threads = fetch_unresolved_threads(PR)
|
||||
new_issue_comments = diff_against_baseline(issue_comments)
|
||||
new_reviews = diff_against_baseline(reviews)
|
||||
if threads or new_issue_comments or new_reviews:
|
||||
invoke_skill("pr-address", PR)
|
||||
baseline = snapshot_state(PR) # reset — the address loop just dealt with these,
|
||||
# otherwise they stay "new" relative to the old baseline forever
|
||||
clean_polls = 0
|
||||
continue
|
||||
|
||||
# 3. Mergeability gate
|
||||
mergeable = gh api repos/.../pulls/${PR} --jq '.mergeable'
|
||||
if mergeable == false (CONFLICTING):
|
||||
resolve_conflicts(PR) # see pr-address skill
|
||||
clean_polls = 0
|
||||
continue
|
||||
if mergeable is null (UNKNOWN):
|
||||
sleep 60; continue
|
||||
|
||||
clean_polls += 1
|
||||
sleep 60
|
||||
```
|
||||
|
||||
Only after `clean_polls == 2` do you report `ORCHESTRATOR:DONE`.
|
||||
|
||||
### Why 2 clean polls, not 1
|
||||
|
||||
A single green snapshot can be misleading — the final CI check often completes ~30s before a bot posts its delayed review. One quiet cycle does not prove the PR is stable; two consecutive cycles with no new threads, reviews, or issue comments arriving gives high confidence nothing else is incoming.
|
||||
|
||||
### Why checking every source each poll
|
||||
|
||||
`/pr-address` polling inside a single round already re-checks its own comments, but `/pr-polish` sits a level above and must also catch:
|
||||
|
||||
- New top-level reviews (autogpt-reviewer sometimes posts structured feedback only after several CI green cycles).
|
||||
- Issue comments from human reviewers (not caught by inline thread polling).
|
||||
- Sentry bug predictions that land on new line numbers post-push.
|
||||
- Merge conflicts introduced by a race between your push and a merge to `dev`.
|
||||
|
||||
## Invocation pattern
|
||||
|
||||
Delegate to existing skills with the `Skill` tool; do not re-implement the review or address logic inline. This keeps the polish loop focused on orchestration and lets the child skills evolve independently.
|
||||
|
||||
```python
|
||||
Skill(skill="pr-review", args=pr_url)
|
||||
Skill(skill="pr-address", args=pr_url)
|
||||
```
|
||||
|
||||
After each child invocation, re-query GitHub state directly — never trust a summary for the stop condition. The orchestrator's `ORCHESTRATOR:DONE` is verified against actual GraphQL / REST responses per the rules in `pr-address`'s "Verify actual count before outputting ORCHESTRATOR:DONE" section.
|
||||
|
||||
### **Auto-continue: do NOT end your response between child skills**
|
||||
|
||||
`/pr-polish` is a single orchestration task — one invocation drives the PR all the way to merge-ready. When a child `Skill()` call returns control to you:
|
||||
|
||||
- Do NOT summarize and stop.
|
||||
- Do NOT wait for user confirmation to continue.
|
||||
- Immediately, in the same response, perform the next loop step: state diff → decide next action → next `Skill()` call or polling sleep.
|
||||
|
||||
The child skill returning is a **loop iteration boundary**, not a conversation turn boundary. You are expected to keep going until one of the exit conditions in the opening section is met (2 consecutive clean polls, `_MAX_ROUNDS` hit, or an unrecoverable error).
|
||||
|
||||
If the user needs to approve a risky action mid-loop (e.g., a force-push or a destructive git operation), pause there — but not at the routine "round N finished, round N+1 needed" boundary. Those are silent transitions.
|
||||
|
||||
## GitHub rate limits
|
||||
|
||||
This skill issues many GraphQL calls (one review-thread query per outer iteration plus per-poll queries inside polish polling). Expect the GraphQL budget to be tight on large PRs. When `gh api rate_limit --jq .resources.graphql.remaining` drops below ~200, back off:
|
||||
|
||||
- Fall back to REST for reads (flat `/pulls/{N}/comments`, `/pulls/{N}/reviews`, `/issues/{N}/comments`) per the `pr-address` skill's GraphQL-fallback section.
|
||||
- Queue thread resolutions (GraphQL-only) until the budget resets; keep making progress on fixes + REST replies meanwhile.
|
||||
- `sleep 5` between any batch of ≥20 writes to avoid secondary rate limits.
|
||||
|
||||
## Safety valves
|
||||
|
||||
- `_MAX_ROUNDS = 10` — if review+address rounds exceed this, stop and escalate to the user with a summary of what's still unresolved. A PR that cannot converge in 10 rounds has systemic issues that need human judgment.
|
||||
- After each commit, run `poetry run format` / `pnpm format && pnpm lint && pnpm types` per the target codebase's conventions. A failing format check is CI `failure` that will never self-resolve.
|
||||
- Every `/pr-review` round checks for **duplicate** concerns first (via `pr-review`'s own "Fetch existing review comments" step) so the loop does not re-post the same finding that a prior round already resolved.
|
||||
|
||||
## Reporting
|
||||
|
||||
When the skill finishes (either via two clean polls or hitting `_MAX_ROUNDS`), produce a compact summary:
|
||||
|
||||
```
|
||||
PR #{N} polish complete ({rounds_completed} rounds):
|
||||
- {X} inline threads opened and resolved
|
||||
- {Y} CI failures fixed
|
||||
- {Z} new commits pushed
|
||||
Final state: CI green, {total} threads all resolved, mergeable.
|
||||
```
|
||||
|
||||
If exiting via `_MAX_ROUNDS`, flag explicitly:
|
||||
|
||||
```
|
||||
PR #{N} polish stopped at {_MAX_ROUNDS} rounds — NOT merge-ready:
|
||||
- {N} threads still unresolved: {titles}
|
||||
- CI status: {summary}
|
||||
Needs human review.
|
||||
```
|
||||
|
||||
## When to use this skill
|
||||
|
||||
Use when the user says any of:
|
||||
- "polish this PR"
|
||||
- "keep reviewing and addressing until it's mergeable"
|
||||
- "loop /pr-review + /pr-address until done"
|
||||
- "make sure the PR is actually merge-ready"
|
||||
|
||||
Do **not** use when:
|
||||
- User wants just one review pass (→ `/pr-review`).
|
||||
- User wants to address already-posted comments without further self-review (→ `/pr-address`).
|
||||
- A fixed round count is explicitly requested (e.g., "do 3 rounds") — honour the count instead of converging.
|
||||
@@ -260,6 +260,32 @@ Use a `trap` so release runs even on `exit 1`:
|
||||
trap 'kill "$HEARTBEAT_PID" 2>/dev/null; rm -f "$LOCK"' EXIT INT TERM
|
||||
```
|
||||
|
||||
### **Release the lock AS SOON AS the test run is done**
|
||||
|
||||
The lock guards **test execution**, not **app lifecycle**. Once Step 5 (record results) and Step 6 (post PR comment) are complete, release the lock IMMEDIATELY — even if:
|
||||
|
||||
- The native `poetry run app` / `pnpm dev` processes are still running so the user can keep poking at the app manually.
|
||||
- You're leaving docker containers up.
|
||||
- You're tailing logs for a minute or two.
|
||||
|
||||
Keeping the lock held past the test run is the single most common way `/pr-test` stalls other agents. **The app staying up is orthogonal to the lock; don't conflate them.** Sibling worktrees running their own `/pr-test` will kill the stray processes and free the ports themselves (Step 3c/3e-native handle that) — they just need the lock file gone.
|
||||
|
||||
Concretely, the sequence at the end of every `/pr-test` run (success or failure) is:
|
||||
|
||||
```bash
|
||||
# 1. Write the final report + post PR comment — done above in Step 5/6.
|
||||
# 2. Release the lock right now, even if the app is still up.
|
||||
kill "$HEARTBEAT_PID" 2>/dev/null
|
||||
rm -f "$LOCK" /tmp/pr-test-heartbeat.pid
|
||||
echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock (app may still be running)" \
|
||||
>> /Users/majdyz/Code/AutoGPT/.ign.testing.log
|
||||
# 3. Optionally leave the app running and note it so the user knows:
|
||||
echo "Native stack still running on :3000 / :8006 for manual poking. Kill with:"
|
||||
echo " pkill -9 -f 'poetry run app'; pkill -9 -f 'next-server|next dev'"
|
||||
```
|
||||
|
||||
If a sibling agent's `/pr-test` needs to take over, it'll do the kill+rebuild dance from Step 3c/3e-native on its own — your only job is to not hold the lock file past the end of your test.
|
||||
|
||||
### Shared status log
|
||||
|
||||
`/Users/majdyz/Code/AutoGPT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes:
|
||||
@@ -755,6 +781,19 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations
|
||||
|
||||
**CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes.
|
||||
|
||||
**CRITICAL — NEVER paste absolute local paths into the PR comment.** Strings like `/Users/…`, `/home/…`, `C:\…` are useless to every reviewer except you. Before posting, grep the final body for `/Users/`, `/home/`, `/tmp/`, `/private/`, `C:\`, `~/` and either drop those lines entirely or rewrite them as repo-relative paths (`autogpt_platform/backend/…`). The PR comment is an artifact reviewers on GitHub read — it must be self-contained on github.com. Keep local paths in `$RESULTS_DIR/test-report.md` for yourself; only copy the *content* they reference (excerpts, test names, log lines) into the PR comment, not the path.
|
||||
|
||||
**Pre-post sanity check** (paste after building the comment body, before `gh api ... comments`):
|
||||
|
||||
```bash
|
||||
# Reject any local-looking absolute path or home-dir shortcut in the body
|
||||
if grep -nE '(^|[^A-Za-z])(/Users/|/home/|/tmp/|/private/|C:\\|~/)[A-Za-z0-9]' "$COMMENT_FILE" ; then
|
||||
echo "ABORT: local filesystem paths detected in PR comment body."
|
||||
echo "Remove or rewrite as repo-relative (autogpt_platform/...) before posting."
|
||||
exit 1
|
||||
fi
|
||||
```
|
||||
|
||||
```bash
|
||||
# Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely)
|
||||
REPO="Significant-Gravitas/AutoGPT"
|
||||
|
||||
@@ -76,6 +76,7 @@ from backend.copilot.tools.models import (
|
||||
SetupRequirementsResponse,
|
||||
SuggestedGoalResponse,
|
||||
TaskDecompositionResponse,
|
||||
TodoWriteResponse,
|
||||
UnderstandingUpdatedResponse,
|
||||
)
|
||||
from backend.copilot.tracking import track_user_message
|
||||
@@ -1443,6 +1444,7 @@ ToolResponseUnion = (
|
||||
| MemorySearchResponse
|
||||
| MemoryForgetCandidatesResponse
|
||||
| MemoryForgetConfirmResponse
|
||||
| TodoWriteResponse
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -50,13 +50,13 @@ _VISIBLE_REASONING_TYPES = frozenset({"reasoning.text", "reasoning.summary"})
|
||||
# (~4,700 deltas per turn in one observed session, vs ~28 for Sonnet); without
|
||||
# coalescing, every chunk is one Redis ``xadd`` + one SSE frame + one React
|
||||
# re-render of the non-virtualised chat list, which paint-storms the browser
|
||||
# main thread and freezes the UI. Batching into ~32-char / ~40 ms windows
|
||||
# cuts the event rate ~100x while staying snappy enough that the Reasoning
|
||||
# main thread and freezes the UI. Batching into ~64-char / ~50 ms windows
|
||||
# cuts the event rate ~150x while staying snappy enough that the Reasoning
|
||||
# collapse still feels live (well under the ~100 ms perceptual threshold).
|
||||
# Per-delta persistence to ``session.messages`` stays granular — we only
|
||||
# coalesce the *wire* emission.
|
||||
_COALESCE_MIN_CHARS = 32
|
||||
_COALESCE_MAX_INTERVAL_MS = 40.0
|
||||
_COALESCE_MIN_CHARS = 64
|
||||
_COALESCE_MAX_INTERVAL_MS = 50.0
|
||||
|
||||
|
||||
class ReasoningDetail(BaseModel):
|
||||
@@ -242,6 +242,12 @@ class BaselineReasoningEmitter:
|
||||
fresh ``ChatMessage(role="reasoning")`` is appended and mutated
|
||||
in-place as further deltas arrive; :meth:`close` drops the reference
|
||||
but leaves the appended row intact.
|
||||
|
||||
``render_in_ui=False`` suppresses only the live wire events
|
||||
(``StreamReasoning*``); the ``role='reasoning'`` persistence row is
|
||||
still appended so ``convertChatSessionToUiMessages.ts`` can hydrate
|
||||
the reasoning bubble on reload. The state machine advances
|
||||
identically either way.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -250,21 +256,19 @@ class BaselineReasoningEmitter:
|
||||
*,
|
||||
coalesce_min_chars: int = _COALESCE_MIN_CHARS,
|
||||
coalesce_max_interval_ms: float = _COALESCE_MAX_INTERVAL_MS,
|
||||
render_in_ui: bool = True,
|
||||
) -> None:
|
||||
self._block_id: str = str(uuid.uuid4())
|
||||
self._open: bool = False
|
||||
self._session_messages = session_messages
|
||||
self._current_row: ChatMessage | None = None
|
||||
# Coalescing state — ``_pending_delta`` accumulates reasoning text
|
||||
# between wire flushes. Providers like Kimi K2.6 emit very fine-
|
||||
# grained chunks; batching them reduces Redis ``xadd`` + SSE + React
|
||||
# re-render load by ~100x for equivalent text output. Tuning knobs
|
||||
# are kwargs so tests can disable coalescing (``=0``) for
|
||||
# deterministic event assertions.
|
||||
# Coalescing state — tests can disable (``=0``) for deterministic
|
||||
# event assertions.
|
||||
self._coalesce_min_chars = coalesce_min_chars
|
||||
self._coalesce_max_interval_ms = coalesce_max_interval_ms
|
||||
self._pending_delta: str = ""
|
||||
self._last_flush_monotonic: float = 0.0
|
||||
self._render_in_ui = render_in_ui
|
||||
|
||||
@property
|
||||
def is_open(self) -> bool:
|
||||
@@ -296,8 +300,9 @@ class BaselineReasoningEmitter:
|
||||
# syscalls off the hot path without changing semantics.
|
||||
now = time.monotonic()
|
||||
if not self._open:
|
||||
events.append(StreamReasoningStart(id=self._block_id))
|
||||
events.append(StreamReasoningDelta(id=self._block_id, delta=text))
|
||||
if self._render_in_ui:
|
||||
events.append(StreamReasoningStart(id=self._block_id))
|
||||
events.append(StreamReasoningDelta(id=self._block_id, delta=text))
|
||||
self._open = True
|
||||
self._last_flush_monotonic = now
|
||||
if self._session_messages is not None:
|
||||
@@ -305,17 +310,15 @@ class BaselineReasoningEmitter:
|
||||
self._session_messages.append(self._current_row)
|
||||
return events
|
||||
|
||||
# Persist per-delta (no coalescing here — the session snapshot stays
|
||||
# consistent at every chunk boundary, independent of the wire
|
||||
# coalesce window).
|
||||
if self._current_row is not None:
|
||||
self._current_row.content = (self._current_row.content or "") + text
|
||||
|
||||
self._pending_delta += text
|
||||
if self._should_flush_pending(now):
|
||||
events.append(
|
||||
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
|
||||
)
|
||||
if self._render_in_ui:
|
||||
events.append(
|
||||
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
|
||||
)
|
||||
self._pending_delta = ""
|
||||
self._last_flush_monotonic = now
|
||||
return events
|
||||
@@ -348,12 +351,13 @@ class BaselineReasoningEmitter:
|
||||
if not self._open:
|
||||
return []
|
||||
events: list[StreamBaseResponse] = []
|
||||
if self._pending_delta:
|
||||
events.append(
|
||||
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
|
||||
)
|
||||
self._pending_delta = ""
|
||||
events.append(StreamReasoningEnd(id=self._block_id))
|
||||
if self._render_in_ui:
|
||||
if self._pending_delta:
|
||||
events.append(
|
||||
StreamReasoningDelta(id=self._block_id, delta=self._pending_delta)
|
||||
)
|
||||
events.append(StreamReasoningEnd(id=self._block_id))
|
||||
self._pending_delta = ""
|
||||
self._open = False
|
||||
self._block_id = str(uuid.uuid4())
|
||||
self._current_row = None
|
||||
|
||||
@@ -452,3 +452,63 @@ class TestReasoningPersistence:
|
||||
events = emitter.on_delta(_delta(reasoning="pure wire"))
|
||||
assert len(events) == 2 # start + delta, no crash
|
||||
# Nothing else to assert — just proves None session is supported.
|
||||
|
||||
|
||||
class TestBaselineReasoningEmitterRenderFlag:
|
||||
"""``render_in_ui=False`` must silence ``StreamReasoning*`` wire events
|
||||
AND drop persistence of ``role="reasoning"`` rows — the operator hides
|
||||
the collapse on both the live wire and on reload. Persistence is tied
|
||||
to the wire events because the frontend's hydration path unconditionally
|
||||
re-renders persisted reasoning rows; keeping them would make the flag a
|
||||
no-op post-reload. These tests pin the contract in both directions so
|
||||
future refactors can't flip only one half."""
|
||||
|
||||
def test_render_off_suppresses_start_and_delta(self):
|
||||
emitter = BaselineReasoningEmitter(render_in_ui=False)
|
||||
events = emitter.on_delta(_delta(reasoning="hidden"))
|
||||
# No wire events, but state advanced (is_open == True) so close()
|
||||
# below has something to rotate.
|
||||
assert events == []
|
||||
assert emitter.is_open is True
|
||||
|
||||
def test_render_off_suppresses_close_end(self):
|
||||
emitter = BaselineReasoningEmitter(render_in_ui=False)
|
||||
emitter.on_delta(_delta(reasoning="hidden"))
|
||||
events = emitter.close()
|
||||
assert events == []
|
||||
assert emitter.is_open is False
|
||||
|
||||
def test_render_off_still_persists(self):
|
||||
"""Persistence is decoupled from the render flag — session
|
||||
transcript always keeps the ``role="reasoning"`` row so audit
|
||||
and ``--resume``-equivalent replay never lose thinking text.
|
||||
The frontend gates rendering separately."""
|
||||
session: list[ChatMessage] = []
|
||||
emitter = BaselineReasoningEmitter(session, render_in_ui=False)
|
||||
|
||||
emitter.on_delta(_delta(reasoning="part one "))
|
||||
emitter.on_delta(_delta(reasoning="part two"))
|
||||
emitter.close()
|
||||
|
||||
assert len(session) == 1
|
||||
assert session[0].role == "reasoning"
|
||||
assert session[0].content == "part one part two"
|
||||
|
||||
def test_render_off_rotates_block_id_between_sessions(self):
|
||||
"""Even with wire events silenced the block id must rotate on close,
|
||||
otherwise a hypothetical mid-session flip would reuse a stale id."""
|
||||
emitter = BaselineReasoningEmitter(render_in_ui=False)
|
||||
emitter.on_delta(_delta(reasoning="first"))
|
||||
first_block_id = emitter._block_id
|
||||
emitter.close()
|
||||
emitter.on_delta(_delta(reasoning="second"))
|
||||
assert emitter._block_id != first_block_id
|
||||
|
||||
def test_render_on_is_default(self):
|
||||
"""Defaulting to True preserves backward compat — existing callers
|
||||
that don't pass the kwarg keep emitting wire events as before."""
|
||||
emitter = BaselineReasoningEmitter()
|
||||
events = emitter.on_delta(_delta(reasoning="hello"))
|
||||
assert len(events) == 2
|
||||
assert isinstance(events[0], StreamReasoningStart)
|
||||
assert isinstance(events[1], StreamReasoningDelta)
|
||||
|
||||
@@ -15,7 +15,7 @@ import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator, Mapping, Sequence
|
||||
from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
@@ -45,6 +45,8 @@ from backend.copilot.model import (
|
||||
maybe_append_user_message,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from backend.copilot.model_router import resolve_model
|
||||
from backend.copilot.moonshot import is_moonshot_model
|
||||
from backend.copilot.pending_message_helpers import (
|
||||
combine_pending_with_current,
|
||||
drain_pending_safe,
|
||||
@@ -82,7 +84,7 @@ from backend.copilot.service import (
|
||||
from backend.copilot.session_cleanup import prune_orphan_tool_calls
|
||||
from backend.copilot.thinking_stripper import ThinkingStripper as _ThinkingStripper
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.copilot.tools import execute_tool, get_available_tools
|
||||
from backend.copilot.tools import ToolGroup, execute_tool, get_available_tools
|
||||
from backend.copilot.tracking import track_user_message
|
||||
from backend.copilot.transcript import (
|
||||
STOP_REASON_END_TURN,
|
||||
@@ -318,20 +320,17 @@ def _filter_tools_by_permissions(
|
||||
]
|
||||
|
||||
|
||||
def _resolve_baseline_model(tier: CopilotLlmModel | None) -> str:
|
||||
async def _resolve_baseline_model(
|
||||
tier: CopilotLlmModel | None, user_id: str | None
|
||||
) -> str:
|
||||
"""Pick the model for the baseline path based on the per-request tier.
|
||||
|
||||
Baseline resolves independently of SDK via the ``fast_*_model`` cells
|
||||
of the (path, tier) matrix. ``'standard'`` / ``None`` picks Kimi
|
||||
K2.6 by default (cheap + OpenRouter ``reasoning`` support);
|
||||
``'advanced'`` picks Opus by default so the advanced tier is a clean
|
||||
A/B against the SDK advanced tier — same model, different path —
|
||||
isolating reasoning-wire + cache differences from model capability.
|
||||
Both defaults are overridable per ``CHAT_FAST_*_MODEL`` env vars.
|
||||
Delegates to :func:`copilot.model_router.resolve_model` so the
|
||||
``(fast, tier)`` cell is LD-overridable per user. ``None`` tier
|
||||
maps to ``"standard"``.
|
||||
"""
|
||||
if tier == "advanced":
|
||||
return config.fast_advanced_model
|
||||
return config.fast_standard_model
|
||||
tier_name = "advanced" if tier == "advanced" else "standard"
|
||||
return await resolve_model("fast", tier_name, user_id, config=config)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -343,7 +342,21 @@ class _BaselineStreamState:
|
||||
"""
|
||||
|
||||
model: str = ""
|
||||
pending_events: list[StreamBaseResponse] = field(default_factory=list)
|
||||
# Live delivery channel drained concurrently by ``stream_chat_completion_baseline``
|
||||
# so reasoning / text / tool events reach the SSE wire **during** the upstream
|
||||
# LLM stream, not after ``_baseline_llm_caller`` returns. Before this was a
|
||||
# ``list`` drained per ``tool_call_loop`` iteration, so any model with
|
||||
# extended thinking (Anthropic via OpenRouter, Moonshot, future reasoning
|
||||
# routes) froze the UI for the entire duration of each LLM round before
|
||||
# flushing the backlog in one burst. The queue is single-producer (the
|
||||
# streaming loop) / single-consumer (the outer async-gen yield loop);
|
||||
# ``None`` is the close sentinel.
|
||||
pending_events: asyncio.Queue[StreamBaseResponse | None] = field(
|
||||
default_factory=asyncio.Queue
|
||||
)
|
||||
# Mirror of every event put on ``pending_events`` — kept for unit tests that
|
||||
# inspect post-hoc what was emitted. Not consumed by production code.
|
||||
emitted_events: list[StreamBaseResponse] = field(default_factory=list)
|
||||
assistant_text: str = ""
|
||||
text_block_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
text_started: bool = False
|
||||
@@ -382,31 +395,71 @@ class _BaselineStreamState:
|
||||
# frontend's ``convertChatSessionToUiMessages`` relies on these
|
||||
# rows to render the Reasoning collapse after the AI SDK's
|
||||
# stream-end hydrate swaps in the DB-backed message list.
|
||||
self.reasoning_emitter = BaselineReasoningEmitter(self.session_messages)
|
||||
# ``render_in_ui`` is sourced from ``config.render_reasoning_in_ui``
|
||||
# so the operator can silence the reasoning collapse globally
|
||||
# without dropping the persisted audit trail.
|
||||
self.reasoning_emitter = BaselineReasoningEmitter(
|
||||
self.session_messages,
|
||||
render_in_ui=config.render_reasoning_in_ui,
|
||||
)
|
||||
|
||||
|
||||
def _emit(state: "_BaselineStreamState", event: StreamBaseResponse) -> None:
|
||||
"""Queue *event* for the live SSE wire AND mirror into ``emitted_events``.
|
||||
|
||||
Single helper so every streaming producer (LLM stream loop, tool executor,
|
||||
conversation updater) posts to the same single-consumer queue. The mirror
|
||||
list is read-only from production code — it exists so unit tests can assert
|
||||
on the full sequence emitted during one call.
|
||||
"""
|
||||
state.pending_events.put_nowait(event)
|
||||
state.emitted_events.append(event)
|
||||
|
||||
|
||||
def _emit_all(
|
||||
state: "_BaselineStreamState", events: Iterable[StreamBaseResponse]
|
||||
) -> None:
|
||||
"""Queue *events* in order — convenience for emitter batches."""
|
||||
for event in events:
|
||||
_emit(state, event)
|
||||
|
||||
|
||||
def _is_anthropic_model(model: str) -> bool:
|
||||
"""Return True if *model* routes to Anthropic (native or via OpenRouter).
|
||||
|
||||
Cache-control markers on message content + the ``anthropic-beta`` header
|
||||
are Anthropic-specific. OpenAI rejects the unknown ``cache_control``
|
||||
field with a 400 ("Extra inputs are not permitted") and Grok / other
|
||||
providers behave similarly. OpenRouter strips unknown headers but
|
||||
passes through ``cache_control`` on the body regardless of provider —
|
||||
which would also fail when OpenRouter routes to a non-Anthropic model.
|
||||
|
||||
Examples that return True:
|
||||
- ``anthropic/claude-sonnet-4-6`` (OpenRouter route)
|
||||
- ``claude-3-5-sonnet-20241022`` (direct Anthropic API)
|
||||
- ``anthropic.claude-3-5-sonnet`` (Bedrock-style)
|
||||
|
||||
False for ``openai/gpt-4o``, ``google/gemini-2.5-pro``, ``xai/grok-4``
|
||||
etc.
|
||||
etc. Moonshot is False here too even though Moonshot's
|
||||
Anthropic-compat endpoint honours ``cache_control`` — use
|
||||
:func:`_supports_prompt_cache_markers` for the cache-gating decision,
|
||||
which also allows Moonshot routes. This function stays scoped to
|
||||
"genuinely Anthropic" so callers that need the stricter check (e.g.
|
||||
``anthropic-beta`` header emission) keep their existing semantics.
|
||||
"""
|
||||
lowered = model.lower()
|
||||
return "claude" in lowered or lowered.startswith("anthropic")
|
||||
|
||||
|
||||
def _supports_prompt_cache_markers(model: str) -> bool:
|
||||
"""Return True when *model* accepts Anthropic-style ``cache_control``.
|
||||
|
||||
Superset of :func:`_is_anthropic_model` — also allows Moonshot
|
||||
(``moonshotai/*``), whose OpenRouter Anthropic-compat endpoint
|
||||
honours the marker and empirically lifts cache hit rate on
|
||||
continuation turns from near-zero (Moonshot's own automatic prefix
|
||||
cache, which drifts readily) to the 60-95% Anthropic ballpark.
|
||||
|
||||
OpenAI / Grok / Gemini still 400 on ``cache_control``, so this
|
||||
function returns False for those providers — add new vendors here
|
||||
only after verifying their endpoint accepts the field.
|
||||
"""
|
||||
return _is_anthropic_model(model) or is_moonshot_model(model)
|
||||
|
||||
|
||||
def _fresh_ephemeral_cache_control() -> dict[str, str]:
|
||||
"""Return a FRESH ephemeral ``cache_control`` dict each call.
|
||||
|
||||
@@ -519,7 +572,7 @@ async def _baseline_llm_caller(
|
||||
|
||||
Extracted from ``stream_chat_completion_baseline`` for readability.
|
||||
"""
|
||||
state.pending_events.append(StreamStartStep())
|
||||
_emit(state, StreamStartStep())
|
||||
# Fresh thinking-strip state per round so a malformed unclosed
|
||||
# block in one LLM call cannot silently drop content in the next.
|
||||
state.thinking_stripper = _ThinkingStripper()
|
||||
@@ -527,19 +580,24 @@ async def _baseline_llm_caller(
|
||||
round_text = ""
|
||||
try:
|
||||
client = _get_openai_client()
|
||||
# Cache markers are Anthropic-specific. For OpenAI/Grok/other
|
||||
# providers, leaving them on would trigger a 400 ("Extra inputs
|
||||
# are not permitted" on cache_control). Tools were precomputed
|
||||
# in stream_chat_completion_baseline via _mark_tools_with_cache_control
|
||||
# (only when the model was Anthropic), so on non-Anthropic routes
|
||||
# tools ship without cache_control on the last entry too.
|
||||
# Cache markers are accepted by Anthropic AND Moonshot (via OR's
|
||||
# Anthropic-compat endpoint). OpenAI/Grok/Gemini 400 on the
|
||||
# unknown ``cache_control`` field — tools were precomputed in
|
||||
# stream_chat_completion_baseline via _mark_tools_with_cache_control
|
||||
# with the same gate, so on unsupported routes tools ship
|
||||
# unmarked too.
|
||||
#
|
||||
# `extra_body` `usage.include=true` asks OpenRouter to embed the real
|
||||
# generation cost into the final usage chunk — required by the
|
||||
# cost-based rate limiter in routes.py. Separate from the Anthropic
|
||||
# The ``anthropic-beta`` header is only emitted for genuinely
|
||||
# Anthropic routes (see :func:`_is_anthropic_model`) — Moonshot
|
||||
# doesn't need the beta header; sending it is a no-op but we
|
||||
# keep the check strict for clarity.
|
||||
#
|
||||
# `extra_body` `usage.include=true` asks OpenRouter to embed the
|
||||
# real generation cost into the final usage chunk — required by
|
||||
# the cost-based rate limiter in routes.py. Separate from the
|
||||
# caching headers, always sent.
|
||||
is_anthropic = _is_anthropic_model(state.model)
|
||||
if is_anthropic:
|
||||
supports_cache = _supports_prompt_cache_markers(state.model)
|
||||
if supports_cache:
|
||||
# Build the cached system dict once per session and splice it in
|
||||
# on each round. The full ``messages`` list grows with every
|
||||
# tool call, so copying the entire list just to mutate index 0
|
||||
@@ -555,7 +613,11 @@ async def _baseline_llm_caller(
|
||||
final_messages = [state.cached_system_message, *messages[1:]]
|
||||
else:
|
||||
final_messages = messages
|
||||
extra_headers = _fresh_anthropic_caching_headers()
|
||||
extra_headers = (
|
||||
_fresh_anthropic_caching_headers()
|
||||
if _is_anthropic_model(state.model)
|
||||
else None
|
||||
)
|
||||
else:
|
||||
final_messages = messages
|
||||
extra_headers = None
|
||||
@@ -621,31 +683,30 @@ async def _baseline_llm_caller(
|
||||
if not delta:
|
||||
continue
|
||||
|
||||
state.pending_events.extend(state.reasoning_emitter.on_delta(delta))
|
||||
_emit_all(state, state.reasoning_emitter.on_delta(delta))
|
||||
|
||||
if delta.content:
|
||||
# Text and reasoning must not interleave on the wire — the
|
||||
# AI SDK maps distinct start/end pairs to distinct UI
|
||||
# parts. Close any open reasoning block before emitting
|
||||
# the first text delta of this run.
|
||||
state.pending_events.extend(state.reasoning_emitter.close())
|
||||
_emit_all(state, state.reasoning_emitter.close())
|
||||
emit = state.thinking_stripper.process(delta.content)
|
||||
if emit:
|
||||
if not state.text_started:
|
||||
state.pending_events.append(
|
||||
StreamTextStart(id=state.text_block_id)
|
||||
)
|
||||
_emit(state, StreamTextStart(id=state.text_block_id))
|
||||
state.text_started = True
|
||||
round_text += emit
|
||||
state.pending_events.append(
|
||||
StreamTextDelta(id=state.text_block_id, delta=emit)
|
||||
_emit(
|
||||
state,
|
||||
StreamTextDelta(id=state.text_block_id, delta=emit),
|
||||
)
|
||||
|
||||
if delta.tool_calls:
|
||||
# Same rule as the text branch: close any open reasoning
|
||||
# block before a tool_use starts so the AI SDK treats
|
||||
# reasoning and tool-use as distinct parts.
|
||||
state.pending_events.extend(state.reasoning_emitter.close())
|
||||
_emit_all(state, state.reasoning_emitter.close())
|
||||
for tc in delta.tool_calls:
|
||||
idx = tc.index
|
||||
if idx not in tool_calls_by_index:
|
||||
@@ -676,19 +737,17 @@ async def _baseline_llm_caller(
|
||||
# ``async for chunk in response`` would otherwise leave reasoning
|
||||
# and/or text unterminated and only ``StreamFinishStep`` emitted —
|
||||
# the Reasoning / Text collapses would never finalise.
|
||||
state.pending_events.extend(state.reasoning_emitter.close())
|
||||
_emit_all(state, state.reasoning_emitter.close())
|
||||
# Flush any buffered text held back by the thinking stripper.
|
||||
tail = state.thinking_stripper.flush()
|
||||
if tail:
|
||||
if not state.text_started:
|
||||
state.pending_events.append(StreamTextStart(id=state.text_block_id))
|
||||
_emit(state, StreamTextStart(id=state.text_block_id))
|
||||
state.text_started = True
|
||||
round_text += tail
|
||||
state.pending_events.append(
|
||||
StreamTextDelta(id=state.text_block_id, delta=tail)
|
||||
)
|
||||
_emit(state, StreamTextDelta(id=state.text_block_id, delta=tail))
|
||||
if state.text_started:
|
||||
state.pending_events.append(StreamTextEnd(id=state.text_block_id))
|
||||
_emit(state, StreamTextEnd(id=state.text_block_id))
|
||||
state.text_started = False
|
||||
state.text_block_id = str(uuid.uuid4())
|
||||
# Always persist partial text so the session history stays consistent,
|
||||
@@ -696,7 +755,7 @@ async def _baseline_llm_caller(
|
||||
state.assistant_text += round_text
|
||||
# Always emit StreamFinishStep to match the StreamStartStep,
|
||||
# even if an exception occurred during streaming.
|
||||
state.pending_events.append(StreamFinishStep())
|
||||
_emit(state, StreamFinishStep())
|
||||
|
||||
# Convert to shared format
|
||||
llm_tool_calls = [
|
||||
@@ -738,13 +797,14 @@ async def _baseline_tool_executor(
|
||||
except orjson.JSONDecodeError as parse_err:
|
||||
parse_error = f"Invalid JSON arguments for tool '{tool_name}': {parse_err}"
|
||||
logger.warning("[Baseline] %s", parse_error)
|
||||
state.pending_events.append(
|
||||
_emit(
|
||||
state,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=parse_error,
|
||||
success=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
@@ -753,15 +813,17 @@ async def _baseline_tool_executor(
|
||||
is_error=True,
|
||||
)
|
||||
|
||||
state.pending_events.append(
|
||||
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name)
|
||||
_emit(
|
||||
state,
|
||||
StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name),
|
||||
)
|
||||
state.pending_events.append(
|
||||
_emit(
|
||||
state,
|
||||
StreamToolInputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
input=tool_args,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Announce the tool call to the session so in-turn guards like
|
||||
@@ -785,7 +847,7 @@ async def _baseline_tool_executor(
|
||||
session=session,
|
||||
tool_call_id=tool_call_id,
|
||||
)
|
||||
state.pending_events.append(result)
|
||||
_emit(state, result)
|
||||
tool_output = (
|
||||
result.output if isinstance(result.output, str) else str(result.output)
|
||||
)
|
||||
@@ -802,13 +864,14 @@ async def _baseline_tool_executor(
|
||||
error_output,
|
||||
exc_info=True,
|
||||
)
|
||||
state.pending_events.append(
|
||||
_emit(
|
||||
state,
|
||||
StreamToolOutputAvailable(
|
||||
toolCallId=tool_call_id,
|
||||
toolName=tool_name,
|
||||
output=error_output,
|
||||
success=False,
|
||||
)
|
||||
),
|
||||
)
|
||||
return ToolCallResult(
|
||||
tool_call_id=tool_call_id,
|
||||
@@ -1317,7 +1380,7 @@ async def stream_chat_completion_baseline(
|
||||
# Select model based on the per-request tier toggle (standard / advanced).
|
||||
# The path (fast vs extended_thinking) is already decided — we're in the
|
||||
# baseline (fast) path; ``mode`` is accepted for logging parity only.
|
||||
active_model = _resolve_baseline_model(model)
|
||||
active_model = await _resolve_baseline_model(model, user_id)
|
||||
|
||||
# --- E2B sandbox setup (feature parity with SDK path) ---
|
||||
e2b_sandbox = None
|
||||
@@ -1583,7 +1646,10 @@ async def stream_chat_completion_baseline(
|
||||
openai_messages[i]["content"] = text
|
||||
break
|
||||
|
||||
tools = get_available_tools()
|
||||
disabled_tool_groups: list[ToolGroup] = []
|
||||
if not graphiti_enabled:
|
||||
disabled_tool_groups.append("graphiti")
|
||||
tools = get_available_tools(disabled_groups=disabled_tool_groups)
|
||||
|
||||
# --- Permission filtering ---
|
||||
if permissions is not None:
|
||||
@@ -1594,9 +1660,10 @@ async def stream_chat_completion_baseline(
|
||||
# _baseline_llm_caller) avoids re-copying ~43 tool dicts on every LLM
|
||||
# round of the tool-call loop.
|
||||
#
|
||||
# Only apply to Anthropic routes — OpenAI/Grok/other providers would
|
||||
# 400 on the unknown ``cache_control`` field inside tool definitions.
|
||||
if _is_anthropic_model(active_model):
|
||||
# Applies to Anthropic AND Moonshot routes — OpenAI/Grok/Gemini 400
|
||||
# on the unknown ``cache_control`` field inside tool definitions, so
|
||||
# the gate stays narrow (see :func:`_supports_prompt_cache_markers`).
|
||||
if _supports_prompt_cache_markers(active_model):
|
||||
tools = cast(
|
||||
list[ChatCompletionToolParam], _mark_tools_with_cache_control(tools)
|
||||
)
|
||||
@@ -1660,139 +1727,172 @@ async def stream_chat_completion_baseline(
|
||||
state=state,
|
||||
)
|
||||
|
||||
try:
|
||||
loop_result = None
|
||||
async for loop_result in tool_call_loop(
|
||||
messages=openai_messages,
|
||||
tools=tools,
|
||||
llm_call=_bound_llm_caller,
|
||||
execute_tool=_bound_tool_executor,
|
||||
update_conversation=_bound_conversation_updater,
|
||||
max_iterations=_MAX_TOOL_ROUNDS,
|
||||
):
|
||||
# Drain buffered events after each iteration (real-time streaming)
|
||||
for evt in state.pending_events:
|
||||
yield evt
|
||||
state.pending_events.clear()
|
||||
# Run the tool-call loop concurrently with the event consumer so
|
||||
# ``StreamReasoning*`` / ``StreamText*`` deltas emitted inside
|
||||
# ``_baseline_llm_caller`` reach the SSE wire DURING the upstream LLM
|
||||
# stream instead of only at iteration boundaries. Any reasoning route
|
||||
# that streams for several minutes per round (extended thinking on
|
||||
# Anthropic / Moonshot / future providers) would otherwise freeze the
|
||||
# UI for the whole window before flushing the backlog in one burst.
|
||||
loop_result_holder: list[Any] = [None]
|
||||
loop_task: asyncio.Task[None] | None = None
|
||||
|
||||
# Inject any messages the user queued while the turn was
|
||||
# running. ``tool_call_loop`` mutates ``openai_messages``
|
||||
# in-place, so appending here means the model sees the new
|
||||
# messages on its next LLM call.
|
||||
#
|
||||
# IMPORTANT: skip when the loop has already finished (no
|
||||
# more LLM calls are coming). ``tool_call_loop`` yields
|
||||
# a final ``ToolCallLoopResult`` on both paths:
|
||||
# - natural finish: ``finished_naturally=True``
|
||||
# - hit max_iterations: ``finished_naturally=False``
|
||||
# and ``iterations >= max_iterations``
|
||||
# In either case the loop is about to return on the next
|
||||
# ``async for`` step, so draining here would silently
|
||||
# lose the message (the user sees 202 but the model never
|
||||
# reads the text). Those messages stay in the buffer and
|
||||
# get picked up at the start of the next turn.
|
||||
is_final_yield = (
|
||||
loop_result.finished_naturally
|
||||
or loop_result.iterations >= _MAX_TOOL_ROUNDS
|
||||
)
|
||||
if is_final_yield:
|
||||
continue
|
||||
try:
|
||||
pending = await drain_pending_messages(session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[Baseline] mid-loop drain_pending_messages failed for session %s",
|
||||
session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
pending = []
|
||||
if pending:
|
||||
# Flush any buffered assistant/tool messages from completed
|
||||
# rounds into session.messages BEFORE appending the pending
|
||||
# user message. ``_baseline_conversation_updater`` only
|
||||
# records assistant+tool rounds into ``state.session_messages``
|
||||
# — they are normally batch-flushed in the finally block.
|
||||
# Without this in-order flush, the mid-loop pending user
|
||||
# message lands before the preceding round's assistant/tool
|
||||
# entries, producing chronologically-wrong session.messages
|
||||
# on persist (user interposed between an assistant tool_call
|
||||
# and its tool-result), which breaks OpenAI tool-call ordering
|
||||
# invariants on the next turn's replay.
|
||||
async def _run_tool_call_loop() -> None:
|
||||
# Read/write the current session via ``_session_holder`` so this
|
||||
# closure doesn't need to ``nonlocal session`` — pyright can't narrow
|
||||
# the outer ``session: ChatSession | None`` through a nested scope,
|
||||
# but the holder is typed non-optional after the preflight guard
|
||||
# above.
|
||||
try:
|
||||
async for loop_result in tool_call_loop(
|
||||
messages=openai_messages,
|
||||
tools=tools,
|
||||
llm_call=_bound_llm_caller,
|
||||
execute_tool=_bound_tool_executor,
|
||||
update_conversation=_bound_conversation_updater,
|
||||
max_iterations=_MAX_TOOL_ROUNDS,
|
||||
):
|
||||
loop_result_holder[0] = loop_result
|
||||
# Inject any messages the user queued while the turn was
|
||||
# running. ``tool_call_loop`` mutates ``openai_messages``
|
||||
# in-place, so appending here means the model sees the new
|
||||
# messages on its next LLM call.
|
||||
#
|
||||
# Also persist any assistant text from text-only rounds (rounds
|
||||
# with no tool calls, which ``_baseline_conversation_updater``
|
||||
# does NOT record in session_messages). If we only update
|
||||
# ``_flushed_assistant_text_len`` without persisting the text,
|
||||
# that text is silently lost: the finally block only appends
|
||||
# assistant_text[_flushed_assistant_text_len:], so text generated
|
||||
# before this drain never reaches session.messages.
|
||||
recorded_text = "".join(
|
||||
m.content or ""
|
||||
for m in state.session_messages
|
||||
if m.role == "assistant"
|
||||
# IMPORTANT: skip when the loop has already finished (no
|
||||
# more LLM calls are coming). ``tool_call_loop`` yields
|
||||
# a final ``ToolCallLoopResult`` on both paths:
|
||||
# - natural finish: ``finished_naturally=True``
|
||||
# - hit max_iterations: ``finished_naturally=False``
|
||||
# and ``iterations >= max_iterations``
|
||||
# In either case the loop is about to return on the next
|
||||
# ``async for`` step, so draining here would silently
|
||||
# lose the message (the user sees 202 but the model never
|
||||
# reads the text). Those messages stay in the buffer and
|
||||
# get picked up at the start of the next turn.
|
||||
is_final_yield = (
|
||||
loop_result.finished_naturally
|
||||
or loop_result.iterations >= _MAX_TOOL_ROUNDS
|
||||
)
|
||||
unflushed_text = state.assistant_text[
|
||||
state._flushed_assistant_text_len :
|
||||
]
|
||||
text_only_text = (
|
||||
unflushed_text[len(recorded_text) :]
|
||||
if unflushed_text.startswith(recorded_text)
|
||||
else unflushed_text
|
||||
)
|
||||
if text_only_text.strip():
|
||||
session.messages.append(
|
||||
ChatMessage(role="assistant", content=text_only_text)
|
||||
if is_final_yield:
|
||||
continue
|
||||
try:
|
||||
pending = await drain_pending_messages(session_id)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[Baseline] mid-loop drain_pending_messages failed for "
|
||||
"session %s",
|
||||
session_id,
|
||||
exc_info=True,
|
||||
)
|
||||
for _buffered in state.session_messages:
|
||||
session.messages.append(_buffered)
|
||||
state.session_messages.clear()
|
||||
# Record how much assistant_text has been covered by the
|
||||
# structured entries just flushed, so the finally block's
|
||||
# final-text dedup doesn't re-append rounds already persisted.
|
||||
state._flushed_assistant_text_len = len(state.assistant_text)
|
||||
pending = []
|
||||
if pending:
|
||||
# Flush any buffered assistant/tool messages from completed
|
||||
# rounds into session.messages BEFORE appending the pending
|
||||
# user message. ``_baseline_conversation_updater`` only
|
||||
# records assistant+tool rounds into ``state.session_messages``
|
||||
# — they are normally batch-flushed in the finally block.
|
||||
# Without this in-order flush, the mid-loop pending user
|
||||
# message lands before the preceding round's assistant/tool
|
||||
# entries, producing chronologically-wrong session.messages
|
||||
# on persist (user interposed between an assistant tool_call
|
||||
# and its tool-result), which breaks OpenAI tool-call ordering
|
||||
# invariants on the next turn's replay.
|
||||
#
|
||||
# Also persist any assistant text from text-only rounds (rounds
|
||||
# with no tool calls, which ``_baseline_conversation_updater``
|
||||
# does NOT record in session_messages). If we only update
|
||||
# ``_flushed_assistant_text_len`` without persisting the text,
|
||||
# that text is silently lost: the finally block only appends
|
||||
# assistant_text[_flushed_assistant_text_len:], so text generated
|
||||
# before this drain never reaches session.messages.
|
||||
recorded_text = "".join(
|
||||
m.content or ""
|
||||
for m in state.session_messages
|
||||
if m.role == "assistant"
|
||||
)
|
||||
unflushed_text = state.assistant_text[
|
||||
state._flushed_assistant_text_len :
|
||||
]
|
||||
text_only_text = (
|
||||
unflushed_text[len(recorded_text) :]
|
||||
if unflushed_text.startswith(recorded_text)
|
||||
else unflushed_text
|
||||
)
|
||||
current_session = _session_holder[0]
|
||||
if text_only_text.strip():
|
||||
current_session.messages.append(
|
||||
ChatMessage(role="assistant", content=text_only_text)
|
||||
)
|
||||
for _buffered in state.session_messages:
|
||||
current_session.messages.append(_buffered)
|
||||
state.session_messages.clear()
|
||||
# Record how much assistant_text has been covered by the
|
||||
# structured entries just flushed, so the finally block's
|
||||
# final-text dedup doesn't re-append rounds already persisted.
|
||||
state._flushed_assistant_text_len = len(state.assistant_text)
|
||||
|
||||
# Persist the assistant/tool flush BEFORE the pending append
|
||||
# so a later pending-persist failure can roll back the
|
||||
# pending rows without also discarding LLM output.
|
||||
session = await persist_session_safe(session, "[Baseline]")
|
||||
# ``upsert_chat_session`` may return a *new* ``ChatSession``
|
||||
# instance (e.g. when a concurrent title update has written a
|
||||
# newer title to Redis, it returns ``session.model_copy``).
|
||||
# Keep ``_session_holder`` in sync so subsequent tool rounds
|
||||
# executed via ``_bound_tool_executor`` see the fresh session
|
||||
# — any tool-side mutations on the stale object would be
|
||||
# discarded when the new one is persisted in the ``finally``.
|
||||
_session_holder[0] = session
|
||||
# Persist the assistant/tool flush BEFORE the pending append
|
||||
# so a later pending-persist failure can roll back the
|
||||
# pending rows without also discarding LLM output.
|
||||
current_session = await persist_session_safe(
|
||||
current_session, "[Baseline]"
|
||||
)
|
||||
# ``upsert_chat_session`` may return a *new* ``ChatSession``
|
||||
# instance (e.g. when a concurrent title update has written a
|
||||
# newer title to Redis, it returns ``session.model_copy``).
|
||||
# Keep ``_session_holder`` in sync so subsequent tool rounds
|
||||
# executed via ``_bound_tool_executor`` see the fresh session
|
||||
# — any tool-side mutations on the stale object would be
|
||||
# discarded when the new one is persisted in the ``finally``.
|
||||
_session_holder[0] = current_session
|
||||
|
||||
# ``format_pending_as_user_message`` embeds file attachments
|
||||
# and context URL/page content into the content string so
|
||||
# the in-session transcript is a faithful copy of what the
|
||||
# model actually saw. We also mirror each push into
|
||||
# ``openai_messages`` so the model's next LLM round sees it.
|
||||
#
|
||||
# Pre-compute the formatted dicts once so both the openai
|
||||
# messages append and the content_of lookup inside the
|
||||
# shared helper use the same string — and so ``on_rollback``
|
||||
# can trim ``openai_messages`` to the recorded anchor.
|
||||
formatted_by_pm = {
|
||||
id(pm): format_pending_as_user_message(pm) for pm in pending
|
||||
}
|
||||
_openai_anchor = len(openai_messages)
|
||||
for pm in pending:
|
||||
openai_messages.append(formatted_by_pm[id(pm)])
|
||||
# ``format_pending_as_user_message`` embeds file attachments
|
||||
# and context URL/page content into the content string so
|
||||
# the in-session transcript is a faithful copy of what the
|
||||
# model actually saw. We also mirror each push into
|
||||
# ``openai_messages`` so the model's next LLM round sees it.
|
||||
#
|
||||
# Pre-compute the formatted dicts once so both the openai
|
||||
# messages append and the content_of lookup inside the
|
||||
# shared helper use the same string — and so ``on_rollback``
|
||||
# can trim ``openai_messages`` to the recorded anchor.
|
||||
formatted_by_pm = {
|
||||
id(pm): format_pending_as_user_message(pm) for pm in pending
|
||||
}
|
||||
_openai_anchor = len(openai_messages)
|
||||
for pm in pending:
|
||||
openai_messages.append(formatted_by_pm[id(pm)])
|
||||
|
||||
def _trim_openai_on_rollback(_session_anchor: int) -> None:
|
||||
del openai_messages[_openai_anchor:]
|
||||
def _trim_openai_on_rollback(_session_anchor: int) -> None:
|
||||
del openai_messages[_openai_anchor:]
|
||||
|
||||
await persist_pending_as_user_rows(
|
||||
session,
|
||||
transcript_builder,
|
||||
pending,
|
||||
log_prefix="[Baseline]",
|
||||
content_of=lambda pm: formatted_by_pm[id(pm)]["content"],
|
||||
on_rollback=_trim_openai_on_rollback,
|
||||
)
|
||||
await persist_pending_as_user_rows(
|
||||
current_session,
|
||||
transcript_builder,
|
||||
pending,
|
||||
log_prefix="[Baseline]",
|
||||
content_of=lambda pm: formatted_by_pm[id(pm)]["content"],
|
||||
on_rollback=_trim_openai_on_rollback,
|
||||
)
|
||||
finally:
|
||||
# Always post the sentinel so the outer consumer exits — even if
|
||||
# ``tool_call_loop`` raised. ``_baseline_llm_caller``'s own
|
||||
# finally block has already pushed ``StreamReasoningEnd`` /
|
||||
# ``StreamTextEnd`` / ``StreamFinishStep`` at this point, so the
|
||||
# sentinel only terminates the consumer; it does not suppress
|
||||
# any still-unflushed events.
|
||||
state.pending_events.put_nowait(None)
|
||||
|
||||
loop_task = asyncio.create_task(_run_tool_call_loop())
|
||||
try:
|
||||
while True:
|
||||
evt = await state.pending_events.get()
|
||||
if evt is None:
|
||||
break
|
||||
yield evt
|
||||
# Sentinel received — surface any exception the inner task hit.
|
||||
await loop_task
|
||||
loop_result = loop_result_holder[0]
|
||||
if loop_result and not loop_result.finished_naturally:
|
||||
limit_msg = (
|
||||
f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds "
|
||||
@@ -1803,25 +1903,34 @@ async def stream_chat_completion_baseline(
|
||||
errorText=limit_msg,
|
||||
code="baseline_tool_round_limit",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
_stream_error = True
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True)
|
||||
# ``_baseline_llm_caller``'s finally block closes any open
|
||||
# reasoning / text blocks and appends ``StreamFinishStep`` on
|
||||
# both normal and exception paths, so pending_events already has
|
||||
# the correct protocol ordering:
|
||||
# StreamStartStep -> StreamReasoningStart -> ...deltas... ->
|
||||
# StreamReasoningEnd -> StreamTextStart -> ...deltas... ->
|
||||
# StreamTextEnd -> StreamFinishStep
|
||||
# Just drain what's buffered, then yield the error.
|
||||
for evt in state.pending_events:
|
||||
yield evt
|
||||
state.pending_events.clear()
|
||||
# Drain any queued tail events (reasoning/text close + finish step)
|
||||
# that ``_baseline_llm_caller``'s finally block pushed before the
|
||||
# sentinel arrived — without this the frontend would be missing the
|
||||
# matching end / finish parts for the partial round.
|
||||
while not state.pending_events.empty():
|
||||
evt = state.pending_events.get_nowait()
|
||||
if evt is not None:
|
||||
yield evt
|
||||
yield StreamError(errorText=error_msg, code="baseline_error")
|
||||
# Still persist whatever we got
|
||||
finally:
|
||||
# Cancel the inner task if we're unwinding early (client disconnect,
|
||||
# unexpected error in the consumer) so it doesn't keep streaming
|
||||
# tokens into a dead queue.
|
||||
if loop_task is not None and not loop_task.done():
|
||||
loop_task.cancel()
|
||||
try:
|
||||
await loop_task
|
||||
except (asyncio.CancelledError, Exception):
|
||||
pass
|
||||
# Re-sync the outer ``session`` binding in case the inner task
|
||||
# reassigned it via a mid-loop ``persist_session_safe`` call.
|
||||
session = _session_holder[0]
|
||||
|
||||
# In-flight tool-call announcements are only meaningful for the
|
||||
# current turn; clear at the top of the outer finally so the next
|
||||
# turn starts with a clean scratch buffer even if one of the
|
||||
|
||||
@@ -21,6 +21,7 @@ from backend.copilot.baseline.service import (
|
||||
_is_anthropic_model,
|
||||
_mark_system_message_with_cache_control,
|
||||
_mark_tools_with_cache_control,
|
||||
_supports_prompt_cache_markers,
|
||||
)
|
||||
from backend.copilot.model import ChatMessage
|
||||
from backend.copilot.response_model import (
|
||||
@@ -39,7 +40,10 @@ from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallRe
|
||||
class TestBaselineStreamState:
|
||||
def test_defaults(self):
|
||||
state = _BaselineStreamState()
|
||||
assert state.pending_events == []
|
||||
# ``pending_events`` is an asyncio.Queue now (live SSE channel).
|
||||
# The durable inspection view is ``emitted_events``.
|
||||
assert state.pending_events.empty()
|
||||
assert state.emitted_events == []
|
||||
assert state.assistant_text == ""
|
||||
assert state.text_started is False
|
||||
assert state.turn_prompt_tokens == 0
|
||||
@@ -1687,7 +1691,7 @@ class TestBaselineReasoningStreaming:
|
||||
state=state,
|
||||
)
|
||||
|
||||
types = [type(e).__name__ for e in state.pending_events]
|
||||
types = [type(e).__name__ for e in state.emitted_events]
|
||||
assert "StreamReasoningStart" in types
|
||||
assert "StreamReasoningDelta" in types
|
||||
assert "StreamReasoningEnd" in types
|
||||
@@ -1702,14 +1706,14 @@ class TestBaselineReasoningStreaming:
|
||||
# a fresh id after the reasoning-end rotation.
|
||||
reasoning_ids = {
|
||||
e.id
|
||||
for e in state.pending_events
|
||||
for e in state.emitted_events
|
||||
if isinstance(
|
||||
e, (StreamReasoningStart, StreamReasoningDelta, StreamReasoningEnd)
|
||||
)
|
||||
}
|
||||
text_ids = {
|
||||
e.id
|
||||
for e in state.pending_events
|
||||
for e in state.emitted_events
|
||||
if isinstance(e, (StreamTextStart, StreamTextDelta, StreamTextEnd))
|
||||
}
|
||||
assert len(reasoning_ids) == 1
|
||||
@@ -1717,7 +1721,7 @@ class TestBaselineReasoningStreaming:
|
||||
assert reasoning_ids.isdisjoint(text_ids)
|
||||
|
||||
combined = "".join(
|
||||
e.delta for e in state.pending_events if isinstance(e, StreamReasoningDelta)
|
||||
e.delta for e in state.emitted_events if isinstance(e, StreamReasoningDelta)
|
||||
)
|
||||
assert combined == "thinking... more"
|
||||
|
||||
@@ -1759,7 +1763,7 @@ class TestBaselineReasoningStreaming:
|
||||
|
||||
# A reasoning-end must have been emitted — this is the tool_calls
|
||||
# branch's responsibility, not the stream-end cleanup.
|
||||
types = [type(e).__name__ for e in state.pending_events]
|
||||
types = [type(e).__name__ for e in state.emitted_events]
|
||||
assert "StreamReasoningStart" in types
|
||||
assert "StreamReasoningEnd" in types
|
||||
|
||||
@@ -1802,7 +1806,7 @@ class TestBaselineReasoningStreaming:
|
||||
state=state,
|
||||
)
|
||||
|
||||
types = [type(e).__name__ for e in state.pending_events]
|
||||
types = [type(e).__name__ for e in state.emitted_events]
|
||||
# The reasoning block was opened, the exception fired, and the
|
||||
# finally block must have closed it before emitting the finish
|
||||
# step.
|
||||
@@ -1861,12 +1865,14 @@ class TestBaselineReasoningStreaming:
|
||||
assert "reasoning" not in extra_body
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_kimi_route_sends_reasoning_but_no_cache_control(self):
|
||||
"""Kimi K2.6 is the default fast_model and sends ``reasoning`` via
|
||||
OpenRouter's unified extension. It must NOT receive ``cache_control``
|
||||
markers or the ``anthropic-beta`` header — Moonshot uses its own
|
||||
auto-caching and those Anthropic-only fields would either get
|
||||
silently dropped or (worst case) 400 on a future provider change."""
|
||||
async def test_kimi_route_sends_reasoning_and_cache_control(self):
|
||||
"""Kimi K2.6 (Moonshot via OpenRouter's Anthropic-compat endpoint)
|
||||
accepts ``cache_control: {type: ephemeral}`` on the system block
|
||||
and the last tool — the endpoint honours the marker and lifts
|
||||
cache hit rate on continuation turns from near-zero (Moonshot's
|
||||
auto-caching drifts) to the Anthropic ~60-95% ballpark. The
|
||||
``anthropic-beta`` header stays off because Moonshot doesn't need
|
||||
it; OpenRouter would strip the unknown header anyway."""
|
||||
state = _BaselineStreamState(model="moonshotai/kimi-k2.6")
|
||||
|
||||
mock_client = MagicMock()
|
||||
@@ -1898,15 +1904,29 @@ class TestBaselineReasoningStreaming:
|
||||
# cheap-but-still-reasoning-capable path.
|
||||
assert "reasoning" in extra_body
|
||||
assert extra_body["reasoning"]["max_tokens"] > 0
|
||||
# Anthropic-only fields stay off.
|
||||
assert "extra_headers" not in call_kwargs
|
||||
# No ``anthropic-beta`` header — that beta is specifically for
|
||||
# native Anthropic endpoints; Moonshot's shim accepts
|
||||
# ``cache_control`` without it, and sending it would be wasted
|
||||
# bytes (OR strips it before forwarding to Moonshot).
|
||||
assert "extra_headers" not in call_kwargs or not call_kwargs.get(
|
||||
"extra_headers"
|
||||
)
|
||||
# System block MUST carry ``cache_control`` so Moonshot's cache
|
||||
# breakpoint is honoured. The cached system-message builder
|
||||
# emits list-shape content with the marker on the first (and
|
||||
# only) block — assert on that shape.
|
||||
sys_msg = call_kwargs["messages"][0]
|
||||
sys_content = sys_msg.get("content")
|
||||
if isinstance(sys_content, list):
|
||||
assert all("cache_control" not in block for block in sys_content)
|
||||
tools = call_kwargs.get("tools", [])
|
||||
for t in tools:
|
||||
assert "cache_control" not in t
|
||||
assert isinstance(
|
||||
sys_content, list
|
||||
), "Cached system message should be a list-shape content block"
|
||||
assert any(
|
||||
"cache_control" in block for block in sys_content if isinstance(block, dict)
|
||||
), "Kimi system message should now carry cache_control markers"
|
||||
# Tool-level cache marking is applied by ``stream_chat_completion_baseline``
|
||||
# (see ``_mark_tools_with_cache_control``) before tools reach
|
||||
# ``_baseline_llm_caller``, so this unit test doesn't exercise
|
||||
# that path — covered by the outer integration test.
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_only_stream_still_closes_block(self):
|
||||
@@ -1935,7 +1955,7 @@ class TestBaselineReasoningStreaming:
|
||||
state=state,
|
||||
)
|
||||
|
||||
types = [type(e).__name__ for e in state.pending_events]
|
||||
types = [type(e).__name__ for e in state.emitted_events]
|
||||
assert "StreamReasoningStart" in types
|
||||
assert "StreamReasoningEnd" in types
|
||||
# No text was produced — no text events should be emitted.
|
||||
@@ -2006,3 +2026,55 @@ class TestBaselineReasoningStreaming:
|
||||
reasoning_rows = [m for m in state.session_messages if m.role == "reasoning"]
|
||||
assert len(reasoning_rows) == 1
|
||||
assert reasoning_rows[0].content == "first thought"
|
||||
|
||||
|
||||
class TestSupportsPromptCacheMarkers:
|
||||
"""``_supports_prompt_cache_markers`` is the widened gate for
|
||||
emitting ``cache_control`` markers on message content. It's a
|
||||
superset of ``_is_anthropic_model`` that ALSO admits Moonshot
|
||||
(whose Anthropic-compat endpoint honours the marker) while keeping
|
||||
the False answer for OpenAI / Grok / Gemini (which 400 on the
|
||||
unknown field)."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"anthropic/claude-sonnet-4-6",
|
||||
"claude-3-5-sonnet-20241022",
|
||||
"anthropic.claude-3-5-sonnet",
|
||||
"ANTHROPIC/Claude-Opus",
|
||||
],
|
||||
)
|
||||
def test_anthropic_routes_are_supported(self, model):
|
||||
assert _supports_prompt_cache_markers(model) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"moonshotai/kimi-k2.6",
|
||||
"moonshotai/kimi-k2-thinking",
|
||||
"moonshotai/kimi-k2.5",
|
||||
"moonshotai/kimi-k3.0", # future SKU
|
||||
],
|
||||
)
|
||||
def test_moonshot_routes_are_supported(self, model):
|
||||
"""The whole reason this predicate exists — Moonshot must be
|
||||
True even though ``_is_anthropic_model`` is False for it."""
|
||||
assert _supports_prompt_cache_markers(model) is True
|
||||
# Verify this is strictly wider than the anthropic-only check.
|
||||
assert _is_anthropic_model(model) is False
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"openai/gpt-4o",
|
||||
"google/gemini-2.5-pro",
|
||||
"xai/grok-4",
|
||||
"meta-llama/llama-3.3-70b-instruct",
|
||||
"deepseek/deepseek-v3",
|
||||
],
|
||||
)
|
||||
def test_other_providers_still_rejected(self, model):
|
||||
"""Regression guard: OpenAI/Grok/Gemini still 400 on
|
||||
``cache_control``, so the widened gate must keep them out."""
|
||||
assert _supports_prompt_cache_markers(model) is False
|
||||
|
||||
@@ -67,34 +67,40 @@ class TestResolveBaselineModel:
|
||||
|
||||
Baseline reads the ``fast_*_model`` cells of the (path, tier) matrix
|
||||
and never falls through to the SDK-side ``thinking_*_model`` cells.
|
||||
Default routing:
|
||||
- ``standard`` / ``None`` → ``config.fast_standard_model`` (Kimi K2.6)
|
||||
- ``advanced`` → ``config.fast_advanced_model`` (Opus — same as SDK's
|
||||
advanced tier, so the advanced A/B isolates path differences)
|
||||
Without a user_id (so no LD context) the resolver returns the
|
||||
``ChatConfig`` static default; per-user overrides are exercised in
|
||||
``copilot/model_router_test.py``.
|
||||
"""
|
||||
|
||||
def test_advanced_tier_selects_fast_advanced_model(self):
|
||||
assert _resolve_baseline_model("advanced") == config.fast_advanced_model
|
||||
@pytest.mark.asyncio
|
||||
async def test_advanced_tier_selects_fast_advanced_model(self):
|
||||
assert (
|
||||
await _resolve_baseline_model("advanced", None)
|
||||
== config.fast_advanced_model
|
||||
)
|
||||
|
||||
def test_standard_tier_selects_fast_standard_model(self):
|
||||
assert _resolve_baseline_model("standard") == config.fast_standard_model
|
||||
@pytest.mark.asyncio
|
||||
async def test_standard_tier_selects_fast_standard_model(self):
|
||||
assert (
|
||||
await _resolve_baseline_model("standard", None)
|
||||
== config.fast_standard_model
|
||||
)
|
||||
|
||||
def test_none_tier_selects_fast_standard_model(self):
|
||||
"""Baseline users without a tier get the cheap fast-standard default."""
|
||||
assert _resolve_baseline_model(None) == config.fast_standard_model
|
||||
@pytest.mark.asyncio
|
||||
async def test_none_tier_selects_fast_standard_model(self):
|
||||
"""Baseline users without a tier get the fast-standard default."""
|
||||
assert await _resolve_baseline_model(None, None) == config.fast_standard_model
|
||||
|
||||
def test_fast_standard_default_is_kimi(self):
|
||||
"""Shipped default: Kimi K2.6 on the baseline standard cell.
|
||||
|
||||
Asserts the declared ``Field`` default — env-independent — so a
|
||||
deploy-time ``CHAT_FAST_STANDARD_MODEL`` rollback override
|
||||
doesn't fail CI while still pinning the shipped default.
|
||||
"""
|
||||
def test_fast_standard_default_is_sonnet(self):
|
||||
"""Shipped default: Sonnet on the baseline standard cell — the
|
||||
non-Anthropic routes ship via the LD flag instead of a config
|
||||
change. Asserts the declared ``Field`` default so a deploy-time
|
||||
``CHAT_FAST_STANDARD_MODEL`` override doesn't flake CI."""
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
assert (
|
||||
ChatConfig.model_fields["fast_standard_model"].default
|
||||
== "moonshotai/kimi-k2.6"
|
||||
== "anthropic/claude-sonnet-4-6"
|
||||
)
|
||||
|
||||
def test_fast_advanced_default_is_opus(self):
|
||||
@@ -108,18 +114,6 @@ class TestResolveBaselineModel:
|
||||
== "anthropic/claude-opus-4.7"
|
||||
)
|
||||
|
||||
def test_standard_cells_diverge_across_paths(self):
|
||||
"""The whole point of the split: baseline cheap (Kimi) vs SDK
|
||||
Anthropic-only (Sonnet). If the shipped standard defaults ever
|
||||
collapse to the same value someone lost the cost savings.
|
||||
Checked against ``Field`` defaults, not the env-backed singleton."""
|
||||
from backend.copilot.config import ChatConfig
|
||||
|
||||
assert (
|
||||
ChatConfig.model_fields["thinking_standard_model"].default
|
||||
!= ChatConfig.model_fields["fast_standard_model"].default
|
||||
)
|
||||
|
||||
def test_standard_and_advanced_cells_differ_on_fast(self):
|
||||
"""Advanced tier defaults to a different model than standard on
|
||||
the baseline path. Checked against declared ``Field`` defaults
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import AliasChoices, Field, field_validator
|
||||
from pydantic import AliasChoices, Field, field_validator, model_validator
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from backend.util.clients import OPENROUTER_BASE_URL
|
||||
@@ -43,26 +43,20 @@ class ChatConfig(BaseSettings):
|
||||
# ``CHAT_FAST_MODEL``) are preserved via ``validation_alias`` so
|
||||
# existing deployments continue to override the same effective cell.
|
||||
fast_standard_model: str = Field(
|
||||
default="moonshotai/kimi-k2.6",
|
||||
default="anthropic/claude-sonnet-4-6",
|
||||
validation_alias=AliasChoices(
|
||||
"CHAT_FAST_STANDARD_MODEL",
|
||||
"CHAT_FAST_MODEL",
|
||||
),
|
||||
description="Baseline path, 'standard' / ``None`` tier. Kimi K2.6 "
|
||||
"by default: ~5x cheaper input and ~5.4x cheaper output than Sonnet, "
|
||||
"SWE-Bench Verified parity with Opus, and OpenRouter advertises the "
|
||||
"``reasoning`` + ``include_reasoning`` extension params on the "
|
||||
"Moonshot endpoints — so the baseline reasoning plumbing lights up "
|
||||
"without provider-specific code. Roll back to the Anthropic route "
|
||||
"via ``CHAT_FAST_STANDARD_MODEL=anthropic/claude-sonnet-4-6`` (then "
|
||||
"``cache_control`` breakpoints reactivate via "
|
||||
"``_is_anthropic_model``).",
|
||||
description="Baseline path, 'standard' / ``None`` tier. Per-user "
|
||||
"overrides flow through the ``copilot-fast-standard-model`` LD flag "
|
||||
"(see ``copilot/model_router.py``); this value is the fallback.",
|
||||
)
|
||||
fast_advanced_model: str = Field(
|
||||
default="anthropic/claude-opus-4.7",
|
||||
validation_alias=AliasChoices("CHAT_FAST_ADVANCED_MODEL"),
|
||||
description="Baseline path, 'advanced' tier. Opus by default. "
|
||||
"Override via ``CHAT_FAST_ADVANCED_MODEL``.",
|
||||
description="Baseline path, 'advanced' tier. LD override: "
|
||||
"``copilot-fast-advanced-model``.",
|
||||
)
|
||||
thinking_standard_model: str = Field(
|
||||
default="anthropic/claude-sonnet-4-6",
|
||||
@@ -71,11 +65,7 @@ class ChatConfig(BaseSettings):
|
||||
"CHAT_MODEL",
|
||||
),
|
||||
description="SDK (extended-thinking) path, 'standard' / ``None`` "
|
||||
"tier. Sonnet by default: the Claude Agent SDK CLI only speaks to "
|
||||
"Anthropic endpoints, so the standard SDK tier has to stay on an "
|
||||
"Anthropic model regardless of what the baseline path runs. "
|
||||
"Override via ``CHAT_THINKING_STANDARD_MODEL`` (legacy "
|
||||
"``CHAT_MODEL`` still honored).",
|
||||
"tier. LD override: ``copilot-thinking-standard-model``.",
|
||||
)
|
||||
thinking_advanced_model: str = Field(
|
||||
default="anthropic/claude-opus-4.7",
|
||||
@@ -83,17 +73,18 @@ class ChatConfig(BaseSettings):
|
||||
"CHAT_THINKING_ADVANCED_MODEL",
|
||||
"CHAT_ADVANCED_MODEL",
|
||||
),
|
||||
description="SDK (extended-thinking) path, 'advanced' tier. Opus "
|
||||
"by default. Override via ``CHAT_THINKING_ADVANCED_MODEL`` "
|
||||
"(legacy ``CHAT_ADVANCED_MODEL`` still honored).",
|
||||
description="SDK (extended-thinking) path, 'advanced' tier. LD "
|
||||
"override: ``copilot-thinking-advanced-model``.",
|
||||
)
|
||||
title_model: str = Field(
|
||||
default="openai/gpt-4o-mini",
|
||||
description="Model to use for generating session titles (should be fast/cheap)",
|
||||
)
|
||||
simulation_model: str = Field(
|
||||
default="google/gemini-2.5-flash",
|
||||
description="Model for dry-run block simulation (should be fast/cheap with good JSON output)",
|
||||
default="google/gemini-2.5-flash-lite",
|
||||
description="Model for dry-run block simulation (should be fast/cheap with good JSON output). "
|
||||
"Gemini 2.5 Flash-Lite is ~3x cheaper than Flash ($0.10/$0.40 vs $0.30/$1.20 per MTok) "
|
||||
"with JSON-mode reliability adequate for shape-matching block outputs.",
|
||||
)
|
||||
api_key: str | None = Field(default=None, description="OpenAI API key")
|
||||
base_url: str | None = Field(
|
||||
@@ -249,13 +240,28 @@ class ChatConfig(BaseSettings):
|
||||
"``max_thinking_tokens`` kwarg so the CLI falls back to model default "
|
||||
"(which, without the flag, leaves extended thinking off).",
|
||||
)
|
||||
render_reasoning_in_ui: bool = Field(
|
||||
default=True,
|
||||
description="Render reasoning as live UI parts "
|
||||
"(``StreamReasoning*`` wire events). False suppresses the live "
|
||||
"wire events only; ``role='reasoning'`` rows are always persisted "
|
||||
"so the reasoning bubble hydrates on reload. Tokens are billed "
|
||||
"upstream regardless.",
|
||||
)
|
||||
stream_replay_count: int = Field(
|
||||
default=200,
|
||||
ge=1,
|
||||
le=10000,
|
||||
description="Max Redis stream entries replayed on SSE reconnect.",
|
||||
)
|
||||
claude_agent_thinking_effort: Literal["low", "medium", "high", "max"] | None = (
|
||||
Field(
|
||||
default=None,
|
||||
description="Thinking effort level: 'low', 'medium', 'high', 'max', or None. "
|
||||
"Only applies to models with extended thinking (Opus). "
|
||||
"Sonnet doesn't have extended thinking — setting effort on Sonnet "
|
||||
"can cause <internal_reasoning> tag leaks. "
|
||||
"Applies to models that emit a reasoning channel — Opus (extended "
|
||||
"thinking) and Kimi K2.6 (OpenRouter ``reasoning`` extension lit "
|
||||
"up by #12871). Sonnet does not have extended thinking — setting "
|
||||
"effort on Sonnet can cause <internal_reasoning> tag leaks. "
|
||||
"None = let the model decide. Override via CHAT_CLAUDE_AGENT_THINKING_EFFORT.",
|
||||
)
|
||||
)
|
||||
@@ -287,6 +293,31 @@ class ChatConfig(BaseSettings):
|
||||
"(24h, permanent) TTL option — see "
|
||||
"https://platform.claude.com/docs/en/build-with-claude/prompt-caching.",
|
||||
)
|
||||
sdk_include_partial_messages: bool = Field(
|
||||
default=True,
|
||||
description="Stream SDK responses token-by-token instead of in "
|
||||
"one lump at the end. Set to False if the SDK path starts "
|
||||
"double-writing text or dropping the tail of long messages.",
|
||||
)
|
||||
sdk_reconcile_openrouter_cost: bool = Field(
|
||||
default=True,
|
||||
description="Query OpenRouter's ``/api/v1/generation?id=`` after each "
|
||||
"SDK turn and record the authoritative ``total_cost`` instead of the "
|
||||
"Claude Agent SDK CLI's estimate. Covers every OpenRouter-routed "
|
||||
"SDK turn regardless of vendor — the CLI's static Anthropic pricing "
|
||||
"table is accurate for Anthropic models (Sonnet/Opus via OpenRouter "
|
||||
"bill at Anthropic's own rates, penny-for-penny), but the reconcile "
|
||||
"catches any future rate change the CLI hasn't picked up and makes "
|
||||
"non-Anthropic cost (Kimi et al) correct — real billed amount, "
|
||||
"matching the baseline path's ``usage.cost`` read since #12864. "
|
||||
"Kill-switch for emergencies: set ``CHAT_SDK_RECONCILE_OPENROUTER_COST"
|
||||
"=false`` to fall back to the CLI's ``total_cost_usd`` reported "
|
||||
"synchronously (accurate-for-Anthropic / over-billed-for-Kimi). "
|
||||
"Tradeoff: 0.5-2s window between turn end and cost write; rate-limit "
|
||||
"counter briefly unaware, back-to-back turns in that window see "
|
||||
"stale state. The alternative (writing an estimate sync then a "
|
||||
"correction delta) would double-count the rate limit.",
|
||||
)
|
||||
claude_agent_cli_path: str | None = Field(
|
||||
default=None,
|
||||
description="Optional explicit path to a Claude Code CLI binary. "
|
||||
@@ -457,6 +488,59 @@ class ChatConfig(BaseSettings):
|
||||
)
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_sdk_model_vendor_compatibility(self) -> "ChatConfig":
|
||||
"""Fail at config load when an SDK model slug is incompatible with
|
||||
explicit direct-Anthropic mode.
|
||||
|
||||
The SDK path's ``_normalize_model_name`` raises ``ValueError`` when
|
||||
a non-Anthropic vendor slug (e.g. ``moonshotai/kimi-k2.6``) is paired
|
||||
with direct-Anthropic mode — but that fires inside the request loop,
|
||||
so a misconfigured deployment would surface a 500 to every user
|
||||
instead of failing visibly at boot.
|
||||
|
||||
Only the **explicit** opt-out (``use_openrouter=False``) is checked
|
||||
here, not the credential-missing path. Build environments and
|
||||
OpenAPI-schema export jobs construct ``ChatConfig()`` without any
|
||||
OpenRouter credentials in the env — that's not a misconfiguration,
|
||||
it's "config loads ok, but no SDK turn will succeed until creds are
|
||||
wired". The runtime guard in ``_normalize_model_name`` still
|
||||
catches the credential-missing path on the first SDK turn.
|
||||
|
||||
Covers all three SDK fields that flow through
|
||||
``_normalize_model_name``: primary tier
|
||||
(``thinking_standard_model``), advanced tier
|
||||
(``thinking_advanced_model``), and fallback model
|
||||
(``claude_agent_fallback_model`` via ``_resolve_fallback_model``).
|
||||
|
||||
Skipped when ``use_claude_code_subscription=True`` because the
|
||||
subscription path resolves the model to ``None`` (CLI default)
|
||||
and never calls ``_normalize_model_name``. Empty fallback strings
|
||||
are also skipped (no fallback configured).
|
||||
"""
|
||||
if self.use_claude_code_subscription:
|
||||
return self
|
||||
if self.use_openrouter:
|
||||
return self
|
||||
for field_name in (
|
||||
"thinking_standard_model",
|
||||
"thinking_advanced_model",
|
||||
"claude_agent_fallback_model",
|
||||
):
|
||||
value: str = getattr(self, field_name)
|
||||
if not value or "/" not in value:
|
||||
continue
|
||||
if value.split("/", 1)[0] != "anthropic":
|
||||
raise ValueError(
|
||||
f"Direct-Anthropic mode (use_openrouter=False) "
|
||||
f"requires an Anthropic model for {field_name}, got "
|
||||
f"{value!r}. Set CHAT_THINKING_STANDARD_MODEL / "
|
||||
f"CHAT_THINKING_ADVANCED_MODEL / "
|
||||
f"CHAT_CLAUDE_AGENT_FALLBACK_MODEL to an anthropic/* "
|
||||
f"slug, or set CHAT_USE_OPENROUTER=true."
|
||||
)
|
||||
return self
|
||||
|
||||
# Prompt paths for different contexts
|
||||
PROMPT_PATHS: dict[str, str] = {
|
||||
"default": "prompts/chat_system.md",
|
||||
|
||||
@@ -5,12 +5,17 @@ import pytest
|
||||
from .config import ChatConfig
|
||||
|
||||
# Env vars that the ChatConfig validators read — must be cleared so they don't
|
||||
# override the explicit constructor values we pass in each test.
|
||||
# override the explicit constructor values we pass in each test. Includes the
|
||||
# SDK/baseline model aliases so a leftover ``CHAT_MODEL=...`` in the developer
|
||||
# or CI environment can't change whether
|
||||
# ``_validate_sdk_model_vendor_compatibility`` raises.
|
||||
_ENV_VARS_TO_CLEAR = (
|
||||
"CHAT_USE_E2B_SANDBOX",
|
||||
"CHAT_E2B_API_KEY",
|
||||
"E2B_API_KEY",
|
||||
"CHAT_USE_OPENROUTER",
|
||||
"CHAT_USE_CLAUDE_AGENT_SDK",
|
||||
"CHAT_USE_CLAUDE_CODE_SUBSCRIPTION",
|
||||
"CHAT_API_KEY",
|
||||
"OPEN_ROUTER_API_KEY",
|
||||
"OPENAI_API_KEY",
|
||||
@@ -19,6 +24,16 @@ _ENV_VARS_TO_CLEAR = (
|
||||
"OPENAI_BASE_URL",
|
||||
"CHAT_CLAUDE_AGENT_CLI_PATH",
|
||||
"CLAUDE_AGENT_CLI_PATH",
|
||||
"CHAT_FAST_STANDARD_MODEL",
|
||||
"CHAT_FAST_MODEL",
|
||||
"CHAT_FAST_ADVANCED_MODEL",
|
||||
"CHAT_THINKING_STANDARD_MODEL",
|
||||
"CHAT_THINKING_ADVANCED_MODEL",
|
||||
"CHAT_MODEL",
|
||||
"CHAT_ADVANCED_MODEL",
|
||||
"CHAT_CLAUDE_AGENT_FALLBACK_MODEL",
|
||||
"CHAT_RENDER_REASONING_IN_UI",
|
||||
"CHAT_STREAM_REPLAY_COUNT",
|
||||
)
|
||||
|
||||
|
||||
@@ -28,6 +43,22 @@ def _clean_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.delenv(var, raising=False)
|
||||
|
||||
|
||||
def _make_direct_safe_config(**kwargs) -> ChatConfig:
|
||||
"""Build a ``ChatConfig`` for tests that pass ``use_openrouter=False``
|
||||
but aren't exercising the SDK vendor-compatibility validator.
|
||||
|
||||
Pins ``thinking_standard_model``/``thinking_advanced_model`` to anthropic/*
|
||||
so the construction passes ``_validate_sdk_model_vendor_compatibility``
|
||||
without each test having to repeat the override.
|
||||
"""
|
||||
defaults: dict = {
|
||||
"thinking_standard_model": "anthropic/claude-sonnet-4-6",
|
||||
"thinking_advanced_model": "anthropic/claude-opus-4-7",
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return ChatConfig(**defaults)
|
||||
|
||||
|
||||
class TestOpenrouterActive:
|
||||
"""Tests for the openrouter_active property."""
|
||||
|
||||
@@ -48,7 +79,7 @@ class TestOpenrouterActive:
|
||||
assert cfg.openrouter_active is False
|
||||
|
||||
def test_disabled_returns_false_despite_credentials(self):
|
||||
cfg = ChatConfig(
|
||||
cfg = _make_direct_safe_config(
|
||||
use_openrouter=False,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
@@ -164,3 +195,133 @@ class TestClaudeAgentCliPathEnvFallback:
|
||||
monkeypatch.setenv("CLAUDE_AGENT_CLI_PATH", str(tmp_path))
|
||||
with pytest.raises(Exception, match="not a regular file"):
|
||||
ChatConfig()
|
||||
|
||||
|
||||
class TestSdkModelVendorCompatibility:
|
||||
"""``model_validator`` that fails fast on SDK model vs routing-mode
|
||||
mismatch — see PR #12878 iteration-2 review. Mirrors the runtime
|
||||
guard in ``_normalize_model_name`` so misconfig surfaces at boot
|
||||
instead of as a 500 on the first SDK turn."""
|
||||
|
||||
def test_direct_anthropic_with_kimi_override_raises(self):
|
||||
"""A non-Anthropic SDK model must fail at config load when the
|
||||
deployment has no OpenRouter credentials."""
|
||||
with pytest.raises(Exception, match="requires an Anthropic model"):
|
||||
ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
thinking_standard_model="moonshotai/kimi-k2.6",
|
||||
)
|
||||
|
||||
def test_direct_anthropic_with_anthropic_default_succeeds(self):
|
||||
"""Direct-Anthropic mode is fine when both SDK slugs are anthropic/*
|
||||
— which is the default after the LD-routed model rollout."""
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
assert cfg.thinking_standard_model == "anthropic/claude-sonnet-4-6"
|
||||
|
||||
def test_openrouter_with_kimi_override_succeeds(self):
|
||||
"""Kimi slug round-trips cleanly when OpenRouter is on — exercised
|
||||
via the LD-flag override path in production."""
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
thinking_standard_model="moonshotai/kimi-k2.6",
|
||||
)
|
||||
assert cfg.thinking_standard_model == "moonshotai/kimi-k2.6"
|
||||
|
||||
def test_subscription_mode_skips_check(self):
|
||||
"""Subscription path resolves the model to None and bypasses
|
||||
``_normalize_model_name``, so the slug check is skipped."""
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
assert cfg.use_claude_code_subscription is True
|
||||
|
||||
def test_advanced_tier_also_validated(self):
|
||||
"""Both standard and advanced SDK slugs are checked."""
|
||||
with pytest.raises(Exception, match="thinking_advanced_model"):
|
||||
ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="moonshotai/kimi-k2.6",
|
||||
)
|
||||
|
||||
def test_fallback_model_also_validated(self):
|
||||
"""``claude_agent_fallback_model`` flows through
|
||||
``_normalize_model_name`` via ``_resolve_fallback_model`` so the
|
||||
same direct-Anthropic guard applies."""
|
||||
with pytest.raises(Exception, match="claude_agent_fallback_model"):
|
||||
ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4-7",
|
||||
claude_agent_fallback_model="moonshotai/kimi-k2.6",
|
||||
)
|
||||
|
||||
def test_empty_fallback_skipped(self):
|
||||
"""Empty ``claude_agent_fallback_model`` (no fallback configured)
|
||||
must not trip the validator — the fallback-disabled state is
|
||||
intentional and shouldn't require a placeholder anthropic/* slug."""
|
||||
cfg = ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4-7",
|
||||
claude_agent_fallback_model="",
|
||||
)
|
||||
assert cfg.claude_agent_fallback_model == ""
|
||||
|
||||
|
||||
class TestRenderReasoningInUi:
|
||||
"""``render_reasoning_in_ui`` gates reasoning wire events globally."""
|
||||
|
||||
def test_defaults_to_true(self):
|
||||
"""Default must stay True — flipping it silences the reasoning
|
||||
collapse for every user, which is an opt-in operator decision."""
|
||||
cfg = ChatConfig()
|
||||
assert cfg.render_reasoning_in_ui is True
|
||||
|
||||
def test_env_override_false(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CHAT_RENDER_REASONING_IN_UI", "false")
|
||||
cfg = ChatConfig()
|
||||
assert cfg.render_reasoning_in_ui is False
|
||||
|
||||
|
||||
class TestStreamReplayCount:
|
||||
"""``stream_replay_count`` caps the SSE reconnect replay batch size."""
|
||||
|
||||
def test_default_is_200(self):
|
||||
"""200 covers a full Kimi turn after coalescing (~150 events) while
|
||||
bounding the replay storm from 1000+ chunks."""
|
||||
cfg = ChatConfig()
|
||||
assert cfg.stream_replay_count == 200
|
||||
|
||||
def test_env_override(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setenv("CHAT_STREAM_REPLAY_COUNT", "500")
|
||||
cfg = ChatConfig()
|
||||
assert cfg.stream_replay_count == 500
|
||||
|
||||
def test_zero_rejected(self):
|
||||
"""count=0 would make XREAD replay nothing — rejected via ge=1."""
|
||||
with pytest.raises(Exception):
|
||||
ChatConfig(stream_replay_count=0)
|
||||
|
||||
@@ -105,25 +105,46 @@ class CoPilotExecutor(AppProcess):
|
||||
time.sleep(1e5)
|
||||
|
||||
def cleanup(self):
|
||||
"""Graceful shutdown with active execution waiting."""
|
||||
pid = os.getpid()
|
||||
logger.info(f"[cleanup {pid}] Starting graceful shutdown...")
|
||||
"""Graceful shutdown — mirrors ``backend.executor.manager`` pattern.
|
||||
|
||||
# Signal the consumer thread to stop
|
||||
1. Stop consumer immediately (both the Python flag that gates
|
||||
``_handle_run_message`` and ``channel.stop_consuming()`` at
|
||||
the broker), so no new work enters.
|
||||
2. Passively wait for ``active_tasks`` to drain — each turn's
|
||||
own ``finally`` publishes its terminal state via
|
||||
``mark_session_completed``. When a turn exits, ``on_run_done``
|
||||
removes it from ``active_tasks`` and releases its cluster lock.
|
||||
3. Shut down the thread-pool executor (cancels pending, leaves
|
||||
running threads alone — process exit handles them).
|
||||
4. Release any cluster locks still held (defensive — on_run_done's
|
||||
finally should have already released them).
|
||||
5. Stop message consumer threads + disconnect pika clients.
|
||||
|
||||
The zombie-session bug this PR targets is handled inside each
|
||||
turn's own lifecycle by :func:`sync_fail_close_session`, NOT by
|
||||
cleanup — so cleanup can stay as a simple "wait, then teardown"
|
||||
and matches agent-executor's proven pattern.
|
||||
"""
|
||||
pid = os.getpid()
|
||||
prefix = f"[cleanup {pid}]"
|
||||
logger.info(f"{prefix} Starting graceful shutdown...")
|
||||
|
||||
# 1. Stop consumer — flag AND broker-side
|
||||
try:
|
||||
self.stop_consuming.set()
|
||||
run_channel = self.run_client.get_channel()
|
||||
run_channel.connection.add_callback_threadsafe(
|
||||
lambda: run_channel.stop_consuming()
|
||||
)
|
||||
logger.info(f"[cleanup {pid}] Consumer has been signaled to stop")
|
||||
logger.info(f"{prefix} Consumer has been signaled to stop")
|
||||
except Exception as e:
|
||||
logger.error(f"[cleanup {pid}] Error stopping consumer: {e}")
|
||||
logger.error(f"{prefix} Error stopping consumer: {e}")
|
||||
|
||||
# Wait for active executions to complete
|
||||
# 2. Wait for in-flight turns to finish naturally
|
||||
if self.active_tasks:
|
||||
logger.info(
|
||||
f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
|
||||
f"{prefix} Waiting for {len(self.active_tasks)} active tasks "
|
||||
f"to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..."
|
||||
)
|
||||
|
||||
start_time = time.monotonic()
|
||||
@@ -138,38 +159,42 @@ class CoPilotExecutor(AppProcess):
|
||||
if not self.active_tasks:
|
||||
break
|
||||
|
||||
# Refresh cluster locks periodically
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_refresh >= lock_refresh_interval:
|
||||
now = time.monotonic()
|
||||
if now - last_refresh >= lock_refresh_interval:
|
||||
for lock in list(self._task_locks.values()):
|
||||
try:
|
||||
lock.refresh()
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[cleanup {pid}] Failed to refresh lock: {e}"
|
||||
)
|
||||
last_refresh = current_time
|
||||
logger.warning(f"{prefix} Failed to refresh lock: {e}")
|
||||
last_refresh = now
|
||||
|
||||
logger.info(
|
||||
f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..."
|
||||
f"{prefix} {len(self.active_tasks)} tasks still active, waiting..."
|
||||
)
|
||||
time.sleep(10.0)
|
||||
|
||||
# Stop message consumers
|
||||
if self.active_tasks:
|
||||
logger.warning(
|
||||
f"{prefix} {len(self.active_tasks)} tasks still running after "
|
||||
f"{GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s — process exit will "
|
||||
f"abandon them; RabbitMQ redelivery handles the message."
|
||||
)
|
||||
|
||||
# 3. Stop message consumer threads
|
||||
if self._run_thread:
|
||||
self._stop_message_consumers(
|
||||
self._run_thread, self.run_client, "[cleanup][run]"
|
||||
self._run_thread, self.run_client, f"{prefix} [run]"
|
||||
)
|
||||
if self._cancel_thread:
|
||||
self._stop_message_consumers(
|
||||
self._cancel_thread, self.cancel_client, "[cleanup][cancel]"
|
||||
self._cancel_thread, self.cancel_client, f"{prefix} [cancel]"
|
||||
)
|
||||
|
||||
# Clean up worker threads (closes per-loop workspace storage sessions)
|
||||
# 4. Worker cleanup + executor shutdown
|
||||
if self._executor:
|
||||
from .processor import cleanup_worker
|
||||
|
||||
logger.info(f"[cleanup {pid}] Cleaning up workers...")
|
||||
logger.info(f"{prefix} Cleaning up workers...")
|
||||
futures = []
|
||||
for _ in range(self._executor._max_workers):
|
||||
futures.append(self._executor.submit(cleanup_worker))
|
||||
@@ -177,22 +202,20 @@ class CoPilotExecutor(AppProcess):
|
||||
try:
|
||||
f.result(timeout=10)
|
||||
except Exception as e:
|
||||
logger.warning(f"[cleanup {pid}] Worker cleanup error: {e}")
|
||||
logger.warning(f"{prefix} Worker cleanup error: {e}")
|
||||
|
||||
logger.info(f"[cleanup {pid}] Shutting down executor...")
|
||||
logger.info(f"{prefix} Shutting down executor...")
|
||||
self._executor.shutdown(wait=False)
|
||||
|
||||
# Release any remaining locks
|
||||
# 5. Release any cluster locks still held
|
||||
for session_id, lock in list(self._task_locks.items()):
|
||||
try:
|
||||
lock.release()
|
||||
logger.info(f"[cleanup {pid}] Released lock for {session_id}")
|
||||
logger.info(f"{prefix} Released lock for {session_id}")
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[cleanup {pid}] Failed to release lock for {session_id}: {e}"
|
||||
)
|
||||
logger.error(f"{prefix} Failed to release lock for {session_id}: {e}")
|
||||
|
||||
logger.info(f"[cleanup {pid}] Graceful shutdown completed")
|
||||
logger.info(f"{prefix} Graceful shutdown completed")
|
||||
|
||||
# ============ RabbitMQ Consumer Methods ============ #
|
||||
|
||||
@@ -387,13 +410,12 @@ class CoPilotExecutor(AppProcess):
|
||||
|
||||
# Execute the task
|
||||
try:
|
||||
self._task_locks[session_id] = cluster_lock
|
||||
|
||||
logger.info(
|
||||
f"Acquired cluster lock for {session_id}, "
|
||||
f"executor_id={self.executor_id}"
|
||||
)
|
||||
|
||||
self._task_locks[session_id] = cluster_lock
|
||||
cancel_event = threading.Event()
|
||||
future = self.executor.submit(
|
||||
execute_copilot_turn, entry, cancel_event, cluster_lock
|
||||
@@ -425,7 +447,6 @@ class CoPilotExecutor(AppProcess):
|
||||
error_msg = str(e) or type(e).__name__
|
||||
logger.exception(f"Error in run completion callback: {error_msg}")
|
||||
finally:
|
||||
# Release the cluster lock
|
||||
if session_id in self._task_locks:
|
||||
logger.info(f"Releasing cluster lock for {session_id}")
|
||||
self._task_locks[session_id].release()
|
||||
|
||||
@@ -5,6 +5,7 @@ in a thread-local context, following the graph executor pattern.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
@@ -30,6 +31,87 @@ from .utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]")
|
||||
|
||||
|
||||
SHUTDOWN_ERROR_MESSAGE = (
|
||||
"Copilot executor shut down before this turn finished. Please retry."
|
||||
)
|
||||
|
||||
# Max time execute() blocks after calling future.cancel() / when draining a
|
||||
# soon-to-be-cancelled future. Gives _execute_async's own finally a chance to
|
||||
# publish the accurate terminal state over the Redis CAS; long enough to let
|
||||
# an in-flight Redis call settle, short enough that shutdown doesn't stall.
|
||||
_CANCEL_GRACE_SECONDS = 5.0
|
||||
|
||||
# Max time the sync safety net itself spends on a single Redis CAS. Without
|
||||
# this bound the whole point of ``sync_fail_close_session`` is defeated —
|
||||
# ``mark_session_completed`` would hang on the same broken Redis that caused
|
||||
# the original failure. On timeout we give up silently; worst case the
|
||||
# session stays ``running`` until the stale-session watchdog reaps it, but
|
||||
# at least the pool worker thread isn't blocked forever.
|
||||
_FAIL_CLOSE_REDIS_TIMEOUT = 10.0
|
||||
|
||||
|
||||
# Module-level symbol preserved for backward-compat with callers that import
|
||||
# ``sync_fail_close_session``; the real implementation now lives on
|
||||
# ``CoPilotProcessor`` so it can reuse ``self.execution_loop`` (same
|
||||
# pattern as ``backend.executor.manager``'s ``node_execution_loop`` bridge
|
||||
# at :meth:`ExecutionProcessor.on_graph_execution`).
|
||||
|
||||
|
||||
def sync_fail_close_session(
|
||||
session_id: str,
|
||||
log: "CoPilotLogMetadata | TruncatedLogger",
|
||||
execution_loop: asyncio.AbstractEventLoop,
|
||||
) -> None:
|
||||
"""Synchronously mark *session_id* as failed from the pool worker thread.
|
||||
|
||||
Submits the CAS coroutine to the long-lived *execution_loop* via
|
||||
``run_coroutine_threadsafe`` — the same shape agent-executor uses at
|
||||
:meth:`backend.executor.manager.ExecutionProcessor.on_graph_execution`
|
||||
to reach its ``node_execution_loop`` from the pool worker. Reusing the
|
||||
persistent loop means:
|
||||
|
||||
* no fresh TCP connection per turn (the ``@thread_cached``
|
||||
``AsyncRedis`` on the execution thread stays bound to the same loop
|
||||
and is reused across every turn);
|
||||
* no loop-teardown overhead;
|
||||
* no ``clear_cache()`` gymnastics to dodge the "loop is closed" pitfall.
|
||||
|
||||
``mark_session_completed`` is an atomic CAS on ``status == "running"``,
|
||||
so when the async path already wrote a terminal state the sync call is
|
||||
a cheap no-op. The inner ``asyncio.wait_for`` bounds the Redis call so
|
||||
a wedged Redis can't hang the safety net for the full redis-py default
|
||||
TCP timeout; the outer ``.result(timeout=...)`` is a belt-and-braces
|
||||
upper bound for the cross-thread wait.
|
||||
"""
|
||||
|
||||
async def _bounded() -> None:
|
||||
await asyncio.wait_for(
|
||||
stream_registry.mark_session_completed(
|
||||
session_id, error_message=SHUTDOWN_ERROR_MESSAGE
|
||||
),
|
||||
timeout=_FAIL_CLOSE_REDIS_TIMEOUT,
|
||||
)
|
||||
|
||||
try:
|
||||
future = asyncio.run_coroutine_threadsafe(_bounded(), execution_loop)
|
||||
except RuntimeError as e:
|
||||
# execution_loop is closed — happens if cleanup() already ran the
|
||||
# per-worker teardown. Nothing we can do; let the stale-session
|
||||
# watchdog reap it.
|
||||
log.warning(f"sync fail-close skipped (execution_loop closed): {e}")
|
||||
return
|
||||
try:
|
||||
future.result(timeout=_FAIL_CLOSE_REDIS_TIMEOUT + 2)
|
||||
except concurrent.futures.TimeoutError:
|
||||
log.warning(
|
||||
f"sync fail-close timed out after {_FAIL_CLOSE_REDIS_TIMEOUT}s "
|
||||
f"(session={session_id})"
|
||||
)
|
||||
future.cancel()
|
||||
except Exception as e:
|
||||
log.warning(f"sync fail-close mark_session_completed failed: {e}")
|
||||
|
||||
|
||||
# ============ Mode Routing ============ #
|
||||
|
||||
|
||||
@@ -252,12 +334,13 @@ class CoPilotProcessor:
|
||||
):
|
||||
"""Execute a CoPilot turn.
|
||||
|
||||
Runs the async logic in the worker's event loop and handles errors.
|
||||
|
||||
Args:
|
||||
entry: The turn payload containing session and message info
|
||||
cancel: Threading event to signal cancellation
|
||||
cluster_lock: Distributed lock to prevent duplicate execution
|
||||
Thin wrapper around :meth:`_execute`. The ``try/finally`` here
|
||||
guarantees :func:`sync_fail_close_session` runs on every exit
|
||||
path — normal completion, exception, or a wedged event loop
|
||||
that escapes via :data:`_CANCEL_GRACE_SECONDS` timeout.
|
||||
``mark_session_completed`` is an atomic CAS on
|
||||
``status == "running"``, so when the async path already wrote a
|
||||
terminal state the sync call is a cheap no-op.
|
||||
"""
|
||||
log = CoPilotLogMetadata(
|
||||
logging.getLogger(__name__),
|
||||
@@ -265,10 +348,28 @@ class CoPilotProcessor:
|
||||
user_id=entry.user_id,
|
||||
)
|
||||
log.info("Starting execution")
|
||||
|
||||
start_time = time.monotonic()
|
||||
try:
|
||||
self._execute(entry, cancel, cluster_lock, log)
|
||||
finally:
|
||||
sync_fail_close_session(entry.session_id, log, self.execution_loop)
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.info(f"Execution completed in {elapsed:.2f}s")
|
||||
|
||||
# Run the async execution in our event loop
|
||||
def _execute(
|
||||
self,
|
||||
entry: CoPilotExecutionEntry,
|
||||
cancel: threading.Event,
|
||||
cluster_lock: ClusterLock,
|
||||
log: CoPilotLogMetadata,
|
||||
):
|
||||
"""Submit the async turn to ``self.execution_loop`` and drive it.
|
||||
|
||||
Handles the sync/async boundary (cancel-event checks, cluster-lock
|
||||
refresh, bounded waits) without any Redis-state cleanup logic —
|
||||
that lives in :func:`sync_fail_close_session` which the outer
|
||||
:meth:`execute` always invokes on exit.
|
||||
"""
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
self._execute_async(entry, cancel, cluster_lock, log),
|
||||
self.execution_loop,
|
||||
@@ -282,16 +383,27 @@ class CoPilotProcessor:
|
||||
if cancel.is_set():
|
||||
log.info("Cancellation requested")
|
||||
future.cancel()
|
||||
break
|
||||
# Refresh cluster lock to maintain ownership
|
||||
# Give _execute_async's own finally a short window to
|
||||
# publish its accurate terminal state before the outer
|
||||
# sync safety net fires.
|
||||
try:
|
||||
future.result(timeout=_CANCEL_GRACE_SECONDS)
|
||||
except BaseException:
|
||||
pass
|
||||
return
|
||||
cluster_lock.refresh()
|
||||
|
||||
if not future.cancelled():
|
||||
# Get result to propagate any exceptions
|
||||
future.result()
|
||||
|
||||
elapsed = time.monotonic() - start_time
|
||||
log.info(f"Execution completed in {elapsed:.2f}s")
|
||||
# Bounded timeout so a wedged event loop can't trap us here —
|
||||
# on timeout we escape to execute()'s finally and the sync
|
||||
# safety net fires.
|
||||
try:
|
||||
future.result(timeout=_CANCEL_GRACE_SECONDS)
|
||||
except concurrent.futures.TimeoutError:
|
||||
log.warning(
|
||||
"Future did not complete within grace window; "
|
||||
"falling through to sync fail-close"
|
||||
)
|
||||
|
||||
async def _execute_async(
|
||||
self,
|
||||
|
||||
@@ -10,6 +10,8 @@ the real production helpers from ``processor.py`` so the routing logic
|
||||
has meaningful coverage.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import threading
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
@@ -20,6 +22,7 @@ from backend.copilot.executor.processor import (
|
||||
CoPilotProcessor,
|
||||
resolve_effective_mode,
|
||||
resolve_use_sdk_for_mode,
|
||||
sync_fail_close_session,
|
||||
)
|
||||
from backend.copilot.executor.utils import CoPilotExecutionEntry, CoPilotLogMetadata
|
||||
|
||||
@@ -275,3 +278,221 @@ class TestExecuteAsyncAclose:
|
||||
await proc._execute_async(_make_entry(), cancel, cluster_lock, _make_log())
|
||||
|
||||
assert published.aclose_called is True
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def exec_loop():
|
||||
"""Long-lived asyncio loop on a daemon thread — mirrors the layout
|
||||
``CoPilotProcessor`` sets up (``execution_loop`` + ``execution_thread``)
|
||||
so ``sync_fail_close_session`` has a real cross-thread loop to submit
|
||||
into via ``run_coroutine_threadsafe``."""
|
||||
loop = asyncio.new_event_loop()
|
||||
thread = threading.Thread(target=loop.run_forever, daemon=True)
|
||||
thread.start()
|
||||
try:
|
||||
yield loop
|
||||
finally:
|
||||
loop.call_soon_threadsafe(loop.stop)
|
||||
thread.join(timeout=5)
|
||||
loop.close()
|
||||
|
||||
|
||||
class TestSyncFailCloseSession:
|
||||
"""``sync_fail_close_session`` is the last-line-of-defense invoked from
|
||||
``CoPilotProcessor.execute``'s ``finally``. It must call
|
||||
``mark_session_completed`` via the processor's long-lived
|
||||
``execution_loop`` (cross-thread submit) and must swallow Redis
|
||||
failures so a transient outage doesn't propagate out of the finally."""
|
||||
|
||||
def test_invokes_mark_session_completed_with_shutdown_message(
|
||||
self, exec_loop
|
||||
) -> None:
|
||||
mock_mark = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
|
||||
new=mock_mark,
|
||||
):
|
||||
sync_fail_close_session("sess-1", _make_log(), exec_loop)
|
||||
|
||||
mock_mark.assert_awaited_once()
|
||||
assert mock_mark.await_args is not None
|
||||
assert mock_mark.await_args.args[0] == "sess-1"
|
||||
assert "shut down" in mock_mark.await_args.kwargs["error_message"].lower()
|
||||
|
||||
def test_swallows_redis_error(self, exec_loop) -> None:
|
||||
# Raising from the mock ensures the helper catches the exception
|
||||
# instead of propagating it back into execute()'s finally block.
|
||||
mock_mark = AsyncMock(side_effect=RuntimeError("redis down"))
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
|
||||
new=mock_mark,
|
||||
):
|
||||
sync_fail_close_session("sess-2", _make_log(), exec_loop) # must not raise
|
||||
|
||||
mock_mark.assert_awaited_once()
|
||||
|
||||
def test_closed_execution_loop_skipped_cleanly(self) -> None:
|
||||
"""If cleanup_worker has already stopped the execution_loop by the
|
||||
time the safety net fires, ``run_coroutine_threadsafe`` raises
|
||||
RuntimeError. Expected behavior: log + return without propagating."""
|
||||
dead_loop = asyncio.new_event_loop()
|
||||
dead_loop.close()
|
||||
|
||||
mock_mark = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
|
||||
new=mock_mark,
|
||||
):
|
||||
# Must not raise even though the loop is closed
|
||||
sync_fail_close_session("sess-closed-loop", _make_log(), dead_loop)
|
||||
|
||||
# mark_session_completed was never scheduled because the loop was dead
|
||||
mock_mark.assert_not_awaited()
|
||||
|
||||
def test_bounded_timeout_when_redis_hangs(self, exec_loop) -> None:
|
||||
"""Scenario D: Redis unreachable — the inner ``asyncio.wait_for``
|
||||
must fire and the helper must return without blocking the worker.
|
||||
|
||||
Simulates a wedged Redis by sleeping past the 10s fail-close budget.
|
||||
The helper must return within the configured grace (+ a small
|
||||
scheduler margin) and must not re-raise.
|
||||
"""
|
||||
import time as _time
|
||||
|
||||
from backend.copilot.executor.processor import _FAIL_CLOSE_REDIS_TIMEOUT
|
||||
|
||||
async def _hang(*_args, **_kwargs):
|
||||
await asyncio.sleep(_FAIL_CLOSE_REDIS_TIMEOUT + 5)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
|
||||
new=_hang,
|
||||
):
|
||||
start = _time.monotonic()
|
||||
sync_fail_close_session(
|
||||
"sess-hang", _make_log(), exec_loop
|
||||
) # must not raise
|
||||
elapsed = _time.monotonic() - start
|
||||
|
||||
# wait_for fires at _FAIL_CLOSE_REDIS_TIMEOUT; outer future.result
|
||||
# has +2s slack. If the timeout is missing/broken the helper would
|
||||
# block the full sleep duration (~15s).
|
||||
assert elapsed < _FAIL_CLOSE_REDIS_TIMEOUT + 4.0, (
|
||||
f"sync_fail_close_session hung for {elapsed:.1f}s — bounded "
|
||||
f"timeout did not fire"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# End-to-end execute() safety-net coverage — the PR's core invariant
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExecuteSafetyNet:
|
||||
"""``CoPilotProcessor.execute`` must always invoke
|
||||
``sync_fail_close_session`` in its ``finally`` so a session never stays
|
||||
``status=running`` in Redis.
|
||||
|
||||
Validates the four deploy-time scenarios the PR targets:
|
||||
|
||||
* A — SIGTERM mid-turn: ``cancel`` event fires, ``_execute`` returns,
|
||||
safety net still runs.
|
||||
* B — happy path: normal completion, safety net runs (cheap CAS no-op).
|
||||
* C — zombie Redis state: the async ``mark_session_completed`` in
|
||||
``_execute_async`` blows up, but the outer safety net marks the
|
||||
session failed anyway.
|
||||
* D — covered by ``TestSyncFailCloseSession::test_bounded_timeout…``.
|
||||
"""
|
||||
|
||||
def _attach_exec_loop(self, proc: CoPilotProcessor, loop) -> None:
|
||||
"""``execute`` dispatches the safety net onto ``self.execution_loop``.
|
||||
Tests don't call ``on_executor_start`` (which spawns the real
|
||||
per-worker loop), so wire the shared fixture loop in directly."""
|
||||
proc.execution_loop = loop
|
||||
|
||||
def _run_execute_in_thread(self, proc: CoPilotProcessor, cancel: threading.Event):
|
||||
"""``CoPilotProcessor.execute`` expects to be called from a pool
|
||||
worker thread that has *no* running event loop, so we always run
|
||||
it off the main thread to preserve that invariant. Returns the
|
||||
future so callers can inspect both result and exception paths."""
|
||||
pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
|
||||
try:
|
||||
fut = pool.submit(proc.execute, _make_entry(), cancel, MagicMock())
|
||||
# Block until execute() returns (or raises) so the safety net
|
||||
# has run by the time we inspect mocks.
|
||||
try:
|
||||
fut.result(timeout=30)
|
||||
except BaseException:
|
||||
pass
|
||||
return fut
|
||||
finally:
|
||||
pool.shutdown(wait=True)
|
||||
|
||||
def test_happy_path_invokes_safety_net(self, exec_loop) -> None:
|
||||
"""Scenario B: normal completion still runs the sync safety net.
|
||||
Proves the ``finally`` always fires, even when nothing went wrong —
|
||||
``mark_session_completed``'s atomic CAS makes this a cheap no-op
|
||||
in production."""
|
||||
mock_mark = AsyncMock()
|
||||
proc = CoPilotProcessor()
|
||||
self._attach_exec_loop(proc, exec_loop)
|
||||
with patch.object(proc, "_execute"), patch(
|
||||
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
|
||||
new=mock_mark,
|
||||
):
|
||||
self._run_execute_in_thread(proc, threading.Event())
|
||||
|
||||
mock_mark.assert_awaited_once()
|
||||
assert mock_mark.await_args is not None
|
||||
assert mock_mark.await_args.args[0] == "sess-1"
|
||||
|
||||
def test_sigterm_mid_turn_invokes_safety_net(self, exec_loop) -> None:
|
||||
"""Scenario A: worker raises (simulating future.cancel + grace
|
||||
timeout escaping ``_execute``); ``execute`` must still reach the
|
||||
safety net in its ``finally`` and mark the session failed."""
|
||||
mock_mark = AsyncMock()
|
||||
proc = CoPilotProcessor()
|
||||
self._attach_exec_loop(proc, exec_loop)
|
||||
with patch.object(
|
||||
proc,
|
||||
"_execute",
|
||||
side_effect=concurrent.futures.TimeoutError("grace expired"),
|
||||
), patch(
|
||||
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
|
||||
new=mock_mark,
|
||||
):
|
||||
self._run_execute_in_thread(proc, threading.Event())
|
||||
|
||||
mock_mark.assert_awaited_once()
|
||||
|
||||
def test_zombie_redis_async_path_still_marks_session_failed(
|
||||
self, exec_loop
|
||||
) -> None:
|
||||
"""Scenario C: ``_execute_async``'s own ``mark_session_completed``
|
||||
call is broken (simulating the exact async-Redis hiccup that caused
|
||||
the original zombie sessions). The outer ``sync_fail_close_session``
|
||||
runs on the processor's long-lived ``execution_loop`` and succeeds
|
||||
where the async path failed."""
|
||||
call_log: list[str] = []
|
||||
|
||||
async def _ok(*args, **kwargs):
|
||||
call_log.append("sync-ok")
|
||||
|
||||
def _broken_execute(entry, cancel, cluster_lock, log):
|
||||
# Simulate the async path raising because its Redis client is
|
||||
# wedged (the pre-fix zombie-session scenario).
|
||||
raise RuntimeError("async Redis client broken")
|
||||
|
||||
proc = CoPilotProcessor()
|
||||
self._attach_exec_loop(proc, exec_loop)
|
||||
with patch.object(proc, "_execute", side_effect=_broken_execute), patch(
|
||||
"backend.copilot.executor.processor.stream_registry.mark_session_completed",
|
||||
new=_ok,
|
||||
):
|
||||
self._run_execute_in_thread(proc, threading.Event())
|
||||
|
||||
# The sync safety net must have fired despite the async path
|
||||
# blowing up — this is the core guarantee of the PR.
|
||||
assert call_log == [
|
||||
"sync-ok"
|
||||
], f"expected sync_fail_close_session to run once, got {call_log!r}"
|
||||
|
||||
@@ -89,11 +89,16 @@ def get_session_lock_key(session_id: str) -> str:
|
||||
|
||||
|
||||
# CoPilot operations can include extended thinking and agent generation
|
||||
# which may take 30+ minutes to complete
|
||||
COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour
|
||||
# which may take several hours to complete. Matches the pod's
|
||||
# terminationGracePeriodSeconds in the helm chart so a rolling deploy can let
|
||||
# the longest legitimate turn finish. Also bounds the stale-session auto-
|
||||
# complete watchdog in stream_registry (consumer_timeout + 5min buffer).
|
||||
COPILOT_CONSUMER_TIMEOUT_SECONDS = 6 * 60 * 60 # 6 hours
|
||||
|
||||
# Graceful shutdown timeout - allow in-flight operations to complete
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes
|
||||
# Graceful shutdown timeout - must match COPILOT_CONSUMER_TIMEOUT_SECONDS so
|
||||
# cleanup can let the longest legitimate turn complete before the pod is
|
||||
# SIGKILL'd by kubelet.
|
||||
GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||
|
||||
|
||||
def create_copilot_queue_config() -> RabbitMQConfig:
|
||||
@@ -113,9 +118,27 @@ def create_copilot_queue_config() -> RabbitMQConfig:
|
||||
durable=True,
|
||||
auto_delete=False,
|
||||
arguments={
|
||||
# Extended consumer timeout for long-running LLM operations
|
||||
# Default 30-minute timeout is insufficient for extended thinking
|
||||
# and agent generation which can take 30+ minutes
|
||||
# Consumer timeout matches the pod graceful-shutdown window so a
|
||||
# rolling deploy never forces redelivery of a turn that the pod
|
||||
# is still legitimately finishing.
|
||||
#
|
||||
# Deploy note: RabbitMQ (verified on 4.1.4) does NOT strictly
|
||||
# compare ``x-consumer-timeout`` on queue redeclaration, so this
|
||||
# value can change between deploys without triggering
|
||||
# PRECONDITION_FAILED. To update the *effective* timeout on an
|
||||
# already-running queue before the new code deploys (so pods
|
||||
# mid-shutdown don't have their consumer cancelled at the old
|
||||
# limit), apply a policy:
|
||||
#
|
||||
# rabbitmqctl set_policy copilot-consumer-timeout \
|
||||
# "^copilot_execution_queue$" \
|
||||
# '{"consumer-timeout": 21600000}' \
|
||||
# --apply-to queues
|
||||
#
|
||||
# The policy takes effect immediately. Once the policy is set
|
||||
# to match the code's value the policy is redundant for new
|
||||
# pods and can be removed after a stable deploy if desired —
|
||||
# but it's harmless to leave in place.
|
||||
"x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS
|
||||
* 1000,
|
||||
},
|
||||
|
||||
104
autogpt_platform/backend/backend/copilot/model_router.py
Normal file
104
autogpt_platform/backend/backend/copilot/model_router.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""LaunchDarkly-aware model selection for the copilot.
|
||||
|
||||
Each cell of the ``(mode, tier)`` matrix has a static default baked into
|
||||
``ChatConfig`` (see ``copilot/config.py``) and a matching LaunchDarkly
|
||||
string-valued feature flag that can override it per-user. This module
|
||||
centralises the lookup so both the baseline and SDK paths agree on the
|
||||
selection rule and so A/B experiments can target a single cell without
|
||||
shipping a config change.
|
||||
|
||||
Matrix:
|
||||
|
||||
+----------+-------------------------------------+-------------------------------------+
|
||||
| | standard | advanced |
|
||||
+----------+-------------------------------------+-------------------------------------+
|
||||
| fast | copilot-fast-standard-model | copilot-fast-advanced-model |
|
||||
| thinking | copilot-thinking-standard-model | copilot-thinking-advanced-model |
|
||||
+----------+-------------------------------------+-------------------------------------+
|
||||
|
||||
LD flag values are arbitrary strings (model identifiers, e.g.
|
||||
``"anthropic/claude-sonnet-4-6"`` or ``"moonshotai/kimi-k2.6"``). Empty
|
||||
or non-string values fall back to the config default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.util.feature_flag import Flag, get_feature_flag_value
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ModelMode = Literal["fast", "thinking"]
|
||||
ModelTier = Literal["standard", "advanced"]
|
||||
|
||||
|
||||
_FLAG_BY_CELL: dict[tuple[ModelMode, ModelTier], Flag] = {
|
||||
("fast", "standard"): Flag.COPILOT_FAST_STANDARD_MODEL,
|
||||
("fast", "advanced"): Flag.COPILOT_FAST_ADVANCED_MODEL,
|
||||
("thinking", "standard"): Flag.COPILOT_THINKING_STANDARD_MODEL,
|
||||
("thinking", "advanced"): Flag.COPILOT_THINKING_ADVANCED_MODEL,
|
||||
}
|
||||
|
||||
|
||||
def _config_default(config: ChatConfig, mode: ModelMode, tier: ModelTier) -> str:
|
||||
if mode == "fast":
|
||||
return (
|
||||
config.fast_advanced_model
|
||||
if tier == "advanced"
|
||||
else config.fast_standard_model
|
||||
)
|
||||
return (
|
||||
config.thinking_advanced_model
|
||||
if tier == "advanced"
|
||||
else config.thinking_standard_model
|
||||
)
|
||||
|
||||
|
||||
async def resolve_model(
|
||||
mode: ModelMode,
|
||||
tier: ModelTier,
|
||||
user_id: str | None,
|
||||
*,
|
||||
config: ChatConfig,
|
||||
) -> str:
|
||||
"""Return the model identifier for a ``(mode, tier)`` cell.
|
||||
|
||||
Consults the matching LaunchDarkly flag for *user_id* first and
|
||||
falls back to the ``ChatConfig`` default on missing user, missing
|
||||
flag, or non-string flag value. Passing *config* explicitly keeps
|
||||
the resolver cheap to unit-test.
|
||||
"""
|
||||
fallback = _config_default(config, mode, tier).strip()
|
||||
if not user_id:
|
||||
return fallback
|
||||
|
||||
flag = _FLAG_BY_CELL[(mode, tier)]
|
||||
try:
|
||||
value = await get_feature_flag_value(flag.value, user_id, default=fallback)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[model_router] LD lookup failed for %s — using config default %s",
|
||||
flag.value,
|
||||
fallback,
|
||||
exc_info=True,
|
||||
)
|
||||
return fallback
|
||||
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
if value != fallback:
|
||||
reason = (
|
||||
"empty string"
|
||||
if isinstance(value, str)
|
||||
else f"non-string ({type(value).__name__})"
|
||||
)
|
||||
logger.warning(
|
||||
"[model_router] LD flag %s returned %s — using config default %s",
|
||||
flag.value,
|
||||
reason,
|
||||
fallback,
|
||||
)
|
||||
return fallback
|
||||
166
autogpt_platform/backend/backend/copilot/model_router_test.py
Normal file
166
autogpt_platform/backend/backend/copilot/model_router_test.py
Normal file
@@ -0,0 +1,166 @@
|
||||
"""Tests for the LD-aware model resolver."""
|
||||
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.model_router import _FLAG_BY_CELL, _config_default, resolve_model
|
||||
|
||||
|
||||
def _make_config() -> ChatConfig:
|
||||
"""Build a config with the canonical defaults so tests read naturally."""
|
||||
return ChatConfig(
|
||||
fast_standard_model="anthropic/claude-sonnet-4-6",
|
||||
fast_advanced_model="anthropic/claude-opus-4.7",
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4.7",
|
||||
)
|
||||
|
||||
|
||||
class TestConfigDefault:
|
||||
def test_fast_standard(self):
|
||||
cfg = _make_config()
|
||||
assert _config_default(cfg, "fast", "standard") == cfg.fast_standard_model
|
||||
|
||||
def test_fast_advanced(self):
|
||||
cfg = _make_config()
|
||||
assert _config_default(cfg, "fast", "advanced") == cfg.fast_advanced_model
|
||||
|
||||
def test_thinking_standard(self):
|
||||
cfg = _make_config()
|
||||
assert (
|
||||
_config_default(cfg, "thinking", "standard") == cfg.thinking_standard_model
|
||||
)
|
||||
|
||||
def test_thinking_advanced(self):
|
||||
cfg = _make_config()
|
||||
assert (
|
||||
_config_default(cfg, "thinking", "advanced") == cfg.thinking_advanced_model
|
||||
)
|
||||
|
||||
|
||||
class TestResolveModel:
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_user_returns_fallback(self):
|
||||
"""Without user_id there's no LD context — skip the lookup entirely."""
|
||||
cfg = _make_config()
|
||||
with patch("backend.copilot.model_router.get_feature_flag_value") as mock_flag:
|
||||
result = await resolve_model("fast", "standard", None, config=cfg)
|
||||
assert result == cfg.fast_standard_model
|
||||
mock_flag.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_user_strips_whitespace_from_fallback(self):
|
||||
"""Sentry MEDIUM: the anonymous-user branch returned an unstripped
|
||||
config value. If ``CHAT_*_MODEL`` env carries trailing whitespace
|
||||
the downstream ``resolved == tier_default`` check in
|
||||
``_resolve_sdk_model_for_request`` would diverge from the
|
||||
whitespace-stripped LD side, bypassing subscription mode for
|
||||
every anonymous request. Strip at the source."""
|
||||
cfg = ChatConfig(
|
||||
fast_standard_model="anthropic/claude-sonnet-4-6 ", # trailing ws
|
||||
fast_advanced_model="anthropic/claude-opus-4.7",
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4.7",
|
||||
)
|
||||
result = await resolve_model("fast", "standard", None, config=cfg)
|
||||
assert result == "anthropic/claude-sonnet-4-6"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_string_override_wins(self):
|
||||
"""LD-returned model string replaces the config default."""
|
||||
cfg = _make_config()
|
||||
with patch(
|
||||
"backend.copilot.model_router.get_feature_flag_value",
|
||||
new=AsyncMock(return_value="moonshotai/kimi-k2.6"),
|
||||
):
|
||||
result = await resolve_model("fast", "standard", "user-1", config=cfg)
|
||||
assert result == "moonshotai/kimi-k2.6"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_is_stripped(self):
|
||||
cfg = _make_config()
|
||||
with patch(
|
||||
"backend.copilot.model_router.get_feature_flag_value",
|
||||
new=AsyncMock(return_value=" xai/grok-4 "),
|
||||
):
|
||||
result = await resolve_model("thinking", "advanced", "user-1", config=cfg)
|
||||
assert result == "xai/grok-4"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_string_value_falls_back_with_type_in_warning(self, caplog):
|
||||
"""LD misconfigured as a boolean flag — don't try to use ``True`` as a
|
||||
model name; return the config default. Warning must say
|
||||
'non-string' (not 'empty string') so the LD operator knows the
|
||||
flag type is wrong, not just missing a value."""
|
||||
import logging
|
||||
|
||||
cfg = _make_config()
|
||||
with caplog.at_level(logging.WARNING, logger="backend.copilot.model_router"):
|
||||
with patch(
|
||||
"backend.copilot.model_router.get_feature_flag_value",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
result = await resolve_model("fast", "advanced", "user-1", config=cfg)
|
||||
assert result == cfg.fast_advanced_model
|
||||
assert any("non-string" in r.message for r in caplog.records)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_string_falls_back_with_empty_in_warning(self, caplog):
|
||||
"""When LD returns ``""`` the warning must say 'empty string' —
|
||||
not 'non-string' — so the operator doesn't chase a type bug
|
||||
when the flag is simply unset to an empty value."""
|
||||
import logging
|
||||
|
||||
cfg = _make_config()
|
||||
with caplog.at_level(logging.WARNING, logger="backend.copilot.model_router"):
|
||||
with patch(
|
||||
"backend.copilot.model_router.get_feature_flag_value",
|
||||
new=AsyncMock(return_value=""),
|
||||
):
|
||||
result = await resolve_model("fast", "standard", "user-1", config=cfg)
|
||||
assert result == cfg.fast_standard_model
|
||||
messages = [r.message for r in caplog.records]
|
||||
assert any("empty string" in m for m in messages)
|
||||
assert not any("non-string" in m for m in messages)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ld_exception_falls_back(self):
|
||||
"""LD client throws (network blip, SDK init race) — serve the default
|
||||
instead of failing the whole request."""
|
||||
cfg = _make_config()
|
||||
with patch(
|
||||
"backend.copilot.model_router.get_feature_flag_value",
|
||||
new=AsyncMock(side_effect=RuntimeError("LD down")),
|
||||
):
|
||||
result = await resolve_model("fast", "standard", "user-1", config=cfg)
|
||||
assert result == cfg.fast_standard_model
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_four_cells_hit_distinct_flags(self):
|
||||
"""Each (mode, tier) cell must route to its own flag — regression
|
||||
guard against copy-paste bugs in the _FLAG_BY_CELL map."""
|
||||
cfg = _make_config()
|
||||
calls: list[str] = []
|
||||
|
||||
async def _capture(flag_key, user_id, default):
|
||||
calls.append(flag_key)
|
||||
return default
|
||||
|
||||
with patch(
|
||||
"backend.copilot.model_router.get_feature_flag_value",
|
||||
new=AsyncMock(side_effect=_capture),
|
||||
):
|
||||
await resolve_model("fast", "standard", "u", config=cfg)
|
||||
await resolve_model("fast", "advanced", "u", config=cfg)
|
||||
await resolve_model("thinking", "standard", "u", config=cfg)
|
||||
await resolve_model("thinking", "advanced", "u", config=cfg)
|
||||
|
||||
assert calls == [
|
||||
_FLAG_BY_CELL[("fast", "standard")].value,
|
||||
_FLAG_BY_CELL[("fast", "advanced")].value,
|
||||
_FLAG_BY_CELL[("thinking", "standard")].value,
|
||||
_FLAG_BY_CELL[("thinking", "advanced")].value,
|
||||
]
|
||||
assert len(set(calls)) == 4
|
||||
147
autogpt_platform/backend/backend/copilot/moonshot.py
Normal file
147
autogpt_platform/backend/backend/copilot/moonshot.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""Moonshot-specific pricing and cache-control helpers.
|
||||
|
||||
Moonshot's Kimi K2.x family is routed through OpenRouter's Anthropic-compat
|
||||
shim — it speaks Anthropic's API shape but its pricing and cache behaviour
|
||||
diverge from Anthropic in ways the Claude Agent SDK CLI and our baseline
|
||||
cache-control gating don't handle on their own:
|
||||
|
||||
* **Rate card** — NOT the canonical cost source. The authoritative number
|
||||
for every OpenRouter-routed turn is the reconcile task
|
||||
(:mod:`openrouter_cost`), which reads ``total_cost`` directly from
|
||||
``/api/v1/generation`` post-turn. This module exists purely so the
|
||||
CLI's in-turn ``ResultMessage.total_cost_usd`` (which silently bills
|
||||
Moonshot at Sonnet rates, ~5x the real Moonshot price because the CLI
|
||||
pricing table only knows Anthropic) isn't left wildly wrong before the
|
||||
reconcile fires AND so the reconcile's lookup-fail fallback records a
|
||||
plausible Moonshot estimate rather than a Sonnet-rate overcharge.
|
||||
Signal authority: reconcile >> this module's rate card >> CLI.
|
||||
|
||||
* **Cache-control** — Anthropic and Moonshot both accept the
|
||||
``cache_control: {type: ephemeral}`` breakpoint on message blocks, but
|
||||
our baseline path currently gates cache markers on an
|
||||
``anthropic/`` / ``claude`` name match because non-Anthropic providers
|
||||
(OpenAI, Grok, Gemini) 400 on the unknown field. Moonshot's
|
||||
Anthropic-compat endpoint silently accepts and honours the marker —
|
||||
empirically boosts cache hit rate on continuation turns — but was
|
||||
caught in the non-Anthropic branch of the original gate.
|
||||
:func:`moonshot_supports_cache_control` lets callers widen the gate
|
||||
to include Moonshot without weakening the ``false`` answer for
|
||||
OpenAI et al. (The predicate is intentionally narrow — Moonshot-only
|
||||
— so callers combine it with an explicit Anthropic check at the call
|
||||
site; see ``baseline/service.py::_supports_prompt_cache_markers``.)
|
||||
|
||||
Detection is prefix-based (``moonshotai/``). Moonshot routes every Kimi
|
||||
SKU through the same Anthropic-compat surface and currently prices them
|
||||
identically, so a new ``moonshotai/kimi-k3.0`` slug transparently
|
||||
inherits both the rate card and the cache-control gate without editing
|
||||
this file. Per-slug overrides are in :data:`_RATE_OVERRIDES_USD_PER_MTOK`
|
||||
for when Moonshot eventually splits prices.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
# All Moonshot slugs share these rates as of April 2026 — Moonshot prices
|
||||
# every Kimi K2.x SKU at $0.60/$2.80 per million (input/output) via
|
||||
# OpenRouter. Cache-read / cache-write discounts are NOT applied here:
|
||||
# OpenRouter currently exposes only a single input price per Moonshot
|
||||
# endpoint; the real billed amount (with cache savings) lands via the
|
||||
# reconcile path. Keep in sync with https://platform.moonshot.ai/docs/pricing.
|
||||
_DEFAULT_MOONSHOT_RATE_USD_PER_MTOK: tuple[float, float] = (0.60, 2.80)
|
||||
|
||||
# Per-slug overrides for when Moonshot splits pricing across SKUs. Empty
|
||||
# today — every slug matching ``moonshotai/`` falls back to
|
||||
# :data:`_DEFAULT_MOONSHOT_RATE_USD_PER_MTOK`.
|
||||
_RATE_OVERRIDES_USD_PER_MTOK: dict[str, tuple[float, float]] = {}
|
||||
|
||||
# Vendor prefix — matches any OpenRouter slug Moonshot ships. Keep as a
|
||||
# module constant so the prefix check stays in exactly one place.
|
||||
_MOONSHOT_PREFIX = "moonshotai/"
|
||||
|
||||
|
||||
def is_moonshot_model(model: str | None) -> bool:
|
||||
"""True when *model* is a Moonshot OpenRouter slug.
|
||||
|
||||
Prefix match against ``moonshotai/`` covers every Kimi SKU Moonshot
|
||||
ships today (``kimi-k2``, ``kimi-k2.5``, ``kimi-k2.6``,
|
||||
``kimi-k2-thinking``) plus any future SKU Moonshot publishes under
|
||||
the same namespace. Used by both pricing and cache-control gating.
|
||||
"""
|
||||
return isinstance(model, str) and model.startswith(_MOONSHOT_PREFIX)
|
||||
|
||||
|
||||
def rate_card_usd(model: str | None) -> tuple[float, float] | None:
|
||||
"""Return (input, output) $/Mtok for *model* or None if non-Moonshot.
|
||||
|
||||
Looks up a per-slug override first, falling back to the shared
|
||||
default for anything under ``moonshotai/``. Returns None for
|
||||
non-Moonshot slugs (including ``None``) so callers can skip the
|
||||
override without a preflight guard.
|
||||
"""
|
||||
if not is_moonshot_model(model):
|
||||
return None
|
||||
# ``is_moonshot_model`` narrowed ``model`` to str; dict.get is
|
||||
# type-safe here despite the wider param annotation above.
|
||||
assert model is not None
|
||||
return _RATE_OVERRIDES_USD_PER_MTOK.get(model, _DEFAULT_MOONSHOT_RATE_USD_PER_MTOK)
|
||||
|
||||
|
||||
def override_cost_usd(
|
||||
*,
|
||||
model: str | None,
|
||||
sdk_reported_usd: float,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cache_read_tokens: int,
|
||||
cache_creation_tokens: int,
|
||||
) -> float:
|
||||
"""Recompute SDK turn cost from the Moonshot rate card.
|
||||
|
||||
Not the canonical cost source — the OpenRouter ``/generation``
|
||||
reconcile (:mod:`openrouter_cost`) lands the authoritative billed
|
||||
amount post-turn. This helper exists only to improve the CLI's
|
||||
in-turn ``ResultMessage.total_cost_usd``:
|
||||
|
||||
1. So the ``cost_usd`` the client sees before the reconcile completes
|
||||
isn't wildly wrong (the CLI would otherwise ship a Sonnet-rate
|
||||
estimate, ~5x the real Moonshot bill).
|
||||
2. So the reconcile's own lookup-fail fallback records a plausible
|
||||
Moonshot estimate rather than the CLI's Sonnet number.
|
||||
|
||||
For Moonshot slugs we compute cost from the reported token counts;
|
||||
for anything else (including Anthropic) we return the SDK number
|
||||
unchanged — Anthropic slugs are priced accurately by the CLI.
|
||||
|
||||
Cache read / creation tokens are folded into ``prompt_tokens`` at
|
||||
the full input rate because Moonshot's rate card doesn't distinguish
|
||||
them at the OpenRouter surface; the reconcile has the authoritative
|
||||
discount accounting for turns where Moonshot's cache engaged.
|
||||
"""
|
||||
if model is None:
|
||||
return sdk_reported_usd
|
||||
rates = rate_card_usd(model)
|
||||
if rates is None:
|
||||
return sdk_reported_usd
|
||||
input_rate, output_rate = rates
|
||||
total_prompt = prompt_tokens + cache_read_tokens + cache_creation_tokens
|
||||
return (total_prompt * input_rate + completion_tokens * output_rate) / 1_000_000
|
||||
|
||||
|
||||
def moonshot_supports_cache_control(model: str | None) -> bool:
|
||||
"""True when a Moonshot *model* accepts Anthropic-style ``cache_control``.
|
||||
|
||||
Narrow, Moonshot-specific predicate — callers that need the full
|
||||
"does this route accept cache markers" answer combine this with an
|
||||
Anthropic check (see ``baseline/service.py::_supports_prompt_cache_markers``).
|
||||
Named ``moonshot_*`` deliberately so the call site can't mistake it
|
||||
for a universal predicate that answers correctly for Anthropic
|
||||
(which also supports cache_control — this function would return
|
||||
False for Anthropic slugs).
|
||||
|
||||
Moonshot's Anthropic-compat endpoint honours the marker. Without
|
||||
it Moonshot falls back to its own automatic prefix caching, which
|
||||
drifts more readily between turns (internal testing saw 0/4 cache
|
||||
hits across two continuation sessions). With explicit
|
||||
``cache_control`` the upstream cache hit rate rises to the same
|
||||
ballpark as Anthropic's ~60-95% on continuations.
|
||||
"""
|
||||
return is_moonshot_model(model)
|
||||
173
autogpt_platform/backend/backend/copilot/moonshot_test.py
Normal file
173
autogpt_platform/backend/backend/copilot/moonshot_test.py
Normal file
@@ -0,0 +1,173 @@
|
||||
"""Unit tests for Moonshot pricing and cache-control helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.moonshot import (
|
||||
is_moonshot_model,
|
||||
moonshot_supports_cache_control,
|
||||
override_cost_usd,
|
||||
rate_card_usd,
|
||||
)
|
||||
|
||||
|
||||
class TestIsMoonshotModel:
|
||||
"""Prefix detection covers every Moonshot SKU without a slug list."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"moonshotai/kimi-k2.6",
|
||||
"moonshotai/kimi-k2-thinking",
|
||||
"moonshotai/kimi-k2.5",
|
||||
"moonshotai/kimi-k2",
|
||||
"moonshotai/kimi-k3.0", # Future SKU must match transparently.
|
||||
],
|
||||
)
|
||||
def test_moonshot_slugs_match(self, model: str) -> None:
|
||||
assert is_moonshot_model(model) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"anthropic/claude-sonnet-4.6",
|
||||
"anthropic/claude-opus-4.7",
|
||||
"openai/gpt-4o",
|
||||
"google/gemini-2.5-flash",
|
||||
"xai/grok-4",
|
||||
"deepseek/deepseek-v3",
|
||||
"", # Empty string — not Moonshot.
|
||||
],
|
||||
)
|
||||
def test_non_moonshot_slugs_do_not_match(self, model: str) -> None:
|
||||
assert is_moonshot_model(model) is False
|
||||
|
||||
@pytest.mark.parametrize("model", [None, 123, ["moonshotai/kimi-k2.6"]])
|
||||
def test_non_string_returns_false(self, model) -> None:
|
||||
# Type-robust: never raise on unexpected types; callers pass None.
|
||||
assert is_moonshot_model(model) is False
|
||||
|
||||
|
||||
class TestRateCardUsd:
|
||||
"""Rate card defaults to the shared Moonshot price for every SKU."""
|
||||
|
||||
def test_moonshot_default_rate(self) -> None:
|
||||
assert rate_card_usd("moonshotai/kimi-k2.6") == (0.60, 2.80)
|
||||
|
||||
def test_future_moonshot_sku_inherits_default(self) -> None:
|
||||
# Verifies the prefix-based fallback — new SKUs don't need a code
|
||||
# edit to get a reasonable rate card.
|
||||
assert rate_card_usd("moonshotai/kimi-k3.0") == (0.60, 2.80)
|
||||
|
||||
def test_non_moonshot_returns_none(self) -> None:
|
||||
assert rate_card_usd("anthropic/claude-sonnet-4.6") is None
|
||||
assert rate_card_usd("openai/gpt-4o") is None
|
||||
|
||||
|
||||
class TestOverrideCostUsd:
|
||||
"""Rate-card override replaces the CLI's Sonnet-rate estimate for
|
||||
Moonshot turns; Anthropic and unknown slugs pass through unchanged."""
|
||||
|
||||
def test_moonshot_recomputes_from_rate_card(self) -> None:
|
||||
"""A 29.5K-prompt Kimi turn should land at ~$0.018 on the
|
||||
Moonshot rate card, not the CLI's $0.09 Sonnet-rate estimate."""
|
||||
recomputed = override_cost_usd(
|
||||
model="moonshotai/kimi-k2.6",
|
||||
sdk_reported_usd=0.089862, # What the CLI reported (Sonnet price).
|
||||
prompt_tokens=29564,
|
||||
completion_tokens=78,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
)
|
||||
expected = (29564 * 0.60 + 78 * 2.80) / 1_000_000
|
||||
assert recomputed == pytest.approx(expected, rel=1e-9)
|
||||
assert 0.017 < recomputed < 0.019 # Sanity against Moonshot's rate card.
|
||||
|
||||
def test_anthropic_passes_through(self) -> None:
|
||||
"""Anthropic slugs are priced accurately by the CLI already —
|
||||
the override returns the SDK number unchanged."""
|
||||
assert (
|
||||
override_cost_usd(
|
||||
model="anthropic/claude-sonnet-4.6",
|
||||
sdk_reported_usd=0.089862,
|
||||
prompt_tokens=29564,
|
||||
completion_tokens=78,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
)
|
||||
== 0.089862
|
||||
)
|
||||
|
||||
def test_unknown_non_moonshot_passes_through(self) -> None:
|
||||
"""A non-Moonshot, non-Anthropic slug falls back to the SDK value
|
||||
— best-effort rather than leaking a zero or a wrong rate card."""
|
||||
assert (
|
||||
override_cost_usd(
|
||||
model="deepseek/deepseek-v3",
|
||||
sdk_reported_usd=0.05,
|
||||
prompt_tokens=10_000,
|
||||
completion_tokens=500,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
)
|
||||
== 0.05
|
||||
)
|
||||
|
||||
def test_none_model_passes_through(self) -> None:
|
||||
"""Subscription mode sets model=None — return the SDK value."""
|
||||
assert (
|
||||
override_cost_usd(
|
||||
model=None,
|
||||
sdk_reported_usd=0.07,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=10,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
)
|
||||
== 0.07
|
||||
)
|
||||
|
||||
def test_cache_tokens_priced_at_input_rate(self) -> None:
|
||||
"""OpenRouter's Moonshot endpoints don't expose a discounted
|
||||
cached-input price — cache_read / cache_creation tokens are
|
||||
priced at the full input rate. The reconcile path has the
|
||||
authoritative discount for turns where Moonshot's cache engaged."""
|
||||
recomputed = override_cost_usd(
|
||||
model="moonshotai/kimi-k2.6",
|
||||
sdk_reported_usd=0.5,
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=0,
|
||||
cache_read_tokens=5000,
|
||||
cache_creation_tokens=2000,
|
||||
)
|
||||
expected = (1000 + 5000 + 2000) * 0.60 / 1_000_000
|
||||
assert recomputed == pytest.approx(expected, rel=1e-9)
|
||||
|
||||
|
||||
class TestSupportsCacheControl:
|
||||
"""Gate for emitting ``cache_control: {type: ephemeral}`` on message
|
||||
blocks. True for Moonshot (Anthropic-compat endpoint accepts it)
|
||||
and False for everything else this module knows about — Anthropic
|
||||
callers use their own ``_is_anthropic_model`` check which is
|
||||
combined with this one into a wider gate."""
|
||||
|
||||
def test_moonshot_supports_cache_control(self) -> None:
|
||||
assert moonshot_supports_cache_control("moonshotai/kimi-k2.6") is True
|
||||
|
||||
def test_future_moonshot_sku_supports_cache_control(self) -> None:
|
||||
assert moonshot_supports_cache_control("moonshotai/kimi-k3.0") is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"openai/gpt-4o",
|
||||
"google/gemini-2.5-flash",
|
||||
"xai/grok-4",
|
||||
"deepseek/deepseek-v3",
|
||||
"",
|
||||
None,
|
||||
],
|
||||
)
|
||||
def test_non_moonshot_does_not_support_cache_control(self, model) -> None:
|
||||
assert moonshot_supports_cache_control(model) is False
|
||||
@@ -240,16 +240,15 @@ async def peek_pending_messages(session_id: str) -> list[PendingMessage]:
|
||||
return messages
|
||||
|
||||
|
||||
async def _clear_pending_messages_unsafe(session_id: str) -> None:
|
||||
async def clear_pending_messages_unsafe(session_id: str) -> None:
|
||||
"""Drop the session's pending buffer — **not** the normal turn cleanup.
|
||||
|
||||
Named ``_unsafe`` because reaching for this at turn end drops queued
|
||||
follow-ups on the floor instead of running them (the bug fixed by
|
||||
commit b64be73). The atomic ``LPOP`` drain at turn start is the
|
||||
primary consumer; anything pushed after the drain window belongs to
|
||||
the next turn by definition. Retained only as an operator/debug
|
||||
escape hatch for manually clearing a stuck session and as a fixture
|
||||
in the unit tests.
|
||||
The ``_unsafe`` suffix warns: reaching for this at turn end drops queued
|
||||
follow-ups on the floor instead of running them (the bug fixed by commit
|
||||
b64be73). The atomic ``LPOP`` drain at turn start is the primary consumer;
|
||||
anything pushed after the drain window belongs to the next turn by
|
||||
definition. Retained only as an operator/debug escape hatch for manually
|
||||
clearing a stuck session and as a fixture in the unit tests.
|
||||
"""
|
||||
redis = await get_redis_async()
|
||||
await redis.delete(_buffer_key(session_id))
|
||||
|
||||
@@ -16,7 +16,7 @@ from backend.copilot.pending_messages import (
|
||||
MAX_PENDING_MESSAGES,
|
||||
PendingMessage,
|
||||
PendingMessageContext,
|
||||
_clear_pending_messages_unsafe,
|
||||
clear_pending_messages_unsafe,
|
||||
drain_and_format_for_injection,
|
||||
drain_pending_for_persist,
|
||||
drain_pending_messages,
|
||||
@@ -208,15 +208,15 @@ async def test_cap_drops_oldest_when_exceeded(fake_redis: _FakeRedis) -> None:
|
||||
async def test_clear_removes_buffer(fake_redis: _FakeRedis) -> None:
|
||||
await push_pending_message("sess4", PendingMessage(content="x"))
|
||||
await push_pending_message("sess4", PendingMessage(content="y"))
|
||||
await _clear_pending_messages_unsafe("sess4")
|
||||
await clear_pending_messages_unsafe("sess4")
|
||||
assert await peek_pending_count("sess4") == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_is_idempotent(fake_redis: _FakeRedis) -> None:
|
||||
# Clearing an already-empty buffer should not raise
|
||||
await _clear_pending_messages_unsafe("sess_empty")
|
||||
await _clear_pending_messages_unsafe("sess_empty")
|
||||
await clear_pending_messages_unsafe("sess_empty")
|
||||
await clear_pending_messages_unsafe("sess_empty")
|
||||
|
||||
|
||||
# ── Publish hook ────────────────────────────────────────────────────
|
||||
|
||||
@@ -52,10 +52,15 @@ is at most as permissive as the parent:
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Literal, get_args
|
||||
from typing import TYPE_CHECKING, Literal, get_args
|
||||
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from backend.copilot.tools import ToolGroup
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants — single source of truth for all accepted tool names
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -66,7 +71,6 @@ from pydantic import BaseModel, PrivateAttr
|
||||
ToolName = Literal[
|
||||
# Platform tools (must match keys in TOOL_REGISTRY)
|
||||
"add_understanding",
|
||||
"ask_question",
|
||||
"bash_exec",
|
||||
"browser_act",
|
||||
"browser_navigate",
|
||||
@@ -125,9 +129,16 @@ ToolName = Literal[
|
||||
# Frozen set of all valid tool names — derived from the Literal.
|
||||
ALL_TOOL_NAMES: frozenset[str] = frozenset(get_args(ToolName))
|
||||
|
||||
# SDK built-in tool names — uppercase-initial names are SDK built-ins.
|
||||
# SDK built-in tool names — tools provided by the Claude Code CLI that our
|
||||
# code does not implement directly. ``TodoWrite`` is DELIBERATELY excluded:
|
||||
# baseline mode ships an MCP-wrapped platform version
|
||||
# (``tools/todo_write.py``), while SDK mode still uses the CLI-native
|
||||
# original via ``_SDK_BUILTIN_ALWAYS`` in ``sdk/tool_adapter.py`` — the
|
||||
# MCP copy is filtered out there. ``Task`` remains an SDK-only built-in
|
||||
# (for queue-backed context-isolation on baseline, use ``run_sub_session``
|
||||
# instead).
|
||||
SDK_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset(
|
||||
n for n in ALL_TOOL_NAMES if n[0].isupper()
|
||||
{"Agent", "Edit", "Glob", "Grep", "Read", "Task", "WebSearch", "Write"}
|
||||
)
|
||||
|
||||
# Platform tool names — everything that isn't an SDK built-in.
|
||||
@@ -364,13 +375,17 @@ def apply_tool_permissions(
|
||||
permissions: CopilotPermissions,
|
||||
*,
|
||||
use_e2b: bool = False,
|
||||
disabled_groups: Iterable[ToolGroup] = (),
|
||||
) -> tuple[list[str], list[str]]:
|
||||
"""Compute (allowed_tools, extra_disallowed) for :class:`ClaudeAgentOptions`.
|
||||
|
||||
Takes the base allowed/disallowed lists from
|
||||
:func:`~backend.copilot.sdk.tool_adapter.get_copilot_tool_names` /
|
||||
:func:`~backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools` and
|
||||
applies *permissions* on top.
|
||||
applies *permissions* on top. Tools belonging to any *disabled_groups*
|
||||
are hidden from the base allowed list — use this to gate capability
|
||||
groups (e.g. ``"graphiti"`` when the memory backend is off for the
|
||||
current user).
|
||||
|
||||
Returns:
|
||||
``(allowed_tools, extra_disallowed)`` where *allowed_tools* is the
|
||||
@@ -380,13 +395,16 @@ def apply_tool_permissions(
|
||||
"""
|
||||
from backend.copilot.sdk.tool_adapter import (
|
||||
_READ_TOOL_NAME,
|
||||
BASELINE_ONLY_MCP_TOOLS,
|
||||
MCP_TOOL_PREFIX,
|
||||
get_copilot_tool_names,
|
||||
get_sdk_disallowed_tools,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
|
||||
base_allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
base_allowed = get_copilot_tool_names(
|
||||
use_e2b=use_e2b, disabled_groups=disabled_groups
|
||||
)
|
||||
base_disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
|
||||
if permissions.is_empty():
|
||||
@@ -420,7 +438,14 @@ def apply_tool_permissions(
|
||||
# keeping only those present in the original base_allowed list.
|
||||
def to_sdk_names(short: str) -> list[str]:
|
||||
names: list[str] = []
|
||||
if short in TOOL_REGISTRY:
|
||||
if short in BASELINE_ONLY_MCP_TOOLS:
|
||||
# Baseline ships MCP versions of these (Task/TodoWrite) for
|
||||
# model-flexibility parity, but SDK mode uses the CLI-native
|
||||
# originals. Permissions target the CLI built-in here so
|
||||
# ``base_allowed`` (which excludes the MCP wrappers) still
|
||||
# matches.
|
||||
names.append(short)
|
||||
elif short in TOOL_REGISTRY:
|
||||
names.append(f"{MCP_TOOL_PREFIX}{short}")
|
||||
elif short in _SDK_TO_MCP:
|
||||
# Map SDK built-in file tool to its MCP equivalent.
|
||||
|
||||
@@ -582,6 +582,11 @@ class TestApplyToolPermissions:
|
||||
|
||||
class TestSdkBuiltinToolNames:
|
||||
def test_expected_builtins_present(self):
|
||||
# ``TodoWrite`` is DELIBERATELY absent: baseline ships an MCP-wrapped
|
||||
# platform version for model-flexibility parity, so it appears in
|
||||
# PLATFORM_TOOL_NAMES / TOOL_REGISTRY instead. ``Task`` remains
|
||||
# SDK-only — baseline uses ``run_sub_session`` for the equivalent
|
||||
# context-isolation role.
|
||||
expected = {
|
||||
"Agent",
|
||||
"Read",
|
||||
@@ -591,9 +596,9 @@ class TestSdkBuiltinToolNames:
|
||||
"Grep",
|
||||
"Task",
|
||||
"WebSearch",
|
||||
"TodoWrite",
|
||||
}
|
||||
assert expected.issubset(SDK_BUILTIN_TOOL_NAMES)
|
||||
assert "TodoWrite" not in SDK_BUILTIN_TOOL_NAMES
|
||||
|
||||
def test_platform_names_match_tool_registry(self):
|
||||
"""PLATFORM_TOOL_NAMES (derived from ToolName Literal) must match TOOL_REGISTRY keys."""
|
||||
|
||||
@@ -145,31 +145,12 @@ When the user asks to interact with a service or API, follow this order:
|
||||
|
||||
**Never skip step 1.** Built-in blocks are more reliable, tested, and user-friendly than MCP or raw API calls.
|
||||
|
||||
### Sub-agent tasks
|
||||
- When using the Task tool, NEVER set `run_in_background` to true.
|
||||
All tasks must run in the foreground.
|
||||
|
||||
### Delegating to another autopilot (sub-autopilot pattern)
|
||||
Use the **`run_sub_session`** tool to delegate a task to a fresh
|
||||
sub-AutoPilot. The sub has its own full tool set and can perform
|
||||
multi-step work autonomously.
|
||||
|
||||
- `prompt` (required): the task description.
|
||||
- `system_context` (optional): extra context prepended to the prompt.
|
||||
- `sub_autopilot_session_id` (optional): continue an existing
|
||||
sub-AutoPilot — pass the `sub_autopilot_session_id` returned by a
|
||||
previous completed run.
|
||||
- `wait_for_result` (default 60, max 300): seconds to wait inline. If
|
||||
the sub isn't done by then you get `status="running"` + a
|
||||
`sub_session_id` — call **`get_sub_session_result`** with that id
|
||||
(wait up to 300s more per call) until it returns `completed` or
|
||||
`error`. Works across turns — safe to reconnect in a later message.
|
||||
|
||||
Use this when a task is complex enough to benefit from a separate
|
||||
autopilot context, e.g. "research X and write a report" while the
|
||||
parent autopilot handles orchestration. Do NOT invoke `AutoPilotBlock`
|
||||
via `run_block` — it's hidden from `run_block` by design because the
|
||||
dedicated tool handles the async lifecycle correctly.
|
||||
### Complex multi-step work
|
||||
- Use `TodoWrite` to track the plan once the job has 3+ distinct steps.
|
||||
- Delegate self-contained subtasks to `run_sub_session` to keep their
|
||||
intermediate tool calls out of the parent context.
|
||||
- Do NOT invoke `AutoPilotBlock` via `run_block`; use `run_sub_session`
|
||||
instead.
|
||||
|
||||
"""
|
||||
|
||||
@@ -182,14 +163,18 @@ sandbox so `bash_exec` can access it for further processing.
|
||||
The exact sandbox path is shown in the `[Sandbox copy available at ...]` note.
|
||||
|
||||
### GitHub CLI (`gh`) and git
|
||||
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it.
|
||||
- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before running `connect_integration(provider="github")` which will ask the user to connect their GitHub regardless if it's already connected.
|
||||
- If the user has connected their GitHub account, both `gh` and `git` are
|
||||
pre-authenticated — use them directly without any manual login step.
|
||||
`git` HTTPS operations (clone, push, pull) work automatically.
|
||||
- If the token changes mid-session (e.g. user reconnects with a new token),
|
||||
run `gh auth setup-git` to re-register the credential helper.
|
||||
- If `gh` or `git` fails with an authentication error (e.g. "authentication
|
||||
required", "could not read Username", or exit code 128), call
|
||||
- **MANDATORY:** You MUST run `gh auth status` before EVER calling
|
||||
`connect_integration(provider="github")`. If it shows `Logged in`,
|
||||
proceed directly — no integration connection needed. Never skip this check.
|
||||
- If `gh auth status` shows NOT logged in, or `gh`/`git` fails with an
|
||||
authentication error (e.g. "authentication required", "could not read
|
||||
Username", or exit code 128), THEN call
|
||||
`connect_integration(provider="github")` to surface the GitHub credentials
|
||||
setup card so the user can connect their account. Once connected, retry
|
||||
the operation.
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Tests for agent generation guide — verifies clarification section."""
|
||||
"""Tests for prompting helpers."""
|
||||
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
|
||||
from backend.copilot import prompting
|
||||
|
||||
@@ -31,28 +30,3 @@ class TestGetSdkSupplementStaticPlaceholder:
|
||||
def test_e2b_mode_has_no_session_placeholder(self):
|
||||
result = prompting.get_sdk_supplement(use_e2b=True)
|
||||
assert "<session-id>" not in result
|
||||
|
||||
|
||||
class TestAgentGenerationGuideContainsClarifySection:
|
||||
"""The agent generation guide must include the clarification section."""
|
||||
|
||||
def test_guide_includes_clarify_section(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
assert "Before or During Building" in content
|
||||
|
||||
def test_guide_mentions_find_block_for_clarification(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
clarify_section = content.split("Before or During Building")[1].split(
|
||||
"### Workflow"
|
||||
)[0]
|
||||
assert "find_block" in clarify_section
|
||||
|
||||
def test_guide_mentions_ask_question_tool(self):
|
||||
guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md"
|
||||
content = guide_path.read_text(encoding="utf-8")
|
||||
clarify_section = content.split("Before or During Building")[1].split(
|
||||
"### Workflow"
|
||||
)[0]
|
||||
assert "ask_question" in clarify_section
|
||||
|
||||
@@ -39,6 +39,7 @@ Before running the workflow below, ALWAYS decompose the goal first:
|
||||
For simple goals (1-2 blocks), keep steps brief (2-3 steps).
|
||||
For complex goals, use as many steps as needed.
|
||||
|
||||
|
||||
### Workflow for Creating/Editing Agents
|
||||
|
||||
1. **If editing**: First narrow to the specific agent by UUID, then fetch its
|
||||
|
||||
@@ -13,12 +13,19 @@ from backend.copilot.config import ChatConfig
|
||||
|
||||
|
||||
def _make_config(**overrides) -> ChatConfig:
|
||||
"""Create a ChatConfig with safe defaults, applying *overrides*."""
|
||||
"""Create a ChatConfig with safe defaults, applying *overrides*.
|
||||
|
||||
SDK model fields are pinned to anthropic/* so the
|
||||
``_validate_sdk_model_vendor_compatibility`` model_validator allows
|
||||
construction with ``use_openrouter=False`` (the default here).
|
||||
"""
|
||||
defaults = {
|
||||
"use_claude_code_subscription": False,
|
||||
"use_openrouter": False,
|
||||
"api_key": None,
|
||||
"base_url": None,
|
||||
"thinking_standard_model": "anthropic/claude-sonnet-4-6",
|
||||
"thinking_advanced_model": "anthropic/claude-opus-4-7",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ChatConfig(**defaults)
|
||||
|
||||
399
autogpt_platform/backend/backend/copilot/sdk/openrouter_cost.py
Normal file
399
autogpt_platform/backend/backend/copilot/sdk/openrouter_cost.py
Normal file
@@ -0,0 +1,399 @@
|
||||
"""Authoritative per-turn cost for OpenRouter-routed SDK generations.
|
||||
|
||||
The Claude Agent SDK CLI's ``ResultMessage.total_cost_usd`` is computed
|
||||
from a static Anthropic pricing table baked into the binary. For
|
||||
non-Anthropic models routed through OpenRouter (e.g. Kimi K2.6) the CLI
|
||||
silently falls back to Sonnet rates — empirically ~5x too high. Even
|
||||
after a rate-card override the estimate is still ~37% off in practice
|
||||
because OpenRouter's own tokenizer counts, reasoning-token rollup, and
|
||||
dated-snapshot pricing tiers can't be reconstructed from what the SDK
|
||||
exposes locally.
|
||||
|
||||
This module provides :func:`record_turn_cost_from_openrouter` — an
|
||||
``asyncio.create_task``-able coroutine that:
|
||||
|
||||
1. Queries ``https://openrouter.ai/api/v1/generation?id=<gen-id>`` for
|
||||
each generation ID captured during the turn.
|
||||
2. Sums the authoritative ``total_cost`` across all rounds.
|
||||
3. Calls :func:`persist_and_record_usage` **once** with the real number,
|
||||
updating both the cost-analytics row and the rate-limit counter.
|
||||
|
||||
If every lookup fails (404 / timeout / parse error), the caller's
|
||||
``fallback_cost_usd`` is recorded instead — keeps the rate-limit counter
|
||||
populated with the best available estimate rather than leaving the turn
|
||||
uncharged.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import httpx
|
||||
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.util import json
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# OpenRouter docs:
|
||||
# https://openrouter.ai/docs/api-reference/get-a-generation
|
||||
_GENERATION_URL = "https://openrouter.ai/api/v1/generation"
|
||||
|
||||
# OpenRouter's generation endpoint indexes the billing row a few seconds
|
||||
# after the SSE stream closes — observed ~8-12s in practice. Retry with
|
||||
# progressive backoff for up to ~30s total before giving up, so the typical
|
||||
# indexing window (~10s) fits inside the retry envelope. Backoff values
|
||||
# in seconds summed: 0.5 + 1 + 2 + 4 + 8 + 15 = 30.5.
|
||||
_MAX_RETRIES = 7
|
||||
_BACKOFF_SECONDS = (0.5, 1.0, 2.0, 4.0, 8.0, 15.0)
|
||||
_REQUEST_TIMEOUT = 10.0
|
||||
|
||||
|
||||
async def _fetch_generation_cost(
|
||||
client: httpx.AsyncClient,
|
||||
gen_id: str,
|
||||
api_key: str,
|
||||
log_prefix: str,
|
||||
) -> float | None:
|
||||
"""Fetch the ``total_cost`` for one generation, with retries.
|
||||
|
||||
Retries only on transient conditions:
|
||||
|
||||
* HTTP 404 — row not yet indexed server-side (typical ~5-10s lag
|
||||
after the SSE stream closes)
|
||||
* HTTP 408 / 429 — timeout / rate limit
|
||||
* HTTP 5xx — transient OpenRouter outage
|
||||
* Network / ``httpx`` exceptions — transport-level retryable
|
||||
|
||||
Fails fast on permanent client errors (401 Unauthorized,
|
||||
403 Forbidden, 400 Bad Request, etc.) since they can't recover
|
||||
within the retry window and would just burn API quota.
|
||||
|
||||
Returns ``None`` when the endpoint reports no data, on a permanent
|
||||
failure, or when every retry attempt hits a transient error.
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
params = {"id": gen_id}
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(_MAX_RETRIES):
|
||||
if attempt > 0:
|
||||
await asyncio.sleep(_BACKOFF_SECONDS[attempt - 1])
|
||||
try:
|
||||
resp = await client.get(
|
||||
_GENERATION_URL,
|
||||
params=params,
|
||||
headers=headers,
|
||||
timeout=_REQUEST_TIMEOUT,
|
||||
)
|
||||
status = resp.status_code
|
||||
# Fast-fail on permanent client errors — retrying 401/403/400
|
||||
# just burns API quota and delays the fallback.
|
||||
if status in (400, 401, 403):
|
||||
logger.warning(
|
||||
"%s OpenRouter /generation permanent error %d for %s — "
|
||||
"not retrying (check API key / request shape)",
|
||||
log_prefix,
|
||||
status,
|
||||
gen_id,
|
||||
)
|
||||
return None
|
||||
# Transient retryable: 404 (indexing lag), 408 (timeout),
|
||||
# 429 (rate limit), 5xx (server error).
|
||||
if status == 404 or status == 408 or status == 429 or status >= 500:
|
||||
last_error = RuntimeError(f"HTTP {status} on attempt {attempt + 1}")
|
||||
continue
|
||||
# Any other 4xx — treat as permanent.
|
||||
if status >= 400:
|
||||
logger.warning(
|
||||
"%s OpenRouter /generation unexpected status %d for %s — "
|
||||
"not retrying",
|
||||
log_prefix,
|
||||
status,
|
||||
gen_id,
|
||||
)
|
||||
return None
|
||||
payload = resp.json().get("data")
|
||||
if not isinstance(payload, dict):
|
||||
logger.warning(
|
||||
"%s OpenRouter /generation returned no data for %s",
|
||||
log_prefix,
|
||||
gen_id,
|
||||
)
|
||||
return None
|
||||
cost = payload.get("total_cost")
|
||||
if cost is None:
|
||||
logger.warning(
|
||||
"%s OpenRouter /generation response missing total_cost "
|
||||
"for %s (keys=%s)",
|
||||
log_prefix,
|
||||
gen_id,
|
||||
sorted(payload.keys())[:10],
|
||||
)
|
||||
return None
|
||||
return float(cost)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# Network / transport errors are retryable.
|
||||
last_error = exc
|
||||
continue
|
||||
logger.warning(
|
||||
"%s OpenRouter /generation lookup failed for %s after %d attempts: %s",
|
||||
log_prefix,
|
||||
gen_id,
|
||||
_MAX_RETRIES,
|
||||
last_error,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _gen_ids_from_jsonl(path: Path) -> set[str]:
|
||||
"""Extract ``gen-`` message IDs from every assistant entry in a
|
||||
Claude CLI JSONL file.
|
||||
|
||||
Tolerant of malformed lines: single bad JSON object doesn't block
|
||||
the whole file. Also reads ``redacted_thinking`` / ``thinking``
|
||||
entries that share an ID with their parent (via ``jq -u`` in the
|
||||
CLI) and dedups by caller.
|
||||
"""
|
||||
ids: set[str] = set()
|
||||
try:
|
||||
with path.open("r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
entry = json.loads(line, fallback=None)
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("type") != "assistant":
|
||||
continue
|
||||
message = entry.get("message")
|
||||
if not isinstance(message, dict):
|
||||
continue
|
||||
msg_id = message.get("id")
|
||||
if isinstance(msg_id, str) and msg_id.startswith("gen-"):
|
||||
ids.add(msg_id)
|
||||
except (OSError, UnicodeDecodeError) as exc:
|
||||
logger.debug(
|
||||
"Failed to scan JSONL for gen-IDs: path=%s err=%s",
|
||||
path,
|
||||
exc,
|
||||
)
|
||||
return ids
|
||||
|
||||
|
||||
def _discover_turn_subagent_gen_ids(
|
||||
project_dir: Path,
|
||||
session_id: str,
|
||||
turn_start_ts: float,
|
||||
known: list[str],
|
||||
) -> list[str]:
|
||||
"""Gen-IDs from this session's subagents created during this turn.
|
||||
|
||||
Main-turn LLM rounds (incl. fallback retries) arrive on the live
|
||||
stream as ``AssistantMessage`` and land on ``known`` via
|
||||
``message_id``. What's NOT on ``known`` is the CLI's subagent LLM
|
||||
calls — chiefly auto-compaction, which spawns a fresh JSONL under
|
||||
``<project_dir>/<session_id>/subagents/agent-acompact-*.jsonl``
|
||||
whose gen-IDs never touch our main adapter. OpenRouter bills them
|
||||
anyway, so without this sweep compaction turns under-report cost.
|
||||
|
||||
Scoping: ONLY the current session's subagent dir
|
||||
(``<project_dir>/<session_id>/subagents/agent-*.jsonl``) and ONLY
|
||||
files whose ``mtime >= turn_start_ts``. Without both guards we'd
|
||||
merge prior turns' gen-IDs (main JSONL accumulates forever) and
|
||||
foreign sessions' gen-IDs (the project dir contains every session
|
||||
for this cwd), double-billing the user.
|
||||
|
||||
Also covers non-compaction subagents (Task tool etc.) when the CLI
|
||||
spawns them — their live-stream visibility depends on SDK version,
|
||||
so the sweep is a safety net. The dedup against ``known`` means
|
||||
anything already captured live doesn't double count.
|
||||
|
||||
Preserves ``known`` ordering so main-turn IDs stay first; only
|
||||
appends truly new IDs from the sweep.
|
||||
"""
|
||||
merged: list[str] = list(known)
|
||||
seen = set(merged)
|
||||
subagents_dir = project_dir / session_id / "subagents"
|
||||
if not subagents_dir.exists():
|
||||
return merged
|
||||
try:
|
||||
for jsonl in subagents_dir.glob("agent-*.jsonl"):
|
||||
try:
|
||||
if jsonl.stat().st_mtime < turn_start_ts:
|
||||
continue
|
||||
except OSError:
|
||||
continue
|
||||
for gen_id in _gen_ids_from_jsonl(jsonl):
|
||||
if gen_id not in seen:
|
||||
seen.add(gen_id)
|
||||
merged.append(gen_id)
|
||||
except OSError as exc:
|
||||
logger.debug("Failed to walk subagents dir=%s: %s", subagents_dir, exc)
|
||||
return merged
|
||||
|
||||
|
||||
async def record_turn_cost_from_openrouter(
|
||||
*,
|
||||
session: "ChatSession",
|
||||
user_id: str | None,
|
||||
model: str | None,
|
||||
prompt_tokens: int,
|
||||
completion_tokens: int,
|
||||
cache_read_tokens: int,
|
||||
cache_creation_tokens: int,
|
||||
generation_ids: list[str],
|
||||
cli_project_dir: str | None,
|
||||
cli_session_id: str | None,
|
||||
turn_start_ts: float | None,
|
||||
fallback_cost_usd: float | None,
|
||||
api_key: str | None,
|
||||
log_prefix: str,
|
||||
) -> None:
|
||||
"""Persist turn cost from OpenRouter's authoritative ``/generation``.
|
||||
|
||||
Writes a single cost-analytics row via :func:`persist_and_record_usage`
|
||||
— same method used for the Anthropic-direct sync path — so the
|
||||
cost-log append and rate-limit counter stay consistent. No double
|
||||
counting: the caller skips its own sync persist for non-Anthropic
|
||||
OpenRouter turns and defers entirely to this task.
|
||||
|
||||
Launched via ``asyncio.create_task`` from the stream ``finally`` block
|
||||
so the ~500-2000ms ``/generation`` indexing delay doesn't add latency
|
||||
to the turn. During that window the rate-limit counter is briefly
|
||||
unaware of the turn's cost; back-to-back turns in that sub-second
|
||||
gap see a stale counter. Acceptable tradeoff — the alternative
|
||||
(writing a possibly-wrong estimate synchronously) creates a
|
||||
double-count when the reconcile delta arrives.
|
||||
|
||||
Fallback semantics: if every generation lookup fails, records
|
||||
``fallback_cost_usd`` instead so the rate-limit counter isn't left
|
||||
completely empty. Keeps behaviour at-worst equivalent to the
|
||||
rate-card estimate that came before this task existed.
|
||||
"""
|
||||
if not api_key:
|
||||
logger.debug(
|
||||
"%s OpenRouter cost record skipped: no API key available",
|
||||
log_prefix,
|
||||
)
|
||||
return
|
||||
|
||||
# Merge in any gen-IDs from CLI subagent JSONLs the live stream
|
||||
# didn't surface — chiefly SDK-internal compaction, which spawns a
|
||||
# summarisation LLM call under
|
||||
# ``<project_dir>/<cli_session_id>/subagents/...`` that OpenRouter
|
||||
# bills but doesn't emit via our main adapter. Safe no-op when no
|
||||
# compaction happened (no subagent files created this turn) or the
|
||||
# CLI wrote nothing there.
|
||||
#
|
||||
# The sweep is SESSION-scoped (``<cli_session_id>/subagents/``, not
|
||||
# the whole project dir) and TURN-scoped (mtime >= turn_start_ts).
|
||||
# Both guards are load-bearing: the project dir contains every
|
||||
# session for this cwd, and subagent files persist across turns,
|
||||
# so an unscoped sweep would re-bill prior turns and foreign
|
||||
# sessions' gen-IDs.
|
||||
if cli_project_dir and cli_session_id and turn_start_ts is not None:
|
||||
merged_ids = _discover_turn_subagent_gen_ids(
|
||||
Path(os.path.expanduser(cli_project_dir)),
|
||||
cli_session_id,
|
||||
turn_start_ts,
|
||||
generation_ids,
|
||||
)
|
||||
if len(merged_ids) != len(generation_ids):
|
||||
logger.info(
|
||||
"%s[cost-record] discovered %d additional gen-IDs in "
|
||||
"session subagents (compaction / Task) — reconcile "
|
||||
"covers all",
|
||||
log_prefix,
|
||||
len(merged_ids) - len(generation_ids),
|
||||
)
|
||||
generation_ids = merged_ids
|
||||
|
||||
if not generation_ids:
|
||||
return
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
tasks = [
|
||||
_fetch_generation_cost(client, gen_id, api_key, log_prefix)
|
||||
for gen_id in generation_ids
|
||||
]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=False)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"%s OpenRouter cost record failed to fetch any generation "
|
||||
"(falling back to rate-card estimate): %s",
|
||||
log_prefix,
|
||||
exc,
|
||||
)
|
||||
results = []
|
||||
|
||||
fetched = [r for r in results if isinstance(r, (int, float))]
|
||||
if fetched and len(fetched) == len(generation_ids):
|
||||
real_cost: float | None = sum(fetched)
|
||||
# Log real (OpenRouter billed) vs CLI rate-card estimate so an
|
||||
# operator can spot divergence without querying OpenRouter by
|
||||
# hand. Under-count typically means a gen-ID source we don't
|
||||
# capture live (e.g. title model, background LLM calls running
|
||||
# outside the main stream); over-count means the CLI's rate
|
||||
# table is stale vs. OpenRouter's current pricing.
|
||||
delta_pct: float | None = None
|
||||
if fallback_cost_usd and fallback_cost_usd > 0:
|
||||
delta_pct = (real_cost - fallback_cost_usd) / fallback_cost_usd * 100
|
||||
logger.info(
|
||||
"%s[cost-record] OpenRouter real=$%.6f cli_estimate=$%s "
|
||||
"delta=%s (gen_ids=%d)",
|
||||
log_prefix,
|
||||
real_cost,
|
||||
f"{fallback_cost_usd:.6f}" if fallback_cost_usd is not None else "?",
|
||||
f"{delta_pct:+.1f}%" if delta_pct is not None else "n/a",
|
||||
len(generation_ids),
|
||||
)
|
||||
else:
|
||||
real_cost = fallback_cost_usd
|
||||
if fetched:
|
||||
# Partial success: some lookups returned a cost, others didn't.
|
||||
# Trusting the partial sum would under-report; fall back to the
|
||||
# estimate so rate-limit enforcement stays conservative.
|
||||
logger.warning(
|
||||
"%s[cost-record] OpenRouter partial lookup (%d/%d) — "
|
||||
"using fallback estimate=$%s",
|
||||
log_prefix,
|
||||
len(fetched),
|
||||
len(generation_ids),
|
||||
real_cost,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"%s[cost-record] OpenRouter lookup failed for all gens — "
|
||||
"using fallback estimate=$%s",
|
||||
log_prefix,
|
||||
real_cost,
|
||||
)
|
||||
|
||||
try:
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=cache_read_tokens,
|
||||
cache_creation_tokens=cache_creation_tokens,
|
||||
log_prefix=f"{log_prefix}[cost-record]",
|
||||
cost_usd=real_cost,
|
||||
model=model,
|
||||
provider="open_router",
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
logger.warning(
|
||||
"%s[cost-record] failed to persist: %s",
|
||||
log_prefix,
|
||||
exc,
|
||||
)
|
||||
@@ -0,0 +1,520 @@
|
||||
"""Unit tests for SDK-path OpenRouter cost recording."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.sdk.openrouter_cost import record_turn_cost_from_openrouter
|
||||
|
||||
|
||||
def _session() -> ChatSession:
|
||||
now = datetime.now(UTC)
|
||||
return ChatSession(
|
||||
session_id="sess-1",
|
||||
user_id="user-1",
|
||||
usage=[],
|
||||
started_at=now,
|
||||
updated_at=now,
|
||||
messages=[],
|
||||
)
|
||||
|
||||
|
||||
def _mock_generation_response(cost: float) -> dict:
|
||||
return {
|
||||
"data": {
|
||||
"total_cost": cost,
|
||||
"native_tokens_prompt": 1000,
|
||||
"native_tokens_completion": 50,
|
||||
"tokens_prompt": 1100,
|
||||
"tokens_completion": 60,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TestRecordTurnCostFromOpenRouter:
|
||||
"""Single-write semantics: the cost + rate-limit counter is updated
|
||||
exactly once per turn via this background task. The sync path at
|
||||
the call site is already skipped for non-Anthropic OpenRouter turns,
|
||||
so there's no double-counting path even on partial failure."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_generation_ids_no_op(self):
|
||||
"""Direct-Anthropic turn produces no gen-IDs — task is a no-op."""
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_persist,
|
||||
patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get,
|
||||
):
|
||||
await record_turn_cost_from_openrouter(
|
||||
session=_session(),
|
||||
user_id="u1",
|
||||
model="anthropic/claude-sonnet-4.6",
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
generation_ids=[],
|
||||
cli_project_dir=None,
|
||||
cli_session_id=None,
|
||||
turn_start_ts=None,
|
||||
fallback_cost_usd=0.05,
|
||||
api_key="sk-or-test",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
mock_persist.assert_not_called()
|
||||
mock_get.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_api_key_no_op(self):
|
||||
"""Without an OpenRouter API key we can't query the endpoint — skip."""
|
||||
with patch(
|
||||
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_persist:
|
||||
await record_turn_cost_from_openrouter(
|
||||
session=_session(),
|
||||
user_id="u1",
|
||||
model="moonshotai/kimi-k2.6",
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
generation_ids=["gen-1"],
|
||||
cli_project_dir=None,
|
||||
cli_session_id=None,
|
||||
turn_start_ts=None,
|
||||
fallback_cost_usd=0.02,
|
||||
api_key=None,
|
||||
log_prefix="[test]",
|
||||
)
|
||||
mock_persist.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_generation_records_real_cost(self):
|
||||
"""Authoritative cost from OpenRouter is the value recorded — no
|
||||
reliance on the fallback estimate."""
|
||||
real_cost = 0.02900595
|
||||
|
||||
async def _get(self, url, **kwargs): # noqa: ARG001
|
||||
return httpx.Response(200, json=_mock_generation_response(real_cost))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_persist,
|
||||
patch("httpx.AsyncClient.get", new=_get),
|
||||
):
|
||||
await record_turn_cost_from_openrouter(
|
||||
session=_session(),
|
||||
user_id="u1",
|
||||
model="moonshotai/kimi-k2.6",
|
||||
prompt_tokens=29669,
|
||||
completion_tokens=280,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
generation_ids=["gen-1776842410"],
|
||||
cli_project_dir=None,
|
||||
cli_session_id=None,
|
||||
turn_start_ts=None,
|
||||
fallback_cost_usd=0.01858, # rate-card estimate, deliberately wrong
|
||||
api_key="sk-or-test",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
mock_persist.assert_called_once()
|
||||
kwargs = mock_persist.call_args.kwargs
|
||||
assert kwargs["cost_usd"] == pytest.approx(real_cost, rel=1e-9)
|
||||
assert kwargs["prompt_tokens"] == 29669
|
||||
assert kwargs["completion_tokens"] == 280
|
||||
assert kwargs["provider"] == "open_router"
|
||||
assert kwargs["model"] == "moonshotai/kimi-k2.6"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multi_round_turn_sums_costs(self):
|
||||
"""Tool-use turn has N generation IDs; the real cost is the sum
|
||||
of ``total_cost`` across all rounds — recorded in a single row."""
|
||||
costs_by_id = {"gen-a": 0.029, "gen-b": 0.030}
|
||||
|
||||
async def _get(self, url, **kwargs): # noqa: ARG001
|
||||
gen_id = kwargs.get("params", {}).get("id")
|
||||
return httpx.Response(
|
||||
200, json=_mock_generation_response(costs_by_id[gen_id])
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_persist,
|
||||
patch("httpx.AsyncClient.get", new=_get),
|
||||
):
|
||||
await record_turn_cost_from_openrouter(
|
||||
session=_session(),
|
||||
user_id="u1",
|
||||
model="moonshotai/kimi-k2.6",
|
||||
prompt_tokens=60000,
|
||||
completion_tokens=600,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
generation_ids=["gen-a", "gen-b"],
|
||||
cli_project_dir=None,
|
||||
cli_session_id=None,
|
||||
turn_start_ts=None,
|
||||
fallback_cost_usd=0.037,
|
||||
api_key="sk-or-test",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
mock_persist.assert_called_once()
|
||||
cost = mock_persist.call_args.kwargs["cost_usd"]
|
||||
assert cost == pytest.approx(sum(costs_by_id.values()), rel=1e-9)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_partial_lookup_falls_back_to_estimate(self):
|
||||
"""If only some gen-IDs resolve, summing them would under-report.
|
||||
Fall back to the caller's estimate and log — the rate-limit
|
||||
counter stays populated with the best available number."""
|
||||
fallback = 0.05
|
||||
seq = iter(
|
||||
[
|
||||
httpx.Response(200, json=_mock_generation_response(0.03)),
|
||||
httpx.Response(404, text="not found"),
|
||||
httpx.Response(404, text="not found"),
|
||||
httpx.Response(404, text="not found"),
|
||||
httpx.Response(404, text="not found"),
|
||||
]
|
||||
)
|
||||
|
||||
async def _get(self, *args, **kwargs): # noqa: ARG001
|
||||
return next(seq)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_persist,
|
||||
patch("httpx.AsyncClient.get", new=_get),
|
||||
):
|
||||
await record_turn_cost_from_openrouter(
|
||||
session=_session(),
|
||||
user_id="u1",
|
||||
model="moonshotai/kimi-k2.6",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=10,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
generation_ids=["gen-a", "gen-b"],
|
||||
cli_project_dir=None,
|
||||
cli_session_id=None,
|
||||
turn_start_ts=None,
|
||||
fallback_cost_usd=fallback,
|
||||
api_key="sk-or-test",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
mock_persist.assert_called_once()
|
||||
assert mock_persist.call_args.kwargs["cost_usd"] == fallback
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fast_fail_on_401_no_retries(self):
|
||||
"""Permanent client errors (401/403/400) must not retry — burning
|
||||
the 30s retry window on an unauthenticated request wastes API
|
||||
quota and delays the fallback."""
|
||||
call_count = {"n": 0}
|
||||
|
||||
async def _get(self, *args, **kwargs): # noqa: ARG001
|
||||
call_count["n"] += 1
|
||||
return httpx.Response(401, text="unauthorized")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_persist,
|
||||
patch("httpx.AsyncClient.get", new=_get),
|
||||
):
|
||||
await record_turn_cost_from_openrouter(
|
||||
session=_session(),
|
||||
user_id="u1",
|
||||
model="moonshotai/kimi-k2.6",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=10,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
generation_ids=["gen-a"],
|
||||
cli_project_dir=None,
|
||||
cli_session_id=None,
|
||||
turn_start_ts=None,
|
||||
fallback_cost_usd=0.02,
|
||||
api_key="sk-bad",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
# Only one call — no retries.
|
||||
assert call_count["n"] == 1
|
||||
# Fallback was recorded (lookup failed → keep rate-limit counter live).
|
||||
mock_persist.assert_called_once()
|
||||
assert mock_persist.call_args.kwargs["cost_usd"] == 0.02
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retries_on_404_then_succeeds(self):
|
||||
"""Indexing lag: endpoint returns 404 initially, then 200 once the
|
||||
billing row is indexed. Retry budget should exhaust transient
|
||||
states rather than giving up on first 404."""
|
||||
seq = iter(
|
||||
[
|
||||
httpx.Response(404, text="not found"),
|
||||
httpx.Response(200, json=_mock_generation_response(0.025)),
|
||||
]
|
||||
)
|
||||
|
||||
async def _get(self, *args, **kwargs): # noqa: ARG001
|
||||
return next(seq)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_persist,
|
||||
patch("httpx.AsyncClient.get", new=_get),
|
||||
):
|
||||
await record_turn_cost_from_openrouter(
|
||||
session=_session(),
|
||||
user_id="u1",
|
||||
model="moonshotai/kimi-k2.6",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=10,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
generation_ids=["gen-a"],
|
||||
cli_project_dir=None,
|
||||
cli_session_id=None,
|
||||
turn_start_ts=None,
|
||||
fallback_cost_usd=0.05,
|
||||
api_key="sk-or-test",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
mock_persist.assert_called_once()
|
||||
assert mock_persist.call_args.kwargs["cost_usd"] == pytest.approx(0.025)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_lookup_failure_falls_back_to_estimate(self):
|
||||
"""Every lookup fails → record the estimate so the rate-limit
|
||||
counter isn't left empty. At-worst parity with the pre-task
|
||||
behaviour."""
|
||||
fallback = 0.02
|
||||
|
||||
async def _get(self, *args, **kwargs): # noqa: ARG001
|
||||
raise httpx.ConnectError("no network")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_persist,
|
||||
patch("httpx.AsyncClient.get", new=_get),
|
||||
):
|
||||
await record_turn_cost_from_openrouter(
|
||||
session=_session(),
|
||||
user_id="u1",
|
||||
model="moonshotai/kimi-k2.6",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=10,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
generation_ids=["gen-a"],
|
||||
cli_project_dir=None,
|
||||
cli_session_id=None,
|
||||
turn_start_ts=None,
|
||||
fallback_cost_usd=fallback,
|
||||
api_key="sk-or-test",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
mock_persist.assert_called_once()
|
||||
assert mock_persist.call_args.kwargs["cost_usd"] == fallback
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_subagent_gen_ids_are_swept(self, tmp_path):
|
||||
"""CLI-internal compaction spawns a subagent JSONL under
|
||||
``<project_dir>/<session_id>/subagents/agent-acompact-*.jsonl``
|
||||
whose gen-IDs the live adapter never surfaces. When
|
||||
``cli_project_dir`` + ``cli_session_id`` + ``turn_start_ts``
|
||||
are supplied the reconcile walks only THIS session's subagents
|
||||
and discovers the compaction IDs."""
|
||||
session_id = "sess-abc"
|
||||
sub_dir = tmp_path / session_id / "subagents"
|
||||
sub_dir.mkdir(parents=True)
|
||||
(sub_dir / "agent-acompact-xyz.jsonl").write_text(
|
||||
'{"type":"assistant","message":{"id":"gen-compact-1","content":[]}}\n'
|
||||
'{"type":"assistant","message":{"id":"gen-compact-2","content":[]}}\n'
|
||||
)
|
||||
|
||||
costs_by_id = {
|
||||
"gen-main-1": 0.020,
|
||||
"gen-compact-1": 0.005,
|
||||
"gen-compact-2": 0.003,
|
||||
}
|
||||
|
||||
async def _get(self, *args, **kwargs): # noqa: ARG001
|
||||
gen_id = kwargs.get("params", {}).get("id")
|
||||
return httpx.Response(
|
||||
200, json=_mock_generation_response(costs_by_id[gen_id])
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_persist,
|
||||
patch("httpx.AsyncClient.get", new=_get),
|
||||
):
|
||||
await record_turn_cost_from_openrouter(
|
||||
session=_session(),
|
||||
user_id="u1",
|
||||
model="anthropic/claude-opus-4.7",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=10,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
generation_ids=["gen-main-1"],
|
||||
cli_project_dir=str(tmp_path),
|
||||
cli_session_id=session_id,
|
||||
turn_start_ts=0.0,
|
||||
fallback_cost_usd=0.05,
|
||||
api_key="sk-or-test",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
mock_persist.assert_called_once()
|
||||
assert mock_persist.call_args.kwargs["cost_usd"] == pytest.approx(
|
||||
sum(costs_by_id.values()), rel=1e-9
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_sweep_no_subagents_is_noop(self, tmp_path):
|
||||
"""No compaction happened → reconcile uses only the caller's
|
||||
gen-IDs, same as when cli_project_dir is None."""
|
||||
session_id = "sess-none"
|
||||
(tmp_path / session_id).mkdir()
|
||||
|
||||
async def _get(self, *args, **kwargs): # noqa: ARG001
|
||||
return httpx.Response(200, json=_mock_generation_response(0.02))
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_persist,
|
||||
patch("httpx.AsyncClient.get", new=_get),
|
||||
):
|
||||
await record_turn_cost_from_openrouter(
|
||||
session=_session(),
|
||||
user_id="u1",
|
||||
model="moonshotai/kimi-k2.6",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=10,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
generation_ids=["gen-main-1"],
|
||||
cli_project_dir=str(tmp_path),
|
||||
cli_session_id=session_id,
|
||||
turn_start_ts=0.0,
|
||||
fallback_cost_usd=0.05,
|
||||
api_key="sk-or-test",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
mock_persist.assert_called_once()
|
||||
assert mock_persist.call_args.kwargs["cost_usd"] == pytest.approx(0.02)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_compaction_sweep_ignores_prior_turn_and_foreign_sessions(
|
||||
self, tmp_path
|
||||
):
|
||||
"""Scoping guards the sweep from double-billing: a stale subagent
|
||||
file from a prior turn (mtime before ``turn_start_ts``) and any
|
||||
subagent from a foreign session (different session_id folder)
|
||||
must BOTH be skipped. Without either guard, a long-running
|
||||
session with past compactions would re-bill every prior turn,
|
||||
and a second chat session in the same cwd would inherit the
|
||||
first session's compaction cost."""
|
||||
import os
|
||||
import time
|
||||
|
||||
this_session = "sess-current"
|
||||
other_session = "sess-other"
|
||||
|
||||
this_subagents = tmp_path / this_session / "subagents"
|
||||
this_subagents.mkdir(parents=True)
|
||||
other_subagents = tmp_path / other_session / "subagents"
|
||||
other_subagents.mkdir(parents=True)
|
||||
|
||||
# Prior-turn compaction file — same session, stale mtime.
|
||||
stale_file = this_subagents / "agent-acompact-stale.jsonl"
|
||||
stale_file.write_text(
|
||||
'{"type":"assistant","message":{"id":"gen-stale-1","content":[]}}\n'
|
||||
)
|
||||
# Foreign session's compaction file.
|
||||
foreign_file = other_subagents / "agent-acompact-foreign.jsonl"
|
||||
foreign_file.write_text(
|
||||
'{"type":"assistant","message":{"id":"gen-foreign-1","content":[]}}\n'
|
||||
)
|
||||
# Current-turn compaction file — fresh.
|
||||
fresh_file = this_subagents / "agent-acompact-fresh.jsonl"
|
||||
fresh_file.write_text(
|
||||
'{"type":"assistant","message":{"id":"gen-fresh-1","content":[]}}\n'
|
||||
)
|
||||
|
||||
# turn_start_ts lies between the stale and fresh mtimes.
|
||||
past = time.time() - 3600
|
||||
os.utime(stale_file, (past, past))
|
||||
os.utime(foreign_file, (past, past))
|
||||
turn_start_ts = time.time() - 60 # 1 min ago
|
||||
fresh_now = time.time()
|
||||
os.utime(fresh_file, (fresh_now, fresh_now))
|
||||
|
||||
costs_by_id = {
|
||||
"gen-main-1": 0.010,
|
||||
"gen-fresh-1": 0.004,
|
||||
}
|
||||
|
||||
async def _get(self, *args, **kwargs): # noqa: ARG001
|
||||
gen_id = kwargs.get("params", {}).get("id")
|
||||
# If the sweep leaks a stale/foreign ID, the test fails here
|
||||
# with a KeyError rather than silently over-billing.
|
||||
assert gen_id in costs_by_id, f"sweep leaked out-of-scope gen_id {gen_id}"
|
||||
return httpx.Response(
|
||||
200, json=_mock_generation_response(costs_by_id[gen_id])
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.sdk.openrouter_cost.persist_and_record_usage",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_persist,
|
||||
patch("httpx.AsyncClient.get", new=_get),
|
||||
):
|
||||
await record_turn_cost_from_openrouter(
|
||||
session=_session(),
|
||||
user_id="u1",
|
||||
model="moonshotai/kimi-k2.6",
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=10,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
generation_ids=["gen-main-1"],
|
||||
cli_project_dir=str(tmp_path),
|
||||
cli_session_id=this_session,
|
||||
turn_start_ts=turn_start_ts,
|
||||
fallback_cost_usd=0.05,
|
||||
api_key="sk-or-test",
|
||||
log_prefix="[test]",
|
||||
)
|
||||
mock_persist.assert_called_once()
|
||||
# Exactly the current-turn main + fresh compaction — no stale,
|
||||
# no foreign.
|
||||
assert mock_persist.call_args.kwargs["cost_usd"] == pytest.approx(
|
||||
sum(costs_by_id.values()), rel=1e-9
|
||||
)
|
||||
@@ -10,12 +10,19 @@ from backend.copilot.constants import is_transient_api_error
|
||||
|
||||
|
||||
def _make_config(**overrides) -> ChatConfig:
|
||||
"""Create a ChatConfig with safe defaults, applying *overrides*."""
|
||||
"""Create a ChatConfig with safe defaults, applying *overrides*.
|
||||
|
||||
SDK model fields are pinned to anthropic/* so the
|
||||
``_validate_sdk_model_vendor_compatibility`` model_validator allows
|
||||
construction with ``use_openrouter=False`` (the default here).
|
||||
"""
|
||||
defaults = {
|
||||
"use_claude_code_subscription": False,
|
||||
"use_openrouter": False,
|
||||
"api_key": None,
|
||||
"base_url": None,
|
||||
"thinking_standard_model": "anthropic/claude-sonnet-4-6",
|
||||
"thinking_advanced_model": "anthropic/claude-opus-4-7",
|
||||
}
|
||||
defaults.update(overrides)
|
||||
return ChatConfig(**defaults)
|
||||
@@ -39,8 +46,11 @@ class TestResolveFallbackModel:
|
||||
|
||||
assert _resolve_fallback_model() is None
|
||||
|
||||
def test_strips_provider_prefix(self):
|
||||
"""OpenRouter-style 'anthropic/claude-sonnet-4-...' is stripped."""
|
||||
def test_keeps_full_slug_when_openrouter_active(self):
|
||||
"""OpenRouter routes by ``vendor/model`` slug — _normalize_model_name
|
||||
now preserves the prefix when openrouter_active is True so non-
|
||||
Anthropic vendors stay routable. Anthropic slugs are passed
|
||||
through unchanged in this mode (PR #12878)."""
|
||||
cfg = _make_config(
|
||||
claude_agent_fallback_model="anthropic/claude-sonnet-4-20250514",
|
||||
use_openrouter=True,
|
||||
@@ -52,8 +62,7 @@ class TestResolveFallbackModel:
|
||||
|
||||
result = _resolve_fallback_model()
|
||||
|
||||
assert result == "claude-sonnet-4-20250514"
|
||||
assert "/" not in result
|
||||
assert result == "anthropic/claude-sonnet-4-20250514"
|
||||
|
||||
def test_dots_replaced_for_direct_anthropic(self):
|
||||
"""Direct Anthropic requires hyphen-separated versions."""
|
||||
@@ -714,11 +723,16 @@ class TestDoTransientBackoff:
|
||||
mock_sleep.assert_called_once_with(7)
|
||||
|
||||
async def test_replaces_adapter_with_new_instance(self):
|
||||
"""state.adapter is replaced with a new SDKResponseAdapter after yield."""
|
||||
"""state.adapter is replaced with a new SDKResponseAdapter after yield,
|
||||
and ``render_reasoning_in_ui`` is threaded from the SDK service config
|
||||
(not hardcoded) so ``CHAT_RENDER_REASONING_IN_UI=false`` at runtime
|
||||
flips the reconstruction consistently with the rest of the path."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from backend.copilot.sdk.service import _do_transient_backoff
|
||||
|
||||
cfg = _make_config(render_reasoning_in_ui=False)
|
||||
|
||||
original_adapter = MagicMock()
|
||||
state = MagicMock()
|
||||
state.adapter = original_adapter
|
||||
@@ -726,6 +740,7 @@ class TestDoTransientBackoff:
|
||||
|
||||
with (
|
||||
patch("asyncio.sleep", new=AsyncMock()),
|
||||
patch(f"{_SVC}.config", cfg),
|
||||
patch("backend.copilot.sdk.service.SDKResponseAdapter") as mock_cls,
|
||||
):
|
||||
new_adapter = MagicMock()
|
||||
@@ -733,7 +748,11 @@ class TestDoTransientBackoff:
|
||||
async for _ in _do_transient_backoff(3, state, "msg-1", "sess-1"):
|
||||
pass
|
||||
|
||||
mock_cls.assert_called_once_with(message_id="msg-1", session_id="sess-1")
|
||||
mock_cls.assert_called_once_with(
|
||||
message_id="msg-1",
|
||||
session_id="sess-1",
|
||||
render_reasoning_in_ui=False,
|
||||
)
|
||||
assert state.adapter is new_adapter
|
||||
|
||||
async def test_resets_usage_after_yield(self):
|
||||
|
||||
@@ -7,12 +7,15 @@ the frontend expects.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
Message,
|
||||
ResultMessage,
|
||||
StreamEvent,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
@@ -46,6 +49,16 @@ from .tool_adapter import MCP_TOOL_PREFIX, pop_pending_tool_output
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Coalescing thresholds for ``thinking_delta`` events on the SDK partial
|
||||
# stream — matches the baseline window (see
|
||||
# ``baseline/reasoning.py::_COALESCE_MIN_CHARS``). Anthropic's extended-
|
||||
# thinking channel emits ~1 event per token (~4,700 per Kimi K2.6 turn);
|
||||
# a 64-char / 50 ms window halves the event rate vs 32/40 while staying
|
||||
# well under the ~100 ms perceptual threshold.
|
||||
_THINKING_COALESCE_MIN_CHARS = 64
|
||||
_THINKING_COALESCE_MAX_INTERVAL_MS = 50.0
|
||||
|
||||
|
||||
class SDKResponseAdapter:
|
||||
"""Adapter for converting Claude Agent SDK messages to Vercel AI SDK format.
|
||||
|
||||
@@ -53,7 +66,13 @@ class SDKResponseAdapter:
|
||||
text blocks, tool calls, and message lifecycle.
|
||||
"""
|
||||
|
||||
def __init__(self, message_id: str | None = None, session_id: str | None = None):
|
||||
def __init__(
|
||||
self,
|
||||
message_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
*,
|
||||
render_reasoning_in_ui: bool = True,
|
||||
):
|
||||
self.message_id = message_id or str(uuid.uuid4())
|
||||
self.session_id = session_id
|
||||
self.text_block_id = str(uuid.uuid4())
|
||||
@@ -62,6 +81,7 @@ class SDKResponseAdapter:
|
||||
self.reasoning_block_id = str(uuid.uuid4())
|
||||
self.has_started_reasoning = False
|
||||
self.has_ended_reasoning = True
|
||||
self.render_reasoning_in_ui = render_reasoning_in_ui
|
||||
self.current_tool_calls: dict[str, dict[str, str]] = {}
|
||||
self.resolved_tool_calls: set[str] = set()
|
||||
self.step_open = False
|
||||
@@ -74,6 +94,36 @@ class SDKResponseAdapter:
|
||||
# case so the turn renders as cleanly complete.
|
||||
self._text_since_last_tool_result = False
|
||||
self._any_tool_results_seen = False
|
||||
# --- Partial-message streaming state (CHAT_SDK_INCLUDE_PARTIAL_MESSAGES)
|
||||
# When ``include_partial_messages=True`` is set on
|
||||
# ``ClaudeAgentOptions``, the CLI emits raw Anthropic streaming
|
||||
# events (``content_block_start`` / ``content_block_delta`` /
|
||||
# ``content_block_stop``) as ``StreamEvent`` messages ahead of
|
||||
# each summary ``AssistantMessage``. We consume those for
|
||||
# per-token wire emission and reconcile against the summary to
|
||||
# catch any tail content the partial stream missed (short blocks
|
||||
# the CLI emits summary-only, OpenRouter proxy quirks, encrypted
|
||||
# thinking).
|
||||
#
|
||||
self._block_types_by_index: dict[int, str] = {}
|
||||
# Running partial-stream buffers. Summary AssistantMessages can
|
||||
# arrive *before* the corresponding ``content_block_stop`` event
|
||||
# (the CLI flushes the summary as soon as the block is complete
|
||||
# on the provider side, with the stop event following as a
|
||||
# separate frame). Reconcile-by-index therefore can't rely on
|
||||
# completed-block queues — instead we maintain running buffers
|
||||
# of all partial output of each type, and each summary block of
|
||||
# that type consumes its prefix. This also trivially handles
|
||||
# Kimi K2.6's pattern of emitting each content block as its own
|
||||
# summary AssistantMessage: Python list indices don't align
|
||||
# with Anthropic content_block indices, but per-type order does.
|
||||
self._partial_text_buffer: str = ""
|
||||
self._partial_thinking_buffer: str = ""
|
||||
# Coalescing buffer for ``thinking_delta`` — text_delta is
|
||||
# naturally coarser so we let it through unbuffered.
|
||||
self._pending_thinking_delta: str = ""
|
||||
self._pending_thinking_index: int | None = None
|
||||
self._last_thinking_flush_monotonic: float = 0.0
|
||||
|
||||
@property
|
||||
def has_unresolved_tool_calls(self) -> bool:
|
||||
@@ -99,6 +149,15 @@ class SDKResponseAdapter:
|
||||
# produced (task_progress events were previously silent).
|
||||
responses.append(StreamHeartbeat())
|
||||
|
||||
elif isinstance(sdk_message, StreamEvent):
|
||||
# Raw Anthropic streaming events — only delivered when
|
||||
# ``include_partial_messages=True`` is set on
|
||||
# ``ClaudeAgentOptions`` (gated by
|
||||
# ``config.sdk_include_partial_messages``). Drives per-token
|
||||
# emission of text + thinking; tool_use and other structural
|
||||
# events stay on the ``AssistantMessage`` path.
|
||||
self._handle_stream_event(sdk_message, responses)
|
||||
|
||||
elif isinstance(sdk_message, AssistantMessage):
|
||||
# Flush any SDK built-in tool calls that didn't get a UserMessage
|
||||
# result (e.g. WebSearch, Read handled internally by the CLI).
|
||||
@@ -115,18 +174,43 @@ class SDKResponseAdapter:
|
||||
responses.append(StreamStartStep())
|
||||
self.step_open = True
|
||||
|
||||
for block in sdk_message.content:
|
||||
# Hoist ThinkingBlocks to the front of the iteration so the UI
|
||||
# sees reasoning *before* the answer it produced — that's the
|
||||
# natural reading order and the way Anthropic models emit them.
|
||||
# OpenRouter passthrough providers (Moonshot/Kimi, DeepSeek)
|
||||
# often place ``reasoning`` after the visible text in the
|
||||
# response, which would make ``ReasoningCollapse`` render under
|
||||
# the assistant message instead of above it. ToolUse and other
|
||||
# block types stay in their original relative order so tool
|
||||
# call sequences remain coherent.
|
||||
#
|
||||
# Note: when ``include_partial_messages=True`` is active the
|
||||
# per-token stream already emitted reasoning + text in their
|
||||
# natural on-the-wire order via ``_handle_stream_event``. The
|
||||
# summary walk below falls through to ``_emit_text_tail`` /
|
||||
# ``_emit_thinking_tail`` which emit only the diff, preserving
|
||||
# that ordering without duplicating content.
|
||||
blocks_with_idx = sorted(
|
||||
enumerate(sdk_message.content),
|
||||
key=lambda pair: 0 if isinstance(pair[1], ThinkingBlock) else 1,
|
||||
)
|
||||
|
||||
for block_index, block in blocks_with_idx:
|
||||
if isinstance(block, TextBlock):
|
||||
if block.text:
|
||||
# Reasoning and text are distinct UI parts; close
|
||||
# any open reasoning block before opening text so
|
||||
# the AI SDK transport doesn't merge them.
|
||||
# Reasoning and text are distinct UI parts; close any
|
||||
# open reasoning block before opening text so the AI
|
||||
# SDK transport doesn't merge them.
|
||||
tail = self._text_tail_for_summary_block(block.text)
|
||||
if tail:
|
||||
self._end_reasoning_if_open(responses)
|
||||
self._ensure_text_started(responses)
|
||||
responses.append(
|
||||
StreamTextDelta(id=self.text_block_id, delta=block.text)
|
||||
StreamTextDelta(id=self.text_block_id, delta=tail)
|
||||
)
|
||||
self._text_since_last_tool_result = True
|
||||
elif block.text:
|
||||
# Partial stream already emitted the full text.
|
||||
self._text_since_last_tool_result = True
|
||||
|
||||
elif isinstance(block, ThinkingBlock):
|
||||
# Stream extended_thinking content as a reasoning
|
||||
@@ -142,13 +226,38 @@ class SDKResponseAdapter:
|
||||
# it live, extended_thinking turns that end
|
||||
# thinking-only left the UI stuck on "Thought for Xs"
|
||||
# with nothing rendered until a page refresh.
|
||||
if block.thinking:
|
||||
#
|
||||
# When ``render_reasoning_in_ui=False`` the three
|
||||
# reasoning helpers below (and the append) no-op, so
|
||||
# the frontend sees a text-only stream AND no
|
||||
# ``ChatMessage(role='reasoning')`` row is persisted
|
||||
# (the row is only created by ``_dispatch_response``
|
||||
# when ``StreamReasoningStart`` arrives, which is
|
||||
# suppressed here). Persistence of the thinking text
|
||||
# into the SDK transcript via
|
||||
# ``_format_sdk_content_blocks`` is unaffected — that
|
||||
# feeds ``--resume`` continuity, not the UI.
|
||||
#
|
||||
# Flush any pending coalesce buffer to the wire BEFORE
|
||||
# computing the tail — otherwise a summary that
|
||||
# arrives between the last partial delta and the
|
||||
# ``content_block_stop`` event (race: summary is
|
||||
# flushed by the CLI as soon as the block is complete
|
||||
# provider-side, with stop lagging as a separate
|
||||
# frame) would see ``_partial_thinking_buffer``
|
||||
# missing the pending prefix, and
|
||||
# ``_thinking_tail_for_summary_block`` would emit the
|
||||
# full block — duplicating the tail that
|
||||
# ``_end_reasoning_if_open`` still drains on stop.
|
||||
self._flush_pending_thinking(responses)
|
||||
tail = self._thinking_tail_for_summary_block(block.thinking)
|
||||
if tail:
|
||||
self._end_text_if_open(responses)
|
||||
self._ensure_reasoning_started(responses)
|
||||
responses.append(
|
||||
StreamReasoningDelta(
|
||||
id=self.reasoning_block_id,
|
||||
delta=block.thinking,
|
||||
delta=tail,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -158,7 +267,7 @@ class SDKResponseAdapter:
|
||||
|
||||
# Strip MCP prefix so frontend sees "find_block"
|
||||
# instead of "mcp__copilot__find_block".
|
||||
tool_name = block.name.removeprefix(MCP_TOOL_PREFIX)
|
||||
tool_name = block.name.strip().removeprefix(MCP_TOOL_PREFIX)
|
||||
|
||||
responses.append(
|
||||
StreamToolInputStart(toolCallId=block.id, toolName=tool_name)
|
||||
@@ -347,8 +456,12 @@ class SDKResponseAdapter:
|
||||
"""Start (or restart) a reasoning block if needed.
|
||||
|
||||
Each ``ThinkingBlock`` the SDK emits gets its own streaming block
|
||||
on the wire so the frontend can render a new ``Reasoning`` part
|
||||
per LLM turn (rather than concatenating across the whole session).
|
||||
so the frontend can render a new ``Reasoning`` part per LLM turn
|
||||
(rather than concatenating across the whole session). Events
|
||||
are emitted unconditionally — the caller filters them out of the
|
||||
SSE wire when ``render_reasoning_in_ui=False`` but still feeds
|
||||
them through ``_dispatch_response`` so the session transcript
|
||||
keeps a ``role='reasoning'`` row.
|
||||
"""
|
||||
if not self.has_started_reasoning or self.has_ended_reasoning:
|
||||
if self.has_ended_reasoning:
|
||||
@@ -358,11 +471,238 @@ class SDKResponseAdapter:
|
||||
self.has_started_reasoning = True
|
||||
|
||||
def _end_reasoning_if_open(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""End the current reasoning block if one is open."""
|
||||
"""End the current reasoning block if one is open.
|
||||
|
||||
Drains any buffered thinking_delta text so the tail isn't lost
|
||||
when the block closes before the coalesce window elapses.
|
||||
"""
|
||||
if self.has_started_reasoning and not self.has_ended_reasoning:
|
||||
if self._pending_thinking_delta:
|
||||
responses.append(
|
||||
StreamReasoningDelta(
|
||||
id=self.reasoning_block_id,
|
||||
delta=self._pending_thinking_delta,
|
||||
)
|
||||
)
|
||||
self._partial_thinking_buffer += self._pending_thinking_delta
|
||||
self._pending_thinking_delta = ""
|
||||
self._pending_thinking_index = None
|
||||
responses.append(StreamReasoningEnd(id=self.reasoning_block_id))
|
||||
self.has_ended_reasoning = True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Partial-message streaming (CHAT_SDK_INCLUDE_PARTIAL_MESSAGES)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _reset_partial_stream_state(self) -> None:
|
||||
"""Clear per-message partial-stream state.
|
||||
|
||||
Anthropic's ``content_block`` indices are scoped to a single
|
||||
message — when a fresh ``message_start`` event arrives (new
|
||||
``AssistantMessage`` turn) the maps must reset so indices from
|
||||
the previous message don't suppress genuine content in the new
|
||||
one.
|
||||
|
||||
Also clears ``_partial_*_buffer``: multi-round turns (tool use)
|
||||
emit a ``message_start`` per LLM round, and leftover prefix
|
||||
content from round N would cause the summary walk in round N+1
|
||||
to either match the wrong prefix (silently dropping new content)
|
||||
or diverge and fall back to re-emitting the whole block.
|
||||
"""
|
||||
self._block_types_by_index = {}
|
||||
self._partial_text_buffer = ""
|
||||
self._partial_thinking_buffer = ""
|
||||
self._pending_thinking_delta = ""
|
||||
self._pending_thinking_index = None
|
||||
|
||||
def _text_tail_for_summary_block(self, full_text: str) -> str:
|
||||
"""Reconcile the next summary ``TextBlock`` against the running
|
||||
partial-stream buffer.
|
||||
|
||||
The CLI can emit the summary ``AssistantMessage`` before the
|
||||
matching ``content_block_stop`` event, so we can't rely on a
|
||||
queue of completed blocks. Instead we maintain
|
||||
``_partial_text_buffer`` — the concatenation of every
|
||||
``text_delta`` chunk that hasn't been claimed by a summary
|
||||
block yet — and consume ``full_text`` as a prefix from it.
|
||||
Summary blocks that have no partial backing (buffer empty)
|
||||
emit their full text; blocks that partial covered wholly are
|
||||
silent; blocks with a partial prefix + a summary tail emit
|
||||
only the tail. Kimi K2.6's pattern of emitting each content
|
||||
block as its own summary ``AssistantMessage`` is handled
|
||||
automatically because block order is preserved across both
|
||||
streams.
|
||||
"""
|
||||
if not full_text:
|
||||
return ""
|
||||
if not self._partial_text_buffer:
|
||||
return full_text
|
||||
if full_text.startswith(self._partial_text_buffer):
|
||||
tail = full_text[len(self._partial_text_buffer) :]
|
||||
self._partial_text_buffer = ""
|
||||
return tail
|
||||
if self._partial_text_buffer.startswith(full_text):
|
||||
# Partial already emitted this whole block plus more — the
|
||||
# "more" belongs to a later summary block. Consume only the
|
||||
# prefix matching this block and leave the rest buffered.
|
||||
self._partial_text_buffer = self._partial_text_buffer[len(full_text) :]
|
||||
return ""
|
||||
logger.warning(
|
||||
"SDK partial/summary text diverged "
|
||||
"(partial_buf=%d chars, summary=%d chars) — emitting summary, "
|
||||
"clearing partial buffer to recover",
|
||||
len(self._partial_text_buffer),
|
||||
len(full_text),
|
||||
)
|
||||
self._partial_text_buffer = ""
|
||||
return full_text
|
||||
|
||||
def _thinking_tail_for_summary_block(self, full_thinking: str) -> str:
|
||||
"""Same as :meth:`_text_tail_for_summary_block` for reasoning."""
|
||||
if not full_thinking:
|
||||
return ""
|
||||
if not self._partial_thinking_buffer:
|
||||
return full_thinking
|
||||
if full_thinking.startswith(self._partial_thinking_buffer):
|
||||
tail = full_thinking[len(self._partial_thinking_buffer) :]
|
||||
self._partial_thinking_buffer = ""
|
||||
return tail
|
||||
if self._partial_thinking_buffer.startswith(full_thinking):
|
||||
self._partial_thinking_buffer = self._partial_thinking_buffer[
|
||||
len(full_thinking) :
|
||||
]
|
||||
return ""
|
||||
logger.warning(
|
||||
"SDK partial/summary thinking diverged "
|
||||
"(partial_buf=%d chars, summary=%d chars) — emitting summary, "
|
||||
"clearing partial buffer to recover",
|
||||
len(self._partial_thinking_buffer),
|
||||
len(full_thinking),
|
||||
)
|
||||
self._partial_thinking_buffer = ""
|
||||
return full_thinking
|
||||
|
||||
def _handle_stream_event(
|
||||
self, evt: StreamEvent, responses: list[StreamBaseResponse]
|
||||
) -> None:
|
||||
"""Translate raw Anthropic streaming events into wire events.
|
||||
|
||||
Handles four event types; everything else (``message_delta``
|
||||
stop reasons, ``signature_delta``, ``input_json_delta``,
|
||||
``ping``, ...) is ignored because the summary ``AssistantMessage``
|
||||
carries their effects.
|
||||
|
||||
* ``message_start`` — new message boundary, reset per-index maps
|
||||
* ``content_block_start`` — open text / reasoning block on the
|
||||
wire and remember the block type at that index
|
||||
* ``content_block_delta`` — forward ``text_delta`` immediately
|
||||
and coalesce ``thinking_delta`` (64-char / 50 ms window)
|
||||
* ``content_block_stop`` — drain any buffered thinking and close
|
||||
the corresponding wire block
|
||||
"""
|
||||
raw: dict[str, Any] = evt.event or {}
|
||||
event_type = raw.get("type")
|
||||
|
||||
if event_type == "message_start":
|
||||
self._reset_partial_stream_state()
|
||||
return
|
||||
|
||||
if event_type == "content_block_start":
|
||||
block = raw.get("content_block") or {}
|
||||
index = raw.get("index")
|
||||
block_type = block.get("type")
|
||||
if not isinstance(index, int) or not isinstance(block_type, str):
|
||||
return
|
||||
self._block_types_by_index[index] = block_type
|
||||
if block_type == "text":
|
||||
self._end_reasoning_if_open(responses)
|
||||
self._ensure_text_started(responses)
|
||||
# Seed any preamble the block_start carries.
|
||||
seed = block.get("text") or ""
|
||||
if seed:
|
||||
responses.append(StreamTextDelta(id=self.text_block_id, delta=seed))
|
||||
self._partial_text_buffer += seed
|
||||
self._text_since_last_tool_result = True
|
||||
elif block_type == "thinking":
|
||||
self._end_text_if_open(responses)
|
||||
self._ensure_reasoning_started(responses)
|
||||
self._last_thinking_flush_monotonic = time.monotonic()
|
||||
# tool_use / server_tool_use / redacted_thinking blocks stay
|
||||
# on the ``AssistantMessage`` path — the frontend widgets
|
||||
# need the final ``input`` payload which only arrives in the
|
||||
# summary.
|
||||
return
|
||||
|
||||
if event_type == "content_block_delta":
|
||||
index = raw.get("index")
|
||||
if not isinstance(index, int):
|
||||
return
|
||||
delta = raw.get("delta") or {}
|
||||
delta_type = delta.get("type")
|
||||
if delta_type == "text_delta":
|
||||
chunk = delta.get("text") or ""
|
||||
if not chunk:
|
||||
return
|
||||
self._ensure_text_started(responses)
|
||||
responses.append(StreamTextDelta(id=self.text_block_id, delta=chunk))
|
||||
self._partial_text_buffer += chunk
|
||||
self._text_since_last_tool_result = True
|
||||
elif delta_type == "thinking_delta":
|
||||
chunk = delta.get("thinking") or ""
|
||||
if not chunk:
|
||||
return
|
||||
self._ensure_reasoning_started(responses)
|
||||
# Flush the coalesce buffer if the index changed — shouldn't
|
||||
# happen in practice but guard against interleaved indices.
|
||||
if (
|
||||
self._pending_thinking_index is not None
|
||||
and self._pending_thinking_index != index
|
||||
):
|
||||
self._flush_pending_thinking(responses)
|
||||
self._pending_thinking_delta += chunk
|
||||
self._pending_thinking_index = index
|
||||
now = time.monotonic()
|
||||
elapsed_ms = (now - self._last_thinking_flush_monotonic) * 1000.0
|
||||
if (
|
||||
len(self._pending_thinking_delta) >= _THINKING_COALESCE_MIN_CHARS
|
||||
or elapsed_ms >= _THINKING_COALESCE_MAX_INTERVAL_MS
|
||||
):
|
||||
self._flush_pending_thinking(responses)
|
||||
self._last_thinking_flush_monotonic = now
|
||||
# Other delta types (``signature_delta``, ``input_json_delta``)
|
||||
# are CLI / tool-dispatch plumbing — not surfaced on the wire.
|
||||
return
|
||||
|
||||
if event_type == "content_block_stop":
|
||||
index = raw.get("index")
|
||||
if not isinstance(index, int):
|
||||
return
|
||||
block_type = self._block_types_by_index.pop(index, None)
|
||||
if block_type == "text":
|
||||
self._end_text_if_open(responses)
|
||||
elif block_type == "thinking":
|
||||
self._end_reasoning_if_open(responses)
|
||||
return
|
||||
|
||||
def _flush_pending_thinking(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Drain the coalesce buffer into a ``StreamReasoningDelta``.
|
||||
|
||||
Separate from ``_end_reasoning_if_open`` because the coalesce
|
||||
window can flush mid-block (threshold hit) without closing the
|
||||
reasoning block.
|
||||
"""
|
||||
if not self._pending_thinking_delta:
|
||||
return
|
||||
responses.append(
|
||||
StreamReasoningDelta(
|
||||
id=self.reasoning_block_id,
|
||||
delta=self._pending_thinking_delta,
|
||||
)
|
||||
)
|
||||
self._partial_thinking_buffer += self._pending_thinking_delta
|
||||
self._pending_thinking_delta = ""
|
||||
self._pending_thinking_index = None
|
||||
|
||||
def _flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None:
|
||||
"""Emit outputs for tool calls that didn't receive a UserMessage result.
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import pytest
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
ResultMessage,
|
||||
StreamEvent,
|
||||
SystemMessage,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
@@ -21,6 +22,7 @@ from backend.copilot.response_model import (
|
||||
StreamFinishStep,
|
||||
StreamHeartbeat,
|
||||
StreamReasoningDelta,
|
||||
StreamReasoningEnd,
|
||||
StreamStart,
|
||||
StreamStartStep,
|
||||
StreamTextDelta,
|
||||
@@ -136,6 +138,29 @@ def test_tool_use_emits_input_start_and_available():
|
||||
assert results[2].input == {"q": "x"}
|
||||
|
||||
|
||||
def test_tool_use_strips_whitespace_in_tool_name():
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(
|
||||
content=[
|
||||
ToolUseBlock(
|
||||
id="tool-1",
|
||||
name=f" {MCP_TOOL_PREFIX}find_block",
|
||||
input={},
|
||||
)
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
tool_events = [
|
||||
r
|
||||
for r in results
|
||||
if isinstance(r, (StreamToolInputStart, StreamToolInputAvailable))
|
||||
]
|
||||
assert tool_events, "expected tool input events"
|
||||
for event in tool_events:
|
||||
assert event.toolName == "find_block"
|
||||
|
||||
|
||||
def test_text_then_tool_ends_text_block():
|
||||
adapter = _adapter()
|
||||
text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test")
|
||||
@@ -298,6 +323,36 @@ def test_text_after_thinking_closes_reasoning_and_opens_text():
|
||||
assert re_idx < ts_idx
|
||||
|
||||
|
||||
def test_thinking_after_text_in_same_message_renders_reasoning_first():
|
||||
"""Kimi K2.6 (and other non-Anthropic OpenRouter providers) place
|
||||
``reasoning`` AFTER the visible text in the response, so the SDK
|
||||
builds an ``AssistantMessage`` with content = [TextBlock, ThinkingBlock].
|
||||
Without reordering, the UI would show the answer first and the
|
||||
reasoning panel below it — the opposite of the natural reading
|
||||
order Anthropic models produce. response_adapter must hoist
|
||||
ThinkingBlocks to the front so ``reasoning-start/delta/end`` events
|
||||
hit the SSE stream BEFORE the ``text-*`` events."""
|
||||
adapter = _adapter()
|
||||
msg = AssistantMessage(
|
||||
content=[
|
||||
TextBlock(text="63"),
|
||||
ThinkingBlock(thinking="7 times 9 is 63", signature=""),
|
||||
],
|
||||
model="test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
types = [type(r).__name__ for r in results]
|
||||
# ReasoningStart must land before TextStart in the emitted stream
|
||||
assert "StreamReasoningStart" in types
|
||||
assert "StreamTextStart" in types
|
||||
assert types.index("StreamReasoningStart") < types.index("StreamTextStart")
|
||||
# ReasoningDelta payload is intact
|
||||
assert any(
|
||||
isinstance(r, StreamReasoningDelta) and r.delta == "7 times 9 is 63"
|
||||
for r in results
|
||||
)
|
||||
|
||||
|
||||
def test_tool_use_after_thinking_closes_reasoning():
|
||||
"""Opening a tool also closes an open reasoning block."""
|
||||
adapter = _adapter()
|
||||
@@ -331,6 +386,69 @@ def test_empty_thinking_block_is_ignored():
|
||||
assert [type(r).__name__ for r in results] == ["StreamStartStep"]
|
||||
|
||||
|
||||
def test_render_reasoning_in_ui_false_still_emits_adapter_events():
|
||||
"""With the persist/render decoupling the adapter is flag-agnostic:
|
||||
it always emits ``StreamReasoning*`` so the session transcript keeps a
|
||||
durable reasoning record. Wire-level suppression when
|
||||
``render_reasoning_in_ui=False`` happens at the SDK service yield
|
||||
boundary, not here — see
|
||||
``backend/copilot/sdk/service.py::_filter_reasoning_events``.
|
||||
"""
|
||||
adapter = SDKResponseAdapter(
|
||||
message_id="m",
|
||||
session_id="s",
|
||||
render_reasoning_in_ui=False,
|
||||
)
|
||||
msg = AssistantMessage(
|
||||
content=[ThinkingBlock(thinking="plan", signature="sig")],
|
||||
model="test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
types = [type(r).__name__ for r in results]
|
||||
assert "StreamReasoningStart" in types
|
||||
assert "StreamReasoningDelta" in types
|
||||
|
||||
|
||||
def test_render_reasoning_off_text_after_thinking_still_closes_reasoning():
|
||||
"""Adapter still emits a ``StreamReasoningEnd`` when text follows a
|
||||
thinking block — decoupled from the render flag. The service layer
|
||||
drops the reasoning events at yield time; the adapter's structural
|
||||
open/close pairing must not depend on the flag or downstream filters
|
||||
would see orphan reasoning starts on the persisted transcript.
|
||||
"""
|
||||
adapter = SDKResponseAdapter(
|
||||
message_id="m",
|
||||
session_id="s",
|
||||
render_reasoning_in_ui=False,
|
||||
)
|
||||
adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ThinkingBlock(thinking="warming up", signature="sig")],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
results = adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="hello")], model="test")
|
||||
)
|
||||
types = [type(r).__name__ for r in results]
|
||||
assert "StreamReasoningEnd" in types
|
||||
assert "StreamTextStart" in types
|
||||
assert "StreamTextDelta" in types
|
||||
|
||||
|
||||
def test_render_reasoning_on_is_default():
|
||||
"""Default is True — existing callers keep emitting reasoning events."""
|
||||
adapter = SDKResponseAdapter(message_id="m", session_id="s")
|
||||
msg = AssistantMessage(
|
||||
content=[ThinkingBlock(thinking="plan", signature="sig")],
|
||||
model="test",
|
||||
)
|
||||
results = adapter.convert_message(msg)
|
||||
types = [type(r).__name__ for r in results]
|
||||
assert "StreamReasoningStart" in types
|
||||
assert "StreamReasoningDelta" in types
|
||||
|
||||
|
||||
def test_result_success_synthesizes_fallback_text_when_final_turn_is_thinking_only():
|
||||
"""If the model's last LLM call after a tool_result produced only a
|
||||
ThinkingBlock (no TextBlock), the UI would hang on the tool output
|
||||
@@ -992,3 +1110,318 @@ def test_end_text_if_open_no_op_after_text_already_ended():
|
||||
second: list[StreamBaseResponse] = []
|
||||
adapter._end_text_if_open(second)
|
||||
assert second == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Partial-message streaming (CHAT_SDK_INCLUDE_PARTIAL_MESSAGES)
|
||||
# Covers the 10 scenarios in docs/sdk-per-token-streaming-followup.md
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _stream_event(payload: dict) -> StreamEvent:
|
||||
"""Convenience constructor for a raw Anthropic StreamEvent payload."""
|
||||
return StreamEvent(
|
||||
uuid="stream-evt",
|
||||
session_id="session-1",
|
||||
parent_tool_use_id=None,
|
||||
event=payload,
|
||||
)
|
||||
|
||||
|
||||
def _message_start() -> StreamEvent:
|
||||
return _stream_event({"type": "message_start"})
|
||||
|
||||
|
||||
def _text_block_start(index: int) -> StreamEvent:
|
||||
return _stream_event(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": {"type": "text", "text": ""},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _text_delta(index: int, text: str) -> StreamEvent:
|
||||
return _stream_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": {"type": "text_delta", "text": text},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _thinking_block_start(index: int) -> StreamEvent:
|
||||
return _stream_event(
|
||||
{
|
||||
"type": "content_block_start",
|
||||
"index": index,
|
||||
"content_block": {"type": "thinking", "thinking": ""},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _thinking_delta(index: int, text: str) -> StreamEvent:
|
||||
return _stream_event(
|
||||
{
|
||||
"type": "content_block_delta",
|
||||
"index": index,
|
||||
"delta": {"type": "thinking_delta", "thinking": text},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _block_stop(index: int) -> StreamEvent:
|
||||
return _stream_event({"type": "content_block_stop", "index": index})
|
||||
|
||||
|
||||
def _collect_text_deltas(responses):
|
||||
return "".join(r.delta for r in responses if isinstance(r, StreamTextDelta))
|
||||
|
||||
|
||||
def _collect_reasoning_deltas(responses):
|
||||
return "".join(r.delta for r in responses if isinstance(r, StreamReasoningDelta))
|
||||
|
||||
|
||||
class TestPartialMessageStreaming:
|
||||
"""Scenarios 1-10 from sdk-per-token-streaming-followup.md.
|
||||
|
||||
The adapter runs unconditionally in partial-aware mode — when the
|
||||
flag ``CHAT_SDK_INCLUDE_PARTIAL_MESSAGES`` is off the CLI simply
|
||||
never emits ``StreamEvent`` messages and the diff maps stay empty
|
||||
(so the tail logic degrades to "emit the full summary content"
|
||||
which is the pre-partial behaviour).
|
||||
"""
|
||||
|
||||
def test_partial_and_summary_agree_no_duplicate(self):
|
||||
"""Scenario 1: partial streams full text, summary matches exactly.
|
||||
No duplicate emission, no truncation — full content reaches the
|
||||
wire once."""
|
||||
adapter = _adapter()
|
||||
full = "Hello world"
|
||||
responses: list[StreamBaseResponse] = []
|
||||
adapter._handle_stream_event(_text_block_start(0), responses)
|
||||
for chunk in ("Hello", " ", "world"):
|
||||
adapter._handle_stream_event(_text_delta(0, chunk), responses)
|
||||
adapter._handle_stream_event(_block_stop(0), responses)
|
||||
# Summary arrives with the same full text
|
||||
summary = adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text=full)], model="test")
|
||||
)
|
||||
combined = responses + summary
|
||||
assert _collect_text_deltas(combined) == full
|
||||
|
||||
def test_partial_short_summary_long_tail_emitted(self):
|
||||
"""Scenario 2 (the truncation bug we saw): partial emitted a
|
||||
prefix of the real answer; summary has the full text. The
|
||||
adapter must emit only the tail so no content is lost."""
|
||||
adapter = _adapter()
|
||||
responses: list[StreamBaseResponse] = []
|
||||
adapter._handle_stream_event(_text_block_start(0), responses)
|
||||
for chunk in ("The user ", "seems confused. They sent"):
|
||||
adapter._handle_stream_event(_text_delta(0, chunk), responses)
|
||||
# Summary has the full, un-truncated content
|
||||
full = (
|
||||
"The user seems confused. They sent a short greeting. "
|
||||
"Let me offer them concrete options."
|
||||
)
|
||||
summary = adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text=full)], model="test")
|
||||
)
|
||||
combined = responses + summary
|
||||
assert _collect_text_deltas(combined) == full
|
||||
|
||||
def test_partial_empty_summary_only(self):
|
||||
"""Scenario 3: no partial deltas (CLI emitted the block entirely
|
||||
in the summary — short blocks, proxy buffering, encrypted
|
||||
content). Summary carries the full text."""
|
||||
adapter = _adapter()
|
||||
summary = adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="short answer")], model="test")
|
||||
)
|
||||
assert _collect_text_deltas(summary) == "short answer"
|
||||
|
||||
def test_partial_long_summary_matches_no_double_emit(self):
|
||||
"""Scenario 4 (most common): partial streams everything, summary
|
||||
repeats the same content. No duplication on the wire."""
|
||||
adapter = _adapter()
|
||||
responses: list[StreamBaseResponse] = []
|
||||
full = "Here is a long paragraph with several words in it."
|
||||
adapter._handle_stream_event(_text_block_start(0), responses)
|
||||
# Partition into chunks that *exactly* reconstruct ``full`` — a
|
||||
# word-split with trailing spaces would emit more content than
|
||||
# the summary carries and the reconcile would correctly flag
|
||||
# divergence.
|
||||
chunks = [full[:13], full[13:25], full[25:]]
|
||||
assert "".join(chunks) == full
|
||||
for chunk in chunks:
|
||||
adapter._handle_stream_event(_text_delta(0, chunk), responses)
|
||||
adapter._handle_stream_event(_block_stop(0), responses)
|
||||
assert _collect_text_deltas(responses) == full
|
||||
|
||||
summary = adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text=full)], model="test")
|
||||
)
|
||||
# Summary must not add any TextDelta since partial already covered it
|
||||
assert _collect_text_deltas(summary) == ""
|
||||
|
||||
def test_partial_diverges_summary_wins(self):
|
||||
"""Scenario 5: partial content isn't a prefix of the summary.
|
||||
Defensive path emits the full summary content — content must
|
||||
not silently disappear."""
|
||||
adapter = _adapter()
|
||||
responses: list[StreamBaseResponse] = []
|
||||
adapter._handle_stream_event(_text_block_start(0), responses)
|
||||
adapter._handle_stream_event(_text_delta(0, "first draft"), responses)
|
||||
# Summary has totally different content (proxy rewrote it)
|
||||
summary = adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[TextBlock(text="final polished answer")],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
# The summary's text must reach the wire even though partial
|
||||
# already emitted "first draft" (which was the proxy's draft).
|
||||
assert "final polished answer" in _collect_text_deltas(responses + summary)
|
||||
|
||||
def test_thinking_only_partial_coalesced(self):
|
||||
"""Scenario 6a (thinking-only permutation): a run of
|
||||
``thinking_delta`` events below the coalesce threshold flushes
|
||||
at ``content_block_stop`` so the reasoning tail isn't lost."""
|
||||
adapter = _adapter()
|
||||
responses: list[StreamBaseResponse] = []
|
||||
adapter._handle_stream_event(_thinking_block_start(0), responses)
|
||||
# Each chunk is well under the 64-char threshold
|
||||
for chunk in ("Let ", "me ", "think"):
|
||||
adapter._handle_stream_event(_thinking_delta(0, chunk), responses)
|
||||
# At stop, the pending buffer drains
|
||||
adapter._handle_stream_event(_block_stop(0), responses)
|
||||
assert _collect_reasoning_deltas(responses) == "Let me think"
|
||||
# Block closed
|
||||
assert any(isinstance(r, StreamReasoningEnd) for r in responses)
|
||||
|
||||
def test_text_only_via_partial_and_summary(self):
|
||||
"""Scenario 6b (text-only permutation): partial fills a block,
|
||||
summary matches — see scenario 4 for no-double-emit assertion."""
|
||||
adapter = _adapter()
|
||||
responses: list[StreamBaseResponse] = []
|
||||
adapter._handle_stream_event(_text_block_start(0), responses)
|
||||
adapter._handle_stream_event(_text_delta(0, "hi"), responses)
|
||||
adapter._handle_stream_event(_block_stop(0), responses)
|
||||
assert _collect_text_deltas(responses) == "hi"
|
||||
|
||||
def test_mixed_text_then_thinking_partial_preserves_order(self):
|
||||
"""Scenario 6c (mixed, Anthropic order — reasoning then text).
|
||||
When partial emits blocks in natural order and summary matches,
|
||||
the wire order is identical to emission order."""
|
||||
adapter = _adapter()
|
||||
responses: list[StreamBaseResponse] = []
|
||||
# Anthropic-shape: thinking index 0, text index 1
|
||||
adapter._handle_stream_event(_thinking_block_start(0), responses)
|
||||
adapter._handle_stream_event(
|
||||
_thinking_delta(0, "X" * 80), responses
|
||||
) # over threshold
|
||||
adapter._handle_stream_event(_block_stop(0), responses)
|
||||
adapter._handle_stream_event(_text_block_start(1), responses)
|
||||
adapter._handle_stream_event(_text_delta(1, "answer"), responses)
|
||||
adapter._handle_stream_event(_block_stop(1), responses)
|
||||
types = [type(r).__name__ for r in responses]
|
||||
# ReasoningStart must come before TextStart — partial streams in
|
||||
# the CLI's natural order, which is also the UI's desired order.
|
||||
assert types.index("StreamReasoningStart") < types.index("StreamTextStart")
|
||||
|
||||
def test_multi_message_turn_resets_per_index_maps(self):
|
||||
"""Scenario 7: tool-use loop creates multiple AssistantMessages
|
||||
per turn. Anthropic content-block indices are scoped to a single
|
||||
message — ``message_start`` must reset the diff maps so the next
|
||||
message's index-0 text isn't silently suppressed."""
|
||||
adapter = _adapter()
|
||||
responses: list[StreamBaseResponse] = []
|
||||
# First message at index 0 = "first"
|
||||
adapter._handle_stream_event(_message_start(), responses)
|
||||
adapter._handle_stream_event(_text_block_start(0), responses)
|
||||
adapter._handle_stream_event(_text_delta(0, "first"), responses)
|
||||
adapter._handle_stream_event(_block_stop(0), responses)
|
||||
# New message starts — index 0 now refers to a fresh block
|
||||
adapter._handle_stream_event(_message_start(), responses)
|
||||
adapter._handle_stream_event(_text_block_start(0), responses)
|
||||
adapter._handle_stream_event(_text_delta(0, "second"), responses)
|
||||
adapter._handle_stream_event(_block_stop(0), responses)
|
||||
# Both texts must land on the wire
|
||||
assert _collect_text_deltas(responses) == "firstsecond"
|
||||
|
||||
def test_empty_thinking_with_signature_emits_nothing(self):
|
||||
"""Scenario 8: encrypted / empty thinking block. Partial emits
|
||||
nothing, summary carries ``block.thinking == ""`` with a
|
||||
signature — the adapter must not open a reasoning block."""
|
||||
adapter = _adapter()
|
||||
summary = adapter.convert_message(
|
||||
AssistantMessage(
|
||||
content=[ThinkingBlock(thinking="", signature="sig")],
|
||||
model="test",
|
||||
)
|
||||
)
|
||||
# No reasoning events should be emitted for empty thinking
|
||||
reasoning_events = [
|
||||
r
|
||||
for r in summary
|
||||
if isinstance(r, StreamReasoningDelta)
|
||||
or type(r).__name__ in ("StreamReasoningStart", "StreamReasoningEnd")
|
||||
]
|
||||
assert reasoning_events == []
|
||||
|
||||
def test_thinking_tail_drains_on_block_stop(self):
|
||||
"""Scenario 10: a thinking_delta chunk smaller than the 64-char
|
||||
threshold arrives, then ``content_block_stop``. The tail text
|
||||
must emit in a final ``StreamReasoningDelta`` BEFORE
|
||||
``StreamReasoningEnd``."""
|
||||
adapter = _adapter()
|
||||
responses: list[StreamBaseResponse] = []
|
||||
adapter._handle_stream_event(_thinking_block_start(0), responses)
|
||||
# One small chunk well under 64 chars
|
||||
adapter._handle_stream_event(_thinking_delta(0, "tiny chunk"), responses)
|
||||
# Block stop must flush the pending buffer
|
||||
adapter._handle_stream_event(_block_stop(0), responses)
|
||||
types = [type(r).__name__ for r in responses]
|
||||
# The final ReasoningDelta must precede ReasoningEnd
|
||||
rd_idx = types.index("StreamReasoningDelta")
|
||||
re_idx = types.index("StreamReasoningEnd")
|
||||
assert rd_idx < re_idx
|
||||
assert _collect_reasoning_deltas(responses) == "tiny chunk"
|
||||
|
||||
def test_thinking_coalesces_on_char_threshold(self):
|
||||
"""Extra: thinking_delta accumulating past 64 chars flushes
|
||||
mid-block without waiting for block_stop (coalesce threshold)."""
|
||||
adapter = _adapter()
|
||||
responses: list[StreamBaseResponse] = []
|
||||
adapter._handle_stream_event(_thinking_block_start(0), responses)
|
||||
# One 80-char chunk trips the threshold on a single event
|
||||
adapter._handle_stream_event(_thinking_delta(0, "x" * 80), responses)
|
||||
# A ReasoningDelta must already have been emitted (not buffered
|
||||
# until block_stop).
|
||||
assert any(isinstance(r, StreamReasoningDelta) for r in responses)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Partial/summary reconcile — summary walk must not duplicate partial content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_summary_walk_skips_fully_streamed_text():
|
||||
"""If the partial stream delivered the entire TextBlock, the summary
|
||||
walk must not emit a second ``StreamTextDelta`` for the same block."""
|
||||
adapter = _adapter()
|
||||
responses: list[StreamBaseResponse] = []
|
||||
adapter._handle_stream_event(_text_block_start(0), responses)
|
||||
adapter._handle_stream_event(_text_delta(0, "complete answer"), responses)
|
||||
adapter._handle_stream_event(_block_stop(0), responses)
|
||||
# Summary arrives with matching content
|
||||
summary = adapter.convert_message(
|
||||
AssistantMessage(content=[TextBlock(text="complete answer")], model="test")
|
||||
)
|
||||
# Partial path emitted exactly one StreamTextDelta
|
||||
partial_deltas = [r for r in responses if isinstance(r, StreamTextDelta)]
|
||||
summary_deltas = [r for r in summary if isinstance(r, StreamTextDelta)]
|
||||
assert len(partial_deltas) == 1
|
||||
assert summary_deltas == []
|
||||
|
||||
@@ -27,6 +27,7 @@ from claude_agent_sdk import (
|
||||
ClaudeAgentOptions,
|
||||
ClaudeSDKClient,
|
||||
ResultMessage,
|
||||
StreamEvent,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
ToolResultBlock,
|
||||
@@ -56,6 +57,11 @@ from ..constants import (
|
||||
from ..session_cleanup import prune_orphan_tool_calls
|
||||
from ..context import encode_cwd_for_cli, get_workspace_manager
|
||||
from ..graphiti.config import is_enabled_for_user
|
||||
from ..model_router import resolve_model
|
||||
from ..moonshot import (
|
||||
is_moonshot_model as _is_moonshot_model,
|
||||
override_cost_usd as _override_cost_for_moonshot,
|
||||
)
|
||||
from ..model import (
|
||||
ChatMessage,
|
||||
ChatSession,
|
||||
@@ -109,6 +115,7 @@ from ..service import (
|
||||
)
|
||||
from ..thinking_stripper import ThinkingStripper
|
||||
from ..token_tracking import persist_and_record_usage
|
||||
from ..tools import ToolGroup
|
||||
from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct
|
||||
from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path
|
||||
from ..tracking import track_user_message
|
||||
@@ -129,6 +136,7 @@ from ..transcript import (
|
||||
from ..transcript_builder import TranscriptBuilder
|
||||
from .compaction import CompactionTracker, filter_compaction_messages
|
||||
from .env import build_sdk_env # noqa: F401 — re-export for backward compat
|
||||
from .openrouter_cost import record_turn_cost_from_openrouter
|
||||
from .response_adapter import SDKResponseAdapter
|
||||
from .security_hooks import create_security_hooks
|
||||
from .tool_adapter import (
|
||||
@@ -364,6 +372,20 @@ class _RetryState:
|
||||
# ``detect_gap`` picks them up as gap-fill entries instead of assuming the
|
||||
# JSONL already covers them.
|
||||
midturn_user_rows: int = 0
|
||||
# OpenRouter generation IDs collected across all attempts of this turn.
|
||||
# Populated from ``AssistantMessage.message_id`` when routed via
|
||||
# OpenRouter (``gen-...`` prefix). Consumed by the finally block to
|
||||
# fire ``record_turn_cost_from_openrouter`` for non-Anthropic models —
|
||||
# the CLI's static-Anthropic-priced estimate is replaced with the
|
||||
# authoritative ``/generation`` total_cost. Lives on ``_RetryState``
|
||||
# (not per-attempt ``_StreamAccumulator``) so it survives retries.
|
||||
generation_ids: list[str] = dataclass_field(default_factory=list)
|
||||
# The *actually executed* model observed on ``AssistantMessage.model`` —
|
||||
# differs from ``state.options.model`` (the requested primary) when
|
||||
# ``_resolve_fallback_model`` swaps to a fallback mid-attempt. The
|
||||
# Moonshot cost override gates on this so a Moonshot-→-Anthropic
|
||||
# fallback doesn't get mis-billed at Moonshot rates, and vice versa.
|
||||
observed_model: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -678,35 +700,51 @@ async def _iter_sdk_messages(
|
||||
def _normalize_model_name(raw_model: str) -> str:
|
||||
"""Normalize a model name for the current routing configuration.
|
||||
|
||||
Applies two transformations shared by both the primary and fallback
|
||||
model resolution paths:
|
||||
Two routing modes:
|
||||
|
||||
1. **Strip provider prefix** — OpenRouter-style names like
|
||||
``"anthropic/claude-opus-4.6"`` are reduced to ``"claude-opus-4.6"``.
|
||||
2. **Dot-to-hyphen conversion** — when *not* routing through OpenRouter
|
||||
the direct Anthropic API requires hyphen-separated versions
|
||||
(``"claude-opus-4-6"``), so dots are replaced with hyphens.
|
||||
1. **OpenRouter active** — the canonical OpenRouter slug is
|
||||
``"<vendor>/<model>"`` (e.g. ``"anthropic/claude-opus-4.6"``,
|
||||
``"moonshotai/kimi-k2.6"``). Pass the prefixed name through
|
||||
unchanged so OpenRouter can route to the correct provider. Anthropic
|
||||
names happen to also resolve when stripped, but non-Anthropic vendors
|
||||
(Moonshot, Google, etc.) do not — keeping the prefix is the only form
|
||||
that works for every model in the catalog.
|
||||
2. **Direct Anthropic** — strip the OpenRouter ``anthropic/`` prefix
|
||||
and convert dots to hyphens (``"claude-opus-4.6"`` →
|
||||
``"claude-opus-4-6"``) since the Anthropic Messages API rejects
|
||||
both the prefix and dot-separated versions. Raises ``ValueError``
|
||||
when a non-Anthropic vendor slug is paired with direct-Anthropic
|
||||
mode — silently stripping ``moonshotai/`` would send ``kimi-k2.6``
|
||||
to the Anthropic API and produce an opaque ``model_not_found``
|
||||
error far from the misconfiguration source.
|
||||
"""
|
||||
if config.openrouter_active:
|
||||
return raw_model
|
||||
model = raw_model
|
||||
if "/" in model:
|
||||
model = model.split("/", 1)[1]
|
||||
# OpenRouter uses dots in versions (claude-opus-4.6) but the direct
|
||||
# Anthropic API requires hyphens (claude-opus-4-6). Only normalise
|
||||
# when NOT routing through OpenRouter.
|
||||
if not config.openrouter_active:
|
||||
model = model.replace(".", "-")
|
||||
return model
|
||||
vendor, model = model.split("/", 1)
|
||||
if vendor != "anthropic":
|
||||
raise ValueError(
|
||||
f"Direct-Anthropic mode (use_openrouter=False or missing "
|
||||
f"OpenRouter credentials) requires an Anthropic model, got "
|
||||
f"vendor={vendor!r} from model={raw_model!r}. Set "
|
||||
f"CHAT_THINKING_STANDARD_MODEL/CHAT_THINKING_ADVANCED_MODEL "
|
||||
f"to an anthropic/* slug, or enable OpenRouter."
|
||||
)
|
||||
return model.replace(".", "-")
|
||||
|
||||
|
||||
def _resolve_sdk_model() -> str | None:
|
||||
"""Resolve the model name for the Claude Agent SDK CLI.
|
||||
"""Resolve the SDK-CLI model name from static config (no LD lookup).
|
||||
|
||||
Uses `config.claude_agent_model` if set, otherwise derives from
|
||||
`config.thinking_standard_model` via :func:`_normalize_model_name`.
|
||||
``config.claude_agent_model`` is an explicit override that wins
|
||||
unconditionally. When the Claude Code subscription is enabled and no
|
||||
override is set, returns ``None`` so the CLI picks the model for the
|
||||
user's subscription plan. Otherwise derives from
|
||||
``config.thinking_standard_model``.
|
||||
|
||||
When `use_claude_code_subscription` is enabled and no explicit
|
||||
`claude_agent_model` is set, returns `None` so the CLI uses the
|
||||
default model for the user's subscription plan.
|
||||
For per-user routing (LaunchDarkly overrides), see
|
||||
:func:`_resolve_sdk_model_for_request`.
|
||||
"""
|
||||
if config.claude_agent_model:
|
||||
return config.claude_agent_model
|
||||
@@ -715,6 +753,18 @@ def _resolve_sdk_model() -> str | None:
|
||||
return _normalize_model_name(config.thinking_standard_model)
|
||||
|
||||
|
||||
async def _resolve_thinking_model_for_user(
|
||||
tier: "CopilotLlmModel",
|
||||
user_id: str | None,
|
||||
) -> str:
|
||||
"""LD-aware thinking-tier model pick for a specific user.
|
||||
|
||||
Consults ``copilot-thinking-{tier}-model`` and falls back to the
|
||||
``ChatConfig`` default on missing user / missing flag.
|
||||
"""
|
||||
return await resolve_model("thinking", tier, user_id, config=config)
|
||||
|
||||
|
||||
def _resolve_fallback_model() -> str | None:
|
||||
"""Resolve the fallback model name via :func:`_normalize_model_name`.
|
||||
|
||||
@@ -729,37 +779,94 @@ def _resolve_fallback_model() -> str | None:
|
||||
async def _resolve_sdk_model_for_request(
|
||||
model: "CopilotLlmModel | None",
|
||||
session_id: str,
|
||||
user_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""Resolve the SDK model string for a turn.
|
||||
|
||||
Priority (highest first):
|
||||
1. Explicit per-request ``model`` tier from the frontend toggle.
|
||||
2. Global config default (``_resolve_sdk_model()``).
|
||||
|
||||
Returns ``None`` when the Claude Code subscription default applies.
|
||||
Rate-limit accounting no longer applies a multiplier — the real turn
|
||||
cost (reported by the SDK) already reflects model-pricing differences.
|
||||
1. ``config.claude_agent_model`` — unconditional override, bypasses LD.
|
||||
2. LaunchDarkly ``copilot-thinking-{tier}-model`` if it serves a value
|
||||
different from the config default for *user_id*. An LD-served
|
||||
override wins over subscription mode so admins can route specific
|
||||
users to a specific model without flipping subscription on/off.
|
||||
3. ``config.use_claude_code_subscription`` on the standard tier —
|
||||
returns ``None`` so the CLI picks the subscription default (this
|
||||
branch fires when LD has no opinion, i.e. the value equals the
|
||||
config default).
|
||||
4. ``ChatConfig`` static default for the tier.
|
||||
"""
|
||||
if model == "advanced":
|
||||
sdk_model = _normalize_model_name(config.thinking_advanced_model)
|
||||
if config.claude_agent_model:
|
||||
return config.claude_agent_model
|
||||
|
||||
tier_name: "CopilotLlmModel" = "advanced" if model == "advanced" else "standard"
|
||||
# Strip at read time so a stray trailing space in ``CHAT_*_MODEL`` (a
|
||||
# common ``.env`` pitfall) doesn't make the ``resolved == tier_default``
|
||||
# comparison below spuriously diverge — ``resolve_model`` already strips
|
||||
# the LD side, so both halves must end up whitespace-normalised to stay
|
||||
# equal when they're semantically equal. Downstream ``_normalize_model_name``
|
||||
# also benefits from the strip.
|
||||
tier_default = (
|
||||
config.thinking_advanced_model
|
||||
if tier_name == "advanced"
|
||||
else config.thinking_standard_model
|
||||
).strip()
|
||||
|
||||
resolved = await _resolve_thinking_model_for_user(tier_name, user_id)
|
||||
|
||||
# Subscription mode on standard tier only wins when LD has no opinion
|
||||
# (value == config default ⇒ admin hasn't explicitly pointed this
|
||||
# user somewhere). Any LD override — even to the same value with
|
||||
# stripped whitespace normalised — is an explicit admin choice that
|
||||
# must be honoured. Without this, a subscription-mode deployment
|
||||
# silently ignores the ``copilot-thinking-standard-model`` flag
|
||||
# entirely, which defeats the point of cohort-based routing.
|
||||
ld_overrides_default = resolved != tier_default
|
||||
if (
|
||||
not ld_overrides_default
|
||||
and tier_name == "standard"
|
||||
and config.use_claude_code_subscription
|
||||
):
|
||||
logger.info(
|
||||
"[SDK] [%s] Per-request model override: advanced (%s)",
|
||||
"[SDK] [%s] Subscription default (tier=standard, LD unset)",
|
||||
session_id[:12] if session_id else "?",
|
||||
)
|
||||
return None
|
||||
try:
|
||||
sdk_model = _normalize_model_name(resolved)
|
||||
except ValueError as exc:
|
||||
# The per-user LD value didn't pass ``_normalize_model_name``'s
|
||||
# vendor check (most commonly: a ``moonshotai/kimi-*`` slug on a
|
||||
# direct-Anthropic deployment that has no OpenRouter route). Fail
|
||||
# soft to the TIER-SPECIFIC config default — using the generic
|
||||
# ``_resolve_sdk_model()`` here would pin advanced-tier requests to
|
||||
# ``thinking_standard_model`` (Sonnet) whenever LD misconfigures
|
||||
# the advanced cell, silently downgrading the user's chosen tier.
|
||||
try:
|
||||
sdk_model = _normalize_model_name(tier_default)
|
||||
except ValueError:
|
||||
# Config default is *also* invalid for the active routing
|
||||
# mode — this is a deployment-level misconfig that the
|
||||
# ``model_validator`` should catch at startup. Re-raise the
|
||||
# original LD error so the issue surfaces loudly rather than
|
||||
# returning something misleading.
|
||||
raise exc
|
||||
logger.warning(
|
||||
"[SDK] [%s] LD model %r rejected for tier=%s (%s); falling "
|
||||
"back to tier default %s",
|
||||
session_id[:12] if session_id else "?",
|
||||
resolved,
|
||||
tier_name,
|
||||
exc,
|
||||
sdk_model,
|
||||
)
|
||||
return sdk_model
|
||||
|
||||
if model == "standard":
|
||||
# Reset to config default — respects subscription mode (None = CLI default).
|
||||
sdk_model = _resolve_sdk_model()
|
||||
logger.info(
|
||||
"[SDK] [%s] Per-request model override: standard (%s)",
|
||||
session_id[:12] if session_id else "?",
|
||||
sdk_model or "subscription-default",
|
||||
)
|
||||
return sdk_model
|
||||
|
||||
return _resolve_sdk_model()
|
||||
logger.info(
|
||||
"[SDK] [%s] Resolved model for tier=%s: %s",
|
||||
session_id[:12] if session_id else "?",
|
||||
tier_name,
|
||||
sdk_model,
|
||||
)
|
||||
return sdk_model
|
||||
|
||||
|
||||
_MAX_TRANSIENT_BACKOFF_SECONDS = 30
|
||||
@@ -823,7 +930,11 @@ async def _do_transient_backoff(
|
||||
"""
|
||||
yield StreamStatus(message=f"Connection interrupted, retrying in {backoff}s…")
|
||||
await asyncio.sleep(backoff)
|
||||
state.adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
|
||||
state.adapter = SDKResponseAdapter(
|
||||
message_id=message_id,
|
||||
session_id=session_id,
|
||||
render_reasoning_in_ui=config.render_reasoning_in_ui,
|
||||
)
|
||||
state.usage.reset()
|
||||
|
||||
|
||||
@@ -2084,6 +2195,33 @@ async def _run_stream_attempt(
|
||||
len(state.adapter.resolved_tool_calls),
|
||||
)
|
||||
|
||||
# Capture OpenRouter generation IDs from each
|
||||
# ``AssistantMessage.message_id`` — when routed via OpenRouter
|
||||
# these are ``gen-...`` slugs we can use post-turn to query
|
||||
# ``/api/v1/generation?id=`` for the authoritative per-turn
|
||||
# cost and token counts (the CLI's ``total_cost_usd`` is
|
||||
# computed from a static Anthropic pricing table that
|
||||
# silently over-bills non-Anthropic routes). Direct-Anthropic
|
||||
# turns produce ``msg_...`` IDs which the generation endpoint
|
||||
# doesn't know about — harmlessly ignored at reconcile time.
|
||||
if isinstance(sdk_msg, AssistantMessage):
|
||||
msg_id = sdk_msg.message_id
|
||||
if (
|
||||
msg_id is not None
|
||||
and msg_id.startswith("gen-")
|
||||
and msg_id not in state.generation_ids
|
||||
):
|
||||
state.generation_ids.append(msg_id)
|
||||
# Track the model the SDK actually used — when a fallback
|
||||
# activates, this differs from ``state.options.model``.
|
||||
# Consumed by the Moonshot cost-override decision so we
|
||||
# don't mis-bill a fallback-Anthropic response at
|
||||
# Moonshot rates (or a fallback-Moonshot at Anthropic
|
||||
# rates).
|
||||
observed = getattr(sdk_msg, "model", None)
|
||||
if isinstance(observed, str) and observed:
|
||||
state.observed_model = observed
|
||||
|
||||
# Log AssistantMessage API errors (e.g. invalid_request)
|
||||
# so we can debug Anthropic API 400s surfaced by the CLI.
|
||||
sdk_error = getattr(sdk_msg, "error", None)
|
||||
@@ -2252,7 +2390,37 @@ async def _run_stream_attempt(
|
||||
state.usage.completion_tokens,
|
||||
)
|
||||
if sdk_msg.total_cost_usd is not None:
|
||||
state.usage.cost_usd = sdk_msg.total_cost_usd
|
||||
# Default: trust the CLI-reported value. Accurate for
|
||||
# Anthropic models (the CLI's bundled pricing table is
|
||||
# Anthropic-authored), and becomes the sync-path cost
|
||||
# when the reconcile is disabled or fails.
|
||||
# Prefer the ACTUALLY executed model
|
||||
# (``state.observed_model`` from ``AssistantMessage.model``)
|
||||
# over the requested primary (``state.options.model``)
|
||||
# so a fallback activation doesn't mis-route pricing.
|
||||
active_model = state.observed_model or getattr(
|
||||
state.options, "model", None
|
||||
)
|
||||
if _is_moonshot_model(active_model):
|
||||
# Moonshot slug — the CLI doesn't know Moonshot's
|
||||
# rate card and silently bills at Sonnet rates
|
||||
# (~5x over-charge). Replace with the rate-card
|
||||
# estimate so the in-stream ``cost_usd`` and the
|
||||
# reconcile's lookup-fail fallback reflect
|
||||
# reality. Reconcile
|
||||
# (``record_turn_cost_from_openrouter``) still
|
||||
# overrides this value when every gen-ID lookup
|
||||
# succeeds.
|
||||
state.usage.cost_usd = _override_cost_for_moonshot(
|
||||
model=active_model,
|
||||
sdk_reported_usd=sdk_msg.total_cost_usd,
|
||||
prompt_tokens=state.usage.prompt_tokens,
|
||||
completion_tokens=state.usage.completion_tokens,
|
||||
cache_read_tokens=state.usage.cache_read_tokens,
|
||||
cache_creation_tokens=state.usage.cache_creation_tokens,
|
||||
)
|
||||
else:
|
||||
state.usage.cost_usd = sdk_msg.total_cost_usd
|
||||
|
||||
# Emit compaction end if SDK finished compacting.
|
||||
# Sync TranscriptBuilder with the CLI's active context.
|
||||
@@ -2374,6 +2542,18 @@ async def _run_stream_attempt(
|
||||
skip_strip=response is tail_delta,
|
||||
)
|
||||
if dispatched is not None:
|
||||
# Persistence (via _dispatch_response) always runs so the
|
||||
# session transcript keeps role='reasoning' rows; the
|
||||
# wire is gated so UI can suppress rendering.
|
||||
if not state.adapter.render_reasoning_in_ui and isinstance(
|
||||
dispatched,
|
||||
(
|
||||
StreamReasoningStart,
|
||||
StreamReasoningDelta,
|
||||
StreamReasoningEnd,
|
||||
),
|
||||
):
|
||||
continue
|
||||
yield dispatched
|
||||
|
||||
# Mid-turn follow-up persistence: the MCP tool wrapper drains
|
||||
@@ -2434,16 +2614,36 @@ async def _run_stream_attempt(
|
||||
# flush the assistant message before tool_calls are set on it
|
||||
# (text and tool_use arrive as separate SDK events), the
|
||||
# tool_calls update is lost — the next flush starts past it.
|
||||
_msgs_since_flush += 1
|
||||
#
|
||||
# With ``include_partial_messages=True`` the CLI delivers
|
||||
# hundreds of ``StreamEvent`` messages per turn — incrementing
|
||||
# ``_msgs_since_flush`` on each one trips the threshold long
|
||||
# before the assistant text is complete, saving a truncated
|
||||
# prefix that subsequent deltas can never extend (append-only).
|
||||
# Count only messages that produce a persisted row boundary
|
||||
# (AssistantMessage, UserMessage, ResultMessage) and skip
|
||||
# raw StreamEvents. Also skip when text or reasoning is
|
||||
# still in-flight on the adapter: the row is live and a flush
|
||||
# would lock it at its current length.
|
||||
if not isinstance(sdk_msg, StreamEvent):
|
||||
_msgs_since_flush += 1
|
||||
now = time.monotonic()
|
||||
has_pending_tools = (
|
||||
acc.has_appended_assistant
|
||||
and acc.accumulated_tool_calls
|
||||
and not acc.has_tool_results
|
||||
)
|
||||
if not has_pending_tools and (
|
||||
_msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD
|
||||
or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS
|
||||
adapter = state.adapter
|
||||
has_open_block = (
|
||||
adapter.has_started_text and not adapter.has_ended_text
|
||||
) or (adapter.has_started_reasoning and not adapter.has_ended_reasoning)
|
||||
if (
|
||||
not has_pending_tools
|
||||
and not has_open_block
|
||||
and (
|
||||
_msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD
|
||||
or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS
|
||||
)
|
||||
):
|
||||
try:
|
||||
await asyncio.shield(upsert_chat_session(ctx.session))
|
||||
@@ -2920,6 +3120,12 @@ async def stream_chat_completion_sdk(
|
||||
# Defaults ensure the finally block can always reference these safely even when
|
||||
# an early return (e.g. sdk_cwd error) skips their normal assignment below.
|
||||
sdk_model: str | None = None
|
||||
# Wall-clock timestamp captured before the CLI runs so the
|
||||
# OpenRouter reconcile can filter subagent JSONLs by mtime — only
|
||||
# files created during THIS turn contribute gen-IDs. Without this
|
||||
# the sweep would pick up prior turns' compaction files that persist
|
||||
# under ``<session_id>/subagents/``, double-billing the user.
|
||||
turn_start_ts = time.time()
|
||||
|
||||
# Make sure there is no more code between the lock acquisition and try-block.
|
||||
try:
|
||||
@@ -3043,8 +3249,8 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
mcp_server = create_copilot_mcp_server(use_e2b=use_e2b)
|
||||
|
||||
# Resolve model (request tier → config default).
|
||||
sdk_model = await _resolve_sdk_model_for_request(model, session_id)
|
||||
# Resolve model (request tier → LD per-user override → config default).
|
||||
sdk_model = await _resolve_sdk_model_for_request(model, session_id, user_id)
|
||||
|
||||
# Track SDK-internal compaction (PreCompact hook → start, next msg → end)
|
||||
compaction = CompactionTracker()
|
||||
@@ -3056,10 +3262,18 @@ async def stream_chat_completion_sdk(
|
||||
on_compact=compaction.on_compact,
|
||||
)
|
||||
|
||||
disabled_tool_groups: list[ToolGroup] = []
|
||||
if not graphiti_enabled:
|
||||
disabled_tool_groups.append("graphiti")
|
||||
|
||||
if permissions is not None:
|
||||
allowed, disallowed = apply_tool_permissions(permissions, use_e2b=use_e2b)
|
||||
allowed, disallowed = apply_tool_permissions(
|
||||
permissions, use_e2b=use_e2b, disabled_groups=disabled_tool_groups
|
||||
)
|
||||
else:
|
||||
allowed = get_copilot_tool_names(use_e2b=use_e2b)
|
||||
allowed = get_copilot_tool_names(
|
||||
use_e2b=use_e2b, disabled_groups=disabled_tool_groups
|
||||
)
|
||||
disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b)
|
||||
|
||||
def _on_stderr(line: str) -> None:
|
||||
@@ -3129,6 +3343,17 @@ async def stream_chat_completion_sdk(
|
||||
sdk_options_kwargs["effort"] = config.claude_agent_thinking_effort
|
||||
if sdk_model:
|
||||
sdk_options_kwargs["model"] = sdk_model
|
||||
if config.sdk_include_partial_messages:
|
||||
# Opt into per-token streaming — the CLI emits raw Anthropic
|
||||
# ``content_block_delta`` events as ``StreamEvent`` messages
|
||||
# ahead of each summary ``AssistantMessage`` so reasoning and
|
||||
# text land on the wire token-by-token (matching the baseline
|
||||
# path's UX shipped in #12873). ``SDKResponseAdapter`` consumes
|
||||
# the partial stream via ``_handle_stream_event`` and emits
|
||||
# only the tail diff from the subsequent summary, so content
|
||||
# never double-emits and a summary-only short block still
|
||||
# reaches the UI.
|
||||
sdk_options_kwargs["include_partial_messages"] = True
|
||||
|
||||
if sdk_env:
|
||||
sdk_options_kwargs["env"] = sdk_env
|
||||
@@ -3160,7 +3385,11 @@ async def stream_chat_completion_sdk(
|
||||
|
||||
options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type] # dynamic kwargs
|
||||
|
||||
adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id)
|
||||
adapter = SDKResponseAdapter(
|
||||
message_id=message_id,
|
||||
session_id=session_id,
|
||||
render_reasoning_in_ui=config.render_reasoning_in_ui,
|
||||
)
|
||||
|
||||
# Propagate user_id/session_id as OTEL context attributes so the
|
||||
# langsmith tracing integration attaches them to every span. This
|
||||
@@ -3494,7 +3723,9 @@ async def stream_chat_completion_sdk(
|
||||
session, user_id, is_user_message, state.query_message
|
||||
)
|
||||
state.adapter = SDKResponseAdapter(
|
||||
message_id=message_id, session_id=session_id
|
||||
message_id=message_id,
|
||||
session_id=session_id,
|
||||
render_reasoning_in_ui=config.render_reasoning_in_ui,
|
||||
)
|
||||
# Reset token accumulators so a failed attempt's partial
|
||||
# usage is not double-counted in the successful attempt.
|
||||
@@ -3854,18 +4085,103 @@ async def stream_chat_completion_sdk(
|
||||
# --- Persist token usage to session + rate-limit counters ---
|
||||
# Both must live in finally so they stay consistent even when an
|
||||
# exception interrupts the try block after StreamUsage was yielded.
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
cache_read_tokens=turn_cache_read_tokens,
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
log_prefix=log_prefix,
|
||||
cost_usd=turn_cost_usd,
|
||||
model=sdk_model or config.thinking_standard_model,
|
||||
provider="anthropic",
|
||||
effective_model = sdk_model or config.thinking_standard_model
|
||||
# ``state`` is populated lazily inside the retry loop; when the
|
||||
# turn exits before the first attempt runs (e.g. very early
|
||||
# validation error) it's still None, so ``generation_ids`` is
|
||||
# empty by definition.
|
||||
collected_gen_ids: list[str] = (
|
||||
list(state.generation_ids) if state is not None else []
|
||||
)
|
||||
_use_openrouter_reconcile = bool(
|
||||
config.openrouter_active
|
||||
and config.sdk_reconcile_openrouter_cost
|
||||
and collected_gen_ids
|
||||
)
|
||||
|
||||
# CLI project dir — used by the reconcile task to sweep for
|
||||
# compaction subagents' gen-IDs. ``sdk_cwd`` is the per-session
|
||||
# CLI working directory; the CLI encodes it into the project-dir
|
||||
# name the same way ``encode_cwd_for_cli`` does, and writes
|
||||
# the main transcript + any ``subagents/`` alongside it under
|
||||
# ``~/.claude/projects/<encoded>/``. Empty when sdk_cwd isn't
|
||||
# set (shouldn't happen in practice for SDK turns).
|
||||
cli_project_dir: str | None = None
|
||||
if sdk_cwd:
|
||||
cli_project_dir = os.path.join(
|
||||
os.path.expanduser("~/.claude/projects"),
|
||||
encode_cwd_for_cli(sdk_cwd),
|
||||
)
|
||||
|
||||
if _use_openrouter_reconcile:
|
||||
# Defer the single cost-and-rate-limit write to a background
|
||||
# task that queries OpenRouter's authoritative
|
||||
# ``/generation?id=`` for every round in this turn. Covers
|
||||
# all vendors:
|
||||
#
|
||||
# * Non-Anthropic (Kimi et al): the CLI's ``total_cost_usd``
|
||||
# is computed from a static Anthropic rate table that
|
||||
# doesn't know the model — silently over-bills by ~5x.
|
||||
# The reconcile replaces it with OpenRouter's real bill.
|
||||
# * Anthropic via OpenRouter: the CLI's number matches
|
||||
# Anthropic's own rates penny-for-penny in the common
|
||||
# case, but the reconcile catches any rate change the
|
||||
# CLI binary hasn't picked up and any OpenRouter-side
|
||||
# divergence (cache-discount accounting, promo pricing).
|
||||
#
|
||||
# The task calls ``persist_and_record_usage`` exactly once
|
||||
# per turn — same method as the sync path, so append-only
|
||||
# cost-log + rate-limit counter update together. The sync
|
||||
# path below is skipped entirely when the reconcile fires,
|
||||
# so no double-counting. Kill-switch:
|
||||
# ``CHAT_SDK_RECONCILE_OPENROUTER_COST=false``.
|
||||
#
|
||||
# Brief window (~0.5-2s) where the rate-limit counter is
|
||||
# unaware of this turn — back-to-back turns in that window
|
||||
# see a stale counter.
|
||||
asyncio.create_task(
|
||||
record_turn_cost_from_openrouter(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
model=effective_model,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
cache_read_tokens=turn_cache_read_tokens,
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
generation_ids=collected_gen_ids,
|
||||
cli_project_dir=cli_project_dir,
|
||||
cli_session_id=session_id,
|
||||
turn_start_ts=turn_start_ts,
|
||||
fallback_cost_usd=turn_cost_usd,
|
||||
api_key=config.api_key,
|
||||
log_prefix=log_prefix,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Reconcile disabled, OpenRouter inactive, or subscription
|
||||
# path (no gen-IDs). Record the SDK CLI's
|
||||
# ``total_cost_usd`` synchronously: accurate for Anthropic
|
||||
# (same rate card as billing); for non-Anthropic it's the
|
||||
# rate-card estimate that ``_override_cost_for_non_anthropic``
|
||||
# caps (still 1.5-2x off vs real OpenRouter bill, but much
|
||||
# closer than the ~5x Sonnet-rate fallback).
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=turn_prompt_tokens,
|
||||
completion_tokens=turn_completion_tokens,
|
||||
cache_read_tokens=turn_cache_read_tokens,
|
||||
cache_creation_tokens=turn_cache_creation_tokens,
|
||||
log_prefix=log_prefix,
|
||||
cost_usd=turn_cost_usd,
|
||||
model=effective_model,
|
||||
# ``provider`` labels the cost-analytics row; the cost
|
||||
# value still comes from the SDK-reported number.
|
||||
# Tracks the actual upstream so the row matches reality:
|
||||
# OpenRouter when ``openrouter_active``, Anthropic
|
||||
# otherwise.
|
||||
provider=("open_router" if config.openrouter_active else "anthropic"),
|
||||
)
|
||||
|
||||
# --- Persist session messages ---
|
||||
# This MUST run in finally to persist messages even when the generator
|
||||
|
||||
@@ -366,38 +366,84 @@ class TestNormalizeModelName:
|
||||
The per-request model toggle calls _normalize_model_name with either
|
||||
``config.thinking_advanced_model`` (for 'advanced') or
|
||||
``config.thinking_standard_model`` (for 'standard'). These tests verify
|
||||
the OpenRouter/provider-prefix stripping that keeps the value compatible
|
||||
with the Claude CLI.
|
||||
the OpenRouter/direct-Anthropic split: OpenRouter routes by full
|
||||
``vendor/model`` slug, while direct-Anthropic strips the prefix and
|
||||
converts dots to hyphens.
|
||||
"""
|
||||
|
||||
def test_strips_anthropic_prefix(self):
|
||||
@pytest.fixture
|
||||
def _direct_anthropic_config(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Force ``config.openrouter_active = False`` for prefix-strip tests.
|
||||
|
||||
Pins the SDK model fields to anthropic/* so the new
|
||||
``_validate_sdk_model_vendor_compatibility`` model_validator
|
||||
permits ChatConfig construction.
|
||||
"""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4-7",
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
|
||||
@pytest.fixture
|
||||
def _openrouter_config(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Force ``config.openrouter_active = True`` for slug-preservation tests."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
|
||||
def test_strips_anthropic_prefix(self, _direct_anthropic_config):
|
||||
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
|
||||
|
||||
def test_strips_openai_prefix(self):
|
||||
assert _normalize_model_name("openai/gpt-4o") == "gpt-4o"
|
||||
def test_rejects_non_anthropic_vendor_in_direct_mode(
|
||||
self, _direct_anthropic_config
|
||||
):
|
||||
"""Direct-Anthropic mode must fail loudly on non-Anthropic vendor
|
||||
slugs — silent strip would send e.g. ``gpt-4o`` to the Anthropic
|
||||
API and produce an opaque model_not_found error."""
|
||||
with pytest.raises(ValueError, match="requires an Anthropic model"):
|
||||
_normalize_model_name("openai/gpt-4o")
|
||||
with pytest.raises(ValueError, match="requires an Anthropic model"):
|
||||
_normalize_model_name("moonshotai/kimi-k2.6")
|
||||
with pytest.raises(ValueError, match="requires an Anthropic model"):
|
||||
_normalize_model_name("google/gemini-2.5-flash")
|
||||
|
||||
def test_strips_google_prefix(self):
|
||||
assert _normalize_model_name("google/gemini-2.5-flash") == "gemini-2.5-flash"
|
||||
|
||||
def test_already_normalized_unchanged(self):
|
||||
def test_already_normalized_unchanged(self, _direct_anthropic_config):
|
||||
assert (
|
||||
_normalize_model_name("claude-sonnet-4-20250514")
|
||||
== "claude-sonnet-4-20250514"
|
||||
)
|
||||
|
||||
def test_empty_string_unchanged(self):
|
||||
def test_empty_string_unchanged(self, _direct_anthropic_config):
|
||||
assert _normalize_model_name("") == ""
|
||||
|
||||
def test_opus_model_roundtrip(self):
|
||||
"""The exact string used for the 'opus' toggle strips correctly."""
|
||||
assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6"
|
||||
def test_opus_model_dot_to_hyphen(self, _direct_anthropic_config):
|
||||
"""Direct-Anthropic mode: dots in versions become hyphens."""
|
||||
assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4-6"
|
||||
|
||||
def test_sonnet_openrouter_model(self):
|
||||
"""Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly."""
|
||||
def test_openrouter_keeps_anthropic_slug(self, _openrouter_config):
|
||||
"""OpenRouter routes by full slug — keep prefix and dots intact."""
|
||||
assert (
|
||||
_normalize_model_name("anthropic/claude-sonnet-4-6") == "claude-sonnet-4-6"
|
||||
_normalize_model_name("anthropic/claude-sonnet-4.6")
|
||||
== "anthropic/claude-sonnet-4.6"
|
||||
)
|
||||
|
||||
def test_openrouter_keeps_kimi_slug(self, _openrouter_config):
|
||||
"""Non-Anthropic vendors (Moonshot) require the prefix to route."""
|
||||
assert _normalize_model_name("moonshotai/kimi-k2.6") == "moonshotai/kimi-k2.6"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug)
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import base64
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -17,6 +18,7 @@ from .service import (
|
||||
_normalize_model_name,
|
||||
_prepare_file_attachments,
|
||||
_resolve_sdk_model,
|
||||
_resolve_sdk_model_for_request,
|
||||
_safe_close_sdk_client,
|
||||
)
|
||||
|
||||
@@ -355,11 +357,16 @@ class TestNormalizeModelName:
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
# Pin SDK slugs to anthropic/* so the new
|
||||
# _validate_sdk_model_vendor_compatibility allows construction.
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4-7",
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4-6"
|
||||
|
||||
def test_dots_preserved_for_openrouter(self, monkeypatch, _clean_config_env):
|
||||
def test_openrouter_keeps_full_slug(self, monkeypatch, _clean_config_env):
|
||||
"""OpenRouter routes by ``vendor/model`` slug — keep prefix and dots."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
@@ -369,7 +376,11 @@ class TestNormalizeModelName:
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4.6"
|
||||
assert (
|
||||
_normalize_model_name("anthropic/claude-opus-4.6")
|
||||
== "anthropic/claude-opus-4.6"
|
||||
)
|
||||
assert _normalize_model_name("moonshotai/kimi-k2.6") == "moonshotai/kimi-k2.6"
|
||||
|
||||
def test_no_prefix_no_dots(self, monkeypatch, _clean_config_env):
|
||||
from backend.copilot import config as cfg_mod
|
||||
@@ -379,6 +390,8 @@ class TestNormalizeModelName:
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4-7",
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert (
|
||||
@@ -390,8 +403,9 @@ class TestNormalizeModelName:
|
||||
class TestResolveSdkModel:
|
||||
"""Tests for _resolve_sdk_model — model ID resolution for the SDK CLI."""
|
||||
|
||||
def test_openrouter_active_keeps_dots(self, monkeypatch, _clean_config_env):
|
||||
"""When OpenRouter is fully active, model keeps dot-separated version."""
|
||||
def test_openrouter_active_keeps_full_slug(self, monkeypatch, _clean_config_env):
|
||||
"""When OpenRouter is fully active, the canonical vendor/model slug
|
||||
is preserved so OpenRouter can route to the correct provider."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
@@ -403,7 +417,23 @@ class TestResolveSdkModel:
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "claude-opus-4.6"
|
||||
assert _resolve_sdk_model() == "anthropic/claude-opus-4.6"
|
||||
|
||||
def test_openrouter_active_kimi_slug(self, monkeypatch, _clean_config_env):
|
||||
"""Non-Anthropic models (Kimi via Moonshot) require the prefix to
|
||||
survive OpenRouter routing — strip would leave an unroutable slug."""
|
||||
from backend.copilot import config as cfg_mod
|
||||
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
thinking_standard_model="moonshotai/kimi-k2.6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
assert _resolve_sdk_model() == "moonshotai/kimi-k2.6"
|
||||
|
||||
def test_openrouter_disabled_normalizes_to_hyphens(
|
||||
self, monkeypatch, _clean_config_env
|
||||
@@ -488,6 +518,213 @@ class TestResolveSdkModel:
|
||||
assert _resolve_sdk_model() == "claude-opus-4-6"
|
||||
|
||||
|
||||
class TestResolveSdkModelForRequestLdFallback:
|
||||
"""``_resolve_sdk_model_for_request`` must fail soft when the LD value
|
||||
can't be normalised for the active routing mode — flagged as MAJOR by
|
||||
CodeRabbit + HIGH by Sentry when it was a hard ValueError."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_anthropic_mode_rejects_kimi_ld_value_and_falls_back(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""LD serves ``moonshotai/kimi-k2.6`` but we're on direct-Anthropic
|
||||
(no OpenRouter key). ``_normalize_model_name`` raises; the
|
||||
resolver must log + return the config-default path instead of
|
||||
500-ing the turn."""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
|
||||
new=AsyncMock(return_value="moonshotai/kimi-k2.6"),
|
||||
):
|
||||
resolved = await _resolve_sdk_model_for_request(
|
||||
model="standard", session_id="sess-abc", user_id="user-1"
|
||||
)
|
||||
|
||||
# Fallback == tier-specific config default (thinking_standard_model
|
||||
# normalised to hyphen-form for direct-Anthropic mode).
|
||||
assert resolved == "claude-sonnet-4-6"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openrouter_mode_accepts_ld_kimi_value(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""On OpenRouter the Kimi slug is legitimate — no fallback,
|
||||
value returned as-is."""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
|
||||
new=AsyncMock(return_value="moonshotai/kimi-k2.6"),
|
||||
):
|
||||
resolved = await _resolve_sdk_model_for_request(
|
||||
model="standard", session_id="sess-abc", user_id="user-1"
|
||||
)
|
||||
assert resolved == "moonshotai/kimi-k2.6"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_advanced_tier_fallback_uses_advanced_default_not_standard(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""An LD-rejected ADVANCED slug must fall back to the advanced
|
||||
config default (Opus) — not the standard default (Sonnet).
|
||||
Using ``_resolve_sdk_model()`` as the fallback silently
|
||||
downgraded the user's chosen tier. Flagged MAJOR by CodeRabbit
|
||||
+ HIGH by Sentry on the first fail-soft commit."""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4.7",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
|
||||
new=AsyncMock(return_value="moonshotai/kimi-k2.6"),
|
||||
):
|
||||
resolved = await _resolve_sdk_model_for_request(
|
||||
model="advanced", session_id="sess-adv", user_id="user-1"
|
||||
)
|
||||
|
||||
# Direct-Anthropic normalises anthropic/claude-opus-4.7 → claude-opus-4-7
|
||||
assert resolved == "claude-opus-4-7"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_standard_ld_override_wins_over_subscription(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""Bug reported in local test: subscription mode + LD serving Kimi
|
||||
on ``copilot-thinking-standard-model`` returned ``None`` (CLI
|
||||
picked subscription default Opus), silently ignoring the LD
|
||||
override. An LD value different from the config default is an
|
||||
explicit admin decision and must win."""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
|
||||
new=AsyncMock(return_value="moonshotai/kimi-k2.6"),
|
||||
):
|
||||
resolved = await _resolve_sdk_model_for_request(
|
||||
model="standard", session_id="sess-std-sub", user_id="user-1"
|
||||
)
|
||||
# Expect LD-served Kimi, NOT None (the old subscription-default bypass)
|
||||
assert resolved == "moonshotai/kimi-k2.6"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_standard_subscription_survives_trailing_whitespace_in_env(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""``_resolve_thinking_model_for_user`` strips whitespace from the LD
|
||||
side; the config tier default must be stripped too, otherwise a
|
||||
stray trailing space in ``CHAT_THINKING_STANDARD_MODEL`` makes
|
||||
``resolved == tier_default`` spuriously False and bypasses
|
||||
subscription-default mode. Sentry HIGH on L856."""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6 ", # trailing spaces
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
|
||||
new=AsyncMock(return_value="anthropic/claude-sonnet-4-6"),
|
||||
):
|
||||
resolved = await _resolve_sdk_model_for_request(
|
||||
model="standard", session_id="sess-ws", user_id="user-1"
|
||||
)
|
||||
assert resolved is None, (
|
||||
"LD value semantically matches the whitespace-padded config "
|
||||
"default — subscription mode must still win and return None"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_standard_subscription_default_honoured_when_ld_matches_config(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""When LD serves the SAME value as the config default (i.e. the
|
||||
flag is effectively unset / no override), subscription mode still
|
||||
wins and we return ``None`` so the CLI uses the subscription
|
||||
default model."""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=False,
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
|
||||
new=AsyncMock(return_value="anthropic/claude-sonnet-4-6"),
|
||||
):
|
||||
resolved = await _resolve_sdk_model_for_request(
|
||||
model="standard", session_id="sess-std-nop", user_id="user-1"
|
||||
)
|
||||
assert resolved is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_advanced_tier_consults_ld_under_subscription(
|
||||
self, monkeypatch, _clean_config_env
|
||||
):
|
||||
"""Subscription mode bypasses LD only on the standard tier —
|
||||
the advanced tier always consults LD because the user explicitly
|
||||
asked for the premium path. A subscription + advanced request
|
||||
with LD-served Opus must return Opus (not ``None``)."""
|
||||
cfg = cfg_mod.ChatConfig(
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4.7",
|
||||
claude_agent_model=None,
|
||||
use_openrouter=True,
|
||||
api_key="or-key",
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
use_claude_code_subscription=True,
|
||||
)
|
||||
monkeypatch.setattr("backend.copilot.sdk.service.config", cfg)
|
||||
|
||||
with patch(
|
||||
"backend.copilot.sdk.service._resolve_thinking_model_for_user",
|
||||
new=AsyncMock(return_value="anthropic/claude-opus-4.7"),
|
||||
):
|
||||
resolved = await _resolve_sdk_model_for_request(
|
||||
model="advanced", session_id="sess-adv-sub", user_id="user-1"
|
||||
)
|
||||
assert resolved == "anthropic/claude-opus-4.7"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_sdk_disconnect_error — classify client disconnect cleanup errors
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -651,6 +888,8 @@ class TestSystemPromptPreset:
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4-7",
|
||||
)
|
||||
assert cfg.claude_agent_cross_user_prompt_cache is True
|
||||
|
||||
@@ -662,6 +901,8 @@ class TestSystemPromptPreset:
|
||||
api_key=None,
|
||||
base_url=None,
|
||||
use_claude_code_subscription=False,
|
||||
thinking_standard_model="anthropic/claude-sonnet-4-6",
|
||||
thinking_advanced_model="anthropic/claude-opus-4-7",
|
||||
)
|
||||
assert cfg.claude_agent_cross_user_prompt_cache is False
|
||||
|
||||
@@ -674,3 +915,225 @@ class TestIdleTimeoutConstant:
|
||||
|
||||
def test_idle_timeout_is_10_min(self):
|
||||
assert _IDLE_TIMEOUT_SECONDS == 10 * 60
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _RetryState.observed_model — Moonshot cost-override input
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRetryStateObservedModel:
|
||||
"""Regression guards for the ``observed_model`` field added to
|
||||
``_RetryState``. The Moonshot cost override reads this — when a
|
||||
fallback model activates mid-attempt, the requested primary
|
||||
(``state.options.model``) no longer matches what actually ran."""
|
||||
|
||||
def _make_state(self, *, options_model: str | None = "primary/model"):
|
||||
"""Build a minimally-valid ``_RetryState``. All the heavy
|
||||
collaborators are ``MagicMock()`` — the field we care about is
|
||||
a plain Optional[str], so the surrounding scaffolding just needs
|
||||
to let the dataclass instantiate."""
|
||||
from .service import _RetryState, _TokenUsage
|
||||
|
||||
options = MagicMock()
|
||||
options.model = options_model
|
||||
return _RetryState(
|
||||
options=options,
|
||||
query_message="",
|
||||
was_compacted=False,
|
||||
use_resume=False,
|
||||
resume_file=None,
|
||||
transcript_msg_count=0,
|
||||
adapter=MagicMock(),
|
||||
transcript_builder=MagicMock(),
|
||||
usage=_TokenUsage(),
|
||||
)
|
||||
|
||||
def test_default_is_none(self):
|
||||
state = self._make_state()
|
||||
assert state.observed_model is None
|
||||
|
||||
def test_assigned_from_assistant_message_model(self):
|
||||
"""Simulates the population path in ``_run_stream_attempt``:
|
||||
``observed`` is pulled off the ``AssistantMessage.model`` attr
|
||||
and assigned onto ``state.observed_model`` when it's a non-empty
|
||||
string."""
|
||||
state = self._make_state()
|
||||
# Simulates the inline assignment the generator does on each
|
||||
# AssistantMessage — a non-empty string lands on state.
|
||||
assistant_like = SimpleNamespace(model="anthropic/claude-sonnet-4-6")
|
||||
observed = getattr(assistant_like, "model", None)
|
||||
if isinstance(observed, str) and observed:
|
||||
state.observed_model = observed
|
||||
assert state.observed_model == "anthropic/claude-sonnet-4-6"
|
||||
|
||||
def test_empty_string_model_is_not_assigned(self):
|
||||
"""Guard against overwriting a real observed value with an
|
||||
empty-string model (the generator's ``and observed`` check)."""
|
||||
state = self._make_state()
|
||||
state.observed_model = "moonshotai/kimi-k2.6" # seeded from a prior msg
|
||||
assistant_like = SimpleNamespace(model="")
|
||||
observed = getattr(assistant_like, "model", None)
|
||||
if isinstance(observed, str) and observed:
|
||||
state.observed_model = observed
|
||||
assert state.observed_model == "moonshotai/kimi-k2.6"
|
||||
|
||||
def test_missing_model_attr_leaves_observed_untouched(self):
|
||||
state = self._make_state()
|
||||
state.observed_model = "moonshotai/kimi-k2.6"
|
||||
# AssistantMessage may not carry ``.model`` on older SDK rels.
|
||||
assistant_like = SimpleNamespace() # no ``.model`` attr
|
||||
observed = getattr(assistant_like, "model", None)
|
||||
if isinstance(observed, str) and observed:
|
||||
state.observed_model = observed
|
||||
assert state.observed_model == "moonshotai/kimi-k2.6"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Moonshot cost-override gate — decision logic at the call site
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMoonshotCostOverrideGate:
|
||||
"""Regression guards for the decision logic in
|
||||
``_run_stream_attempt`` that picks between the CLI-reported cost
|
||||
and the Moonshot rate-card override. The code:
|
||||
|
||||
active_model = state.observed_model or getattr(state.options, "model", None)
|
||||
if _is_moonshot_model(active_model):
|
||||
state.usage.cost_usd = _override_cost_for_moonshot(...)
|
||||
else:
|
||||
state.usage.cost_usd = sdk_msg.total_cost_usd
|
||||
|
||||
is critical-path billing logic — make sure observed_model wins over
|
||||
the requested primary, and Anthropic turns pass through untouched."""
|
||||
|
||||
def _decide_cost(
|
||||
self,
|
||||
*,
|
||||
observed_model: str | None,
|
||||
options_model: str | None,
|
||||
sdk_reported_usd: float,
|
||||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
) -> float:
|
||||
"""Mirror of the real decision block — lets us assert the gate
|
||||
without constructing the whole 1000-line generator."""
|
||||
from .service import _is_moonshot_model, _override_cost_for_moonshot
|
||||
|
||||
active_model = observed_model or options_model
|
||||
if _is_moonshot_model(active_model):
|
||||
return _override_cost_for_moonshot(
|
||||
model=active_model,
|
||||
sdk_reported_usd=sdk_reported_usd,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cache_read_tokens=0,
|
||||
cache_creation_tokens=0,
|
||||
)
|
||||
return sdk_reported_usd
|
||||
|
||||
def test_anthropic_turn_passes_sdk_cost_through(self):
|
||||
"""Anthropic — the CLI's pricing table is authoritative, so
|
||||
``state.usage.cost_usd`` is set to ``sdk_msg.total_cost_usd``
|
||||
unchanged."""
|
||||
cost = self._decide_cost(
|
||||
observed_model="anthropic/claude-sonnet-4-6",
|
||||
options_model="anthropic/claude-sonnet-4-6",
|
||||
sdk_reported_usd=0.123,
|
||||
)
|
||||
assert cost == 0.123
|
||||
|
||||
def test_moonshot_turn_uses_rate_card_override(self):
|
||||
"""Moonshot — the CLI would silently bill at Sonnet rates, so
|
||||
the override recomputes from the Moonshot rate card."""
|
||||
cost = self._decide_cost(
|
||||
observed_model="moonshotai/kimi-k2.6",
|
||||
options_model="moonshotai/kimi-k2.6",
|
||||
sdk_reported_usd=0.089862, # CLI's Sonnet-priced estimate.
|
||||
prompt_tokens=29564,
|
||||
completion_tokens=78,
|
||||
)
|
||||
expected = (29564 * 0.60 + 78 * 2.80) / 1_000_000
|
||||
assert cost == pytest.approx(expected, rel=1e-9)
|
||||
# Sanity: ~5x cheaper than the CLI's Sonnet-priced number.
|
||||
assert cost < 0.089862 / 4
|
||||
|
||||
def test_observed_model_wins_over_options_primary(self):
|
||||
"""The whole point of ``observed_model``: a Moonshot-primary
|
||||
request that fell back to Anthropic must NOT get Moonshot
|
||||
pricing applied. The gate follows the observed model, not the
|
||||
requested primary."""
|
||||
cost = self._decide_cost(
|
||||
observed_model="anthropic/claude-sonnet-4-6",
|
||||
options_model="moonshotai/kimi-k2.6", # what we ASKED for
|
||||
sdk_reported_usd=0.123,
|
||||
prompt_tokens=1000,
|
||||
completion_tokens=100,
|
||||
)
|
||||
# Observed == Anthropic → CLI-reported cost passes through unchanged.
|
||||
assert cost == 0.123
|
||||
|
||||
def test_anthropic_to_moonshot_fallback_uses_override(self):
|
||||
"""The inverse: an Anthropic-primary request that fell back to
|
||||
Moonshot must get the Moonshot override applied — the CLI is
|
||||
still billing at Sonnet rates for the fallback response."""
|
||||
cost = self._decide_cost(
|
||||
observed_model="moonshotai/kimi-k2.6",
|
||||
options_model="anthropic/claude-sonnet-4-6",
|
||||
sdk_reported_usd=0.089862,
|
||||
prompt_tokens=29564,
|
||||
completion_tokens=78,
|
||||
)
|
||||
expected = (29564 * 0.60 + 78 * 2.80) / 1_000_000
|
||||
assert cost == pytest.approx(expected, rel=1e-9)
|
||||
|
||||
def test_no_observed_falls_back_to_options_model(self):
|
||||
"""First AssistantMessage hasn't arrived yet (or the SDK didn't
|
||||
emit ``.model``) — the gate falls back to the requested primary."""
|
||||
cost = self._decide_cost(
|
||||
observed_model=None,
|
||||
options_model="moonshotai/kimi-k2.6",
|
||||
sdk_reported_usd=0.089862,
|
||||
prompt_tokens=100,
|
||||
completion_tokens=10,
|
||||
)
|
||||
expected = (100 * 0.60 + 10 * 2.80) / 1_000_000
|
||||
assert cost == pytest.approx(expected, rel=1e-9)
|
||||
|
||||
def test_both_none_passes_sdk_cost_through(self):
|
||||
"""Subscription mode — ``options.model`` may be None and no
|
||||
AssistantMessage has arrived yet. ``None`` is not a Moonshot
|
||||
slug so the SDK number lands unchanged."""
|
||||
cost = self._decide_cost(
|
||||
observed_model=None,
|
||||
options_model=None,
|
||||
sdk_reported_usd=0.05,
|
||||
)
|
||||
assert cost == 0.05
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Moonshot helper re-exports — keep imports stable for call-site code
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMoonshotHelperReexports:
|
||||
"""``sdk/service.py`` imports the Moonshot helpers under local
|
||||
aliases (``_is_moonshot_model``, ``_override_cost_for_moonshot``).
|
||||
Regression guard so a refactor doesn't silently break the import
|
||||
path the hot-loop code relies on."""
|
||||
|
||||
def test_is_moonshot_model_aliased(self):
|
||||
from backend.copilot.moonshot import is_moonshot_model as canonical
|
||||
|
||||
from .service import _is_moonshot_model
|
||||
|
||||
assert _is_moonshot_model is canonical
|
||||
|
||||
def test_override_cost_for_moonshot_aliased(self):
|
||||
from backend.copilot.moonshot import override_cost_usd as canonical
|
||||
|
||||
from .service import _override_cost_for_moonshot
|
||||
|
||||
assert _override_cost_for_moonshot is canonical
|
||||
|
||||
@@ -10,6 +10,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Iterable
|
||||
from contextvars import ContextVar
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -33,7 +34,7 @@ from backend.copilot.sdk.file_ref import (
|
||||
expand_file_refs_in_args,
|
||||
read_file_bytes,
|
||||
)
|
||||
from backend.copilot.tools import TOOL_REGISTRY
|
||||
from backend.copilot.tools import TOOL_REGISTRY, ToolGroup, tool_names_in_groups
|
||||
from backend.copilot.tools.base import BaseTool
|
||||
from backend.util.truncate import truncate
|
||||
|
||||
@@ -853,9 +854,29 @@ DANGEROUS_PATTERNS = [
|
||||
r"subprocess",
|
||||
]
|
||||
|
||||
# Platform-tool names whose MCP wrappers must NOT be exposed to SDK mode.
|
||||
# Baseline ships an MCP ``TodoWrite`` for model-flexibility parity; SDK mode
|
||||
# keeps using the CLI-native built-in listed in ``_SDK_BUILTIN_ALWAYS`` so
|
||||
# there is no double exposure. Public (no leading underscore) so a future
|
||||
# refactor renaming it is visible at both call sites —
|
||||
# ``permissions.apply_tool_permissions`` maps short tool names back to the
|
||||
# CLI built-in form for SDK mode.
|
||||
BASELINE_ONLY_MCP_TOOLS: frozenset[str] = frozenset({"TodoWrite"})
|
||||
|
||||
|
||||
def _registry_mcp_tools(*, hidden: frozenset[str] = frozenset()) -> list[str]:
|
||||
return [
|
||||
f"{MCP_TOOL_PREFIX}{name}"
|
||||
for name in TOOL_REGISTRY.keys()
|
||||
if name not in BASELINE_ONLY_MCP_TOOLS and name not in hidden
|
||||
]
|
||||
|
||||
|
||||
# Static tool name list for the non-E2B case (backward compatibility).
|
||||
# Includes all capability-gated tools; per-user filtering happens in
|
||||
# ``get_copilot_tool_names`` when the caller passes ``disabled_groups``.
|
||||
COPILOT_TOOL_NAMES = [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
*_registry_mcp_tools(),
|
||||
f"{MCP_TOOL_PREFIX}{WRITE_TOOL_NAME}",
|
||||
f"{MCP_TOOL_PREFIX}{READ_TOOL_NAME}",
|
||||
f"{MCP_TOOL_PREFIX}{EDIT_TOOL_NAME}",
|
||||
@@ -864,20 +885,31 @@ COPILOT_TOOL_NAMES = [
|
||||
]
|
||||
|
||||
|
||||
def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]:
|
||||
def get_copilot_tool_names(
|
||||
*,
|
||||
use_e2b: bool = False,
|
||||
disabled_groups: Iterable[ToolGroup] = (),
|
||||
) -> list[str]:
|
||||
"""Build the ``allowed_tools`` list for :class:`ClaudeAgentOptions`.
|
||||
|
||||
When *use_e2b* is True the SDK built-in file tools are replaced by MCP
|
||||
equivalents that route to the E2B sandbox.
|
||||
equivalents that route to the E2B sandbox. Tools belonging to any of
|
||||
*disabled_groups* are filtered out — see ``ToolGroup`` / ``TOOL_GROUPS``
|
||||
in ``backend.copilot.tools`` for the full list.
|
||||
"""
|
||||
hidden_short_names = tool_names_in_groups(disabled_groups)
|
||||
hidden_mcp_names = {f"{MCP_TOOL_PREFIX}{n}" for n in hidden_short_names}
|
||||
|
||||
if not use_e2b:
|
||||
return list(COPILOT_TOOL_NAMES)
|
||||
if not hidden_mcp_names:
|
||||
return list(COPILOT_TOOL_NAMES)
|
||||
return [n for n in COPILOT_TOOL_NAMES if n not in hidden_mcp_names]
|
||||
|
||||
# In E2B mode, Write/Edit are NOT registered (E2B uses write_file/edit_file
|
||||
# from E2B_FILE_TOOLS instead), so don't include them here.
|
||||
# _READ_TOOL_NAME is still needed for SDK tool-result reads.
|
||||
return [
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()],
|
||||
*_registry_mcp_tools(hidden=hidden_short_names),
|
||||
f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}",
|
||||
*[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES],
|
||||
*_SDK_BUILTIN_ALWAYS,
|
||||
|
||||
@@ -1249,7 +1249,14 @@ class TestStripStaleThinkingBlocks:
|
||||
new_asst = self._asst_entry(
|
||||
"msg_new",
|
||||
[
|
||||
{"type": "thinking", "thinking": "latest thoughts"},
|
||||
# Anthropic-shape thinking block (has signature) — preserved
|
||||
# on the last turn. Signature-less variant covered by
|
||||
# ``test_strips_signatureless_last_turn_thinking``.
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "latest thoughts",
|
||||
"signature": "anthropic-sig",
|
||||
},
|
||||
{"type": "text", "text": "world"},
|
||||
],
|
||||
uuid="a2",
|
||||
@@ -1271,11 +1278,16 @@ class TestStripStaleThinkingBlocks:
|
||||
assert new_content[1]["type"] == "text"
|
||||
|
||||
def test_preserves_last_assistant_thinking(self) -> None:
|
||||
"""The last assistant entry's thinking blocks must be preserved."""
|
||||
"""The last assistant entry's thinking blocks must be preserved
|
||||
when they carry an Anthropic ``signature``. Signature-less
|
||||
blocks (e.g. from Kimi K2.6 via OpenRouter) are stripped to
|
||||
prevent ``Invalid `signature` in `thinking` block`` errors on
|
||||
a subsequent Anthropic-model turn — covered by
|
||||
``test_strips_signatureless_last_turn_thinking``."""
|
||||
entry = self._asst_entry(
|
||||
"msg_only",
|
||||
[
|
||||
{"type": "thinking", "thinking": "must keep"},
|
||||
{"type": "thinking", "thinking": "must keep", "signature": "sig"},
|
||||
{"type": "text", "text": "response"},
|
||||
],
|
||||
)
|
||||
@@ -1284,6 +1296,47 @@ class TestStripStaleThinkingBlocks:
|
||||
lines = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
assert len(lines[0]["message"]["content"]) == 2
|
||||
|
||||
def test_strips_signatureless_last_turn_thinking(self) -> None:
|
||||
"""Cross-model fix (PR #12878): a signature-less thinking block
|
||||
on the last assistant turn must be stripped before a subsequent
|
||||
Anthropic-model dispatch tries to replay it."""
|
||||
entry = self._asst_entry(
|
||||
"msg_kimi",
|
||||
[
|
||||
# No signature → non-Anthropic provider
|
||||
{"type": "thinking", "thinking": "kimi reasoning"},
|
||||
{"type": "text", "text": "answer"},
|
||||
],
|
||||
)
|
||||
content = _make_jsonl(entry)
|
||||
result = strip_stale_thinking_blocks(content)
|
||||
lines = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
types = [b["type"] for b in lines[0]["message"]["content"]]
|
||||
assert "thinking" not in types
|
||||
assert "text" in types
|
||||
|
||||
def test_preserves_redacted_thinking_on_last_turn(self) -> None:
|
||||
"""``redacted_thinking`` blocks are signature-less by design
|
||||
(encrypted ``data`` field instead). Stripping them on the last
|
||||
turn would violate Anthropic's value-identity requirement for
|
||||
multi-turn replay. The signature rule only applies to plain
|
||||
``thinking`` blocks."""
|
||||
entry = self._asst_entry(
|
||||
"msg_anthropic",
|
||||
[
|
||||
# Anthropic-emitted redacted_thinking: has ``data``,
|
||||
# never has ``signature``.
|
||||
{"type": "redacted_thinking", "data": "encrypted_blob"},
|
||||
{"type": "text", "text": "response"},
|
||||
],
|
||||
)
|
||||
content = _make_jsonl(entry)
|
||||
result = strip_stale_thinking_blocks(content)
|
||||
lines = [json.loads(ln) for ln in result.strip().split("\n")]
|
||||
types = [b["type"] for b in lines[0]["message"]["content"]]
|
||||
assert "redacted_thinking" in types
|
||||
assert "text" in types
|
||||
|
||||
def test_no_assistant_entries_returns_unchanged(self) -> None:
|
||||
"""Transcripts with only user entries should pass through unchanged."""
|
||||
user = self._user_entry("hello")
|
||||
@@ -1294,12 +1347,13 @@ class TestStripStaleThinkingBlocks:
|
||||
assert strip_stale_thinking_blocks("") == ""
|
||||
|
||||
def test_multiple_turns_strips_all_but_last(self) -> None:
|
||||
"""With 3 assistant turns, only the last keeps thinking blocks."""
|
||||
"""With 3 assistant turns, only the last keeps thinking blocks
|
||||
(and only when those blocks carry an Anthropic ``signature``)."""
|
||||
entries = [
|
||||
self._asst_entry(
|
||||
"msg_1",
|
||||
[
|
||||
{"type": "thinking", "thinking": "t1"},
|
||||
{"type": "thinking", "thinking": "t1", "signature": "s1"},
|
||||
{"type": "text", "text": "a1"},
|
||||
],
|
||||
uuid="a1",
|
||||
@@ -1308,7 +1362,7 @@ class TestStripStaleThinkingBlocks:
|
||||
self._asst_entry(
|
||||
"msg_2",
|
||||
[
|
||||
{"type": "thinking", "thinking": "t2"},
|
||||
{"type": "thinking", "thinking": "t2", "signature": "s2"},
|
||||
{"type": "text", "text": "a2"},
|
||||
],
|
||||
uuid="a2",
|
||||
@@ -1318,7 +1372,7 @@ class TestStripStaleThinkingBlocks:
|
||||
self._asst_entry(
|
||||
"msg_3",
|
||||
[
|
||||
{"type": "thinking", "thinking": "t3"},
|
||||
{"type": "thinking", "thinking": "t3", "signature": "s3"},
|
||||
{"type": "text", "text": "a3"},
|
||||
],
|
||||
uuid="a3",
|
||||
@@ -1339,16 +1393,18 @@ class TestStripStaleThinkingBlocks:
|
||||
assert lines[4]["message"]["content"][0]["type"] == "thinking"
|
||||
|
||||
def test_same_msg_id_multi_entry_turn(self) -> None:
|
||||
"""Multiple entries sharing the same message.id (same turn) are preserved."""
|
||||
"""Multiple entries sharing the same message.id (same turn) are
|
||||
preserved when their thinking blocks carry an Anthropic
|
||||
``signature``."""
|
||||
entries = [
|
||||
self._asst_entry(
|
||||
"msg_old",
|
||||
[{"type": "thinking", "thinking": "old"}],
|
||||
[{"type": "thinking", "thinking": "old", "signature": "old_sig"}],
|
||||
uuid="a1",
|
||||
),
|
||||
self._asst_entry(
|
||||
"msg_last",
|
||||
[{"type": "thinking", "thinking": "t_part1"}],
|
||||
[{"type": "thinking", "thinking": "t_part1", "signature": "p1_sig"}],
|
||||
uuid="a2",
|
||||
parent="a1",
|
||||
),
|
||||
|
||||
@@ -17,6 +17,7 @@ from langfuse import get_client
|
||||
from langfuse.openai import (
|
||||
AsyncOpenAI as LangfuseAsyncOpenAI, # pyright: ignore[reportPrivateImportUsage]
|
||||
)
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
from backend.data.db_accessors import chat_db, understanding_db
|
||||
from backend.data.understanding import (
|
||||
@@ -34,6 +35,7 @@ from .model import (
|
||||
update_session_title,
|
||||
upsert_chat_session,
|
||||
)
|
||||
from .token_tracking import persist_and_record_usage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -495,20 +497,31 @@ async def _generate_session_title(
|
||||
message: str,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> str | None:
|
||||
) -> tuple[str | None, ChatCompletion | None]:
|
||||
"""Generate a concise title for a chat session based on the first message.
|
||||
|
||||
Returns ``(title, response)``. The caller is responsible for
|
||||
persisting the title AND recording the title call's cost — keeping
|
||||
them as separate concerns in the caller lets a cost-tracking hiccup
|
||||
not lose the title, and lets a title-persist failure still record
|
||||
the cost (we paid for the LLM call either way).
|
||||
|
||||
Args:
|
||||
message: The first user message in the session
|
||||
user_id: User ID for OpenRouter tracing (optional)
|
||||
session_id: Session ID for OpenRouter tracing (optional)
|
||||
|
||||
Returns:
|
||||
A short title (3-6 words) or None if generation fails
|
||||
``(title, response)`` on success; ``(None, None)`` if the LLM
|
||||
call raised. ``response`` is returned even when ``title`` is
|
||||
empty so the caller can still record the (paid-for) cost.
|
||||
"""
|
||||
try:
|
||||
# Build extra_body for OpenRouter tracing and PostHog analytics
|
||||
extra_body: dict[str, Any] = {}
|
||||
# Build extra_body for OpenRouter tracing and PostHog analytics.
|
||||
# ``usage: {"include": True}`` asks OR to embed the real billed
|
||||
# cost into the final usage chunk — matches the baseline path's
|
||||
# ``_OPENROUTER_INCLUDE_USAGE_COST`` pattern, same read path.
|
||||
extra_body: dict[str, Any] = {"usage": {"include": True}}
|
||||
if user_id:
|
||||
extra_body["user"] = user_id[:128] # OpenRouter limit
|
||||
extra_body["posthogDistinctId"] = user_id
|
||||
@@ -534,18 +547,113 @@ async def _generate_session_title(
|
||||
max_tokens=20,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
title = response.choices[0].message.content
|
||||
if title:
|
||||
# Clean up the title
|
||||
title = title.strip().strip("\"'")
|
||||
# Limit length
|
||||
if len(title) > 50:
|
||||
title = title[:47] + "..."
|
||||
return title
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to generate session title: {e}")
|
||||
return None
|
||||
return None, None
|
||||
|
||||
# Robust against an empty ``choices`` list OR a choice whose
|
||||
# ``message`` is missing ``content`` (shouldn't happen on the OpenAI
|
||||
# SDK typing, but belt-and-suspenders — the background task would
|
||||
# otherwise die on ``IndexError`` and lose the (paid-for) cost
|
||||
# recording we're about to do below).
|
||||
title: str | None = None
|
||||
if response.choices:
|
||||
msg = response.choices[0].message
|
||||
title = msg.content if msg is not None else None
|
||||
if title:
|
||||
title = title.strip().strip("\"'")
|
||||
if len(title) > 50:
|
||||
title = title[:47] + "..."
|
||||
return title, response
|
||||
|
||||
|
||||
def _title_usage_from_response(
|
||||
response: ChatCompletion,
|
||||
) -> tuple[int, int, float | None]:
|
||||
"""Extract ``(prompt_tokens, completion_tokens, cost_usd)`` from a
|
||||
title-generation chat-completion response.
|
||||
|
||||
Returns zeros / ``None`` for missing fields — the OpenAI SDK's
|
||||
``CompletionUsage`` doesn't declare OpenRouter's ``cost`` extension,
|
||||
so we read it off ``model_extra`` (pydantic v2 extras container).
|
||||
Absent for non-OR routes; returned as ``None`` in that case.
|
||||
"""
|
||||
usage = response.usage
|
||||
if usage is None:
|
||||
return 0, 0, None
|
||||
prompt_tokens = usage.prompt_tokens or 0
|
||||
completion_tokens = usage.completion_tokens or 0
|
||||
extras = usage.model_extra or {}
|
||||
cost_raw = extras.get("cost") if isinstance(extras, dict) else None
|
||||
if isinstance(cost_raw, (int, float)):
|
||||
cost_usd: float | None = float(cost_raw)
|
||||
else:
|
||||
cost_usd = None
|
||||
return prompt_tokens, completion_tokens, cost_usd
|
||||
|
||||
|
||||
async def _record_title_generation_cost(
|
||||
*,
|
||||
response: ChatCompletion,
|
||||
user_id: str | None,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
"""Persist the title LLM call's cost to ``PlatformCostLog``.
|
||||
|
||||
Title generation runs in a background task per-session — low cost
|
||||
(~$0.0001 per title) but 100% of sessions pay it. Without this the
|
||||
admin dashboard under-reports total provider spend by the aggregate
|
||||
of those calls. Separate ``block_name="copilot:title"`` so the row
|
||||
is clearly distinguishable from the turn's main ``copilot:SDK`` /
|
||||
``copilot:baseline`` attributions.
|
||||
|
||||
Invariants enforced by the caller:
|
||||
* ``response`` is a completed ``ChatCompletion`` (the create call
|
||||
didn't raise) — so ``response.usage`` shape is SDK-contractual.
|
||||
* Exceptions are NOT suppressed — the caller runs this AFTER
|
||||
title persistence so a persist failure here doesn't lose the
|
||||
title, and a real DB / Prisma outage surfaces in the caller's
|
||||
single background-task warning handler.
|
||||
"""
|
||||
prompt_tokens, completion_tokens, cost_usd = _title_usage_from_response(response)
|
||||
|
||||
# Nothing meaningful to record — skip the DB roundtrip entirely
|
||||
# rather than writing a zero-valued row. Covers the non-OR route
|
||||
# (no ``usage.cost`` field) and the degenerate zero-tokens case.
|
||||
if cost_usd is None and prompt_tokens == 0 and completion_tokens == 0:
|
||||
return
|
||||
|
||||
# Provider label is derived from the configured ``base_url`` (title
|
||||
# LLM uses the shared copilot OpenAI client whose base URL mirrors
|
||||
# ``ChatConfig.base_url``). This lets a deployment that points
|
||||
# title generation at a non-OR endpoint still get the correct
|
||||
# ``provider`` on the cost-log row.
|
||||
provider = (
|
||||
"open_router"
|
||||
if (config.base_url and "openrouter.ai" in config.base_url)
|
||||
else "openai"
|
||||
)
|
||||
|
||||
# Intentionally pass ``session=None``. ``persist_and_record_usage``
|
||||
# would otherwise append a ``Usage`` entry to the live session
|
||||
# object, but this background task holds no reference to the
|
||||
# request-scoped session — we'd have to ``get_chat_session`` +
|
||||
# ``upsert_chat_session`` round-trip the mutation back, and the
|
||||
# turn's main ``persist_and_record_usage`` already owns the session
|
||||
# usage-list mirror for the originating turn. Title cost is
|
||||
# recorded into ``PlatformCostLog`` (admin dashboard) and the
|
||||
# microdollar rate-limit counter — those are the two places that
|
||||
# actually matter for this call.
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=user_id,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
log_prefix="[title]",
|
||||
cost_usd=cost_usd,
|
||||
model=config.title_model,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
|
||||
async def _update_title_async(
|
||||
@@ -553,15 +661,29 @@ async def _update_title_async(
|
||||
) -> None:
|
||||
"""Generate and persist a session title in the background.
|
||||
|
||||
Shared by both the SDK and baseline execution paths.
|
||||
Shared by both the SDK and baseline execution paths. Title
|
||||
persistence and cost recording are run as independent best-effort
|
||||
steps — a failure in one does not cancel the other, so a flaky
|
||||
Prisma call on cost recording never costs us the generated title.
|
||||
"""
|
||||
try:
|
||||
title = await _generate_session_title(message, user_id, session_id)
|
||||
if title and user_id:
|
||||
title, response = await _generate_session_title(message, user_id, session_id)
|
||||
|
||||
if title and user_id:
|
||||
try:
|
||||
await update_session_title(session_id, user_id, title, only_if_empty=True)
|
||||
logger.debug("Generated title for session %s", session_id)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to update session title for %s: %s", session_id, e)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to persist session title for %s: %s", session_id, e)
|
||||
|
||||
if response is not None:
|
||||
try:
|
||||
await _record_title_generation_cost(
|
||||
response=response, user_id=user_id, session_id=session_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to record title generation cost for %s: %s", session_id, e
|
||||
)
|
||||
|
||||
|
||||
async def assign_user_to_session(
|
||||
|
||||
541
autogpt_platform/backend/backend/copilot/service_unit_test.py
Normal file
541
autogpt_platform/backend/backend/copilot/service_unit_test.py
Normal file
@@ -0,0 +1,541 @@
|
||||
"""Unit tests for title-generation cost tracking helpers.
|
||||
|
||||
Covers the new code added in PR #12882:
|
||||
* ``_title_usage_from_response`` — shape-robust OR ``usage.cost`` extraction
|
||||
* ``_record_title_generation_cost`` — provider-label + zero-tokens gate
|
||||
* ``_update_title_async`` — independent title / cost persistence try blocks
|
||||
* ``_generate_session_title`` — tuple return + robustness against empty choices
|
||||
|
||||
Mocks ``persist_and_record_usage`` / ``update_session_title`` at the boundary
|
||||
where the code under test imports them (``backend.copilot.service.*``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from backend.copilot.service import (
|
||||
_generate_session_title,
|
||||
_record_title_generation_cost,
|
||||
_title_usage_from_response,
|
||||
_update_title_async,
|
||||
)
|
||||
|
||||
|
||||
def _build_completion(
|
||||
*,
|
||||
content: str | None = "Hello Title",
|
||||
usage: CompletionUsage | None = None,
|
||||
choices: list[Choice] | None = None,
|
||||
) -> ChatCompletion:
|
||||
if choices is None:
|
||||
msg = ChatCompletionMessage(role="assistant", content=content)
|
||||
choices = [Choice(index=0, message=msg, finish_reason="stop")]
|
||||
return ChatCompletion(
|
||||
id="cmpl-1",
|
||||
choices=choices,
|
||||
created=0,
|
||||
model="anthropic/claude-haiku",
|
||||
object="chat.completion",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
def _usage_with_cost(cost: object | None) -> CompletionUsage:
|
||||
"""Return a CompletionUsage whose ``model_extra`` carries ``cost``.
|
||||
|
||||
Uses ``model_validate`` so OpenRouter's ``cost`` extension lands in
|
||||
the pydantic ``model_extra`` dict the helper reads from.
|
||||
"""
|
||||
payload: dict[str, object] = {
|
||||
"prompt_tokens": 12,
|
||||
"completion_tokens": 3,
|
||||
"total_tokens": 15,
|
||||
}
|
||||
if cost is not None:
|
||||
payload["cost"] = cost
|
||||
return CompletionUsage.model_validate(payload)
|
||||
|
||||
|
||||
class TestTitleUsageFromResponse:
|
||||
"""``_title_usage_from_response`` returns sensible zeros/Nones when
|
||||
optional fields are absent or of unexpected shape."""
|
||||
|
||||
def test_usage_none_returns_all_zero(self):
|
||||
resp = _build_completion(usage=None)
|
||||
prompt, completion, cost = _title_usage_from_response(resp)
|
||||
assert prompt == 0
|
||||
assert completion == 0
|
||||
assert cost is None
|
||||
|
||||
def test_missing_cost_field_returns_none_cost(self):
|
||||
resp = _build_completion(usage=_usage_with_cost(None))
|
||||
prompt, completion, cost = _title_usage_from_response(resp)
|
||||
assert prompt == 12
|
||||
assert completion == 3
|
||||
assert cost is None
|
||||
|
||||
def test_cost_as_int_is_coerced_to_float(self):
|
||||
resp = _build_completion(usage=_usage_with_cost(2))
|
||||
_, _, cost = _title_usage_from_response(resp)
|
||||
assert isinstance(cost, float)
|
||||
assert cost == 2.0
|
||||
|
||||
def test_cost_as_float_is_returned_as_is(self):
|
||||
resp = _build_completion(usage=_usage_with_cost(0.000123))
|
||||
_, _, cost = _title_usage_from_response(resp)
|
||||
assert cost == pytest.approx(0.000123)
|
||||
|
||||
def test_cost_as_non_numeric_string_returns_none(self):
|
||||
resp = _build_completion(usage=_usage_with_cost("free"))
|
||||
_, _, cost = _title_usage_from_response(resp)
|
||||
assert cost is None
|
||||
|
||||
def test_empty_model_extra_returns_none_cost(self):
|
||||
# ``model_extra`` is empty for non-OR routes where pydantic didn't
|
||||
# receive any extras — prompt/completion still flow through.
|
||||
usage = CompletionUsage(prompt_tokens=5, completion_tokens=2, total_tokens=7)
|
||||
resp = _build_completion(usage=usage)
|
||||
prompt, completion, cost = _title_usage_from_response(resp)
|
||||
assert (prompt, completion, cost) == (5, 2, None)
|
||||
|
||||
def test_zero_prompt_and_completion_tokens(self):
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
resp = _build_completion(usage=usage)
|
||||
prompt, completion, cost = _title_usage_from_response(resp)
|
||||
assert (prompt, completion, cost) == (0, 0, None)
|
||||
|
||||
|
||||
class TestRecordTitleGenerationCost:
|
||||
"""``_record_title_generation_cost`` persists cost + picks the right
|
||||
provider label and skips the DB roundtrip when nothing's meaningful
|
||||
to record."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openrouter_base_url_uses_open_router_provider(self):
|
||||
resp = _build_completion(usage=_usage_with_cost(0.0002))
|
||||
persist = AsyncMock(return_value=0)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.service.persist_and_record_usage",
|
||||
new=persist,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service.config",
|
||||
MagicMock(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
title_model="anthropic/claude-haiku",
|
||||
),
|
||||
),
|
||||
):
|
||||
await _record_title_generation_cost(
|
||||
response=resp, user_id="u", session_id="s"
|
||||
)
|
||||
persist.assert_awaited_once()
|
||||
kwargs = persist.await_args.kwargs
|
||||
assert kwargs["provider"] == "open_router"
|
||||
assert kwargs["model"] == "anthropic/claude-haiku"
|
||||
assert kwargs["prompt_tokens"] == 12
|
||||
assert kwargs["completion_tokens"] == 3
|
||||
assert kwargs["cost_usd"] == pytest.approx(0.0002)
|
||||
assert kwargs["log_prefix"] == "[title]"
|
||||
assert kwargs["session"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_openrouter_base_url_uses_openai_provider(self):
|
||||
resp = _build_completion(usage=_usage_with_cost(0.0002))
|
||||
persist = AsyncMock(return_value=0)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.service.persist_and_record_usage",
|
||||
new=persist,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service.config",
|
||||
MagicMock(
|
||||
base_url="https://api.openai.com/v1",
|
||||
title_model="gpt-4o-mini",
|
||||
),
|
||||
),
|
||||
):
|
||||
await _record_title_generation_cost(
|
||||
response=resp, user_id="u", session_id="s"
|
||||
)
|
||||
persist.assert_awaited_once()
|
||||
assert persist.await_args.kwargs["provider"] == "openai"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_base_url_uses_openai_provider(self):
|
||||
resp = _build_completion(usage=_usage_with_cost(0.0001))
|
||||
persist = AsyncMock(return_value=0)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.service.persist_and_record_usage",
|
||||
new=persist,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service.config",
|
||||
MagicMock(base_url=None, title_model="gpt-4o-mini"),
|
||||
),
|
||||
):
|
||||
await _record_title_generation_cost(
|
||||
response=resp, user_id=None, session_id=None
|
||||
)
|
||||
persist.assert_awaited_once()
|
||||
assert persist.await_args.kwargs["provider"] == "openai"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zero_tokens_zero_cost_skips_persist(self):
|
||||
"""No cost, no tokens — the early return avoids a worthless
|
||||
``PlatformCostLog`` row."""
|
||||
usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0)
|
||||
resp = _build_completion(usage=usage)
|
||||
persist = AsyncMock(return_value=0)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.service.persist_and_record_usage",
|
||||
new=persist,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service.config",
|
||||
MagicMock(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
title_model="x",
|
||||
),
|
||||
),
|
||||
):
|
||||
await _record_title_generation_cost(
|
||||
response=resp, user_id="u", session_id="s"
|
||||
)
|
||||
persist.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_usage_none_skips_persist(self):
|
||||
"""``usage`` absent on the response == provider didn't report —
|
||||
still short-circuits to avoid writing a zero-valued row."""
|
||||
resp = _build_completion(usage=None)
|
||||
persist = AsyncMock(return_value=0)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.service.persist_and_record_usage",
|
||||
new=persist,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service.config",
|
||||
MagicMock(
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
title_model="x",
|
||||
),
|
||||
),
|
||||
):
|
||||
await _record_title_generation_cost(
|
||||
response=resp, user_id="u", session_id="s"
|
||||
)
|
||||
persist.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokens_without_cost_still_records(self):
|
||||
"""Tokens present but ``cost`` missing (non-OR route) still
|
||||
records a row so token counts are captured — ``cost_usd=None``
|
||||
is accepted by ``persist_and_record_usage``."""
|
||||
usage = CompletionUsage(prompt_tokens=8, completion_tokens=2, total_tokens=10)
|
||||
resp = _build_completion(usage=usage)
|
||||
persist = AsyncMock(return_value=0)
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.service.persist_and_record_usage",
|
||||
new=persist,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service.config",
|
||||
MagicMock(base_url=None, title_model="m"),
|
||||
),
|
||||
):
|
||||
await _record_title_generation_cost(
|
||||
response=resp, user_id="u", session_id="s"
|
||||
)
|
||||
persist.assert_awaited_once()
|
||||
assert persist.await_args.kwargs["cost_usd"] is None
|
||||
assert persist.await_args.kwargs["prompt_tokens"] == 8
|
||||
|
||||
|
||||
class TestUpdateTitleAsync:
|
||||
"""``_update_title_async`` runs title persistence and cost recording
|
||||
as independent best-effort steps — a failure in one does NOT
|
||||
cancel the other."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_success_cost_success(self):
|
||||
resp = _build_completion(usage=_usage_with_cost(0.0001))
|
||||
gen = AsyncMock(return_value=("My Title", resp))
|
||||
update = AsyncMock(return_value=True)
|
||||
record = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.service._generate_session_title",
|
||||
new=gen,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service.update_session_title",
|
||||
new=update,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service._record_title_generation_cost",
|
||||
new=record,
|
||||
),
|
||||
):
|
||||
await _update_title_async("sess-1", "hello", user_id="u1")
|
||||
|
||||
update.assert_awaited_once_with("sess-1", "u1", "My Title", only_if_empty=True)
|
||||
record.assert_awaited_once()
|
||||
assert record.await_args.kwargs["response"] is resp
|
||||
assert record.await_args.kwargs["user_id"] == "u1"
|
||||
assert record.await_args.kwargs["session_id"] == "sess-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_title_persist_fails_but_cost_still_recorded(self):
|
||||
resp = _build_completion(usage=_usage_with_cost(0.0001))
|
||||
gen = AsyncMock(return_value=("Title", resp))
|
||||
update = AsyncMock(side_effect=RuntimeError("prisma boom"))
|
||||
record = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.service._generate_session_title",
|
||||
new=gen,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service.update_session_title",
|
||||
new=update,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service._record_title_generation_cost",
|
||||
new=record,
|
||||
),
|
||||
):
|
||||
# Must NOT raise — persist failure is swallowed.
|
||||
await _update_title_async("sess-2", "msg", user_id="u")
|
||||
|
||||
update.assert_awaited_once()
|
||||
record.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cost_record_fails_but_title_was_persisted(self):
|
||||
resp = _build_completion(usage=_usage_with_cost(0.0001))
|
||||
gen = AsyncMock(return_value=("Title", resp))
|
||||
update = AsyncMock(return_value=True)
|
||||
record = AsyncMock(side_effect=RuntimeError("cost record boom"))
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.service._generate_session_title",
|
||||
new=gen,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service.update_session_title",
|
||||
new=update,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service._record_title_generation_cost",
|
||||
new=record,
|
||||
),
|
||||
):
|
||||
# Must NOT raise — cost-recording failure is swallowed.
|
||||
await _update_title_async("sess-3", "msg", user_id="u")
|
||||
|
||||
update.assert_awaited_once()
|
||||
record.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_user_id_skips_title_persist_but_records_cost(self):
|
||||
"""Anonymous sessions skip the user-scoped title write, but we
|
||||
still paid for the LLM call — cost recording runs regardless."""
|
||||
resp = _build_completion(usage=_usage_with_cost(0.0001))
|
||||
gen = AsyncMock(return_value=("Title", resp))
|
||||
update = AsyncMock()
|
||||
record = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.service._generate_session_title",
|
||||
new=gen,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service.update_session_title",
|
||||
new=update,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service._record_title_generation_cost",
|
||||
new=record,
|
||||
),
|
||||
):
|
||||
await _update_title_async("sess-4", "msg", user_id=None)
|
||||
|
||||
update.assert_not_awaited()
|
||||
record.assert_awaited_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generation_returns_none_response_skips_cost(self):
|
||||
"""``_generate_session_title`` swallows exceptions and returns
|
||||
``(None, None)`` — no response means no cost to record."""
|
||||
gen = AsyncMock(return_value=(None, None))
|
||||
update = AsyncMock()
|
||||
record = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.service._generate_session_title",
|
||||
new=gen,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service.update_session_title",
|
||||
new=update,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service._record_title_generation_cost",
|
||||
new=record,
|
||||
),
|
||||
):
|
||||
await _update_title_async("sess-5", "msg", user_id="u")
|
||||
|
||||
update.assert_not_awaited()
|
||||
record.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_title_with_response_still_records_cost(self):
|
||||
"""Title came back empty but we still paid for the LLM call —
|
||||
cost recording runs even though the title write is skipped."""
|
||||
resp = _build_completion(usage=_usage_with_cost(0.0001))
|
||||
gen = AsyncMock(return_value=(None, resp))
|
||||
update = AsyncMock()
|
||||
record = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.service._generate_session_title",
|
||||
new=gen,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service.update_session_title",
|
||||
new=update,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.service._record_title_generation_cost",
|
||||
new=record,
|
||||
),
|
||||
):
|
||||
await _update_title_async("sess-6", "msg", user_id="u")
|
||||
|
||||
update.assert_not_awaited()
|
||||
record.assert_awaited_once()
|
||||
|
||||
|
||||
class TestGenerateSessionTitle:
|
||||
"""``_generate_session_title`` returns ``(title, response)`` — the
|
||||
caller owns both the persist and the cost-record decisions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_response_returns_cleaned_title_and_response(self):
|
||||
# Code strips whitespace, then strips ``"'`` — whitespace inside
|
||||
# the quotes survives on purpose (titles like ``My Agent`` read
|
||||
# better than ``MyAgent``). Test keeps the outer quotes + inner
|
||||
# whitespace distinct so the ordering is pinned.
|
||||
resp = _build_completion(content='"Clean Me" ')
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(return_value=resp)
|
||||
with patch(
|
||||
"backend.copilot.service._get_openai_client",
|
||||
return_value=client,
|
||||
):
|
||||
title, response = await _generate_session_title(
|
||||
"first message", user_id="u", session_id="s"
|
||||
)
|
||||
assert title == "Clean Me"
|
||||
assert response is resp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_title_truncated_with_ellipsis(self):
|
||||
"""Titles >50 chars get truncated to 47 + '...'."""
|
||||
long_title = "A" * 80
|
||||
resp = _build_completion(content=long_title)
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(return_value=resp)
|
||||
with patch(
|
||||
"backend.copilot.service._get_openai_client",
|
||||
return_value=client,
|
||||
):
|
||||
title, _ = await _generate_session_title("x", user_id=None)
|
||||
assert title is not None
|
||||
assert len(title) == 50
|
||||
assert title.endswith("...")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_choices_returns_none_title_with_response(self):
|
||||
"""No ``choices`` on the response (shouldn't happen per SDK
|
||||
typing) must not raise IndexError — response is preserved so the
|
||||
caller can still record the paid-for cost."""
|
||||
resp = _build_completion(choices=[])
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(return_value=resp)
|
||||
with patch(
|
||||
"backend.copilot.service._get_openai_client",
|
||||
return_value=client,
|
||||
):
|
||||
title, response = await _generate_session_title("x")
|
||||
assert title is None
|
||||
assert response is resp
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_message_returns_none_title(self):
|
||||
"""A choice whose ``.message`` is absent produces a None title
|
||||
but the response still lands on the caller."""
|
||||
fake_choice = SimpleNamespace(message=None)
|
||||
fake_response = SimpleNamespace(choices=[fake_choice])
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(return_value=fake_response)
|
||||
with patch(
|
||||
"backend.copilot.service._get_openai_client",
|
||||
return_value=client,
|
||||
):
|
||||
title, response = await _generate_session_title("x")
|
||||
assert title is None
|
||||
assert response is fake_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_call_raises_returns_none_none(self):
|
||||
"""Network / API errors on the create call are swallowed;
|
||||
``(None, None)`` ensures the caller skips both title and cost
|
||||
without crashing the background task."""
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(
|
||||
side_effect=RuntimeError("connection reset")
|
||||
)
|
||||
with patch(
|
||||
"backend.copilot.service._get_openai_client",
|
||||
return_value=client,
|
||||
):
|
||||
title, response = await _generate_session_title("x")
|
||||
assert title is None
|
||||
assert response is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_receives_usage_include_extra_body(self):
|
||||
"""PR adds ``usage: {'include': True}`` so OpenRouter embeds the
|
||||
real billed cost into the final usage chunk."""
|
||||
resp = _build_completion(content="Title")
|
||||
client = MagicMock()
|
||||
client.chat.completions.create = AsyncMock(return_value=resp)
|
||||
with patch(
|
||||
"backend.copilot.service._get_openai_client",
|
||||
return_value=client,
|
||||
):
|
||||
await _generate_session_title(
|
||||
"hello world", user_id="user-abc", session_id="sess-abc"
|
||||
)
|
||||
client.chat.completions.create.assert_awaited_once()
|
||||
extra_body = client.chat.completions.create.await_args.kwargs["extra_body"]
|
||||
assert extra_body["usage"] == {"include": True}
|
||||
assert extra_body["user"] == "user-abc"
|
||||
assert extra_body["session_id"] == "sess-abc"
|
||||
@@ -485,9 +485,11 @@ async def subscribe_to_session(
|
||||
subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue()
|
||||
stream_key = _get_turn_stream_key(session.turn_id)
|
||||
|
||||
# Step 1: Replay messages from Redis Stream
|
||||
# Replay batch capped by ``stream_replay_count``.
|
||||
xread_start = time.perf_counter()
|
||||
messages = await redis.xread({stream_key: last_message_id}, block=None, count=1000)
|
||||
messages = await redis.xread(
|
||||
{stream_key: last_message_id}, block=None, count=config.stream_replay_count
|
||||
)
|
||||
xread_time = (time.perf_counter() - xread_start) * 1000
|
||||
logger.info(
|
||||
f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={session_status}",
|
||||
@@ -1024,8 +1026,8 @@ async def get_active_session(
|
||||
|
||||
# Check if session is stale (running beyond tool timeout + buffer).
|
||||
# Auto-complete it to prevent infinite polling loops.
|
||||
# Synchronous tools can run up to COPILOT_CONSUMER_TIMEOUT_SECONDS (1 hour),
|
||||
# so we add a 5-minute buffer to avoid false positives during legitimate operations.
|
||||
# A turn can legitimately run up to COPILOT_CONSUMER_TIMEOUT_SECONDS, so we
|
||||
# add a 5-minute buffer to avoid false positives during legitimate operations.
|
||||
created_at_str = meta.get("created_at")
|
||||
if created_at_str:
|
||||
try:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from openai.types.chat import ChatCompletionToolParam
|
||||
|
||||
@@ -10,7 +11,6 @@ from backend.copilot.tracking import track_tool_called
|
||||
from .add_understanding import AddUnderstandingTool
|
||||
from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool
|
||||
from .agent_output import AgentOutputTool
|
||||
from .ask_question import AskQuestionTool
|
||||
from .base import BaseTool
|
||||
from .bash_exec import BashExecTool
|
||||
from .connect_integration import ConnectIntegrationTool
|
||||
@@ -44,6 +44,7 @@ from .run_block import RunBlockTool
|
||||
from .run_mcp_tool import RunMCPToolTool
|
||||
from .run_sub_session import RunSubSessionTool
|
||||
from .search_docs import SearchDocsTool
|
||||
from .todo_write import TodoWriteTool
|
||||
from .validate_agent import ValidateAgentGraphTool
|
||||
from .web_fetch import WebFetchTool
|
||||
from .web_search import WebSearchTool
|
||||
@@ -63,7 +64,6 @@ logger = logging.getLogger(__name__)
|
||||
# Single source of truth for all tools
|
||||
TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"add_understanding": AddUnderstandingTool(),
|
||||
"ask_question": AskQuestionTool(),
|
||||
"create_agent": CreateAgentTool(),
|
||||
"customize_agent": CustomizeAgentTool(),
|
||||
"decompose_goal": DecomposeGoalTool(),
|
||||
@@ -88,6 +88,7 @@ TOOL_REGISTRY: dict[str, BaseTool] = {
|
||||
"continue_run_block": ContinueRunBlockTool(),
|
||||
"run_sub_session": RunSubSessionTool(),
|
||||
"get_sub_session_result": GetSubSessionResultTool(),
|
||||
"TodoWrite": TodoWriteTool(),
|
||||
"run_mcp_tool": RunMCPToolTool(),
|
||||
"get_mcp_guide": GetMCPGuideTool(),
|
||||
"view_agent_output": AgentOutputTool(),
|
||||
@@ -123,15 +124,45 @@ find_agent_tool = TOOL_REGISTRY["find_agent"]
|
||||
run_agent_tool = TOOL_REGISTRY["run_agent"]
|
||||
|
||||
|
||||
def get_available_tools() -> list[ChatCompletionToolParam]:
|
||||
# Capability groups a tool may belong to. The service layer can hide all
|
||||
# tools in a group when the backing capability isn't available to this user
|
||||
# (e.g. Graphiti memory behind a feature flag), so the model doesn't reach
|
||||
# for tools whose backend is off and then hit opaque runtime errors. Add
|
||||
# a new group by extending ``ToolGroup`` and registering its members in
|
||||
# ``TOOL_GROUPS`` below.
|
||||
ToolGroup = Literal["graphiti"]
|
||||
|
||||
TOOL_GROUPS: dict[str, ToolGroup] = {
|
||||
"memory_store": "graphiti",
|
||||
"memory_search": "graphiti",
|
||||
"memory_forget_search": "graphiti",
|
||||
"memory_forget_confirm": "graphiti",
|
||||
}
|
||||
|
||||
|
||||
def tool_names_in_groups(groups: Iterable[ToolGroup]) -> frozenset[str]:
|
||||
"""Return the set of tool short-names belonging to any of *groups*."""
|
||||
group_set = frozenset(groups)
|
||||
return frozenset(name for name, g in TOOL_GROUPS.items() if g in group_set)
|
||||
|
||||
|
||||
def get_available_tools(
|
||||
*,
|
||||
disabled_groups: Iterable[ToolGroup] = (),
|
||||
) -> list[ChatCompletionToolParam]:
|
||||
"""Return OpenAI tool schemas for tools available in the current environment.
|
||||
|
||||
Called per-request so that env-var or binary availability is evaluated
|
||||
fresh each time (e.g. browser_* tools are excluded when agent-browser
|
||||
CLI is not installed).
|
||||
CLI is not installed). Tools belonging to any *disabled_groups* are
|
||||
also filtered out — use this to hide capability-gated tools (e.g.
|
||||
``graphiti`` when the memory backend is off for the current user).
|
||||
"""
|
||||
hidden = tool_names_in_groups(disabled_groups)
|
||||
return [
|
||||
tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available
|
||||
tool.as_openai_tool()
|
||||
for name, tool in TOOL_REGISTRY.items()
|
||||
if tool.is_available and name not in hidden
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -181,7 +181,9 @@ async def execute_block(
|
||||
# (e.g., "42" → 42, string booleans → bool, enum defaults applied).
|
||||
coerce_inputs_to_schema(input_data, block.input_schema)
|
||||
outputs: dict[str, list[Any]] = defaultdict(list)
|
||||
async for output_name, output_data in simulate_block(block, input_data):
|
||||
async for output_name, output_data in simulate_block(
|
||||
block, input_data, user_id=user_id
|
||||
):
|
||||
outputs[output_name].append(output_data)
|
||||
# simulator signals internal failure via ("error", "[SIMULATOR ERROR …]")
|
||||
sim_error = outputs.get("error", [])
|
||||
|
||||
@@ -91,6 +91,9 @@ class ResponseType(str, Enum):
|
||||
MEMORY_FORGET_CANDIDATES = "memory_forget_candidates"
|
||||
MEMORY_FORGET_CONFIRM = "memory_forget_confirm"
|
||||
|
||||
# Planning
|
||||
TODO_WRITE = "todo_write"
|
||||
|
||||
|
||||
# Base response model
|
||||
class ToolResponseBase(BaseModel):
|
||||
@@ -605,11 +608,13 @@ class WebSearchResponse(ToolResponseBase):
|
||||
|
||||
type: ResponseType = ResponseType.WEB_SEARCH
|
||||
query: str
|
||||
# Web-grounded synthesised answer the search provider wrote from
|
||||
# fresh page content. The LLM caller should read this directly
|
||||
# instead of re-fetching each citation URL — many sites are
|
||||
# bot-protected and ``web_fetch`` won't get through. Empty string
|
||||
# when the provider returned only citations.
|
||||
answer: str = ""
|
||||
results: list[WebSearchResult] = Field(default_factory=list)
|
||||
# Backend-reported usage for this call (copied from Anthropic's
|
||||
# ``usage.server_tool_use``). Surfaces as metadata for frontend
|
||||
# debug panels but is also what drives rate-limit / cost tracking
|
||||
# via ``persist_and_record_usage(provider="anthropic")``.
|
||||
search_requests: int = 0
|
||||
|
||||
|
||||
@@ -896,3 +901,36 @@ class MemoryForgetConfirmResponse(ToolResponseBase):
|
||||
type: ResponseType = ResponseType.MEMORY_FORGET_CONFIRM
|
||||
deleted_uuids: list[str] = Field(default_factory=list)
|
||||
failed_uuids: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
# --- Planning ---
|
||||
|
||||
|
||||
class TodoItem(BaseModel):
|
||||
"""One entry in a ``TodoWrite`` checklist.
|
||||
|
||||
Mirrors the schema used by Claude Code's built-in ``TodoWrite`` tool so
|
||||
the frontend's ``GenericTool`` accordion renders baseline-emitted todos
|
||||
identically to SDK-emitted ones.
|
||||
"""
|
||||
|
||||
content: str = Field(description="Imperative description of the task.")
|
||||
activeForm: str = Field(
|
||||
description="Present-continuous form shown while the task is running.",
|
||||
)
|
||||
status: Literal["pending", "in_progress", "completed"] = Field(
|
||||
default="pending",
|
||||
)
|
||||
|
||||
|
||||
class TodoWriteResponse(ToolResponseBase):
|
||||
"""Ack returned by ``TodoWrite``.
|
||||
|
||||
The tool is effectively stateless — the authoritative task list lives in
|
||||
the assistant's latest tool-call arguments, which are replayed from the
|
||||
transcript on each turn. The tool output only needs to confirm that the
|
||||
update was accepted so the model can proceed.
|
||||
"""
|
||||
|
||||
type: ResponseType = ResponseType.TODO_WRITE
|
||||
todos: list[TodoItem] = Field(default_factory=list)
|
||||
|
||||
@@ -237,7 +237,7 @@ async def test_execute_block_dry_run_skips_real_execution():
|
||||
mock_block = make_mock_block()
|
||||
mock_block.execute = AsyncMock() # should NOT be called
|
||||
|
||||
async def fake_simulate(block, input_data):
|
||||
async def fake_simulate(block, input_data, **_kwargs):
|
||||
yield "result", "simulated"
|
||||
|
||||
# Patching at helpers.simulate_block works because helpers.py imports
|
||||
@@ -267,7 +267,7 @@ async def test_execute_block_dry_run_response_format():
|
||||
"""Dry-run response should look like a normal success (no dry-run signal to LLM)."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate(block, input_data):
|
||||
async def fake_simulate(block, input_data, **_kwargs):
|
||||
yield "result", "simulated"
|
||||
|
||||
with patch(
|
||||
@@ -331,7 +331,7 @@ async def test_execute_block_real_execution_unchanged():
|
||||
# Just verify simulate_block is NOT called.
|
||||
simulate_called = False
|
||||
|
||||
async def fake_simulate(block, input_data):
|
||||
async def fake_simulate(block, input_data, **_kwargs):
|
||||
nonlocal simulate_called
|
||||
simulate_called = True
|
||||
yield "result", "should not happen"
|
||||
@@ -455,7 +455,7 @@ async def test_execute_block_dry_run_no_empty_error_from_simulator():
|
||||
"""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate(block, input_data):
|
||||
async def fake_simulate(block, input_data, **_kwargs):
|
||||
# Simulator now omits empty error pins at source
|
||||
yield "result", "simulated output"
|
||||
|
||||
@@ -485,7 +485,7 @@ async def test_execute_block_dry_run_keeps_nonempty_error_pin():
|
||||
"""Dry-run should keep the 'error' pin when it contains a real error message."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate(block, input_data):
|
||||
async def fake_simulate(block, input_data, **_kwargs):
|
||||
yield "result", ""
|
||||
yield "error", "API rate limit exceeded"
|
||||
|
||||
@@ -515,7 +515,7 @@ async def test_execute_block_dry_run_message_includes_completed_status():
|
||||
"""Dry-run message should clearly indicate COMPLETED status."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate(block, input_data):
|
||||
async def fake_simulate(block, input_data, **_kwargs):
|
||||
yield "result", "simulated"
|
||||
|
||||
with patch(
|
||||
@@ -541,7 +541,7 @@ async def test_execute_block_dry_run_simulator_error_returns_error_response():
|
||||
"""When simulate_block yields a SIMULATOR ERROR tuple, execute_block returns ErrorResponse."""
|
||||
mock_block = make_mock_block()
|
||||
|
||||
async def fake_simulate_error(block, input_data):
|
||||
async def fake_simulate_error(block, input_data, **_kwargs):
|
||||
yield (
|
||||
"error",
|
||||
"[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available (missing OpenAI/OpenRouter API key).",
|
||||
|
||||
120
autogpt_platform/backend/backend/copilot/tools/todo_write.py
Normal file
120
autogpt_platform/backend/backend/copilot/tools/todo_write.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Task-list tool for baseline copilot mode.
|
||||
|
||||
Mirrors the schema and UX of Claude Code's built-in ``TodoWrite`` tool so
|
||||
the frontend's generic tool renderer draws baseline-emitted checklists the
|
||||
same way it draws SDK-emitted ones. The tool is stateless: the model's
|
||||
latest ``todos`` argument IS the canonical list, replayed from transcript
|
||||
on subsequent turns.
|
||||
|
||||
Baseline needs this as a platform tool because OpenAI-compatible providers
|
||||
(Kimi, GPT, Grok, Gemini) do not ship a built-in equivalent. The SDK path
|
||||
continues to use the CLI's native ``TodoWrite`` — the MCP-wrapped version
|
||||
of this tool is filtered out of SDK's allowed_tools list (see
|
||||
``sdk/tool_adapter.py``) to avoid name shadowing.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, TodoItem, TodoWriteResponse, ToolResponseBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TodoWriteTool(BaseTool):
|
||||
"""Maintain a step-by-step task checklist visible to the user."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# Capitalised to match the frontend's switch on ``"TodoWrite"``
|
||||
# (see ``copilot/tools/GenericTool/helpers.ts``).
|
||||
return "TodoWrite"
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Plan and track multi-step work as a visible checklist. Send "
|
||||
"the full list every call; exactly one item in_progress at a time."
|
||||
)
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"todos": {
|
||||
"type": "array",
|
||||
"description": "Full updated task list (not a delta).",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "Imperative (e.g. 'Run tests').",
|
||||
},
|
||||
"activeForm": {
|
||||
"type": "string",
|
||||
"description": (
|
||||
"Present-continuous (e.g. 'Running tests')."
|
||||
),
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "in_progress", "completed"],
|
||||
"default": "pending",
|
||||
},
|
||||
},
|
||||
"required": ["content", "activeForm"],
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["todos"],
|
||||
}
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
user_id: str | None,
|
||||
session: ChatSession,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
del user_id
|
||||
raw_todos = kwargs.get("todos")
|
||||
if raw_todos is None:
|
||||
return ErrorResponse(
|
||||
message="`todos` is required.",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
if not isinstance(raw_todos, list):
|
||||
return ErrorResponse(
|
||||
message="`todos` must be an array.",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
try:
|
||||
parsed = [TodoItem.model_validate(item) for item in raw_todos]
|
||||
except Exception as exc:
|
||||
return ErrorResponse(
|
||||
message=f"Invalid todo entry: {exc}",
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
in_progress = sum(1 for t in parsed if t.status == "in_progress")
|
||||
if in_progress > 1:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Only one todo may be 'in_progress' at a time "
|
||||
f"(found {in_progress})."
|
||||
),
|
||||
session_id=session.session_id,
|
||||
)
|
||||
|
||||
return TodoWriteResponse(
|
||||
message="Task list updated.",
|
||||
session_id=session.session_id,
|
||||
todos=parsed,
|
||||
)
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Tests for TodoWriteTool."""
|
||||
|
||||
import pytest
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.tools.models import ErrorResponse, TodoItem, TodoWriteResponse
|
||||
from backend.copilot.tools.todo_write import TodoWriteTool
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def tool() -> TodoWriteTool:
|
||||
return TodoWriteTool()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def session() -> ChatSession:
|
||||
return ChatSession.new(user_id="test-user", dry_run=False)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_todo_list(tool: TodoWriteTool, session: ChatSession):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
todos=[
|
||||
{
|
||||
"content": "Write tests",
|
||||
"activeForm": "Writing tests",
|
||||
"status": "pending",
|
||||
},
|
||||
{
|
||||
"content": "Ship PR",
|
||||
"activeForm": "Shipping PR",
|
||||
"status": "in_progress",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(result, TodoWriteResponse)
|
||||
assert result.session_id == session.session_id
|
||||
assert len(result.todos) == 2
|
||||
assert result.todos[0] == TodoItem(
|
||||
content="Write tests",
|
||||
activeForm="Writing tests",
|
||||
status="pending",
|
||||
)
|
||||
assert result.todos[1].status == "in_progress"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_status_is_pending(tool: TodoWriteTool, session: ChatSession):
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
todos=[{"content": "Write tests", "activeForm": "Writing tests"}],
|
||||
)
|
||||
|
||||
assert isinstance(result, TodoWriteResponse)
|
||||
assert result.todos[0].status == "pending"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_todos_returns_error(tool: TodoWriteTool, session: ChatSession):
|
||||
result = await tool._execute(user_id=None, session=session)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "todos" in result.message.lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_list_todos_returns_error(tool: TodoWriteTool, session: ChatSession):
|
||||
result = await tool._execute(user_id=None, session=session, todos="not a list")
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_item_returns_error(tool: TodoWriteTool, session: ChatSession):
|
||||
# Missing required `activeForm` field.
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
todos=[{"content": "Missing active form"}],
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_in_progress_rejected(tool: TodoWriteTool, session: ChatSession):
|
||||
"""Exactly one item should be in_progress at a time — SDK parity rule."""
|
||||
result = await tool._execute(
|
||||
user_id=None,
|
||||
session=session,
|
||||
todos=[
|
||||
{
|
||||
"content": "A",
|
||||
"activeForm": "Doing A",
|
||||
"status": "in_progress",
|
||||
},
|
||||
{
|
||||
"content": "B",
|
||||
"activeForm": "Doing B",
|
||||
"status": "in_progress",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert "in_progress" in result.message
|
||||
|
||||
|
||||
def test_openai_schema_shape(tool: TodoWriteTool):
|
||||
schema = tool.as_openai_tool()
|
||||
assert schema["type"] == "function"
|
||||
assert schema["function"]["name"] == "TodoWrite"
|
||||
params = schema["function"]["parameters"]
|
||||
assert params["required"] == ["todos"]
|
||||
items = params["properties"]["todos"]["items"]
|
||||
assert items["required"] == ["content", "activeForm"]
|
||||
assert items["properties"]["status"]["enum"] == [
|
||||
"pending",
|
||||
"in_progress",
|
||||
"completed",
|
||||
]
|
||||
@@ -21,7 +21,9 @@ from backend.copilot.tools import TOOL_REGISTRY
|
||||
# response shape carries) and the dry_run description. Keeps the
|
||||
# regression gate effective while accepting a deliberate ~120-token
|
||||
# spend on LLM-decision-critical copy.
|
||||
_CHAR_BUDGET = 34_000
|
||||
# Bumped to 34000 to accommodate decompose_goal tool + web_search +
|
||||
# TodoWrite tool descriptions.
|
||||
_CHAR_BUDGET = 35_000
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -111,9 +113,10 @@ def test_total_schema_char_budget() -> None:
|
||||
|
||||
This locks in the 34% token reduction from #12398 and prevents future
|
||||
description bloat from eroding the gains. Uses character count with a
|
||||
~4 chars/token heuristic (budget of 32000 chars ≈ 8000 tokens).
|
||||
Character count is tokenizer-agnostic — no dependency on GPT or Claude
|
||||
tokenizers — while still providing a stable regression gate.
|
||||
~4 chars/token heuristic; see ``_CHAR_BUDGET`` above for the current
|
||||
value and its change history. Character count is tokenizer-agnostic
|
||||
— no dependency on GPT or Claude tokenizers — while still providing a
|
||||
stable regression gate.
|
||||
"""
|
||||
schemas = [tool.as_openai_tool() for tool in TOOL_REGISTRY.values()]
|
||||
serialized = json.dumps(schemas)
|
||||
@@ -123,3 +126,60 @@ def test_total_schema_char_budget() -> None:
|
||||
f"exceeding budget of {_CHAR_BUDGET} chars (~{_CHAR_BUDGET // 4} tokens). "
|
||||
f"Description bloat detected — trim descriptions or raise the budget intentionally."
|
||||
)
|
||||
|
||||
|
||||
# ── Capability-group filtering (ToolGroup / disabled_groups) ───────────
|
||||
|
||||
|
||||
def test_get_available_tools_hides_graphiti_when_disabled() -> None:
|
||||
"""When the ``graphiti`` group is disabled, the memory_* tools must
|
||||
not appear in the OpenAI schema list — they'd just confuse the model
|
||||
and produce opaque runtime errors."""
|
||||
from backend.copilot.tools import get_available_tools
|
||||
|
||||
memory_tool_names = {
|
||||
"memory_store",
|
||||
"memory_search",
|
||||
"memory_forget_search",
|
||||
"memory_forget_confirm",
|
||||
}
|
||||
|
||||
default = {t["function"]["name"] for t in get_available_tools()}
|
||||
assert memory_tool_names.issubset(
|
||||
default
|
||||
), "sanity: memory_* tools should be present when no groups disabled"
|
||||
|
||||
filtered = {
|
||||
t["function"]["name"] for t in get_available_tools(disabled_groups=["graphiti"])
|
||||
}
|
||||
assert not (
|
||||
memory_tool_names & filtered
|
||||
), f"graphiti disabled but memory_* still present: {memory_tool_names & filtered}"
|
||||
# Non-graphiti tools stay visible.
|
||||
assert "find_block" in filtered
|
||||
assert "TodoWrite" in filtered
|
||||
|
||||
|
||||
def test_get_copilot_tool_names_hides_graphiti_when_disabled() -> None:
|
||||
"""Same invariant for the SDK tool-name list."""
|
||||
from backend.copilot.sdk.tool_adapter import MCP_TOOL_PREFIX, get_copilot_tool_names
|
||||
|
||||
memory_mcp_names = {
|
||||
f"{MCP_TOOL_PREFIX}memory_store",
|
||||
f"{MCP_TOOL_PREFIX}memory_search",
|
||||
f"{MCP_TOOL_PREFIX}memory_forget_search",
|
||||
f"{MCP_TOOL_PREFIX}memory_forget_confirm",
|
||||
}
|
||||
|
||||
default = set(get_copilot_tool_names())
|
||||
assert memory_mcp_names.issubset(default)
|
||||
|
||||
filtered = set(get_copilot_tool_names(disabled_groups=["graphiti"]))
|
||||
assert not (
|
||||
memory_mcp_names & filtered
|
||||
), f"graphiti disabled but memory MCP names still present: {memory_mcp_names & filtered}"
|
||||
# E2B path stays consistent.
|
||||
filtered_e2b = set(
|
||||
get_copilot_tool_names(use_e2b=True, disabled_groups=["graphiti"])
|
||||
)
|
||||
assert not (memory_mcp_names & filtered_e2b)
|
||||
|
||||
@@ -1,29 +1,66 @@
|
||||
"""Web search tool — wraps Anthropic's server-side ``web_search`` beta.
|
||||
"""Web search tool — Perplexity Sonar via OpenRouter.
|
||||
|
||||
Single entry point for web search on both SDK and baseline paths. The
|
||||
``web_search_20250305`` tool is server-side on Anthropic, so we call
|
||||
the Messages API directly regardless of which LLM invoked the copilot
|
||||
tool — OpenRouter can't proxy server-side tool execution.
|
||||
One provider, two tiers, one billing path:
|
||||
|
||||
* ``deep=False`` (default) — ``perplexity/sonar``. Searches the web
|
||||
natively and returns citation annotations in a single inference pass.
|
||||
* ``deep=True`` — ``perplexity/sonar-deep-research``. Multi-step
|
||||
agentic research; slower and costlier.
|
||||
|
||||
Why Sonar and not the ``openrouter:web_search`` server tool + dispatch
|
||||
model? The server tool feeds all search-result page content back into
|
||||
the dispatch model for a second inference pass — one observed call was
|
||||
74K input tokens at Gemini Flash rates, billing $0.072. Sonar
|
||||
searches natively in one pass, returns annotations typed as
|
||||
``ChatCompletionMessage.annotations`` in ``openai.types``, and at
|
||||
$1 / MTok base pricing lands ~$0.01 / call at our default shape.
|
||||
|
||||
``resp.usage.cost`` carries the real billed value via OpenRouter's
|
||||
``include: true`` extension; the value flows through
|
||||
``persist_and_record_usage(provider='open_router')`` into the daily /
|
||||
weekly microdollar rate-limit counter on the same rails as every other
|
||||
OpenRouter turn — no separate provider ledger line, no estimation
|
||||
drift. ``_extract_cost_usd`` mirrors the baseline service's
|
||||
``_extract_usage_cost`` logic; keep the two in sync if one changes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
from anthropic import AsyncAnthropic
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion
|
||||
|
||||
from backend.copilot.config import ChatConfig
|
||||
from backend.copilot.model import ChatSession
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.util.settings import Settings
|
||||
|
||||
from .base import BaseTool
|
||||
from .models import ErrorResponse, ToolResponseBase, WebSearchResponse, WebSearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_WEB_SEARCH_DISPATCH_MODEL = "claude-haiku-4-5"
|
||||
_MAX_DISPATCH_TOKENS = 512
|
||||
_chat_config = ChatConfig()
|
||||
|
||||
_QUICK_MODEL = "perplexity/sonar"
|
||||
# Sonar base can emit up to ~4K output; cap at the provider ceiling so the
|
||||
# model stops when the answer is complete rather than when our budget trips.
|
||||
_QUICK_MAX_TOKENS = 4096
|
||||
|
||||
_DEEP_MODEL = "perplexity/sonar-deep-research"
|
||||
# Deep runs can produce long structured writeups — ~4x the quick ceiling
|
||||
# is enough headroom for multi-source comparisons without uncapping.
|
||||
_DEEP_MAX_TOKENS = _QUICK_MAX_TOKENS * 4
|
||||
|
||||
_DEFAULT_MAX_RESULTS = 5
|
||||
_HARD_MAX_RESULTS = 20
|
||||
_SNIPPET_MAX_CHARS = 500
|
||||
|
||||
# OpenRouter-specific extra_body flag that embeds the real generation
|
||||
# cost into the response usage object. Same dict shape the baseline
|
||||
# service uses — keep the two aligned.
|
||||
_OPENROUTER_INCLUDE_USAGE_COST: dict[str, Any] = {"usage": {"include": True}}
|
||||
|
||||
|
||||
class WebSearchTool(BaseTool):
|
||||
@@ -36,9 +73,13 @@ class WebSearchTool(BaseTool):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return (
|
||||
"Search the web for live info (news, recent docs). Returns "
|
||||
"{title, url, snippet}; use web_fetch to deep-dive a URL. "
|
||||
"Prefer one targeted query over many reformulations."
|
||||
"Search the web for live info (news, recent docs). Returns a "
|
||||
"synthesised answer grounded in fresh page content plus "
|
||||
"{title, url, snippet} citations — read the answer first "
|
||||
"before reaching for web_fetch. Set deep=true when the user "
|
||||
"asks for research / comparison / in-depth analysis; leave "
|
||||
"deep=false for quick fact lookups. Prefer one targeted "
|
||||
"query over many reformulations."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -58,6 +99,18 @@ class WebSearchTool(BaseTool):
|
||||
),
|
||||
"default": _DEFAULT_MAX_RESULTS,
|
||||
},
|
||||
"deep": {
|
||||
"type": "boolean",
|
||||
"description": (
|
||||
"Only set true when the user EXPLICITLY asks for "
|
||||
"research, comparison, or in-depth investigation "
|
||||
"across many sources — it is ~100x more expensive "
|
||||
"and much slower than a normal search. Default "
|
||||
"false; do not flip it for ordinary fact lookups "
|
||||
"or fresh-news questions."
|
||||
),
|
||||
"default": False,
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
}
|
||||
@@ -68,7 +121,7 @@ class WebSearchTool(BaseTool):
|
||||
|
||||
@property
|
||||
def is_available(self) -> bool:
|
||||
return bool(Settings().secrets.anthropic_api_key)
|
||||
return bool(_chat_config.api_key and _chat_config.base_url)
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
@@ -76,6 +129,7 @@ class WebSearchTool(BaseTool):
|
||||
session: ChatSession,
|
||||
query: str = "",
|
||||
max_results: int = _DEFAULT_MAX_RESULTS,
|
||||
deep: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponseBase:
|
||||
query = (query or "").strip()
|
||||
@@ -93,44 +147,35 @@ class WebSearchTool(BaseTool):
|
||||
max_results = _DEFAULT_MAX_RESULTS
|
||||
max_results = max(1, min(max_results, _HARD_MAX_RESULTS))
|
||||
|
||||
api_key = Settings().secrets.anthropic_api_key
|
||||
if not api_key:
|
||||
if not _chat_config.api_key or not _chat_config.base_url:
|
||||
return ErrorResponse(
|
||||
message=(
|
||||
"Web search is unavailable — the deployment has no "
|
||||
"Anthropic API key configured."
|
||||
"OpenRouter credentials configured."
|
||||
),
|
||||
error="web_search_not_configured",
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
client = AsyncAnthropic(api_key=api_key)
|
||||
client = AsyncOpenAI(
|
||||
api_key=_chat_config.api_key, base_url=_chat_config.base_url
|
||||
)
|
||||
model_used = _DEEP_MODEL if deep else _QUICK_MODEL
|
||||
max_tokens = _DEEP_MAX_TOKENS if deep else _QUICK_MAX_TOKENS
|
||||
|
||||
try:
|
||||
resp = await client.messages.create(
|
||||
model=_WEB_SEARCH_DISPATCH_MODEL,
|
||||
max_tokens=_MAX_DISPATCH_TOKENS,
|
||||
tools=[
|
||||
{
|
||||
"type": "web_search_20250305",
|
||||
"name": "web_search",
|
||||
"max_uses": 1,
|
||||
}
|
||||
],
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": (
|
||||
f"Use the web_search tool exactly once with the "
|
||||
f"query {query!r} and then stop. Do not "
|
||||
f"summarise — the caller parses the raw "
|
||||
f"tool_result."
|
||||
),
|
||||
}
|
||||
],
|
||||
resp = await client.chat.completions.create(
|
||||
model=model_used,
|
||||
max_tokens=max_tokens,
|
||||
messages=[{"role": "user", "content": query}],
|
||||
extra_body=_OPENROUTER_INCLUDE_USAGE_COST,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"[web_search] Anthropic call failed for query=%r: %s", query, exc
|
||||
"[web_search] OpenRouter call failed (deep=%s) for query=%r: %s",
|
||||
deep,
|
||||
query,
|
||||
exc,
|
||||
)
|
||||
return ErrorResponse(
|
||||
message=f"Web search failed: {exc}",
|
||||
@@ -138,20 +183,20 @@ class WebSearchTool(BaseTool):
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
results, search_requests = _extract_results(resp, limit=max_results)
|
||||
answer = _extract_answer(resp)
|
||||
results = _extract_results(resp, limit=max_results)
|
||||
cost_usd = _extract_cost_usd(resp.usage)
|
||||
|
||||
cost_usd = _estimate_cost_usd(resp, search_requests=search_requests)
|
||||
try:
|
||||
usage = getattr(resp, "usage", None)
|
||||
await persist_and_record_usage(
|
||||
session=session,
|
||||
user_id=user_id,
|
||||
prompt_tokens=getattr(usage, "input_tokens", 0) or 0,
|
||||
completion_tokens=getattr(usage, "output_tokens", 0) or 0,
|
||||
prompt_tokens=resp.usage.prompt_tokens if resp.usage else 0,
|
||||
completion_tokens=resp.usage.completion_tokens if resp.usage else 0,
|
||||
log_prefix="[web_search]",
|
||||
cost_usd=cost_usd,
|
||||
model=_WEB_SEARCH_DISPATCH_MODEL,
|
||||
provider="anthropic",
|
||||
model=model_used,
|
||||
provider="open_router",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("[web_search] usage tracking failed: %s", exc)
|
||||
@@ -159,66 +204,92 @@ class WebSearchTool(BaseTool):
|
||||
return WebSearchResponse(
|
||||
message=f"Found {len(results)} result(s) for {query!r}.",
|
||||
query=query,
|
||||
answer=answer,
|
||||
results=results,
|
||||
search_requests=search_requests,
|
||||
search_requests=1 if results else 0,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
|
||||
def _extract_results(resp: Any, *, limit: int) -> tuple[list[WebSearchResult], int]:
|
||||
"""Pull results + server-side request count from an Anthropic response."""
|
||||
results: list[WebSearchResult] = []
|
||||
search_requests = 0
|
||||
def _extract_answer(resp: ChatCompletion) -> str:
|
||||
"""Return the synthesised answer text from Sonar's response.
|
||||
|
||||
for block in getattr(resp, "content", []) or []:
|
||||
btype = getattr(block, "type", None)
|
||||
if btype == "web_search_tool_result":
|
||||
content = getattr(block, "content", []) or []
|
||||
for item in content:
|
||||
if getattr(item, "type", None) != "web_search_result":
|
||||
continue
|
||||
if len(results) >= limit:
|
||||
break
|
||||
# Anthropic's ``web_search_result`` exposes only
|
||||
# ``title``/``url``/``page_age`` plus an opaque
|
||||
# ``encrypted_content`` blob that is meant for citation
|
||||
# round-tripping, not for display — it is base64-ish
|
||||
# binary and would show as gibberish if surfaced to the
|
||||
# model or the frontend. There is no plain-text snippet
|
||||
# field in the current beta; callers get the readable
|
||||
# text via the model's ``text`` blocks with citations,
|
||||
# not via this list. Leave ``snippet`` empty.
|
||||
results.append(
|
||||
WebSearchResult(
|
||||
title=getattr(item, "title", "") or "",
|
||||
url=getattr(item, "url", "") or "",
|
||||
snippet="",
|
||||
page_age=getattr(item, "page_age", None),
|
||||
)
|
||||
)
|
||||
|
||||
usage = getattr(resp, "usage", None)
|
||||
server_tool_use = getattr(usage, "server_tool_use", None) if usage else None
|
||||
if server_tool_use is not None:
|
||||
search_requests = getattr(server_tool_use, "web_search_requests", 0) or 0
|
||||
|
||||
return results, search_requests
|
||||
Sonar reads every page it cites and writes a web-grounded synthesis
|
||||
into ``choices[0].message.content`` on the same call we pay for.
|
||||
Surfacing it saves the agent from re-fetching citation URLs — many
|
||||
are bot-protected and ``web_fetch`` can't reach them.
|
||||
"""
|
||||
if not resp.choices:
|
||||
return ""
|
||||
content = resp.choices[0].message.content
|
||||
return content or ""
|
||||
|
||||
|
||||
# Update when Anthropic revises pricing.
|
||||
_COST_PER_SEARCH_USD = 0.010 # $10 per 1,000 web_search requests
|
||||
_HAIKU_INPUT_USD_PER_MTOK = 1.0
|
||||
_HAIKU_OUTPUT_USD_PER_MTOK = 5.0
|
||||
def _extract_results(resp: ChatCompletion, *, limit: int) -> list[WebSearchResult]:
|
||||
"""Pull ``url_citation`` annotations from the response.
|
||||
|
||||
Shared across both tiers — OpenRouter normalises the annotation
|
||||
schema across Perplexity's sonar models into
|
||||
``Annotation.url_citation`` (typed in ``openai.types.chat``). The
|
||||
``content`` snippet is an OpenRouter extension on the otherwise-
|
||||
typed ``AnnotationURLCitation``; pydantic stashes unknown fields in
|
||||
``model_extra``, which we read there rather than via ``getattr``.
|
||||
"""
|
||||
if not resp.choices:
|
||||
return []
|
||||
annotations = resp.choices[0].message.annotations or []
|
||||
out: list[WebSearchResult] = []
|
||||
for ann in annotations:
|
||||
if len(out) >= limit:
|
||||
break
|
||||
if ann.type != "url_citation":
|
||||
continue
|
||||
citation = ann.url_citation
|
||||
extras = citation.model_extra or {}
|
||||
snippet_raw = extras.get("content")
|
||||
snippet = (snippet_raw or "")[:_SNIPPET_MAX_CHARS] if snippet_raw else ""
|
||||
out.append(
|
||||
WebSearchResult(
|
||||
title=citation.title,
|
||||
url=citation.url,
|
||||
snippet=snippet,
|
||||
page_age=None,
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _estimate_cost_usd(resp: Any, *, search_requests: int) -> float:
|
||||
"""Per-search fee × count + Haiku dispatch tokens."""
|
||||
usage = getattr(resp, "usage", None)
|
||||
input_tokens = getattr(usage, "input_tokens", 0) if usage else 0
|
||||
output_tokens = getattr(usage, "output_tokens", 0) if usage else 0
|
||||
def _extract_cost_usd(usage: CompletionUsage | None) -> float | None:
|
||||
"""Return the provider-reported USD cost off the response usage.
|
||||
|
||||
search_cost = search_requests * _COST_PER_SEARCH_USD
|
||||
inference_cost = (input_tokens / 1_000_000) * _HAIKU_INPUT_USD_PER_MTOK + (
|
||||
output_tokens / 1_000_000
|
||||
) * _HAIKU_OUTPUT_USD_PER_MTOK
|
||||
return round(search_cost + inference_cost, 6)
|
||||
OpenRouter piggybacks a ``cost`` field on the OpenAI-compatible
|
||||
usage object when the request body includes
|
||||
``usage: {"include": True}``. The OpenAI SDK's typed
|
||||
``CompletionUsage`` does not declare it, so we read it off
|
||||
``model_extra`` (the pydantic v2 container for extras) to keep
|
||||
access fully typed — no ``getattr``. Mirrors the baseline service
|
||||
``_extract_usage_cost``; keep the two in sync.
|
||||
|
||||
Returns ``None`` when the field is absent, null, non-numeric,
|
||||
non-finite, or negative. Invalid values log at error level because
|
||||
they indicate a provider bug worth chasing; plain absences are
|
||||
silent so the caller can dedupe the "missing cost" warning.
|
||||
"""
|
||||
if usage is None:
|
||||
return None
|
||||
extras = usage.model_extra or {}
|
||||
if "cost" not in extras:
|
||||
return None
|
||||
raw = extras["cost"]
|
||||
if raw is None:
|
||||
logger.error("[web_search] usage.cost is present but null")
|
||||
return None
|
||||
try:
|
||||
val = float(raw)
|
||||
except (TypeError, ValueError):
|
||||
logger.error("[web_search] usage.cost is not numeric: %r", raw)
|
||||
return None
|
||||
if not math.isfinite(val) or val < 0:
|
||||
logger.error("[web_search] usage.cost is non-finite or negative: %r", val)
|
||||
return None
|
||||
return val
|
||||
|
||||
@@ -1,212 +1,289 @@
|
||||
"""Tests for the ``web_search`` copilot tool.
|
||||
|
||||
Covers the result extractor + cost estimator as pure units (fed with
|
||||
synthetic Anthropic response objects), plus light integration tests that
|
||||
mock ``AsyncAnthropic.messages.create`` and confirm the handler plumbs
|
||||
through to ``persist_and_record_usage`` with the right provider tag.
|
||||
Covers the annotation extractor + cost extractor as pure units (fed
|
||||
with real ``openai`` SDK types — no duck-typed ``SimpleNamespace``
|
||||
stand-ins), plus integration tests exercising both the quick
|
||||
(``perplexity/sonar``) and deep (``perplexity/sonar-deep-research``)
|
||||
paths — mocking ``AsyncOpenAI.chat.completions.create`` and confirming
|
||||
the handler plumbs through to ``persist_and_record_usage`` with
|
||||
``provider='open_router'`` and the real ``usage.cost`` value.
|
||||
"""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_message import (
|
||||
Annotation,
|
||||
AnnotationURLCitation,
|
||||
ChatCompletionMessage,
|
||||
)
|
||||
|
||||
from backend.copilot.model import ChatSession
|
||||
|
||||
from .models import ErrorResponse, WebSearchResponse, WebSearchResult
|
||||
from .models import ErrorResponse, WebSearchResponse
|
||||
from .web_search import (
|
||||
_COST_PER_SEARCH_USD,
|
||||
WebSearchTool,
|
||||
_estimate_cost_usd,
|
||||
_extract_answer,
|
||||
_extract_cost_usd,
|
||||
_extract_results,
|
||||
)
|
||||
|
||||
|
||||
def _fake_anthropic_response(
|
||||
def _usage(
|
||||
*,
|
||||
results: list[dict] | None = None,
|
||||
search_requests: int = 1,
|
||||
input_tokens: int = 120,
|
||||
output_tokens: int = 40,
|
||||
) -> SimpleNamespace:
|
||||
"""Build a synthetic Anthropic Messages response.
|
||||
prompt_tokens: int = 120,
|
||||
completion_tokens: int = 40,
|
||||
cost: object = 0.01,
|
||||
) -> CompletionUsage:
|
||||
"""Typed ``CompletionUsage`` with OpenRouter's ``cost`` extension
|
||||
parked in ``model_extra`` — the same channel the production code
|
||||
reads it from. ``model_construct`` preserves unknown fields;
|
||||
``model_validate`` would drop them because ``CompletionUsage``
|
||||
treats the schema as strict."""
|
||||
payload: dict[str, Any] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
if cost is not None:
|
||||
payload["cost"] = cost
|
||||
return CompletionUsage.model_construct(None, **payload)
|
||||
|
||||
Matches the shape produced by ``client.messages.create`` when the
|
||||
response includes a ``web_search_tool_result`` content block and
|
||||
``usage.server_tool_use.web_search_requests`` on the turn meter.
|
||||
"""
|
||||
content = []
|
||||
if results is not None:
|
||||
content.append(
|
||||
SimpleNamespace(
|
||||
type="web_search_tool_result",
|
||||
content=[
|
||||
SimpleNamespace(
|
||||
type="web_search_result",
|
||||
title=r.get("title", "untitled"),
|
||||
url=r.get("url", ""),
|
||||
encrypted_content=r.get("snippet", ""),
|
||||
page_age=r.get("page_age"),
|
||||
)
|
||||
for r in results
|
||||
],
|
||||
)
|
||||
|
||||
def _citation(*, url: str, title: str, content: str | None = None) -> Annotation:
|
||||
"""Typed ``Annotation`` for a URL citation. ``content`` is an
|
||||
OpenRouter extension on the otherwise-typed schema — goes into
|
||||
``url_citation.model_extra`` when model_construct preserves it."""
|
||||
payload: dict[str, Any] = {
|
||||
"url": url,
|
||||
"title": title,
|
||||
"start_index": 0,
|
||||
"end_index": len(title),
|
||||
}
|
||||
if content is not None:
|
||||
payload["content"] = content
|
||||
url_citation = AnnotationURLCitation.model_construct(None, **payload)
|
||||
return Annotation(type="url_citation", url_citation=url_citation)
|
||||
|
||||
|
||||
def _fake_response(
|
||||
*,
|
||||
citations: list[dict] | None = None,
|
||||
answer: str = "ok",
|
||||
prompt_tokens: int = 120,
|
||||
completion_tokens: int = 40,
|
||||
cost: object = 0.01,
|
||||
) -> ChatCompletion:
|
||||
"""Build a typed ``ChatCompletion`` shaped like an OpenRouter
|
||||
response — typed end-to-end so the production code's attribute
|
||||
access runs under the real SDK types in tests."""
|
||||
annotations = [
|
||||
_citation(
|
||||
url=c.get("url", ""),
|
||||
title=c.get("title", "untitled"),
|
||||
content=c.get("content"),
|
||||
)
|
||||
usage = SimpleNamespace(
|
||||
input_tokens=input_tokens,
|
||||
output_tokens=output_tokens,
|
||||
server_tool_use=SimpleNamespace(web_search_requests=search_requests),
|
||||
for c in citations or []
|
||||
]
|
||||
message = ChatCompletionMessage.model_construct(
|
||||
None,
|
||||
role="assistant",
|
||||
content=answer,
|
||||
annotations=annotations,
|
||||
)
|
||||
choice = Choice.model_construct(
|
||||
None,
|
||||
index=0,
|
||||
finish_reason="stop",
|
||||
message=message,
|
||||
)
|
||||
return ChatCompletion.model_construct(
|
||||
None,
|
||||
id="cmpl-test",
|
||||
object="chat.completion",
|
||||
created=0,
|
||||
model="perplexity/sonar",
|
||||
choices=[choice],
|
||||
usage=_usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
cost=cost,
|
||||
),
|
||||
)
|
||||
return SimpleNamespace(content=content, usage=usage)
|
||||
|
||||
|
||||
class TestExtractResults:
|
||||
"""The extractor is the only Anthropic-response-shape contact point;
|
||||
pin its behaviour so an API shape change surfaces here first."""
|
||||
"""Pin the annotation shape — a schema bump in the OpenAI SDK or
|
||||
OpenRouter surfaces here first. Same extractor serves both tiers
|
||||
because OpenRouter normalises annotations across models."""
|
||||
|
||||
def test_extracts_title_url_page_age_and_drops_encrypted_snippet(self):
|
||||
# Anthropic's ``web_search_result`` ships an opaque
|
||||
# ``encrypted_content`` blob that is not safe to surface —
|
||||
# the extractor must drop it (snippet=="") regardless of
|
||||
# whether the blob is non-empty.
|
||||
resp = _fake_anthropic_response(
|
||||
results=[
|
||||
def test_extracts_title_url_and_content_snippet(self):
|
||||
resp = _fake_response(
|
||||
citations=[
|
||||
{
|
||||
"title": "Kimi K2.6 launch",
|
||||
"url": "https://example.com/kimi",
|
||||
"snippet": "EiJjbGF1ZGUtZW5jcnlwdGVkLWJsb2I=",
|
||||
"page_age": "1 day",
|
||||
"content": "Moonshot released K2.6 on 2026-04-20.",
|
||||
},
|
||||
{
|
||||
"title": "OpenRouter pricing",
|
||||
"url": "https://openrouter.ai/moonshotai/kimi-k2.6",
|
||||
"snippet": "",
|
||||
},
|
||||
]
|
||||
)
|
||||
out, requests = _extract_results(resp, limit=10)
|
||||
assert requests == 1
|
||||
out = _extract_results(resp, limit=10)
|
||||
assert len(out) == 2
|
||||
assert out[0].title == "Kimi K2.6 launch"
|
||||
assert out[0].url == "https://example.com/kimi"
|
||||
assert out[0].snippet == ""
|
||||
assert out[0].page_age == "1 day"
|
||||
assert out[0].snippet.startswith("Moonshot released")
|
||||
# Missing ``content`` extension → empty snippet rather than crash.
|
||||
assert out[1].snippet == ""
|
||||
|
||||
def test_limit_caps_returned_results(self):
|
||||
resp = _fake_anthropic_response(
|
||||
results=[{"title": f"r{i}", "url": f"https://e/{i}"} for i in range(10)]
|
||||
resp = _fake_response(
|
||||
citations=[{"title": f"r{i}", "url": f"https://e/{i}"} for i in range(10)]
|
||||
)
|
||||
out, _ = _extract_results(resp, limit=3)
|
||||
out = _extract_results(resp, limit=3)
|
||||
assert len(out) == 3
|
||||
assert [r.title for r in out] == ["r0", "r1", "r2"]
|
||||
|
||||
def test_missing_content_returns_empty(self):
|
||||
resp = SimpleNamespace(content=[], usage=None)
|
||||
out, requests = _extract_results(resp, limit=10)
|
||||
assert out == []
|
||||
assert requests == 0
|
||||
|
||||
def test_non_search_blocks_are_ignored(self):
|
||||
resp = SimpleNamespace(
|
||||
content=[
|
||||
SimpleNamespace(type="text", text="Here's what I found..."),
|
||||
SimpleNamespace(
|
||||
type="web_search_tool_result",
|
||||
content=[
|
||||
SimpleNamespace(
|
||||
type="web_search_result",
|
||||
title="real",
|
||||
url="https://real.example",
|
||||
encrypted_content="body",
|
||||
page_age=None,
|
||||
)
|
||||
],
|
||||
),
|
||||
],
|
||||
usage=None,
|
||||
def test_missing_choices_returns_empty(self):
|
||||
resp = ChatCompletion.model_construct(
|
||||
None,
|
||||
id="cmpl-test",
|
||||
object="chat.completion",
|
||||
created=0,
|
||||
model="perplexity/sonar",
|
||||
choices=[],
|
||||
usage=_usage(),
|
||||
)
|
||||
out, _ = _extract_results(resp, limit=10)
|
||||
assert len(out) == 1 and out[0].title == "real"
|
||||
assert _extract_results(resp, limit=10) == []
|
||||
|
||||
|
||||
class TestEstimateCostUsd:
|
||||
"""Pin the per-search fee + Haiku inference math — the pricing
|
||||
constants in ``web_search.py`` are hard-coded (no live lookup) so a
|
||||
drift between Anthropic's schedule and our constants must surface
|
||||
in this test for the next reader to notice."""
|
||||
|
||||
def test_zero_searches_still_charges_inference(self):
|
||||
resp = _fake_anthropic_response(results=[], search_requests=0)
|
||||
cost = _estimate_cost_usd(resp, search_requests=0)
|
||||
# Haiku at 1000 input / 5000 output tokens = tiny but non-zero.
|
||||
assert 0 < cost < 0.001
|
||||
|
||||
def test_single_search_fee_dominates(self):
|
||||
resp = _fake_anthropic_response(
|
||||
results=[{"title": "x", "url": "https://e"}],
|
||||
search_requests=1,
|
||||
input_tokens=100,
|
||||
output_tokens=20,
|
||||
def test_extract_answer_returns_message_content(self):
|
||||
resp = _fake_response(
|
||||
answer="Sonar's synthesised, web-grounded answer text.",
|
||||
citations=[{"title": "t", "url": "https://e"}],
|
||||
)
|
||||
cost = _estimate_cost_usd(resp, search_requests=1)
|
||||
# ~$0.010 search + trivial inference — total still ~1 cent.
|
||||
assert cost >= _COST_PER_SEARCH_USD
|
||||
assert cost < _COST_PER_SEARCH_USD + 0.001
|
||||
assert _extract_answer(resp) == "Sonar's synthesised, web-grounded answer text."
|
||||
|
||||
def test_three_searches_linear_in_count(self):
|
||||
resp = _fake_anthropic_response(
|
||||
results=[], search_requests=3, input_tokens=0, output_tokens=0
|
||||
def test_extract_answer_returns_empty_when_no_choices(self):
|
||||
resp = ChatCompletion.model_construct(
|
||||
None,
|
||||
id="cmpl-test",
|
||||
object="chat.completion",
|
||||
created=0,
|
||||
model="perplexity/sonar",
|
||||
choices=[],
|
||||
usage=_usage(),
|
||||
)
|
||||
cost = _estimate_cost_usd(resp, search_requests=3)
|
||||
assert cost == pytest.approx(3 * _COST_PER_SEARCH_USD)
|
||||
assert _extract_answer(resp) == ""
|
||||
|
||||
def test_snippet_clamped_to_max_chars(self):
|
||||
long_body = "x" * 5000
|
||||
resp = _fake_response(
|
||||
citations=[{"title": "t", "url": "https://e", "content": long_body}]
|
||||
)
|
||||
out = _extract_results(resp, limit=1)
|
||||
assert len(out) == 1
|
||||
assert len(out[0].snippet) == 500
|
||||
|
||||
|
||||
class TestExtractCostUsd:
|
||||
"""Read real ``usage.cost`` via typed ``model_extra`` — no
|
||||
hard-coded rates, so a future provider price change is reflected
|
||||
automatically. Error handling mirrors the baseline service's
|
||||
``_extract_usage_cost``."""
|
||||
|
||||
def test_returns_cost_value(self):
|
||||
assert _extract_cost_usd(_usage(cost=0.023456)) == pytest.approx(0.023456)
|
||||
|
||||
def test_returns_none_when_usage_missing(self):
|
||||
assert _extract_cost_usd(None) is None
|
||||
|
||||
def test_returns_none_when_cost_field_missing(self):
|
||||
assert _extract_cost_usd(_usage(cost=None)) is None
|
||||
|
||||
def test_returns_none_when_cost_is_explicit_null(self):
|
||||
usage = CompletionUsage.model_construct(
|
||||
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=None
|
||||
)
|
||||
assert _extract_cost_usd(usage) is None
|
||||
|
||||
def test_returns_none_when_cost_is_negative(self):
|
||||
usage = CompletionUsage.model_construct(
|
||||
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=-1.0
|
||||
)
|
||||
assert _extract_cost_usd(usage) is None
|
||||
|
||||
def test_accepts_numeric_string(self):
|
||||
usage = CompletionUsage.model_construct(
|
||||
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost="0.017"
|
||||
)
|
||||
assert _extract_cost_usd(usage) == pytest.approx(0.017)
|
||||
|
||||
|
||||
class TestWebSearchToolDispatch:
|
||||
"""Lightweight integration test: mock the Anthropic client, confirm
|
||||
the handler returns a ``WebSearchResponse`` and the usage tracker is
|
||||
called with ``provider='anthropic'`` (not 'open_router', even on the
|
||||
baseline path — server-side web_search bills Anthropic regardless of
|
||||
the calling LLM's route)."""
|
||||
"""Integration test: mock the OpenAI client, confirm both paths
|
||||
dispatch the right Sonar model + track cost."""
|
||||
|
||||
def _session(self) -> ChatSession:
|
||||
s = ChatSession.new("test-user", dry_run=False)
|
||||
s.session_id = "sess-1"
|
||||
return s
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_response_with_results_and_tracks_cost(self, monkeypatch):
|
||||
fake_resp = _fake_anthropic_response(
|
||||
results=[
|
||||
{
|
||||
"title": "hello",
|
||||
"url": "https://example.com",
|
||||
"snippet": "greeting",
|
||||
}
|
||||
],
|
||||
search_requests=1,
|
||||
)
|
||||
mock_client = type(
|
||||
def _mock_client(self, fake_resp: ChatCompletion) -> Any:
|
||||
return type(
|
||||
"MC",
|
||||
(),
|
||||
{
|
||||
"messages": type(
|
||||
"M", (), {"create": AsyncMock(return_value=fake_resp)}
|
||||
"chat": type(
|
||||
"C",
|
||||
(),
|
||||
{
|
||||
"completions": type(
|
||||
"CC",
|
||||
(),
|
||||
{"create": AsyncMock(return_value=fake_resp)},
|
||||
)()
|
||||
},
|
||||
)()
|
||||
},
|
||||
)()
|
||||
|
||||
# Stub the Anthropic API key so ``is_available`` is True.
|
||||
@pytest.mark.asyncio
|
||||
async def test_quick_path_uses_sonar_base(self, monkeypatch):
|
||||
fake_resp = _fake_response(
|
||||
citations=[
|
||||
{
|
||||
"title": "hello",
|
||||
"url": "https://example.com",
|
||||
"content": "greeting",
|
||||
}
|
||||
],
|
||||
answer="Kimi K2.6 launched 2026-04-20 [1].",
|
||||
cost=0.01,
|
||||
)
|
||||
mock_client = self._mock_client(fake_resp)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.web_search.Settings",
|
||||
lambda: SimpleNamespace(
|
||||
secrets=SimpleNamespace(anthropic_api_key="sk-test")
|
||||
),
|
||||
"backend.copilot.tools.web_search._chat_config",
|
||||
type(
|
||||
"C",
|
||||
(),
|
||||
{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.AsyncAnthropic",
|
||||
"backend.copilot.tools.web_search.AsyncOpenAI",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
@@ -220,35 +297,88 @@ class TestWebSearchToolDispatch:
|
||||
session=self._session(),
|
||||
query="kimi k2.6 launch",
|
||||
max_results=5,
|
||||
deep=False,
|
||||
)
|
||||
|
||||
assert isinstance(result, WebSearchResponse)
|
||||
assert result.query == "kimi k2.6 launch"
|
||||
assert result.answer == "Kimi K2.6 launched 2026-04-20 [1]."
|
||||
assert len(result.results) == 1
|
||||
assert isinstance(result.results[0], WebSearchResult)
|
||||
assert result.search_requests == 1
|
||||
assert result.results[0].snippet == "greeting"
|
||||
|
||||
create_call = mock_client.chat.completions.create.call_args
|
||||
assert create_call.kwargs["model"] == "perplexity/sonar"
|
||||
# Sonar searches natively — no server-tool extras.
|
||||
assert create_call.kwargs["extra_body"] == {"usage": {"include": True}}
|
||||
|
||||
# Cost tracker must have been called with provider="anthropic".
|
||||
assert mock_track.await_count == 1
|
||||
kwargs = mock_track.await_args.kwargs
|
||||
assert kwargs["provider"] == "anthropic"
|
||||
assert kwargs["model"] == "claude-haiku-4-5"
|
||||
assert kwargs["user_id"] == "u1"
|
||||
assert kwargs["cost_usd"] >= _COST_PER_SEARCH_USD
|
||||
assert kwargs["provider"] == "open_router"
|
||||
assert kwargs["model"] == "perplexity/sonar"
|
||||
assert kwargs["cost_usd"] == pytest.approx(0.01)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_api_key_returns_error_without_calling_anthropic(
|
||||
self, monkeypatch
|
||||
):
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.web_search.Settings",
|
||||
lambda: SimpleNamespace(secrets=SimpleNamespace(anthropic_api_key="")),
|
||||
async def test_deep_path_uses_sonar_deep_research(self, monkeypatch):
|
||||
fake_resp = _fake_response(
|
||||
citations=[
|
||||
{
|
||||
"title": "deep find",
|
||||
"url": "https://example.com/deep",
|
||||
"content": "research body",
|
||||
}
|
||||
],
|
||||
cost=0.087,
|
||||
)
|
||||
anthropic_stub = AsyncMock()
|
||||
mock_client = self._mock_client(fake_resp)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.web_search._chat_config",
|
||||
type(
|
||||
"C",
|
||||
(),
|
||||
{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
)(),
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.AsyncAnthropic",
|
||||
return_value=anthropic_stub,
|
||||
"backend.copilot.tools.web_search.AsyncOpenAI",
|
||||
return_value=mock_client,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.persist_and_record_usage",
|
||||
new=AsyncMock(return_value=160),
|
||||
) as mock_track,
|
||||
):
|
||||
tool = WebSearchTool()
|
||||
result = await tool._execute(
|
||||
user_id="u1",
|
||||
session=self._session(),
|
||||
query="research question",
|
||||
deep=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, WebSearchResponse)
|
||||
create_call = mock_client.chat.completions.create.call_args
|
||||
assert create_call.kwargs["model"] == "perplexity/sonar-deep-research"
|
||||
|
||||
kwargs = mock_track.await_args.kwargs
|
||||
assert kwargs["provider"] == "open_router"
|
||||
assert kwargs["model"] == "perplexity/sonar-deep-research"
|
||||
assert kwargs["cost_usd"] == pytest.approx(0.087)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_credentials_returns_error(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.web_search._chat_config",
|
||||
type("C", (), {"api_key": "", "base_url": ""})(),
|
||||
)
|
||||
openai_stub = AsyncMock()
|
||||
with (
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.AsyncOpenAI",
|
||||
return_value=openai_stub,
|
||||
),
|
||||
patch(
|
||||
"backend.copilot.tools.web_search.persist_and_record_usage",
|
||||
@@ -264,21 +394,26 @@ class TestWebSearchToolDispatch:
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "web_search_not_configured"
|
||||
anthropic_stub.messages.create.assert_not_called()
|
||||
openai_stub.chat.completions.create.assert_not_called()
|
||||
mock_track.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_query_rejected_without_api_call(self, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
"backend.copilot.tools.web_search.Settings",
|
||||
lambda: SimpleNamespace(
|
||||
secrets=SimpleNamespace(anthropic_api_key="sk-test")
|
||||
),
|
||||
"backend.copilot.tools.web_search._chat_config",
|
||||
type(
|
||||
"C",
|
||||
(),
|
||||
{
|
||||
"api_key": "sk-test",
|
||||
"base_url": "https://openrouter.ai/api/v1",
|
||||
},
|
||||
)(),
|
||||
)
|
||||
anthropic_stub = AsyncMock()
|
||||
openai_stub = AsyncMock()
|
||||
with patch(
|
||||
"backend.copilot.tools.web_search.AsyncAnthropic",
|
||||
return_value=anthropic_stub,
|
||||
"backend.copilot.tools.web_search.AsyncOpenAI",
|
||||
return_value=openai_stub,
|
||||
):
|
||||
tool = WebSearchTool()
|
||||
result = await tool._execute(
|
||||
@@ -286,13 +421,13 @@ class TestWebSearchToolDispatch:
|
||||
)
|
||||
assert isinstance(result, ErrorResponse)
|
||||
assert result.error == "missing_query"
|
||||
anthropic_stub.messages.create.assert_not_called()
|
||||
openai_stub.chat.completions.create.assert_not_called()
|
||||
|
||||
|
||||
class TestToolRegistryIntegration:
|
||||
"""The tool must be registered under the ``web_search`` name so the
|
||||
MCP layer exposes it as ``mcp__copilot__web_search`` — which is
|
||||
what the SDK path now dispatches to (see
|
||||
what the SDK path dispatches to (see
|
||||
``sdk/tool_adapter.py::SDK_DISALLOWED_TOOLS`` which blocks the CLI's
|
||||
native ``WebSearch`` in favour of the MCP route)."""
|
||||
|
||||
|
||||
@@ -195,16 +195,17 @@ def strip_stale_thinking_blocks(content: str) -> str:
|
||||
is_last_turn = (
|
||||
last_asst_msg_id is not None and msg.get("id") == last_asst_msg_id
|
||||
) or (last_asst_msg_id is None and i == last_asst_idx)
|
||||
if (
|
||||
msg.get("role") == "assistant"
|
||||
and not is_last_turn
|
||||
and isinstance(msg.get("content"), list)
|
||||
):
|
||||
if msg.get("role") == "assistant" and isinstance(msg.get("content"), list):
|
||||
content_blocks = msg["content"]
|
||||
producing_model = msg.get("model") if isinstance(msg, dict) else None
|
||||
filtered = [
|
||||
b
|
||||
for b in content_blocks
|
||||
if not (isinstance(b, dict) and b.get("type") in _THINKING_BLOCK_TYPES)
|
||||
if not _should_strip_thinking_block(
|
||||
b,
|
||||
is_last_turn=is_last_turn,
|
||||
producing_model=producing_model,
|
||||
)
|
||||
]
|
||||
if len(filtered) < len(content_blocks):
|
||||
stripped_count += len(content_blocks) - len(filtered)
|
||||
@@ -310,23 +311,30 @@ def strip_for_upload(content: str) -> str:
|
||||
if uid in reparented:
|
||||
needs_reserialize = True
|
||||
|
||||
# Strip stale thinking blocks from non-last assistant entries
|
||||
# Strip stale thinking blocks from non-last assistant entries.
|
||||
# Also strip *signature-less* thinking blocks from the last entry —
|
||||
# those come from non-Anthropic providers (e.g. Kimi K2.6 via
|
||||
# OpenRouter) and are rejected with ``Invalid `signature` in
|
||||
# `thinking` block`` if a subsequent turn is dispatched to an
|
||||
# Anthropic model that re-validates them. Anthropic-emitted
|
||||
# thinking blocks always carry a non-empty ``signature`` field, so
|
||||
# this filter is a no-op on Sonnet/Opus turns and only kicks in
|
||||
# when the prior turn ran on a non-Anthropic vendor.
|
||||
if last_asst_idx is not None:
|
||||
msg = entry.get("message", {})
|
||||
is_last_turn = (
|
||||
last_asst_msg_id is not None and msg.get("id") == last_asst_msg_id
|
||||
) or (last_asst_msg_id is None and i == last_asst_idx)
|
||||
if (
|
||||
msg.get("role") == "assistant"
|
||||
and not is_last_turn
|
||||
and isinstance(msg.get("content"), list)
|
||||
):
|
||||
if msg.get("role") == "assistant" and isinstance(msg.get("content"), list):
|
||||
content_blocks = msg["content"]
|
||||
producing_model = msg.get("model") if isinstance(msg, dict) else None
|
||||
filtered = [
|
||||
b
|
||||
for b in content_blocks
|
||||
if not (
|
||||
isinstance(b, dict) and b.get("type") in _THINKING_BLOCK_TYPES
|
||||
if not _should_strip_thinking_block(
|
||||
b,
|
||||
is_last_turn=is_last_turn,
|
||||
producing_model=producing_model,
|
||||
)
|
||||
]
|
||||
if len(filtered) < len(content_blocks):
|
||||
@@ -951,6 +959,92 @@ ENTRY_TYPE_MESSAGE = "message"
|
||||
_THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"})
|
||||
|
||||
|
||||
def _is_anthropic_model(model: str | None) -> bool:
|
||||
"""True when *model* is an Anthropic-issued slug.
|
||||
|
||||
Used to decide whether a thinking block's signature is
|
||||
cryptographically valid for Anthropic replay. Non-Anthropic vendors
|
||||
routed through OpenRouter's Anthropic-compat shim (Kimi K2.6,
|
||||
DeepSeek, GPT-OSS) sometimes emit thinking blocks with a
|
||||
placeholder signature — it passes a non-empty string check but
|
||||
fails Anthropic's cryptographic validation, producing the opaque
|
||||
``Invalid signature in thinking block`` 400 on the next turn
|
||||
whenever the model toggle switches to Sonnet/Opus.
|
||||
"""
|
||||
return isinstance(model, str) and model.startswith("anthropic/")
|
||||
|
||||
|
||||
def _should_strip_thinking_block(
|
||||
block: object,
|
||||
*,
|
||||
is_last_turn: bool,
|
||||
producing_model: str | None = None,
|
||||
) -> bool:
|
||||
"""Return True when *block* is a thinking block that should be removed
|
||||
from a transcript entry before upload.
|
||||
|
||||
Strip only when the block CAN'T be replayed safely. Never strip a
|
||||
valid Anthropic-issued thinking block — it carries real reasoning
|
||||
state that preserves context continuity on ``--resume``.
|
||||
|
||||
Strip rules (first match wins):
|
||||
|
||||
1. **Non-Anthropic producer (any position)** — thinking blocks from
|
||||
Kimi / DeepSeek / GPT-OSS via OpenRouter's Anthropic-compat shim
|
||||
carry either no signature or a placeholder string that passes a
|
||||
non-empty check but fails Anthropic's cryptographic validation.
|
||||
Strip unconditionally; they also add low-value tokens to the
|
||||
replay context.
|
||||
2. **Malformed ``thinking`` (any position, Anthropic producer,
|
||||
empty signature)** — shouldn't happen in practice, but if the
|
||||
signature is missing / empty the block can't be validated.
|
||||
Safer to drop than to 400 the next turn.
|
||||
3. **Stale non-last entry with unknown producer** — when the
|
||||
caller doesn't wire ``producing_model`` through (legacy paths /
|
||||
older tests) we can't tell if the block is safe to keep; fall
|
||||
back to the old behaviour of dropping non-last thinking blocks
|
||||
to avoid replaying an unverifiable block to Anthropic.
|
||||
|
||||
Preserved:
|
||||
|
||||
* Anthropic ``thinking`` with non-empty signature — at any
|
||||
position, last OR non-last. Keeping prior-turn reasoning
|
||||
chains helps continuity on multi-round SDK resumes without any
|
||||
risk of signature rejection.
|
||||
* Anthropic ``redacted_thinking`` — carries an encrypted ``data``
|
||||
payload instead of a ``signature``; by design signature-less,
|
||||
but Anthropic-issued and safely replayable.
|
||||
"""
|
||||
if not isinstance(block, dict):
|
||||
return False
|
||||
btype = block.get("type")
|
||||
if btype not in _THINKING_BLOCK_TYPES:
|
||||
return False
|
||||
# Legacy call sites pass producing_model=None — preserve the old
|
||||
# "strip-all-non-last-thinking" heuristic for those so we don't
|
||||
# regress callers that haven't been updated.
|
||||
if producing_model is None:
|
||||
if not is_last_turn:
|
||||
return True
|
||||
if btype != "thinking":
|
||||
return False
|
||||
signature = block.get("signature")
|
||||
return not (isinstance(signature, str) and signature)
|
||||
# Non-Anthropic producer — strip at any position. These blocks
|
||||
# CAN'T be cryptographically validated by Anthropic on replay.
|
||||
if not _is_anthropic_model(producing_model):
|
||||
return True
|
||||
# Anthropic producer, redacted_thinking: always preserve — the
|
||||
# ``data`` field is the signature analog.
|
||||
if btype == "redacted_thinking":
|
||||
return False
|
||||
# Anthropic producer, ``thinking``: keep iff it has a real
|
||||
# (non-empty) signature. Empty-signature Anthropic thinking
|
||||
# shouldn't happen but guard against it anyway.
|
||||
signature = block.get("signature")
|
||||
return not (isinstance(signature, str) and signature)
|
||||
|
||||
|
||||
def _flatten_assistant_content(blocks: list) -> str:
|
||||
"""Flatten assistant content blocks into a single plain-text string.
|
||||
|
||||
|
||||
@@ -591,7 +591,16 @@ class TestStripForUpload:
|
||||
"role": "assistant",
|
||||
"id": "msg_new",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "fresh thinking"},
|
||||
# Anthropic-style thinking block — has a signature so
|
||||
# ``_should_strip_thinking_block`` preserves it on the
|
||||
# last turn. Without the signature (e.g. emitted by
|
||||
# Kimi K2.6 via OpenRouter) it would be stripped — see
|
||||
# ``test_strips_signatureless_thinking_from_last_turn``.
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "fresh thinking",
|
||||
"signature": "anthropic-signed-blob",
|
||||
},
|
||||
{"type": "text", "text": "new answer"},
|
||||
],
|
||||
},
|
||||
@@ -624,6 +633,224 @@ class TestStripForUpload:
|
||||
new_types = [b["type"] for b in new_content if isinstance(b, dict)]
|
||||
assert "thinking" in new_types # last assistant preserved
|
||||
|
||||
def test_strips_signatureless_thinking_from_last_turn(self):
|
||||
"""Kimi K2.6 (and other non-Anthropic OpenRouter providers) emit
|
||||
thinking blocks without the Anthropic ``signature`` field. When
|
||||
a subsequent advanced-tier toggle replays the transcript to Opus,
|
||||
Anthropic's API rejects the signature-less block with ``Invalid
|
||||
`signature` in `thinking` block`` — so strip_for_upload must drop
|
||||
them from the LAST assistant entry too, not just stale ones."""
|
||||
user = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
}
|
||||
# Last (and only) assistant entry with a Kimi-shape thinking block
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_kimi",
|
||||
"content": [
|
||||
# No ``signature`` field → non-Anthropic provider
|
||||
{"type": "thinking", "thinking": "kimi reasoning"},
|
||||
{"type": "text", "text": "answer"},
|
||||
],
|
||||
},
|
||||
}
|
||||
content = _make_jsonl(user, asst)
|
||||
result = strip_for_upload(content)
|
||||
entries = [json.loads(line) for line in result.strip().split("\n")]
|
||||
asst_entry = next(
|
||||
e for e in entries if e.get("message", {}).get("id") == "msg_kimi"
|
||||
)
|
||||
types = [
|
||||
b["type"] for b in asst_entry["message"]["content"] if isinstance(b, dict)
|
||||
]
|
||||
assert "thinking" not in types, (
|
||||
"Signature-less thinking block on last turn must be stripped "
|
||||
"to prevent Anthropic API rejection on model-switch replay"
|
||||
)
|
||||
assert "text" in types, "Text content must survive stripping"
|
||||
|
||||
def test_strips_non_anthropic_thinking_with_placeholder_signature(self):
|
||||
"""OpenRouter's Anthropic-compat shim can emit thinking blocks
|
||||
from non-Anthropic producers (Kimi K2.6, DeepSeek) with a
|
||||
PLACEHOLDER signature string that passes the "non-empty string"
|
||||
check but fails Anthropic's cryptographic validation on replay.
|
||||
|
||||
Observed in session 864a55ba after model-toggle from standard
|
||||
(Kimi) to advanced (Opus): the CLI session upload included a
|
||||
thinking block with ``signature="ANTHROPIC_SHIM_PLACEHOLDER"``
|
||||
(or similar), Opus 4.7 rejected with 400 ``Invalid `signature`
|
||||
in `thinking` block``. Fix: strip thinking blocks from the
|
||||
LAST assistant turn whenever the producing ``model`` isn't an
|
||||
``anthropic/*`` slug, regardless of signature presence."""
|
||||
user = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_kimi_shim",
|
||||
"model": "moonshotai/kimi-k2.6-20260420",
|
||||
"content": [
|
||||
# Placeholder signature — non-empty but cryptographically
|
||||
# invalid for Anthropic. Legacy strip (signature-only)
|
||||
# would KEEP this block.
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "shimmed reasoning",
|
||||
"signature": "PLACEHOLDER_SHIM_SIG_abc123",
|
||||
},
|
||||
{"type": "text", "text": "answer"},
|
||||
],
|
||||
},
|
||||
}
|
||||
content = _make_jsonl(user, asst)
|
||||
result = strip_for_upload(content)
|
||||
entries = [json.loads(line) for line in result.strip().split("\n")]
|
||||
asst_entry = next(
|
||||
e for e in entries if e.get("message", {}).get("id") == "msg_kimi_shim"
|
||||
)
|
||||
types = [
|
||||
b["type"] for b in asst_entry["message"]["content"] if isinstance(b, dict)
|
||||
]
|
||||
assert "thinking" not in types, (
|
||||
"Non-Anthropic thinking block must be stripped even when it "
|
||||
"carries a placeholder signature — replay-to-Opus otherwise "
|
||||
"400s with Invalid signature"
|
||||
)
|
||||
assert "text" in types
|
||||
|
||||
def test_preserves_anthropic_thinking_on_non_last_turn(self):
|
||||
"""Anthropic ``thinking`` blocks on NON-last turns carry real
|
||||
reasoning state that helps context continuity on ``--resume``.
|
||||
Keep them when the producing model is known-Anthropic with a
|
||||
valid signature; strip only when we can't validate safely
|
||||
(legacy callers with no model info — falls through to the
|
||||
old stale-strip rule).
|
||||
"""
|
||||
user = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "first"},
|
||||
}
|
||||
asst1 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_opus_prev",
|
||||
"model": "anthropic/claude-4.7-opus-20260416",
|
||||
"content": [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "first-turn reasoning",
|
||||
"signature": "ANTHROPIC_SIG_1",
|
||||
},
|
||||
{"type": "text", "text": "first answer"},
|
||||
],
|
||||
},
|
||||
}
|
||||
user2 = {
|
||||
"type": "user",
|
||||
"uuid": "u2",
|
||||
"parentUuid": "a1",
|
||||
"message": {"role": "user", "content": "second"},
|
||||
}
|
||||
asst2 = {
|
||||
"type": "assistant",
|
||||
"uuid": "a2",
|
||||
"parentUuid": "u2",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_opus_last",
|
||||
"model": "anthropic/claude-4.7-opus-20260416",
|
||||
"content": [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "last-turn reasoning",
|
||||
"signature": "ANTHROPIC_SIG_2",
|
||||
},
|
||||
{"type": "text", "text": "last answer"},
|
||||
],
|
||||
},
|
||||
}
|
||||
content = _make_jsonl(user, asst1, user2, asst2)
|
||||
result = strip_for_upload(content)
|
||||
entries = [json.loads(line) for line in result.strip().split("\n")]
|
||||
|
||||
# Prior Opus turn's thinking must survive — valid Anthropic
|
||||
# block with signature.
|
||||
prev = next(
|
||||
e for e in entries if e.get("message", {}).get("id") == "msg_opus_prev"
|
||||
)
|
||||
prev_types = [b["type"] for b in prev["message"]["content"]]
|
||||
assert "thinking" in prev_types, (
|
||||
"Anthropic thinking block on a non-last turn must be "
|
||||
"preserved — it carries real reasoning state"
|
||||
)
|
||||
# Last turn's thinking also preserved.
|
||||
last = next(
|
||||
e for e in entries if e.get("message", {}).get("id") == "msg_opus_last"
|
||||
)
|
||||
last_types = [b["type"] for b in last["message"]["content"]]
|
||||
assert "thinking" in last_types
|
||||
|
||||
def test_preserves_anthropic_thinking_with_valid_signature(self):
|
||||
"""Sanity: an Anthropic-issued thinking block with a real
|
||||
signature on the last turn must NOT be stripped — Anthropic
|
||||
requires value-identity on replay."""
|
||||
user = {
|
||||
"type": "user",
|
||||
"uuid": "u1",
|
||||
"parentUuid": "",
|
||||
"message": {"role": "user", "content": "hi"},
|
||||
}
|
||||
asst = {
|
||||
"type": "assistant",
|
||||
"uuid": "a1",
|
||||
"parentUuid": "u1",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"id": "msg_opus",
|
||||
"model": "anthropic/claude-4.7-opus-20260416",
|
||||
"content": [
|
||||
{
|
||||
"type": "thinking",
|
||||
"thinking": "reasoning",
|
||||
"signature": "REAL_ANTHROPIC_SIG_blob",
|
||||
},
|
||||
{"type": "text", "text": "answer"},
|
||||
],
|
||||
},
|
||||
}
|
||||
content = _make_jsonl(user, asst)
|
||||
result = strip_for_upload(content)
|
||||
entries = [json.loads(line) for line in result.strip().split("\n")]
|
||||
asst_entry = next(
|
||||
e for e in entries if e.get("message", {}).get("id") == "msg_opus"
|
||||
)
|
||||
types = [
|
||||
b["type"] for b in asst_entry["message"]["content"] if isinstance(b, dict)
|
||||
]
|
||||
assert (
|
||||
"thinking" in types
|
||||
), "Anthropic-signed thinking on last turn must survive strip"
|
||||
assert "text" in types
|
||||
|
||||
def test_empty_content(self):
|
||||
result = strip_for_upload("")
|
||||
# Empty string produces a single empty line after split, resulting in "\n"
|
||||
|
||||
@@ -14,6 +14,21 @@ HOST = os.getenv("REDIS_HOST", "localhost")
|
||||
PORT = int(os.getenv("REDIS_PORT", "6379"))
|
||||
PASSWORD = os.getenv("REDIS_PASSWORD", None)
|
||||
|
||||
# Default socket timeouts so a wedged Redis endpoint can't hang callers
|
||||
# indefinitely — long-running code paths (cluster_lock refresh in particular)
|
||||
# rely on these to fail-fast instead of blocking on no-response TCP. Override
|
||||
# via env if a specific deployment needs a different budget.
|
||||
#
|
||||
# 30s matches the convention in ``backend.data.rabbitmq`` and leaves ~6x
|
||||
# headroom over the largest ``xread(block=5000)`` wait in stream_registry.
|
||||
# The connect timeout is shorter (5s) because initial connects should be
|
||||
# fast; a slow connect usually means the endpoint is genuinely unreachable.
|
||||
SOCKET_TIMEOUT = float(os.getenv("REDIS_SOCKET_TIMEOUT", "30"))
|
||||
SOCKET_CONNECT_TIMEOUT = float(os.getenv("REDIS_SOCKET_CONNECT_TIMEOUT", "5"))
|
||||
# How often redis-py sends a PING on idle connections to detect half-open
|
||||
# sockets; cheap and avoids waiting for the OS TCP keepalive (~2h default).
|
||||
HEALTH_CHECK_INTERVAL = int(os.getenv("REDIS_HEALTH_CHECK_INTERVAL", "30"))
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -24,6 +39,10 @@ def connect() -> Redis:
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=True,
|
||||
socket_timeout=SOCKET_TIMEOUT,
|
||||
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
||||
socket_keepalive=True,
|
||||
health_check_interval=HEALTH_CHECK_INTERVAL,
|
||||
)
|
||||
c.ping()
|
||||
return c
|
||||
@@ -46,6 +65,10 @@ async def connect_async() -> AsyncRedis:
|
||||
port=PORT,
|
||||
password=PASSWORD,
|
||||
decode_responses=True,
|
||||
socket_timeout=SOCKET_TIMEOUT,
|
||||
socket_connect_timeout=SOCKET_CONNECT_TIMEOUT,
|
||||
socket_keepalive=True,
|
||||
health_check_interval=HEALTH_CHECK_INTERVAL,
|
||||
)
|
||||
await c.ping()
|
||||
return c
|
||||
|
||||
@@ -366,7 +366,7 @@ async def execute_node(
|
||||
|
||||
try:
|
||||
if execution_context.dry_run and _dry_run_input is None:
|
||||
block_iter = simulate_block(node_block, input_data)
|
||||
block_iter = simulate_block(node_block, input_data, user_id=user_id)
|
||||
else:
|
||||
block_iter = node_block.execute(input_data, **extra_exec_kwargs)
|
||||
|
||||
|
||||
@@ -31,21 +31,31 @@ Inspired by https://github.com/Significant-Gravitas/agent-simulator
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from openai.types import CompletionUsage
|
||||
|
||||
from backend.blocks.agent import AgentExecutorBlock
|
||||
from backend.blocks.io import AgentInputBlock, AgentOutputBlock
|
||||
from backend.blocks.orchestrator import OrchestratorBlock
|
||||
from backend.copilot.token_tracking import persist_and_record_usage
|
||||
from backend.util.clients import get_openai_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Default simulator model — Gemini 2.5 Flash via OpenRouter (fast, cheap, good at
|
||||
# JSON generation). Configurable via ChatConfig.simulation_model
|
||||
# (CHAT_SIMULATION_MODEL env var).
|
||||
_DEFAULT_SIMULATOR_MODEL = "google/gemini-2.5-flash"
|
||||
# Default simulator model — Gemini 2.5 Flash-Lite via OpenRouter. Same provider
|
||||
# as Flash ($0.10 / $0.40 per MTok vs $0.30 / $1.20 — ~3× cheaper) with JSON-mode
|
||||
# reliability that's more than enough for dry-run shape-matching. Configurable
|
||||
# via ChatConfig.simulation_model (CHAT_SIMULATION_MODEL env var).
|
||||
_DEFAULT_SIMULATOR_MODEL = "google/gemini-2.5-flash-lite"
|
||||
|
||||
# OpenRouter-specific extra_body flag that embeds the real generation cost on
|
||||
# the response usage object. Same shape used by the baseline copilot service
|
||||
# and web_search tool — keep the three aligned.
|
||||
_OPENROUTER_INCLUDE_USAGE_COST: dict[str, Any] = {"usage": {"include": True}}
|
||||
|
||||
|
||||
def _simulator_model() -> str:
|
||||
@@ -105,10 +115,15 @@ async def _call_llm_for_simulation(
|
||||
user_prompt: str,
|
||||
*,
|
||||
label: str = "simulate",
|
||||
user_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a simulation prompt to the LLM and return the parsed JSON dict.
|
||||
|
||||
Handles client acquisition, retries on invalid JSON, and logging.
|
||||
Handles client acquisition, retries on invalid JSON, logging, and platform
|
||||
cost tracking. The dry-run simulator calls OpenRouter on the platform's
|
||||
key rather than a user's own API credentials, so every successful call is
|
||||
recorded against the triggering ``user_id``'s rate-limit counter via
|
||||
``persist_and_record_usage`` (same rails as every copilot turn).
|
||||
|
||||
Raises:
|
||||
RuntimeError: If no LLM client is available.
|
||||
@@ -133,6 +148,7 @@ async def _call_llm_for_simulation(
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
extra_body=_OPENROUTER_INCLUDE_USAGE_COST,
|
||||
)
|
||||
if not response.choices:
|
||||
raise ValueError("LLM returned empty choices array")
|
||||
@@ -141,13 +157,21 @@ async def _call_llm_for_simulation(
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError(f"LLM returned non-object JSON: {raw[:200]}")
|
||||
|
||||
logger.debug(
|
||||
"simulate(%s): attempt=%d tokens=%s/%s",
|
||||
label,
|
||||
attempt + 1,
|
||||
getattr(getattr(response, "usage", None), "prompt_tokens", "?"),
|
||||
getattr(getattr(response, "usage", None), "completion_tokens", "?"),
|
||||
)
|
||||
usage = response.usage
|
||||
if usage is not None:
|
||||
logger.debug(
|
||||
"simulate(%s): attempt=%d tokens=%d/%d",
|
||||
label,
|
||||
attempt + 1,
|
||||
usage.prompt_tokens,
|
||||
usage.completion_tokens,
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
"simulate(%s): attempt=%d usage unavailable", label, attempt + 1
|
||||
)
|
||||
|
||||
await _track_simulator_cost(usage=usage, user_id=user_id, model=model)
|
||||
return parsed
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as e:
|
||||
@@ -174,6 +198,69 @@ async def _call_llm_for_simulation(
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _extract_cost_usd(usage: CompletionUsage | None) -> float | None:
|
||||
"""Return the provider-reported USD cost on the response usage object.
|
||||
|
||||
OpenRouter attaches a ``cost`` field to the OpenAI-compatible usage object
|
||||
when the request body includes ``usage: {"include": True}``. The typed
|
||||
``CompletionUsage`` does not declare it, so we read it off ``model_extra``
|
||||
(pydantic v2's container for extras) to keep access fully typed — no
|
||||
``getattr``. Mirrors ``backend.copilot.tools.web_search._extract_cost_usd``
|
||||
and ``backend.copilot.baseline.service._extract_usage_cost``; keep the
|
||||
three in sync.
|
||||
"""
|
||||
if usage is None:
|
||||
return None
|
||||
extras = usage.model_extra or {}
|
||||
if "cost" not in extras:
|
||||
return None
|
||||
raw = extras["cost"]
|
||||
if raw is None:
|
||||
logger.error("[simulator] usage.cost is present but null")
|
||||
return None
|
||||
try:
|
||||
val = float(raw)
|
||||
except (TypeError, ValueError):
|
||||
logger.error("[simulator] usage.cost is not numeric: %r", raw)
|
||||
return None
|
||||
if not math.isfinite(val) or val < 0:
|
||||
logger.error("[simulator] usage.cost is non-finite or negative: %r", val)
|
||||
return None
|
||||
return val
|
||||
|
||||
|
||||
async def _track_simulator_cost(
|
||||
*,
|
||||
usage: CompletionUsage | None,
|
||||
user_id: str | None,
|
||||
model: str,
|
||||
) -> None:
|
||||
"""Record platform cost for a single simulator LLM call.
|
||||
|
||||
The simulator runs outside a copilot ``ChatSession`` — pass ``session=None``
|
||||
so ``persist_and_record_usage`` skips the session append but still charges
|
||||
the user's rate-limit counter and writes a ``PlatformCostLog`` entry. No
|
||||
user_id means no tracking (e.g. in-process tests that don't plumb one
|
||||
through); rate-limit accounting silently no-ops in that case.
|
||||
"""
|
||||
if usage is None:
|
||||
return
|
||||
cost_usd = _extract_cost_usd(usage)
|
||||
try:
|
||||
await persist_and_record_usage(
|
||||
session=None,
|
||||
user_id=user_id,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
log_prefix="[simulator]",
|
||||
cost_usd=cost_usd,
|
||||
model=model,
|
||||
provider="open_router",
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("[simulator] usage tracking failed: %s", exc)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prompt builders
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -393,12 +480,18 @@ def _default_for_input_result(result_schema: dict[str, Any], name: str | None) -
|
||||
async def simulate_block(
|
||||
block: Any,
|
||||
input_data: dict[str, Any],
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Simulate block execution using an LLM.
|
||||
|
||||
All block types (including MCPToolBlock) use the same generic LLM prompt
|
||||
which includes the block's run() source code for accurate simulation.
|
||||
|
||||
``user_id`` is threaded through to platform cost tracking — every dry-run
|
||||
LLM call hits the platform's OpenRouter key and is charged against the
|
||||
triggering user's rate-limit counter, same rails as copilot turns.
|
||||
|
||||
Note: callers should check ``prepare_dry_run(block, input_data)`` first.
|
||||
OrchestratorBlock and AgentExecutorBlock execute for real in dry-run mode
|
||||
(see manager.py).
|
||||
@@ -462,7 +555,9 @@ async def simulate_block(
|
||||
label = getattr(block, "name", "?")
|
||||
|
||||
try:
|
||||
parsed = await _call_llm_for_simulation(system_prompt, user_prompt, label=label)
|
||||
parsed = await _call_llm_for_simulation(
|
||||
system_prompt, user_prompt, label=label, user_id=user_id
|
||||
)
|
||||
|
||||
# Track which pins were yielded so we can fill in missing required
|
||||
# ones afterwards — downstream nodes connected to unyielded pins
|
||||
|
||||
@@ -5,6 +5,7 @@ Covers:
|
||||
- Input/output block passthrough
|
||||
- prepare_dry_run routing
|
||||
- simulate_block output-pin filling
|
||||
- Default simulator model + OpenRouter cost tracking
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -13,8 +14,14 @@ from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types import CompletionUsage
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
|
||||
from backend.executor.simulator import (
|
||||
_DEFAULT_SIMULATOR_MODEL,
|
||||
_extract_cost_usd,
|
||||
_truncate_input_values,
|
||||
_truncate_value,
|
||||
build_simulation_prompt,
|
||||
@@ -511,3 +518,217 @@ class TestSimulateBlockPassthrough:
|
||||
assert len(outputs) == 1
|
||||
assert outputs[0][0] == "error"
|
||||
assert "No client" in outputs[0][1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default model + OpenRouter cost tracking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sim_usage(
|
||||
*,
|
||||
prompt_tokens: int = 1200,
|
||||
completion_tokens: int = 300,
|
||||
cost: object = 0.000157,
|
||||
) -> CompletionUsage:
|
||||
"""Typed ``CompletionUsage`` carrying OpenRouter's ``cost`` extension
|
||||
via ``model_extra`` — same pattern as
|
||||
``copilot/tools/web_search_test.py::_usage``. ``model_construct``
|
||||
preserves unknown fields; ``model_validate`` would drop them."""
|
||||
payload: dict[str, Any] = {
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": prompt_tokens + completion_tokens,
|
||||
}
|
||||
if cost is not None:
|
||||
payload["cost"] = cost
|
||||
return CompletionUsage.model_construct(None, **payload)
|
||||
|
||||
|
||||
def _sim_completion(*, content: str, usage: CompletionUsage) -> ChatCompletion:
|
||||
"""Typed ``ChatCompletion`` shaped like an OpenRouter simulator
|
||||
response so the production code runs under real SDK types."""
|
||||
message = ChatCompletionMessage.model_construct(
|
||||
None, role="assistant", content=content
|
||||
)
|
||||
choice = Choice.model_construct(
|
||||
None, index=0, finish_reason="stop", message=message
|
||||
)
|
||||
return ChatCompletion.model_construct(
|
||||
None,
|
||||
id="cmpl-sim",
|
||||
object="chat.completion",
|
||||
created=0,
|
||||
model=_DEFAULT_SIMULATOR_MODEL,
|
||||
choices=[choice],
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
class TestDefaultSimulatorModel:
|
||||
"""Pin the default model — anyone flipping this without a cost review
|
||||
trips the test."""
|
||||
|
||||
def test_default_is_flash_lite(self) -> None:
|
||||
assert _DEFAULT_SIMULATOR_MODEL == "google/gemini-2.5-flash-lite"
|
||||
|
||||
|
||||
class TestExtractCostUsd:
|
||||
"""Provider-reported USD cost via typed ``model_extra`` — mirrors
|
||||
``copilot.tools.web_search._extract_cost_usd`` and
|
||||
``copilot.baseline.service._extract_usage_cost``."""
|
||||
|
||||
def test_returns_cost_value(self) -> None:
|
||||
assert _extract_cost_usd(_sim_usage(cost=0.000157)) == pytest.approx(0.000157)
|
||||
|
||||
def test_returns_none_when_usage_missing(self) -> None:
|
||||
assert _extract_cost_usd(None) is None
|
||||
|
||||
def test_returns_none_when_cost_field_missing(self) -> None:
|
||||
assert _extract_cost_usd(_sim_usage(cost=None)) is None
|
||||
|
||||
def test_returns_none_when_cost_is_explicit_null(self) -> None:
|
||||
usage = CompletionUsage.model_construct(
|
||||
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=None
|
||||
)
|
||||
assert _extract_cost_usd(usage) is None
|
||||
|
||||
def test_returns_none_when_cost_is_negative(self) -> None:
|
||||
usage = CompletionUsage.model_construct(
|
||||
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=-0.5
|
||||
)
|
||||
assert _extract_cost_usd(usage) is None
|
||||
|
||||
def test_accepts_numeric_string(self) -> None:
|
||||
usage = CompletionUsage.model_construct(
|
||||
None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost="0.017"
|
||||
)
|
||||
assert _extract_cost_usd(usage) == pytest.approx(0.017)
|
||||
|
||||
|
||||
class TestSimulatorCostTracking:
|
||||
"""Integration: mock the OpenAI client, confirm the simulator sends
|
||||
the flash-lite default + extra_body, then plumbs through to
|
||||
``persist_and_record_usage`` with ``provider='open_router'`` and the
|
||||
real ``usage.cost`` pulled off ``model_extra``."""
|
||||
|
||||
def _mock_client(self, fake_resp: ChatCompletion) -> tuple[Any, AsyncMock]:
|
||||
"""Build a fake ``AsyncOpenAI`` client. Same nested-type pattern as
|
||||
``copilot/tools/web_search_test.py::_mock_client`` — avoids
|
||||
MagicMock's auto-child-attr behaviour so the exact ``create`` call
|
||||
surface is what gets invoked."""
|
||||
create_mock = AsyncMock(return_value=fake_resp)
|
||||
client = type(
|
||||
"MC",
|
||||
(),
|
||||
{
|
||||
"chat": type(
|
||||
"C",
|
||||
(),
|
||||
{"completions": type("CC", (), {"create": create_mock})()},
|
||||
)()
|
||||
},
|
||||
)()
|
||||
return client, create_mock
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_default_model_and_tracks_cost(self) -> None:
|
||||
block = _make_block()
|
||||
fake_resp = _sim_completion(
|
||||
content='{"result": "simulated"}',
|
||||
usage=_sim_usage(prompt_tokens=1100, completion_tokens=220, cost=0.000189),
|
||||
)
|
||||
client, create_mock = self._mock_client(fake_resp)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.simulator.get_openai_client",
|
||||
return_value=client,
|
||||
),
|
||||
patch(
|
||||
"backend.executor.simulator.persist_and_record_usage",
|
||||
new=AsyncMock(return_value=1320),
|
||||
) as mock_track,
|
||||
):
|
||||
outputs = []
|
||||
async for name, data in simulate_block(
|
||||
block, {"query": "hello"}, user_id="user-42"
|
||||
):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert ("result", "simulated") in outputs
|
||||
|
||||
create_kwargs = create_mock.await_args.kwargs
|
||||
assert create_kwargs["model"] == _DEFAULT_SIMULATOR_MODEL
|
||||
assert create_kwargs["extra_body"] == {"usage": {"include": True}}
|
||||
|
||||
track_kwargs = mock_track.await_args.kwargs
|
||||
assert track_kwargs["provider"] == "open_router"
|
||||
assert track_kwargs["model"] == _DEFAULT_SIMULATOR_MODEL
|
||||
assert track_kwargs["user_id"] == "user-42"
|
||||
assert track_kwargs["prompt_tokens"] == 1100
|
||||
assert track_kwargs["completion_tokens"] == 220
|
||||
assert track_kwargs["cost_usd"] == pytest.approx(0.000189)
|
||||
assert track_kwargs["session"] is None
|
||||
assert track_kwargs["log_prefix"] == "[simulator]"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracks_even_when_cost_absent(self) -> None:
|
||||
"""Provider may omit ``cost`` (e.g. non-OpenRouter proxies). We
|
||||
still record token counts — ``persist_and_record_usage`` logs the
|
||||
turn and skips the rate-limit ledger when cost is ``None``."""
|
||||
block = _make_block()
|
||||
fake_resp = _sim_completion(
|
||||
content='{"result": "ok"}',
|
||||
usage=_sim_usage(prompt_tokens=100, completion_tokens=20, cost=None),
|
||||
)
|
||||
client, _ = self._mock_client(fake_resp)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.simulator.get_openai_client",
|
||||
return_value=client,
|
||||
),
|
||||
patch(
|
||||
"backend.executor.simulator.persist_and_record_usage",
|
||||
new=AsyncMock(return_value=120),
|
||||
) as mock_track,
|
||||
):
|
||||
async for _name, _data in simulate_block(
|
||||
block, {"query": "x"}, user_id="user-7"
|
||||
):
|
||||
pass
|
||||
|
||||
track_kwargs = mock_track.await_args.kwargs
|
||||
assert track_kwargs["cost_usd"] is None
|
||||
assert track_kwargs["user_id"] == "user-7"
|
||||
assert track_kwargs["provider"] == "open_router"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tracking_failure_does_not_break_simulation(self) -> None:
|
||||
"""Cost-tracking failures are warnings, not simulation failures —
|
||||
the block output must still flow to the caller."""
|
||||
block = _make_block()
|
||||
fake_resp = _sim_completion(
|
||||
content='{"result": "simulated"}',
|
||||
usage=_sim_usage(),
|
||||
)
|
||||
client, _ = self._mock_client(fake_resp)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"backend.executor.simulator.get_openai_client",
|
||||
return_value=client,
|
||||
),
|
||||
patch(
|
||||
"backend.executor.simulator.persist_and_record_usage",
|
||||
new=AsyncMock(side_effect=RuntimeError("redis down")),
|
||||
),
|
||||
):
|
||||
outputs = []
|
||||
async for name, data in simulate_block(
|
||||
block, {"query": "hello"}, user_id="user-42"
|
||||
):
|
||||
outputs.append((name, data))
|
||||
|
||||
assert ("result", "simulated") in outputs
|
||||
|
||||
@@ -48,6 +48,16 @@ class Flag(str, Enum):
|
||||
STRIPE_PRICE_BUSINESS = "stripe-price-id-business"
|
||||
GRAPHITI_MEMORY = "graphiti-memory"
|
||||
|
||||
# Copilot model routing — string-valued, returns the model identifier
|
||||
# (e.g. ``"anthropic/claude-sonnet-4-6"`` or ``"moonshotai/kimi-k2.6"``)
|
||||
# to use for each cell of the (mode, tier) matrix. Falls back to the
|
||||
# ``CHAT_*_MODEL`` env/config default when the flag is unset or LD is
|
||||
# unavailable. Evaluated per user_id so cohorts can be targeted.
|
||||
COPILOT_FAST_STANDARD_MODEL = "copilot-fast-standard-model"
|
||||
COPILOT_FAST_ADVANCED_MODEL = "copilot-fast-advanced-model"
|
||||
COPILOT_THINKING_STANDARD_MODEL = "copilot-thinking-standard-model"
|
||||
COPILOT_THINKING_ADVANCED_MODEL = "copilot-thinking-advanced-model"
|
||||
|
||||
|
||||
def is_configured() -> bool:
|
||||
"""Check if LaunchDarkly is configured with an SDK key."""
|
||||
|
||||
@@ -0,0 +1,177 @@
|
||||
import { act, renderHook } from "@testing-library/react";
|
||||
import type { UIMessage } from "ai";
|
||||
import { afterEach, beforeEach, describe, expect, it, vi } from "vitest";
|
||||
|
||||
import { useCopilotStream } from "../useCopilotStream";
|
||||
|
||||
// Capture the args passed to ``useChat`` so tests can invoke onFinish/onError
|
||||
// directly — that's the only way to drive handleReconnect without a real SSE.
|
||||
let lastUseChatArgs: {
|
||||
onFinish?: (args: { isDisconnect?: boolean; isAbort?: boolean }) => void;
|
||||
onError?: (err: Error) => void;
|
||||
} | null = null;
|
||||
|
||||
const resumeStreamMock = vi.fn();
|
||||
const sdkStopMock = vi.fn();
|
||||
const sdkSendMessageMock = vi.fn();
|
||||
const setMessagesMock = vi.fn();
|
||||
|
||||
function resetSdkMocks() {
|
||||
lastUseChatArgs = null;
|
||||
resumeStreamMock.mockReset();
|
||||
sdkStopMock.mockReset();
|
||||
sdkSendMessageMock.mockReset();
|
||||
setMessagesMock.mockReset();
|
||||
}
|
||||
|
||||
vi.mock("@ai-sdk/react", () => ({
|
||||
useChat: (args: unknown) => {
|
||||
lastUseChatArgs = args as typeof lastUseChatArgs;
|
||||
return {
|
||||
messages: [] as UIMessage[],
|
||||
sendMessage: sdkSendMessageMock,
|
||||
stop: sdkStopMock,
|
||||
status: "ready" as const,
|
||||
error: undefined,
|
||||
setMessages: setMessagesMock,
|
||||
resumeStream: resumeStreamMock,
|
||||
};
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("ai", async () => {
|
||||
const actual = await vi.importActual<typeof import("ai")>("ai");
|
||||
return {
|
||||
...actual,
|
||||
DefaultChatTransport: class {
|
||||
constructor(public opts: unknown) {}
|
||||
},
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("@tanstack/react-query", () => ({
|
||||
useQueryClient: () => ({ invalidateQueries: vi.fn() }),
|
||||
}));
|
||||
|
||||
vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({
|
||||
getGetV2GetCopilotUsageQueryKey: () => ["copilot-usage"],
|
||||
getGetV2GetSessionQueryKey: (id: string) => ["session", id],
|
||||
postV2CancelSessionTask: vi.fn(),
|
||||
deleteV2DisconnectSessionStream: vi.fn().mockResolvedValue(undefined),
|
||||
}));
|
||||
|
||||
vi.mock("@/components/molecules/Toast/use-toast", () => ({
|
||||
toast: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@/services/environment", () => ({
|
||||
environment: {
|
||||
getAGPTServerBaseUrl: () => "http://localhost",
|
||||
},
|
||||
}));
|
||||
|
||||
vi.mock("../helpers", async () => {
|
||||
const actual =
|
||||
await vi.importActual<typeof import("../helpers")>("../helpers");
|
||||
return {
|
||||
...actual,
|
||||
getCopilotAuthHeaders: vi.fn().mockResolvedValue({}),
|
||||
disconnectSessionStream: vi.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
vi.mock("../useHydrateOnStreamEnd", () => ({
|
||||
useHydrateOnStreamEnd: () => undefined,
|
||||
}));
|
||||
|
||||
function renderStream() {
|
||||
return renderHook(() =>
|
||||
useCopilotStream({
|
||||
sessionId: "sess-1",
|
||||
hydratedMessages: [],
|
||||
hasActiveStream: false,
|
||||
refetchSession: vi.fn().mockResolvedValue({ data: undefined }),
|
||||
copilotMode: undefined,
|
||||
copilotModel: undefined,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
describe("useCopilotStream — reconnect debounce", () => {
|
||||
beforeEach(() => {
|
||||
resetSdkMocks();
|
||||
vi.useFakeTimers();
|
||||
// Pin Date.now so sinceLastResume math is deterministic. The hook reads
|
||||
// Date.now() both when stashing lastReconnectResumeAtRef and when
|
||||
// deciding whether to debounce.
|
||||
vi.setSystemTime(new Date(2025, 0, 1, 12, 0, 0));
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
});
|
||||
|
||||
it("coalesces a burst of disconnect events into one resumeStream call", async () => {
|
||||
renderStream();
|
||||
|
||||
// First disconnect — schedules a reconnect at the exponential backoff
|
||||
// delay (1000ms for attempt #1).
|
||||
await act(async () => {
|
||||
await lastUseChatArgs!.onFinish!({ isDisconnect: true });
|
||||
});
|
||||
|
||||
// Fire the scheduled timer → resumeStream runs once and stamps
|
||||
// lastReconnectResumeAtRef.current = Date.now().
|
||||
await act(async () => {
|
||||
await vi.advanceTimersByTimeAsync(1_000);
|
||||
});
|
||||
expect(resumeStreamMock).toHaveBeenCalledTimes(1);
|
||||
|
||||
// A second disconnect arrives immediately after (still inside the
|
||||
// 1500ms debounce window) — the debounce path must fire and queue a
|
||||
// coalesced timer, NOT a fresh resume.
|
||||
await act(async () => {
|
||||
await lastUseChatArgs!.onFinish!({ isDisconnect: true });
|
||||
});
|
||||
expect(resumeStreamMock).toHaveBeenCalledTimes(1);
|
||||
|
||||
// The coalesced timer fires at the window boundary and reschedules a
|
||||
// real reconnect. Advance past the window AND past the second
|
||||
// reconnect's backoff (attempt #2 = 2000ms) so resume runs.
|
||||
await act(async () => {
|
||||
await vi.advanceTimersByTimeAsync(1_500);
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.advanceTimersByTimeAsync(2_000);
|
||||
});
|
||||
expect(resumeStreamMock).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
|
||||
it("does not debounce a reconnect that arrives after the window closes", async () => {
|
||||
renderStream();
|
||||
|
||||
// First reconnect cycle.
|
||||
await act(async () => {
|
||||
await lastUseChatArgs!.onFinish!({ isDisconnect: true });
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.advanceTimersByTimeAsync(1_000);
|
||||
});
|
||||
expect(resumeStreamMock).toHaveBeenCalledTimes(1);
|
||||
|
||||
// Wait past the debounce window before the next disconnect.
|
||||
await act(async () => {
|
||||
await vi.advanceTimersByTimeAsync(2_000);
|
||||
});
|
||||
|
||||
// Now a fresh disconnect should go through the normal path (NOT the
|
||||
// debounce branch) and schedule a backoff of 2000ms (attempt #2).
|
||||
await act(async () => {
|
||||
await lastUseChatArgs!.onFinish!({ isDisconnect: true });
|
||||
});
|
||||
await act(async () => {
|
||||
await vi.advanceTimersByTimeAsync(2_000);
|
||||
});
|
||||
expect(resumeStreamMock).toHaveBeenCalledTimes(2);
|
||||
});
|
||||
});
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
formatNotificationTitle,
|
||||
getSendSuppressionReason,
|
||||
parseSessionIDs,
|
||||
shouldDebounceReconnect,
|
||||
shouldSuppressDuplicateSend,
|
||||
} from "./helpers";
|
||||
|
||||
@@ -466,3 +467,88 @@ describe("deduplicateMessages", () => {
|
||||
expect(result).toHaveLength(2); // duplicate step-start messages are deduped
|
||||
});
|
||||
});
|
||||
|
||||
describe("shouldDebounceReconnect", () => {
|
||||
const WINDOW_MS = 1_500;
|
||||
|
||||
it("returns null for the first reconnect (lastResumeAt === 0)", () => {
|
||||
expect(shouldDebounceReconnect(0, 10_000, WINDOW_MS)).toBeNull();
|
||||
});
|
||||
|
||||
it("returns null for a negative lastResumeAt sentinel", () => {
|
||||
// Defensive: a negative value is still treated as "no reconnect yet".
|
||||
expect(shouldDebounceReconnect(-1, 10_000, WINDOW_MS)).toBeNull();
|
||||
});
|
||||
|
||||
it("returns the remaining delay when now is inside the window", () => {
|
||||
// 500ms since the last resume — the caller must wait another 1000ms
|
||||
// before the storm cap reopens.
|
||||
const remaining = shouldDebounceReconnect(1_000, 1_500, WINDOW_MS);
|
||||
expect(remaining).toBe(1_000);
|
||||
});
|
||||
|
||||
it("coalesces a reconnect that arrives immediately after the previous resume", () => {
|
||||
// now === lastResumeAt → sinceLastResume === 0, so the full window remains.
|
||||
const remaining = shouldDebounceReconnect(5_000, 5_000, WINDOW_MS);
|
||||
expect(remaining).toBe(WINDOW_MS);
|
||||
});
|
||||
|
||||
it("returns null when exactly on the window boundary", () => {
|
||||
// sinceLastResume === windowMs is NOT inside the window — the next
|
||||
// reconnect should fire immediately.
|
||||
expect(shouldDebounceReconnect(1_000, 2_500, WINDOW_MS)).toBeNull();
|
||||
});
|
||||
|
||||
it("returns null when the window has elapsed", () => {
|
||||
expect(shouldDebounceReconnect(1_000, 5_000, WINDOW_MS)).toBeNull();
|
||||
});
|
||||
|
||||
it("returns a small remaining delay at the far edge of the window", () => {
|
||||
// 1ms before the window closes → 1ms left.
|
||||
const remaining = shouldDebounceReconnect(1_000, 2_499, WINDOW_MS);
|
||||
expect(remaining).toBe(1);
|
||||
});
|
||||
|
||||
it("collapses a burst of reconnects into one debounced scheduling", () => {
|
||||
// Simulates the browser tab-throttle storm: three reconnect calls fire
|
||||
// within a single second after the last resume. Only the first slot
|
||||
// would actually run; subsequent calls must always be coalesced.
|
||||
const lastResumeAt = 10_000;
|
||||
const firstCallRemaining = shouldDebounceReconnect(
|
||||
lastResumeAt,
|
||||
10_100,
|
||||
WINDOW_MS,
|
||||
);
|
||||
const secondCallRemaining = shouldDebounceReconnect(
|
||||
lastResumeAt,
|
||||
10_200,
|
||||
WINDOW_MS,
|
||||
);
|
||||
const thirdCallRemaining = shouldDebounceReconnect(
|
||||
lastResumeAt,
|
||||
10_300,
|
||||
WINDOW_MS,
|
||||
);
|
||||
expect(firstCallRemaining).toBe(1_400);
|
||||
expect(secondCallRemaining).toBe(1_300);
|
||||
expect(thirdCallRemaining).toBe(1_200);
|
||||
});
|
||||
|
||||
it("allows a reconnect to fire immediately once the window has passed", () => {
|
||||
// After the window expires, a retry that came in earlier can now fire
|
||||
// rather than stalling the loop. Guards against the regression that
|
||||
// motivated the coalesce-instead-of-drop fix.
|
||||
const lastResumeAt = 10_000;
|
||||
expect(
|
||||
shouldDebounceReconnect(lastResumeAt, 10_500, WINDOW_MS),
|
||||
).not.toBeNull();
|
||||
expect(shouldDebounceReconnect(lastResumeAt, 11_500, WINDOW_MS)).toBeNull();
|
||||
});
|
||||
|
||||
it("honours a custom windowMs value", () => {
|
||||
// Shouldn't hard-code 1500 anywhere: the helper is generic over the
|
||||
// window.
|
||||
expect(shouldDebounceReconnect(1_000, 1_500, 2_000)).toBe(1_500);
|
||||
expect(shouldDebounceReconnect(1_000, 3_500, 2_000)).toBeNull();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -184,6 +184,28 @@ export function disconnectSessionStream(sessionId: string): void {
|
||||
deleteV2DisconnectSessionStream(sessionId).catch(() => {});
|
||||
}
|
||||
|
||||
/**
|
||||
* Decide whether a reconnect request must be coalesced onto the debounce
|
||||
* window boundary, rather than firing immediately.
|
||||
*
|
||||
* Returns the remaining milliseconds until the window closes (so the caller
|
||||
* can schedule a `setTimeout` for that delay) when the previous resume
|
||||
* happened inside the window, or `null` to let the reconnect proceed now.
|
||||
*
|
||||
* `lastResumeAt === 0` signals "no reconnect has fired yet in this session"
|
||||
* — the first reconnect always passes through regardless of `now`.
|
||||
*/
|
||||
export function shouldDebounceReconnect(
|
||||
lastResumeAt: number,
|
||||
now: number,
|
||||
windowMs: number,
|
||||
): number | null {
|
||||
if (lastResumeAt <= 0) return null;
|
||||
const sinceLastResume = now - lastResumeAt;
|
||||
if (sinceLastResume >= windowMs) return null;
|
||||
return windowMs - sinceLastResume;
|
||||
}
|
||||
|
||||
/**
|
||||
* Deduplicate messages by ID and by consecutive content fingerprint.
|
||||
*
|
||||
|
||||
@@ -313,11 +313,19 @@ function getWebAccordionData(
|
||||
: null;
|
||||
|
||||
if (results) {
|
||||
const deep = inp.deep === true;
|
||||
const noun = deep ? "research source" : "search result";
|
||||
const answer = getStringField(output, "answer");
|
||||
return {
|
||||
title: `${results.length} search result${results.length === 1 ? "" : "s"}`,
|
||||
title: `${results.length} ${noun}${results.length === 1 ? "" : "s"}`,
|
||||
description: query ? truncate(query, 80) : undefined,
|
||||
content: (
|
||||
<div className="space-y-3">
|
||||
{answer && (
|
||||
<div className="whitespace-pre-wrap rounded-md bg-slate-50 p-3 text-sm text-slate-800">
|
||||
{answer}
|
||||
</div>
|
||||
)}
|
||||
{results.map((r, i) => {
|
||||
const title = getStringField(r, "title") ?? "(untitled)";
|
||||
const href = getStringField(r, "url") ?? "";
|
||||
|
||||
@@ -141,6 +141,7 @@ describe("GenericTool", () => {
|
||||
function makeWebSearchPart(
|
||||
results: Array<Record<string, unknown>>,
|
||||
query = "kimi k2.6",
|
||||
answer = "",
|
||||
): ToolUIPart {
|
||||
return {
|
||||
type: "tool-web_search",
|
||||
@@ -149,6 +150,7 @@ describe("GenericTool", () => {
|
||||
input: { query },
|
||||
output: {
|
||||
type: "web_search_response",
|
||||
answer,
|
||||
results,
|
||||
query,
|
||||
search_requests: 1,
|
||||
@@ -254,6 +256,25 @@ describe("GenericTool", () => {
|
||||
expect(normalized).toContain('Searched "kimi k2.6"');
|
||||
});
|
||||
|
||||
it("renders the synthesised answer above the citations when present", () => {
|
||||
render(
|
||||
<GenericTool
|
||||
part={makeWebSearchPart(
|
||||
[
|
||||
{ title: "Citation 1", url: "https://example.com/one" },
|
||||
{ title: "Citation 2", url: "https://example.com/two" },
|
||||
],
|
||||
"kimi k2.6 launch",
|
||||
"Kimi K2.6 launched on 2026-04-20 with SWE-Bench parity to Opus.",
|
||||
)}
|
||||
/>,
|
||||
);
|
||||
fireEvent.click(screen.getByRole("button", { expanded: false }));
|
||||
expect(
|
||||
screen.getByText(/Kimi K2\.6 launched on 2026-04-20/),
|
||||
).not.toBeNull();
|
||||
});
|
||||
|
||||
it("uses '(untitled)' when a search result has no title", () => {
|
||||
render(
|
||||
<GenericTool
|
||||
|
||||
@@ -205,6 +205,14 @@ export function humanizeFileName(filePath: string): string {
|
||||
/* Animation text */
|
||||
/* ------------------------------------------------------------------ */
|
||||
|
||||
// web_search accepts a ``deep`` arg that dispatches to a multi-step
|
||||
// research model; render a distinct verb ("Researching"/"Researched"/
|
||||
// "Research failed") so users know the call takes longer.
|
||||
function _isDeepWebSearch(part: ToolUIPart): boolean {
|
||||
const input = part.input as Record<string, unknown> | undefined;
|
||||
return input?.deep === true;
|
||||
}
|
||||
|
||||
export function getAnimationText(
|
||||
part: ToolUIPart,
|
||||
category: ToolCategory,
|
||||
@@ -223,9 +231,11 @@ export function getAnimationText(
|
||||
: "Running command\u2026";
|
||||
case "web":
|
||||
if (toolName === "WebSearch" || toolName === "web_search") {
|
||||
const deep = _isDeepWebSearch(part);
|
||||
const verb = deep ? "Researching" : "Searching";
|
||||
return shortSummary
|
||||
? `Searching "${shortSummary}"`
|
||||
: "Searching the web\u2026";
|
||||
? `${verb} "${shortSummary}"`
|
||||
: `${verb} the web\u2026`;
|
||||
}
|
||||
return shortSummary
|
||||
? `Fetching ${shortSummary}`
|
||||
@@ -285,9 +295,12 @@ export function getAnimationText(
|
||||
return shortSummary ? `Ran: ${shortSummary}` : "Command completed";
|
||||
case "web":
|
||||
if (toolName === "WebSearch" || toolName === "web_search") {
|
||||
return shortSummary
|
||||
? `Searched "${shortSummary}"`
|
||||
const deep = _isDeepWebSearch(part);
|
||||
const verb = deep ? "Researched" : "Searched";
|
||||
const completed = deep
|
||||
? "Web research completed"
|
||||
: "Web search completed";
|
||||
return shortSummary ? `${verb} "${shortSummary}"` : completed;
|
||||
}
|
||||
return shortSummary
|
||||
? `Fetched ${shortSummary}`
|
||||
@@ -354,9 +367,10 @@ export function getAnimationText(
|
||||
case "bash":
|
||||
return "Command failed";
|
||||
case "web":
|
||||
return toolName === "WebSearch" || toolName === "web_search"
|
||||
? "Search failed"
|
||||
: "Fetch failed";
|
||||
if (toolName === "WebSearch" || toolName === "web_search") {
|
||||
return _isDeepWebSearch(part) ? "Research failed" : "Search failed";
|
||||
}
|
||||
return "Fetch failed";
|
||||
case "browser":
|
||||
return "Browser action failed";
|
||||
default:
|
||||
|
||||
@@ -18,6 +18,7 @@ import {
|
||||
resolveInProgressTools,
|
||||
getSendSuppressionReason,
|
||||
disconnectSessionStream,
|
||||
shouldDebounceReconnect,
|
||||
} from "./helpers";
|
||||
import type { CopilotLlmModel, CopilotMode } from "./store";
|
||||
import { useHydrateOnStreamEnd } from "./useHydrateOnStreamEnd";
|
||||
@@ -25,6 +26,19 @@ import { useHydrateOnStreamEnd } from "./useHydrateOnStreamEnd";
|
||||
const RECONNECT_BASE_DELAY_MS = 1_000;
|
||||
const RECONNECT_MAX_ATTEMPTS = 3;
|
||||
|
||||
/**
|
||||
* Minimum spacing between successive reconnect attempts.
|
||||
* `isReconnectScheduledRef` already prevents OVERLAPPING reconnects, but
|
||||
* tab-throttle / visibility wake bursts can fire `onFinish(isDisconnect)`
|
||||
* several times inside a single second — each one would schedule a fresh
|
||||
* reconnect the moment the previous timer cleared the ref. Requests that
|
||||
* arrive inside this window since the last reconnect's resume are COALESCED:
|
||||
* scheduled to run at the window boundary rather than dropped, so a
|
||||
* fast-failing resume (e.g. a 502 on GET /stream that trips `onError` inside
|
||||
* 500 ms) still retries instead of stalling the retry loop.
|
||||
*/
|
||||
const RECONNECT_DEBOUNCE_MS = 1_500;
|
||||
|
||||
/** Minimum time the page must have been hidden to trigger a wake re-sync. */
|
||||
const WAKE_RESYNC_THRESHOLD_MS = 30_000;
|
||||
|
||||
@@ -110,6 +124,11 @@ export function useCopilotStream({
|
||||
const isReconnectScheduledRef = useRef(false);
|
||||
const [isReconnectScheduled, setIsReconnectScheduled] = useState(false);
|
||||
const reconnectTimerRef = useRef<ReturnType<typeof setTimeout>>();
|
||||
// Timestamp of the last reconnect's actual resume call — used together
|
||||
// with RECONNECT_DEBOUNCE_MS to drop rapid duplicate reconnect requests
|
||||
// (e.g. visibility throttle firing onFinish(isDisconnect) several times
|
||||
// in the same second). 0 = no reconnect has fired yet in this session.
|
||||
const lastReconnectResumeAtRef = useRef(0);
|
||||
const hasShownDisconnectToast = useRef(false);
|
||||
// Set when the user explicitly clicks stop — prevents onError from
|
||||
// triggering a reconnect cycle for the resulting AbortError.
|
||||
@@ -127,6 +146,32 @@ export function useCopilotStream({
|
||||
function handleReconnect(sid: string) {
|
||||
if (isReconnectScheduledRef.current || !sid) return;
|
||||
|
||||
// Debounce: if the previous reconnect resumed within the last
|
||||
// RECONNECT_DEBOUNCE_MS, COALESCE this request onto the window boundary
|
||||
// rather than dropping it. Browser tab-throttle bursts can fire
|
||||
// onFinish(isDisconnect) 2–3 times in a second; without the debounce,
|
||||
// each fires its own GET /stream, each one replays the Redis stream,
|
||||
// and the flicker storm is back. Dropping the request silently (the
|
||||
// previous behaviour) stalled the retry loop when a resume failed
|
||||
// quickly — e.g. a 502 on GET /stream that trips onError inside 500 ms
|
||||
// while the 1500 ms window is still open. Scheduling the retry for
|
||||
// the remaining window preserves both the storm cap and the retry.
|
||||
const remainingDelay = shouldDebounceReconnect(
|
||||
lastReconnectResumeAtRef.current,
|
||||
Date.now(),
|
||||
RECONNECT_DEBOUNCE_MS,
|
||||
);
|
||||
if (remainingDelay !== null) {
|
||||
isReconnectScheduledRef.current = true;
|
||||
setIsReconnectScheduled(true);
|
||||
reconnectTimerRef.current = setTimeout(() => {
|
||||
isReconnectScheduledRef.current = false;
|
||||
setIsReconnectScheduled(false);
|
||||
handleReconnect(sid);
|
||||
}, remainingDelay);
|
||||
return;
|
||||
}
|
||||
|
||||
const nextAttempt = reconnectAttemptsRef.current + 1;
|
||||
if (nextAttempt > RECONNECT_MAX_ATTEMPTS) {
|
||||
setReconnectExhausted(true);
|
||||
@@ -163,6 +208,7 @@ export function useCopilotStream({
|
||||
}
|
||||
return prev;
|
||||
});
|
||||
lastReconnectResumeAtRef.current = Date.now();
|
||||
resumeStreamRef.current();
|
||||
}, delay);
|
||||
}
|
||||
@@ -469,6 +515,7 @@ export function useCopilotStream({
|
||||
setRateLimitMessage(null);
|
||||
hasShownDisconnectToast.current = false;
|
||||
lastSubmittedMsgRef.current = null;
|
||||
lastReconnectResumeAtRef.current = 0;
|
||||
setReconnectExhausted(false);
|
||||
setIsSyncing(false);
|
||||
hasResumedRef.current.clear();
|
||||
|
||||
@@ -2119,7 +2119,8 @@
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/MemoryForgetConfirmResponse"
|
||||
}
|
||||
},
|
||||
{ "$ref": "#/components/schemas/TodoWriteResponse" }
|
||||
],
|
||||
"title": "Response Getv2[Dummy] Tool Response Type Export For Codegen"
|
||||
}
|
||||
@@ -14648,7 +14649,8 @@
|
||||
"memory_store",
|
||||
"memory_search",
|
||||
"memory_forget_candidates",
|
||||
"memory_forget_confirm"
|
||||
"memory_forget_confirm",
|
||||
"todo_write"
|
||||
],
|
||||
"title": "ResponseType",
|
||||
"description": "Types of tool responses."
|
||||
@@ -16810,6 +16812,52 @@
|
||||
"required": ["timezone"],
|
||||
"title": "TimezoneResponse"
|
||||
},
|
||||
"TodoItem": {
|
||||
"properties": {
|
||||
"content": {
|
||||
"type": "string",
|
||||
"title": "Content",
|
||||
"description": "Imperative description of the task."
|
||||
},
|
||||
"activeForm": {
|
||||
"type": "string",
|
||||
"title": "Activeform",
|
||||
"description": "Present-continuous form shown while the task is running."
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["pending", "in_progress", "completed"],
|
||||
"title": "Status",
|
||||
"default": "pending"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["content", "activeForm"],
|
||||
"title": "TodoItem",
|
||||
"description": "One entry in a ``TodoWrite`` checklist.\n\nMirrors the schema used by Claude Code's built-in ``TodoWrite`` tool so\nthe frontend's ``GenericTool`` accordion renders baseline-emitted todos\nidentically to SDK-emitted ones."
|
||||
},
|
||||
"TodoWriteResponse": {
|
||||
"properties": {
|
||||
"type": {
|
||||
"$ref": "#/components/schemas/ResponseType",
|
||||
"default": "todo_write"
|
||||
},
|
||||
"message": { "type": "string", "title": "Message" },
|
||||
"session_id": {
|
||||
"anyOf": [{ "type": "string" }, { "type": "null" }],
|
||||
"title": "Session Id"
|
||||
},
|
||||
"todos": {
|
||||
"items": { "$ref": "#/components/schemas/TodoItem" },
|
||||
"type": "array",
|
||||
"title": "Todos"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"required": ["message"],
|
||||
"title": "TodoWriteResponse",
|
||||
"description": "Ack returned by ``TodoWrite``.\n\nThe tool is effectively stateless — the authoritative task list lives in\nthe assistant's latest tool-call arguments, which are replayed from the\ntranscript on each turn. The tool output only needs to confirm that the\nupdate was accepted so the model can proceed."
|
||||
},
|
||||
"TokenIntrospectionResult": {
|
||||
"properties": {
|
||||
"active": { "type": "boolean", "title": "Active" },
|
||||
|
||||
Reference in New Issue
Block a user