diff --git a/.claude/skills/pr-polish/SKILL.md b/.claude/skills/pr-polish/SKILL.md new file mode 100644 index 0000000000..3b36adee14 --- /dev/null +++ b/.claude/skills/pr-polish/SKILL.md @@ -0,0 +1,245 @@ +--- +name: pr-polish +description: Alternate /pr-review and /pr-address on a PR until the PR is truly mergeable — no new review findings, zero unresolved inline threads, zero unaddressed top-level reviews or issue comments, all CI checks green, and two consecutive quiet polls after CI settles. Use when the user wants a PR polished to merge-ready without setting a fixed number of rounds. +user-invocable: true +argument-hint: "[PR number or URL] — if omitted, finds PR for current branch." +metadata: + author: autogpt-team + version: "1.0.0" +--- + +# PR Polish + +**Goal.** Drive a PR to merge-ready by alternating `/pr-review` and `/pr-address` until **all** of the following hold: + +1. The most recent `/pr-review` produces **zero new findings** (no new inline comments, no new top-level reviews with a non-empty body). +2. Every inline review thread reachable via GraphQL reports `isResolved: true`. +3. Every non-bot, non-author top-level review has been acknowledged (replied-to) OR resolved via a thread it spawned. +4. Every non-bot, non-author issue comment has been acknowledged (replied-to). +5. Every CI check is `conclusion: "success"` or `"skipped"` / `"neutral"` — none `"failure"` or still pending. +6. **Two consecutive post-CI polls** (≥60s apart) stay clean — no new threads, no new non-empty reviews, no new issue comments. Bots (coderabbitai, sentry, autogpt-reviewer) frequently post late after CI settles; a single green snapshot is not sufficient. + +**Do not stop at a fixed number of rounds.** If round N introduces new comments, round N+1 is required. Cap at `_MAX_ROUNDS = 10` as a safety valve, but expect 2–5 in practice. + +## TodoWrite + +Before starting, write two todos so the user can see the loop progression: + +- `Round {current}: /pr-review + /pr-address on PR #{N}` — current iteration. +- `Final polish polling: 2 consecutive clean polls, CI green, 0 unresolved` — runs after the last non-empty review round. + +Update the `current` round counter at the start of each iteration; mark `completed` only when the round's address step finishes (all new threads addressed + resolved). + +## Find the PR + +```bash +ARG_PR="${ARG:-}" +# Normalize URL → numeric ID if the skill arg is a pull-request URL. +if [[ "$ARG_PR" =~ ^https?://github\.com/[^/]+/[^/]+/pull/([0-9]+) ]]; then + ARG_PR="${BASH_REMATCH[1]}" +fi +PR="${ARG_PR:-$(gh pr list --head "$(git branch --show-current)" --repo Significant-Gravitas/AutoGPT --json number --jq '.[0].number')}" +if [ -z "$PR" ] || [ "$PR" = "null" ]; then + echo "No PR found for current branch. Provide a PR number or URL as the skill arg." + exit 1 +fi +echo "Polishing PR #$PR" +``` + +## The outer loop + +```text +round = 0 +while round < _MAX_ROUNDS: + round += 1 + baseline = snapshot_state(PR) # see "Snapshotting state" below + invoke_skill("pr-review", PR) # posts findings as inline comments / top-level review + findings = diff_state(PR, baseline) + if findings.total == 0: + break # no new findings → go to polish polling + invoke_skill("pr-address", PR) # resolves every unresolved thread + CI failure +# Post-loop: polish polling (see below). +polish_polling(PR) +``` + +### Snapshotting state + +Before each `/pr-review`, capture a baseline so the diff after the review reflects **only** what the review just added (not pre-existing threads): + +```bash +# Inline threads — total count + latest databaseId per thread +gh api graphql -f query=" +{ + repository(owner: \"Significant-Gravitas\", name: \"AutoGPT\") { + pullRequest(number: ${PR}) { + reviewThreads(first: 100) { + totalCount + nodes { + id + isResolved + comments(last: 1) { nodes { databaseId } } + } + } + } + } +}" > /tmp/baseline_threads.json + +# Top-level reviews — count + latest id per non-empty review +gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}/reviews" --paginate \ + --jq '[.[] | select((.body // "") != "") | {id, user: .user.login, state, submitted_at}]' \ + > /tmp/baseline_reviews.json + +# Issue comments — count + latest id per non-bot, non-author comment. +# Bots are filtered by User.type == "Bot" (GitHub sets this for app/bot +# accounts like coderabbitai, github-actions, sentry-io). The author is +# filtered by comparing login to the PR author — export it so jq can see it. +AUTHOR=$(gh api "repos/Significant-Gravitas/AutoGPT/pulls/${PR}" --jq '.user.login') +gh api "repos/Significant-Gravitas/AutoGPT/issues/${PR}/comments" --paginate \ + --jq --arg author "$AUTHOR" \ + '[.[] | select(.user.type != "Bot" and .user.login != $author) + | {id, user: .user.login, created_at}]' \ + > /tmp/baseline_issue_comments.json +``` + +### Diffing after a review + +After `/pr-review` runs, any of these counting as "new findings" means another address round is needed: + +- New inline thread `id` not in the baseline. +- An existing thread whose latest comment `databaseId` is higher than the baseline's (new reply on an old thread). +- A new top-level review `id` with a non-empty body. +- A new issue comment `id` from a non-bot, non-author user. + +If any of the four buckets is non-empty → not done; invoke `/pr-address` and loop. + +## Polish polling + +Once `/pr-review` produces zero new findings, do **not** exit yet. Bots (coderabbitai, sentry, autogpt-reviewer) commonly post late reviews after CI settles — 30–90 seconds after the final push. Poll at 60-second intervals: + +```text +NON_SUCCESS_TERMINAL = {"failure", "cancelled", "timed_out", "action_required", "startup_failure"} +clean_polls = 0 +required_clean = 2 +while clean_polls < required_clean: + # 1. CI gate — any terminal non-success conclusion (not just "failure") + # must trigger /pr-address. "success", "skipped", "neutral" are clean; + # anything else (including cancelled, timed_out, action_required) is a + # blocker that won't self-resolve. + ci = fetch_check_runs(PR) + if any ci.conclusion in NON_SUCCESS_TERMINAL: + invoke_skill("pr-address", PR) # address failures + any new comments + baseline = snapshot_state(PR) # reset — push during address invalidates old baseline + clean_polls = 0 + continue + if any ci.conclusion is None (still in_progress): + sleep 60; continue # wait without counting this as clean + + # 2. Comment / thread gate + threads = fetch_unresolved_threads(PR) + new_issue_comments = diff_against_baseline(issue_comments) + new_reviews = diff_against_baseline(reviews) + if threads or new_issue_comments or new_reviews: + invoke_skill("pr-address", PR) + baseline = snapshot_state(PR) # reset — the address loop just dealt with these, + # otherwise they stay "new" relative to the old baseline forever + clean_polls = 0 + continue + + # 3. Mergeability gate + mergeable = gh api repos/.../pulls/${PR} --jq '.mergeable' + if mergeable == false (CONFLICTING): + resolve_conflicts(PR) # see pr-address skill + clean_polls = 0 + continue + if mergeable is null (UNKNOWN): + sleep 60; continue + + clean_polls += 1 + sleep 60 +``` + +Only after `clean_polls == 2` do you report `ORCHESTRATOR:DONE`. + +### Why 2 clean polls, not 1 + +A single green snapshot can be misleading — the final CI check often completes ~30s before a bot posts its delayed review. One quiet cycle does not prove the PR is stable; two consecutive cycles with no new threads, reviews, or issue comments arriving gives high confidence nothing else is incoming. + +### Why checking every source each poll + +`/pr-address` polling inside a single round already re-checks its own comments, but `/pr-polish` sits a level above and must also catch: + +- New top-level reviews (autogpt-reviewer sometimes posts structured feedback only after several CI green cycles). +- Issue comments from human reviewers (not caught by inline thread polling). +- Sentry bug predictions that land on new line numbers post-push. +- Merge conflicts introduced by a race between your push and a merge to `dev`. + +## Invocation pattern + +Delegate to existing skills with the `Skill` tool; do not re-implement the review or address logic inline. This keeps the polish loop focused on orchestration and lets the child skills evolve independently. + +```python +Skill(skill="pr-review", args=pr_url) +Skill(skill="pr-address", args=pr_url) +``` + +After each child invocation, re-query GitHub state directly — never trust a summary for the stop condition. The orchestrator's `ORCHESTRATOR:DONE` is verified against actual GraphQL / REST responses per the rules in `pr-address`'s "Verify actual count before outputting ORCHESTRATOR:DONE" section. + +### **Auto-continue: do NOT end your response between child skills** + +`/pr-polish` is a single orchestration task — one invocation drives the PR all the way to merge-ready. When a child `Skill()` call returns control to you: + +- Do NOT summarize and stop. +- Do NOT wait for user confirmation to continue. +- Immediately, in the same response, perform the next loop step: state diff → decide next action → next `Skill()` call or polling sleep. + +The child skill returning is a **loop iteration boundary**, not a conversation turn boundary. You are expected to keep going until one of the exit conditions in the opening section is met (2 consecutive clean polls, `_MAX_ROUNDS` hit, or an unrecoverable error). + +If the user needs to approve a risky action mid-loop (e.g., a force-push or a destructive git operation), pause there — but not at the routine "round N finished, round N+1 needed" boundary. Those are silent transitions. + +## GitHub rate limits + +This skill issues many GraphQL calls (one review-thread query per outer iteration plus per-poll queries inside polish polling). Expect the GraphQL budget to be tight on large PRs. When `gh api rate_limit --jq .resources.graphql.remaining` drops below ~200, back off: + +- Fall back to REST for reads (flat `/pulls/{N}/comments`, `/pulls/{N}/reviews`, `/issues/{N}/comments`) per the `pr-address` skill's GraphQL-fallback section. +- Queue thread resolutions (GraphQL-only) until the budget resets; keep making progress on fixes + REST replies meanwhile. +- `sleep 5` between any batch of ≥20 writes to avoid secondary rate limits. + +## Safety valves + +- `_MAX_ROUNDS = 10` — if review+address rounds exceed this, stop and escalate to the user with a summary of what's still unresolved. A PR that cannot converge in 10 rounds has systemic issues that need human judgment. +- After each commit, run `poetry run format` / `pnpm format && pnpm lint && pnpm types` per the target codebase's conventions. A failing format check is CI `failure` that will never self-resolve. +- Every `/pr-review` round checks for **duplicate** concerns first (via `pr-review`'s own "Fetch existing review comments" step) so the loop does not re-post the same finding that a prior round already resolved. + +## Reporting + +When the skill finishes (either via two clean polls or hitting `_MAX_ROUNDS`), produce a compact summary: + +``` +PR #{N} polish complete ({rounds_completed} rounds): +- {X} inline threads opened and resolved +- {Y} CI failures fixed +- {Z} new commits pushed +Final state: CI green, {total} threads all resolved, mergeable. +``` + +If exiting via `_MAX_ROUNDS`, flag explicitly: + +``` +PR #{N} polish stopped at {_MAX_ROUNDS} rounds — NOT merge-ready: +- {N} threads still unresolved: {titles} +- CI status: {summary} +Needs human review. +``` + +## When to use this skill + +Use when the user says any of: +- "polish this PR" +- "keep reviewing and addressing until it's mergeable" +- "loop /pr-review + /pr-address until done" +- "make sure the PR is actually merge-ready" + +Do **not** use when: +- User wants just one review pass (→ `/pr-review`). +- User wants to address already-posted comments without further self-review (→ `/pr-address`). +- A fixed round count is explicitly requested (e.g., "do 3 rounds") — honour the count instead of converging. diff --git a/.claude/skills/pr-test/SKILL.md b/.claude/skills/pr-test/SKILL.md index b368fb7f0d..09699ec546 100644 --- a/.claude/skills/pr-test/SKILL.md +++ b/.claude/skills/pr-test/SKILL.md @@ -260,6 +260,32 @@ Use a `trap` so release runs even on `exit 1`: trap 'kill "$HEARTBEAT_PID" 2>/dev/null; rm -f "$LOCK"' EXIT INT TERM ``` +### **Release the lock AS SOON AS the test run is done** + +The lock guards **test execution**, not **app lifecycle**. Once Step 5 (record results) and Step 6 (post PR comment) are complete, release the lock IMMEDIATELY — even if: + +- The native `poetry run app` / `pnpm dev` processes are still running so the user can keep poking at the app manually. +- You're leaving docker containers up. +- You're tailing logs for a minute or two. + +Keeping the lock held past the test run is the single most common way `/pr-test` stalls other agents. **The app staying up is orthogonal to the lock; don't conflate them.** Sibling worktrees running their own `/pr-test` will kill the stray processes and free the ports themselves (Step 3c/3e-native handle that) — they just need the lock file gone. + +Concretely, the sequence at the end of every `/pr-test` run (success or failure) is: + +```bash +# 1. Write the final report + post PR comment — done above in Step 5/6. +# 2. Release the lock right now, even if the app is still up. +kill "$HEARTBEAT_PID" 2>/dev/null +rm -f "$LOCK" /tmp/pr-test-heartbeat.pid +echo "$(date -u +%Y-%m-%dT%H:%MZ) [pr-${PR_NUMBER}] released lock (app may still be running)" \ + >> /Users/majdyz/Code/AutoGPT/.ign.testing.log +# 3. Optionally leave the app running and note it so the user knows: +echo "Native stack still running on :3000 / :8006 for manual poking. Kill with:" +echo " pkill -9 -f 'poetry run app'; pkill -9 -f 'next-server|next dev'" +``` + +If a sibling agent's `/pr-test` needs to take over, it'll do the kill+rebuild dance from Step 3c/3e-native on its own — your only job is to not hold the lock file past the end of your test. + ### Shared status log `/Users/majdyz/Code/AutoGPT/.ign.testing.log` is an append-only channel any agent can read/write. Use it for "I'm waiting", "I'm done, resources free", or post-run notes: @@ -755,6 +781,19 @@ Upload screenshots to the PR using the GitHub Git API (no local git operations **CRITICAL — NEVER post a bare directory link like `https://github.com/.../tree/...`.** Every screenshot MUST appear as `![name](raw_url)` inline in the PR comment so reviewers can see them without clicking any links. After posting, the verification step below greps the comment for `![` tags and exits 1 if none are found — the test run is considered incomplete until this passes. +**CRITICAL — NEVER paste absolute local paths into the PR comment.** Strings like `/Users/…`, `/home/…`, `C:\…` are useless to every reviewer except you. Before posting, grep the final body for `/Users/`, `/home/`, `/tmp/`, `/private/`, `C:\`, `~/` and either drop those lines entirely or rewrite them as repo-relative paths (`autogpt_platform/backend/…`). The PR comment is an artifact reviewers on GitHub read — it must be self-contained on github.com. Keep local paths in `$RESULTS_DIR/test-report.md` for yourself; only copy the *content* they reference (excerpts, test names, log lines) into the PR comment, not the path. + +**Pre-post sanity check** (paste after building the comment body, before `gh api ... comments`): + +```bash +# Reject any local-looking absolute path or home-dir shortcut in the body +if grep -nE '(^|[^A-Za-z])(/Users/|/home/|/tmp/|/private/|C:\\|~/)[A-Za-z0-9]' "$COMMENT_FILE" ; then + echo "ABORT: local filesystem paths detected in PR comment body." + echo "Remove or rewrite as repo-relative (autogpt_platform/...) before posting." + exit 1 +fi +``` + ```bash # Upload screenshots via GitHub Git API (creates blobs, tree, commit, and ref remotely) REPO="Significant-Gravitas/AutoGPT" diff --git a/autogpt_platform/backend/backend/api/features/chat/routes.py b/autogpt_platform/backend/backend/api/features/chat/routes.py index 01020d690a..dfad86a36c 100644 --- a/autogpt_platform/backend/backend/api/features/chat/routes.py +++ b/autogpt_platform/backend/backend/api/features/chat/routes.py @@ -76,6 +76,7 @@ from backend.copilot.tools.models import ( SetupRequirementsResponse, SuggestedGoalResponse, TaskDecompositionResponse, + TodoWriteResponse, UnderstandingUpdatedResponse, ) from backend.copilot.tracking import track_user_message @@ -1443,6 +1444,7 @@ ToolResponseUnion = ( | MemorySearchResponse | MemoryForgetCandidatesResponse | MemoryForgetConfirmResponse + | TodoWriteResponse ) diff --git a/autogpt_platform/backend/backend/copilot/baseline/reasoning.py b/autogpt_platform/backend/backend/copilot/baseline/reasoning.py index 0c689ed4a7..1d0da8ce7e 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/reasoning.py +++ b/autogpt_platform/backend/backend/copilot/baseline/reasoning.py @@ -50,13 +50,13 @@ _VISIBLE_REASONING_TYPES = frozenset({"reasoning.text", "reasoning.summary"}) # (~4,700 deltas per turn in one observed session, vs ~28 for Sonnet); without # coalescing, every chunk is one Redis ``xadd`` + one SSE frame + one React # re-render of the non-virtualised chat list, which paint-storms the browser -# main thread and freezes the UI. Batching into ~32-char / ~40 ms windows -# cuts the event rate ~100x while staying snappy enough that the Reasoning +# main thread and freezes the UI. Batching into ~64-char / ~50 ms windows +# cuts the event rate ~150x while staying snappy enough that the Reasoning # collapse still feels live (well under the ~100 ms perceptual threshold). # Per-delta persistence to ``session.messages`` stays granular — we only # coalesce the *wire* emission. -_COALESCE_MIN_CHARS = 32 -_COALESCE_MAX_INTERVAL_MS = 40.0 +_COALESCE_MIN_CHARS = 64 +_COALESCE_MAX_INTERVAL_MS = 50.0 class ReasoningDetail(BaseModel): @@ -242,6 +242,12 @@ class BaselineReasoningEmitter: fresh ``ChatMessage(role="reasoning")`` is appended and mutated in-place as further deltas arrive; :meth:`close` drops the reference but leaves the appended row intact. + + ``render_in_ui=False`` suppresses only the live wire events + (``StreamReasoning*``); the ``role='reasoning'`` persistence row is + still appended so ``convertChatSessionToUiMessages.ts`` can hydrate + the reasoning bubble on reload. The state machine advances + identically either way. """ def __init__( @@ -250,21 +256,19 @@ class BaselineReasoningEmitter: *, coalesce_min_chars: int = _COALESCE_MIN_CHARS, coalesce_max_interval_ms: float = _COALESCE_MAX_INTERVAL_MS, + render_in_ui: bool = True, ) -> None: self._block_id: str = str(uuid.uuid4()) self._open: bool = False self._session_messages = session_messages self._current_row: ChatMessage | None = None - # Coalescing state — ``_pending_delta`` accumulates reasoning text - # between wire flushes. Providers like Kimi K2.6 emit very fine- - # grained chunks; batching them reduces Redis ``xadd`` + SSE + React - # re-render load by ~100x for equivalent text output. Tuning knobs - # are kwargs so tests can disable coalescing (``=0``) for - # deterministic event assertions. + # Coalescing state — tests can disable (``=0``) for deterministic + # event assertions. self._coalesce_min_chars = coalesce_min_chars self._coalesce_max_interval_ms = coalesce_max_interval_ms self._pending_delta: str = "" self._last_flush_monotonic: float = 0.0 + self._render_in_ui = render_in_ui @property def is_open(self) -> bool: @@ -296,8 +300,9 @@ class BaselineReasoningEmitter: # syscalls off the hot path without changing semantics. now = time.monotonic() if not self._open: - events.append(StreamReasoningStart(id=self._block_id)) - events.append(StreamReasoningDelta(id=self._block_id, delta=text)) + if self._render_in_ui: + events.append(StreamReasoningStart(id=self._block_id)) + events.append(StreamReasoningDelta(id=self._block_id, delta=text)) self._open = True self._last_flush_monotonic = now if self._session_messages is not None: @@ -305,17 +310,15 @@ class BaselineReasoningEmitter: self._session_messages.append(self._current_row) return events - # Persist per-delta (no coalescing here — the session snapshot stays - # consistent at every chunk boundary, independent of the wire - # coalesce window). if self._current_row is not None: self._current_row.content = (self._current_row.content or "") + text self._pending_delta += text if self._should_flush_pending(now): - events.append( - StreamReasoningDelta(id=self._block_id, delta=self._pending_delta) - ) + if self._render_in_ui: + events.append( + StreamReasoningDelta(id=self._block_id, delta=self._pending_delta) + ) self._pending_delta = "" self._last_flush_monotonic = now return events @@ -348,12 +351,13 @@ class BaselineReasoningEmitter: if not self._open: return [] events: list[StreamBaseResponse] = [] - if self._pending_delta: - events.append( - StreamReasoningDelta(id=self._block_id, delta=self._pending_delta) - ) - self._pending_delta = "" - events.append(StreamReasoningEnd(id=self._block_id)) + if self._render_in_ui: + if self._pending_delta: + events.append( + StreamReasoningDelta(id=self._block_id, delta=self._pending_delta) + ) + events.append(StreamReasoningEnd(id=self._block_id)) + self._pending_delta = "" self._open = False self._block_id = str(uuid.uuid4()) self._current_row = None diff --git a/autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py b/autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py index e18c8066e4..1f5ca01845 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/reasoning_test.py @@ -452,3 +452,63 @@ class TestReasoningPersistence: events = emitter.on_delta(_delta(reasoning="pure wire")) assert len(events) == 2 # start + delta, no crash # Nothing else to assert — just proves None session is supported. + + +class TestBaselineReasoningEmitterRenderFlag: + """``render_in_ui=False`` must silence ``StreamReasoning*`` wire events + AND drop persistence of ``role="reasoning"`` rows — the operator hides + the collapse on both the live wire and on reload. Persistence is tied + to the wire events because the frontend's hydration path unconditionally + re-renders persisted reasoning rows; keeping them would make the flag a + no-op post-reload. These tests pin the contract in both directions so + future refactors can't flip only one half.""" + + def test_render_off_suppresses_start_and_delta(self): + emitter = BaselineReasoningEmitter(render_in_ui=False) + events = emitter.on_delta(_delta(reasoning="hidden")) + # No wire events, but state advanced (is_open == True) so close() + # below has something to rotate. + assert events == [] + assert emitter.is_open is True + + def test_render_off_suppresses_close_end(self): + emitter = BaselineReasoningEmitter(render_in_ui=False) + emitter.on_delta(_delta(reasoning="hidden")) + events = emitter.close() + assert events == [] + assert emitter.is_open is False + + def test_render_off_still_persists(self): + """Persistence is decoupled from the render flag — session + transcript always keeps the ``role="reasoning"`` row so audit + and ``--resume``-equivalent replay never lose thinking text. + The frontend gates rendering separately.""" + session: list[ChatMessage] = [] + emitter = BaselineReasoningEmitter(session, render_in_ui=False) + + emitter.on_delta(_delta(reasoning="part one ")) + emitter.on_delta(_delta(reasoning="part two")) + emitter.close() + + assert len(session) == 1 + assert session[0].role == "reasoning" + assert session[0].content == "part one part two" + + def test_render_off_rotates_block_id_between_sessions(self): + """Even with wire events silenced the block id must rotate on close, + otherwise a hypothetical mid-session flip would reuse a stale id.""" + emitter = BaselineReasoningEmitter(render_in_ui=False) + emitter.on_delta(_delta(reasoning="first")) + first_block_id = emitter._block_id + emitter.close() + emitter.on_delta(_delta(reasoning="second")) + assert emitter._block_id != first_block_id + + def test_render_on_is_default(self): + """Defaulting to True preserves backward compat — existing callers + that don't pass the kwarg keep emitting wire events as before.""" + emitter = BaselineReasoningEmitter() + events = emitter.on_delta(_delta(reasoning="hello")) + assert len(events) == 2 + assert isinstance(events[0], StreamReasoningStart) + assert isinstance(events[1], StreamReasoningDelta) diff --git a/autogpt_platform/backend/backend/copilot/baseline/service.py b/autogpt_platform/backend/backend/copilot/baseline/service.py index 6aa88e9d41..0f1174d51e 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service.py @@ -15,7 +15,7 @@ import re import shutil import tempfile import uuid -from collections.abc import AsyncGenerator, Mapping, Sequence +from collections.abc import AsyncGenerator, Iterable, Mapping, Sequence from dataclasses import dataclass, field from functools import partial from typing import TYPE_CHECKING, Any, cast @@ -45,6 +45,8 @@ from backend.copilot.model import ( maybe_append_user_message, upsert_chat_session, ) +from backend.copilot.model_router import resolve_model +from backend.copilot.moonshot import is_moonshot_model from backend.copilot.pending_message_helpers import ( combine_pending_with_current, drain_pending_safe, @@ -82,7 +84,7 @@ from backend.copilot.service import ( from backend.copilot.session_cleanup import prune_orphan_tool_calls from backend.copilot.thinking_stripper import ThinkingStripper as _ThinkingStripper from backend.copilot.token_tracking import persist_and_record_usage -from backend.copilot.tools import execute_tool, get_available_tools +from backend.copilot.tools import ToolGroup, execute_tool, get_available_tools from backend.copilot.tracking import track_user_message from backend.copilot.transcript import ( STOP_REASON_END_TURN, @@ -318,20 +320,17 @@ def _filter_tools_by_permissions( ] -def _resolve_baseline_model(tier: CopilotLlmModel | None) -> str: +async def _resolve_baseline_model( + tier: CopilotLlmModel | None, user_id: str | None +) -> str: """Pick the model for the baseline path based on the per-request tier. - Baseline resolves independently of SDK via the ``fast_*_model`` cells - of the (path, tier) matrix. ``'standard'`` / ``None`` picks Kimi - K2.6 by default (cheap + OpenRouter ``reasoning`` support); - ``'advanced'`` picks Opus by default so the advanced tier is a clean - A/B against the SDK advanced tier — same model, different path — - isolating reasoning-wire + cache differences from model capability. - Both defaults are overridable per ``CHAT_FAST_*_MODEL`` env vars. + Delegates to :func:`copilot.model_router.resolve_model` so the + ``(fast, tier)`` cell is LD-overridable per user. ``None`` tier + maps to ``"standard"``. """ - if tier == "advanced": - return config.fast_advanced_model - return config.fast_standard_model + tier_name = "advanced" if tier == "advanced" else "standard" + return await resolve_model("fast", tier_name, user_id, config=config) @dataclass @@ -343,7 +342,21 @@ class _BaselineStreamState: """ model: str = "" - pending_events: list[StreamBaseResponse] = field(default_factory=list) + # Live delivery channel drained concurrently by ``stream_chat_completion_baseline`` + # so reasoning / text / tool events reach the SSE wire **during** the upstream + # LLM stream, not after ``_baseline_llm_caller`` returns. Before this was a + # ``list`` drained per ``tool_call_loop`` iteration, so any model with + # extended thinking (Anthropic via OpenRouter, Moonshot, future reasoning + # routes) froze the UI for the entire duration of each LLM round before + # flushing the backlog in one burst. The queue is single-producer (the + # streaming loop) / single-consumer (the outer async-gen yield loop); + # ``None`` is the close sentinel. + pending_events: asyncio.Queue[StreamBaseResponse | None] = field( + default_factory=asyncio.Queue + ) + # Mirror of every event put on ``pending_events`` — kept for unit tests that + # inspect post-hoc what was emitted. Not consumed by production code. + emitted_events: list[StreamBaseResponse] = field(default_factory=list) assistant_text: str = "" text_block_id: str = field(default_factory=lambda: str(uuid.uuid4())) text_started: bool = False @@ -382,31 +395,71 @@ class _BaselineStreamState: # frontend's ``convertChatSessionToUiMessages`` relies on these # rows to render the Reasoning collapse after the AI SDK's # stream-end hydrate swaps in the DB-backed message list. - self.reasoning_emitter = BaselineReasoningEmitter(self.session_messages) + # ``render_in_ui`` is sourced from ``config.render_reasoning_in_ui`` + # so the operator can silence the reasoning collapse globally + # without dropping the persisted audit trail. + self.reasoning_emitter = BaselineReasoningEmitter( + self.session_messages, + render_in_ui=config.render_reasoning_in_ui, + ) + + +def _emit(state: "_BaselineStreamState", event: StreamBaseResponse) -> None: + """Queue *event* for the live SSE wire AND mirror into ``emitted_events``. + + Single helper so every streaming producer (LLM stream loop, tool executor, + conversation updater) posts to the same single-consumer queue. The mirror + list is read-only from production code — it exists so unit tests can assert + on the full sequence emitted during one call. + """ + state.pending_events.put_nowait(event) + state.emitted_events.append(event) + + +def _emit_all( + state: "_BaselineStreamState", events: Iterable[StreamBaseResponse] +) -> None: + """Queue *events* in order — convenience for emitter batches.""" + for event in events: + _emit(state, event) def _is_anthropic_model(model: str) -> bool: """Return True if *model* routes to Anthropic (native or via OpenRouter). - Cache-control markers on message content + the ``anthropic-beta`` header - are Anthropic-specific. OpenAI rejects the unknown ``cache_control`` - field with a 400 ("Extra inputs are not permitted") and Grok / other - providers behave similarly. OpenRouter strips unknown headers but - passes through ``cache_control`` on the body regardless of provider — - which would also fail when OpenRouter routes to a non-Anthropic model. - Examples that return True: - ``anthropic/claude-sonnet-4-6`` (OpenRouter route) - ``claude-3-5-sonnet-20241022`` (direct Anthropic API) - ``anthropic.claude-3-5-sonnet`` (Bedrock-style) False for ``openai/gpt-4o``, ``google/gemini-2.5-pro``, ``xai/grok-4`` - etc. + etc. Moonshot is False here too even though Moonshot's + Anthropic-compat endpoint honours ``cache_control`` — use + :func:`_supports_prompt_cache_markers` for the cache-gating decision, + which also allows Moonshot routes. This function stays scoped to + "genuinely Anthropic" so callers that need the stricter check (e.g. + ``anthropic-beta`` header emission) keep their existing semantics. """ lowered = model.lower() return "claude" in lowered or lowered.startswith("anthropic") +def _supports_prompt_cache_markers(model: str) -> bool: + """Return True when *model* accepts Anthropic-style ``cache_control``. + + Superset of :func:`_is_anthropic_model` — also allows Moonshot + (``moonshotai/*``), whose OpenRouter Anthropic-compat endpoint + honours the marker and empirically lifts cache hit rate on + continuation turns from near-zero (Moonshot's own automatic prefix + cache, which drifts readily) to the 60-95% Anthropic ballpark. + + OpenAI / Grok / Gemini still 400 on ``cache_control``, so this + function returns False for those providers — add new vendors here + only after verifying their endpoint accepts the field. + """ + return _is_anthropic_model(model) or is_moonshot_model(model) + + def _fresh_ephemeral_cache_control() -> dict[str, str]: """Return a FRESH ephemeral ``cache_control`` dict each call. @@ -519,7 +572,7 @@ async def _baseline_llm_caller( Extracted from ``stream_chat_completion_baseline`` for readability. """ - state.pending_events.append(StreamStartStep()) + _emit(state, StreamStartStep()) # Fresh thinking-strip state per round so a malformed unclosed # block in one LLM call cannot silently drop content in the next. state.thinking_stripper = _ThinkingStripper() @@ -527,19 +580,24 @@ async def _baseline_llm_caller( round_text = "" try: client = _get_openai_client() - # Cache markers are Anthropic-specific. For OpenAI/Grok/other - # providers, leaving them on would trigger a 400 ("Extra inputs - # are not permitted" on cache_control). Tools were precomputed - # in stream_chat_completion_baseline via _mark_tools_with_cache_control - # (only when the model was Anthropic), so on non-Anthropic routes - # tools ship without cache_control on the last entry too. + # Cache markers are accepted by Anthropic AND Moonshot (via OR's + # Anthropic-compat endpoint). OpenAI/Grok/Gemini 400 on the + # unknown ``cache_control`` field — tools were precomputed in + # stream_chat_completion_baseline via _mark_tools_with_cache_control + # with the same gate, so on unsupported routes tools ship + # unmarked too. # - # `extra_body` `usage.include=true` asks OpenRouter to embed the real - # generation cost into the final usage chunk — required by the - # cost-based rate limiter in routes.py. Separate from the Anthropic + # The ``anthropic-beta`` header is only emitted for genuinely + # Anthropic routes (see :func:`_is_anthropic_model`) — Moonshot + # doesn't need the beta header; sending it is a no-op but we + # keep the check strict for clarity. + # + # `extra_body` `usage.include=true` asks OpenRouter to embed the + # real generation cost into the final usage chunk — required by + # the cost-based rate limiter in routes.py. Separate from the # caching headers, always sent. - is_anthropic = _is_anthropic_model(state.model) - if is_anthropic: + supports_cache = _supports_prompt_cache_markers(state.model) + if supports_cache: # Build the cached system dict once per session and splice it in # on each round. The full ``messages`` list grows with every # tool call, so copying the entire list just to mutate index 0 @@ -555,7 +613,11 @@ async def _baseline_llm_caller( final_messages = [state.cached_system_message, *messages[1:]] else: final_messages = messages - extra_headers = _fresh_anthropic_caching_headers() + extra_headers = ( + _fresh_anthropic_caching_headers() + if _is_anthropic_model(state.model) + else None + ) else: final_messages = messages extra_headers = None @@ -621,31 +683,30 @@ async def _baseline_llm_caller( if not delta: continue - state.pending_events.extend(state.reasoning_emitter.on_delta(delta)) + _emit_all(state, state.reasoning_emitter.on_delta(delta)) if delta.content: # Text and reasoning must not interleave on the wire — the # AI SDK maps distinct start/end pairs to distinct UI # parts. Close any open reasoning block before emitting # the first text delta of this run. - state.pending_events.extend(state.reasoning_emitter.close()) + _emit_all(state, state.reasoning_emitter.close()) emit = state.thinking_stripper.process(delta.content) if emit: if not state.text_started: - state.pending_events.append( - StreamTextStart(id=state.text_block_id) - ) + _emit(state, StreamTextStart(id=state.text_block_id)) state.text_started = True round_text += emit - state.pending_events.append( - StreamTextDelta(id=state.text_block_id, delta=emit) + _emit( + state, + StreamTextDelta(id=state.text_block_id, delta=emit), ) if delta.tool_calls: # Same rule as the text branch: close any open reasoning # block before a tool_use starts so the AI SDK treats # reasoning and tool-use as distinct parts. - state.pending_events.extend(state.reasoning_emitter.close()) + _emit_all(state, state.reasoning_emitter.close()) for tc in delta.tool_calls: idx = tc.index if idx not in tool_calls_by_index: @@ -676,19 +737,17 @@ async def _baseline_llm_caller( # ``async for chunk in response`` would otherwise leave reasoning # and/or text unterminated and only ``StreamFinishStep`` emitted — # the Reasoning / Text collapses would never finalise. - state.pending_events.extend(state.reasoning_emitter.close()) + _emit_all(state, state.reasoning_emitter.close()) # Flush any buffered text held back by the thinking stripper. tail = state.thinking_stripper.flush() if tail: if not state.text_started: - state.pending_events.append(StreamTextStart(id=state.text_block_id)) + _emit(state, StreamTextStart(id=state.text_block_id)) state.text_started = True round_text += tail - state.pending_events.append( - StreamTextDelta(id=state.text_block_id, delta=tail) - ) + _emit(state, StreamTextDelta(id=state.text_block_id, delta=tail)) if state.text_started: - state.pending_events.append(StreamTextEnd(id=state.text_block_id)) + _emit(state, StreamTextEnd(id=state.text_block_id)) state.text_started = False state.text_block_id = str(uuid.uuid4()) # Always persist partial text so the session history stays consistent, @@ -696,7 +755,7 @@ async def _baseline_llm_caller( state.assistant_text += round_text # Always emit StreamFinishStep to match the StreamStartStep, # even if an exception occurred during streaming. - state.pending_events.append(StreamFinishStep()) + _emit(state, StreamFinishStep()) # Convert to shared format llm_tool_calls = [ @@ -738,13 +797,14 @@ async def _baseline_tool_executor( except orjson.JSONDecodeError as parse_err: parse_error = f"Invalid JSON arguments for tool '{tool_name}': {parse_err}" logger.warning("[Baseline] %s", parse_error) - state.pending_events.append( + _emit( + state, StreamToolOutputAvailable( toolCallId=tool_call_id, toolName=tool_name, output=parse_error, success=False, - ) + ), ) return ToolCallResult( tool_call_id=tool_call_id, @@ -753,15 +813,17 @@ async def _baseline_tool_executor( is_error=True, ) - state.pending_events.append( - StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name) + _emit( + state, + StreamToolInputStart(toolCallId=tool_call_id, toolName=tool_name), ) - state.pending_events.append( + _emit( + state, StreamToolInputAvailable( toolCallId=tool_call_id, toolName=tool_name, input=tool_args, - ) + ), ) # Announce the tool call to the session so in-turn guards like @@ -785,7 +847,7 @@ async def _baseline_tool_executor( session=session, tool_call_id=tool_call_id, ) - state.pending_events.append(result) + _emit(state, result) tool_output = ( result.output if isinstance(result.output, str) else str(result.output) ) @@ -802,13 +864,14 @@ async def _baseline_tool_executor( error_output, exc_info=True, ) - state.pending_events.append( + _emit( + state, StreamToolOutputAvailable( toolCallId=tool_call_id, toolName=tool_name, output=error_output, success=False, - ) + ), ) return ToolCallResult( tool_call_id=tool_call_id, @@ -1317,7 +1380,7 @@ async def stream_chat_completion_baseline( # Select model based on the per-request tier toggle (standard / advanced). # The path (fast vs extended_thinking) is already decided — we're in the # baseline (fast) path; ``mode`` is accepted for logging parity only. - active_model = _resolve_baseline_model(model) + active_model = await _resolve_baseline_model(model, user_id) # --- E2B sandbox setup (feature parity with SDK path) --- e2b_sandbox = None @@ -1583,7 +1646,10 @@ async def stream_chat_completion_baseline( openai_messages[i]["content"] = text break - tools = get_available_tools() + disabled_tool_groups: list[ToolGroup] = [] + if not graphiti_enabled: + disabled_tool_groups.append("graphiti") + tools = get_available_tools(disabled_groups=disabled_tool_groups) # --- Permission filtering --- if permissions is not None: @@ -1594,9 +1660,10 @@ async def stream_chat_completion_baseline( # _baseline_llm_caller) avoids re-copying ~43 tool dicts on every LLM # round of the tool-call loop. # - # Only apply to Anthropic routes — OpenAI/Grok/other providers would - # 400 on the unknown ``cache_control`` field inside tool definitions. - if _is_anthropic_model(active_model): + # Applies to Anthropic AND Moonshot routes — OpenAI/Grok/Gemini 400 + # on the unknown ``cache_control`` field inside tool definitions, so + # the gate stays narrow (see :func:`_supports_prompt_cache_markers`). + if _supports_prompt_cache_markers(active_model): tools = cast( list[ChatCompletionToolParam], _mark_tools_with_cache_control(tools) ) @@ -1660,139 +1727,172 @@ async def stream_chat_completion_baseline( state=state, ) - try: - loop_result = None - async for loop_result in tool_call_loop( - messages=openai_messages, - tools=tools, - llm_call=_bound_llm_caller, - execute_tool=_bound_tool_executor, - update_conversation=_bound_conversation_updater, - max_iterations=_MAX_TOOL_ROUNDS, - ): - # Drain buffered events after each iteration (real-time streaming) - for evt in state.pending_events: - yield evt - state.pending_events.clear() + # Run the tool-call loop concurrently with the event consumer so + # ``StreamReasoning*`` / ``StreamText*`` deltas emitted inside + # ``_baseline_llm_caller`` reach the SSE wire DURING the upstream LLM + # stream instead of only at iteration boundaries. Any reasoning route + # that streams for several minutes per round (extended thinking on + # Anthropic / Moonshot / future providers) would otherwise freeze the + # UI for the whole window before flushing the backlog in one burst. + loop_result_holder: list[Any] = [None] + loop_task: asyncio.Task[None] | None = None - # Inject any messages the user queued while the turn was - # running. ``tool_call_loop`` mutates ``openai_messages`` - # in-place, so appending here means the model sees the new - # messages on its next LLM call. - # - # IMPORTANT: skip when the loop has already finished (no - # more LLM calls are coming). ``tool_call_loop`` yields - # a final ``ToolCallLoopResult`` on both paths: - # - natural finish: ``finished_naturally=True`` - # - hit max_iterations: ``finished_naturally=False`` - # and ``iterations >= max_iterations`` - # In either case the loop is about to return on the next - # ``async for`` step, so draining here would silently - # lose the message (the user sees 202 but the model never - # reads the text). Those messages stay in the buffer and - # get picked up at the start of the next turn. - is_final_yield = ( - loop_result.finished_naturally - or loop_result.iterations >= _MAX_TOOL_ROUNDS - ) - if is_final_yield: - continue - try: - pending = await drain_pending_messages(session_id) - except Exception: - logger.warning( - "[Baseline] mid-loop drain_pending_messages failed for session %s", - session_id, - exc_info=True, - ) - pending = [] - if pending: - # Flush any buffered assistant/tool messages from completed - # rounds into session.messages BEFORE appending the pending - # user message. ``_baseline_conversation_updater`` only - # records assistant+tool rounds into ``state.session_messages`` - # — they are normally batch-flushed in the finally block. - # Without this in-order flush, the mid-loop pending user - # message lands before the preceding round's assistant/tool - # entries, producing chronologically-wrong session.messages - # on persist (user interposed between an assistant tool_call - # and its tool-result), which breaks OpenAI tool-call ordering - # invariants on the next turn's replay. + async def _run_tool_call_loop() -> None: + # Read/write the current session via ``_session_holder`` so this + # closure doesn't need to ``nonlocal session`` — pyright can't narrow + # the outer ``session: ChatSession | None`` through a nested scope, + # but the holder is typed non-optional after the preflight guard + # above. + try: + async for loop_result in tool_call_loop( + messages=openai_messages, + tools=tools, + llm_call=_bound_llm_caller, + execute_tool=_bound_tool_executor, + update_conversation=_bound_conversation_updater, + max_iterations=_MAX_TOOL_ROUNDS, + ): + loop_result_holder[0] = loop_result + # Inject any messages the user queued while the turn was + # running. ``tool_call_loop`` mutates ``openai_messages`` + # in-place, so appending here means the model sees the new + # messages on its next LLM call. # - # Also persist any assistant text from text-only rounds (rounds - # with no tool calls, which ``_baseline_conversation_updater`` - # does NOT record in session_messages). If we only update - # ``_flushed_assistant_text_len`` without persisting the text, - # that text is silently lost: the finally block only appends - # assistant_text[_flushed_assistant_text_len:], so text generated - # before this drain never reaches session.messages. - recorded_text = "".join( - m.content or "" - for m in state.session_messages - if m.role == "assistant" + # IMPORTANT: skip when the loop has already finished (no + # more LLM calls are coming). ``tool_call_loop`` yields + # a final ``ToolCallLoopResult`` on both paths: + # - natural finish: ``finished_naturally=True`` + # - hit max_iterations: ``finished_naturally=False`` + # and ``iterations >= max_iterations`` + # In either case the loop is about to return on the next + # ``async for`` step, so draining here would silently + # lose the message (the user sees 202 but the model never + # reads the text). Those messages stay in the buffer and + # get picked up at the start of the next turn. + is_final_yield = ( + loop_result.finished_naturally + or loop_result.iterations >= _MAX_TOOL_ROUNDS ) - unflushed_text = state.assistant_text[ - state._flushed_assistant_text_len : - ] - text_only_text = ( - unflushed_text[len(recorded_text) :] - if unflushed_text.startswith(recorded_text) - else unflushed_text - ) - if text_only_text.strip(): - session.messages.append( - ChatMessage(role="assistant", content=text_only_text) + if is_final_yield: + continue + try: + pending = await drain_pending_messages(session_id) + except Exception: + logger.warning( + "[Baseline] mid-loop drain_pending_messages failed for " + "session %s", + session_id, + exc_info=True, ) - for _buffered in state.session_messages: - session.messages.append(_buffered) - state.session_messages.clear() - # Record how much assistant_text has been covered by the - # structured entries just flushed, so the finally block's - # final-text dedup doesn't re-append rounds already persisted. - state._flushed_assistant_text_len = len(state.assistant_text) + pending = [] + if pending: + # Flush any buffered assistant/tool messages from completed + # rounds into session.messages BEFORE appending the pending + # user message. ``_baseline_conversation_updater`` only + # records assistant+tool rounds into ``state.session_messages`` + # — they are normally batch-flushed in the finally block. + # Without this in-order flush, the mid-loop pending user + # message lands before the preceding round's assistant/tool + # entries, producing chronologically-wrong session.messages + # on persist (user interposed between an assistant tool_call + # and its tool-result), which breaks OpenAI tool-call ordering + # invariants on the next turn's replay. + # + # Also persist any assistant text from text-only rounds (rounds + # with no tool calls, which ``_baseline_conversation_updater`` + # does NOT record in session_messages). If we only update + # ``_flushed_assistant_text_len`` without persisting the text, + # that text is silently lost: the finally block only appends + # assistant_text[_flushed_assistant_text_len:], so text generated + # before this drain never reaches session.messages. + recorded_text = "".join( + m.content or "" + for m in state.session_messages + if m.role == "assistant" + ) + unflushed_text = state.assistant_text[ + state._flushed_assistant_text_len : + ] + text_only_text = ( + unflushed_text[len(recorded_text) :] + if unflushed_text.startswith(recorded_text) + else unflushed_text + ) + current_session = _session_holder[0] + if text_only_text.strip(): + current_session.messages.append( + ChatMessage(role="assistant", content=text_only_text) + ) + for _buffered in state.session_messages: + current_session.messages.append(_buffered) + state.session_messages.clear() + # Record how much assistant_text has been covered by the + # structured entries just flushed, so the finally block's + # final-text dedup doesn't re-append rounds already persisted. + state._flushed_assistant_text_len = len(state.assistant_text) - # Persist the assistant/tool flush BEFORE the pending append - # so a later pending-persist failure can roll back the - # pending rows without also discarding LLM output. - session = await persist_session_safe(session, "[Baseline]") - # ``upsert_chat_session`` may return a *new* ``ChatSession`` - # instance (e.g. when a concurrent title update has written a - # newer title to Redis, it returns ``session.model_copy``). - # Keep ``_session_holder`` in sync so subsequent tool rounds - # executed via ``_bound_tool_executor`` see the fresh session - # — any tool-side mutations on the stale object would be - # discarded when the new one is persisted in the ``finally``. - _session_holder[0] = session + # Persist the assistant/tool flush BEFORE the pending append + # so a later pending-persist failure can roll back the + # pending rows without also discarding LLM output. + current_session = await persist_session_safe( + current_session, "[Baseline]" + ) + # ``upsert_chat_session`` may return a *new* ``ChatSession`` + # instance (e.g. when a concurrent title update has written a + # newer title to Redis, it returns ``session.model_copy``). + # Keep ``_session_holder`` in sync so subsequent tool rounds + # executed via ``_bound_tool_executor`` see the fresh session + # — any tool-side mutations on the stale object would be + # discarded when the new one is persisted in the ``finally``. + _session_holder[0] = current_session - # ``format_pending_as_user_message`` embeds file attachments - # and context URL/page content into the content string so - # the in-session transcript is a faithful copy of what the - # model actually saw. We also mirror each push into - # ``openai_messages`` so the model's next LLM round sees it. - # - # Pre-compute the formatted dicts once so both the openai - # messages append and the content_of lookup inside the - # shared helper use the same string — and so ``on_rollback`` - # can trim ``openai_messages`` to the recorded anchor. - formatted_by_pm = { - id(pm): format_pending_as_user_message(pm) for pm in pending - } - _openai_anchor = len(openai_messages) - for pm in pending: - openai_messages.append(formatted_by_pm[id(pm)]) + # ``format_pending_as_user_message`` embeds file attachments + # and context URL/page content into the content string so + # the in-session transcript is a faithful copy of what the + # model actually saw. We also mirror each push into + # ``openai_messages`` so the model's next LLM round sees it. + # + # Pre-compute the formatted dicts once so both the openai + # messages append and the content_of lookup inside the + # shared helper use the same string — and so ``on_rollback`` + # can trim ``openai_messages`` to the recorded anchor. + formatted_by_pm = { + id(pm): format_pending_as_user_message(pm) for pm in pending + } + _openai_anchor = len(openai_messages) + for pm in pending: + openai_messages.append(formatted_by_pm[id(pm)]) - def _trim_openai_on_rollback(_session_anchor: int) -> None: - del openai_messages[_openai_anchor:] + def _trim_openai_on_rollback(_session_anchor: int) -> None: + del openai_messages[_openai_anchor:] - await persist_pending_as_user_rows( - session, - transcript_builder, - pending, - log_prefix="[Baseline]", - content_of=lambda pm: formatted_by_pm[id(pm)]["content"], - on_rollback=_trim_openai_on_rollback, - ) + await persist_pending_as_user_rows( + current_session, + transcript_builder, + pending, + log_prefix="[Baseline]", + content_of=lambda pm: formatted_by_pm[id(pm)]["content"], + on_rollback=_trim_openai_on_rollback, + ) + finally: + # Always post the sentinel so the outer consumer exits — even if + # ``tool_call_loop`` raised. ``_baseline_llm_caller``'s own + # finally block has already pushed ``StreamReasoningEnd`` / + # ``StreamTextEnd`` / ``StreamFinishStep`` at this point, so the + # sentinel only terminates the consumer; it does not suppress + # any still-unflushed events. + state.pending_events.put_nowait(None) + loop_task = asyncio.create_task(_run_tool_call_loop()) + try: + while True: + evt = await state.pending_events.get() + if evt is None: + break + yield evt + # Sentinel received — surface any exception the inner task hit. + await loop_task + loop_result = loop_result_holder[0] if loop_result and not loop_result.finished_naturally: limit_msg = ( f"Exceeded {_MAX_TOOL_ROUNDS} tool-call rounds " @@ -1803,25 +1903,34 @@ async def stream_chat_completion_baseline( errorText=limit_msg, code="baseline_tool_round_limit", ) - except Exception as e: _stream_error = True error_msg = str(e) or type(e).__name__ logger.error("[Baseline] Streaming error: %s", error_msg, exc_info=True) - # ``_baseline_llm_caller``'s finally block closes any open - # reasoning / text blocks and appends ``StreamFinishStep`` on - # both normal and exception paths, so pending_events already has - # the correct protocol ordering: - # StreamStartStep -> StreamReasoningStart -> ...deltas... -> - # StreamReasoningEnd -> StreamTextStart -> ...deltas... -> - # StreamTextEnd -> StreamFinishStep - # Just drain what's buffered, then yield the error. - for evt in state.pending_events: - yield evt - state.pending_events.clear() + # Drain any queued tail events (reasoning/text close + finish step) + # that ``_baseline_llm_caller``'s finally block pushed before the + # sentinel arrived — without this the frontend would be missing the + # matching end / finish parts for the partial round. + while not state.pending_events.empty(): + evt = state.pending_events.get_nowait() + if evt is not None: + yield evt yield StreamError(errorText=error_msg, code="baseline_error") # Still persist whatever we got finally: + # Cancel the inner task if we're unwinding early (client disconnect, + # unexpected error in the consumer) so it doesn't keep streaming + # tokens into a dead queue. + if loop_task is not None and not loop_task.done(): + loop_task.cancel() + try: + await loop_task + except (asyncio.CancelledError, Exception): + pass + # Re-sync the outer ``session`` binding in case the inner task + # reassigned it via a mid-loop ``persist_session_safe`` call. + session = _session_holder[0] + # In-flight tool-call announcements are only meaningful for the # current turn; clear at the top of the outer finally so the next # turn starts with a clean scratch buffer even if one of the diff --git a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py index 5a95c5c901..3051ea5d99 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/service_unit_test.py @@ -21,6 +21,7 @@ from backend.copilot.baseline.service import ( _is_anthropic_model, _mark_system_message_with_cache_control, _mark_tools_with_cache_control, + _supports_prompt_cache_markers, ) from backend.copilot.model import ChatMessage from backend.copilot.response_model import ( @@ -39,7 +40,10 @@ from backend.util.tool_call_loop import LLMLoopResponse, LLMToolCall, ToolCallRe class TestBaselineStreamState: def test_defaults(self): state = _BaselineStreamState() - assert state.pending_events == [] + # ``pending_events`` is an asyncio.Queue now (live SSE channel). + # The durable inspection view is ``emitted_events``. + assert state.pending_events.empty() + assert state.emitted_events == [] assert state.assistant_text == "" assert state.text_started is False assert state.turn_prompt_tokens == 0 @@ -1687,7 +1691,7 @@ class TestBaselineReasoningStreaming: state=state, ) - types = [type(e).__name__ for e in state.pending_events] + types = [type(e).__name__ for e in state.emitted_events] assert "StreamReasoningStart" in types assert "StreamReasoningDelta" in types assert "StreamReasoningEnd" in types @@ -1702,14 +1706,14 @@ class TestBaselineReasoningStreaming: # a fresh id after the reasoning-end rotation. reasoning_ids = { e.id - for e in state.pending_events + for e in state.emitted_events if isinstance( e, (StreamReasoningStart, StreamReasoningDelta, StreamReasoningEnd) ) } text_ids = { e.id - for e in state.pending_events + for e in state.emitted_events if isinstance(e, (StreamTextStart, StreamTextDelta, StreamTextEnd)) } assert len(reasoning_ids) == 1 @@ -1717,7 +1721,7 @@ class TestBaselineReasoningStreaming: assert reasoning_ids.isdisjoint(text_ids) combined = "".join( - e.delta for e in state.pending_events if isinstance(e, StreamReasoningDelta) + e.delta for e in state.emitted_events if isinstance(e, StreamReasoningDelta) ) assert combined == "thinking... more" @@ -1759,7 +1763,7 @@ class TestBaselineReasoningStreaming: # A reasoning-end must have been emitted — this is the tool_calls # branch's responsibility, not the stream-end cleanup. - types = [type(e).__name__ for e in state.pending_events] + types = [type(e).__name__ for e in state.emitted_events] assert "StreamReasoningStart" in types assert "StreamReasoningEnd" in types @@ -1802,7 +1806,7 @@ class TestBaselineReasoningStreaming: state=state, ) - types = [type(e).__name__ for e in state.pending_events] + types = [type(e).__name__ for e in state.emitted_events] # The reasoning block was opened, the exception fired, and the # finally block must have closed it before emitting the finish # step. @@ -1861,12 +1865,14 @@ class TestBaselineReasoningStreaming: assert "reasoning" not in extra_body @pytest.mark.asyncio - async def test_kimi_route_sends_reasoning_but_no_cache_control(self): - """Kimi K2.6 is the default fast_model and sends ``reasoning`` via - OpenRouter's unified extension. It must NOT receive ``cache_control`` - markers or the ``anthropic-beta`` header — Moonshot uses its own - auto-caching and those Anthropic-only fields would either get - silently dropped or (worst case) 400 on a future provider change.""" + async def test_kimi_route_sends_reasoning_and_cache_control(self): + """Kimi K2.6 (Moonshot via OpenRouter's Anthropic-compat endpoint) + accepts ``cache_control: {type: ephemeral}`` on the system block + and the last tool — the endpoint honours the marker and lifts + cache hit rate on continuation turns from near-zero (Moonshot's + auto-caching drifts) to the Anthropic ~60-95% ballpark. The + ``anthropic-beta`` header stays off because Moonshot doesn't need + it; OpenRouter would strip the unknown header anyway.""" state = _BaselineStreamState(model="moonshotai/kimi-k2.6") mock_client = MagicMock() @@ -1898,15 +1904,29 @@ class TestBaselineReasoningStreaming: # cheap-but-still-reasoning-capable path. assert "reasoning" in extra_body assert extra_body["reasoning"]["max_tokens"] > 0 - # Anthropic-only fields stay off. - assert "extra_headers" not in call_kwargs + # No ``anthropic-beta`` header — that beta is specifically for + # native Anthropic endpoints; Moonshot's shim accepts + # ``cache_control`` without it, and sending it would be wasted + # bytes (OR strips it before forwarding to Moonshot). + assert "extra_headers" not in call_kwargs or not call_kwargs.get( + "extra_headers" + ) + # System block MUST carry ``cache_control`` so Moonshot's cache + # breakpoint is honoured. The cached system-message builder + # emits list-shape content with the marker on the first (and + # only) block — assert on that shape. sys_msg = call_kwargs["messages"][0] sys_content = sys_msg.get("content") - if isinstance(sys_content, list): - assert all("cache_control" not in block for block in sys_content) - tools = call_kwargs.get("tools", []) - for t in tools: - assert "cache_control" not in t + assert isinstance( + sys_content, list + ), "Cached system message should be a list-shape content block" + assert any( + "cache_control" in block for block in sys_content if isinstance(block, dict) + ), "Kimi system message should now carry cache_control markers" + # Tool-level cache marking is applied by ``stream_chat_completion_baseline`` + # (see ``_mark_tools_with_cache_control``) before tools reach + # ``_baseline_llm_caller``, so this unit test doesn't exercise + # that path — covered by the outer integration test. @pytest.mark.asyncio async def test_reasoning_only_stream_still_closes_block(self): @@ -1935,7 +1955,7 @@ class TestBaselineReasoningStreaming: state=state, ) - types = [type(e).__name__ for e in state.pending_events] + types = [type(e).__name__ for e in state.emitted_events] assert "StreamReasoningStart" in types assert "StreamReasoningEnd" in types # No text was produced — no text events should be emitted. @@ -2006,3 +2026,55 @@ class TestBaselineReasoningStreaming: reasoning_rows = [m for m in state.session_messages if m.role == "reasoning"] assert len(reasoning_rows) == 1 assert reasoning_rows[0].content == "first thought" + + +class TestSupportsPromptCacheMarkers: + """``_supports_prompt_cache_markers`` is the widened gate for + emitting ``cache_control`` markers on message content. It's a + superset of ``_is_anthropic_model`` that ALSO admits Moonshot + (whose Anthropic-compat endpoint honours the marker) while keeping + the False answer for OpenAI / Grok / Gemini (which 400 on the + unknown field).""" + + @pytest.mark.parametrize( + "model", + [ + "anthropic/claude-sonnet-4-6", + "claude-3-5-sonnet-20241022", + "anthropic.claude-3-5-sonnet", + "ANTHROPIC/Claude-Opus", + ], + ) + def test_anthropic_routes_are_supported(self, model): + assert _supports_prompt_cache_markers(model) is True + + @pytest.mark.parametrize( + "model", + [ + "moonshotai/kimi-k2.6", + "moonshotai/kimi-k2-thinking", + "moonshotai/kimi-k2.5", + "moonshotai/kimi-k3.0", # future SKU + ], + ) + def test_moonshot_routes_are_supported(self, model): + """The whole reason this predicate exists — Moonshot must be + True even though ``_is_anthropic_model`` is False for it.""" + assert _supports_prompt_cache_markers(model) is True + # Verify this is strictly wider than the anthropic-only check. + assert _is_anthropic_model(model) is False + + @pytest.mark.parametrize( + "model", + [ + "openai/gpt-4o", + "google/gemini-2.5-pro", + "xai/grok-4", + "meta-llama/llama-3.3-70b-instruct", + "deepseek/deepseek-v3", + ], + ) + def test_other_providers_still_rejected(self, model): + """Regression guard: OpenAI/Grok/Gemini still 400 on + ``cache_control``, so the widened gate must keep them out.""" + assert _supports_prompt_cache_markers(model) is False diff --git a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py index 808b06eb32..8a9e435743 100644 --- a/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py +++ b/autogpt_platform/backend/backend/copilot/baseline/transcript_integration_test.py @@ -67,34 +67,40 @@ class TestResolveBaselineModel: Baseline reads the ``fast_*_model`` cells of the (path, tier) matrix and never falls through to the SDK-side ``thinking_*_model`` cells. - Default routing: - - ``standard`` / ``None`` → ``config.fast_standard_model`` (Kimi K2.6) - - ``advanced`` → ``config.fast_advanced_model`` (Opus — same as SDK's - advanced tier, so the advanced A/B isolates path differences) + Without a user_id (so no LD context) the resolver returns the + ``ChatConfig`` static default; per-user overrides are exercised in + ``copilot/model_router_test.py``. """ - def test_advanced_tier_selects_fast_advanced_model(self): - assert _resolve_baseline_model("advanced") == config.fast_advanced_model + @pytest.mark.asyncio + async def test_advanced_tier_selects_fast_advanced_model(self): + assert ( + await _resolve_baseline_model("advanced", None) + == config.fast_advanced_model + ) - def test_standard_tier_selects_fast_standard_model(self): - assert _resolve_baseline_model("standard") == config.fast_standard_model + @pytest.mark.asyncio + async def test_standard_tier_selects_fast_standard_model(self): + assert ( + await _resolve_baseline_model("standard", None) + == config.fast_standard_model + ) - def test_none_tier_selects_fast_standard_model(self): - """Baseline users without a tier get the cheap fast-standard default.""" - assert _resolve_baseline_model(None) == config.fast_standard_model + @pytest.mark.asyncio + async def test_none_tier_selects_fast_standard_model(self): + """Baseline users without a tier get the fast-standard default.""" + assert await _resolve_baseline_model(None, None) == config.fast_standard_model - def test_fast_standard_default_is_kimi(self): - """Shipped default: Kimi K2.6 on the baseline standard cell. - - Asserts the declared ``Field`` default — env-independent — so a - deploy-time ``CHAT_FAST_STANDARD_MODEL`` rollback override - doesn't fail CI while still pinning the shipped default. - """ + def test_fast_standard_default_is_sonnet(self): + """Shipped default: Sonnet on the baseline standard cell — the + non-Anthropic routes ship via the LD flag instead of a config + change. Asserts the declared ``Field`` default so a deploy-time + ``CHAT_FAST_STANDARD_MODEL`` override doesn't flake CI.""" from backend.copilot.config import ChatConfig assert ( ChatConfig.model_fields["fast_standard_model"].default - == "moonshotai/kimi-k2.6" + == "anthropic/claude-sonnet-4-6" ) def test_fast_advanced_default_is_opus(self): @@ -108,18 +114,6 @@ class TestResolveBaselineModel: == "anthropic/claude-opus-4.7" ) - def test_standard_cells_diverge_across_paths(self): - """The whole point of the split: baseline cheap (Kimi) vs SDK - Anthropic-only (Sonnet). If the shipped standard defaults ever - collapse to the same value someone lost the cost savings. - Checked against ``Field`` defaults, not the env-backed singleton.""" - from backend.copilot.config import ChatConfig - - assert ( - ChatConfig.model_fields["thinking_standard_model"].default - != ChatConfig.model_fields["fast_standard_model"].default - ) - def test_standard_and_advanced_cells_differ_on_fast(self): """Advanced tier defaults to a different model than standard on the baseline path. Checked against declared ``Field`` defaults diff --git a/autogpt_platform/backend/backend/copilot/config.py b/autogpt_platform/backend/backend/copilot/config.py index d2c66a3484..64e0e92ee8 100644 --- a/autogpt_platform/backend/backend/copilot/config.py +++ b/autogpt_platform/backend/backend/copilot/config.py @@ -3,7 +3,7 @@ import os from typing import Literal -from pydantic import AliasChoices, Field, field_validator +from pydantic import AliasChoices, Field, field_validator, model_validator from pydantic_settings import BaseSettings from backend.util.clients import OPENROUTER_BASE_URL @@ -43,26 +43,20 @@ class ChatConfig(BaseSettings): # ``CHAT_FAST_MODEL``) are preserved via ``validation_alias`` so # existing deployments continue to override the same effective cell. fast_standard_model: str = Field( - default="moonshotai/kimi-k2.6", + default="anthropic/claude-sonnet-4-6", validation_alias=AliasChoices( "CHAT_FAST_STANDARD_MODEL", "CHAT_FAST_MODEL", ), - description="Baseline path, 'standard' / ``None`` tier. Kimi K2.6 " - "by default: ~5x cheaper input and ~5.4x cheaper output than Sonnet, " - "SWE-Bench Verified parity with Opus, and OpenRouter advertises the " - "``reasoning`` + ``include_reasoning`` extension params on the " - "Moonshot endpoints — so the baseline reasoning plumbing lights up " - "without provider-specific code. Roll back to the Anthropic route " - "via ``CHAT_FAST_STANDARD_MODEL=anthropic/claude-sonnet-4-6`` (then " - "``cache_control`` breakpoints reactivate via " - "``_is_anthropic_model``).", + description="Baseline path, 'standard' / ``None`` tier. Per-user " + "overrides flow through the ``copilot-fast-standard-model`` LD flag " + "(see ``copilot/model_router.py``); this value is the fallback.", ) fast_advanced_model: str = Field( default="anthropic/claude-opus-4.7", validation_alias=AliasChoices("CHAT_FAST_ADVANCED_MODEL"), - description="Baseline path, 'advanced' tier. Opus by default. " - "Override via ``CHAT_FAST_ADVANCED_MODEL``.", + description="Baseline path, 'advanced' tier. LD override: " + "``copilot-fast-advanced-model``.", ) thinking_standard_model: str = Field( default="anthropic/claude-sonnet-4-6", @@ -71,11 +65,7 @@ class ChatConfig(BaseSettings): "CHAT_MODEL", ), description="SDK (extended-thinking) path, 'standard' / ``None`` " - "tier. Sonnet by default: the Claude Agent SDK CLI only speaks to " - "Anthropic endpoints, so the standard SDK tier has to stay on an " - "Anthropic model regardless of what the baseline path runs. " - "Override via ``CHAT_THINKING_STANDARD_MODEL`` (legacy " - "``CHAT_MODEL`` still honored).", + "tier. LD override: ``copilot-thinking-standard-model``.", ) thinking_advanced_model: str = Field( default="anthropic/claude-opus-4.7", @@ -83,17 +73,18 @@ class ChatConfig(BaseSettings): "CHAT_THINKING_ADVANCED_MODEL", "CHAT_ADVANCED_MODEL", ), - description="SDK (extended-thinking) path, 'advanced' tier. Opus " - "by default. Override via ``CHAT_THINKING_ADVANCED_MODEL`` " - "(legacy ``CHAT_ADVANCED_MODEL`` still honored).", + description="SDK (extended-thinking) path, 'advanced' tier. LD " + "override: ``copilot-thinking-advanced-model``.", ) title_model: str = Field( default="openai/gpt-4o-mini", description="Model to use for generating session titles (should be fast/cheap)", ) simulation_model: str = Field( - default="google/gemini-2.5-flash", - description="Model for dry-run block simulation (should be fast/cheap with good JSON output)", + default="google/gemini-2.5-flash-lite", + description="Model for dry-run block simulation (should be fast/cheap with good JSON output). " + "Gemini 2.5 Flash-Lite is ~3x cheaper than Flash ($0.10/$0.40 vs $0.30/$1.20 per MTok) " + "with JSON-mode reliability adequate for shape-matching block outputs.", ) api_key: str | None = Field(default=None, description="OpenAI API key") base_url: str | None = Field( @@ -249,13 +240,28 @@ class ChatConfig(BaseSettings): "``max_thinking_tokens`` kwarg so the CLI falls back to model default " "(which, without the flag, leaves extended thinking off).", ) + render_reasoning_in_ui: bool = Field( + default=True, + description="Render reasoning as live UI parts " + "(``StreamReasoning*`` wire events). False suppresses the live " + "wire events only; ``role='reasoning'`` rows are always persisted " + "so the reasoning bubble hydrates on reload. Tokens are billed " + "upstream regardless.", + ) + stream_replay_count: int = Field( + default=200, + ge=1, + le=10000, + description="Max Redis stream entries replayed on SSE reconnect.", + ) claude_agent_thinking_effort: Literal["low", "medium", "high", "max"] | None = ( Field( default=None, description="Thinking effort level: 'low', 'medium', 'high', 'max', or None. " - "Only applies to models with extended thinking (Opus). " - "Sonnet doesn't have extended thinking — setting effort on Sonnet " - "can cause tag leaks. " + "Applies to models that emit a reasoning channel — Opus (extended " + "thinking) and Kimi K2.6 (OpenRouter ``reasoning`` extension lit " + "up by #12871). Sonnet does not have extended thinking — setting " + "effort on Sonnet can cause tag leaks. " "None = let the model decide. Override via CHAT_CLAUDE_AGENT_THINKING_EFFORT.", ) ) @@ -287,6 +293,31 @@ class ChatConfig(BaseSettings): "(24h, permanent) TTL option — see " "https://platform.claude.com/docs/en/build-with-claude/prompt-caching.", ) + sdk_include_partial_messages: bool = Field( + default=True, + description="Stream SDK responses token-by-token instead of in " + "one lump at the end. Set to False if the SDK path starts " + "double-writing text or dropping the tail of long messages.", + ) + sdk_reconcile_openrouter_cost: bool = Field( + default=True, + description="Query OpenRouter's ``/api/v1/generation?id=`` after each " + "SDK turn and record the authoritative ``total_cost`` instead of the " + "Claude Agent SDK CLI's estimate. Covers every OpenRouter-routed " + "SDK turn regardless of vendor — the CLI's static Anthropic pricing " + "table is accurate for Anthropic models (Sonnet/Opus via OpenRouter " + "bill at Anthropic's own rates, penny-for-penny), but the reconcile " + "catches any future rate change the CLI hasn't picked up and makes " + "non-Anthropic cost (Kimi et al) correct — real billed amount, " + "matching the baseline path's ``usage.cost`` read since #12864. " + "Kill-switch for emergencies: set ``CHAT_SDK_RECONCILE_OPENROUTER_COST" + "=false`` to fall back to the CLI's ``total_cost_usd`` reported " + "synchronously (accurate-for-Anthropic / over-billed-for-Kimi). " + "Tradeoff: 0.5-2s window between turn end and cost write; rate-limit " + "counter briefly unaware, back-to-back turns in that window see " + "stale state. The alternative (writing an estimate sync then a " + "correction delta) would double-count the rate limit.", + ) claude_agent_cli_path: str | None = Field( default=None, description="Optional explicit path to a Claude Code CLI binary. " @@ -457,6 +488,59 @@ class ChatConfig(BaseSettings): ) return v + @model_validator(mode="after") + def _validate_sdk_model_vendor_compatibility(self) -> "ChatConfig": + """Fail at config load when an SDK model slug is incompatible with + explicit direct-Anthropic mode. + + The SDK path's ``_normalize_model_name`` raises ``ValueError`` when + a non-Anthropic vendor slug (e.g. ``moonshotai/kimi-k2.6``) is paired + with direct-Anthropic mode — but that fires inside the request loop, + so a misconfigured deployment would surface a 500 to every user + instead of failing visibly at boot. + + Only the **explicit** opt-out (``use_openrouter=False``) is checked + here, not the credential-missing path. Build environments and + OpenAPI-schema export jobs construct ``ChatConfig()`` without any + OpenRouter credentials in the env — that's not a misconfiguration, + it's "config loads ok, but no SDK turn will succeed until creds are + wired". The runtime guard in ``_normalize_model_name`` still + catches the credential-missing path on the first SDK turn. + + Covers all three SDK fields that flow through + ``_normalize_model_name``: primary tier + (``thinking_standard_model``), advanced tier + (``thinking_advanced_model``), and fallback model + (``claude_agent_fallback_model`` via ``_resolve_fallback_model``). + + Skipped when ``use_claude_code_subscription=True`` because the + subscription path resolves the model to ``None`` (CLI default) + and never calls ``_normalize_model_name``. Empty fallback strings + are also skipped (no fallback configured). + """ + if self.use_claude_code_subscription: + return self + if self.use_openrouter: + return self + for field_name in ( + "thinking_standard_model", + "thinking_advanced_model", + "claude_agent_fallback_model", + ): + value: str = getattr(self, field_name) + if not value or "/" not in value: + continue + if value.split("/", 1)[0] != "anthropic": + raise ValueError( + f"Direct-Anthropic mode (use_openrouter=False) " + f"requires an Anthropic model for {field_name}, got " + f"{value!r}. Set CHAT_THINKING_STANDARD_MODEL / " + f"CHAT_THINKING_ADVANCED_MODEL / " + f"CHAT_CLAUDE_AGENT_FALLBACK_MODEL to an anthropic/* " + f"slug, or set CHAT_USE_OPENROUTER=true." + ) + return self + # Prompt paths for different contexts PROMPT_PATHS: dict[str, str] = { "default": "prompts/chat_system.md", diff --git a/autogpt_platform/backend/backend/copilot/config_test.py b/autogpt_platform/backend/backend/copilot/config_test.py index fe8e67b7ff..42e36bc1f4 100644 --- a/autogpt_platform/backend/backend/copilot/config_test.py +++ b/autogpt_platform/backend/backend/copilot/config_test.py @@ -5,12 +5,17 @@ import pytest from .config import ChatConfig # Env vars that the ChatConfig validators read — must be cleared so they don't -# override the explicit constructor values we pass in each test. +# override the explicit constructor values we pass in each test. Includes the +# SDK/baseline model aliases so a leftover ``CHAT_MODEL=...`` in the developer +# or CI environment can't change whether +# ``_validate_sdk_model_vendor_compatibility`` raises. _ENV_VARS_TO_CLEAR = ( "CHAT_USE_E2B_SANDBOX", "CHAT_E2B_API_KEY", "E2B_API_KEY", "CHAT_USE_OPENROUTER", + "CHAT_USE_CLAUDE_AGENT_SDK", + "CHAT_USE_CLAUDE_CODE_SUBSCRIPTION", "CHAT_API_KEY", "OPEN_ROUTER_API_KEY", "OPENAI_API_KEY", @@ -19,6 +24,16 @@ _ENV_VARS_TO_CLEAR = ( "OPENAI_BASE_URL", "CHAT_CLAUDE_AGENT_CLI_PATH", "CLAUDE_AGENT_CLI_PATH", + "CHAT_FAST_STANDARD_MODEL", + "CHAT_FAST_MODEL", + "CHAT_FAST_ADVANCED_MODEL", + "CHAT_THINKING_STANDARD_MODEL", + "CHAT_THINKING_ADVANCED_MODEL", + "CHAT_MODEL", + "CHAT_ADVANCED_MODEL", + "CHAT_CLAUDE_AGENT_FALLBACK_MODEL", + "CHAT_RENDER_REASONING_IN_UI", + "CHAT_STREAM_REPLAY_COUNT", ) @@ -28,6 +43,22 @@ def _clean_env(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv(var, raising=False) +def _make_direct_safe_config(**kwargs) -> ChatConfig: + """Build a ``ChatConfig`` for tests that pass ``use_openrouter=False`` + but aren't exercising the SDK vendor-compatibility validator. + + Pins ``thinking_standard_model``/``thinking_advanced_model`` to anthropic/* + so the construction passes ``_validate_sdk_model_vendor_compatibility`` + without each test having to repeat the override. + """ + defaults: dict = { + "thinking_standard_model": "anthropic/claude-sonnet-4-6", + "thinking_advanced_model": "anthropic/claude-opus-4-7", + } + defaults.update(kwargs) + return ChatConfig(**defaults) + + class TestOpenrouterActive: """Tests for the openrouter_active property.""" @@ -48,7 +79,7 @@ class TestOpenrouterActive: assert cfg.openrouter_active is False def test_disabled_returns_false_despite_credentials(self): - cfg = ChatConfig( + cfg = _make_direct_safe_config( use_openrouter=False, api_key="or-key", base_url="https://openrouter.ai/api/v1", @@ -164,3 +195,133 @@ class TestClaudeAgentCliPathEnvFallback: monkeypatch.setenv("CLAUDE_AGENT_CLI_PATH", str(tmp_path)) with pytest.raises(Exception, match="not a regular file"): ChatConfig() + + +class TestSdkModelVendorCompatibility: + """``model_validator`` that fails fast on SDK model vs routing-mode + mismatch — see PR #12878 iteration-2 review. Mirrors the runtime + guard in ``_normalize_model_name`` so misconfig surfaces at boot + instead of as a 500 on the first SDK turn.""" + + def test_direct_anthropic_with_kimi_override_raises(self): + """A non-Anthropic SDK model must fail at config load when the + deployment has no OpenRouter credentials.""" + with pytest.raises(Exception, match="requires an Anthropic model"): + ChatConfig( + use_openrouter=False, + api_key=None, + base_url=None, + use_claude_code_subscription=False, + thinking_standard_model="moonshotai/kimi-k2.6", + ) + + def test_direct_anthropic_with_anthropic_default_succeeds(self): + """Direct-Anthropic mode is fine when both SDK slugs are anthropic/* + — which is the default after the LD-routed model rollout.""" + cfg = ChatConfig( + use_openrouter=False, + api_key=None, + base_url=None, + use_claude_code_subscription=False, + ) + assert cfg.thinking_standard_model == "anthropic/claude-sonnet-4-6" + + def test_openrouter_with_kimi_override_succeeds(self): + """Kimi slug round-trips cleanly when OpenRouter is on — exercised + via the LD-flag override path in production.""" + cfg = ChatConfig( + use_openrouter=True, + api_key="or-key", + base_url="https://openrouter.ai/api/v1", + use_claude_code_subscription=False, + thinking_standard_model="moonshotai/kimi-k2.6", + ) + assert cfg.thinking_standard_model == "moonshotai/kimi-k2.6" + + def test_subscription_mode_skips_check(self): + """Subscription path resolves the model to None and bypasses + ``_normalize_model_name``, so the slug check is skipped.""" + cfg = ChatConfig( + use_openrouter=False, + api_key=None, + base_url=None, + use_claude_code_subscription=True, + ) + assert cfg.use_claude_code_subscription is True + + def test_advanced_tier_also_validated(self): + """Both standard and advanced SDK slugs are checked.""" + with pytest.raises(Exception, match="thinking_advanced_model"): + ChatConfig( + use_openrouter=False, + api_key=None, + base_url=None, + use_claude_code_subscription=False, + thinking_standard_model="anthropic/claude-sonnet-4-6", + thinking_advanced_model="moonshotai/kimi-k2.6", + ) + + def test_fallback_model_also_validated(self): + """``claude_agent_fallback_model`` flows through + ``_normalize_model_name`` via ``_resolve_fallback_model`` so the + same direct-Anthropic guard applies.""" + with pytest.raises(Exception, match="claude_agent_fallback_model"): + ChatConfig( + use_openrouter=False, + api_key=None, + base_url=None, + use_claude_code_subscription=False, + thinking_standard_model="anthropic/claude-sonnet-4-6", + thinking_advanced_model="anthropic/claude-opus-4-7", + claude_agent_fallback_model="moonshotai/kimi-k2.6", + ) + + def test_empty_fallback_skipped(self): + """Empty ``claude_agent_fallback_model`` (no fallback configured) + must not trip the validator — the fallback-disabled state is + intentional and shouldn't require a placeholder anthropic/* slug.""" + cfg = ChatConfig( + use_openrouter=False, + api_key=None, + base_url=None, + use_claude_code_subscription=False, + thinking_standard_model="anthropic/claude-sonnet-4-6", + thinking_advanced_model="anthropic/claude-opus-4-7", + claude_agent_fallback_model="", + ) + assert cfg.claude_agent_fallback_model == "" + + +class TestRenderReasoningInUi: + """``render_reasoning_in_ui`` gates reasoning wire events globally.""" + + def test_defaults_to_true(self): + """Default must stay True — flipping it silences the reasoning + collapse for every user, which is an opt-in operator decision.""" + cfg = ChatConfig() + assert cfg.render_reasoning_in_ui is True + + def test_env_override_false(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CHAT_RENDER_REASONING_IN_UI", "false") + cfg = ChatConfig() + assert cfg.render_reasoning_in_ui is False + + +class TestStreamReplayCount: + """``stream_replay_count`` caps the SSE reconnect replay batch size.""" + + def test_default_is_200(self): + """200 covers a full Kimi turn after coalescing (~150 events) while + bounding the replay storm from 1000+ chunks.""" + cfg = ChatConfig() + assert cfg.stream_replay_count == 200 + + def test_env_override(self, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setenv("CHAT_STREAM_REPLAY_COUNT", "500") + cfg = ChatConfig() + assert cfg.stream_replay_count == 500 + + def test_zero_rejected(self): + """count=0 would make XREAD replay nothing — rejected via ge=1.""" + with pytest.raises(Exception): + ChatConfig(stream_replay_count=0) diff --git a/autogpt_platform/backend/backend/copilot/executor/manager.py b/autogpt_platform/backend/backend/copilot/executor/manager.py index 02a2913883..08baf73c05 100644 --- a/autogpt_platform/backend/backend/copilot/executor/manager.py +++ b/autogpt_platform/backend/backend/copilot/executor/manager.py @@ -105,25 +105,46 @@ class CoPilotExecutor(AppProcess): time.sleep(1e5) def cleanup(self): - """Graceful shutdown with active execution waiting.""" - pid = os.getpid() - logger.info(f"[cleanup {pid}] Starting graceful shutdown...") + """Graceful shutdown — mirrors ``backend.executor.manager`` pattern. - # Signal the consumer thread to stop + 1. Stop consumer immediately (both the Python flag that gates + ``_handle_run_message`` and ``channel.stop_consuming()`` at + the broker), so no new work enters. + 2. Passively wait for ``active_tasks`` to drain — each turn's + own ``finally`` publishes its terminal state via + ``mark_session_completed``. When a turn exits, ``on_run_done`` + removes it from ``active_tasks`` and releases its cluster lock. + 3. Shut down the thread-pool executor (cancels pending, leaves + running threads alone — process exit handles them). + 4. Release any cluster locks still held (defensive — on_run_done's + finally should have already released them). + 5. Stop message consumer threads + disconnect pika clients. + + The zombie-session bug this PR targets is handled inside each + turn's own lifecycle by :func:`sync_fail_close_session`, NOT by + cleanup — so cleanup can stay as a simple "wait, then teardown" + and matches agent-executor's proven pattern. + """ + pid = os.getpid() + prefix = f"[cleanup {pid}]" + logger.info(f"{prefix} Starting graceful shutdown...") + + # 1. Stop consumer — flag AND broker-side try: self.stop_consuming.set() run_channel = self.run_client.get_channel() run_channel.connection.add_callback_threadsafe( lambda: run_channel.stop_consuming() ) - logger.info(f"[cleanup {pid}] Consumer has been signaled to stop") + logger.info(f"{prefix} Consumer has been signaled to stop") except Exception as e: - logger.error(f"[cleanup {pid}] Error stopping consumer: {e}") + logger.error(f"{prefix} Error stopping consumer: {e}") - # Wait for active executions to complete + # 2. Wait for in-flight turns to finish naturally if self.active_tasks: logger.info( - f"[cleanup {pid}] Waiting for {len(self.active_tasks)} active tasks to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..." + f"{prefix} Waiting for {len(self.active_tasks)} active tasks " + f"to complete (timeout: {GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s)..." ) start_time = time.monotonic() @@ -138,38 +159,42 @@ class CoPilotExecutor(AppProcess): if not self.active_tasks: break - # Refresh cluster locks periodically - current_time = time.monotonic() - if current_time - last_refresh >= lock_refresh_interval: + now = time.monotonic() + if now - last_refresh >= lock_refresh_interval: for lock in list(self._task_locks.values()): try: lock.refresh() except Exception as e: - logger.warning( - f"[cleanup {pid}] Failed to refresh lock: {e}" - ) - last_refresh = current_time + logger.warning(f"{prefix} Failed to refresh lock: {e}") + last_refresh = now logger.info( - f"[cleanup {pid}] {len(self.active_tasks)} tasks still active, waiting..." + f"{prefix} {len(self.active_tasks)} tasks still active, waiting..." ) time.sleep(10.0) - # Stop message consumers + if self.active_tasks: + logger.warning( + f"{prefix} {len(self.active_tasks)} tasks still running after " + f"{GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS}s — process exit will " + f"abandon them; RabbitMQ redelivery handles the message." + ) + + # 3. Stop message consumer threads if self._run_thread: self._stop_message_consumers( - self._run_thread, self.run_client, "[cleanup][run]" + self._run_thread, self.run_client, f"{prefix} [run]" ) if self._cancel_thread: self._stop_message_consumers( - self._cancel_thread, self.cancel_client, "[cleanup][cancel]" + self._cancel_thread, self.cancel_client, f"{prefix} [cancel]" ) - # Clean up worker threads (closes per-loop workspace storage sessions) + # 4. Worker cleanup + executor shutdown if self._executor: from .processor import cleanup_worker - logger.info(f"[cleanup {pid}] Cleaning up workers...") + logger.info(f"{prefix} Cleaning up workers...") futures = [] for _ in range(self._executor._max_workers): futures.append(self._executor.submit(cleanup_worker)) @@ -177,22 +202,20 @@ class CoPilotExecutor(AppProcess): try: f.result(timeout=10) except Exception as e: - logger.warning(f"[cleanup {pid}] Worker cleanup error: {e}") + logger.warning(f"{prefix} Worker cleanup error: {e}") - logger.info(f"[cleanup {pid}] Shutting down executor...") + logger.info(f"{prefix} Shutting down executor...") self._executor.shutdown(wait=False) - # Release any remaining locks + # 5. Release any cluster locks still held for session_id, lock in list(self._task_locks.items()): try: lock.release() - logger.info(f"[cleanup {pid}] Released lock for {session_id}") + logger.info(f"{prefix} Released lock for {session_id}") except Exception as e: - logger.error( - f"[cleanup {pid}] Failed to release lock for {session_id}: {e}" - ) + logger.error(f"{prefix} Failed to release lock for {session_id}: {e}") - logger.info(f"[cleanup {pid}] Graceful shutdown completed") + logger.info(f"{prefix} Graceful shutdown completed") # ============ RabbitMQ Consumer Methods ============ # @@ -387,13 +410,12 @@ class CoPilotExecutor(AppProcess): # Execute the task try: - self._task_locks[session_id] = cluster_lock - logger.info( f"Acquired cluster lock for {session_id}, " f"executor_id={self.executor_id}" ) + self._task_locks[session_id] = cluster_lock cancel_event = threading.Event() future = self.executor.submit( execute_copilot_turn, entry, cancel_event, cluster_lock @@ -425,7 +447,6 @@ class CoPilotExecutor(AppProcess): error_msg = str(e) or type(e).__name__ logger.exception(f"Error in run completion callback: {error_msg}") finally: - # Release the cluster lock if session_id in self._task_locks: logger.info(f"Releasing cluster lock for {session_id}") self._task_locks[session_id].release() diff --git a/autogpt_platform/backend/backend/copilot/executor/processor.py b/autogpt_platform/backend/backend/copilot/executor/processor.py index f40264b70b..3838302504 100644 --- a/autogpt_platform/backend/backend/copilot/executor/processor.py +++ b/autogpt_platform/backend/backend/copilot/executor/processor.py @@ -5,6 +5,7 @@ in a thread-local context, following the graph executor pattern. """ import asyncio +import concurrent.futures import logging import os import subprocess @@ -30,6 +31,87 @@ from .utils import CoPilotExecutionEntry, CoPilotLogMetadata logger = TruncatedLogger(logging.getLogger(__name__), prefix="[CoPilotExecutor]") +SHUTDOWN_ERROR_MESSAGE = ( + "Copilot executor shut down before this turn finished. Please retry." +) + +# Max time execute() blocks after calling future.cancel() / when draining a +# soon-to-be-cancelled future. Gives _execute_async's own finally a chance to +# publish the accurate terminal state over the Redis CAS; long enough to let +# an in-flight Redis call settle, short enough that shutdown doesn't stall. +_CANCEL_GRACE_SECONDS = 5.0 + +# Max time the sync safety net itself spends on a single Redis CAS. Without +# this bound the whole point of ``sync_fail_close_session`` is defeated — +# ``mark_session_completed`` would hang on the same broken Redis that caused +# the original failure. On timeout we give up silently; worst case the +# session stays ``running`` until the stale-session watchdog reaps it, but +# at least the pool worker thread isn't blocked forever. +_FAIL_CLOSE_REDIS_TIMEOUT = 10.0 + + +# Module-level symbol preserved for backward-compat with callers that import +# ``sync_fail_close_session``; the real implementation now lives on +# ``CoPilotProcessor`` so it can reuse ``self.execution_loop`` (same +# pattern as ``backend.executor.manager``'s ``node_execution_loop`` bridge +# at :meth:`ExecutionProcessor.on_graph_execution`). + + +def sync_fail_close_session( + session_id: str, + log: "CoPilotLogMetadata | TruncatedLogger", + execution_loop: asyncio.AbstractEventLoop, +) -> None: + """Synchronously mark *session_id* as failed from the pool worker thread. + + Submits the CAS coroutine to the long-lived *execution_loop* via + ``run_coroutine_threadsafe`` — the same shape agent-executor uses at + :meth:`backend.executor.manager.ExecutionProcessor.on_graph_execution` + to reach its ``node_execution_loop`` from the pool worker. Reusing the + persistent loop means: + + * no fresh TCP connection per turn (the ``@thread_cached`` + ``AsyncRedis`` on the execution thread stays bound to the same loop + and is reused across every turn); + * no loop-teardown overhead; + * no ``clear_cache()`` gymnastics to dodge the "loop is closed" pitfall. + + ``mark_session_completed`` is an atomic CAS on ``status == "running"``, + so when the async path already wrote a terminal state the sync call is + a cheap no-op. The inner ``asyncio.wait_for`` bounds the Redis call so + a wedged Redis can't hang the safety net for the full redis-py default + TCP timeout; the outer ``.result(timeout=...)`` is a belt-and-braces + upper bound for the cross-thread wait. + """ + + async def _bounded() -> None: + await asyncio.wait_for( + stream_registry.mark_session_completed( + session_id, error_message=SHUTDOWN_ERROR_MESSAGE + ), + timeout=_FAIL_CLOSE_REDIS_TIMEOUT, + ) + + try: + future = asyncio.run_coroutine_threadsafe(_bounded(), execution_loop) + except RuntimeError as e: + # execution_loop is closed — happens if cleanup() already ran the + # per-worker teardown. Nothing we can do; let the stale-session + # watchdog reap it. + log.warning(f"sync fail-close skipped (execution_loop closed): {e}") + return + try: + future.result(timeout=_FAIL_CLOSE_REDIS_TIMEOUT + 2) + except concurrent.futures.TimeoutError: + log.warning( + f"sync fail-close timed out after {_FAIL_CLOSE_REDIS_TIMEOUT}s " + f"(session={session_id})" + ) + future.cancel() + except Exception as e: + log.warning(f"sync fail-close mark_session_completed failed: {e}") + + # ============ Mode Routing ============ # @@ -252,12 +334,13 @@ class CoPilotProcessor: ): """Execute a CoPilot turn. - Runs the async logic in the worker's event loop and handles errors. - - Args: - entry: The turn payload containing session and message info - cancel: Threading event to signal cancellation - cluster_lock: Distributed lock to prevent duplicate execution + Thin wrapper around :meth:`_execute`. The ``try/finally`` here + guarantees :func:`sync_fail_close_session` runs on every exit + path — normal completion, exception, or a wedged event loop + that escapes via :data:`_CANCEL_GRACE_SECONDS` timeout. + ``mark_session_completed`` is an atomic CAS on + ``status == "running"``, so when the async path already wrote a + terminal state the sync call is a cheap no-op. """ log = CoPilotLogMetadata( logging.getLogger(__name__), @@ -265,10 +348,28 @@ class CoPilotProcessor: user_id=entry.user_id, ) log.info("Starting execution") - start_time = time.monotonic() + try: + self._execute(entry, cancel, cluster_lock, log) + finally: + sync_fail_close_session(entry.session_id, log, self.execution_loop) + elapsed = time.monotonic() - start_time + log.info(f"Execution completed in {elapsed:.2f}s") - # Run the async execution in our event loop + def _execute( + self, + entry: CoPilotExecutionEntry, + cancel: threading.Event, + cluster_lock: ClusterLock, + log: CoPilotLogMetadata, + ): + """Submit the async turn to ``self.execution_loop`` and drive it. + + Handles the sync/async boundary (cancel-event checks, cluster-lock + refresh, bounded waits) without any Redis-state cleanup logic — + that lives in :func:`sync_fail_close_session` which the outer + :meth:`execute` always invokes on exit. + """ future = asyncio.run_coroutine_threadsafe( self._execute_async(entry, cancel, cluster_lock, log), self.execution_loop, @@ -282,16 +383,27 @@ class CoPilotProcessor: if cancel.is_set(): log.info("Cancellation requested") future.cancel() - break - # Refresh cluster lock to maintain ownership + # Give _execute_async's own finally a short window to + # publish its accurate terminal state before the outer + # sync safety net fires. + try: + future.result(timeout=_CANCEL_GRACE_SECONDS) + except BaseException: + pass + return cluster_lock.refresh() if not future.cancelled(): - # Get result to propagate any exceptions - future.result() - - elapsed = time.monotonic() - start_time - log.info(f"Execution completed in {elapsed:.2f}s") + # Bounded timeout so a wedged event loop can't trap us here — + # on timeout we escape to execute()'s finally and the sync + # safety net fires. + try: + future.result(timeout=_CANCEL_GRACE_SECONDS) + except concurrent.futures.TimeoutError: + log.warning( + "Future did not complete within grace window; " + "falling through to sync fail-close" + ) async def _execute_async( self, diff --git a/autogpt_platform/backend/backend/copilot/executor/processor_test.py b/autogpt_platform/backend/backend/copilot/executor/processor_test.py index 5541648747..cdc393e5b1 100644 --- a/autogpt_platform/backend/backend/copilot/executor/processor_test.py +++ b/autogpt_platform/backend/backend/copilot/executor/processor_test.py @@ -10,6 +10,8 @@ the real production helpers from ``processor.py`` so the routing logic has meaningful coverage. """ +import asyncio +import concurrent.futures import logging import threading from unittest.mock import AsyncMock, MagicMock, patch @@ -20,6 +22,7 @@ from backend.copilot.executor.processor import ( CoPilotProcessor, resolve_effective_mode, resolve_use_sdk_for_mode, + sync_fail_close_session, ) from backend.copilot.executor.utils import CoPilotExecutionEntry, CoPilotLogMetadata @@ -275,3 +278,221 @@ class TestExecuteAsyncAclose: await proc._execute_async(_make_entry(), cancel, cluster_lock, _make_log()) assert published.aclose_called is True + + +@pytest.fixture +def exec_loop(): + """Long-lived asyncio loop on a daemon thread — mirrors the layout + ``CoPilotProcessor`` sets up (``execution_loop`` + ``execution_thread``) + so ``sync_fail_close_session`` has a real cross-thread loop to submit + into via ``run_coroutine_threadsafe``.""" + loop = asyncio.new_event_loop() + thread = threading.Thread(target=loop.run_forever, daemon=True) + thread.start() + try: + yield loop + finally: + loop.call_soon_threadsafe(loop.stop) + thread.join(timeout=5) + loop.close() + + +class TestSyncFailCloseSession: + """``sync_fail_close_session`` is the last-line-of-defense invoked from + ``CoPilotProcessor.execute``'s ``finally``. It must call + ``mark_session_completed`` via the processor's long-lived + ``execution_loop`` (cross-thread submit) and must swallow Redis + failures so a transient outage doesn't propagate out of the finally.""" + + def test_invokes_mark_session_completed_with_shutdown_message( + self, exec_loop + ) -> None: + mock_mark = AsyncMock() + with patch( + "backend.copilot.executor.processor.stream_registry.mark_session_completed", + new=mock_mark, + ): + sync_fail_close_session("sess-1", _make_log(), exec_loop) + + mock_mark.assert_awaited_once() + assert mock_mark.await_args is not None + assert mock_mark.await_args.args[0] == "sess-1" + assert "shut down" in mock_mark.await_args.kwargs["error_message"].lower() + + def test_swallows_redis_error(self, exec_loop) -> None: + # Raising from the mock ensures the helper catches the exception + # instead of propagating it back into execute()'s finally block. + mock_mark = AsyncMock(side_effect=RuntimeError("redis down")) + with patch( + "backend.copilot.executor.processor.stream_registry.mark_session_completed", + new=mock_mark, + ): + sync_fail_close_session("sess-2", _make_log(), exec_loop) # must not raise + + mock_mark.assert_awaited_once() + + def test_closed_execution_loop_skipped_cleanly(self) -> None: + """If cleanup_worker has already stopped the execution_loop by the + time the safety net fires, ``run_coroutine_threadsafe`` raises + RuntimeError. Expected behavior: log + return without propagating.""" + dead_loop = asyncio.new_event_loop() + dead_loop.close() + + mock_mark = AsyncMock() + with patch( + "backend.copilot.executor.processor.stream_registry.mark_session_completed", + new=mock_mark, + ): + # Must not raise even though the loop is closed + sync_fail_close_session("sess-closed-loop", _make_log(), dead_loop) + + # mark_session_completed was never scheduled because the loop was dead + mock_mark.assert_not_awaited() + + def test_bounded_timeout_when_redis_hangs(self, exec_loop) -> None: + """Scenario D: Redis unreachable — the inner ``asyncio.wait_for`` + must fire and the helper must return without blocking the worker. + + Simulates a wedged Redis by sleeping past the 10s fail-close budget. + The helper must return within the configured grace (+ a small + scheduler margin) and must not re-raise. + """ + import time as _time + + from backend.copilot.executor.processor import _FAIL_CLOSE_REDIS_TIMEOUT + + async def _hang(*_args, **_kwargs): + await asyncio.sleep(_FAIL_CLOSE_REDIS_TIMEOUT + 5) + + with patch( + "backend.copilot.executor.processor.stream_registry.mark_session_completed", + new=_hang, + ): + start = _time.monotonic() + sync_fail_close_session( + "sess-hang", _make_log(), exec_loop + ) # must not raise + elapsed = _time.monotonic() - start + + # wait_for fires at _FAIL_CLOSE_REDIS_TIMEOUT; outer future.result + # has +2s slack. If the timeout is missing/broken the helper would + # block the full sleep duration (~15s). + assert elapsed < _FAIL_CLOSE_REDIS_TIMEOUT + 4.0, ( + f"sync_fail_close_session hung for {elapsed:.1f}s — bounded " + f"timeout did not fire" + ) + + +# --------------------------------------------------------------------------- +# End-to-end execute() safety-net coverage — the PR's core invariant +# --------------------------------------------------------------------------- + + +class TestExecuteSafetyNet: + """``CoPilotProcessor.execute`` must always invoke + ``sync_fail_close_session`` in its ``finally`` so a session never stays + ``status=running`` in Redis. + + Validates the four deploy-time scenarios the PR targets: + + * A — SIGTERM mid-turn: ``cancel`` event fires, ``_execute`` returns, + safety net still runs. + * B — happy path: normal completion, safety net runs (cheap CAS no-op). + * C — zombie Redis state: the async ``mark_session_completed`` in + ``_execute_async`` blows up, but the outer safety net marks the + session failed anyway. + * D — covered by ``TestSyncFailCloseSession::test_bounded_timeout…``. + """ + + def _attach_exec_loop(self, proc: CoPilotProcessor, loop) -> None: + """``execute`` dispatches the safety net onto ``self.execution_loop``. + Tests don't call ``on_executor_start`` (which spawns the real + per-worker loop), so wire the shared fixture loop in directly.""" + proc.execution_loop = loop + + def _run_execute_in_thread(self, proc: CoPilotProcessor, cancel: threading.Event): + """``CoPilotProcessor.execute`` expects to be called from a pool + worker thread that has *no* running event loop, so we always run + it off the main thread to preserve that invariant. Returns the + future so callers can inspect both result and exception paths.""" + pool = concurrent.futures.ThreadPoolExecutor(max_workers=1) + try: + fut = pool.submit(proc.execute, _make_entry(), cancel, MagicMock()) + # Block until execute() returns (or raises) so the safety net + # has run by the time we inspect mocks. + try: + fut.result(timeout=30) + except BaseException: + pass + return fut + finally: + pool.shutdown(wait=True) + + def test_happy_path_invokes_safety_net(self, exec_loop) -> None: + """Scenario B: normal completion still runs the sync safety net. + Proves the ``finally`` always fires, even when nothing went wrong — + ``mark_session_completed``'s atomic CAS makes this a cheap no-op + in production.""" + mock_mark = AsyncMock() + proc = CoPilotProcessor() + self._attach_exec_loop(proc, exec_loop) + with patch.object(proc, "_execute"), patch( + "backend.copilot.executor.processor.stream_registry.mark_session_completed", + new=mock_mark, + ): + self._run_execute_in_thread(proc, threading.Event()) + + mock_mark.assert_awaited_once() + assert mock_mark.await_args is not None + assert mock_mark.await_args.args[0] == "sess-1" + + def test_sigterm_mid_turn_invokes_safety_net(self, exec_loop) -> None: + """Scenario A: worker raises (simulating future.cancel + grace + timeout escaping ``_execute``); ``execute`` must still reach the + safety net in its ``finally`` and mark the session failed.""" + mock_mark = AsyncMock() + proc = CoPilotProcessor() + self._attach_exec_loop(proc, exec_loop) + with patch.object( + proc, + "_execute", + side_effect=concurrent.futures.TimeoutError("grace expired"), + ), patch( + "backend.copilot.executor.processor.stream_registry.mark_session_completed", + new=mock_mark, + ): + self._run_execute_in_thread(proc, threading.Event()) + + mock_mark.assert_awaited_once() + + def test_zombie_redis_async_path_still_marks_session_failed( + self, exec_loop + ) -> None: + """Scenario C: ``_execute_async``'s own ``mark_session_completed`` + call is broken (simulating the exact async-Redis hiccup that caused + the original zombie sessions). The outer ``sync_fail_close_session`` + runs on the processor's long-lived ``execution_loop`` and succeeds + where the async path failed.""" + call_log: list[str] = [] + + async def _ok(*args, **kwargs): + call_log.append("sync-ok") + + def _broken_execute(entry, cancel, cluster_lock, log): + # Simulate the async path raising because its Redis client is + # wedged (the pre-fix zombie-session scenario). + raise RuntimeError("async Redis client broken") + + proc = CoPilotProcessor() + self._attach_exec_loop(proc, exec_loop) + with patch.object(proc, "_execute", side_effect=_broken_execute), patch( + "backend.copilot.executor.processor.stream_registry.mark_session_completed", + new=_ok, + ): + self._run_execute_in_thread(proc, threading.Event()) + + # The sync safety net must have fired despite the async path + # blowing up — this is the core guarantee of the PR. + assert call_log == [ + "sync-ok" + ], f"expected sync_fail_close_session to run once, got {call_log!r}" diff --git a/autogpt_platform/backend/backend/copilot/executor/utils.py b/autogpt_platform/backend/backend/copilot/executor/utils.py index a2b051d82b..de1681b55c 100644 --- a/autogpt_platform/backend/backend/copilot/executor/utils.py +++ b/autogpt_platform/backend/backend/copilot/executor/utils.py @@ -89,11 +89,16 @@ def get_session_lock_key(session_id: str) -> str: # CoPilot operations can include extended thinking and agent generation -# which may take 30+ minutes to complete -COPILOT_CONSUMER_TIMEOUT_SECONDS = 60 * 60 # 1 hour +# which may take several hours to complete. Matches the pod's +# terminationGracePeriodSeconds in the helm chart so a rolling deploy can let +# the longest legitimate turn finish. Also bounds the stale-session auto- +# complete watchdog in stream_registry (consumer_timeout + 5min buffer). +COPILOT_CONSUMER_TIMEOUT_SECONDS = 6 * 60 * 60 # 6 hours -# Graceful shutdown timeout - allow in-flight operations to complete -GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = 30 * 60 # 30 minutes +# Graceful shutdown timeout - must match COPILOT_CONSUMER_TIMEOUT_SECONDS so +# cleanup can let the longest legitimate turn complete before the pod is +# SIGKILL'd by kubelet. +GRACEFUL_SHUTDOWN_TIMEOUT_SECONDS = COPILOT_CONSUMER_TIMEOUT_SECONDS def create_copilot_queue_config() -> RabbitMQConfig: @@ -113,9 +118,27 @@ def create_copilot_queue_config() -> RabbitMQConfig: durable=True, auto_delete=False, arguments={ - # Extended consumer timeout for long-running LLM operations - # Default 30-minute timeout is insufficient for extended thinking - # and agent generation which can take 30+ minutes + # Consumer timeout matches the pod graceful-shutdown window so a + # rolling deploy never forces redelivery of a turn that the pod + # is still legitimately finishing. + # + # Deploy note: RabbitMQ (verified on 4.1.4) does NOT strictly + # compare ``x-consumer-timeout`` on queue redeclaration, so this + # value can change between deploys without triggering + # PRECONDITION_FAILED. To update the *effective* timeout on an + # already-running queue before the new code deploys (so pods + # mid-shutdown don't have their consumer cancelled at the old + # limit), apply a policy: + # + # rabbitmqctl set_policy copilot-consumer-timeout \ + # "^copilot_execution_queue$" \ + # '{"consumer-timeout": 21600000}' \ + # --apply-to queues + # + # The policy takes effect immediately. Once the policy is set + # to match the code's value the policy is redundant for new + # pods and can be removed after a stable deploy if desired — + # but it's harmless to leave in place. "x-consumer-timeout": COPILOT_CONSUMER_TIMEOUT_SECONDS * 1000, }, diff --git a/autogpt_platform/backend/backend/copilot/model_router.py b/autogpt_platform/backend/backend/copilot/model_router.py new file mode 100644 index 0000000000..35a881393e --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/model_router.py @@ -0,0 +1,104 @@ +"""LaunchDarkly-aware model selection for the copilot. + +Each cell of the ``(mode, tier)`` matrix has a static default baked into +``ChatConfig`` (see ``copilot/config.py``) and a matching LaunchDarkly +string-valued feature flag that can override it per-user. This module +centralises the lookup so both the baseline and SDK paths agree on the +selection rule and so A/B experiments can target a single cell without +shipping a config change. + +Matrix: + + +----------+-------------------------------------+-------------------------------------+ + | | standard | advanced | + +----------+-------------------------------------+-------------------------------------+ + | fast | copilot-fast-standard-model | copilot-fast-advanced-model | + | thinking | copilot-thinking-standard-model | copilot-thinking-advanced-model | + +----------+-------------------------------------+-------------------------------------+ + +LD flag values are arbitrary strings (model identifiers, e.g. +``"anthropic/claude-sonnet-4-6"`` or ``"moonshotai/kimi-k2.6"``). Empty +or non-string values fall back to the config default. +""" + +from __future__ import annotations + +import logging +from typing import Literal + +from backend.copilot.config import ChatConfig +from backend.util.feature_flag import Flag, get_feature_flag_value + +logger = logging.getLogger(__name__) + +ModelMode = Literal["fast", "thinking"] +ModelTier = Literal["standard", "advanced"] + + +_FLAG_BY_CELL: dict[tuple[ModelMode, ModelTier], Flag] = { + ("fast", "standard"): Flag.COPILOT_FAST_STANDARD_MODEL, + ("fast", "advanced"): Flag.COPILOT_FAST_ADVANCED_MODEL, + ("thinking", "standard"): Flag.COPILOT_THINKING_STANDARD_MODEL, + ("thinking", "advanced"): Flag.COPILOT_THINKING_ADVANCED_MODEL, +} + + +def _config_default(config: ChatConfig, mode: ModelMode, tier: ModelTier) -> str: + if mode == "fast": + return ( + config.fast_advanced_model + if tier == "advanced" + else config.fast_standard_model + ) + return ( + config.thinking_advanced_model + if tier == "advanced" + else config.thinking_standard_model + ) + + +async def resolve_model( + mode: ModelMode, + tier: ModelTier, + user_id: str | None, + *, + config: ChatConfig, +) -> str: + """Return the model identifier for a ``(mode, tier)`` cell. + + Consults the matching LaunchDarkly flag for *user_id* first and + falls back to the ``ChatConfig`` default on missing user, missing + flag, or non-string flag value. Passing *config* explicitly keeps + the resolver cheap to unit-test. + """ + fallback = _config_default(config, mode, tier).strip() + if not user_id: + return fallback + + flag = _FLAG_BY_CELL[(mode, tier)] + try: + value = await get_feature_flag_value(flag.value, user_id, default=fallback) + except Exception: + logger.warning( + "[model_router] LD lookup failed for %s — using config default %s", + flag.value, + fallback, + exc_info=True, + ) + return fallback + + if isinstance(value, str) and value.strip(): + return value.strip() + if value != fallback: + reason = ( + "empty string" + if isinstance(value, str) + else f"non-string ({type(value).__name__})" + ) + logger.warning( + "[model_router] LD flag %s returned %s — using config default %s", + flag.value, + reason, + fallback, + ) + return fallback diff --git a/autogpt_platform/backend/backend/copilot/model_router_test.py b/autogpt_platform/backend/backend/copilot/model_router_test.py new file mode 100644 index 0000000000..e388d5018b --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/model_router_test.py @@ -0,0 +1,166 @@ +"""Tests for the LD-aware model resolver.""" + +from unittest.mock import AsyncMock, patch + +import pytest + +from backend.copilot.config import ChatConfig +from backend.copilot.model_router import _FLAG_BY_CELL, _config_default, resolve_model + + +def _make_config() -> ChatConfig: + """Build a config with the canonical defaults so tests read naturally.""" + return ChatConfig( + fast_standard_model="anthropic/claude-sonnet-4-6", + fast_advanced_model="anthropic/claude-opus-4.7", + thinking_standard_model="anthropic/claude-sonnet-4-6", + thinking_advanced_model="anthropic/claude-opus-4.7", + ) + + +class TestConfigDefault: + def test_fast_standard(self): + cfg = _make_config() + assert _config_default(cfg, "fast", "standard") == cfg.fast_standard_model + + def test_fast_advanced(self): + cfg = _make_config() + assert _config_default(cfg, "fast", "advanced") == cfg.fast_advanced_model + + def test_thinking_standard(self): + cfg = _make_config() + assert ( + _config_default(cfg, "thinking", "standard") == cfg.thinking_standard_model + ) + + def test_thinking_advanced(self): + cfg = _make_config() + assert ( + _config_default(cfg, "thinking", "advanced") == cfg.thinking_advanced_model + ) + + +class TestResolveModel: + @pytest.mark.asyncio + async def test_missing_user_returns_fallback(self): + """Without user_id there's no LD context — skip the lookup entirely.""" + cfg = _make_config() + with patch("backend.copilot.model_router.get_feature_flag_value") as mock_flag: + result = await resolve_model("fast", "standard", None, config=cfg) + assert result == cfg.fast_standard_model + mock_flag.assert_not_called() + + @pytest.mark.asyncio + async def test_missing_user_strips_whitespace_from_fallback(self): + """Sentry MEDIUM: the anonymous-user branch returned an unstripped + config value. If ``CHAT_*_MODEL`` env carries trailing whitespace + the downstream ``resolved == tier_default`` check in + ``_resolve_sdk_model_for_request`` would diverge from the + whitespace-stripped LD side, bypassing subscription mode for + every anonymous request. Strip at the source.""" + cfg = ChatConfig( + fast_standard_model="anthropic/claude-sonnet-4-6 ", # trailing ws + fast_advanced_model="anthropic/claude-opus-4.7", + thinking_standard_model="anthropic/claude-sonnet-4-6", + thinking_advanced_model="anthropic/claude-opus-4.7", + ) + result = await resolve_model("fast", "standard", None, config=cfg) + assert result == "anthropic/claude-sonnet-4-6" + + @pytest.mark.asyncio + async def test_ld_string_override_wins(self): + """LD-returned model string replaces the config default.""" + cfg = _make_config() + with patch( + "backend.copilot.model_router.get_feature_flag_value", + new=AsyncMock(return_value="moonshotai/kimi-k2.6"), + ): + result = await resolve_model("fast", "standard", "user-1", config=cfg) + assert result == "moonshotai/kimi-k2.6" + + @pytest.mark.asyncio + async def test_whitespace_is_stripped(self): + cfg = _make_config() + with patch( + "backend.copilot.model_router.get_feature_flag_value", + new=AsyncMock(return_value=" xai/grok-4 "), + ): + result = await resolve_model("thinking", "advanced", "user-1", config=cfg) + assert result == "xai/grok-4" + + @pytest.mark.asyncio + async def test_non_string_value_falls_back_with_type_in_warning(self, caplog): + """LD misconfigured as a boolean flag — don't try to use ``True`` as a + model name; return the config default. Warning must say + 'non-string' (not 'empty string') so the LD operator knows the + flag type is wrong, not just missing a value.""" + import logging + + cfg = _make_config() + with caplog.at_level(logging.WARNING, logger="backend.copilot.model_router"): + with patch( + "backend.copilot.model_router.get_feature_flag_value", + new=AsyncMock(return_value=True), + ): + result = await resolve_model("fast", "advanced", "user-1", config=cfg) + assert result == cfg.fast_advanced_model + assert any("non-string" in r.message for r in caplog.records) + + @pytest.mark.asyncio + async def test_empty_string_falls_back_with_empty_in_warning(self, caplog): + """When LD returns ``""`` the warning must say 'empty string' — + not 'non-string' — so the operator doesn't chase a type bug + when the flag is simply unset to an empty value.""" + import logging + + cfg = _make_config() + with caplog.at_level(logging.WARNING, logger="backend.copilot.model_router"): + with patch( + "backend.copilot.model_router.get_feature_flag_value", + new=AsyncMock(return_value=""), + ): + result = await resolve_model("fast", "standard", "user-1", config=cfg) + assert result == cfg.fast_standard_model + messages = [r.message for r in caplog.records] + assert any("empty string" in m for m in messages) + assert not any("non-string" in m for m in messages) + + @pytest.mark.asyncio + async def test_ld_exception_falls_back(self): + """LD client throws (network blip, SDK init race) — serve the default + instead of failing the whole request.""" + cfg = _make_config() + with patch( + "backend.copilot.model_router.get_feature_flag_value", + new=AsyncMock(side_effect=RuntimeError("LD down")), + ): + result = await resolve_model("fast", "standard", "user-1", config=cfg) + assert result == cfg.fast_standard_model + + @pytest.mark.asyncio + async def test_all_four_cells_hit_distinct_flags(self): + """Each (mode, tier) cell must route to its own flag — regression + guard against copy-paste bugs in the _FLAG_BY_CELL map.""" + cfg = _make_config() + calls: list[str] = [] + + async def _capture(flag_key, user_id, default): + calls.append(flag_key) + return default + + with patch( + "backend.copilot.model_router.get_feature_flag_value", + new=AsyncMock(side_effect=_capture), + ): + await resolve_model("fast", "standard", "u", config=cfg) + await resolve_model("fast", "advanced", "u", config=cfg) + await resolve_model("thinking", "standard", "u", config=cfg) + await resolve_model("thinking", "advanced", "u", config=cfg) + + assert calls == [ + _FLAG_BY_CELL[("fast", "standard")].value, + _FLAG_BY_CELL[("fast", "advanced")].value, + _FLAG_BY_CELL[("thinking", "standard")].value, + _FLAG_BY_CELL[("thinking", "advanced")].value, + ] + assert len(set(calls)) == 4 diff --git a/autogpt_platform/backend/backend/copilot/moonshot.py b/autogpt_platform/backend/backend/copilot/moonshot.py new file mode 100644 index 0000000000..c117e76120 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/moonshot.py @@ -0,0 +1,147 @@ +"""Moonshot-specific pricing and cache-control helpers. + +Moonshot's Kimi K2.x family is routed through OpenRouter's Anthropic-compat +shim — it speaks Anthropic's API shape but its pricing and cache behaviour +diverge from Anthropic in ways the Claude Agent SDK CLI and our baseline +cache-control gating don't handle on their own: + +* **Rate card** — NOT the canonical cost source. The authoritative number + for every OpenRouter-routed turn is the reconcile task + (:mod:`openrouter_cost`), which reads ``total_cost`` directly from + ``/api/v1/generation`` post-turn. This module exists purely so the + CLI's in-turn ``ResultMessage.total_cost_usd`` (which silently bills + Moonshot at Sonnet rates, ~5x the real Moonshot price because the CLI + pricing table only knows Anthropic) isn't left wildly wrong before the + reconcile fires AND so the reconcile's lookup-fail fallback records a + plausible Moonshot estimate rather than a Sonnet-rate overcharge. + Signal authority: reconcile >> this module's rate card >> CLI. + +* **Cache-control** — Anthropic and Moonshot both accept the + ``cache_control: {type: ephemeral}`` breakpoint on message blocks, but + our baseline path currently gates cache markers on an + ``anthropic/`` / ``claude`` name match because non-Anthropic providers + (OpenAI, Grok, Gemini) 400 on the unknown field. Moonshot's + Anthropic-compat endpoint silently accepts and honours the marker — + empirically boosts cache hit rate on continuation turns — but was + caught in the non-Anthropic branch of the original gate. + :func:`moonshot_supports_cache_control` lets callers widen the gate + to include Moonshot without weakening the ``false`` answer for + OpenAI et al. (The predicate is intentionally narrow — Moonshot-only + — so callers combine it with an explicit Anthropic check at the call + site; see ``baseline/service.py::_supports_prompt_cache_markers``.) + +Detection is prefix-based (``moonshotai/``). Moonshot routes every Kimi +SKU through the same Anthropic-compat surface and currently prices them +identically, so a new ``moonshotai/kimi-k3.0`` slug transparently +inherits both the rate card and the cache-control gate without editing +this file. Per-slug overrides are in :data:`_RATE_OVERRIDES_USD_PER_MTOK` +for when Moonshot eventually splits prices. +""" + +from __future__ import annotations + +# All Moonshot slugs share these rates as of April 2026 — Moonshot prices +# every Kimi K2.x SKU at $0.60/$2.80 per million (input/output) via +# OpenRouter. Cache-read / cache-write discounts are NOT applied here: +# OpenRouter currently exposes only a single input price per Moonshot +# endpoint; the real billed amount (with cache savings) lands via the +# reconcile path. Keep in sync with https://platform.moonshot.ai/docs/pricing. +_DEFAULT_MOONSHOT_RATE_USD_PER_MTOK: tuple[float, float] = (0.60, 2.80) + +# Per-slug overrides for when Moonshot splits pricing across SKUs. Empty +# today — every slug matching ``moonshotai/`` falls back to +# :data:`_DEFAULT_MOONSHOT_RATE_USD_PER_MTOK`. +_RATE_OVERRIDES_USD_PER_MTOK: dict[str, tuple[float, float]] = {} + +# Vendor prefix — matches any OpenRouter slug Moonshot ships. Keep as a +# module constant so the prefix check stays in exactly one place. +_MOONSHOT_PREFIX = "moonshotai/" + + +def is_moonshot_model(model: str | None) -> bool: + """True when *model* is a Moonshot OpenRouter slug. + + Prefix match against ``moonshotai/`` covers every Kimi SKU Moonshot + ships today (``kimi-k2``, ``kimi-k2.5``, ``kimi-k2.6``, + ``kimi-k2-thinking``) plus any future SKU Moonshot publishes under + the same namespace. Used by both pricing and cache-control gating. + """ + return isinstance(model, str) and model.startswith(_MOONSHOT_PREFIX) + + +def rate_card_usd(model: str | None) -> tuple[float, float] | None: + """Return (input, output) $/Mtok for *model* or None if non-Moonshot. + + Looks up a per-slug override first, falling back to the shared + default for anything under ``moonshotai/``. Returns None for + non-Moonshot slugs (including ``None``) so callers can skip the + override without a preflight guard. + """ + if not is_moonshot_model(model): + return None + # ``is_moonshot_model`` narrowed ``model`` to str; dict.get is + # type-safe here despite the wider param annotation above. + assert model is not None + return _RATE_OVERRIDES_USD_PER_MTOK.get(model, _DEFAULT_MOONSHOT_RATE_USD_PER_MTOK) + + +def override_cost_usd( + *, + model: str | None, + sdk_reported_usd: float, + prompt_tokens: int, + completion_tokens: int, + cache_read_tokens: int, + cache_creation_tokens: int, +) -> float: + """Recompute SDK turn cost from the Moonshot rate card. + + Not the canonical cost source — the OpenRouter ``/generation`` + reconcile (:mod:`openrouter_cost`) lands the authoritative billed + amount post-turn. This helper exists only to improve the CLI's + in-turn ``ResultMessage.total_cost_usd``: + + 1. So the ``cost_usd`` the client sees before the reconcile completes + isn't wildly wrong (the CLI would otherwise ship a Sonnet-rate + estimate, ~5x the real Moonshot bill). + 2. So the reconcile's own lookup-fail fallback records a plausible + Moonshot estimate rather than the CLI's Sonnet number. + + For Moonshot slugs we compute cost from the reported token counts; + for anything else (including Anthropic) we return the SDK number + unchanged — Anthropic slugs are priced accurately by the CLI. + + Cache read / creation tokens are folded into ``prompt_tokens`` at + the full input rate because Moonshot's rate card doesn't distinguish + them at the OpenRouter surface; the reconcile has the authoritative + discount accounting for turns where Moonshot's cache engaged. + """ + if model is None: + return sdk_reported_usd + rates = rate_card_usd(model) + if rates is None: + return sdk_reported_usd + input_rate, output_rate = rates + total_prompt = prompt_tokens + cache_read_tokens + cache_creation_tokens + return (total_prompt * input_rate + completion_tokens * output_rate) / 1_000_000 + + +def moonshot_supports_cache_control(model: str | None) -> bool: + """True when a Moonshot *model* accepts Anthropic-style ``cache_control``. + + Narrow, Moonshot-specific predicate — callers that need the full + "does this route accept cache markers" answer combine this with an + Anthropic check (see ``baseline/service.py::_supports_prompt_cache_markers``). + Named ``moonshot_*`` deliberately so the call site can't mistake it + for a universal predicate that answers correctly for Anthropic + (which also supports cache_control — this function would return + False for Anthropic slugs). + + Moonshot's Anthropic-compat endpoint honours the marker. Without + it Moonshot falls back to its own automatic prefix caching, which + drifts more readily between turns (internal testing saw 0/4 cache + hits across two continuation sessions). With explicit + ``cache_control`` the upstream cache hit rate rises to the same + ballpark as Anthropic's ~60-95% on continuations. + """ + return is_moonshot_model(model) diff --git a/autogpt_platform/backend/backend/copilot/moonshot_test.py b/autogpt_platform/backend/backend/copilot/moonshot_test.py new file mode 100644 index 0000000000..7fcea124bf --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/moonshot_test.py @@ -0,0 +1,173 @@ +"""Unit tests for Moonshot pricing and cache-control helpers.""" + +from __future__ import annotations + +import pytest + +from backend.copilot.moonshot import ( + is_moonshot_model, + moonshot_supports_cache_control, + override_cost_usd, + rate_card_usd, +) + + +class TestIsMoonshotModel: + """Prefix detection covers every Moonshot SKU without a slug list.""" + + @pytest.mark.parametrize( + "model", + [ + "moonshotai/kimi-k2.6", + "moonshotai/kimi-k2-thinking", + "moonshotai/kimi-k2.5", + "moonshotai/kimi-k2", + "moonshotai/kimi-k3.0", # Future SKU must match transparently. + ], + ) + def test_moonshot_slugs_match(self, model: str) -> None: + assert is_moonshot_model(model) is True + + @pytest.mark.parametrize( + "model", + [ + "anthropic/claude-sonnet-4.6", + "anthropic/claude-opus-4.7", + "openai/gpt-4o", + "google/gemini-2.5-flash", + "xai/grok-4", + "deepseek/deepseek-v3", + "", # Empty string — not Moonshot. + ], + ) + def test_non_moonshot_slugs_do_not_match(self, model: str) -> None: + assert is_moonshot_model(model) is False + + @pytest.mark.parametrize("model", [None, 123, ["moonshotai/kimi-k2.6"]]) + def test_non_string_returns_false(self, model) -> None: + # Type-robust: never raise on unexpected types; callers pass None. + assert is_moonshot_model(model) is False + + +class TestRateCardUsd: + """Rate card defaults to the shared Moonshot price for every SKU.""" + + def test_moonshot_default_rate(self) -> None: + assert rate_card_usd("moonshotai/kimi-k2.6") == (0.60, 2.80) + + def test_future_moonshot_sku_inherits_default(self) -> None: + # Verifies the prefix-based fallback — new SKUs don't need a code + # edit to get a reasonable rate card. + assert rate_card_usd("moonshotai/kimi-k3.0") == (0.60, 2.80) + + def test_non_moonshot_returns_none(self) -> None: + assert rate_card_usd("anthropic/claude-sonnet-4.6") is None + assert rate_card_usd("openai/gpt-4o") is None + + +class TestOverrideCostUsd: + """Rate-card override replaces the CLI's Sonnet-rate estimate for + Moonshot turns; Anthropic and unknown slugs pass through unchanged.""" + + def test_moonshot_recomputes_from_rate_card(self) -> None: + """A 29.5K-prompt Kimi turn should land at ~$0.018 on the + Moonshot rate card, not the CLI's $0.09 Sonnet-rate estimate.""" + recomputed = override_cost_usd( + model="moonshotai/kimi-k2.6", + sdk_reported_usd=0.089862, # What the CLI reported (Sonnet price). + prompt_tokens=29564, + completion_tokens=78, + cache_read_tokens=0, + cache_creation_tokens=0, + ) + expected = (29564 * 0.60 + 78 * 2.80) / 1_000_000 + assert recomputed == pytest.approx(expected, rel=1e-9) + assert 0.017 < recomputed < 0.019 # Sanity against Moonshot's rate card. + + def test_anthropic_passes_through(self) -> None: + """Anthropic slugs are priced accurately by the CLI already — + the override returns the SDK number unchanged.""" + assert ( + override_cost_usd( + model="anthropic/claude-sonnet-4.6", + sdk_reported_usd=0.089862, + prompt_tokens=29564, + completion_tokens=78, + cache_read_tokens=0, + cache_creation_tokens=0, + ) + == 0.089862 + ) + + def test_unknown_non_moonshot_passes_through(self) -> None: + """A non-Moonshot, non-Anthropic slug falls back to the SDK value + — best-effort rather than leaking a zero or a wrong rate card.""" + assert ( + override_cost_usd( + model="deepseek/deepseek-v3", + sdk_reported_usd=0.05, + prompt_tokens=10_000, + completion_tokens=500, + cache_read_tokens=0, + cache_creation_tokens=0, + ) + == 0.05 + ) + + def test_none_model_passes_through(self) -> None: + """Subscription mode sets model=None — return the SDK value.""" + assert ( + override_cost_usd( + model=None, + sdk_reported_usd=0.07, + prompt_tokens=100, + completion_tokens=10, + cache_read_tokens=0, + cache_creation_tokens=0, + ) + == 0.07 + ) + + def test_cache_tokens_priced_at_input_rate(self) -> None: + """OpenRouter's Moonshot endpoints don't expose a discounted + cached-input price — cache_read / cache_creation tokens are + priced at the full input rate. The reconcile path has the + authoritative discount for turns where Moonshot's cache engaged.""" + recomputed = override_cost_usd( + model="moonshotai/kimi-k2.6", + sdk_reported_usd=0.5, + prompt_tokens=1000, + completion_tokens=0, + cache_read_tokens=5000, + cache_creation_tokens=2000, + ) + expected = (1000 + 5000 + 2000) * 0.60 / 1_000_000 + assert recomputed == pytest.approx(expected, rel=1e-9) + + +class TestSupportsCacheControl: + """Gate for emitting ``cache_control: {type: ephemeral}`` on message + blocks. True for Moonshot (Anthropic-compat endpoint accepts it) + and False for everything else this module knows about — Anthropic + callers use their own ``_is_anthropic_model`` check which is + combined with this one into a wider gate.""" + + def test_moonshot_supports_cache_control(self) -> None: + assert moonshot_supports_cache_control("moonshotai/kimi-k2.6") is True + + def test_future_moonshot_sku_supports_cache_control(self) -> None: + assert moonshot_supports_cache_control("moonshotai/kimi-k3.0") is True + + @pytest.mark.parametrize( + "model", + [ + "openai/gpt-4o", + "google/gemini-2.5-flash", + "xai/grok-4", + "deepseek/deepseek-v3", + "", + None, + ], + ) + def test_non_moonshot_does_not_support_cache_control(self, model) -> None: + assert moonshot_supports_cache_control(model) is False diff --git a/autogpt_platform/backend/backend/copilot/pending_messages.py b/autogpt_platform/backend/backend/copilot/pending_messages.py index ff6eed8b59..8e6aa61af9 100644 --- a/autogpt_platform/backend/backend/copilot/pending_messages.py +++ b/autogpt_platform/backend/backend/copilot/pending_messages.py @@ -240,16 +240,15 @@ async def peek_pending_messages(session_id: str) -> list[PendingMessage]: return messages -async def _clear_pending_messages_unsafe(session_id: str) -> None: +async def clear_pending_messages_unsafe(session_id: str) -> None: """Drop the session's pending buffer — **not** the normal turn cleanup. - Named ``_unsafe`` because reaching for this at turn end drops queued - follow-ups on the floor instead of running them (the bug fixed by - commit b64be73). The atomic ``LPOP`` drain at turn start is the - primary consumer; anything pushed after the drain window belongs to - the next turn by definition. Retained only as an operator/debug - escape hatch for manually clearing a stuck session and as a fixture - in the unit tests. + The ``_unsafe`` suffix warns: reaching for this at turn end drops queued + follow-ups on the floor instead of running them (the bug fixed by commit + b64be73). The atomic ``LPOP`` drain at turn start is the primary consumer; + anything pushed after the drain window belongs to the next turn by + definition. Retained only as an operator/debug escape hatch for manually + clearing a stuck session and as a fixture in the unit tests. """ redis = await get_redis_async() await redis.delete(_buffer_key(session_id)) diff --git a/autogpt_platform/backend/backend/copilot/pending_messages_test.py b/autogpt_platform/backend/backend/copilot/pending_messages_test.py index 06f809579f..c997d7d9cf 100644 --- a/autogpt_platform/backend/backend/copilot/pending_messages_test.py +++ b/autogpt_platform/backend/backend/copilot/pending_messages_test.py @@ -16,7 +16,7 @@ from backend.copilot.pending_messages import ( MAX_PENDING_MESSAGES, PendingMessage, PendingMessageContext, - _clear_pending_messages_unsafe, + clear_pending_messages_unsafe, drain_and_format_for_injection, drain_pending_for_persist, drain_pending_messages, @@ -208,15 +208,15 @@ async def test_cap_drops_oldest_when_exceeded(fake_redis: _FakeRedis) -> None: async def test_clear_removes_buffer(fake_redis: _FakeRedis) -> None: await push_pending_message("sess4", PendingMessage(content="x")) await push_pending_message("sess4", PendingMessage(content="y")) - await _clear_pending_messages_unsafe("sess4") + await clear_pending_messages_unsafe("sess4") assert await peek_pending_count("sess4") == 0 @pytest.mark.asyncio async def test_clear_is_idempotent(fake_redis: _FakeRedis) -> None: # Clearing an already-empty buffer should not raise - await _clear_pending_messages_unsafe("sess_empty") - await _clear_pending_messages_unsafe("sess_empty") + await clear_pending_messages_unsafe("sess_empty") + await clear_pending_messages_unsafe("sess_empty") # ── Publish hook ──────────────────────────────────────────────────── diff --git a/autogpt_platform/backend/backend/copilot/permissions.py b/autogpt_platform/backend/backend/copilot/permissions.py index 7636792ca4..ab9ed82b9c 100644 --- a/autogpt_platform/backend/backend/copilot/permissions.py +++ b/autogpt_platform/backend/backend/copilot/permissions.py @@ -52,10 +52,15 @@ is at most as permissive as the parent: from __future__ import annotations import re -from typing import Literal, get_args +from typing import TYPE_CHECKING, Literal, get_args from pydantic import BaseModel, PrivateAttr +if TYPE_CHECKING: + from collections.abc import Iterable + + from backend.copilot.tools import ToolGroup + # --------------------------------------------------------------------------- # Constants — single source of truth for all accepted tool names # --------------------------------------------------------------------------- @@ -66,7 +71,6 @@ from pydantic import BaseModel, PrivateAttr ToolName = Literal[ # Platform tools (must match keys in TOOL_REGISTRY) "add_understanding", - "ask_question", "bash_exec", "browser_act", "browser_navigate", @@ -125,9 +129,16 @@ ToolName = Literal[ # Frozen set of all valid tool names — derived from the Literal. ALL_TOOL_NAMES: frozenset[str] = frozenset(get_args(ToolName)) -# SDK built-in tool names — uppercase-initial names are SDK built-ins. +# SDK built-in tool names — tools provided by the Claude Code CLI that our +# code does not implement directly. ``TodoWrite`` is DELIBERATELY excluded: +# baseline mode ships an MCP-wrapped platform version +# (``tools/todo_write.py``), while SDK mode still uses the CLI-native +# original via ``_SDK_BUILTIN_ALWAYS`` in ``sdk/tool_adapter.py`` — the +# MCP copy is filtered out there. ``Task`` remains an SDK-only built-in +# (for queue-backed context-isolation on baseline, use ``run_sub_session`` +# instead). SDK_BUILTIN_TOOL_NAMES: frozenset[str] = frozenset( - n for n in ALL_TOOL_NAMES if n[0].isupper() + {"Agent", "Edit", "Glob", "Grep", "Read", "Task", "WebSearch", "Write"} ) # Platform tool names — everything that isn't an SDK built-in. @@ -364,13 +375,17 @@ def apply_tool_permissions( permissions: CopilotPermissions, *, use_e2b: bool = False, + disabled_groups: Iterable[ToolGroup] = (), ) -> tuple[list[str], list[str]]: """Compute (allowed_tools, extra_disallowed) for :class:`ClaudeAgentOptions`. Takes the base allowed/disallowed lists from :func:`~backend.copilot.sdk.tool_adapter.get_copilot_tool_names` / :func:`~backend.copilot.sdk.tool_adapter.get_sdk_disallowed_tools` and - applies *permissions* on top. + applies *permissions* on top. Tools belonging to any *disabled_groups* + are hidden from the base allowed list — use this to gate capability + groups (e.g. ``"graphiti"`` when the memory backend is off for the + current user). Returns: ``(allowed_tools, extra_disallowed)`` where *allowed_tools* is the @@ -380,13 +395,16 @@ def apply_tool_permissions( """ from backend.copilot.sdk.tool_adapter import ( _READ_TOOL_NAME, + BASELINE_ONLY_MCP_TOOLS, MCP_TOOL_PREFIX, get_copilot_tool_names, get_sdk_disallowed_tools, ) from backend.copilot.tools import TOOL_REGISTRY - base_allowed = get_copilot_tool_names(use_e2b=use_e2b) + base_allowed = get_copilot_tool_names( + use_e2b=use_e2b, disabled_groups=disabled_groups + ) base_disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b) if permissions.is_empty(): @@ -420,7 +438,14 @@ def apply_tool_permissions( # keeping only those present in the original base_allowed list. def to_sdk_names(short: str) -> list[str]: names: list[str] = [] - if short in TOOL_REGISTRY: + if short in BASELINE_ONLY_MCP_TOOLS: + # Baseline ships MCP versions of these (Task/TodoWrite) for + # model-flexibility parity, but SDK mode uses the CLI-native + # originals. Permissions target the CLI built-in here so + # ``base_allowed`` (which excludes the MCP wrappers) still + # matches. + names.append(short) + elif short in TOOL_REGISTRY: names.append(f"{MCP_TOOL_PREFIX}{short}") elif short in _SDK_TO_MCP: # Map SDK built-in file tool to its MCP equivalent. diff --git a/autogpt_platform/backend/backend/copilot/permissions_test.py b/autogpt_platform/backend/backend/copilot/permissions_test.py index 5289ea8d22..367c1c7a2c 100644 --- a/autogpt_platform/backend/backend/copilot/permissions_test.py +++ b/autogpt_platform/backend/backend/copilot/permissions_test.py @@ -582,6 +582,11 @@ class TestApplyToolPermissions: class TestSdkBuiltinToolNames: def test_expected_builtins_present(self): + # ``TodoWrite`` is DELIBERATELY absent: baseline ships an MCP-wrapped + # platform version for model-flexibility parity, so it appears in + # PLATFORM_TOOL_NAMES / TOOL_REGISTRY instead. ``Task`` remains + # SDK-only — baseline uses ``run_sub_session`` for the equivalent + # context-isolation role. expected = { "Agent", "Read", @@ -591,9 +596,9 @@ class TestSdkBuiltinToolNames: "Grep", "Task", "WebSearch", - "TodoWrite", } assert expected.issubset(SDK_BUILTIN_TOOL_NAMES) + assert "TodoWrite" not in SDK_BUILTIN_TOOL_NAMES def test_platform_names_match_tool_registry(self): """PLATFORM_TOOL_NAMES (derived from ToolName Literal) must match TOOL_REGISTRY keys.""" diff --git a/autogpt_platform/backend/backend/copilot/prompting.py b/autogpt_platform/backend/backend/copilot/prompting.py index 399d31c1cc..c8af41637c 100644 --- a/autogpt_platform/backend/backend/copilot/prompting.py +++ b/autogpt_platform/backend/backend/copilot/prompting.py @@ -145,31 +145,12 @@ When the user asks to interact with a service or API, follow this order: **Never skip step 1.** Built-in blocks are more reliable, tested, and user-friendly than MCP or raw API calls. -### Sub-agent tasks -- When using the Task tool, NEVER set `run_in_background` to true. - All tasks must run in the foreground. - -### Delegating to another autopilot (sub-autopilot pattern) -Use the **`run_sub_session`** tool to delegate a task to a fresh -sub-AutoPilot. The sub has its own full tool set and can perform -multi-step work autonomously. - -- `prompt` (required): the task description. -- `system_context` (optional): extra context prepended to the prompt. -- `sub_autopilot_session_id` (optional): continue an existing - sub-AutoPilot — pass the `sub_autopilot_session_id` returned by a - previous completed run. -- `wait_for_result` (default 60, max 300): seconds to wait inline. If - the sub isn't done by then you get `status="running"` + a - `sub_session_id` — call **`get_sub_session_result`** with that id - (wait up to 300s more per call) until it returns `completed` or - `error`. Works across turns — safe to reconnect in a later message. - -Use this when a task is complex enough to benefit from a separate -autopilot context, e.g. "research X and write a report" while the -parent autopilot handles orchestration. Do NOT invoke `AutoPilotBlock` -via `run_block` — it's hidden from `run_block` by design because the -dedicated tool handles the async lifecycle correctly. +### Complex multi-step work +- Use `TodoWrite` to track the plan once the job has 3+ distinct steps. +- Delegate self-contained subtasks to `run_sub_session` to keep their + intermediate tool calls out of the parent context. +- Do NOT invoke `AutoPilotBlock` via `run_block`; use `run_sub_session` + instead. """ @@ -182,14 +163,18 @@ sandbox so `bash_exec` can access it for further processing. The exact sandbox path is shown in the `[Sandbox copy available at ...]` note. ### GitHub CLI (`gh`) and git -- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before asking them to connect it. +- To check if the user has their GitHub account already connected, run `gh auth status`. Always check this before running `connect_integration(provider="github")` which will ask the user to connect their GitHub regardless if it's already connected. - If the user has connected their GitHub account, both `gh` and `git` are pre-authenticated — use them directly without any manual login step. `git` HTTPS operations (clone, push, pull) work automatically. - If the token changes mid-session (e.g. user reconnects with a new token), run `gh auth setup-git` to re-register the credential helper. -- If `gh` or `git` fails with an authentication error (e.g. "authentication - required", "could not read Username", or exit code 128), call +- **MANDATORY:** You MUST run `gh auth status` before EVER calling + `connect_integration(provider="github")`. If it shows `Logged in`, + proceed directly — no integration connection needed. Never skip this check. +- If `gh auth status` shows NOT logged in, or `gh`/`git` fails with an + authentication error (e.g. "authentication required", "could not read + Username", or exit code 128), THEN call `connect_integration(provider="github")` to surface the GitHub credentials setup card so the user can connect their account. Once connected, retry the operation. diff --git a/autogpt_platform/backend/backend/copilot/prompting_test.py b/autogpt_platform/backend/backend/copilot/prompting_test.py index 5a719f1b00..d125b66a74 100644 --- a/autogpt_platform/backend/backend/copilot/prompting_test.py +++ b/autogpt_platform/backend/backend/copilot/prompting_test.py @@ -1,7 +1,6 @@ -"""Tests for agent generation guide — verifies clarification section.""" +"""Tests for prompting helpers.""" import importlib -from pathlib import Path from backend.copilot import prompting @@ -31,28 +30,3 @@ class TestGetSdkSupplementStaticPlaceholder: def test_e2b_mode_has_no_session_placeholder(self): result = prompting.get_sdk_supplement(use_e2b=True) assert "" not in result - - -class TestAgentGenerationGuideContainsClarifySection: - """The agent generation guide must include the clarification section.""" - - def test_guide_includes_clarify_section(self): - guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md" - content = guide_path.read_text(encoding="utf-8") - assert "Before or During Building" in content - - def test_guide_mentions_find_block_for_clarification(self): - guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md" - content = guide_path.read_text(encoding="utf-8") - clarify_section = content.split("Before or During Building")[1].split( - "### Workflow" - )[0] - assert "find_block" in clarify_section - - def test_guide_mentions_ask_question_tool(self): - guide_path = Path(__file__).parent / "sdk" / "agent_generation_guide.md" - content = guide_path.read_text(encoding="utf-8") - clarify_section = content.split("Before or During Building")[1].split( - "### Workflow" - )[0] - assert "ask_question" in clarify_section diff --git a/autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md b/autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md index 4fedb186ab..bec82f0c29 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md +++ b/autogpt_platform/backend/backend/copilot/sdk/agent_generation_guide.md @@ -39,6 +39,7 @@ Before running the workflow below, ALWAYS decompose the goal first: For simple goals (1-2 blocks), keep steps brief (2-3 steps). For complex goals, use as many steps as needed. + ### Workflow for Creating/Editing Agents 1. **If editing**: First narrow to the specific agent by UUID, then fetch its diff --git a/autogpt_platform/backend/backend/copilot/sdk/env_test.py b/autogpt_platform/backend/backend/copilot/sdk/env_test.py index e61908081c..36f3dc32cb 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/env_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/env_test.py @@ -13,12 +13,19 @@ from backend.copilot.config import ChatConfig def _make_config(**overrides) -> ChatConfig: - """Create a ChatConfig with safe defaults, applying *overrides*.""" + """Create a ChatConfig with safe defaults, applying *overrides*. + + SDK model fields are pinned to anthropic/* so the + ``_validate_sdk_model_vendor_compatibility`` model_validator allows + construction with ``use_openrouter=False`` (the default here). + """ defaults = { "use_claude_code_subscription": False, "use_openrouter": False, "api_key": None, "base_url": None, + "thinking_standard_model": "anthropic/claude-sonnet-4-6", + "thinking_advanced_model": "anthropic/claude-opus-4-7", } defaults.update(overrides) return ChatConfig(**defaults) diff --git a/autogpt_platform/backend/backend/copilot/sdk/openrouter_cost.py b/autogpt_platform/backend/backend/copilot/sdk/openrouter_cost.py new file mode 100644 index 0000000000..ee0f02de44 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/openrouter_cost.py @@ -0,0 +1,399 @@ +"""Authoritative per-turn cost for OpenRouter-routed SDK generations. + +The Claude Agent SDK CLI's ``ResultMessage.total_cost_usd`` is computed +from a static Anthropic pricing table baked into the binary. For +non-Anthropic models routed through OpenRouter (e.g. Kimi K2.6) the CLI +silently falls back to Sonnet rates — empirically ~5x too high. Even +after a rate-card override the estimate is still ~37% off in practice +because OpenRouter's own tokenizer counts, reasoning-token rollup, and +dated-snapshot pricing tiers can't be reconstructed from what the SDK +exposes locally. + +This module provides :func:`record_turn_cost_from_openrouter` — an +``asyncio.create_task``-able coroutine that: + +1. Queries ``https://openrouter.ai/api/v1/generation?id=`` for + each generation ID captured during the turn. +2. Sums the authoritative ``total_cost`` across all rounds. +3. Calls :func:`persist_and_record_usage` **once** with the real number, + updating both the cost-analytics row and the rate-limit counter. + +If every lookup fails (404 / timeout / parse error), the caller's +``fallback_cost_usd`` is recorded instead — keeps the rate-limit counter +populated with the best available estimate rather than leaving the turn +uncharged. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from pathlib import Path +from typing import TYPE_CHECKING + +import httpx + +from backend.copilot.token_tracking import persist_and_record_usage +from backend.util import json + +if TYPE_CHECKING: + from backend.copilot.model import ChatSession + +logger = logging.getLogger(__name__) + +# OpenRouter docs: +# https://openrouter.ai/docs/api-reference/get-a-generation +_GENERATION_URL = "https://openrouter.ai/api/v1/generation" + +# OpenRouter's generation endpoint indexes the billing row a few seconds +# after the SSE stream closes — observed ~8-12s in practice. Retry with +# progressive backoff for up to ~30s total before giving up, so the typical +# indexing window (~10s) fits inside the retry envelope. Backoff values +# in seconds summed: 0.5 + 1 + 2 + 4 + 8 + 15 = 30.5. +_MAX_RETRIES = 7 +_BACKOFF_SECONDS = (0.5, 1.0, 2.0, 4.0, 8.0, 15.0) +_REQUEST_TIMEOUT = 10.0 + + +async def _fetch_generation_cost( + client: httpx.AsyncClient, + gen_id: str, + api_key: str, + log_prefix: str, +) -> float | None: + """Fetch the ``total_cost`` for one generation, with retries. + + Retries only on transient conditions: + + * HTTP 404 — row not yet indexed server-side (typical ~5-10s lag + after the SSE stream closes) + * HTTP 408 / 429 — timeout / rate limit + * HTTP 5xx — transient OpenRouter outage + * Network / ``httpx`` exceptions — transport-level retryable + + Fails fast on permanent client errors (401 Unauthorized, + 403 Forbidden, 400 Bad Request, etc.) since they can't recover + within the retry window and would just burn API quota. + + Returns ``None`` when the endpoint reports no data, on a permanent + failure, or when every retry attempt hits a transient error. + """ + headers = {"Authorization": f"Bearer {api_key}"} + params = {"id": gen_id} + last_error: Exception | None = None + for attempt in range(_MAX_RETRIES): + if attempt > 0: + await asyncio.sleep(_BACKOFF_SECONDS[attempt - 1]) + try: + resp = await client.get( + _GENERATION_URL, + params=params, + headers=headers, + timeout=_REQUEST_TIMEOUT, + ) + status = resp.status_code + # Fast-fail on permanent client errors — retrying 401/403/400 + # just burns API quota and delays the fallback. + if status in (400, 401, 403): + logger.warning( + "%s OpenRouter /generation permanent error %d for %s — " + "not retrying (check API key / request shape)", + log_prefix, + status, + gen_id, + ) + return None + # Transient retryable: 404 (indexing lag), 408 (timeout), + # 429 (rate limit), 5xx (server error). + if status == 404 or status == 408 or status == 429 or status >= 500: + last_error = RuntimeError(f"HTTP {status} on attempt {attempt + 1}") + continue + # Any other 4xx — treat as permanent. + if status >= 400: + logger.warning( + "%s OpenRouter /generation unexpected status %d for %s — " + "not retrying", + log_prefix, + status, + gen_id, + ) + return None + payload = resp.json().get("data") + if not isinstance(payload, dict): + logger.warning( + "%s OpenRouter /generation returned no data for %s", + log_prefix, + gen_id, + ) + return None + cost = payload.get("total_cost") + if cost is None: + logger.warning( + "%s OpenRouter /generation response missing total_cost " + "for %s (keys=%s)", + log_prefix, + gen_id, + sorted(payload.keys())[:10], + ) + return None + return float(cost) + except Exception as exc: # noqa: BLE001 + # Network / transport errors are retryable. + last_error = exc + continue + logger.warning( + "%s OpenRouter /generation lookup failed for %s after %d attempts: %s", + log_prefix, + gen_id, + _MAX_RETRIES, + last_error, + ) + return None + + +def _gen_ids_from_jsonl(path: Path) -> set[str]: + """Extract ``gen-`` message IDs from every assistant entry in a + Claude CLI JSONL file. + + Tolerant of malformed lines: single bad JSON object doesn't block + the whole file. Also reads ``redacted_thinking`` / ``thinking`` + entries that share an ID with their parent (via ``jq -u`` in the + CLI) and dedups by caller. + """ + ids: set[str] = set() + try: + with path.open("r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + entry = json.loads(line, fallback=None) + if not isinstance(entry, dict): + continue + if entry.get("type") != "assistant": + continue + message = entry.get("message") + if not isinstance(message, dict): + continue + msg_id = message.get("id") + if isinstance(msg_id, str) and msg_id.startswith("gen-"): + ids.add(msg_id) + except (OSError, UnicodeDecodeError) as exc: + logger.debug( + "Failed to scan JSONL for gen-IDs: path=%s err=%s", + path, + exc, + ) + return ids + + +def _discover_turn_subagent_gen_ids( + project_dir: Path, + session_id: str, + turn_start_ts: float, + known: list[str], +) -> list[str]: + """Gen-IDs from this session's subagents created during this turn. + + Main-turn LLM rounds (incl. fallback retries) arrive on the live + stream as ``AssistantMessage`` and land on ``known`` via + ``message_id``. What's NOT on ``known`` is the CLI's subagent LLM + calls — chiefly auto-compaction, which spawns a fresh JSONL under + ``//subagents/agent-acompact-*.jsonl`` + whose gen-IDs never touch our main adapter. OpenRouter bills them + anyway, so without this sweep compaction turns under-report cost. + + Scoping: ONLY the current session's subagent dir + (``//subagents/agent-*.jsonl``) and ONLY + files whose ``mtime >= turn_start_ts``. Without both guards we'd + merge prior turns' gen-IDs (main JSONL accumulates forever) and + foreign sessions' gen-IDs (the project dir contains every session + for this cwd), double-billing the user. + + Also covers non-compaction subagents (Task tool etc.) when the CLI + spawns them — their live-stream visibility depends on SDK version, + so the sweep is a safety net. The dedup against ``known`` means + anything already captured live doesn't double count. + + Preserves ``known`` ordering so main-turn IDs stay first; only + appends truly new IDs from the sweep. + """ + merged: list[str] = list(known) + seen = set(merged) + subagents_dir = project_dir / session_id / "subagents" + if not subagents_dir.exists(): + return merged + try: + for jsonl in subagents_dir.glob("agent-*.jsonl"): + try: + if jsonl.stat().st_mtime < turn_start_ts: + continue + except OSError: + continue + for gen_id in _gen_ids_from_jsonl(jsonl): + if gen_id not in seen: + seen.add(gen_id) + merged.append(gen_id) + except OSError as exc: + logger.debug("Failed to walk subagents dir=%s: %s", subagents_dir, exc) + return merged + + +async def record_turn_cost_from_openrouter( + *, + session: "ChatSession", + user_id: str | None, + model: str | None, + prompt_tokens: int, + completion_tokens: int, + cache_read_tokens: int, + cache_creation_tokens: int, + generation_ids: list[str], + cli_project_dir: str | None, + cli_session_id: str | None, + turn_start_ts: float | None, + fallback_cost_usd: float | None, + api_key: str | None, + log_prefix: str, +) -> None: + """Persist turn cost from OpenRouter's authoritative ``/generation``. + + Writes a single cost-analytics row via :func:`persist_and_record_usage` + — same method used for the Anthropic-direct sync path — so the + cost-log append and rate-limit counter stay consistent. No double + counting: the caller skips its own sync persist for non-Anthropic + OpenRouter turns and defers entirely to this task. + + Launched via ``asyncio.create_task`` from the stream ``finally`` block + so the ~500-2000ms ``/generation`` indexing delay doesn't add latency + to the turn. During that window the rate-limit counter is briefly + unaware of the turn's cost; back-to-back turns in that sub-second + gap see a stale counter. Acceptable tradeoff — the alternative + (writing a possibly-wrong estimate synchronously) creates a + double-count when the reconcile delta arrives. + + Fallback semantics: if every generation lookup fails, records + ``fallback_cost_usd`` instead so the rate-limit counter isn't left + completely empty. Keeps behaviour at-worst equivalent to the + rate-card estimate that came before this task existed. + """ + if not api_key: + logger.debug( + "%s OpenRouter cost record skipped: no API key available", + log_prefix, + ) + return + + # Merge in any gen-IDs from CLI subagent JSONLs the live stream + # didn't surface — chiefly SDK-internal compaction, which spawns a + # summarisation LLM call under + # ``//subagents/...`` that OpenRouter + # bills but doesn't emit via our main adapter. Safe no-op when no + # compaction happened (no subagent files created this turn) or the + # CLI wrote nothing there. + # + # The sweep is SESSION-scoped (``/subagents/``, not + # the whole project dir) and TURN-scoped (mtime >= turn_start_ts). + # Both guards are load-bearing: the project dir contains every + # session for this cwd, and subagent files persist across turns, + # so an unscoped sweep would re-bill prior turns and foreign + # sessions' gen-IDs. + if cli_project_dir and cli_session_id and turn_start_ts is not None: + merged_ids = _discover_turn_subagent_gen_ids( + Path(os.path.expanduser(cli_project_dir)), + cli_session_id, + turn_start_ts, + generation_ids, + ) + if len(merged_ids) != len(generation_ids): + logger.info( + "%s[cost-record] discovered %d additional gen-IDs in " + "session subagents (compaction / Task) — reconcile " + "covers all", + log_prefix, + len(merged_ids) - len(generation_ids), + ) + generation_ids = merged_ids + + if not generation_ids: + return + + try: + async with httpx.AsyncClient() as client: + tasks = [ + _fetch_generation_cost(client, gen_id, api_key, log_prefix) + for gen_id in generation_ids + ] + results = await asyncio.gather(*tasks, return_exceptions=False) + except Exception as exc: # noqa: BLE001 + logger.warning( + "%s OpenRouter cost record failed to fetch any generation " + "(falling back to rate-card estimate): %s", + log_prefix, + exc, + ) + results = [] + + fetched = [r for r in results if isinstance(r, (int, float))] + if fetched and len(fetched) == len(generation_ids): + real_cost: float | None = sum(fetched) + # Log real (OpenRouter billed) vs CLI rate-card estimate so an + # operator can spot divergence without querying OpenRouter by + # hand. Under-count typically means a gen-ID source we don't + # capture live (e.g. title model, background LLM calls running + # outside the main stream); over-count means the CLI's rate + # table is stale vs. OpenRouter's current pricing. + delta_pct: float | None = None + if fallback_cost_usd and fallback_cost_usd > 0: + delta_pct = (real_cost - fallback_cost_usd) / fallback_cost_usd * 100 + logger.info( + "%s[cost-record] OpenRouter real=$%.6f cli_estimate=$%s " + "delta=%s (gen_ids=%d)", + log_prefix, + real_cost, + f"{fallback_cost_usd:.6f}" if fallback_cost_usd is not None else "?", + f"{delta_pct:+.1f}%" if delta_pct is not None else "n/a", + len(generation_ids), + ) + else: + real_cost = fallback_cost_usd + if fetched: + # Partial success: some lookups returned a cost, others didn't. + # Trusting the partial sum would under-report; fall back to the + # estimate so rate-limit enforcement stays conservative. + logger.warning( + "%s[cost-record] OpenRouter partial lookup (%d/%d) — " + "using fallback estimate=$%s", + log_prefix, + len(fetched), + len(generation_ids), + real_cost, + ) + else: + logger.warning( + "%s[cost-record] OpenRouter lookup failed for all gens — " + "using fallback estimate=$%s", + log_prefix, + real_cost, + ) + + try: + await persist_and_record_usage( + session=session, + user_id=user_id, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + cache_read_tokens=cache_read_tokens, + cache_creation_tokens=cache_creation_tokens, + log_prefix=f"{log_prefix}[cost-record]", + cost_usd=real_cost, + model=model, + provider="open_router", + ) + except Exception as exc: # noqa: BLE001 + logger.warning( + "%s[cost-record] failed to persist: %s", + log_prefix, + exc, + ) diff --git a/autogpt_platform/backend/backend/copilot/sdk/openrouter_cost_test.py b/autogpt_platform/backend/backend/copilot/sdk/openrouter_cost_test.py new file mode 100644 index 0000000000..442e858c0a --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/sdk/openrouter_cost_test.py @@ -0,0 +1,520 @@ +"""Unit tests for SDK-path OpenRouter cost recording.""" + +from __future__ import annotations + +from datetime import UTC, datetime +from unittest.mock import AsyncMock, patch + +import httpx +import pytest + +from backend.copilot.model import ChatSession +from backend.copilot.sdk.openrouter_cost import record_turn_cost_from_openrouter + + +def _session() -> ChatSession: + now = datetime.now(UTC) + return ChatSession( + session_id="sess-1", + user_id="user-1", + usage=[], + started_at=now, + updated_at=now, + messages=[], + ) + + +def _mock_generation_response(cost: float) -> dict: + return { + "data": { + "total_cost": cost, + "native_tokens_prompt": 1000, + "native_tokens_completion": 50, + "tokens_prompt": 1100, + "tokens_completion": 60, + } + } + + +class TestRecordTurnCostFromOpenRouter: + """Single-write semantics: the cost + rate-limit counter is updated + exactly once per turn via this background task. The sync path at + the call site is already skipped for non-Anthropic OpenRouter turns, + so there's no double-counting path even on partial failure.""" + + @pytest.mark.asyncio + async def test_empty_generation_ids_no_op(self): + """Direct-Anthropic turn produces no gen-IDs — task is a no-op.""" + with ( + patch( + "backend.copilot.sdk.openrouter_cost.persist_and_record_usage", + new_callable=AsyncMock, + ) as mock_persist, + patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get, + ): + await record_turn_cost_from_openrouter( + session=_session(), + user_id="u1", + model="anthropic/claude-sonnet-4.6", + prompt_tokens=10, + completion_tokens=5, + cache_read_tokens=0, + cache_creation_tokens=0, + generation_ids=[], + cli_project_dir=None, + cli_session_id=None, + turn_start_ts=None, + fallback_cost_usd=0.05, + api_key="sk-or-test", + log_prefix="[test]", + ) + mock_persist.assert_not_called() + mock_get.assert_not_called() + + @pytest.mark.asyncio + async def test_missing_api_key_no_op(self): + """Without an OpenRouter API key we can't query the endpoint — skip.""" + with patch( + "backend.copilot.sdk.openrouter_cost.persist_and_record_usage", + new_callable=AsyncMock, + ) as mock_persist: + await record_turn_cost_from_openrouter( + session=_session(), + user_id="u1", + model="moonshotai/kimi-k2.6", + prompt_tokens=10, + completion_tokens=5, + cache_read_tokens=0, + cache_creation_tokens=0, + generation_ids=["gen-1"], + cli_project_dir=None, + cli_session_id=None, + turn_start_ts=None, + fallback_cost_usd=0.02, + api_key=None, + log_prefix="[test]", + ) + mock_persist.assert_not_called() + + @pytest.mark.asyncio + async def test_single_generation_records_real_cost(self): + """Authoritative cost from OpenRouter is the value recorded — no + reliance on the fallback estimate.""" + real_cost = 0.02900595 + + async def _get(self, url, **kwargs): # noqa: ARG001 + return httpx.Response(200, json=_mock_generation_response(real_cost)) + + with ( + patch( + "backend.copilot.sdk.openrouter_cost.persist_and_record_usage", + new_callable=AsyncMock, + ) as mock_persist, + patch("httpx.AsyncClient.get", new=_get), + ): + await record_turn_cost_from_openrouter( + session=_session(), + user_id="u1", + model="moonshotai/kimi-k2.6", + prompt_tokens=29669, + completion_tokens=280, + cache_read_tokens=0, + cache_creation_tokens=0, + generation_ids=["gen-1776842410"], + cli_project_dir=None, + cli_session_id=None, + turn_start_ts=None, + fallback_cost_usd=0.01858, # rate-card estimate, deliberately wrong + api_key="sk-or-test", + log_prefix="[test]", + ) + mock_persist.assert_called_once() + kwargs = mock_persist.call_args.kwargs + assert kwargs["cost_usd"] == pytest.approx(real_cost, rel=1e-9) + assert kwargs["prompt_tokens"] == 29669 + assert kwargs["completion_tokens"] == 280 + assert kwargs["provider"] == "open_router" + assert kwargs["model"] == "moonshotai/kimi-k2.6" + + @pytest.mark.asyncio + async def test_multi_round_turn_sums_costs(self): + """Tool-use turn has N generation IDs; the real cost is the sum + of ``total_cost`` across all rounds — recorded in a single row.""" + costs_by_id = {"gen-a": 0.029, "gen-b": 0.030} + + async def _get(self, url, **kwargs): # noqa: ARG001 + gen_id = kwargs.get("params", {}).get("id") + return httpx.Response( + 200, json=_mock_generation_response(costs_by_id[gen_id]) + ) + + with ( + patch( + "backend.copilot.sdk.openrouter_cost.persist_and_record_usage", + new_callable=AsyncMock, + ) as mock_persist, + patch("httpx.AsyncClient.get", new=_get), + ): + await record_turn_cost_from_openrouter( + session=_session(), + user_id="u1", + model="moonshotai/kimi-k2.6", + prompt_tokens=60000, + completion_tokens=600, + cache_read_tokens=0, + cache_creation_tokens=0, + generation_ids=["gen-a", "gen-b"], + cli_project_dir=None, + cli_session_id=None, + turn_start_ts=None, + fallback_cost_usd=0.037, + api_key="sk-or-test", + log_prefix="[test]", + ) + mock_persist.assert_called_once() + cost = mock_persist.call_args.kwargs["cost_usd"] + assert cost == pytest.approx(sum(costs_by_id.values()), rel=1e-9) + + @pytest.mark.asyncio + async def test_partial_lookup_falls_back_to_estimate(self): + """If only some gen-IDs resolve, summing them would under-report. + Fall back to the caller's estimate and log — the rate-limit + counter stays populated with the best available number.""" + fallback = 0.05 + seq = iter( + [ + httpx.Response(200, json=_mock_generation_response(0.03)), + httpx.Response(404, text="not found"), + httpx.Response(404, text="not found"), + httpx.Response(404, text="not found"), + httpx.Response(404, text="not found"), + ] + ) + + async def _get(self, *args, **kwargs): # noqa: ARG001 + return next(seq) + + with ( + patch( + "backend.copilot.sdk.openrouter_cost.persist_and_record_usage", + new_callable=AsyncMock, + ) as mock_persist, + patch("httpx.AsyncClient.get", new=_get), + ): + await record_turn_cost_from_openrouter( + session=_session(), + user_id="u1", + model="moonshotai/kimi-k2.6", + prompt_tokens=1000, + completion_tokens=10, + cache_read_tokens=0, + cache_creation_tokens=0, + generation_ids=["gen-a", "gen-b"], + cli_project_dir=None, + cli_session_id=None, + turn_start_ts=None, + fallback_cost_usd=fallback, + api_key="sk-or-test", + log_prefix="[test]", + ) + mock_persist.assert_called_once() + assert mock_persist.call_args.kwargs["cost_usd"] == fallback + + @pytest.mark.asyncio + async def test_fast_fail_on_401_no_retries(self): + """Permanent client errors (401/403/400) must not retry — burning + the 30s retry window on an unauthenticated request wastes API + quota and delays the fallback.""" + call_count = {"n": 0} + + async def _get(self, *args, **kwargs): # noqa: ARG001 + call_count["n"] += 1 + return httpx.Response(401, text="unauthorized") + + with ( + patch( + "backend.copilot.sdk.openrouter_cost.persist_and_record_usage", + new_callable=AsyncMock, + ) as mock_persist, + patch("httpx.AsyncClient.get", new=_get), + ): + await record_turn_cost_from_openrouter( + session=_session(), + user_id="u1", + model="moonshotai/kimi-k2.6", + prompt_tokens=1000, + completion_tokens=10, + cache_read_tokens=0, + cache_creation_tokens=0, + generation_ids=["gen-a"], + cli_project_dir=None, + cli_session_id=None, + turn_start_ts=None, + fallback_cost_usd=0.02, + api_key="sk-bad", + log_prefix="[test]", + ) + # Only one call — no retries. + assert call_count["n"] == 1 + # Fallback was recorded (lookup failed → keep rate-limit counter live). + mock_persist.assert_called_once() + assert mock_persist.call_args.kwargs["cost_usd"] == 0.02 + + @pytest.mark.asyncio + async def test_retries_on_404_then_succeeds(self): + """Indexing lag: endpoint returns 404 initially, then 200 once the + billing row is indexed. Retry budget should exhaust transient + states rather than giving up on first 404.""" + seq = iter( + [ + httpx.Response(404, text="not found"), + httpx.Response(200, json=_mock_generation_response(0.025)), + ] + ) + + async def _get(self, *args, **kwargs): # noqa: ARG001 + return next(seq) + + with ( + patch( + "backend.copilot.sdk.openrouter_cost.persist_and_record_usage", + new_callable=AsyncMock, + ) as mock_persist, + patch("httpx.AsyncClient.get", new=_get), + ): + await record_turn_cost_from_openrouter( + session=_session(), + user_id="u1", + model="moonshotai/kimi-k2.6", + prompt_tokens=1000, + completion_tokens=10, + cache_read_tokens=0, + cache_creation_tokens=0, + generation_ids=["gen-a"], + cli_project_dir=None, + cli_session_id=None, + turn_start_ts=None, + fallback_cost_usd=0.05, + api_key="sk-or-test", + log_prefix="[test]", + ) + mock_persist.assert_called_once() + assert mock_persist.call_args.kwargs["cost_usd"] == pytest.approx(0.025) + + @pytest.mark.asyncio + async def test_complete_lookup_failure_falls_back_to_estimate(self): + """Every lookup fails → record the estimate so the rate-limit + counter isn't left empty. At-worst parity with the pre-task + behaviour.""" + fallback = 0.02 + + async def _get(self, *args, **kwargs): # noqa: ARG001 + raise httpx.ConnectError("no network") + + with ( + patch( + "backend.copilot.sdk.openrouter_cost.persist_and_record_usage", + new_callable=AsyncMock, + ) as mock_persist, + patch("httpx.AsyncClient.get", new=_get), + ): + await record_turn_cost_from_openrouter( + session=_session(), + user_id="u1", + model="moonshotai/kimi-k2.6", + prompt_tokens=1000, + completion_tokens=10, + cache_read_tokens=0, + cache_creation_tokens=0, + generation_ids=["gen-a"], + cli_project_dir=None, + cli_session_id=None, + turn_start_ts=None, + fallback_cost_usd=fallback, + api_key="sk-or-test", + log_prefix="[test]", + ) + mock_persist.assert_called_once() + assert mock_persist.call_args.kwargs["cost_usd"] == fallback + + @pytest.mark.asyncio + async def test_compaction_subagent_gen_ids_are_swept(self, tmp_path): + """CLI-internal compaction spawns a subagent JSONL under + ``//subagents/agent-acompact-*.jsonl`` + whose gen-IDs the live adapter never surfaces. When + ``cli_project_dir`` + ``cli_session_id`` + ``turn_start_ts`` + are supplied the reconcile walks only THIS session's subagents + and discovers the compaction IDs.""" + session_id = "sess-abc" + sub_dir = tmp_path / session_id / "subagents" + sub_dir.mkdir(parents=True) + (sub_dir / "agent-acompact-xyz.jsonl").write_text( + '{"type":"assistant","message":{"id":"gen-compact-1","content":[]}}\n' + '{"type":"assistant","message":{"id":"gen-compact-2","content":[]}}\n' + ) + + costs_by_id = { + "gen-main-1": 0.020, + "gen-compact-1": 0.005, + "gen-compact-2": 0.003, + } + + async def _get(self, *args, **kwargs): # noqa: ARG001 + gen_id = kwargs.get("params", {}).get("id") + return httpx.Response( + 200, json=_mock_generation_response(costs_by_id[gen_id]) + ) + + with ( + patch( + "backend.copilot.sdk.openrouter_cost.persist_and_record_usage", + new_callable=AsyncMock, + ) as mock_persist, + patch("httpx.AsyncClient.get", new=_get), + ): + await record_turn_cost_from_openrouter( + session=_session(), + user_id="u1", + model="anthropic/claude-opus-4.7", + prompt_tokens=1000, + completion_tokens=10, + cache_read_tokens=0, + cache_creation_tokens=0, + generation_ids=["gen-main-1"], + cli_project_dir=str(tmp_path), + cli_session_id=session_id, + turn_start_ts=0.0, + fallback_cost_usd=0.05, + api_key="sk-or-test", + log_prefix="[test]", + ) + mock_persist.assert_called_once() + assert mock_persist.call_args.kwargs["cost_usd"] == pytest.approx( + sum(costs_by_id.values()), rel=1e-9 + ) + + @pytest.mark.asyncio + async def test_compaction_sweep_no_subagents_is_noop(self, tmp_path): + """No compaction happened → reconcile uses only the caller's + gen-IDs, same as when cli_project_dir is None.""" + session_id = "sess-none" + (tmp_path / session_id).mkdir() + + async def _get(self, *args, **kwargs): # noqa: ARG001 + return httpx.Response(200, json=_mock_generation_response(0.02)) + + with ( + patch( + "backend.copilot.sdk.openrouter_cost.persist_and_record_usage", + new_callable=AsyncMock, + ) as mock_persist, + patch("httpx.AsyncClient.get", new=_get), + ): + await record_turn_cost_from_openrouter( + session=_session(), + user_id="u1", + model="moonshotai/kimi-k2.6", + prompt_tokens=1000, + completion_tokens=10, + cache_read_tokens=0, + cache_creation_tokens=0, + generation_ids=["gen-main-1"], + cli_project_dir=str(tmp_path), + cli_session_id=session_id, + turn_start_ts=0.0, + fallback_cost_usd=0.05, + api_key="sk-or-test", + log_prefix="[test]", + ) + mock_persist.assert_called_once() + assert mock_persist.call_args.kwargs["cost_usd"] == pytest.approx(0.02) + + @pytest.mark.asyncio + async def test_compaction_sweep_ignores_prior_turn_and_foreign_sessions( + self, tmp_path + ): + """Scoping guards the sweep from double-billing: a stale subagent + file from a prior turn (mtime before ``turn_start_ts``) and any + subagent from a foreign session (different session_id folder) + must BOTH be skipped. Without either guard, a long-running + session with past compactions would re-bill every prior turn, + and a second chat session in the same cwd would inherit the + first session's compaction cost.""" + import os + import time + + this_session = "sess-current" + other_session = "sess-other" + + this_subagents = tmp_path / this_session / "subagents" + this_subagents.mkdir(parents=True) + other_subagents = tmp_path / other_session / "subagents" + other_subagents.mkdir(parents=True) + + # Prior-turn compaction file — same session, stale mtime. + stale_file = this_subagents / "agent-acompact-stale.jsonl" + stale_file.write_text( + '{"type":"assistant","message":{"id":"gen-stale-1","content":[]}}\n' + ) + # Foreign session's compaction file. + foreign_file = other_subagents / "agent-acompact-foreign.jsonl" + foreign_file.write_text( + '{"type":"assistant","message":{"id":"gen-foreign-1","content":[]}}\n' + ) + # Current-turn compaction file — fresh. + fresh_file = this_subagents / "agent-acompact-fresh.jsonl" + fresh_file.write_text( + '{"type":"assistant","message":{"id":"gen-fresh-1","content":[]}}\n' + ) + + # turn_start_ts lies between the stale and fresh mtimes. + past = time.time() - 3600 + os.utime(stale_file, (past, past)) + os.utime(foreign_file, (past, past)) + turn_start_ts = time.time() - 60 # 1 min ago + fresh_now = time.time() + os.utime(fresh_file, (fresh_now, fresh_now)) + + costs_by_id = { + "gen-main-1": 0.010, + "gen-fresh-1": 0.004, + } + + async def _get(self, *args, **kwargs): # noqa: ARG001 + gen_id = kwargs.get("params", {}).get("id") + # If the sweep leaks a stale/foreign ID, the test fails here + # with a KeyError rather than silently over-billing. + assert gen_id in costs_by_id, f"sweep leaked out-of-scope gen_id {gen_id}" + return httpx.Response( + 200, json=_mock_generation_response(costs_by_id[gen_id]) + ) + + with ( + patch( + "backend.copilot.sdk.openrouter_cost.persist_and_record_usage", + new_callable=AsyncMock, + ) as mock_persist, + patch("httpx.AsyncClient.get", new=_get), + ): + await record_turn_cost_from_openrouter( + session=_session(), + user_id="u1", + model="moonshotai/kimi-k2.6", + prompt_tokens=1000, + completion_tokens=10, + cache_read_tokens=0, + cache_creation_tokens=0, + generation_ids=["gen-main-1"], + cli_project_dir=str(tmp_path), + cli_session_id=this_session, + turn_start_ts=turn_start_ts, + fallback_cost_usd=0.05, + api_key="sk-or-test", + log_prefix="[test]", + ) + mock_persist.assert_called_once() + # Exactly the current-turn main + fresh compaction — no stale, + # no foreign. + assert mock_persist.call_args.kwargs["cost_usd"] == pytest.approx( + sum(costs_by_id.values()), rel=1e-9 + ) diff --git a/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py b/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py index 17b54797b8..070e6992be 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/p0_guardrails_test.py @@ -10,12 +10,19 @@ from backend.copilot.constants import is_transient_api_error def _make_config(**overrides) -> ChatConfig: - """Create a ChatConfig with safe defaults, applying *overrides*.""" + """Create a ChatConfig with safe defaults, applying *overrides*. + + SDK model fields are pinned to anthropic/* so the + ``_validate_sdk_model_vendor_compatibility`` model_validator allows + construction with ``use_openrouter=False`` (the default here). + """ defaults = { "use_claude_code_subscription": False, "use_openrouter": False, "api_key": None, "base_url": None, + "thinking_standard_model": "anthropic/claude-sonnet-4-6", + "thinking_advanced_model": "anthropic/claude-opus-4-7", } defaults.update(overrides) return ChatConfig(**defaults) @@ -39,8 +46,11 @@ class TestResolveFallbackModel: assert _resolve_fallback_model() is None - def test_strips_provider_prefix(self): - """OpenRouter-style 'anthropic/claude-sonnet-4-...' is stripped.""" + def test_keeps_full_slug_when_openrouter_active(self): + """OpenRouter routes by ``vendor/model`` slug — _normalize_model_name + now preserves the prefix when openrouter_active is True so non- + Anthropic vendors stay routable. Anthropic slugs are passed + through unchanged in this mode (PR #12878).""" cfg = _make_config( claude_agent_fallback_model="anthropic/claude-sonnet-4-20250514", use_openrouter=True, @@ -52,8 +62,7 @@ class TestResolveFallbackModel: result = _resolve_fallback_model() - assert result == "claude-sonnet-4-20250514" - assert "/" not in result + assert result == "anthropic/claude-sonnet-4-20250514" def test_dots_replaced_for_direct_anthropic(self): """Direct Anthropic requires hyphen-separated versions.""" @@ -714,11 +723,16 @@ class TestDoTransientBackoff: mock_sleep.assert_called_once_with(7) async def test_replaces_adapter_with_new_instance(self): - """state.adapter is replaced with a new SDKResponseAdapter after yield.""" + """state.adapter is replaced with a new SDKResponseAdapter after yield, + and ``render_reasoning_in_ui`` is threaded from the SDK service config + (not hardcoded) so ``CHAT_RENDER_REASONING_IN_UI=false`` at runtime + flips the reconstruction consistently with the rest of the path.""" from unittest.mock import AsyncMock, MagicMock, patch from backend.copilot.sdk.service import _do_transient_backoff + cfg = _make_config(render_reasoning_in_ui=False) + original_adapter = MagicMock() state = MagicMock() state.adapter = original_adapter @@ -726,6 +740,7 @@ class TestDoTransientBackoff: with ( patch("asyncio.sleep", new=AsyncMock()), + patch(f"{_SVC}.config", cfg), patch("backend.copilot.sdk.service.SDKResponseAdapter") as mock_cls, ): new_adapter = MagicMock() @@ -733,7 +748,11 @@ class TestDoTransientBackoff: async for _ in _do_transient_backoff(3, state, "msg-1", "sess-1"): pass - mock_cls.assert_called_once_with(message_id="msg-1", session_id="sess-1") + mock_cls.assert_called_once_with( + message_id="msg-1", + session_id="sess-1", + render_reasoning_in_ui=False, + ) assert state.adapter is new_adapter async def test_resets_usage_after_yield(self): diff --git a/autogpt_platform/backend/backend/copilot/sdk/response_adapter.py b/autogpt_platform/backend/backend/copilot/sdk/response_adapter.py index fbd73d9277..2a15e9f1fc 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/response_adapter.py +++ b/autogpt_platform/backend/backend/copilot/sdk/response_adapter.py @@ -7,12 +7,15 @@ the frontend expects. import json import logging +import time import uuid +from typing import Any from claude_agent_sdk import ( AssistantMessage, Message, ResultMessage, + StreamEvent, SystemMessage, TextBlock, ThinkingBlock, @@ -46,6 +49,16 @@ from .tool_adapter import MCP_TOOL_PREFIX, pop_pending_tool_output logger = logging.getLogger(__name__) +# Coalescing thresholds for ``thinking_delta`` events on the SDK partial +# stream — matches the baseline window (see +# ``baseline/reasoning.py::_COALESCE_MIN_CHARS``). Anthropic's extended- +# thinking channel emits ~1 event per token (~4,700 per Kimi K2.6 turn); +# a 64-char / 50 ms window halves the event rate vs 32/40 while staying +# well under the ~100 ms perceptual threshold. +_THINKING_COALESCE_MIN_CHARS = 64 +_THINKING_COALESCE_MAX_INTERVAL_MS = 50.0 + + class SDKResponseAdapter: """Adapter for converting Claude Agent SDK messages to Vercel AI SDK format. @@ -53,7 +66,13 @@ class SDKResponseAdapter: text blocks, tool calls, and message lifecycle. """ - def __init__(self, message_id: str | None = None, session_id: str | None = None): + def __init__( + self, + message_id: str | None = None, + session_id: str | None = None, + *, + render_reasoning_in_ui: bool = True, + ): self.message_id = message_id or str(uuid.uuid4()) self.session_id = session_id self.text_block_id = str(uuid.uuid4()) @@ -62,6 +81,7 @@ class SDKResponseAdapter: self.reasoning_block_id = str(uuid.uuid4()) self.has_started_reasoning = False self.has_ended_reasoning = True + self.render_reasoning_in_ui = render_reasoning_in_ui self.current_tool_calls: dict[str, dict[str, str]] = {} self.resolved_tool_calls: set[str] = set() self.step_open = False @@ -74,6 +94,36 @@ class SDKResponseAdapter: # case so the turn renders as cleanly complete. self._text_since_last_tool_result = False self._any_tool_results_seen = False + # --- Partial-message streaming state (CHAT_SDK_INCLUDE_PARTIAL_MESSAGES) + # When ``include_partial_messages=True`` is set on + # ``ClaudeAgentOptions``, the CLI emits raw Anthropic streaming + # events (``content_block_start`` / ``content_block_delta`` / + # ``content_block_stop``) as ``StreamEvent`` messages ahead of + # each summary ``AssistantMessage``. We consume those for + # per-token wire emission and reconcile against the summary to + # catch any tail content the partial stream missed (short blocks + # the CLI emits summary-only, OpenRouter proxy quirks, encrypted + # thinking). + # + self._block_types_by_index: dict[int, str] = {} + # Running partial-stream buffers. Summary AssistantMessages can + # arrive *before* the corresponding ``content_block_stop`` event + # (the CLI flushes the summary as soon as the block is complete + # on the provider side, with the stop event following as a + # separate frame). Reconcile-by-index therefore can't rely on + # completed-block queues — instead we maintain running buffers + # of all partial output of each type, and each summary block of + # that type consumes its prefix. This also trivially handles + # Kimi K2.6's pattern of emitting each content block as its own + # summary AssistantMessage: Python list indices don't align + # with Anthropic content_block indices, but per-type order does. + self._partial_text_buffer: str = "" + self._partial_thinking_buffer: str = "" + # Coalescing buffer for ``thinking_delta`` — text_delta is + # naturally coarser so we let it through unbuffered. + self._pending_thinking_delta: str = "" + self._pending_thinking_index: int | None = None + self._last_thinking_flush_monotonic: float = 0.0 @property def has_unresolved_tool_calls(self) -> bool: @@ -99,6 +149,15 @@ class SDKResponseAdapter: # produced (task_progress events were previously silent). responses.append(StreamHeartbeat()) + elif isinstance(sdk_message, StreamEvent): + # Raw Anthropic streaming events — only delivered when + # ``include_partial_messages=True`` is set on + # ``ClaudeAgentOptions`` (gated by + # ``config.sdk_include_partial_messages``). Drives per-token + # emission of text + thinking; tool_use and other structural + # events stay on the ``AssistantMessage`` path. + self._handle_stream_event(sdk_message, responses) + elif isinstance(sdk_message, AssistantMessage): # Flush any SDK built-in tool calls that didn't get a UserMessage # result (e.g. WebSearch, Read handled internally by the CLI). @@ -115,18 +174,43 @@ class SDKResponseAdapter: responses.append(StreamStartStep()) self.step_open = True - for block in sdk_message.content: + # Hoist ThinkingBlocks to the front of the iteration so the UI + # sees reasoning *before* the answer it produced — that's the + # natural reading order and the way Anthropic models emit them. + # OpenRouter passthrough providers (Moonshot/Kimi, DeepSeek) + # often place ``reasoning`` after the visible text in the + # response, which would make ``ReasoningCollapse`` render under + # the assistant message instead of above it. ToolUse and other + # block types stay in their original relative order so tool + # call sequences remain coherent. + # + # Note: when ``include_partial_messages=True`` is active the + # per-token stream already emitted reasoning + text in their + # natural on-the-wire order via ``_handle_stream_event``. The + # summary walk below falls through to ``_emit_text_tail`` / + # ``_emit_thinking_tail`` which emit only the diff, preserving + # that ordering without duplicating content. + blocks_with_idx = sorted( + enumerate(sdk_message.content), + key=lambda pair: 0 if isinstance(pair[1], ThinkingBlock) else 1, + ) + + for block_index, block in blocks_with_idx: if isinstance(block, TextBlock): - if block.text: - # Reasoning and text are distinct UI parts; close - # any open reasoning block before opening text so - # the AI SDK transport doesn't merge them. + # Reasoning and text are distinct UI parts; close any + # open reasoning block before opening text so the AI + # SDK transport doesn't merge them. + tail = self._text_tail_for_summary_block(block.text) + if tail: self._end_reasoning_if_open(responses) self._ensure_text_started(responses) responses.append( - StreamTextDelta(id=self.text_block_id, delta=block.text) + StreamTextDelta(id=self.text_block_id, delta=tail) ) self._text_since_last_tool_result = True + elif block.text: + # Partial stream already emitted the full text. + self._text_since_last_tool_result = True elif isinstance(block, ThinkingBlock): # Stream extended_thinking content as a reasoning @@ -142,13 +226,38 @@ class SDKResponseAdapter: # it live, extended_thinking turns that end # thinking-only left the UI stuck on "Thought for Xs" # with nothing rendered until a page refresh. - if block.thinking: + # + # When ``render_reasoning_in_ui=False`` the three + # reasoning helpers below (and the append) no-op, so + # the frontend sees a text-only stream AND no + # ``ChatMessage(role='reasoning')`` row is persisted + # (the row is only created by ``_dispatch_response`` + # when ``StreamReasoningStart`` arrives, which is + # suppressed here). Persistence of the thinking text + # into the SDK transcript via + # ``_format_sdk_content_blocks`` is unaffected — that + # feeds ``--resume`` continuity, not the UI. + # + # Flush any pending coalesce buffer to the wire BEFORE + # computing the tail — otherwise a summary that + # arrives between the last partial delta and the + # ``content_block_stop`` event (race: summary is + # flushed by the CLI as soon as the block is complete + # provider-side, with stop lagging as a separate + # frame) would see ``_partial_thinking_buffer`` + # missing the pending prefix, and + # ``_thinking_tail_for_summary_block`` would emit the + # full block — duplicating the tail that + # ``_end_reasoning_if_open`` still drains on stop. + self._flush_pending_thinking(responses) + tail = self._thinking_tail_for_summary_block(block.thinking) + if tail: self._end_text_if_open(responses) self._ensure_reasoning_started(responses) responses.append( StreamReasoningDelta( id=self.reasoning_block_id, - delta=block.thinking, + delta=tail, ) ) @@ -158,7 +267,7 @@ class SDKResponseAdapter: # Strip MCP prefix so frontend sees "find_block" # instead of "mcp__copilot__find_block". - tool_name = block.name.removeprefix(MCP_TOOL_PREFIX) + tool_name = block.name.strip().removeprefix(MCP_TOOL_PREFIX) responses.append( StreamToolInputStart(toolCallId=block.id, toolName=tool_name) @@ -347,8 +456,12 @@ class SDKResponseAdapter: """Start (or restart) a reasoning block if needed. Each ``ThinkingBlock`` the SDK emits gets its own streaming block - on the wire so the frontend can render a new ``Reasoning`` part - per LLM turn (rather than concatenating across the whole session). + so the frontend can render a new ``Reasoning`` part per LLM turn + (rather than concatenating across the whole session). Events + are emitted unconditionally — the caller filters them out of the + SSE wire when ``render_reasoning_in_ui=False`` but still feeds + them through ``_dispatch_response`` so the session transcript + keeps a ``role='reasoning'`` row. """ if not self.has_started_reasoning or self.has_ended_reasoning: if self.has_ended_reasoning: @@ -358,11 +471,238 @@ class SDKResponseAdapter: self.has_started_reasoning = True def _end_reasoning_if_open(self, responses: list[StreamBaseResponse]) -> None: - """End the current reasoning block if one is open.""" + """End the current reasoning block if one is open. + + Drains any buffered thinking_delta text so the tail isn't lost + when the block closes before the coalesce window elapses. + """ if self.has_started_reasoning and not self.has_ended_reasoning: + if self._pending_thinking_delta: + responses.append( + StreamReasoningDelta( + id=self.reasoning_block_id, + delta=self._pending_thinking_delta, + ) + ) + self._partial_thinking_buffer += self._pending_thinking_delta + self._pending_thinking_delta = "" + self._pending_thinking_index = None responses.append(StreamReasoningEnd(id=self.reasoning_block_id)) self.has_ended_reasoning = True + # ------------------------------------------------------------------ + # Partial-message streaming (CHAT_SDK_INCLUDE_PARTIAL_MESSAGES) + # ------------------------------------------------------------------ + + def _reset_partial_stream_state(self) -> None: + """Clear per-message partial-stream state. + + Anthropic's ``content_block`` indices are scoped to a single + message — when a fresh ``message_start`` event arrives (new + ``AssistantMessage`` turn) the maps must reset so indices from + the previous message don't suppress genuine content in the new + one. + + Also clears ``_partial_*_buffer``: multi-round turns (tool use) + emit a ``message_start`` per LLM round, and leftover prefix + content from round N would cause the summary walk in round N+1 + to either match the wrong prefix (silently dropping new content) + or diverge and fall back to re-emitting the whole block. + """ + self._block_types_by_index = {} + self._partial_text_buffer = "" + self._partial_thinking_buffer = "" + self._pending_thinking_delta = "" + self._pending_thinking_index = None + + def _text_tail_for_summary_block(self, full_text: str) -> str: + """Reconcile the next summary ``TextBlock`` against the running + partial-stream buffer. + + The CLI can emit the summary ``AssistantMessage`` before the + matching ``content_block_stop`` event, so we can't rely on a + queue of completed blocks. Instead we maintain + ``_partial_text_buffer`` — the concatenation of every + ``text_delta`` chunk that hasn't been claimed by a summary + block yet — and consume ``full_text`` as a prefix from it. + Summary blocks that have no partial backing (buffer empty) + emit their full text; blocks that partial covered wholly are + silent; blocks with a partial prefix + a summary tail emit + only the tail. Kimi K2.6's pattern of emitting each content + block as its own summary ``AssistantMessage`` is handled + automatically because block order is preserved across both + streams. + """ + if not full_text: + return "" + if not self._partial_text_buffer: + return full_text + if full_text.startswith(self._partial_text_buffer): + tail = full_text[len(self._partial_text_buffer) :] + self._partial_text_buffer = "" + return tail + if self._partial_text_buffer.startswith(full_text): + # Partial already emitted this whole block plus more — the + # "more" belongs to a later summary block. Consume only the + # prefix matching this block and leave the rest buffered. + self._partial_text_buffer = self._partial_text_buffer[len(full_text) :] + return "" + logger.warning( + "SDK partial/summary text diverged " + "(partial_buf=%d chars, summary=%d chars) — emitting summary, " + "clearing partial buffer to recover", + len(self._partial_text_buffer), + len(full_text), + ) + self._partial_text_buffer = "" + return full_text + + def _thinking_tail_for_summary_block(self, full_thinking: str) -> str: + """Same as :meth:`_text_tail_for_summary_block` for reasoning.""" + if not full_thinking: + return "" + if not self._partial_thinking_buffer: + return full_thinking + if full_thinking.startswith(self._partial_thinking_buffer): + tail = full_thinking[len(self._partial_thinking_buffer) :] + self._partial_thinking_buffer = "" + return tail + if self._partial_thinking_buffer.startswith(full_thinking): + self._partial_thinking_buffer = self._partial_thinking_buffer[ + len(full_thinking) : + ] + return "" + logger.warning( + "SDK partial/summary thinking diverged " + "(partial_buf=%d chars, summary=%d chars) — emitting summary, " + "clearing partial buffer to recover", + len(self._partial_thinking_buffer), + len(full_thinking), + ) + self._partial_thinking_buffer = "" + return full_thinking + + def _handle_stream_event( + self, evt: StreamEvent, responses: list[StreamBaseResponse] + ) -> None: + """Translate raw Anthropic streaming events into wire events. + + Handles four event types; everything else (``message_delta`` + stop reasons, ``signature_delta``, ``input_json_delta``, + ``ping``, ...) is ignored because the summary ``AssistantMessage`` + carries their effects. + + * ``message_start`` — new message boundary, reset per-index maps + * ``content_block_start`` — open text / reasoning block on the + wire and remember the block type at that index + * ``content_block_delta`` — forward ``text_delta`` immediately + and coalesce ``thinking_delta`` (64-char / 50 ms window) + * ``content_block_stop`` — drain any buffered thinking and close + the corresponding wire block + """ + raw: dict[str, Any] = evt.event or {} + event_type = raw.get("type") + + if event_type == "message_start": + self._reset_partial_stream_state() + return + + if event_type == "content_block_start": + block = raw.get("content_block") or {} + index = raw.get("index") + block_type = block.get("type") + if not isinstance(index, int) or not isinstance(block_type, str): + return + self._block_types_by_index[index] = block_type + if block_type == "text": + self._end_reasoning_if_open(responses) + self._ensure_text_started(responses) + # Seed any preamble the block_start carries. + seed = block.get("text") or "" + if seed: + responses.append(StreamTextDelta(id=self.text_block_id, delta=seed)) + self._partial_text_buffer += seed + self._text_since_last_tool_result = True + elif block_type == "thinking": + self._end_text_if_open(responses) + self._ensure_reasoning_started(responses) + self._last_thinking_flush_monotonic = time.monotonic() + # tool_use / server_tool_use / redacted_thinking blocks stay + # on the ``AssistantMessage`` path — the frontend widgets + # need the final ``input`` payload which only arrives in the + # summary. + return + + if event_type == "content_block_delta": + index = raw.get("index") + if not isinstance(index, int): + return + delta = raw.get("delta") or {} + delta_type = delta.get("type") + if delta_type == "text_delta": + chunk = delta.get("text") or "" + if not chunk: + return + self._ensure_text_started(responses) + responses.append(StreamTextDelta(id=self.text_block_id, delta=chunk)) + self._partial_text_buffer += chunk + self._text_since_last_tool_result = True + elif delta_type == "thinking_delta": + chunk = delta.get("thinking") or "" + if not chunk: + return + self._ensure_reasoning_started(responses) + # Flush the coalesce buffer if the index changed — shouldn't + # happen in practice but guard against interleaved indices. + if ( + self._pending_thinking_index is not None + and self._pending_thinking_index != index + ): + self._flush_pending_thinking(responses) + self._pending_thinking_delta += chunk + self._pending_thinking_index = index + now = time.monotonic() + elapsed_ms = (now - self._last_thinking_flush_monotonic) * 1000.0 + if ( + len(self._pending_thinking_delta) >= _THINKING_COALESCE_MIN_CHARS + or elapsed_ms >= _THINKING_COALESCE_MAX_INTERVAL_MS + ): + self._flush_pending_thinking(responses) + self._last_thinking_flush_monotonic = now + # Other delta types (``signature_delta``, ``input_json_delta``) + # are CLI / tool-dispatch plumbing — not surfaced on the wire. + return + + if event_type == "content_block_stop": + index = raw.get("index") + if not isinstance(index, int): + return + block_type = self._block_types_by_index.pop(index, None) + if block_type == "text": + self._end_text_if_open(responses) + elif block_type == "thinking": + self._end_reasoning_if_open(responses) + return + + def _flush_pending_thinking(self, responses: list[StreamBaseResponse]) -> None: + """Drain the coalesce buffer into a ``StreamReasoningDelta``. + + Separate from ``_end_reasoning_if_open`` because the coalesce + window can flush mid-block (threshold hit) without closing the + reasoning block. + """ + if not self._pending_thinking_delta: + return + responses.append( + StreamReasoningDelta( + id=self.reasoning_block_id, + delta=self._pending_thinking_delta, + ) + ) + self._partial_thinking_buffer += self._pending_thinking_delta + self._pending_thinking_delta = "" + self._pending_thinking_index = None + def _flush_unresolved_tool_calls(self, responses: list[StreamBaseResponse]) -> None: """Emit outputs for tool calls that didn't receive a UserMessage result. diff --git a/autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py b/autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py index 634454f9e5..6d59e21fab 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/response_adapter_test.py @@ -6,6 +6,7 @@ import pytest from claude_agent_sdk import ( AssistantMessage, ResultMessage, + StreamEvent, SystemMessage, TextBlock, ThinkingBlock, @@ -21,6 +22,7 @@ from backend.copilot.response_model import ( StreamFinishStep, StreamHeartbeat, StreamReasoningDelta, + StreamReasoningEnd, StreamStart, StreamStartStep, StreamTextDelta, @@ -136,6 +138,29 @@ def test_tool_use_emits_input_start_and_available(): assert results[2].input == {"q": "x"} +def test_tool_use_strips_whitespace_in_tool_name(): + adapter = _adapter() + msg = AssistantMessage( + content=[ + ToolUseBlock( + id="tool-1", + name=f" {MCP_TOOL_PREFIX}find_block", + input={}, + ) + ], + model="test", + ) + results = adapter.convert_message(msg) + tool_events = [ + r + for r in results + if isinstance(r, (StreamToolInputStart, StreamToolInputAvailable)) + ] + assert tool_events, "expected tool input events" + for event in tool_events: + assert event.toolName == "find_block" + + def test_text_then_tool_ends_text_block(): adapter = _adapter() text_msg = AssistantMessage(content=[TextBlock(text="thinking...")], model="test") @@ -298,6 +323,36 @@ def test_text_after_thinking_closes_reasoning_and_opens_text(): assert re_idx < ts_idx +def test_thinking_after_text_in_same_message_renders_reasoning_first(): + """Kimi K2.6 (and other non-Anthropic OpenRouter providers) place + ``reasoning`` AFTER the visible text in the response, so the SDK + builds an ``AssistantMessage`` with content = [TextBlock, ThinkingBlock]. + Without reordering, the UI would show the answer first and the + reasoning panel below it — the opposite of the natural reading + order Anthropic models produce. response_adapter must hoist + ThinkingBlocks to the front so ``reasoning-start/delta/end`` events + hit the SSE stream BEFORE the ``text-*`` events.""" + adapter = _adapter() + msg = AssistantMessage( + content=[ + TextBlock(text="63"), + ThinkingBlock(thinking="7 times 9 is 63", signature=""), + ], + model="test", + ) + results = adapter.convert_message(msg) + types = [type(r).__name__ for r in results] + # ReasoningStart must land before TextStart in the emitted stream + assert "StreamReasoningStart" in types + assert "StreamTextStart" in types + assert types.index("StreamReasoningStart") < types.index("StreamTextStart") + # ReasoningDelta payload is intact + assert any( + isinstance(r, StreamReasoningDelta) and r.delta == "7 times 9 is 63" + for r in results + ) + + def test_tool_use_after_thinking_closes_reasoning(): """Opening a tool also closes an open reasoning block.""" adapter = _adapter() @@ -331,6 +386,69 @@ def test_empty_thinking_block_is_ignored(): assert [type(r).__name__ for r in results] == ["StreamStartStep"] +def test_render_reasoning_in_ui_false_still_emits_adapter_events(): + """With the persist/render decoupling the adapter is flag-agnostic: + it always emits ``StreamReasoning*`` so the session transcript keeps a + durable reasoning record. Wire-level suppression when + ``render_reasoning_in_ui=False`` happens at the SDK service yield + boundary, not here — see + ``backend/copilot/sdk/service.py::_filter_reasoning_events``. + """ + adapter = SDKResponseAdapter( + message_id="m", + session_id="s", + render_reasoning_in_ui=False, + ) + msg = AssistantMessage( + content=[ThinkingBlock(thinking="plan", signature="sig")], + model="test", + ) + results = adapter.convert_message(msg) + types = [type(r).__name__ for r in results] + assert "StreamReasoningStart" in types + assert "StreamReasoningDelta" in types + + +def test_render_reasoning_off_text_after_thinking_still_closes_reasoning(): + """Adapter still emits a ``StreamReasoningEnd`` when text follows a + thinking block — decoupled from the render flag. The service layer + drops the reasoning events at yield time; the adapter's structural + open/close pairing must not depend on the flag or downstream filters + would see orphan reasoning starts on the persisted transcript. + """ + adapter = SDKResponseAdapter( + message_id="m", + session_id="s", + render_reasoning_in_ui=False, + ) + adapter.convert_message( + AssistantMessage( + content=[ThinkingBlock(thinking="warming up", signature="sig")], + model="test", + ) + ) + results = adapter.convert_message( + AssistantMessage(content=[TextBlock(text="hello")], model="test") + ) + types = [type(r).__name__ for r in results] + assert "StreamReasoningEnd" in types + assert "StreamTextStart" in types + assert "StreamTextDelta" in types + + +def test_render_reasoning_on_is_default(): + """Default is True — existing callers keep emitting reasoning events.""" + adapter = SDKResponseAdapter(message_id="m", session_id="s") + msg = AssistantMessage( + content=[ThinkingBlock(thinking="plan", signature="sig")], + model="test", + ) + results = adapter.convert_message(msg) + types = [type(r).__name__ for r in results] + assert "StreamReasoningStart" in types + assert "StreamReasoningDelta" in types + + def test_result_success_synthesizes_fallback_text_when_final_turn_is_thinking_only(): """If the model's last LLM call after a tool_result produced only a ThinkingBlock (no TextBlock), the UI would hang on the tool output @@ -992,3 +1110,318 @@ def test_end_text_if_open_no_op_after_text_already_ended(): second: list[StreamBaseResponse] = [] adapter._end_text_if_open(second) assert second == [] + + +# --------------------------------------------------------------------------- +# Partial-message streaming (CHAT_SDK_INCLUDE_PARTIAL_MESSAGES) +# Covers the 10 scenarios in docs/sdk-per-token-streaming-followup.md +# --------------------------------------------------------------------------- + + +def _stream_event(payload: dict) -> StreamEvent: + """Convenience constructor for a raw Anthropic StreamEvent payload.""" + return StreamEvent( + uuid="stream-evt", + session_id="session-1", + parent_tool_use_id=None, + event=payload, + ) + + +def _message_start() -> StreamEvent: + return _stream_event({"type": "message_start"}) + + +def _text_block_start(index: int) -> StreamEvent: + return _stream_event( + { + "type": "content_block_start", + "index": index, + "content_block": {"type": "text", "text": ""}, + } + ) + + +def _text_delta(index: int, text: str) -> StreamEvent: + return _stream_event( + { + "type": "content_block_delta", + "index": index, + "delta": {"type": "text_delta", "text": text}, + } + ) + + +def _thinking_block_start(index: int) -> StreamEvent: + return _stream_event( + { + "type": "content_block_start", + "index": index, + "content_block": {"type": "thinking", "thinking": ""}, + } + ) + + +def _thinking_delta(index: int, text: str) -> StreamEvent: + return _stream_event( + { + "type": "content_block_delta", + "index": index, + "delta": {"type": "thinking_delta", "thinking": text}, + } + ) + + +def _block_stop(index: int) -> StreamEvent: + return _stream_event({"type": "content_block_stop", "index": index}) + + +def _collect_text_deltas(responses): + return "".join(r.delta for r in responses if isinstance(r, StreamTextDelta)) + + +def _collect_reasoning_deltas(responses): + return "".join(r.delta for r in responses if isinstance(r, StreamReasoningDelta)) + + +class TestPartialMessageStreaming: + """Scenarios 1-10 from sdk-per-token-streaming-followup.md. + + The adapter runs unconditionally in partial-aware mode — when the + flag ``CHAT_SDK_INCLUDE_PARTIAL_MESSAGES`` is off the CLI simply + never emits ``StreamEvent`` messages and the diff maps stay empty + (so the tail logic degrades to "emit the full summary content" + which is the pre-partial behaviour). + """ + + def test_partial_and_summary_agree_no_duplicate(self): + """Scenario 1: partial streams full text, summary matches exactly. + No duplicate emission, no truncation — full content reaches the + wire once.""" + adapter = _adapter() + full = "Hello world" + responses: list[StreamBaseResponse] = [] + adapter._handle_stream_event(_text_block_start(0), responses) + for chunk in ("Hello", " ", "world"): + adapter._handle_stream_event(_text_delta(0, chunk), responses) + adapter._handle_stream_event(_block_stop(0), responses) + # Summary arrives with the same full text + summary = adapter.convert_message( + AssistantMessage(content=[TextBlock(text=full)], model="test") + ) + combined = responses + summary + assert _collect_text_deltas(combined) == full + + def test_partial_short_summary_long_tail_emitted(self): + """Scenario 2 (the truncation bug we saw): partial emitted a + prefix of the real answer; summary has the full text. The + adapter must emit only the tail so no content is lost.""" + adapter = _adapter() + responses: list[StreamBaseResponse] = [] + adapter._handle_stream_event(_text_block_start(0), responses) + for chunk in ("The user ", "seems confused. They sent"): + adapter._handle_stream_event(_text_delta(0, chunk), responses) + # Summary has the full, un-truncated content + full = ( + "The user seems confused. They sent a short greeting. " + "Let me offer them concrete options." + ) + summary = adapter.convert_message( + AssistantMessage(content=[TextBlock(text=full)], model="test") + ) + combined = responses + summary + assert _collect_text_deltas(combined) == full + + def test_partial_empty_summary_only(self): + """Scenario 3: no partial deltas (CLI emitted the block entirely + in the summary — short blocks, proxy buffering, encrypted + content). Summary carries the full text.""" + adapter = _adapter() + summary = adapter.convert_message( + AssistantMessage(content=[TextBlock(text="short answer")], model="test") + ) + assert _collect_text_deltas(summary) == "short answer" + + def test_partial_long_summary_matches_no_double_emit(self): + """Scenario 4 (most common): partial streams everything, summary + repeats the same content. No duplication on the wire.""" + adapter = _adapter() + responses: list[StreamBaseResponse] = [] + full = "Here is a long paragraph with several words in it." + adapter._handle_stream_event(_text_block_start(0), responses) + # Partition into chunks that *exactly* reconstruct ``full`` — a + # word-split with trailing spaces would emit more content than + # the summary carries and the reconcile would correctly flag + # divergence. + chunks = [full[:13], full[13:25], full[25:]] + assert "".join(chunks) == full + for chunk in chunks: + adapter._handle_stream_event(_text_delta(0, chunk), responses) + adapter._handle_stream_event(_block_stop(0), responses) + assert _collect_text_deltas(responses) == full + + summary = adapter.convert_message( + AssistantMessage(content=[TextBlock(text=full)], model="test") + ) + # Summary must not add any TextDelta since partial already covered it + assert _collect_text_deltas(summary) == "" + + def test_partial_diverges_summary_wins(self): + """Scenario 5: partial content isn't a prefix of the summary. + Defensive path emits the full summary content — content must + not silently disappear.""" + adapter = _adapter() + responses: list[StreamBaseResponse] = [] + adapter._handle_stream_event(_text_block_start(0), responses) + adapter._handle_stream_event(_text_delta(0, "first draft"), responses) + # Summary has totally different content (proxy rewrote it) + summary = adapter.convert_message( + AssistantMessage( + content=[TextBlock(text="final polished answer")], + model="test", + ) + ) + # The summary's text must reach the wire even though partial + # already emitted "first draft" (which was the proxy's draft). + assert "final polished answer" in _collect_text_deltas(responses + summary) + + def test_thinking_only_partial_coalesced(self): + """Scenario 6a (thinking-only permutation): a run of + ``thinking_delta`` events below the coalesce threshold flushes + at ``content_block_stop`` so the reasoning tail isn't lost.""" + adapter = _adapter() + responses: list[StreamBaseResponse] = [] + adapter._handle_stream_event(_thinking_block_start(0), responses) + # Each chunk is well under the 64-char threshold + for chunk in ("Let ", "me ", "think"): + adapter._handle_stream_event(_thinking_delta(0, chunk), responses) + # At stop, the pending buffer drains + adapter._handle_stream_event(_block_stop(0), responses) + assert _collect_reasoning_deltas(responses) == "Let me think" + # Block closed + assert any(isinstance(r, StreamReasoningEnd) for r in responses) + + def test_text_only_via_partial_and_summary(self): + """Scenario 6b (text-only permutation): partial fills a block, + summary matches — see scenario 4 for no-double-emit assertion.""" + adapter = _adapter() + responses: list[StreamBaseResponse] = [] + adapter._handle_stream_event(_text_block_start(0), responses) + adapter._handle_stream_event(_text_delta(0, "hi"), responses) + adapter._handle_stream_event(_block_stop(0), responses) + assert _collect_text_deltas(responses) == "hi" + + def test_mixed_text_then_thinking_partial_preserves_order(self): + """Scenario 6c (mixed, Anthropic order — reasoning then text). + When partial emits blocks in natural order and summary matches, + the wire order is identical to emission order.""" + adapter = _adapter() + responses: list[StreamBaseResponse] = [] + # Anthropic-shape: thinking index 0, text index 1 + adapter._handle_stream_event(_thinking_block_start(0), responses) + adapter._handle_stream_event( + _thinking_delta(0, "X" * 80), responses + ) # over threshold + adapter._handle_stream_event(_block_stop(0), responses) + adapter._handle_stream_event(_text_block_start(1), responses) + adapter._handle_stream_event(_text_delta(1, "answer"), responses) + adapter._handle_stream_event(_block_stop(1), responses) + types = [type(r).__name__ for r in responses] + # ReasoningStart must come before TextStart — partial streams in + # the CLI's natural order, which is also the UI's desired order. + assert types.index("StreamReasoningStart") < types.index("StreamTextStart") + + def test_multi_message_turn_resets_per_index_maps(self): + """Scenario 7: tool-use loop creates multiple AssistantMessages + per turn. Anthropic content-block indices are scoped to a single + message — ``message_start`` must reset the diff maps so the next + message's index-0 text isn't silently suppressed.""" + adapter = _adapter() + responses: list[StreamBaseResponse] = [] + # First message at index 0 = "first" + adapter._handle_stream_event(_message_start(), responses) + adapter._handle_stream_event(_text_block_start(0), responses) + adapter._handle_stream_event(_text_delta(0, "first"), responses) + adapter._handle_stream_event(_block_stop(0), responses) + # New message starts — index 0 now refers to a fresh block + adapter._handle_stream_event(_message_start(), responses) + adapter._handle_stream_event(_text_block_start(0), responses) + adapter._handle_stream_event(_text_delta(0, "second"), responses) + adapter._handle_stream_event(_block_stop(0), responses) + # Both texts must land on the wire + assert _collect_text_deltas(responses) == "firstsecond" + + def test_empty_thinking_with_signature_emits_nothing(self): + """Scenario 8: encrypted / empty thinking block. Partial emits + nothing, summary carries ``block.thinking == ""`` with a + signature — the adapter must not open a reasoning block.""" + adapter = _adapter() + summary = adapter.convert_message( + AssistantMessage( + content=[ThinkingBlock(thinking="", signature="sig")], + model="test", + ) + ) + # No reasoning events should be emitted for empty thinking + reasoning_events = [ + r + for r in summary + if isinstance(r, StreamReasoningDelta) + or type(r).__name__ in ("StreamReasoningStart", "StreamReasoningEnd") + ] + assert reasoning_events == [] + + def test_thinking_tail_drains_on_block_stop(self): + """Scenario 10: a thinking_delta chunk smaller than the 64-char + threshold arrives, then ``content_block_stop``. The tail text + must emit in a final ``StreamReasoningDelta`` BEFORE + ``StreamReasoningEnd``.""" + adapter = _adapter() + responses: list[StreamBaseResponse] = [] + adapter._handle_stream_event(_thinking_block_start(0), responses) + # One small chunk well under 64 chars + adapter._handle_stream_event(_thinking_delta(0, "tiny chunk"), responses) + # Block stop must flush the pending buffer + adapter._handle_stream_event(_block_stop(0), responses) + types = [type(r).__name__ for r in responses] + # The final ReasoningDelta must precede ReasoningEnd + rd_idx = types.index("StreamReasoningDelta") + re_idx = types.index("StreamReasoningEnd") + assert rd_idx < re_idx + assert _collect_reasoning_deltas(responses) == "tiny chunk" + + def test_thinking_coalesces_on_char_threshold(self): + """Extra: thinking_delta accumulating past 64 chars flushes + mid-block without waiting for block_stop (coalesce threshold).""" + adapter = _adapter() + responses: list[StreamBaseResponse] = [] + adapter._handle_stream_event(_thinking_block_start(0), responses) + # One 80-char chunk trips the threshold on a single event + adapter._handle_stream_event(_thinking_delta(0, "x" * 80), responses) + # A ReasoningDelta must already have been emitted (not buffered + # until block_stop). + assert any(isinstance(r, StreamReasoningDelta) for r in responses) + + +# --------------------------------------------------------------------------- +# Partial/summary reconcile — summary walk must not duplicate partial content +# --------------------------------------------------------------------------- + + +def test_summary_walk_skips_fully_streamed_text(): + """If the partial stream delivered the entire TextBlock, the summary + walk must not emit a second ``StreamTextDelta`` for the same block.""" + adapter = _adapter() + responses: list[StreamBaseResponse] = [] + adapter._handle_stream_event(_text_block_start(0), responses) + adapter._handle_stream_event(_text_delta(0, "complete answer"), responses) + adapter._handle_stream_event(_block_stop(0), responses) + # Summary arrives with matching content + summary = adapter.convert_message( + AssistantMessage(content=[TextBlock(text="complete answer")], model="test") + ) + # Partial path emitted exactly one StreamTextDelta + partial_deltas = [r for r in responses if isinstance(r, StreamTextDelta)] + summary_deltas = [r for r in summary if isinstance(r, StreamTextDelta)] + assert len(partial_deltas) == 1 + assert summary_deltas == [] diff --git a/autogpt_platform/backend/backend/copilot/sdk/service.py b/autogpt_platform/backend/backend/copilot/sdk/service.py index 908e2aebdd..1ce2ede6b8 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service.py @@ -27,6 +27,7 @@ from claude_agent_sdk import ( ClaudeAgentOptions, ClaudeSDKClient, ResultMessage, + StreamEvent, TextBlock, ThinkingBlock, ToolResultBlock, @@ -56,6 +57,11 @@ from ..constants import ( from ..session_cleanup import prune_orphan_tool_calls from ..context import encode_cwd_for_cli, get_workspace_manager from ..graphiti.config import is_enabled_for_user +from ..model_router import resolve_model +from ..moonshot import ( + is_moonshot_model as _is_moonshot_model, + override_cost_usd as _override_cost_for_moonshot, +) from ..model import ( ChatMessage, ChatSession, @@ -109,6 +115,7 @@ from ..service import ( ) from ..thinking_stripper import ThinkingStripper from ..token_tracking import persist_and_record_usage +from ..tools import ToolGroup from ..tools.e2b_sandbox import get_or_create_sandbox, pause_sandbox_direct from ..tools.sandbox import WORKSPACE_PREFIX, make_session_path from ..tracking import track_user_message @@ -129,6 +136,7 @@ from ..transcript import ( from ..transcript_builder import TranscriptBuilder from .compaction import CompactionTracker, filter_compaction_messages from .env import build_sdk_env # noqa: F401 — re-export for backward compat +from .openrouter_cost import record_turn_cost_from_openrouter from .response_adapter import SDKResponseAdapter from .security_hooks import create_security_hooks from .tool_adapter import ( @@ -364,6 +372,20 @@ class _RetryState: # ``detect_gap`` picks them up as gap-fill entries instead of assuming the # JSONL already covers them. midturn_user_rows: int = 0 + # OpenRouter generation IDs collected across all attempts of this turn. + # Populated from ``AssistantMessage.message_id`` when routed via + # OpenRouter (``gen-...`` prefix). Consumed by the finally block to + # fire ``record_turn_cost_from_openrouter`` for non-Anthropic models — + # the CLI's static-Anthropic-priced estimate is replaced with the + # authoritative ``/generation`` total_cost. Lives on ``_RetryState`` + # (not per-attempt ``_StreamAccumulator``) so it survives retries. + generation_ids: list[str] = dataclass_field(default_factory=list) + # The *actually executed* model observed on ``AssistantMessage.model`` — + # differs from ``state.options.model`` (the requested primary) when + # ``_resolve_fallback_model`` swaps to a fallback mid-attempt. The + # Moonshot cost override gates on this so a Moonshot-→-Anthropic + # fallback doesn't get mis-billed at Moonshot rates, and vice versa. + observed_model: str | None = None @dataclass @@ -678,35 +700,51 @@ async def _iter_sdk_messages( def _normalize_model_name(raw_model: str) -> str: """Normalize a model name for the current routing configuration. - Applies two transformations shared by both the primary and fallback - model resolution paths: + Two routing modes: - 1. **Strip provider prefix** — OpenRouter-style names like - ``"anthropic/claude-opus-4.6"`` are reduced to ``"claude-opus-4.6"``. - 2. **Dot-to-hyphen conversion** — when *not* routing through OpenRouter - the direct Anthropic API requires hyphen-separated versions - (``"claude-opus-4-6"``), so dots are replaced with hyphens. + 1. **OpenRouter active** — the canonical OpenRouter slug is + ``"/"`` (e.g. ``"anthropic/claude-opus-4.6"``, + ``"moonshotai/kimi-k2.6"``). Pass the prefixed name through + unchanged so OpenRouter can route to the correct provider. Anthropic + names happen to also resolve when stripped, but non-Anthropic vendors + (Moonshot, Google, etc.) do not — keeping the prefix is the only form + that works for every model in the catalog. + 2. **Direct Anthropic** — strip the OpenRouter ``anthropic/`` prefix + and convert dots to hyphens (``"claude-opus-4.6"`` → + ``"claude-opus-4-6"``) since the Anthropic Messages API rejects + both the prefix and dot-separated versions. Raises ``ValueError`` + when a non-Anthropic vendor slug is paired with direct-Anthropic + mode — silently stripping ``moonshotai/`` would send ``kimi-k2.6`` + to the Anthropic API and produce an opaque ``model_not_found`` + error far from the misconfiguration source. """ + if config.openrouter_active: + return raw_model model = raw_model if "/" in model: - model = model.split("/", 1)[1] - # OpenRouter uses dots in versions (claude-opus-4.6) but the direct - # Anthropic API requires hyphens (claude-opus-4-6). Only normalise - # when NOT routing through OpenRouter. - if not config.openrouter_active: - model = model.replace(".", "-") - return model + vendor, model = model.split("/", 1) + if vendor != "anthropic": + raise ValueError( + f"Direct-Anthropic mode (use_openrouter=False or missing " + f"OpenRouter credentials) requires an Anthropic model, got " + f"vendor={vendor!r} from model={raw_model!r}. Set " + f"CHAT_THINKING_STANDARD_MODEL/CHAT_THINKING_ADVANCED_MODEL " + f"to an anthropic/* slug, or enable OpenRouter." + ) + return model.replace(".", "-") def _resolve_sdk_model() -> str | None: - """Resolve the model name for the Claude Agent SDK CLI. + """Resolve the SDK-CLI model name from static config (no LD lookup). - Uses `config.claude_agent_model` if set, otherwise derives from - `config.thinking_standard_model` via :func:`_normalize_model_name`. + ``config.claude_agent_model`` is an explicit override that wins + unconditionally. When the Claude Code subscription is enabled and no + override is set, returns ``None`` so the CLI picks the model for the + user's subscription plan. Otherwise derives from + ``config.thinking_standard_model``. - When `use_claude_code_subscription` is enabled and no explicit - `claude_agent_model` is set, returns `None` so the CLI uses the - default model for the user's subscription plan. + For per-user routing (LaunchDarkly overrides), see + :func:`_resolve_sdk_model_for_request`. """ if config.claude_agent_model: return config.claude_agent_model @@ -715,6 +753,18 @@ def _resolve_sdk_model() -> str | None: return _normalize_model_name(config.thinking_standard_model) +async def _resolve_thinking_model_for_user( + tier: "CopilotLlmModel", + user_id: str | None, +) -> str: + """LD-aware thinking-tier model pick for a specific user. + + Consults ``copilot-thinking-{tier}-model`` and falls back to the + ``ChatConfig`` default on missing user / missing flag. + """ + return await resolve_model("thinking", tier, user_id, config=config) + + def _resolve_fallback_model() -> str | None: """Resolve the fallback model name via :func:`_normalize_model_name`. @@ -729,37 +779,94 @@ def _resolve_fallback_model() -> str | None: async def _resolve_sdk_model_for_request( model: "CopilotLlmModel | None", session_id: str, + user_id: str | None = None, ) -> str | None: """Resolve the SDK model string for a turn. Priority (highest first): - 1. Explicit per-request ``model`` tier from the frontend toggle. - 2. Global config default (``_resolve_sdk_model()``). - - Returns ``None`` when the Claude Code subscription default applies. - Rate-limit accounting no longer applies a multiplier — the real turn - cost (reported by the SDK) already reflects model-pricing differences. + 1. ``config.claude_agent_model`` — unconditional override, bypasses LD. + 2. LaunchDarkly ``copilot-thinking-{tier}-model`` if it serves a value + different from the config default for *user_id*. An LD-served + override wins over subscription mode so admins can route specific + users to a specific model without flipping subscription on/off. + 3. ``config.use_claude_code_subscription`` on the standard tier — + returns ``None`` so the CLI picks the subscription default (this + branch fires when LD has no opinion, i.e. the value equals the + config default). + 4. ``ChatConfig`` static default for the tier. """ - if model == "advanced": - sdk_model = _normalize_model_name(config.thinking_advanced_model) + if config.claude_agent_model: + return config.claude_agent_model + + tier_name: "CopilotLlmModel" = "advanced" if model == "advanced" else "standard" + # Strip at read time so a stray trailing space in ``CHAT_*_MODEL`` (a + # common ``.env`` pitfall) doesn't make the ``resolved == tier_default`` + # comparison below spuriously diverge — ``resolve_model`` already strips + # the LD side, so both halves must end up whitespace-normalised to stay + # equal when they're semantically equal. Downstream ``_normalize_model_name`` + # also benefits from the strip. + tier_default = ( + config.thinking_advanced_model + if tier_name == "advanced" + else config.thinking_standard_model + ).strip() + + resolved = await _resolve_thinking_model_for_user(tier_name, user_id) + + # Subscription mode on standard tier only wins when LD has no opinion + # (value == config default ⇒ admin hasn't explicitly pointed this + # user somewhere). Any LD override — even to the same value with + # stripped whitespace normalised — is an explicit admin choice that + # must be honoured. Without this, a subscription-mode deployment + # silently ignores the ``copilot-thinking-standard-model`` flag + # entirely, which defeats the point of cohort-based routing. + ld_overrides_default = resolved != tier_default + if ( + not ld_overrides_default + and tier_name == "standard" + and config.use_claude_code_subscription + ): logger.info( - "[SDK] [%s] Per-request model override: advanced (%s)", + "[SDK] [%s] Subscription default (tier=standard, LD unset)", session_id[:12] if session_id else "?", + ) + return None + try: + sdk_model = _normalize_model_name(resolved) + except ValueError as exc: + # The per-user LD value didn't pass ``_normalize_model_name``'s + # vendor check (most commonly: a ``moonshotai/kimi-*`` slug on a + # direct-Anthropic deployment that has no OpenRouter route). Fail + # soft to the TIER-SPECIFIC config default — using the generic + # ``_resolve_sdk_model()`` here would pin advanced-tier requests to + # ``thinking_standard_model`` (Sonnet) whenever LD misconfigures + # the advanced cell, silently downgrading the user's chosen tier. + try: + sdk_model = _normalize_model_name(tier_default) + except ValueError: + # Config default is *also* invalid for the active routing + # mode — this is a deployment-level misconfig that the + # ``model_validator`` should catch at startup. Re-raise the + # original LD error so the issue surfaces loudly rather than + # returning something misleading. + raise exc + logger.warning( + "[SDK] [%s] LD model %r rejected for tier=%s (%s); falling " + "back to tier default %s", + session_id[:12] if session_id else "?", + resolved, + tier_name, + exc, sdk_model, ) return sdk_model - - if model == "standard": - # Reset to config default — respects subscription mode (None = CLI default). - sdk_model = _resolve_sdk_model() - logger.info( - "[SDK] [%s] Per-request model override: standard (%s)", - session_id[:12] if session_id else "?", - sdk_model or "subscription-default", - ) - return sdk_model - - return _resolve_sdk_model() + logger.info( + "[SDK] [%s] Resolved model for tier=%s: %s", + session_id[:12] if session_id else "?", + tier_name, + sdk_model, + ) + return sdk_model _MAX_TRANSIENT_BACKOFF_SECONDS = 30 @@ -823,7 +930,11 @@ async def _do_transient_backoff( """ yield StreamStatus(message=f"Connection interrupted, retrying in {backoff}s…") await asyncio.sleep(backoff) - state.adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id) + state.adapter = SDKResponseAdapter( + message_id=message_id, + session_id=session_id, + render_reasoning_in_ui=config.render_reasoning_in_ui, + ) state.usage.reset() @@ -2084,6 +2195,33 @@ async def _run_stream_attempt( len(state.adapter.resolved_tool_calls), ) + # Capture OpenRouter generation IDs from each + # ``AssistantMessage.message_id`` — when routed via OpenRouter + # these are ``gen-...`` slugs we can use post-turn to query + # ``/api/v1/generation?id=`` for the authoritative per-turn + # cost and token counts (the CLI's ``total_cost_usd`` is + # computed from a static Anthropic pricing table that + # silently over-bills non-Anthropic routes). Direct-Anthropic + # turns produce ``msg_...`` IDs which the generation endpoint + # doesn't know about — harmlessly ignored at reconcile time. + if isinstance(sdk_msg, AssistantMessage): + msg_id = sdk_msg.message_id + if ( + msg_id is not None + and msg_id.startswith("gen-") + and msg_id not in state.generation_ids + ): + state.generation_ids.append(msg_id) + # Track the model the SDK actually used — when a fallback + # activates, this differs from ``state.options.model``. + # Consumed by the Moonshot cost-override decision so we + # don't mis-bill a fallback-Anthropic response at + # Moonshot rates (or a fallback-Moonshot at Anthropic + # rates). + observed = getattr(sdk_msg, "model", None) + if isinstance(observed, str) and observed: + state.observed_model = observed + # Log AssistantMessage API errors (e.g. invalid_request) # so we can debug Anthropic API 400s surfaced by the CLI. sdk_error = getattr(sdk_msg, "error", None) @@ -2252,7 +2390,37 @@ async def _run_stream_attempt( state.usage.completion_tokens, ) if sdk_msg.total_cost_usd is not None: - state.usage.cost_usd = sdk_msg.total_cost_usd + # Default: trust the CLI-reported value. Accurate for + # Anthropic models (the CLI's bundled pricing table is + # Anthropic-authored), and becomes the sync-path cost + # when the reconcile is disabled or fails. + # Prefer the ACTUALLY executed model + # (``state.observed_model`` from ``AssistantMessage.model``) + # over the requested primary (``state.options.model``) + # so a fallback activation doesn't mis-route pricing. + active_model = state.observed_model or getattr( + state.options, "model", None + ) + if _is_moonshot_model(active_model): + # Moonshot slug — the CLI doesn't know Moonshot's + # rate card and silently bills at Sonnet rates + # (~5x over-charge). Replace with the rate-card + # estimate so the in-stream ``cost_usd`` and the + # reconcile's lookup-fail fallback reflect + # reality. Reconcile + # (``record_turn_cost_from_openrouter``) still + # overrides this value when every gen-ID lookup + # succeeds. + state.usage.cost_usd = _override_cost_for_moonshot( + model=active_model, + sdk_reported_usd=sdk_msg.total_cost_usd, + prompt_tokens=state.usage.prompt_tokens, + completion_tokens=state.usage.completion_tokens, + cache_read_tokens=state.usage.cache_read_tokens, + cache_creation_tokens=state.usage.cache_creation_tokens, + ) + else: + state.usage.cost_usd = sdk_msg.total_cost_usd # Emit compaction end if SDK finished compacting. # Sync TranscriptBuilder with the CLI's active context. @@ -2374,6 +2542,18 @@ async def _run_stream_attempt( skip_strip=response is tail_delta, ) if dispatched is not None: + # Persistence (via _dispatch_response) always runs so the + # session transcript keeps role='reasoning' rows; the + # wire is gated so UI can suppress rendering. + if not state.adapter.render_reasoning_in_ui and isinstance( + dispatched, + ( + StreamReasoningStart, + StreamReasoningDelta, + StreamReasoningEnd, + ), + ): + continue yield dispatched # Mid-turn follow-up persistence: the MCP tool wrapper drains @@ -2434,16 +2614,36 @@ async def _run_stream_attempt( # flush the assistant message before tool_calls are set on it # (text and tool_use arrive as separate SDK events), the # tool_calls update is lost — the next flush starts past it. - _msgs_since_flush += 1 + # + # With ``include_partial_messages=True`` the CLI delivers + # hundreds of ``StreamEvent`` messages per turn — incrementing + # ``_msgs_since_flush`` on each one trips the threshold long + # before the assistant text is complete, saving a truncated + # prefix that subsequent deltas can never extend (append-only). + # Count only messages that produce a persisted row boundary + # (AssistantMessage, UserMessage, ResultMessage) and skip + # raw StreamEvents. Also skip when text or reasoning is + # still in-flight on the adapter: the row is live and a flush + # would lock it at its current length. + if not isinstance(sdk_msg, StreamEvent): + _msgs_since_flush += 1 now = time.monotonic() has_pending_tools = ( acc.has_appended_assistant and acc.accumulated_tool_calls and not acc.has_tool_results ) - if not has_pending_tools and ( - _msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD - or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS + adapter = state.adapter + has_open_block = ( + adapter.has_started_text and not adapter.has_ended_text + ) or (adapter.has_started_reasoning and not adapter.has_ended_reasoning) + if ( + not has_pending_tools + and not has_open_block + and ( + _msgs_since_flush >= _FLUSH_MESSAGE_THRESHOLD + or (now - _last_flush_time) >= _FLUSH_INTERVAL_SECONDS + ) ): try: await asyncio.shield(upsert_chat_session(ctx.session)) @@ -2920,6 +3120,12 @@ async def stream_chat_completion_sdk( # Defaults ensure the finally block can always reference these safely even when # an early return (e.g. sdk_cwd error) skips their normal assignment below. sdk_model: str | None = None + # Wall-clock timestamp captured before the CLI runs so the + # OpenRouter reconcile can filter subagent JSONLs by mtime — only + # files created during THIS turn contribute gen-IDs. Without this + # the sweep would pick up prior turns' compaction files that persist + # under ``/subagents/``, double-billing the user. + turn_start_ts = time.time() # Make sure there is no more code between the lock acquisition and try-block. try: @@ -3043,8 +3249,8 @@ async def stream_chat_completion_sdk( mcp_server = create_copilot_mcp_server(use_e2b=use_e2b) - # Resolve model (request tier → config default). - sdk_model = await _resolve_sdk_model_for_request(model, session_id) + # Resolve model (request tier → LD per-user override → config default). + sdk_model = await _resolve_sdk_model_for_request(model, session_id, user_id) # Track SDK-internal compaction (PreCompact hook → start, next msg → end) compaction = CompactionTracker() @@ -3056,10 +3262,18 @@ async def stream_chat_completion_sdk( on_compact=compaction.on_compact, ) + disabled_tool_groups: list[ToolGroup] = [] + if not graphiti_enabled: + disabled_tool_groups.append("graphiti") + if permissions is not None: - allowed, disallowed = apply_tool_permissions(permissions, use_e2b=use_e2b) + allowed, disallowed = apply_tool_permissions( + permissions, use_e2b=use_e2b, disabled_groups=disabled_tool_groups + ) else: - allowed = get_copilot_tool_names(use_e2b=use_e2b) + allowed = get_copilot_tool_names( + use_e2b=use_e2b, disabled_groups=disabled_tool_groups + ) disallowed = get_sdk_disallowed_tools(use_e2b=use_e2b) def _on_stderr(line: str) -> None: @@ -3129,6 +3343,17 @@ async def stream_chat_completion_sdk( sdk_options_kwargs["effort"] = config.claude_agent_thinking_effort if sdk_model: sdk_options_kwargs["model"] = sdk_model + if config.sdk_include_partial_messages: + # Opt into per-token streaming — the CLI emits raw Anthropic + # ``content_block_delta`` events as ``StreamEvent`` messages + # ahead of each summary ``AssistantMessage`` so reasoning and + # text land on the wire token-by-token (matching the baseline + # path's UX shipped in #12873). ``SDKResponseAdapter`` consumes + # the partial stream via ``_handle_stream_event`` and emits + # only the tail diff from the subsequent summary, so content + # never double-emits and a summary-only short block still + # reaches the UI. + sdk_options_kwargs["include_partial_messages"] = True if sdk_env: sdk_options_kwargs["env"] = sdk_env @@ -3160,7 +3385,11 @@ async def stream_chat_completion_sdk( options = ClaudeAgentOptions(**sdk_options_kwargs) # type: ignore[arg-type] # dynamic kwargs - adapter = SDKResponseAdapter(message_id=message_id, session_id=session_id) + adapter = SDKResponseAdapter( + message_id=message_id, + session_id=session_id, + render_reasoning_in_ui=config.render_reasoning_in_ui, + ) # Propagate user_id/session_id as OTEL context attributes so the # langsmith tracing integration attaches them to every span. This @@ -3494,7 +3723,9 @@ async def stream_chat_completion_sdk( session, user_id, is_user_message, state.query_message ) state.adapter = SDKResponseAdapter( - message_id=message_id, session_id=session_id + message_id=message_id, + session_id=session_id, + render_reasoning_in_ui=config.render_reasoning_in_ui, ) # Reset token accumulators so a failed attempt's partial # usage is not double-counted in the successful attempt. @@ -3854,18 +4085,103 @@ async def stream_chat_completion_sdk( # --- Persist token usage to session + rate-limit counters --- # Both must live in finally so they stay consistent even when an # exception interrupts the try block after StreamUsage was yielded. - await persist_and_record_usage( - session=session, - user_id=user_id, - prompt_tokens=turn_prompt_tokens, - completion_tokens=turn_completion_tokens, - cache_read_tokens=turn_cache_read_tokens, - cache_creation_tokens=turn_cache_creation_tokens, - log_prefix=log_prefix, - cost_usd=turn_cost_usd, - model=sdk_model or config.thinking_standard_model, - provider="anthropic", + effective_model = sdk_model or config.thinking_standard_model + # ``state`` is populated lazily inside the retry loop; when the + # turn exits before the first attempt runs (e.g. very early + # validation error) it's still None, so ``generation_ids`` is + # empty by definition. + collected_gen_ids: list[str] = ( + list(state.generation_ids) if state is not None else [] ) + _use_openrouter_reconcile = bool( + config.openrouter_active + and config.sdk_reconcile_openrouter_cost + and collected_gen_ids + ) + + # CLI project dir — used by the reconcile task to sweep for + # compaction subagents' gen-IDs. ``sdk_cwd`` is the per-session + # CLI working directory; the CLI encodes it into the project-dir + # name the same way ``encode_cwd_for_cli`` does, and writes + # the main transcript + any ``subagents/`` alongside it under + # ``~/.claude/projects//``. Empty when sdk_cwd isn't + # set (shouldn't happen in practice for SDK turns). + cli_project_dir: str | None = None + if sdk_cwd: + cli_project_dir = os.path.join( + os.path.expanduser("~/.claude/projects"), + encode_cwd_for_cli(sdk_cwd), + ) + + if _use_openrouter_reconcile: + # Defer the single cost-and-rate-limit write to a background + # task that queries OpenRouter's authoritative + # ``/generation?id=`` for every round in this turn. Covers + # all vendors: + # + # * Non-Anthropic (Kimi et al): the CLI's ``total_cost_usd`` + # is computed from a static Anthropic rate table that + # doesn't know the model — silently over-bills by ~5x. + # The reconcile replaces it with OpenRouter's real bill. + # * Anthropic via OpenRouter: the CLI's number matches + # Anthropic's own rates penny-for-penny in the common + # case, but the reconcile catches any rate change the + # CLI binary hasn't picked up and any OpenRouter-side + # divergence (cache-discount accounting, promo pricing). + # + # The task calls ``persist_and_record_usage`` exactly once + # per turn — same method as the sync path, so append-only + # cost-log + rate-limit counter update together. The sync + # path below is skipped entirely when the reconcile fires, + # so no double-counting. Kill-switch: + # ``CHAT_SDK_RECONCILE_OPENROUTER_COST=false``. + # + # Brief window (~0.5-2s) where the rate-limit counter is + # unaware of this turn — back-to-back turns in that window + # see a stale counter. + asyncio.create_task( + record_turn_cost_from_openrouter( + session=session, + user_id=user_id, + model=effective_model, + prompt_tokens=turn_prompt_tokens, + completion_tokens=turn_completion_tokens, + cache_read_tokens=turn_cache_read_tokens, + cache_creation_tokens=turn_cache_creation_tokens, + generation_ids=collected_gen_ids, + cli_project_dir=cli_project_dir, + cli_session_id=session_id, + turn_start_ts=turn_start_ts, + fallback_cost_usd=turn_cost_usd, + api_key=config.api_key, + log_prefix=log_prefix, + ) + ) + else: + # Reconcile disabled, OpenRouter inactive, or subscription + # path (no gen-IDs). Record the SDK CLI's + # ``total_cost_usd`` synchronously: accurate for Anthropic + # (same rate card as billing); for non-Anthropic it's the + # rate-card estimate that ``_override_cost_for_non_anthropic`` + # caps (still 1.5-2x off vs real OpenRouter bill, but much + # closer than the ~5x Sonnet-rate fallback). + await persist_and_record_usage( + session=session, + user_id=user_id, + prompt_tokens=turn_prompt_tokens, + completion_tokens=turn_completion_tokens, + cache_read_tokens=turn_cache_read_tokens, + cache_creation_tokens=turn_cache_creation_tokens, + log_prefix=log_prefix, + cost_usd=turn_cost_usd, + model=effective_model, + # ``provider`` labels the cost-analytics row; the cost + # value still comes from the SDK-reported number. + # Tracks the actual upstream so the row matches reality: + # OpenRouter when ``openrouter_active``, Anthropic + # otherwise. + provider=("open_router" if config.openrouter_active else "anthropic"), + ) # --- Persist session messages --- # This MUST run in finally to persist messages even when the generator diff --git a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py index 4eb5bc4ac2..0146fe53f1 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service_helpers_test.py @@ -366,38 +366,84 @@ class TestNormalizeModelName: The per-request model toggle calls _normalize_model_name with either ``config.thinking_advanced_model`` (for 'advanced') or ``config.thinking_standard_model`` (for 'standard'). These tests verify - the OpenRouter/provider-prefix stripping that keeps the value compatible - with the Claude CLI. + the OpenRouter/direct-Anthropic split: OpenRouter routes by full + ``vendor/model`` slug, while direct-Anthropic strips the prefix and + converts dots to hyphens. """ - def test_strips_anthropic_prefix(self): + @pytest.fixture + def _direct_anthropic_config(self, monkeypatch: pytest.MonkeyPatch): + """Force ``config.openrouter_active = False`` for prefix-strip tests. + + Pins the SDK model fields to anthropic/* so the new + ``_validate_sdk_model_vendor_compatibility`` model_validator + permits ChatConfig construction. + """ + from backend.copilot import config as cfg_mod + + cfg = cfg_mod.ChatConfig( + use_openrouter=False, + api_key=None, + base_url=None, + use_claude_code_subscription=False, + thinking_standard_model="anthropic/claude-sonnet-4-6", + thinking_advanced_model="anthropic/claude-opus-4-7", + ) + monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) + + @pytest.fixture + def _openrouter_config(self, monkeypatch: pytest.MonkeyPatch): + """Force ``config.openrouter_active = True`` for slug-preservation tests.""" + from backend.copilot import config as cfg_mod + + cfg = cfg_mod.ChatConfig( + use_openrouter=True, + api_key="or-key", + base_url="https://openrouter.ai/api/v1", + use_claude_code_subscription=False, + ) + monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) + + def test_strips_anthropic_prefix(self, _direct_anthropic_config): assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6" - def test_strips_openai_prefix(self): - assert _normalize_model_name("openai/gpt-4o") == "gpt-4o" + def test_rejects_non_anthropic_vendor_in_direct_mode( + self, _direct_anthropic_config + ): + """Direct-Anthropic mode must fail loudly on non-Anthropic vendor + slugs — silent strip would send e.g. ``gpt-4o`` to the Anthropic + API and produce an opaque model_not_found error.""" + with pytest.raises(ValueError, match="requires an Anthropic model"): + _normalize_model_name("openai/gpt-4o") + with pytest.raises(ValueError, match="requires an Anthropic model"): + _normalize_model_name("moonshotai/kimi-k2.6") + with pytest.raises(ValueError, match="requires an Anthropic model"): + _normalize_model_name("google/gemini-2.5-flash") - def test_strips_google_prefix(self): - assert _normalize_model_name("google/gemini-2.5-flash") == "gemini-2.5-flash" - - def test_already_normalized_unchanged(self): + def test_already_normalized_unchanged(self, _direct_anthropic_config): assert ( _normalize_model_name("claude-sonnet-4-20250514") == "claude-sonnet-4-20250514" ) - def test_empty_string_unchanged(self): + def test_empty_string_unchanged(self, _direct_anthropic_config): assert _normalize_model_name("") == "" - def test_opus_model_roundtrip(self): - """The exact string used for the 'opus' toggle strips correctly.""" - assert _normalize_model_name("anthropic/claude-opus-4-6") == "claude-opus-4-6" + def test_opus_model_dot_to_hyphen(self, _direct_anthropic_config): + """Direct-Anthropic mode: dots in versions become hyphens.""" + assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4-6" - def test_sonnet_openrouter_model(self): - """Sonnet model as stored in config (OpenRouter-prefixed) strips cleanly.""" + def test_openrouter_keeps_anthropic_slug(self, _openrouter_config): + """OpenRouter routes by full slug — keep prefix and dots intact.""" assert ( - _normalize_model_name("anthropic/claude-sonnet-4-6") == "claude-sonnet-4-6" + _normalize_model_name("anthropic/claude-sonnet-4.6") + == "anthropic/claude-sonnet-4.6" ) + def test_openrouter_keeps_kimi_slug(self, _openrouter_config): + """Non-Anthropic vendors (Moonshot) require the prefix to route.""" + assert _normalize_model_name("moonshotai/kimi-k2.6") == "moonshotai/kimi-k2.6" + # --------------------------------------------------------------------------- # _TokenUsage — null-safe accumulation (OpenRouter initial-stream-event bug) diff --git a/autogpt_platform/backend/backend/copilot/sdk/service_test.py b/autogpt_platform/backend/backend/copilot/sdk/service_test.py index 619fce3017..82d6ff7e60 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/service_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/service_test.py @@ -4,6 +4,7 @@ import asyncio import base64 import os from dataclasses import dataclass +from types import SimpleNamespace from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -17,6 +18,7 @@ from .service import ( _normalize_model_name, _prepare_file_attachments, _resolve_sdk_model, + _resolve_sdk_model_for_request, _safe_close_sdk_client, ) @@ -355,11 +357,16 @@ class TestNormalizeModelName: api_key=None, base_url=None, use_claude_code_subscription=False, + # Pin SDK slugs to anthropic/* so the new + # _validate_sdk_model_vendor_compatibility allows construction. + thinking_standard_model="anthropic/claude-sonnet-4-6", + thinking_advanced_model="anthropic/claude-opus-4-7", ) monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4-6" - def test_dots_preserved_for_openrouter(self, monkeypatch, _clean_config_env): + def test_openrouter_keeps_full_slug(self, monkeypatch, _clean_config_env): + """OpenRouter routes by ``vendor/model`` slug — keep prefix and dots.""" from backend.copilot import config as cfg_mod cfg = cfg_mod.ChatConfig( @@ -369,7 +376,11 @@ class TestNormalizeModelName: use_claude_code_subscription=False, ) monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) - assert _normalize_model_name("anthropic/claude-opus-4.6") == "claude-opus-4.6" + assert ( + _normalize_model_name("anthropic/claude-opus-4.6") + == "anthropic/claude-opus-4.6" + ) + assert _normalize_model_name("moonshotai/kimi-k2.6") == "moonshotai/kimi-k2.6" def test_no_prefix_no_dots(self, monkeypatch, _clean_config_env): from backend.copilot import config as cfg_mod @@ -379,6 +390,8 @@ class TestNormalizeModelName: api_key=None, base_url=None, use_claude_code_subscription=False, + thinking_standard_model="anthropic/claude-sonnet-4-6", + thinking_advanced_model="anthropic/claude-opus-4-7", ) monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) assert ( @@ -390,8 +403,9 @@ class TestNormalizeModelName: class TestResolveSdkModel: """Tests for _resolve_sdk_model — model ID resolution for the SDK CLI.""" - def test_openrouter_active_keeps_dots(self, monkeypatch, _clean_config_env): - """When OpenRouter is fully active, model keeps dot-separated version.""" + def test_openrouter_active_keeps_full_slug(self, monkeypatch, _clean_config_env): + """When OpenRouter is fully active, the canonical vendor/model slug + is preserved so OpenRouter can route to the correct provider.""" from backend.copilot import config as cfg_mod cfg = cfg_mod.ChatConfig( @@ -403,7 +417,23 @@ class TestResolveSdkModel: use_claude_code_subscription=False, ) monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) - assert _resolve_sdk_model() == "claude-opus-4.6" + assert _resolve_sdk_model() == "anthropic/claude-opus-4.6" + + def test_openrouter_active_kimi_slug(self, monkeypatch, _clean_config_env): + """Non-Anthropic models (Kimi via Moonshot) require the prefix to + survive OpenRouter routing — strip would leave an unroutable slug.""" + from backend.copilot import config as cfg_mod + + cfg = cfg_mod.ChatConfig( + thinking_standard_model="moonshotai/kimi-k2.6", + claude_agent_model=None, + use_openrouter=True, + api_key="or-key", + base_url="https://openrouter.ai/api/v1", + use_claude_code_subscription=False, + ) + monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) + assert _resolve_sdk_model() == "moonshotai/kimi-k2.6" def test_openrouter_disabled_normalizes_to_hyphens( self, monkeypatch, _clean_config_env @@ -488,6 +518,213 @@ class TestResolveSdkModel: assert _resolve_sdk_model() == "claude-opus-4-6" +class TestResolveSdkModelForRequestLdFallback: + """``_resolve_sdk_model_for_request`` must fail soft when the LD value + can't be normalised for the active routing mode — flagged as MAJOR by + CodeRabbit + HIGH by Sentry when it was a hard ValueError.""" + + @pytest.mark.asyncio + async def test_direct_anthropic_mode_rejects_kimi_ld_value_and_falls_back( + self, monkeypatch, _clean_config_env + ): + """LD serves ``moonshotai/kimi-k2.6`` but we're on direct-Anthropic + (no OpenRouter key). ``_normalize_model_name`` raises; the + resolver must log + return the config-default path instead of + 500-ing the turn.""" + cfg = cfg_mod.ChatConfig( + thinking_standard_model="anthropic/claude-sonnet-4-6", + claude_agent_model=None, + use_openrouter=False, + api_key=None, + base_url=None, + use_claude_code_subscription=False, + ) + monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) + + with patch( + "backend.copilot.sdk.service._resolve_thinking_model_for_user", + new=AsyncMock(return_value="moonshotai/kimi-k2.6"), + ): + resolved = await _resolve_sdk_model_for_request( + model="standard", session_id="sess-abc", user_id="user-1" + ) + + # Fallback == tier-specific config default (thinking_standard_model + # normalised to hyphen-form for direct-Anthropic mode). + assert resolved == "claude-sonnet-4-6" + + @pytest.mark.asyncio + async def test_openrouter_mode_accepts_ld_kimi_value( + self, monkeypatch, _clean_config_env + ): + """On OpenRouter the Kimi slug is legitimate — no fallback, + value returned as-is.""" + cfg = cfg_mod.ChatConfig( + thinking_standard_model="anthropic/claude-sonnet-4-6", + claude_agent_model=None, + use_openrouter=True, + api_key="or-key", + base_url="https://openrouter.ai/api/v1", + use_claude_code_subscription=False, + ) + monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) + + with patch( + "backend.copilot.sdk.service._resolve_thinking_model_for_user", + new=AsyncMock(return_value="moonshotai/kimi-k2.6"), + ): + resolved = await _resolve_sdk_model_for_request( + model="standard", session_id="sess-abc", user_id="user-1" + ) + assert resolved == "moonshotai/kimi-k2.6" + + @pytest.mark.asyncio + async def test_advanced_tier_fallback_uses_advanced_default_not_standard( + self, monkeypatch, _clean_config_env + ): + """An LD-rejected ADVANCED slug must fall back to the advanced + config default (Opus) — not the standard default (Sonnet). + Using ``_resolve_sdk_model()`` as the fallback silently + downgraded the user's chosen tier. Flagged MAJOR by CodeRabbit + + HIGH by Sentry on the first fail-soft commit.""" + cfg = cfg_mod.ChatConfig( + thinking_standard_model="anthropic/claude-sonnet-4-6", + thinking_advanced_model="anthropic/claude-opus-4.7", + claude_agent_model=None, + use_openrouter=False, + api_key=None, + base_url=None, + use_claude_code_subscription=False, + ) + monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) + + with patch( + "backend.copilot.sdk.service._resolve_thinking_model_for_user", + new=AsyncMock(return_value="moonshotai/kimi-k2.6"), + ): + resolved = await _resolve_sdk_model_for_request( + model="advanced", session_id="sess-adv", user_id="user-1" + ) + + # Direct-Anthropic normalises anthropic/claude-opus-4.7 → claude-opus-4-7 + assert resolved == "claude-opus-4-7" + + @pytest.mark.asyncio + async def test_standard_ld_override_wins_over_subscription( + self, monkeypatch, _clean_config_env + ): + """Bug reported in local test: subscription mode + LD serving Kimi + on ``copilot-thinking-standard-model`` returned ``None`` (CLI + picked subscription default Opus), silently ignoring the LD + override. An LD value different from the config default is an + explicit admin decision and must win.""" + cfg = cfg_mod.ChatConfig( + thinking_standard_model="anthropic/claude-sonnet-4-6", + claude_agent_model=None, + use_openrouter=True, + api_key="or-key", + base_url="https://openrouter.ai/api/v1", + use_claude_code_subscription=True, + ) + monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) + + with patch( + "backend.copilot.sdk.service._resolve_thinking_model_for_user", + new=AsyncMock(return_value="moonshotai/kimi-k2.6"), + ): + resolved = await _resolve_sdk_model_for_request( + model="standard", session_id="sess-std-sub", user_id="user-1" + ) + # Expect LD-served Kimi, NOT None (the old subscription-default bypass) + assert resolved == "moonshotai/kimi-k2.6" + + @pytest.mark.asyncio + async def test_standard_subscription_survives_trailing_whitespace_in_env( + self, monkeypatch, _clean_config_env + ): + """``_resolve_thinking_model_for_user`` strips whitespace from the LD + side; the config tier default must be stripped too, otherwise a + stray trailing space in ``CHAT_THINKING_STANDARD_MODEL`` makes + ``resolved == tier_default`` spuriously False and bypasses + subscription-default mode. Sentry HIGH on L856.""" + cfg = cfg_mod.ChatConfig( + thinking_standard_model="anthropic/claude-sonnet-4-6 ", # trailing spaces + claude_agent_model=None, + use_openrouter=False, + api_key=None, + base_url=None, + use_claude_code_subscription=True, + ) + monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) + + with patch( + "backend.copilot.sdk.service._resolve_thinking_model_for_user", + new=AsyncMock(return_value="anthropic/claude-sonnet-4-6"), + ): + resolved = await _resolve_sdk_model_for_request( + model="standard", session_id="sess-ws", user_id="user-1" + ) + assert resolved is None, ( + "LD value semantically matches the whitespace-padded config " + "default — subscription mode must still win and return None" + ) + + @pytest.mark.asyncio + async def test_standard_subscription_default_honoured_when_ld_matches_config( + self, monkeypatch, _clean_config_env + ): + """When LD serves the SAME value as the config default (i.e. the + flag is effectively unset / no override), subscription mode still + wins and we return ``None`` so the CLI uses the subscription + default model.""" + cfg = cfg_mod.ChatConfig( + thinking_standard_model="anthropic/claude-sonnet-4-6", + claude_agent_model=None, + use_openrouter=False, + api_key=None, + base_url=None, + use_claude_code_subscription=True, + ) + monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) + + with patch( + "backend.copilot.sdk.service._resolve_thinking_model_for_user", + new=AsyncMock(return_value="anthropic/claude-sonnet-4-6"), + ): + resolved = await _resolve_sdk_model_for_request( + model="standard", session_id="sess-std-nop", user_id="user-1" + ) + assert resolved is None + + @pytest.mark.asyncio + async def test_advanced_tier_consults_ld_under_subscription( + self, monkeypatch, _clean_config_env + ): + """Subscription mode bypasses LD only on the standard tier — + the advanced tier always consults LD because the user explicitly + asked for the premium path. A subscription + advanced request + with LD-served Opus must return Opus (not ``None``).""" + cfg = cfg_mod.ChatConfig( + thinking_standard_model="anthropic/claude-sonnet-4-6", + thinking_advanced_model="anthropic/claude-opus-4.7", + claude_agent_model=None, + use_openrouter=True, + api_key="or-key", + base_url="https://openrouter.ai/api/v1", + use_claude_code_subscription=True, + ) + monkeypatch.setattr("backend.copilot.sdk.service.config", cfg) + + with patch( + "backend.copilot.sdk.service._resolve_thinking_model_for_user", + new=AsyncMock(return_value="anthropic/claude-opus-4.7"), + ): + resolved = await _resolve_sdk_model_for_request( + model="advanced", session_id="sess-adv-sub", user_id="user-1" + ) + assert resolved == "anthropic/claude-opus-4.7" + + # --------------------------------------------------------------------------- # _is_sdk_disconnect_error — classify client disconnect cleanup errors # --------------------------------------------------------------------------- @@ -651,6 +888,8 @@ class TestSystemPromptPreset: api_key=None, base_url=None, use_claude_code_subscription=False, + thinking_standard_model="anthropic/claude-sonnet-4-6", + thinking_advanced_model="anthropic/claude-opus-4-7", ) assert cfg.claude_agent_cross_user_prompt_cache is True @@ -662,6 +901,8 @@ class TestSystemPromptPreset: api_key=None, base_url=None, use_claude_code_subscription=False, + thinking_standard_model="anthropic/claude-sonnet-4-6", + thinking_advanced_model="anthropic/claude-opus-4-7", ) assert cfg.claude_agent_cross_user_prompt_cache is False @@ -674,3 +915,225 @@ class TestIdleTimeoutConstant: def test_idle_timeout_is_10_min(self): assert _IDLE_TIMEOUT_SECONDS == 10 * 60 + + +# --------------------------------------------------------------------------- +# _RetryState.observed_model — Moonshot cost-override input +# --------------------------------------------------------------------------- + + +class TestRetryStateObservedModel: + """Regression guards for the ``observed_model`` field added to + ``_RetryState``. The Moonshot cost override reads this — when a + fallback model activates mid-attempt, the requested primary + (``state.options.model``) no longer matches what actually ran.""" + + def _make_state(self, *, options_model: str | None = "primary/model"): + """Build a minimally-valid ``_RetryState``. All the heavy + collaborators are ``MagicMock()`` — the field we care about is + a plain Optional[str], so the surrounding scaffolding just needs + to let the dataclass instantiate.""" + from .service import _RetryState, _TokenUsage + + options = MagicMock() + options.model = options_model + return _RetryState( + options=options, + query_message="", + was_compacted=False, + use_resume=False, + resume_file=None, + transcript_msg_count=0, + adapter=MagicMock(), + transcript_builder=MagicMock(), + usage=_TokenUsage(), + ) + + def test_default_is_none(self): + state = self._make_state() + assert state.observed_model is None + + def test_assigned_from_assistant_message_model(self): + """Simulates the population path in ``_run_stream_attempt``: + ``observed`` is pulled off the ``AssistantMessage.model`` attr + and assigned onto ``state.observed_model`` when it's a non-empty + string.""" + state = self._make_state() + # Simulates the inline assignment the generator does on each + # AssistantMessage — a non-empty string lands on state. + assistant_like = SimpleNamespace(model="anthropic/claude-sonnet-4-6") + observed = getattr(assistant_like, "model", None) + if isinstance(observed, str) and observed: + state.observed_model = observed + assert state.observed_model == "anthropic/claude-sonnet-4-6" + + def test_empty_string_model_is_not_assigned(self): + """Guard against overwriting a real observed value with an + empty-string model (the generator's ``and observed`` check).""" + state = self._make_state() + state.observed_model = "moonshotai/kimi-k2.6" # seeded from a prior msg + assistant_like = SimpleNamespace(model="") + observed = getattr(assistant_like, "model", None) + if isinstance(observed, str) and observed: + state.observed_model = observed + assert state.observed_model == "moonshotai/kimi-k2.6" + + def test_missing_model_attr_leaves_observed_untouched(self): + state = self._make_state() + state.observed_model = "moonshotai/kimi-k2.6" + # AssistantMessage may not carry ``.model`` on older SDK rels. + assistant_like = SimpleNamespace() # no ``.model`` attr + observed = getattr(assistant_like, "model", None) + if isinstance(observed, str) and observed: + state.observed_model = observed + assert state.observed_model == "moonshotai/kimi-k2.6" + + +# --------------------------------------------------------------------------- +# Moonshot cost-override gate — decision logic at the call site +# --------------------------------------------------------------------------- + + +class TestMoonshotCostOverrideGate: + """Regression guards for the decision logic in + ``_run_stream_attempt`` that picks between the CLI-reported cost + and the Moonshot rate-card override. The code: + + active_model = state.observed_model or getattr(state.options, "model", None) + if _is_moonshot_model(active_model): + state.usage.cost_usd = _override_cost_for_moonshot(...) + else: + state.usage.cost_usd = sdk_msg.total_cost_usd + + is critical-path billing logic — make sure observed_model wins over + the requested primary, and Anthropic turns pass through untouched.""" + + def _decide_cost( + self, + *, + observed_model: str | None, + options_model: str | None, + sdk_reported_usd: float, + prompt_tokens: int = 0, + completion_tokens: int = 0, + ) -> float: + """Mirror of the real decision block — lets us assert the gate + without constructing the whole 1000-line generator.""" + from .service import _is_moonshot_model, _override_cost_for_moonshot + + active_model = observed_model or options_model + if _is_moonshot_model(active_model): + return _override_cost_for_moonshot( + model=active_model, + sdk_reported_usd=sdk_reported_usd, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + cache_read_tokens=0, + cache_creation_tokens=0, + ) + return sdk_reported_usd + + def test_anthropic_turn_passes_sdk_cost_through(self): + """Anthropic — the CLI's pricing table is authoritative, so + ``state.usage.cost_usd`` is set to ``sdk_msg.total_cost_usd`` + unchanged.""" + cost = self._decide_cost( + observed_model="anthropic/claude-sonnet-4-6", + options_model="anthropic/claude-sonnet-4-6", + sdk_reported_usd=0.123, + ) + assert cost == 0.123 + + def test_moonshot_turn_uses_rate_card_override(self): + """Moonshot — the CLI would silently bill at Sonnet rates, so + the override recomputes from the Moonshot rate card.""" + cost = self._decide_cost( + observed_model="moonshotai/kimi-k2.6", + options_model="moonshotai/kimi-k2.6", + sdk_reported_usd=0.089862, # CLI's Sonnet-priced estimate. + prompt_tokens=29564, + completion_tokens=78, + ) + expected = (29564 * 0.60 + 78 * 2.80) / 1_000_000 + assert cost == pytest.approx(expected, rel=1e-9) + # Sanity: ~5x cheaper than the CLI's Sonnet-priced number. + assert cost < 0.089862 / 4 + + def test_observed_model_wins_over_options_primary(self): + """The whole point of ``observed_model``: a Moonshot-primary + request that fell back to Anthropic must NOT get Moonshot + pricing applied. The gate follows the observed model, not the + requested primary.""" + cost = self._decide_cost( + observed_model="anthropic/claude-sonnet-4-6", + options_model="moonshotai/kimi-k2.6", # what we ASKED for + sdk_reported_usd=0.123, + prompt_tokens=1000, + completion_tokens=100, + ) + # Observed == Anthropic → CLI-reported cost passes through unchanged. + assert cost == 0.123 + + def test_anthropic_to_moonshot_fallback_uses_override(self): + """The inverse: an Anthropic-primary request that fell back to + Moonshot must get the Moonshot override applied — the CLI is + still billing at Sonnet rates for the fallback response.""" + cost = self._decide_cost( + observed_model="moonshotai/kimi-k2.6", + options_model="anthropic/claude-sonnet-4-6", + sdk_reported_usd=0.089862, + prompt_tokens=29564, + completion_tokens=78, + ) + expected = (29564 * 0.60 + 78 * 2.80) / 1_000_000 + assert cost == pytest.approx(expected, rel=1e-9) + + def test_no_observed_falls_back_to_options_model(self): + """First AssistantMessage hasn't arrived yet (or the SDK didn't + emit ``.model``) — the gate falls back to the requested primary.""" + cost = self._decide_cost( + observed_model=None, + options_model="moonshotai/kimi-k2.6", + sdk_reported_usd=0.089862, + prompt_tokens=100, + completion_tokens=10, + ) + expected = (100 * 0.60 + 10 * 2.80) / 1_000_000 + assert cost == pytest.approx(expected, rel=1e-9) + + def test_both_none_passes_sdk_cost_through(self): + """Subscription mode — ``options.model`` may be None and no + AssistantMessage has arrived yet. ``None`` is not a Moonshot + slug so the SDK number lands unchanged.""" + cost = self._decide_cost( + observed_model=None, + options_model=None, + sdk_reported_usd=0.05, + ) + assert cost == 0.05 + + +# --------------------------------------------------------------------------- +# Moonshot helper re-exports — keep imports stable for call-site code +# --------------------------------------------------------------------------- + + +class TestMoonshotHelperReexports: + """``sdk/service.py`` imports the Moonshot helpers under local + aliases (``_is_moonshot_model``, ``_override_cost_for_moonshot``). + Regression guard so a refactor doesn't silently break the import + path the hot-loop code relies on.""" + + def test_is_moonshot_model_aliased(self): + from backend.copilot.moonshot import is_moonshot_model as canonical + + from .service import _is_moonshot_model + + assert _is_moonshot_model is canonical + + def test_override_cost_for_moonshot_aliased(self): + from backend.copilot.moonshot import override_cost_usd as canonical + + from .service import _override_cost_for_moonshot + + assert _override_cost_for_moonshot is canonical diff --git a/autogpt_platform/backend/backend/copilot/sdk/tool_adapter.py b/autogpt_platform/backend/backend/copilot/sdk/tool_adapter.py index 7e1fa0396d..ca1f1f821e 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/tool_adapter.py +++ b/autogpt_platform/backend/backend/copilot/sdk/tool_adapter.py @@ -10,6 +10,7 @@ import json import logging import os import uuid +from collections.abc import Iterable from contextvars import ContextVar from typing import TYPE_CHECKING, Any @@ -33,7 +34,7 @@ from backend.copilot.sdk.file_ref import ( expand_file_refs_in_args, read_file_bytes, ) -from backend.copilot.tools import TOOL_REGISTRY +from backend.copilot.tools import TOOL_REGISTRY, ToolGroup, tool_names_in_groups from backend.copilot.tools.base import BaseTool from backend.util.truncate import truncate @@ -853,9 +854,29 @@ DANGEROUS_PATTERNS = [ r"subprocess", ] +# Platform-tool names whose MCP wrappers must NOT be exposed to SDK mode. +# Baseline ships an MCP ``TodoWrite`` for model-flexibility parity; SDK mode +# keeps using the CLI-native built-in listed in ``_SDK_BUILTIN_ALWAYS`` so +# there is no double exposure. Public (no leading underscore) so a future +# refactor renaming it is visible at both call sites — +# ``permissions.apply_tool_permissions`` maps short tool names back to the +# CLI built-in form for SDK mode. +BASELINE_ONLY_MCP_TOOLS: frozenset[str] = frozenset({"TodoWrite"}) + + +def _registry_mcp_tools(*, hidden: frozenset[str] = frozenset()) -> list[str]: + return [ + f"{MCP_TOOL_PREFIX}{name}" + for name in TOOL_REGISTRY.keys() + if name not in BASELINE_ONLY_MCP_TOOLS and name not in hidden + ] + + # Static tool name list for the non-E2B case (backward compatibility). +# Includes all capability-gated tools; per-user filtering happens in +# ``get_copilot_tool_names`` when the caller passes ``disabled_groups``. COPILOT_TOOL_NAMES = [ - *[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()], + *_registry_mcp_tools(), f"{MCP_TOOL_PREFIX}{WRITE_TOOL_NAME}", f"{MCP_TOOL_PREFIX}{READ_TOOL_NAME}", f"{MCP_TOOL_PREFIX}{EDIT_TOOL_NAME}", @@ -864,20 +885,31 @@ COPILOT_TOOL_NAMES = [ ] -def get_copilot_tool_names(*, use_e2b: bool = False) -> list[str]: +def get_copilot_tool_names( + *, + use_e2b: bool = False, + disabled_groups: Iterable[ToolGroup] = (), +) -> list[str]: """Build the ``allowed_tools`` list for :class:`ClaudeAgentOptions`. When *use_e2b* is True the SDK built-in file tools are replaced by MCP - equivalents that route to the E2B sandbox. + equivalents that route to the E2B sandbox. Tools belonging to any of + *disabled_groups* are filtered out — see ``ToolGroup`` / ``TOOL_GROUPS`` + in ``backend.copilot.tools`` for the full list. """ + hidden_short_names = tool_names_in_groups(disabled_groups) + hidden_mcp_names = {f"{MCP_TOOL_PREFIX}{n}" for n in hidden_short_names} + if not use_e2b: - return list(COPILOT_TOOL_NAMES) + if not hidden_mcp_names: + return list(COPILOT_TOOL_NAMES) + return [n for n in COPILOT_TOOL_NAMES if n not in hidden_mcp_names] # In E2B mode, Write/Edit are NOT registered (E2B uses write_file/edit_file # from E2B_FILE_TOOLS instead), so don't include them here. # _READ_TOOL_NAME is still needed for SDK tool-result reads. return [ - *[f"{MCP_TOOL_PREFIX}{name}" for name in TOOL_REGISTRY.keys()], + *_registry_mcp_tools(hidden=hidden_short_names), f"{MCP_TOOL_PREFIX}{_READ_TOOL_NAME}", *[f"{MCP_TOOL_PREFIX}{name}" for name in E2B_FILE_TOOL_NAMES], *_SDK_BUILTIN_ALWAYS, diff --git a/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py b/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py index 01f3540c28..4b8849fb57 100644 --- a/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py +++ b/autogpt_platform/backend/backend/copilot/sdk/transcript_test.py @@ -1249,7 +1249,14 @@ class TestStripStaleThinkingBlocks: new_asst = self._asst_entry( "msg_new", [ - {"type": "thinking", "thinking": "latest thoughts"}, + # Anthropic-shape thinking block (has signature) — preserved + # on the last turn. Signature-less variant covered by + # ``test_strips_signatureless_last_turn_thinking``. + { + "type": "thinking", + "thinking": "latest thoughts", + "signature": "anthropic-sig", + }, {"type": "text", "text": "world"}, ], uuid="a2", @@ -1271,11 +1278,16 @@ class TestStripStaleThinkingBlocks: assert new_content[1]["type"] == "text" def test_preserves_last_assistant_thinking(self) -> None: - """The last assistant entry's thinking blocks must be preserved.""" + """The last assistant entry's thinking blocks must be preserved + when they carry an Anthropic ``signature``. Signature-less + blocks (e.g. from Kimi K2.6 via OpenRouter) are stripped to + prevent ``Invalid `signature` in `thinking` block`` errors on + a subsequent Anthropic-model turn — covered by + ``test_strips_signatureless_last_turn_thinking``.""" entry = self._asst_entry( "msg_only", [ - {"type": "thinking", "thinking": "must keep"}, + {"type": "thinking", "thinking": "must keep", "signature": "sig"}, {"type": "text", "text": "response"}, ], ) @@ -1284,6 +1296,47 @@ class TestStripStaleThinkingBlocks: lines = [json.loads(ln) for ln in result.strip().split("\n")] assert len(lines[0]["message"]["content"]) == 2 + def test_strips_signatureless_last_turn_thinking(self) -> None: + """Cross-model fix (PR #12878): a signature-less thinking block + on the last assistant turn must be stripped before a subsequent + Anthropic-model dispatch tries to replay it.""" + entry = self._asst_entry( + "msg_kimi", + [ + # No signature → non-Anthropic provider + {"type": "thinking", "thinking": "kimi reasoning"}, + {"type": "text", "text": "answer"}, + ], + ) + content = _make_jsonl(entry) + result = strip_stale_thinking_blocks(content) + lines = [json.loads(ln) for ln in result.strip().split("\n")] + types = [b["type"] for b in lines[0]["message"]["content"]] + assert "thinking" not in types + assert "text" in types + + def test_preserves_redacted_thinking_on_last_turn(self) -> None: + """``redacted_thinking`` blocks are signature-less by design + (encrypted ``data`` field instead). Stripping them on the last + turn would violate Anthropic's value-identity requirement for + multi-turn replay. The signature rule only applies to plain + ``thinking`` blocks.""" + entry = self._asst_entry( + "msg_anthropic", + [ + # Anthropic-emitted redacted_thinking: has ``data``, + # never has ``signature``. + {"type": "redacted_thinking", "data": "encrypted_blob"}, + {"type": "text", "text": "response"}, + ], + ) + content = _make_jsonl(entry) + result = strip_stale_thinking_blocks(content) + lines = [json.loads(ln) for ln in result.strip().split("\n")] + types = [b["type"] for b in lines[0]["message"]["content"]] + assert "redacted_thinking" in types + assert "text" in types + def test_no_assistant_entries_returns_unchanged(self) -> None: """Transcripts with only user entries should pass through unchanged.""" user = self._user_entry("hello") @@ -1294,12 +1347,13 @@ class TestStripStaleThinkingBlocks: assert strip_stale_thinking_blocks("") == "" def test_multiple_turns_strips_all_but_last(self) -> None: - """With 3 assistant turns, only the last keeps thinking blocks.""" + """With 3 assistant turns, only the last keeps thinking blocks + (and only when those blocks carry an Anthropic ``signature``).""" entries = [ self._asst_entry( "msg_1", [ - {"type": "thinking", "thinking": "t1"}, + {"type": "thinking", "thinking": "t1", "signature": "s1"}, {"type": "text", "text": "a1"}, ], uuid="a1", @@ -1308,7 +1362,7 @@ class TestStripStaleThinkingBlocks: self._asst_entry( "msg_2", [ - {"type": "thinking", "thinking": "t2"}, + {"type": "thinking", "thinking": "t2", "signature": "s2"}, {"type": "text", "text": "a2"}, ], uuid="a2", @@ -1318,7 +1372,7 @@ class TestStripStaleThinkingBlocks: self._asst_entry( "msg_3", [ - {"type": "thinking", "thinking": "t3"}, + {"type": "thinking", "thinking": "t3", "signature": "s3"}, {"type": "text", "text": "a3"}, ], uuid="a3", @@ -1339,16 +1393,18 @@ class TestStripStaleThinkingBlocks: assert lines[4]["message"]["content"][0]["type"] == "thinking" def test_same_msg_id_multi_entry_turn(self) -> None: - """Multiple entries sharing the same message.id (same turn) are preserved.""" + """Multiple entries sharing the same message.id (same turn) are + preserved when their thinking blocks carry an Anthropic + ``signature``.""" entries = [ self._asst_entry( "msg_old", - [{"type": "thinking", "thinking": "old"}], + [{"type": "thinking", "thinking": "old", "signature": "old_sig"}], uuid="a1", ), self._asst_entry( "msg_last", - [{"type": "thinking", "thinking": "t_part1"}], + [{"type": "thinking", "thinking": "t_part1", "signature": "p1_sig"}], uuid="a2", parent="a1", ), diff --git a/autogpt_platform/backend/backend/copilot/service.py b/autogpt_platform/backend/backend/copilot/service.py index b0399f87e3..061088e788 100644 --- a/autogpt_platform/backend/backend/copilot/service.py +++ b/autogpt_platform/backend/backend/copilot/service.py @@ -17,6 +17,7 @@ from langfuse import get_client from langfuse.openai import ( AsyncOpenAI as LangfuseAsyncOpenAI, # pyright: ignore[reportPrivateImportUsage] ) +from openai.types.chat import ChatCompletion from backend.data.db_accessors import chat_db, understanding_db from backend.data.understanding import ( @@ -34,6 +35,7 @@ from .model import ( update_session_title, upsert_chat_session, ) +from .token_tracking import persist_and_record_usage logger = logging.getLogger(__name__) @@ -495,20 +497,31 @@ async def _generate_session_title( message: str, user_id: str | None = None, session_id: str | None = None, -) -> str | None: +) -> tuple[str | None, ChatCompletion | None]: """Generate a concise title for a chat session based on the first message. + Returns ``(title, response)``. The caller is responsible for + persisting the title AND recording the title call's cost — keeping + them as separate concerns in the caller lets a cost-tracking hiccup + not lose the title, and lets a title-persist failure still record + the cost (we paid for the LLM call either way). + Args: message: The first user message in the session user_id: User ID for OpenRouter tracing (optional) session_id: Session ID for OpenRouter tracing (optional) Returns: - A short title (3-6 words) or None if generation fails + ``(title, response)`` on success; ``(None, None)`` if the LLM + call raised. ``response`` is returned even when ``title`` is + empty so the caller can still record the (paid-for) cost. """ try: - # Build extra_body for OpenRouter tracing and PostHog analytics - extra_body: dict[str, Any] = {} + # Build extra_body for OpenRouter tracing and PostHog analytics. + # ``usage: {"include": True}`` asks OR to embed the real billed + # cost into the final usage chunk — matches the baseline path's + # ``_OPENROUTER_INCLUDE_USAGE_COST`` pattern, same read path. + extra_body: dict[str, Any] = {"usage": {"include": True}} if user_id: extra_body["user"] = user_id[:128] # OpenRouter limit extra_body["posthogDistinctId"] = user_id @@ -534,18 +547,113 @@ async def _generate_session_title( max_tokens=20, extra_body=extra_body, ) - title = response.choices[0].message.content - if title: - # Clean up the title - title = title.strip().strip("\"'") - # Limit length - if len(title) > 50: - title = title[:47] + "..." - return title - return None except Exception as e: logger.warning(f"Failed to generate session title: {e}") - return None + return None, None + + # Robust against an empty ``choices`` list OR a choice whose + # ``message`` is missing ``content`` (shouldn't happen on the OpenAI + # SDK typing, but belt-and-suspenders — the background task would + # otherwise die on ``IndexError`` and lose the (paid-for) cost + # recording we're about to do below). + title: str | None = None + if response.choices: + msg = response.choices[0].message + title = msg.content if msg is not None else None + if title: + title = title.strip().strip("\"'") + if len(title) > 50: + title = title[:47] + "..." + return title, response + + +def _title_usage_from_response( + response: ChatCompletion, +) -> tuple[int, int, float | None]: + """Extract ``(prompt_tokens, completion_tokens, cost_usd)`` from a + title-generation chat-completion response. + + Returns zeros / ``None`` for missing fields — the OpenAI SDK's + ``CompletionUsage`` doesn't declare OpenRouter's ``cost`` extension, + so we read it off ``model_extra`` (pydantic v2 extras container). + Absent for non-OR routes; returned as ``None`` in that case. + """ + usage = response.usage + if usage is None: + return 0, 0, None + prompt_tokens = usage.prompt_tokens or 0 + completion_tokens = usage.completion_tokens or 0 + extras = usage.model_extra or {} + cost_raw = extras.get("cost") if isinstance(extras, dict) else None + if isinstance(cost_raw, (int, float)): + cost_usd: float | None = float(cost_raw) + else: + cost_usd = None + return prompt_tokens, completion_tokens, cost_usd + + +async def _record_title_generation_cost( + *, + response: ChatCompletion, + user_id: str | None, + session_id: str | None, +) -> None: + """Persist the title LLM call's cost to ``PlatformCostLog``. + + Title generation runs in a background task per-session — low cost + (~$0.0001 per title) but 100% of sessions pay it. Without this the + admin dashboard under-reports total provider spend by the aggregate + of those calls. Separate ``block_name="copilot:title"`` so the row + is clearly distinguishable from the turn's main ``copilot:SDK`` / + ``copilot:baseline`` attributions. + + Invariants enforced by the caller: + * ``response`` is a completed ``ChatCompletion`` (the create call + didn't raise) — so ``response.usage`` shape is SDK-contractual. + * Exceptions are NOT suppressed — the caller runs this AFTER + title persistence so a persist failure here doesn't lose the + title, and a real DB / Prisma outage surfaces in the caller's + single background-task warning handler. + """ + prompt_tokens, completion_tokens, cost_usd = _title_usage_from_response(response) + + # Nothing meaningful to record — skip the DB roundtrip entirely + # rather than writing a zero-valued row. Covers the non-OR route + # (no ``usage.cost`` field) and the degenerate zero-tokens case. + if cost_usd is None and prompt_tokens == 0 and completion_tokens == 0: + return + + # Provider label is derived from the configured ``base_url`` (title + # LLM uses the shared copilot OpenAI client whose base URL mirrors + # ``ChatConfig.base_url``). This lets a deployment that points + # title generation at a non-OR endpoint still get the correct + # ``provider`` on the cost-log row. + provider = ( + "open_router" + if (config.base_url and "openrouter.ai" in config.base_url) + else "openai" + ) + + # Intentionally pass ``session=None``. ``persist_and_record_usage`` + # would otherwise append a ``Usage`` entry to the live session + # object, but this background task holds no reference to the + # request-scoped session — we'd have to ``get_chat_session`` + + # ``upsert_chat_session`` round-trip the mutation back, and the + # turn's main ``persist_and_record_usage`` already owns the session + # usage-list mirror for the originating turn. Title cost is + # recorded into ``PlatformCostLog`` (admin dashboard) and the + # microdollar rate-limit counter — those are the two places that + # actually matter for this call. + await persist_and_record_usage( + session=None, + user_id=user_id, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + log_prefix="[title]", + cost_usd=cost_usd, + model=config.title_model, + provider=provider, + ) async def _update_title_async( @@ -553,15 +661,29 @@ async def _update_title_async( ) -> None: """Generate and persist a session title in the background. - Shared by both the SDK and baseline execution paths. + Shared by both the SDK and baseline execution paths. Title + persistence and cost recording are run as independent best-effort + steps — a failure in one does not cancel the other, so a flaky + Prisma call on cost recording never costs us the generated title. """ - try: - title = await _generate_session_title(message, user_id, session_id) - if title and user_id: + title, response = await _generate_session_title(message, user_id, session_id) + + if title and user_id: + try: await update_session_title(session_id, user_id, title, only_if_empty=True) logger.debug("Generated title for session %s", session_id) - except Exception as e: - logger.warning("Failed to update session title for %s: %s", session_id, e) + except Exception as e: + logger.warning("Failed to persist session title for %s: %s", session_id, e) + + if response is not None: + try: + await _record_title_generation_cost( + response=response, user_id=user_id, session_id=session_id + ) + except Exception as e: + logger.warning( + "Failed to record title generation cost for %s: %s", session_id, e + ) async def assign_user_to_session( diff --git a/autogpt_platform/backend/backend/copilot/service_unit_test.py b/autogpt_platform/backend/backend/copilot/service_unit_test.py new file mode 100644 index 0000000000..5463404fc8 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/service_unit_test.py @@ -0,0 +1,541 @@ +"""Unit tests for title-generation cost tracking helpers. + +Covers the new code added in PR #12882: + * ``_title_usage_from_response`` — shape-robust OR ``usage.cost`` extraction + * ``_record_title_generation_cost`` — provider-label + zero-tokens gate + * ``_update_title_async`` — independent title / cost persistence try blocks + * ``_generate_session_title`` — tuple return + robustness against empty choices + +Mocks ``persist_and_record_usage`` / ``update_session_title`` at the boundary +where the code under test imports them (``backend.copilot.service.*``). +""" + +from __future__ import annotations + +from types import SimpleNamespace +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from openai.types.chat import ChatCompletion +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.completion_usage import CompletionUsage + +from backend.copilot.service import ( + _generate_session_title, + _record_title_generation_cost, + _title_usage_from_response, + _update_title_async, +) + + +def _build_completion( + *, + content: str | None = "Hello Title", + usage: CompletionUsage | None = None, + choices: list[Choice] | None = None, +) -> ChatCompletion: + if choices is None: + msg = ChatCompletionMessage(role="assistant", content=content) + choices = [Choice(index=0, message=msg, finish_reason="stop")] + return ChatCompletion( + id="cmpl-1", + choices=choices, + created=0, + model="anthropic/claude-haiku", + object="chat.completion", + usage=usage, + ) + + +def _usage_with_cost(cost: object | None) -> CompletionUsage: + """Return a CompletionUsage whose ``model_extra`` carries ``cost``. + + Uses ``model_validate`` so OpenRouter's ``cost`` extension lands in + the pydantic ``model_extra`` dict the helper reads from. + """ + payload: dict[str, object] = { + "prompt_tokens": 12, + "completion_tokens": 3, + "total_tokens": 15, + } + if cost is not None: + payload["cost"] = cost + return CompletionUsage.model_validate(payload) + + +class TestTitleUsageFromResponse: + """``_title_usage_from_response`` returns sensible zeros/Nones when + optional fields are absent or of unexpected shape.""" + + def test_usage_none_returns_all_zero(self): + resp = _build_completion(usage=None) + prompt, completion, cost = _title_usage_from_response(resp) + assert prompt == 0 + assert completion == 0 + assert cost is None + + def test_missing_cost_field_returns_none_cost(self): + resp = _build_completion(usage=_usage_with_cost(None)) + prompt, completion, cost = _title_usage_from_response(resp) + assert prompt == 12 + assert completion == 3 + assert cost is None + + def test_cost_as_int_is_coerced_to_float(self): + resp = _build_completion(usage=_usage_with_cost(2)) + _, _, cost = _title_usage_from_response(resp) + assert isinstance(cost, float) + assert cost == 2.0 + + def test_cost_as_float_is_returned_as_is(self): + resp = _build_completion(usage=_usage_with_cost(0.000123)) + _, _, cost = _title_usage_from_response(resp) + assert cost == pytest.approx(0.000123) + + def test_cost_as_non_numeric_string_returns_none(self): + resp = _build_completion(usage=_usage_with_cost("free")) + _, _, cost = _title_usage_from_response(resp) + assert cost is None + + def test_empty_model_extra_returns_none_cost(self): + # ``model_extra`` is empty for non-OR routes where pydantic didn't + # receive any extras — prompt/completion still flow through. + usage = CompletionUsage(prompt_tokens=5, completion_tokens=2, total_tokens=7) + resp = _build_completion(usage=usage) + prompt, completion, cost = _title_usage_from_response(resp) + assert (prompt, completion, cost) == (5, 2, None) + + def test_zero_prompt_and_completion_tokens(self): + usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) + resp = _build_completion(usage=usage) + prompt, completion, cost = _title_usage_from_response(resp) + assert (prompt, completion, cost) == (0, 0, None) + + +class TestRecordTitleGenerationCost: + """``_record_title_generation_cost`` persists cost + picks the right + provider label and skips the DB roundtrip when nothing's meaningful + to record.""" + + @pytest.mark.asyncio + async def test_openrouter_base_url_uses_open_router_provider(self): + resp = _build_completion(usage=_usage_with_cost(0.0002)) + persist = AsyncMock(return_value=0) + with ( + patch( + "backend.copilot.service.persist_and_record_usage", + new=persist, + ), + patch( + "backend.copilot.service.config", + MagicMock( + base_url="https://openrouter.ai/api/v1", + title_model="anthropic/claude-haiku", + ), + ), + ): + await _record_title_generation_cost( + response=resp, user_id="u", session_id="s" + ) + persist.assert_awaited_once() + kwargs = persist.await_args.kwargs + assert kwargs["provider"] == "open_router" + assert kwargs["model"] == "anthropic/claude-haiku" + assert kwargs["prompt_tokens"] == 12 + assert kwargs["completion_tokens"] == 3 + assert kwargs["cost_usd"] == pytest.approx(0.0002) + assert kwargs["log_prefix"] == "[title]" + assert kwargs["session"] is None + + @pytest.mark.asyncio + async def test_non_openrouter_base_url_uses_openai_provider(self): + resp = _build_completion(usage=_usage_with_cost(0.0002)) + persist = AsyncMock(return_value=0) + with ( + patch( + "backend.copilot.service.persist_and_record_usage", + new=persist, + ), + patch( + "backend.copilot.service.config", + MagicMock( + base_url="https://api.openai.com/v1", + title_model="gpt-4o-mini", + ), + ), + ): + await _record_title_generation_cost( + response=resp, user_id="u", session_id="s" + ) + persist.assert_awaited_once() + assert persist.await_args.kwargs["provider"] == "openai" + + @pytest.mark.asyncio + async def test_empty_base_url_uses_openai_provider(self): + resp = _build_completion(usage=_usage_with_cost(0.0001)) + persist = AsyncMock(return_value=0) + with ( + patch( + "backend.copilot.service.persist_and_record_usage", + new=persist, + ), + patch( + "backend.copilot.service.config", + MagicMock(base_url=None, title_model="gpt-4o-mini"), + ), + ): + await _record_title_generation_cost( + response=resp, user_id=None, session_id=None + ) + persist.assert_awaited_once() + assert persist.await_args.kwargs["provider"] == "openai" + + @pytest.mark.asyncio + async def test_zero_tokens_zero_cost_skips_persist(self): + """No cost, no tokens — the early return avoids a worthless + ``PlatformCostLog`` row.""" + usage = CompletionUsage(prompt_tokens=0, completion_tokens=0, total_tokens=0) + resp = _build_completion(usage=usage) + persist = AsyncMock(return_value=0) + with ( + patch( + "backend.copilot.service.persist_and_record_usage", + new=persist, + ), + patch( + "backend.copilot.service.config", + MagicMock( + base_url="https://openrouter.ai/api/v1", + title_model="x", + ), + ), + ): + await _record_title_generation_cost( + response=resp, user_id="u", session_id="s" + ) + persist.assert_not_awaited() + + @pytest.mark.asyncio + async def test_usage_none_skips_persist(self): + """``usage`` absent on the response == provider didn't report — + still short-circuits to avoid writing a zero-valued row.""" + resp = _build_completion(usage=None) + persist = AsyncMock(return_value=0) + with ( + patch( + "backend.copilot.service.persist_and_record_usage", + new=persist, + ), + patch( + "backend.copilot.service.config", + MagicMock( + base_url="https://openrouter.ai/api/v1", + title_model="x", + ), + ), + ): + await _record_title_generation_cost( + response=resp, user_id="u", session_id="s" + ) + persist.assert_not_awaited() + + @pytest.mark.asyncio + async def test_tokens_without_cost_still_records(self): + """Tokens present but ``cost`` missing (non-OR route) still + records a row so token counts are captured — ``cost_usd=None`` + is accepted by ``persist_and_record_usage``.""" + usage = CompletionUsage(prompt_tokens=8, completion_tokens=2, total_tokens=10) + resp = _build_completion(usage=usage) + persist = AsyncMock(return_value=0) + with ( + patch( + "backend.copilot.service.persist_and_record_usage", + new=persist, + ), + patch( + "backend.copilot.service.config", + MagicMock(base_url=None, title_model="m"), + ), + ): + await _record_title_generation_cost( + response=resp, user_id="u", session_id="s" + ) + persist.assert_awaited_once() + assert persist.await_args.kwargs["cost_usd"] is None + assert persist.await_args.kwargs["prompt_tokens"] == 8 + + +class TestUpdateTitleAsync: + """``_update_title_async`` runs title persistence and cost recording + as independent best-effort steps — a failure in one does NOT + cancel the other.""" + + @pytest.mark.asyncio + async def test_title_success_cost_success(self): + resp = _build_completion(usage=_usage_with_cost(0.0001)) + gen = AsyncMock(return_value=("My Title", resp)) + update = AsyncMock(return_value=True) + record = AsyncMock() + with ( + patch( + "backend.copilot.service._generate_session_title", + new=gen, + ), + patch( + "backend.copilot.service.update_session_title", + new=update, + ), + patch( + "backend.copilot.service._record_title_generation_cost", + new=record, + ), + ): + await _update_title_async("sess-1", "hello", user_id="u1") + + update.assert_awaited_once_with("sess-1", "u1", "My Title", only_if_empty=True) + record.assert_awaited_once() + assert record.await_args.kwargs["response"] is resp + assert record.await_args.kwargs["user_id"] == "u1" + assert record.await_args.kwargs["session_id"] == "sess-1" + + @pytest.mark.asyncio + async def test_title_persist_fails_but_cost_still_recorded(self): + resp = _build_completion(usage=_usage_with_cost(0.0001)) + gen = AsyncMock(return_value=("Title", resp)) + update = AsyncMock(side_effect=RuntimeError("prisma boom")) + record = AsyncMock() + with ( + patch( + "backend.copilot.service._generate_session_title", + new=gen, + ), + patch( + "backend.copilot.service.update_session_title", + new=update, + ), + patch( + "backend.copilot.service._record_title_generation_cost", + new=record, + ), + ): + # Must NOT raise — persist failure is swallowed. + await _update_title_async("sess-2", "msg", user_id="u") + + update.assert_awaited_once() + record.assert_awaited_once() + + @pytest.mark.asyncio + async def test_cost_record_fails_but_title_was_persisted(self): + resp = _build_completion(usage=_usage_with_cost(0.0001)) + gen = AsyncMock(return_value=("Title", resp)) + update = AsyncMock(return_value=True) + record = AsyncMock(side_effect=RuntimeError("cost record boom")) + with ( + patch( + "backend.copilot.service._generate_session_title", + new=gen, + ), + patch( + "backend.copilot.service.update_session_title", + new=update, + ), + patch( + "backend.copilot.service._record_title_generation_cost", + new=record, + ), + ): + # Must NOT raise — cost-recording failure is swallowed. + await _update_title_async("sess-3", "msg", user_id="u") + + update.assert_awaited_once() + record.assert_awaited_once() + + @pytest.mark.asyncio + async def test_no_user_id_skips_title_persist_but_records_cost(self): + """Anonymous sessions skip the user-scoped title write, but we + still paid for the LLM call — cost recording runs regardless.""" + resp = _build_completion(usage=_usage_with_cost(0.0001)) + gen = AsyncMock(return_value=("Title", resp)) + update = AsyncMock() + record = AsyncMock() + with ( + patch( + "backend.copilot.service._generate_session_title", + new=gen, + ), + patch( + "backend.copilot.service.update_session_title", + new=update, + ), + patch( + "backend.copilot.service._record_title_generation_cost", + new=record, + ), + ): + await _update_title_async("sess-4", "msg", user_id=None) + + update.assert_not_awaited() + record.assert_awaited_once() + + @pytest.mark.asyncio + async def test_generation_returns_none_response_skips_cost(self): + """``_generate_session_title`` swallows exceptions and returns + ``(None, None)`` — no response means no cost to record.""" + gen = AsyncMock(return_value=(None, None)) + update = AsyncMock() + record = AsyncMock() + with ( + patch( + "backend.copilot.service._generate_session_title", + new=gen, + ), + patch( + "backend.copilot.service.update_session_title", + new=update, + ), + patch( + "backend.copilot.service._record_title_generation_cost", + new=record, + ), + ): + await _update_title_async("sess-5", "msg", user_id="u") + + update.assert_not_awaited() + record.assert_not_awaited() + + @pytest.mark.asyncio + async def test_empty_title_with_response_still_records_cost(self): + """Title came back empty but we still paid for the LLM call — + cost recording runs even though the title write is skipped.""" + resp = _build_completion(usage=_usage_with_cost(0.0001)) + gen = AsyncMock(return_value=(None, resp)) + update = AsyncMock() + record = AsyncMock() + with ( + patch( + "backend.copilot.service._generate_session_title", + new=gen, + ), + patch( + "backend.copilot.service.update_session_title", + new=update, + ), + patch( + "backend.copilot.service._record_title_generation_cost", + new=record, + ), + ): + await _update_title_async("sess-6", "msg", user_id="u") + + update.assert_not_awaited() + record.assert_awaited_once() + + +class TestGenerateSessionTitle: + """``_generate_session_title`` returns ``(title, response)`` — the + caller owns both the persist and the cost-record decisions.""" + + @pytest.mark.asyncio + async def test_valid_response_returns_cleaned_title_and_response(self): + # Code strips whitespace, then strips ``"'`` — whitespace inside + # the quotes survives on purpose (titles like ``My Agent`` read + # better than ``MyAgent``). Test keeps the outer quotes + inner + # whitespace distinct so the ordering is pinned. + resp = _build_completion(content='"Clean Me" ') + client = MagicMock() + client.chat.completions.create = AsyncMock(return_value=resp) + with patch( + "backend.copilot.service._get_openai_client", + return_value=client, + ): + title, response = await _generate_session_title( + "first message", user_id="u", session_id="s" + ) + assert title == "Clean Me" + assert response is resp + + @pytest.mark.asyncio + async def test_long_title_truncated_with_ellipsis(self): + """Titles >50 chars get truncated to 47 + '...'.""" + long_title = "A" * 80 + resp = _build_completion(content=long_title) + client = MagicMock() + client.chat.completions.create = AsyncMock(return_value=resp) + with patch( + "backend.copilot.service._get_openai_client", + return_value=client, + ): + title, _ = await _generate_session_title("x", user_id=None) + assert title is not None + assert len(title) == 50 + assert title.endswith("...") + + @pytest.mark.asyncio + async def test_empty_choices_returns_none_title_with_response(self): + """No ``choices`` on the response (shouldn't happen per SDK + typing) must not raise IndexError — response is preserved so the + caller can still record the paid-for cost.""" + resp = _build_completion(choices=[]) + client = MagicMock() + client.chat.completions.create = AsyncMock(return_value=resp) + with patch( + "backend.copilot.service._get_openai_client", + return_value=client, + ): + title, response = await _generate_session_title("x") + assert title is None + assert response is resp + + @pytest.mark.asyncio + async def test_missing_message_returns_none_title(self): + """A choice whose ``.message`` is absent produces a None title + but the response still lands on the caller.""" + fake_choice = SimpleNamespace(message=None) + fake_response = SimpleNamespace(choices=[fake_choice]) + client = MagicMock() + client.chat.completions.create = AsyncMock(return_value=fake_response) + with patch( + "backend.copilot.service._get_openai_client", + return_value=client, + ): + title, response = await _generate_session_title("x") + assert title is None + assert response is fake_response + + @pytest.mark.asyncio + async def test_llm_call_raises_returns_none_none(self): + """Network / API errors on the create call are swallowed; + ``(None, None)`` ensures the caller skips both title and cost + without crashing the background task.""" + client = MagicMock() + client.chat.completions.create = AsyncMock( + side_effect=RuntimeError("connection reset") + ) + with patch( + "backend.copilot.service._get_openai_client", + return_value=client, + ): + title, response = await _generate_session_title("x") + assert title is None + assert response is None + + @pytest.mark.asyncio + async def test_create_receives_usage_include_extra_body(self): + """PR adds ``usage: {'include': True}`` so OpenRouter embeds the + real billed cost into the final usage chunk.""" + resp = _build_completion(content="Title") + client = MagicMock() + client.chat.completions.create = AsyncMock(return_value=resp) + with patch( + "backend.copilot.service._get_openai_client", + return_value=client, + ): + await _generate_session_title( + "hello world", user_id="user-abc", session_id="sess-abc" + ) + client.chat.completions.create.assert_awaited_once() + extra_body = client.chat.completions.create.await_args.kwargs["extra_body"] + assert extra_body["usage"] == {"include": True} + assert extra_body["user"] == "user-abc" + assert extra_body["session_id"] == "sess-abc" diff --git a/autogpt_platform/backend/backend/copilot/stream_registry.py b/autogpt_platform/backend/backend/copilot/stream_registry.py index 424964e075..79deadacc0 100644 --- a/autogpt_platform/backend/backend/copilot/stream_registry.py +++ b/autogpt_platform/backend/backend/copilot/stream_registry.py @@ -485,9 +485,11 @@ async def subscribe_to_session( subscriber_queue: asyncio.Queue[StreamBaseResponse] = asyncio.Queue() stream_key = _get_turn_stream_key(session.turn_id) - # Step 1: Replay messages from Redis Stream + # Replay batch capped by ``stream_replay_count``. xread_start = time.perf_counter() - messages = await redis.xread({stream_key: last_message_id}, block=None, count=1000) + messages = await redis.xread( + {stream_key: last_message_id}, block=None, count=config.stream_replay_count + ) xread_time = (time.perf_counter() - xread_start) * 1000 logger.info( f"[TIMING] Redis xread (replay) took {xread_time:.1f}ms, status={session_status}", @@ -1024,8 +1026,8 @@ async def get_active_session( # Check if session is stale (running beyond tool timeout + buffer). # Auto-complete it to prevent infinite polling loops. - # Synchronous tools can run up to COPILOT_CONSUMER_TIMEOUT_SECONDS (1 hour), - # so we add a 5-minute buffer to avoid false positives during legitimate operations. + # A turn can legitimately run up to COPILOT_CONSUMER_TIMEOUT_SECONDS, so we + # add a 5-minute buffer to avoid false positives during legitimate operations. created_at_str = meta.get("created_at") if created_at_str: try: diff --git a/autogpt_platform/backend/backend/copilot/tools/__init__.py b/autogpt_platform/backend/backend/copilot/tools/__init__.py index 3fa0be2933..f954abe973 100644 --- a/autogpt_platform/backend/backend/copilot/tools/__init__.py +++ b/autogpt_platform/backend/backend/copilot/tools/__init__.py @@ -1,7 +1,8 @@ from __future__ import annotations import logging -from typing import TYPE_CHECKING, Any +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any, Literal from openai.types.chat import ChatCompletionToolParam @@ -10,7 +11,6 @@ from backend.copilot.tracking import track_tool_called from .add_understanding import AddUnderstandingTool from .agent_browser import BrowserActTool, BrowserNavigateTool, BrowserScreenshotTool from .agent_output import AgentOutputTool -from .ask_question import AskQuestionTool from .base import BaseTool from .bash_exec import BashExecTool from .connect_integration import ConnectIntegrationTool @@ -44,6 +44,7 @@ from .run_block import RunBlockTool from .run_mcp_tool import RunMCPToolTool from .run_sub_session import RunSubSessionTool from .search_docs import SearchDocsTool +from .todo_write import TodoWriteTool from .validate_agent import ValidateAgentGraphTool from .web_fetch import WebFetchTool from .web_search import WebSearchTool @@ -63,7 +64,6 @@ logger = logging.getLogger(__name__) # Single source of truth for all tools TOOL_REGISTRY: dict[str, BaseTool] = { "add_understanding": AddUnderstandingTool(), - "ask_question": AskQuestionTool(), "create_agent": CreateAgentTool(), "customize_agent": CustomizeAgentTool(), "decompose_goal": DecomposeGoalTool(), @@ -88,6 +88,7 @@ TOOL_REGISTRY: dict[str, BaseTool] = { "continue_run_block": ContinueRunBlockTool(), "run_sub_session": RunSubSessionTool(), "get_sub_session_result": GetSubSessionResultTool(), + "TodoWrite": TodoWriteTool(), "run_mcp_tool": RunMCPToolTool(), "get_mcp_guide": GetMCPGuideTool(), "view_agent_output": AgentOutputTool(), @@ -123,15 +124,45 @@ find_agent_tool = TOOL_REGISTRY["find_agent"] run_agent_tool = TOOL_REGISTRY["run_agent"] -def get_available_tools() -> list[ChatCompletionToolParam]: +# Capability groups a tool may belong to. The service layer can hide all +# tools in a group when the backing capability isn't available to this user +# (e.g. Graphiti memory behind a feature flag), so the model doesn't reach +# for tools whose backend is off and then hit opaque runtime errors. Add +# a new group by extending ``ToolGroup`` and registering its members in +# ``TOOL_GROUPS`` below. +ToolGroup = Literal["graphiti"] + +TOOL_GROUPS: dict[str, ToolGroup] = { + "memory_store": "graphiti", + "memory_search": "graphiti", + "memory_forget_search": "graphiti", + "memory_forget_confirm": "graphiti", +} + + +def tool_names_in_groups(groups: Iterable[ToolGroup]) -> frozenset[str]: + """Return the set of tool short-names belonging to any of *groups*.""" + group_set = frozenset(groups) + return frozenset(name for name, g in TOOL_GROUPS.items() if g in group_set) + + +def get_available_tools( + *, + disabled_groups: Iterable[ToolGroup] = (), +) -> list[ChatCompletionToolParam]: """Return OpenAI tool schemas for tools available in the current environment. Called per-request so that env-var or binary availability is evaluated fresh each time (e.g. browser_* tools are excluded when agent-browser - CLI is not installed). + CLI is not installed). Tools belonging to any *disabled_groups* are + also filtered out — use this to hide capability-gated tools (e.g. + ``graphiti`` when the memory backend is off for the current user). """ + hidden = tool_names_in_groups(disabled_groups) return [ - tool.as_openai_tool() for tool in TOOL_REGISTRY.values() if tool.is_available + tool.as_openai_tool() + for name, tool in TOOL_REGISTRY.items() + if tool.is_available and name not in hidden ] diff --git a/autogpt_platform/backend/backend/copilot/tools/helpers.py b/autogpt_platform/backend/backend/copilot/tools/helpers.py index 6c25e79188..9de94cb2f2 100644 --- a/autogpt_platform/backend/backend/copilot/tools/helpers.py +++ b/autogpt_platform/backend/backend/copilot/tools/helpers.py @@ -181,7 +181,9 @@ async def execute_block( # (e.g., "42" → 42, string booleans → bool, enum defaults applied). coerce_inputs_to_schema(input_data, block.input_schema) outputs: dict[str, list[Any]] = defaultdict(list) - async for output_name, output_data in simulate_block(block, input_data): + async for output_name, output_data in simulate_block( + block, input_data, user_id=user_id + ): outputs[output_name].append(output_data) # simulator signals internal failure via ("error", "[SIMULATOR ERROR …]") sim_error = outputs.get("error", []) diff --git a/autogpt_platform/backend/backend/copilot/tools/models.py b/autogpt_platform/backend/backend/copilot/tools/models.py index b54d14d073..51172288dd 100644 --- a/autogpt_platform/backend/backend/copilot/tools/models.py +++ b/autogpt_platform/backend/backend/copilot/tools/models.py @@ -91,6 +91,9 @@ class ResponseType(str, Enum): MEMORY_FORGET_CANDIDATES = "memory_forget_candidates" MEMORY_FORGET_CONFIRM = "memory_forget_confirm" + # Planning + TODO_WRITE = "todo_write" + # Base response model class ToolResponseBase(BaseModel): @@ -605,11 +608,13 @@ class WebSearchResponse(ToolResponseBase): type: ResponseType = ResponseType.WEB_SEARCH query: str + # Web-grounded synthesised answer the search provider wrote from + # fresh page content. The LLM caller should read this directly + # instead of re-fetching each citation URL — many sites are + # bot-protected and ``web_fetch`` won't get through. Empty string + # when the provider returned only citations. + answer: str = "" results: list[WebSearchResult] = Field(default_factory=list) - # Backend-reported usage for this call (copied from Anthropic's - # ``usage.server_tool_use``). Surfaces as metadata for frontend - # debug panels but is also what drives rate-limit / cost tracking - # via ``persist_and_record_usage(provider="anthropic")``. search_requests: int = 0 @@ -896,3 +901,36 @@ class MemoryForgetConfirmResponse(ToolResponseBase): type: ResponseType = ResponseType.MEMORY_FORGET_CONFIRM deleted_uuids: list[str] = Field(default_factory=list) failed_uuids: list[str] = Field(default_factory=list) + + +# --- Planning --- + + +class TodoItem(BaseModel): + """One entry in a ``TodoWrite`` checklist. + + Mirrors the schema used by Claude Code's built-in ``TodoWrite`` tool so + the frontend's ``GenericTool`` accordion renders baseline-emitted todos + identically to SDK-emitted ones. + """ + + content: str = Field(description="Imperative description of the task.") + activeForm: str = Field( + description="Present-continuous form shown while the task is running.", + ) + status: Literal["pending", "in_progress", "completed"] = Field( + default="pending", + ) + + +class TodoWriteResponse(ToolResponseBase): + """Ack returned by ``TodoWrite``. + + The tool is effectively stateless — the authoritative task list lives in + the assistant's latest tool-call arguments, which are replayed from the + transcript on each turn. The tool output only needs to confirm that the + update was accepted so the model can proceed. + """ + + type: ResponseType = ResponseType.TODO_WRITE + todos: list[TodoItem] = Field(default_factory=list) diff --git a/autogpt_platform/backend/backend/copilot/tools/test_dry_run.py b/autogpt_platform/backend/backend/copilot/tools/test_dry_run.py index 1f71c837cf..fc44a57c86 100644 --- a/autogpt_platform/backend/backend/copilot/tools/test_dry_run.py +++ b/autogpt_platform/backend/backend/copilot/tools/test_dry_run.py @@ -237,7 +237,7 @@ async def test_execute_block_dry_run_skips_real_execution(): mock_block = make_mock_block() mock_block.execute = AsyncMock() # should NOT be called - async def fake_simulate(block, input_data): + async def fake_simulate(block, input_data, **_kwargs): yield "result", "simulated" # Patching at helpers.simulate_block works because helpers.py imports @@ -267,7 +267,7 @@ async def test_execute_block_dry_run_response_format(): """Dry-run response should look like a normal success (no dry-run signal to LLM).""" mock_block = make_mock_block() - async def fake_simulate(block, input_data): + async def fake_simulate(block, input_data, **_kwargs): yield "result", "simulated" with patch( @@ -331,7 +331,7 @@ async def test_execute_block_real_execution_unchanged(): # Just verify simulate_block is NOT called. simulate_called = False - async def fake_simulate(block, input_data): + async def fake_simulate(block, input_data, **_kwargs): nonlocal simulate_called simulate_called = True yield "result", "should not happen" @@ -455,7 +455,7 @@ async def test_execute_block_dry_run_no_empty_error_from_simulator(): """ mock_block = make_mock_block() - async def fake_simulate(block, input_data): + async def fake_simulate(block, input_data, **_kwargs): # Simulator now omits empty error pins at source yield "result", "simulated output" @@ -485,7 +485,7 @@ async def test_execute_block_dry_run_keeps_nonempty_error_pin(): """Dry-run should keep the 'error' pin when it contains a real error message.""" mock_block = make_mock_block() - async def fake_simulate(block, input_data): + async def fake_simulate(block, input_data, **_kwargs): yield "result", "" yield "error", "API rate limit exceeded" @@ -515,7 +515,7 @@ async def test_execute_block_dry_run_message_includes_completed_status(): """Dry-run message should clearly indicate COMPLETED status.""" mock_block = make_mock_block() - async def fake_simulate(block, input_data): + async def fake_simulate(block, input_data, **_kwargs): yield "result", "simulated" with patch( @@ -541,7 +541,7 @@ async def test_execute_block_dry_run_simulator_error_returns_error_response(): """When simulate_block yields a SIMULATOR ERROR tuple, execute_block returns ErrorResponse.""" mock_block = make_mock_block() - async def fake_simulate_error(block, input_data): + async def fake_simulate_error(block, input_data, **_kwargs): yield ( "error", "[SIMULATOR ERROR — NOT A BLOCK FAILURE] No LLM client available (missing OpenAI/OpenRouter API key).", diff --git a/autogpt_platform/backend/backend/copilot/tools/todo_write.py b/autogpt_platform/backend/backend/copilot/tools/todo_write.py new file mode 100644 index 0000000000..6047281e10 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/todo_write.py @@ -0,0 +1,120 @@ +"""Task-list tool for baseline copilot mode. + +Mirrors the schema and UX of Claude Code's built-in ``TodoWrite`` tool so +the frontend's generic tool renderer draws baseline-emitted checklists the +same way it draws SDK-emitted ones. The tool is stateless: the model's +latest ``todos`` argument IS the canonical list, replayed from transcript +on subsequent turns. + +Baseline needs this as a platform tool because OpenAI-compatible providers +(Kimi, GPT, Grok, Gemini) do not ship a built-in equivalent. The SDK path +continues to use the CLI's native ``TodoWrite`` — the MCP-wrapped version +of this tool is filtered out of SDK's allowed_tools list (see +``sdk/tool_adapter.py``) to avoid name shadowing. +""" + +from __future__ import annotations + +import logging +from typing import Any + +from backend.copilot.model import ChatSession + +from .base import BaseTool +from .models import ErrorResponse, TodoItem, TodoWriteResponse, ToolResponseBase + +logger = logging.getLogger(__name__) + + +class TodoWriteTool(BaseTool): + """Maintain a step-by-step task checklist visible to the user.""" + + @property + def name(self) -> str: + # Capitalised to match the frontend's switch on ``"TodoWrite"`` + # (see ``copilot/tools/GenericTool/helpers.ts``). + return "TodoWrite" + + @property + def description(self) -> str: + return ( + "Plan and track multi-step work as a visible checklist. Send " + "the full list every call; exactly one item in_progress at a time." + ) + + @property + def parameters(self) -> dict[str, Any]: + return { + "type": "object", + "properties": { + "todos": { + "type": "array", + "description": "Full updated task list (not a delta).", + "items": { + "type": "object", + "properties": { + "content": { + "type": "string", + "description": "Imperative (e.g. 'Run tests').", + }, + "activeForm": { + "type": "string", + "description": ( + "Present-continuous (e.g. 'Running tests')." + ), + }, + "status": { + "type": "string", + "enum": ["pending", "in_progress", "completed"], + "default": "pending", + }, + }, + "required": ["content", "activeForm"], + }, + }, + }, + "required": ["todos"], + } + + async def _execute( + self, + user_id: str | None, + session: ChatSession, + **kwargs: Any, + ) -> ToolResponseBase: + del user_id + raw_todos = kwargs.get("todos") + if raw_todos is None: + return ErrorResponse( + message="`todos` is required.", + session_id=session.session_id, + ) + if not isinstance(raw_todos, list): + return ErrorResponse( + message="`todos` must be an array.", + session_id=session.session_id, + ) + + try: + parsed = [TodoItem.model_validate(item) for item in raw_todos] + except Exception as exc: + return ErrorResponse( + message=f"Invalid todo entry: {exc}", + session_id=session.session_id, + ) + + in_progress = sum(1 for t in parsed if t.status == "in_progress") + if in_progress > 1: + return ErrorResponse( + message=( + "Only one todo may be 'in_progress' at a time " + f"(found {in_progress})." + ), + session_id=session.session_id, + ) + + return TodoWriteResponse( + message="Task list updated.", + session_id=session.session_id, + todos=parsed, + ) diff --git a/autogpt_platform/backend/backend/copilot/tools/todo_write_test.py b/autogpt_platform/backend/backend/copilot/tools/todo_write_test.py new file mode 100644 index 0000000000..60c14f81d0 --- /dev/null +++ b/autogpt_platform/backend/backend/copilot/tools/todo_write_test.py @@ -0,0 +1,125 @@ +"""Tests for TodoWriteTool.""" + +import pytest + +from backend.copilot.model import ChatSession +from backend.copilot.tools.models import ErrorResponse, TodoItem, TodoWriteResponse +from backend.copilot.tools.todo_write import TodoWriteTool + + +@pytest.fixture() +def tool() -> TodoWriteTool: + return TodoWriteTool() + + +@pytest.fixture() +def session() -> ChatSession: + return ChatSession.new(user_id="test-user", dry_run=False) + + +@pytest.mark.asyncio +async def test_valid_todo_list(tool: TodoWriteTool, session: ChatSession): + result = await tool._execute( + user_id=None, + session=session, + todos=[ + { + "content": "Write tests", + "activeForm": "Writing tests", + "status": "pending", + }, + { + "content": "Ship PR", + "activeForm": "Shipping PR", + "status": "in_progress", + }, + ], + ) + + assert isinstance(result, TodoWriteResponse) + assert result.session_id == session.session_id + assert len(result.todos) == 2 + assert result.todos[0] == TodoItem( + content="Write tests", + activeForm="Writing tests", + status="pending", + ) + assert result.todos[1].status == "in_progress" + + +@pytest.mark.asyncio +async def test_default_status_is_pending(tool: TodoWriteTool, session: ChatSession): + result = await tool._execute( + user_id=None, + session=session, + todos=[{"content": "Write tests", "activeForm": "Writing tests"}], + ) + + assert isinstance(result, TodoWriteResponse) + assert result.todos[0].status == "pending" + + +@pytest.mark.asyncio +async def test_missing_todos_returns_error(tool: TodoWriteTool, session: ChatSession): + result = await tool._execute(user_id=None, session=session) + + assert isinstance(result, ErrorResponse) + assert "todos" in result.message.lower() + + +@pytest.mark.asyncio +async def test_non_list_todos_returns_error(tool: TodoWriteTool, session: ChatSession): + result = await tool._execute(user_id=None, session=session, todos="not a list") + + assert isinstance(result, ErrorResponse) + + +@pytest.mark.asyncio +async def test_invalid_item_returns_error(tool: TodoWriteTool, session: ChatSession): + # Missing required `activeForm` field. + result = await tool._execute( + user_id=None, + session=session, + todos=[{"content": "Missing active form"}], + ) + + assert isinstance(result, ErrorResponse) + + +@pytest.mark.asyncio +async def test_multiple_in_progress_rejected(tool: TodoWriteTool, session: ChatSession): + """Exactly one item should be in_progress at a time — SDK parity rule.""" + result = await tool._execute( + user_id=None, + session=session, + todos=[ + { + "content": "A", + "activeForm": "Doing A", + "status": "in_progress", + }, + { + "content": "B", + "activeForm": "Doing B", + "status": "in_progress", + }, + ], + ) + + assert isinstance(result, ErrorResponse) + assert "in_progress" in result.message + + +def test_openai_schema_shape(tool: TodoWriteTool): + schema = tool.as_openai_tool() + assert schema["type"] == "function" + assert schema["function"]["name"] == "TodoWrite" + params = schema["function"]["parameters"] + assert params["required"] == ["todos"] + items = params["properties"]["todos"]["items"] + assert items["required"] == ["content", "activeForm"] + assert items["properties"]["status"]["enum"] == [ + "pending", + "in_progress", + "completed", + ] diff --git a/autogpt_platform/backend/backend/copilot/tools/tool_schema_test.py b/autogpt_platform/backend/backend/copilot/tools/tool_schema_test.py index 0dc22b8cfa..7ed9fe5ad2 100644 --- a/autogpt_platform/backend/backend/copilot/tools/tool_schema_test.py +++ b/autogpt_platform/backend/backend/copilot/tools/tool_schema_test.py @@ -21,7 +21,9 @@ from backend.copilot.tools import TOOL_REGISTRY # response shape carries) and the dry_run description. Keeps the # regression gate effective while accepting a deliberate ~120-token # spend on LLM-decision-critical copy. -_CHAR_BUDGET = 34_000 +# Bumped to 34000 to accommodate decompose_goal tool + web_search + +# TodoWrite tool descriptions. +_CHAR_BUDGET = 35_000 @pytest.fixture(scope="module") @@ -111,9 +113,10 @@ def test_total_schema_char_budget() -> None: This locks in the 34% token reduction from #12398 and prevents future description bloat from eroding the gains. Uses character count with a - ~4 chars/token heuristic (budget of 32000 chars ≈ 8000 tokens). - Character count is tokenizer-agnostic — no dependency on GPT or Claude - tokenizers — while still providing a stable regression gate. + ~4 chars/token heuristic; see ``_CHAR_BUDGET`` above for the current + value and its change history. Character count is tokenizer-agnostic + — no dependency on GPT or Claude tokenizers — while still providing a + stable regression gate. """ schemas = [tool.as_openai_tool() for tool in TOOL_REGISTRY.values()] serialized = json.dumps(schemas) @@ -123,3 +126,60 @@ def test_total_schema_char_budget() -> None: f"exceeding budget of {_CHAR_BUDGET} chars (~{_CHAR_BUDGET // 4} tokens). " f"Description bloat detected — trim descriptions or raise the budget intentionally." ) + + +# ── Capability-group filtering (ToolGroup / disabled_groups) ─────────── + + +def test_get_available_tools_hides_graphiti_when_disabled() -> None: + """When the ``graphiti`` group is disabled, the memory_* tools must + not appear in the OpenAI schema list — they'd just confuse the model + and produce opaque runtime errors.""" + from backend.copilot.tools import get_available_tools + + memory_tool_names = { + "memory_store", + "memory_search", + "memory_forget_search", + "memory_forget_confirm", + } + + default = {t["function"]["name"] for t in get_available_tools()} + assert memory_tool_names.issubset( + default + ), "sanity: memory_* tools should be present when no groups disabled" + + filtered = { + t["function"]["name"] for t in get_available_tools(disabled_groups=["graphiti"]) + } + assert not ( + memory_tool_names & filtered + ), f"graphiti disabled but memory_* still present: {memory_tool_names & filtered}" + # Non-graphiti tools stay visible. + assert "find_block" in filtered + assert "TodoWrite" in filtered + + +def test_get_copilot_tool_names_hides_graphiti_when_disabled() -> None: + """Same invariant for the SDK tool-name list.""" + from backend.copilot.sdk.tool_adapter import MCP_TOOL_PREFIX, get_copilot_tool_names + + memory_mcp_names = { + f"{MCP_TOOL_PREFIX}memory_store", + f"{MCP_TOOL_PREFIX}memory_search", + f"{MCP_TOOL_PREFIX}memory_forget_search", + f"{MCP_TOOL_PREFIX}memory_forget_confirm", + } + + default = set(get_copilot_tool_names()) + assert memory_mcp_names.issubset(default) + + filtered = set(get_copilot_tool_names(disabled_groups=["graphiti"])) + assert not ( + memory_mcp_names & filtered + ), f"graphiti disabled but memory MCP names still present: {memory_mcp_names & filtered}" + # E2B path stays consistent. + filtered_e2b = set( + get_copilot_tool_names(use_e2b=True, disabled_groups=["graphiti"]) + ) + assert not (memory_mcp_names & filtered_e2b) diff --git a/autogpt_platform/backend/backend/copilot/tools/web_search.py b/autogpt_platform/backend/backend/copilot/tools/web_search.py index feb999d4d6..ed54868917 100644 --- a/autogpt_platform/backend/backend/copilot/tools/web_search.py +++ b/autogpt_platform/backend/backend/copilot/tools/web_search.py @@ -1,29 +1,66 @@ -"""Web search tool — wraps Anthropic's server-side ``web_search`` beta. +"""Web search tool — Perplexity Sonar via OpenRouter. -Single entry point for web search on both SDK and baseline paths. The -``web_search_20250305`` tool is server-side on Anthropic, so we call -the Messages API directly regardless of which LLM invoked the copilot -tool — OpenRouter can't proxy server-side tool execution. +One provider, two tiers, one billing path: + +* ``deep=False`` (default) — ``perplexity/sonar``. Searches the web + natively and returns citation annotations in a single inference pass. +* ``deep=True`` — ``perplexity/sonar-deep-research``. Multi-step + agentic research; slower and costlier. + +Why Sonar and not the ``openrouter:web_search`` server tool + dispatch +model? The server tool feeds all search-result page content back into +the dispatch model for a second inference pass — one observed call was +74K input tokens at Gemini Flash rates, billing $0.072. Sonar +searches natively in one pass, returns annotations typed as +``ChatCompletionMessage.annotations`` in ``openai.types``, and at +$1 / MTok base pricing lands ~$0.01 / call at our default shape. + +``resp.usage.cost`` carries the real billed value via OpenRouter's +``include: true`` extension; the value flows through +``persist_and_record_usage(provider='open_router')`` into the daily / +weekly microdollar rate-limit counter on the same rails as every other +OpenRouter turn — no separate provider ledger line, no estimation +drift. ``_extract_cost_usd`` mirrors the baseline service's +``_extract_usage_cost`` logic; keep the two in sync if one changes. """ import logging +import math from typing import Any -from anthropic import AsyncAnthropic +from openai import AsyncOpenAI +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletion +from backend.copilot.config import ChatConfig from backend.copilot.model import ChatSession from backend.copilot.token_tracking import persist_and_record_usage -from backend.util.settings import Settings from .base import BaseTool from .models import ErrorResponse, ToolResponseBase, WebSearchResponse, WebSearchResult logger = logging.getLogger(__name__) -_WEB_SEARCH_DISPATCH_MODEL = "claude-haiku-4-5" -_MAX_DISPATCH_TOKENS = 512 +_chat_config = ChatConfig() + +_QUICK_MODEL = "perplexity/sonar" +# Sonar base can emit up to ~4K output; cap at the provider ceiling so the +# model stops when the answer is complete rather than when our budget trips. +_QUICK_MAX_TOKENS = 4096 + +_DEEP_MODEL = "perplexity/sonar-deep-research" +# Deep runs can produce long structured writeups — ~4x the quick ceiling +# is enough headroom for multi-source comparisons without uncapping. +_DEEP_MAX_TOKENS = _QUICK_MAX_TOKENS * 4 + _DEFAULT_MAX_RESULTS = 5 _HARD_MAX_RESULTS = 20 +_SNIPPET_MAX_CHARS = 500 + +# OpenRouter-specific extra_body flag that embeds the real generation +# cost into the response usage object. Same dict shape the baseline +# service uses — keep the two aligned. +_OPENROUTER_INCLUDE_USAGE_COST: dict[str, Any] = {"usage": {"include": True}} class WebSearchTool(BaseTool): @@ -36,9 +73,13 @@ class WebSearchTool(BaseTool): @property def description(self) -> str: return ( - "Search the web for live info (news, recent docs). Returns " - "{title, url, snippet}; use web_fetch to deep-dive a URL. " - "Prefer one targeted query over many reformulations." + "Search the web for live info (news, recent docs). Returns a " + "synthesised answer grounded in fresh page content plus " + "{title, url, snippet} citations — read the answer first " + "before reaching for web_fetch. Set deep=true when the user " + "asks for research / comparison / in-depth analysis; leave " + "deep=false for quick fact lookups. Prefer one targeted " + "query over many reformulations." ) @property @@ -58,6 +99,18 @@ class WebSearchTool(BaseTool): ), "default": _DEFAULT_MAX_RESULTS, }, + "deep": { + "type": "boolean", + "description": ( + "Only set true when the user EXPLICITLY asks for " + "research, comparison, or in-depth investigation " + "across many sources — it is ~100x more expensive " + "and much slower than a normal search. Default " + "false; do not flip it for ordinary fact lookups " + "or fresh-news questions." + ), + "default": False, + }, }, "required": ["query"], } @@ -68,7 +121,7 @@ class WebSearchTool(BaseTool): @property def is_available(self) -> bool: - return bool(Settings().secrets.anthropic_api_key) + return bool(_chat_config.api_key and _chat_config.base_url) async def _execute( self, @@ -76,6 +129,7 @@ class WebSearchTool(BaseTool): session: ChatSession, query: str = "", max_results: int = _DEFAULT_MAX_RESULTS, + deep: bool = False, **kwargs: Any, ) -> ToolResponseBase: query = (query or "").strip() @@ -93,44 +147,35 @@ class WebSearchTool(BaseTool): max_results = _DEFAULT_MAX_RESULTS max_results = max(1, min(max_results, _HARD_MAX_RESULTS)) - api_key = Settings().secrets.anthropic_api_key - if not api_key: + if not _chat_config.api_key or not _chat_config.base_url: return ErrorResponse( message=( "Web search is unavailable — the deployment has no " - "Anthropic API key configured." + "OpenRouter credentials configured." ), error="web_search_not_configured", session_id=session_id, ) - client = AsyncAnthropic(api_key=api_key) + client = AsyncOpenAI( + api_key=_chat_config.api_key, base_url=_chat_config.base_url + ) + model_used = _DEEP_MODEL if deep else _QUICK_MODEL + max_tokens = _DEEP_MAX_TOKENS if deep else _QUICK_MAX_TOKENS + try: - resp = await client.messages.create( - model=_WEB_SEARCH_DISPATCH_MODEL, - max_tokens=_MAX_DISPATCH_TOKENS, - tools=[ - { - "type": "web_search_20250305", - "name": "web_search", - "max_uses": 1, - } - ], - messages=[ - { - "role": "user", - "content": ( - f"Use the web_search tool exactly once with the " - f"query {query!r} and then stop. Do not " - f"summarise — the caller parses the raw " - f"tool_result." - ), - } - ], + resp = await client.chat.completions.create( + model=model_used, + max_tokens=max_tokens, + messages=[{"role": "user", "content": query}], + extra_body=_OPENROUTER_INCLUDE_USAGE_COST, ) except Exception as exc: logger.warning( - "[web_search] Anthropic call failed for query=%r: %s", query, exc + "[web_search] OpenRouter call failed (deep=%s) for query=%r: %s", + deep, + query, + exc, ) return ErrorResponse( message=f"Web search failed: {exc}", @@ -138,20 +183,20 @@ class WebSearchTool(BaseTool): session_id=session_id, ) - results, search_requests = _extract_results(resp, limit=max_results) + answer = _extract_answer(resp) + results = _extract_results(resp, limit=max_results) + cost_usd = _extract_cost_usd(resp.usage) - cost_usd = _estimate_cost_usd(resp, search_requests=search_requests) try: - usage = getattr(resp, "usage", None) await persist_and_record_usage( session=session, user_id=user_id, - prompt_tokens=getattr(usage, "input_tokens", 0) or 0, - completion_tokens=getattr(usage, "output_tokens", 0) or 0, + prompt_tokens=resp.usage.prompt_tokens if resp.usage else 0, + completion_tokens=resp.usage.completion_tokens if resp.usage else 0, log_prefix="[web_search]", cost_usd=cost_usd, - model=_WEB_SEARCH_DISPATCH_MODEL, - provider="anthropic", + model=model_used, + provider="open_router", ) except Exception as exc: logger.warning("[web_search] usage tracking failed: %s", exc) @@ -159,66 +204,92 @@ class WebSearchTool(BaseTool): return WebSearchResponse( message=f"Found {len(results)} result(s) for {query!r}.", query=query, + answer=answer, results=results, - search_requests=search_requests, + search_requests=1 if results else 0, session_id=session_id, ) -def _extract_results(resp: Any, *, limit: int) -> tuple[list[WebSearchResult], int]: - """Pull results + server-side request count from an Anthropic response.""" - results: list[WebSearchResult] = [] - search_requests = 0 +def _extract_answer(resp: ChatCompletion) -> str: + """Return the synthesised answer text from Sonar's response. - for block in getattr(resp, "content", []) or []: - btype = getattr(block, "type", None) - if btype == "web_search_tool_result": - content = getattr(block, "content", []) or [] - for item in content: - if getattr(item, "type", None) != "web_search_result": - continue - if len(results) >= limit: - break - # Anthropic's ``web_search_result`` exposes only - # ``title``/``url``/``page_age`` plus an opaque - # ``encrypted_content`` blob that is meant for citation - # round-tripping, not for display — it is base64-ish - # binary and would show as gibberish if surfaced to the - # model or the frontend. There is no plain-text snippet - # field in the current beta; callers get the readable - # text via the model's ``text`` blocks with citations, - # not via this list. Leave ``snippet`` empty. - results.append( - WebSearchResult( - title=getattr(item, "title", "") or "", - url=getattr(item, "url", "") or "", - snippet="", - page_age=getattr(item, "page_age", None), - ) - ) - - usage = getattr(resp, "usage", None) - server_tool_use = getattr(usage, "server_tool_use", None) if usage else None - if server_tool_use is not None: - search_requests = getattr(server_tool_use, "web_search_requests", 0) or 0 - - return results, search_requests + Sonar reads every page it cites and writes a web-grounded synthesis + into ``choices[0].message.content`` on the same call we pay for. + Surfacing it saves the agent from re-fetching citation URLs — many + are bot-protected and ``web_fetch`` can't reach them. + """ + if not resp.choices: + return "" + content = resp.choices[0].message.content + return content or "" -# Update when Anthropic revises pricing. -_COST_PER_SEARCH_USD = 0.010 # $10 per 1,000 web_search requests -_HAIKU_INPUT_USD_PER_MTOK = 1.0 -_HAIKU_OUTPUT_USD_PER_MTOK = 5.0 +def _extract_results(resp: ChatCompletion, *, limit: int) -> list[WebSearchResult]: + """Pull ``url_citation`` annotations from the response. + + Shared across both tiers — OpenRouter normalises the annotation + schema across Perplexity's sonar models into + ``Annotation.url_citation`` (typed in ``openai.types.chat``). The + ``content`` snippet is an OpenRouter extension on the otherwise- + typed ``AnnotationURLCitation``; pydantic stashes unknown fields in + ``model_extra``, which we read there rather than via ``getattr``. + """ + if not resp.choices: + return [] + annotations = resp.choices[0].message.annotations or [] + out: list[WebSearchResult] = [] + for ann in annotations: + if len(out) >= limit: + break + if ann.type != "url_citation": + continue + citation = ann.url_citation + extras = citation.model_extra or {} + snippet_raw = extras.get("content") + snippet = (snippet_raw or "")[:_SNIPPET_MAX_CHARS] if snippet_raw else "" + out.append( + WebSearchResult( + title=citation.title, + url=citation.url, + snippet=snippet, + page_age=None, + ) + ) + return out -def _estimate_cost_usd(resp: Any, *, search_requests: int) -> float: - """Per-search fee × count + Haiku dispatch tokens.""" - usage = getattr(resp, "usage", None) - input_tokens = getattr(usage, "input_tokens", 0) if usage else 0 - output_tokens = getattr(usage, "output_tokens", 0) if usage else 0 +def _extract_cost_usd(usage: CompletionUsage | None) -> float | None: + """Return the provider-reported USD cost off the response usage. - search_cost = search_requests * _COST_PER_SEARCH_USD - inference_cost = (input_tokens / 1_000_000) * _HAIKU_INPUT_USD_PER_MTOK + ( - output_tokens / 1_000_000 - ) * _HAIKU_OUTPUT_USD_PER_MTOK - return round(search_cost + inference_cost, 6) + OpenRouter piggybacks a ``cost`` field on the OpenAI-compatible + usage object when the request body includes + ``usage: {"include": True}``. The OpenAI SDK's typed + ``CompletionUsage`` does not declare it, so we read it off + ``model_extra`` (the pydantic v2 container for extras) to keep + access fully typed — no ``getattr``. Mirrors the baseline service + ``_extract_usage_cost``; keep the two in sync. + + Returns ``None`` when the field is absent, null, non-numeric, + non-finite, or negative. Invalid values log at error level because + they indicate a provider bug worth chasing; plain absences are + silent so the caller can dedupe the "missing cost" warning. + """ + if usage is None: + return None + extras = usage.model_extra or {} + if "cost" not in extras: + return None + raw = extras["cost"] + if raw is None: + logger.error("[web_search] usage.cost is present but null") + return None + try: + val = float(raw) + except (TypeError, ValueError): + logger.error("[web_search] usage.cost is not numeric: %r", raw) + return None + if not math.isfinite(val) or val < 0: + logger.error("[web_search] usage.cost is non-finite or negative: %r", val) + return None + return val diff --git a/autogpt_platform/backend/backend/copilot/tools/web_search_test.py b/autogpt_platform/backend/backend/copilot/tools/web_search_test.py index 3d516f295a..7b341e3c44 100644 --- a/autogpt_platform/backend/backend/copilot/tools/web_search_test.py +++ b/autogpt_platform/backend/backend/copilot/tools/web_search_test.py @@ -1,212 +1,289 @@ """Tests for the ``web_search`` copilot tool. -Covers the result extractor + cost estimator as pure units (fed with -synthetic Anthropic response objects), plus light integration tests that -mock ``AsyncAnthropic.messages.create`` and confirm the handler plumbs -through to ``persist_and_record_usage`` with the right provider tag. +Covers the annotation extractor + cost extractor as pure units (fed +with real ``openai`` SDK types — no duck-typed ``SimpleNamespace`` +stand-ins), plus integration tests exercising both the quick +(``perplexity/sonar``) and deep (``perplexity/sonar-deep-research``) +paths — mocking ``AsyncOpenAI.chat.completions.create`` and confirming +the handler plumbs through to ``persist_and_record_usage`` with +``provider='open_router'`` and the real ``usage.cost`` value. """ -from types import SimpleNamespace +from typing import Any from unittest.mock import AsyncMock, patch import pytest +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletion +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message import ( + Annotation, + AnnotationURLCitation, + ChatCompletionMessage, +) from backend.copilot.model import ChatSession -from .models import ErrorResponse, WebSearchResponse, WebSearchResult +from .models import ErrorResponse, WebSearchResponse from .web_search import ( - _COST_PER_SEARCH_USD, WebSearchTool, - _estimate_cost_usd, + _extract_answer, + _extract_cost_usd, _extract_results, ) -def _fake_anthropic_response( +def _usage( *, - results: list[dict] | None = None, - search_requests: int = 1, - input_tokens: int = 120, - output_tokens: int = 40, -) -> SimpleNamespace: - """Build a synthetic Anthropic Messages response. + prompt_tokens: int = 120, + completion_tokens: int = 40, + cost: object = 0.01, +) -> CompletionUsage: + """Typed ``CompletionUsage`` with OpenRouter's ``cost`` extension + parked in ``model_extra`` — the same channel the production code + reads it from. ``model_construct`` preserves unknown fields; + ``model_validate`` would drop them because ``CompletionUsage`` + treats the schema as strict.""" + payload: dict[str, Any] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + if cost is not None: + payload["cost"] = cost + return CompletionUsage.model_construct(None, **payload) - Matches the shape produced by ``client.messages.create`` when the - response includes a ``web_search_tool_result`` content block and - ``usage.server_tool_use.web_search_requests`` on the turn meter. - """ - content = [] - if results is not None: - content.append( - SimpleNamespace( - type="web_search_tool_result", - content=[ - SimpleNamespace( - type="web_search_result", - title=r.get("title", "untitled"), - url=r.get("url", ""), - encrypted_content=r.get("snippet", ""), - page_age=r.get("page_age"), - ) - for r in results - ], - ) + +def _citation(*, url: str, title: str, content: str | None = None) -> Annotation: + """Typed ``Annotation`` for a URL citation. ``content`` is an + OpenRouter extension on the otherwise-typed schema — goes into + ``url_citation.model_extra`` when model_construct preserves it.""" + payload: dict[str, Any] = { + "url": url, + "title": title, + "start_index": 0, + "end_index": len(title), + } + if content is not None: + payload["content"] = content + url_citation = AnnotationURLCitation.model_construct(None, **payload) + return Annotation(type="url_citation", url_citation=url_citation) + + +def _fake_response( + *, + citations: list[dict] | None = None, + answer: str = "ok", + prompt_tokens: int = 120, + completion_tokens: int = 40, + cost: object = 0.01, +) -> ChatCompletion: + """Build a typed ``ChatCompletion`` shaped like an OpenRouter + response — typed end-to-end so the production code's attribute + access runs under the real SDK types in tests.""" + annotations = [ + _citation( + url=c.get("url", ""), + title=c.get("title", "untitled"), + content=c.get("content"), ) - usage = SimpleNamespace( - input_tokens=input_tokens, - output_tokens=output_tokens, - server_tool_use=SimpleNamespace(web_search_requests=search_requests), + for c in citations or [] + ] + message = ChatCompletionMessage.model_construct( + None, + role="assistant", + content=answer, + annotations=annotations, + ) + choice = Choice.model_construct( + None, + index=0, + finish_reason="stop", + message=message, + ) + return ChatCompletion.model_construct( + None, + id="cmpl-test", + object="chat.completion", + created=0, + model="perplexity/sonar", + choices=[choice], + usage=_usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + cost=cost, + ), ) - return SimpleNamespace(content=content, usage=usage) class TestExtractResults: - """The extractor is the only Anthropic-response-shape contact point; - pin its behaviour so an API shape change surfaces here first.""" + """Pin the annotation shape — a schema bump in the OpenAI SDK or + OpenRouter surfaces here first. Same extractor serves both tiers + because OpenRouter normalises annotations across models.""" - def test_extracts_title_url_page_age_and_drops_encrypted_snippet(self): - # Anthropic's ``web_search_result`` ships an opaque - # ``encrypted_content`` blob that is not safe to surface — - # the extractor must drop it (snippet=="") regardless of - # whether the blob is non-empty. - resp = _fake_anthropic_response( - results=[ + def test_extracts_title_url_and_content_snippet(self): + resp = _fake_response( + citations=[ { "title": "Kimi K2.6 launch", "url": "https://example.com/kimi", - "snippet": "EiJjbGF1ZGUtZW5jcnlwdGVkLWJsb2I=", - "page_age": "1 day", + "content": "Moonshot released K2.6 on 2026-04-20.", }, { "title": "OpenRouter pricing", "url": "https://openrouter.ai/moonshotai/kimi-k2.6", - "snippet": "", }, ] ) - out, requests = _extract_results(resp, limit=10) - assert requests == 1 + out = _extract_results(resp, limit=10) assert len(out) == 2 assert out[0].title == "Kimi K2.6 launch" assert out[0].url == "https://example.com/kimi" - assert out[0].snippet == "" - assert out[0].page_age == "1 day" + assert out[0].snippet.startswith("Moonshot released") + # Missing ``content`` extension → empty snippet rather than crash. assert out[1].snippet == "" def test_limit_caps_returned_results(self): - resp = _fake_anthropic_response( - results=[{"title": f"r{i}", "url": f"https://e/{i}"} for i in range(10)] + resp = _fake_response( + citations=[{"title": f"r{i}", "url": f"https://e/{i}"} for i in range(10)] ) - out, _ = _extract_results(resp, limit=3) + out = _extract_results(resp, limit=3) assert len(out) == 3 assert [r.title for r in out] == ["r0", "r1", "r2"] - def test_missing_content_returns_empty(self): - resp = SimpleNamespace(content=[], usage=None) - out, requests = _extract_results(resp, limit=10) - assert out == [] - assert requests == 0 - - def test_non_search_blocks_are_ignored(self): - resp = SimpleNamespace( - content=[ - SimpleNamespace(type="text", text="Here's what I found..."), - SimpleNamespace( - type="web_search_tool_result", - content=[ - SimpleNamespace( - type="web_search_result", - title="real", - url="https://real.example", - encrypted_content="body", - page_age=None, - ) - ], - ), - ], - usage=None, + def test_missing_choices_returns_empty(self): + resp = ChatCompletion.model_construct( + None, + id="cmpl-test", + object="chat.completion", + created=0, + model="perplexity/sonar", + choices=[], + usage=_usage(), ) - out, _ = _extract_results(resp, limit=10) - assert len(out) == 1 and out[0].title == "real" + assert _extract_results(resp, limit=10) == [] - -class TestEstimateCostUsd: - """Pin the per-search fee + Haiku inference math — the pricing - constants in ``web_search.py`` are hard-coded (no live lookup) so a - drift between Anthropic's schedule and our constants must surface - in this test for the next reader to notice.""" - - def test_zero_searches_still_charges_inference(self): - resp = _fake_anthropic_response(results=[], search_requests=0) - cost = _estimate_cost_usd(resp, search_requests=0) - # Haiku at 1000 input / 5000 output tokens = tiny but non-zero. - assert 0 < cost < 0.001 - - def test_single_search_fee_dominates(self): - resp = _fake_anthropic_response( - results=[{"title": "x", "url": "https://e"}], - search_requests=1, - input_tokens=100, - output_tokens=20, + def test_extract_answer_returns_message_content(self): + resp = _fake_response( + answer="Sonar's synthesised, web-grounded answer text.", + citations=[{"title": "t", "url": "https://e"}], ) - cost = _estimate_cost_usd(resp, search_requests=1) - # ~$0.010 search + trivial inference — total still ~1 cent. - assert cost >= _COST_PER_SEARCH_USD - assert cost < _COST_PER_SEARCH_USD + 0.001 + assert _extract_answer(resp) == "Sonar's synthesised, web-grounded answer text." - def test_three_searches_linear_in_count(self): - resp = _fake_anthropic_response( - results=[], search_requests=3, input_tokens=0, output_tokens=0 + def test_extract_answer_returns_empty_when_no_choices(self): + resp = ChatCompletion.model_construct( + None, + id="cmpl-test", + object="chat.completion", + created=0, + model="perplexity/sonar", + choices=[], + usage=_usage(), ) - cost = _estimate_cost_usd(resp, search_requests=3) - assert cost == pytest.approx(3 * _COST_PER_SEARCH_USD) + assert _extract_answer(resp) == "" + + def test_snippet_clamped_to_max_chars(self): + long_body = "x" * 5000 + resp = _fake_response( + citations=[{"title": "t", "url": "https://e", "content": long_body}] + ) + out = _extract_results(resp, limit=1) + assert len(out) == 1 + assert len(out[0].snippet) == 500 + + +class TestExtractCostUsd: + """Read real ``usage.cost`` via typed ``model_extra`` — no + hard-coded rates, so a future provider price change is reflected + automatically. Error handling mirrors the baseline service's + ``_extract_usage_cost``.""" + + def test_returns_cost_value(self): + assert _extract_cost_usd(_usage(cost=0.023456)) == pytest.approx(0.023456) + + def test_returns_none_when_usage_missing(self): + assert _extract_cost_usd(None) is None + + def test_returns_none_when_cost_field_missing(self): + assert _extract_cost_usd(_usage(cost=None)) is None + + def test_returns_none_when_cost_is_explicit_null(self): + usage = CompletionUsage.model_construct( + None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=None + ) + assert _extract_cost_usd(usage) is None + + def test_returns_none_when_cost_is_negative(self): + usage = CompletionUsage.model_construct( + None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=-1.0 + ) + assert _extract_cost_usd(usage) is None + + def test_accepts_numeric_string(self): + usage = CompletionUsage.model_construct( + None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost="0.017" + ) + assert _extract_cost_usd(usage) == pytest.approx(0.017) class TestWebSearchToolDispatch: - """Lightweight integration test: mock the Anthropic client, confirm - the handler returns a ``WebSearchResponse`` and the usage tracker is - called with ``provider='anthropic'`` (not 'open_router', even on the - baseline path — server-side web_search bills Anthropic regardless of - the calling LLM's route).""" + """Integration test: mock the OpenAI client, confirm both paths + dispatch the right Sonar model + track cost.""" def _session(self) -> ChatSession: s = ChatSession.new("test-user", dry_run=False) s.session_id = "sess-1" return s - @pytest.mark.asyncio - async def test_returns_response_with_results_and_tracks_cost(self, monkeypatch): - fake_resp = _fake_anthropic_response( - results=[ - { - "title": "hello", - "url": "https://example.com", - "snippet": "greeting", - } - ], - search_requests=1, - ) - mock_client = type( + def _mock_client(self, fake_resp: ChatCompletion) -> Any: + return type( "MC", (), { - "messages": type( - "M", (), {"create": AsyncMock(return_value=fake_resp)} + "chat": type( + "C", + (), + { + "completions": type( + "CC", + (), + {"create": AsyncMock(return_value=fake_resp)}, + )() + }, )() }, )() - # Stub the Anthropic API key so ``is_available`` is True. + @pytest.mark.asyncio + async def test_quick_path_uses_sonar_base(self, monkeypatch): + fake_resp = _fake_response( + citations=[ + { + "title": "hello", + "url": "https://example.com", + "content": "greeting", + } + ], + answer="Kimi K2.6 launched 2026-04-20 [1].", + cost=0.01, + ) + mock_client = self._mock_client(fake_resp) + monkeypatch.setattr( - "backend.copilot.tools.web_search.Settings", - lambda: SimpleNamespace( - secrets=SimpleNamespace(anthropic_api_key="sk-test") - ), + "backend.copilot.tools.web_search._chat_config", + type( + "C", + (), + { + "api_key": "sk-test", + "base_url": "https://openrouter.ai/api/v1", + }, + )(), ) with ( patch( - "backend.copilot.tools.web_search.AsyncAnthropic", + "backend.copilot.tools.web_search.AsyncOpenAI", return_value=mock_client, ), patch( @@ -220,35 +297,88 @@ class TestWebSearchToolDispatch: session=self._session(), query="kimi k2.6 launch", max_results=5, + deep=False, ) assert isinstance(result, WebSearchResponse) - assert result.query == "kimi k2.6 launch" + assert result.answer == "Kimi K2.6 launched 2026-04-20 [1]." assert len(result.results) == 1 - assert isinstance(result.results[0], WebSearchResult) - assert result.search_requests == 1 + assert result.results[0].snippet == "greeting" + + create_call = mock_client.chat.completions.create.call_args + assert create_call.kwargs["model"] == "perplexity/sonar" + # Sonar searches natively — no server-tool extras. + assert create_call.kwargs["extra_body"] == {"usage": {"include": True}} - # Cost tracker must have been called with provider="anthropic". - assert mock_track.await_count == 1 kwargs = mock_track.await_args.kwargs - assert kwargs["provider"] == "anthropic" - assert kwargs["model"] == "claude-haiku-4-5" - assert kwargs["user_id"] == "u1" - assert kwargs["cost_usd"] >= _COST_PER_SEARCH_USD + assert kwargs["provider"] == "open_router" + assert kwargs["model"] == "perplexity/sonar" + assert kwargs["cost_usd"] == pytest.approx(0.01) @pytest.mark.asyncio - async def test_missing_api_key_returns_error_without_calling_anthropic( - self, monkeypatch - ): - monkeypatch.setattr( - "backend.copilot.tools.web_search.Settings", - lambda: SimpleNamespace(secrets=SimpleNamespace(anthropic_api_key="")), + async def test_deep_path_uses_sonar_deep_research(self, monkeypatch): + fake_resp = _fake_response( + citations=[ + { + "title": "deep find", + "url": "https://example.com/deep", + "content": "research body", + } + ], + cost=0.087, ) - anthropic_stub = AsyncMock() + mock_client = self._mock_client(fake_resp) + + monkeypatch.setattr( + "backend.copilot.tools.web_search._chat_config", + type( + "C", + (), + { + "api_key": "sk-test", + "base_url": "https://openrouter.ai/api/v1", + }, + )(), + ) + with ( patch( - "backend.copilot.tools.web_search.AsyncAnthropic", - return_value=anthropic_stub, + "backend.copilot.tools.web_search.AsyncOpenAI", + return_value=mock_client, + ), + patch( + "backend.copilot.tools.web_search.persist_and_record_usage", + new=AsyncMock(return_value=160), + ) as mock_track, + ): + tool = WebSearchTool() + result = await tool._execute( + user_id="u1", + session=self._session(), + query="research question", + deep=True, + ) + + assert isinstance(result, WebSearchResponse) + create_call = mock_client.chat.completions.create.call_args + assert create_call.kwargs["model"] == "perplexity/sonar-deep-research" + + kwargs = mock_track.await_args.kwargs + assert kwargs["provider"] == "open_router" + assert kwargs["model"] == "perplexity/sonar-deep-research" + assert kwargs["cost_usd"] == pytest.approx(0.087) + + @pytest.mark.asyncio + async def test_missing_credentials_returns_error(self, monkeypatch): + monkeypatch.setattr( + "backend.copilot.tools.web_search._chat_config", + type("C", (), {"api_key": "", "base_url": ""})(), + ) + openai_stub = AsyncMock() + with ( + patch( + "backend.copilot.tools.web_search.AsyncOpenAI", + return_value=openai_stub, ), patch( "backend.copilot.tools.web_search.persist_and_record_usage", @@ -264,21 +394,26 @@ class TestWebSearchToolDispatch: ) assert isinstance(result, ErrorResponse) assert result.error == "web_search_not_configured" - anthropic_stub.messages.create.assert_not_called() + openai_stub.chat.completions.create.assert_not_called() mock_track.assert_not_called() @pytest.mark.asyncio async def test_empty_query_rejected_without_api_call(self, monkeypatch): monkeypatch.setattr( - "backend.copilot.tools.web_search.Settings", - lambda: SimpleNamespace( - secrets=SimpleNamespace(anthropic_api_key="sk-test") - ), + "backend.copilot.tools.web_search._chat_config", + type( + "C", + (), + { + "api_key": "sk-test", + "base_url": "https://openrouter.ai/api/v1", + }, + )(), ) - anthropic_stub = AsyncMock() + openai_stub = AsyncMock() with patch( - "backend.copilot.tools.web_search.AsyncAnthropic", - return_value=anthropic_stub, + "backend.copilot.tools.web_search.AsyncOpenAI", + return_value=openai_stub, ): tool = WebSearchTool() result = await tool._execute( @@ -286,13 +421,13 @@ class TestWebSearchToolDispatch: ) assert isinstance(result, ErrorResponse) assert result.error == "missing_query" - anthropic_stub.messages.create.assert_not_called() + openai_stub.chat.completions.create.assert_not_called() class TestToolRegistryIntegration: """The tool must be registered under the ``web_search`` name so the MCP layer exposes it as ``mcp__copilot__web_search`` — which is - what the SDK path now dispatches to (see + what the SDK path dispatches to (see ``sdk/tool_adapter.py::SDK_DISALLOWED_TOOLS`` which blocks the CLI's native ``WebSearch`` in favour of the MCP route).""" diff --git a/autogpt_platform/backend/backend/copilot/transcript.py b/autogpt_platform/backend/backend/copilot/transcript.py index 5a46760dfd..468a02f796 100644 --- a/autogpt_platform/backend/backend/copilot/transcript.py +++ b/autogpt_platform/backend/backend/copilot/transcript.py @@ -195,16 +195,17 @@ def strip_stale_thinking_blocks(content: str) -> str: is_last_turn = ( last_asst_msg_id is not None and msg.get("id") == last_asst_msg_id ) or (last_asst_msg_id is None and i == last_asst_idx) - if ( - msg.get("role") == "assistant" - and not is_last_turn - and isinstance(msg.get("content"), list) - ): + if msg.get("role") == "assistant" and isinstance(msg.get("content"), list): content_blocks = msg["content"] + producing_model = msg.get("model") if isinstance(msg, dict) else None filtered = [ b for b in content_blocks - if not (isinstance(b, dict) and b.get("type") in _THINKING_BLOCK_TYPES) + if not _should_strip_thinking_block( + b, + is_last_turn=is_last_turn, + producing_model=producing_model, + ) ] if len(filtered) < len(content_blocks): stripped_count += len(content_blocks) - len(filtered) @@ -310,23 +311,30 @@ def strip_for_upload(content: str) -> str: if uid in reparented: needs_reserialize = True - # Strip stale thinking blocks from non-last assistant entries + # Strip stale thinking blocks from non-last assistant entries. + # Also strip *signature-less* thinking blocks from the last entry — + # those come from non-Anthropic providers (e.g. Kimi K2.6 via + # OpenRouter) and are rejected with ``Invalid `signature` in + # `thinking` block`` if a subsequent turn is dispatched to an + # Anthropic model that re-validates them. Anthropic-emitted + # thinking blocks always carry a non-empty ``signature`` field, so + # this filter is a no-op on Sonnet/Opus turns and only kicks in + # when the prior turn ran on a non-Anthropic vendor. if last_asst_idx is not None: msg = entry.get("message", {}) is_last_turn = ( last_asst_msg_id is not None and msg.get("id") == last_asst_msg_id ) or (last_asst_msg_id is None and i == last_asst_idx) - if ( - msg.get("role") == "assistant" - and not is_last_turn - and isinstance(msg.get("content"), list) - ): + if msg.get("role") == "assistant" and isinstance(msg.get("content"), list): content_blocks = msg["content"] + producing_model = msg.get("model") if isinstance(msg, dict) else None filtered = [ b for b in content_blocks - if not ( - isinstance(b, dict) and b.get("type") in _THINKING_BLOCK_TYPES + if not _should_strip_thinking_block( + b, + is_last_turn=is_last_turn, + producing_model=producing_model, ) ] if len(filtered) < len(content_blocks): @@ -951,6 +959,92 @@ ENTRY_TYPE_MESSAGE = "message" _THINKING_BLOCK_TYPES = frozenset({"thinking", "redacted_thinking"}) +def _is_anthropic_model(model: str | None) -> bool: + """True when *model* is an Anthropic-issued slug. + + Used to decide whether a thinking block's signature is + cryptographically valid for Anthropic replay. Non-Anthropic vendors + routed through OpenRouter's Anthropic-compat shim (Kimi K2.6, + DeepSeek, GPT-OSS) sometimes emit thinking blocks with a + placeholder signature — it passes a non-empty string check but + fails Anthropic's cryptographic validation, producing the opaque + ``Invalid signature in thinking block`` 400 on the next turn + whenever the model toggle switches to Sonnet/Opus. + """ + return isinstance(model, str) and model.startswith("anthropic/") + + +def _should_strip_thinking_block( + block: object, + *, + is_last_turn: bool, + producing_model: str | None = None, +) -> bool: + """Return True when *block* is a thinking block that should be removed + from a transcript entry before upload. + + Strip only when the block CAN'T be replayed safely. Never strip a + valid Anthropic-issued thinking block — it carries real reasoning + state that preserves context continuity on ``--resume``. + + Strip rules (first match wins): + + 1. **Non-Anthropic producer (any position)** — thinking blocks from + Kimi / DeepSeek / GPT-OSS via OpenRouter's Anthropic-compat shim + carry either no signature or a placeholder string that passes a + non-empty check but fails Anthropic's cryptographic validation. + Strip unconditionally; they also add low-value tokens to the + replay context. + 2. **Malformed ``thinking`` (any position, Anthropic producer, + empty signature)** — shouldn't happen in practice, but if the + signature is missing / empty the block can't be validated. + Safer to drop than to 400 the next turn. + 3. **Stale non-last entry with unknown producer** — when the + caller doesn't wire ``producing_model`` through (legacy paths / + older tests) we can't tell if the block is safe to keep; fall + back to the old behaviour of dropping non-last thinking blocks + to avoid replaying an unverifiable block to Anthropic. + + Preserved: + + * Anthropic ``thinking`` with non-empty signature — at any + position, last OR non-last. Keeping prior-turn reasoning + chains helps continuity on multi-round SDK resumes without any + risk of signature rejection. + * Anthropic ``redacted_thinking`` — carries an encrypted ``data`` + payload instead of a ``signature``; by design signature-less, + but Anthropic-issued and safely replayable. + """ + if not isinstance(block, dict): + return False + btype = block.get("type") + if btype not in _THINKING_BLOCK_TYPES: + return False + # Legacy call sites pass producing_model=None — preserve the old + # "strip-all-non-last-thinking" heuristic for those so we don't + # regress callers that haven't been updated. + if producing_model is None: + if not is_last_turn: + return True + if btype != "thinking": + return False + signature = block.get("signature") + return not (isinstance(signature, str) and signature) + # Non-Anthropic producer — strip at any position. These blocks + # CAN'T be cryptographically validated by Anthropic on replay. + if not _is_anthropic_model(producing_model): + return True + # Anthropic producer, redacted_thinking: always preserve — the + # ``data`` field is the signature analog. + if btype == "redacted_thinking": + return False + # Anthropic producer, ``thinking``: keep iff it has a real + # (non-empty) signature. Empty-signature Anthropic thinking + # shouldn't happen but guard against it anyway. + signature = block.get("signature") + return not (isinstance(signature, str) and signature) + + def _flatten_assistant_content(blocks: list) -> str: """Flatten assistant content blocks into a single plain-text string. diff --git a/autogpt_platform/backend/backend/copilot/transcript_test.py b/autogpt_platform/backend/backend/copilot/transcript_test.py index dde07a063e..96c7b3fc70 100644 --- a/autogpt_platform/backend/backend/copilot/transcript_test.py +++ b/autogpt_platform/backend/backend/copilot/transcript_test.py @@ -591,7 +591,16 @@ class TestStripForUpload: "role": "assistant", "id": "msg_new", "content": [ - {"type": "thinking", "thinking": "fresh thinking"}, + # Anthropic-style thinking block — has a signature so + # ``_should_strip_thinking_block`` preserves it on the + # last turn. Without the signature (e.g. emitted by + # Kimi K2.6 via OpenRouter) it would be stripped — see + # ``test_strips_signatureless_thinking_from_last_turn``. + { + "type": "thinking", + "thinking": "fresh thinking", + "signature": "anthropic-signed-blob", + }, {"type": "text", "text": "new answer"}, ], }, @@ -624,6 +633,224 @@ class TestStripForUpload: new_types = [b["type"] for b in new_content if isinstance(b, dict)] assert "thinking" in new_types # last assistant preserved + def test_strips_signatureless_thinking_from_last_turn(self): + """Kimi K2.6 (and other non-Anthropic OpenRouter providers) emit + thinking blocks without the Anthropic ``signature`` field. When + a subsequent advanced-tier toggle replays the transcript to Opus, + Anthropic's API rejects the signature-less block with ``Invalid + `signature` in `thinking` block`` — so strip_for_upload must drop + them from the LAST assistant entry too, not just stale ones.""" + user = { + "type": "user", + "uuid": "u1", + "parentUuid": "", + "message": {"role": "user", "content": "hi"}, + } + # Last (and only) assistant entry with a Kimi-shape thinking block + asst = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "u1", + "message": { + "role": "assistant", + "id": "msg_kimi", + "content": [ + # No ``signature`` field → non-Anthropic provider + {"type": "thinking", "thinking": "kimi reasoning"}, + {"type": "text", "text": "answer"}, + ], + }, + } + content = _make_jsonl(user, asst) + result = strip_for_upload(content) + entries = [json.loads(line) for line in result.strip().split("\n")] + asst_entry = next( + e for e in entries if e.get("message", {}).get("id") == "msg_kimi" + ) + types = [ + b["type"] for b in asst_entry["message"]["content"] if isinstance(b, dict) + ] + assert "thinking" not in types, ( + "Signature-less thinking block on last turn must be stripped " + "to prevent Anthropic API rejection on model-switch replay" + ) + assert "text" in types, "Text content must survive stripping" + + def test_strips_non_anthropic_thinking_with_placeholder_signature(self): + """OpenRouter's Anthropic-compat shim can emit thinking blocks + from non-Anthropic producers (Kimi K2.6, DeepSeek) with a + PLACEHOLDER signature string that passes the "non-empty string" + check but fails Anthropic's cryptographic validation on replay. + + Observed in session 864a55ba after model-toggle from standard + (Kimi) to advanced (Opus): the CLI session upload included a + thinking block with ``signature="ANTHROPIC_SHIM_PLACEHOLDER"`` + (or similar), Opus 4.7 rejected with 400 ``Invalid `signature` + in `thinking` block``. Fix: strip thinking blocks from the + LAST assistant turn whenever the producing ``model`` isn't an + ``anthropic/*`` slug, regardless of signature presence.""" + user = { + "type": "user", + "uuid": "u1", + "parentUuid": "", + "message": {"role": "user", "content": "hi"}, + } + asst = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "u1", + "message": { + "role": "assistant", + "id": "msg_kimi_shim", + "model": "moonshotai/kimi-k2.6-20260420", + "content": [ + # Placeholder signature — non-empty but cryptographically + # invalid for Anthropic. Legacy strip (signature-only) + # would KEEP this block. + { + "type": "thinking", + "thinking": "shimmed reasoning", + "signature": "PLACEHOLDER_SHIM_SIG_abc123", + }, + {"type": "text", "text": "answer"}, + ], + }, + } + content = _make_jsonl(user, asst) + result = strip_for_upload(content) + entries = [json.loads(line) for line in result.strip().split("\n")] + asst_entry = next( + e for e in entries if e.get("message", {}).get("id") == "msg_kimi_shim" + ) + types = [ + b["type"] for b in asst_entry["message"]["content"] if isinstance(b, dict) + ] + assert "thinking" not in types, ( + "Non-Anthropic thinking block must be stripped even when it " + "carries a placeholder signature — replay-to-Opus otherwise " + "400s with Invalid signature" + ) + assert "text" in types + + def test_preserves_anthropic_thinking_on_non_last_turn(self): + """Anthropic ``thinking`` blocks on NON-last turns carry real + reasoning state that helps context continuity on ``--resume``. + Keep them when the producing model is known-Anthropic with a + valid signature; strip only when we can't validate safely + (legacy callers with no model info — falls through to the + old stale-strip rule). + """ + user = { + "type": "user", + "uuid": "u1", + "parentUuid": "", + "message": {"role": "user", "content": "first"}, + } + asst1 = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "u1", + "message": { + "role": "assistant", + "id": "msg_opus_prev", + "model": "anthropic/claude-4.7-opus-20260416", + "content": [ + { + "type": "thinking", + "thinking": "first-turn reasoning", + "signature": "ANTHROPIC_SIG_1", + }, + {"type": "text", "text": "first answer"}, + ], + }, + } + user2 = { + "type": "user", + "uuid": "u2", + "parentUuid": "a1", + "message": {"role": "user", "content": "second"}, + } + asst2 = { + "type": "assistant", + "uuid": "a2", + "parentUuid": "u2", + "message": { + "role": "assistant", + "id": "msg_opus_last", + "model": "anthropic/claude-4.7-opus-20260416", + "content": [ + { + "type": "thinking", + "thinking": "last-turn reasoning", + "signature": "ANTHROPIC_SIG_2", + }, + {"type": "text", "text": "last answer"}, + ], + }, + } + content = _make_jsonl(user, asst1, user2, asst2) + result = strip_for_upload(content) + entries = [json.loads(line) for line in result.strip().split("\n")] + + # Prior Opus turn's thinking must survive — valid Anthropic + # block with signature. + prev = next( + e for e in entries if e.get("message", {}).get("id") == "msg_opus_prev" + ) + prev_types = [b["type"] for b in prev["message"]["content"]] + assert "thinking" in prev_types, ( + "Anthropic thinking block on a non-last turn must be " + "preserved — it carries real reasoning state" + ) + # Last turn's thinking also preserved. + last = next( + e for e in entries if e.get("message", {}).get("id") == "msg_opus_last" + ) + last_types = [b["type"] for b in last["message"]["content"]] + assert "thinking" in last_types + + def test_preserves_anthropic_thinking_with_valid_signature(self): + """Sanity: an Anthropic-issued thinking block with a real + signature on the last turn must NOT be stripped — Anthropic + requires value-identity on replay.""" + user = { + "type": "user", + "uuid": "u1", + "parentUuid": "", + "message": {"role": "user", "content": "hi"}, + } + asst = { + "type": "assistant", + "uuid": "a1", + "parentUuid": "u1", + "message": { + "role": "assistant", + "id": "msg_opus", + "model": "anthropic/claude-4.7-opus-20260416", + "content": [ + { + "type": "thinking", + "thinking": "reasoning", + "signature": "REAL_ANTHROPIC_SIG_blob", + }, + {"type": "text", "text": "answer"}, + ], + }, + } + content = _make_jsonl(user, asst) + result = strip_for_upload(content) + entries = [json.loads(line) for line in result.strip().split("\n")] + asst_entry = next( + e for e in entries if e.get("message", {}).get("id") == "msg_opus" + ) + types = [ + b["type"] for b in asst_entry["message"]["content"] if isinstance(b, dict) + ] + assert ( + "thinking" in types + ), "Anthropic-signed thinking on last turn must survive strip" + assert "text" in types + def test_empty_content(self): result = strip_for_upload("") # Empty string produces a single empty line after split, resulting in "\n" diff --git a/autogpt_platform/backend/backend/data/redis_client.py b/autogpt_platform/backend/backend/data/redis_client.py index f7d030c62b..e3675370e5 100644 --- a/autogpt_platform/backend/backend/data/redis_client.py +++ b/autogpt_platform/backend/backend/data/redis_client.py @@ -14,6 +14,21 @@ HOST = os.getenv("REDIS_HOST", "localhost") PORT = int(os.getenv("REDIS_PORT", "6379")) PASSWORD = os.getenv("REDIS_PASSWORD", None) +# Default socket timeouts so a wedged Redis endpoint can't hang callers +# indefinitely — long-running code paths (cluster_lock refresh in particular) +# rely on these to fail-fast instead of blocking on no-response TCP. Override +# via env if a specific deployment needs a different budget. +# +# 30s matches the convention in ``backend.data.rabbitmq`` and leaves ~6x +# headroom over the largest ``xread(block=5000)`` wait in stream_registry. +# The connect timeout is shorter (5s) because initial connects should be +# fast; a slow connect usually means the endpoint is genuinely unreachable. +SOCKET_TIMEOUT = float(os.getenv("REDIS_SOCKET_TIMEOUT", "30")) +SOCKET_CONNECT_TIMEOUT = float(os.getenv("REDIS_SOCKET_CONNECT_TIMEOUT", "5")) +# How often redis-py sends a PING on idle connections to detect half-open +# sockets; cheap and avoids waiting for the OS TCP keepalive (~2h default). +HEALTH_CHECK_INTERVAL = int(os.getenv("REDIS_HEALTH_CHECK_INTERVAL", "30")) + logger = logging.getLogger(__name__) @@ -24,6 +39,10 @@ def connect() -> Redis: port=PORT, password=PASSWORD, decode_responses=True, + socket_timeout=SOCKET_TIMEOUT, + socket_connect_timeout=SOCKET_CONNECT_TIMEOUT, + socket_keepalive=True, + health_check_interval=HEALTH_CHECK_INTERVAL, ) c.ping() return c @@ -46,6 +65,10 @@ async def connect_async() -> AsyncRedis: port=PORT, password=PASSWORD, decode_responses=True, + socket_timeout=SOCKET_TIMEOUT, + socket_connect_timeout=SOCKET_CONNECT_TIMEOUT, + socket_keepalive=True, + health_check_interval=HEALTH_CHECK_INTERVAL, ) await c.ping() return c diff --git a/autogpt_platform/backend/backend/executor/manager.py b/autogpt_platform/backend/backend/executor/manager.py index 87ee3cbc44..0cf0ea0936 100644 --- a/autogpt_platform/backend/backend/executor/manager.py +++ b/autogpt_platform/backend/backend/executor/manager.py @@ -366,7 +366,7 @@ async def execute_node( try: if execution_context.dry_run and _dry_run_input is None: - block_iter = simulate_block(node_block, input_data) + block_iter = simulate_block(node_block, input_data, user_id=user_id) else: block_iter = node_block.execute(input_data, **extra_exec_kwargs) diff --git a/autogpt_platform/backend/backend/executor/simulator.py b/autogpt_platform/backend/backend/executor/simulator.py index 7d514fb2b9..5d4770a46c 100644 --- a/autogpt_platform/backend/backend/executor/simulator.py +++ b/autogpt_platform/backend/backend/executor/simulator.py @@ -31,21 +31,31 @@ Inspired by https://github.com/Significant-Gravitas/agent-simulator import inspect import json import logging +import math from collections.abc import AsyncGenerator from typing import Any +from openai.types import CompletionUsage + from backend.blocks.agent import AgentExecutorBlock from backend.blocks.io import AgentInputBlock, AgentOutputBlock from backend.blocks.orchestrator import OrchestratorBlock +from backend.copilot.token_tracking import persist_and_record_usage from backend.util.clients import get_openai_client logger = logging.getLogger(__name__) -# Default simulator model — Gemini 2.5 Flash via OpenRouter (fast, cheap, good at -# JSON generation). Configurable via ChatConfig.simulation_model -# (CHAT_SIMULATION_MODEL env var). -_DEFAULT_SIMULATOR_MODEL = "google/gemini-2.5-flash" +# Default simulator model — Gemini 2.5 Flash-Lite via OpenRouter. Same provider +# as Flash ($0.10 / $0.40 per MTok vs $0.30 / $1.20 — ~3× cheaper) with JSON-mode +# reliability that's more than enough for dry-run shape-matching. Configurable +# via ChatConfig.simulation_model (CHAT_SIMULATION_MODEL env var). +_DEFAULT_SIMULATOR_MODEL = "google/gemini-2.5-flash-lite" + +# OpenRouter-specific extra_body flag that embeds the real generation cost on +# the response usage object. Same shape used by the baseline copilot service +# and web_search tool — keep the three aligned. +_OPENROUTER_INCLUDE_USAGE_COST: dict[str, Any] = {"usage": {"include": True}} def _simulator_model() -> str: @@ -105,10 +115,15 @@ async def _call_llm_for_simulation( user_prompt: str, *, label: str = "simulate", + user_id: str | None = None, ) -> dict[str, Any]: """Send a simulation prompt to the LLM and return the parsed JSON dict. - Handles client acquisition, retries on invalid JSON, and logging. + Handles client acquisition, retries on invalid JSON, logging, and platform + cost tracking. The dry-run simulator calls OpenRouter on the platform's + key rather than a user's own API credentials, so every successful call is + recorded against the triggering ``user_id``'s rate-limit counter via + ``persist_and_record_usage`` (same rails as every copilot turn). Raises: RuntimeError: If no LLM client is available. @@ -133,6 +148,7 @@ async def _call_llm_for_simulation( {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], + extra_body=_OPENROUTER_INCLUDE_USAGE_COST, ) if not response.choices: raise ValueError("LLM returned empty choices array") @@ -141,13 +157,21 @@ async def _call_llm_for_simulation( if not isinstance(parsed, dict): raise ValueError(f"LLM returned non-object JSON: {raw[:200]}") - logger.debug( - "simulate(%s): attempt=%d tokens=%s/%s", - label, - attempt + 1, - getattr(getattr(response, "usage", None), "prompt_tokens", "?"), - getattr(getattr(response, "usage", None), "completion_tokens", "?"), - ) + usage = response.usage + if usage is not None: + logger.debug( + "simulate(%s): attempt=%d tokens=%d/%d", + label, + attempt + 1, + usage.prompt_tokens, + usage.completion_tokens, + ) + else: + logger.debug( + "simulate(%s): attempt=%d usage unavailable", label, attempt + 1 + ) + + await _track_simulator_cost(usage=usage, user_id=user_id, model=model) return parsed except (json.JSONDecodeError, ValueError) as e: @@ -174,6 +198,69 @@ async def _call_llm_for_simulation( raise ValueError(msg) +def _extract_cost_usd(usage: CompletionUsage | None) -> float | None: + """Return the provider-reported USD cost on the response usage object. + + OpenRouter attaches a ``cost`` field to the OpenAI-compatible usage object + when the request body includes ``usage: {"include": True}``. The typed + ``CompletionUsage`` does not declare it, so we read it off ``model_extra`` + (pydantic v2's container for extras) to keep access fully typed — no + ``getattr``. Mirrors ``backend.copilot.tools.web_search._extract_cost_usd`` + and ``backend.copilot.baseline.service._extract_usage_cost``; keep the + three in sync. + """ + if usage is None: + return None + extras = usage.model_extra or {} + if "cost" not in extras: + return None + raw = extras["cost"] + if raw is None: + logger.error("[simulator] usage.cost is present but null") + return None + try: + val = float(raw) + except (TypeError, ValueError): + logger.error("[simulator] usage.cost is not numeric: %r", raw) + return None + if not math.isfinite(val) or val < 0: + logger.error("[simulator] usage.cost is non-finite or negative: %r", val) + return None + return val + + +async def _track_simulator_cost( + *, + usage: CompletionUsage | None, + user_id: str | None, + model: str, +) -> None: + """Record platform cost for a single simulator LLM call. + + The simulator runs outside a copilot ``ChatSession`` — pass ``session=None`` + so ``persist_and_record_usage`` skips the session append but still charges + the user's rate-limit counter and writes a ``PlatformCostLog`` entry. No + user_id means no tracking (e.g. in-process tests that don't plumb one + through); rate-limit accounting silently no-ops in that case. + """ + if usage is None: + return + cost_usd = _extract_cost_usd(usage) + try: + await persist_and_record_usage( + session=None, + user_id=user_id, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + log_prefix="[simulator]", + cost_usd=cost_usd, + model=model, + provider="open_router", + ) + except Exception as exc: + logger.warning("[simulator] usage tracking failed: %s", exc) + + # --------------------------------------------------------------------------- # Prompt builders # --------------------------------------------------------------------------- @@ -393,12 +480,18 @@ def _default_for_input_result(result_schema: dict[str, Any], name: str | None) - async def simulate_block( block: Any, input_data: dict[str, Any], + *, + user_id: str | None = None, ) -> AsyncGenerator[tuple[str, Any], None]: """Simulate block execution using an LLM. All block types (including MCPToolBlock) use the same generic LLM prompt which includes the block's run() source code for accurate simulation. + ``user_id`` is threaded through to platform cost tracking — every dry-run + LLM call hits the platform's OpenRouter key and is charged against the + triggering user's rate-limit counter, same rails as copilot turns. + Note: callers should check ``prepare_dry_run(block, input_data)`` first. OrchestratorBlock and AgentExecutorBlock execute for real in dry-run mode (see manager.py). @@ -462,7 +555,9 @@ async def simulate_block( label = getattr(block, "name", "?") try: - parsed = await _call_llm_for_simulation(system_prompt, user_prompt, label=label) + parsed = await _call_llm_for_simulation( + system_prompt, user_prompt, label=label, user_id=user_id + ) # Track which pins were yielded so we can fill in missing required # ones afterwards — downstream nodes connected to unyielded pins diff --git a/autogpt_platform/backend/backend/executor/simulator_test.py b/autogpt_platform/backend/backend/executor/simulator_test.py index 8590d9bdbf..d331f1ebc1 100644 --- a/autogpt_platform/backend/backend/executor/simulator_test.py +++ b/autogpt_platform/backend/backend/executor/simulator_test.py @@ -5,6 +5,7 @@ Covers: - Input/output block passthrough - prepare_dry_run routing - simulate_block output-pin filling + - Default simulator model + OpenRouter cost tracking """ from __future__ import annotations @@ -13,8 +14,14 @@ from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletion +from openai.types.chat.chat_completion import Choice +from openai.types.chat.chat_completion_message import ChatCompletionMessage from backend.executor.simulator import ( + _DEFAULT_SIMULATOR_MODEL, + _extract_cost_usd, _truncate_input_values, _truncate_value, build_simulation_prompt, @@ -511,3 +518,217 @@ class TestSimulateBlockPassthrough: assert len(outputs) == 1 assert outputs[0][0] == "error" assert "No client" in outputs[0][1] + + +# --------------------------------------------------------------------------- +# Default model + OpenRouter cost tracking +# --------------------------------------------------------------------------- + + +def _sim_usage( + *, + prompt_tokens: int = 1200, + completion_tokens: int = 300, + cost: object = 0.000157, +) -> CompletionUsage: + """Typed ``CompletionUsage`` carrying OpenRouter's ``cost`` extension + via ``model_extra`` — same pattern as + ``copilot/tools/web_search_test.py::_usage``. ``model_construct`` + preserves unknown fields; ``model_validate`` would drop them.""" + payload: dict[str, Any] = { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + } + if cost is not None: + payload["cost"] = cost + return CompletionUsage.model_construct(None, **payload) + + +def _sim_completion(*, content: str, usage: CompletionUsage) -> ChatCompletion: + """Typed ``ChatCompletion`` shaped like an OpenRouter simulator + response so the production code runs under real SDK types.""" + message = ChatCompletionMessage.model_construct( + None, role="assistant", content=content + ) + choice = Choice.model_construct( + None, index=0, finish_reason="stop", message=message + ) + return ChatCompletion.model_construct( + None, + id="cmpl-sim", + object="chat.completion", + created=0, + model=_DEFAULT_SIMULATOR_MODEL, + choices=[choice], + usage=usage, + ) + + +class TestDefaultSimulatorModel: + """Pin the default model — anyone flipping this without a cost review + trips the test.""" + + def test_default_is_flash_lite(self) -> None: + assert _DEFAULT_SIMULATOR_MODEL == "google/gemini-2.5-flash-lite" + + +class TestExtractCostUsd: + """Provider-reported USD cost via typed ``model_extra`` — mirrors + ``copilot.tools.web_search._extract_cost_usd`` and + ``copilot.baseline.service._extract_usage_cost``.""" + + def test_returns_cost_value(self) -> None: + assert _extract_cost_usd(_sim_usage(cost=0.000157)) == pytest.approx(0.000157) + + def test_returns_none_when_usage_missing(self) -> None: + assert _extract_cost_usd(None) is None + + def test_returns_none_when_cost_field_missing(self) -> None: + assert _extract_cost_usd(_sim_usage(cost=None)) is None + + def test_returns_none_when_cost_is_explicit_null(self) -> None: + usage = CompletionUsage.model_construct( + None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=None + ) + assert _extract_cost_usd(usage) is None + + def test_returns_none_when_cost_is_negative(self) -> None: + usage = CompletionUsage.model_construct( + None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost=-0.5 + ) + assert _extract_cost_usd(usage) is None + + def test_accepts_numeric_string(self) -> None: + usage = CompletionUsage.model_construct( + None, prompt_tokens=0, completion_tokens=0, total_tokens=0, cost="0.017" + ) + assert _extract_cost_usd(usage) == pytest.approx(0.017) + + +class TestSimulatorCostTracking: + """Integration: mock the OpenAI client, confirm the simulator sends + the flash-lite default + extra_body, then plumbs through to + ``persist_and_record_usage`` with ``provider='open_router'`` and the + real ``usage.cost`` pulled off ``model_extra``.""" + + def _mock_client(self, fake_resp: ChatCompletion) -> tuple[Any, AsyncMock]: + """Build a fake ``AsyncOpenAI`` client. Same nested-type pattern as + ``copilot/tools/web_search_test.py::_mock_client`` — avoids + MagicMock's auto-child-attr behaviour so the exact ``create`` call + surface is what gets invoked.""" + create_mock = AsyncMock(return_value=fake_resp) + client = type( + "MC", + (), + { + "chat": type( + "C", + (), + {"completions": type("CC", (), {"create": create_mock})()}, + )() + }, + )() + return client, create_mock + + @pytest.mark.asyncio + async def test_passes_default_model_and_tracks_cost(self) -> None: + block = _make_block() + fake_resp = _sim_completion( + content='{"result": "simulated"}', + usage=_sim_usage(prompt_tokens=1100, completion_tokens=220, cost=0.000189), + ) + client, create_mock = self._mock_client(fake_resp) + + with ( + patch( + "backend.executor.simulator.get_openai_client", + return_value=client, + ), + patch( + "backend.executor.simulator.persist_and_record_usage", + new=AsyncMock(return_value=1320), + ) as mock_track, + ): + outputs = [] + async for name, data in simulate_block( + block, {"query": "hello"}, user_id="user-42" + ): + outputs.append((name, data)) + + assert ("result", "simulated") in outputs + + create_kwargs = create_mock.await_args.kwargs + assert create_kwargs["model"] == _DEFAULT_SIMULATOR_MODEL + assert create_kwargs["extra_body"] == {"usage": {"include": True}} + + track_kwargs = mock_track.await_args.kwargs + assert track_kwargs["provider"] == "open_router" + assert track_kwargs["model"] == _DEFAULT_SIMULATOR_MODEL + assert track_kwargs["user_id"] == "user-42" + assert track_kwargs["prompt_tokens"] == 1100 + assert track_kwargs["completion_tokens"] == 220 + assert track_kwargs["cost_usd"] == pytest.approx(0.000189) + assert track_kwargs["session"] is None + assert track_kwargs["log_prefix"] == "[simulator]" + + @pytest.mark.asyncio + async def test_tracks_even_when_cost_absent(self) -> None: + """Provider may omit ``cost`` (e.g. non-OpenRouter proxies). We + still record token counts — ``persist_and_record_usage`` logs the + turn and skips the rate-limit ledger when cost is ``None``.""" + block = _make_block() + fake_resp = _sim_completion( + content='{"result": "ok"}', + usage=_sim_usage(prompt_tokens=100, completion_tokens=20, cost=None), + ) + client, _ = self._mock_client(fake_resp) + + with ( + patch( + "backend.executor.simulator.get_openai_client", + return_value=client, + ), + patch( + "backend.executor.simulator.persist_and_record_usage", + new=AsyncMock(return_value=120), + ) as mock_track, + ): + async for _name, _data in simulate_block( + block, {"query": "x"}, user_id="user-7" + ): + pass + + track_kwargs = mock_track.await_args.kwargs + assert track_kwargs["cost_usd"] is None + assert track_kwargs["user_id"] == "user-7" + assert track_kwargs["provider"] == "open_router" + + @pytest.mark.asyncio + async def test_tracking_failure_does_not_break_simulation(self) -> None: + """Cost-tracking failures are warnings, not simulation failures — + the block output must still flow to the caller.""" + block = _make_block() + fake_resp = _sim_completion( + content='{"result": "simulated"}', + usage=_sim_usage(), + ) + client, _ = self._mock_client(fake_resp) + + with ( + patch( + "backend.executor.simulator.get_openai_client", + return_value=client, + ), + patch( + "backend.executor.simulator.persist_and_record_usage", + new=AsyncMock(side_effect=RuntimeError("redis down")), + ), + ): + outputs = [] + async for name, data in simulate_block( + block, {"query": "hello"}, user_id="user-42" + ): + outputs.append((name, data)) + + assert ("result", "simulated") in outputs diff --git a/autogpt_platform/backend/backend/util/feature_flag.py b/autogpt_platform/backend/backend/util/feature_flag.py index 1e29ff4102..8699fc2eeb 100644 --- a/autogpt_platform/backend/backend/util/feature_flag.py +++ b/autogpt_platform/backend/backend/util/feature_flag.py @@ -48,6 +48,16 @@ class Flag(str, Enum): STRIPE_PRICE_BUSINESS = "stripe-price-id-business" GRAPHITI_MEMORY = "graphiti-memory" + # Copilot model routing — string-valued, returns the model identifier + # (e.g. ``"anthropic/claude-sonnet-4-6"`` or ``"moonshotai/kimi-k2.6"``) + # to use for each cell of the (mode, tier) matrix. Falls back to the + # ``CHAT_*_MODEL`` env/config default when the flag is unset or LD is + # unavailable. Evaluated per user_id so cohorts can be targeted. + COPILOT_FAST_STANDARD_MODEL = "copilot-fast-standard-model" + COPILOT_FAST_ADVANCED_MODEL = "copilot-fast-advanced-model" + COPILOT_THINKING_STANDARD_MODEL = "copilot-thinking-standard-model" + COPILOT_THINKING_ADVANCED_MODEL = "copilot-thinking-advanced-model" + def is_configured() -> bool: """Check if LaunchDarkly is configured with an SDK key.""" diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/useCopilotStream.test.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/useCopilotStream.test.ts new file mode 100644 index 0000000000..e56317bf04 --- /dev/null +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/__tests__/useCopilotStream.test.ts @@ -0,0 +1,177 @@ +import { act, renderHook } from "@testing-library/react"; +import type { UIMessage } from "ai"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +import { useCopilotStream } from "../useCopilotStream"; + +// Capture the args passed to ``useChat`` so tests can invoke onFinish/onError +// directly — that's the only way to drive handleReconnect without a real SSE. +let lastUseChatArgs: { + onFinish?: (args: { isDisconnect?: boolean; isAbort?: boolean }) => void; + onError?: (err: Error) => void; +} | null = null; + +const resumeStreamMock = vi.fn(); +const sdkStopMock = vi.fn(); +const sdkSendMessageMock = vi.fn(); +const setMessagesMock = vi.fn(); + +function resetSdkMocks() { + lastUseChatArgs = null; + resumeStreamMock.mockReset(); + sdkStopMock.mockReset(); + sdkSendMessageMock.mockReset(); + setMessagesMock.mockReset(); +} + +vi.mock("@ai-sdk/react", () => ({ + useChat: (args: unknown) => { + lastUseChatArgs = args as typeof lastUseChatArgs; + return { + messages: [] as UIMessage[], + sendMessage: sdkSendMessageMock, + stop: sdkStopMock, + status: "ready" as const, + error: undefined, + setMessages: setMessagesMock, + resumeStream: resumeStreamMock, + }; + }, +})); + +vi.mock("ai", async () => { + const actual = await vi.importActual("ai"); + return { + ...actual, + DefaultChatTransport: class { + constructor(public opts: unknown) {} + }, + }; +}); + +vi.mock("@tanstack/react-query", () => ({ + useQueryClient: () => ({ invalidateQueries: vi.fn() }), +})); + +vi.mock("@/app/api/__generated__/endpoints/chat/chat", () => ({ + getGetV2GetCopilotUsageQueryKey: () => ["copilot-usage"], + getGetV2GetSessionQueryKey: (id: string) => ["session", id], + postV2CancelSessionTask: vi.fn(), + deleteV2DisconnectSessionStream: vi.fn().mockResolvedValue(undefined), +})); + +vi.mock("@/components/molecules/Toast/use-toast", () => ({ + toast: vi.fn(), +})); + +vi.mock("@/services/environment", () => ({ + environment: { + getAGPTServerBaseUrl: () => "http://localhost", + }, +})); + +vi.mock("../helpers", async () => { + const actual = + await vi.importActual("../helpers"); + return { + ...actual, + getCopilotAuthHeaders: vi.fn().mockResolvedValue({}), + disconnectSessionStream: vi.fn(), + }; +}); + +vi.mock("../useHydrateOnStreamEnd", () => ({ + useHydrateOnStreamEnd: () => undefined, +})); + +function renderStream() { + return renderHook(() => + useCopilotStream({ + sessionId: "sess-1", + hydratedMessages: [], + hasActiveStream: false, + refetchSession: vi.fn().mockResolvedValue({ data: undefined }), + copilotMode: undefined, + copilotModel: undefined, + }), + ); +} + +describe("useCopilotStream — reconnect debounce", () => { + beforeEach(() => { + resetSdkMocks(); + vi.useFakeTimers(); + // Pin Date.now so sinceLastResume math is deterministic. The hook reads + // Date.now() both when stashing lastReconnectResumeAtRef and when + // deciding whether to debounce. + vi.setSystemTime(new Date(2025, 0, 1, 12, 0, 0)); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it("coalesces a burst of disconnect events into one resumeStream call", async () => { + renderStream(); + + // First disconnect — schedules a reconnect at the exponential backoff + // delay (1000ms for attempt #1). + await act(async () => { + await lastUseChatArgs!.onFinish!({ isDisconnect: true }); + }); + + // Fire the scheduled timer → resumeStream runs once and stamps + // lastReconnectResumeAtRef.current = Date.now(). + await act(async () => { + await vi.advanceTimersByTimeAsync(1_000); + }); + expect(resumeStreamMock).toHaveBeenCalledTimes(1); + + // A second disconnect arrives immediately after (still inside the + // 1500ms debounce window) — the debounce path must fire and queue a + // coalesced timer, NOT a fresh resume. + await act(async () => { + await lastUseChatArgs!.onFinish!({ isDisconnect: true }); + }); + expect(resumeStreamMock).toHaveBeenCalledTimes(1); + + // The coalesced timer fires at the window boundary and reschedules a + // real reconnect. Advance past the window AND past the second + // reconnect's backoff (attempt #2 = 2000ms) so resume runs. + await act(async () => { + await vi.advanceTimersByTimeAsync(1_500); + }); + await act(async () => { + await vi.advanceTimersByTimeAsync(2_000); + }); + expect(resumeStreamMock).toHaveBeenCalledTimes(2); + }); + + it("does not debounce a reconnect that arrives after the window closes", async () => { + renderStream(); + + // First reconnect cycle. + await act(async () => { + await lastUseChatArgs!.onFinish!({ isDisconnect: true }); + }); + await act(async () => { + await vi.advanceTimersByTimeAsync(1_000); + }); + expect(resumeStreamMock).toHaveBeenCalledTimes(1); + + // Wait past the debounce window before the next disconnect. + await act(async () => { + await vi.advanceTimersByTimeAsync(2_000); + }); + + // Now a fresh disconnect should go through the normal path (NOT the + // debounce branch) and schedule a backoff of 2000ms (attempt #2). + await act(async () => { + await lastUseChatArgs!.onFinish!({ isDisconnect: true }); + }); + await act(async () => { + await vi.advanceTimersByTimeAsync(2_000); + }); + expect(resumeStreamMock).toHaveBeenCalledTimes(2); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.test.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.test.ts index 91e09efde3..7c61390f1f 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.test.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.test.ts @@ -7,6 +7,7 @@ import { formatNotificationTitle, getSendSuppressionReason, parseSessionIDs, + shouldDebounceReconnect, shouldSuppressDuplicateSend, } from "./helpers"; @@ -466,3 +467,88 @@ describe("deduplicateMessages", () => { expect(result).toHaveLength(2); // duplicate step-start messages are deduped }); }); + +describe("shouldDebounceReconnect", () => { + const WINDOW_MS = 1_500; + + it("returns null for the first reconnect (lastResumeAt === 0)", () => { + expect(shouldDebounceReconnect(0, 10_000, WINDOW_MS)).toBeNull(); + }); + + it("returns null for a negative lastResumeAt sentinel", () => { + // Defensive: a negative value is still treated as "no reconnect yet". + expect(shouldDebounceReconnect(-1, 10_000, WINDOW_MS)).toBeNull(); + }); + + it("returns the remaining delay when now is inside the window", () => { + // 500ms since the last resume — the caller must wait another 1000ms + // before the storm cap reopens. + const remaining = shouldDebounceReconnect(1_000, 1_500, WINDOW_MS); + expect(remaining).toBe(1_000); + }); + + it("coalesces a reconnect that arrives immediately after the previous resume", () => { + // now === lastResumeAt → sinceLastResume === 0, so the full window remains. + const remaining = shouldDebounceReconnect(5_000, 5_000, WINDOW_MS); + expect(remaining).toBe(WINDOW_MS); + }); + + it("returns null when exactly on the window boundary", () => { + // sinceLastResume === windowMs is NOT inside the window — the next + // reconnect should fire immediately. + expect(shouldDebounceReconnect(1_000, 2_500, WINDOW_MS)).toBeNull(); + }); + + it("returns null when the window has elapsed", () => { + expect(shouldDebounceReconnect(1_000, 5_000, WINDOW_MS)).toBeNull(); + }); + + it("returns a small remaining delay at the far edge of the window", () => { + // 1ms before the window closes → 1ms left. + const remaining = shouldDebounceReconnect(1_000, 2_499, WINDOW_MS); + expect(remaining).toBe(1); + }); + + it("collapses a burst of reconnects into one debounced scheduling", () => { + // Simulates the browser tab-throttle storm: three reconnect calls fire + // within a single second after the last resume. Only the first slot + // would actually run; subsequent calls must always be coalesced. + const lastResumeAt = 10_000; + const firstCallRemaining = shouldDebounceReconnect( + lastResumeAt, + 10_100, + WINDOW_MS, + ); + const secondCallRemaining = shouldDebounceReconnect( + lastResumeAt, + 10_200, + WINDOW_MS, + ); + const thirdCallRemaining = shouldDebounceReconnect( + lastResumeAt, + 10_300, + WINDOW_MS, + ); + expect(firstCallRemaining).toBe(1_400); + expect(secondCallRemaining).toBe(1_300); + expect(thirdCallRemaining).toBe(1_200); + }); + + it("allows a reconnect to fire immediately once the window has passed", () => { + // After the window expires, a retry that came in earlier can now fire + // rather than stalling the loop. Guards against the regression that + // motivated the coalesce-instead-of-drop fix. + const lastResumeAt = 10_000; + expect( + shouldDebounceReconnect(lastResumeAt, 10_500, WINDOW_MS), + ).not.toBeNull(); + expect(shouldDebounceReconnect(lastResumeAt, 11_500, WINDOW_MS)).toBeNull(); + }); + + it("honours a custom windowMs value", () => { + // Shouldn't hard-code 1500 anywhere: the helper is generic over the + // window. + expect(shouldDebounceReconnect(1_000, 1_500, 2_000)).toBe(1_500); + expect(shouldDebounceReconnect(1_000, 3_500, 2_000)).toBeNull(); + }); +}); diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts index b1d87a25d2..131a721117 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/helpers.ts @@ -184,6 +184,28 @@ export function disconnectSessionStream(sessionId: string): void { deleteV2DisconnectSessionStream(sessionId).catch(() => {}); } +/** + * Decide whether a reconnect request must be coalesced onto the debounce + * window boundary, rather than firing immediately. + * + * Returns the remaining milliseconds until the window closes (so the caller + * can schedule a `setTimeout` for that delay) when the previous resume + * happened inside the window, or `null` to let the reconnect proceed now. + * + * `lastResumeAt === 0` signals "no reconnect has fired yet in this session" + * — the first reconnect always passes through regardless of `now`. + */ +export function shouldDebounceReconnect( + lastResumeAt: number, + now: number, + windowMs: number, +): number | null { + if (lastResumeAt <= 0) return null; + const sinceLastResume = now - lastResumeAt; + if (sinceLastResume >= windowMs) return null; + return windowMs - sinceLastResume; +} + /** * Deduplicate messages by ID and by consecutive content fingerprint. * diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx index 74aa3153d5..694571de7a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/GenericTool.tsx @@ -313,11 +313,19 @@ function getWebAccordionData( : null; if (results) { + const deep = inp.deep === true; + const noun = deep ? "research source" : "search result"; + const answer = getStringField(output, "answer"); return { - title: `${results.length} search result${results.length === 1 ? "" : "s"}`, + title: `${results.length} ${noun}${results.length === 1 ? "" : "s"}`, description: query ? truncate(query, 80) : undefined, content: (
+ {answer && ( +
+ {answer} +
+ )} {results.map((r, i) => { const title = getStringField(r, "title") ?? "(untitled)"; const href = getStringField(r, "url") ?? ""; diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/__tests__/GenericTool.test.tsx b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/__tests__/GenericTool.test.tsx index 48e0409393..61339eeac2 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/__tests__/GenericTool.test.tsx +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/tools/GenericTool/__tests__/GenericTool.test.tsx @@ -141,6 +141,7 @@ describe("GenericTool", () => { function makeWebSearchPart( results: Array>, query = "kimi k2.6", + answer = "", ): ToolUIPart { return { type: "tool-web_search", @@ -149,6 +150,7 @@ describe("GenericTool", () => { input: { query }, output: { type: "web_search_response", + answer, results, query, search_requests: 1, @@ -254,6 +256,25 @@ describe("GenericTool", () => { expect(normalized).toContain('Searched "kimi k2.6"'); }); + it("renders the synthesised answer above the citations when present", () => { + render( + , + ); + fireEvent.click(screen.getByRole("button", { expanded: false })); + expect( + screen.getByText(/Kimi K2\.6 launched on 2026-04-20/), + ).not.toBeNull(); + }); + it("uses '(untitled)' when a search result has no title", () => { render( | undefined; + return input?.deep === true; +} + export function getAnimationText( part: ToolUIPart, category: ToolCategory, @@ -223,9 +231,11 @@ export function getAnimationText( : "Running command\u2026"; case "web": if (toolName === "WebSearch" || toolName === "web_search") { + const deep = _isDeepWebSearch(part); + const verb = deep ? "Researching" : "Searching"; return shortSummary - ? `Searching "${shortSummary}"` - : "Searching the web\u2026"; + ? `${verb} "${shortSummary}"` + : `${verb} the web\u2026`; } return shortSummary ? `Fetching ${shortSummary}` @@ -285,9 +295,12 @@ export function getAnimationText( return shortSummary ? `Ran: ${shortSummary}` : "Command completed"; case "web": if (toolName === "WebSearch" || toolName === "web_search") { - return shortSummary - ? `Searched "${shortSummary}"` + const deep = _isDeepWebSearch(part); + const verb = deep ? "Researched" : "Searched"; + const completed = deep + ? "Web research completed" : "Web search completed"; + return shortSummary ? `${verb} "${shortSummary}"` : completed; } return shortSummary ? `Fetched ${shortSummary}` @@ -354,9 +367,10 @@ export function getAnimationText( case "bash": return "Command failed"; case "web": - return toolName === "WebSearch" || toolName === "web_search" - ? "Search failed" - : "Fetch failed"; + if (toolName === "WebSearch" || toolName === "web_search") { + return _isDeepWebSearch(part) ? "Research failed" : "Search failed"; + } + return "Fetch failed"; case "browser": return "Browser action failed"; default: diff --git a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts index 2412ff5988..afef20c85a 100644 --- a/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts +++ b/autogpt_platform/frontend/src/app/(platform)/copilot/useCopilotStream.ts @@ -18,6 +18,7 @@ import { resolveInProgressTools, getSendSuppressionReason, disconnectSessionStream, + shouldDebounceReconnect, } from "./helpers"; import type { CopilotLlmModel, CopilotMode } from "./store"; import { useHydrateOnStreamEnd } from "./useHydrateOnStreamEnd"; @@ -25,6 +26,19 @@ import { useHydrateOnStreamEnd } from "./useHydrateOnStreamEnd"; const RECONNECT_BASE_DELAY_MS = 1_000; const RECONNECT_MAX_ATTEMPTS = 3; +/** + * Minimum spacing between successive reconnect attempts. + * `isReconnectScheduledRef` already prevents OVERLAPPING reconnects, but + * tab-throttle / visibility wake bursts can fire `onFinish(isDisconnect)` + * several times inside a single second — each one would schedule a fresh + * reconnect the moment the previous timer cleared the ref. Requests that + * arrive inside this window since the last reconnect's resume are COALESCED: + * scheduled to run at the window boundary rather than dropped, so a + * fast-failing resume (e.g. a 502 on GET /stream that trips `onError` inside + * 500 ms) still retries instead of stalling the retry loop. + */ +const RECONNECT_DEBOUNCE_MS = 1_500; + /** Minimum time the page must have been hidden to trigger a wake re-sync. */ const WAKE_RESYNC_THRESHOLD_MS = 30_000; @@ -110,6 +124,11 @@ export function useCopilotStream({ const isReconnectScheduledRef = useRef(false); const [isReconnectScheduled, setIsReconnectScheduled] = useState(false); const reconnectTimerRef = useRef>(); + // Timestamp of the last reconnect's actual resume call — used together + // with RECONNECT_DEBOUNCE_MS to drop rapid duplicate reconnect requests + // (e.g. visibility throttle firing onFinish(isDisconnect) several times + // in the same second). 0 = no reconnect has fired yet in this session. + const lastReconnectResumeAtRef = useRef(0); const hasShownDisconnectToast = useRef(false); // Set when the user explicitly clicks stop — prevents onError from // triggering a reconnect cycle for the resulting AbortError. @@ -127,6 +146,32 @@ export function useCopilotStream({ function handleReconnect(sid: string) { if (isReconnectScheduledRef.current || !sid) return; + // Debounce: if the previous reconnect resumed within the last + // RECONNECT_DEBOUNCE_MS, COALESCE this request onto the window boundary + // rather than dropping it. Browser tab-throttle bursts can fire + // onFinish(isDisconnect) 2–3 times in a second; without the debounce, + // each fires its own GET /stream, each one replays the Redis stream, + // and the flicker storm is back. Dropping the request silently (the + // previous behaviour) stalled the retry loop when a resume failed + // quickly — e.g. a 502 on GET /stream that trips onError inside 500 ms + // while the 1500 ms window is still open. Scheduling the retry for + // the remaining window preserves both the storm cap and the retry. + const remainingDelay = shouldDebounceReconnect( + lastReconnectResumeAtRef.current, + Date.now(), + RECONNECT_DEBOUNCE_MS, + ); + if (remainingDelay !== null) { + isReconnectScheduledRef.current = true; + setIsReconnectScheduled(true); + reconnectTimerRef.current = setTimeout(() => { + isReconnectScheduledRef.current = false; + setIsReconnectScheduled(false); + handleReconnect(sid); + }, remainingDelay); + return; + } + const nextAttempt = reconnectAttemptsRef.current + 1; if (nextAttempt > RECONNECT_MAX_ATTEMPTS) { setReconnectExhausted(true); @@ -163,6 +208,7 @@ export function useCopilotStream({ } return prev; }); + lastReconnectResumeAtRef.current = Date.now(); resumeStreamRef.current(); }, delay); } @@ -469,6 +515,7 @@ export function useCopilotStream({ setRateLimitMessage(null); hasShownDisconnectToast.current = false; lastSubmittedMsgRef.current = null; + lastReconnectResumeAtRef.current = 0; setReconnectExhausted(false); setIsSyncing(false); hasResumedRef.current.clear(); diff --git a/autogpt_platform/frontend/src/app/api/openapi.json b/autogpt_platform/frontend/src/app/api/openapi.json index cd2e857c8c..cdb683fede 100644 --- a/autogpt_platform/frontend/src/app/api/openapi.json +++ b/autogpt_platform/frontend/src/app/api/openapi.json @@ -2119,7 +2119,8 @@ }, { "$ref": "#/components/schemas/MemoryForgetConfirmResponse" - } + }, + { "$ref": "#/components/schemas/TodoWriteResponse" } ], "title": "Response Getv2[Dummy] Tool Response Type Export For Codegen" } @@ -14648,7 +14649,8 @@ "memory_store", "memory_search", "memory_forget_candidates", - "memory_forget_confirm" + "memory_forget_confirm", + "todo_write" ], "title": "ResponseType", "description": "Types of tool responses." @@ -16810,6 +16812,52 @@ "required": ["timezone"], "title": "TimezoneResponse" }, + "TodoItem": { + "properties": { + "content": { + "type": "string", + "title": "Content", + "description": "Imperative description of the task." + }, + "activeForm": { + "type": "string", + "title": "Activeform", + "description": "Present-continuous form shown while the task is running." + }, + "status": { + "type": "string", + "enum": ["pending", "in_progress", "completed"], + "title": "Status", + "default": "pending" + } + }, + "type": "object", + "required": ["content", "activeForm"], + "title": "TodoItem", + "description": "One entry in a ``TodoWrite`` checklist.\n\nMirrors the schema used by Claude Code's built-in ``TodoWrite`` tool so\nthe frontend's ``GenericTool`` accordion renders baseline-emitted todos\nidentically to SDK-emitted ones." + }, + "TodoWriteResponse": { + "properties": { + "type": { + "$ref": "#/components/schemas/ResponseType", + "default": "todo_write" + }, + "message": { "type": "string", "title": "Message" }, + "session_id": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Session Id" + }, + "todos": { + "items": { "$ref": "#/components/schemas/TodoItem" }, + "type": "array", + "title": "Todos" + } + }, + "type": "object", + "required": ["message"], + "title": "TodoWriteResponse", + "description": "Ack returned by ``TodoWrite``.\n\nThe tool is effectively stateless — the authoritative task list lives in\nthe assistant's latest tool-call arguments, which are replayed from the\ntranscript on each turn. The tool output only needs to confirm that the\nupdate was accepted so the model can proceed." + }, "TokenIntrospectionResult": { "properties": { "active": { "type": "boolean", "title": "Active" },